test_cutlass_scaled_mm.py 22.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
"""Tests for cutlass kernels

5
Run `pytest tests/kernels/quantization/test_cutlass_scaled_mm.py`.
6
"""
7

8
import random
9
10
11
12

import pytest
import torch

13
from tests.kernels.utils import baseline_scaled_mm, opcheck, to_fp8, to_int8
14
from vllm import _custom_ops as ops
15
from vllm.platforms import current_platform
16
from vllm.utils import cdiv
17

18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
MNK_FACTORS = [
    (1, 256, 128),
    (1, 16384, 1024),
    (1, 24576, 496),
    (16, 256, 496),
    (16, 16384, 128),
    (16, 24576, 4096),
    (32, 8192, 4096),
    (32, 16384, 4096),
    (33, 1024, 1024),
    (33, 8192, 128),
    (64, 2048, 496),
    (64, 16384, 1024),
    (100, 8192, 496),
    (128, 32768, 4096),
    (256, 4096, 4096),
    (512, 256, 1024),
    (512, 8192, 4096),
    (512, 16384, 128),
    (512, 24576, 128),
]

40
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
41

42
43
44
45
46
# -1 means full extent in that dimension
TENSORWISE_GROUP_SHAPE = (-1, -1)
PER_TOKEN_GROUP_SHAPE = (1, -1)
PER_OUT_CH_GROUP_SHAPE = (-1, 1)

47
48
49
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]

50

51
52
53
54
def rand_int8(shape: tuple, device: str = "cuda"):
    return to_int8(torch.rand(shape, device=device) * 255 - 128)


55
56
57
58
59
60
61
def group_scale_helper(shape, group_shape):
    return [shape[i] if s < 0 else s for i, s in enumerate(group_shape)]


def scale_shape(shape, group_shape):
    assert len(shape) == len(group_shape)
    group_shape = group_scale_helper(shape, group_shape)
62
63
64
65
66
67
68
69
70
71
72
73
74
    return tuple(cdiv(shape[i], group_shape[i]) for i in range(len(group_shape)))


def cutlass_fp8_gemm_helper(
    m: int,
    n: int,
    k: int,
    a_scale_group_shape: tuple,
    b_scale_group_shape: tuple,
    use_bias: bool,
    out_dtype: type[torch.dtype] = torch.bfloat16,
    device: str = "cuda",
):
75
76
77
78
79
    # Test for a cutlass kernel with per-token activation quantization
    # and per-output channel weight quantization.
    a = to_fp8(torch.randn((m, k), device=device))
    b = to_fp8(torch.randn((n, k), device=device).t())

80
81
82
    a_scales_shape = scale_shape(a.shape, a_scale_group_shape)
    b_scales_shape = scale_shape(b.shape, b_scale_group_shape)

83
84
    scale_a = torch.randn(a_scales_shape, device=device, dtype=torch.float32)
    scale_b = torch.randn(b_scales_shape, device=device, dtype=torch.float32)
85
86
87
88
89

    # make scales M-major for blockwise quant, doesn't affect 1D scales
    scale_a = scale_a.t().contiguous().t()
    # make scales K-major for blockwise quant, doesn't affect 1D scales
    scale_b = scale_b.t().contiguous().t()
90

91
    if use_bias:
92
        bias = torch.rand((n,), device=device, dtype=out_dtype) * 10
93
    else:
94
        bias = None
95

96
97
    out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
    baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
98

99
    torch.testing.assert_close(out, baseline, rtol=5e-1, atol=1.5e-1)
100

101
    opcheck(torch.ops._C.cutlass_scaled_mm, (out, a, b, scale_a, scale_b, bias))
102

103

104
105
106
107
108
109
110
111
112
113
def cutlass_int8_gemm_helper(
    m: int,
    n: int,
    k: int,
    a_scale_group_shape: tuple,
    b_scale_group_shape: tuple,
    use_bias: bool,
    out_dtype: type[torch.dtype] = torch.bfloat16,
    device: str = "cuda",
):
114
115
116
117
118
    # Test for a cutlass kernel with per-token activation quantization
    # and per-output channel weight quantization.
    a = to_int8(torch.randn((m, k), device=device) * 5)
    b = to_int8(torch.randn((n, k), device=device).t() * 5)

119
120
    a_scales_shape = scale_shape(a.shape, a_scale_group_shape)
    b_scales_shape = scale_shape(b.shape, b_scale_group_shape)
121

122
123
    scale_a = torch.randn(a_scales_shape, device=device, dtype=torch.float32)
    scale_b = torch.randn(b_scales_shape, device=device, dtype=torch.float32)
124

125
    if use_bias:
126
        bias = torch.rand((n,), device=device, dtype=out_dtype) * 10
127
    else:
128
129
130
131
        bias = None

    out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
    baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
132

133
    torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
134

135
    opcheck(torch.ops._C.cutlass_scaled_mm, (out, a, b, scale_a, scale_b, bias))
136

137

138
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
139
140
141
142
143
144
@pytest.mark.parametrize(
    "a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
@pytest.mark.parametrize(
    "b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
145
@pytest.mark.parametrize("use_bias", [True, False])
146
147
148
149
150
151
152
153
@pytest.mark.skipif(
    not current_platform.has_device_capability(89),
    reason="FP8 is not supported on this GPU type.",
)
def test_cutlass_fp8_gemm(
    m: int, n: int, k: int, a_scale_group_shape, b_scale_group_shape, use_bias: bool
):
    cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape, use_bias)
154
155
156


@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
157
158
159
@pytest.mark.parametrize(
    "a_scale_group_shape,b_scale_group_shape", [((1, 128), (128, 128))]
)
160
@pytest.mark.parametrize("use_bias", [False])
161
162
163
164
165
166
167
@pytest.mark.skipif(
    not current_platform.has_device_capability(90),
    reason="FP8 blockwise is not supported on this GPU type.",
)
def test_cutlass_fp8_blockwise_scale_gemm(
    m: int, n: int, k: int, a_scale_group_shape, b_scale_group_shape, use_bias: bool
):
168
169
170
171
    if k % b_scale_group_shape[0] != 0 or n % b_scale_group_shape[1] != 0:
        return
    if m % a_scale_group_shape[0] != 0 or k % a_scale_group_shape[1] != 0:
        return
172
173
    if m % 4 != 0 and current_platform.has_device_capability(100):
        return
174
    cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape, use_bias)
175
176


177
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
178
179
180
181
182
183
@pytest.mark.parametrize(
    "a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
@pytest.mark.parametrize(
    "b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
184
@pytest.mark.parametrize("use_bias", [True, False])
185
186
187
188
189
190
191
192
193
194
195
196
197
198
def test_cutlass_int8_gemm(
    m: int, n: int, k: int, a_scale_group_shape, b_scale_group_shape, use_bias: bool
):
    cutlass_int8_gemm_helper(
        m, n, k, a_scale_group_shape, b_scale_group_shape, use_bias
    )


@pytest.mark.parametrize(
    "a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
@pytest.mark.parametrize(
    "b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
199
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
200
@pytest.mark.parametrize("use_bias", [True, False])
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
def test_cutlass_int8_gemm_output_dtype(
    a_scale_group_shape,
    b_scale_group_shape,
    out_dtype: type[torch.dtype],
    use_bias: bool,
):
    cutlass_int8_gemm_helper(
        512,
        512,
        512,
        a_scale_group_shape,
        b_scale_group_shape,
        use_bias,
        out_dtype=out_dtype,
    )


@pytest.mark.parametrize(
    "a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
@pytest.mark.parametrize(
    "b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
224
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
225
@pytest.mark.parametrize("use_bias", [True, False])
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
@pytest.mark.skipif(
    not current_platform.has_device_capability(89),
    reason="FP8 is not supported on this GPU type.",
)
def test_cutlass_fp8_gemm_output_dtype(
    a_scale_group_shape,
    b_scale_group_shape,
    out_dtype: type[torch.dtype],
    use_bias: bool,
):
    cutlass_fp8_gemm_helper(
        512,
        512,
        512,
        a_scale_group_shape,
        b_scale_group_shape,
        use_bias,
        out_dtype=out_dtype,
    )


@pytest.mark.parametrize(
    "a_scale_group_shape,b_scale_group_shape", [((1, 128), (128, 128))]
)
250
251
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("use_bias", [False])
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
@pytest.mark.skipif(
    not current_platform.has_device_capability(90),
    reason="FP8 blockwise is not supported on this GPU type.",
)
def test_cutlass_fp8_blockwise_scale_gemm_dtype(
    a_scale_group_shape,
    b_scale_group_shape,
    out_dtype: type[torch.dtype],
    use_bias: bool,
):
    cutlass_fp8_gemm_helper(
        512,
        512,
        512,
        a_scale_group_shape,
        b_scale_group_shape,
        use_bias,
        out_dtype=out_dtype,
    )


@pytest.mark.parametrize(
    "a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
@pytest.mark.parametrize(
    "b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
279
@pytest.mark.parametrize("use_bias", [True, False])
280
@pytest.mark.parametrize("device", CUDA_DEVICES)
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
@pytest.mark.skipif(
    not current_platform.has_device_capability(89),
    reason="FP8 is not supported on this GPU type.",
)
def test_cutlass_fp8_gemm_devices(
    a_scale_group_shape, b_scale_group_shape, use_bias: bool, device: str
):
    cutlass_fp8_gemm_helper(
        512,
        512,
        512,
        a_scale_group_shape,
        b_scale_group_shape,
        use_bias,
        torch.bfloat16,
        device,
    )


@pytest.mark.parametrize(
    "a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
@pytest.mark.parametrize(
    "b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
306
@pytest.mark.parametrize("use_bias", [True, False])
307
@pytest.mark.parametrize("device", CUDA_DEVICES)
308
309
310
311
312
313
314
315
316
317
318
319
320
def test_cutlass_int8_gemm_devices(
    a_scale_group_shape, b_scale_group_shape, use_bias: bool, device: str
):
    cutlass_int8_gemm_helper(
        512,
        512,
        512,
        a_scale_group_shape,
        b_scale_group_shape,
        use_bias,
        out_dtype=torch.bfloat16,
        device=device,
    )
321
322
323
324
325
326
327


# For the following two tests:
# N and K correspond to the size of the weight matrix and likely to be multiples
# of a large power of two. In any case, the kernel will have a naive fallback
# when N and K are not divisible by 16. But M is the number of tokens and the
# kernel must handle any M thrown at it.
328
329
330
331
332
333
@pytest.mark.parametrize(
    "a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
@pytest.mark.parametrize(
    "b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
334
@pytest.mark.parametrize("use_bias", [True, False])
335
336
337
338
339
340
341
@pytest.mark.skipif(
    not current_platform.has_device_capability(89),
    reason="FP8 is not supported on this GPU type.",
)
def test_cutlass_fp8_gemm_m_sweep(
    a_scale_group_shape, b_scale_group_shape, use_bias: bool
):
342
343
    for nk in range(32, 128, 32):
        for m in range(1, 128):
344
345
346
            cutlass_fp8_gemm_helper(
                m, nk, nk, a_scale_group_shape, b_scale_group_shape, use_bias
            )
347
348


349
350
351
352
353
354
@pytest.mark.parametrize(
    "a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
@pytest.mark.parametrize(
    "b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
355
@pytest.mark.parametrize("use_bias", [True, False])
356
357
358
def test_cutlass_int8_gemm_m_sweep(
    a_scale_group_shape, b_scale_group_shape, use_bias: bool
):
359
360
    for nk in range(32, 128, 32):
        for m in range(1, 128):
361
362
363
            cutlass_int8_gemm_helper(
                m, nk, nk, a_scale_group_shape, b_scale_group_shape, use_bias
            )
364
365


366
367
368
369
370
@pytest.mark.parametrize("m", [32, 64, 128])
@pytest.mark.parametrize("n", [16, 32, 64])
@pytest.mark.parametrize("k", [64, 128, 256])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.skip
371
def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int, out_dtype: torch.dtype):
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
    # Currently, the test is failing because folding azp into
    # 16-bit bias loses too much precision
    scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
    scale_b = torch.randn((1, n), device="cuda", dtype=torch.float32) / 10

    aq_i8 = rand_int8((m, k))
    bq_i8 = rand_int8((n, k)).t()

    aq_i32 = aq_i8.to(dtype=torch.int32)
    bq_i32 = bq_i8.to(dtype=torch.int32)

    aq_f32 = aq_i8.to(dtype=torch.float32)
    bq_f32 = bq_i8.to(dtype=torch.float32)

    b_dq = scale_b * bq_f32

388
    azp_a = torch.rand((1,), device="cuda", dtype=torch.float32) * 10 + 1.5
389
390
391
392
    azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8)
    azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a  # correct for rounding

    a_dq = scale_a * (aq_i32 + azp_aq_i8).to(dtype=torch.float32)
393
    torch.testing.assert_close(a_dq, scale_a * aq_f32 + azp_a)
394
395
396
397
398
399

    baseline_dq = torch.mm(a_dq, b_dq).to(out_dtype)

    J = torch.ones((1, k), device="cuda", dtype=torch.float32)
    azp_bias = (azp_a * scale_b * (J @ bq_f32)).to(out_dtype)
    assert azp_bias.shape == (1, n)
400
401
402
403
404
405
406
407
408
409
410
    assert azp_bias[0, :].shape == (n,)

    baseline_q = (
        scale_a.to(device="cpu")
        * scale_b.to(device="cpu")
        * ((aq_i32 + azp_aq_i8).to(device="cpu") @ bq_i32.to(device="cpu"))
    ).to(dtype=out_dtype, device="cuda")

    out = ops.cutlass_scaled_mm(
        aq_i8, bq_i8, scale_a, scale_b, out_dtype=out_dtype, bias=azp_bias[0, :]
    )
411
412
    torch.testing.assert_close(out, baseline_dq, rtol=1e-2, atol=1e0)
    torch.testing.assert_close(out, baseline_q, rtol=1e-2, atol=1e0)
413
414
415
416
417
418
419
420


@pytest.mark.parametrize("m", [32, 64, 128])
@pytest.mark.parametrize("n", [16, 32, 64])
@pytest.mark.parametrize("k", [64, 128, 256])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.parametrize("azp_per_token", [True, False])
421
422
423
def test_cutlass_int8_azp(
    m: int, n: int, k: int, out_dtype: torch.dtype, use_bias: bool, azp_per_token: bool
):
424
425
426
427
428
429
430
431
432
433
434
435
436
    m_azp = m if azp_per_token else 1
    scale_a = torch.randn((m_azp, 1), device="cuda", dtype=torch.float32) / 10
    scale_b = torch.randn((1, n), device="cuda", dtype=torch.float32) / 10

    aq_i8 = rand_int8((m, k))
    aq_i32 = aq_i8.to(dtype=torch.int32)
    aq_f32 = aq_i8.to(dtype=torch.float32)

    bq_i8 = rand_int8((n, k)).t()
    bq_i32 = bq_i8.to(dtype=torch.int32)
    bq_f32 = bq_i8.to(dtype=torch.float32)
    b_dq = scale_b * bq_f32

437
    azp_a = torch.rand((m_azp, 1), device="cuda", dtype=torch.float32) * 10 + 1.5
438
439
440
441
    azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8)
    azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a  # correct for rounding

    a_dq = scale_a * (aq_i32 - azp_aq_i8).to(dtype=torch.float32)
442
    torch.testing.assert_close(a_dq, scale_a * aq_f32 - azp_a, rtol=1e-4, atol=1e-3)
443
444
445
446
447
448
449
450
451

    if use_bias:
        bias = torch.rand((1, n), device="cuda", dtype=out_dtype) * 10 + 2.5
    else:
        bias = torch.zeros((1, n), device="cuda", dtype=out_dtype)

    baseline_dq = (torch.mm(a_dq, b_dq) + bias).to(out_dtype)

    # int32 mm not supported on CUDA
452
453
    a_noazp_i32_cpu = (aq_i32 - azp_aq_i8).to(device="cpu")
    cq = (a_noazp_i32_cpu @ bq_i32.to(device="cpu")).to(device="cuda")
454
455
456
457
458
459
460
461
    baseline_q = (scale_a * scale_b * cq + bias).to(dtype=out_dtype)

    # Hadamard is just the sum of the cols
    azp_adj_i32 = bq_i32.sum(dim=0, keepdim=True, dtype=torch.int32)
    azp_i32 = azp_aq_i8.to(dtype=torch.int32)
    func_bias = bias if use_bias else None

    if azp_per_token:
462
463
464
        out = ops.cutlass_scaled_mm_azp(
            aq_i8, bq_i8, scale_a, scale_b, out_dtype, azp_adj_i32, azp_i32, func_bias
        )
465
466
    else:
        azp_with_adj_i32 = azp_i32 * azp_adj_i32
467
468
469
        out = ops.cutlass_scaled_mm_azp(
            aq_i8, bq_i8, scale_a, scale_b, out_dtype, azp_with_adj_i32, None, func_bias
        )
470
471
472
473
474

    # bfloat16 precision is 7-bit mantissa -> 2^-8 ~ 0.4%
    # float16 precision is 10-bit mantissa -> 2^-11 ~ 0.05%
    rtol = 1e-2 if out_dtype == torch.bfloat16 else 1e-3
    atol = 1e-3
475
476
    torch.testing.assert_close(out, baseline_dq, rtol=rtol, atol=atol)
    torch.testing.assert_close(out, baseline_q, rtol=rtol, atol=atol)
477

478
    if azp_per_token:
479
480
481
482
        opcheck(
            torch.ops._C.cutlass_scaled_mm_azp,
            (out, aq_i8, bq_i8, scale_a, scale_b, azp_adj_i32, azp_i32, func_bias),
        )
483
    else:
484
485
486
487
        opcheck(
            torch.ops._C.cutlass_scaled_mm_azp,
            (out, aq_i8, bq_i8, scale_a, scale_b, azp_with_adj_i32, None, func_bias),
        )
488

489

490
491
492
493
494
495
496
497
498
499
500
501
502
# Test working with a subset of A and B
def test_cutlass_subset():
    big_m, big_n, big_k = 1024, 1024, 1024
    m, n, k = 512, 512, 512

    whole_a = to_int8(torch.randn((big_m, big_k), device="cuda") * 5)
    whole_b = to_int8(torch.randn((big_n, big_k), device="cuda").t() * 5)
    a = whole_a[0:m, 0:k]
    b = whole_b[0:k, 0:n]

    scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
    scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10

503
504
    out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype=torch.bfloat16)
    baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype=torch.bfloat16)
505

506
    torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
507
508
509
510
511
512
513
514
515
516
517
518


# Test to make sure cuda graphs work
class CutlassLayer(torch.nn.Module):
    def __init__(self, b, scale_a, scale_b, out_dtype):
        super().__init__()
        self.b = b
        self.scale_a = scale_a
        self.scale_b = scale_b
        self.out_dtype = out_dtype

    def forward(self, a):
519
520
521
        return ops.cutlass_scaled_mm(
            a, self.b, self.scale_a, self.scale_b, self.out_dtype
        )
522
523


524
525
526
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool):
527
528
529
530
531
    m, n, k = 512, 512, 512

    a = to_int8(torch.randn((m, k), device="cuda"))
    b = to_int8(torch.randn((n, k), device="cuda").t())

532
533
534
    m_a_scales = m if per_act_token else 1
    n_b_scales = n if per_out_ch else 1

535
536
    scale_a = torch.randn((m_a_scales, 1), device="cuda", dtype=torch.float32) / 10
    scale_b = torch.randn((1, n_b_scales), device="cuda", dtype=torch.float32) / 10
537
538
539
540
541
542
543
544
545
546
547
548
549

    # Construct a trivial model with a single layer that calls a CUTLASS kernel
    model = CutlassLayer(b, scale_a, scale_b, torch.bfloat16)

    # Run the model with a cuda graph
    stream = torch.cuda.Stream()
    with torch.cuda.stream(stream):
        g = torch.cuda.CUDAGraph()
        with torch.cuda.graph(g):
            out = model(a)
    out.zero_()
    g.replay()

550
551
552
    baseline = torch.mm(
        scale_a * a.to(dtype=torch.float32), scale_b * b.to(dtype=torch.float32)
    ).to(torch.bfloat16)
553
    torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
554
555
556


def test_cutlass_support_opcheck():
557
    opcheck(torch.ops._C.cutlass_scaled_mm_supports_fp8, (capability,))
558
559
560
561
562
563
564
565


@pytest.mark.parametrize("num_experts", [8, 64])
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("use_bias", [False])
@pytest.mark.skipif(
    (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
566
567
568
569
570
571
572
        current_platform.get_device_capability()
    ),
    reason="Grouped gemm is not supported on this GPU type.",
)
def test_cutlass_fp8_group_gemm(
    num_experts: int, per_act_token: bool, per_out_ch: bool, use_bias: bool
):
573
574
575
576
577
578
579
580
581
582
583
    # Device and dtype setup
    device = "cuda"
    out_dtype = torch.half

    # Create separate A, B, C tensors for each group
    a_tensors = []
    b_tensors = []
    a_scales_tensors = []
    b_scales_tensors = []
    baseline_tensors = []

584
    expert_offsets = torch.zeros((num_experts + 1), device=device, dtype=torch.int64)
585

586
    problem_sizes = torch.zeros((num_experts, 3), device=device, dtype=torch.int32)
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612

    if not per_act_token:
        one_scale_a = torch.randn((1, 1), device=device, dtype=torch.float32)

    alignment = 16  # 128 // 8
    # For variation, each group has dimensions
    n_g = alignment * random.randint(1, 64)
    k_g = alignment * random.randint(1, 64)
    for g in range(num_experts):
        m_g = alignment * random.randint(1, 64)

        expert_offsets[g + 1] = expert_offsets[g] + m_g
        problem_sizes[g][0] = m_g
        problem_sizes[g][1] = n_g
        problem_sizes[g][2] = k_g

        m_a_scales = m_g if per_act_token else 1
        n_b_scales = n_g if per_out_ch else 1

        # Create group-specific A and B (FP8) and output (FP16/FP32)
        a_g = to_fp8(torch.randn((m_g, k_g), device=device))
        b_g = to_fp8(torch.randn((n_g, k_g), device=device).t())
        a_tensors.append(a_g)
        b_tensors.append(b_g)

        # Set up A/B scales
613
        scale_b = torch.randn((1, n_b_scales), device=device, dtype=torch.float32)
614
615
616
        b_scales_tensors.append(scale_b)

        if per_act_token:
617
            scale_a = torch.randn((m_a_scales, 1), device=device, dtype=torch.float32)
618
619
620
621
622
            a_scales_tensors.append(scale_a)
        else:
            scale_a = one_scale_a

        # Compute baseline result for this group
623
        baseline_g = baseline_scaled_mm(a_g, b_g, scale_a, scale_b, out_dtype, None)
624
625
        baseline_tensors.append(baseline_g)

626
627
628
629
630
631
    a_tensors_stacked = torch.empty(
        (expert_offsets[num_experts], k_g), device=device, dtype=torch.float8_e4m3fn
    )
    b_tensors_stacked = torch.empty(
        (num_experts, n_g, k_g), device=device, dtype=torch.float8_e4m3fn
    )
632
633

    for g in range(num_experts):
634
        a_tensors_stacked[expert_offsets[g] : expert_offsets[g + 1]] = a_tensors[g]
635
636
637
638
639
        b_tensors_stacked[g] = b_tensors[g].t()
    b_tensors_stacked = b_tensors_stacked.transpose(1, 2)

    if per_act_token:
        a_scales_tensors_stacked = torch.empty(
640
641
            (expert_offsets[num_experts], 1), device=device, dtype=torch.float32
        )
642
        for g in range(num_experts):
643
644
645
            a_scales_tensors_stacked[expert_offsets[g] : expert_offsets[g + 1]] = (
                a_scales_tensors[g]
            )
646
647
648
    else:
        a_scales_tensors_stacked = one_scale_a

649
650
651
    b_scales_tensors_stacked = torch.empty(
        (num_experts, n_b_scales), device=device, dtype=torch.float32
    )
652
653
654
    for g in range(num_experts):
        b_scales_tensors_stacked[g] = b_scales_tensors[g]

655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
    out_tensors_stacked = torch.zeros(
        (expert_offsets[num_experts], n_g), device=device, dtype=out_dtype
    )

    ab_strides = torch.full(
        (num_experts,), a_tensors_stacked.stride(0), device="cuda", dtype=torch.int64
    )
    c_strides = torch.full(
        (num_experts,), out_tensors_stacked.stride(0), device="cuda", dtype=torch.int64
    )

    ops.cutlass_moe_mm(
        out_tensors_stacked,
        a_tensors_stacked,
        b_tensors_stacked,
        a_scales_tensors_stacked,
        b_scales_tensors_stacked,
        expert_offsets[:-1],
        problem_sizes,
        ab_strides,
        ab_strides,
        c_strides,
        per_act_token,
        per_out_ch,
    )
680
681
682
683

    # Validate each group's result against the baseline
    for g in range(num_experts):
        baseline = baseline_tensors[g]
684
        c = out_tensors_stacked[expert_offsets[g] : expert_offsets[g + 1]]
685
        torch.testing.assert_close(c, baseline, rtol=1e-2, atol=5e-4)