"vscode:/vscode.git/clone" did not exist on "2ad10292339c045db0cb3998a76240a459cc83a7"
test_marlin_gemm.py 20.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
"""Tests for the marlin kernel.

5
Run `pytest tests/kernels/quantization/test_marlin_gemm.py`.
6
"""
7

8
9
import itertools

10
11
12
import pytest
import torch

13
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
14
from tests.quantization.utils import is_quant_method_supported
15
from vllm import _custom_ops as ops
16
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
17
18
19
20
21
    GPTQ_MARLIN_24_MAX_PARALLEL,
    GPTQ_MARLIN_24_MIN_THREAD_N,
    GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES,
    GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES,
)
22
23
24
from vllm.model_executor.layers.quantization.utils.int8_utils import (
    per_token_quant_int8,
)
25
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
26
27
28
29
30
    marlin_make_empty_g_idx,
    marlin_make_workspace_new,
    marlin_permute_bias,
    query_marlin_supported_quant_types,
)
31
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
32
33
34
    rand_marlin_weight_mxfp4_like,
    rand_marlin_weight_nvfp4_like,
)
35
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
36
37
    marlin_quant_fp8_torch,
)
38
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
39
40
41
42
43
44
    MarlinWorkspace,
    awq_marlin_quantize,
    get_weight_perm,
    marlin_quantize,
    marlin_weights,
)
45
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
46
47
    marlin_24_quantize,
)
48
from vllm.model_executor.layers.quantization.utils.quant_utils import (
49
50
51
52
53
54
    awq_pack,
    gptq_pack,
    gptq_quantize_weights,
    quantize_weights,
    sort_weights,
)
55
from vllm.platforms import current_platform
56
from vllm.scalar_type import scalar_types
57

58
59
60
61
if current_platform.is_rocm():
    pytest.skip(
        "These tests require gptq_marlin_repack,"
        "marlin_int4_fp8_preprocess, gptq_marlin_24_gemm,"
62
        "or marlin_gemm which are not supported on ROCm.",
63
64
65
        allow_module_level=True,
    )

66
67
ACT_ORDER_OPTS = [False, True]
K_FULL_OPTS = [False, True]
68
USE_ATOMIC_ADD_OPTS = [False, True]
69
USE_FP32_REDUCE_OPTS = [True]
70

71
MARLIN_K_CHUNKS = [128]
72
MARLIN_N_CHUNKS = [64, 256]
73
74

MARLIN_24_K_CHUNKS = [128]
75
MARLIN_24_N_CHUNKS = [512]
76

77
78
79
80
81
82
MARLIN_REPACK_NK_FACTORS = [
    (4, 8),
    (7, 5),
    (13, 11),
]

83
84
85
86
MNK_FACTORS = [
    (1, 1, 1),
    (1, 4, 8),
    (26, 37, 13),
87
    (257, 13, 11),
88
89
]

90
DTYPES = [torch.float16, torch.bfloat16]
91

92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
DENSE_MARLIN_QUANT_TEST_CONFIGS = [
    # AWQ-INT4
    {"b_type": scalar_types.uint4, "group_blocks": [-1, 2, 4, 8]},
    # GPTQ-INT4
    {
        "b_type": scalar_types.uint4b8,
        "support_act_order": True,
        "group_blocks": [-1, 2, 4, 8],
    },
    # GPTQ-INT8
    {
        "b_type": scalar_types.uint8b128,
        "support_act_order": True,
        "group_blocks": [-1, 2, 4, 8],
    },
    # FP8
    {"b_type": scalar_types.float8_e4m3fn, "group_blocks": [-1, 8]},
    # NVFP4
    {"b_type": scalar_types.float4_e2m1f, "group_blocks": [1]},
    # MXFP4
    {
        "a_type": [scalar_types.bfloat16],
        "b_type": scalar_types.float4_e2m1f,
        "group_blocks": [2],
    },
    # AWQ-INT4 with INT8 activation
    {
        "a_type": [scalar_types.int8],
        "b_type": scalar_types.uint4,
        "group_blocks": [-1, 2, 4, 8],
    },
    # GPTQ-INT4 with INT8 activation
    {
        "a_type": [scalar_types.int8],
        "b_type": scalar_types.uint4b8,
        "group_blocks": [-1, 2, 4, 8],
    },
    # GPTQ-INT4 with FP8 activation
    {
        "a_type": [scalar_types.float8_e4m3fn],
        "b_type": scalar_types.uint4b8,
        "group_blocks": [-1, 2, 4, 8],
    },
    # AWQ-INT4 with FP8 activation
    {
        "a_type": [scalar_types.float8_e4m3fn],
        "b_type": scalar_types.uint4,
        "group_blocks": [-1, 2, 4, 8],
    },
    # MXFP4 with FP8 activation
    {
        "a_type": [scalar_types.float8_e4m3fn],
        "b_type": scalar_types.float4_e2m1f,
        "c_type": [scalar_types.bfloat16],
        "group_blocks": [2],
    },
]

150

151
152
def compute_max_diff(output, output_ref):
    return torch.mean(torch.abs(output - output_ref)) / torch.mean(
153
154
        torch.abs(output_ref)
    )
155
156


157
158
def rand_data(shape, dtype=torch.float16):
    return torch.randn(shape, dtype=dtype, device="cuda")
159
160


161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
@pytest.mark.skipif(
    not is_quant_method_supported("gptq_marlin"),
    reason="Marlin is not supported on this GPU type.",
)
def test_marlin_int4_fp8_preprocess_without_zp():
    qweight_unpacked = torch.randint(
        0, 16, size=(2048, 2048), dtype=torch.int32, device="cuda"
    )
    qweight_packed = qweight_unpacked[:, ::2] * 16 + qweight_unpacked[:, 1::2]
    qweight_packed = qweight_packed.to(torch.int8).view(torch.int32)

    cuda_res = ops.marlin_int4_fp8_preprocess(qweight_packed)

    torch_res = torch.where(
        qweight_unpacked >= 8, qweight_unpacked - 8, 15 - qweight_unpacked
    )
    torch_res = torch_res[:, ::2] * 16 + torch_res[:, 1::2]
    torch_res = torch_res.to(torch.int8).view(torch.int32)

    assert (cuda_res == torch_res).all()


@pytest.mark.skipif(
    not is_quant_method_supported("gptq_marlin"),
    reason="Marlin is not supported on this GPU type.",
)
def test_marlin_int4_fp8_preprocess_awq():
    group_size = 128

    qweight_unpacked = torch.randint(
        0, 16, size=(2048, 2048), dtype=torch.int32, device="cuda"
    )
    qzeros_unpacked = torch.randint(
        0, 16, size=(2048 // group_size, 2048), dtype=torch.int32, device="cuda"
    )

    qweight_packed = qweight_unpacked[:, ::2] * 16 + qweight_unpacked[:, 1::2]
    qweight_packed = qweight_packed.to(torch.int8).view(torch.int32)
    qzeros_packed = qzeros_unpacked[:, ::2] * 16 + qzeros_unpacked[:, 1::2]
    qzeros_packed = qzeros_packed.to(torch.int8).view(torch.int32)

    cuda_res = ops.marlin_int4_fp8_preprocess(qweight_packed, qzeros_packed)

    repeated_zp = qzeros_unpacked.repeat_interleave(group_size, 0)
    torch_res = qweight_unpacked - repeated_zp
    torch_res[torch_res < 0] = 15 - qweight_unpacked[torch_res < 0]
    torch_res = torch_res[:, ::2] * 16 + torch_res[:, 1::2]
    torch_res = torch_res.to(torch.int8).view(torch.int32)

    assert (cuda_res == torch_res).all()


213
214
215
216
@pytest.mark.skipif(
    not is_quant_method_supported("gptq_marlin"),
    reason="Marlin is not supported on this GPU type.",
)
217
218
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
219
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(False, False))
220
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
221
222
@pytest.mark.parametrize("is_a_8bit", [True, False])
@pytest.mark.parametrize("nk_factors", MARLIN_REPACK_NK_FACTORS)
223
def test_gptq_marlin_repack(
224
    k_chunk, n_chunk, quant_type, act_order, is_a_8bit, nk_factors
225
):
226
    n_factor, k_factor = nk_factors
227
228
229

    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor
230
    group_size = 128
231
232
233
234
235
236
237

    # Filter act_order
    if act_order:
        if group_size == -1:
            return
        if group_size == size_k:
            return
238
239
        if is_a_8bit:
            return
240
241
242
243
244
245
246
247
248
249

    # Normalize group_size
    if group_size == -1:
        group_size = size_k
    assert group_size <= size_k

    # Create input
    b_weight = rand_data((size_k, size_n))

    # Quantize (and apply act_order if provided)
250
    w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
251
252
        b_weight, quant_type, group_size, act_order
    )
253
254

    # Pack to GPTQ format
255
    q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)
256
257
258
259
260
261
262
263

    # For act_order, sort the "weights" and "g_idx" so that group ids are
    # increasing
    sort_indices = torch.empty(0, dtype=torch.int, device=b_weight.device)
    if act_order:
        q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)

    # Pack to Marlin format
264
    weight_perm = get_weight_perm(quant_type.size_bits, is_a_8bit)
265
    marlin_q_w_1 = marlin_weights(
266
        q_w, size_k, size_n, quant_type.size_bits, weight_perm, is_a_8bit
267
    )
268

269
270
    opcheck(
        torch.ops._C.gptq_marlin_repack,
271
        (q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits, is_a_8bit),
272
    )
273

274
275
    # Run Marlin repack GPU kernel
    marlin_q_w_2 = ops.gptq_marlin_repack(
276
        q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits, is_a_8bit
277
278
279
    )
    torch.cuda.synchronize()

280
    torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
281
282


283
284
285
286
@pytest.mark.skipif(
    not is_quant_method_supported("gptq_marlin"),
    reason="Marlin is not supported on this GPU type.",
)
287
288
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
289
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(True))
290
291
292
293
@pytest.mark.parametrize("is_a_8bit", [True, False])
@pytest.mark.parametrize("nk_factors", MARLIN_REPACK_NK_FACTORS)
def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, is_a_8bit, nk_factors):
    n_factor, k_factor = nk_factors
294
295
296
297

    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

298
    group_size = 128
299
300
301
302
303

    # Create input
    b_weight = rand_data((size_k, size_n))

    # Quantize
304
305
306
    w_ref, q_w, s, zp = quantize_weights(
        b_weight, quant_type, group_size, zero_points=True
    )
307
308

    # Pack to AWQ format
309
    q_w_awq = awq_pack(q_w, quant_type.size_bits, size_k, size_n)
310
311

    # Pack to Marlin format
312
    weight_perm = get_weight_perm(quant_type.size_bits, is_a_8bit)
313
    marlin_q_w_1 = marlin_weights(
314
        q_w, size_k, size_n, quant_type.size_bits, weight_perm, is_a_8bit
315
    )
316

317
    opcheck(
318
319
        torch.ops._C.awq_marlin_repack,
        (q_w_awq, size_k, size_n, quant_type.size_bits, is_a_8bit),
320
    )
321

322
323
    # Run Marlin repack GPU kernel
    marlin_q_w_2 = ops.awq_marlin_repack(
324
        q_w_awq, size_k, size_n, quant_type.size_bits, is_a_8bit
325
326
327
    )
    torch.cuda.synchronize()

328
    torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
329
330


331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
def marlin_generate_valid_test_cases():
    all_combinations = itertools.product(
        DENSE_MARLIN_QUANT_TEST_CONFIGS,
        MNK_FACTORS,
        MARLIN_N_CHUNKS,
        MARLIN_K_CHUNKS,
        ACT_ORDER_OPTS,
        K_FULL_OPTS,
        USE_ATOMIC_ADD_OPTS,
        USE_FP32_REDUCE_OPTS,
    )

    def is_invalid(
        a_type,
        b_type,
        c_type,
        group_blocks,
        size_m,
        size_n,
        size_k,
        act_order,
        is_k_full,
        use_atomic_add,
        use_fp32_reduce,
    ):
        if use_atomic_add:
            if use_fp32_reduce:
                return False
            if (
                c_type == scalar_types.bfloat16
                and torch.cuda.get_device_capability()[0] < 9
            ):
                return False

        group_size = group_blocks if group_blocks <= 0 else group_blocks * 16
        if group_size > 0 and size_k % group_size != 0:
            return False

        if act_order and group_size in [-1, size_k]:
            return False
        if group_size == size_k:
            return False
        if not act_order and is_k_full:
            return False

        return a_type.size_bits < 16 or a_type is c_type

    cases = []
    for case in all_combinations:
        quant_test_config, mnk_factors, n_chunk, k_chunk, act_order, *_ = case
        size_m = mnk_factors[0]
        size_n = mnk_factors[1] * n_chunk
        size_k = mnk_factors[2] * k_chunk

        if act_order and not quant_test_config.get("support_act_order", False):
            continue

        f16_types = [scalar_types.float16, scalar_types.bfloat16]
        inner_combinations = itertools.product(
            quant_test_config.get("a_type", f16_types),
            [quant_test_config["b_type"]],
            quant_test_config.get("c_type", f16_types),
            quant_test_config["group_blocks"],
        )

        for sub_case in inner_combinations:
            if (
                sub_case[0] == scalar_types.float8_e4m3fn
                and current_platform.get_device_capability() not in [89, 120]
            ):
                continue
            args = sub_case + (size_m, size_n, size_k) + case[4:]
            if is_invalid(*args):
                cases.append(args)
    return cases


408
409
410
411
@pytest.mark.skipif(
    not is_quant_method_supported("gptq_marlin"),
    reason="Marlin is not supported on this GPU type.",
)
412
@pytest.mark.parametrize(
413
414
415
416
417
418
    (
        "a_type, b_type, c_type, group_blocks,"
        "size_m, size_n, size_k, act_order, is_k_full,"
        "use_atomic_add, use_fp32_reduce"
    ),
    marlin_generate_valid_test_cases(),
419
)
420
def test_marlin_gemm(
421
422
423
424
425
426
427
    a_type,
    b_type,
    c_type,
    group_blocks,
    size_m,
    size_n,
    size_k,
428
429
430
431
432
    act_order,
    is_k_full,
    use_atomic_add,
    use_fp32_reduce,
):
433
    has_zp = b_type in [scalar_types.uint4, scalar_types.uint8]
434

435
    group_size = group_blocks if group_blocks <= 0 else group_blocks * 16
436

437
438
439
440
441
442
    if c_type == scalar_types.float16:
        dtype = torch.float16
    elif c_type == scalar_types.bfloat16:
        dtype = torch.bfloat16
    else:
        raise RuntimeError("unsupported c_type")
443

444
445
446
447
448
449
    if a_type == scalar_types.int8:
        a_dtype = torch.int8
    elif a_type == scalar_types.float8_e4m3fn:
        a_dtype = torch.float8_e4m3fn
    else:
        a_dtype = dtype
450

451
452
    a_input = rand_data((size_m, size_k), dtype=dtype)
    b_weight = rand_data((size_k, size_n), dtype=dtype)
453

454
    if b_type == scalar_types.float4_e2m1f:
455
        if group_size == 16:
456
            w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_nvfp4_like(
457
                b_weight.T, group_size, input_dtype=a_dtype
458
            )
459
        else:
460
            w_ref, marlin_q_w, marlin_s = rand_marlin_weight_mxfp4_like(
461
                b_weight.T, group_size, input_dtype=a_dtype
462
            )
463
464
            marlin_s2 = None

465
466
467
        g_idx = None
        sort_indices = None
        marlin_zp = None
468
469
470
471
    elif b_type == scalar_types.float8_e4m3fn:
        w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(
            b_weight.T, group_size, input_dtype=a_dtype
        )
472
473
        g_idx = None
        sort_indices = None
474
475
476
477
        marlin_zp = None
        marlin_s2 = None
    elif has_zp:
        w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
478
            b_weight, b_type, group_size, input_dtype=a_dtype
479
        )
480
481
482
        g_idx = None
        sort_indices = None
        marlin_s2 = None
483
484
    else:
        w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
485
            b_weight, b_type, group_size, act_order, input_dtype=a_dtype
486
        )
487

488
489
        marlin_zp = None
        marlin_s2 = None
490

491
    workspace = marlin_make_workspace_new(w_ref.device)
492

493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
    if a_type == scalar_types.int8:
        a_input, a_scales = per_token_quant_int8(a_input)
        a_input_ref = a_input.to(a_scales.dtype) * a_scales.view(-1, 1)
        a_input_ref = a_input_ref.to(dtype)

        if group_size != -1:
            a_scales = a_scales / 4096 * marlin_s.max()
            a_scales = a_scales.float()
            marlin_s = marlin_s / marlin_s.max() * 4096
            marlin_s = marlin_s.round().to(torch.int16).view(dtype)
    elif a_type == scalar_types.float8_e4m3fn:
        a_input, a_scales = ops.scaled_fp8_quant(a_input, use_per_token_if_dynamic=True)
        a_input_ref = a_input.to(a_scales.dtype) * a_scales.view(-1, 1)
        a_input_ref = a_input_ref.to(dtype)
    else:
        assert a_type.size_bits == 16
        a_input_ref = a_input
        a_scales = None

    output = torch.empty((size_m, size_n), dtype=dtype, device=a_input.device)
513

514
    output = ops.marlin_gemm(
515
        a_input,
516
        output,
517
        marlin_q_w,
518
        None,
519
        marlin_s,
520
        a_scales,
521
        marlin_s2,
522
        marlin_zp,
523
524
        g_idx,
        sort_indices,
525
        workspace,
526
        b_type,
527
528
529
        a_input.shape[0],
        b_weight.shape[1],
        a_input.shape[1],
530
        is_k_full=is_k_full,
531
        use_atomic_add=use_atomic_add,
532
        use_fp32_reduce=use_fp32_reduce,
533
        is_zp_float=False,
534
    )
535
    output_ref = torch.matmul(a_input_ref, w_ref)
536

537
538
539
540
    max_diff = compute_max_diff(output, output_ref)
    assert max_diff < 0.04


541
542
# TODO: find better way to test this?
@torch.compile(fullgraph=True)
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
def marlin_24_gemm_tester(
    a_input,
    marlin_24_q_w_comp,
    marlin_24_meta,
    marlin_24_s,
    scratch,
    quant_type,
    size_m,
    size_n,
    size_k,
):
    return ops.gptq_marlin_24_gemm(
        a_input,
        marlin_24_q_w_comp,
        marlin_24_meta,
        marlin_24_s,
        scratch,
        quant_type,
        size_m,
        size_n,
        size_k,
    )
565
566


567
568
569
570
@pytest.mark.skipif(
    not is_quant_method_supported("gptq_marlin"),
    reason="Marlin is not supported on this GPU type.",
)
571
572
@pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_24_N_CHUNKS)
573
@pytest.mark.parametrize("quant_type", GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
574
575
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
576
def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size, mnk_factors):
577
578
579
580
581
582
583
584
585
    m_factor, n_factor, k_factor = mnk_factors

    size_m = m_factor
    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

    a_input = rand_data((size_m, size_k))
    b_weight = rand_data((size_k, size_n))

586
587
588
    (w_24_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = marlin_24_quantize(
        b_weight, quant_type, group_size
    )
589

590
591
592
    workspace_24 = MarlinWorkspace(
        size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL
    )
593
594
595

    output_ref = torch.matmul(a_input, w_24_ref)

596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
    opcheck(
        torch.ops._C.gptq_marlin_24_gemm,
        (
            a_input,
            marlin_24_q_w_comp,
            marlin_24_meta,
            marlin_24_s,
            workspace_24.scratch,
            quant_type.id,
            a_input.shape[0],
            b_weight.shape[1],
            a_input.shape[1],
        ),
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )
611

612
    output = marlin_24_gemm_tester(
613
614
615
616
617
        a_input,
        marlin_24_q_w_comp,
        marlin_24_meta,
        marlin_24_s,
        workspace_24.scratch,
618
        quant_type,
619
620
621
622
623
624
625
626
627
628
        a_input.shape[0],
        b_weight.shape[1],
        a_input.shape[1],
    )

    torch.cuda.synchronize()

    max_diff = compute_max_diff(output, output_ref)

    assert max_diff < 0.04
629
630


631
632
633
634
635
636
637
638
def test_marlin_gemm_subset_input():
    quant_type = scalar_types.uint4b8
    group_size = 128

    size_m, size_k, size_n = 32, 1024, 2048
    big_m = size_m * 2
    big_k = size_k * 2

639
    a_input = rand_data((big_m, big_k))[8 : size_m + 8, 8 : size_k + 8]
640
641
642
    b_weight = rand_data((size_k, size_n))

    w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
643
644
        b_weight, quant_type, group_size, False
    )
645
646

    marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
647
    workspace = marlin_make_workspace_new(a_input.device)
648

649
    output = ops.marlin_gemm(
650
        a_input,
651
        None,
652
        marlin_q_w,
653
        None,
654
        marlin_s,
655
        None,
656
        None,
657
658
659
        marlin_zp,
        g_idx,
        sort_indices,
660
        workspace,
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
        quant_type,
        a_input.shape[0],
        b_weight.shape[1],
        a_input.shape[1],
        is_k_full=True,
        use_atomic_add=False,
        use_fp32_reduce=True,
        is_zp_float=False,
    )
    output_ref = torch.matmul(a_input, w_ref)

    torch.cuda.synchronize()

    max_diff = compute_max_diff(output, output_ref)

    assert max_diff < 0.04


679
680
681
682
683
684
685
686
@pytest.mark.parametrize("size_m", [1, 256])
def test_marlin_gemm_with_bias(size_m):
    quant_type = scalar_types.uint4b8
    group_size = 128

    size_k, size_n = 1024, 2048
    a_input = rand_data((size_m, size_k))
    b_weight = rand_data((size_k, size_n))
687
    b_bias = rand_data((size_n,)) * 10
688
689
690
691

    marlin_bias = marlin_permute_bias(b_bias)

    w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
692
693
        b_weight, quant_type, group_size, False
    )
694
695
696
697

    marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
    workspace = marlin_make_workspace_new(a_input.device)

698
    output = ops.marlin_gemm(
699
700
701
702
703
704
        a_input,
        None,
        marlin_q_w,
        marlin_bias,
        marlin_s,
        None,
705
        None,
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
        marlin_zp,
        g_idx,
        sort_indices,
        workspace,
        quant_type,
        a_input.shape[0],
        b_weight.shape[1],
        a_input.shape[1],
        is_k_full=True,
        use_atomic_add=False,
        use_fp32_reduce=True,
        is_zp_float=False,
    )
    output_ref = torch.matmul(a_input, w_ref) + b_bias.view(1, -1)

    torch.cuda.synchronize()

    max_diff = compute_max_diff(output, output_ref)

    assert max_diff < 0.04