benchmark_moe.py 26 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import argparse
5
import json
6
import os
7
import time
8
from contextlib import nullcontext
9
from datetime import datetime
10
from itertools import product
11
from typing import Any, TypedDict
12
13
14
15
16

import ray
import torch
from ray.experimental.tqdm_ray import tqdm

17
18
19
20
from vllm.model_executor.layers.fused_moe.config import (
    FusedMoEQuantConfig,
    _get_config_dtype_str,
)
21
from vllm.model_executor.layers.fused_moe.fused_moe import *
22
from vllm.platforms import current_platform
23
from vllm.transformers_utils.config import get_config
24
from vllm.triton_utils import triton
25
from vllm.utils.argparse_utils import FlexibleArgumentParser
26

27
FP8_DTYPE = current_platform.fp8_dtype()
28
29


30
def ensure_divisibility(numerator, denominator, text):
31
    """Ensure that numerator is divisible by the denominator."""
32
33
    assert numerator % denominator == 0, "{} {} is not divisible by tp {}.".format(
        text, numerator, denominator
34
35
36
    )


37
38
39
40
41
42
43
44
45
class BenchmarkConfig(TypedDict):
    BLOCK_SIZE_M: int
    BLOCK_SIZE_N: int
    BLOCK_SIZE_K: int
    GROUP_SIZE_M: int
    num_warps: int
    num_stages: int


46
47
48
49
50
51
52
53
54
55
56
def benchmark_config(
    config: BenchmarkConfig,
    num_tokens: int,
    num_experts: int,
    shard_intermediate_size: int,
    hidden_size: int,
    topk: int,
    dtype: torch.dtype,
    use_fp8_w8a8: bool,
    use_int8_w8a16: bool,
    num_iters: int = 100,
57
    block_quant_shape: list[int] = None,
58
59
    use_deep_gemm: bool = False,
) -> float:
60
    init_dtype = torch.float16 if use_fp8_w8a8 else dtype
61
    x = torch.randn(num_tokens, hidden_size, dtype=dtype)
62
    if use_int8_w8a16:
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
        w1 = torch.randint(
            -127,
            127,
            (
                num_experts,
                shard_intermediate_size,
                hidden_size,
            ),
            dtype=torch.int8,
        )
        w2 = torch.randint(
            -127,
            127,
            (
                num_experts,
                hidden_size,
                shard_intermediate_size // 2,
            ),
            dtype=torch.int8,
        )
83
    else:
84
85
86
87
88
89
90
        w1 = torch.randn(
            num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
        )
        w2 = torch.randn(
            num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
        )
    gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32)
91
92
93
94
95

    w1_scale = None
    w2_scale = None
    a1_scale = None
    a2_scale = None
96
    if use_int8_w8a16:
97
98
99
        w1_scale = torch.randn(
            (num_experts, 2 * shard_intermediate_size), dtype=torch.float32
        )
100
        w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
101
102
103
    if use_deep_gemm:
        # we use the default block shape for deepgemm
        block_quant_shape = [128, 128]
104
    if use_fp8_w8a8:
105
106
107
108
109
110
111
112
113
114
        if block_quant_shape:
            block_n, block_k = block_quant_shape[0], block_quant_shape[1]
            E = num_experts
            N = shard_intermediate_size // 2
            K = hidden_size
            factor_for_scale = 1e-2
            n_tiles_w1 = (2 * N + block_n - 1) // block_n
            n_tiles_w2 = (K + block_n - 1) // block_n
            k_tiles_w1 = (K + block_k - 1) // block_k
            k_tiles_w2 = (N + block_k - 1) // block_k
115
116
117
118
119
120
121
122
            w1_scale = (
                torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32)
                * factor_for_scale
            )
            w2_scale = (
                torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32)
                * factor_for_scale
            )
123
124
125
126
        else:
            w1_scale = torch.randn(num_experts, dtype=torch.float32)
            w2_scale = torch.randn(num_experts, dtype=torch.float32)

127
128
129
        a1_scale = torch.randn(1, dtype=torch.float32)
        a2_scale = torch.randn(1, dtype=torch.float32)

130
131
        w1 = w1.to(FP8_DTYPE)
        w2 = w2.to(FP8_DTYPE)
132
133
134
135
136
137
138

    input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32)

    def prepare(i: int):
        input_gating.copy_(gating_output[i])

    def run():
139
        from vllm.model_executor.layers.fused_moe import override_config
140

141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
        if use_fp8_w8a8:
            quant_dtype = torch.float8_e4m3fn
        elif use_int8_w8a16:
            quant_dtype = torch.int8
        else:
            quant_dtype = None

        quant_config = FusedMoEQuantConfig.make(
            quant_dtype=quant_dtype,
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
            block_shape=block_quant_shape,
        )

157
        with override_config(config):
158
159
160
161
162
163
164
165
166
167
168
169
170
            topk_weights, topk_ids, token_expert_indices = fused_topk(
                x, input_gating, topk, renormalize=not use_deep_gemm
            )
            return fused_experts(
                x,
                w1,
                w2,
                topk_weights,
                topk_ids,
                inplace=True,
                quant_config=quant_config,
                allow_deep_gemm=use_deep_gemm,
            )
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190

    # JIT compilation & warmup
    run()
    torch.cuda.synchronize()

    # Capture 10 invocations with CUDA graph
    graph = torch.cuda.CUDAGraph()
    with torch.cuda.graph(graph):
        for _ in range(10):
            run()
    torch.cuda.synchronize()

    # Warmup
    for _ in range(5):
        graph.replay()
    torch.cuda.synchronize()

    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)

191
    latencies: list[float] = []
192
193
194
195
196
197
198
199
200
201
202
203
204
205
    for i in range(num_iters):
        prepare(i)
        torch.cuda.synchronize()

        start_event.record()
        graph.replay()
        end_event.record()
        end_event.synchronize()
        latencies.append(start_event.elapsed_time(end_event))
    avg = sum(latencies) / (num_iters * 10) * 1000  # us
    graph.reset()
    return avg


206
207
208
209
210
211
212
213
def get_rocm_tuning_space(use_fp16):
    block_mn_range = [16, 32, 64, 128, 256]
    block_k_range = [16, 32, 64, 128, 256]
    if not use_fp16:
        block_k_range.remove(16)  # BLOCK_K=16 not supported for fp8
    num_warps_range = [1, 2, 4, 8]
    group_m_range = [1, 4, 8, 16, 32]
    num_stage_range = [2]
214
    waves_per_eu_range = [0, 1, 2, 4]
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
    matrix_instr_nonkdim_range = [16, 32] if use_fp16 else []
    kpack_range = [1, 2] if use_fp16 else []

    param_ranges = {
        "BLOCK_SIZE_M": block_mn_range,
        "BLOCK_SIZE_N": block_mn_range,
        "BLOCK_SIZE_K": block_k_range,
        "GROUP_SIZE_M": group_m_range,
        "num_warps": num_warps_range,
        "num_stages": num_stage_range,
        "waves_per_eu": waves_per_eu_range,
    }
    if use_fp16:
        param_ranges["matrix_instr_nonkdim"] = matrix_instr_nonkdim_range
        param_ranges["kpack"] = kpack_range

    return param_ranges


234
def get_configs_compute_bound(use_fp16, block_quant_shape) -> list[dict[str, int]]:
235
    configs: list[BenchmarkConfig] = []
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262

    if current_platform.is_rocm():
        param_ranges = get_rocm_tuning_space(use_fp16)
    else:
        # Reduced search space for faster tuning.
        # TODO(woosuk): Increase the search space and use a performance model to
        # prune the search space.
        block_m_range = [16, 32, 64, 128, 256]
        block_n_range = [32, 64, 128, 256]
        block_k_range = [64, 128, 256]
        num_warps_range = [4, 8]
        group_m_range = [1, 16, 32, 64]
        num_stage_range = [2, 3, 4, 5]

        param_ranges = {
            "BLOCK_SIZE_M": block_m_range,
            "BLOCK_SIZE_N": block_n_range,
            "BLOCK_SIZE_K": block_k_range,
            "GROUP_SIZE_M": group_m_range,
            "num_warps": num_warps_range,
            "num_stages": num_stage_range,
        }

    keys, values = zip(*param_ranges.items())
    for config_values in product(*values):
        config = dict(zip(keys, config_values))
        configs.append(config)
263
264
265
266
267
268
269

    # Remove configs that are not compatible with fp8 block quantization
    # BLOCK_SIZE_K must be a multiple of block_k
    # BLOCK_SIZE_N must be a multiple of block_n
    if block_quant_shape is not None and not use_fp16:
        block_n, block_k = block_quant_shape[0], block_quant_shape[1]
        for config in configs[:]:
270
271
272
273
            if (
                config["BLOCK_SIZE_K"] % block_k != 0
                or config["BLOCK_SIZE_N"] % block_n != 0
            ):
274
                configs.remove(config)
275
276
277
    return configs


278
279
280
def prune_rocm_search_space(
    num_tokens, shard_intermediate_size, hidden_size, search_space, is_fp16, topk
):
281
282
    N1, K1 = shard_intermediate_size, hidden_size
    N2, K2 = hidden_size, shard_intermediate_size // 2
283
284
285
286
287
288
    pruned_space_1 = prune_rocm_configs(
        num_tokens * topk, N1, K1, search_space, is_fp16
    )
    pruned_space_2 = prune_rocm_configs(
        num_tokens * topk, N2, K2, search_space, is_fp16
    )
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
    search_space = merge_unique_dicts(pruned_space_1, pruned_space_2)
    return search_space


# The following code is inspired by ROCm/Triton GEMM tuning script:
# https://github.com/ROCm/triton/blob/triton-mlir/scripts/amd/gemm/tune_gemm.py#L89
def prune_rocm_configs(M, N, K, configs, is_fp16=True):
    pruned_configs = []
    elemBytes_a = 2 if is_fp16 else 1
    elemBytes_b = 2 if is_fp16 else 1

    mfma = 16 if M < 32 or N < 32 else 32

    # TODO (zhanglx): figure out the boundary between large and small gemms
    large_gemm = False
    if M >= 2048 and N >= 2048:
        large_gemm = True

    for config in configs:
        BLOCK_SIZE_M = config.get("BLOCK_SIZE_M")
        BLOCK_SIZE_N = config.get("BLOCK_SIZE_N")
        BLOCK_SIZE_K = config.get("BLOCK_SIZE_K")
        num_warps = config.get("num_warps")

        if is_fp16:
            matrix_instr_nonkdim = config.get("matrix_instr_nonkdim")
            if matrix_instr_nonkdim > mfma:
                continue
        if mfma == 4 and BLOCK_SIZE_K < 64:
            continue
        # some layouts could not work properly in case
        # number elements per thread is less 1
        if BLOCK_SIZE_M * BLOCK_SIZE_N < 64:
            continue
        SPLIT_K = config.get("SPLIT_K", 1)
        GROUP_M = config.get("GROUP_SIZE_M")
        if is_fp16:
326
327
328
329
            if (
                matrix_instr_nonkdim > BLOCK_SIZE_M
                or matrix_instr_nonkdim > BLOCK_SIZE_N
            ):
330
                continue
331
            if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M:
332
                continue
333
            if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N:
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
                continue
        # Skip BLOCK_SIZE that is too large compare to M/N
        # unless BLOCK_SIZE is already small enough
        if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16:
            continue
        if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16:
            continue
        # skip large split_k when not necessary
        if SPLIT_K != 1 and not need_split_k(M, N, K):
            continue
        # skip split_k that leads to EVEN_K = false
        leap = SPLIT_K * BLOCK_SIZE_K
        modv = K % leap
        if modv != 0:
            continue
        # skip large GROUP_M
        if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1:
            continue
        # out of shared memory resource
        # TODO (zhanglx): This does not consider the LDS usage in the epilogue
354
355
356
357
        LDS = (
            BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a
            + BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b
        )
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
        if LDS > 65536:
            continue
        # Skip small block sizes and num_warps for large gemm
        # For fp16 and f8, we want to only use BLOCK_SIZE >= 64
        if large_gemm:
            if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64:
                continue
            if BLOCK_SIZE_K < 64:
                continue
            if num_warps < 4:
                continue

        pruned_configs.append(config)

    return pruned_configs


def need_split_k(SIZE_M, SIZE_N, SIZE_K):
    return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024


def merge_unique_dicts(list1, list2):
    result = []
    combined_list = list1.copy()
    combined_list.extend(list2)
    for dictionary in combined_list:
        if dictionary not in result:
            result.append(dictionary)
    return result


389
390
391
392
@ray.remote(num_gpus=1)
class BenchmarkWorker:
    def __init__(self, seed: int) -> None:
        torch.set_default_device("cuda")
393
        current_platform.seed_everything(seed)
394
        self.seed = seed
395
396
397
398
        # Get the device ID to allocate tensors and kernels
        # on the respective GPU. This is required for Ray to work
        # correctly with multi-GPU tuning on the ROCm platform.
        self.device_id = int(ray.get_gpu_ids()[0])
399
400
401
402
403
404
405
406
407

    def benchmark(
        self,
        num_tokens: int,
        num_experts: int,
        shard_intermediate_size: int,
        hidden_size: int,
        topk: int,
        dtype: torch.dtype,
408
409
        use_fp8_w8a8: bool,
        use_int8_w8a16: bool,
410
        block_quant_shape: list[int] = None,
411
        use_deep_gemm: bool = False,
412
    ) -> tuple[dict[str, int], float]:
413
        current_platform.seed_everything(self.seed)
414
        dtype_str = _get_config_dtype_str(
415
416
            dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
        )
417
418
        # NOTE(woosuk): The current naming convention uses w2.shape[2], which
        # is the intermediate size after silu_and_mul.
419
420
        block_n = block_quant_shape[0] if block_quant_shape else None
        block_k = block_quant_shape[1] if block_quant_shape else None
421
        op_config = get_moe_configs(
422
            num_experts, shard_intermediate_size // 2, dtype_str, block_n, block_k
423
        )
424
        if op_config is None:
425
426
427
428
429
430
431
            config = get_default_config(
                num_tokens,
                num_experts,
                shard_intermediate_size,
                hidden_size,
                topk,
                dtype_str,
432
                block_quant_shape,
433
            )
434
        else:
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
            config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))]
        kernel_time = benchmark_config(
            config,
            num_tokens,
            num_experts,
            shard_intermediate_size,
            hidden_size,
            topk,
            dtype,
            use_fp8_w8a8,
            use_int8_w8a16,
            num_iters=100,
            block_quant_shape=block_quant_shape,
            use_deep_gemm=use_deep_gemm,
        )
450
451
452
453
454
455
456
457
458
459
        return config, kernel_time

    def tune(
        self,
        num_tokens: int,
        num_experts: int,
        shard_intermediate_size: int,
        hidden_size: int,
        topk: int,
        dtype: torch.dtype,
460
461
        use_fp8_w8a8: bool,
        use_int8_w8a16: bool,
462
        search_space: list[dict[str, int]],
463
        block_quant_shape: list[int],
464
        use_deep_gemm: bool,
465
    ) -> dict[str, int]:
466
467
        best_config = None
        best_time = float("inf")
468
469
        if current_platform.is_rocm():
            is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
470
471
472
473
474
475
476
477
            search_space = prune_rocm_search_space(
                num_tokens,
                shard_intermediate_size,
                hidden_size,
                search_space,
                is_fp16,
                topk,
            )
478

479
480
481
482
483
484
        need_device_guard = False
        if current_platform.is_rocm():
            visible_device = os.environ.get("ROCR_VISIBLE_DEVICES", None)
            if visible_device != f"{self.device_id}":
                need_device_guard = True

485
        with torch.cuda.device(self.device_id) if need_device_guard else nullcontext():
486
487
            for config in tqdm(search_space):
                try:
488
489
490
491
492
493
494
495
496
497
498
                    kernel_time = benchmark_config(
                        config,
                        num_tokens,
                        num_experts,
                        shard_intermediate_size,
                        hidden_size,
                        topk,
                        dtype,
                        use_fp8_w8a8,
                        use_int8_w8a16,
                        num_iters=20,
499
                        block_quant_shape=block_quant_shape,
500
501
                        use_deep_gemm=use_deep_gemm,
                    )
502
503
504
505
506
507
508
                except triton.runtime.autotuner.OutOfResources:
                    # Some configurations may be invalid and fail to compile.
                    continue

                if kernel_time < best_time:
                    best_time = kernel_time
                    best_config = config
509
510
        now = datetime.now()
        print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
511
        assert best_config is not None
512
513
514
        return best_config


515
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
516
    return {
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
        "BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
        "BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
        "BLOCK_SIZE_K": config["BLOCK_SIZE_K"],
        "GROUP_SIZE_M": config["GROUP_SIZE_M"],
        "num_warps": config["num_warps"],
        "num_stages": config["num_stages"],
        **(
            {"waves_per_eu": config["waves_per_eu"]} if "waves_per_eu" in config else {}
        ),
        **(
            {"matrix_instr_nonkdim": config["matrix_instr_nonkdim"]}
            if "matrix_instr_nonkdim" in config
            else {}
        ),
        **({"kpack": config["kpack"]} if "kpack" in config else {}),
532
533
534
    }


535
536
537
538
539
540
541
542
543
def save_configs(
    configs: dict[int, BenchmarkConfig],
    num_experts: int,
    shard_intermediate_size: int,
    hidden_size: int,
    topk: int,
    dtype: torch.dtype,
    use_fp8_w8a8: bool,
    use_int8_w8a16: bool,
544
    block_quant_shape: list[int],
545
    save_dir: str,
546
) -> None:
547
    dtype_str = _get_config_dtype_str(
548
549
        dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
    )
550

551
552
    # NOTE(woosuk): The current naming convention uses w2.shape[2], which
    # is the intermediate size after silu_and_mul.
553
554
555
    filename = get_config_file_name(
        num_experts, shard_intermediate_size // 2, dtype_str, block_quant_shape
    )
556
557
    os.makedirs(save_dir, exist_ok=True)
    filename = os.path.join(save_dir, filename)
558
559
    print(f"Writing best config to {filename}...")
    with open(filename, "w") as f:
560
        json.dump({"triton_version": triton.__version__, **configs}, f, indent=4)
561
562
563
        f.write("\n")


564
def get_weight_block_size_safety(config, default_value=None):
565
    quantization_config = getattr(config, "quantization_config", {})
566
    if isinstance(quantization_config, dict):
567
        return quantization_config.get("weight_block_size", default_value)
568
569
570
    return default_value


571
572
def main(args: argparse.Namespace):
    print(args)
573

574
    config = get_config(model=args.model, trust_remote_code=args.trust_remote_code)
575
576
577
    if args.model_prefix:
        config = getattr(config, args.model_prefix)

578
579
580
581
    if config.architectures[0] == "DbrxForCausalLM":
        E = config.ffn_config.moe_num_experts
        topk = config.ffn_config.moe_top_k
        intermediate_size = config.ffn_config.ffn_hidden_size
582
        hidden_size = config.hidden_size
583
584
585
586
    elif config.architectures[0] == "JambaForCausalLM":
        E = config.num_experts
        topk = config.num_experts_per_tok
        intermediate_size = config.intermediate_size
587
        hidden_size = config.hidden_size
Yuxuan Zhang's avatar
Yuxuan Zhang committed
588
589
    elif config.architectures[0] in (
        "DeepseekV2ForCausalLM",
590
591
        "DeepseekV3ForCausalLM",
        "DeepseekV32ForCausalLM",
Yuxuan Zhang's avatar
Yuxuan Zhang committed
592
        "Glm4MoeForCausalLM",
593
        "NemotronHForCausalLM",
Yuxuan Zhang's avatar
Yuxuan Zhang committed
594
    ):
595
596
597
        E = config.n_routed_experts
        topk = config.num_experts_per_tok
        intermediate_size = config.moe_intermediate_size
598
        hidden_size = config.hidden_size
599
600
601
602
603
    elif config.architectures[0] in (
        "Qwen2MoeForCausalLM",
        "Qwen3MoeForCausalLM",
        "Qwen3NextForCausalLM",
    ):
604
605
606
        E = config.num_experts
        topk = config.num_experts_per_tok
        intermediate_size = config.moe_intermediate_size
607
608
609
610
611
612
613
        hidden_size = config.hidden_size
    elif config.architectures[0] == "Qwen3VLMoeForConditionalGeneration":
        text_config = config.get_text_config()
        E = text_config.num_experts
        topk = text_config.num_experts_per_tok
        intermediate_size = text_config.moe_intermediate_size
        hidden_size = text_config.hidden_size
614
615
616
617
    elif config.architectures[0] in ("HunYuanMoEV1ForCausalLM"):
        E = config.num_experts
        topk = config.moe_topk[0]
        intermediate_size = config.moe_intermediate_size[0]
618
        hidden_size = config.hidden_size
619
620
621
622
623
    elif config.architectures[0] in ["Qwen3OmniMoeForConditionalGeneration"]:
        E = config.thinker_config.text_config.num_experts
        topk = config.thinker_config.text_config.num_experts_per_tok
        intermediate_size = config.thinker_config.text_config.moe_intermediate_size
        hidden_size = config.thinker_config.text_config.hidden_size
624
    else:
625
626
        # Support for llama4
        config = config.get_text_config()
627
628
629
630
        # Default: Mixtral.
        E = config.num_local_experts
        topk = config.num_experts_per_tok
        intermediate_size = config.intermediate_size
631
        hidden_size = config.hidden_size
632
633
634
635
636
637
638
    enable_ep = bool(args.enable_expert_parallel)
    if enable_ep:
        ensure_divisibility(E, args.tp_size, "Number of experts")
        E = E // args.tp_size
        shard_intermediate_size = 2 * intermediate_size
    else:
        ensure_divisibility(intermediate_size, args.tp_size, "intermediate_size")
639
        shard_intermediate_size = 2 * intermediate_size // args.tp_size
640
    dtype = torch.float16 if current_platform.is_rocm() else config.dtype
641
642
    use_fp8_w8a8 = args.dtype == "fp8_w8a8"
    use_int8_w8a16 = args.dtype == "int8_w8a16"
643
    block_quant_shape = get_weight_block_size_safety(config)
644
645

    if args.batch_size is None:
646
        batch_sizes = [
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
            1,
            2,
            4,
            8,
            16,
            24,
            32,
            48,
            64,
            96,
            128,
            256,
            512,
            1024,
            1536,
            2048,
            3072,
            4096,
665
        ]
666
    else:
667
        batch_sizes = args.batch_size
668

669
670
    use_deep_gemm = bool(args.use_deep_gemm)

671
672
673
674
    if current_platform.is_rocm() and "HIP_VISIBLE_DEVICES" in os.environ:
        # Ray will set ROCR_VISIBLE_DEVICES for device visibility
        logger.warning(
            "Ray uses ROCR_VISIBLE_DEVICES to control device accessibility."
675
676
            "Replacing HIP_VISIBLE_DEVICES with ROCR_VISIBLE_DEVICES."
        )
677
678
679
680
        val = os.environ["HIP_VISIBLE_DEVICES"]
        os.environ["ROCR_VISIBLE_DEVICES"] = val
        del os.environ["HIP_VISIBLE_DEVICES"]

681
682
683
684
    ray.init()
    num_gpus = int(ray.available_resources()["GPU"])
    workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]

685
    def _distribute(method: str, inputs: list[Any]) -> list[Any]:
686
687
688
689
690
691
692
693
694
695
696
        outputs = []
        worker_idx = 0
        for input_args in inputs:
            worker = workers[worker_idx]
            worker_method = getattr(worker, method)
            output = worker_method.remote(*input_args)
            outputs.append(output)
            worker_idx = (worker_idx + 1) % num_gpus
        return ray.get(outputs)

    if args.tune:
697
        is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
698
        search_space = get_configs_compute_bound(is_fp16, block_quant_shape)
699
        print(f"Start tuning over {len(search_space)} configurations...")
700
701
702
703
704
        if use_deep_gemm:
            raise ValueError(
                "Tuning with --use-deep-gemm is not supported as it only tunes Triton "
                "kernels. Please remove the flag."
            )
705
706
        start = time.time()
        configs = _distribute(
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
            "tune",
            [
                (
                    batch_size,
                    E,
                    shard_intermediate_size,
                    hidden_size,
                    topk,
                    dtype,
                    use_fp8_w8a8,
                    use_int8_w8a16,
                    search_space,
                    block_quant_shape,
                    use_deep_gemm,
                )
                for batch_size in batch_sizes
            ],
        )
725
        best_configs = {
726
            M: sort_config(config) for M, config in zip(batch_sizes, configs)
727
        }
728
729
730
731
732
733
734
735
736
737
        save_configs(
            best_configs,
            E,
            shard_intermediate_size,
            hidden_size,
            topk,
            dtype,
            use_fp8_w8a8,
            use_int8_w8a16,
            block_quant_shape,
738
            args.save_dir,
739
        )
740
741
742
        end = time.time()
        print(f"Tuning took {end - start:.2f} seconds")
    else:
743
        outputs = _distribute(
744
            "benchmark",
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
            [
                (
                    batch_size,
                    E,
                    shard_intermediate_size,
                    hidden_size,
                    topk,
                    dtype,
                    use_fp8_w8a8,
                    use_int8_w8a16,
                    block_quant_shape,
                    use_deep_gemm,
                )
                for batch_size in batch_sizes
            ],
        )
761
762
763
764
765
766
767

        for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
            print(f"Batch size: {batch_size}, config: {config}")
            print(f"Kernel time: {kernel_time:.2f} us")


if __name__ == "__main__":
768
    parser = FlexibleArgumentParser()
769
770
771
772
773
774
    parser.add_argument(
        "--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
    )
    parser.add_argument(
        "--tp-size", "-tp", "--tensor-parallel-size", type=int, default=2
    )
775
    parser.add_argument("--enable-expert-parallel", "-enable-ep", action="store_true")
776
777
778
    parser.add_argument(
        "--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto"
    )
779
    parser.add_argument("--use-deep-gemm", action="store_true")
780
781
782
    parser.add_argument(
        "--save-dir", type=str, default="./", help="Directory to save tuned results"
    )
783
    parser.add_argument("--seed", type=int, default=0)
784
    parser.add_argument("--batch-size", type=int, nargs="+", required=False)
785
    parser.add_argument("--tune", action="store_true")
786
    parser.add_argument("--trust-remote-code", action="store_true")
787
    parser.add_argument("--model-prefix", type=str, required=False)
788
789
790
    args = parser.parse_args()

    main(args)