test_rocm_skinny_gemms.py 8.29 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
import math

5
6
7
8
9
10
import pytest
import torch

import vllm._custom_ops as ops
from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant
from vllm.platforms import current_platform
11
from vllm.platforms.rocm import on_gfx950
12
from vllm.utils.platform_utils import num_compute_units
13
14

DTYPES = [torch.bfloat16, torch.float16]
15
BIAS_MODES = [0, 1, 2]
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# Specific (N, K, M) combinations for targeted testing
NKM_FACTORS_LLMM1 = [
    # Small, medium, large cases
    (1, 8, 16),
    (1, 32, 64),
    (1, 128, 256),
    (1, 512, 1024),
    (1, 2048, 4096),
    # Edge cases with specific K sizes
    (1, 6144, 1024),
    (1, 8192, 2048),
    # Very large case
    (1, 4096, 8192),
]

NKM_FACTORS_WVSPLITK = [
    # Different batch sizes with key dimensions
33
    (1, 32, 16),
34
35
36
37
    (1, 64, 64),
    (2, 256, 256),
    (3, 1024, 1024),
    (4, 4096, 4096),
38
39
40
    (4, 4096, 4096 + 1),
    (4, 4096 + 16, 4096),
    (4, 4096 + 16, 4096 + 1),
41
42
43
44
    # Extended K values
    (1, 9216, 512),
    (2, 10240, 1024),
    (4, 16384, 8192),
45
46
47
48
    (4, 16384 * 2, 8192),
    (4, 16384 * 2, 8192 + 1),
    (4, 16384 * 2 + 16, 8192),
    (4, 16384 * 2 + 16, 8192 + 1),
49
50
51
52
53
54
    # Minimum M constraint validation (m >= 8)
    (1, 64, 8),
    (2, 128, 8),
    (4, 256, 8),
]

55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
N_FACTORS_WVSPLITKRC = [
    13,
    16,
    17,
    25,
    29,
    31,
    32,
    41,
    51,
    64,
    71,
    81,
    91,
    103,
    117,
    128,
72
73
]

74
75
76
K_FACTORS_WVSPLITKRC = [2880, 2880 + 8, 3072, 3072 + 8]
M_FACTORS_WVSPLITKRC = [128, 128 + 16, 256, 256 + 16, 640, 640 + 16]

77
78
79
NKM_FACTORS_WVSPLITK_FP8 = [
    # FP8-specific cases with K % 16 == 0
    (1, 16, 16),
80
    (1, 32, 16 + 16),
81
    (1, 64, 64),
82
83
84
85
86
87
88
    (1, 64, 64 + 16),
    (1, 64 + 16, 64),
    (1, 64 + 16, 64 + 16),
    (4, 64, 64),
    (4, 64, 64 + 16),
    (4, 64 + 16, 64),
    (4, 64 + 16, 64 + 16),
89
    (2, 512, 512),
90
91
92
    (3, 512, 512),
    (3, 512, 512 + 16),
    (4, 512, 512),
93
    (3, 2048, 2048),
94
95
96
    (3, 2048, 2048 + 16),
    (4, 2048 + 16, 2048),
    (4, 2048 + 16, 2048 + 16),
97
    (4, 4096, 4096),
98
    (4, 16400, 2048),
99
    (4, 16400, 2048 + 16),
100
101
102
103
    # Extended FP8 dimensions not covered by WVSPLITK
    (1, 14336, 1024),
    (2, 24576, 2048),
    (4, 32768, 28672),
104
105
106
107
    (4, 32768 * 2, 28672),
    (4, 32768 * 2, 28672 + 16),
    (4, 32768 * 2 + 16, 28672),
    (4, 32768 * 2 + 16, 28672 + 16),
108
109
]

110
111
112
SEEDS = [0]


113
def pad_fp8(weight):
114
115
116
117
118
119
    num_pad = 256 // weight.element_size()
    import torch.nn.functional as F

    return F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]


120
121
122
123
@pytest.mark.parametrize("xnorm", [False, True])
@pytest.mark.parametrize("n", N_FACTORS_WVSPLITKRC)
@pytest.mark.parametrize("k", K_FACTORS_WVSPLITKRC)
@pytest.mark.parametrize("m", M_FACTORS_WVSPLITKRC)
124
125
126
127
128
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("bias_mode", BIAS_MODES)
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
@pytest.mark.skipif(not on_gfx950(), reason="only meant for gfx950")
129
def test_rocm_wvsplitkrc_kernel(xnorm, n, k, m, dtype, seed, bias_mode):
130
    torch.manual_seed(seed)
131
    cu_count = num_compute_units()
132

133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
    # Next ^2 of n
    N_p2 = 1 << (n - 1).bit_length()
    # With 64 Ms per CU (each of 4 SIMDs working on a 16x16 tile),
    # and each working on a 512-shard of K, how many CUs would we need?
    rndup_cus = ((m + 64 - 1) // 64) * ((k + 512 - 1) // 512)
    # How many of 4 waves in a group can work on same 16 Ms at same time?
    # This reduces the Ms each group works on, i.e. increasing the number of CUs needed.
    GrpsShrB = min(N_p2 // 16, 4)
    # Given the above, how many CUs would we need?
    CuNeeded = rndup_cus * GrpsShrB
    # candidate for atomic reduce count splitk?
    fits_wvsplitkrc = CuNeeded <= cu_count

    if not fits_wvsplitkrc:
        pytest.skip("Too large for wvSplitKrc")

    xavier = (
        math.sqrt(2 / k) if xnorm else 1
    )  # normalize to avoid large output-bias deltas
    A = (torch.rand(n, k, dtype=dtype, device="cuda") * 2 - 1) * xavier
    B = (torch.rand(m, k, dtype=dtype, device="cuda") * 2 - 1) * xavier
154
155
156

    BIAS = None
    if bias_mode == 1:
157
        BIAS = torch.rand(m, dtype=dtype, device="cuda") * 2 - 1
158
    elif bias_mode == 2:
159
        BIAS = torch.rand(n, m, dtype=dtype, device="cuda") * 2 - 1
160
161
162
163

    ref_out = torch.nn.functional.linear(A, B, BIAS)
    out = ops.wvSplitKrc(B, A.view(-1, A.size(-1)), cu_count, BIAS)

164
    if xnorm:
165
        torch.testing.assert_close(out, ref_out, atol=1e-3, rtol=1e-8)
166
    else:
167
        torch.testing.assert_close(out, ref_out, atol=1e-3, rtol=1e-2)
168
169


170
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_LLMM1)
171
172
173
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("rows_per_block", [2, 4, 8, 16])
@pytest.mark.parametrize("seed", SEEDS)
174
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
175
176
177
@torch.inference_mode()
def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed):
    torch.manual_seed(seed)
178
    # TODO: Zero-centering the inputs causes errors for LLMM1!
179
180
    #      Without that the numbers quickly saturate, and may
    #      be giving false matches.
181
182
183
184
185
186
    A = torch.rand(n, k, dtype=dtype, device="cuda")
    B = torch.rand(m, k, dtype=dtype, device="cuda")

    ref_out = torch.matmul(A, B.t())
    out = ops.LLMM1(B, A, rows_per_block)

187
    torch.testing.assert_close(out, ref_out, atol=1e-8, rtol=1e-2)
188
189


190
@pytest.mark.parametrize("xnorm", [False, True])
191
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
192
193
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
194
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
195
196
197
198
199
200
@pytest.mark.parametrize("bias_mode", BIAS_MODES)
@pytest.mark.parametrize("padded_a", [False, True])
@pytest.mark.parametrize("padded_b", [False, True])
def test_rocm_wvsplitk_kernel(
    xnorm, n, k, m, dtype, seed, bias_mode, padded_a, padded_b
):
201
    torch.manual_seed(seed)
202
    cu_count = num_compute_units()
203

204
205
206
207
208
    xavier = (
        math.sqrt(2 / k) if xnorm else 1
    )  # normalize to avoid large output-bias deltas
    A = (torch.rand(n, k, dtype=dtype, device="cuda") * 2 - 1) * xavier
    B = (torch.rand(m, k, dtype=dtype, device="cuda") * 2 - 1) * xavier
209

210
211
212
213
214
    BIAS = None
    if bias_mode == 1:
        BIAS = torch.rand(m, dtype=dtype, device="cuda") * 2 - 1
    elif bias_mode == 2:
        BIAS = torch.rand(n, m, dtype=dtype, device="cuda") * 2 - 1
215

216
217
218
219
    if padded_a:
        A = pad_fp8(A)
    if padded_b:
        B = pad_fp8(B)
220
221
222

    ref_out = torch.nn.functional.linear(A, B, BIAS)
    out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS)
223

224
225
226
227
    if xnorm:
        assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-8)
    else:
        assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-2)
228
229


230
@pytest.mark.parametrize("xnorm", [False, True])
231
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8)
232
233
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
234
235
236
@pytest.mark.parametrize("padded_a", [False, True])
@pytest.mark.parametrize("padded_b", [False, True])
@pytest.mark.parametrize("biased", [False, True])
237
238
@pytest.mark.skipif(
    not (current_platform.is_rocm() and current_platform.supports_fp8()),
239
240
    reason="only test for rocm fp8",
)
241
242
243
def test_rocm_wvsplitk_fp8_kernel(
    xnorm, n, k, m, dtype, seed, padded_a, padded_b, biased
):
244
245
    torch.manual_seed(seed)

246
247
248
    xavier = math.sqrt(2 / k) if xnorm else 1  # normalize to avoid large deltas
    A = (torch.rand(n, k, device="cuda") * 2 - 1) * xavier
    B = (torch.rand(m, k, device="cuda") * 2 - 1) * xavier
249
250
251

    A, scale_a = ref_dynamic_per_tensor_fp8_quant(A)
    B, scale_b = ref_dynamic_per_tensor_fp8_quant(B)
252
253
254
255
    if padded_b:
        B = pad_fp8(B)
    if padded_a:
        A = pad_fp8(A)
256

257
    BIAS = None if (not biased) else (torch.rand(m, dtype=dtype, device="cuda") * 2 - 1)
258

259
260
261
    ref_out = torch._scaled_mm(
        A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b, bias=BIAS
    )
262
    out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b, num_compute_units(), BIAS)
263

264
    if xnorm:
265
266
267
268
        torch.testing.assert_close(out, ref_out, atol=1e-3, rtol=1e-8)
    elif k >= 32 * 1024:
        # wider pytrch thresh for large-K & no xnorm
        torch.testing.assert_close(out, ref_out, atol=0.07, rtol=5e-2)
269
    else:
270
        torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2)