benchmark_moe_permute_unpermute.py 12.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10

import argparse
from typing import Any, TypedDict

import ray
import torch
from transformers import AutoConfig

11
12
from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
13
14
    _moe_permute,
    _moe_unpermute_and_reduce,
15
16
    moe_permute,
    moe_unpermute,
17
)
18
19
from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize
from vllm.platforms import current_platform
20
from vllm.utils.argparse_utils import FlexibleArgumentParser
21
from vllm.utils.torch_utils import set_random_seed
22
23
24
25
26
27
28
29
30
31
32
33
34

FP8_DTYPE = current_platform.fp8_dtype()


class BenchmarkConfig(TypedDict):
    BLOCK_SIZE_M: int
    BLOCK_SIZE_N: int
    BLOCK_SIZE_K: int
    GROUP_SIZE_M: int
    num_warps: int
    num_stages: int


35
36
37
38
39
40
41
42
43
44
45
def benchmark_permute(
    num_tokens: int,
    num_experts: int,
    hidden_size: int,
    topk: int,
    dtype: torch.dtype,
    use_fp8_w8a8: bool,
    use_int8_w8a16: bool,
    num_iters: int = 100,
    use_customized_permute: bool = False,
) -> float:
46
47
48
49
50
51
52
53
54
55
    # init_dtype = torch.float16 if use_fp8_w8a8 else dtype
    hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
    # output_hidden_states = torch.empty_like(hidden_states)
    if use_fp8_w8a8:
        align_block_size = 128  # deepgemm needs 128 m aligned block
        qhidden_states, scale = _fp8_quantize(hidden_states, None, None)
    else:
        align_block_size = None
        qhidden_states = hidden_states

56
    gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32)
57
58
59

    input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
    topk_weights, topk_ids, token_expert_indices = fused_topk(
60
61
        qhidden_states, input_gating, topk, False
    )
62
63
64
65
66
67

    def prepare(i: int):
        input_gating.copy_(gating_output[i])

    def run():
        if use_customized_permute:
68
69
70
71
72
73
74
75
76
77
78
79
80
            (
                permuted_hidden_states,
                a1q_scale,
                first_token_off,
                inv_perm_idx,
                m_indices,
            ) = moe_permute(
                qhidden_states,
                a1q_scale=None,
                topk_ids=topk_ids,
                n_expert=num_experts,
                expert_map=None,
                align_block_size=align_block_size,
81
            )
82
        else:
83
84
85
86
87
88
89
90
91
            (
                permuted_hidden_states,
                a1q_scale,
                sorted_token_ids,
                expert_ids,
                inv_perm,
            ) = _moe_permute(
                qhidden_states, None, topk_ids, num_experts, None, align_block_size
            )
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108

    # JIT compilation & warmup
    run()
    torch.cuda.synchronize()

    # Capture 10 invocations with CUDA graph
    graph = torch.cuda.CUDAGraph()
    with torch.cuda.graph(graph):
        for _ in range(10):
            run()
    torch.cuda.synchronize()

    # Warmup
    for _ in range(5):
        graph.replay()
    torch.cuda.synchronize()

109
110
    start_event = torch.Event(enable_timing=True)
    end_event = torch.Event(enable_timing=True)
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126

    latencies: list[float] = []
    for i in range(num_iters):
        prepare(i)
        torch.cuda.synchronize()

        start_event.record()
        graph.replay()
        end_event.record()
        end_event.synchronize()
        latencies.append(start_event.elapsed_time(end_event))
    avg = sum(latencies) / (num_iters * 10) * 1000  # us
    graph.reset()
    return avg


127
128
129
130
131
132
133
134
135
136
137
def benchmark_unpermute(
    num_tokens: int,
    num_experts: int,
    hidden_size: int,
    topk: int,
    dtype: torch.dtype,
    use_fp8_w8a8: bool,
    use_int8_w8a16: bool,
    num_iters: int = 100,
    use_customized_permute: bool = False,
) -> float:
138
139
140
141
142
143
144
145
146
147
148
149
150
    # init_dtype = torch.float16 if use_fp8_w8a8 else dtype
    hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
    output_hidden_states = torch.empty_like(hidden_states)
    if use_fp8_w8a8:
        align_block_size = 128  # deepgemm needs 128 m aligned block
        qhidden_states, scale = _fp8_quantize(hidden_states, None, None)
    else:
        align_block_size = None
        qhidden_states = hidden_states

    input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)

    topk_weights, topk_ids, token_expert_indices = fused_topk(
151
152
        qhidden_states, input_gating, topk, False
    )
153
154
155

    def prepare():
        if use_customized_permute:
156
157
158
159
160
161
162
163
164
165
166
167
168
            (
                permuted_hidden_states,
                a1q_scale,
                first_token_off,
                inv_perm_idx,
                m_indices,
            ) = moe_permute(
                qhidden_states,
                a1q_scale=None,
                topk_ids=topk_ids,
                n_expert=num_experts,
                expert_map=None,
                align_block_size=align_block_size,
169
            )
170
            # convert to fp16/bf16 as gemm output
171
172
173
174
175
176
            return (
                permuted_hidden_states.to(dtype),
                first_token_off,
                inv_perm_idx,
                m_indices,
            )
177
        else:
178
179
180
181
182
183
184
185
186
            (
                permuted_qhidden_states,
                a1q_scale,
                sorted_token_ids,
                expert_ids,
                inv_perm,
            ) = _moe_permute(
                qhidden_states, None, topk_ids, num_experts, None, align_block_size
            )
187
            # convert to fp16/bf16 as gemm output
188
189
190
191
192
193
194
            return (
                permuted_qhidden_states.to(dtype),
                a1q_scale,
                sorted_token_ids,
                expert_ids,
                inv_perm,
            )
195
196
197

    def run(input: tuple):
        if use_customized_permute:
198
199
200
201
202
203
204
            (
                permuted_hidden_states,
                first_token_off,
                inv_perm_idx,
                m_indices,
            ) = input
            output = torch.empty_like(hidden_states)
205
            moe_unpermute(
206
                output,
207
208
209
210
211
                permuted_hidden_states,
                topk_weights,
                inv_perm_idx,
                first_token_off,
            )
212
        else:
213
214
215
216
217
218
219
220
            (
                permuted_hidden_states,
                a1q_scale,
                sorted_token_ids,
                expert_ids,
                inv_perm,
            ) = input
            _moe_unpermute_and_reduce(
221
222
223
224
225
                output_hidden_states,
                permuted_hidden_states,
                inv_perm,
                topk_weights,
                True,
226
            )
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244

    # JIT compilation & warmup
    input = prepare()
    run(input)
    torch.cuda.synchronize()

    # Capture 10 invocations with CUDA graph
    graph = torch.cuda.CUDAGraph()
    with torch.cuda.graph(graph):
        for _ in range(10):
            run(input)
    torch.cuda.synchronize()

    # Warmup
    for _ in range(5):
        graph.replay()
    torch.cuda.synchronize()

245
246
    start_event = torch.Event(enable_timing=True)
    end_event = torch.Event(enable_timing=True)
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264

    latencies: list[float] = []
    for i in range(num_iters):
        torch.cuda.synchronize()
        start_event.record()
        graph.replay()
        end_event.record()
        end_event.synchronize()
        latencies.append(start_event.elapsed_time(end_event))
    avg = sum(latencies) / (num_iters * 10) * 1000  # us
    graph.reset()
    return avg


@ray.remote(num_gpus=1)
class BenchmarkWorker:
    def __init__(self, seed: int) -> None:
        torch.set_default_device("cuda")
265
        set_random_seed(seed)
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
        self.seed = seed
        # Get the device ID to allocate tensors and kernels
        # on the respective GPU. This is required for Ray to work
        # correctly with multi-GPU tuning on the ROCm platform.
        self.device_id = int(ray.get_gpu_ids()[0])

    def benchmark(
        self,
        num_tokens: int,
        num_experts: int,
        hidden_size: int,
        topk: int,
        dtype: torch.dtype,
        use_fp8_w8a8: bool,
        use_int8_w8a16: bool,
        use_customized_permute: bool = False,
    ) -> tuple[dict[str, int], float]:
283
        set_random_seed(self.seed)
284
285
286
287
288
289
290
291
292
293

        permute_time = benchmark_permute(
            num_tokens,
            num_experts,
            hidden_size,
            topk,
            dtype,
            use_fp8_w8a8,
            use_int8_w8a16,
            num_iters=100,
294
295
            use_customized_permute=use_customized_permute,
        )
296
297
298
299
300
301
302
303
304
        unpermute_time = benchmark_unpermute(
            num_tokens,
            num_experts,
            hidden_size,
            topk,
            dtype,
            use_fp8_w8a8,
            use_int8_w8a16,
            num_iters=100,
305
306
            use_customized_permute=use_customized_permute,
        )
307
308
309
310
        return permute_time, unpermute_time


def get_weight_block_size_safety(config, default_value=None):
311
    quantization_config = getattr(config, "quantization_config", {})
312
    if isinstance(quantization_config, dict):
313
        return quantization_config.get("weight_block_size", default_value)
314
315
316
317
318
319
320
    return default_value


def main(args: argparse.Namespace):
    print(args)

    config = AutoConfig.from_pretrained(
321
322
        args.model, trust_remote_code=args.trust_remote_code
    )
323
324
325
326
327
328
    if config.architectures[0] == "DbrxForCausalLM":
        E = config.ffn_config.moe_num_experts
        topk = config.ffn_config.moe_top_k
    elif config.architectures[0] == "JambaForCausalLM":
        E = config.num_experts
        topk = config.num_experts_per_tok
329
330
331
    elif (
        config.architectures[0] == "DeepseekV3ForCausalLM"
        or config.architectures[0] == "DeepseekV2ForCausalLM"
Yuxuan Zhang's avatar
Yuxuan Zhang committed
332
        or config.architectures[0] == "Glm4MoeForCausalLM"
333
    ):
334
335
        E = config.n_routed_experts
        topk = config.num_experts_per_tok
336
    elif config.architectures[0] in ["Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"]:
337
338
339
340
341
342
343
344
345
346
347
        E = config.num_experts
        topk = config.num_experts_per_tok

    else:
        # Support for llama4
        config = config.get_text_config()
        # Default: Mixtral.
        E = config.num_local_experts
        topk = config.num_experts_per_tok

    hidden_size = config.hidden_size
348
    dtype = torch.float16 if current_platform.is_rocm() else config.dtype
349
350
351
352
353
354
    use_fp8_w8a8 = args.dtype == "fp8_w8a8"
    use_int8_w8a16 = args.dtype == "int8_w8a16"
    use_customized_permute = args.use_customized_permute

    if args.batch_size is None:
        batch_sizes = [
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
            1,
            2,
            4,
            8,
            16,
            24,
            32,
            48,
            64,
            96,
            128,
            256,
            512,
            1024,
            1536,
            2048,
            3072,
            4096,
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
        ]
    else:
        batch_sizes = [args.batch_size]

    ray.init()
    num_gpus = int(ray.available_resources()["GPU"])
    workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]

    def _distribute(method: str, inputs: list[Any]) -> list[Any]:
        outputs = []
        worker_idx = 0
        for input_args in inputs:
            worker = workers[worker_idx]
            worker_method = getattr(worker, method)
            output = worker_method.remote(*input_args)
            outputs.append(output)
            worker_idx = (worker_idx + 1) % num_gpus
        return ray.get(outputs)

    outputs = _distribute(
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
        "benchmark",
        [
            (
                batch_size,
                E,
                hidden_size,
                topk,
                dtype,
                use_fp8_w8a8,
                use_int8_w8a16,
                use_customized_permute,
            )
            for batch_size in batch_sizes
        ],
    )
408
409
410
411
412
413
414
415
416

    for batch_size, (permute, unpermute) in zip(batch_sizes, outputs):
        print(f"Batch size: {batch_size}")
        print(f"Permute time: {permute:.2f} us")
        print(f"Unpermute time: {unpermute:.2f} us")


if __name__ == "__main__":
    parser = FlexibleArgumentParser()
417
418
419
420
421
422
    parser.add_argument(
        "--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
    )
    parser.add_argument(
        "--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto"
    )
423
424
425
426
427
428
429
    parser.add_argument("--use-customized-permute", action="store_true")
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--batch-size", type=int, required=False)
    parser.add_argument("--trust-remote-code", action="store_true")
    args = parser.parse_args()

    main(args)