test_marlin_gemm.py 18.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
6
7
8
"""Tests for the marlin kernel.

Run `pytest tests/kernels/marlin/test_marlin_gemm.py`.
"""
import pytest
import torch

9
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
10
from tests.quantization.utils import is_quant_method_supported
11
from vllm import _custom_ops as ops
12
13
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
    GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
14
    GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
15
16
17
from vllm.model_executor.layers.quantization.qqq import (
    MARLIN_QQQ_MAX_PARALLEL, MARLIN_QQQ_MIN_THREAD_N,
    MARLIN_QQQ_SUPPORTED_GROUP_SIZES, MARLIN_QQQ_SUPPORTED_NUM_BITS)
18
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
19
    GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
20
    MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx,
21
22
    marlin_make_workspace_new, marlin_permute_scales,
    query_marlin_supported_quant_types)
23
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
24
    marlin_quant_fp8_torch)
25
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
26
27
    MarlinWorkspace, awq_marlin_quantize, get_weight_perm, marlin_quantize,
    marlin_weights)
28
29
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
    marlin_24_quantize)
30
31
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_qqq import (  # noqa: E501
    marlin_qqq_quantize)
32
from vllm.model_executor.layers.quantization.utils.quant_utils import (
33
    awq_pack, gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights)
34
from vllm.scalar_type import scalar_types
35
36
37

ACT_ORDER_OPTS = [False, True]
K_FULL_OPTS = [False, True]
38
USE_ATOMIC_ADD_OPTS = [False, True]
39
USE_FP32_REDUCE_OPTS = [False, True]
40

41
MARLIN_K_CHUNKS = [128]
42
MARLIN_N_CHUNKS = [64, 256]
43
44

MARLIN_24_K_CHUNKS = [128]
45
MARLIN_24_N_CHUNKS = [512]
46

47
48
HQQ_SUPPORTED_GROUP_SIZES = [64]

49
50
51
52
53
54
55
MNK_FACTORS = [
    (1, 1, 1),
    (1, 4, 8),
    (1, 7, 5),
    (13, 17, 67),
    (26, 37, 13),
    (67, 13, 11),
56
57
    (257, 13, 11),
    (658, 13, 11),
58
59
]

60
DTYPES = [torch.float16, torch.bfloat16]
61

62

63
64
65
66
67
def compute_max_diff(output, output_ref):
    return torch.mean(torch.abs(output - output_ref)) / torch.mean(
        torch.abs(output_ref))


68
69
def rand_data(shape, dtype=torch.float16):
    return torch.randn(shape, dtype=dtype, device="cuda")
70
71


72
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
73
                    reason="Marlin is not supported on this GPU type.")
74
75
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
76
@pytest.mark.parametrize("quant_type",
77
                         query_marlin_supported_quant_types(False, False))
78
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
79
80
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
81
82
def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
                            act_order, mnk_factors):
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
    m_factor, n_factor, k_factor = mnk_factors

    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

    # Filter act_order
    if act_order:
        if group_size == -1:
            return
        if group_size == size_k:
            return

    # Normalize group_size
    if group_size == -1:
        group_size = size_k
    assert group_size <= size_k

    # Create input
    b_weight = rand_data((size_k, size_n))

    # Quantize (and apply act_order if provided)
104
105
    w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
        b_weight, quant_type, group_size, act_order)
106
107

    # Pack to GPTQ format
108
    q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)
109
110
111
112
113
114
115
116

    # For act_order, sort the "weights" and "g_idx" so that group ids are
    # increasing
    sort_indices = torch.empty(0, dtype=torch.int, device=b_weight.device)
    if act_order:
        q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)

    # Pack to Marlin format
117
118
119
    weight_perm = get_weight_perm(quant_type.size_bits)
    marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits,
                                  weight_perm)
120

121
122
123
    opcheck(torch.ops._C.gptq_marlin_repack,
            (q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits))

124
125
126
127
128
129
    # Run Marlin repack GPU kernel
    marlin_q_w_2 = ops.gptq_marlin_repack(
        q_w_gptq,
        sort_indices,
        size_k,
        size_n,
130
        quant_type.size_bits,
131
132
133
    )
    torch.cuda.synchronize()

134
    torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
135
136


137
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
138
                    reason="Marlin is not supported on this GPU type.")
139
140
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
141
@pytest.mark.parametrize("quant_type",
142
                         query_marlin_supported_quant_types(True))
143
144
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
145
def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
                           mnk_factors):
    m_factor, n_factor, k_factor = mnk_factors

    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

    # Normalize group_size
    if group_size == -1:
        group_size = size_k
    assert group_size <= size_k

    # Create input
    b_weight = rand_data((size_k, size_n))

    # Quantize
161
162
163
164
    w_ref, q_w, s, zp = quantize_weights(b_weight,
                                         quant_type,
                                         group_size,
                                         zero_points=True)
165
166

    # Pack to AWQ format
167
    q_w_awq = awq_pack(q_w, quant_type.size_bits, size_k, size_n)
168
169

    # Pack to Marlin format
170
171
172
    weight_perm = get_weight_perm(quant_type.size_bits)
    marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits,
                                  weight_perm)
173

174
175
176
    opcheck(torch.ops._C.awq_marlin_repack,
            (q_w_awq, size_k, size_n, quant_type.size_bits))

177
178
179
180
181
    # Run Marlin repack GPU kernel
    marlin_q_w_2 = ops.awq_marlin_repack(
        q_w_awq,
        size_k,
        size_n,
182
        quant_type.size_bits,
183
184
185
    )
    torch.cuda.synchronize()

186
    torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
187
188
189
190
191
192


@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
                    reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
193
194
@pytest.mark.parametrize("quant_type",
                         query_marlin_supported_quant_types(False))
195
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
196
197
198
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
@pytest.mark.parametrize("is_k_full", K_FULL_OPTS)
199
@pytest.mark.parametrize("use_atomic_add", USE_ATOMIC_ADD_OPTS)
200
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
201
def test_gptq_marlin_gemm(
202
203
    k_chunk,
    n_chunk,
204
    quant_type,
205
206
207
208
    group_size,
    mnk_factors,
    act_order,
    is_k_full,
209
    use_atomic_add,
210
    use_fp32_reduce,
211
212
213
214
215
216
217
218
219
220
221
222
223
):
    m_factor, n_factor, k_factor = mnk_factors

    size_m = m_factor
    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

    if act_order:
        if group_size == -1:
            return
        if group_size == size_k:
            return

224
225
226
    if size_k % group_size != 0:
        return

227
228
229
    a_input = rand_data((size_m, size_k))
    b_weight = rand_data((size_k, size_n))

230
231
232
233
234
235
236
237
238
239
240
241
    if quant_type == scalar_types.float8_e4m3fn:
        if group_size not in [-1, 128]:
            return
        if act_order:
            return
        w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(
            b_weight.T, group_size)
        g_idx = None
        sort_indices = None
    else:
        w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
            b_weight, quant_type, group_size, act_order)
242

243
244
    marlin_zp = marlin_make_empty_g_idx(marlin_s.device)

245
    workspace = marlin_make_workspace_new(w_ref.device)
246

247
248
249
250
251
252
    opcheck(
        torch.ops._C.gptq_marlin_gemm,
        (a_input, None, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices,
         workspace, quant_type.id, a_input.shape[0], b_weight.shape[1],
         a_input.shape[1], is_k_full, use_atomic_add, use_fp32_reduce, False),
        test_utils=DEFAULT_OPCHECK_TEST_UTILS)
253

254
255
    output = ops.gptq_marlin_gemm(
        a_input,
256
        None,
257
258
        marlin_q_w,
        marlin_s,
259
        marlin_zp,
260
261
        g_idx,
        sort_indices,
262
        workspace,
263
        quant_type,
264
265
266
        a_input.shape[0],
        b_weight.shape[1],
        a_input.shape[1],
267
        is_k_full=is_k_full,
268
        use_atomic_add=use_atomic_add,
269
        use_fp32_reduce=use_fp32_reduce,
270
        is_zp_float=False,
271
272
273
274
275
    )
    output_ref = torch.matmul(a_input, w_ref)

    torch.cuda.synchronize()

276
277
278
279
280
    max_diff = compute_max_diff(output, output_ref)

    assert max_diff < 0.04


281
282
283
284
285
286
287
288
289
290
# TODO: find better way to test this?
@torch.compile(fullgraph=True)
def marlin_24_gemm_tester(a_input, marlin_24_q_w_comp, marlin_24_meta,
                          marlin_24_s, scratch, quant_type, size_m, size_n,
                          size_k):
    return ops.gptq_marlin_24_gemm(a_input, marlin_24_q_w_comp, marlin_24_meta,
                                   marlin_24_s, scratch, quant_type, size_m,
                                   size_n, size_k)


291
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
292
293
294
                    reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_24_N_CHUNKS)
295
@pytest.mark.parametrize("quant_type", GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
296
297
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
298
def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
299
                             mnk_factors):
300
301
302
303
304
305
306
307
308
309
    m_factor, n_factor, k_factor = mnk_factors

    size_m = m_factor
    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

    a_input = rand_data((size_m, size_k))
    b_weight = rand_data((size_k, size_n))

    (w_24_ref, marlin_24_q_w_comp, marlin_24_meta,
310
     marlin_24_s) = marlin_24_quantize(b_weight, quant_type, group_size)
311
312
313
314
315
316

    workspace_24 = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N,
                                   GPTQ_MARLIN_24_MAX_PARALLEL)

    output_ref = torch.matmul(a_input, w_24_ref)

317
318
    opcheck(torch.ops._C.gptq_marlin_24_gemm,
            (a_input, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s,
319
             workspace_24.scratch, quant_type.id, a_input.shape[0],
320
321
322
             b_weight.shape[1], a_input.shape[1]),
            test_utils=DEFAULT_OPCHECK_TEST_UTILS)

323
    output = marlin_24_gemm_tester(
324
325
326
327
328
        a_input,
        marlin_24_q_w_comp,
        marlin_24_meta,
        marlin_24_s,
        workspace_24.scratch,
329
        quant_type,
330
331
332
333
334
335
336
337
338
339
        a_input.shape[0],
        b_weight.shape[1],
        a_input.shape[1],
    )

    torch.cuda.synchronize()

    max_diff = compute_max_diff(output, output_ref)

    assert max_diff < 0.04
340
341


342
343
344
345
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
                    reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
346
347
@pytest.mark.parametrize("quant_type",
                         query_marlin_supported_quant_types(True))
348
349
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
350
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
351
352
353
def test_awq_marlin_gemm(
    k_chunk,
    n_chunk,
354
    quant_type,
355
356
    group_size,
    mnk_factors,
357
    use_fp32_reduce,
358
359
360
361
362
363
364
365
366
367
368
):
    m_factor, n_factor, k_factor = mnk_factors

    size_m = m_factor
    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

    a_input = rand_data((size_m, size_k))
    b_weight = rand_data((size_k, size_n))

    w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
369
        b_weight, quant_type, group_size)
370
371
372
373
374

    g_idx = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
    sort_indices = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
    is_k_full = True

375
    workspace = marlin_make_workspace_new(a_input.device)
376
377
378

    output = ops.gptq_marlin_gemm(
        a_input,
379
        None,
380
381
382
383
384
        marlin_q_w,
        marlin_s,
        marlin_zp,
        g_idx,
        sort_indices,
385
        workspace,
386
        quant_type,
387
388
389
        a_input.shape[0],
        b_weight.shape[1],
        a_input.shape[1],
390
391
        is_k_full=is_k_full,
        use_fp32_reduce=use_fp32_reduce,
392
        is_zp_float=False,
393
394
395
396
397
398
399
400
    )
    output_ref = torch.matmul(a_input, w_ref)

    torch.cuda.synchronize()

    max_diff = compute_max_diff(output, output_ref)

    assert max_diff < 0.04
401
402


403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
                    reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("group_size", HQQ_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
def test_hqq_marlin_gemm(
    k_chunk,
    n_chunk,
    group_size,
    mnk_factors,
    use_fp32_reduce,
):
    m_factor, n_factor, k_factor = mnk_factors

    size_m = m_factor
    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

    quant_type = scalar_types.uint4

    a_input = rand_data((size_m, size_k))
    dev = a_input.device

    b_weight = torch.randint(0,
                             10, (size_n, size_k),
                             dtype=torch.uint8,
                             device=dev)
    scale = rand_data((size_n, size_k // group_size))
    zero = rand_data((size_n, size_k // group_size))

    gptq_w_q = gptq_pack(b_weight.transpose(1, 0), 4, size_k, size_n)

    sort_indices = torch.empty(0, dtype=torch.int, device=dev)
    marlin_w_q = ops.gptq_marlin_repack(gptq_w_q, sort_indices, size_k, size_n,
                                        4).to(dev)
    marlin_s = marlin_permute_scales(scale.transpose(1, 0), size_k, size_n,
                                     group_size).to(dev)
    marlin_zp = marlin_permute_scales(zero.transpose(1, 0), size_k, size_n,
                                      group_size).to(dev)

    g_idx = marlin_make_empty_g_idx(dev)
    g_idx_sort_indices = marlin_make_empty_g_idx(dev)

448
    workspace = marlin_make_workspace_new(b_weight.device)
449
450
451

    output = ops.gptq_marlin_gemm(
        a_input,
452
        None,
453
454
455
456
457
        marlin_w_q,
        marlin_s,
        marlin_zp,
        g_idx,
        g_idx_sort_indices,
458
        workspace,
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
        quant_type,
        a_input.shape[0],
        b_weight.shape[0],
        a_input.shape[1],
        is_k_full=True,
        use_fp32_reduce=use_fp32_reduce,
        is_zp_float=True,
    )

    b_flat = b_weight.reshape(-1, group_size)
    zp_flat = zero.reshape(-1, 1)
    s_flat = scale.reshape(-1, 1)
    dequant = (b_flat - zp_flat) * s_flat

    output_ref = torch.matmul(a_input,
                              dequant.reshape(b_weight.shape).transpose(1, 0))

    torch.cuda.synchronize()

    max_diff = compute_max_diff(output, output_ref)

    assert max_diff < 0.04


483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
@pytest.mark.skipif(not is_quant_method_supported("qqq"),
                    reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("num_bits", MARLIN_QQQ_SUPPORTED_NUM_BITS)
@pytest.mark.parametrize("group_size", MARLIN_QQQ_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
def test_marlin_qqq_gemm(
    k_chunk,
    n_chunk,
    num_bits,
    group_size,
    mnk_factors,
):
    int8_traits = torch.iinfo(torch.int8)
    m_factor, n_factor, k_factor = mnk_factors

    size_m = m_factor
    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

    a_input = rand_data((size_m, size_k))
    b_weight = rand_data((size_k, size_n))

    # Quantize activations
    s_a = a_input.abs().max(dim=-1, keepdim=True)[0].div(int8_traits.max).to(
        torch.float)
    q_a = (a_input / s_a).round().clamp(int8_traits.min,
                                        int8_traits.max).to(torch.int8)

    # Quantize weights
    w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel = \
    marlin_qqq_quantize(b_weight, num_bits, group_size)

    workspace = MarlinWorkspace(size_n, MARLIN_QQQ_MIN_THREAD_N,
                                MARLIN_QQQ_MAX_PARALLEL)

520
521
522
523
524
    opcheck(torch.ops._C.marlin_qqq_gemm,
            (q_a, marlin_qqq_q_w, s_a, marlin_qqq_s_channel,
             marlin_qqq_s_group, workspace.scratch, a_input.shape[0],
             b_weight.shape[1], a_input.shape[1]))

525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
    output = ops.marlin_qqq_gemm(
        q_a,
        marlin_qqq_q_w,
        s_a,
        marlin_qqq_s_channel,
        marlin_qqq_s_group,
        workspace.scratch,
        a_input.shape[0],
        b_weight.shape[1],
        a_input.shape[1],
    )
    output_ref = torch.matmul(q_a.half() * s_a.half(), w_ref)

    torch.cuda.synchronize()

    max_diff = compute_max_diff(output, output_ref)

    assert max_diff < 0.04
543
544


545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
def test_marlin_gemm_subset_input():
    quant_type = scalar_types.uint4b8
    group_size = 128

    size_m, size_k, size_n = 32, 1024, 2048
    big_m = size_m * 2
    big_k = size_k * 2

    a_input = rand_data((big_m, big_k))[8:size_m + 8, 8:size_k + 8]
    b_weight = rand_data((size_k, size_n))

    w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
        b_weight, quant_type, group_size, False)

    marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
560
    workspace = marlin_make_workspace_new(a_input.device)
561
562
563

    output = ops.gptq_marlin_gemm(
        a_input,
564
        None,
565
566
567
568
569
        marlin_q_w,
        marlin_s,
        marlin_zp,
        g_idx,
        sort_indices,
570
        workspace,
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
        quant_type,
        a_input.shape[0],
        b_weight.shape[1],
        a_input.shape[1],
        is_k_full=True,
        use_atomic_add=False,
        use_fp32_reduce=True,
        is_zp_float=False,
    )
    output_ref = torch.matmul(a_input, w_ref)

    torch.cuda.synchronize()

    max_diff = compute_max_diff(output, output_ref)

    assert max_diff < 0.04


589
590
591
592
593
594
595
596
597
598
599
600
601
def test_marlin_gemm_opcheck():
    size_m = 2048
    size_n = 4096
    size_k = 4096
    a = torch.rand((size_m, size_n), device='cuda', dtype=torch.float16)
    w = torch.randint(-5, 5, (256, 8192), device='cuda', dtype=torch.int32)
    s = torch.full((32, size_k), 0.125, device='cuda', dtype=torch.float16)
    wk = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
                         GPTQ_MARLIN_MAX_PARALLEL).scratch
    x = torch.ops._C.marlin_gemm(a, w, s, wk, size_m, size_n, size_k)
    y = torch.ops._C.marlin_gemm(a, w, s, wk, size_m, size_n, size_k)
    torch.testing.assert_close(x, y)
    opcheck(torch.ops._C.marlin_gemm, (a, w, s, wk, size_m, size_n, size_k))