benchmark_grouped_gemm_cutlass.py 12 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7

import torch
import torch.utils.benchmark as benchmark
from benchmark_shapes import WEIGHT_SHAPES_MOE

8
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
9
10
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
11
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
12
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
13
14
15
16
from vllm.model_executor.layers.fused_moe.fused_moe import (
    fused_experts,
    fused_topk,
)
17
18
19
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
    MoEPrepareAndFinalizeNoEP,
)
20
from vllm.utils.argparse_utils import FlexibleArgumentParser
21
from vllm.v1.worker.workspace import init_workspace_manager
22
23

DEFAULT_MODELS = [
24
25
    "mistralai/Mixtral-8x7B-Instruct-v0.1",
    "deepseek-ai/DeepSeek-V2-Lite",
26
27
    "ibm-granite/granite-3.0-1b-a400m",
    "ibm-granite/granite-3.0-3b-a800m",
28
29
30
31
32
33
34
35
36
37
]
DEFAULT_BATCH_SIZES = [1, 4, 8, 16, 32, 64, 128, 256, 512]
DEFAULT_TP_SIZES = [1]

PER_ACT_TOKEN_OPTS = [False]
PER_OUT_CH_OPTS = [False]


def to_fp8(tensor: torch.Tensor):
    finfo = torch.finfo(torch.float8_e4m3fn)
38
39
40
    return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to(
        dtype=torch.float8_e4m3fn
    )
41
42


43
44
45
46
47
48
49
50
51
def bench_run(
    results: list[benchmark.Measurement],
    model: str,
    num_experts: int,
    topk: int,
    per_act_token: bool,
    per_out_ch: bool,
    mkn: tuple[int, int, int],
):
52
    init_workspace_manager(torch.cuda.current_device())
53
54
55
    label = "Quant Matmul"

    sub_label = (
56
57
58
59
        "{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, MKN=({})".format(
            model, num_experts, topk, per_act_token, per_out_ch, mkn
        )
    )
60
61
62
63
64
65
66
67
68
69
70
71
72

    print(f"Testing: {sub_label}")

    (m, k, n) = mkn

    dtype = torch.half

    a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
    w1 = torch.randn((num_experts, 2 * n, k), device="cuda", dtype=dtype) / 10
    w2 = torch.randn((num_experts, k, n), device="cuda", dtype=dtype) / 10

    _, a_scale = ops.scaled_fp8_quant(a)

73
74
75
76
77
78
79
    w1_q = torch.empty(
        (num_experts, 2 * n, k), device="cuda", dtype=torch.float8_e4m3fn
    )
    w2_q = torch.empty((num_experts, k, n), device="cuda", dtype=torch.float8_e4m3fn)
    w1_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32)
    w2_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32)

80
81
82
83
84
85
    for expert in range(num_experts):
        w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert])
        w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(w2[expert])

    score = torch.randn((m, num_experts), device="cuda", dtype=dtype)

86
    topk_weights, topk_ids, token_expert_indices = fused_topk(
87
88
        a, score, topk, renormalize=False
    )
89

90
91
92
93
94
95
96
97
98
99
100
    def run_triton_moe(
        a: torch.Tensor,
        w1: torch.Tensor,
        w2: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        w1_scale: torch.Tensor,
        w2_scale: torch.Tensor,
        a_scale: torch.Tensor,
        num_repeats: int,
    ):
101
102
103
104
105
        quant_config = fp8_w8a8_moe_quant_config(
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            a1_scale=a_scale,
        )
106
        for _ in range(num_repeats):
107
108
109
110
111
112
            fused_experts(
                a,
                w1,
                w2,
                topk_weights,
                topk_ids,
113
                quant_config=quant_config,
114
115
116
117
118
119
120
121
122
123
124
            )

    def run_cutlass_moe(
        a: torch.Tensor,
        a_scale: torch.Tensor,
        w1: torch.Tensor,
        w2: torch.Tensor,
        w1_scale: torch.Tensor,
        w2_scale: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
bnellnm's avatar
bnellnm committed
125
        per_act_token: bool,
126
127
        num_repeats: int,
    ):
128
129
130
131
132
133
        quant_config = fp8_w8a8_moe_quant_config(
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            per_act_token_quant=per_act_token,
        )

134
135
136
137
138
139
140
141
        fn = mk.FusedMoEModularKernel(
            MoEPrepareAndFinalizeNoEP(),
            CutlassExpertsFp8(
                out_dtype=a.dtype,
                # NOTE(rob): w2 is shaped as [E, hidden, intermediate]
                e=w2.shape[0],
                n=w2.shape[2],
                k=w2.shape[1],
142
                quant_config=quant_config,
143
144
145
146
147
148
                device=w1.device,
            ),
        )

        for _ in range(num_repeats):
            fn(a, w1, w2, topk_weights, topk_ids)
149
150

    def run_cutlass_from_graph(
151
152
        a: torch.Tensor,
        a_scale: torch.Tensor,
153
154
        w1: torch.Tensor,
        w2: torch.Tensor,
155
156
157
158
159
        w1_scale: torch.Tensor,
        w2_scale: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
    ):
160
161
162
163
164
165
        quant_config = fp8_w8a8_moe_quant_config(
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            per_act_token_quant=per_act_token,
        )

166
167
168
169
170
171
172
173
174
175
176
177
178
        fn = mk.FusedMoEModularKernel(
            MoEPrepareAndFinalizeNoEP(),
            CutlassExpertsFp8(
                out_dtype=a.dtype,
                # NOTE(rob): w2 is shaped as [E, hidden, intermediate]
                e=w2.shape[0],
                n=w2.shape[2],
                k=w2.shape[1],
                quant_config=quant_config,
                device=w1.device,
            ),
        )

179
        with set_current_vllm_config(
180
181
            VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
        ):
182
            return fn(a, w1, w2, topk_weights, topk_ids)
183
184
185
186
187
188
189
190
191
192
193

    def run_triton_from_graph(
        a: torch.Tensor,
        w1: torch.Tensor,
        w2: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        w1_scale: torch.Tensor,
        w2_scale: torch.Tensor,
        a_scale: torch.Tensor,
    ):
194
195
196
197
198
        quant_config = fp8_w8a8_moe_quant_config(
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            a1_scale=a_scale,
        )
199
        with set_current_vllm_config(
200
201
202
203
204
205
206
207
            VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
        ):
            return fused_experts(
                a,
                w1,
                w2,
                topk_weights,
                topk_ids,
208
                quant_config=quant_config,
209
            )
210
211
212
213
214
215
216
217
218

    def replay_graph(graph, num_repeats):
        for _ in range(num_repeats):
            graph.replay()
        torch.cuda.synchronize()

    cutlass_stream = torch.cuda.Stream()
    cutlass_graph = torch.cuda.CUDAGraph()
    with torch.cuda.graph(cutlass_graph, stream=cutlass_stream):
219
220
221
222
223
224
225
226
227
228
        run_cutlass_from_graph(
            a,
            a_scale,
            w1_q,
            w2_q,
            w1_scale,
            w2_scale,
            topk_weights,
            topk_ids,
        )
229
230
231
232
233
    torch.cuda.synchronize()

    triton_stream = torch.cuda.Stream()
    triton_graph = torch.cuda.CUDAGraph()
    with torch.cuda.graph(triton_graph, stream=triton_stream):
234
235
        run_triton_from_graph(
            a,
236
237
            w1_q,
            w2_q,
238
239
240
241
242
243
            topk_weights,
            topk_ids,
            w1_scale,
            w2_scale,
            a_scale,
        )
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
    torch.cuda.synchronize()

    min_run_time = 5
    num_warmup = 5
    num_runs = 25

    globals = {
        # Baseline params
        "w1": w1,
        "w2": w2,
        "score": score,
        "topk": topk,
        # Cutlass params
        "a_scale": a_scale,
        "w1_q": w1_q,
        "w2_q": w2_q,
        "w1_scale": w1_scale,
        "w2_scale": w2_scale,
bnellnm's avatar
bnellnm committed
262
        "per_act_token": per_act_token,
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
        # cuda graph params
        "cutlass_graph": cutlass_graph,
        "triton_graph": triton_graph,
        # Gen params
        "a": a,
        "topk_weights": topk_weights,
        "topk_ids": topk_ids,
        "num_runs": num_runs,
        # Kernels
        "run_triton_moe": run_triton_moe,
        "run_cutlass_moe": run_cutlass_moe,
        "replay_graph": replay_graph,
    }

    # Warmup
278
279
    run_triton_moe(
        a,
280
281
        w1_q,
        w2_q,
282
283
284
285
286
287
288
        topk_weights,
        topk_ids,
        w1_scale,
        w2_scale,
        a_scale,
        num_warmup,
    )
289
290
291

    results.append(
        benchmark.Timer(
292
            stmt="run_triton_moe(a, w1_q, w2_q, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)",  # noqa: E501
293
294
295
296
            globals=globals,
            label=label,
            sub_label=sub_label,
            description="triton_moe",
297
298
        ).blocked_autorange(min_run_time=min_run_time)
    )
299
300
301
302
303
304
305
306
307
308
309

    # Warmup
    replay_graph(triton_graph, num_warmup)

    results.append(
        benchmark.Timer(
            stmt="replay_graph(triton_graph, num_runs)",
            globals=globals,
            label=label,
            sub_label=sub_label,
            description="triton_moe_cuda_graphs",
310
311
        ).blocked_autorange(min_run_time=min_run_time)
    )
312
313

    # Warmup
314
315
316
317
318
319
320
321
322
    run_cutlass_moe(
        a,
        a_scale,
        w1_q,
        w2_q,
        w1_scale,
        w2_scale,
        topk_weights,
        topk_ids,
bnellnm's avatar
bnellnm committed
323
        per_act_token,
324
325
        num_warmup,
    )
326
327
328

    results.append(
        benchmark.Timer(
329
            stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)",  # noqa: E501
330
331
332
333
            globals=globals,
            label=label,
            sub_label=sub_label,
            description="grouped_gemm_moe",
334
335
        ).blocked_autorange(min_run_time=min_run_time)
    )
336
337
338
339
340
341
342
343
344
345
346

    # Warmup
    replay_graph(cutlass_graph, num_warmup)

    results.append(
        benchmark.Timer(
            stmt="replay_graph(cutlass_graph, num_runs)",
            globals=globals,
            label=label,
            sub_label=sub_label,
            description="grouped_gemm_moe_cuda_graphs",
347
348
        ).blocked_autorange(min_run_time=min_run_time)
    )
349
350
351


def main(args):
352
353
354
355
    # Initialize workspace manager (required for CUTLASS MoE kernels)
    device = torch.device("cuda:0")
    init_workspace_manager(device)

356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
    print("Benchmarking models:")
    for i, model in enumerate(args.models):
        print(f"[{i}]  {model}")

    results: list[benchmark.Measurement] = []

    for model in args.models:
        for tp in args.tp_sizes:
            for layer in WEIGHT_SHAPES_MOE[model]:
                num_experts = layer[0]
                topk = layer[1]
                size_k = layer[2]
                size_n = layer[3] // tp

                if len(args.limit_k) > 0 and size_k not in args.limit_k:
                    continue

                if len(args.limit_n) > 0 and size_n not in args.limit_n:
                    continue

                for per_act_token in PER_ACT_TOKEN_OPTS:
                    for per_out_ch in PER_OUT_CH_OPTS:
                        for size_m in DEFAULT_BATCH_SIZES:
                            mkn = (size_m, size_k, size_n)
380
381
382
383
384
385
386
387
388
                            bench_run(
                                results,
                                model,
                                num_experts,
                                topk,
                                per_act_token,
                                per_out_ch,
                                mkn,
                            )
389
390
391
392
393
394
395

    compare = benchmark.Compare(results)
    compare.print()


if __name__ == "__main__":
    parser = FlexibleArgumentParser(
396
397
        description="Benchmark Marlin across specified models/shapes/batches"
    )
398
399
400
401
402
403
404
    parser.add_argument(
        "--models",
        nargs="+",
        type=str,
        default=DEFAULT_MODELS,
        choices=WEIGHT_SHAPES_MOE.keys(),
    )
405
406
407
408
    parser.add_argument("--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES)
    parser.add_argument(
        "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
    )
409
410
411
    parser.add_argument("--limit-k", nargs="+", type=int, default=[])
    parser.add_argument("--limit-n", nargs="+", type=int, default=[])
    parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[])
412
    parser.add_argument("--limit-per-act-token", nargs="+", type=int, default=[])
413
414
415
416
    parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[])

    args = parser.parse_args()
    main(args)