benchmark_moe.py 27.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import argparse
5
import gc
6
import json
7
import os
8
import time
9
from contextlib import nullcontext
10
from datetime import datetime
11
from itertools import product
12
from typing import Any, TypedDict
13
14
15
16
17

import ray
import torch
from ray.experimental.tqdm_ray import tqdm

18
19
20
21
from vllm.model_executor.layers.fused_moe.config import (
    FusedMoEQuantConfig,
    _get_config_dtype_str,
)
22
from vllm.model_executor.layers.fused_moe.fused_moe import *
23
from vllm.platforms import current_platform
24
from vllm.transformers_utils.config import get_config
25
from vllm.triton_utils import triton
26
from vllm.utils.argparse_utils import FlexibleArgumentParser
27

28
FP8_DTYPE = current_platform.fp8_dtype()
29

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# Default interval for clearing Triton JIT cache during tuning
# Set to 0 to disable automatic cache clearing
_CACHE_CLEAR_INTERVAL_ENV = "VLLM_MOE_TUNE_CACHE_CLEAR_INTERVAL"
TRITON_CACHE_CLEAR_INTERVAL = int(os.environ.get(_CACHE_CLEAR_INTERVAL_ENV, "50"))


def clear_triton_cache():
    """Clear Triton JIT compilation cache and Python/CUDA memory.

    This helps prevent OOM during tuning with large models (many experts).
    """
    # Force Python garbage collection
    gc.collect()

    # Clear CUDA memory cache
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # Try to clear Triton's runtime cache
    try:
        import triton

        if (
            hasattr(triton, "runtime")
            and hasattr(triton.runtime, "cache")
            and hasattr(triton.runtime.cache, "clear")
        ):
            triton.runtime.cache.clear()
    except ImportError:
        # Triton not installed, skip cache clearing
        pass
    except AttributeError:
        # Triton version doesn't have expected cache API
        pass
    except Exception as e:
        print(f"Warning: Failed to clear Triton cache: {e}")

    # Additional garbage collection after clearing caches
    gc.collect()

70

71
def ensure_divisibility(numerator, denominator, text):
72
    """Ensure that numerator is divisible by the denominator."""
73
74
    assert numerator % denominator == 0, "{} {} is not divisible by tp {}.".format(
        text, numerator, denominator
75
76
77
    )


78
79
80
81
82
83
84
85
86
class BenchmarkConfig(TypedDict):
    BLOCK_SIZE_M: int
    BLOCK_SIZE_N: int
    BLOCK_SIZE_K: int
    GROUP_SIZE_M: int
    num_warps: int
    num_stages: int


87
88
89
90
91
92
93
94
95
96
97
def benchmark_config(
    config: BenchmarkConfig,
    num_tokens: int,
    num_experts: int,
    shard_intermediate_size: int,
    hidden_size: int,
    topk: int,
    dtype: torch.dtype,
    use_fp8_w8a8: bool,
    use_int8_w8a16: bool,
    num_iters: int = 100,
98
    block_quant_shape: list[int] = None,
99
100
    use_deep_gemm: bool = False,
) -> float:
101
    init_dtype = torch.float16 if use_fp8_w8a8 else dtype
102
    x = torch.randn(num_tokens, hidden_size, dtype=dtype)
103
    if use_int8_w8a16:
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
        w1 = torch.randint(
            -127,
            127,
            (
                num_experts,
                shard_intermediate_size,
                hidden_size,
            ),
            dtype=torch.int8,
        )
        w2 = torch.randint(
            -127,
            127,
            (
                num_experts,
                hidden_size,
                shard_intermediate_size // 2,
            ),
            dtype=torch.int8,
        )
124
    else:
125
126
127
128
129
130
131
        w1 = torch.randn(
            num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
        )
        w2 = torch.randn(
            num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
        )
    gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32)
132
133
134
135
136

    w1_scale = None
    w2_scale = None
    a1_scale = None
    a2_scale = None
137
    if use_int8_w8a16:
138
139
140
        w1_scale = torch.randn(
            (num_experts, 2 * shard_intermediate_size), dtype=torch.float32
        )
141
        w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
142
143
144
    if use_deep_gemm:
        # we use the default block shape for deepgemm
        block_quant_shape = [128, 128]
145
    if use_fp8_w8a8:
146
147
148
149
150
151
152
153
154
155
        if block_quant_shape:
            block_n, block_k = block_quant_shape[0], block_quant_shape[1]
            E = num_experts
            N = shard_intermediate_size // 2
            K = hidden_size
            factor_for_scale = 1e-2
            n_tiles_w1 = (2 * N + block_n - 1) // block_n
            n_tiles_w2 = (K + block_n - 1) // block_n
            k_tiles_w1 = (K + block_k - 1) // block_k
            k_tiles_w2 = (N + block_k - 1) // block_k
156
157
158
159
160
161
162
163
            w1_scale = (
                torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32)
                * factor_for_scale
            )
            w2_scale = (
                torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32)
                * factor_for_scale
            )
164
165
166
167
        else:
            w1_scale = torch.randn(num_experts, dtype=torch.float32)
            w2_scale = torch.randn(num_experts, dtype=torch.float32)

168
169
170
        a1_scale = torch.randn(1, dtype=torch.float32)
        a2_scale = torch.randn(1, dtype=torch.float32)

171
172
        w1 = w1.to(FP8_DTYPE)
        w2 = w2.to(FP8_DTYPE)
173
174
175
176
177
178
179

    input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32)

    def prepare(i: int):
        input_gating.copy_(gating_output[i])

    def run():
180
        from vllm.model_executor.layers.fused_moe import override_config
181

182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
        if use_fp8_w8a8:
            quant_dtype = torch.float8_e4m3fn
        elif use_int8_w8a16:
            quant_dtype = torch.int8
        else:
            quant_dtype = None

        quant_config = FusedMoEQuantConfig.make(
            quant_dtype=quant_dtype,
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
            block_shape=block_quant_shape,
        )

198
        with override_config(config):
199
200
201
202
203
204
205
206
207
208
209
210
211
            topk_weights, topk_ids, token_expert_indices = fused_topk(
                x, input_gating, topk, renormalize=not use_deep_gemm
            )
            return fused_experts(
                x,
                w1,
                w2,
                topk_weights,
                topk_ids,
                inplace=True,
                quant_config=quant_config,
                allow_deep_gemm=use_deep_gemm,
            )
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228

    # JIT compilation & warmup
    run()
    torch.cuda.synchronize()

    # Capture 10 invocations with CUDA graph
    graph = torch.cuda.CUDAGraph()
    with torch.cuda.graph(graph):
        for _ in range(10):
            run()
    torch.cuda.synchronize()

    # Warmup
    for _ in range(5):
        graph.replay()
    torch.cuda.synchronize()

229
230
    start_event = torch.Event(enable_timing=True)
    end_event = torch.Event(enable_timing=True)
231

232
    latencies: list[float] = []
233
234
235
236
237
238
239
240
241
242
243
244
245
246
    for i in range(num_iters):
        prepare(i)
        torch.cuda.synchronize()

        start_event.record()
        graph.replay()
        end_event.record()
        end_event.synchronize()
        latencies.append(start_event.elapsed_time(end_event))
    avg = sum(latencies) / (num_iters * 10) * 1000  # us
    graph.reset()
    return avg


247
248
249
250
251
252
253
254
def get_rocm_tuning_space(use_fp16):
    block_mn_range = [16, 32, 64, 128, 256]
    block_k_range = [16, 32, 64, 128, 256]
    if not use_fp16:
        block_k_range.remove(16)  # BLOCK_K=16 not supported for fp8
    num_warps_range = [1, 2, 4, 8]
    group_m_range = [1, 4, 8, 16, 32]
    num_stage_range = [2]
255
    waves_per_eu_range = [0, 1, 2, 4]
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
    matrix_instr_nonkdim_range = [16, 32] if use_fp16 else []
    kpack_range = [1, 2] if use_fp16 else []

    param_ranges = {
        "BLOCK_SIZE_M": block_mn_range,
        "BLOCK_SIZE_N": block_mn_range,
        "BLOCK_SIZE_K": block_k_range,
        "GROUP_SIZE_M": group_m_range,
        "num_warps": num_warps_range,
        "num_stages": num_stage_range,
        "waves_per_eu": waves_per_eu_range,
    }
    if use_fp16:
        param_ranges["matrix_instr_nonkdim"] = matrix_instr_nonkdim_range
        param_ranges["kpack"] = kpack_range

    return param_ranges


275
def get_configs_compute_bound(use_fp16, block_quant_shape) -> list[dict[str, int]]:
276
    configs: list[BenchmarkConfig] = []
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303

    if current_platform.is_rocm():
        param_ranges = get_rocm_tuning_space(use_fp16)
    else:
        # Reduced search space for faster tuning.
        # TODO(woosuk): Increase the search space and use a performance model to
        # prune the search space.
        block_m_range = [16, 32, 64, 128, 256]
        block_n_range = [32, 64, 128, 256]
        block_k_range = [64, 128, 256]
        num_warps_range = [4, 8]
        group_m_range = [1, 16, 32, 64]
        num_stage_range = [2, 3, 4, 5]

        param_ranges = {
            "BLOCK_SIZE_M": block_m_range,
            "BLOCK_SIZE_N": block_n_range,
            "BLOCK_SIZE_K": block_k_range,
            "GROUP_SIZE_M": group_m_range,
            "num_warps": num_warps_range,
            "num_stages": num_stage_range,
        }

    keys, values = zip(*param_ranges.items())
    for config_values in product(*values):
        config = dict(zip(keys, config_values))
        configs.append(config)
304
305
306
307
308
309
310

    # Remove configs that are not compatible with fp8 block quantization
    # BLOCK_SIZE_K must be a multiple of block_k
    # BLOCK_SIZE_N must be a multiple of block_n
    if block_quant_shape is not None and not use_fp16:
        block_n, block_k = block_quant_shape[0], block_quant_shape[1]
        for config in configs[:]:
311
312
313
314
            if (
                config["BLOCK_SIZE_K"] % block_k != 0
                or config["BLOCK_SIZE_N"] % block_n != 0
            ):
315
                configs.remove(config)
316
317
318
    return configs


319
320
321
def prune_rocm_search_space(
    num_tokens, shard_intermediate_size, hidden_size, search_space, is_fp16, topk
):
322
323
    N1, K1 = shard_intermediate_size, hidden_size
    N2, K2 = hidden_size, shard_intermediate_size // 2
324
325
326
327
328
329
    pruned_space_1 = prune_rocm_configs(
        num_tokens * topk, N1, K1, search_space, is_fp16
    )
    pruned_space_2 = prune_rocm_configs(
        num_tokens * topk, N2, K2, search_space, is_fp16
    )
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
    search_space = merge_unique_dicts(pruned_space_1, pruned_space_2)
    return search_space


# The following code is inspired by ROCm/Triton GEMM tuning script:
# https://github.com/ROCm/triton/blob/triton-mlir/scripts/amd/gemm/tune_gemm.py#L89
def prune_rocm_configs(M, N, K, configs, is_fp16=True):
    pruned_configs = []
    elemBytes_a = 2 if is_fp16 else 1
    elemBytes_b = 2 if is_fp16 else 1

    mfma = 16 if M < 32 or N < 32 else 32

    # TODO (zhanglx): figure out the boundary between large and small gemms
    large_gemm = False
    if M >= 2048 and N >= 2048:
        large_gemm = True

    for config in configs:
        BLOCK_SIZE_M = config.get("BLOCK_SIZE_M")
        BLOCK_SIZE_N = config.get("BLOCK_SIZE_N")
        BLOCK_SIZE_K = config.get("BLOCK_SIZE_K")
        num_warps = config.get("num_warps")

        if is_fp16:
            matrix_instr_nonkdim = config.get("matrix_instr_nonkdim")
            if matrix_instr_nonkdim > mfma:
                continue
        if mfma == 4 and BLOCK_SIZE_K < 64:
            continue
        # some layouts could not work properly in case
        # number elements per thread is less 1
        if BLOCK_SIZE_M * BLOCK_SIZE_N < 64:
            continue
        SPLIT_K = config.get("SPLIT_K", 1)
        GROUP_M = config.get("GROUP_SIZE_M")
        if is_fp16:
367
368
369
370
            if (
                matrix_instr_nonkdim > BLOCK_SIZE_M
                or matrix_instr_nonkdim > BLOCK_SIZE_N
            ):
371
                continue
372
            if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M:
373
                continue
374
            if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N:
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
                continue
        # Skip BLOCK_SIZE that is too large compare to M/N
        # unless BLOCK_SIZE is already small enough
        if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16:
            continue
        if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16:
            continue
        # skip large split_k when not necessary
        if SPLIT_K != 1 and not need_split_k(M, N, K):
            continue
        # skip split_k that leads to EVEN_K = false
        leap = SPLIT_K * BLOCK_SIZE_K
        modv = K % leap
        if modv != 0:
            continue
        # skip large GROUP_M
        if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1:
            continue
        # out of shared memory resource
        # TODO (zhanglx): This does not consider the LDS usage in the epilogue
395
396
397
398
        LDS = (
            BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a
            + BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b
        )
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
        if LDS > 65536:
            continue
        # Skip small block sizes and num_warps for large gemm
        # For fp16 and f8, we want to only use BLOCK_SIZE >= 64
        if large_gemm:
            if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64:
                continue
            if BLOCK_SIZE_K < 64:
                continue
            if num_warps < 4:
                continue

        pruned_configs.append(config)

    return pruned_configs


def need_split_k(SIZE_M, SIZE_N, SIZE_K):
    return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024


def merge_unique_dicts(list1, list2):
    result = []
    combined_list = list1.copy()
    combined_list.extend(list2)
    for dictionary in combined_list:
        if dictionary not in result:
            result.append(dictionary)
    return result


430
431
432
433
@ray.remote(num_gpus=1)
class BenchmarkWorker:
    def __init__(self, seed: int) -> None:
        torch.set_default_device("cuda")
434
        current_platform.seed_everything(seed)
435
        self.seed = seed
436
437
438
439
        # Get the device ID to allocate tensors and kernels
        # on the respective GPU. This is required for Ray to work
        # correctly with multi-GPU tuning on the ROCm platform.
        self.device_id = int(ray.get_gpu_ids()[0])
440
441
442
443
444
445
446
447
448

    def benchmark(
        self,
        num_tokens: int,
        num_experts: int,
        shard_intermediate_size: int,
        hidden_size: int,
        topk: int,
        dtype: torch.dtype,
449
450
        use_fp8_w8a8: bool,
        use_int8_w8a16: bool,
451
        block_quant_shape: list[int] = None,
452
        use_deep_gemm: bool = False,
453
    ) -> tuple[dict[str, int], float]:
454
        current_platform.seed_everything(self.seed)
455
        dtype_str = _get_config_dtype_str(
456
457
            dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
        )
458
459
        # NOTE(woosuk): The current naming convention uses w2.shape[2], which
        # is the intermediate size after silu_and_mul.
460
461
        block_n = block_quant_shape[0] if block_quant_shape else None
        block_k = block_quant_shape[1] if block_quant_shape else None
462
        op_config = get_moe_configs(
463
            num_experts, shard_intermediate_size // 2, dtype_str, block_n, block_k
464
        )
465
        if op_config is None:
466
467
468
469
470
471
472
            config = get_default_config(
                num_tokens,
                num_experts,
                shard_intermediate_size,
                hidden_size,
                topk,
                dtype_str,
473
                block_quant_shape,
474
            )
475
        else:
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
            config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))]
        kernel_time = benchmark_config(
            config,
            num_tokens,
            num_experts,
            shard_intermediate_size,
            hidden_size,
            topk,
            dtype,
            use_fp8_w8a8,
            use_int8_w8a16,
            num_iters=100,
            block_quant_shape=block_quant_shape,
            use_deep_gemm=use_deep_gemm,
        )
491
492
493
494
495
496
497
498
499
500
        return config, kernel_time

    def tune(
        self,
        num_tokens: int,
        num_experts: int,
        shard_intermediate_size: int,
        hidden_size: int,
        topk: int,
        dtype: torch.dtype,
501
502
        use_fp8_w8a8: bool,
        use_int8_w8a16: bool,
503
        search_space: list[dict[str, int]],
504
        block_quant_shape: list[int],
505
        use_deep_gemm: bool,
506
    ) -> dict[str, int]:
507
508
        best_config = None
        best_time = float("inf")
509
510
        if current_platform.is_rocm():
            is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
511
512
513
514
515
516
517
518
            search_space = prune_rocm_search_space(
                num_tokens,
                shard_intermediate_size,
                hidden_size,
                search_space,
                is_fp16,
                topk,
            )
519

520
521
522
523
524
525
        need_device_guard = False
        if current_platform.is_rocm():
            visible_device = os.environ.get("ROCR_VISIBLE_DEVICES", None)
            if visible_device != f"{self.device_id}":
                need_device_guard = True

526
        with torch.cuda.device(self.device_id) if need_device_guard else nullcontext():
527
            for idx, config in enumerate(tqdm(search_space)):
528
                try:
529
530
531
532
533
534
535
536
537
538
539
                    kernel_time = benchmark_config(
                        config,
                        num_tokens,
                        num_experts,
                        shard_intermediate_size,
                        hidden_size,
                        topk,
                        dtype,
                        use_fp8_w8a8,
                        use_int8_w8a16,
                        num_iters=20,
540
                        block_quant_shape=block_quant_shape,
541
542
                        use_deep_gemm=use_deep_gemm,
                    )
543
544
545
546
547
548
549
                except triton.runtime.autotuner.OutOfResources:
                    # Some configurations may be invalid and fail to compile.
                    continue

                if kernel_time < best_time:
                    best_time = kernel_time
                    best_config = config
550
551
552
553
554
555
556
557
558
559
560
561
562

                # Periodically clear Triton JIT cache to prevent OOM
                # This is especially important for large models with many experts
                if (
                    TRITON_CACHE_CLEAR_INTERVAL > 0
                    and idx > 0
                    and idx % TRITON_CACHE_CLEAR_INTERVAL == 0
                ):
                    clear_triton_cache()

        # Final cleanup after tuning completes
        clear_triton_cache()

563
564
        now = datetime.now()
        print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
565
        assert best_config is not None
566
567
568
        return best_config


569
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
570
    return {
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
        "BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
        "BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
        "BLOCK_SIZE_K": config["BLOCK_SIZE_K"],
        "GROUP_SIZE_M": config["GROUP_SIZE_M"],
        "num_warps": config["num_warps"],
        "num_stages": config["num_stages"],
        **(
            {"waves_per_eu": config["waves_per_eu"]} if "waves_per_eu" in config else {}
        ),
        **(
            {"matrix_instr_nonkdim": config["matrix_instr_nonkdim"]}
            if "matrix_instr_nonkdim" in config
            else {}
        ),
        **({"kpack": config["kpack"]} if "kpack" in config else {}),
586
587
588
    }


589
590
591
592
593
594
595
596
597
def save_configs(
    configs: dict[int, BenchmarkConfig],
    num_experts: int,
    shard_intermediate_size: int,
    hidden_size: int,
    topk: int,
    dtype: torch.dtype,
    use_fp8_w8a8: bool,
    use_int8_w8a16: bool,
598
    block_quant_shape: list[int],
599
    save_dir: str,
600
) -> None:
601
    dtype_str = _get_config_dtype_str(
602
603
        dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
    )
604

605
606
    # NOTE(woosuk): The current naming convention uses w2.shape[2], which
    # is the intermediate size after silu_and_mul.
607
608
609
    filename = get_config_file_name(
        num_experts, shard_intermediate_size // 2, dtype_str, block_quant_shape
    )
610
611
    os.makedirs(save_dir, exist_ok=True)
    filename = os.path.join(save_dir, filename)
612
613
    print(f"Writing best config to {filename}...")
    with open(filename, "w") as f:
614
        json.dump({"triton_version": triton.__version__, **configs}, f, indent=4)
615
616
617
        f.write("\n")


618
def get_weight_block_size_safety(config, default_value=None):
619
    quantization_config = getattr(config, "quantization_config", {})
620
    if isinstance(quantization_config, dict):
621
        return quantization_config.get("weight_block_size", default_value)
622
623
624
    return default_value


625
626
def main(args: argparse.Namespace):
    print(args)
627

628
    config = get_config(model=args.model, trust_remote_code=args.trust_remote_code)
629
630
631
    if args.model_prefix:
        config = getattr(config, args.model_prefix)

632
633
634
635
    if config.architectures[0] == "DbrxForCausalLM":
        E = config.ffn_config.moe_num_experts
        topk = config.ffn_config.moe_top_k
        intermediate_size = config.ffn_config.ffn_hidden_size
636
        hidden_size = config.hidden_size
637
638
639
640
    elif config.architectures[0] == "JambaForCausalLM":
        E = config.num_experts
        topk = config.num_experts_per_tok
        intermediate_size = config.intermediate_size
641
        hidden_size = config.hidden_size
Yuxuan Zhang's avatar
Yuxuan Zhang committed
642
643
    elif config.architectures[0] in (
        "DeepseekV2ForCausalLM",
644
645
        "DeepseekV3ForCausalLM",
        "DeepseekV32ForCausalLM",
Yuxuan Zhang's avatar
Yuxuan Zhang committed
646
        "Glm4MoeForCausalLM",
647
        "NemotronHForCausalLM",
Yuxuan Zhang's avatar
Yuxuan Zhang committed
648
    ):
649
650
651
        E = config.n_routed_experts
        topk = config.num_experts_per_tok
        intermediate_size = config.moe_intermediate_size
652
        hidden_size = config.hidden_size
653
654
655
656
657
    elif config.architectures[0] in (
        "Qwen2MoeForCausalLM",
        "Qwen3MoeForCausalLM",
        "Qwen3NextForCausalLM",
    ):
658
659
660
        E = config.num_experts
        topk = config.num_experts_per_tok
        intermediate_size = config.moe_intermediate_size
661
662
663
664
665
666
667
        hidden_size = config.hidden_size
    elif config.architectures[0] == "Qwen3VLMoeForConditionalGeneration":
        text_config = config.get_text_config()
        E = text_config.num_experts
        topk = text_config.num_experts_per_tok
        intermediate_size = text_config.moe_intermediate_size
        hidden_size = text_config.hidden_size
668
669
670
671
    elif config.architectures[0] in ("HunYuanMoEV1ForCausalLM"):
        E = config.num_experts
        topk = config.moe_topk[0]
        intermediate_size = config.moe_intermediate_size[0]
672
        hidden_size = config.hidden_size
673
674
675
676
677
    elif config.architectures[0] in ["Qwen3OmniMoeForConditionalGeneration"]:
        E = config.thinker_config.text_config.num_experts
        topk = config.thinker_config.text_config.num_experts_per_tok
        intermediate_size = config.thinker_config.text_config.moe_intermediate_size
        hidden_size = config.thinker_config.text_config.hidden_size
678
    else:
679
680
        # Support for llama4
        config = config.get_text_config()
681
682
683
684
        # Default: Mixtral.
        E = config.num_local_experts
        topk = config.num_experts_per_tok
        intermediate_size = config.intermediate_size
685
        hidden_size = config.hidden_size
686
687
688
689
690
691
692
    enable_ep = bool(args.enable_expert_parallel)
    if enable_ep:
        ensure_divisibility(E, args.tp_size, "Number of experts")
        E = E // args.tp_size
        shard_intermediate_size = 2 * intermediate_size
    else:
        ensure_divisibility(intermediate_size, args.tp_size, "intermediate_size")
693
        shard_intermediate_size = 2 * intermediate_size // args.tp_size
694
    dtype = torch.float16 if current_platform.is_rocm() else config.dtype
695
696
    use_fp8_w8a8 = args.dtype == "fp8_w8a8"
    use_int8_w8a16 = args.dtype == "int8_w8a16"
697
    block_quant_shape = get_weight_block_size_safety(config)
698
699

    if args.batch_size is None:
700
        batch_sizes = [
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
            1,
            2,
            4,
            8,
            16,
            24,
            32,
            48,
            64,
            96,
            128,
            256,
            512,
            1024,
            1536,
            2048,
            3072,
            4096,
719
        ]
720
    else:
721
        batch_sizes = args.batch_size
722

723
724
    use_deep_gemm = bool(args.use_deep_gemm)

725
726
727
728
    if current_platform.is_rocm() and "HIP_VISIBLE_DEVICES" in os.environ:
        # Ray will set ROCR_VISIBLE_DEVICES for device visibility
        logger.warning(
            "Ray uses ROCR_VISIBLE_DEVICES to control device accessibility."
729
730
            "Replacing HIP_VISIBLE_DEVICES with ROCR_VISIBLE_DEVICES."
        )
731
732
733
734
        val = os.environ["HIP_VISIBLE_DEVICES"]
        os.environ["ROCR_VISIBLE_DEVICES"] = val
        del os.environ["HIP_VISIBLE_DEVICES"]

735
736
737
738
    ray.init()
    num_gpus = int(ray.available_resources()["GPU"])
    workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]

739
    def _distribute(method: str, inputs: list[Any]) -> list[Any]:
740
741
742
743
744
745
746
747
748
749
750
        outputs = []
        worker_idx = 0
        for input_args in inputs:
            worker = workers[worker_idx]
            worker_method = getattr(worker, method)
            output = worker_method.remote(*input_args)
            outputs.append(output)
            worker_idx = (worker_idx + 1) % num_gpus
        return ray.get(outputs)

    if args.tune:
751
        is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
752
        search_space = get_configs_compute_bound(is_fp16, block_quant_shape)
753
        print(f"Start tuning over {len(search_space)} configurations...")
754
755
756
757
758
        if use_deep_gemm:
            raise ValueError(
                "Tuning with --use-deep-gemm is not supported as it only tunes Triton "
                "kernels. Please remove the flag."
            )
759
760
        start = time.time()
        configs = _distribute(
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
            "tune",
            [
                (
                    batch_size,
                    E,
                    shard_intermediate_size,
                    hidden_size,
                    topk,
                    dtype,
                    use_fp8_w8a8,
                    use_int8_w8a16,
                    search_space,
                    block_quant_shape,
                    use_deep_gemm,
                )
                for batch_size in batch_sizes
            ],
        )
779
        best_configs = {
780
            M: sort_config(config) for M, config in zip(batch_sizes, configs)
781
        }
782
783
784
785
786
787
788
789
790
791
        save_configs(
            best_configs,
            E,
            shard_intermediate_size,
            hidden_size,
            topk,
            dtype,
            use_fp8_w8a8,
            use_int8_w8a16,
            block_quant_shape,
792
            args.save_dir,
793
        )
794
795
796
        end = time.time()
        print(f"Tuning took {end - start:.2f} seconds")
    else:
797
        outputs = _distribute(
798
            "benchmark",
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
            [
                (
                    batch_size,
                    E,
                    shard_intermediate_size,
                    hidden_size,
                    topk,
                    dtype,
                    use_fp8_w8a8,
                    use_int8_w8a16,
                    block_quant_shape,
                    use_deep_gemm,
                )
                for batch_size in batch_sizes
            ],
        )
815
816
817
818
819
820
821

        for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
            print(f"Batch size: {batch_size}, config: {config}")
            print(f"Kernel time: {kernel_time:.2f} us")


if __name__ == "__main__":
822
    parser = FlexibleArgumentParser()
823
824
825
826
827
828
    parser.add_argument(
        "--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
    )
    parser.add_argument(
        "--tp-size", "-tp", "--tensor-parallel-size", type=int, default=2
    )
829
    parser.add_argument("--enable-expert-parallel", "-enable-ep", action="store_true")
830
831
832
    parser.add_argument(
        "--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto"
    )
833
    parser.add_argument("--use-deep-gemm", action="store_true")
834
835
836
    parser.add_argument(
        "--save-dir", type=str, default="./", help="Directory to save tuned results"
    )
837
    parser.add_argument("--seed", type=int, default=0)
838
    parser.add_argument("--batch-size", type=int, nargs="+", required=False)
839
    parser.add_argument("--tune", action="store_true")
840
    parser.add_argument("--trust-remote-code", action="store_true")
841
    parser.add_argument("--model-prefix", type=str, required=False)
842
843
844
    args = parser.parse_args()

    main(args)