"vllm/entrypoints/pooling/typing.py" did not exist on "440f0e7dc6cb0adfc9c3c98076939668b90c4bf2"
test_marlin_gemm.py 12.7 KB
Newer Older
1
2
3
4
5
6
7
"""Tests for the marlin kernel.

Run `pytest tests/kernels/marlin/test_marlin_gemm.py`.
"""
import pytest
import torch

8
from tests.quantization.utils import is_quant_method_supported
9
from vllm import _custom_ops as ops
10
11
12
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
    GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
    GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
13
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
14
    GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
15
16
    MARLIN_SUPPORTED_GROUP_SIZES, MARLIN_SUPPORTED_NUM_BITS,
    marlin_make_empty_g_idx, marlin_permute_scales)
17
18
19
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
    pack_fp8_to_int32)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
20
21
    MarlinWorkspace, awq_marlin_quantize, get_weight_perm, marlin_quantize,
    marlin_weights)
22
23
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
    marlin_24_quantize)
24
from vllm.model_executor.layers.quantization.utils.quant_utils import (
25
26
    awq_pack, gptq_pack, quantize_weights, quantize_weights_with_zp,
    sort_weights)
27
28
29
30

ACT_ORDER_OPTS = [False, True]
K_FULL_OPTS = [False, True]

31
32
33
34
MARLIN_K_CHUNKS = [128]
MARLIN_N_CHUNKS = [64, 128, 256]

MARLIN_24_K_CHUNKS = [128]
35
MARLIN_24_N_CHUNKS = [512]
36
37
38
39
40
41
42
43
44
45

MNK_FACTORS = [
    (1, 1, 1),
    (1, 4, 8),
    (1, 7, 5),
    (13, 17, 67),
    (26, 37, 13),
    (67, 13, 11),
]

46
DTYPES = [torch.float16, torch.bfloat16]
47

48

49
50
51
52
53
def compute_max_diff(output, output_ref):
    return torch.mean(torch.abs(output - output_ref)) / torch.mean(
        torch.abs(output_ref))


54
55
def rand_data(shape, dtype=torch.float16):
    return torch.randn(shape, dtype=dtype, device="cuda")
56
57


58
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
59
                    reason="Marlin is not supported on this GPU type.")
60
61
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
62
63
@pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS)
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
64
65
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
66
67
def test_gptq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
                            mnk_factors):
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
    m_factor, n_factor, k_factor = mnk_factors

    size_m = m_factor
    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

    print(f"MNK = {size_m} {size_n} {size_k}")

    # Filter act_order
    if act_order:
        if group_size == -1:
            return
        if group_size == size_k:
            return

    # Normalize group_size
    if group_size == -1:
        group_size = size_k
    assert group_size <= size_k

    # Create input
    b_weight = rand_data((size_k, size_n))

    # Quantize (and apply act_order if provided)
    w_ref, q_w, s, g_idx, rand_perm = quantize_weights(b_weight, num_bits,
                                                       group_size, act_order)

    # Pack to GPTQ format
    q_w_gptq = gptq_pack(q_w, num_bits, size_k, size_n)

    # For act_order, sort the "weights" and "g_idx" so that group ids are
    # increasing
    sort_indices = torch.empty(0, dtype=torch.int, device=b_weight.device)
    if act_order:
        q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)

    # Pack to Marlin format
105
106
    weight_perm = get_weight_perm(num_bits)
    marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm)
107
108
109
110
111
112
113
114
115
116
117
118
119
120

    # Run Marlin repack GPU kernel
    marlin_q_w_2 = ops.gptq_marlin_repack(
        q_w_gptq,
        sort_indices,
        size_k,
        size_n,
        num_bits,
    )
    torch.cuda.synchronize()

    assert torch.allclose(marlin_q_w_1, marlin_q_w_2)


121
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
122
                    reason="Marlin is not supported on this GPU type.")
123
124
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
@pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS)
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
def test_awq_marlin_repack(k_chunk, n_chunk, num_bits, group_size,
                           mnk_factors):
    m_factor, n_factor, k_factor = mnk_factors

    size_m = m_factor
    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

    print(f"MNK = {size_m} {size_n} {size_k}")

    # Normalize group_size
    if group_size == -1:
        group_size = size_k
    assert group_size <= size_k

    # Create input
    b_weight = rand_data((size_k, size_n))

    # Quantize
    w_ref, q_w, s, zp = quantize_weights_with_zp(b_weight, num_bits,
                                                 group_size)

    # Pack to AWQ format
    q_w_awq = awq_pack(q_w, num_bits, size_k, size_n)

    # Pack to Marlin format
    weight_perm = get_weight_perm(num_bits)
    marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm)

    # Run Marlin repack GPU kernel
    marlin_q_w_2 = ops.awq_marlin_repack(
        q_w_awq,
        size_k,
        size_n,
        num_bits,
    )
    torch.cuda.synchronize()

    assert torch.allclose(marlin_q_w_1, marlin_q_w_2)


@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
                    reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS)
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
175
176
177
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
@pytest.mark.parametrize("is_k_full", K_FULL_OPTS)
178
def test_gptq_marlin_gemm(
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
    k_chunk,
    n_chunk,
    num_bits,
    group_size,
    mnk_factors,
    act_order,
    is_k_full,
):
    m_factor, n_factor, k_factor = mnk_factors

    size_m = m_factor
    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

    print(f"MNK = {size_m} {size_n} {size_k}")
    print(f"groupsize = {group_size}")

    if act_order:
        if group_size == -1:
            return
        if group_size == size_k:
            return

    a_input = rand_data((size_m, size_k))
    b_weight = rand_data((size_k, size_n))

    w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
        b_weight, num_bits, group_size, act_order)

208
209
    marlin_zp = marlin_make_empty_g_idx(marlin_s.device)

210
211
    workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
                                GPTQ_MARLIN_MAX_PARALLEL)
212
213
214
215
216

    output = ops.gptq_marlin_gemm(
        a_input,
        marlin_q_w,
        marlin_s,
217
        marlin_zp,
218
219
220
221
222
223
224
225
        g_idx,
        sort_indices,
        workspace.scratch,
        num_bits,
        a_input.shape[0],
        b_weight.shape[1],
        a_input.shape[1],
        is_k_full,
226
        has_zp=False,
227
228
229
230
231
    )
    output_ref = torch.matmul(a_input, w_ref)

    torch.cuda.synchronize()

232
233
234
235
236
237
    max_diff = compute_max_diff(output, output_ref)
    print("max_diff = {}".format(max_diff))

    assert max_diff < 0.04


238
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
239
240
241
242
243
244
                    reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_24_N_CHUNKS)
@pytest.mark.parametrize("num_bits", GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
245
246
def test_gptq_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size,
                             mnk_factors):
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
    m_factor, n_factor, k_factor = mnk_factors

    size_m = m_factor
    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

    print(f"MNK = {size_m} {size_n} {size_k}")
    print(f"groupsize = {group_size}")

    a_input = rand_data((size_m, size_k))
    b_weight = rand_data((size_k, size_n))

    (w_24_ref, marlin_24_q_w_comp, marlin_24_meta,
     marlin_24_s) = marlin_24_quantize(b_weight, num_bits, group_size)

    workspace_24 = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N,
                                   GPTQ_MARLIN_24_MAX_PARALLEL)

    output_ref = torch.matmul(a_input, w_24_ref)

    output = ops.gptq_marlin_24_gemm(
        a_input,
        marlin_24_q_w_comp,
        marlin_24_meta,
        marlin_24_s,
        workspace_24.scratch,
        num_bits,
        a_input.shape[0],
        b_weight.shape[1],
        a_input.shape[1],
    )

    torch.cuda.synchronize()

    max_diff = compute_max_diff(output, output_ref)
    print("max_diff = {}".format(max_diff))

    assert max_diff < 0.04
285
286


287
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
                    reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("num_bits", [8])
@pytest.mark.parametrize("group_size", [-1])
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("dtype", DTYPES)
def test_fp8_marlin_gemm(
    k_chunk,
    n_chunk,
    num_bits,
    group_size,
    mnk_factors,
    dtype,
):
    m_factor, n_factor, k_factor = mnk_factors

    size_m = m_factor
    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

    print(f"MNK = {size_m} {size_n} {size_k}")
    print(f"groupsize = {group_size}")

    a_input = rand_data((size_m, size_k), dtype=dtype)
    b_weight = rand_data((size_k, size_n), dtype=dtype)

    # WEIGHTS
    fp8_weight, weight_scale = ops.scaled_fp8_quant(b_weight, scale=None)
    # Repack weights to gptq format (packed int32 elements)
    packed_gptq_qweight = pack_fp8_to_int32(fp8_weight)
    # Repack weights to marlin format
    marlin_qweight = ops.gptq_marlin_repack(
        b_q_weight=packed_gptq_qweight,
        perm=torch.empty(0, dtype=torch.int, device="cuda"),
        size_k=size_k,
        size_n=size_n,
        num_bits=8,
    )

    # WEIGHT SCALES
    # Currently Marlin doesn't support per-tensor scales, so we
    # expand it to channelwise
    scales = weight_scale.repeat(1, size_n).to(a_input.dtype).to("cuda")
    # Permute scales
333
334
335
336
    marlin_scales = marlin_permute_scales(s=scales,
                                          size_k=size_k,
                                          size_n=size_n,
                                          group_size=-1)
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358

    workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
                                GPTQ_MARLIN_MAX_PARALLEL)

    output = ops.fp8_marlin_gemm(
        a=a_input,
        b_q_weight=marlin_qweight,
        b_scales=marlin_scales,
        workspace=workspace.scratch,
        num_bits=num_bits,
        size_m=a_input.shape[0],
        size_n=b_weight.shape[1],
        size_k=a_input.shape[1],
    )
    output_ref = torch.matmul(a_input, b_weight)

    torch.cuda.synchronize()

    max_diff = compute_max_diff(output, output_ref)
    print("max_diff = {}".format(max_diff))

    assert max_diff < 0.04
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420


@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
                    reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS)
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
def test_awq_marlin_gemm(
    k_chunk,
    n_chunk,
    num_bits,
    group_size,
    mnk_factors,
):
    m_factor, n_factor, k_factor = mnk_factors

    size_m = m_factor
    size_k = k_chunk * k_factor
    size_n = n_chunk * n_factor

    print(f"MNK = {size_m} {size_n} {size_k}")
    print(f"groupsize = {group_size}")

    a_input = rand_data((size_m, size_k))
    b_weight = rand_data((size_k, size_n))

    w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
        b_weight, num_bits, group_size)

    g_idx = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
    sort_indices = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
    is_k_full = True
    has_zp = True

    workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
                                GPTQ_MARLIN_MAX_PARALLEL)

    output = ops.gptq_marlin_gemm(
        a_input,
        marlin_q_w,
        marlin_s,
        marlin_zp,
        g_idx,
        sort_indices,
        workspace.scratch,
        num_bits,
        a_input.shape[0],
        b_weight.shape[1],
        a_input.shape[1],
        is_k_full,
        has_zp,
    )
    output_ref = torch.matmul(a_input, w_ref)

    torch.cuda.synchronize()

    max_diff = compute_max_diff(output, output_ref)
    print("max_diff = {}".format(max_diff))

    assert max_diff < 0.04