Commit 300892fe authored by zhuwenwen's avatar zhuwenwen
Browse files

update test_cutlass.py

parent c56b26cd
...@@ -128,9 +128,10 @@ def cutlass_int8_gemm_helper(m: int, ...@@ -128,9 +128,10 @@ def cutlass_int8_gemm_helper(m: int,
elif torch_version.startswith("2.4"): elif torch_version.startswith("2.4"):
from tests.kernels.utils import opcheck from tests.kernels.utils import opcheck
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0) torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
# opcheck(torch.ops._C.cutlass_scaled_mm,
opcheck(torch.ops._C.cutlass_scaled_mm, # (out, a, b, scale_a, scale_b, bias))
(out, a, b, scale_a, scale_b, bias)) else:
print(f"PyTorch version {torch_version} is not specifically handled.")
# @pytest.mark.parametrize("m", [1, 16, 32, 64, 128, 256, 512, 222, 100, 33]) # @pytest.mark.parametrize("m", [1, 16, 32, 64, 128, 256, 512, 222, 100, 33])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment