benchmark_moe_permute_unpermute.py 9.83 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10

import argparse
from typing import Any, TypedDict

import ray
import torch
from transformers import AutoConfig

11
from vllm.model_executor.layers.fused_moe import fused_topk
12
13
14
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
    moe_permute,
    moe_unpermute,
15
)
16
17
from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize
from vllm.platforms import current_platform
18
from vllm.utils.argparse_utils import FlexibleArgumentParser
19
from vllm.utils.torch_utils import set_random_seed
20
21
22
23
24
25
26
27
28
29
30
31
32

FP8_DTYPE = current_platform.fp8_dtype()


class BenchmarkConfig(TypedDict):
    BLOCK_SIZE_M: int
    BLOCK_SIZE_N: int
    BLOCK_SIZE_K: int
    GROUP_SIZE_M: int
    num_warps: int
    num_stages: int


33
34
35
36
37
38
39
40
41
42
def benchmark_permute(
    num_tokens: int,
    num_experts: int,
    hidden_size: int,
    topk: int,
    dtype: torch.dtype,
    use_fp8_w8a8: bool,
    use_int8_w8a16: bool,
    num_iters: int = 100,
) -> float:
43
44
45
46
47
48
49
50
    # init_dtype = torch.float16 if use_fp8_w8a8 else dtype
    hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
    # output_hidden_states = torch.empty_like(hidden_states)
    if use_fp8_w8a8:
        qhidden_states, scale = _fp8_quantize(hidden_states, None, None)
    else:
        qhidden_states = hidden_states

51
    gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32)
52
53
54

    input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
    topk_weights, topk_ids, token_expert_indices = fused_topk(
55
56
        qhidden_states, input_gating, topk, False
    )
57
58
59
60
61

    def prepare(i: int):
        input_gating.copy_(gating_output[i])

    def run():
62
63
64
65
66
67
68
        moe_permute(
            qhidden_states,
            a1q_scale=None,
            topk_ids=topk_ids,
            n_expert=num_experts,
            expert_map=None,
        )
69
70
71

    # JIT compilation & warmup
    run()
72
    torch.accelerator.synchronize()
73
74
75
76
77
78

    # Capture 10 invocations with CUDA graph
    graph = torch.cuda.CUDAGraph()
    with torch.cuda.graph(graph):
        for _ in range(10):
            run()
79
    torch.accelerator.synchronize()
80
81
82
83

    # Warmup
    for _ in range(5):
        graph.replay()
84
    torch.accelerator.synchronize()
85

86
87
    start_event = torch.Event(enable_timing=True)
    end_event = torch.Event(enable_timing=True)
88
89
90
91

    latencies: list[float] = []
    for i in range(num_iters):
        prepare(i)
92
        torch.accelerator.synchronize()
93
94
95
96
97
98
99
100
101
102
103

        start_event.record()
        graph.replay()
        end_event.record()
        end_event.synchronize()
        latencies.append(start_event.elapsed_time(end_event))
    avg = sum(latencies) / (num_iters * 10) * 1000  # us
    graph.reset()
    return avg


104
105
106
107
108
109
110
111
112
113
def benchmark_unpermute(
    num_tokens: int,
    num_experts: int,
    hidden_size: int,
    topk: int,
    dtype: torch.dtype,
    use_fp8_w8a8: bool,
    use_int8_w8a16: bool,
    num_iters: int = 100,
) -> float:
114
115
116
117
118
119
120
121
122
123
    # init_dtype = torch.float16 if use_fp8_w8a8 else dtype
    hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
    if use_fp8_w8a8:
        qhidden_states, scale = _fp8_quantize(hidden_states, None, None)
    else:
        qhidden_states = hidden_states

    input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)

    topk_weights, topk_ids, token_expert_indices = fused_topk(
124
125
        qhidden_states, input_gating, topk, False
    )
126
127

    def prepare():
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
        (
            permuted_hidden_states,
            _,
            first_token_off,
            inv_perm_idx,
            _,
        ) = moe_permute(
            qhidden_states,
            a1q_scale=None,
            topk_ids=topk_ids,
            n_expert=num_experts,
            expert_map=None,
        )
        # convert to fp16/bf16 as gemm output
        return (
            permuted_hidden_states.to(dtype),
            first_token_off,
            inv_perm_idx,
        )
147
148

    def run(input: tuple):
zhuwenwen's avatar
zhuwenwen committed
149
150
151
152
153
154
155
156
157
        (permuted_hidden_states, first_token_off, inv_perm_idx) = input
        output = torch.empty_like(hidden_states)
        moe_unpermute(
            output,
            permuted_hidden_states,
            topk_weights,
            inv_perm_idx,
            first_token_off,
        )
158
159
160
161

    # JIT compilation & warmup
    input = prepare()
    run(input)
162
    torch.accelerator.synchronize()
163
164
165
166
167
168

    # Capture 10 invocations with CUDA graph
    graph = torch.cuda.CUDAGraph()
    with torch.cuda.graph(graph):
        for _ in range(10):
            run(input)
169
    torch.accelerator.synchronize()
170
171
172
173

    # Warmup
    for _ in range(5):
        graph.replay()
174
    torch.accelerator.synchronize()
175

176
177
    start_event = torch.Event(enable_timing=True)
    end_event = torch.Event(enable_timing=True)
178
179
180

    latencies: list[float] = []
    for i in range(num_iters):
181
        torch.accelerator.synchronize()
182
183
184
185
186
187
188
189
190
191
192
193
194
195
        start_event.record()
        graph.replay()
        end_event.record()
        end_event.synchronize()
        latencies.append(start_event.elapsed_time(end_event))
    avg = sum(latencies) / (num_iters * 10) * 1000  # us
    graph.reset()
    return avg


@ray.remote(num_gpus=1)
class BenchmarkWorker:
    def __init__(self, seed: int) -> None:
        torch.set_default_device("cuda")
196
        set_random_seed(seed)
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
        self.seed = seed
        # Get the device ID to allocate tensors and kernels
        # on the respective GPU. This is required for Ray to work
        # correctly with multi-GPU tuning on the ROCm platform.
        self.device_id = int(ray.get_gpu_ids()[0])

    def benchmark(
        self,
        num_tokens: int,
        num_experts: int,
        hidden_size: int,
        topk: int,
        dtype: torch.dtype,
        use_fp8_w8a8: bool,
        use_int8_w8a16: bool,
zhuwenwen's avatar
zhuwenwen committed
212
    ) -> tuple[float, float]:
213
        set_random_seed(self.seed)
214
215
216
217
218
219
220
221
222
223

        permute_time = benchmark_permute(
            num_tokens,
            num_experts,
            hidden_size,
            topk,
            dtype,
            use_fp8_w8a8,
            use_int8_w8a16,
            num_iters=100,
224
        )
225
226
227
228
229
230
231
232
233
        unpermute_time = benchmark_unpermute(
            num_tokens,
            num_experts,
            hidden_size,
            topk,
            dtype,
            use_fp8_w8a8,
            use_int8_w8a16,
            num_iters=100,
234
        )
235
236
237
238
        return permute_time, unpermute_time


def get_weight_block_size_safety(config, default_value=None):
239
    quantization_config = getattr(config, "quantization_config", {})
240
    if isinstance(quantization_config, dict):
241
        return quantization_config.get("weight_block_size", default_value)
242
243
244
245
246
247
248
    return default_value


def main(args: argparse.Namespace):
    print(args)

    config = AutoConfig.from_pretrained(
249
250
        args.model, trust_remote_code=args.trust_remote_code
    )
251
252
253
254
255
256
    if config.architectures[0] == "DbrxForCausalLM":
        E = config.ffn_config.moe_num_experts
        topk = config.ffn_config.moe_top_k
    elif config.architectures[0] == "JambaForCausalLM":
        E = config.num_experts
        topk = config.num_experts_per_tok
257
258
259
    elif (
        config.architectures[0] == "DeepseekV3ForCausalLM"
        or config.architectures[0] == "DeepseekV2ForCausalLM"
Yuxuan Zhang's avatar
Yuxuan Zhang committed
260
        or config.architectures[0] == "Glm4MoeForCausalLM"
261
        or config.architectures[0] == "Glm4MoeLiteForCausalLM"
262
    ):
263
264
        E = config.n_routed_experts
        topk = config.num_experts_per_tok
265
    elif config.architectures[0] in ["Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"]:
266
267
268
269
270
271
272
273
274
275
276
        E = config.num_experts
        topk = config.num_experts_per_tok

    else:
        # Support for llama4
        config = config.get_text_config()
        # Default: Mixtral.
        E = config.num_local_experts
        topk = config.num_experts_per_tok

    hidden_size = config.hidden_size
277
    dtype = torch.float16 if current_platform.is_rocm() else config.dtype
278
279
280
281
282
    use_fp8_w8a8 = args.dtype == "fp8_w8a8"
    use_int8_w8a16 = args.dtype == "int8_w8a16"

    if args.batch_size is None:
        batch_sizes = [
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
            1,
            2,
            4,
            8,
            16,
            24,
            32,
            48,
            64,
            96,
            128,
            256,
            512,
            1024,
            1536,
            2048,
            3072,
            4096,
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
        ]
    else:
        batch_sizes = [args.batch_size]

    ray.init()
    num_gpus = int(ray.available_resources()["GPU"])
    workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]

    def _distribute(method: str, inputs: list[Any]) -> list[Any]:
        outputs = []
        worker_idx = 0
        for input_args in inputs:
            worker = workers[worker_idx]
            worker_method = getattr(worker, method)
            output = worker_method.remote(*input_args)
            outputs.append(output)
            worker_idx = (worker_idx + 1) % num_gpus
        return ray.get(outputs)

    outputs = _distribute(
321
322
323
324
325
326
327
328
329
330
331
332
333
334
        "benchmark",
        [
            (
                batch_size,
                E,
                hidden_size,
                topk,
                dtype,
                use_fp8_w8a8,
                use_int8_w8a16,
            )
            for batch_size in batch_sizes
        ],
    )
335
336
337
338
339
340
341
342
343

    for batch_size, (permute, unpermute) in zip(batch_sizes, outputs):
        print(f"Batch size: {batch_size}")
        print(f"Permute time: {permute:.2f} us")
        print(f"Unpermute time: {unpermute:.2f} us")


if __name__ == "__main__":
    parser = FlexibleArgumentParser()
344
345
346
347
348
349
    parser.add_argument(
        "--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
    )
    parser.add_argument(
        "--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto"
    )
350
351
352
353
354
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--batch-size", type=int, required=False)
    parser.add_argument("--trust-remote-code", action="store_true")
    args = parser.parse_args()

zhuwenwen's avatar
zhuwenwen committed
355
    main(args)