test_rocm_skinny_gemms.py 8.55 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
import math

5
6
7
8
9
10
import pytest
import torch

import vllm._custom_ops as ops
from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant
from vllm.platforms import current_platform
11
from vllm.platforms.rocm import on_gfx950
12
from vllm.utils.platform_utils import num_compute_units
13
14

DTYPES = [torch.bfloat16, torch.float16]
15
BIAS_MODES = [0, 1, 2]
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# Specific (N, K, M) combinations for targeted testing
NKM_FACTORS_LLMM1 = [
    # Small, medium, large cases
    (1, 8, 16),
    (1, 32, 64),
    (1, 128, 256),
    (1, 512, 1024),
    (1, 2048, 4096),
    # Edge cases with specific K sizes
    (1, 6144, 1024),
    (1, 8192, 2048),
    # Very large case
    (1, 4096, 8192),
]

NKM_FACTORS_WVSPLITK = [
    # Different batch sizes with key dimensions
33
    (1, 32, 16),
34
35
36
37
    (1, 64, 64),
    (2, 256, 256),
    (3, 1024, 1024),
    (4, 4096, 4096),
38
39
40
    (4, 4096, 4096 + 1),
    (4, 4096 + 16, 4096),
    (4, 4096 + 16, 4096 + 1),
41
42
43
44
    # Extended K values
    (1, 9216, 512),
    (2, 10240, 1024),
    (4, 16384, 8192),
45
46
47
48
    (4, 16384 * 2, 8192),
    (4, 16384 * 2, 8192 + 1),
    (4, 16384 * 2 + 16, 8192),
    (4, 16384 * 2 + 16, 8192 + 1),
49
50
51
52
53
54
    # Minimum M constraint validation (m >= 8)
    (1, 64, 8),
    (2, 128, 8),
    (4, 256, 8),
]

55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
N_FACTORS_WVSPLITKRC = [
    13,
    16,
    17,
    25,
    29,
    31,
    32,
    41,
    51,
    64,
    71,
    81,
    91,
    103,
    117,
    128,
72
]
73
74
75
K_FACTORS_WVSPLITKRC = [2880, 2880 + 8, 3072, 3072 + 8]
M_FACTORS_WVSPLITKRC = [128, 128 + 16, 256, 256 + 16, 640, 640 + 16]

76
77
78
NKM_FACTORS_WVSPLITK_FP8 = [
    # FP8-specific cases with K % 16 == 0
    (1, 16, 16),
79
    (1, 32, 16 + 16),
80
    (1, 64, 64),
81
82
83
84
85
86
87
    (1, 64, 64 + 16),
    (1, 64 + 16, 64),
    (1, 64 + 16, 64 + 16),
    (4, 64, 64),
    (4, 64, 64 + 16),
    (4, 64 + 16, 64),
    (4, 64 + 16, 64 + 16),
88
    (2, 512, 512),
89
90
91
    (3, 512, 512),
    (3, 512, 512 + 16),
    (4, 512, 512),
92
    (3, 2048, 2048),
93
94
95
    (3, 2048, 2048 + 16),
    (4, 2048 + 16, 2048),
    (4, 2048 + 16, 2048 + 16),
96
    (4, 4096, 4096),
97
    (4, 16400, 2048),
98
    (4, 16400, 2048 + 16),
99
100
101
102
    # Extended FP8 dimensions not covered by WVSPLITK
    (1, 14336, 1024),
    (2, 24576, 2048),
    (4, 32768, 28672),
103
104
105
106
    (4, 32768 * 2, 28672),
    (4, 32768 * 2, 28672 + 16),
    (4, 32768 * 2 + 16, 28672),
    (4, 32768 * 2 + 16, 28672 + 16),
107
108
]

109
110
111
SEEDS = [0]


112
def pad_fp8(weight):
113
114
115
116
117
118
    num_pad = 256 // weight.element_size()
    import torch.nn.functional as F

    return F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]


119
120
121
122
@pytest.mark.parametrize("xnorm", [False, True])
@pytest.mark.parametrize("n", N_FACTORS_WVSPLITKRC)
@pytest.mark.parametrize("k", K_FACTORS_WVSPLITKRC)
@pytest.mark.parametrize("m", M_FACTORS_WVSPLITKRC)
123
124
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
125
@pytest.mark.parametrize("padded_a", [False, True])
126
127
128
@pytest.mark.parametrize("bias_mode", BIAS_MODES)
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
@pytest.mark.skipif(not on_gfx950(), reason="only meant for gfx950")
129
def test_rocm_wvsplitkrc_kernel(xnorm, n, k, m, dtype, seed, padded_a, bias_mode):
130
    torch.manual_seed(seed)
131
    cu_count = num_compute_units()
132

133
134
135
136
137
138
139
140
141
142
143
    # Next ^2 of n
    N_p2 = 1 << (n - 1).bit_length()
    # With 64 Ms per CU (each of 4 SIMDs working on a 16x16 tile),
    # and each working on a 512-shard of K, how many CUs would we need?
    rndup_cus = ((m + 64 - 1) // 64) * ((k + 512 - 1) // 512)
    # How many of 4 waves in a group can work on same 16 Ms at same time?
    # This reduces the Ms each group works on, i.e. increasing the number of CUs needed.
    GrpsShrB = min(N_p2 // 16, 4)
    # Given the above, how many CUs would we need?
    CuNeeded = rndup_cus * GrpsShrB
    # candidate for atomic reduce count splitk?
144
145
    fits_wvsplitkrc = (N_p2 * m * ((k + 512 - 1) // 512)) <= 128 * 1024 * 12
    fits_wvsplitkrc &= CuNeeded <= cu_count
146
147
148
149
150
151
152
153
154

    if not fits_wvsplitkrc:
        pytest.skip("Too large for wvSplitKrc")

    xavier = (
        math.sqrt(2 / k) if xnorm else 1
    )  # normalize to avoid large output-bias deltas
    A = (torch.rand(n, k, dtype=dtype, device="cuda") * 2 - 1) * xavier
    B = (torch.rand(m, k, dtype=dtype, device="cuda") * 2 - 1) * xavier
155
156
    if padded_a:
        A = pad_fp8(A)
157
158
159

    BIAS = None
    if bias_mode == 1:
160
        BIAS = torch.rand(m, dtype=dtype, device="cuda") * 2 - 1
161
    elif bias_mode == 2:
162
        BIAS = torch.rand(n, m, dtype=dtype, device="cuda") * 2 - 1
163
164
    elif bias_mode == 3:
        BIAS = torch.rand(1, m, dtype=dtype, device="cuda") * 2 - 1
165
166

    ref_out = torch.nn.functional.linear(A, B, BIAS)
167
    out = ops.wvSplitKrc(A, B, cu_count, BIAS)
168

169
    if xnorm:
170
        torch.testing.assert_close(out, ref_out, atol=1e-3, rtol=1e-8)
171
    else:
172
        torch.testing.assert_close(out, ref_out, atol=1e-3, rtol=1e-2)
173
174


175
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_LLMM1)
176
177
178
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("rows_per_block", [2, 4, 8, 16])
@pytest.mark.parametrize("seed", SEEDS)
179
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
180
181
182
@torch.inference_mode()
def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed):
    torch.manual_seed(seed)
183
    # TODO: Zero-centering the inputs causes errors for LLMM1!
184
185
    #      Without that the numbers quickly saturate, and may
    #      be giving false matches.
186
187
188
189
190
191
    A = torch.rand(n, k, dtype=dtype, device="cuda")
    B = torch.rand(m, k, dtype=dtype, device="cuda")

    ref_out = torch.matmul(A, B.t())
    out = ops.LLMM1(B, A, rows_per_block)

192
    torch.testing.assert_close(out, ref_out, atol=1e-8, rtol=1e-2)
193
194


195
@pytest.mark.parametrize("xnorm", [False, True])
196
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
197
198
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
199
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
200
201
202
203
204
205
@pytest.mark.parametrize("bias_mode", BIAS_MODES)
@pytest.mark.parametrize("padded_a", [False, True])
@pytest.mark.parametrize("padded_b", [False, True])
def test_rocm_wvsplitk_kernel(
    xnorm, n, k, m, dtype, seed, bias_mode, padded_a, padded_b
):
206
    torch.manual_seed(seed)
207
    cu_count = num_compute_units()
208

209
210
211
212
213
    xavier = (
        math.sqrt(2 / k) if xnorm else 1
    )  # normalize to avoid large output-bias deltas
    A = (torch.rand(n, k, dtype=dtype, device="cuda") * 2 - 1) * xavier
    B = (torch.rand(m, k, dtype=dtype, device="cuda") * 2 - 1) * xavier
214

215
216
217
218
219
    BIAS = None
    if bias_mode == 1:
        BIAS = torch.rand(m, dtype=dtype, device="cuda") * 2 - 1
    elif bias_mode == 2:
        BIAS = torch.rand(n, m, dtype=dtype, device="cuda") * 2 - 1
220

221
222
223
224
    if padded_a:
        A = pad_fp8(A)
    if padded_b:
        B = pad_fp8(B)
225
226
227

    ref_out = torch.nn.functional.linear(A, B, BIAS)
    out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS)
228

229
230
231
    # Accumulation error in fp16 GEMM scales with sqrt(K)
    atol = torch.finfo(dtype).eps * math.sqrt(k)
    torch.testing.assert_close(out, ref_out, atol=atol, rtol=1e-2)
232
233


234
@pytest.mark.parametrize("xnorm", [False, True])
235
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8)
236
237
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
238
239
240
@pytest.mark.parametrize("padded_a", [False, True])
@pytest.mark.parametrize("padded_b", [False, True])
@pytest.mark.parametrize("biased", [False, True])
241
242
@pytest.mark.skipif(
    not (current_platform.is_rocm() and current_platform.supports_fp8()),
243
244
    reason="only test for rocm fp8",
)
245
246
247
def test_rocm_wvsplitk_fp8_kernel(
    xnorm, n, k, m, dtype, seed, padded_a, padded_b, biased
):
248
249
    torch.manual_seed(seed)

250
251
252
    xavier = math.sqrt(2 / k) if xnorm else 1  # normalize to avoid large deltas
    A = (torch.rand(n, k, device="cuda") * 2 - 1) * xavier
    B = (torch.rand(m, k, device="cuda") * 2 - 1) * xavier
253
254
255

    A, scale_a = ref_dynamic_per_tensor_fp8_quant(A)
    B, scale_b = ref_dynamic_per_tensor_fp8_quant(B)
256
257
258
259
    if padded_b:
        B = pad_fp8(B)
    if padded_a:
        A = pad_fp8(A)
260

261
    BIAS = None if (not biased) else (torch.rand(m, dtype=dtype, device="cuda") * 2 - 1)
262

263
264
265
    ref_out = torch._scaled_mm(
        A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b, bias=BIAS
    )
266
    out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b, num_compute_units(), BIAS)
267

268
    if xnorm:
269
270
271
272
        torch.testing.assert_close(out, ref_out, atol=1e-3, rtol=1e-8)
    elif k >= 32 * 1024:
        # wider pytrch thresh for large-K & no xnorm
        torch.testing.assert_close(out, ref_out, atol=0.07, rtol=5e-2)
273
    else:
274
        torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2)