"vllm/vscode:/vscode.git/clone" did not exist on "3a92c6f3b5f010453d81f871f642839c15402cda"
test_cutlass.py 20.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
"""Tests for cutlass kernels

Run `pytest tests/kernels/test_cutlass.py`.
"""
zhuwenwen's avatar
zhuwenwen committed
6
from typing import Type, Optional
7
8
9
10

import pytest
import torch

11
from tests.kernels.utils import opcheck
12
from vllm import _custom_ops as ops
13
from vllm.platforms import current_platform
14
from vllm.utils import cdiv
15

16
from .utils import baseline_scaled_mm, to_fp8, to_int8
17

18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
MNK_FACTORS = [
    (1, 256, 128),
    (1, 16384, 1024),
    (1, 24576, 496),
    (16, 256, 496),
    (16, 16384, 128),
    (16, 24576, 4096),
    (32, 8192, 4096),
    (32, 16384, 4096),
    (33, 1024, 1024),
    (33, 8192, 128),
    (64, 2048, 496),
    (64, 16384, 1024),
    (100, 8192, 496),
    (128, 32768, 4096),
    (256, 4096, 4096),
    (512, 256, 1024),
    (512, 8192, 4096),
    (512, 16384, 128),
    (512, 24576, 128),
]
39
40

CUDA_DEVICES = [
41
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
42
43
]

44
45
46
47
48
# -1 means full extent in that dimension
TENSORWISE_GROUP_SHAPE = (-1, -1)
PER_TOKEN_GROUP_SHAPE = (1, -1)
PER_OUT_CH_GROUP_SHAPE = (-1, 1)

49
50
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
51

52

53
54
55
56
def rand_int8(shape: tuple, device: str = "cuda"):
    return to_int8(torch.rand(shape, device=device) * 255 - 128)


57
58
59
60
61
62
def baseline_scaled_mm(a: torch.Tensor,
                       b: torch.Tensor,
                       scale_a: torch.Tensor,
                       scale_b: torch.Tensor,
                       out_dtype: Type[torch.dtype],
                       bias: Optional[torch.Tensor] = None) -> torch.Tensor:
gaoqiong's avatar
gaoqiong committed
63
    output = (scale_a * (scale_b.T * (torch.mm(
64
65
66
67
        a.to(dtype=torch.float32), b.to(dtype=torch.float32))))).to(out_dtype)
    if bias is not None:
        output = output + bias

68
69
70
71
72
73
74
75
76
def group_scale_helper(shape, group_shape):
    return [shape[i] if s < 0 else s for i, s in enumerate(group_shape)]


def scale_shape(shape, group_shape):
    assert len(shape) == len(group_shape)
    group_shape = group_scale_helper(shape, group_shape)
    return tuple(
        cdiv(shape[i], group_shape[i]) for i in range(len(group_shape)))
77
78


79
80
81
def cutlass_fp8_gemm_helper(m: int,
                            n: int,
                            k: int,
82
83
                            a_scale_group_shape: tuple,
                            b_scale_group_shape: tuple,
84
                            use_bias: bool,
85
86
87
88
89
90
91
                            out_dtype: Type[torch.dtype] = torch.bfloat16,
                            device: str = "cuda"):
    # Test for a cutlass kernel with per-token activation quantization
    # and per-output channel weight quantization.
    a = to_fp8(torch.randn((m, k), device=device))
    b = to_fp8(torch.randn((n, k), device=device).t())

92
93
94
95
96
97
98
99
100
101
    a_scales_shape = scale_shape(a.shape, a_scale_group_shape)
    b_scales_shape = scale_shape(b.shape, b_scale_group_shape)

    scale_a = (torch.randn(a_scales_shape, device=device, dtype=torch.float32))
    scale_b = (torch.randn(b_scales_shape, device=device, dtype=torch.float32))

    # make scales M-major for blockwise quant, doesn't affect 1D scales
    scale_a = scale_a.t().contiguous().t()
    # make scales K-major for blockwise quant, doesn't affect 1D scales
    scale_b = scale_b.t().contiguous().t()
102

103
104
    if use_bias:
        bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
105
    else:
106
        bias = None
107

108
109
    out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
    baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
110

111
    torch.testing.assert_close(out, baseline, rtol=1e-2, atol=5e-2)
112

113
114
115
    opcheck(torch.ops._C.cutlass_scaled_mm,
            (out, a, b, scale_a, scale_b, bias))

116
117
118
119

def cutlass_int8_gemm_helper(m: int,
                             n: int,
                             k: int,
120
121
                             a_scale_group_shape: tuple,
                             b_scale_group_shape: tuple,
122
                             use_bias: bool,
123
124
125
126
127
128
129
                             out_dtype: Type[torch.dtype] = torch.bfloat16,
                             device: str = "cuda"):
    # Test for a cutlass kernel with per-token activation quantization
    # and per-output channel weight quantization.
    a = to_int8(torch.randn((m, k), device=device) * 5)
    b = to_int8(torch.randn((n, k), device=device).t() * 5)

130
131
    a_scales_shape = scale_shape(a.shape, a_scale_group_shape)
    b_scales_shape = scale_shape(b.shape, b_scale_group_shape)
132

133
134
    scale_a = (torch.randn(a_scales_shape, device=device, dtype=torch.float32))
    scale_b = (torch.randn(b_scales_shape, device=device, dtype=torch.float32))
135

136
137
    if use_bias:
        bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
138
    else:
139
140
141
142
        bias = None

    out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
    baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
143

144
    torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
145

146
147
148
    opcheck(torch.ops._C.cutlass_scaled_mm,
            (out, a, b, scale_a, scale_b, bias))

149

zhuwenwen's avatar
zhuwenwen committed
150
# @pytest.mark.parametrize("m,n,k", MNK_FACTORS)
zhuwenwen's avatar
zhuwenwen committed
151
152
153
154
# @pytest.mark.parametrize("a_scale_group_shape",
#                          [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
# @pytest.mark.parametrize("b_scale_group_shape",
#                          [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
gaoqiong's avatar
gaoqiong committed
155
156
157
# @pytest.mark.parametrize("use_bias", [True, False])
# @pytest.mark.skipif(not current_platform.has_device_capability(89),
#                     reason="FP8 is not supported on this GPU type.")
zhuwenwen's avatar
zhuwenwen committed
158
159
160
161
# def test_cutlass_fp8_gemm(m: int, n: int, k: int, a_scale_group_shape,
#                           b_scale_group_shape, use_bias: bool):
#     cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape,
#                             use_bias)
162
163


164
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
165
166
167
168
169
170
171
172
173
174
175
176
177
178
@pytest.mark.parametrize("a_scale_group_shape,b_scale_group_shape",
                         [((1, 128), (128, 128))])
@pytest.mark.parametrize("use_bias", [False])
@pytest.mark.skipif(not current_platform.has_device_capability(90),
                    reason="FP8 blockwise is not supported on this GPU type.")
def test_cutlass_fp8_blockwise_scale_gemm(m: int, n: int, k: int,
                                          a_scale_group_shape,
                                          b_scale_group_shape, use_bias: bool):
    if k % b_scale_group_shape[0] != 0 or n % b_scale_group_shape[1] != 0:
        return
    if m % a_scale_group_shape[0] != 0 or k % a_scale_group_shape[1] != 0:
        return
    cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape,
                            use_bias)
179
180


181
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
182
183
184
185
@pytest.mark.parametrize("a_scale_group_shape",
                         [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize("b_scale_group_shape",
                         [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
186
@pytest.mark.parametrize("use_bias", [True, False])
187
188
189
190
def test_cutlass_int8_gemm(m: int, n: int, k: int, a_scale_group_shape,
                           b_scale_group_shape, use_bias: bool):
    cutlass_int8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape,
                             use_bias)
191
192


193
194
195
196
@pytest.mark.parametrize("a_scale_group_shape",
                         [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize("b_scale_group_shape",
                         [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
197
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
198
@pytest.mark.parametrize("use_bias", [True, False])
199
200
def test_cutlass_int8_gemm_output_dtype(a_scale_group_shape,
                                        b_scale_group_shape,
201
                                        out_dtype: Type[torch.dtype],
202
                                        use_bias: bool):
203
204
205
    cutlass_int8_gemm_helper(512,
                             512,
                             512,
206
207
                             a_scale_group_shape,
                             b_scale_group_shape,
208
                             use_bias,
209
                             out_dtype=out_dtype)
210
211


zhuwenwen's avatar
zhuwenwen committed
212
213
214
215
# @pytest.mark.parametrize("a_scale_group_shape",
#                          [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
# @pytest.mark.parametrize("b_scale_group_shape",
#                          [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
gaoqiong's avatar
gaoqiong committed
216
217
218
219
# @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
# @pytest.mark.parametrize("use_bias", [True, False])
# @pytest.mark.skipif(not current_platform.has_device_capability(89),
#                     reason="FP8 is not supported on this GPU type.")
zhuwenwen's avatar
zhuwenwen committed
220
221
# def test_cutlass_fp8_gemm_output_dtype(a_scale_group_shape,
#                                        b_scale_group_shape,
gaoqiong's avatar
gaoqiong committed
222
223
224
225
226
#                                        out_dtype: Type[torch.dtype],
#                                        use_bias: bool):
#     cutlass_fp8_gemm_helper(512,
#                             512,
#                             512,
zhuwenwen's avatar
zhuwenwen committed
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
#                             a_scale_group_shape,
#                             b_scale_group_shape,
#                             use_bias,
#                             out_dtype=out_dtype)


# @pytest.mark.parametrize("a_scale_group_shape,b_scale_group_shape",
#                          [((1, 128), (128, 128))])
# @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
# @pytest.mark.parametrize("use_bias", [False])
# @pytest.mark.skipif(not current_platform.has_device_capability(90),
#                     reason="FP8 blockwise is not supported on this GPU type.")
# def test_cutlass_fp8_blockwise_scale_gemm_dtype(a_scale_group_shape,
#                                                 b_scale_group_shape,
#                                                 out_dtype: Type[torch.dtype],
#                                                 use_bias: bool):
#     cutlass_fp8_gemm_helper(512,
#                             512,
#                             512,
#                             a_scale_group_shape,
#                             b_scale_group_shape,
gaoqiong's avatar
gaoqiong committed
248
249
250
251
#                             use_bias,
#                             out_dtype=out_dtype)


zhuwenwen's avatar
zhuwenwen committed
252
253
254
255
# @pytest.mark.parametrize("a_scale_group_shape",
#                          [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
# @pytest.mark.parametrize("b_scale_group_shape",
#                          [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
gaoqiong's avatar
gaoqiong committed
256
257
258
259
# @pytest.mark.parametrize("use_bias", [True, False])
# @pytest.mark.parametrize("device", CUDA_DEVICES)
# @pytest.mark.skipif(not current_platform.has_device_capability(89),
#                     reason="FP8 is not supported on this GPU type.")
zhuwenwen's avatar
zhuwenwen committed
260
# def test_cutlass_fp8_gemm_devices(a_scale_group_shape, b_scale_group_shape,
gaoqiong's avatar
gaoqiong committed
261
#                                   use_bias: bool, device: str):
zhuwenwen's avatar
zhuwenwen committed
262
263
264
#     cutlass_fp8_gemm_helper(512, 512, 512, a_scale_group_shape,
#                             b_scale_group_shape, use_bias, torch.bfloat16,
#                             device)
gaoqiong's avatar
gaoqiong committed
265
266


zhuwenwen's avatar
zhuwenwen committed
267
268
269
270
# @pytest.mark.parametrize("a_scale_group_shape",
#                          [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
# @pytest.mark.parametrize("b_scale_group_shape",
#                          [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
gaoqiong's avatar
gaoqiong committed
271
272
# @pytest.mark.parametrize("use_bias", [True, False])
# @pytest.mark.parametrize("device", CUDA_DEVICES)
zhuwenwen's avatar
zhuwenwen committed
273
# def test_cutlass_int8_gemm_devices(a_scale_group_shape, b_scale_group_shape,
274
                                   use_bias: bool, device: str):
275
276
277
    cutlass_int8_gemm_helper(512,
                             512,
                             512,
278
279
                             a_scale_group_shape,
                             b_scale_group_shape,
280
                             use_bias,
281
282
                             out_dtype=torch.bfloat16,
                             device=device)
283
284
285
286
287
288
289
290



# For the following two tests:
# N and K correspond to the size of the weight matrix and likely to be multiples
# of a large power of two. In any case, the kernel will have a naive fallback
# when N and K are not divisible by 16. But M is the number of tokens and the
# kernel must handle any M thrown at it.
zhuwenwen's avatar
zhuwenwen committed
291
292
293
294
# @pytest.mark.parametrize("a_scale_group_shape",
#                          [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
# @pytest.mark.parametrize("b_scale_group_shape",
#                          [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
gaoqiong's avatar
gaoqiong committed
295
296
297
# @pytest.mark.parametrize("use_bias", [True, False])
# @pytest.mark.skipif(not current_platform.has_device_capability(89),
#                     reason="FP8 is not supported on this GPU type.")
zhuwenwen's avatar
zhuwenwen committed
298
# def test_cutlass_fp8_gemm_m_sweep(a_scale_group_shape, b_scale_group_shape,
gaoqiong's avatar
gaoqiong committed
299
300
301
#                                   use_bias: bool):
#     for nk in range(32, 128, 32):
#         for m in range(1, 128):
zhuwenwen's avatar
zhuwenwen committed
302
303
#             cutlass_fp8_gemm_helper(m, nk, nk, a_scale_group_shape,
#                                     b_scale_group_shape, use_bias)
gaoqiong's avatar
gaoqiong committed
304
305


zhuwenwen's avatar
zhuwenwen committed
306
307
308
309
# @pytest.mark.parametrize("a_scale_group_shape",
#                          [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
# @pytest.mark.parametrize("b_scale_group_shape",
#                          [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
gaoqiong's avatar
gaoqiong committed
310
# @pytest.mark.parametrize("use_bias", [True, False])
zhuwenwen's avatar
zhuwenwen committed
311
# def test_cutlass_int8_gemm_m_sweep(a_scale_group_shape, b_scale_group_shape,
312
                                   use_bias: bool):
313
314
    for nk in range(32, 128, 32):
        for m in range(1, 128):
315
316
            cutlass_int8_gemm_helper(m, nk, nk, a_scale_group_shape,
                                     b_scale_group_shape, use_bias)
gaoqiong's avatar
gaoqiong committed
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373


# @pytest.mark.parametrize("m", [32, 64, 128])
# @pytest.mark.parametrize("n", [16, 32, 64])
# @pytest.mark.parametrize("k", [64, 128, 256])
# @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
# @pytest.mark.skip
# def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
#                                     out_dtype: torch.dtype):
#     # Currently, the test is failing because folding azp into
#     # 16-bit bias loses too much precision
#     scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
#     scale_b = torch.randn((1, n), device="cuda", dtype=torch.float32) / 10

#     aq_i8 = rand_int8((m, k))
#     bq_i8 = rand_int8((n, k)).t()

#     aq_i32 = aq_i8.to(dtype=torch.int32)
#     bq_i32 = bq_i8.to(dtype=torch.int32)

#     aq_f32 = aq_i8.to(dtype=torch.float32)
#     bq_f32 = bq_i8.to(dtype=torch.float32)

#     b_dq = scale_b * bq_f32

#     azp_a = torch.rand((1, ), device="cuda", dtype=torch.float32) * 10 + 1.5
#     azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8)
#     azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a  # correct for rounding

#     a_dq = scale_a * (aq_i32 + azp_aq_i8).to(dtype=torch.float32)
#     torch.testing.assert_close(a_dq, scale_a * aq_f32 + azp_a)

#     baseline_dq = torch.mm(a_dq, b_dq).to(out_dtype)

#     J = torch.ones((1, k), device="cuda", dtype=torch.float32)
#     azp_bias = (azp_a * scale_b * (J @ bq_f32)).to(out_dtype)
#     assert azp_bias.shape == (1, n)
#     assert azp_bias[0, :].shape == (n, )

#     baseline_q = (scale_a.to(device='cpu') * scale_b.to(device='cpu') * (
#         (aq_i32 + azp_aq_i8).to(device='cpu') @ bq_i32.to(device='cpu'))).to(
#             dtype=out_dtype, device='cuda')

#     out = ops.cutlass_scaled_mm(aq_i8,
#                                 bq_i8,
#                                 scale_a,
#                                 scale_b,
#                                 out_dtype=out_dtype,
#                                 bias=azp_bias[0, :])
#     torch.testing.assert_close(out, baseline_dq, rtol=1e-2, atol=1e0)
#     torch.testing.assert_close(out, baseline_q, rtol=1e-2, atol=1e0)


# @pytest.mark.parametrize("m", [32, 64, 128])
# @pytest.mark.parametrize("n", [16, 32, 64])
# @pytest.mark.parametrize("k", [64, 128, 256])
# @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
zhuwenwen's avatar
zhuwenwen committed
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
# @pytest.mark.parametrize("use_bias", [True, False])
# @pytest.mark.parametrize("azp_per_token", [True, False])
# def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
    #                       use_bias: bool, azp_per_token: bool):
    # m_azp = m if azp_per_token else 1
    # scale_a = torch.randn((m_azp, 1), device="cuda", dtype=torch.float32) / 10
    # scale_b = torch.randn((1, n), device="cuda", dtype=torch.float32) / 10

    # aq_i8 = rand_int8((m, k))
    # aq_i32 = aq_i8.to(dtype=torch.int32)
    # aq_f32 = aq_i8.to(dtype=torch.float32)

    # bq_i8 = rand_int8((n, k)).t()
    # bq_i32 = bq_i8.to(dtype=torch.int32)
    # bq_f32 = bq_i8.to(dtype=torch.float32)
    # b_dq = scale_b * bq_f32

    # azp_a = torch.rand(
    #     (m_azp, 1), device="cuda", dtype=torch.float32) * 10 + 1.5
    # azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8)
    # azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a  # correct for rounding

    # a_dq = scale_a * (aq_i32 - azp_aq_i8).to(dtype=torch.float32)
    # torch.testing.assert_close(a_dq,
    #                            scale_a * aq_f32 - azp_a,
    #                            rtol=1e-4,
    #                            atol=1e-3)

    # if use_bias:
    #     bias = torch.rand((1, n), device="cuda", dtype=out_dtype) * 10 + 2.5
    # else:
    #     bias = torch.zeros((1, n), device="cuda", dtype=out_dtype)
zhuwenwen's avatar
zhuwenwen committed
406

407

zhuwenwen's avatar
zhuwenwen committed
408
409
410
411
412
413
414
415
416
417
418
419
#     baseline_q = (scale_a.to(device='cpu') * scale_b.to(device='cpu') * (
#         (aq_i32 + azp_aq_i8).to(device='cpu') @ bq_i32.to(device='cpu'))).to(
#             dtype=out_dtype, device='cuda')

#     out = ops.cutlass_scaled_mm(aq_i8,
#                                 bq_i8,
#                                 scale_a,
#                                 scale_b,
#                                 out_dtype=out_dtype,
#                                 bias=azp_bias[0, :])
#     torch.testing.assert_close(out, baseline_dq, rtol=1e-2, atol=1e0)
#     torch.testing.assert_close(out, baseline_q, rtol=1e-2, atol=1e0)
420

421

422
423
424
425
426
427
428
429
430
# Test working with a subset of A and B
def test_cutlass_subset():
    big_m, big_n, big_k = 1024, 1024, 1024
    m, n, k = 512, 512, 512

    whole_a = to_int8(torch.randn((big_m, big_k), device="cuda") * 5)
    whole_b = to_int8(torch.randn((big_n, big_k), device="cuda").t() * 5)
    a = whole_a[0:m, 0:k]
    b = whole_b[0:k, 0:n]
gaoqiong's avatar
gaoqiong committed
431
432
433
434
435
    
    #变成连续内存,矩阵子模块目前不支持计算,需要重新计算lda
    a=a.contiguous().reshape(m,-1)
    b=b.contiguous().reshape(k,-1)
    
436
437
438
    scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
    scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10

439
440
441
442
443
    out = ops.cutlass_scaled_mm(a,
                                b,
                                scale_a,
                                scale_b,
                                out_dtype=torch.bfloat16)
444
445
446
447
448
    baseline = baseline_scaled_mm(a,
                                  b,
                                  scale_a,
                                  scale_b,
                                  out_dtype=torch.bfloat16)
449

450
    torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
451
452
453
454
455
456
457
458
459
460
461
462
463


# Test to make sure cuda graphs work
class CutlassLayer(torch.nn.Module):

    def __init__(self, b, scale_a, scale_b, out_dtype):
        super().__init__()
        self.b = b
        self.scale_a = scale_a
        self.scale_b = scale_b
        self.out_dtype = out_dtype

    def forward(self, a):
464
465
        return ops.cutlass_scaled_mm(a, self.b, self.scale_a, self.scale_b,
                                     self.out_dtype)
466

gaoqiong's avatar
gaoqiong committed
467
468
469
#目前只支持per-act-token+per-out-ch(fp16)
@pytest.mark.parametrize("per_act_token", [True])
@pytest.mark.parametrize("per_out_ch", [True])
470
def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool):
471
472
473
474
    m, n, k = 512, 512, 512

    a = to_int8(torch.randn((m, k), device="cuda"))
    b = to_int8(torch.randn((n, k), device="cuda").t())
gaoqiong's avatar
gaoqiong committed
475
476
    b=b.contiguous().reshape(k,-1)
    
477
478
479
480
481
482
    m_a_scales = m if per_act_token else 1
    n_b_scales = n if per_out_ch else 1

    scale_a = (torch.randn(
        (m_a_scales, 1), device="cuda", dtype=torch.float32) / 10)
    scale_b = (torch.randn(
gaoqiong's avatar
gaoqiong committed
483
        (n_b_scales,1), device="cuda", dtype=torch.float32) / 10)
484
485

    # Construct a trivial model with a single layer that calls a CUTLASS kernel
gaoqiong's avatar
gaoqiong committed
486
    model = CutlassLayer(b, scale_a, scale_b, torch.float16)
487
488
489
490
491
492
493
494
495
496
497

    # Run the model with a cuda graph
    stream = torch.cuda.Stream()
    with torch.cuda.stream(stream):
        g = torch.cuda.CUDAGraph()
        with torch.cuda.graph(g):
            out = model(a)
    out.zero_()
    g.replay()

    baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
gaoqiong's avatar
gaoqiong committed
498
499
500
501
502
                        scale_b.T * b.to(dtype=torch.float32)).to(torch.float16)
    #print("baseline:",baseline)
    out=ops.cutlass_scaled_mm(a, b, scale_a, scale_b,
                                     torch.float16)
    #print("out:",out)
503
    torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)