Unverified Commit 7600642e authored by Hashem Hashemi's avatar Hashem Hashemi Committed by GitHub
Browse files

Add padding support to wvSplitK solution for skinny GEMMs (#33762)


Signed-off-by: default avatarHashem Hashemi <hashem.hashemi@amd.com>
parent 1e69c048
This diff is collapsed.
...@@ -30,15 +30,22 @@ NKM_FACTORS_LLMM1 = [ ...@@ -30,15 +30,22 @@ NKM_FACTORS_LLMM1 = [
NKM_FACTORS_WVSPLITK = [ NKM_FACTORS_WVSPLITK = [
# Different batch sizes with key dimensions # Different batch sizes with key dimensions
(1, 16, 16), (1, 32, 16),
(1, 64, 64), (1, 64, 64),
(2, 256, 256), (2, 256, 256),
(3, 1024, 1024), (3, 1024, 1024),
(4, 4096, 4096), (4, 4096, 4096),
(4, 4096, 4096 + 1),
(4, 4096 + 16, 4096),
(4, 4096 + 16, 4096 + 1),
# Extended K values # Extended K values
(1, 9216, 512), (1, 9216, 512),
(2, 10240, 1024), (2, 10240, 1024),
(4, 16384, 8192), (4, 16384, 8192),
(4, 16384 * 2, 8192),
(4, 16384 * 2, 8192 + 1),
(4, 16384 * 2 + 16, 8192),
(4, 16384 * 2 + 16, 8192 + 1),
# Minimum M constraint validation (m >= 8) # Minimum M constraint validation (m >= 8)
(1, 64, 8), (1, 64, 8),
(2, 128, 8), (2, 128, 8),
...@@ -180,59 +187,44 @@ def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed): ...@@ -180,59 +187,44 @@ def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed):
torch.testing.assert_close(out, ref_out, atol=1e-8, rtol=1e-2) torch.testing.assert_close(out, ref_out, atol=1e-8, rtol=1e-2)
@pytest.mark.parametrize("xnorm", [False, True])
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK) @pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm") @pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed): @pytest.mark.parametrize("bias_mode", BIAS_MODES)
torch.manual_seed(seed) @pytest.mark.parametrize("padded_a", [False, True])
cu_count = num_compute_units() @pytest.mark.parametrize("padded_b", [False, True])
def test_rocm_wvsplitk_kernel(
A = torch.rand(n, k, dtype=dtype, device="cuda") - 0.5 xnorm, n, k, m, dtype, seed, bias_mode, padded_a, padded_b
B = torch.rand(m, k, dtype=dtype, device="cuda") - 0.5 ):
ref_out = torch.nn.functional.linear(A, B)
out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count)
torch.testing.assert_close(out, ref_out, atol=1e-8, rtol=1e-2)
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed) torch.manual_seed(seed)
cu_count = num_compute_units() cu_count = num_compute_units()
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas xavier = (
A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier math.sqrt(2 / k) if xnorm else 1
B = (torch.rand(m, k, dtype=dtype, device="cuda") - 0.5) * xavier ) # normalize to avoid large output-bias deltas
BIAS = torch.rand(m, dtype=dtype, device="cuda") - 0.5 A = (torch.rand(n, k, dtype=dtype, device="cuda") * 2 - 1) * xavier
B = (torch.rand(m, k, dtype=dtype, device="cuda") * 2 - 1) * xavier
ref_out = torch.nn.functional.linear(A, B, BIAS)
out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS)
torch.testing.assert_close(out, ref_out, atol=1e-8, rtol=1e-2)
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK) BIAS = None
@pytest.mark.parametrize("dtype", DTYPES) if bias_mode == 1:
@pytest.mark.parametrize("seed", SEEDS) BIAS = torch.rand(m, dtype=dtype, device="cuda") * 2 - 1
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm") elif bias_mode == 2:
def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed): BIAS = torch.rand(n, m, dtype=dtype, device="cuda") * 2 - 1
torch.manual_seed(seed)
cu_count = num_compute_units()
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas if padded_a:
A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier A = pad_fp8(A)
B = (torch.rand(m, k, dtype=dtype, device="cuda") - 0.5) * xavier if padded_b:
BIAS = torch.rand(n, m, dtype=dtype, device="cuda") - 0.5 B = pad_fp8(B)
ref_out = torch.nn.functional.linear(A, B, BIAS) ref_out = torch.nn.functional.linear(A, B, BIAS)
out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS) out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS)
torch.testing.assert_close(out, ref_out, atol=1e-8, rtol=1e-2) if xnorm:
assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-8)
else:
assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-2)
@pytest.mark.parametrize("xnorm", [False, True]) @pytest.mark.parametrize("xnorm", [False, True])
......
...@@ -191,7 +191,6 @@ def rocm_unquantized_gemm_impl( ...@@ -191,7 +191,6 @@ def rocm_unquantized_gemm_impl(
and on_gfx9() and on_gfx9()
and x.dtype in [torch.float16, torch.bfloat16] and x.dtype in [torch.float16, torch.bfloat16]
and k % 8 == 0 and k % 8 == 0
and x.is_contiguous()
) )
if use_skinny is not True: if use_skinny is not True:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment