Unverified Commit ed17f54c authored by Hashem Hashemi's avatar Hashem Hashemi Committed by GitHub
Browse files

Perf tuning and expansion of cases covered for wvSplitKrc (#33493)


Signed-off-by: default avatarHashem Hashemi <hashem.hashemi@amd.com>
parent 860981d8
This diff is collapsed.
...@@ -45,31 +45,28 @@ NKM_FACTORS_WVSPLITK = [ ...@@ -45,31 +45,28 @@ NKM_FACTORS_WVSPLITK = [
(4, 256, 8), (4, 256, 8),
] ]
NKM_FACTORS_WVSPLITKRC = [ N_FACTORS_WVSPLITKRC = [
(16, 2880, 128), 13,
(16, 2880, 640), 16,
(17, 2880, 128), 17,
(17, 2880, 640), 25,
(25, 2880, 128), 29,
(25, 2880, 640), 31,
(31, 2880, 128), 32,
(31, 2880, 640), 41,
(32, 2880, 128), 51,
(32, 2880, 640), 64,
(40, 2880, 128), 71,
(40, 2880, 640), 81,
(60, 2880, 128), 91,
(60, 2880, 640), 103,
(64, 2880, 128), 117,
(64, 2880, 640), 128,
(81, 2880, 128),
(81, 2880, 640),
(98, 2880, 128),
(98, 2880, 640),
(128, 2880, 128),
(128, 2880, 640),
] ]
K_FACTORS_WVSPLITKRC = [2880, 2880 + 8, 3072, 3072 + 8]
M_FACTORS_WVSPLITKRC = [128, 128 + 16, 256, 256 + 16, 640, 640 + 16]
NKM_FACTORS_WVSPLITK_FP8 = [ NKM_FACTORS_WVSPLITK_FP8 = [
# FP8-specific cases with K % 16 == 0 # FP8-specific cases with K % 16 == 0
(1, 16, 16), (1, 16, 16),
...@@ -113,30 +110,54 @@ def pad_fp8(weight): ...@@ -113,30 +110,54 @@ def pad_fp8(weight):
return F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad] return F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITKRC) @pytest.mark.parametrize("xnorm", [False, True])
@pytest.mark.parametrize("n", N_FACTORS_WVSPLITKRC)
@pytest.mark.parametrize("k", K_FACTORS_WVSPLITKRC)
@pytest.mark.parametrize("m", M_FACTORS_WVSPLITKRC)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("bias_mode", BIAS_MODES) @pytest.mark.parametrize("bias_mode", BIAS_MODES)
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm") @pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
@pytest.mark.skipif(not on_gfx950(), reason="only meant for gfx950") @pytest.mark.skipif(not on_gfx950(), reason="only meant for gfx950")
def test_rocm_wvsplitkrc_kernel(n, k, m, dtype, seed, bias_mode): def test_rocm_wvsplitkrc_kernel(xnorm, n, k, m, dtype, seed, bias_mode):
torch.manual_seed(seed) torch.manual_seed(seed)
cu_count = get_cu_count() cu_count = get_cu_count()
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas # Next ^2 of n
A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier N_p2 = 1 << (n - 1).bit_length()
B = (torch.rand(m, k, dtype=dtype, device="cuda") - 0.5) * xavier # With 64 Ms per CU (each of 4 SIMDs working on a 16x16 tile),
# and each working on a 512-shard of K, how many CUs would we need?
rndup_cus = ((m + 64 - 1) // 64) * ((k + 512 - 1) // 512)
# How many of 4 waves in a group can work on same 16 Ms at same time?
# This reduces the Ms each group works on, i.e. increasing the number of CUs needed.
GrpsShrB = min(N_p2 // 16, 4)
# Given the above, how many CUs would we need?
CuNeeded = rndup_cus * GrpsShrB
# candidate for atomic reduce count splitk?
fits_wvsplitkrc = CuNeeded <= cu_count
if not fits_wvsplitkrc:
pytest.skip("Too large for wvSplitKrc")
xavier = (
math.sqrt(2 / k) if xnorm else 1
) # normalize to avoid large output-bias deltas
A = (torch.rand(n, k, dtype=dtype, device="cuda") * 2 - 1) * xavier
B = (torch.rand(m, k, dtype=dtype, device="cuda") * 2 - 1) * xavier
BIAS = None BIAS = None
if bias_mode == 1: if bias_mode == 1:
BIAS = torch.rand(m, dtype=dtype, device="cuda") - 0.5 BIAS = torch.rand(m, dtype=dtype, device="cuda") * 2 - 1
elif bias_mode == 2: elif bias_mode == 2:
BIAS = torch.rand(n, m, dtype=dtype, device="cuda") - 0.5 BIAS = torch.rand(n, m, dtype=dtype, device="cuda") * 2 - 1
ref_out = torch.nn.functional.linear(A, B, BIAS) ref_out = torch.nn.functional.linear(A, B, BIAS)
out = ops.wvSplitKrc(B, A.view(-1, A.size(-1)), cu_count, BIAS) out = ops.wvSplitKrc(B, A.view(-1, A.size(-1)), cu_count, BIAS)
assert torch.allclose(out, ref_out, rtol=0.01) if xnorm:
assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-8)
else:
assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-2)
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_LLMM1) @pytest.mark.parametrize("n,k,m", NKM_FACTORS_LLMM1)
......
...@@ -145,32 +145,43 @@ def rocm_unquantized_gemm_impl( ...@@ -145,32 +145,43 @@ def rocm_unquantized_gemm_impl(
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.platforms.rocm import on_gfx9, on_gfx950 from vllm.platforms.rocm import on_gfx9, on_gfx950
n = x.numel() / x.size(-1) n = x.numel() // x.size(-1)
m = weight.shape[0] m = weight.shape[0]
k = weight.shape[1] k = weight.shape[1]
import math cu_count = get_cu_count()
if use_aiter_triton_gemm(n, m, k, x.dtype): if use_aiter_triton_gemm(n, m, k, x.dtype):
from aiter.ops.triton.gemm_a16w16 import gemm_a16w16 from aiter.ops.triton.gemm_a16w16 import gemm_a16w16
return gemm_a16w16(x, weight, bias) return gemm_a16w16(x, weight, bias)
# Next ^2 of n
N_p2 = 1 << (n - 1).bit_length()
# With 64 Ms per CU (each of 4 SIMDs working on a 16x16 tile),
# and each working on a 512-shard of K, how many CUs would we need?
rndup_cus = ((m + 64 - 1) // 64) * ((k + 512 - 1) // 512)
# How many of 4 waves in a group can work on same 16 Ms at same time?
# This reduces the Ms each group works on, i.e. increasing the number of CUs needed.
GrpsShrB = min(N_p2 // 16, 4)
# Given the above, how many CUs would we need?
CuNeeded = rndup_cus * GrpsShrB
# candidate for atomic reduce count splitk?
fits_wvsplitkrc = CuNeeded <= cu_count
use_skinny_reduce_counting = ( use_skinny_reduce_counting = (
envs.VLLM_ROCM_USE_SKINNY_GEMM envs.VLLM_ROCM_USE_SKINNY_GEMM
and on_gfx950() and on_gfx950()
and x.dtype in [torch.float16, torch.bfloat16] and x.dtype in [torch.float16, torch.bfloat16]
and ( and (
n >= 16 10 <= n <= 128
and n <= 128 and k % 8 == 0
and k > 512 and k > 512
and math.ceil(k / 512) * math.ceil(m / 16) < get_cu_count() and m % 16 == 0
and fits_wvsplitkrc
and x.is_contiguous() and x.is_contiguous()
) )
# k == 2880 and (m == 640 or m == 128))
) )
if use_skinny_reduce_counting: if use_skinny_reduce_counting:
cu_count = get_cu_count()
x_view = x.reshape(-1, x.size(-1)) x_view = x.reshape(-1, x.size(-1))
out = ops.wvSplitKrc(weight, x_view, cu_count, bias) out = ops.wvSplitKrc(weight, x_view, cu_count, bias)
return out.reshape(*x.shape[:-1], weight.shape[0]) return out.reshape(*x.shape[:-1], weight.shape[0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment