Commit 6dfe66e9 authored by wenjh's avatar wenjh
Browse files

[Workaround] Force NVTE_FORCE_ROCM_GEMM=1



The acc problem in test_grouped_linear_accuracy and test_grouped_gemm is
because calc test out and ref out using diff kernel.
Make NVTE_FORCE_ROCM_GEMM=1 can force these tests to call rocm gemm using
same kernel.
Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent 7a47930f
......@@ -1573,6 +1573,10 @@ def test_grouped_linear_accuracy(
weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32)
sequential_linear[i].weight.main_grad = weight_i.main_grad.clone()
# Force the sequential_linear and grouped_linear to use hipblaslt rather than hipblas
if IS_HIP_EXTENSION:
os.environ["NVTE_FORCE_ROCM_GEMM"] = "1"
outputs_ref = _test_grouped_linear_accuracy(
sequential_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation
)
......@@ -2270,6 +2274,10 @@ def test_grouped_gemm(shape, dtype, layout, accumulate):
grad = True
single_output = False
# Force the sequential_linear and grouped_linear to use hipblaslt rather than hipblas
if IS_HIP_EXTENSION:
os.environ["NVTE_FORCE_ROCM_GEMM"] = "1"
for i in range(z):
general_gemm(
A[i],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment