test_cutlass.py 17 KB
Newer Older
1
2
3
4
"""Tests for cutlass kernels

Run `pytest tests/kernels/test_cutlass.py`.
"""
5
from typing import Optional, Type
6
7
8
9

import pytest
import torch

10
from tests.kernels.utils import opcheck
11
from vllm import _custom_ops as ops
12
from vllm.platforms import current_platform
13
14
15
16
17
18

CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]


19
def to_fp8(tensor: torch.Tensor):
20
21
22
23
24
    finfo = torch.finfo(torch.float8_e4m3fn)
    return torch.round(tensor.clamp(
        min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)


25
def to_int8(tensor: torch.Tensor):
26
27
28
    return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)


29
30
31
32
def rand_int8(shape: tuple, device: str = "cuda"):
    return to_int8(torch.rand(shape, device=device) * 255 - 128)


33
34
35
36
37
38
39
40
41
42
43
44
45
46
def baseline_scaled_mm(a: torch.Tensor,
                       b: torch.Tensor,
                       scale_a: torch.Tensor,
                       scale_b: torch.Tensor,
                       out_dtype: Type[torch.dtype],
                       bias: Optional[torch.Tensor] = None) -> torch.Tensor:
    output = (scale_a * (scale_b * (torch.mm(
        a.to(dtype=torch.float32), b.to(dtype=torch.float32))))).to(out_dtype)
    if bias is not None:
        output = output + bias

    return output


47
48
49
50
51
def cutlass_fp8_gemm_helper(m: int,
                            n: int,
                            k: int,
                            per_token_act_quant: bool,
                            per_out_channel_weight_quant: bool,
52
                            use_bias: bool,
53
54
55
56
57
58
59
60
61
62
                            out_dtype: Type[torch.dtype] = torch.bfloat16,
                            device: str = "cuda"):
    # Test for a cutlass kernel with per-token activation quantization
    # and per-output channel weight quantization.
    a = to_fp8(torch.randn((m, k), device=device))
    b = to_fp8(torch.randn((n, k), device=device).t())

    m_a_scales = m if per_token_act_quant else 1
    n_b_scales = n if per_out_channel_weight_quant else 1

63
64
65
66
67
68
    scale_a = (torch.randn((m_a_scales, 1), device=device,
                           dtype=torch.float32))
    scale_b = (torch.randn((1, n_b_scales), device=device,
                           dtype=torch.float32))
    if use_bias:
        bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
69
    else:
70
        bias = None
71

72
73
    out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
    baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
74

75
    torch.testing.assert_close(out, baseline, rtol=1e-2, atol=5e-2)
76
77
78
79
80
81
82


def cutlass_int8_gemm_helper(m: int,
                             n: int,
                             k: int,
                             per_token_act_quant: bool,
                             per_out_channel_weight_quant: bool,
83
                             use_bias: bool,
84
85
86
87
88
89
90
91
92
93
                             out_dtype: Type[torch.dtype] = torch.bfloat16,
                             device: str = "cuda"):
    # Test for a cutlass kernel with per-token activation quantization
    # and per-output channel weight quantization.
    a = to_int8(torch.randn((m, k), device=device) * 5)
    b = to_int8(torch.randn((n, k), device=device).t() * 5)

    m_a_scales = m if per_token_act_quant else 1
    n_b_scales = n if per_out_channel_weight_quant else 1

94
95
96
97
    scale_a = (torch.randn((m_a_scales, 1), device=device,
                           dtype=torch.float32))
    scale_b = (torch.randn((1, n_b_scales), device=device,
                           dtype=torch.float32))
98

99
100
    if use_bias:
        bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
101
    else:
102
103
104
105
        bias = None

    out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
    baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
106

107
    torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
108

109
110
111
    opcheck(torch.ops._C.cutlass_scaled_mm,
            (out, a, b, scale_a, scale_b, bias))

112

113
114
@pytest.mark.parametrize("m", [1, 16, 32, 64, 128, 256, 512, 222, 100, 33])
@pytest.mark.parametrize("n", [2048, 4096, 8192, 16384, 24576, 256, 1024])
115
116
117
@pytest.mark.parametrize("k", [128, 496, 1024])
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
118
@pytest.mark.parametrize("use_bias", [True, False])
119
@pytest.mark.skipif(not current_platform.has_device_capability(89),
120
121
                    reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool,
122
123
                          per_out_ch: bool, use_bias: bool):
    cutlass_fp8_gemm_helper(m, n, k, per_act_token, per_out_ch, use_bias)
124
125


126
127
@pytest.mark.parametrize("m", [1, 16, 32, 64, 128, 256, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [2048, 8192, 16384, 256, 1024])
128
129
130
@pytest.mark.parametrize("k", [128, 496, 1024])
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
131
@pytest.mark.parametrize("use_bias", [True, False])
132
def test_cutlass_int8_gemm(m: int, n: int, k: int, per_act_token: bool,
133
134
                           per_out_ch: bool, use_bias: bool):
    cutlass_int8_gemm_helper(m, n, k, per_act_token, per_out_ch, use_bias)
135
136
137
138
139


@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
140
@pytest.mark.parametrize("use_bias", [True, False])
141
def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
142
                                        out_dtype: Type[torch.dtype],
143
                                        use_bias: bool):
144
145
146
147
148
    cutlass_int8_gemm_helper(512,
                             512,
                             512,
                             per_act_token,
                             per_out_ch,
149
                             use_bias,
150
                             out_dtype=out_dtype)
151
152
153
154
155


@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
156
@pytest.mark.parametrize("use_bias", [True, False])
157
@pytest.mark.skipif(not current_platform.has_device_capability(89),
158
159
                    reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
160
                                       out_dtype: Type[torch.dtype],
161
                                       use_bias: bool):
162
163
164
165
166
    cutlass_fp8_gemm_helper(512,
                            512,
                            512,
                            per_act_token,
                            per_out_ch,
167
                            use_bias,
168
                            out_dtype=out_dtype)
169
170
171
172


@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
173
@pytest.mark.parametrize("use_bias", [True, False])
174
@pytest.mark.parametrize("device", CUDA_DEVICES)
175
@pytest.mark.skipif(not current_platform.has_device_capability(89),
176
177
                    reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm_devices(per_act_token: bool, per_out_ch: bool,
178
179
                                  use_bias: bool, device: str):
    cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, use_bias,
180
181
182
183
184
                            torch.bfloat16, device)


@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
185
@pytest.mark.parametrize("use_bias", [True, False])
186
187
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
188
                                   use_bias: bool, device: str):
189
190
191
192
193
    cutlass_int8_gemm_helper(512,
                             512,
                             512,
                             per_act_token,
                             per_out_ch,
194
                             use_bias,
195
196
                             out_dtype=torch.bfloat16,
                             device=device)
197
198
199
200
201
202
203
204
205


# For the following two tests:
# N and K correspond to the size of the weight matrix and likely to be multiples
# of a large power of two. In any case, the kernel will have a naive fallback
# when N and K are not divisible by 16. But M is the number of tokens and the
# kernel must handle any M thrown at it.
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
206
@pytest.mark.parametrize("use_bias", [True, False])
207
@pytest.mark.skipif(not current_platform.has_device_capability(89),
208
                    reason="FP8 is not supported on this GPU type.")
209
def test_cutlass_fp8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
210
                                  use_bias: bool):
211
212
    for nk in range(32, 128, 32):
        for m in range(1, 128):
213
214
            cutlass_fp8_gemm_helper(m, nk, nk, per_act_token, per_out_ch,
                                    use_bias)
215
216
217
218


@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
219
@pytest.mark.parametrize("use_bias", [True, False])
220
def test_cutlass_int8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
221
                                   use_bias: bool):
222
223
    for nk in range(32, 128, 32):
        for m in range(1, 128):
224
            cutlass_int8_gemm_helper(m, nk, nk, per_act_token, per_out_ch,
225
                                     use_bias)
226
227


228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
@pytest.mark.parametrize("m", [32, 64, 128])
@pytest.mark.parametrize("n", [16, 32, 64])
@pytest.mark.parametrize("k", [64, 128, 256])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.skip
def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
                                    out_dtype: torch.dtype):
    # Currently, the test is failing because folding azp into
    # 16-bit bias loses too much precision
    scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
    scale_b = torch.randn((1, n), device="cuda", dtype=torch.float32) / 10

    aq_i8 = rand_int8((m, k))
    bq_i8 = rand_int8((n, k)).t()

    aq_i32 = aq_i8.to(dtype=torch.int32)
    bq_i32 = bq_i8.to(dtype=torch.int32)

    aq_f32 = aq_i8.to(dtype=torch.float32)
    bq_f32 = bq_i8.to(dtype=torch.float32)

    b_dq = scale_b * bq_f32

    azp_a = torch.rand((1, ), device="cuda", dtype=torch.float32) * 10 + 1.5
    azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8)
    azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a  # correct for rounding

    a_dq = scale_a * (aq_i32 + azp_aq_i8).to(dtype=torch.float32)
256
    torch.testing.assert_close(a_dq, scale_a * aq_f32 + azp_a)
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274

    baseline_dq = torch.mm(a_dq, b_dq).to(out_dtype)

    J = torch.ones((1, k), device="cuda", dtype=torch.float32)
    azp_bias = (azp_a * scale_b * (J @ bq_f32)).to(out_dtype)
    assert azp_bias.shape == (1, n)
    assert azp_bias[0, :].shape == (n, )

    baseline_q = (scale_a.to(device='cpu') * scale_b.to(device='cpu') * (
        (aq_i32 + azp_aq_i8).to(device='cpu') @ bq_i32.to(device='cpu'))).to(
            dtype=out_dtype, device='cuda')

    out = ops.cutlass_scaled_mm(aq_i8,
                                bq_i8,
                                scale_a,
                                scale_b,
                                out_dtype=out_dtype,
                                bias=azp_bias[0, :])
275
276
    torch.testing.assert_close(out, baseline_dq, rtol=1e-2, atol=1e0)
    torch.testing.assert_close(out, baseline_q, rtol=1e-2, atol=1e0)
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305


@pytest.mark.parametrize("m", [32, 64, 128])
@pytest.mark.parametrize("n", [16, 32, 64])
@pytest.mark.parametrize("k", [64, 128, 256])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.parametrize("azp_per_token", [True, False])
def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
                          use_bias: bool, azp_per_token: bool):
    m_azp = m if azp_per_token else 1
    scale_a = torch.randn((m_azp, 1), device="cuda", dtype=torch.float32) / 10
    scale_b = torch.randn((1, n), device="cuda", dtype=torch.float32) / 10

    aq_i8 = rand_int8((m, k))
    aq_i32 = aq_i8.to(dtype=torch.int32)
    aq_f32 = aq_i8.to(dtype=torch.float32)

    bq_i8 = rand_int8((n, k)).t()
    bq_i32 = bq_i8.to(dtype=torch.int32)
    bq_f32 = bq_i8.to(dtype=torch.float32)
    b_dq = scale_b * bq_f32

    azp_a = torch.rand(
        (m_azp, 1), device="cuda", dtype=torch.float32) * 10 + 1.5
    azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8)
    azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a  # correct for rounding

    a_dq = scale_a * (aq_i32 - azp_aq_i8).to(dtype=torch.float32)
306
307
308
309
    torch.testing.assert_close(a_dq,
                               scale_a * aq_f32 - azp_a,
                               rtol=1e-4,
                               atol=1e-3)
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341

    if use_bias:
        bias = torch.rand((1, n), device="cuda", dtype=out_dtype) * 10 + 2.5
    else:
        bias = torch.zeros((1, n), device="cuda", dtype=out_dtype)

    baseline_dq = (torch.mm(a_dq, b_dq) + bias).to(out_dtype)

    # int32 mm not supported on CUDA
    a_noazp_i32_cpu = (aq_i32 - azp_aq_i8).to(device='cpu')
    cq = (a_noazp_i32_cpu @ bq_i32.to(device='cpu')).to(device='cuda')
    baseline_q = (scale_a * scale_b * cq + bias).to(dtype=out_dtype)

    # Hadamard is just the sum of the cols
    azp_adj_i32 = bq_i32.sum(dim=0, keepdim=True, dtype=torch.int32)
    azp_i32 = azp_aq_i8.to(dtype=torch.int32)
    func_bias = bias if use_bias else None

    if azp_per_token:
        out = ops.cutlass_scaled_mm_azp(aq_i8, bq_i8, scale_a, scale_b,
                                        out_dtype, azp_adj_i32, azp_i32,
                                        func_bias)
    else:
        azp_with_adj_i32 = azp_i32 * azp_adj_i32
        out = ops.cutlass_scaled_mm_azp(aq_i8, bq_i8, scale_a, scale_b,
                                        out_dtype, azp_with_adj_i32, None,
                                        func_bias)

    # bfloat16 precision is 7-bit mantissa -> 2^-8 ~ 0.4%
    # float16 precision is 10-bit mantissa -> 2^-11 ~ 0.05%
    rtol = 1e-2 if out_dtype == torch.bfloat16 else 1e-3
    atol = 1e-3
342
343
    torch.testing.assert_close(out, baseline_dq, rtol=rtol, atol=atol)
    torch.testing.assert_close(out, baseline_q, rtol=rtol, atol=atol)
344

345
346
347
348
349
350
351
352
353
    if azp_per_token:
        opcheck(torch.ops._C.cutlass_scaled_mm_azp,
                (out, aq_i8, bq_i8, scale_a, scale_b, azp_adj_i32, azp_i32,
                 func_bias))
    else:
        opcheck(torch.ops._C.cutlass_scaled_mm_azp,
                (out, aq_i8, bq_i8, scale_a, scale_b, azp_with_adj_i32, None,
                 func_bias))

354

355
356
357
358
359
360
361
362
363
364
365
366
367
# Test working with a subset of A and B
def test_cutlass_subset():
    big_m, big_n, big_k = 1024, 1024, 1024
    m, n, k = 512, 512, 512

    whole_a = to_int8(torch.randn((big_m, big_k), device="cuda") * 5)
    whole_b = to_int8(torch.randn((big_n, big_k), device="cuda").t() * 5)
    a = whole_a[0:m, 0:k]
    b = whole_b[0:k, 0:n]

    scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
    scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10

368
369
370
371
372
    out = ops.cutlass_scaled_mm(a,
                                b,
                                scale_a,
                                scale_b,
                                out_dtype=torch.bfloat16)
373
374
375
376
377
    baseline = baseline_scaled_mm(a,
                                  b,
                                  scale_a,
                                  scale_b,
                                  out_dtype=torch.bfloat16)
378

379
    torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
380
381
382
383
384
385
386
387
388
389
390
391
392


# Test to make sure cuda graphs work
class CutlassLayer(torch.nn.Module):

    def __init__(self, b, scale_a, scale_b, out_dtype):
        super().__init__()
        self.b = b
        self.scale_a = scale_a
        self.scale_b = scale_b
        self.out_dtype = out_dtype

    def forward(self, a):
393
394
        return ops.cutlass_scaled_mm(a, self.b, self.scale_a, self.scale_b,
                                     self.out_dtype)
395
396


397
398
399
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool):
400
401
402
403
404
    m, n, k = 512, 512, 512

    a = to_int8(torch.randn((m, k), device="cuda"))
    b = to_int8(torch.randn((n, k), device="cuda").t())

405
406
407
408
409
410
411
    m_a_scales = m if per_act_token else 1
    n_b_scales = n if per_out_ch else 1

    scale_a = (torch.randn(
        (m_a_scales, 1), device="cuda", dtype=torch.float32) / 10)
    scale_b = (torch.randn(
        (1, n_b_scales), device="cuda", dtype=torch.float32) / 10)
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426

    # Construct a trivial model with a single layer that calls a CUTLASS kernel
    model = CutlassLayer(b, scale_a, scale_b, torch.bfloat16)

    # Run the model with a cuda graph
    stream = torch.cuda.Stream()
    with torch.cuda.stream(stream):
        g = torch.cuda.CUDAGraph()
        with torch.cuda.graph(g):
            out = model(a)
    out.zero_()
    g.replay()

    baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
                        scale_b * b.to(dtype=torch.float32)).to(torch.bfloat16)
427
    torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)