benchmark_marlin.py 11.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
import torch
import torch.utils.benchmark as benchmark
from benchmark_shapes import WEIGHT_SHAPES

from vllm import _custom_ops as ops
9
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
10
11
12
13
14
    GPTQ_MARLIN_24_MAX_PARALLEL,
    GPTQ_MARLIN_24_MIN_THREAD_N,
    GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES,
    GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES,
)
15
from vllm.model_executor.layers.quantization.utils.allspark_utils import (
16
17
18
    ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
    ALLSPARK_SUPPORTED_QUANT_TYPES,
)
19
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
20
21
22
23
24
    GPTQ_MARLIN_MAX_PARALLEL,
    GPTQ_MARLIN_MIN_THREAD_N,
    MARLIN_SUPPORTED_GROUP_SIZES,
    query_marlin_supported_quant_types,
)
25
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
26
27
28
    MarlinWorkspace,
    marlin_quantize,
)
29
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
30
31
    marlin_24_quantize,
)
32
from vllm.model_executor.layers.quantization.utils.quant_utils import (
33
34
35
36
37
    gptq_pack,
    gptq_quantize_weights,
    quantize_weights,
    sort_weights,
)
38
from vllm.scalar_type import ScalarType
39
from vllm.utils import FlexibleArgumentParser
40
41

DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"]
42
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
43
44
45
46
47

ACT_ORDER_OPTS = [False, True]
K_FULL_OPTS = [False, True]


48
49
50
51
52
53
54
55
56
57
58
def bench_run(
    results: list[benchmark.Measurement],
    model: str,
    act_order: bool,
    is_k_full: bool,
    quant_type: ScalarType,
    group_size: int,
    size_m: int,
    size_k: int,
    size_n: int,
):
59
60
    label = "Quant Matmul"

61
62
63
    sub_label = "{}, act={} k_full={}, q={}, g={}, MKN=({}x{}x{})".format(
        model, act_order, is_k_full, str(quant_type), group_size, size_m, size_k, size_n
    )
64
65
66
67
68
69

    print(f"Testing: {sub_label}")

    a = torch.randn(size_m, size_k).to(torch.half).cuda()
    b = torch.rand(size_k, size_n).to(torch.half).cuda()

70
    a_tmp = torch.zeros(size_m, size_k).to(torch.half).cuda()
71
72
73
74
75
76
77
78
79

    # Marlin quant
    (
        marlin_w_ref,
        marlin_q_w,
        marlin_s,
        marlin_g_idx,
        marlin_sort_indices,
        marlin_rand_perm,
80
    ) = marlin_quantize(b, quant_type, group_size, act_order)
81

82
    # Marlin_24 quant
83
84
85
    (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = (
        marlin_24_quantize(b, quant_type, group_size)
    )
86

87
88
    marlin_zp = torch.empty(0, dtype=torch.int, device=b.device)

89
    # GPTQ quant
90
91
92
    (w_ref, q_w, s, g_idx, rand_perm) = gptq_quantize_weights(
        b, quant_type, group_size, act_order
    )
93
    q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)
94
95
96
97
98
99
100
101

    # For act_order, sort the "weights" and "g_idx"
    # so that group ids are increasing
    repack_sort_indices = torch.empty(0, dtype=torch.int, device=b.device)
    if act_order:
        (q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx)

    # Prepare
102
103
104
    marlin_workspace = MarlinWorkspace(
        size_n, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL
    )
105

106
107
108
    marlin_24_workspace = MarlinWorkspace(
        size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL
    )
109
    marlin_zp = torch.zeros_like(marlin_s, dtype=torch.int)
110

111
    # AllSpark W8A16 quant
112
113
114
115
116
117
    as_supported_case = (
        quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES
        and group_size == -1
        and not act_order
        and is_k_full
    )
118
119
120
121
122
    if as_supported_case:
        properties = torch.cuda.get_device_properties(b.device.index)
        sm_count = properties.multi_processor_count
        sm_version = properties.major * 10 + properties.minor

123
        supported_arch = sm_version >= 80 and sm_version < 90
124
125
126
        as_supported_case = as_supported_case and supported_arch
        if supported_arch:
            has_zp = False
127
            w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size, has_zp)
128
129
            qw = qw.to(torch.uint8)

130
131
132
            qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight(
                qw, s, zp, has_zp
            )
133
134
            CUBLAS_M_THRESHOLD = ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD

135
    globals = {
136
        # Gen params
137
        "quant_type": quant_type,
138
139
140
141
142
143
144
        "group_size": group_size,
        "size_m": size_m,
        "size_n": size_n,
        "size_k": size_k,
        "a": a,
        "a_tmp": a_tmp,
        # Marlin params
145
146
147
        "marlin_w_ref": marlin_w_ref,
        "marlin_q_w": marlin_q_w,
        "marlin_s": marlin_s,
148
        "marlin_zp": marlin_zp,
149
150
151
        "marlin_g_idx": marlin_g_idx,
        "marlin_sort_indices": marlin_sort_indices,
        "marlin_rand_perm": marlin_rand_perm,
152
153
154
155
156
157
158
159
160
        "marlin_workspace": marlin_workspace,
        "is_k_full": is_k_full,
        # Marlin_24 params
        "marlin_24_w_ref": marlin_24_w_ref,
        "marlin_24_q_w_comp": marlin_24_q_w_comp,
        "marlin_24_meta": marlin_24_meta,
        "marlin_24_s": marlin_24_s,
        "marlin_24_workspace": marlin_24_workspace,
        # GPTQ params
161
162
        "q_w_gptq": q_w_gptq,
        "repack_sort_indices": repack_sort_indices,
163
164
165
166
167
168
        # AllSpark W8A16 params
        "qw_reorder": qw_reorder if as_supported_case else None,
        "s_reorder": s_reorder if as_supported_case else None,
        "zp_reorder": zp_reorder if as_supported_case else None,
        "sm_count": sm_count if as_supported_case else None,
        "sm_version": sm_version if as_supported_case else None,
169
        "CUBLAS_M_THRESHOLD": CUBLAS_M_THRESHOLD if as_supported_case else None,
170
        # Kernels
171
        "gptq_marlin_gemm": ops.gptq_marlin_gemm,
172
        "gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm,
173
        "gptq_marlin_repack": ops.gptq_marlin_repack,
174
        "allspark_w8a16_gemm": ops.allspark_w8a16_gemm,
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
    }

    min_run_time = 1

    # Warmup pytorch
    for i in range(5):
        torch.matmul(a, marlin_w_ref)

    results.append(
        benchmark.Timer(
            stmt="torch.matmul(a, marlin_w_ref)",
            globals=globals,
            label=label,
            sub_label=sub_label,
            description="pytorch_gemm",
190
191
        ).blocked_autorange(min_run_time=min_run_time)
    )
192
193
194

    results.append(
        benchmark.Timer(
195
            stmt="output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)",  # noqa: E501
196
197
198
199
            globals=globals,
            label=label,
            sub_label=sub_label,
            description="gptq_marlin_gemm_fp16",
200
201
        ).blocked_autorange(min_run_time=min_run_time)
    )
202
203
204

    results.append(
        benchmark.Timer(
205
            stmt="output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)",  # noqa: E501
206
207
208
            globals=globals,
            label=label,
            sub_label=sub_label,
209
            description="gptq_marlin_gemm_fp32",
210
211
        ).blocked_autorange(min_run_time=min_run_time)
    )
212

213
214
215
216
    if (
        quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
        and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
    ):
217
218
        results.append(
            benchmark.Timer(
219
                stmt="output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)",  # noqa: E501
220
221
222
223
                globals=globals,
                label=label,
                sub_label=sub_label,
                description="gptq_marlin_24_gemm",
224
225
            ).blocked_autorange(min_run_time=min_run_time)
        )
226

227
228
    results.append(
        benchmark.Timer(
229
            stmt="q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)",  # noqa: E501
230
231
232
233
            globals=globals,
            label=label,
            sub_label=sub_label,
            description="gptq_marlin_repack",
234
235
        ).blocked_autorange(min_run_time=min_run_time)
    )
236

237
238
239
    if as_supported_case:
        results.append(
            benchmark.Timer(
240
                stmt="output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)",  # noqa: E501
241
242
243
244
                globals=globals,
                label=label,
                sub_label=sub_label,
                description="allspark_w8a16_gemm_fp32",
245
246
            ).blocked_autorange(min_run_time=min_run_time)
        )
247

248
249
250
251
252
253

def main(args):
    print("Benchmarking models:")
    for i, model in enumerate(args.models):
        print(f"[{i}]  {model}")

254
    results: list[benchmark.Measurement] = []
255
256
257
258
259
260
261
262
263
264
265
266
267

    for model in args.models:
        for layer in WEIGHT_SHAPES[model]:
            size_k = layer[0]
            size_n = layer[1]

            if len(args.limit_k) > 0 and size_k not in args.limit_k:
                continue

            if len(args.limit_n) > 0 and size_n not in args.limit_n:
                continue

            for act_order in ACT_ORDER_OPTS:
268
269
270
271
                if (
                    len(args.limit_act_order) > 0
                    and act_order not in args.limit_act_order
                ):
272
273
                    continue

274
                for is_k_full in K_FULL_OPTS:
275
276
277
278
                    if (
                        len(args.limit_k_full) > 0
                        and is_k_full not in args.limit_k_full
                    ):
279
280
                        continue

281
282
283
284
285
                    for quant_type in query_marlin_supported_quant_types(False):
                        if (
                            len(args.limit_num_bits) > 0
                            and quant_type.size_bits not in args.limit_num_bits
                        ):
286
287
                            continue

288
                        for group_size in MARLIN_SUPPORTED_GROUP_SIZES:
289
290
291
292
                            if (
                                len(args.limit_group_size) > 0
                                and group_size not in args.limit_group_size
                            ):
293
294
295
296
                                continue

                            # For act_order, the group_size must be less than
                            # size_k
297
                            if act_order and (group_size == size_k or group_size == -1):
298
299
300
                                continue

                            for size_m in args.batch_sizes:
301
302
303
304
305
306
307
308
309
310
311
                                bench_run(
                                    results,
                                    model,
                                    act_order,
                                    is_k_full,
                                    quant_type,
                                    group_size,
                                    size_m,
                                    size_k,
                                    size_n,
                                )
312
313
314
315
316
317

    compare = benchmark.Compare(results)
    compare.print()


# For quick benchmarking use:
318
#   python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 --limit-num-bits 4 --limit-act-order 0 --limit-k-full 1 # noqa E501
319
320
#
if __name__ == "__main__":
321
    parser = FlexibleArgumentParser(
322
323
        description="Benchmark Marlin across specified models/shapes/batches"
    )
324
325
326
327
328
329
330
    parser.add_argument(
        "--models",
        nargs="+",
        type=str,
        default=DEFAULT_MODELS,
        choices=WEIGHT_SHAPES.keys(),
    )
331
332
333
    parser.add_argument(
        "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
    )
334
335
336
    parser.add_argument("--limit-k", nargs="+", type=int, default=[])
    parser.add_argument("--limit-n", nargs="+", type=int, default=[])
    parser.add_argument("--limit-group-size", nargs="+", type=int, default=[])
337
338
339
    parser.add_argument("--limit-num-bits", nargs="+", type=int, default=[])
    parser.add_argument("--limit-act-order", nargs="+", type=int, default=[])
    parser.add_argument("--limit-k-full", nargs="+", type=int, default=[])
340
341
342

    args = parser.parse_args()
    main(args)