test_marlin_gemm.py 22.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
"""Tests for the marlin kernel.

5
Run `pytest tests/kernels/quantization/test_marlin_gemm.py`.
6
"""
7

8
9
import itertools

10
11
12
import pytest
import torch

13
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
14
from tests.quantization.utils import is_quant_method_supported
15
from vllm import _custom_ops as ops
16
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
17
18
19
20
21
    GPTQ_MARLIN_24_MAX_PARALLEL,
    GPTQ_MARLIN_24_MIN_THREAD_N,
    GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES,
    GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES,
)
22
23
24
from vllm.model_executor.layers.quantization.utils.int8_utils import (
    per_token_quant_int8,
)
25
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
26
27
28
29
30
31
    marlin_make_empty_g_idx,
    marlin_make_workspace_new,
    marlin_permute_bias,
    marlin_permute_scales,
    query_marlin_supported_quant_types,
)
32
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
33
34
35
    rand_marlin_weight_mxfp4_like,
    rand_marlin_weight_nvfp4_like,
)
36
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
37
38
    marlin_quant_fp8_torch,
)
39
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
40
41
42
43
44
45
    MarlinWorkspace,
    awq_marlin_quantize,
    get_weight_perm,
    marlin_quantize,
    marlin_weights,
)
46
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
47
48
    marlin_24_quantize,
)
49
from vllm.model_executor.layers.quantization.utils.quant_utils import (
50
51
52
53
54
55
    awq_pack,
    gptq_pack,
    gptq_quantize_weights,
    quantize_weights,
    sort_weights,
)
56
from vllm.platforms import current_platform
57
from vllm.scalar_type import scalar_types
58
59
60

ACT_ORDER_OPTS = [False, True]
K_FULL_OPTS = [False, True]
61
USE_ATOMIC_ADD_OPTS = [False, True]
62
USE_FP32_REDUCE_OPTS = [True]
63

64
MARLIN_K_CHUNKS = [128]
65
MARLIN_N_CHUNKS = [64, 256]
66
67

MARLIN_24_K_CHUNKS = [128]
68
MARLIN_24_N_CHUNKS = [512]
69

70
71
HQQ_SUPPORTED_GROUP_SIZES = [64]

72
73
74
75
76
77
MARLIN_REPACK_NK_FACTORS = [
    (4, 8),
    (7, 5),
    (13, 11),
]

78
79
80
81
MNK_FACTORS = [
    (1, 1, 1),
    (1, 4, 8),
    (26, 37, 13),
82
    (257, 13, 11),
83
84
]

85
DTYPES = [torch.float16, torch.bfloat16]
86

87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
DENSE_MARLIN_QUANT_TEST_CONFIGS = [
    # AWQ-INT4
    {"b_type": scalar_types.uint4, "group_blocks": [-1, 2, 4, 8]},
    # GPTQ-INT4
    {
        "b_type": scalar_types.uint4b8,
        "support_act_order": True,
        "group_blocks": [-1, 2, 4, 8],
    },
    # GPTQ-INT8
    {
        "b_type": scalar_types.uint8b128,
        "support_act_order": True,
        "group_blocks": [-1, 2, 4, 8],
    },
    # FP8
    {"b_type": scalar_types.float8_e4m3fn, "group_blocks": [-1, 8]},
    # NVFP4
    {"b_type": scalar_types.float4_e2m1f, "group_blocks": [1]},
    # MXFP4
    {
        "a_type": [scalar_types.bfloat16],
        "b_type": scalar_types.float4_e2m1f,
        "group_blocks": [2],
    },
    # AWQ-INT4 with INT8 activation
    {
        "a_type": [scalar_types.int8],
        "b_type": scalar_types.uint4,
        "group_blocks": [-1, 2, 4, 8],
    },
    # GPTQ-INT4 with INT8 activation
    {
        "a_type": [scalar_types.int8],
        "b_type": scalar_types.uint4b8,
        "group_blocks": [-1, 2, 4, 8],
    },
    # GPTQ-INT4 with FP8 activation
    {
        "a_type": [scalar_types.float8_e4m3fn],
        "b_type": scalar_types.uint4b8,
        "group_blocks": [-1, 2, 4, 8],
    },
    # AWQ-INT4 with FP8 activation
    {
        "a_type": [scalar_types.float8_e4m3fn],
        "b_type": scalar_types.uint4,
        "group_blocks": [-1, 2, 4, 8],
    },
    # MXFP4 with FP8 activation
    {
        "a_type": [scalar_types.float8_e4m3fn],
        "b_type": scalar_types.float4_e2m1f,
        "c_type": [scalar_types.bfloat16],
        "group_blocks": [2],
    },
]

145

146
147
def compute_max_diff(output, output_ref):
    return torch.mean(torch.abs(output - output_ref)) / torch.mean(
148
149
        torch.abs(output_ref)
    )
150
151


152
153
def rand_data(shape, dtype=torch.float16):
    return torch.randn(shape, dtype=dtype, device="cuda")
154
155


156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
@pytest.mark.skipif(
    not is_quant_method_supported("gptq_marlin"),
    reason="Marlin is not supported on this GPU type.",
)
def test_marlin_int4_fp8_preprocess_without_zp():
    qweight_unpacked = torch.randint(
        0, 16, size=(2048, 2048), dtype=torch.int32, device="cuda"
    )
    qweight_packed = qweight_unpacked[:, ::2] * 16 + qweight_unpacked[:, 1::2]
    qweight_packed = qweight_packed.to(torch.int8).view(torch.int32)

    cuda_res = ops.marlin_int4_fp8_preprocess(qweight_packed)

    torch_res = torch.where(
        qweight_unpacked >= 8, qweight_unpacked - 8, 15 - qweight_unpacked
    )
    torch_res = torch_res[:, ::2] * 16 + torch_res[:, 1::2]
    torch_res = torch_res.to(torch.int8).view(torch.int32)

    assert (cuda_res == torch_res).all()


@pytest.mark.skipif(
    not is_quant_method_supported("gptq_marlin"),
    reason="Marlin is not supported on this GPU type.",
)
def test_marlin_int4_fp8_preprocess_awq():
    group_size = 128

    qweight_unpacked = torch.randint(
        0, 16, size=(2048, 2048), dtype=torch.int32, device="cuda"
    )
    qzeros_unpacked = torch.randint(
        0, 16, size=(2048 // group_size, 2048), dtype=torch.int32, device="cuda"
    )

    qweight_packed = qweight_unpacked[:, ::2] * 16 + qweight_unpacked[:, 1::2]
    qweight_packed = qweight_packed.to(torch.int8).view(torch.int32)
    qzeros_packed = qzeros_unpacked[:, ::2] * 16 + qzeros_unpacked[:, 1::2]
    qzeros_packed = qzeros_packed.to(torch.int8).view(torch.int32)

    cuda_res = ops.marlin_int4_fp8_preprocess(qweight_packed, qzeros_packed)

    repeated_zp = qzeros_unpacked.repeat_interleave(group_size, 0)
    torch_res = qweight_unpacked - repeated_zp
    torch_res[torch_res < 0] = 15 - qweight_unpacked[torch_res < 0]
    torch_res = torch_res[:, ::2] * 16 + torch_res[:, 1::2]
    torch_res = torch_res.to(torch.int8).view(torch.int32)

    assert (cuda_res == torch_res).all()


208
209
210
211
@pytest.mark.skipif(
    not is_quant_method_supported("gptq_marlin"),
    reason="Marlin is not supported on this GPU type.",
)
212
213
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
214
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(False, False))
215
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
216
217
@pytest.mark.parametrize("is_a_8bit", [True, False])
@pytest.mark.parametrize("nk_factors", MARLIN_REPACK_NK_FACTORS)
218
def test_gptq_marlin_repack(
219
    k_chunk, n_chunk, quant_type, act_order, is_a_8bit, nk_factors
220
):
221
    n_factor, k_factor = nk_factors
222
223
224

    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor
225
    group_size = 128
226
227
228
229
230
231
232

    # Filter act_order
    if act_order:
        if group_size == -1:
            return
        if group_size == size_k:
            return
233
234
        if is_a_8bit:
            return
235
236
237
238
239
240
241
242
243
244

    # Normalize group_size
    if group_size == -1:
        group_size = size_k
    assert group_size <= size_k

    # Create input
    b_weight = rand_data((size_k, size_n))

    # Quantize (and apply act_order if provided)
245
    w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
246
247
        b_weight, quant_type, group_size, act_order
    )
248
249

    # Pack to GPTQ format
250
    q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)
251
252
253
254
255
256
257
258

    # For act_order, sort the "weights" and "g_idx" so that group ids are
    # increasing
    sort_indices = torch.empty(0, dtype=torch.int, device=b_weight.device)
    if act_order:
        q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)

    # Pack to Marlin format
259
    weight_perm = get_weight_perm(quant_type.size_bits, is_a_8bit)
260
    marlin_q_w_1 = marlin_weights(
261
        q_w, size_k, size_n, quant_type.size_bits, weight_perm, is_a_8bit
262
    )
263

264
265
    opcheck(
        torch.ops._C.gptq_marlin_repack,
266
        (q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits, is_a_8bit),
267
    )
268

269
270
    # Run Marlin repack GPU kernel
    marlin_q_w_2 = ops.gptq_marlin_repack(
271
        q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits, is_a_8bit
272
273
274
    )
    torch.cuda.synchronize()

275
    torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
276
277


278
279
280
281
@pytest.mark.skipif(
    not is_quant_method_supported("gptq_marlin"),
    reason="Marlin is not supported on this GPU type.",
)
282
283
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
284
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(True))
285
286
287
288
@pytest.mark.parametrize("is_a_8bit", [True, False])
@pytest.mark.parametrize("nk_factors", MARLIN_REPACK_NK_FACTORS)
def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, is_a_8bit, nk_factors):
    n_factor, k_factor = nk_factors
289
290
291
292

    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

293
    group_size = 128
294
295
296
297
298

    # Create input
    b_weight = rand_data((size_k, size_n))

    # Quantize
299
300
301
    w_ref, q_w, s, zp = quantize_weights(
        b_weight, quant_type, group_size, zero_points=True
    )
302
303

    # Pack to AWQ format
304
    q_w_awq = awq_pack(q_w, quant_type.size_bits, size_k, size_n)
305
306

    # Pack to Marlin format
307
    weight_perm = get_weight_perm(quant_type.size_bits, is_a_8bit)
308
    marlin_q_w_1 = marlin_weights(
309
        q_w, size_k, size_n, quant_type.size_bits, weight_perm, is_a_8bit
310
    )
311

312
    opcheck(
313
314
        torch.ops._C.awq_marlin_repack,
        (q_w_awq, size_k, size_n, quant_type.size_bits, is_a_8bit),
315
    )
316

317
318
    # Run Marlin repack GPU kernel
    marlin_q_w_2 = ops.awq_marlin_repack(
319
        q_w_awq, size_k, size_n, quant_type.size_bits, is_a_8bit
320
321
322
    )
    torch.cuda.synchronize()

323
    torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
324
325


326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
def marlin_generate_valid_test_cases():
    all_combinations = itertools.product(
        DENSE_MARLIN_QUANT_TEST_CONFIGS,
        MNK_FACTORS,
        MARLIN_N_CHUNKS,
        MARLIN_K_CHUNKS,
        ACT_ORDER_OPTS,
        K_FULL_OPTS,
        USE_ATOMIC_ADD_OPTS,
        USE_FP32_REDUCE_OPTS,
    )

    def is_invalid(
        a_type,
        b_type,
        c_type,
        group_blocks,
        size_m,
        size_n,
        size_k,
        act_order,
        is_k_full,
        use_atomic_add,
        use_fp32_reduce,
    ):
        if use_atomic_add:
            if use_fp32_reduce:
                return False
            if (
                c_type == scalar_types.bfloat16
                and torch.cuda.get_device_capability()[0] < 9
            ):
                return False

        group_size = group_blocks if group_blocks <= 0 else group_blocks * 16
        if group_size > 0 and size_k % group_size != 0:
            return False

        if act_order and group_size in [-1, size_k]:
            return False
        if group_size == size_k:
            return False
        if not act_order and is_k_full:
            return False

        return a_type.size_bits < 16 or a_type is c_type

    cases = []
    for case in all_combinations:
        quant_test_config, mnk_factors, n_chunk, k_chunk, act_order, *_ = case
        size_m = mnk_factors[0]
        size_n = mnk_factors[1] * n_chunk
        size_k = mnk_factors[2] * k_chunk

        if act_order and not quant_test_config.get("support_act_order", False):
            continue

        f16_types = [scalar_types.float16, scalar_types.bfloat16]
        inner_combinations = itertools.product(
            quant_test_config.get("a_type", f16_types),
            [quant_test_config["b_type"]],
            quant_test_config.get("c_type", f16_types),
            quant_test_config["group_blocks"],
        )

        for sub_case in inner_combinations:
            if (
                sub_case[0] == scalar_types.float8_e4m3fn
                and current_platform.get_device_capability() not in [89, 120]
            ):
                continue
            args = sub_case + (size_m, size_n, size_k) + case[4:]
            if is_invalid(*args):
                cases.append(args)
    return cases


403
404
405
406
@pytest.mark.skipif(
    not is_quant_method_supported("gptq_marlin"),
    reason="Marlin is not supported on this GPU type.",
)
407
@pytest.mark.parametrize(
408
409
410
411
412
413
    (
        "a_type, b_type, c_type, group_blocks,"
        "size_m, size_n, size_k, act_order, is_k_full,"
        "use_atomic_add, use_fp32_reduce"
    ),
    marlin_generate_valid_test_cases(),
414
415
)
def test_gptq_marlin_gemm(
416
417
418
419
420
421
422
    a_type,
    b_type,
    c_type,
    group_blocks,
    size_m,
    size_n,
    size_k,
423
424
425
426
427
    act_order,
    is_k_full,
    use_atomic_add,
    use_fp32_reduce,
):
428
    has_zp = b_type in [scalar_types.uint4, scalar_types.uint8]
429

430
    group_size = group_blocks if group_blocks <= 0 else group_blocks * 16
431

432
433
434
435
436
437
    if c_type == scalar_types.float16:
        dtype = torch.float16
    elif c_type == scalar_types.bfloat16:
        dtype = torch.bfloat16
    else:
        raise RuntimeError("unsupported c_type")
438

439
440
441
442
443
444
    if a_type == scalar_types.int8:
        a_dtype = torch.int8
    elif a_type == scalar_types.float8_e4m3fn:
        a_dtype = torch.float8_e4m3fn
    else:
        a_dtype = dtype
445

446
447
    a_input = rand_data((size_m, size_k), dtype=dtype)
    b_weight = rand_data((size_k, size_n), dtype=dtype)
448

449
    if b_type == scalar_types.float4_e2m1f:
450
        if group_size == 16:
451
            w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_nvfp4_like(
452
                b_weight.T, group_size, input_dtype=a_dtype
453
            )
454
        else:
455
            w_ref, marlin_q_w, marlin_s = rand_marlin_weight_mxfp4_like(
456
                b_weight.T, group_size, input_dtype=a_dtype
457
            )
458
459
            marlin_s2 = None

460
461
462
        g_idx = None
        sort_indices = None
        marlin_zp = None
463
464
465
466
    elif b_type == scalar_types.float8_e4m3fn:
        w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(
            b_weight.T, group_size, input_dtype=a_dtype
        )
467
468
        g_idx = None
        sort_indices = None
469
470
471
472
        marlin_zp = None
        marlin_s2 = None
    elif has_zp:
        w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
473
            b_weight, b_type, group_size, input_dtype=a_dtype
474
        )
475
476
477
        g_idx = None
        sort_indices = None
        marlin_s2 = None
478
479
    else:
        w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
480
            b_weight, b_type, group_size, act_order, input_dtype=a_dtype
481
        )
482

483
484
        marlin_zp = None
        marlin_s2 = None
485

486
    workspace = marlin_make_workspace_new(w_ref.device)
487

488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
    if a_type == scalar_types.int8:
        a_input, a_scales = per_token_quant_int8(a_input)
        a_input_ref = a_input.to(a_scales.dtype) * a_scales.view(-1, 1)
        a_input_ref = a_input_ref.to(dtype)

        if group_size != -1:
            a_scales = a_scales / 4096 * marlin_s.max()
            a_scales = a_scales.float()
            marlin_s = marlin_s / marlin_s.max() * 4096
            marlin_s = marlin_s.round().to(torch.int16).view(dtype)
    elif a_type == scalar_types.float8_e4m3fn:
        a_input, a_scales = ops.scaled_fp8_quant(a_input, use_per_token_if_dynamic=True)
        a_input_ref = a_input.to(a_scales.dtype) * a_scales.view(-1, 1)
        a_input_ref = a_input_ref.to(dtype)
    else:
        assert a_type.size_bits == 16
        a_input_ref = a_input
        a_scales = None

    output = torch.empty((size_m, size_n), dtype=dtype, device=a_input.device)
508

509
510
    output = ops.gptq_marlin_gemm(
        a_input,
511
        output,
512
        marlin_q_w,
513
        None,
514
        marlin_s,
515
        a_scales,
516
        marlin_s2,
517
        marlin_zp,
518
519
        g_idx,
        sort_indices,
520
        workspace,
521
        b_type,
522
523
524
        a_input.shape[0],
        b_weight.shape[1],
        a_input.shape[1],
525
        is_k_full=is_k_full,
526
        use_atomic_add=use_atomic_add,
527
        use_fp32_reduce=use_fp32_reduce,
528
        is_zp_float=False,
529
    )
530
    output_ref = torch.matmul(a_input_ref, w_ref)
531

532
533
534
535
    max_diff = compute_max_diff(output, output_ref)
    assert max_diff < 0.04


536
537
# TODO: find better way to test this?
@torch.compile(fullgraph=True)
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
def marlin_24_gemm_tester(
    a_input,
    marlin_24_q_w_comp,
    marlin_24_meta,
    marlin_24_s,
    scratch,
    quant_type,
    size_m,
    size_n,
    size_k,
):
    return ops.gptq_marlin_24_gemm(
        a_input,
        marlin_24_q_w_comp,
        marlin_24_meta,
        marlin_24_s,
        scratch,
        quant_type,
        size_m,
        size_n,
        size_k,
    )
560
561


562
563
564
565
@pytest.mark.skipif(
    not is_quant_method_supported("gptq_marlin"),
    reason="Marlin is not supported on this GPU type.",
)
566
567
@pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_24_N_CHUNKS)
568
@pytest.mark.parametrize("quant_type", GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
569
570
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
571
def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size, mnk_factors):
572
573
574
575
576
577
578
579
580
    m_factor, n_factor, k_factor = mnk_factors

    size_m = m_factor
    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

    a_input = rand_data((size_m, size_k))
    b_weight = rand_data((size_k, size_n))

581
582
583
    (w_24_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = marlin_24_quantize(
        b_weight, quant_type, group_size
    )
584

585
586
587
    workspace_24 = MarlinWorkspace(
        size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL
    )
588
589
590

    output_ref = torch.matmul(a_input, w_24_ref)

591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
    opcheck(
        torch.ops._C.gptq_marlin_24_gemm,
        (
            a_input,
            marlin_24_q_w_comp,
            marlin_24_meta,
            marlin_24_s,
            workspace_24.scratch,
            quant_type.id,
            a_input.shape[0],
            b_weight.shape[1],
            a_input.shape[1],
        ),
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )
606

607
    output = marlin_24_gemm_tester(
608
609
610
611
612
        a_input,
        marlin_24_q_w_comp,
        marlin_24_meta,
        marlin_24_s,
        workspace_24.scratch,
613
        quant_type,
614
615
616
617
618
619
620
621
622
623
        a_input.shape[0],
        b_weight.shape[1],
        a_input.shape[1],
    )

    torch.cuda.synchronize()

    max_diff = compute_max_diff(output, output_ref)

    assert max_diff < 0.04
624
625


626
627
628
629
@pytest.mark.skipif(
    not is_quant_method_supported("gptq_marlin"),
    reason="Marlin is not supported on this GPU type.",
)
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("group_size", HQQ_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
def test_hqq_marlin_gemm(
    k_chunk,
    n_chunk,
    group_size,
    mnk_factors,
    use_fp32_reduce,
):
    m_factor, n_factor, k_factor = mnk_factors

    size_m = m_factor
    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

    quant_type = scalar_types.uint4

    a_input = rand_data((size_m, size_k))
    dev = a_input.device

653
    b_weight = torch.randint(0, 10, (size_n, size_k), dtype=torch.uint8, device=dev)
654
655
656
657
658
659
    scale = rand_data((size_n, size_k // group_size))
    zero = rand_data((size_n, size_k // group_size))

    gptq_w_q = gptq_pack(b_weight.transpose(1, 0), 4, size_k, size_n)

    sort_indices = torch.empty(0, dtype=torch.int, device=dev)
660
661
662
663
664
665
666
667
668
    marlin_w_q = ops.gptq_marlin_repack(gptq_w_q, sort_indices, size_k, size_n, 4).to(
        dev
    )
    marlin_s = marlin_permute_scales(
        scale.transpose(1, 0), size_k, size_n, group_size
    ).to(dev)
    marlin_zp = marlin_permute_scales(
        zero.transpose(1, 0), size_k, size_n, group_size
    ).to(dev)
669
670
671
672

    g_idx = marlin_make_empty_g_idx(dev)
    g_idx_sort_indices = marlin_make_empty_g_idx(dev)

673
    workspace = marlin_make_workspace_new(b_weight.device)
674
675
676

    output = ops.gptq_marlin_gemm(
        a_input,
677
        None,
678
        marlin_w_q,
679
        None,
680
        marlin_s,
681
        None,
682
        None,
683
684
685
        marlin_zp,
        g_idx,
        g_idx_sort_indices,
686
        workspace,
687
688
689
690
691
692
693
694
695
696
697
698
699
700
        quant_type,
        a_input.shape[0],
        b_weight.shape[0],
        a_input.shape[1],
        is_k_full=True,
        use_fp32_reduce=use_fp32_reduce,
        is_zp_float=True,
    )

    b_flat = b_weight.reshape(-1, group_size)
    zp_flat = zero.reshape(-1, 1)
    s_flat = scale.reshape(-1, 1)
    dequant = (b_flat - zp_flat) * s_flat

701
    output_ref = torch.matmul(a_input, dequant.reshape(b_weight.shape).transpose(1, 0))
702
703
704
705
706
707
708
709

    torch.cuda.synchronize()

    max_diff = compute_max_diff(output, output_ref)

    assert max_diff < 0.04


710
711
712
713
714
715
716
717
def test_marlin_gemm_subset_input():
    quant_type = scalar_types.uint4b8
    group_size = 128

    size_m, size_k, size_n = 32, 1024, 2048
    big_m = size_m * 2
    big_k = size_k * 2

718
    a_input = rand_data((big_m, big_k))[8 : size_m + 8, 8 : size_k + 8]
719
720
721
    b_weight = rand_data((size_k, size_n))

    w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
722
723
        b_weight, quant_type, group_size, False
    )
724
725

    marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
726
    workspace = marlin_make_workspace_new(a_input.device)
727
728
729

    output = ops.gptq_marlin_gemm(
        a_input,
730
        None,
731
        marlin_q_w,
732
        None,
733
        marlin_s,
734
        None,
735
        None,
736
737
738
        marlin_zp,
        g_idx,
        sort_indices,
739
        workspace,
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
        quant_type,
        a_input.shape[0],
        b_weight.shape[1],
        a_input.shape[1],
        is_k_full=True,
        use_atomic_add=False,
        use_fp32_reduce=True,
        is_zp_float=False,
    )
    output_ref = torch.matmul(a_input, w_ref)

    torch.cuda.synchronize()

    max_diff = compute_max_diff(output, output_ref)

    assert max_diff < 0.04


758
759
760
761
762
763
764
765
@pytest.mark.parametrize("size_m", [1, 256])
def test_marlin_gemm_with_bias(size_m):
    quant_type = scalar_types.uint4b8
    group_size = 128

    size_k, size_n = 1024, 2048
    a_input = rand_data((size_m, size_k))
    b_weight = rand_data((size_k, size_n))
766
    b_bias = rand_data((size_n,)) * 10
767
768
769
770

    marlin_bias = marlin_permute_bias(b_bias)

    w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
771
772
        b_weight, quant_type, group_size, False
    )
773
774
775
776
777
778
779
780
781
782
783

    marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
    workspace = marlin_make_workspace_new(a_input.device)

    output = ops.gptq_marlin_gemm(
        a_input,
        None,
        marlin_q_w,
        marlin_bias,
        marlin_s,
        None,
784
        None,
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
        marlin_zp,
        g_idx,
        sort_indices,
        workspace,
        quant_type,
        a_input.shape[0],
        b_weight.shape[1],
        a_input.shape[1],
        is_k_full=True,
        use_atomic_add=False,
        use_fp32_reduce=True,
        is_zp_float=False,
    )
    output_ref = torch.matmul(a_input, w_ref) + b_bias.view(1, -1)

    torch.cuda.synchronize()

    max_diff = compute_max_diff(output, output_ref)

    assert max_diff < 0.04