"vllm/vscode:/vscode.git/clone" did not exist on "1ec978c209391286d4cee968426900e9a4d256a5"
test_rocm_skinny_gemms.py 8.07 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
import math

5
6
7
8
9
10
import pytest
import torch

import vllm._custom_ops as ops
from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant
from vllm.platforms import current_platform
11
from vllm.platforms.rocm import on_gfx950
12
from vllm.utils.platform_utils import get_cu_count
13
14

DTYPES = [torch.bfloat16, torch.float16]
15
BIAS_MODES = [0, 1, 2]
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# Specific (N, K, M) combinations for targeted testing
NKM_FACTORS_LLMM1 = [
    # Small, medium, large cases
    (1, 8, 16),
    (1, 32, 64),
    (1, 128, 256),
    (1, 512, 1024),
    (1, 2048, 4096),
    # Edge cases with specific K sizes
    (1, 6144, 1024),
    (1, 8192, 2048),
    # Very large case
    (1, 4096, 8192),
]

NKM_FACTORS_WVSPLITK = [
    # Different batch sizes with key dimensions
    (1, 16, 16),
    (1, 64, 64),
    (2, 256, 256),
    (3, 1024, 1024),
    (4, 4096, 4096),
    # Extended K values
    (1, 9216, 512),
    (2, 10240, 1024),
    (4, 16384, 8192),
    # Minimum M constraint validation (m >= 8)
    (1, 64, 8),
    (2, 128, 8),
    (4, 256, 8),
]

48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
NKM_FACTORS_WVSPLITKRC = [
    (16, 2880, 128),
    (16, 2880, 640),
    (17, 2880, 128),
    (17, 2880, 640),
    (25, 2880, 128),
    (25, 2880, 640),
    (31, 2880, 128),
    (31, 2880, 640),
    (32, 2880, 128),
    (32, 2880, 640),
    (40, 2880, 128),
    (40, 2880, 640),
    (60, 2880, 128),
    (60, 2880, 640),
    (64, 2880, 128),
    (64, 2880, 640),
    (81, 2880, 128),
    (81, 2880, 640),
    (98, 2880, 128),
    (98, 2880, 640),
    (128, 2880, 128),
    (128, 2880, 640),
]

73
74
75
76
77
78
79
NKM_FACTORS_WVSPLITK_FP8 = [
    # FP8-specific cases with K % 16 == 0
    (1, 16, 16),
    (1, 64, 64),
    (2, 512, 512),
    (3, 2048, 2048),
    (4, 4096, 4096),
80
    (4, 16400, 2048),
81
82
83
84
85
86
    # Extended FP8 dimensions not covered by WVSPLITK
    (1, 14336, 1024),
    (2, 24576, 2048),
    (4, 32768, 28672),
]

87
88
89
SEEDS = [0]


90
91
92
93
94
95
96
def pad_weights_fp8(weight):
    num_pad = 256 // weight.element_size()
    import torch.nn.functional as F

    return F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]


97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITKRC)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("bias_mode", BIAS_MODES)
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
@pytest.mark.skipif(not on_gfx950(), reason="only meant for gfx950")
def test_rocm_wvsplitkrc_kernel(n, k, m, dtype, seed, bias_mode):
    torch.manual_seed(seed)
    cu_count = get_cu_count()

    xavier = math.sqrt(2 / k)  # normalize to avoid large output-bias deltas
    A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier
    B = (torch.rand(m, k, dtype=dtype, device="cuda") - 0.5) * xavier

    BIAS = None
    if bias_mode == 1:
        BIAS = torch.rand(m, dtype=dtype, device="cuda") - 0.5
    elif bias_mode == 2:
        BIAS = torch.rand(n, m, dtype=dtype, device="cuda") - 0.5

    ref_out = torch.nn.functional.linear(A, B, BIAS)
    out = ops.wvSplitKrc(B, A.view(-1, A.size(-1)), cu_count, BIAS)

    assert torch.allclose(out, ref_out, rtol=0.01)


123
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_LLMM1)
124
125
126
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("rows_per_block", [2, 4, 8, 16])
@pytest.mark.parametrize("seed", SEEDS)
127
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
128
129
130
@torch.inference_mode()
def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed):
    torch.manual_seed(seed)
131
    # TODO: Zero-centering the inputs causes errors for LLMM1!
132
133
    #      Without that the numbers quickly saturate, and may
    #      be giving false matches.
134
135
136
137
138
139
140
141
142
    A = torch.rand(n, k, dtype=dtype, device="cuda")
    B = torch.rand(m, k, dtype=dtype, device="cuda")

    ref_out = torch.matmul(A, B.t())
    out = ops.LLMM1(B, A, rows_per_block)

    assert torch.allclose(out, ref_out, rtol=0.01)


143
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
144
145
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
146
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
147
148
def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
    torch.manual_seed(seed)
149
    cu_count = get_cu_count()
150

151
152
    A = torch.rand(n, k, dtype=dtype, device="cuda") - 0.5
    B = torch.rand(m, k, dtype=dtype, device="cuda") - 0.5
153

154
155
156
157
158
159
160
161
162
    ref_out = torch.nn.functional.linear(A, B)
    out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count)

    assert torch.allclose(out, ref_out, rtol=0.01)


@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
163
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
164
165
def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed):
    torch.manual_seed(seed)
166
    cu_count = get_cu_count()
167
168

    xavier = math.sqrt(2 / k)  # normalize to avoid large output-bias deltas
169
170
171
    A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier
    B = (torch.rand(m, k, dtype=dtype, device="cuda") - 0.5) * xavier
    BIAS = torch.rand(m, dtype=dtype, device="cuda") - 0.5
172
173
174
175
176
177
178
179
180
181

    ref_out = torch.nn.functional.linear(A, B, BIAS)
    out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS)

    assert torch.allclose(out, ref_out, rtol=0.01)


@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
182
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
183
184
def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed):
    torch.manual_seed(seed)
185
    cu_count = get_cu_count()
186
187

    xavier = math.sqrt(2 / k)  # normalize to avoid large output-bias deltas
188
189
190
    A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier
    B = (torch.rand(m, k, dtype=dtype, device="cuda") - 0.5) * xavier
    BIAS = torch.rand(n, m, dtype=dtype, device="cuda") - 0.5
191
192
193

    ref_out = torch.nn.functional.linear(A, B, BIAS)
    out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS)
194
195
196
197

    assert torch.allclose(out, ref_out, rtol=0.01)


198
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8)
199
200
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
201
@pytest.mark.parametrize("padded", [False, True])
202
203
@pytest.mark.skipif(
    not (current_platform.is_rocm() and current_platform.supports_fp8()),
204
205
    reason="only test for rocm fp8",
)
206
def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed, padded):
207
208
    torch.manual_seed(seed)

209
210
    A = torch.rand(n, k, device="cuda") - 0.5
    B = torch.rand(m, k, device="cuda") - 0.5
211
212
213

    A, scale_a = ref_dynamic_per_tensor_fp8_quant(A)
    B, scale_b = ref_dynamic_per_tensor_fp8_quant(B)
214
215
    if padded:
        B = pad_weights_fp8(B)
216

217
218
219
    ref_out = torch._scaled_mm(
        A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b
    )
220
221
222
223
224
225
226
227
    out = ops.wvSplitKQ(
        B,
        A,
        dtype,
        scale_a,
        scale_b,
        get_cu_count(),
    )
228
229

    assert torch.allclose(out, ref_out, rtol=0.01)
230
231
232
233
234


@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
235
@pytest.mark.parametrize("padded", [False, True])
236
237
@pytest.mark.skipif(
    not (current_platform.is_rocm() and current_platform.supports_fp8()),
238
239
    reason="only test for rocm fp8",
)
240
def test_rocm_wvsplitk_fp8_bias1D_kernel(n, k, m, dtype, seed, padded):
241
242
    torch.manual_seed(seed)

243
    xavier = math.sqrt(2 / k)  # normalize to avoid large output-bias deltas
244
245
246
    A = (torch.rand(n, k, device="cuda") - 0.5) * xavier
    B = (torch.rand(m, k, device="cuda") - 0.5) * xavier
    BIAS = torch.rand(m, dtype=dtype, device="cuda") - 0.5
247
248
249

    A, scale_a = ref_dynamic_per_tensor_fp8_quant(A)
    B, scale_b = ref_dynamic_per_tensor_fp8_quant(B)
250
251
    if padded:
        B = pad_weights_fp8(B)
252

253
254
255
256
    ref_out = torch._scaled_mm(
        A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b, bias=BIAS
    )
    out = ops.wvSplitKQ(
257
258
259
260
261
262
263
        B,
        A,
        dtype,
        scale_a,
        scale_b,
        get_cu_count(),
        BIAS,
264
    )
265
266

    assert torch.allclose(out, ref_out, rtol=0.01)