benchmark_marlin.py 9.63 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from typing import List
4
5
6
7
8
9

import torch
import torch.utils.benchmark as benchmark
from benchmark_shapes import WEIGHT_SHAPES

from vllm import _custom_ops as ops
10
11
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
    GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
12
    GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
13
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
14
    GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
15
    MARLIN_SUPPORTED_GROUP_SIZES, query_marlin_supported_quant_types)
16
17
18
19
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
    MarlinWorkspace, marlin_quantize)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
    marlin_24_quantize)
20
from vllm.model_executor.layers.quantization.utils.quant_utils import (
21
22
    gptq_pack, gptq_quantize_weights, sort_weights)
from vllm.scalar_type import ScalarType
23
from vllm.utils import FlexibleArgumentParser
24
25
26
27
28
29
30
31

DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"]
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]

ACT_ORDER_OPTS = [False, True]
K_FULL_OPTS = [False, True]


32
def bench_run(results: List[benchmark.Measurement], model: str,
33
34
              act_order: bool, is_k_full: bool, quant_type: ScalarType,
              group_size: int, size_m: int, size_k: int, size_n: int):
35
36
    label = "Quant Matmul"

37
38
39
40
    sub_label = ("{}, act={} k_full={}, q={}, g={}, "
                 "MKN=({}x{}x{})".format(model, act_order, is_k_full,
                                         str(quant_type), group_size, size_m,
                                         size_k, size_n))
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56

    print(f"Testing: {sub_label}")

    a = torch.randn(size_m, size_k).to(torch.half).cuda()
    b = torch.rand(size_k, size_n).to(torch.half).cuda()

    a_tmp = (torch.zeros(size_m, size_k).to(torch.half).cuda())

    # Marlin quant
    (
        marlin_w_ref,
        marlin_q_w,
        marlin_s,
        marlin_g_idx,
        marlin_sort_indices,
        marlin_rand_perm,
57
    ) = marlin_quantize(b, quant_type, group_size, act_order)
58

59
60
    # Marlin_24 quant
    (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta,
61
     marlin_24_s) = marlin_24_quantize(b, quant_type, group_size)
62

63
64
    marlin_zp = torch.empty(0, dtype=torch.int, device=b.device)

65
66
    # GPTQ quant
    (w_ref, q_w, s, g_idx,
67
68
     rand_perm) = gptq_quantize_weights(b, quant_type, group_size, act_order)
    q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)
69
70
71
72
73
74
75
76

    # For act_order, sort the "weights" and "g_idx"
    # so that group ids are increasing
    repack_sort_indices = torch.empty(0, dtype=torch.int, device=b.device)
    if act_order:
        (q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx)

    # Prepare
77
78
79
80
81
    marlin_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
                                       GPTQ_MARLIN_MAX_PARALLEL)

    marlin_24_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N,
                                          GPTQ_MARLIN_24_MAX_PARALLEL)
82
    marlin_zp = torch.zeros_like(marlin_s, dtype=torch.int)
83
84

    globals = {
85
        # Gen params
86
        "quant_type": quant_type,
87
88
89
90
91
92
93
        "group_size": group_size,
        "size_m": size_m,
        "size_n": size_n,
        "size_k": size_k,
        "a": a,
        "a_tmp": a_tmp,
        # Marlin params
94
95
96
        "marlin_w_ref": marlin_w_ref,
        "marlin_q_w": marlin_q_w,
        "marlin_s": marlin_s,
97
        "marlin_zp": marlin_zp,
98
99
100
        "marlin_g_idx": marlin_g_idx,
        "marlin_sort_indices": marlin_sort_indices,
        "marlin_rand_perm": marlin_rand_perm,
101
102
103
104
105
106
107
108
109
        "marlin_workspace": marlin_workspace,
        "is_k_full": is_k_full,
        # Marlin_24 params
        "marlin_24_w_ref": marlin_24_w_ref,
        "marlin_24_q_w_comp": marlin_24_q_w_comp,
        "marlin_24_meta": marlin_24_meta,
        "marlin_24_s": marlin_24_s,
        "marlin_24_workspace": marlin_24_workspace,
        # GPTQ params
110
111
        "q_w_gptq": q_w_gptq,
        "repack_sort_indices": repack_sort_indices,
112
        # Kernels
113
        "gptq_marlin_gemm": ops.gptq_marlin_gemm,
114
        "gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm,
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
        "gptq_marlin_repack": ops.gptq_marlin_repack,
    }

    min_run_time = 1

    # Warmup pytorch
    for i in range(5):
        torch.matmul(a, marlin_w_ref)

    results.append(
        benchmark.Timer(
            stmt="torch.matmul(a, marlin_w_ref)",
            globals=globals,
            label=label,
            sub_label=sub_label,
            description="pytorch_gemm",
        ).blocked_autorange(min_run_time=min_run_time))

    results.append(
        benchmark.Timer(
            stmt=
136
            "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)",  # noqa: E501
137
138
139
140
141
142
143
144
145
            globals=globals,
            label=label,
            sub_label=sub_label,
            description="gptq_marlin_gemm_fp16",
        ).blocked_autorange(min_run_time=min_run_time))

    results.append(
        benchmark.Timer(
            stmt=
146
            "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)",  # noqa: E501
147
148
149
            globals=globals,
            label=label,
            sub_label=sub_label,
150
            description="gptq_marlin_gemm_fp32",
151
152
        ).blocked_autorange(min_run_time=min_run_time))

153
    if (quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
154
155
156
157
            and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES):
        results.append(
            benchmark.Timer(
                stmt=
158
                "output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)",  # noqa: E501
159
160
161
162
163
164
                globals=globals,
                label=label,
                sub_label=sub_label,
                description="gptq_marlin_24_gemm",
            ).blocked_autorange(min_run_time=min_run_time))

165
166
167
    results.append(
        benchmark.Timer(
            stmt=
168
            "q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)",  # noqa: E501
169
170
171
172
173
174
175
176
177
178
179
180
            globals=globals,
            label=label,
            sub_label=sub_label,
            description="gptq_marlin_repack",
        ).blocked_autorange(min_run_time=min_run_time))


def main(args):
    print("Benchmarking models:")
    for i, model in enumerate(args.models):
        print(f"[{i}]  {model}")

181
    results: List[benchmark.Measurement] = []
182
183
184
185
186
187
188
189
190
191
192
193
194

    for model in args.models:
        for layer in WEIGHT_SHAPES[model]:
            size_k = layer[0]
            size_n = layer[1]

            if len(args.limit_k) > 0 and size_k not in args.limit_k:
                continue

            if len(args.limit_n) > 0 and size_n not in args.limit_n:
                continue

            for act_order in ACT_ORDER_OPTS:
195
196
197
198
                if len(args.limit_act_order
                       ) > 0 and act_order not in args.limit_act_order:
                    continue

199
                for is_k_full in K_FULL_OPTS:
200
201
202
203
                    if len(args.limit_k_full
                           ) > 0 and is_k_full not in args.limit_k_full:
                        continue

204
205
206
207
                    for quant_type in query_marlin_supported_quant_types(
                            False):
                        if len(args.limit_num_bits) > 0 and \
                            quant_type.size_bits not in args.limit_num_bits:
208
209
                            continue

210
                        for group_size in MARLIN_SUPPORTED_GROUP_SIZES:
211
212
213
214
215
216
217
218
219
220
221
222
223
                            if len(
                                    args.limit_group_size
                            ) > 0 and group_size not in args.limit_group_size:
                                continue

                            # For act_order, the group_size must be less than
                            # size_k
                            if act_order and (group_size == size_k
                                              or group_size == -1):
                                continue

                            for size_m in args.batch_sizes:
                                bench_run(results, model, act_order, is_k_full,
224
225
                                          quant_type, group_size, size_m,
                                          size_k, size_n)
226
227
228
229
230
231

    compare = benchmark.Compare(results)
    compare.print()


# For quick benchmarking use:
232
#   python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 --limit-num-bits 4 --limit-act-order 0 --limit-k-full 1 # noqa E501
233
234
#
if __name__ == "__main__":
235
    parser = FlexibleArgumentParser(
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
        description="Benchmark Marlin across specified models/shapes/batches")
    parser.add_argument(
        "--models",
        nargs="+",
        type=str,
        default=DEFAULT_MODELS,
        choices=WEIGHT_SHAPES.keys(),
    )
    parser.add_argument("--batch-sizes",
                        nargs="+",
                        type=int,
                        default=DEFAULT_BATCH_SIZES)
    parser.add_argument("--limit-k", nargs="+", type=int, default=[])
    parser.add_argument("--limit-n", nargs="+", type=int, default=[])
    parser.add_argument("--limit-group-size", nargs="+", type=int, default=[])
251
252
253
    parser.add_argument("--limit-num-bits", nargs="+", type=int, default=[])
    parser.add_argument("--limit-act-order", nargs="+", type=int, default=[])
    parser.add_argument("--limit-k-full", nargs="+", type=int, default=[])
254
255
256

    args = parser.parse_args()
    main(args)