"vscode:/vscode.git/clone" did not exist on "d910816c7356f4decd56eefb80e963b476cdf3e5"
test_marlin_gemm.py 16.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
"""Tests for the marlin kernel.

5
Run `pytest tests/kernels/quantization/test_marlin_gemm.py`.
6
"""
7

8
9
10
import pytest
import torch

11
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
12
from tests.quantization.utils import is_quant_method_supported
13
from vllm import _custom_ops as ops
14
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
15
16
17
18
19
    GPTQ_MARLIN_24_MAX_PARALLEL,
    GPTQ_MARLIN_24_MIN_THREAD_N,
    GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES,
    GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES,
)
20
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
21
22
23
24
25
26
27
    MARLIN_SUPPORTED_GROUP_SIZES,
    marlin_make_empty_g_idx,
    marlin_make_workspace_new,
    marlin_permute_bias,
    marlin_permute_scales,
    query_marlin_supported_quant_types,
)
28
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
29
30
31
32
    FP4_MARLIN_SUPPORTED_GROUP_SIZES,
    rand_marlin_weight_mxfp4_like,
    rand_marlin_weight_nvfp4_like,
)
33
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
34
35
    marlin_quant_fp8_torch,
)
36
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
37
38
39
40
41
42
    MarlinWorkspace,
    awq_marlin_quantize,
    get_weight_perm,
    marlin_quantize,
    marlin_weights,
)
43
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
44
45
    marlin_24_quantize,
)
46
from vllm.model_executor.layers.quantization.utils.quant_utils import (
47
48
49
50
51
52
    awq_pack,
    gptq_pack,
    gptq_quantize_weights,
    quantize_weights,
    sort_weights,
)
53
from vllm.scalar_type import scalar_types
54
55
56

ACT_ORDER_OPTS = [False, True]
K_FULL_OPTS = [False, True]
57
USE_ATOMIC_ADD_OPTS = [False, True]
58
USE_FP32_REDUCE_OPTS = [True]
59

60
MARLIN_K_CHUNKS = [128]
61
MARLIN_N_CHUNKS = [64, 256]
62
63

MARLIN_24_K_CHUNKS = [128]
64
MARLIN_24_N_CHUNKS = [512]
65

66
67
HQQ_SUPPORTED_GROUP_SIZES = [64]

68
69
70
71
MNK_FACTORS = [
    (1, 1, 1),
    (1, 4, 8),
    (26, 37, 13),
72
    (257, 13, 11),
73
74
]

75
DTYPES = [torch.float16, torch.bfloat16]
76

77

78
79
def compute_max_diff(output, output_ref):
    return torch.mean(torch.abs(output - output_ref)) / torch.mean(
80
81
        torch.abs(output_ref)
    )
82
83


84
85
def rand_data(shape, dtype=torch.float16):
    return torch.randn(shape, dtype=dtype, device="cuda")
86
87


88
89
90
91
@pytest.mark.skipif(
    not is_quant_method_supported("gptq_marlin"),
    reason="Marlin is not supported on this GPU type.",
)
92
93
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
94
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(False, False))
95
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
96
97
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
98
99
100
def test_gptq_marlin_repack(
    k_chunk, n_chunk, quant_type, group_size, act_order, mnk_factors
):
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
    m_factor, n_factor, k_factor = mnk_factors

    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

    # Filter act_order
    if act_order:
        if group_size == -1:
            return
        if group_size == size_k:
            return

    # Normalize group_size
    if group_size == -1:
        group_size = size_k
    assert group_size <= size_k

    # Create input
    b_weight = rand_data((size_k, size_n))

    # Quantize (and apply act_order if provided)
122
    w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
123
124
        b_weight, quant_type, group_size, act_order
    )
125
126

    # Pack to GPTQ format
127
    q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)
128
129
130
131
132
133
134
135

    # For act_order, sort the "weights" and "g_idx" so that group ids are
    # increasing
    sort_indices = torch.empty(0, dtype=torch.int, device=b_weight.device)
    if act_order:
        q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)

    # Pack to Marlin format
136
    weight_perm = get_weight_perm(quant_type.size_bits)
137
138
139
    marlin_q_w_1 = marlin_weights(
        q_w, size_k, size_n, quant_type.size_bits, weight_perm
    )
140

141
142
143
144
    opcheck(
        torch.ops._C.gptq_marlin_repack,
        (q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits),
    )
145

146
147
148
149
150
151
    # Run Marlin repack GPU kernel
    marlin_q_w_2 = ops.gptq_marlin_repack(
        q_w_gptq,
        sort_indices,
        size_k,
        size_n,
152
        quant_type.size_bits,
153
154
155
    )
    torch.cuda.synchronize()

156
    torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
157
158


159
160
161
162
@pytest.mark.skipif(
    not is_quant_method_supported("gptq_marlin"),
    reason="Marlin is not supported on this GPU type.",
)
163
164
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
165
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(True))
166
167
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
168
def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, mnk_factors):
169
170
171
172
173
174
175
176
177
178
179
180
181
182
    m_factor, n_factor, k_factor = mnk_factors

    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

    # Normalize group_size
    if group_size == -1:
        group_size = size_k
    assert group_size <= size_k

    # Create input
    b_weight = rand_data((size_k, size_n))

    # Quantize
183
184
185
    w_ref, q_w, s, zp = quantize_weights(
        b_weight, quant_type, group_size, zero_points=True
    )
186
187

    # Pack to AWQ format
188
    q_w_awq = awq_pack(q_w, quant_type.size_bits, size_k, size_n)
189
190

    # Pack to Marlin format
191
    weight_perm = get_weight_perm(quant_type.size_bits)
192
193
194
    marlin_q_w_1 = marlin_weights(
        q_w, size_k, size_n, quant_type.size_bits, weight_perm
    )
195

196
197
198
    opcheck(
        torch.ops._C.awq_marlin_repack, (q_w_awq, size_k, size_n, quant_type.size_bits)
    )
199

200
201
202
203
204
    # Run Marlin repack GPU kernel
    marlin_q_w_2 = ops.awq_marlin_repack(
        q_w_awq,
        size_k,
        size_n,
205
        quant_type.size_bits,
206
207
208
    )
    torch.cuda.synchronize()

209
    torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
210
211


212
213
214
215
@pytest.mark.skipif(
    not is_quant_method_supported("gptq_marlin"),
    reason="Marlin is not supported on this GPU type.",
)
216
217
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
218
219
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types())
@pytest.mark.parametrize(
220
221
    "group_size", set(MARLIN_SUPPORTED_GROUP_SIZES + FP4_MARLIN_SUPPORTED_GROUP_SIZES)
)
222
223
224
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
@pytest.mark.parametrize("is_k_full", K_FULL_OPTS)
225
@pytest.mark.parametrize("use_atomic_add", USE_ATOMIC_ADD_OPTS)
226
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
227
@pytest.mark.parametrize("dtype", DTYPES)
228
229
230
231
232
233
234
235
236
237
238
239
def test_gptq_marlin_gemm(
    k_chunk,
    n_chunk,
    quant_type,
    group_size,
    mnk_factors,
    act_order,
    is_k_full,
    use_atomic_add,
    use_fp32_reduce,
    dtype,
):
240
    m_factor, n_factor, k_factor = mnk_factors
241
    has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
242
243
244
245
246
247
248
249
250
251

    size_m = m_factor
    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

    if act_order:
        if group_size == -1:
            return
        if group_size == size_k:
            return
252
253
        if has_zp:
            return
254

255
256
257
    if size_k % group_size != 0:
        return

258
259
    a_input = rand_data((size_m, size_k), dtype)
    b_weight = rand_data((size_k, size_n), dtype)
260

261
    if quant_type == scalar_types.float4_e2m1f:
262
        if group_size not in [16, 32] or act_order:
263
            return
264
265
266
267
        if group_size == 32 and dtype == torch.float16:
            return

        if group_size == 16:
268
269
270
            w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_nvfp4_like(
                b_weight.T, group_size
            )
271
        else:
272
273
274
            w_ref, marlin_q_w, marlin_s = rand_marlin_weight_mxfp4_like(
                b_weight.T, group_size
            )
275
276
            marlin_s2 = None

277
278
279
280
        g_idx = None
        sort_indices = None
        marlin_zp = None
    elif quant_type == scalar_types.float8_e4m3fn:
281
282
283
284
        if group_size not in [-1, 128]:
            return
        if act_order:
            return
285
        w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(b_weight.T, group_size)
286
287
        g_idx = None
        sort_indices = None
288
289
290
291
292
293
        marlin_zp = None
        marlin_s2 = None
    elif has_zp:
        if group_size == 16:
            return
        w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
294
295
            b_weight, quant_type, group_size
        )
296
297
298
        g_idx = None
        sort_indices = None
        marlin_s2 = None
299
    else:
300
301
        if group_size == 16:
            return
302
        w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
303
304
            b_weight, quant_type, group_size, act_order
        )
305
306
        marlin_zp = None
        marlin_s2 = None
307

308
    workspace = marlin_make_workspace_new(w_ref.device)
309

310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
    opcheck(
        torch.ops._C.gptq_marlin_gemm,
        (
            a_input,
            None,
            marlin_q_w,
            None,
            marlin_s,
            marlin_s2,
            marlin_zp,
            g_idx,
            sort_indices,
            workspace,
            quant_type.id,
            a_input.shape[0],
            b_weight.shape[1],
            a_input.shape[1],
            is_k_full,
            use_atomic_add,
            use_fp32_reduce,
            False,
        ),
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )
334

335
336
    output = ops.gptq_marlin_gemm(
        a_input,
337
        None,
338
        marlin_q_w,
339
        None,
340
        marlin_s,
341
        marlin_s2,
342
        marlin_zp,
343
344
        g_idx,
        sort_indices,
345
        workspace,
346
        quant_type,
347
348
349
        a_input.shape[0],
        b_weight.shape[1],
        a_input.shape[1],
350
        is_k_full=is_k_full,
351
        use_atomic_add=use_atomic_add,
352
        use_fp32_reduce=use_fp32_reduce,
353
        is_zp_float=False,
354
355
356
357
358
    )
    output_ref = torch.matmul(a_input, w_ref)

    torch.cuda.synchronize()

359
360
361
362
363
    max_diff = compute_max_diff(output, output_ref)

    assert max_diff < 0.04


364
365
# TODO: find better way to test this?
@torch.compile(fullgraph=True)
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
def marlin_24_gemm_tester(
    a_input,
    marlin_24_q_w_comp,
    marlin_24_meta,
    marlin_24_s,
    scratch,
    quant_type,
    size_m,
    size_n,
    size_k,
):
    return ops.gptq_marlin_24_gemm(
        a_input,
        marlin_24_q_w_comp,
        marlin_24_meta,
        marlin_24_s,
        scratch,
        quant_type,
        size_m,
        size_n,
        size_k,
    )
388
389


390
391
392
393
@pytest.mark.skipif(
    not is_quant_method_supported("gptq_marlin"),
    reason="Marlin is not supported on this GPU type.",
)
394
395
@pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_24_N_CHUNKS)
396
@pytest.mark.parametrize("quant_type", GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
397
398
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
399
def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size, mnk_factors):
400
401
402
403
404
405
406
407
408
    m_factor, n_factor, k_factor = mnk_factors

    size_m = m_factor
    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

    a_input = rand_data((size_m, size_k))
    b_weight = rand_data((size_k, size_n))

409
410
411
    (w_24_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = marlin_24_quantize(
        b_weight, quant_type, group_size
    )
412

413
414
415
    workspace_24 = MarlinWorkspace(
        size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL
    )
416
417
418

    output_ref = torch.matmul(a_input, w_24_ref)

419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
    opcheck(
        torch.ops._C.gptq_marlin_24_gemm,
        (
            a_input,
            marlin_24_q_w_comp,
            marlin_24_meta,
            marlin_24_s,
            workspace_24.scratch,
            quant_type.id,
            a_input.shape[0],
            b_weight.shape[1],
            a_input.shape[1],
        ),
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )
434

435
    output = marlin_24_gemm_tester(
436
437
438
439
440
        a_input,
        marlin_24_q_w_comp,
        marlin_24_meta,
        marlin_24_s,
        workspace_24.scratch,
441
        quant_type,
442
443
444
445
446
447
448
449
450
451
        a_input.shape[0],
        b_weight.shape[1],
        a_input.shape[1],
    )

    torch.cuda.synchronize()

    max_diff = compute_max_diff(output, output_ref)

    assert max_diff < 0.04
452
453


454
455
456
457
@pytest.mark.skipif(
    not is_quant_method_supported("gptq_marlin"),
    reason="Marlin is not supported on this GPU type.",
)
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("group_size", HQQ_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
def test_hqq_marlin_gemm(
    k_chunk,
    n_chunk,
    group_size,
    mnk_factors,
    use_fp32_reduce,
):
    m_factor, n_factor, k_factor = mnk_factors

    size_m = m_factor
    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

    quant_type = scalar_types.uint4

    a_input = rand_data((size_m, size_k))
    dev = a_input.device

481
    b_weight = torch.randint(0, 10, (size_n, size_k), dtype=torch.uint8, device=dev)
482
483
484
485
486
487
    scale = rand_data((size_n, size_k // group_size))
    zero = rand_data((size_n, size_k // group_size))

    gptq_w_q = gptq_pack(b_weight.transpose(1, 0), 4, size_k, size_n)

    sort_indices = torch.empty(0, dtype=torch.int, device=dev)
488
489
490
491
492
493
494
495
496
    marlin_w_q = ops.gptq_marlin_repack(gptq_w_q, sort_indices, size_k, size_n, 4).to(
        dev
    )
    marlin_s = marlin_permute_scales(
        scale.transpose(1, 0), size_k, size_n, group_size
    ).to(dev)
    marlin_zp = marlin_permute_scales(
        zero.transpose(1, 0), size_k, size_n, group_size
    ).to(dev)
497
498
499
500

    g_idx = marlin_make_empty_g_idx(dev)
    g_idx_sort_indices = marlin_make_empty_g_idx(dev)

501
    workspace = marlin_make_workspace_new(b_weight.device)
502
503
504

    output = ops.gptq_marlin_gemm(
        a_input,
505
        None,
506
        marlin_w_q,
507
        None,
508
        marlin_s,
509
        None,
510
511
512
        marlin_zp,
        g_idx,
        g_idx_sort_indices,
513
        workspace,
514
515
516
517
518
519
520
521
522
523
524
525
526
527
        quant_type,
        a_input.shape[0],
        b_weight.shape[0],
        a_input.shape[1],
        is_k_full=True,
        use_fp32_reduce=use_fp32_reduce,
        is_zp_float=True,
    )

    b_flat = b_weight.reshape(-1, group_size)
    zp_flat = zero.reshape(-1, 1)
    s_flat = scale.reshape(-1, 1)
    dequant = (b_flat - zp_flat) * s_flat

528
    output_ref = torch.matmul(a_input, dequant.reshape(b_weight.shape).transpose(1, 0))
529
530
531
532
533
534
535
536

    torch.cuda.synchronize()

    max_diff = compute_max_diff(output, output_ref)

    assert max_diff < 0.04


537
538
539
540
541
542
543
544
def test_marlin_gemm_subset_input():
    quant_type = scalar_types.uint4b8
    group_size = 128

    size_m, size_k, size_n = 32, 1024, 2048
    big_m = size_m * 2
    big_k = size_k * 2

545
    a_input = rand_data((big_m, big_k))[8 : size_m + 8, 8 : size_k + 8]
546
547
548
    b_weight = rand_data((size_k, size_n))

    w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
549
550
        b_weight, quant_type, group_size, False
    )
551
552

    marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
553
    workspace = marlin_make_workspace_new(a_input.device)
554
555
556

    output = ops.gptq_marlin_gemm(
        a_input,
557
        None,
558
        marlin_q_w,
559
        None,
560
        marlin_s,
561
        None,
562
563
564
        marlin_zp,
        g_idx,
        sort_indices,
565
        workspace,
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
        quant_type,
        a_input.shape[0],
        b_weight.shape[1],
        a_input.shape[1],
        is_k_full=True,
        use_atomic_add=False,
        use_fp32_reduce=True,
        is_zp_float=False,
    )
    output_ref = torch.matmul(a_input, w_ref)

    torch.cuda.synchronize()

    max_diff = compute_max_diff(output, output_ref)

    assert max_diff < 0.04


584
585
586
587
588
589
590
591
@pytest.mark.parametrize("size_m", [1, 256])
def test_marlin_gemm_with_bias(size_m):
    quant_type = scalar_types.uint4b8
    group_size = 128

    size_k, size_n = 1024, 2048
    a_input = rand_data((size_m, size_k))
    b_weight = rand_data((size_k, size_n))
592
    b_bias = rand_data((size_n,)) * 10
593
594
595
596

    marlin_bias = marlin_permute_bias(b_bias)

    w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
597
598
        b_weight, quant_type, group_size, False
    )
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629

    marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
    workspace = marlin_make_workspace_new(a_input.device)

    output = ops.gptq_marlin_gemm(
        a_input,
        None,
        marlin_q_w,
        marlin_bias,
        marlin_s,
        None,
        marlin_zp,
        g_idx,
        sort_indices,
        workspace,
        quant_type,
        a_input.shape[0],
        b_weight.shape[1],
        a_input.shape[1],
        is_k_full=True,
        use_atomic_add=False,
        use_fp32_reduce=True,
        is_zp_float=False,
    )
    output_ref = torch.matmul(a_input, w_ref) + b_bias.view(1, -1)

    torch.cuda.synchronize()

    max_diff = compute_max_diff(output, output_ref)

    assert max_diff < 0.04