benchmark_moe_permute_unpermute.py 12.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10

import argparse
from typing import Any, TypedDict

import ray
import torch
from transformers import AutoConfig

11
12
from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
13
14
    _moe_permute,
    _moe_unpermute_and_reduce,
15
16
    moe_permute,
    moe_unpermute,
17
)
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize
from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser

FP8_DTYPE = current_platform.fp8_dtype()


class BenchmarkConfig(TypedDict):
    BLOCK_SIZE_M: int
    BLOCK_SIZE_N: int
    BLOCK_SIZE_K: int
    GROUP_SIZE_M: int
    num_warps: int
    num_stages: int


34
35
36
37
38
39
40
41
42
43
44
def benchmark_permute(
    num_tokens: int,
    num_experts: int,
    hidden_size: int,
    topk: int,
    dtype: torch.dtype,
    use_fp8_w8a8: bool,
    use_int8_w8a16: bool,
    num_iters: int = 100,
    use_customized_permute: bool = False,
) -> float:
45
46
47
48
49
50
51
52
53
54
    # init_dtype = torch.float16 if use_fp8_w8a8 else dtype
    hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
    # output_hidden_states = torch.empty_like(hidden_states)
    if use_fp8_w8a8:
        align_block_size = 128  # deepgemm needs 128 m aligned block
        qhidden_states, scale = _fp8_quantize(hidden_states, None, None)
    else:
        align_block_size = None
        qhidden_states = hidden_states

55
    gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32)
56
57
58

    input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
    topk_weights, topk_ids, token_expert_indices = fused_topk(
59
60
        qhidden_states, input_gating, topk, False
    )
61
62
63
64
65
66

    def prepare(i: int):
        input_gating.copy_(gating_output[i])

    def run():
        if use_customized_permute:
67
68
69
70
71
72
73
74
75
76
77
78
79
            (
                permuted_hidden_states,
                a1q_scale,
                first_token_off,
                inv_perm_idx,
                m_indices,
            ) = moe_permute(
                qhidden_states,
                a1q_scale=None,
                topk_ids=topk_ids,
                n_expert=num_experts,
                expert_map=None,
                align_block_size=align_block_size,
80
            )
81
        else:
82
83
84
85
86
87
88
89
90
            (
                permuted_hidden_states,
                a1q_scale,
                sorted_token_ids,
                expert_ids,
                inv_perm,
            ) = _moe_permute(
                qhidden_states, None, topk_ids, num_experts, None, align_block_size
            )
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125

    # JIT compilation & warmup
    run()
    torch.cuda.synchronize()

    # Capture 10 invocations with CUDA graph
    graph = torch.cuda.CUDAGraph()
    with torch.cuda.graph(graph):
        for _ in range(10):
            run()
    torch.cuda.synchronize()

    # Warmup
    for _ in range(5):
        graph.replay()
    torch.cuda.synchronize()

    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)

    latencies: list[float] = []
    for i in range(num_iters):
        prepare(i)
        torch.cuda.synchronize()

        start_event.record()
        graph.replay()
        end_event.record()
        end_event.synchronize()
        latencies.append(start_event.elapsed_time(end_event))
    avg = sum(latencies) / (num_iters * 10) * 1000  # us
    graph.reset()
    return avg


126
127
128
129
130
131
132
133
134
135
136
def benchmark_unpermute(
    num_tokens: int,
    num_experts: int,
    hidden_size: int,
    topk: int,
    dtype: torch.dtype,
    use_fp8_w8a8: bool,
    use_int8_w8a16: bool,
    num_iters: int = 100,
    use_customized_permute: bool = False,
) -> float:
137
138
139
140
141
142
143
144
145
146
147
148
149
    # init_dtype = torch.float16 if use_fp8_w8a8 else dtype
    hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
    output_hidden_states = torch.empty_like(hidden_states)
    if use_fp8_w8a8:
        align_block_size = 128  # deepgemm needs 128 m aligned block
        qhidden_states, scale = _fp8_quantize(hidden_states, None, None)
    else:
        align_block_size = None
        qhidden_states = hidden_states

    input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)

    topk_weights, topk_ids, token_expert_indices = fused_topk(
150
151
        qhidden_states, input_gating, topk, False
    )
152
153
154

    def prepare():
        if use_customized_permute:
155
156
157
158
159
160
161
162
163
164
165
166
167
            (
                permuted_hidden_states,
                a1q_scale,
                first_token_off,
                inv_perm_idx,
                m_indices,
            ) = moe_permute(
                qhidden_states,
                a1q_scale=None,
                topk_ids=topk_ids,
                n_expert=num_experts,
                expert_map=None,
                align_block_size=align_block_size,
168
            )
169
            # convert to fp16/bf16 as gemm output
170
171
172
173
174
175
            return (
                permuted_hidden_states.to(dtype),
                first_token_off,
                inv_perm_idx,
                m_indices,
            )
176
        else:
177
178
179
180
181
182
183
184
185
            (
                permuted_qhidden_states,
                a1q_scale,
                sorted_token_ids,
                expert_ids,
                inv_perm,
            ) = _moe_permute(
                qhidden_states, None, topk_ids, num_experts, None, align_block_size
            )
186
            # convert to fp16/bf16 as gemm output
187
188
189
190
191
192
193
            return (
                permuted_qhidden_states.to(dtype),
                a1q_scale,
                sorted_token_ids,
                expert_ids,
                inv_perm,
            )
194
195
196

    def run(input: tuple):
        if use_customized_permute:
197
198
199
200
201
202
203
            (
                permuted_hidden_states,
                first_token_off,
                inv_perm_idx,
                m_indices,
            ) = input
            output = torch.empty_like(hidden_states)
204
            moe_unpermute(
205
                output,
206
207
208
209
210
                permuted_hidden_states,
                topk_weights,
                inv_perm_idx,
                first_token_off,
            )
211
        else:
212
213
214
215
216
217
218
219
            (
                permuted_hidden_states,
                a1q_scale,
                sorted_token_ids,
                expert_ids,
                inv_perm,
            ) = input
            _moe_unpermute_and_reduce(
220
221
222
223
224
                output_hidden_states,
                permuted_hidden_states,
                inv_perm,
                topk_weights,
                True,
225
            )
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292

    # JIT compilation & warmup
    input = prepare()
    run(input)
    torch.cuda.synchronize()

    # Capture 10 invocations with CUDA graph
    graph = torch.cuda.CUDAGraph()
    with torch.cuda.graph(graph):
        for _ in range(10):
            run(input)
    torch.cuda.synchronize()

    # Warmup
    for _ in range(5):
        graph.replay()
    torch.cuda.synchronize()

    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)

    latencies: list[float] = []
    for i in range(num_iters):
        torch.cuda.synchronize()
        start_event.record()
        graph.replay()
        end_event.record()
        end_event.synchronize()
        latencies.append(start_event.elapsed_time(end_event))
    avg = sum(latencies) / (num_iters * 10) * 1000  # us
    graph.reset()
    return avg


@ray.remote(num_gpus=1)
class BenchmarkWorker:
    def __init__(self, seed: int) -> None:
        torch.set_default_device("cuda")
        current_platform.seed_everything(seed)
        self.seed = seed
        # Get the device ID to allocate tensors and kernels
        # on the respective GPU. This is required for Ray to work
        # correctly with multi-GPU tuning on the ROCm platform.
        self.device_id = int(ray.get_gpu_ids()[0])

    def benchmark(
        self,
        num_tokens: int,
        num_experts: int,
        hidden_size: int,
        topk: int,
        dtype: torch.dtype,
        use_fp8_w8a8: bool,
        use_int8_w8a16: bool,
        use_customized_permute: bool = False,
    ) -> tuple[dict[str, int], float]:
        current_platform.seed_everything(self.seed)

        permute_time = benchmark_permute(
            num_tokens,
            num_experts,
            hidden_size,
            topk,
            dtype,
            use_fp8_w8a8,
            use_int8_w8a16,
            num_iters=100,
293
294
            use_customized_permute=use_customized_permute,
        )
295
296
297
298
299
300
301
302
303
        unpermute_time = benchmark_unpermute(
            num_tokens,
            num_experts,
            hidden_size,
            topk,
            dtype,
            use_fp8_w8a8,
            use_int8_w8a16,
            num_iters=100,
304
305
            use_customized_permute=use_customized_permute,
        )
306
307
308
309
        return permute_time, unpermute_time


def get_weight_block_size_safety(config, default_value=None):
310
    quantization_config = getattr(config, "quantization_config", {})
311
    if isinstance(quantization_config, dict):
312
        return quantization_config.get("weight_block_size", default_value)
313
314
315
316
317
318
319
    return default_value


def main(args: argparse.Namespace):
    print(args)

    config = AutoConfig.from_pretrained(
320
321
        args.model, trust_remote_code=args.trust_remote_code
    )
322
323
324
325
326
327
    if config.architectures[0] == "DbrxForCausalLM":
        E = config.ffn_config.moe_num_experts
        topk = config.ffn_config.moe_top_k
    elif config.architectures[0] == "JambaForCausalLM":
        E = config.num_experts
        topk = config.num_experts_per_tok
328
329
330
    elif (
        config.architectures[0] == "DeepseekV3ForCausalLM"
        or config.architectures[0] == "DeepseekV2ForCausalLM"
Yuxuan Zhang's avatar
Yuxuan Zhang committed
331
        or config.architectures[0] == "Glm4MoeForCausalLM"
332
    ):
333
334
        E = config.n_routed_experts
        topk = config.num_experts_per_tok
335
    elif config.architectures[0] in ["Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"]:
336
337
338
339
340
341
342
343
344
345
346
        E = config.num_experts
        topk = config.num_experts_per_tok

    else:
        # Support for llama4
        config = config.get_text_config()
        # Default: Mixtral.
        E = config.num_local_experts
        topk = config.num_experts_per_tok

    hidden_size = config.hidden_size
347
    dtype = torch.float16 if current_platform.is_rocm() else config.dtype
348
349
350
351
352
353
    use_fp8_w8a8 = args.dtype == "fp8_w8a8"
    use_int8_w8a16 = args.dtype == "int8_w8a16"
    use_customized_permute = args.use_customized_permute

    if args.batch_size is None:
        batch_sizes = [
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
            1,
            2,
            4,
            8,
            16,
            24,
            32,
            48,
            64,
            96,
            128,
            256,
            512,
            1024,
            1536,
            2048,
            3072,
            4096,
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
        ]
    else:
        batch_sizes = [args.batch_size]

    ray.init()
    num_gpus = int(ray.available_resources()["GPU"])
    workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]

    def _distribute(method: str, inputs: list[Any]) -> list[Any]:
        outputs = []
        worker_idx = 0
        for input_args in inputs:
            worker = workers[worker_idx]
            worker_method = getattr(worker, method)
            output = worker_method.remote(*input_args)
            outputs.append(output)
            worker_idx = (worker_idx + 1) % num_gpus
        return ray.get(outputs)

    outputs = _distribute(
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
        "benchmark",
        [
            (
                batch_size,
                E,
                hidden_size,
                topk,
                dtype,
                use_fp8_w8a8,
                use_int8_w8a16,
                use_customized_permute,
            )
            for batch_size in batch_sizes
        ],
    )
407
408
409
410
411
412
413
414
415

    for batch_size, (permute, unpermute) in zip(batch_sizes, outputs):
        print(f"Batch size: {batch_size}")
        print(f"Permute time: {permute:.2f} us")
        print(f"Unpermute time: {unpermute:.2f} us")


if __name__ == "__main__":
    parser = FlexibleArgumentParser()
416
417
418
419
420
421
    parser.add_argument(
        "--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
    )
    parser.add_argument(
        "--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto"
    )
422
423
424
425
426
427
428
    parser.add_argument("--use-customized-permute", action="store_true")
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--batch-size", type=int, required=False)
    parser.add_argument("--trust-remote-code", action="store_true")
    args = parser.parse_args()

    main(args)