"vllm/vscode:/vscode.git/clone" did not exist on "3a92c6f3b5f010453d81f871f642839c15402cda"
test_rocm_skinny_gemms.py 8.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
import math

5
6
7
8
9
10
import pytest
import torch

import vllm._custom_ops as ops
from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant
from vllm.platforms import current_platform
11
from vllm.platforms.rocm import on_gfx950
12
from vllm.utils.platform_utils import get_cu_count
13
14

DTYPES = [torch.bfloat16, torch.float16]
15
BIAS_MODES = [0, 1, 2]
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# Specific (N, K, M) combinations for targeted testing
NKM_FACTORS_LLMM1 = [
    # Small, medium, large cases
    (1, 8, 16),
    (1, 32, 64),
    (1, 128, 256),
    (1, 512, 1024),
    (1, 2048, 4096),
    # Edge cases with specific K sizes
    (1, 6144, 1024),
    (1, 8192, 2048),
    # Very large case
    (1, 4096, 8192),
]

NKM_FACTORS_WVSPLITK = [
    # Different batch sizes with key dimensions
    (1, 16, 16),
    (1, 64, 64),
    (2, 256, 256),
    (3, 1024, 1024),
    (4, 4096, 4096),
    # Extended K values
    (1, 9216, 512),
    (2, 10240, 1024),
    (4, 16384, 8192),
    # Minimum M constraint validation (m >= 8)
    (1, 64, 8),
    (2, 128, 8),
    (4, 256, 8),
]

48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
N_FACTORS_WVSPLITKRC = [
    13,
    16,
    17,
    25,
    29,
    31,
    32,
    41,
    51,
    64,
    71,
    81,
    91,
    103,
    117,
    128,
65
66
]

67
68
69
K_FACTORS_WVSPLITKRC = [2880, 2880 + 8, 3072, 3072 + 8]
M_FACTORS_WVSPLITKRC = [128, 128 + 16, 256, 256 + 16, 640, 640 + 16]

70
71
72
NKM_FACTORS_WVSPLITK_FP8 = [
    # FP8-specific cases with K % 16 == 0
    (1, 16, 16),
73
    (1, 32, 16 + 16),
74
    (1, 64, 64),
75
76
77
78
79
80
81
    (1, 64, 64 + 16),
    (1, 64 + 16, 64),
    (1, 64 + 16, 64 + 16),
    (4, 64, 64),
    (4, 64, 64 + 16),
    (4, 64 + 16, 64),
    (4, 64 + 16, 64 + 16),
82
    (2, 512, 512),
83
84
85
    (3, 512, 512),
    (3, 512, 512 + 16),
    (4, 512, 512),
86
    (3, 2048, 2048),
87
88
89
    (3, 2048, 2048 + 16),
    (4, 2048 + 16, 2048),
    (4, 2048 + 16, 2048 + 16),
90
    (4, 4096, 4096),
91
    (4, 16400, 2048),
92
    (4, 16400, 2048 + 16),
93
94
95
96
    # Extended FP8 dimensions not covered by WVSPLITK
    (1, 14336, 1024),
    (2, 24576, 2048),
    (4, 32768, 28672),
97
98
99
100
    (4, 32768 * 2, 28672),
    (4, 32768 * 2, 28672 + 16),
    (4, 32768 * 2 + 16, 28672),
    (4, 32768 * 2 + 16, 28672 + 16),
101
102
]

103
104
105
SEEDS = [0]


106
def pad_fp8(weight):
107
108
109
110
111
112
    num_pad = 256 // weight.element_size()
    import torch.nn.functional as F

    return F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]


113
114
115
116
@pytest.mark.parametrize("xnorm", [False, True])
@pytest.mark.parametrize("n", N_FACTORS_WVSPLITKRC)
@pytest.mark.parametrize("k", K_FACTORS_WVSPLITKRC)
@pytest.mark.parametrize("m", M_FACTORS_WVSPLITKRC)
117
118
119
120
121
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("bias_mode", BIAS_MODES)
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
@pytest.mark.skipif(not on_gfx950(), reason="only meant for gfx950")
122
def test_rocm_wvsplitkrc_kernel(xnorm, n, k, m, dtype, seed, bias_mode):
123
124
125
    torch.manual_seed(seed)
    cu_count = get_cu_count()

126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
    # Next ^2 of n
    N_p2 = 1 << (n - 1).bit_length()
    # With 64 Ms per CU (each of 4 SIMDs working on a 16x16 tile),
    # and each working on a 512-shard of K, how many CUs would we need?
    rndup_cus = ((m + 64 - 1) // 64) * ((k + 512 - 1) // 512)
    # How many of 4 waves in a group can work on same 16 Ms at same time?
    # This reduces the Ms each group works on, i.e. increasing the number of CUs needed.
    GrpsShrB = min(N_p2 // 16, 4)
    # Given the above, how many CUs would we need?
    CuNeeded = rndup_cus * GrpsShrB
    # candidate for atomic reduce count splitk?
    fits_wvsplitkrc = CuNeeded <= cu_count

    if not fits_wvsplitkrc:
        pytest.skip("Too large for wvSplitKrc")

    xavier = (
        math.sqrt(2 / k) if xnorm else 1
    )  # normalize to avoid large output-bias deltas
    A = (torch.rand(n, k, dtype=dtype, device="cuda") * 2 - 1) * xavier
    B = (torch.rand(m, k, dtype=dtype, device="cuda") * 2 - 1) * xavier
147
148
149

    BIAS = None
    if bias_mode == 1:
150
        BIAS = torch.rand(m, dtype=dtype, device="cuda") * 2 - 1
151
    elif bias_mode == 2:
152
        BIAS = torch.rand(n, m, dtype=dtype, device="cuda") * 2 - 1
153
154
155
156

    ref_out = torch.nn.functional.linear(A, B, BIAS)
    out = ops.wvSplitKrc(B, A.view(-1, A.size(-1)), cu_count, BIAS)

157
    if xnorm:
158
        torch.testing.assert_close(out, ref_out, atol=1e-3, rtol=1e-8)
159
    else:
160
        torch.testing.assert_close(out, ref_out, atol=1e-3, rtol=1e-2)
161
162


163
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_LLMM1)
164
165
166
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("rows_per_block", [2, 4, 8, 16])
@pytest.mark.parametrize("seed", SEEDS)
167
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
168
169
170
@torch.inference_mode()
def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed):
    torch.manual_seed(seed)
171
    # TODO: Zero-centering the inputs causes errors for LLMM1!
172
173
    #      Without that the numbers quickly saturate, and may
    #      be giving false matches.
174
175
176
177
178
179
    A = torch.rand(n, k, dtype=dtype, device="cuda")
    B = torch.rand(m, k, dtype=dtype, device="cuda")

    ref_out = torch.matmul(A, B.t())
    out = ops.LLMM1(B, A, rows_per_block)

180
    torch.testing.assert_close(out, ref_out, atol=1e-8, rtol=1e-2)
181
182


183
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
184
185
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
186
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
187
188
def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
    torch.manual_seed(seed)
189
    cu_count = get_cu_count()
190

191
192
    A = torch.rand(n, k, dtype=dtype, device="cuda") - 0.5
    B = torch.rand(m, k, dtype=dtype, device="cuda") - 0.5
193

194
195
196
    ref_out = torch.nn.functional.linear(A, B)
    out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count)

197
    torch.testing.assert_close(out, ref_out, atol=1e-8, rtol=1e-2)
198
199
200
201
202


@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
203
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
204
205
def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed):
    torch.manual_seed(seed)
206
    cu_count = get_cu_count()
207
208

    xavier = math.sqrt(2 / k)  # normalize to avoid large output-bias deltas
209
210
211
    A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier
    B = (torch.rand(m, k, dtype=dtype, device="cuda") - 0.5) * xavier
    BIAS = torch.rand(m, dtype=dtype, device="cuda") - 0.5
212
213
214
215

    ref_out = torch.nn.functional.linear(A, B, BIAS)
    out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS)

216
    torch.testing.assert_close(out, ref_out, atol=1e-8, rtol=1e-2)
217
218
219
220
221


@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
222
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
223
224
def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed):
    torch.manual_seed(seed)
225
    cu_count = get_cu_count()
226
227

    xavier = math.sqrt(2 / k)  # normalize to avoid large output-bias deltas
228
229
230
    A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier
    B = (torch.rand(m, k, dtype=dtype, device="cuda") - 0.5) * xavier
    BIAS = torch.rand(n, m, dtype=dtype, device="cuda") - 0.5
231
232
233

    ref_out = torch.nn.functional.linear(A, B, BIAS)
    out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS)
234

235
    torch.testing.assert_close(out, ref_out, atol=1e-8, rtol=1e-2)
236
237


238
@pytest.mark.parametrize("xnorm", [False, True])
239
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8)
240
241
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
242
243
244
@pytest.mark.parametrize("padded_a", [False, True])
@pytest.mark.parametrize("padded_b", [False, True])
@pytest.mark.parametrize("biased", [False, True])
245
246
@pytest.mark.skipif(
    not (current_platform.is_rocm() and current_platform.supports_fp8()),
247
248
    reason="only test for rocm fp8",
)
249
250
251
def test_rocm_wvsplitk_fp8_kernel(
    xnorm, n, k, m, dtype, seed, padded_a, padded_b, biased
):
252
253
    torch.manual_seed(seed)

254
255
256
    xavier = math.sqrt(2 / k) if xnorm else 1  # normalize to avoid large deltas
    A = (torch.rand(n, k, device="cuda") * 2 - 1) * xavier
    B = (torch.rand(m, k, device="cuda") * 2 - 1) * xavier
257
258
259

    A, scale_a = ref_dynamic_per_tensor_fp8_quant(A)
    B, scale_b = ref_dynamic_per_tensor_fp8_quant(B)
260
261
262
263
    if padded_b:
        B = pad_fp8(B)
    if padded_a:
        A = pad_fp8(A)
264

265
    BIAS = None if (not biased) else (torch.rand(m, dtype=dtype, device="cuda") * 2 - 1)
266

267
268
269
    ref_out = torch._scaled_mm(
        A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b, bias=BIAS
    )
270
    out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b, get_cu_count(), BIAS)
271

272
    if xnorm:
273
274
275
276
        torch.testing.assert_close(out, ref_out, atol=1e-3, rtol=1e-8)
    elif k >= 32 * 1024:
        # wider pytrch thresh for large-K & no xnorm
        torch.testing.assert_close(out, ref_out, atol=0.07, rtol=5e-2)
277
    else:
278
        torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2)