Unverified Commit 2b84ac66 authored by rasmith's avatar rasmith Committed by GitHub
Browse files

[CI][AMD][BugFix] Use torch.testing.assert_close instead of assert...


[CI][AMD][BugFix] Use torch.testing.assert_close instead of assert torch.allclose in test_rocm_skinny_gemms.py (#34181)
Signed-off-by: default avatarRandall Smith <Randall.Smith@amd.com>
parent 11d3976b
...@@ -155,9 +155,9 @@ def test_rocm_wvsplitkrc_kernel(xnorm, n, k, m, dtype, seed, bias_mode): ...@@ -155,9 +155,9 @@ def test_rocm_wvsplitkrc_kernel(xnorm, n, k, m, dtype, seed, bias_mode):
out = ops.wvSplitKrc(B, A.view(-1, A.size(-1)), cu_count, BIAS) out = ops.wvSplitKrc(B, A.view(-1, A.size(-1)), cu_count, BIAS)
if xnorm: if xnorm:
assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-8) torch.testing.assert_close(out, ref_out, atol=1e-3, rtol=1e-8)
else: else:
assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-2) torch.testing.assert_close(out, ref_out, atol=1e-3, rtol=1e-2)
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_LLMM1) @pytest.mark.parametrize("n,k,m", NKM_FACTORS_LLMM1)
...@@ -177,7 +177,7 @@ def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed): ...@@ -177,7 +177,7 @@ def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed):
ref_out = torch.matmul(A, B.t()) ref_out = torch.matmul(A, B.t())
out = ops.LLMM1(B, A, rows_per_block) out = ops.LLMM1(B, A, rows_per_block)
assert torch.allclose(out, ref_out, rtol=0.01) torch.testing.assert_close(out, ref_out, atol=1e-8, rtol=1e-2)
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK) @pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
...@@ -194,7 +194,7 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed): ...@@ -194,7 +194,7 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
ref_out = torch.nn.functional.linear(A, B) ref_out = torch.nn.functional.linear(A, B)
out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count) out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count)
assert torch.allclose(out, ref_out, rtol=0.01) torch.testing.assert_close(out, ref_out, atol=1e-8, rtol=1e-2)
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK) @pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
...@@ -213,7 +213,7 @@ def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed): ...@@ -213,7 +213,7 @@ def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed):
ref_out = torch.nn.functional.linear(A, B, BIAS) ref_out = torch.nn.functional.linear(A, B, BIAS)
out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS) out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS)
assert torch.allclose(out, ref_out, rtol=0.01) torch.testing.assert_close(out, ref_out, atol=1e-8, rtol=1e-2)
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK) @pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
...@@ -232,7 +232,7 @@ def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed): ...@@ -232,7 +232,7 @@ def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed):
ref_out = torch.nn.functional.linear(A, B, BIAS) ref_out = torch.nn.functional.linear(A, B, BIAS)
out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS) out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS)
assert torch.allclose(out, ref_out, rtol=0.01) torch.testing.assert_close(out, ref_out, atol=1e-8, rtol=1e-2)
@pytest.mark.parametrize("xnorm", [False, True]) @pytest.mark.parametrize("xnorm", [False, True])
...@@ -275,4 +275,4 @@ def test_rocm_wvsplitk_fp8_kernel( ...@@ -275,4 +275,4 @@ def test_rocm_wvsplitk_fp8_kernel(
# wider pytrch thresh for large-K & no xnorm # wider pytrch thresh for large-K & no xnorm
torch.testing.assert_close(out, ref_out, atol=0.07, rtol=5e-2) torch.testing.assert_close(out, ref_out, atol=0.07, rtol=5e-2)
else: else:
torch.testing.assert_close(out, ref_out, atol=0.01, rtol=0.01) torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment