test_marlin_gemm.py 17.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
"""Tests for the marlin kernel.

5
Run `pytest tests/kernels/quantization/test_marlin_gemm.py`.
6
"""
7

8
9
import itertools

10
11
12
import pytest
import torch

13
from tests.kernels.utils import opcheck
14
from tests.quantization.utils import is_quant_method_supported
15
from vllm import _custom_ops as ops
16
17
18
from vllm.model_executor.layers.quantization.utils.int8_utils import (
    per_token_quant_int8,
)
19
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
20
21
22
23
24
    marlin_make_empty_g_idx,
    marlin_make_workspace_new,
    marlin_permute_bias,
    query_marlin_supported_quant_types,
)
25
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
26
27
28
    rand_marlin_weight_mxfp4_like,
    rand_marlin_weight_nvfp4_like,
)
29
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
30
31
    marlin_quant_fp8_torch,
)
32
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
33
34
35
36
37
    awq_marlin_quantize,
    get_weight_perm,
    marlin_quantize,
    marlin_weights,
)
38
from vllm.model_executor.layers.quantization.utils.quant_utils import (
39
40
41
42
43
44
    awq_pack,
    gptq_pack,
    gptq_quantize_weights,
    quantize_weights,
    sort_weights,
)
45
from vllm.platforms import current_platform
46
from vllm.scalar_type import scalar_types
47

48
49
if current_platform.is_rocm():
    pytest.skip(
50
        "These tests require marlin, which is not supported on ROCm.",
51
52
53
        allow_module_level=True,
    )

54
55
ACT_ORDER_OPTS = [False, True]
K_FULL_OPTS = [False, True]
56
USE_ATOMIC_ADD_OPTS = [False, True]
57
USE_FP32_REDUCE_OPTS = [True]
58

59
MARLIN_K_CHUNKS = [128]
60
MARLIN_N_CHUNKS = [64, 256]
61

62
63
64
65
66
67
MARLIN_REPACK_NK_FACTORS = [
    (4, 8),
    (7, 5),
    (13, 11),
]

68
69
70
71
MNK_FACTORS = [
    (1, 1, 1),
    (1, 4, 8),
    (26, 37, 13),
72
    (257, 13, 11),
73
74
]

75
DTYPES = [torch.float16, torch.bfloat16]
76

77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
DENSE_MARLIN_QUANT_TEST_CONFIGS = [
    # AWQ-INT4
    {"b_type": scalar_types.uint4, "group_blocks": [-1, 2, 4, 8]},
    # GPTQ-INT4
    {
        "b_type": scalar_types.uint4b8,
        "support_act_order": True,
        "group_blocks": [-1, 2, 4, 8],
    },
    # GPTQ-INT8
    {
        "b_type": scalar_types.uint8b128,
        "support_act_order": True,
        "group_blocks": [-1, 2, 4, 8],
    },
    # FP8
    {"b_type": scalar_types.float8_e4m3fn, "group_blocks": [-1, 8]},
    # NVFP4
    {"b_type": scalar_types.float4_e2m1f, "group_blocks": [1]},
    # MXFP4
    {
        "a_type": [scalar_types.bfloat16],
        "b_type": scalar_types.float4_e2m1f,
        "group_blocks": [2],
    },
    # AWQ-INT4 with INT8 activation
    {
        "a_type": [scalar_types.int8],
        "b_type": scalar_types.uint4,
        "group_blocks": [-1, 2, 4, 8],
    },
    # GPTQ-INT4 with INT8 activation
    {
        "a_type": [scalar_types.int8],
        "b_type": scalar_types.uint4b8,
        "group_blocks": [-1, 2, 4, 8],
    },
    # GPTQ-INT4 with FP8 activation
    {
        "a_type": [scalar_types.float8_e4m3fn],
        "b_type": scalar_types.uint4b8,
        "group_blocks": [-1, 2, 4, 8],
    },
    # AWQ-INT4 with FP8 activation
    {
        "a_type": [scalar_types.float8_e4m3fn],
        "b_type": scalar_types.uint4,
        "group_blocks": [-1, 2, 4, 8],
    },
    # MXFP4 with FP8 activation
    {
        "a_type": [scalar_types.float8_e4m3fn],
        "b_type": scalar_types.float4_e2m1f,
        "c_type": [scalar_types.bfloat16],
        "group_blocks": [2],
    },
]

135

136
137
def compute_max_diff(output, output_ref):
    return torch.mean(torch.abs(output - output_ref)) / torch.mean(
138
139
        torch.abs(output_ref)
    )
140
141


142
143
def rand_data(shape, dtype=torch.float16):
    return torch.randn(shape, dtype=dtype, device="cuda")
144
145


146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
@pytest.mark.skipif(
    not is_quant_method_supported("gptq_marlin"),
    reason="Marlin is not supported on this GPU type.",
)
def test_marlin_int4_fp8_preprocess_without_zp():
    qweight_unpacked = torch.randint(
        0, 16, size=(2048, 2048), dtype=torch.int32, device="cuda"
    )
    qweight_packed = qweight_unpacked[:, ::2] * 16 + qweight_unpacked[:, 1::2]
    qweight_packed = qweight_packed.to(torch.int8).view(torch.int32)

    cuda_res = ops.marlin_int4_fp8_preprocess(qweight_packed)

    torch_res = torch.where(
        qweight_unpacked >= 8, qweight_unpacked - 8, 15 - qweight_unpacked
    )
    torch_res = torch_res[:, ::2] * 16 + torch_res[:, 1::2]
    torch_res = torch_res.to(torch.int8).view(torch.int32)

    assert (cuda_res == torch_res).all()


@pytest.mark.skipif(
    not is_quant_method_supported("gptq_marlin"),
    reason="Marlin is not supported on this GPU type.",
)
def test_marlin_int4_fp8_preprocess_awq():
    group_size = 128

    qweight_unpacked = torch.randint(
        0, 16, size=(2048, 2048), dtype=torch.int32, device="cuda"
    )
    qzeros_unpacked = torch.randint(
        0, 16, size=(2048 // group_size, 2048), dtype=torch.int32, device="cuda"
    )

    qweight_packed = qweight_unpacked[:, ::2] * 16 + qweight_unpacked[:, 1::2]
    qweight_packed = qweight_packed.to(torch.int8).view(torch.int32)
    qzeros_packed = qzeros_unpacked[:, ::2] * 16 + qzeros_unpacked[:, 1::2]
    qzeros_packed = qzeros_packed.to(torch.int8).view(torch.int32)

    cuda_res = ops.marlin_int4_fp8_preprocess(qweight_packed, qzeros_packed)

    repeated_zp = qzeros_unpacked.repeat_interleave(group_size, 0)
    torch_res = qweight_unpacked - repeated_zp
    torch_res[torch_res < 0] = 15 - qweight_unpacked[torch_res < 0]
    torch_res = torch_res[:, ::2] * 16 + torch_res[:, 1::2]
    torch_res = torch_res.to(torch.int8).view(torch.int32)

    assert (cuda_res == torch_res).all()


198
199
200
201
@pytest.mark.skipif(
    not is_quant_method_supported("gptq_marlin"),
    reason="Marlin is not supported on this GPU type.",
)
202
203
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
204
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(False, False))
205
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
206
207
@pytest.mark.parametrize("is_a_8bit", [True, False])
@pytest.mark.parametrize("nk_factors", MARLIN_REPACK_NK_FACTORS)
208
def test_gptq_marlin_repack(
209
    k_chunk, n_chunk, quant_type, act_order, is_a_8bit, nk_factors
210
):
211
    n_factor, k_factor = nk_factors
212
213
214

    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor
215
    group_size = 128
216
217
218
219
220
221
222

    # Filter act_order
    if act_order:
        if group_size == -1:
            return
        if group_size == size_k:
            return
223
224
        if is_a_8bit:
            return
225
226
227
228
229
230
231
232
233
234

    # Normalize group_size
    if group_size == -1:
        group_size = size_k
    assert group_size <= size_k

    # Create input
    b_weight = rand_data((size_k, size_n))

    # Quantize (and apply act_order if provided)
235
    w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
236
237
        b_weight, quant_type, group_size, act_order
    )
238
239

    # Pack to GPTQ format
240
    q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)
241
242
243
244
245
246
247
248

    # For act_order, sort the "weights" and "g_idx" so that group ids are
    # increasing
    sort_indices = torch.empty(0, dtype=torch.int, device=b_weight.device)
    if act_order:
        q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)

    # Pack to Marlin format
249
    weight_perm = get_weight_perm(quant_type.size_bits, is_a_8bit)
250
    marlin_q_w_1 = marlin_weights(
251
        q_w, size_k, size_n, quant_type.size_bits, weight_perm, is_a_8bit
252
    )
253

254
255
    opcheck(
        torch.ops._C.gptq_marlin_repack,
256
        (q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits, is_a_8bit),
257
    )
258

259
260
    # Run Marlin repack GPU kernel
    marlin_q_w_2 = ops.gptq_marlin_repack(
261
        q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits, is_a_8bit
262
263
264
    )
    torch.cuda.synchronize()

265
    torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
266
267


268
269
270
271
@pytest.mark.skipif(
    not is_quant_method_supported("gptq_marlin"),
    reason="Marlin is not supported on this GPU type.",
)
272
273
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
274
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(True))
275
276
277
278
@pytest.mark.parametrize("is_a_8bit", [True, False])
@pytest.mark.parametrize("nk_factors", MARLIN_REPACK_NK_FACTORS)
def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, is_a_8bit, nk_factors):
    n_factor, k_factor = nk_factors
279
280
281
282

    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

283
    group_size = 128
284
285
286
287
288

    # Create input
    b_weight = rand_data((size_k, size_n))

    # Quantize
289
290
291
    w_ref, q_w, s, zp = quantize_weights(
        b_weight, quant_type, group_size, zero_points=True
    )
292
293

    # Pack to AWQ format
294
    q_w_awq = awq_pack(q_w, quant_type.size_bits, size_k, size_n)
295
296

    # Pack to Marlin format
297
    weight_perm = get_weight_perm(quant_type.size_bits, is_a_8bit)
298
    marlin_q_w_1 = marlin_weights(
299
        q_w, size_k, size_n, quant_type.size_bits, weight_perm, is_a_8bit
300
    )
301

302
    opcheck(
303
304
        torch.ops._C.awq_marlin_repack,
        (q_w_awq, size_k, size_n, quant_type.size_bits, is_a_8bit),
305
    )
306

307
308
    # Run Marlin repack GPU kernel
    marlin_q_w_2 = ops.awq_marlin_repack(
309
        q_w_awq, size_k, size_n, quant_type.size_bits, is_a_8bit
310
311
312
    )
    torch.cuda.synchronize()

313
    torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
314
315


316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
def marlin_generate_valid_test_cases():
    all_combinations = itertools.product(
        DENSE_MARLIN_QUANT_TEST_CONFIGS,
        MNK_FACTORS,
        MARLIN_N_CHUNKS,
        MARLIN_K_CHUNKS,
        ACT_ORDER_OPTS,
        K_FULL_OPTS,
        USE_ATOMIC_ADD_OPTS,
        USE_FP32_REDUCE_OPTS,
    )

    def is_invalid(
        a_type,
        b_type,
        c_type,
        group_blocks,
        size_m,
        size_n,
        size_k,
        act_order,
        is_k_full,
        use_atomic_add,
        use_fp32_reduce,
    ):
        if use_atomic_add:
            if use_fp32_reduce:
                return False
            if (
                c_type == scalar_types.bfloat16
                and torch.cuda.get_device_capability()[0] < 9
            ):
                return False

        group_size = group_blocks if group_blocks <= 0 else group_blocks * 16
        if group_size > 0 and size_k % group_size != 0:
            return False

        if act_order and group_size in [-1, size_k]:
            return False
        if group_size == size_k:
            return False
        if not act_order and is_k_full:
            return False

        return a_type.size_bits < 16 or a_type is c_type

    cases = []
    for case in all_combinations:
        quant_test_config, mnk_factors, n_chunk, k_chunk, act_order, *_ = case
        size_m = mnk_factors[0]
        size_n = mnk_factors[1] * n_chunk
        size_k = mnk_factors[2] * k_chunk

        if act_order and not quant_test_config.get("support_act_order", False):
            continue

        f16_types = [scalar_types.float16, scalar_types.bfloat16]
        inner_combinations = itertools.product(
            quant_test_config.get("a_type", f16_types),
            [quant_test_config["b_type"]],
            quant_test_config.get("c_type", f16_types),
            quant_test_config["group_blocks"],
        )

        for sub_case in inner_combinations:
            if (
                sub_case[0] == scalar_types.float8_e4m3fn
                and current_platform.get_device_capability() not in [89, 120]
            ):
                continue
            args = sub_case + (size_m, size_n, size_k) + case[4:]
            if is_invalid(*args):
                cases.append(args)
    return cases


393
394
395
396
@pytest.mark.skipif(
    not is_quant_method_supported("gptq_marlin"),
    reason="Marlin is not supported on this GPU type.",
)
397
@pytest.mark.parametrize(
398
399
400
401
402
403
    (
        "a_type, b_type, c_type, group_blocks,"
        "size_m, size_n, size_k, act_order, is_k_full,"
        "use_atomic_add, use_fp32_reduce"
    ),
    marlin_generate_valid_test_cases(),
404
)
405
def test_marlin_gemm(
406
407
408
409
410
411
412
    a_type,
    b_type,
    c_type,
    group_blocks,
    size_m,
    size_n,
    size_k,
413
414
415
416
417
    act_order,
    is_k_full,
    use_atomic_add,
    use_fp32_reduce,
):
418
    has_zp = b_type in [scalar_types.uint4, scalar_types.uint8]
419

420
    group_size = group_blocks if group_blocks <= 0 else group_blocks * 16
421

422
423
424
425
426
427
    if c_type == scalar_types.float16:
        dtype = torch.float16
    elif c_type == scalar_types.bfloat16:
        dtype = torch.bfloat16
    else:
        raise RuntimeError("unsupported c_type")
428

429
430
431
432
433
434
    if a_type == scalar_types.int8:
        a_dtype = torch.int8
    elif a_type == scalar_types.float8_e4m3fn:
        a_dtype = torch.float8_e4m3fn
    else:
        a_dtype = dtype
435

436
437
    a_input = rand_data((size_m, size_k), dtype=dtype)
    b_weight = rand_data((size_k, size_n), dtype=dtype)
438

439
    if b_type == scalar_types.float4_e2m1f:
440
        if group_size == 16:
441
            w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_nvfp4_like(
442
                b_weight.T, group_size, input_dtype=a_dtype
443
            )
444
        else:
445
            w_ref, marlin_q_w, marlin_s = rand_marlin_weight_mxfp4_like(
446
                b_weight.T, group_size, input_dtype=a_dtype
447
            )
448
449
            marlin_s2 = None

450
451
452
        g_idx = None
        sort_indices = None
        marlin_zp = None
453
454
455
456
    elif b_type == scalar_types.float8_e4m3fn:
        w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(
            b_weight.T, group_size, input_dtype=a_dtype
        )
457
458
        g_idx = None
        sort_indices = None
459
460
461
462
        marlin_zp = None
        marlin_s2 = None
    elif has_zp:
        w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
463
            b_weight, b_type, group_size, input_dtype=a_dtype
464
        )
465
466
467
        g_idx = None
        sort_indices = None
        marlin_s2 = None
468
469
    else:
        w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
470
            b_weight, b_type, group_size, act_order, input_dtype=a_dtype
471
        )
472

473
474
        marlin_zp = None
        marlin_s2 = None
475

476
    workspace = marlin_make_workspace_new(w_ref.device)
477

478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
    if a_type == scalar_types.int8:
        a_input, a_scales = per_token_quant_int8(a_input)
        a_input_ref = a_input.to(a_scales.dtype) * a_scales.view(-1, 1)
        a_input_ref = a_input_ref.to(dtype)

        if group_size != -1:
            a_scales = a_scales / 4096 * marlin_s.max()
            a_scales = a_scales.float()
            marlin_s = marlin_s / marlin_s.max() * 4096
            marlin_s = marlin_s.round().to(torch.int16).view(dtype)
    elif a_type == scalar_types.float8_e4m3fn:
        a_input, a_scales = ops.scaled_fp8_quant(a_input, use_per_token_if_dynamic=True)
        a_input_ref = a_input.to(a_scales.dtype) * a_scales.view(-1, 1)
        a_input_ref = a_input_ref.to(dtype)
    else:
        assert a_type.size_bits == 16
        a_input_ref = a_input
        a_scales = None

    output = torch.empty((size_m, size_n), dtype=dtype, device=a_input.device)
498

499
    output = ops.marlin_gemm(
500
        a_input,
501
        output,
502
        marlin_q_w,
503
        None,
504
        marlin_s,
505
        a_scales,
506
        marlin_s2,
507
        marlin_zp,
508
509
        g_idx,
        sort_indices,
510
        workspace,
511
        b_type,
512
513
514
        a_input.shape[0],
        b_weight.shape[1],
        a_input.shape[1],
515
        is_k_full=is_k_full,
516
        use_atomic_add=use_atomic_add,
517
        use_fp32_reduce=use_fp32_reduce,
518
        is_zp_float=False,
519
    )
520
    output_ref = torch.matmul(a_input_ref, w_ref)
521

522
523
524
525
    max_diff = compute_max_diff(output, output_ref)
    assert max_diff < 0.04


526
527
528
529
530
531
532
533
def test_marlin_gemm_subset_input():
    quant_type = scalar_types.uint4b8
    group_size = 128

    size_m, size_k, size_n = 32, 1024, 2048
    big_m = size_m * 2
    big_k = size_k * 2

534
    a_input = rand_data((big_m, big_k))[8 : size_m + 8, 8 : size_k + 8]
535
536
537
    b_weight = rand_data((size_k, size_n))

    w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
538
539
        b_weight, quant_type, group_size, False
    )
540
541

    marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
542
    workspace = marlin_make_workspace_new(a_input.device)
543

544
    output = ops.marlin_gemm(
545
        a_input,
546
        None,
547
        marlin_q_w,
548
        None,
549
        marlin_s,
550
        None,
551
        None,
552
553
554
        marlin_zp,
        g_idx,
        sort_indices,
555
        workspace,
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
        quant_type,
        a_input.shape[0],
        b_weight.shape[1],
        a_input.shape[1],
        is_k_full=True,
        use_atomic_add=False,
        use_fp32_reduce=True,
        is_zp_float=False,
    )
    output_ref = torch.matmul(a_input, w_ref)

    torch.cuda.synchronize()

    max_diff = compute_max_diff(output, output_ref)

    assert max_diff < 0.04


574
575
576
577
578
579
580
581
@pytest.mark.parametrize("size_m", [1, 256])
def test_marlin_gemm_with_bias(size_m):
    quant_type = scalar_types.uint4b8
    group_size = 128

    size_k, size_n = 1024, 2048
    a_input = rand_data((size_m, size_k))
    b_weight = rand_data((size_k, size_n))
582
    b_bias = rand_data((size_n,)) * 10
583
584
585
586

    marlin_bias = marlin_permute_bias(b_bias)

    w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
587
588
        b_weight, quant_type, group_size, False
    )
589
590
591
592

    marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
    workspace = marlin_make_workspace_new(a_input.device)

593
    output = ops.marlin_gemm(
594
595
596
597
598
599
        a_input,
        None,
        marlin_q_w,
        marlin_bias,
        marlin_s,
        None,
600
        None,
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
        marlin_zp,
        g_idx,
        sort_indices,
        workspace,
        quant_type,
        a_input.shape[0],
        b_weight.shape[1],
        a_input.shape[1],
        is_k_full=True,
        use_atomic_add=False,
        use_fp32_reduce=True,
        is_zp_float=False,
    )
    output_ref = torch.matmul(a_input, w_ref) + b_bias.view(1, -1)

    torch.cuda.synchronize()

    max_diff = compute_max_diff(output, output_ref)

    assert max_diff < 0.04