Unverified Commit 7a103043 authored by Hashem Hashemi's avatar Hashem Hashemi Committed by GitHub
Browse files

Atomics Reduce Counting Optimization for SplitK Skinny GEMMs. (#29843)


Signed-off-by: default avatarHashem Hashemi <hashem.hashemi@amd.com>
parent 9fd918e5
...@@ -9,6 +9,10 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b, ...@@ -9,6 +9,10 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
const std::optional<at::Tensor>& in_bias, const std::optional<at::Tensor>& in_bias,
const int64_t CuCount); const int64_t CuCount);
torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b,
const std::optional<at::Tensor>& in_bias,
const int64_t CuCount);
void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b, void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b,
const std::optional<at::Tensor>& in_bias, at::Tensor& out_c, const std::optional<at::Tensor>& in_bias, at::Tensor& out_c,
const at::Tensor& scale_a, const at::Tensor& scale_b, const at::Tensor& scale_a, const at::Tensor& scale_b,
......
This diff is collapsed.
...@@ -26,6 +26,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { ...@@ -26,6 +26,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
"Tensor"); "Tensor");
rocm_ops.impl("wvSplitK", torch::kCUDA, &wvSplitK); rocm_ops.impl("wvSplitK", torch::kCUDA, &wvSplitK);
// Custom gemm op for skinny matrix-matrix multiplication
rocm_ops.def(
"wvSplitKrc(Tensor in_a, Tensor in_b, Tensor? in_bias, int CuCount) -> "
"Tensor");
rocm_ops.impl("wvSplitKrc", torch::kCUDA, &wvSplitKrc);
// wvSplitK for fp8 // wvSplitK for fp8
rocm_ops.def( rocm_ops.def(
"wvSplitKQ(Tensor in_a, Tensor in_b, Tensor? in_bias, Tensor! out_c, " "wvSplitKQ(Tensor in_a, Tensor in_b, Tensor? in_bias, Tensor! out_c, "
......
...@@ -8,9 +8,11 @@ import torch ...@@ -8,9 +8,11 @@ import torch
import vllm._custom_ops as ops import vllm._custom_ops as ops
from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.rocm import on_gfx950
from vllm.utils.platform_utils import get_cu_count from vllm.utils.platform_utils import get_cu_count
DTYPES = [torch.bfloat16, torch.float16] DTYPES = [torch.bfloat16, torch.float16]
BIAS_MODES = [0, 1, 2]
# Specific (N, K, M) combinations for targeted testing # Specific (N, K, M) combinations for targeted testing
NKM_FACTORS_LLMM1 = [ NKM_FACTORS_LLMM1 = [
# Small, medium, large cases # Small, medium, large cases
...@@ -43,6 +45,31 @@ NKM_FACTORS_WVSPLITK = [ ...@@ -43,6 +45,31 @@ NKM_FACTORS_WVSPLITK = [
(4, 256, 8), (4, 256, 8),
] ]
NKM_FACTORS_WVSPLITKRC = [
(16, 2880, 128),
(16, 2880, 640),
(17, 2880, 128),
(17, 2880, 640),
(25, 2880, 128),
(25, 2880, 640),
(31, 2880, 128),
(31, 2880, 640),
(32, 2880, 128),
(32, 2880, 640),
(40, 2880, 128),
(40, 2880, 640),
(60, 2880, 128),
(60, 2880, 640),
(64, 2880, 128),
(64, 2880, 640),
(81, 2880, 128),
(81, 2880, 640),
(98, 2880, 128),
(98, 2880, 640),
(128, 2880, 128),
(128, 2880, 640),
]
NKM_FACTORS_WVSPLITK_FP8 = [ NKM_FACTORS_WVSPLITK_FP8 = [
# FP8-specific cases with K % 16 == 0 # FP8-specific cases with K % 16 == 0
(1, 16, 16), (1, 16, 16),
...@@ -60,6 +87,32 @@ NKM_FACTORS_WVSPLITK_FP8 = [ ...@@ -60,6 +87,32 @@ NKM_FACTORS_WVSPLITK_FP8 = [
SEEDS = [0] SEEDS = [0]
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITKRC)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("bias_mode", BIAS_MODES)
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
@pytest.mark.skipif(not on_gfx950(), reason="only meant for gfx950")
def test_rocm_wvsplitkrc_kernel(n, k, m, dtype, seed, bias_mode):
torch.manual_seed(seed)
cu_count = get_cu_count()
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier
B = (torch.rand(m, k, dtype=dtype, device="cuda") - 0.5) * xavier
BIAS = None
if bias_mode == 1:
BIAS = torch.rand(m, dtype=dtype, device="cuda") - 0.5
elif bias_mode == 2:
BIAS = torch.rand(n, m, dtype=dtype, device="cuda") - 0.5
ref_out = torch.nn.functional.linear(A, B, BIAS)
out = ops.wvSplitKrc(B, A.view(-1, A.size(-1)), cu_count, BIAS)
assert torch.allclose(out, ref_out, rtol=0.01)
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_LLMM1) @pytest.mark.parametrize("n,k,m", NKM_FACTORS_LLMM1)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("rows_per_block", [2, 4, 8, 16]) @pytest.mark.parametrize("rows_per_block", [2, 4, 8, 16])
......
...@@ -2072,6 +2072,12 @@ def wvSplitK( ...@@ -2072,6 +2072,12 @@ def wvSplitK(
return torch.ops._rocm_C.wvSplitK(a, b, bias, cu_count) return torch.ops._rocm_C.wvSplitK(a, b, bias, cu_count)
def wvSplitKrc(
a: torch.Tensor, b: torch.Tensor, cu_count: int, bias: torch.Tensor = None
) -> torch.Tensor:
return torch.ops._rocm_C.wvSplitKrc(a, b, bias, cu_count)
def wvSplitKQ( def wvSplitKQ(
a: torch.Tensor, a: torch.Tensor,
b: torch.Tensor, b: torch.Tensor,
......
...@@ -129,12 +129,32 @@ def use_aiter_triton_gemm(n, m, k, dtype): ...@@ -129,12 +129,32 @@ def use_aiter_triton_gemm(n, m, k, dtype):
def rocm_unquantized_gemm_impl( def rocm_unquantized_gemm_impl(
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.platforms.rocm import on_gfx9 from vllm.platforms.rocm import on_gfx9, on_gfx950
n = x.numel() / x.size(-1) n = x.numel() / x.size(-1)
m = weight.shape[0] m = weight.shape[0]
k = weight.shape[1] k = weight.shape[1]
import math
use_skinny_reduce_counting = (
envs.VLLM_ROCM_USE_SKINNY_GEMM
and on_gfx950()
and x.dtype in [torch.float16, torch.bfloat16]
and (
n >= 16
and n <= 128
and k > 512
and math.ceil(k / 512) * math.ceil(m / 16) < get_cu_count()
)
# k == 2880 and (m == 640 or m == 128))
)
if use_skinny_reduce_counting:
cu_count = get_cu_count()
x_view = x.reshape(-1, x.size(-1))
out = ops.wvSplitKrc(weight, x_view, cu_count, bias)
return out.reshape(*x.shape[:-1], weight.shape[0])
if use_aiter_triton_gemm(n, m, k, x.dtype): if use_aiter_triton_gemm(n, m, k, x.dtype):
from aiter.ops.triton.gemm_a16w16 import gemm_a16w16 from aiter.ops.triton.gemm_a16w16 import gemm_a16w16
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment