Unverified Commit 1779c098 authored by L.B.R.'s avatar L.B.R. Committed by GitHub
Browse files

[ROCm] Enable wvSplitK skinny GEMM kernel for RDNA4/gfx1x decode (#34709)


Signed-off-by: default avatarL.B.R. <lbr@mmonad.com>
Co-authored-by: default avatarL.B.R. <lbr@mmonad.com>
parent 44eea10f
This diff is collapsed.
......@@ -160,6 +160,8 @@ def test_rocm_wvsplitkrc_kernel(xnorm, n, k, m, dtype, seed, padded_a, bias_mode
BIAS = torch.rand(m, dtype=dtype, device="cuda") * 2 - 1
elif bias_mode == 2:
BIAS = torch.rand(n, m, dtype=dtype, device="cuda") * 2 - 1
elif bias_mode == 3:
BIAS = torch.rand(1, m, dtype=dtype, device="cuda") * 2 - 1
ref_out = torch.nn.functional.linear(A, B, BIAS)
out = ops.wvSplitKrc(A, B, cu_count, BIAS)
......@@ -224,10 +226,9 @@ def test_rocm_wvsplitk_kernel(
ref_out = torch.nn.functional.linear(A, B, BIAS)
out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS)
if xnorm:
assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-8)
else:
assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-2)
# Accumulation error in fp16 GEMM scales with sqrt(K)
atol = torch.finfo(dtype).eps * math.sqrt(k)
torch.testing.assert_close(out, ref_out, atol=atol, rtol=1e-2)
@pytest.mark.parametrize("xnorm", [False, True])
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock
import pytest
import torch
from vllm.platforms import current_platform
if current_platform.is_cuda():
pytest.skip(
"ROCm skinny GEMM tests are not supported on CUDA.",
allow_module_level=True,
)
from vllm.model_executor.layers import utils
def test_rocm_unquantized_gemm_gfx1x_wvsplitk_path(monkeypatch):
x = torch.randn(1, 64, dtype=torch.float16)
weight = torch.randn(128, 64, dtype=torch.float16)
monkeypatch.setattr(utils, "use_aiter_triton_gemm", lambda *args: False)
monkeypatch.setattr(utils.envs, "VLLM_ROCM_USE_SKINNY_GEMM", True)
monkeypatch.setattr("vllm.platforms.rocm.on_gfx1x", lambda: True)
monkeypatch.setattr("vllm.platforms.rocm.on_gfx9", lambda: False)
monkeypatch.setattr("vllm.platforms.rocm.on_gfx950", lambda: False)
monkeypatch.setattr(utils, "get_cu_count", lambda: 120)
wvsplitk_mock = MagicMock(side_effect=lambda w, x_view, _, __: x_view @ w.t())
monkeypatch.setattr(utils.ops, "wvSplitK", wvsplitk_mock)
llmm1_mock = MagicMock(side_effect=lambda w, x_view, _: x_view @ w.t())
monkeypatch.setattr(utils.ops, "LLMM1", llmm1_mock)
out = utils.rocm_unquantized_gemm_impl(x, weight, None)
ref = torch.nn.functional.linear(x, weight, None)
wvsplitk_mock.assert_called_once()
llmm1_mock.assert_not_called()
assert torch.allclose(out, ref, atol=1e-3, rtol=1e-3)
def test_rocm_unquantized_gemm_gfx1x_n_gt_4_falls_back(monkeypatch):
x = torch.randn(5, 64, dtype=torch.float16)
weight = torch.randn(128, 64, dtype=torch.float16)
monkeypatch.setattr(utils, "use_aiter_triton_gemm", lambda *args: False)
monkeypatch.setattr(utils.envs, "VLLM_ROCM_USE_SKINNY_GEMM", True)
monkeypatch.setattr("vllm.platforms.rocm.on_gfx1x", lambda: True)
monkeypatch.setattr("vllm.platforms.rocm.on_gfx9", lambda: False)
monkeypatch.setattr("vllm.platforms.rocm.on_gfx950", lambda: False)
monkeypatch.setattr(utils, "get_cu_count", lambda: 120)
wvsplitk_mock = MagicMock(side_effect=lambda w, x_view, _, __: x_view @ w.t())
monkeypatch.setattr(utils.ops, "wvSplitK", wvsplitk_mock)
llmm1_mock = MagicMock(side_effect=lambda w, x_view, _: x_view @ w.t())
monkeypatch.setattr(utils.ops, "LLMM1", llmm1_mock)
out = utils.rocm_unquantized_gemm_impl(x, weight, None)
ref = torch.nn.functional.linear(x, weight, None)
wvsplitk_mock.assert_not_called()
llmm1_mock.assert_not_called()
assert torch.allclose(out, ref, atol=1e-3, rtol=1e-3)
def test_rocm_unquantized_gemm_gfx950_wvsplitkrc_path(monkeypatch):
x = torch.randn(16, 1024, dtype=torch.float16)
weight = torch.randn(256, 1024, dtype=torch.float16)
monkeypatch.setattr(utils, "use_aiter_triton_gemm", lambda *args: False)
monkeypatch.setattr(utils.envs, "VLLM_ROCM_USE_SKINNY_GEMM", True)
monkeypatch.setattr("vllm.platforms.rocm.on_gfx1x", lambda: False)
monkeypatch.setattr("vllm.platforms.rocm.on_gfx9", lambda: False)
monkeypatch.setattr("vllm.platforms.rocm.on_gfx950", lambda: True)
monkeypatch.setattr(utils, "get_cu_count", lambda: 120)
wvsplitkrc_mock = MagicMock(side_effect=lambda w, x_view, _, __: x_view @ w.t())
monkeypatch.setattr(utils.ops, "wvSplitKrc", wvsplitkrc_mock)
wvsplitk_mock = MagicMock(side_effect=lambda w, x_view, _, __: x_view @ w.t())
monkeypatch.setattr(utils.ops, "wvSplitK", wvsplitk_mock)
out = utils.rocm_unquantized_gemm_impl(x, weight, None)
ref = torch.nn.functional.linear(x, weight, None)
wvsplitkrc_mock.assert_called_once()
wvsplitk_mock.assert_not_called()
assert torch.allclose(out, ref, atol=1e-3, rtol=1e-3)
......@@ -122,7 +122,7 @@ def use_aiter_triton_gemm(n, m, k, dtype):
def rocm_unquantized_gemm_impl(
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
from vllm.platforms.rocm import on_gfx9, on_gfx950
from vllm.platforms.rocm import on_gfx1x, on_gfx9, on_gfx950
n = x.numel() // x.size(-1)
m = weight.shape[0]
......@@ -169,12 +169,12 @@ def rocm_unquantized_gemm_impl(
use_skinny = (
envs.VLLM_ROCM_USE_SKINNY_GEMM
and on_gfx9()
and (on_gfx9() or on_gfx1x())
and x.dtype in [torch.float16, torch.bfloat16]
and k % 8 == 0
)
if use_skinny is not True:
if not use_skinny:
return torch.nn.functional.linear(x, weight, bias)
x_view = x.reshape(-1, x.size(-1))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment