test_rocm_skinny_gemms.py 8.44 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
import math

5
6
7
8
9
10
import pytest
import torch

import vllm._custom_ops as ops
from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant
from vllm.platforms import current_platform
11
from vllm.platforms.rocm import on_gfx950
12
from vllm.utils.platform_utils import num_compute_units
13
14

DTYPES = [torch.bfloat16, torch.float16]
15
BIAS_MODES = [0, 1, 2]
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# Specific (N, K, M) combinations for targeted testing
NKM_FACTORS_LLMM1 = [
    # Small, medium, large cases
    (1, 8, 16),
    (1, 32, 64),
    (1, 128, 256),
    (1, 512, 1024),
    (1, 2048, 4096),
    # Edge cases with specific K sizes
    (1, 6144, 1024),
    (1, 8192, 2048),
    # Very large case
    (1, 4096, 8192),
]

NKM_FACTORS_WVSPLITK = [
    # Different batch sizes with key dimensions
33
    (1, 32, 16),
34
35
36
37
    (1, 64, 64),
    (2, 256, 256),
    (3, 1024, 1024),
    (4, 4096, 4096),
38
39
40
    (4, 4096, 4096 + 1),
    (4, 4096 + 16, 4096),
    (4, 4096 + 16, 4096 + 1),
41
42
43
44
    # Extended K values
    (1, 9216, 512),
    (2, 10240, 1024),
    (4, 16384, 8192),
45
46
47
48
    (4, 16384 * 2, 8192),
    (4, 16384 * 2, 8192 + 1),
    (4, 16384 * 2 + 16, 8192),
    (4, 16384 * 2 + 16, 8192 + 1),
49
50
51
52
53
54
    # Minimum M constraint validation (m >= 8)
    (1, 64, 8),
    (2, 128, 8),
    (4, 256, 8),
]

55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
N_FACTORS_WVSPLITKRC = [
    13,
    16,
    17,
    25,
    29,
    31,
    32,
    41,
    51,
    64,
    71,
    81,
    91,
    103,
    117,
    128,
72
]
73
74
75
K_FACTORS_WVSPLITKRC = [2880, 2880 + 8, 3072, 3072 + 8]
M_FACTORS_WVSPLITKRC = [128, 128 + 16, 256, 256 + 16, 640, 640 + 16]

76
77
78
NKM_FACTORS_WVSPLITK_FP8 = [
    # FP8-specific cases with K % 16 == 0
    (1, 16, 16),
79
    (1, 32, 16 + 16),
80
    (1, 64, 64),
81
82
83
84
85
86
87
    (1, 64, 64 + 16),
    (1, 64 + 16, 64),
    (1, 64 + 16, 64 + 16),
    (4, 64, 64),
    (4, 64, 64 + 16),
    (4, 64 + 16, 64),
    (4, 64 + 16, 64 + 16),
88
    (2, 512, 512),
89
90
91
    (3, 512, 512),
    (3, 512, 512 + 16),
    (4, 512, 512),
92
    (3, 2048, 2048),
93
94
95
    (3, 2048, 2048 + 16),
    (4, 2048 + 16, 2048),
    (4, 2048 + 16, 2048 + 16),
96
    (4, 4096, 4096),
97
    (4, 16400, 2048),
98
    (4, 16400, 2048 + 16),
99
100
101
102
    # Extended FP8 dimensions not covered by WVSPLITK
    (1, 14336, 1024),
    (2, 24576, 2048),
    (4, 32768, 28672),
103
104
105
106
    (4, 32768 * 2, 28672),
    (4, 32768 * 2, 28672 + 16),
    (4, 32768 * 2 + 16, 28672),
    (4, 32768 * 2 + 16, 28672 + 16),
107
108
]

109
110
111
SEEDS = [0]


112
def pad_fp8(weight):
113
114
115
116
117
118
    num_pad = 256 // weight.element_size()
    import torch.nn.functional as F

    return F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]


119
120
121
122
@pytest.mark.parametrize("xnorm", [False, True])
@pytest.mark.parametrize("n", N_FACTORS_WVSPLITKRC)
@pytest.mark.parametrize("k", K_FACTORS_WVSPLITKRC)
@pytest.mark.parametrize("m", M_FACTORS_WVSPLITKRC)
123
124
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
125
@pytest.mark.parametrize("padded_a", [False, True])
126
127
128
@pytest.mark.parametrize("bias_mode", BIAS_MODES)
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
@pytest.mark.skipif(not on_gfx950(), reason="only meant for gfx950")
129
def test_rocm_wvsplitkrc_kernel(xnorm, n, k, m, dtype, seed, padded_a, bias_mode):
130
    torch.manual_seed(seed)
131
    cu_count = num_compute_units()
132

133
134
135
136
137
138
139
140
141
142
143
    # Next ^2 of n
    N_p2 = 1 << (n - 1).bit_length()
    # With 64 Ms per CU (each of 4 SIMDs working on a 16x16 tile),
    # and each working on a 512-shard of K, how many CUs would we need?
    rndup_cus = ((m + 64 - 1) // 64) * ((k + 512 - 1) // 512)
    # How many of 4 waves in a group can work on same 16 Ms at same time?
    # This reduces the Ms each group works on, i.e. increasing the number of CUs needed.
    GrpsShrB = min(N_p2 // 16, 4)
    # Given the above, how many CUs would we need?
    CuNeeded = rndup_cus * GrpsShrB
    # candidate for atomic reduce count splitk?
144
145
    fits_wvsplitkrc = (N_p2 * m * ((k + 512 - 1) // 512)) <= 128 * 1024 * 12
    fits_wvsplitkrc &= CuNeeded <= cu_count
146
147
148
149
150
151
152
153
154

    if not fits_wvsplitkrc:
        pytest.skip("Too large for wvSplitKrc")

    xavier = (
        math.sqrt(2 / k) if xnorm else 1
    )  # normalize to avoid large output-bias deltas
    A = (torch.rand(n, k, dtype=dtype, device="cuda") * 2 - 1) * xavier
    B = (torch.rand(m, k, dtype=dtype, device="cuda") * 2 - 1) * xavier
155
156
    if padded_a:
        A = pad_fp8(A)
157
158
159

    BIAS = None
    if bias_mode == 1:
160
        BIAS = torch.rand(m, dtype=dtype, device="cuda") * 2 - 1
161
    elif bias_mode == 2:
162
        BIAS = torch.rand(n, m, dtype=dtype, device="cuda") * 2 - 1
163
164

    ref_out = torch.nn.functional.linear(A, B, BIAS)
165
    out = ops.wvSplitKrc(A, B, cu_count, BIAS)
166

167
    if xnorm:
168
        torch.testing.assert_close(out, ref_out, atol=1e-3, rtol=1e-8)
169
    else:
170
        torch.testing.assert_close(out, ref_out, atol=1e-3, rtol=1e-2)
171
172


173
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_LLMM1)
174
175
176
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("rows_per_block", [2, 4, 8, 16])
@pytest.mark.parametrize("seed", SEEDS)
177
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
178
179
180
@torch.inference_mode()
def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed):
    torch.manual_seed(seed)
181
    # TODO: Zero-centering the inputs causes errors for LLMM1!
182
183
    #      Without that the numbers quickly saturate, and may
    #      be giving false matches.
184
185
186
187
188
189
    A = torch.rand(n, k, dtype=dtype, device="cuda")
    B = torch.rand(m, k, dtype=dtype, device="cuda")

    ref_out = torch.matmul(A, B.t())
    out = ops.LLMM1(B, A, rows_per_block)

190
    torch.testing.assert_close(out, ref_out, atol=1e-8, rtol=1e-2)
191
192


193
@pytest.mark.parametrize("xnorm", [False, True])
194
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
195
196
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
197
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
198
199
200
201
202
203
@pytest.mark.parametrize("bias_mode", BIAS_MODES)
@pytest.mark.parametrize("padded_a", [False, True])
@pytest.mark.parametrize("padded_b", [False, True])
def test_rocm_wvsplitk_kernel(
    xnorm, n, k, m, dtype, seed, bias_mode, padded_a, padded_b
):
204
    torch.manual_seed(seed)
205
    cu_count = num_compute_units()
206

207
208
209
210
211
    xavier = (
        math.sqrt(2 / k) if xnorm else 1
    )  # normalize to avoid large output-bias deltas
    A = (torch.rand(n, k, dtype=dtype, device="cuda") * 2 - 1) * xavier
    B = (torch.rand(m, k, dtype=dtype, device="cuda") * 2 - 1) * xavier
212

213
214
215
216
217
    BIAS = None
    if bias_mode == 1:
        BIAS = torch.rand(m, dtype=dtype, device="cuda") * 2 - 1
    elif bias_mode == 2:
        BIAS = torch.rand(n, m, dtype=dtype, device="cuda") * 2 - 1
218

219
220
221
222
    if padded_a:
        A = pad_fp8(A)
    if padded_b:
        B = pad_fp8(B)
223
224
225

    ref_out = torch.nn.functional.linear(A, B, BIAS)
    out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS)
226

227
228
229
230
    if xnorm:
        assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-8)
    else:
        assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-2)
231
232


233
@pytest.mark.parametrize("xnorm", [False, True])
234
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8)
235
236
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
237
238
239
@pytest.mark.parametrize("padded_a", [False, True])
@pytest.mark.parametrize("padded_b", [False, True])
@pytest.mark.parametrize("biased", [False, True])
240
241
@pytest.mark.skipif(
    not (current_platform.is_rocm() and current_platform.supports_fp8()),
242
243
    reason="only test for rocm fp8",
)
244
245
246
def test_rocm_wvsplitk_fp8_kernel(
    xnorm, n, k, m, dtype, seed, padded_a, padded_b, biased
):
247
248
    torch.manual_seed(seed)

249
250
251
    xavier = math.sqrt(2 / k) if xnorm else 1  # normalize to avoid large deltas
    A = (torch.rand(n, k, device="cuda") * 2 - 1) * xavier
    B = (torch.rand(m, k, device="cuda") * 2 - 1) * xavier
252
253
254

    A, scale_a = ref_dynamic_per_tensor_fp8_quant(A)
    B, scale_b = ref_dynamic_per_tensor_fp8_quant(B)
255
256
257
258
    if padded_b:
        B = pad_fp8(B)
    if padded_a:
        A = pad_fp8(A)
259

260
    BIAS = None if (not biased) else (torch.rand(m, dtype=dtype, device="cuda") * 2 - 1)
261

262
263
264
    ref_out = torch._scaled_mm(
        A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b, bias=BIAS
    )
265
    out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b, num_compute_units(), BIAS)
266

267
    if xnorm:
268
269
270
271
        torch.testing.assert_close(out, ref_out, atol=1e-3, rtol=1e-8)
    elif k >= 32 * 1024:
        # wider pytrch thresh for large-K & no xnorm
        torch.testing.assert_close(out, ref_out, atol=0.07, rtol=5e-2)
272
    else:
273
        torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2)