test_rocm_skinny_gemms.py 6.39 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
import math

5
6
7
8
9
10
11
12
import pytest
import torch

import vllm._custom_ops as ops
from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant
from vllm.platforms import current_platform

DTYPES = [torch.bfloat16, torch.float16]
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
# Specific (N, K, M) combinations for targeted testing
NKM_FACTORS_LLMM1 = [
    # Small, medium, large cases
    (1, 8, 16),
    (1, 32, 64),
    (1, 128, 256),
    (1, 512, 1024),
    (1, 2048, 4096),
    # Edge cases with specific K sizes
    (1, 6144, 1024),
    (1, 8192, 2048),
    # Very large case
    (1, 4096, 8192),
]

NKM_FACTORS_WVSPLITK = [
    # Different batch sizes with key dimensions
    (1, 16, 16),
    (1, 64, 64),
    (2, 256, 256),
    (3, 1024, 1024),
    (4, 4096, 4096),
    # Extended K values
    (1, 9216, 512),
    (2, 10240, 1024),
    (4, 16384, 8192),
    # Minimum M constraint validation (m >= 8)
    (1, 64, 8),
    (2, 128, 8),
    (4, 256, 8),
]

NKM_FACTORS_WVSPLITK_FP8 = [
    # FP8-specific cases with K % 16 == 0
    (1, 16, 16),
    (1, 64, 64),
    (2, 512, 512),
    (3, 2048, 2048),
    (4, 4096, 4096),
52
    (4, 16400, 2048),
53
54
55
56
57
58
    # Extended FP8 dimensions not covered by WVSPLITK
    (1, 14336, 1024),
    (2, 24576, 2048),
    (4, 32768, 28672),
]

59
60
61
SEEDS = [0]


62
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_LLMM1)
63
64
65
66
67
68
69
70
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("rows_per_block", [2, 4, 8, 16])
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not current_platform.is_rocm(),
                    reason="only test for rocm")
@torch.inference_mode()
def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed):
    torch.manual_seed(seed)
71
72
73
    #TODO: Zero-centering the inputs causes errors for LLMM1!
    #      Without that the numbers quickly saturate, and may
    #      be giving false matches.
74
75
76
77
78
79
80
81
82
    A = torch.rand(n, k, dtype=dtype, device="cuda")
    B = torch.rand(m, k, dtype=dtype, device="cuda")

    ref_out = torch.matmul(A, B.t())
    out = ops.LLMM1(B, A, rows_per_block)

    assert torch.allclose(out, ref_out, rtol=0.01)


83
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
84
85
86
87
88
89
90
91
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not current_platform.is_rocm(),
                    reason="only test for rocm")
def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
    torch.manual_seed(seed)
    cu_count = current_platform.get_cu_count()

92
93
    A = torch.rand(n, k, dtype=dtype, device="cuda") - .5
    B = torch.rand(m, k, dtype=dtype, device="cuda") - .5
94

95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
    ref_out = torch.nn.functional.linear(A, B)
    out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count)

    assert torch.allclose(out, ref_out, rtol=0.01)


@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not current_platform.is_rocm(),
                    reason="only test for rocm")
def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed):
    torch.manual_seed(seed)
    cu_count = current_platform.get_cu_count()

    xavier = math.sqrt(2 / k)  # normalize to avoid large output-bias deltas
    A = (torch.rand(n, k, dtype=dtype, device="cuda") - .5) * xavier
    B = (torch.rand(m, k, dtype=dtype, device="cuda") - .5) * xavier
    BIAS = torch.rand(m, dtype=dtype, device="cuda") - .5

    ref_out = torch.nn.functional.linear(A, B, BIAS)
    out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS)

    assert torch.allclose(out, ref_out, rtol=0.01)


@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not current_platform.is_rocm(),
                    reason="only test for rocm")
def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed):
    torch.manual_seed(seed)
    cu_count = current_platform.get_cu_count()

    xavier = math.sqrt(2 / k)  # normalize to avoid large output-bias deltas
    A = (torch.rand(n, k, dtype=dtype, device="cuda") - .5) * xavier
    B = (torch.rand(m, k, dtype=dtype, device="cuda") - .5) * xavier
    BIAS = torch.rand(n, m, dtype=dtype, device="cuda") - .5

    ref_out = torch.nn.functional.linear(A, B, BIAS)
    out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS)
137
138
139
140

    assert torch.allclose(out, ref_out, rtol=0.01)


141
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8)
142
143
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
144
145
146
@pytest.mark.skipif(
    not (current_platform.is_rocm() and current_platform.supports_fp8()),
    reason="only test for rocm fp8")
147
148
149
def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
    torch.manual_seed(seed)

150
151
    A = torch.rand(n, k, device="cuda") - 0.5
    B = torch.rand(m, k, device="cuda") - 0.5
152
153
154
155
156
157
158
159
160
161
162
163
164

    A, scale_a = ref_dynamic_per_tensor_fp8_quant(A)
    B, scale_b = ref_dynamic_per_tensor_fp8_quant(B)

    ref_out = torch._scaled_mm(A,
                               B.t(),
                               out_dtype=dtype,
                               scale_a=scale_a,
                               scale_b=scale_b)
    out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b,
                        current_platform.get_cu_count())

    assert torch.allclose(out, ref_out, rtol=0.01)
165
166
167
168
169
170
171
172


@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(
    not (current_platform.is_rocm() and current_platform.supports_fp8()),
    reason="only test for rocm fp8")
173
def test_rocm_wvsplitk_fp8_bias1D_kernel(n, k, m, dtype, seed):
174
175
    torch.manual_seed(seed)

176
177
178
179
    xavier = math.sqrt(2 / k)  # normalize to avoid large output-bias deltas
    A = (torch.rand(n, k, device="cuda") - .5) * xavier
    B = (torch.rand(m, k, device="cuda") - .5) * xavier
    BIAS = torch.rand(m, dtype=dtype, device="cuda") - .5
180
181
182
183
184
185
186
187
188

    A, scale_a = ref_dynamic_per_tensor_fp8_quant(A)
    B, scale_b = ref_dynamic_per_tensor_fp8_quant(B)

    ref_out = torch._scaled_mm(A,
                               B.t(),
                               out_dtype=dtype,
                               scale_a=scale_a,
                               scale_b=scale_b,
189
190
191
192
193
                               bias=BIAS)
    out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b,
                        current_platform.get_cu_count(), BIAS)

    assert torch.allclose(out, ref_out, rtol=0.01)