test_marlin_gemm.py 15.6 KB
Newer Older
1
2
3
4
5
6
7
"""Tests for the marlin kernel.

Run `pytest tests/kernels/marlin/test_marlin_gemm.py`.
"""
import pytest
import torch

8
from tests.quantization.utils import is_quant_method_supported
9
from vllm import _custom_ops as ops
10
11
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
    GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
12
    GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
13
14
15
from vllm.model_executor.layers.quantization.qqq import (
    MARLIN_QQQ_MAX_PARALLEL, MARLIN_QQQ_MIN_THREAD_N,
    MARLIN_QQQ_SUPPORTED_GROUP_SIZES, MARLIN_QQQ_SUPPORTED_NUM_BITS)
16
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
17
    GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
18
19
    MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx,
    marlin_permute_scales, query_marlin_supported_quant_types)
20
21
22
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
    pack_fp8_to_int32)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
23
24
    MarlinWorkspace, awq_marlin_quantize, get_weight_perm, marlin_quantize,
    marlin_weights)
25
26
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
    marlin_24_quantize)
27
28
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_qqq import (  # noqa: E501
    marlin_qqq_quantize)
29
from vllm.model_executor.layers.quantization.utils.quant_utils import (
30
    awq_pack, gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights)
31
32
33

ACT_ORDER_OPTS = [False, True]
K_FULL_OPTS = [False, True]
34
USE_FP32_REDUCE_OPTS = [False, True]
35

36
37
38
39
MARLIN_K_CHUNKS = [128]
MARLIN_N_CHUNKS = [64, 128, 256]

MARLIN_24_K_CHUNKS = [128]
40
MARLIN_24_N_CHUNKS = [512]
41
42
43
44
45
46
47
48
49
50

MNK_FACTORS = [
    (1, 1, 1),
    (1, 4, 8),
    (1, 7, 5),
    (13, 17, 67),
    (26, 37, 13),
    (67, 13, 11),
]

51
DTYPES = [torch.float16, torch.bfloat16]
52

53

54
55
56
57
58
def compute_max_diff(output, output_ref):
    return torch.mean(torch.abs(output - output_ref)) / torch.mean(
        torch.abs(output_ref))


59
60
def rand_data(shape, dtype=torch.float16):
    return torch.randn(shape, dtype=dtype, device="cuda")
61
62


63
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
64
                    reason="Marlin is not supported on this GPU type.")
65
66
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
67
68
@pytest.mark.parametrize("quant_type",
                         query_marlin_supported_quant_types(False))
69
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
70
71
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
72
73
def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
                            act_order, mnk_factors):
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    m_factor, n_factor, k_factor = mnk_factors

    size_m = m_factor
    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

    print(f"MNK = {size_m} {size_n} {size_k}")

    # Filter act_order
    if act_order:
        if group_size == -1:
            return
        if group_size == size_k:
            return

    # Normalize group_size
    if group_size == -1:
        group_size = size_k
    assert group_size <= size_k

    # Create input
    b_weight = rand_data((size_k, size_n))

    # Quantize (and apply act_order if provided)
98
99
    w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
        b_weight, quant_type, group_size, act_order)
100
101

    # Pack to GPTQ format
102
    q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)
103
104
105
106
107
108
109
110

    # For act_order, sort the "weights" and "g_idx" so that group ids are
    # increasing
    sort_indices = torch.empty(0, dtype=torch.int, device=b_weight.device)
    if act_order:
        q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)

    # Pack to Marlin format
111
112
113
    weight_perm = get_weight_perm(quant_type.size_bits)
    marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits,
                                  weight_perm)
114
115
116
117
118
119
120

    # Run Marlin repack GPU kernel
    marlin_q_w_2 = ops.gptq_marlin_repack(
        q_w_gptq,
        sort_indices,
        size_k,
        size_n,
121
        quant_type.size_bits,
122
123
124
125
126
127
    )
    torch.cuda.synchronize()

    assert torch.allclose(marlin_q_w_1, marlin_q_w_2)


128
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
129
                    reason="Marlin is not supported on this GPU type.")
130
131
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
132
133
@pytest.mark.parametrize("quant_type",
                         query_marlin_supported_quant_types(False))
134
135
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
136
def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
                           mnk_factors):
    m_factor, n_factor, k_factor = mnk_factors

    size_m = m_factor
    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

    print(f"MNK = {size_m} {size_n} {size_k}")

    # Normalize group_size
    if group_size == -1:
        group_size = size_k
    assert group_size <= size_k

    # Create input
    b_weight = rand_data((size_k, size_n))

    # Quantize
155
156
157
158
    w_ref, q_w, s, zp = quantize_weights(b_weight,
                                         quant_type,
                                         group_size,
                                         zero_points=True)
159
160

    # Pack to AWQ format
161
    q_w_awq = awq_pack(q_w, quant_type.size_bits, size_k, size_n)
162
163

    # Pack to Marlin format
164
165
166
    weight_perm = get_weight_perm(quant_type.size_bits)
    marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits,
                                  weight_perm)
167
168
169
170
171
172

    # Run Marlin repack GPU kernel
    marlin_q_w_2 = ops.awq_marlin_repack(
        q_w_awq,
        size_k,
        size_n,
173
        quant_type.size_bits,
174
175
176
177
178
179
180
181
182
183
    )
    torch.cuda.synchronize()

    assert torch.allclose(marlin_q_w_1, marlin_q_w_2)


@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
                    reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
184
185
@pytest.mark.parametrize("quant_type",
                         query_marlin_supported_quant_types(False))
186
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
187
188
189
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
@pytest.mark.parametrize("is_k_full", K_FULL_OPTS)
190
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
191
def test_gptq_marlin_gemm(
192
193
    k_chunk,
    n_chunk,
194
    quant_type,
195
196
197
198
    group_size,
    mnk_factors,
    act_order,
    is_k_full,
199
    use_fp32_reduce,
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
):
    m_factor, n_factor, k_factor = mnk_factors

    size_m = m_factor
    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

    print(f"MNK = {size_m} {size_n} {size_k}")
    print(f"groupsize = {group_size}")

    if act_order:
        if group_size == -1:
            return
        if group_size == size_k:
            return

    a_input = rand_data((size_m, size_k))
    b_weight = rand_data((size_k, size_n))

    w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
220
        b_weight, quant_type, group_size, act_order)
221

222
223
    marlin_zp = marlin_make_empty_g_idx(marlin_s.device)

224
225
    workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
                                GPTQ_MARLIN_MAX_PARALLEL)
226
227
228
229
230

    output = ops.gptq_marlin_gemm(
        a_input,
        marlin_q_w,
        marlin_s,
231
        marlin_zp,
232
233
234
        g_idx,
        sort_indices,
        workspace.scratch,
235
        quant_type,
236
237
238
        a_input.shape[0],
        b_weight.shape[1],
        a_input.shape[1],
239
        is_k_full=is_k_full,
240
        has_zp=False,
241
        use_fp32_reduce=use_fp32_reduce,
242
243
244
245
246
    )
    output_ref = torch.matmul(a_input, w_ref)

    torch.cuda.synchronize()

247
248
249
250
251
252
    max_diff = compute_max_diff(output, output_ref)
    print("max_diff = {}".format(max_diff))

    assert max_diff < 0.04


253
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
254
255
256
                    reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_24_N_CHUNKS)
257
@pytest.mark.parametrize("quant_type", GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
258
259
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
260
def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
261
                             mnk_factors):
262
263
264
265
266
267
268
269
270
271
272
273
274
    m_factor, n_factor, k_factor = mnk_factors

    size_m = m_factor
    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

    print(f"MNK = {size_m} {size_n} {size_k}")
    print(f"groupsize = {group_size}")

    a_input = rand_data((size_m, size_k))
    b_weight = rand_data((size_k, size_n))

    (w_24_ref, marlin_24_q_w_comp, marlin_24_meta,
275
     marlin_24_s) = marlin_24_quantize(b_weight, quant_type, group_size)
276
277
278
279
280
281
282
283
284
285
286
287

    workspace_24 = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N,
                                   GPTQ_MARLIN_24_MAX_PARALLEL)

    output_ref = torch.matmul(a_input, w_24_ref)

    output = ops.gptq_marlin_24_gemm(
        a_input,
        marlin_24_q_w_comp,
        marlin_24_meta,
        marlin_24_s,
        workspace_24.scratch,
288
        quant_type,
289
290
291
292
293
294
295
296
297
298
299
        a_input.shape[0],
        b_weight.shape[1],
        a_input.shape[1],
    )

    torch.cuda.synchronize()

    max_diff = compute_max_diff(output, output_ref)
    print("max_diff = {}".format(max_diff))

    assert max_diff < 0.04
300
301


302
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
                    reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("num_bits", [8])
@pytest.mark.parametrize("group_size", [-1])
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("dtype", DTYPES)
def test_fp8_marlin_gemm(
    k_chunk,
    n_chunk,
    num_bits,
    group_size,
    mnk_factors,
    dtype,
):
    m_factor, n_factor, k_factor = mnk_factors

    size_m = m_factor
    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

    print(f"MNK = {size_m} {size_n} {size_k}")
    print(f"groupsize = {group_size}")

    a_input = rand_data((size_m, size_k), dtype=dtype)
    b_weight = rand_data((size_k, size_n), dtype=dtype)

    # WEIGHTS
    fp8_weight, weight_scale = ops.scaled_fp8_quant(b_weight, scale=None)
    # Repack weights to gptq format (packed int32 elements)
    packed_gptq_qweight = pack_fp8_to_int32(fp8_weight)
    # Repack weights to marlin format
    marlin_qweight = ops.gptq_marlin_repack(
        b_q_weight=packed_gptq_qweight,
        perm=torch.empty(0, dtype=torch.int, device="cuda"),
        size_k=size_k,
        size_n=size_n,
        num_bits=8,
    )

    # WEIGHT SCALES
    # Currently Marlin doesn't support per-tensor scales, so we
    # expand it to channelwise
    scales = weight_scale.repeat(1, size_n).to(a_input.dtype).to("cuda")
    # Permute scales
348
349
350
351
    marlin_scales = marlin_permute_scales(s=scales,
                                          size_k=size_k,
                                          size_n=size_n,
                                          group_size=-1)
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373

    workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
                                GPTQ_MARLIN_MAX_PARALLEL)

    output = ops.fp8_marlin_gemm(
        a=a_input,
        b_q_weight=marlin_qweight,
        b_scales=marlin_scales,
        workspace=workspace.scratch,
        num_bits=num_bits,
        size_m=a_input.shape[0],
        size_n=b_weight.shape[1],
        size_k=a_input.shape[1],
    )
    output_ref = torch.matmul(a_input, b_weight)

    torch.cuda.synchronize()

    max_diff = compute_max_diff(output, output_ref)
    print("max_diff = {}".format(max_diff))

    assert max_diff < 0.04
374
375
376
377
378
379


@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
                    reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
380
381
@pytest.mark.parametrize("quant_type",
                         query_marlin_supported_quant_types(True))
382
383
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
384
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
385
386
387
def test_awq_marlin_gemm(
    k_chunk,
    n_chunk,
388
    quant_type,
389
390
    group_size,
    mnk_factors,
391
    use_fp32_reduce,
392
393
394
395
396
397
398
399
400
401
402
403
404
405
):
    m_factor, n_factor, k_factor = mnk_factors

    size_m = m_factor
    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

    print(f"MNK = {size_m} {size_n} {size_k}")
    print(f"groupsize = {group_size}")

    a_input = rand_data((size_m, size_k))
    b_weight = rand_data((size_k, size_n))

    w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
406
        b_weight, quant_type, group_size)
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423

    g_idx = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
    sort_indices = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
    is_k_full = True
    has_zp = True

    workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
                                GPTQ_MARLIN_MAX_PARALLEL)

    output = ops.gptq_marlin_gemm(
        a_input,
        marlin_q_w,
        marlin_s,
        marlin_zp,
        g_idx,
        sort_indices,
        workspace.scratch,
424
        quant_type,
425
426
427
        a_input.shape[0],
        b_weight.shape[1],
        a_input.shape[1],
428
429
430
        is_k_full=is_k_full,
        has_zp=has_zp,
        use_fp32_reduce=use_fp32_reduce,
431
432
433
434
435
436
437
438
439
    )
    output_ref = torch.matmul(a_input, w_ref)

    torch.cuda.synchronize()

    max_diff = compute_max_diff(output, output_ref)
    print("max_diff = {}".format(max_diff))

    assert max_diff < 0.04
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500


@pytest.mark.skipif(not is_quant_method_supported("qqq"),
                    reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("num_bits", MARLIN_QQQ_SUPPORTED_NUM_BITS)
@pytest.mark.parametrize("group_size", MARLIN_QQQ_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
def test_marlin_qqq_gemm(
    k_chunk,
    n_chunk,
    num_bits,
    group_size,
    mnk_factors,
):
    int8_traits = torch.iinfo(torch.int8)
    m_factor, n_factor, k_factor = mnk_factors

    size_m = m_factor
    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

    print(f"MNK = {size_m} {size_n} {size_k}")
    print(f"groupsize = {group_size}")

    a_input = rand_data((size_m, size_k))
    b_weight = rand_data((size_k, size_n))

    # Quantize activations
    s_a = a_input.abs().max(dim=-1, keepdim=True)[0].div(int8_traits.max).to(
        torch.float)
    q_a = (a_input / s_a).round().clamp(int8_traits.min,
                                        int8_traits.max).to(torch.int8)

    # Quantize weights
    w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel = \
    marlin_qqq_quantize(b_weight, num_bits, group_size)

    workspace = MarlinWorkspace(size_n, MARLIN_QQQ_MIN_THREAD_N,
                                MARLIN_QQQ_MAX_PARALLEL)

    output = ops.marlin_qqq_gemm(
        q_a,
        marlin_qqq_q_w,
        s_a,
        marlin_qqq_s_channel,
        marlin_qqq_s_group,
        workspace.scratch,
        a_input.shape[0],
        b_weight.shape[1],
        a_input.shape[1],
    )
    output_ref = torch.matmul(q_a.half() * s_a.half(), w_ref)

    torch.cuda.synchronize()

    max_diff = compute_max_diff(output, output_ref)
    print("max_diff = {}".format(max_diff))

    assert max_diff < 0.04