Commit 28726eaf authored by wenjh's avatar wenjh
Browse files

Close Env NVTE_FORCE_ROCM_GEMM after tested gemm

parent b8fe26e7
......@@ -1800,6 +1800,8 @@ def test_grouped_linear_accuracy(
fuse_wgrad_accumulation,
delay_wgrad_compute,
)
if IS_HIP_EXTENSION:
os.environ["NVTE_FORCE_ROCM_GEMM"] = "0"
# Shoule be bit-wise match
for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)):
......@@ -2508,6 +2510,8 @@ def test_grouped_gemm(shape, dtype, layout, accumulate):
layout=layout,
single_output=single_output,
)
if IS_HIP_EXTENSION:
os.environ["NVTE_FORCE_ROCM_GEMM"] = "0"
# should be bit-wise match
for o, o_ref in zip(out, out_ref):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment