benchmark_moe.py 30.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import argparse
5
import gc
6
import json
7
import os
8
import time
9
from contextlib import nullcontext
10
from datetime import datetime
11
from itertools import product
12
from typing import Any, TypedDict
13
14
15
16
17

import ray
import torch
from ray.experimental.tqdm_ray import tqdm

18
from vllm.model_executor.layers.fused_moe import fused_topk
19
from vllm.model_executor.layers.fused_moe.config import (
20
21
    FusedMoEConfig,
    FusedMoEParallelConfig,
22
    FusedMoEQuantConfig,
23
    RoutingMethodType,
24
25
    _get_config_dtype_str,
)
26
from vllm.model_executor.layers.fused_moe.fused_moe import *
27
28
29
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
    TritonOrDeepGemmExperts,
)
30
from vllm.transformers_utils.config import get_config
31
from vllm.triton_utils import triton
32
from vllm.utils.argparse_utils import FlexibleArgumentParser
33
from vllm.utils.torch_utils import set_random_seed
34

35
FP8_DTYPE = current_platform.fp8_dtype()
36

37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# Default interval for clearing Triton JIT cache during tuning
# Set to 0 to disable automatic cache clearing
_CACHE_CLEAR_INTERVAL_ENV = "VLLM_MOE_TUNE_CACHE_CLEAR_INTERVAL"
TRITON_CACHE_CLEAR_INTERVAL = int(os.environ.get(_CACHE_CLEAR_INTERVAL_ENV, "50"))


def clear_triton_cache():
    """Clear Triton JIT compilation cache and Python/CUDA memory.

    This helps prevent OOM during tuning with large models (many experts).
    """
    # Force Python garbage collection
    gc.collect()

    # Clear CUDA memory cache
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # Try to clear Triton's runtime cache
    try:
        if (
            hasattr(triton, "runtime")
            and hasattr(triton.runtime, "cache")
            and hasattr(triton.runtime.cache, "clear")
        ):
            triton.runtime.cache.clear()
    except ImportError:
        # Triton not installed, skip cache clearing
        pass
    except AttributeError:
        # Triton version doesn't have expected cache API
        pass
    except Exception as e:
        print(f"Warning: Failed to clear Triton cache: {e}")

    # Additional garbage collection after clearing caches
    gc.collect()

75

76
def ensure_divisibility(numerator, denominator, text):
77
    """Ensure that numerator is divisible by the denominator."""
78
79
    assert numerator % denominator == 0, "{} {} is not divisible by tp {}.".format(
        text, numerator, denominator
80
81
82
    )


83
84
85
86
87
88
89
90
91
class BenchmarkConfig(TypedDict):
    BLOCK_SIZE_M: int
    BLOCK_SIZE_N: int
    BLOCK_SIZE_K: int
    GROUP_SIZE_M: int
    num_warps: int
    num_stages: int


92
93
94
95
96
97
98
99
100
101
102
def benchmark_config(
    config: BenchmarkConfig,
    num_tokens: int,
    num_experts: int,
    shard_intermediate_size: int,
    hidden_size: int,
    topk: int,
    dtype: torch.dtype,
    use_fp8_w8a8: bool,
    use_int8_w8a16: bool,
    num_iters: int = 100,
103
    block_quant_shape: list[int] = None,
104
105
    use_deep_gemm: bool = False,
) -> float:
106
    init_dtype = torch.float16 if use_fp8_w8a8 else dtype
107
    x = torch.randn(num_tokens, hidden_size, dtype=dtype)
108
    if use_int8_w8a16:
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
        w1 = torch.randint(
            -127,
            127,
            (
                num_experts,
                shard_intermediate_size,
                hidden_size,
            ),
            dtype=torch.int8,
        )
        w2 = torch.randint(
            -127,
            127,
            (
                num_experts,
                hidden_size,
                shard_intermediate_size // 2,
            ),
            dtype=torch.int8,
        )
129
    else:
130
131
132
133
134
135
136
        w1 = torch.randn(
            num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
        )
        w2 = torch.randn(
            num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
        )
    gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32)
137
138
139
140
141

    w1_scale = None
    w2_scale = None
    a1_scale = None
    a2_scale = None
142
    if use_int8_w8a16:
143
144
145
        w1_scale = torch.randn(
            (num_experts, 2 * shard_intermediate_size), dtype=torch.float32
        )
146
        w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
147
148
149
    if use_deep_gemm:
        # we use the default block shape for deepgemm
        block_quant_shape = [128, 128]
150
    if use_fp8_w8a8:
151
152
153
154
155
156
157
158
159
160
        if block_quant_shape:
            block_n, block_k = block_quant_shape[0], block_quant_shape[1]
            E = num_experts
            N = shard_intermediate_size // 2
            K = hidden_size
            factor_for_scale = 1e-2
            n_tiles_w1 = (2 * N + block_n - 1) // block_n
            n_tiles_w2 = (K + block_n - 1) // block_n
            k_tiles_w1 = (K + block_k - 1) // block_k
            k_tiles_w2 = (N + block_k - 1) // block_k
161
162
163
164
165
166
167
168
            w1_scale = (
                torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32)
                * factor_for_scale
            )
            w2_scale = (
                torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32)
                * factor_for_scale
            )
169
170
171
172
        else:
            w1_scale = torch.randn(num_experts, dtype=torch.float32)
            w2_scale = torch.randn(num_experts, dtype=torch.float32)

173
174
175
        a1_scale = torch.randn(1, dtype=torch.float32)
        a2_scale = torch.randn(1, dtype=torch.float32)

176
177
        w1 = w1.to(FP8_DTYPE)
        w2 = w2.to(FP8_DTYPE)
178
179
180
181
182
183
184

    input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32)

    def prepare(i: int):
        input_gating.copy_(gating_output[i])

    def run():
185
        from vllm.model_executor.layers.fused_moe import override_config
186

187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
        if use_fp8_w8a8:
            quant_dtype = torch.float8_e4m3fn
        elif use_int8_w8a16:
            quant_dtype = torch.int8
        else:
            quant_dtype = None

        quant_config = FusedMoEQuantConfig.make(
            quant_dtype=quant_dtype,
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
            block_shape=block_quant_shape,
        )

203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
        deep_gemm_experts = None
        if use_deep_gemm:
            deep_gemm_experts = mk.FusedMoEModularKernel(
                prepare_finalize=MoEPrepareAndFinalizeNoEP(),
                fused_experts=TritonOrDeepGemmExperts(
                    moe_config=FusedMoEConfig(
                        num_experts=num_experts,
                        experts_per_token=topk,
                        hidden_dim=hidden_size,
                        intermediate_size_per_partition=shard_intermediate_size,
                        num_local_experts=num_experts,
                        activation="silu",
                        moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
                        in_dtype=init_dtype,
                        routing_method=RoutingMethodType.TopK,
                        device="cuda",
                    ),
                    quant_config=quant_config,
221
                ),
222
            )
223

224
        with override_config(config):
225
226
227
            topk_weights, topk_ids, token_expert_indices = fused_topk(
                x, input_gating, topk, renormalize=not use_deep_gemm
            )
228

229
            inplace = not disable_inplace()
230
231
            if use_deep_gemm:
                return deep_gemm_experts(
232
                    x, w1, w2, topk_weights, topk_ids, inplace=inplace
233
                )
234
235
236
237
238
239
            return fused_experts(
                x,
                w1,
                w2,
                topk_weights,
                topk_ids,
240
                inplace=inplace,
241
242
                quant_config=quant_config,
            )
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259

    # JIT compilation & warmup
    run()
    torch.cuda.synchronize()

    # Capture 10 invocations with CUDA graph
    graph = torch.cuda.CUDAGraph()
    with torch.cuda.graph(graph):
        for _ in range(10):
            run()
    torch.cuda.synchronize()

    # Warmup
    for _ in range(5):
        graph.replay()
    torch.cuda.synchronize()

260
261
    start_event = torch.Event(enable_timing=True)
    end_event = torch.Event(enable_timing=True)
262

263
    latencies: list[float] = []
264
265
266
267
268
269
270
271
272
273
274
275
276
277
    for i in range(num_iters):
        prepare(i)
        torch.cuda.synchronize()

        start_event.record()
        graph.replay()
        end_event.record()
        end_event.synchronize()
        latencies.append(start_event.elapsed_time(end_event))
    avg = sum(latencies) / (num_iters * 10) * 1000  # us
    graph.reset()
    return avg


278
279
280
281
282
283
284
285
def get_rocm_tuning_space(use_fp16):
    block_mn_range = [16, 32, 64, 128, 256]
    block_k_range = [16, 32, 64, 128, 256]
    if not use_fp16:
        block_k_range.remove(16)  # BLOCK_K=16 not supported for fp8
    num_warps_range = [1, 2, 4, 8]
    group_m_range = [1, 4, 8, 16, 32]
    num_stage_range = [2]
286
    waves_per_eu_range = [0, 1, 2, 4]
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
    matrix_instr_nonkdim_range = [16, 32] if use_fp16 else []
    kpack_range = [1, 2] if use_fp16 else []

    param_ranges = {
        "BLOCK_SIZE_M": block_mn_range,
        "BLOCK_SIZE_N": block_mn_range,
        "BLOCK_SIZE_K": block_k_range,
        "GROUP_SIZE_M": group_m_range,
        "num_warps": num_warps_range,
        "num_stages": num_stage_range,
        "waves_per_eu": waves_per_eu_range,
    }
    if use_fp16:
        param_ranges["matrix_instr_nonkdim"] = matrix_instr_nonkdim_range
        param_ranges["kpack"] = kpack_range

    return param_ranges


306
def get_configs_compute_bound(use_fp16, block_quant_shape) -> list[dict[str, int]]:
307
    configs: list[BenchmarkConfig] = []
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334

    if current_platform.is_rocm():
        param_ranges = get_rocm_tuning_space(use_fp16)
    else:
        # Reduced search space for faster tuning.
        # TODO(woosuk): Increase the search space and use a performance model to
        # prune the search space.
        block_m_range = [16, 32, 64, 128, 256]
        block_n_range = [32, 64, 128, 256]
        block_k_range = [64, 128, 256]
        num_warps_range = [4, 8]
        group_m_range = [1, 16, 32, 64]
        num_stage_range = [2, 3, 4, 5]

        param_ranges = {
            "BLOCK_SIZE_M": block_m_range,
            "BLOCK_SIZE_N": block_n_range,
            "BLOCK_SIZE_K": block_k_range,
            "GROUP_SIZE_M": group_m_range,
            "num_warps": num_warps_range,
            "num_stages": num_stage_range,
        }

    keys, values = zip(*param_ranges.items())
    for config_values in product(*values):
        config = dict(zip(keys, config_values))
        configs.append(config)
335
336
337
338
339
340
341

    # Remove configs that are not compatible with fp8 block quantization
    # BLOCK_SIZE_K must be a multiple of block_k
    # BLOCK_SIZE_N must be a multiple of block_n
    if block_quant_shape is not None and not use_fp16:
        block_n, block_k = block_quant_shape[0], block_quant_shape[1]
        for config in configs[:]:
342
343
344
345
            if (
                config["BLOCK_SIZE_K"] % block_k != 0
                or config["BLOCK_SIZE_N"] % block_n != 0
            ):
346
                configs.remove(config)
347
348
349
    return configs


350
351
352
def prune_rocm_search_space(
    num_tokens, shard_intermediate_size, hidden_size, search_space, is_fp16, topk
):
353
354
    N1, K1 = shard_intermediate_size, hidden_size
    N2, K2 = hidden_size, shard_intermediate_size // 2
355
356
357
358
359
360
    pruned_space_1 = prune_rocm_configs(
        num_tokens * topk, N1, K1, search_space, is_fp16
    )
    pruned_space_2 = prune_rocm_configs(
        num_tokens * topk, N2, K2, search_space, is_fp16
    )
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
    search_space = merge_unique_dicts(pruned_space_1, pruned_space_2)
    return search_space


# The following code is inspired by ROCm/Triton GEMM tuning script:
# https://github.com/ROCm/triton/blob/triton-mlir/scripts/amd/gemm/tune_gemm.py#L89
def prune_rocm_configs(M, N, K, configs, is_fp16=True):
    pruned_configs = []
    elemBytes_a = 2 if is_fp16 else 1
    elemBytes_b = 2 if is_fp16 else 1

    mfma = 16 if M < 32 or N < 32 else 32

    # TODO (zhanglx): figure out the boundary between large and small gemms
    large_gemm = False
    if M >= 2048 and N >= 2048:
        large_gemm = True

    for config in configs:
        BLOCK_SIZE_M = config.get("BLOCK_SIZE_M")
        BLOCK_SIZE_N = config.get("BLOCK_SIZE_N")
        BLOCK_SIZE_K = config.get("BLOCK_SIZE_K")
        num_warps = config.get("num_warps")

        if is_fp16:
            matrix_instr_nonkdim = config.get("matrix_instr_nonkdim")
            if matrix_instr_nonkdim > mfma:
                continue
        if mfma == 4 and BLOCK_SIZE_K < 64:
            continue
        # some layouts could not work properly in case
        # number elements per thread is less 1
        if BLOCK_SIZE_M * BLOCK_SIZE_N < 64:
            continue
        SPLIT_K = config.get("SPLIT_K", 1)
        GROUP_M = config.get("GROUP_SIZE_M")
        if is_fp16:
398
399
400
401
            if (
                matrix_instr_nonkdim > BLOCK_SIZE_M
                or matrix_instr_nonkdim > BLOCK_SIZE_N
            ):
402
                continue
403
            if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M:
404
                continue
405
            if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N:
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
                continue
        # Skip BLOCK_SIZE that is too large compare to M/N
        # unless BLOCK_SIZE is already small enough
        if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16:
            continue
        if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16:
            continue
        # skip large split_k when not necessary
        if SPLIT_K != 1 and not need_split_k(M, N, K):
            continue
        # skip split_k that leads to EVEN_K = false
        leap = SPLIT_K * BLOCK_SIZE_K
        modv = K % leap
        if modv != 0:
            continue
        # skip large GROUP_M
        if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1:
            continue
        # out of shared memory resource
        # TODO (zhanglx): This does not consider the LDS usage in the epilogue
426
427
428
429
        LDS = (
            BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a
            + BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b
        )
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
        if LDS > 65536:
            continue
        # Skip small block sizes and num_warps for large gemm
        # For fp16 and f8, we want to only use BLOCK_SIZE >= 64
        if large_gemm:
            if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64:
                continue
            if BLOCK_SIZE_K < 64:
                continue
            if num_warps < 4:
                continue

        pruned_configs.append(config)

    return pruned_configs


def need_split_k(SIZE_M, SIZE_N, SIZE_K):
    return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024


def merge_unique_dicts(list1, list2):
    result = []
    combined_list = list1.copy()
    combined_list.extend(list2)
    for dictionary in combined_list:
        if dictionary not in result:
            result.append(dictionary)
    return result


461
462
463
464
@ray.remote(num_gpus=1)
class BenchmarkWorker:
    def __init__(self, seed: int) -> None:
        torch.set_default_device("cuda")
465
        set_random_seed(seed)
466
        self.seed = seed
467
468
469
470
        # Get the device ID to allocate tensors and kernels
        # on the respective GPU. This is required for Ray to work
        # correctly with multi-GPU tuning on the ROCm platform.
        self.device_id = int(ray.get_gpu_ids()[0])
471
472
473
474
475
476
477
478
479

    def benchmark(
        self,
        num_tokens: int,
        num_experts: int,
        shard_intermediate_size: int,
        hidden_size: int,
        topk: int,
        dtype: torch.dtype,
480
481
        use_fp8_w8a8: bool,
        use_int8_w8a16: bool,
482
        block_quant_shape: list[int] = None,
483
        use_deep_gemm: bool = False,
484
    ) -> tuple[dict[str, int], float]:
485
486
        # local import to allow serialization by ray

487
        set_random_seed(self.seed)
488
        dtype_str = _get_config_dtype_str(
489
490
            dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
        )
491
492
        # NOTE(woosuk): The current naming convention uses w2.shape[2], which
        # is the intermediate size after silu_and_mul.
493
494
        block_n = block_quant_shape[0] if block_quant_shape else None
        block_k = block_quant_shape[1] if block_quant_shape else None
495
        op_config = get_moe_configs(
496
            num_experts, shard_intermediate_size // 2, dtype_str, block_n, block_k
497
        )
498
        if op_config is None:
499
500
501
502
503
504
505
            config = get_default_config(
                num_tokens,
                num_experts,
                shard_intermediate_size,
                hidden_size,
                topk,
                dtype_str,
506
                block_quant_shape,
507
            )
508
        else:
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
            config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))]
        kernel_time = benchmark_config(
            config,
            num_tokens,
            num_experts,
            shard_intermediate_size,
            hidden_size,
            topk,
            dtype,
            use_fp8_w8a8,
            use_int8_w8a16,
            num_iters=100,
            block_quant_shape=block_quant_shape,
            use_deep_gemm=use_deep_gemm,
        )
524
525
526
527
528
529
530
531
532
533
        return config, kernel_time

    def tune(
        self,
        num_tokens: int,
        num_experts: int,
        shard_intermediate_size: int,
        hidden_size: int,
        topk: int,
        dtype: torch.dtype,
534
535
        use_fp8_w8a8: bool,
        use_int8_w8a16: bool,
536
        search_space: list[dict[str, int]],
537
        block_quant_shape: list[int],
538
        use_deep_gemm: bool,
539
    ) -> dict[str, int]:
540
541
542
        # local import to allow serialization by ray
        from vllm.platforms import current_platform

543
544
        best_config = None
        best_time = float("inf")
545
546
        if current_platform.is_rocm():
            is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
547
548
549
550
551
552
553
554
            search_space = prune_rocm_search_space(
                num_tokens,
                shard_intermediate_size,
                hidden_size,
                search_space,
                is_fp16,
                topk,
            )
555

556
557
558
559
560
561
        need_device_guard = False
        if current_platform.is_rocm():
            visible_device = os.environ.get("ROCR_VISIBLE_DEVICES", None)
            if visible_device != f"{self.device_id}":
                need_device_guard = True

562
        with torch.cuda.device(self.device_id) if need_device_guard else nullcontext():
563
            for idx, config in enumerate(tqdm(search_space)):
564
                try:
565
566
567
568
569
570
571
572
573
574
575
                    kernel_time = benchmark_config(
                        config,
                        num_tokens,
                        num_experts,
                        shard_intermediate_size,
                        hidden_size,
                        topk,
                        dtype,
                        use_fp8_w8a8,
                        use_int8_w8a16,
                        num_iters=20,
576
                        block_quant_shape=block_quant_shape,
577
578
                        use_deep_gemm=use_deep_gemm,
                    )
579
580
581
582
583
584
585
                except triton.runtime.autotuner.OutOfResources:
                    # Some configurations may be invalid and fail to compile.
                    continue

                if kernel_time < best_time:
                    best_time = kernel_time
                    best_config = config
586
587
588
589
590
591
592
593
594
595
596
597
598

                # Periodically clear Triton JIT cache to prevent OOM
                # This is especially important for large models with many experts
                if (
                    TRITON_CACHE_CLEAR_INTERVAL > 0
                    and idx > 0
                    and idx % TRITON_CACHE_CLEAR_INTERVAL == 0
                ):
                    clear_triton_cache()

        # Final cleanup after tuning completes
        clear_triton_cache()

599
600
        now = datetime.now()
        print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
601
        assert best_config is not None
602
603
604
        return best_config


605
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
606
    return {
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
        "BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
        "BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
        "BLOCK_SIZE_K": config["BLOCK_SIZE_K"],
        "GROUP_SIZE_M": config["GROUP_SIZE_M"],
        "num_warps": config["num_warps"],
        "num_stages": config["num_stages"],
        **(
            {"waves_per_eu": config["waves_per_eu"]} if "waves_per_eu" in config else {}
        ),
        **(
            {"matrix_instr_nonkdim": config["matrix_instr_nonkdim"]}
            if "matrix_instr_nonkdim" in config
            else {}
        ),
        **({"kpack": config["kpack"]} if "kpack" in config else {}),
622
623
624
    }


625
626
627
628
629
630
631
632
633
def save_configs(
    configs: dict[int, BenchmarkConfig],
    num_experts: int,
    shard_intermediate_size: int,
    hidden_size: int,
    topk: int,
    dtype: torch.dtype,
    use_fp8_w8a8: bool,
    use_int8_w8a16: bool,
634
    block_quant_shape: list[int],
635
    save_dir: str,
636
) -> None:
637
    dtype_str = _get_config_dtype_str(
638
639
        dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
    )
640

641
642
    # NOTE(woosuk): The current naming convention uses w2.shape[2], which
    # is the intermediate size after silu_and_mul.
643
644
645
    filename = get_config_file_name(
        num_experts, shard_intermediate_size // 2, dtype_str, block_quant_shape
    )
646
647
    os.makedirs(save_dir, exist_ok=True)
    filename = os.path.join(save_dir, filename)
648
649
    print(f"Writing best config to {filename}...")
    with open(filename, "w") as f:
650
        json.dump({"triton_version": triton.__version__, **configs}, f, indent=4)
651
652
653
        f.write("\n")


654
655
656
657
658
659
660
661
662
663
def get_compressed_tensors_block_structure(config, default_value=None):
    config_groups = config.get("config_groups", {})
    if len(config_groups) != 1:
        return default_value
    group = next(iter(config_groups.values()))
    weights = group.get("weights", {})
    block_structure = weights.get("block_structure", default_value)
    return block_structure


664
def get_weight_block_size_safety(config, default_value=None):
665
    quantization_config = getattr(config, "quantization_config", {})
666
    if isinstance(quantization_config, dict):
667
668
669
670
671
        if "weight_block_size" in quantization_config:
            return quantization_config["weight_block_size"]
        return get_compressed_tensors_block_structure(
            quantization_config, default_value
        )
672
673
674
    return default_value


675
def get_model_params(config):
676
677
678
679
    if config.architectures[0] == "DbrxForCausalLM":
        E = config.ffn_config.moe_num_experts
        topk = config.ffn_config.moe_top_k
        intermediate_size = config.ffn_config.ffn_hidden_size
680
        hidden_size = config.hidden_size
681
682
683
684
    elif config.architectures[0] == "JambaForCausalLM":
        E = config.num_experts
        topk = config.num_experts_per_tok
        intermediate_size = config.intermediate_size
685
        hidden_size = config.hidden_size
Yuxuan Zhang's avatar
Yuxuan Zhang committed
686
687
    elif config.architectures[0] in (
        "DeepseekV2ForCausalLM",
688
689
        "DeepseekV3ForCausalLM",
        "DeepseekV32ForCausalLM",
Jee Jee Li's avatar
Jee Jee Li committed
690
        "GlmMoeDsaForCausalLM",
Yuxuan Zhang's avatar
Yuxuan Zhang committed
691
        "Glm4MoeForCausalLM",
692
        "Glm4MoeLiteForCausalLM",
693
        "NemotronHForCausalLM",
694
        "MistralLarge3ForCausalLM",
Yuxuan Zhang's avatar
Yuxuan Zhang committed
695
    ):
696
697
698
        E = config.n_routed_experts
        topk = config.num_experts_per_tok
        intermediate_size = config.moe_intermediate_size
699
        hidden_size = config.hidden_size
700
701
702
703
704
    elif config.architectures[0] in (
        "Qwen2MoeForCausalLM",
        "Qwen3MoeForCausalLM",
        "Qwen3NextForCausalLM",
    ):
705
706
707
        E = config.num_experts
        topk = config.num_experts_per_tok
        intermediate_size = config.moe_intermediate_size
708
709
710
711
712
713
714
        hidden_size = config.hidden_size
    elif config.architectures[0] == "Qwen3VLMoeForConditionalGeneration":
        text_config = config.get_text_config()
        E = text_config.num_experts
        topk = text_config.num_experts_per_tok
        intermediate_size = text_config.moe_intermediate_size
        hidden_size = text_config.hidden_size
715
    elif config.architectures[0] == "HunYuanMoEV1ForCausalLM":
716
717
718
        E = config.num_experts
        topk = config.moe_topk[0]
        intermediate_size = config.moe_intermediate_size[0]
719
        hidden_size = config.hidden_size
720
    elif config.architectures[0] == "Qwen3OmniMoeForConditionalGeneration":
721
722
723
724
        E = config.thinker_config.text_config.num_experts
        topk = config.thinker_config.text_config.num_experts_per_tok
        intermediate_size = config.thinker_config.text_config.moe_intermediate_size
        hidden_size = config.thinker_config.text_config.hidden_size
725
726
727
728
    elif config.architectures[0] == "PixtralForConditionalGeneration":
        # Pixtral can contain different LLM architectures,
        # recurse to get their parameters
        return get_model_params(config.get_text_config())
729
    else:
730
731
        # Support for llama4
        config = config.get_text_config()
732
733
734
735
        # Default: Mixtral.
        E = config.num_local_experts
        topk = config.num_experts_per_tok
        intermediate_size = config.intermediate_size
736
        hidden_size = config.hidden_size
737
738
739
740
741
742
743
744
745
746
    return E, topk, intermediate_size, hidden_size


def main(args: argparse.Namespace):
    print(args)

    config = get_config(model=args.model, trust_remote_code=args.trust_remote_code)
    if args.model_prefix:
        config = getattr(config, args.model_prefix)
    E, topk, intermediate_size, hidden_size = get_model_params(config)
747
748
749
750
751
752
753
    enable_ep = bool(args.enable_expert_parallel)
    if enable_ep:
        ensure_divisibility(E, args.tp_size, "Number of experts")
        E = E // args.tp_size
        shard_intermediate_size = 2 * intermediate_size
    else:
        ensure_divisibility(intermediate_size, args.tp_size, "intermediate_size")
754
        shard_intermediate_size = 2 * intermediate_size // args.tp_size
755
    dtype = torch.float16 if current_platform.is_rocm() else config.dtype
756
757
    use_fp8_w8a8 = args.dtype == "fp8_w8a8"
    use_int8_w8a16 = args.dtype == "int8_w8a16"
758
    block_quant_shape = get_weight_block_size_safety(config)
759
760

    if args.batch_size is None:
761
        batch_sizes = [
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
            1,
            2,
            4,
            8,
            16,
            24,
            32,
            48,
            64,
            96,
            128,
            256,
            512,
            1024,
            1536,
            2048,
            3072,
            4096,
780
        ]
781
    else:
782
        batch_sizes = args.batch_size
783

784
785
    use_deep_gemm = bool(args.use_deep_gemm)

786
787
788
789
    if current_platform.is_rocm() and "HIP_VISIBLE_DEVICES" in os.environ:
        # Ray will set ROCR_VISIBLE_DEVICES for device visibility
        logger.warning(
            "Ray uses ROCR_VISIBLE_DEVICES to control device accessibility."
790
791
            "Replacing HIP_VISIBLE_DEVICES with ROCR_VISIBLE_DEVICES."
        )
792
793
794
795
        val = os.environ["HIP_VISIBLE_DEVICES"]
        os.environ["ROCR_VISIBLE_DEVICES"] = val
        del os.environ["HIP_VISIBLE_DEVICES"]

796
797
798
799
    ray.init()
    num_gpus = int(ray.available_resources()["GPU"])
    workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]

800
    def _distribute(method: str, inputs: list[Any]) -> list[Any]:
801
802
803
804
805
806
807
808
809
810
811
        outputs = []
        worker_idx = 0
        for input_args in inputs:
            worker = workers[worker_idx]
            worker_method = getattr(worker, method)
            output = worker_method.remote(*input_args)
            outputs.append(output)
            worker_idx = (worker_idx + 1) % num_gpus
        return ray.get(outputs)

    if args.tune:
812
        is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
813
        search_space = get_configs_compute_bound(is_fp16, block_quant_shape)
814
        print(f"Start tuning over {len(search_space)} configurations...")
815
816
817
818
819
        if use_deep_gemm:
            raise ValueError(
                "Tuning with --use-deep-gemm is not supported as it only tunes Triton "
                "kernels. Please remove the flag."
            )
820
821
        start = time.time()
        configs = _distribute(
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
            "tune",
            [
                (
                    batch_size,
                    E,
                    shard_intermediate_size,
                    hidden_size,
                    topk,
                    dtype,
                    use_fp8_w8a8,
                    use_int8_w8a16,
                    search_space,
                    block_quant_shape,
                    use_deep_gemm,
                )
                for batch_size in batch_sizes
            ],
        )
840
        best_configs = {
841
            M: sort_config(config) for M, config in zip(batch_sizes, configs)
842
        }
843
844
845
846
847
848
849
850
851
852
        save_configs(
            best_configs,
            E,
            shard_intermediate_size,
            hidden_size,
            topk,
            dtype,
            use_fp8_w8a8,
            use_int8_w8a16,
            block_quant_shape,
853
            args.save_dir,
854
        )
855
856
857
        end = time.time()
        print(f"Tuning took {end - start:.2f} seconds")
    else:
858
        outputs = _distribute(
859
            "benchmark",
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
            [
                (
                    batch_size,
                    E,
                    shard_intermediate_size,
                    hidden_size,
                    topk,
                    dtype,
                    use_fp8_w8a8,
                    use_int8_w8a16,
                    block_quant_shape,
                    use_deep_gemm,
                )
                for batch_size in batch_sizes
            ],
        )
876
877
878
879
880
881
882

        for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
            print(f"Batch size: {batch_size}, config: {config}")
            print(f"Kernel time: {kernel_time:.2f} us")


if __name__ == "__main__":
883
    parser = FlexibleArgumentParser()
884
885
886
887
888
889
    parser.add_argument(
        "--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
    )
    parser.add_argument(
        "--tp-size", "-tp", "--tensor-parallel-size", type=int, default=2
    )
890
    parser.add_argument("--enable-expert-parallel", "-enable-ep", action="store_true")
891
892
893
    parser.add_argument(
        "--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto"
    )
894
    parser.add_argument("--use-deep-gemm", action="store_true")
895
896
897
    parser.add_argument(
        "--save-dir", type=str, default="./", help="Directory to save tuned results"
    )
898
    parser.add_argument("--seed", type=int, default=0)
899
    parser.add_argument("--batch-size", type=int, nargs="+", required=False)
900
    parser.add_argument("--tune", action="store_true")
901
    parser.add_argument("--trust-remote-code", action="store_true")
902
    parser.add_argument("--model-prefix", type=str, required=False)
903
904
905
    args = parser.parse_args()

    main(args)