test_marlin_gemm.py 16.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
"""Tests for the marlin kernel.

5
Run `pytest tests/kernels/quantization/test_marlin_gemm.py`.
6
7
8
9
"""
import pytest
import torch

10
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
11
from tests.quantization.utils import is_quant_method_supported
12
from vllm import _custom_ops as ops
13
14
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
    GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
15
    GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
16
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
17
    MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx,
18
    marlin_make_workspace_new, marlin_permute_bias, marlin_permute_scales,
19
    query_marlin_supported_quant_types)
20
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
21
22
    FP4_MARLIN_SUPPORTED_GROUP_SIZES, rand_marlin_weight_mxfp4_like,
    rand_marlin_weight_nvfp4_like)
23
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
24
    marlin_quant_fp8_torch)
25
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
26
27
    MarlinWorkspace, awq_marlin_quantize, get_weight_perm, marlin_quantize,
    marlin_weights)
28
29
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
    marlin_24_quantize)
30
from vllm.model_executor.layers.quantization.utils.quant_utils import (
31
    awq_pack, gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights)
32
from vllm.scalar_type import scalar_types
33
34
35

ACT_ORDER_OPTS = [False, True]
K_FULL_OPTS = [False, True]
36
USE_ATOMIC_ADD_OPTS = [False, True]
37
USE_FP32_REDUCE_OPTS = [True]
38

39
MARLIN_K_CHUNKS = [128]
40
MARLIN_N_CHUNKS = [64, 256]
41
42

MARLIN_24_K_CHUNKS = [128]
43
MARLIN_24_N_CHUNKS = [512]
44

45
46
HQQ_SUPPORTED_GROUP_SIZES = [64]

47
48
49
50
MNK_FACTORS = [
    (1, 1, 1),
    (1, 4, 8),
    (26, 37, 13),
51
    (257, 13, 11),
52
53
]

54
DTYPES = [torch.float16, torch.bfloat16]
55

56

57
58
59
60
61
def compute_max_diff(output, output_ref):
    return torch.mean(torch.abs(output - output_ref)) / torch.mean(
        torch.abs(output_ref))


62
63
def rand_data(shape, dtype=torch.float16):
    return torch.randn(shape, dtype=dtype, device="cuda")
64
65


66
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
67
                    reason="Marlin is not supported on this GPU type.")
68
69
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
70
@pytest.mark.parametrize("quant_type",
71
                         query_marlin_supported_quant_types(False, False))
72
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
73
74
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
75
76
def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
                            act_order, mnk_factors):
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    m_factor, n_factor, k_factor = mnk_factors

    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

    # Filter act_order
    if act_order:
        if group_size == -1:
            return
        if group_size == size_k:
            return

    # Normalize group_size
    if group_size == -1:
        group_size = size_k
    assert group_size <= size_k

    # Create input
    b_weight = rand_data((size_k, size_n))

    # Quantize (and apply act_order if provided)
98
99
    w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
        b_weight, quant_type, group_size, act_order)
100
101

    # Pack to GPTQ format
102
    q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)
103
104
105
106
107
108
109
110

    # For act_order, sort the "weights" and "g_idx" so that group ids are
    # increasing
    sort_indices = torch.empty(0, dtype=torch.int, device=b_weight.device)
    if act_order:
        q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)

    # Pack to Marlin format
111
112
113
    weight_perm = get_weight_perm(quant_type.size_bits)
    marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits,
                                  weight_perm)
114

115
116
117
    opcheck(torch.ops._C.gptq_marlin_repack,
            (q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits))

118
119
120
121
122
123
    # Run Marlin repack GPU kernel
    marlin_q_w_2 = ops.gptq_marlin_repack(
        q_w_gptq,
        sort_indices,
        size_k,
        size_n,
124
        quant_type.size_bits,
125
126
127
    )
    torch.cuda.synchronize()

128
    torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
129
130


131
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
132
                    reason="Marlin is not supported on this GPU type.")
133
134
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
135
@pytest.mark.parametrize("quant_type",
136
                         query_marlin_supported_quant_types(True))
137
138
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
139
def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
                           mnk_factors):
    m_factor, n_factor, k_factor = mnk_factors

    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

    # Normalize group_size
    if group_size == -1:
        group_size = size_k
    assert group_size <= size_k

    # Create input
    b_weight = rand_data((size_k, size_n))

    # Quantize
155
156
157
158
    w_ref, q_w, s, zp = quantize_weights(b_weight,
                                         quant_type,
                                         group_size,
                                         zero_points=True)
159
160

    # Pack to AWQ format
161
    q_w_awq = awq_pack(q_w, quant_type.size_bits, size_k, size_n)
162
163

    # Pack to Marlin format
164
165
166
    weight_perm = get_weight_perm(quant_type.size_bits)
    marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits,
                                  weight_perm)
167

168
169
170
    opcheck(torch.ops._C.awq_marlin_repack,
            (q_w_awq, size_k, size_n, quant_type.size_bits))

171
172
173
174
175
    # Run Marlin repack GPU kernel
    marlin_q_w_2 = ops.awq_marlin_repack(
        q_w_awq,
        size_k,
        size_n,
176
        quant_type.size_bits,
177
178
179
    )
    torch.cuda.synchronize()

180
    torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
181
182
183
184
185
186


@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
                    reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
187
188
189
190
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types())
@pytest.mark.parametrize(
    "group_size",
    set(MARLIN_SUPPORTED_GROUP_SIZES + FP4_MARLIN_SUPPORTED_GROUP_SIZES))
191
192
193
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
@pytest.mark.parametrize("is_k_full", K_FULL_OPTS)
194
@pytest.mark.parametrize("use_atomic_add", USE_ATOMIC_ADD_OPTS)
195
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
196
197
198
199
@pytest.mark.parametrize("dtype", DTYPES)
def test_gptq_marlin_gemm(k_chunk, n_chunk, quant_type, group_size,
                          mnk_factors, act_order, is_k_full, use_atomic_add,
                          use_fp32_reduce, dtype):
200
    m_factor, n_factor, k_factor = mnk_factors
201
    has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
202
203
204
205
206
207
208
209
210
211

    size_m = m_factor
    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

    if act_order:
        if group_size == -1:
            return
        if group_size == size_k:
            return
212
213
        if has_zp:
            return
214

215
216
217
    if size_k % group_size != 0:
        return

218
219
    a_input = rand_data((size_m, size_k), dtype)
    b_weight = rand_data((size_k, size_n), dtype)
220

221
    if quant_type == scalar_types.float4_e2m1f:
222
        if group_size not in [16, 32] or act_order:
223
            return
224
225
226
227
228
229
230
231
232
233
234
        if group_size == 32 and dtype == torch.float16:
            return

        if group_size == 16:
            w_ref, marlin_q_w, marlin_s, marlin_s2 = \
                rand_marlin_weight_nvfp4_like(b_weight.T, group_size)
        else:
            w_ref, marlin_q_w, marlin_s = \
                rand_marlin_weight_mxfp4_like(b_weight.T, group_size)
            marlin_s2 = None

235
236
237
238
        g_idx = None
        sort_indices = None
        marlin_zp = None
    elif quant_type == scalar_types.float8_e4m3fn:
239
240
241
242
243
244
245
246
        if group_size not in [-1, 128]:
            return
        if act_order:
            return
        w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(
            b_weight.T, group_size)
        g_idx = None
        sort_indices = None
247
248
249
250
251
252
253
254
255
256
        marlin_zp = None
        marlin_s2 = None
    elif has_zp:
        if group_size == 16:
            return
        w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
            b_weight, quant_type, group_size)
        g_idx = None
        sort_indices = None
        marlin_s2 = None
257
    else:
258
259
        if group_size == 16:
            return
260
261
        w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
            b_weight, quant_type, group_size, act_order)
262
263
        marlin_zp = None
        marlin_s2 = None
264

265
    workspace = marlin_make_workspace_new(w_ref.device)
266

267
    opcheck(torch.ops._C.gptq_marlin_gemm,
268
269
            (a_input, None, marlin_q_w, None, marlin_s, marlin_s2, marlin_zp,
             g_idx, sort_indices, workspace, quant_type.id, a_input.shape[0],
270
271
272
             b_weight.shape[1], a_input.shape[1], is_k_full, use_atomic_add,
             use_fp32_reduce, False),
            test_utils=DEFAULT_OPCHECK_TEST_UTILS)
273

274
275
    output = ops.gptq_marlin_gemm(
        a_input,
276
        None,
277
        marlin_q_w,
278
        None,
279
        marlin_s,
280
        marlin_s2,
281
        marlin_zp,
282
283
        g_idx,
        sort_indices,
284
        workspace,
285
        quant_type,
286
287
288
        a_input.shape[0],
        b_weight.shape[1],
        a_input.shape[1],
289
        is_k_full=is_k_full,
290
        use_atomic_add=use_atomic_add,
291
        use_fp32_reduce=use_fp32_reduce,
292
        is_zp_float=False,
293
294
295
296
297
    )
    output_ref = torch.matmul(a_input, w_ref)

    torch.cuda.synchronize()

298
299
300
301
302
    max_diff = compute_max_diff(output, output_ref)

    assert max_diff < 0.04


303
304
305
306
307
308
309
310
311
312
# TODO: find better way to test this?
@torch.compile(fullgraph=True)
def marlin_24_gemm_tester(a_input, marlin_24_q_w_comp, marlin_24_meta,
                          marlin_24_s, scratch, quant_type, size_m, size_n,
                          size_k):
    return ops.gptq_marlin_24_gemm(a_input, marlin_24_q_w_comp, marlin_24_meta,
                                   marlin_24_s, scratch, quant_type, size_m,
                                   size_n, size_k)


313
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
314
315
316
                    reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_24_N_CHUNKS)
317
@pytest.mark.parametrize("quant_type", GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
318
319
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
320
def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
321
                             mnk_factors):
322
323
324
325
326
327
328
329
330
331
    m_factor, n_factor, k_factor = mnk_factors

    size_m = m_factor
    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

    a_input = rand_data((size_m, size_k))
    b_weight = rand_data((size_k, size_n))

    (w_24_ref, marlin_24_q_w_comp, marlin_24_meta,
332
     marlin_24_s) = marlin_24_quantize(b_weight, quant_type, group_size)
333
334
335
336
337
338

    workspace_24 = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N,
                                   GPTQ_MARLIN_24_MAX_PARALLEL)

    output_ref = torch.matmul(a_input, w_24_ref)

339
340
    opcheck(torch.ops._C.gptq_marlin_24_gemm,
            (a_input, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s,
341
             workspace_24.scratch, quant_type.id, a_input.shape[0],
342
343
344
             b_weight.shape[1], a_input.shape[1]),
            test_utils=DEFAULT_OPCHECK_TEST_UTILS)

345
    output = marlin_24_gemm_tester(
346
347
348
349
350
        a_input,
        marlin_24_q_w_comp,
        marlin_24_meta,
        marlin_24_s,
        workspace_24.scratch,
351
        quant_type,
352
353
354
355
356
357
358
359
360
361
        a_input.shape[0],
        b_weight.shape[1],
        a_input.shape[1],
    )

    torch.cuda.synchronize()

    max_diff = compute_max_diff(output, output_ref)

    assert max_diff < 0.04
362
363


364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
                    reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("group_size", HQQ_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
def test_hqq_marlin_gemm(
    k_chunk,
    n_chunk,
    group_size,
    mnk_factors,
    use_fp32_reduce,
):
    m_factor, n_factor, k_factor = mnk_factors

    size_m = m_factor
    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

    quant_type = scalar_types.uint4

    a_input = rand_data((size_m, size_k))
    dev = a_input.device

    b_weight = torch.randint(0,
                             10, (size_n, size_k),
                             dtype=torch.uint8,
                             device=dev)
    scale = rand_data((size_n, size_k // group_size))
    zero = rand_data((size_n, size_k // group_size))

    gptq_w_q = gptq_pack(b_weight.transpose(1, 0), 4, size_k, size_n)

    sort_indices = torch.empty(0, dtype=torch.int, device=dev)
    marlin_w_q = ops.gptq_marlin_repack(gptq_w_q, sort_indices, size_k, size_n,
                                        4).to(dev)
    marlin_s = marlin_permute_scales(scale.transpose(1, 0), size_k, size_n,
                                     group_size).to(dev)
    marlin_zp = marlin_permute_scales(zero.transpose(1, 0), size_k, size_n,
                                      group_size).to(dev)

    g_idx = marlin_make_empty_g_idx(dev)
    g_idx_sort_indices = marlin_make_empty_g_idx(dev)

409
    workspace = marlin_make_workspace_new(b_weight.device)
410
411
412

    output = ops.gptq_marlin_gemm(
        a_input,
413
        None,
414
        marlin_w_q,
415
        None,
416
        marlin_s,
417
        None,
418
419
420
        marlin_zp,
        g_idx,
        g_idx_sort_indices,
421
        workspace,
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
        quant_type,
        a_input.shape[0],
        b_weight.shape[0],
        a_input.shape[1],
        is_k_full=True,
        use_fp32_reduce=use_fp32_reduce,
        is_zp_float=True,
    )

    b_flat = b_weight.reshape(-1, group_size)
    zp_flat = zero.reshape(-1, 1)
    s_flat = scale.reshape(-1, 1)
    dequant = (b_flat - zp_flat) * s_flat

    output_ref = torch.matmul(a_input,
                              dequant.reshape(b_weight.shape).transpose(1, 0))

    torch.cuda.synchronize()

    max_diff = compute_max_diff(output, output_ref)

    assert max_diff < 0.04


446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
def test_marlin_gemm_subset_input():
    quant_type = scalar_types.uint4b8
    group_size = 128

    size_m, size_k, size_n = 32, 1024, 2048
    big_m = size_m * 2
    big_k = size_k * 2

    a_input = rand_data((big_m, big_k))[8:size_m + 8, 8:size_k + 8]
    b_weight = rand_data((size_k, size_n))

    w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
        b_weight, quant_type, group_size, False)

    marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
461
    workspace = marlin_make_workspace_new(a_input.device)
462
463
464

    output = ops.gptq_marlin_gemm(
        a_input,
465
        None,
466
        marlin_q_w,
467
        None,
468
        marlin_s,
469
        None,
470
471
472
        marlin_zp,
        g_idx,
        sort_indices,
473
        workspace,
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
        quant_type,
        a_input.shape[0],
        b_weight.shape[1],
        a_input.shape[1],
        is_k_full=True,
        use_atomic_add=False,
        use_fp32_reduce=True,
        is_zp_float=False,
    )
    output_ref = torch.matmul(a_input, w_ref)

    torch.cuda.synchronize()

    max_diff = compute_max_diff(output, output_ref)

    assert max_diff < 0.04


492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
@pytest.mark.parametrize("size_m", [1, 256])
def test_marlin_gemm_with_bias(size_m):
    quant_type = scalar_types.uint4b8
    group_size = 128

    size_k, size_n = 1024, 2048
    a_input = rand_data((size_m, size_k))
    b_weight = rand_data((size_k, size_n))
    b_bias = rand_data((size_n, )) * 10

    marlin_bias = marlin_permute_bias(b_bias)

    w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
        b_weight, quant_type, group_size, False)

    marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
    workspace = marlin_make_workspace_new(a_input.device)

    output = ops.gptq_marlin_gemm(
        a_input,
        None,
        marlin_q_w,
        marlin_bias,
        marlin_s,
        None,
        marlin_zp,
        g_idx,
        sort_indices,
        workspace,
        quant_type,
        a_input.shape[0],
        b_weight.shape[1],
        a_input.shape[1],
        is_k_full=True,
        use_atomic_add=False,
        use_fp32_reduce=True,
        is_zp_float=False,
    )
    output_ref = torch.matmul(a_input, w_ref) + b_bias.view(1, -1)

    torch.cuda.synchronize()

    max_diff = compute_max_diff(output, output_ref)

    assert max_diff < 0.04