Commit 5f45b0b7 authored by Gregory Shtrasberg's avatar Gregory Shtrasberg Committed by khluu
Browse files

[Bugfix][ROCm] Fixing the skinny gemm dispatch logic from #32831 (#33366)


Signed-off-by: default avatarGregory Shtrasberg <Gregory.Shtrasberg@amd.com>
(cherry picked from commit 31aedfe7)
parent a2dba556
...@@ -87,6 +87,13 @@ NKM_FACTORS_WVSPLITK_FP8 = [ ...@@ -87,6 +87,13 @@ NKM_FACTORS_WVSPLITK_FP8 = [
SEEDS = [0] SEEDS = [0]
def pad_weights_fp8(weight):
num_pad = 256 // weight.element_size()
import torch.nn.functional as F
return F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITKRC) @pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITKRC)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
...@@ -191,11 +198,12 @@ def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed): ...@@ -191,11 +198,12 @@ def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed):
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8) @pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("padded", [False, True])
@pytest.mark.skipif( @pytest.mark.skipif(
not (current_platform.is_rocm() and current_platform.supports_fp8()), not (current_platform.is_rocm() and current_platform.supports_fp8()),
reason="only test for rocm fp8", reason="only test for rocm fp8",
) )
def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed): def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed, padded):
torch.manual_seed(seed) torch.manual_seed(seed)
A = torch.rand(n, k, device="cuda") - 0.5 A = torch.rand(n, k, device="cuda") - 0.5
...@@ -203,6 +211,8 @@ def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed): ...@@ -203,6 +211,8 @@ def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
A, scale_a = ref_dynamic_per_tensor_fp8_quant(A) A, scale_a = ref_dynamic_per_tensor_fp8_quant(A)
B, scale_b = ref_dynamic_per_tensor_fp8_quant(B) B, scale_b = ref_dynamic_per_tensor_fp8_quant(B)
if padded:
B = pad_weights_fp8(B)
ref_out = torch._scaled_mm( ref_out = torch._scaled_mm(
A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b
...@@ -222,11 +232,12 @@ def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed): ...@@ -222,11 +232,12 @@ def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8) @pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("padded", [False, True])
@pytest.mark.skipif( @pytest.mark.skipif(
not (current_platform.is_rocm() and current_platform.supports_fp8()), not (current_platform.is_rocm() and current_platform.supports_fp8()),
reason="only test for rocm fp8", reason="only test for rocm fp8",
) )
def test_rocm_wvsplitk_fp8_bias1D_kernel(n, k, m, dtype, seed): def test_rocm_wvsplitk_fp8_bias1D_kernel(n, k, m, dtype, seed, padded):
torch.manual_seed(seed) torch.manual_seed(seed)
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
...@@ -236,6 +247,8 @@ def test_rocm_wvsplitk_fp8_bias1D_kernel(n, k, m, dtype, seed): ...@@ -236,6 +247,8 @@ def test_rocm_wvsplitk_fp8_bias1D_kernel(n, k, m, dtype, seed):
A, scale_a = ref_dynamic_per_tensor_fp8_quant(A) A, scale_a = ref_dynamic_per_tensor_fp8_quant(A)
B, scale_b = ref_dynamic_per_tensor_fp8_quant(B) B, scale_b = ref_dynamic_per_tensor_fp8_quant(B)
if padded:
B = pad_weights_fp8(B)
ref_out = torch._scaled_mm( ref_out = torch._scaled_mm(
A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b, bias=BIAS A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b, bias=BIAS
......
...@@ -2034,35 +2034,20 @@ def selective_scan_fwd( ...@@ -2034,35 +2034,20 @@ def selective_scan_fwd(
) )
# NOTE: The wvSplitK kernel (and all of the kernels in skinny_gemms.cu)
# are unable to properly handle non-contiguous
# tensors. It might be a good TODO(rasmith) to augment these kernels
# to be able to handle non-contiguous kernels for better performance.
def rocm_enforce_contiguous_skinny_gemm_inputs(
a: torch.Tensor, b: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
a = a.contiguous() # no-op if already contiguous, else clone
b = b.contiguous() # no-op if already contiguous, else clone
return a, b
# ROCm skinny gemms # ROCm skinny gemms
def LLMM1(a: torch.Tensor, b: torch.Tensor, rows_per_block: int) -> torch.Tensor: def LLMM1(a: torch.Tensor, b: torch.Tensor, rows_per_block: int) -> torch.Tensor:
a, b = rocm_enforce_contiguous_skinny_gemm_inputs(a, b)
return torch.ops._rocm_C.LLMM1(a, b, rows_per_block) return torch.ops._rocm_C.LLMM1(a, b, rows_per_block)
def wvSplitK( def wvSplitK(
a: torch.Tensor, b: torch.Tensor, cu_count: int, bias: torch.Tensor = None a: torch.Tensor, b: torch.Tensor, cu_count: int, bias: torch.Tensor = None
) -> torch.Tensor: ) -> torch.Tensor:
a, b = rocm_enforce_contiguous_skinny_gemm_inputs(a, b)
return torch.ops._rocm_C.wvSplitK(a, b, bias, cu_count) return torch.ops._rocm_C.wvSplitK(a, b, bias, cu_count)
def wvSplitKrc( def wvSplitKrc(
a: torch.Tensor, b: torch.Tensor, cu_count: int, bias: torch.Tensor = None a: torch.Tensor, b: torch.Tensor, cu_count: int, bias: torch.Tensor = None
) -> torch.Tensor: ) -> torch.Tensor:
a, b = rocm_enforce_contiguous_skinny_gemm_inputs(a, b)
return torch.ops._rocm_C.wvSplitKrc(a, b, bias, cu_count) return torch.ops._rocm_C.wvSplitKrc(a, b, bias, cu_count)
...@@ -2075,7 +2060,6 @@ def wvSplitKQ( ...@@ -2075,7 +2060,6 @@ def wvSplitKQ(
cu_count: int, cu_count: int,
bias: torch.Tensor = None, bias: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
a, b = rocm_enforce_contiguous_skinny_gemm_inputs(a, b)
out = torch.empty((b.shape[0], a.shape[0]), dtype=out_dtype, device=b.device) out = torch.empty((b.shape[0], a.shape[0]), dtype=out_dtype, device=b.device)
torch.ops._rocm_C.wvSplitKQ(a, b, bias, out, scale_a, scale_b, cu_count) torch.ops._rocm_C.wvSplitKQ(a, b, bias, out, scale_a, scale_b, cu_count)
return out return out
......
...@@ -28,6 +28,7 @@ def rocm_per_tensor_float_w8a8_scaled_mm_impl( ...@@ -28,6 +28,7 @@ def rocm_per_tensor_float_w8a8_scaled_mm_impl(
A.shape[0] == 1 A.shape[0] == 1
and B.shape[1] % 16 == 0 and B.shape[1] % 16 == 0
and ((bias is None) or (bias.dtype == out_dtype)) and ((bias is None) or (bias.dtype == out_dtype))
and A.is_contiguous()
): ):
output = ops.wvSplitKQ( output = ops.wvSplitKQ(
B.t(), B.t(),
......
...@@ -146,6 +146,7 @@ def rocm_unquantized_gemm_impl( ...@@ -146,6 +146,7 @@ def rocm_unquantized_gemm_impl(
and n <= 128 and n <= 128
and k > 512 and k > 512
and math.ceil(k / 512) * math.ceil(m / 16) < get_cu_count() and math.ceil(k / 512) * math.ceil(m / 16) < get_cu_count()
and x.is_contiguous()
) )
# k == 2880 and (m == 640 or m == 128)) # k == 2880 and (m == 640 or m == 128))
) )
...@@ -165,6 +166,7 @@ def rocm_unquantized_gemm_impl( ...@@ -165,6 +166,7 @@ def rocm_unquantized_gemm_impl(
and on_gfx9() and on_gfx9()
and x.dtype in [torch.float16, torch.bfloat16] and x.dtype in [torch.float16, torch.bfloat16]
and k % 8 == 0 and k % 8 == 0
and x.is_contiguous()
) )
if use_skinny is not True: if use_skinny is not True:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment