benchmark_grouped_gemm_cutlass.py 12.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7

import torch
import torch.utils.benchmark as benchmark
from benchmark_shapes import WEIGHT_SHAPES_MOE

8
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
9
from tests.kernels.moe.utils import make_dummy_moe_config
10
11
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
12
13
14
from vllm.model_executor.layers.fused_moe.all2all_utils import (
    maybe_make_prepare_finalize,
)
15
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
16
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
17
18
19
20
from vllm.model_executor.layers.fused_moe.fused_moe import (
    fused_experts,
    fused_topk,
)
21
from vllm.utils.argparse_utils import FlexibleArgumentParser
22
from vllm.v1.worker.workspace import init_workspace_manager
23
24

DEFAULT_MODELS = [
25
26
    "mistralai/Mixtral-8x7B-Instruct-v0.1",
    "deepseek-ai/DeepSeek-V2-Lite",
27
28
    "ibm-granite/granite-3.0-1b-a400m",
    "ibm-granite/granite-3.0-3b-a800m",
29
30
31
32
33
34
35
36
37
38
]
DEFAULT_BATCH_SIZES = [1, 4, 8, 16, 32, 64, 128, 256, 512]
DEFAULT_TP_SIZES = [1]

PER_ACT_TOKEN_OPTS = [False]
PER_OUT_CH_OPTS = [False]


def to_fp8(tensor: torch.Tensor):
    finfo = torch.finfo(torch.float8_e4m3fn)
39
40
41
    return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to(
        dtype=torch.float8_e4m3fn
    )
42
43


44
45
46
47
48
49
50
51
52
def bench_run(
    results: list[benchmark.Measurement],
    model: str,
    num_experts: int,
    topk: int,
    per_act_token: bool,
    per_out_ch: bool,
    mkn: tuple[int, int, int],
):
53
    init_workspace_manager(torch.cuda.current_device())
54
55
56
    label = "Quant Matmul"

    sub_label = (
57
58
59
60
        "{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, MKN=({})".format(
            model, num_experts, topk, per_act_token, per_out_ch, mkn
        )
    )
61
62
63
64
65
66
67
68
69
70
71
72
73

    print(f"Testing: {sub_label}")

    (m, k, n) = mkn

    dtype = torch.half

    a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
    w1 = torch.randn((num_experts, 2 * n, k), device="cuda", dtype=dtype) / 10
    w2 = torch.randn((num_experts, k, n), device="cuda", dtype=dtype) / 10

    _, a_scale = ops.scaled_fp8_quant(a)

74
75
76
77
78
79
80
    w1_q = torch.empty(
        (num_experts, 2 * n, k), device="cuda", dtype=torch.float8_e4m3fn
    )
    w2_q = torch.empty((num_experts, k, n), device="cuda", dtype=torch.float8_e4m3fn)
    w1_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32)
    w2_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32)

81
82
83
84
85
86
    for expert in range(num_experts):
        w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert])
        w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(w2[expert])

    score = torch.randn((m, num_experts), device="cuda", dtype=dtype)

87
    topk_weights, topk_ids, token_expert_indices = fused_topk(
88
89
        a, score, topk, renormalize=False
    )
90

91
92
93
94
95
96
97
98
99
100
101
    def run_triton_moe(
        a: torch.Tensor,
        w1: torch.Tensor,
        w2: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        w1_scale: torch.Tensor,
        w2_scale: torch.Tensor,
        a_scale: torch.Tensor,
        num_repeats: int,
    ):
102
103
104
105
106
        quant_config = fp8_w8a8_moe_quant_config(
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            a1_scale=a_scale,
        )
107
        for _ in range(num_repeats):
108
109
110
111
112
113
            fused_experts(
                a,
                w1,
                w2,
                topk_weights,
                topk_ids,
114
                quant_config=quant_config,
115
116
117
118
119
120
121
122
123
124
125
            )

    def run_cutlass_moe(
        a: torch.Tensor,
        a_scale: torch.Tensor,
        w1: torch.Tensor,
        w2: torch.Tensor,
        w1_scale: torch.Tensor,
        w2_scale: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
bnellnm's avatar
bnellnm committed
126
        per_act_token: bool,
127
128
        num_repeats: int,
    ):
129
130
131
132
133
        quant_config = fp8_w8a8_moe_quant_config(
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            per_act_token_quant=per_act_token,
        )
134
135
136
137
138
139
        moe_config = make_dummy_moe_config(
            num_experts=w2.shape[0],
            hidden_dim=w2.shape[1],
            intermediate_size_per_partition=w2.shape[2],
            in_dtype=a.dtype,
        )
140

141
142
143
144
145
146
147
        fn = mk.FusedMoEKernel(
            maybe_make_prepare_finalize(
                moe=moe_config,
                quant_config=quant_config,
                allow_new_interface=True,
                use_monolithic=False,
            ),
148
            CutlassExpertsFp8(
149
                moe_config=moe_config,
150
                quant_config=quant_config,
151
152
153
154
155
            ),
        )

        for _ in range(num_repeats):
            fn(a, w1, w2, topk_weights, topk_ids)
156
157

    def run_cutlass_from_graph(
158
159
        a: torch.Tensor,
        a_scale: torch.Tensor,
160
161
        w1: torch.Tensor,
        w2: torch.Tensor,
162
163
164
165
166
        w1_scale: torch.Tensor,
        w2_scale: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
    ):
167
168
169
170
171
        quant_config = fp8_w8a8_moe_quant_config(
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            per_act_token_quant=per_act_token,
        )
172
173
174
175
176
177
        moe_config = make_dummy_moe_config(
            num_experts=w2.shape[0],
            hidden_dim=w2.shape[1],
            intermediate_size_per_partition=w2.shape[2],
            in_dtype=a.dtype,
        )
178

179
180
181
182
183
184
185
        fn = mk.FusedMoEKernel(
            maybe_make_prepare_finalize(
                moe=moe_config,
                quant_config=quant_config,
                allow_new_interface=True,
                use_monolithic=False,
            ),
186
            CutlassExpertsFp8(
187
                moe_config=moe_config,
188
189
190
191
                quant_config=quant_config,
            ),
        )

192
        with set_current_vllm_config(
193
194
            VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
        ):
195
            return fn(a, w1, w2, topk_weights, topk_ids)
196
197
198
199
200
201
202
203
204
205
206

    def run_triton_from_graph(
        a: torch.Tensor,
        w1: torch.Tensor,
        w2: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        w1_scale: torch.Tensor,
        w2_scale: torch.Tensor,
        a_scale: torch.Tensor,
    ):
207
208
209
210
211
        quant_config = fp8_w8a8_moe_quant_config(
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            a1_scale=a_scale,
        )
212
        with set_current_vllm_config(
213
214
215
216
217
218
219
220
            VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
        ):
            return fused_experts(
                a,
                w1,
                w2,
                topk_weights,
                topk_ids,
221
                quant_config=quant_config,
222
            )
223
224
225
226

    def replay_graph(graph, num_repeats):
        for _ in range(num_repeats):
            graph.replay()
227
        torch.accelerator.synchronize()
228
229
230
231

    cutlass_stream = torch.cuda.Stream()
    cutlass_graph = torch.cuda.CUDAGraph()
    with torch.cuda.graph(cutlass_graph, stream=cutlass_stream):
232
233
234
235
236
237
238
239
240
241
        run_cutlass_from_graph(
            a,
            a_scale,
            w1_q,
            w2_q,
            w1_scale,
            w2_scale,
            topk_weights,
            topk_ids,
        )
242
    torch.accelerator.synchronize()
243
244
245
246

    triton_stream = torch.cuda.Stream()
    triton_graph = torch.cuda.CUDAGraph()
    with torch.cuda.graph(triton_graph, stream=triton_stream):
247
248
        run_triton_from_graph(
            a,
249
250
            w1_q,
            w2_q,
251
252
253
254
255
256
            topk_weights,
            topk_ids,
            w1_scale,
            w2_scale,
            a_scale,
        )
257
    torch.accelerator.synchronize()
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274

    min_run_time = 5
    num_warmup = 5
    num_runs = 25

    globals = {
        # Baseline params
        "w1": w1,
        "w2": w2,
        "score": score,
        "topk": topk,
        # Cutlass params
        "a_scale": a_scale,
        "w1_q": w1_q,
        "w2_q": w2_q,
        "w1_scale": w1_scale,
        "w2_scale": w2_scale,
bnellnm's avatar
bnellnm committed
275
        "per_act_token": per_act_token,
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
        # cuda graph params
        "cutlass_graph": cutlass_graph,
        "triton_graph": triton_graph,
        # Gen params
        "a": a,
        "topk_weights": topk_weights,
        "topk_ids": topk_ids,
        "num_runs": num_runs,
        # Kernels
        "run_triton_moe": run_triton_moe,
        "run_cutlass_moe": run_cutlass_moe,
        "replay_graph": replay_graph,
    }

    # Warmup
291
292
    run_triton_moe(
        a,
293
294
        w1_q,
        w2_q,
295
296
297
298
299
300
301
        topk_weights,
        topk_ids,
        w1_scale,
        w2_scale,
        a_scale,
        num_warmup,
    )
302
303
304

    results.append(
        benchmark.Timer(
305
            stmt="run_triton_moe(a, w1_q, w2_q, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)",  # noqa: E501
306
307
308
309
            globals=globals,
            label=label,
            sub_label=sub_label,
            description="triton_moe",
310
311
        ).blocked_autorange(min_run_time=min_run_time)
    )
312
313
314
315
316
317
318
319
320
321
322

    # Warmup
    replay_graph(triton_graph, num_warmup)

    results.append(
        benchmark.Timer(
            stmt="replay_graph(triton_graph, num_runs)",
            globals=globals,
            label=label,
            sub_label=sub_label,
            description="triton_moe_cuda_graphs",
323
324
        ).blocked_autorange(min_run_time=min_run_time)
    )
325
326

    # Warmup
327
328
329
330
331
332
333
334
335
    run_cutlass_moe(
        a,
        a_scale,
        w1_q,
        w2_q,
        w1_scale,
        w2_scale,
        topk_weights,
        topk_ids,
bnellnm's avatar
bnellnm committed
336
        per_act_token,
337
338
        num_warmup,
    )
339
340
341

    results.append(
        benchmark.Timer(
342
            stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)",  # noqa: E501
343
344
345
346
            globals=globals,
            label=label,
            sub_label=sub_label,
            description="grouped_gemm_moe",
347
348
        ).blocked_autorange(min_run_time=min_run_time)
    )
349
350
351
352
353
354
355
356
357
358
359

    # Warmup
    replay_graph(cutlass_graph, num_warmup)

    results.append(
        benchmark.Timer(
            stmt="replay_graph(cutlass_graph, num_runs)",
            globals=globals,
            label=label,
            sub_label=sub_label,
            description="grouped_gemm_moe_cuda_graphs",
360
361
        ).blocked_autorange(min_run_time=min_run_time)
    )
362
363
364


def main(args):
365
366
367
368
    # Initialize workspace manager (required for CUTLASS MoE kernels)
    device = torch.device("cuda:0")
    init_workspace_manager(device)

369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
    print("Benchmarking models:")
    for i, model in enumerate(args.models):
        print(f"[{i}]  {model}")

    results: list[benchmark.Measurement] = []

    for model in args.models:
        for tp in args.tp_sizes:
            for layer in WEIGHT_SHAPES_MOE[model]:
                num_experts = layer[0]
                topk = layer[1]
                size_k = layer[2]
                size_n = layer[3] // tp

                if len(args.limit_k) > 0 and size_k not in args.limit_k:
                    continue

                if len(args.limit_n) > 0 and size_n not in args.limit_n:
                    continue

                for per_act_token in PER_ACT_TOKEN_OPTS:
                    for per_out_ch in PER_OUT_CH_OPTS:
                        for size_m in DEFAULT_BATCH_SIZES:
                            mkn = (size_m, size_k, size_n)
393
394
395
396
397
398
399
400
401
                            bench_run(
                                results,
                                model,
                                num_experts,
                                topk,
                                per_act_token,
                                per_out_ch,
                                mkn,
                            )
402
403
404
405
406
407
408

    compare = benchmark.Compare(results)
    compare.print()


if __name__ == "__main__":
    parser = FlexibleArgumentParser(
409
410
        description="Benchmark Marlin across specified models/shapes/batches"
    )
411
412
413
414
415
416
417
    parser.add_argument(
        "--models",
        nargs="+",
        type=str,
        default=DEFAULT_MODELS,
        choices=WEIGHT_SHAPES_MOE.keys(),
    )
418
419
420
421
    parser.add_argument("--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES)
    parser.add_argument(
        "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
    )
422
423
424
    parser.add_argument("--limit-k", nargs="+", type=int, default=[])
    parser.add_argument("--limit-n", nargs="+", type=int, default=[])
    parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[])
425
    parser.add_argument("--limit-per-act-token", nargs="+", type=int, default=[])
426
427
428
429
    parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[])

    args = parser.parse_args()
    main(args)