benchmark_grouped_gemm_cutlass.py 12.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9

import torch
import torch.utils.benchmark as benchmark
from benchmark_shapes import WEIGHT_SHAPES_MOE

from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
10
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
11
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8
12
13
14
15
from vllm.model_executor.layers.fused_moe.fused_moe import (
    fused_experts,
    fused_topk,
)
16
17
18
from vllm.utils import FlexibleArgumentParser

DEFAULT_MODELS = [
19
20
21
22
    "nm-testing/Mixtral-8x7B-Instruct-v0.1",
    "nm-testing/deepseekv2-lite",
    "ibm-granite/granite-3.0-1b-a400m",
    "ibm-granite/granite-3.0-3b-a800m",
23
24
25
26
27
28
29
30
31
32
]
DEFAULT_BATCH_SIZES = [1, 4, 8, 16, 32, 64, 128, 256, 512]
DEFAULT_TP_SIZES = [1]

PER_ACT_TOKEN_OPTS = [False]
PER_OUT_CH_OPTS = [False]


def to_fp8(tensor: torch.Tensor):
    finfo = torch.finfo(torch.float8_e4m3fn)
33
34
35
    return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to(
        dtype=torch.float8_e4m3fn
    )
36
37


38
39
40
41
42
43
44
45
46
def bench_run(
    results: list[benchmark.Measurement],
    model: str,
    num_experts: int,
    topk: int,
    per_act_token: bool,
    per_out_ch: bool,
    mkn: tuple[int, int, int],
):
47
48
49
    label = "Quant Matmul"

    sub_label = (
50
51
52
53
        "{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, MKN=({})".format(
            model, num_experts, topk, per_act_token, per_out_ch, mkn
        )
    )
54
55
56
57
58
59
60
61
62
63
64
65
66

    print(f"Testing: {sub_label}")

    (m, k, n) = mkn

    dtype = torch.half

    a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
    w1 = torch.randn((num_experts, 2 * n, k), device="cuda", dtype=dtype) / 10
    w2 = torch.randn((num_experts, k, n), device="cuda", dtype=dtype) / 10

    _, a_scale = ops.scaled_fp8_quant(a)

67
68
69
70
71
72
73
    w1_q = torch.empty(
        (num_experts, 2 * n, k), device="cuda", dtype=torch.float8_e4m3fn
    )
    w2_q = torch.empty((num_experts, k, n), device="cuda", dtype=torch.float8_e4m3fn)
    w1_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32)
    w2_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32)

74
75
76
77
78
79
    for expert in range(num_experts):
        w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert])
        w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(w2[expert])

    score = torch.randn((m, num_experts), device="cuda", dtype=dtype)

80
    topk_weights, topk_ids, token_expert_indices = fused_topk(
81
82
        a, score, topk, renormalize=False
    )
83

84
85
86
87
88
    ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
    ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64)
    c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64)
    c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)

89
90
91
92
93
94
95
96
97
98
99
    def run_triton_moe(
        a: torch.Tensor,
        w1: torch.Tensor,
        w2: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        w1_scale: torch.Tensor,
        w2_scale: torch.Tensor,
        a_scale: torch.Tensor,
        num_repeats: int,
    ):
100
101
102
103
104
        quant_config = fp8_w8a8_moe_quant_config(
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            a1_scale=a_scale,
        )
105
        for _ in range(num_repeats):
106
107
108
109
110
111
            fused_experts(
                a,
                w1,
                w2,
                topk_weights,
                topk_ids,
112
                quant_config=quant_config,
113
114
115
116
117
118
119
120
121
            )

    def run_cutlass_moe(
        a: torch.Tensor,
        a_scale: torch.Tensor,
        w1: torch.Tensor,
        w2: torch.Tensor,
        w1_scale: torch.Tensor,
        w2_scale: torch.Tensor,
122
123
124
125
        ab_strides1: torch.Tensor,
        ab_strides2: torch.Tensor,
        c_strides1: torch.Tensor,
        c_strides2: torch.Tensor,
126
127
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
bnellnm's avatar
bnellnm committed
128
        per_act_token: bool,
129
130
        num_repeats: int,
    ):
131
132
133
134
135
136
        quant_config = fp8_w8a8_moe_quant_config(
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            per_act_token_quant=per_act_token,
        )

137
        for _ in range(num_repeats):
138
139
140
141
142
143
            cutlass_moe_fp8(
                a,
                w1,
                w2,
                topk_weights,
                topk_ids,
144
145
146
147
                ab_strides1,
                ab_strides2,
                c_strides1,
                c_strides2,
148
                quant_config=quant_config,
149
            )
150
151

    def run_cutlass_from_graph(
152
153
154
155
156
157
        a: torch.Tensor,
        a_scale: torch.Tensor,
        w1_q: torch.Tensor,
        w2_q: torch.Tensor,
        w1_scale: torch.Tensor,
        w2_scale: torch.Tensor,
158
159
160
161
        ab_strides1: torch.Tensor,
        ab_strides2: torch.Tensor,
        c_strides1: torch.Tensor,
        c_strides2: torch.Tensor,
162
163
164
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
    ):
165
166
167
168
169
170
        quant_config = fp8_w8a8_moe_quant_config(
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            per_act_token_quant=per_act_token,
        )

171
        with set_current_vllm_config(
172
173
174
175
176
177
178
179
            VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
        ):
            return cutlass_moe_fp8(
                a,
                w1_q,
                w2_q,
                topk_weights,
                topk_ids,
180
181
182
183
                ab_strides1,
                ab_strides2,
                c_strides1,
                c_strides2,
184
                quant_config=quant_config,
185
186
187
188
189
190
191
192
193
194
195
196
            )

    def run_triton_from_graph(
        a: torch.Tensor,
        w1: torch.Tensor,
        w2: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        w1_scale: torch.Tensor,
        w2_scale: torch.Tensor,
        a_scale: torch.Tensor,
    ):
197
198
199
200
201
        quant_config = fp8_w8a8_moe_quant_config(
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            a1_scale=a_scale,
        )
202
        with set_current_vllm_config(
203
204
205
206
207
208
209
210
            VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
        ):
            return fused_experts(
                a,
                w1,
                w2,
                topk_weights,
                topk_ids,
211
                quant_config=quant_config,
212
            )
213
214
215
216
217
218
219
220
221

    def replay_graph(graph, num_repeats):
        for _ in range(num_repeats):
            graph.replay()
        torch.cuda.synchronize()

    cutlass_stream = torch.cuda.Stream()
    cutlass_graph = torch.cuda.CUDAGraph()
    with torch.cuda.graph(cutlass_graph, stream=cutlass_stream):
222
223
224
225
226
227
228
        run_cutlass_from_graph(
            a,
            a_scale,
            w1_q,
            w2_q,
            w1_scale,
            w2_scale,
229
230
231
232
            ab_strides1,
            ab_strides2,
            c_strides1,
            c_strides2,
233
234
235
            topk_weights,
            topk_ids,
        )
236
237
238
239
240
    torch.cuda.synchronize()

    triton_stream = torch.cuda.Stream()
    triton_graph = torch.cuda.CUDAGraph()
    with torch.cuda.graph(triton_graph, stream=triton_stream):
241
242
        run_triton_from_graph(
            a,
243
244
            w1_q,
            w2_q,
245
246
247
248
249
250
            topk_weights,
            topk_ids,
            w1_scale,
            w2_scale,
            a_scale,
        )
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
    torch.cuda.synchronize()

    min_run_time = 5
    num_warmup = 5
    num_runs = 25

    globals = {
        # Baseline params
        "w1": w1,
        "w2": w2,
        "score": score,
        "topk": topk,
        # Cutlass params
        "a_scale": a_scale,
        "w1_q": w1_q,
        "w2_q": w2_q,
        "w1_scale": w1_scale,
        "w2_scale": w2_scale,
bnellnm's avatar
bnellnm committed
269
        "per_act_token": per_act_token,
270
271
272
273
        "ab_strides1": ab_strides1,
        "ab_strides2": ab_strides2,
        "c_strides1": c_strides1,
        "c_strides2": c_strides2,
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
        # cuda graph params
        "cutlass_graph": cutlass_graph,
        "triton_graph": triton_graph,
        # Gen params
        "a": a,
        "topk_weights": topk_weights,
        "topk_ids": topk_ids,
        "num_runs": num_runs,
        # Kernels
        "run_triton_moe": run_triton_moe,
        "run_cutlass_moe": run_cutlass_moe,
        "replay_graph": replay_graph,
    }

    # Warmup
289
290
    run_triton_moe(
        a,
291
292
        w1_q,
        w2_q,
293
294
295
296
297
298
299
        topk_weights,
        topk_ids,
        w1_scale,
        w2_scale,
        a_scale,
        num_warmup,
    )
300
301
302

    results.append(
        benchmark.Timer(
303
            stmt="run_triton_moe(a, w1_q, w2_q, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)",  # noqa: E501
304
305
306
307
            globals=globals,
            label=label,
            sub_label=sub_label,
            description="triton_moe",
308
309
        ).blocked_autorange(min_run_time=min_run_time)
    )
310
311
312
313
314
315
316
317
318
319
320

    # Warmup
    replay_graph(triton_graph, num_warmup)

    results.append(
        benchmark.Timer(
            stmt="replay_graph(triton_graph, num_runs)",
            globals=globals,
            label=label,
            sub_label=sub_label,
            description="triton_moe_cuda_graphs",
321
322
        ).blocked_autorange(min_run_time=min_run_time)
    )
323
324

    # Warmup
325
326
327
328
329
330
331
    run_cutlass_moe(
        a,
        a_scale,
        w1_q,
        w2_q,
        w1_scale,
        w2_scale,
332
333
334
335
        ab_strides1,
        ab_strides2,
        c_strides1,
        c_strides2,
336
337
        topk_weights,
        topk_ids,
bnellnm's avatar
bnellnm committed
338
        per_act_token,
339
340
        num_warmup,
    )
341
342
343

    results.append(
        benchmark.Timer(
344
            stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, ab_strides1, ab_strides2, c_strides1, c_strides2, topk_weights, topk_ids, per_act_token, num_runs)",  # noqa: E501
345
346
347
348
            globals=globals,
            label=label,
            sub_label=sub_label,
            description="grouped_gemm_moe",
349
350
        ).blocked_autorange(min_run_time=min_run_time)
    )
351
352
353
354
355
356
357
358
359
360
361

    # Warmup
    replay_graph(cutlass_graph, num_warmup)

    results.append(
        benchmark.Timer(
            stmt="replay_graph(cutlass_graph, num_runs)",
            globals=globals,
            label=label,
            sub_label=sub_label,
            description="grouped_gemm_moe_cuda_graphs",
362
363
        ).blocked_autorange(min_run_time=min_run_time)
    )
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390


def main(args):
    print("Benchmarking models:")
    for i, model in enumerate(args.models):
        print(f"[{i}]  {model}")

    results: list[benchmark.Measurement] = []

    for model in args.models:
        for tp in args.tp_sizes:
            for layer in WEIGHT_SHAPES_MOE[model]:
                num_experts = layer[0]
                topk = layer[1]
                size_k = layer[2]
                size_n = layer[3] // tp

                if len(args.limit_k) > 0 and size_k not in args.limit_k:
                    continue

                if len(args.limit_n) > 0 and size_n not in args.limit_n:
                    continue

                for per_act_token in PER_ACT_TOKEN_OPTS:
                    for per_out_ch in PER_OUT_CH_OPTS:
                        for size_m in DEFAULT_BATCH_SIZES:
                            mkn = (size_m, size_k, size_n)
391
392
393
394
395
396
397
398
399
                            bench_run(
                                results,
                                model,
                                num_experts,
                                topk,
                                per_act_token,
                                per_out_ch,
                                mkn,
                            )
400
401
402
403
404
405
406

    compare = benchmark.Compare(results)
    compare.print()


if __name__ == "__main__":
    parser = FlexibleArgumentParser(
407
408
        description="Benchmark Marlin across specified models/shapes/batches"
    )
409
410
411
412
413
414
415
    parser.add_argument(
        "--models",
        nargs="+",
        type=str,
        default=DEFAULT_MODELS,
        choices=WEIGHT_SHAPES_MOE.keys(),
    )
416
417
418
419
    parser.add_argument("--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES)
    parser.add_argument(
        "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
    )
420
421
422
    parser.add_argument("--limit-k", nargs="+", type=int, default=[])
    parser.add_argument("--limit-n", nargs="+", type=int, default=[])
    parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[])
423
    parser.add_argument("--limit-per-act-token", nargs="+", type=int, default=[])
424
425
426
427
    parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[])

    args = parser.parse_args()
    main(args)