benchmark_marlin.py 12.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
import torch
import torch.utils.benchmark as benchmark
from benchmark_shapes import WEIGHT_SHAPES

from vllm import _custom_ops as ops
9
from vllm.model_executor.layers.quantization.utils.allspark_utils import (
10
11
12
    ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
    ALLSPARK_SUPPORTED_QUANT_TYPES,
)
13
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
14
15
16
17
18
    GPTQ_MARLIN_MAX_PARALLEL,
    GPTQ_MARLIN_MIN_THREAD_N,
    MARLIN_SUPPORTED_GROUP_SIZES,
    query_marlin_supported_quant_types,
)
19
20
21
22
23
24
25
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
    FP4_MARLIN_SUPPORTED_GROUP_SIZES,
    rand_marlin_weight_fp4_like,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
    marlin_quant_fp8_torch,
)
26
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
27
    MarlinWorkspace,
28
    awq_marlin_quantize,
29
30
    marlin_quantize,
)
31
from vllm.model_executor.layers.quantization.utils.quant_utils import (
32
33
34
35
36
    gptq_pack,
    gptq_quantize_weights,
    quantize_weights,
    sort_weights,
)
37
from vllm.scalar_type import ScalarType, scalar_types
38
from vllm.utils.argparse_utils import FlexibleArgumentParser
39
40

DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"]
41
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
42
43
44
45
46

ACT_ORDER_OPTS = [False, True]
K_FULL_OPTS = [False, True]


47
48
49
50
51
52
53
54
55
56
57
def bench_run(
    results: list[benchmark.Measurement],
    model: str,
    act_order: bool,
    is_k_full: bool,
    quant_type: ScalarType,
    group_size: int,
    size_m: int,
    size_k: int,
    size_n: int,
):
58
    label = "Quant Matmul"
59
60
61
    sub_label = "{}, act={} k_full={}, q={}, g={}, MKN=({}x{}x{})".format(
        model, act_order, is_k_full, str(quant_type), group_size, size_m, size_k, size_n
    )
62
63
64
65
    print(f"Testing: {sub_label}")

    a = torch.randn(size_m, size_k).to(torch.half).cuda()
    b = torch.rand(size_k, size_n).to(torch.half).cuda()
66
67
68
69
70
    has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
    if act_order and (group_size == -1 or group_size == size_k or has_zp):
        return
    if size_k % group_size != 0:
        return
71

72
    repack_supported = group_size in MARLIN_SUPPORTED_GROUP_SIZES
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
    allspark_supported = (
        quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES
        and group_size == -1
        and not act_order
        and is_k_full
    )

    def gen_marlin_params():
        # Marlin quant
        marlin_g_idx = marlin_sort_indices = marlin_zp = marlin_s2 = None
        if quant_type == scalar_types.float4_e2m1f:
            if group_size != 16 or act_order:
                return
            marlin_w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_fp4_like(
                b.T, group_size
            )
        elif quant_type == scalar_types.float8_e4m3fn:
            if group_size not in [-1, 128] or act_order:
                return
            marlin_w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(b.T, group_size)
        elif group_size == 16:
            return
        elif has_zp:
            marlin_w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
                b, quant_type, group_size
            )
        else:
            marlin_w_ref, marlin_q_w, marlin_s, marlin_g_idx, marlin_sort_indices, _ = (
                marlin_quantize(b, quant_type, group_size, act_order)
            )
        return (
            marlin_w_ref,
            marlin_q_w,
            marlin_s,
            marlin_s2,
            marlin_zp,
            marlin_g_idx,
            marlin_sort_indices,
        )

    def gen_repack_params():
        q_w_gptq = None
        repack_sort_indices = None
        if repack_supported:
            (w_ref, q_w, s, g_idx, rand_perm) = gptq_quantize_weights(
                b, quant_type, group_size, act_order
            )
            q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)

            # For act_order, sort the "weights" and "g_idx"
            # so that group ids are increasing
            repack_sort_indices = torch.empty(0, dtype=torch.int, device=b.device)
            if act_order:
                (q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx)
        return q_w_gptq, repack_sort_indices

    def gen_allspark_params():
        qw_reorder = s_reorder = zp_reorder = sm_count = sm_version = (
            CUBLAS_M_THRESHOLD
        ) = None
        nonlocal allspark_supported
        if allspark_supported:
            properties = torch.cuda.get_device_properties(b.device.index)
            sm_count = properties.multi_processor_count
            sm_version = properties.major * 10 + properties.minor

            supported_arch = sm_version >= 80 and sm_version < 90
            allspark_supported = allspark_supported and supported_arch
            if supported_arch:
                w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size, has_zp)
                qw = qw.to(torch.uint8)

                qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight(
                    qw, s, zp, has_zp
                )
                CUBLAS_M_THRESHOLD = ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
        return (
            qw_reorder,
            s_reorder,
            zp_reorder,
            sm_count,
            sm_version,
            CUBLAS_M_THRESHOLD,
        )
157
158
159
160
161

    (
        marlin_w_ref,
        marlin_q_w,
        marlin_s,
162
163
        marlin_s2,
        marlin_zp,
164
165
        marlin_g_idx,
        marlin_sort_indices,
166
167
168
169
    ) = gen_marlin_params()
    q_w_gptq, repack_sort_indices = gen_repack_params()
    qw_reorder, s_reorder, zp_reorder, sm_count, sm_version, CUBLAS_M_THRESHOLD = (
        gen_allspark_params()
170
    )
171
172

    # Prepare
173
174
175
    marlin_workspace = MarlinWorkspace(
        size_n, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL
    )
176

177
    globals = {
178
        # Gen params
179
        "quant_type": quant_type,
180
181
182
183
184
185
        "group_size": group_size,
        "size_m": size_m,
        "size_n": size_n,
        "size_k": size_k,
        "a": a,
        # Marlin params
186
187
188
        "marlin_w_ref": marlin_w_ref,
        "marlin_q_w": marlin_q_w,
        "marlin_s": marlin_s,
189
        "marlin_s2": marlin_s2,
190
        "marlin_zp": marlin_zp,
191
192
        "marlin_g_idx": marlin_g_idx,
        "marlin_sort_indices": marlin_sort_indices,
193
194
195
        "marlin_workspace": marlin_workspace,
        "is_k_full": is_k_full,
        # GPTQ params
196
197
        "q_w_gptq": q_w_gptq,
        "repack_sort_indices": repack_sort_indices,
198
        # AllSpark W8A16 params
199
200
201
202
203
204
        "qw_reorder": qw_reorder,
        "s_reorder": s_reorder,
        "zp_reorder": zp_reorder,
        "sm_count": sm_count,
        "sm_version": sm_version,
        "CUBLAS_M_THRESHOLD": CUBLAS_M_THRESHOLD,
205
        # Kernels
206
        "marlin_gemm": ops.marlin_gemm,
207
        "gptq_marlin_repack": ops.gptq_marlin_repack,
208
        "allspark_w8a16_gemm": ops.allspark_w8a16_gemm,
209
210
211
212
213
    }

    min_run_time = 1

    # Warmup pytorch
214
    for _ in range(5):
215
216
217
218
219
220
221
222
223
        torch.matmul(a, marlin_w_ref)

    results.append(
        benchmark.Timer(
            stmt="torch.matmul(a, marlin_w_ref)",
            globals=globals,
            label=label,
            sub_label=sub_label,
            description="pytorch_gemm",
224
225
        ).blocked_autorange(min_run_time=min_run_time)
    )
226
227
228

    results.append(
        benchmark.Timer(
229
            stmt="output = marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)",  # noqa: E501
230
231
232
            globals=globals,
            label=label,
            sub_label=sub_label,
233
            description="marlin_gemm",
234
235
        ).blocked_autorange(min_run_time=min_run_time)
    )
236
237
238

    results.append(
        benchmark.Timer(
239
            stmt="output = marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)",  # noqa: E501
240
241
242
            globals=globals,
            label=label,
            sub_label=sub_label,
243
            description="marlin_gemm_fp32",
244
245
        ).blocked_autorange(min_run_time=min_run_time)
    )
246

247
248
249
250
251
252
253
254
255
256
    if repack_supported:
        results.append(
            benchmark.Timer(
                stmt="q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)",  # noqa: E501
                globals=globals,
                label=label,
                sub_label=sub_label,
                description="gptq_marlin_repack",
            ).blocked_autorange(min_run_time=min_run_time)
        )
257

258
    if allspark_supported:
259
260
        results.append(
            benchmark.Timer(
261
                stmt="output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)",  # noqa: E501
262
263
264
265
                globals=globals,
                label=label,
                sub_label=sub_label,
                description="allspark_w8a16_gemm_fp32",
266
267
            ).blocked_autorange(min_run_time=min_run_time)
        )
268

269
270
271
272
273

def main(args):
    print("Benchmarking models:")
    for i, model in enumerate(args.models):
        print(f"[{i}]  {model}")
274
    results: list[benchmark.Measurement] = []
275
276
277
278
279
280
281
282
283
284
285
286
287

    for model in args.models:
        for layer in WEIGHT_SHAPES[model]:
            size_k = layer[0]
            size_n = layer[1]

            if len(args.limit_k) > 0 and size_k not in args.limit_k:
                continue

            if len(args.limit_n) > 0 and size_n not in args.limit_n:
                continue

            for act_order in ACT_ORDER_OPTS:
288
289
290
291
                if (
                    len(args.limit_act_order) > 0
                    and act_order not in args.limit_act_order
                ):
292
293
                    continue

294
                for is_k_full in K_FULL_OPTS:
295
296
297
298
                    if (
                        len(args.limit_k_full) > 0
                        and is_k_full not in args.limit_k_full
                    ):
299
300
                        continue

301
                    for quant_type in query_marlin_supported_quant_types():
302
303
304
305
                        if (
                            len(args.limit_num_bits) > 0
                            and quant_type.size_bits not in args.limit_num_bits
                        ):
306
307
                            continue

308
309
310
311
                        for group_size in (
                            MARLIN_SUPPORTED_GROUP_SIZES
                            + FP4_MARLIN_SUPPORTED_GROUP_SIZES
                        ):
312
313
314
315
                            if (
                                len(args.limit_group_size) > 0
                                and group_size not in args.limit_group_size
                            ):
316
317
318
319
                                continue

                            # For act_order, the group_size must be less than
                            # size_k
320
                            if act_order and (group_size == size_k or group_size == -1):
321
322
323
                                continue

                            for size_m in args.batch_sizes:
324
325
326
327
328
329
330
331
332
333
334
                                bench_run(
                                    results,
                                    model,
                                    act_order,
                                    is_k_full,
                                    quant_type,
                                    group_size,
                                    size_m,
                                    size_k,
                                    size_n,
                                )
335
336
337
338
339
340

    compare = benchmark.Compare(results)
    compare.print()


# For quick benchmarking use:
341
#   python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 --limit-num-bits 4 --limit-act-order 0 --limit-k-full 1 # noqa E501
342
343
#
if __name__ == "__main__":
344
    parser = FlexibleArgumentParser(
345
346
        description="Benchmark Marlin across specified models/shapes/batches"
    )
347
348
349
350
351
352
353
    parser.add_argument(
        "--models",
        nargs="+",
        type=str,
        default=DEFAULT_MODELS,
        choices=WEIGHT_SHAPES.keys(),
    )
354
355
356
    parser.add_argument(
        "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
    )
357
358
359
    parser.add_argument("--limit-k", nargs="+", type=int, default=[])
    parser.add_argument("--limit-n", nargs="+", type=int, default=[])
    parser.add_argument("--limit-group-size", nargs="+", type=int, default=[])
360
361
362
    parser.add_argument("--limit-num-bits", nargs="+", type=int, default=[])
    parser.add_argument("--limit-act-order", nargs="+", type=int, default=[])
    parser.add_argument("--limit-k-full", nargs="+", type=int, default=[])
363
364
365

    args = parser.parse_args()
    main(args)