test_marlin_gemm.py 22.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
"""Tests for the marlin kernel.

5
Run `pytest tests/kernels/quantization/test_marlin_gemm.py`.
6
"""
7

8
9
import itertools

10
11
12
import pytest
import torch

13
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
14
from tests.quantization.utils import is_quant_method_supported
15
from vllm import _custom_ops as ops
16
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
17
18
19
20
21
    GPTQ_MARLIN_24_MAX_PARALLEL,
    GPTQ_MARLIN_24_MIN_THREAD_N,
    GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES,
    GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES,
)
22
23
24
from vllm.model_executor.layers.quantization.utils.int8_utils import (
    per_token_quant_int8,
)
25
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
26
27
28
29
30
31
    marlin_make_empty_g_idx,
    marlin_make_workspace_new,
    marlin_permute_bias,
    marlin_permute_scales,
    query_marlin_supported_quant_types,
)
32
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
33
34
35
    rand_marlin_weight_mxfp4_like,
    rand_marlin_weight_nvfp4_like,
)
36
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
37
38
    marlin_quant_fp8_torch,
)
39
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
40
41
42
43
44
45
    MarlinWorkspace,
    awq_marlin_quantize,
    get_weight_perm,
    marlin_quantize,
    marlin_weights,
)
46
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
47
48
    marlin_24_quantize,
)
49
from vllm.model_executor.layers.quantization.utils.quant_utils import (
50
51
52
53
54
55
    awq_pack,
    gptq_pack,
    gptq_quantize_weights,
    quantize_weights,
    sort_weights,
)
56
from vllm.platforms import current_platform
57
from vllm.scalar_type import scalar_types
58

59
60
61
62
63
64
65
66
if current_platform.is_rocm():
    pytest.skip(
        "These tests require gptq_marlin_repack,"
        "marlin_int4_fp8_preprocess, gptq_marlin_24_gemm,"
        "or gptq_marlin_gemm which are not supported on ROCm.",
        allow_module_level=True,
    )

67
68
ACT_ORDER_OPTS = [False, True]
K_FULL_OPTS = [False, True]
69
USE_ATOMIC_ADD_OPTS = [False, True]
70
USE_FP32_REDUCE_OPTS = [True]
71

72
MARLIN_K_CHUNKS = [128]
73
MARLIN_N_CHUNKS = [64, 256]
74
75

MARLIN_24_K_CHUNKS = [128]
76
MARLIN_24_N_CHUNKS = [512]
77

78
79
HQQ_SUPPORTED_GROUP_SIZES = [64]

80
81
82
83
84
85
MARLIN_REPACK_NK_FACTORS = [
    (4, 8),
    (7, 5),
    (13, 11),
]

86
87
88
89
MNK_FACTORS = [
    (1, 1, 1),
    (1, 4, 8),
    (26, 37, 13),
90
    (257, 13, 11),
91
92
]

93
DTYPES = [torch.float16, torch.bfloat16]
94

95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
DENSE_MARLIN_QUANT_TEST_CONFIGS = [
    # AWQ-INT4
    {"b_type": scalar_types.uint4, "group_blocks": [-1, 2, 4, 8]},
    # GPTQ-INT4
    {
        "b_type": scalar_types.uint4b8,
        "support_act_order": True,
        "group_blocks": [-1, 2, 4, 8],
    },
    # GPTQ-INT8
    {
        "b_type": scalar_types.uint8b128,
        "support_act_order": True,
        "group_blocks": [-1, 2, 4, 8],
    },
    # FP8
    {"b_type": scalar_types.float8_e4m3fn, "group_blocks": [-1, 8]},
    # NVFP4
    {"b_type": scalar_types.float4_e2m1f, "group_blocks": [1]},
    # MXFP4
    {
        "a_type": [scalar_types.bfloat16],
        "b_type": scalar_types.float4_e2m1f,
        "group_blocks": [2],
    },
    # AWQ-INT4 with INT8 activation
    {
        "a_type": [scalar_types.int8],
        "b_type": scalar_types.uint4,
        "group_blocks": [-1, 2, 4, 8],
    },
    # GPTQ-INT4 with INT8 activation
    {
        "a_type": [scalar_types.int8],
        "b_type": scalar_types.uint4b8,
        "group_blocks": [-1, 2, 4, 8],
    },
    # GPTQ-INT4 with FP8 activation
    {
        "a_type": [scalar_types.float8_e4m3fn],
        "b_type": scalar_types.uint4b8,
        "group_blocks": [-1, 2, 4, 8],
    },
    # AWQ-INT4 with FP8 activation
    {
        "a_type": [scalar_types.float8_e4m3fn],
        "b_type": scalar_types.uint4,
        "group_blocks": [-1, 2, 4, 8],
    },
    # MXFP4 with FP8 activation
    {
        "a_type": [scalar_types.float8_e4m3fn],
        "b_type": scalar_types.float4_e2m1f,
        "c_type": [scalar_types.bfloat16],
        "group_blocks": [2],
    },
]

153

154
155
def compute_max_diff(output, output_ref):
    return torch.mean(torch.abs(output - output_ref)) / torch.mean(
156
157
        torch.abs(output_ref)
    )
158
159


160
161
def rand_data(shape, dtype=torch.float16):
    return torch.randn(shape, dtype=dtype, device="cuda")
162
163


164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
@pytest.mark.skipif(
    not is_quant_method_supported("gptq_marlin"),
    reason="Marlin is not supported on this GPU type.",
)
def test_marlin_int4_fp8_preprocess_without_zp():
    qweight_unpacked = torch.randint(
        0, 16, size=(2048, 2048), dtype=torch.int32, device="cuda"
    )
    qweight_packed = qweight_unpacked[:, ::2] * 16 + qweight_unpacked[:, 1::2]
    qweight_packed = qweight_packed.to(torch.int8).view(torch.int32)

    cuda_res = ops.marlin_int4_fp8_preprocess(qweight_packed)

    torch_res = torch.where(
        qweight_unpacked >= 8, qweight_unpacked - 8, 15 - qweight_unpacked
    )
    torch_res = torch_res[:, ::2] * 16 + torch_res[:, 1::2]
    torch_res = torch_res.to(torch.int8).view(torch.int32)

    assert (cuda_res == torch_res).all()


@pytest.mark.skipif(
    not is_quant_method_supported("gptq_marlin"),
    reason="Marlin is not supported on this GPU type.",
)
def test_marlin_int4_fp8_preprocess_awq():
    group_size = 128

    qweight_unpacked = torch.randint(
        0, 16, size=(2048, 2048), dtype=torch.int32, device="cuda"
    )
    qzeros_unpacked = torch.randint(
        0, 16, size=(2048 // group_size, 2048), dtype=torch.int32, device="cuda"
    )

    qweight_packed = qweight_unpacked[:, ::2] * 16 + qweight_unpacked[:, 1::2]
    qweight_packed = qweight_packed.to(torch.int8).view(torch.int32)
    qzeros_packed = qzeros_unpacked[:, ::2] * 16 + qzeros_unpacked[:, 1::2]
    qzeros_packed = qzeros_packed.to(torch.int8).view(torch.int32)

    cuda_res = ops.marlin_int4_fp8_preprocess(qweight_packed, qzeros_packed)

    repeated_zp = qzeros_unpacked.repeat_interleave(group_size, 0)
    torch_res = qweight_unpacked - repeated_zp
    torch_res[torch_res < 0] = 15 - qweight_unpacked[torch_res < 0]
    torch_res = torch_res[:, ::2] * 16 + torch_res[:, 1::2]
    torch_res = torch_res.to(torch.int8).view(torch.int32)

    assert (cuda_res == torch_res).all()


216
217
218
219
@pytest.mark.skipif(
    not is_quant_method_supported("gptq_marlin"),
    reason="Marlin is not supported on this GPU type.",
)
220
221
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
222
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(False, False))
223
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
224
225
@pytest.mark.parametrize("is_a_8bit", [True, False])
@pytest.mark.parametrize("nk_factors", MARLIN_REPACK_NK_FACTORS)
226
def test_gptq_marlin_repack(
227
    k_chunk, n_chunk, quant_type, act_order, is_a_8bit, nk_factors
228
):
229
    n_factor, k_factor = nk_factors
230
231
232

    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor
233
    group_size = 128
234
235
236
237
238
239
240

    # Filter act_order
    if act_order:
        if group_size == -1:
            return
        if group_size == size_k:
            return
241
242
        if is_a_8bit:
            return
243
244
245
246
247
248
249
250
251
252

    # Normalize group_size
    if group_size == -1:
        group_size = size_k
    assert group_size <= size_k

    # Create input
    b_weight = rand_data((size_k, size_n))

    # Quantize (and apply act_order if provided)
253
    w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
254
255
        b_weight, quant_type, group_size, act_order
    )
256
257

    # Pack to GPTQ format
258
    q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)
259
260
261
262
263
264
265
266

    # For act_order, sort the "weights" and "g_idx" so that group ids are
    # increasing
    sort_indices = torch.empty(0, dtype=torch.int, device=b_weight.device)
    if act_order:
        q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)

    # Pack to Marlin format
267
    weight_perm = get_weight_perm(quant_type.size_bits, is_a_8bit)
268
    marlin_q_w_1 = marlin_weights(
269
        q_w, size_k, size_n, quant_type.size_bits, weight_perm, is_a_8bit
270
    )
271

272
273
    opcheck(
        torch.ops._C.gptq_marlin_repack,
274
        (q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits, is_a_8bit),
275
    )
276

277
278
    # Run Marlin repack GPU kernel
    marlin_q_w_2 = ops.gptq_marlin_repack(
279
        q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits, is_a_8bit
280
281
282
    )
    torch.cuda.synchronize()

283
    torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
284
285


286
287
288
289
@pytest.mark.skipif(
    not is_quant_method_supported("gptq_marlin"),
    reason="Marlin is not supported on this GPU type.",
)
290
291
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
292
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(True))
293
294
295
296
@pytest.mark.parametrize("is_a_8bit", [True, False])
@pytest.mark.parametrize("nk_factors", MARLIN_REPACK_NK_FACTORS)
def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, is_a_8bit, nk_factors):
    n_factor, k_factor = nk_factors
297
298
299
300

    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

301
    group_size = 128
302
303
304
305
306

    # Create input
    b_weight = rand_data((size_k, size_n))

    # Quantize
307
308
309
    w_ref, q_w, s, zp = quantize_weights(
        b_weight, quant_type, group_size, zero_points=True
    )
310
311

    # Pack to AWQ format
312
    q_w_awq = awq_pack(q_w, quant_type.size_bits, size_k, size_n)
313
314

    # Pack to Marlin format
315
    weight_perm = get_weight_perm(quant_type.size_bits, is_a_8bit)
316
    marlin_q_w_1 = marlin_weights(
317
        q_w, size_k, size_n, quant_type.size_bits, weight_perm, is_a_8bit
318
    )
319

320
    opcheck(
321
322
        torch.ops._C.awq_marlin_repack,
        (q_w_awq, size_k, size_n, quant_type.size_bits, is_a_8bit),
323
    )
324

325
326
    # Run Marlin repack GPU kernel
    marlin_q_w_2 = ops.awq_marlin_repack(
327
        q_w_awq, size_k, size_n, quant_type.size_bits, is_a_8bit
328
329
330
    )
    torch.cuda.synchronize()

331
    torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
332
333


334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
def marlin_generate_valid_test_cases():
    all_combinations = itertools.product(
        DENSE_MARLIN_QUANT_TEST_CONFIGS,
        MNK_FACTORS,
        MARLIN_N_CHUNKS,
        MARLIN_K_CHUNKS,
        ACT_ORDER_OPTS,
        K_FULL_OPTS,
        USE_ATOMIC_ADD_OPTS,
        USE_FP32_REDUCE_OPTS,
    )

    def is_invalid(
        a_type,
        b_type,
        c_type,
        group_blocks,
        size_m,
        size_n,
        size_k,
        act_order,
        is_k_full,
        use_atomic_add,
        use_fp32_reduce,
    ):
        if use_atomic_add:
            if use_fp32_reduce:
                return False
            if (
                c_type == scalar_types.bfloat16
                and torch.cuda.get_device_capability()[0] < 9
            ):
                return False

        group_size = group_blocks if group_blocks <= 0 else group_blocks * 16
        if group_size > 0 and size_k % group_size != 0:
            return False

        if act_order and group_size in [-1, size_k]:
            return False
        if group_size == size_k:
            return False
        if not act_order and is_k_full:
            return False

        return a_type.size_bits < 16 or a_type is c_type

    cases = []
    for case in all_combinations:
        quant_test_config, mnk_factors, n_chunk, k_chunk, act_order, *_ = case
        size_m = mnk_factors[0]
        size_n = mnk_factors[1] * n_chunk
        size_k = mnk_factors[2] * k_chunk

        if act_order and not quant_test_config.get("support_act_order", False):
            continue

        f16_types = [scalar_types.float16, scalar_types.bfloat16]
        inner_combinations = itertools.product(
            quant_test_config.get("a_type", f16_types),
            [quant_test_config["b_type"]],
            quant_test_config.get("c_type", f16_types),
            quant_test_config["group_blocks"],
        )

        for sub_case in inner_combinations:
            if (
                sub_case[0] == scalar_types.float8_e4m3fn
                and current_platform.get_device_capability() not in [89, 120]
            ):
                continue
            args = sub_case + (size_m, size_n, size_k) + case[4:]
            if is_invalid(*args):
                cases.append(args)
    return cases


411
412
413
414
@pytest.mark.skipif(
    not is_quant_method_supported("gptq_marlin"),
    reason="Marlin is not supported on this GPU type.",
)
415
@pytest.mark.parametrize(
416
417
418
419
420
421
    (
        "a_type, b_type, c_type, group_blocks,"
        "size_m, size_n, size_k, act_order, is_k_full,"
        "use_atomic_add, use_fp32_reduce"
    ),
    marlin_generate_valid_test_cases(),
422
423
)
def test_gptq_marlin_gemm(
424
425
426
427
428
429
430
    a_type,
    b_type,
    c_type,
    group_blocks,
    size_m,
    size_n,
    size_k,
431
432
433
434
435
    act_order,
    is_k_full,
    use_atomic_add,
    use_fp32_reduce,
):
436
    has_zp = b_type in [scalar_types.uint4, scalar_types.uint8]
437

438
    group_size = group_blocks if group_blocks <= 0 else group_blocks * 16
439

440
441
442
443
444
445
    if c_type == scalar_types.float16:
        dtype = torch.float16
    elif c_type == scalar_types.bfloat16:
        dtype = torch.bfloat16
    else:
        raise RuntimeError("unsupported c_type")
446

447
448
449
450
451
452
    if a_type == scalar_types.int8:
        a_dtype = torch.int8
    elif a_type == scalar_types.float8_e4m3fn:
        a_dtype = torch.float8_e4m3fn
    else:
        a_dtype = dtype
453

454
455
    a_input = rand_data((size_m, size_k), dtype=dtype)
    b_weight = rand_data((size_k, size_n), dtype=dtype)
456

457
    if b_type == scalar_types.float4_e2m1f:
458
        if group_size == 16:
459
            w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_nvfp4_like(
460
                b_weight.T, group_size, input_dtype=a_dtype
461
            )
462
        else:
463
            w_ref, marlin_q_w, marlin_s = rand_marlin_weight_mxfp4_like(
464
                b_weight.T, group_size, input_dtype=a_dtype
465
            )
466
467
            marlin_s2 = None

468
469
470
        g_idx = None
        sort_indices = None
        marlin_zp = None
471
472
473
474
    elif b_type == scalar_types.float8_e4m3fn:
        w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(
            b_weight.T, group_size, input_dtype=a_dtype
        )
475
476
        g_idx = None
        sort_indices = None
477
478
479
480
        marlin_zp = None
        marlin_s2 = None
    elif has_zp:
        w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
481
            b_weight, b_type, group_size, input_dtype=a_dtype
482
        )
483
484
485
        g_idx = None
        sort_indices = None
        marlin_s2 = None
486
487
    else:
        w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
488
            b_weight, b_type, group_size, act_order, input_dtype=a_dtype
489
        )
490

491
492
        marlin_zp = None
        marlin_s2 = None
493

494
    workspace = marlin_make_workspace_new(w_ref.device)
495

496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
    if a_type == scalar_types.int8:
        a_input, a_scales = per_token_quant_int8(a_input)
        a_input_ref = a_input.to(a_scales.dtype) * a_scales.view(-1, 1)
        a_input_ref = a_input_ref.to(dtype)

        if group_size != -1:
            a_scales = a_scales / 4096 * marlin_s.max()
            a_scales = a_scales.float()
            marlin_s = marlin_s / marlin_s.max() * 4096
            marlin_s = marlin_s.round().to(torch.int16).view(dtype)
    elif a_type == scalar_types.float8_e4m3fn:
        a_input, a_scales = ops.scaled_fp8_quant(a_input, use_per_token_if_dynamic=True)
        a_input_ref = a_input.to(a_scales.dtype) * a_scales.view(-1, 1)
        a_input_ref = a_input_ref.to(dtype)
    else:
        assert a_type.size_bits == 16
        a_input_ref = a_input
        a_scales = None

    output = torch.empty((size_m, size_n), dtype=dtype, device=a_input.device)
516

517
518
    output = ops.gptq_marlin_gemm(
        a_input,
519
        output,
520
        marlin_q_w,
521
        None,
522
        marlin_s,
523
        a_scales,
524
        marlin_s2,
525
        marlin_zp,
526
527
        g_idx,
        sort_indices,
528
        workspace,
529
        b_type,
530
531
532
        a_input.shape[0],
        b_weight.shape[1],
        a_input.shape[1],
533
        is_k_full=is_k_full,
534
        use_atomic_add=use_atomic_add,
535
        use_fp32_reduce=use_fp32_reduce,
536
        is_zp_float=False,
537
    )
538
    output_ref = torch.matmul(a_input_ref, w_ref)
539

540
541
542
543
    max_diff = compute_max_diff(output, output_ref)
    assert max_diff < 0.04


544
545
# TODO: find better way to test this?
@torch.compile(fullgraph=True)
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
def marlin_24_gemm_tester(
    a_input,
    marlin_24_q_w_comp,
    marlin_24_meta,
    marlin_24_s,
    scratch,
    quant_type,
    size_m,
    size_n,
    size_k,
):
    return ops.gptq_marlin_24_gemm(
        a_input,
        marlin_24_q_w_comp,
        marlin_24_meta,
        marlin_24_s,
        scratch,
        quant_type,
        size_m,
        size_n,
        size_k,
    )
568
569


570
571
572
573
@pytest.mark.skipif(
    not is_quant_method_supported("gptq_marlin"),
    reason="Marlin is not supported on this GPU type.",
)
574
575
@pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_24_N_CHUNKS)
576
@pytest.mark.parametrize("quant_type", GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
577
578
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
579
def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size, mnk_factors):
580
581
582
583
584
585
586
587
588
    m_factor, n_factor, k_factor = mnk_factors

    size_m = m_factor
    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

    a_input = rand_data((size_m, size_k))
    b_weight = rand_data((size_k, size_n))

589
590
591
    (w_24_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = marlin_24_quantize(
        b_weight, quant_type, group_size
    )
592

593
594
595
    workspace_24 = MarlinWorkspace(
        size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL
    )
596
597
598

    output_ref = torch.matmul(a_input, w_24_ref)

599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
    opcheck(
        torch.ops._C.gptq_marlin_24_gemm,
        (
            a_input,
            marlin_24_q_w_comp,
            marlin_24_meta,
            marlin_24_s,
            workspace_24.scratch,
            quant_type.id,
            a_input.shape[0],
            b_weight.shape[1],
            a_input.shape[1],
        ),
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )
614

615
    output = marlin_24_gemm_tester(
616
617
618
619
620
        a_input,
        marlin_24_q_w_comp,
        marlin_24_meta,
        marlin_24_s,
        workspace_24.scratch,
621
        quant_type,
622
623
624
625
626
627
628
629
630
631
        a_input.shape[0],
        b_weight.shape[1],
        a_input.shape[1],
    )

    torch.cuda.synchronize()

    max_diff = compute_max_diff(output, output_ref)

    assert max_diff < 0.04
632
633


634
635
636
637
@pytest.mark.skipif(
    not is_quant_method_supported("gptq_marlin"),
    reason="Marlin is not supported on this GPU type.",
)
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("group_size", HQQ_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
def test_hqq_marlin_gemm(
    k_chunk,
    n_chunk,
    group_size,
    mnk_factors,
    use_fp32_reduce,
):
    m_factor, n_factor, k_factor = mnk_factors

    size_m = m_factor
    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

    quant_type = scalar_types.uint4

    a_input = rand_data((size_m, size_k))
    dev = a_input.device

661
    b_weight = torch.randint(0, 10, (size_n, size_k), dtype=torch.uint8, device=dev)
662
663
664
665
666
667
    scale = rand_data((size_n, size_k // group_size))
    zero = rand_data((size_n, size_k // group_size))

    gptq_w_q = gptq_pack(b_weight.transpose(1, 0), 4, size_k, size_n)

    sort_indices = torch.empty(0, dtype=torch.int, device=dev)
668
669
670
671
672
673
674
675
676
    marlin_w_q = ops.gptq_marlin_repack(gptq_w_q, sort_indices, size_k, size_n, 4).to(
        dev
    )
    marlin_s = marlin_permute_scales(
        scale.transpose(1, 0), size_k, size_n, group_size
    ).to(dev)
    marlin_zp = marlin_permute_scales(
        zero.transpose(1, 0), size_k, size_n, group_size
    ).to(dev)
677
678
679
680

    g_idx = marlin_make_empty_g_idx(dev)
    g_idx_sort_indices = marlin_make_empty_g_idx(dev)

681
    workspace = marlin_make_workspace_new(b_weight.device)
682
683
684

    output = ops.gptq_marlin_gemm(
        a_input,
685
        None,
686
        marlin_w_q,
687
        None,
688
        marlin_s,
689
        None,
690
        None,
691
692
693
        marlin_zp,
        g_idx,
        g_idx_sort_indices,
694
        workspace,
695
696
697
698
699
700
701
702
703
704
705
706
707
708
        quant_type,
        a_input.shape[0],
        b_weight.shape[0],
        a_input.shape[1],
        is_k_full=True,
        use_fp32_reduce=use_fp32_reduce,
        is_zp_float=True,
    )

    b_flat = b_weight.reshape(-1, group_size)
    zp_flat = zero.reshape(-1, 1)
    s_flat = scale.reshape(-1, 1)
    dequant = (b_flat - zp_flat) * s_flat

709
    output_ref = torch.matmul(a_input, dequant.reshape(b_weight.shape).transpose(1, 0))
710
711
712
713
714
715
716
717

    torch.cuda.synchronize()

    max_diff = compute_max_diff(output, output_ref)

    assert max_diff < 0.04


718
719
720
721
722
723
724
725
def test_marlin_gemm_subset_input():
    quant_type = scalar_types.uint4b8
    group_size = 128

    size_m, size_k, size_n = 32, 1024, 2048
    big_m = size_m * 2
    big_k = size_k * 2

726
    a_input = rand_data((big_m, big_k))[8 : size_m + 8, 8 : size_k + 8]
727
728
729
    b_weight = rand_data((size_k, size_n))

    w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
730
731
        b_weight, quant_type, group_size, False
    )
732
733

    marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
734
    workspace = marlin_make_workspace_new(a_input.device)
735
736
737

    output = ops.gptq_marlin_gemm(
        a_input,
738
        None,
739
        marlin_q_w,
740
        None,
741
        marlin_s,
742
        None,
743
        None,
744
745
746
        marlin_zp,
        g_idx,
        sort_indices,
747
        workspace,
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
        quant_type,
        a_input.shape[0],
        b_weight.shape[1],
        a_input.shape[1],
        is_k_full=True,
        use_atomic_add=False,
        use_fp32_reduce=True,
        is_zp_float=False,
    )
    output_ref = torch.matmul(a_input, w_ref)

    torch.cuda.synchronize()

    max_diff = compute_max_diff(output, output_ref)

    assert max_diff < 0.04


766
767
768
769
770
771
772
773
@pytest.mark.parametrize("size_m", [1, 256])
def test_marlin_gemm_with_bias(size_m):
    quant_type = scalar_types.uint4b8
    group_size = 128

    size_k, size_n = 1024, 2048
    a_input = rand_data((size_m, size_k))
    b_weight = rand_data((size_k, size_n))
774
    b_bias = rand_data((size_n,)) * 10
775
776
777
778

    marlin_bias = marlin_permute_bias(b_bias)

    w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
779
780
        b_weight, quant_type, group_size, False
    )
781
782
783
784
785
786
787
788
789
790
791

    marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
    workspace = marlin_make_workspace_new(a_input.device)

    output = ops.gptq_marlin_gemm(
        a_input,
        None,
        marlin_q_w,
        marlin_bias,
        marlin_s,
        None,
792
        None,
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
        marlin_zp,
        g_idx,
        sort_indices,
        workspace,
        quant_type,
        a_input.shape[0],
        b_weight.shape[1],
        a_input.shape[1],
        is_k_full=True,
        use_atomic_add=False,
        use_fp32_reduce=True,
        is_zp_float=False,
    )
    output_ref = torch.matmul(a_input, w_ref) + b_bias.view(1, -1)

    torch.cuda.synchronize()

    max_diff = compute_max_diff(output, output_ref)

    assert max_diff < 0.04