benchmark_grouped_gemm_cutlass.py 11.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9

import torch
import torch.utils.benchmark as benchmark
from benchmark_shapes import WEIGHT_SHAPES_MOE

from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
10
11
12
13
14
from vllm.model_executor.layers.fused_moe.fused_moe import (
    cutlass_moe_fp8,
    fused_experts,
    fused_topk,
)
15
16
17
from vllm.utils import FlexibleArgumentParser

DEFAULT_MODELS = [
18
19
20
21
    "nm-testing/Mixtral-8x7B-Instruct-v0.1",
    "nm-testing/deepseekv2-lite",
    "ibm-granite/granite-3.0-1b-a400m",
    "ibm-granite/granite-3.0-3b-a800m",
22
23
24
25
26
27
28
29
30
31
]
DEFAULT_BATCH_SIZES = [1, 4, 8, 16, 32, 64, 128, 256, 512]
DEFAULT_TP_SIZES = [1]

PER_ACT_TOKEN_OPTS = [False]
PER_OUT_CH_OPTS = [False]


def to_fp8(tensor: torch.Tensor):
    finfo = torch.finfo(torch.float8_e4m3fn)
32
33
34
    return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to(
        dtype=torch.float8_e4m3fn
    )
35
36


37
38
39
40
41
42
43
44
45
def bench_run(
    results: list[benchmark.Measurement],
    model: str,
    num_experts: int,
    topk: int,
    per_act_token: bool,
    per_out_ch: bool,
    mkn: tuple[int, int, int],
):
46
47
48
    label = "Quant Matmul"

    sub_label = (
49
50
51
52
        "{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, MKN=({})".format(
            model, num_experts, topk, per_act_token, per_out_ch, mkn
        )
    )
53
54
55
56
57
58
59
60
61
62
63
64
65

    print(f"Testing: {sub_label}")

    (m, k, n) = mkn

    dtype = torch.half

    a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
    w1 = torch.randn((num_experts, 2 * n, k), device="cuda", dtype=dtype) / 10
    w2 = torch.randn((num_experts, k, n), device="cuda", dtype=dtype) / 10

    _, a_scale = ops.scaled_fp8_quant(a)

66
67
68
69
70
71
72
73
74
75
76
    w1_q = torch.empty(
        (num_experts, 2 * n, k), device="cuda", dtype=torch.float8_e4m3fn
    )
    w2_q = torch.empty((num_experts, k, n), device="cuda", dtype=torch.float8_e4m3fn)
    w1_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32)
    w2_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32)

    ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
    c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64)
    ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64)
    c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
77
78
79
80
81
82
83
84
85
86
87

    for expert in range(num_experts):
        w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert])
        w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(w2[expert])
    w1_q_notransp = w1_q.clone()
    w2_q_notransp = w2_q.clone()
    w1_q = w1_q.transpose(1, 2)
    w2_q = w2_q.transpose(1, 2)

    score = torch.randn((m, num_experts), device="cuda", dtype=dtype)

88
    topk_weights, topk_ids, token_expert_indices = fused_topk(
89
90
        a, score, topk, renormalize=False
    )
91

92
93
94
95
96
97
98
99
100
101
102
    def run_triton_moe(
        a: torch.Tensor,
        w1: torch.Tensor,
        w2: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        w1_scale: torch.Tensor,
        w2_scale: torch.Tensor,
        a_scale: torch.Tensor,
        num_repeats: int,
    ):
103
        for _ in range(num_repeats):
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
            fused_experts(
                a,
                w1,
                w2,
                topk_weights,
                topk_ids,
                use_fp8_w8a8=True,
                w1_scale=w1_scale,
                w2_scale=w2_scale,
                a1_scale=a_scale,
            )

    def run_cutlass_moe(
        a: torch.Tensor,
        a_scale: torch.Tensor,
        w1: torch.Tensor,
        w2: torch.Tensor,
        w1_scale: torch.Tensor,
        w2_scale: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        ab_strides1: torch.Tensor,
        c_strides1: torch.Tensor,
        ab_strides2: torch.Tensor,
        c_strides2: torch.Tensor,
        num_repeats: int,
    ):
131
        for _ in range(num_repeats):
132
133
134
135
136
137
138
139
140
141
142
143
144
145
            cutlass_moe_fp8(
                a,
                w1,
                w2,
                w1_scale,
                w2_scale,
                topk_weights,
                topk_ids,
                ab_strides1,
                c_strides1,
                ab_strides2,
                c_strides2,
                a1_scale=a_scale,
            )
146
147

    def run_cutlass_from_graph(
148
149
150
151
152
153
154
155
156
157
158
159
160
        a: torch.Tensor,
        a_scale: torch.Tensor,
        w1_q: torch.Tensor,
        w2_q: torch.Tensor,
        w1_scale: torch.Tensor,
        w2_scale: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        ab_strides1: torch.Tensor,
        c_strides1: torch.Tensor,
        ab_strides2: torch.Tensor,
        c_strides2: torch.Tensor,
    ):
161
        with set_current_vllm_config(
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
            VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
        ):
            return cutlass_moe_fp8(
                a,
                w1_q,
                w2_q,
                w1_scale,
                w2_scale,
                topk_weights,
                topk_ids,
                ab_strides1,
                c_strides1,
                ab_strides2,
                c_strides2,
                a1_scale=a_scale,
            )

    def run_triton_from_graph(
        a: torch.Tensor,
        w1: torch.Tensor,
        w2: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        w1_scale: torch.Tensor,
        w2_scale: torch.Tensor,
        a_scale: torch.Tensor,
    ):
189
        with set_current_vllm_config(
190
191
192
193
194
195
196
197
198
199
200
201
202
            VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
        ):
            return fused_experts(
                a,
                w1,
                w2,
                topk_weights,
                topk_ids,
                use_fp8_w8a8=True,
                w1_scale=w1_scale,
                w2_scale=w2_scale,
                a1_scale=a_scale,
            )
203
204
205
206
207
208
209
210
211

    def replay_graph(graph, num_repeats):
        for _ in range(num_repeats):
            graph.replay()
        torch.cuda.synchronize()

    cutlass_stream = torch.cuda.Stream()
    cutlass_graph = torch.cuda.CUDAGraph()
    with torch.cuda.graph(cutlass_graph, stream=cutlass_stream):
212
213
214
215
216
217
218
219
220
221
222
223
224
225
        run_cutlass_from_graph(
            a,
            a_scale,
            w1_q,
            w2_q,
            w1_scale,
            w2_scale,
            topk_weights,
            topk_ids,
            ab_strides1,
            c_strides1,
            ab_strides2,
            c_strides2,
        )
226
227
228
229
230
    torch.cuda.synchronize()

    triton_stream = torch.cuda.Stream()
    triton_graph = torch.cuda.CUDAGraph()
    with torch.cuda.graph(triton_graph, stream=triton_stream):
231
232
233
234
235
236
237
238
239
240
        run_triton_from_graph(
            a,
            w1_q_notransp,
            w2_q_notransp,
            topk_weights,
            topk_ids,
            w1_scale,
            w2_scale,
            a_scale,
        )
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
    torch.cuda.synchronize()

    min_run_time = 5
    num_warmup = 5
    num_runs = 25

    globals = {
        # Baseline params
        "w1": w1,
        "w2": w2,
        "score": score,
        "topk": topk,
        "w1_q_notransp": w1_q_notransp,
        "w2_q_notransp": w2_q_notransp,
        # Cutlass params
        "a_scale": a_scale,
        "w1_q": w1_q,
        "w2_q": w2_q,
        "w1_scale": w1_scale,
        "w2_scale": w2_scale,
        "ab_strides1": ab_strides1,
        "c_strides1": c_strides1,
        "ab_strides2": ab_strides2,
        "c_strides2": c_strides2,
        # cuda graph params
        "cutlass_graph": cutlass_graph,
        "triton_graph": triton_graph,
        # Gen params
        "a": a,
        "topk_weights": topk_weights,
        "topk_ids": topk_ids,
        "num_runs": num_runs,
        # Kernels
        "run_triton_moe": run_triton_moe,
        "run_cutlass_moe": run_cutlass_moe,
        "replay_graph": replay_graph,
    }

    # Warmup
280
281
282
283
284
285
286
287
288
289
290
    run_triton_moe(
        a,
        w1_q_notransp,
        w2_q_notransp,
        topk_weights,
        topk_ids,
        w1_scale,
        w2_scale,
        a_scale,
        num_warmup,
    )
291
292
293

    results.append(
        benchmark.Timer(
294
            stmt="run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)",  # noqa: E501
295
296
297
298
            globals=globals,
            label=label,
            sub_label=sub_label,
            description="triton_moe",
299
300
        ).blocked_autorange(min_run_time=min_run_time)
    )
301
302
303
304
305
306
307
308
309
310
311

    # Warmup
    replay_graph(triton_graph, num_warmup)

    results.append(
        benchmark.Timer(
            stmt="replay_graph(triton_graph, num_runs)",
            globals=globals,
            label=label,
            sub_label=sub_label,
            description="triton_moe_cuda_graphs",
312
313
        ).blocked_autorange(min_run_time=min_run_time)
    )
314
315

    # Warmup
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
    run_cutlass_moe(
        a,
        a_scale,
        w1_q,
        w2_q,
        w1_scale,
        w2_scale,
        topk_weights,
        topk_ids,
        ab_strides1,
        c_strides1,
        ab_strides2,
        c_strides2,
        num_warmup,
    )
331
332
333

    results.append(
        benchmark.Timer(
334
            stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, num_runs)",  # noqa: E501
335
336
337
338
            globals=globals,
            label=label,
            sub_label=sub_label,
            description="grouped_gemm_moe",
339
340
        ).blocked_autorange(min_run_time=min_run_time)
    )
341
342
343
344
345
346
347
348
349
350
351

    # Warmup
    replay_graph(cutlass_graph, num_warmup)

    results.append(
        benchmark.Timer(
            stmt="replay_graph(cutlass_graph, num_runs)",
            globals=globals,
            label=label,
            sub_label=sub_label,
            description="grouped_gemm_moe_cuda_graphs",
352
353
        ).blocked_autorange(min_run_time=min_run_time)
    )
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380


def main(args):
    print("Benchmarking models:")
    for i, model in enumerate(args.models):
        print(f"[{i}]  {model}")

    results: list[benchmark.Measurement] = []

    for model in args.models:
        for tp in args.tp_sizes:
            for layer in WEIGHT_SHAPES_MOE[model]:
                num_experts = layer[0]
                topk = layer[1]
                size_k = layer[2]
                size_n = layer[3] // tp

                if len(args.limit_k) > 0 and size_k not in args.limit_k:
                    continue

                if len(args.limit_n) > 0 and size_n not in args.limit_n:
                    continue

                for per_act_token in PER_ACT_TOKEN_OPTS:
                    for per_out_ch in PER_OUT_CH_OPTS:
                        for size_m in DEFAULT_BATCH_SIZES:
                            mkn = (size_m, size_k, size_n)
381
382
383
384
385
386
387
388
389
                            bench_run(
                                results,
                                model,
                                num_experts,
                                topk,
                                per_act_token,
                                per_out_ch,
                                mkn,
                            )
390
391
392
393
394
395
396

    compare = benchmark.Compare(results)
    compare.print()


if __name__ == "__main__":
    parser = FlexibleArgumentParser(
397
398
        description="Benchmark Marlin across specified models/shapes/batches"
    )
399
400
401
402
403
404
405
    parser.add_argument(
        "--models",
        nargs="+",
        type=str,
        default=DEFAULT_MODELS,
        choices=WEIGHT_SHAPES_MOE.keys(),
    )
406
407
408
409
    parser.add_argument("--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES)
    parser.add_argument(
        "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
    )
410
411
412
    parser.add_argument("--limit-k", nargs="+", type=int, default=[])
    parser.add_argument("--limit-n", nargs="+", type=int, default=[])
    parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[])
413
    parser.add_argument("--limit-per-act-token", nargs="+", type=int, default=[])
414
415
416
417
    parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[])

    args = parser.parse_args()
    main(args)