test_marlin_gemm.py 19.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
"""Tests for the marlin kernel.

Run `pytest tests/kernels/marlin/test_marlin_gemm.py`.
"""
import pytest
import torch

10
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
11
from tests.quantization.utils import is_quant_method_supported
12
from vllm import _custom_ops as ops
13
14
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
    GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
15
    GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
16
17
18
from vllm.model_executor.layers.quantization.qqq import (
    MARLIN_QQQ_MAX_PARALLEL, MARLIN_QQQ_MIN_THREAD_N,
    MARLIN_QQQ_SUPPORTED_GROUP_SIZES, MARLIN_QQQ_SUPPORTED_NUM_BITS)
19
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
20
    GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
21
    MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx,
22
    marlin_make_workspace_new, marlin_permute_bias, marlin_permute_scales,
23
    query_marlin_supported_quant_types)
24
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
25
26
    FP4_MARLIN_SUPPORTED_GROUP_SIZES, rand_marlin_weight_mxfp4_like,
    rand_marlin_weight_nvfp4_like)
27
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
28
    marlin_quant_fp8_torch)
29
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
30
31
    MarlinWorkspace, awq_marlin_quantize, get_weight_perm, marlin_quantize,
    marlin_weights)
32
33
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
    marlin_24_quantize)
34
35
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_qqq import (  # noqa: E501
    marlin_qqq_quantize)
36
from vllm.model_executor.layers.quantization.utils.quant_utils import (
37
    awq_pack, gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights)
38
from vllm.scalar_type import scalar_types
39
40
41

ACT_ORDER_OPTS = [False, True]
K_FULL_OPTS = [False, True]
42
USE_ATOMIC_ADD_OPTS = [False, True]
43
USE_FP32_REDUCE_OPTS = [True]
44

45
MARLIN_K_CHUNKS = [128]
46
MARLIN_N_CHUNKS = [64, 256]
47
48

MARLIN_24_K_CHUNKS = [128]
49
MARLIN_24_N_CHUNKS = [512]
50

51
52
HQQ_SUPPORTED_GROUP_SIZES = [64]

53
54
55
56
57
58
59
MNK_FACTORS = [
    (1, 1, 1),
    (1, 4, 8),
    (1, 7, 5),
    (13, 17, 67),
    (26, 37, 13),
    (67, 13, 11),
60
61
    (257, 13, 11),
    (658, 13, 11),
62
63
]

64
DTYPES = [torch.float16, torch.bfloat16]
65

66

67
68
69
70
71
def compute_max_diff(output, output_ref):
    return torch.mean(torch.abs(output - output_ref)) / torch.mean(
        torch.abs(output_ref))


72
73
def rand_data(shape, dtype=torch.float16):
    return torch.randn(shape, dtype=dtype, device="cuda")
74
75


76
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
77
                    reason="Marlin is not supported on this GPU type.")
78
79
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
80
@pytest.mark.parametrize("quant_type",
81
                         query_marlin_supported_quant_types(False, False))
82
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
83
84
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
85
86
def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
                            act_order, mnk_factors):
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
    m_factor, n_factor, k_factor = mnk_factors

    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

    # Filter act_order
    if act_order:
        if group_size == -1:
            return
        if group_size == size_k:
            return

    # Normalize group_size
    if group_size == -1:
        group_size = size_k
    assert group_size <= size_k

    # Create input
    b_weight = rand_data((size_k, size_n))

    # Quantize (and apply act_order if provided)
108
109
    w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
        b_weight, quant_type, group_size, act_order)
110
111

    # Pack to GPTQ format
112
    q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)
113
114
115
116
117
118
119
120

    # For act_order, sort the "weights" and "g_idx" so that group ids are
    # increasing
    sort_indices = torch.empty(0, dtype=torch.int, device=b_weight.device)
    if act_order:
        q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)

    # Pack to Marlin format
121
122
123
    weight_perm = get_weight_perm(quant_type.size_bits)
    marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits,
                                  weight_perm)
124

125
126
127
    opcheck(torch.ops._C.gptq_marlin_repack,
            (q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits))

128
129
130
131
132
133
    # Run Marlin repack GPU kernel
    marlin_q_w_2 = ops.gptq_marlin_repack(
        q_w_gptq,
        sort_indices,
        size_k,
        size_n,
134
        quant_type.size_bits,
135
136
137
    )
    torch.cuda.synchronize()

138
    torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
139
140


141
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
142
                    reason="Marlin is not supported on this GPU type.")
143
144
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
145
@pytest.mark.parametrize("quant_type",
146
                         query_marlin_supported_quant_types(True))
147
148
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
149
def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
                           mnk_factors):
    m_factor, n_factor, k_factor = mnk_factors

    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

    # Normalize group_size
    if group_size == -1:
        group_size = size_k
    assert group_size <= size_k

    # Create input
    b_weight = rand_data((size_k, size_n))

    # Quantize
165
166
167
168
    w_ref, q_w, s, zp = quantize_weights(b_weight,
                                         quant_type,
                                         group_size,
                                         zero_points=True)
169
170

    # Pack to AWQ format
171
    q_w_awq = awq_pack(q_w, quant_type.size_bits, size_k, size_n)
172
173

    # Pack to Marlin format
174
175
176
    weight_perm = get_weight_perm(quant_type.size_bits)
    marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits,
                                  weight_perm)
177

178
179
180
    opcheck(torch.ops._C.awq_marlin_repack,
            (q_w_awq, size_k, size_n, quant_type.size_bits))

181
182
183
184
185
    # Run Marlin repack GPU kernel
    marlin_q_w_2 = ops.awq_marlin_repack(
        q_w_awq,
        size_k,
        size_n,
186
        quant_type.size_bits,
187
188
189
    )
    torch.cuda.synchronize()

190
    torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
191
192
193
194
195
196


@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
                    reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
197
198
199
200
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types())
@pytest.mark.parametrize(
    "group_size",
    set(MARLIN_SUPPORTED_GROUP_SIZES + FP4_MARLIN_SUPPORTED_GROUP_SIZES))
201
202
203
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
@pytest.mark.parametrize("is_k_full", K_FULL_OPTS)
204
@pytest.mark.parametrize("use_atomic_add", USE_ATOMIC_ADD_OPTS)
205
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
206
207
208
209
@pytest.mark.parametrize("dtype", DTYPES)
def test_gptq_marlin_gemm(k_chunk, n_chunk, quant_type, group_size,
                          mnk_factors, act_order, is_k_full, use_atomic_add,
                          use_fp32_reduce, dtype):
210
    m_factor, n_factor, k_factor = mnk_factors
211
    has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
212
213
214
215
216
217
218
219
220
221

    size_m = m_factor
    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

    if act_order:
        if group_size == -1:
            return
        if group_size == size_k:
            return
222
223
        if has_zp:
            return
224

225
226
227
    if size_k % group_size != 0:
        return

228
229
    a_input = rand_data((size_m, size_k), dtype)
    b_weight = rand_data((size_k, size_n), dtype)
230

231
    if quant_type == scalar_types.float4_e2m1f:
232
        if group_size not in [16, 32] or act_order:
233
            return
234
235
236
237
238
239
240
241
242
243
244
        if group_size == 32 and dtype == torch.float16:
            return

        if group_size == 16:
            w_ref, marlin_q_w, marlin_s, marlin_s2 = \
                rand_marlin_weight_nvfp4_like(b_weight.T, group_size)
        else:
            w_ref, marlin_q_w, marlin_s = \
                rand_marlin_weight_mxfp4_like(b_weight.T, group_size)
            marlin_s2 = None

245
246
247
248
        g_idx = None
        sort_indices = None
        marlin_zp = None
    elif quant_type == scalar_types.float8_e4m3fn:
249
250
251
252
253
254
255
256
        if group_size not in [-1, 128]:
            return
        if act_order:
            return
        w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(
            b_weight.T, group_size)
        g_idx = None
        sort_indices = None
257
258
259
260
261
262
263
264
265
266
        marlin_zp = None
        marlin_s2 = None
    elif has_zp:
        if group_size == 16:
            return
        w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
            b_weight, quant_type, group_size)
        g_idx = None
        sort_indices = None
        marlin_s2 = None
267
    else:
268
269
        if group_size == 16:
            return
270
271
        w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
            b_weight, quant_type, group_size, act_order)
272
273
        marlin_zp = None
        marlin_s2 = None
274

275
    workspace = marlin_make_workspace_new(w_ref.device)
276

277
    opcheck(torch.ops._C.gptq_marlin_gemm,
278
279
            (a_input, None, marlin_q_w, None, marlin_s, marlin_s2, marlin_zp,
             g_idx, sort_indices, workspace, quant_type.id, a_input.shape[0],
280
281
282
             b_weight.shape[1], a_input.shape[1], is_k_full, use_atomic_add,
             use_fp32_reduce, False),
            test_utils=DEFAULT_OPCHECK_TEST_UTILS)
283

284
285
    output = ops.gptq_marlin_gemm(
        a_input,
286
        None,
287
        marlin_q_w,
288
        None,
289
        marlin_s,
290
        marlin_s2,
291
        marlin_zp,
292
293
        g_idx,
        sort_indices,
294
        workspace,
295
        quant_type,
296
297
298
        a_input.shape[0],
        b_weight.shape[1],
        a_input.shape[1],
299
        is_k_full=is_k_full,
300
        use_atomic_add=use_atomic_add,
301
        use_fp32_reduce=use_fp32_reduce,
302
        is_zp_float=False,
303
304
305
306
307
    )
    output_ref = torch.matmul(a_input, w_ref)

    torch.cuda.synchronize()

308
309
310
311
312
    max_diff = compute_max_diff(output, output_ref)

    assert max_diff < 0.04


313
314
315
316
317
318
319
320
321
322
# TODO: find better way to test this?
@torch.compile(fullgraph=True)
def marlin_24_gemm_tester(a_input, marlin_24_q_w_comp, marlin_24_meta,
                          marlin_24_s, scratch, quant_type, size_m, size_n,
                          size_k):
    return ops.gptq_marlin_24_gemm(a_input, marlin_24_q_w_comp, marlin_24_meta,
                                   marlin_24_s, scratch, quant_type, size_m,
                                   size_n, size_k)


323
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
324
325
326
                    reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_24_N_CHUNKS)
327
@pytest.mark.parametrize("quant_type", GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
328
329
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
330
def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
331
                             mnk_factors):
332
333
334
335
336
337
338
339
340
341
    m_factor, n_factor, k_factor = mnk_factors

    size_m = m_factor
    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

    a_input = rand_data((size_m, size_k))
    b_weight = rand_data((size_k, size_n))

    (w_24_ref, marlin_24_q_w_comp, marlin_24_meta,
342
     marlin_24_s) = marlin_24_quantize(b_weight, quant_type, group_size)
343
344
345
346
347
348

    workspace_24 = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N,
                                   GPTQ_MARLIN_24_MAX_PARALLEL)

    output_ref = torch.matmul(a_input, w_24_ref)

349
350
    opcheck(torch.ops._C.gptq_marlin_24_gemm,
            (a_input, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s,
351
             workspace_24.scratch, quant_type.id, a_input.shape[0],
352
353
354
             b_weight.shape[1], a_input.shape[1]),
            test_utils=DEFAULT_OPCHECK_TEST_UTILS)

355
    output = marlin_24_gemm_tester(
356
357
358
359
360
        a_input,
        marlin_24_q_w_comp,
        marlin_24_meta,
        marlin_24_s,
        workspace_24.scratch,
361
        quant_type,
362
363
364
365
366
367
368
369
370
371
        a_input.shape[0],
        b_weight.shape[1],
        a_input.shape[1],
    )

    torch.cuda.synchronize()

    max_diff = compute_max_diff(output, output_ref)

    assert max_diff < 0.04
372
373


374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
                    reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("group_size", HQQ_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
def test_hqq_marlin_gemm(
    k_chunk,
    n_chunk,
    group_size,
    mnk_factors,
    use_fp32_reduce,
):
    m_factor, n_factor, k_factor = mnk_factors

    size_m = m_factor
    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

    quant_type = scalar_types.uint4

    a_input = rand_data((size_m, size_k))
    dev = a_input.device

    b_weight = torch.randint(0,
                             10, (size_n, size_k),
                             dtype=torch.uint8,
                             device=dev)
    scale = rand_data((size_n, size_k // group_size))
    zero = rand_data((size_n, size_k // group_size))

    gptq_w_q = gptq_pack(b_weight.transpose(1, 0), 4, size_k, size_n)

    sort_indices = torch.empty(0, dtype=torch.int, device=dev)
    marlin_w_q = ops.gptq_marlin_repack(gptq_w_q, sort_indices, size_k, size_n,
                                        4).to(dev)
    marlin_s = marlin_permute_scales(scale.transpose(1, 0), size_k, size_n,
                                     group_size).to(dev)
    marlin_zp = marlin_permute_scales(zero.transpose(1, 0), size_k, size_n,
                                      group_size).to(dev)

    g_idx = marlin_make_empty_g_idx(dev)
    g_idx_sort_indices = marlin_make_empty_g_idx(dev)

419
    workspace = marlin_make_workspace_new(b_weight.device)
420
421
422

    output = ops.gptq_marlin_gemm(
        a_input,
423
        None,
424
        marlin_w_q,
425
        None,
426
        marlin_s,
427
        None,
428
429
430
        marlin_zp,
        g_idx,
        g_idx_sort_indices,
431
        workspace,
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
        quant_type,
        a_input.shape[0],
        b_weight.shape[0],
        a_input.shape[1],
        is_k_full=True,
        use_fp32_reduce=use_fp32_reduce,
        is_zp_float=True,
    )

    b_flat = b_weight.reshape(-1, group_size)
    zp_flat = zero.reshape(-1, 1)
    s_flat = scale.reshape(-1, 1)
    dequant = (b_flat - zp_flat) * s_flat

    output_ref = torch.matmul(a_input,
                              dequant.reshape(b_weight.shape).transpose(1, 0))

    torch.cuda.synchronize()

    max_diff = compute_max_diff(output, output_ref)

    assert max_diff < 0.04


456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
@pytest.mark.skipif(not is_quant_method_supported("qqq"),
                    reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("num_bits", MARLIN_QQQ_SUPPORTED_NUM_BITS)
@pytest.mark.parametrize("group_size", MARLIN_QQQ_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
def test_marlin_qqq_gemm(
    k_chunk,
    n_chunk,
    num_bits,
    group_size,
    mnk_factors,
):
    int8_traits = torch.iinfo(torch.int8)
    m_factor, n_factor, k_factor = mnk_factors

    size_m = m_factor
    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

    a_input = rand_data((size_m, size_k))
    b_weight = rand_data((size_k, size_n))

    # Quantize activations
    s_a = a_input.abs().max(dim=-1, keepdim=True)[0].div(int8_traits.max).to(
        torch.float)
    q_a = (a_input / s_a).round().clamp(int8_traits.min,
                                        int8_traits.max).to(torch.int8)

    # Quantize weights
    w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel = \
    marlin_qqq_quantize(b_weight, num_bits, group_size)

    workspace = MarlinWorkspace(size_n, MARLIN_QQQ_MIN_THREAD_N,
                                MARLIN_QQQ_MAX_PARALLEL)

493
494
495
496
497
    opcheck(torch.ops._C.marlin_qqq_gemm,
            (q_a, marlin_qqq_q_w, s_a, marlin_qqq_s_channel,
             marlin_qqq_s_group, workspace.scratch, a_input.shape[0],
             b_weight.shape[1], a_input.shape[1]))

498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
    output = ops.marlin_qqq_gemm(
        q_a,
        marlin_qqq_q_w,
        s_a,
        marlin_qqq_s_channel,
        marlin_qqq_s_group,
        workspace.scratch,
        a_input.shape[0],
        b_weight.shape[1],
        a_input.shape[1],
    )
    output_ref = torch.matmul(q_a.half() * s_a.half(), w_ref)

    torch.cuda.synchronize()

    max_diff = compute_max_diff(output, output_ref)

    assert max_diff < 0.04
516
517


518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
def test_marlin_gemm_subset_input():
    quant_type = scalar_types.uint4b8
    group_size = 128

    size_m, size_k, size_n = 32, 1024, 2048
    big_m = size_m * 2
    big_k = size_k * 2

    a_input = rand_data((big_m, big_k))[8:size_m + 8, 8:size_k + 8]
    b_weight = rand_data((size_k, size_n))

    w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
        b_weight, quant_type, group_size, False)

    marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
533
    workspace = marlin_make_workspace_new(a_input.device)
534
535
536

    output = ops.gptq_marlin_gemm(
        a_input,
537
        None,
538
        marlin_q_w,
539
        None,
540
        marlin_s,
541
        None,
542
543
544
        marlin_zp,
        g_idx,
        sort_indices,
545
        workspace,
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
        quant_type,
        a_input.shape[0],
        b_weight.shape[1],
        a_input.shape[1],
        is_k_full=True,
        use_atomic_add=False,
        use_fp32_reduce=True,
        is_zp_float=False,
    )
    output_ref = torch.matmul(a_input, w_ref)

    torch.cuda.synchronize()

    max_diff = compute_max_diff(output, output_ref)

    assert max_diff < 0.04


564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
@pytest.mark.parametrize("size_m", [1, 256])
def test_marlin_gemm_with_bias(size_m):
    quant_type = scalar_types.uint4b8
    group_size = 128

    size_k, size_n = 1024, 2048
    a_input = rand_data((size_m, size_k))
    b_weight = rand_data((size_k, size_n))
    b_bias = rand_data((size_n, )) * 10

    marlin_bias = marlin_permute_bias(b_bias)

    w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
        b_weight, quant_type, group_size, False)

    marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
    workspace = marlin_make_workspace_new(a_input.device)

    output = ops.gptq_marlin_gemm(
        a_input,
        None,
        marlin_q_w,
        marlin_bias,
        marlin_s,
        None,
        marlin_zp,
        g_idx,
        sort_indices,
        workspace,
        quant_type,
        a_input.shape[0],
        b_weight.shape[1],
        a_input.shape[1],
        is_k_full=True,
        use_atomic_add=False,
        use_fp32_reduce=True,
        is_zp_float=False,
    )
    output_ref = torch.matmul(a_input, w_ref) + b_bias.view(1, -1)

    torch.cuda.synchronize()

    max_diff = compute_max_diff(output, output_ref)

    assert max_diff < 0.04


611
612
613
614
615
616
617
618
619
620
621
622
623
def test_marlin_gemm_opcheck():
    size_m = 2048
    size_n = 4096
    size_k = 4096
    a = torch.rand((size_m, size_n), device='cuda', dtype=torch.float16)
    w = torch.randint(-5, 5, (256, 8192), device='cuda', dtype=torch.int32)
    s = torch.full((32, size_k), 0.125, device='cuda', dtype=torch.float16)
    wk = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
                         GPTQ_MARLIN_MAX_PARALLEL).scratch
    x = torch.ops._C.marlin_gemm(a, w, s, wk, size_m, size_n, size_k)
    y = torch.ops._C.marlin_gemm(a, w, s, wk, size_m, size_n, size_k)
    torch.testing.assert_close(x, y)
    opcheck(torch.ops._C.marlin_gemm, (a, w, s, wk, size_m, size_n, size_k))