benchmark_marlin.py 14.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
import torch
import torch.utils.benchmark as benchmark
from benchmark_shapes import WEIGHT_SHAPES

from vllm import _custom_ops as ops
9
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
10
11
12
13
14
    GPTQ_MARLIN_24_MAX_PARALLEL,
    GPTQ_MARLIN_24_MIN_THREAD_N,
    GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES,
    GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES,
)
15
from vllm.model_executor.layers.quantization.utils.allspark_utils import (
16
17
18
    ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
    ALLSPARK_SUPPORTED_QUANT_TYPES,
)
19
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
20
21
22
23
24
    GPTQ_MARLIN_MAX_PARALLEL,
    GPTQ_MARLIN_MIN_THREAD_N,
    MARLIN_SUPPORTED_GROUP_SIZES,
    query_marlin_supported_quant_types,
)
25
26
27
28
29
30
31
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
    FP4_MARLIN_SUPPORTED_GROUP_SIZES,
    rand_marlin_weight_fp4_like,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
    marlin_quant_fp8_torch,
)
32
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
33
    MarlinWorkspace,
34
    awq_marlin_quantize,
35
36
    marlin_quantize,
)
37
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
38
39
    marlin_24_quantize,
)
40
from vllm.model_executor.layers.quantization.utils.quant_utils import (
41
42
43
44
45
    gptq_pack,
    gptq_quantize_weights,
    quantize_weights,
    sort_weights,
)
46
from vllm.scalar_type import ScalarType, scalar_types
47
from vllm.utils import FlexibleArgumentParser
48
49

DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"]
50
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
51
52
53
54
55

ACT_ORDER_OPTS = [False, True]
K_FULL_OPTS = [False, True]


56
57
58
59
60
61
62
63
64
65
66
def bench_run(
    results: list[benchmark.Measurement],
    model: str,
    act_order: bool,
    is_k_full: bool,
    quant_type: ScalarType,
    group_size: int,
    size_m: int,
    size_k: int,
    size_n: int,
):
67
    label = "Quant Matmul"
68
69
70
    sub_label = "{}, act={} k_full={}, q={}, g={}, MKN=({}x{}x{})".format(
        model, act_order, is_k_full, str(quant_type), group_size, size_m, size_k, size_n
    )
71
72
73
74
    print(f"Testing: {sub_label}")

    a = torch.randn(size_m, size_k).to(torch.half).cuda()
    b = torch.rand(size_k, size_n).to(torch.half).cuda()
75
76
77
78
79
    has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
    if act_order and (group_size == -1 or group_size == size_k or has_zp):
        return
    if size_k % group_size != 0:
        return
80

81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
    marlin_24_supported = (
        quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
        and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
    )
    repack_supported = (
        quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
        and group_size in MARLIN_SUPPORTED_GROUP_SIZES
    )
    allspark_supported = (
        quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES
        and group_size == -1
        and not act_order
        and is_k_full
    )

    def gen_marlin_params():
        # Marlin quant
        marlin_g_idx = marlin_sort_indices = marlin_zp = marlin_s2 = None
        if quant_type == scalar_types.float4_e2m1f:
            if group_size != 16 or act_order:
                return
            marlin_w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_fp4_like(
                b.T, group_size
            )
        elif quant_type == scalar_types.float8_e4m3fn:
            if group_size not in [-1, 128] or act_order:
                return
            marlin_w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(b.T, group_size)
        elif group_size == 16:
            return
        elif has_zp:
            marlin_w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
                b, quant_type, group_size
            )
        else:
            marlin_w_ref, marlin_q_w, marlin_s, marlin_g_idx, marlin_sort_indices, _ = (
                marlin_quantize(b, quant_type, group_size, act_order)
            )
        return (
            marlin_w_ref,
            marlin_q_w,
            marlin_s,
            marlin_s2,
            marlin_zp,
            marlin_g_idx,
            marlin_sort_indices,
        )

    def gen_marlin_24_params():
        marlin_24_w_ref = marlin_24_q_w_comp = marlin_24_meta = marlin_24_s = None
        if marlin_24_supported:
            (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = (
                marlin_24_quantize(b, quant_type, group_size)
            )
        return (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s)

    def gen_repack_params():
        q_w_gptq = None
        repack_sort_indices = None
        if repack_supported:
            (w_ref, q_w, s, g_idx, rand_perm) = gptq_quantize_weights(
                b, quant_type, group_size, act_order
            )
            q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)

            # For act_order, sort the "weights" and "g_idx"
            # so that group ids are increasing
            repack_sort_indices = torch.empty(0, dtype=torch.int, device=b.device)
            if act_order:
                (q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx)
        return q_w_gptq, repack_sort_indices

    def gen_allspark_params():
        qw_reorder = s_reorder = zp_reorder = sm_count = sm_version = (
            CUBLAS_M_THRESHOLD
        ) = None
        nonlocal allspark_supported
        if allspark_supported:
            properties = torch.cuda.get_device_properties(b.device.index)
            sm_count = properties.multi_processor_count
            sm_version = properties.major * 10 + properties.minor

            supported_arch = sm_version >= 80 and sm_version < 90
            allspark_supported = allspark_supported and supported_arch
            if supported_arch:
                w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size, has_zp)
                qw = qw.to(torch.uint8)

                qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight(
                    qw, s, zp, has_zp
                )
                CUBLAS_M_THRESHOLD = ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
        return (
            qw_reorder,
            s_reorder,
            zp_reorder,
            sm_count,
            sm_version,
            CUBLAS_M_THRESHOLD,
        )
181
182
183
184
185

    (
        marlin_w_ref,
        marlin_q_w,
        marlin_s,
186
187
        marlin_s2,
        marlin_zp,
188
189
        marlin_g_idx,
        marlin_sort_indices,
190
191
192
    ) = gen_marlin_params()
    marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s = (
        gen_marlin_24_params()
193
    )
194
195
196
    q_w_gptq, repack_sort_indices = gen_repack_params()
    qw_reorder, s_reorder, zp_reorder, sm_count, sm_version, CUBLAS_M_THRESHOLD = (
        gen_allspark_params()
197
    )
198
199

    # Prepare
200
201
202
203
204
205
    marlin_workspace = MarlinWorkspace(
        size_n, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL
    )
    marlin_24_workspace = MarlinWorkspace(
        size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL
    )
206

207
    globals = {
208
        # Gen params
209
        "quant_type": quant_type,
210
211
212
213
214
215
        "group_size": group_size,
        "size_m": size_m,
        "size_n": size_n,
        "size_k": size_k,
        "a": a,
        # Marlin params
216
217
218
        "marlin_w_ref": marlin_w_ref,
        "marlin_q_w": marlin_q_w,
        "marlin_s": marlin_s,
219
        "marlin_s2": marlin_s2,
220
        "marlin_zp": marlin_zp,
221
222
        "marlin_g_idx": marlin_g_idx,
        "marlin_sort_indices": marlin_sort_indices,
223
224
225
226
227
228
229
230
231
        "marlin_workspace": marlin_workspace,
        "is_k_full": is_k_full,
        # Marlin_24 params
        "marlin_24_w_ref": marlin_24_w_ref,
        "marlin_24_q_w_comp": marlin_24_q_w_comp,
        "marlin_24_meta": marlin_24_meta,
        "marlin_24_s": marlin_24_s,
        "marlin_24_workspace": marlin_24_workspace,
        # GPTQ params
232
233
        "q_w_gptq": q_w_gptq,
        "repack_sort_indices": repack_sort_indices,
234
        # AllSpark W8A16 params
235
236
237
238
239
240
        "qw_reorder": qw_reorder,
        "s_reorder": s_reorder,
        "zp_reorder": zp_reorder,
        "sm_count": sm_count,
        "sm_version": sm_version,
        "CUBLAS_M_THRESHOLD": CUBLAS_M_THRESHOLD,
241
        # Kernels
242
        "gptq_marlin_gemm": ops.gptq_marlin_gemm,
243
        "gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm,
244
        "gptq_marlin_repack": ops.gptq_marlin_repack,
245
        "allspark_w8a16_gemm": ops.allspark_w8a16_gemm,
246
247
248
249
250
    }

    min_run_time = 1

    # Warmup pytorch
251
    for _ in range(5):
252
253
254
255
256
257
258
259
260
        torch.matmul(a, marlin_w_ref)

    results.append(
        benchmark.Timer(
            stmt="torch.matmul(a, marlin_w_ref)",
            globals=globals,
            label=label,
            sub_label=sub_label,
            description="pytorch_gemm",
261
262
        ).blocked_autorange(min_run_time=min_run_time)
    )
263
264
265

    results.append(
        benchmark.Timer(
266
            stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)",  # noqa: E501
267
268
269
            globals=globals,
            label=label,
            sub_label=sub_label,
270
            description="gptq_marlin_gemm",
271
272
        ).blocked_autorange(min_run_time=min_run_time)
    )
273
274
275

    results.append(
        benchmark.Timer(
276
            stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)",  # noqa: E501
277
278
279
            globals=globals,
            label=label,
            sub_label=sub_label,
280
            description="gptq_marlin_gemm_fp32",
281
282
        ).blocked_autorange(min_run_time=min_run_time)
    )
283

284
    if marlin_24_supported:
285
286
        results.append(
            benchmark.Timer(
287
                stmt="output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)",  # noqa: E501
288
289
290
291
                globals=globals,
                label=label,
                sub_label=sub_label,
                description="gptq_marlin_24_gemm",
292
293
            ).blocked_autorange(min_run_time=min_run_time)
        )
294

295
296
297
298
299
300
301
302
303
304
    if repack_supported:
        results.append(
            benchmark.Timer(
                stmt="q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)",  # noqa: E501
                globals=globals,
                label=label,
                sub_label=sub_label,
                description="gptq_marlin_repack",
            ).blocked_autorange(min_run_time=min_run_time)
        )
305

306
    if allspark_supported:
307
308
        results.append(
            benchmark.Timer(
309
                stmt="output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)",  # noqa: E501
310
311
312
313
                globals=globals,
                label=label,
                sub_label=sub_label,
                description="allspark_w8a16_gemm_fp32",
314
315
            ).blocked_autorange(min_run_time=min_run_time)
        )
316

317
318
319
320
321

def main(args):
    print("Benchmarking models:")
    for i, model in enumerate(args.models):
        print(f"[{i}]  {model}")
322
    results: list[benchmark.Measurement] = []
323
324
325
326
327
328
329
330
331
332
333
334
335

    for model in args.models:
        for layer in WEIGHT_SHAPES[model]:
            size_k = layer[0]
            size_n = layer[1]

            if len(args.limit_k) > 0 and size_k not in args.limit_k:
                continue

            if len(args.limit_n) > 0 and size_n not in args.limit_n:
                continue

            for act_order in ACT_ORDER_OPTS:
336
337
338
339
                if (
                    len(args.limit_act_order) > 0
                    and act_order not in args.limit_act_order
                ):
340
341
                    continue

342
                for is_k_full in K_FULL_OPTS:
343
344
345
346
                    if (
                        len(args.limit_k_full) > 0
                        and is_k_full not in args.limit_k_full
                    ):
347
348
                        continue

349
                    for quant_type in query_marlin_supported_quant_types():
350
351
352
353
                        if (
                            len(args.limit_num_bits) > 0
                            and quant_type.size_bits not in args.limit_num_bits
                        ):
354
355
                            continue

356
357
358
359
                        for group_size in (
                            MARLIN_SUPPORTED_GROUP_SIZES
                            + FP4_MARLIN_SUPPORTED_GROUP_SIZES
                        ):
360
361
362
363
                            if (
                                len(args.limit_group_size) > 0
                                and group_size not in args.limit_group_size
                            ):
364
365
366
367
                                continue

                            # For act_order, the group_size must be less than
                            # size_k
368
                            if act_order and (group_size == size_k or group_size == -1):
369
370
371
                                continue

                            for size_m in args.batch_sizes:
372
373
374
375
376
377
378
379
380
381
382
                                bench_run(
                                    results,
                                    model,
                                    act_order,
                                    is_k_full,
                                    quant_type,
                                    group_size,
                                    size_m,
                                    size_k,
                                    size_n,
                                )
383
384
385
386
387
388

    compare = benchmark.Compare(results)
    compare.print()


# For quick benchmarking use:
389
#   python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 --limit-num-bits 4 --limit-act-order 0 --limit-k-full 1 # noqa E501
390
391
#
if __name__ == "__main__":
392
    parser = FlexibleArgumentParser(
393
394
        description="Benchmark Marlin across specified models/shapes/batches"
    )
395
396
397
398
399
400
401
    parser.add_argument(
        "--models",
        nargs="+",
        type=str,
        default=DEFAULT_MODELS,
        choices=WEIGHT_SHAPES.keys(),
    )
402
403
404
    parser.add_argument(
        "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
    )
405
406
407
    parser.add_argument("--limit-k", nargs="+", type=int, default=[])
    parser.add_argument("--limit-n", nargs="+", type=int, default=[])
    parser.add_argument("--limit-group-size", nargs="+", type=int, default=[])
408
409
410
    parser.add_argument("--limit-num-bits", nargs="+", type=int, default=[])
    parser.add_argument("--limit-act-order", nargs="+", type=int, default=[])
    parser.add_argument("--limit-k-full", nargs="+", type=int, default=[])
411
412
413

    args = parser.parse_args()
    main(args)