test_marlin_gemm.py 18.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
"""Tests for the marlin kernel.

Run `pytest tests/kernels/marlin/test_marlin_gemm.py`.
"""
import pytest
import torch

10
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
11
from tests.quantization.utils import is_quant_method_supported
12
from vllm import _custom_ops as ops
13
14
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
    GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
15
    GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
16
17
18
from vllm.model_executor.layers.quantization.qqq import (
    MARLIN_QQQ_MAX_PARALLEL, MARLIN_QQQ_MIN_THREAD_N,
    MARLIN_QQQ_SUPPORTED_GROUP_SIZES, MARLIN_QQQ_SUPPORTED_NUM_BITS)
19
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
20
    GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
21
    MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx,
22
23
    marlin_make_workspace_new, marlin_permute_scales,
    query_marlin_supported_quant_types)
24
25
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
    FP4_MARLIN_SUPPORTED_GROUP_SIZES, rand_marlin_weight_fp4_like)
26
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
27
    marlin_quant_fp8_torch)
28
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
29
30
    MarlinWorkspace, awq_marlin_quantize, get_weight_perm, marlin_quantize,
    marlin_weights)
31
32
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
    marlin_24_quantize)
33
34
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_qqq import (  # noqa: E501
    marlin_qqq_quantize)
35
from vllm.model_executor.layers.quantization.utils.quant_utils import (
36
    awq_pack, gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights)
37
from vllm.scalar_type import scalar_types
38
39
40

ACT_ORDER_OPTS = [False, True]
K_FULL_OPTS = [False, True]
41
USE_ATOMIC_ADD_OPTS = [False, True]
42
USE_FP32_REDUCE_OPTS = [False, True]
43

44
MARLIN_K_CHUNKS = [128]
45
MARLIN_N_CHUNKS = [64, 256]
46
47

MARLIN_24_K_CHUNKS = [128]
48
MARLIN_24_N_CHUNKS = [512]
49

50
51
HQQ_SUPPORTED_GROUP_SIZES = [64]

52
53
54
55
56
57
58
MNK_FACTORS = [
    (1, 1, 1),
    (1, 4, 8),
    (1, 7, 5),
    (13, 17, 67),
    (26, 37, 13),
    (67, 13, 11),
59
60
    (257, 13, 11),
    (658, 13, 11),
61
62
]

63
DTYPES = [torch.float16, torch.bfloat16]
64

65

66
67
68
69
70
def compute_max_diff(output, output_ref):
    return torch.mean(torch.abs(output - output_ref)) / torch.mean(
        torch.abs(output_ref))


71
72
def rand_data(shape, dtype=torch.float16):
    return torch.randn(shape, dtype=dtype, device="cuda")
73
74


75
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
76
                    reason="Marlin is not supported on this GPU type.")
77
78
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
79
@pytest.mark.parametrize("quant_type",
80
                         query_marlin_supported_quant_types(False, False))
81
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
82
83
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
84
85
def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
                            act_order, mnk_factors):
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
    m_factor, n_factor, k_factor = mnk_factors

    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

    # Filter act_order
    if act_order:
        if group_size == -1:
            return
        if group_size == size_k:
            return

    # Normalize group_size
    if group_size == -1:
        group_size = size_k
    assert group_size <= size_k

    # Create input
    b_weight = rand_data((size_k, size_n))

    # Quantize (and apply act_order if provided)
107
108
    w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
        b_weight, quant_type, group_size, act_order)
109
110

    # Pack to GPTQ format
111
    q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)
112
113
114
115
116
117
118
119

    # For act_order, sort the "weights" and "g_idx" so that group ids are
    # increasing
    sort_indices = torch.empty(0, dtype=torch.int, device=b_weight.device)
    if act_order:
        q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)

    # Pack to Marlin format
120
121
122
    weight_perm = get_weight_perm(quant_type.size_bits)
    marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits,
                                  weight_perm)
123

124
125
126
    opcheck(torch.ops._C.gptq_marlin_repack,
            (q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits))

127
128
129
130
131
132
    # Run Marlin repack GPU kernel
    marlin_q_w_2 = ops.gptq_marlin_repack(
        q_w_gptq,
        sort_indices,
        size_k,
        size_n,
133
        quant_type.size_bits,
134
135
136
    )
    torch.cuda.synchronize()

137
    torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
138
139


140
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
141
                    reason="Marlin is not supported on this GPU type.")
142
143
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
144
@pytest.mark.parametrize("quant_type",
145
                         query_marlin_supported_quant_types(True))
146
147
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
148
def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
                           mnk_factors):
    m_factor, n_factor, k_factor = mnk_factors

    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

    # Normalize group_size
    if group_size == -1:
        group_size = size_k
    assert group_size <= size_k

    # Create input
    b_weight = rand_data((size_k, size_n))

    # Quantize
164
165
166
167
    w_ref, q_w, s, zp = quantize_weights(b_weight,
                                         quant_type,
                                         group_size,
                                         zero_points=True)
168
169

    # Pack to AWQ format
170
    q_w_awq = awq_pack(q_w, quant_type.size_bits, size_k, size_n)
171
172

    # Pack to Marlin format
173
174
175
    weight_perm = get_weight_perm(quant_type.size_bits)
    marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits,
                                  weight_perm)
176

177
178
179
    opcheck(torch.ops._C.awq_marlin_repack,
            (q_w_awq, size_k, size_n, quant_type.size_bits))

180
181
182
183
184
    # Run Marlin repack GPU kernel
    marlin_q_w_2 = ops.awq_marlin_repack(
        q_w_awq,
        size_k,
        size_n,
185
        quant_type.size_bits,
186
187
188
    )
    torch.cuda.synchronize()

189
    torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
190
191
192
193
194
195


@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
                    reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
196
197
198
199
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types())
@pytest.mark.parametrize(
    "group_size",
    set(MARLIN_SUPPORTED_GROUP_SIZES + FP4_MARLIN_SUPPORTED_GROUP_SIZES))
200
201
202
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
@pytest.mark.parametrize("is_k_full", K_FULL_OPTS)
203
@pytest.mark.parametrize("use_atomic_add", USE_ATOMIC_ADD_OPTS)
204
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
205
def test_gptq_marlin_gemm(
206
207
    k_chunk,
    n_chunk,
208
    quant_type,
209
210
211
212
    group_size,
    mnk_factors,
    act_order,
    is_k_full,
213
    use_atomic_add,
214
    use_fp32_reduce,
215
216
):
    m_factor, n_factor, k_factor = mnk_factors
217
    has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
218
219
220
221
222
223
224
225
226
227

    size_m = m_factor
    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

    if act_order:
        if group_size == -1:
            return
        if group_size == size_k:
            return
228
229
        if has_zp:
            return
230

231
232
233
    if size_k % group_size != 0:
        return

234
235
236
    a_input = rand_data((size_m, size_k))
    b_weight = rand_data((size_k, size_n))

237
238
239
240
241
242
243
244
245
    if quant_type == scalar_types.float4_e2m1f:
        if group_size != 16 or act_order:
            return
        w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_fp4_like(
            b_weight.T, group_size)
        g_idx = None
        sort_indices = None
        marlin_zp = None
    elif quant_type == scalar_types.float8_e4m3fn:
246
247
248
249
250
251
252
253
        if group_size not in [-1, 128]:
            return
        if act_order:
            return
        w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(
            b_weight.T, group_size)
        g_idx = None
        sort_indices = None
254
255
256
257
258
259
260
261
262
263
        marlin_zp = None
        marlin_s2 = None
    elif has_zp:
        if group_size == 16:
            return
        w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
            b_weight, quant_type, group_size)
        g_idx = None
        sort_indices = None
        marlin_s2 = None
264
    else:
265
266
        if group_size == 16:
            return
267
268
        w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
            b_weight, quant_type, group_size, act_order)
269
270
        marlin_zp = None
        marlin_s2 = None
271

272
    workspace = marlin_make_workspace_new(w_ref.device)
273

274
275
276
277
278
279
    opcheck(torch.ops._C.gptq_marlin_gemm,
            (a_input, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, g_idx,
             sort_indices, workspace, quant_type.id, a_input.shape[0],
             b_weight.shape[1], a_input.shape[1], is_k_full, use_atomic_add,
             use_fp32_reduce, False),
            test_utils=DEFAULT_OPCHECK_TEST_UTILS)
280

281
282
    output = ops.gptq_marlin_gemm(
        a_input,
283
        None,
284
285
        marlin_q_w,
        marlin_s,
286
        marlin_s2,
287
        marlin_zp,
288
289
        g_idx,
        sort_indices,
290
        workspace,
291
        quant_type,
292
293
294
        a_input.shape[0],
        b_weight.shape[1],
        a_input.shape[1],
295
        is_k_full=is_k_full,
296
        use_atomic_add=use_atomic_add,
297
        use_fp32_reduce=use_fp32_reduce,
298
        is_zp_float=False,
299
300
301
302
303
    )
    output_ref = torch.matmul(a_input, w_ref)

    torch.cuda.synchronize()

304
305
306
307
308
    max_diff = compute_max_diff(output, output_ref)

    assert max_diff < 0.04


309
310
311
312
313
314
315
316
317
318
# TODO: find better way to test this?
@torch.compile(fullgraph=True)
def marlin_24_gemm_tester(a_input, marlin_24_q_w_comp, marlin_24_meta,
                          marlin_24_s, scratch, quant_type, size_m, size_n,
                          size_k):
    return ops.gptq_marlin_24_gemm(a_input, marlin_24_q_w_comp, marlin_24_meta,
                                   marlin_24_s, scratch, quant_type, size_m,
                                   size_n, size_k)


319
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
320
321
322
                    reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_24_N_CHUNKS)
323
@pytest.mark.parametrize("quant_type", GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
324
325
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
326
def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
327
                             mnk_factors):
328
329
330
331
332
333
334
335
336
337
    m_factor, n_factor, k_factor = mnk_factors

    size_m = m_factor
    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

    a_input = rand_data((size_m, size_k))
    b_weight = rand_data((size_k, size_n))

    (w_24_ref, marlin_24_q_w_comp, marlin_24_meta,
338
     marlin_24_s) = marlin_24_quantize(b_weight, quant_type, group_size)
339
340
341
342
343
344

    workspace_24 = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N,
                                   GPTQ_MARLIN_24_MAX_PARALLEL)

    output_ref = torch.matmul(a_input, w_24_ref)

345
346
    opcheck(torch.ops._C.gptq_marlin_24_gemm,
            (a_input, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s,
347
             workspace_24.scratch, quant_type.id, a_input.shape[0],
348
349
350
             b_weight.shape[1], a_input.shape[1]),
            test_utils=DEFAULT_OPCHECK_TEST_UTILS)

351
    output = marlin_24_gemm_tester(
352
353
354
355
356
        a_input,
        marlin_24_q_w_comp,
        marlin_24_meta,
        marlin_24_s,
        workspace_24.scratch,
357
        quant_type,
358
359
360
361
362
363
364
365
366
367
        a_input.shape[0],
        b_weight.shape[1],
        a_input.shape[1],
    )

    torch.cuda.synchronize()

    max_diff = compute_max_diff(output, output_ref)

    assert max_diff < 0.04
368
369


370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
                    reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("group_size", HQQ_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
def test_hqq_marlin_gemm(
    k_chunk,
    n_chunk,
    group_size,
    mnk_factors,
    use_fp32_reduce,
):
    m_factor, n_factor, k_factor = mnk_factors

    size_m = m_factor
    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

    quant_type = scalar_types.uint4

    a_input = rand_data((size_m, size_k))
    dev = a_input.device

    b_weight = torch.randint(0,
                             10, (size_n, size_k),
                             dtype=torch.uint8,
                             device=dev)
    scale = rand_data((size_n, size_k // group_size))
    zero = rand_data((size_n, size_k // group_size))

    gptq_w_q = gptq_pack(b_weight.transpose(1, 0), 4, size_k, size_n)

    sort_indices = torch.empty(0, dtype=torch.int, device=dev)
    marlin_w_q = ops.gptq_marlin_repack(gptq_w_q, sort_indices, size_k, size_n,
                                        4).to(dev)
    marlin_s = marlin_permute_scales(scale.transpose(1, 0), size_k, size_n,
                                     group_size).to(dev)
    marlin_zp = marlin_permute_scales(zero.transpose(1, 0), size_k, size_n,
                                      group_size).to(dev)

    g_idx = marlin_make_empty_g_idx(dev)
    g_idx_sort_indices = marlin_make_empty_g_idx(dev)

415
    workspace = marlin_make_workspace_new(b_weight.device)
416
417
418

    output = ops.gptq_marlin_gemm(
        a_input,
419
        None,
420
421
        marlin_w_q,
        marlin_s,
422
        None,
423
424
425
        marlin_zp,
        g_idx,
        g_idx_sort_indices,
426
        workspace,
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
        quant_type,
        a_input.shape[0],
        b_weight.shape[0],
        a_input.shape[1],
        is_k_full=True,
        use_fp32_reduce=use_fp32_reduce,
        is_zp_float=True,
    )

    b_flat = b_weight.reshape(-1, group_size)
    zp_flat = zero.reshape(-1, 1)
    s_flat = scale.reshape(-1, 1)
    dequant = (b_flat - zp_flat) * s_flat

    output_ref = torch.matmul(a_input,
                              dequant.reshape(b_weight.shape).transpose(1, 0))

    torch.cuda.synchronize()

    max_diff = compute_max_diff(output, output_ref)

    assert max_diff < 0.04


451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
@pytest.mark.skipif(not is_quant_method_supported("qqq"),
                    reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("num_bits", MARLIN_QQQ_SUPPORTED_NUM_BITS)
@pytest.mark.parametrize("group_size", MARLIN_QQQ_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
def test_marlin_qqq_gemm(
    k_chunk,
    n_chunk,
    num_bits,
    group_size,
    mnk_factors,
):
    int8_traits = torch.iinfo(torch.int8)
    m_factor, n_factor, k_factor = mnk_factors

    size_m = m_factor
    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

    a_input = rand_data((size_m, size_k))
    b_weight = rand_data((size_k, size_n))

    # Quantize activations
    s_a = a_input.abs().max(dim=-1, keepdim=True)[0].div(int8_traits.max).to(
        torch.float)
    q_a = (a_input / s_a).round().clamp(int8_traits.min,
                                        int8_traits.max).to(torch.int8)

    # Quantize weights
    w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel = \
    marlin_qqq_quantize(b_weight, num_bits, group_size)

    workspace = MarlinWorkspace(size_n, MARLIN_QQQ_MIN_THREAD_N,
                                MARLIN_QQQ_MAX_PARALLEL)

488
489
490
491
492
    opcheck(torch.ops._C.marlin_qqq_gemm,
            (q_a, marlin_qqq_q_w, s_a, marlin_qqq_s_channel,
             marlin_qqq_s_group, workspace.scratch, a_input.shape[0],
             b_weight.shape[1], a_input.shape[1]))

493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
    output = ops.marlin_qqq_gemm(
        q_a,
        marlin_qqq_q_w,
        s_a,
        marlin_qqq_s_channel,
        marlin_qqq_s_group,
        workspace.scratch,
        a_input.shape[0],
        b_weight.shape[1],
        a_input.shape[1],
    )
    output_ref = torch.matmul(q_a.half() * s_a.half(), w_ref)

    torch.cuda.synchronize()

    max_diff = compute_max_diff(output, output_ref)

    assert max_diff < 0.04
511
512


513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
def test_marlin_gemm_subset_input():
    quant_type = scalar_types.uint4b8
    group_size = 128

    size_m, size_k, size_n = 32, 1024, 2048
    big_m = size_m * 2
    big_k = size_k * 2

    a_input = rand_data((big_m, big_k))[8:size_m + 8, 8:size_k + 8]
    b_weight = rand_data((size_k, size_n))

    w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
        b_weight, quant_type, group_size, False)

    marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
528
    workspace = marlin_make_workspace_new(a_input.device)
529
530
531

    output = ops.gptq_marlin_gemm(
        a_input,
532
        None,
533
534
        marlin_q_w,
        marlin_s,
535
        None,
536
537
538
        marlin_zp,
        g_idx,
        sort_indices,
539
        workspace,
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
        quant_type,
        a_input.shape[0],
        b_weight.shape[1],
        a_input.shape[1],
        is_k_full=True,
        use_atomic_add=False,
        use_fp32_reduce=True,
        is_zp_float=False,
    )
    output_ref = torch.matmul(a_input, w_ref)

    torch.cuda.synchronize()

    max_diff = compute_max_diff(output, output_ref)

    assert max_diff < 0.04


558
559
560
561
562
563
564
565
566
567
568
569
570
def test_marlin_gemm_opcheck():
    size_m = 2048
    size_n = 4096
    size_k = 4096
    a = torch.rand((size_m, size_n), device='cuda', dtype=torch.float16)
    w = torch.randint(-5, 5, (256, 8192), device='cuda', dtype=torch.int32)
    s = torch.full((32, size_k), 0.125, device='cuda', dtype=torch.float16)
    wk = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
                         GPTQ_MARLIN_MAX_PARALLEL).scratch
    x = torch.ops._C.marlin_gemm(a, w, s, wk, size_m, size_n, size_k)
    y = torch.ops._C.marlin_gemm(a, w, s, wk, size_m, size_n, size_k)
    torch.testing.assert_close(x, y)
    opcheck(torch.ops._C.marlin_gemm, (a, w, s, wk, size_m, size_n, size_k))