benchmark_moe.py 35.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import argparse
5
import gc
6
import json
7
import os
8
import time
9
from contextlib import nullcontext
10
from datetime import datetime
11
from itertools import product
12
from typing import Any, TypedDict
13
14
15
16
17

import ray
import torch
from ray.experimental.tqdm_ray import tqdm

18
from vllm.model_executor.layers.fused_moe import fused_topk
19
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
20
21
22
from vllm.model_executor.layers.fused_moe.all2all_utils import (
    maybe_make_prepare_finalize,
)
23
from vllm.model_executor.layers.fused_moe.config import (
24
25
    FusedMoEConfig,
    FusedMoEParallelConfig,
26
    FusedMoEQuantConfig,
27
    RoutingMethodType,
28
29
    _get_config_dtype_str,
)
30
from vllm.model_executor.layers.fused_moe.fused_moe import *
31
32
33
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
    TritonOrDeepGemmExperts,
)
34
from vllm.transformers_utils.config import get_config
35
from vllm.triton_utils import triton
36
from vllm.utils.argparse_utils import FlexibleArgumentParser
37
from vllm.utils.torch_utils import set_random_seed
38

39
FP8_DTYPE = current_platform.fp8_dtype()
40

41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# Default interval for clearing Triton JIT cache during tuning
# Set to 0 to disable automatic cache clearing
_CACHE_CLEAR_INTERVAL_ENV = "VLLM_MOE_TUNE_CACHE_CLEAR_INTERVAL"
TRITON_CACHE_CLEAR_INTERVAL = int(os.environ.get(_CACHE_CLEAR_INTERVAL_ENV, "50"))


def clear_triton_cache():
    """Clear Triton JIT compilation cache and Python/CUDA memory.

    This helps prevent OOM during tuning with large models (many experts).
    """
    # Force Python garbage collection
    gc.collect()

    # Clear CUDA memory cache
    if torch.cuda.is_available():
57
        torch.accelerator.empty_cache()
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78

    # Try to clear Triton's runtime cache
    try:
        if (
            hasattr(triton, "runtime")
            and hasattr(triton.runtime, "cache")
            and hasattr(triton.runtime.cache, "clear")
        ):
            triton.runtime.cache.clear()
    except ImportError:
        # Triton not installed, skip cache clearing
        pass
    except AttributeError:
        # Triton version doesn't have expected cache API
        pass
    except Exception as e:
        print(f"Warning: Failed to clear Triton cache: {e}")

    # Additional garbage collection after clearing caches
    gc.collect()

79

80
def ensure_divisibility(numerator, denominator, text):
81
    """Ensure that numerator is divisible by the denominator."""
82
83
    assert numerator % denominator == 0, "{} {} is not divisible by tp {}.".format(
        text, numerator, denominator
84
85
86
    )


87
88
89
90
91
92
93
94
95
class BenchmarkConfig(TypedDict):
    BLOCK_SIZE_M: int
    BLOCK_SIZE_N: int
    BLOCK_SIZE_K: int
    GROUP_SIZE_M: int
    num_warps: int
    num_stages: int


96
97
98
99
100
101
102
103
104
105
def benchmark_config(
    config: BenchmarkConfig,
    num_tokens: int,
    num_experts: int,
    shard_intermediate_size: int,
    hidden_size: int,
    topk: int,
    dtype: torch.dtype,
    use_fp8_w8a8: bool,
    use_int8_w8a16: bool,
106
    use_int4_w4a16: bool = False,
107
    num_iters: int = 100,
108
    block_quant_shape: list[int] = None,
109
110
    use_deep_gemm: bool = False,
) -> float:
111
    init_dtype = torch.float16 if use_fp8_w8a8 else dtype
112
    x = torch.randn(num_tokens, hidden_size, dtype=dtype)
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
    if use_int4_w4a16:
        # Int4 packed weights: 2 int4 values per uint8 byte
        # K dimension is packed (halved)
        intermediate_size = shard_intermediate_size // 2  # after silu_and_mul
        w1 = torch.randint(
            0,
            255,
            (
                num_experts,
                shard_intermediate_size,
                hidden_size // 2,  # int4 packing
            ),
            dtype=torch.uint8,
        )
        w2 = torch.randint(
            0,
            255,
            (
                num_experts,
                hidden_size,
                intermediate_size // 2,  # int4 packing
            ),
            dtype=torch.uint8,
        )
    elif use_int8_w8a16:
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
        w1 = torch.randint(
            -127,
            127,
            (
                num_experts,
                shard_intermediate_size,
                hidden_size,
            ),
            dtype=torch.int8,
        )
        w2 = torch.randint(
            -127,
            127,
            (
                num_experts,
                hidden_size,
                shard_intermediate_size // 2,
            ),
            dtype=torch.int8,
        )
158
    else:
159
160
161
162
163
164
165
        w1 = torch.randn(
            num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
        )
        w2 = torch.randn(
            num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
        )
    gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32)
166
167
168
169
170

    w1_scale = None
    w2_scale = None
    a1_scale = None
    a2_scale = None
171
172
173
174
175
176
177
178
179
180
181
182
183
184
    if use_int4_w4a16:
        if block_quant_shape is None:
            raise ValueError("block_quant_shape is required for int4_w4a16")
        group_size = block_quant_shape[1]
        # Scales shape: (E, N, K // group_size) in fp16
        w1_scale = torch.rand(
            (num_experts, shard_intermediate_size, hidden_size // group_size),
            dtype=dtype,
        )
        w2_scale = torch.rand(
            (num_experts, hidden_size, intermediate_size // group_size),
            dtype=dtype,
        )
    elif use_int8_w8a16:
185
186
187
        w1_scale = torch.randn(
            (num_experts, 2 * shard_intermediate_size), dtype=torch.float32
        )
188
        w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
189
190
191
    if use_deep_gemm:
        # we use the default block shape for deepgemm
        block_quant_shape = [128, 128]
192
    if use_fp8_w8a8:
193
194
195
196
197
198
199
200
201
202
        if block_quant_shape:
            block_n, block_k = block_quant_shape[0], block_quant_shape[1]
            E = num_experts
            N = shard_intermediate_size // 2
            K = hidden_size
            factor_for_scale = 1e-2
            n_tiles_w1 = (2 * N + block_n - 1) // block_n
            n_tiles_w2 = (K + block_n - 1) // block_n
            k_tiles_w1 = (K + block_k - 1) // block_k
            k_tiles_w2 = (N + block_k - 1) // block_k
203
204
205
206
207
208
209
210
            w1_scale = (
                torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32)
                * factor_for_scale
            )
            w2_scale = (
                torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32)
                * factor_for_scale
            )
211
212
213
214
        else:
            w1_scale = torch.randn(num_experts, dtype=torch.float32)
            w2_scale = torch.randn(num_experts, dtype=torch.float32)

215
216
217
        a1_scale = torch.randn(1, dtype=torch.float32)
        a2_scale = torch.randn(1, dtype=torch.float32)

218
219
        w1 = w1.to(FP8_DTYPE)
        w2 = w2.to(FP8_DTYPE)
220
221
222
223
224
225
226

    input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32)

    def prepare(i: int):
        input_gating.copy_(gating_output[i])

    def run():
227
        from vllm.model_executor.layers.fused_moe import override_config
228

229
230
231
232
233
234
235
236
237
238
239
240
241
242
        if use_fp8_w8a8:
            quant_dtype = torch.float8_e4m3fn
        elif use_int8_w8a16:
            quant_dtype = torch.int8
        else:
            quant_dtype = None

        quant_config = FusedMoEQuantConfig.make(
            quant_dtype=quant_dtype,
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
            block_shape=block_quant_shape,
243
            weight_dtype="int4" if use_int4_w4a16 else None,
244
245
        )

246
247
        deep_gemm_experts = None
        if use_deep_gemm:
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
            moe_config = (
                FusedMoEConfig(
                    num_experts=num_experts,
                    experts_per_token=topk,
                    hidden_dim=hidden_size,
                    intermediate_size_per_partition=shard_intermediate_size,
                    num_local_experts=num_experts,
                    num_logical_experts=num_experts,
                    activation=MoEActivation.SILU,
                    moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
                    in_dtype=init_dtype,
                    routing_method=RoutingMethodType.TopK,
                    device="cuda",
                ),
            )
            deep_gemm_experts = mk.FusedMoEKernel(
                prepare_finalize=maybe_make_prepare_finalize(
                    moe=moe_config,
                    quant_config=quant_config,
                    allow_new_interface=True,
                    use_monolithic=False,
                ),
270
                fused_experts=TritonOrDeepGemmExperts(
271
                    moe_config=moe_config,
272
                    quant_config=quant_config,
273
                ),
274
                inplace=not disable_inplace(),
275
            )
276

277
        with override_config(config):
278
279
280
            topk_weights, topk_ids, token_expert_indices = fused_topk(
                x, input_gating, topk, renormalize=not use_deep_gemm
            )
281

282
            inplace = not disable_inplace()
283
            if use_deep_gemm:
284
285
286
287
288
289
290
291
292
293
                return deep_gemm_experts.apply(
                    x,
                    w1,
                    w2,
                    topk_weights,
                    topk_ids,
                    activation=MoEActivation.SILU,
                    global_num_experts=num_experts,
                    apply_router_weight_on_input=False,
                    expert_map=False,
294
                )
295
296
297
298
299
300
            return fused_experts(
                x,
                w1,
                w2,
                topk_weights,
                topk_ids,
301
                inplace=inplace,
302
303
                quant_config=quant_config,
            )
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320

    # JIT compilation & warmup
    run()
    torch.cuda.synchronize()

    # Capture 10 invocations with CUDA graph
    graph = torch.cuda.CUDAGraph()
    with torch.cuda.graph(graph):
        for _ in range(10):
            run()
    torch.cuda.synchronize()

    # Warmup
    for _ in range(5):
        graph.replay()
    torch.cuda.synchronize()

321
322
    start_event = torch.Event(enable_timing=True)
    end_event = torch.Event(enable_timing=True)
323

324
    latencies: list[float] = []
325
326
327
328
329
330
331
332
333
334
335
336
337
338
    for i in range(num_iters):
        prepare(i)
        torch.cuda.synchronize()

        start_event.record()
        graph.replay()
        end_event.record()
        end_event.synchronize()
        latencies.append(start_event.elapsed_time(end_event))
    avg = sum(latencies) / (num_iters * 10) * 1000  # us
    graph.reset()
    return avg


339
340
341
342
343
344
345
346
def get_rocm_tuning_space(use_fp16):
    block_mn_range = [16, 32, 64, 128, 256]
    block_k_range = [16, 32, 64, 128, 256]
    if not use_fp16:
        block_k_range.remove(16)  # BLOCK_K=16 not supported for fp8
    num_warps_range = [1, 2, 4, 8]
    group_m_range = [1, 4, 8, 16, 32]
    num_stage_range = [2]
347
    waves_per_eu_range = [0, 1, 2, 4]
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
    matrix_instr_nonkdim_range = [16, 32] if use_fp16 else []
    kpack_range = [1, 2] if use_fp16 else []

    param_ranges = {
        "BLOCK_SIZE_M": block_mn_range,
        "BLOCK_SIZE_N": block_mn_range,
        "BLOCK_SIZE_K": block_k_range,
        "GROUP_SIZE_M": group_m_range,
        "num_warps": num_warps_range,
        "num_stages": num_stage_range,
        "waves_per_eu": waves_per_eu_range,
    }
    if use_fp16:
        param_ranges["matrix_instr_nonkdim"] = matrix_instr_nonkdim_range
        param_ranges["kpack"] = kpack_range

    return param_ranges


367
def get_configs_compute_bound(use_fp16, block_quant_shape) -> list[dict[str, int]]:
368
    configs: list[BenchmarkConfig] = []
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395

    if current_platform.is_rocm():
        param_ranges = get_rocm_tuning_space(use_fp16)
    else:
        # Reduced search space for faster tuning.
        # TODO(woosuk): Increase the search space and use a performance model to
        # prune the search space.
        block_m_range = [16, 32, 64, 128, 256]
        block_n_range = [32, 64, 128, 256]
        block_k_range = [64, 128, 256]
        num_warps_range = [4, 8]
        group_m_range = [1, 16, 32, 64]
        num_stage_range = [2, 3, 4, 5]

        param_ranges = {
            "BLOCK_SIZE_M": block_m_range,
            "BLOCK_SIZE_N": block_n_range,
            "BLOCK_SIZE_K": block_k_range,
            "GROUP_SIZE_M": group_m_range,
            "num_warps": num_warps_range,
            "num_stages": num_stage_range,
        }

    keys, values = zip(*param_ranges.items())
    for config_values in product(*values):
        config = dict(zip(keys, config_values))
        configs.append(config)
396
397
398
399
400
401
402

    # Remove configs that are not compatible with fp8 block quantization
    # BLOCK_SIZE_K must be a multiple of block_k
    # BLOCK_SIZE_N must be a multiple of block_n
    if block_quant_shape is not None and not use_fp16:
        block_n, block_k = block_quant_shape[0], block_quant_shape[1]
        for config in configs[:]:
403
404
405
406
            if (
                config["BLOCK_SIZE_K"] % block_k != 0
                or config["BLOCK_SIZE_N"] % block_n != 0
            ):
407
                configs.remove(config)
408
409
410
    return configs


411
412
413
def prune_rocm_search_space(
    num_tokens, shard_intermediate_size, hidden_size, search_space, is_fp16, topk
):
414
415
    N1, K1 = shard_intermediate_size, hidden_size
    N2, K2 = hidden_size, shard_intermediate_size // 2
416
417
418
419
420
421
    pruned_space_1 = prune_rocm_configs(
        num_tokens * topk, N1, K1, search_space, is_fp16
    )
    pruned_space_2 = prune_rocm_configs(
        num_tokens * topk, N2, K2, search_space, is_fp16
    )
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
    search_space = merge_unique_dicts(pruned_space_1, pruned_space_2)
    return search_space


# The following code is inspired by ROCm/Triton GEMM tuning script:
# https://github.com/ROCm/triton/blob/triton-mlir/scripts/amd/gemm/tune_gemm.py#L89
def prune_rocm_configs(M, N, K, configs, is_fp16=True):
    pruned_configs = []
    elemBytes_a = 2 if is_fp16 else 1
    elemBytes_b = 2 if is_fp16 else 1

    mfma = 16 if M < 32 or N < 32 else 32

    # TODO (zhanglx): figure out the boundary between large and small gemms
    large_gemm = False
    if M >= 2048 and N >= 2048:
        large_gemm = True

    for config in configs:
        BLOCK_SIZE_M = config.get("BLOCK_SIZE_M")
        BLOCK_SIZE_N = config.get("BLOCK_SIZE_N")
        BLOCK_SIZE_K = config.get("BLOCK_SIZE_K")
        num_warps = config.get("num_warps")

        if is_fp16:
            matrix_instr_nonkdim = config.get("matrix_instr_nonkdim")
            if matrix_instr_nonkdim > mfma:
                continue
        if mfma == 4 and BLOCK_SIZE_K < 64:
            continue
        # some layouts could not work properly in case
        # number elements per thread is less 1
        if BLOCK_SIZE_M * BLOCK_SIZE_N < 64:
            continue
        SPLIT_K = config.get("SPLIT_K", 1)
        GROUP_M = config.get("GROUP_SIZE_M")
        if is_fp16:
459
460
461
462
            if (
                matrix_instr_nonkdim > BLOCK_SIZE_M
                or matrix_instr_nonkdim > BLOCK_SIZE_N
            ):
463
                continue
464
            if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M:
465
                continue
466
            if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N:
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
                continue
        # Skip BLOCK_SIZE that is too large compare to M/N
        # unless BLOCK_SIZE is already small enough
        if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16:
            continue
        if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16:
            continue
        # skip large split_k when not necessary
        if SPLIT_K != 1 and not need_split_k(M, N, K):
            continue
        # skip split_k that leads to EVEN_K = false
        leap = SPLIT_K * BLOCK_SIZE_K
        modv = K % leap
        if modv != 0:
            continue
        # skip large GROUP_M
        if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1:
            continue
        # out of shared memory resource
        # TODO (zhanglx): This does not consider the LDS usage in the epilogue
487
488
489
490
        LDS = (
            BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a
            + BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b
        )
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
        if LDS > 65536:
            continue
        # Skip small block sizes and num_warps for large gemm
        # For fp16 and f8, we want to only use BLOCK_SIZE >= 64
        if large_gemm:
            if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64:
                continue
            if BLOCK_SIZE_K < 64:
                continue
            if num_warps < 4:
                continue

        pruned_configs.append(config)

    return pruned_configs


def need_split_k(SIZE_M, SIZE_N, SIZE_K):
    return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024


def merge_unique_dicts(list1, list2):
    result = []
    combined_list = list1.copy()
    combined_list.extend(list2)
    for dictionary in combined_list:
        if dictionary not in result:
            result.append(dictionary)
    return result


522
523
524
525
@ray.remote(num_gpus=1)
class BenchmarkWorker:
    def __init__(self, seed: int) -> None:
        torch.set_default_device("cuda")
526
        set_random_seed(seed)
527
        self.seed = seed
528
529
530
531
        # Get the device ID to allocate tensors and kernels
        # on the respective GPU. This is required for Ray to work
        # correctly with multi-GPU tuning on the ROCm platform.
        self.device_id = int(ray.get_gpu_ids()[0])
532
533
534
535
536
537
538
539
540

    def benchmark(
        self,
        num_tokens: int,
        num_experts: int,
        shard_intermediate_size: int,
        hidden_size: int,
        topk: int,
        dtype: torch.dtype,
541
542
        use_fp8_w8a8: bool,
        use_int8_w8a16: bool,
543
        use_int4_w4a16: bool = False,
544
        block_quant_shape: list[int] = None,
545
        use_deep_gemm: bool = False,
546
    ) -> tuple[dict[str, int], float]:
547
548
        # local import to allow serialization by ray

549
        set_random_seed(self.seed)
550
        dtype_str = _get_config_dtype_str(
551
552
553
554
            dtype,
            use_int8_w8a16=use_int8_w8a16,
            use_fp8_w8a8=use_fp8_w8a8,
            use_int4_w4a16=use_int4_w4a16,
555
        )
556
557
        # NOTE(woosuk): The current naming convention uses w2.shape[2], which
        # is the intermediate size after silu_and_mul.
558
559
        block_n = block_quant_shape[0] if block_quant_shape else None
        block_k = block_quant_shape[1] if block_quant_shape else None
560
        op_config = get_moe_configs(
561
            num_experts, shard_intermediate_size // 2, dtype_str, block_n, block_k
562
        )
563
        if op_config is None:
564
565
566
567
568
569
570
            config = get_default_config(
                num_tokens,
                num_experts,
                shard_intermediate_size,
                hidden_size,
                topk,
                dtype_str,
571
                block_quant_shape,
572
            )
573
        else:
574
575
576
577
578
579
580
581
582
583
584
            config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))]
        kernel_time = benchmark_config(
            config,
            num_tokens,
            num_experts,
            shard_intermediate_size,
            hidden_size,
            topk,
            dtype,
            use_fp8_w8a8,
            use_int8_w8a16,
585
            use_int4_w4a16=use_int4_w4a16,
586
587
588
589
            num_iters=100,
            block_quant_shape=block_quant_shape,
            use_deep_gemm=use_deep_gemm,
        )
590
591
592
593
594
595
596
597
598
599
        return config, kernel_time

    def tune(
        self,
        num_tokens: int,
        num_experts: int,
        shard_intermediate_size: int,
        hidden_size: int,
        topk: int,
        dtype: torch.dtype,
600
601
        use_fp8_w8a8: bool,
        use_int8_w8a16: bool,
602
        use_int4_w4a16: bool,
603
        search_space: list[dict[str, int]],
604
        block_quant_shape: list[int],
605
        use_deep_gemm: bool,
606
    ) -> dict[str, int]:
607
608
609
        # local import to allow serialization by ray
        from vllm.platforms import current_platform

610
611
        best_config = None
        best_time = float("inf")
612
        if current_platform.is_rocm():
613
            is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16 or use_int4_w4a16)
614
615
616
617
618
619
620
621
            search_space = prune_rocm_search_space(
                num_tokens,
                shard_intermediate_size,
                hidden_size,
                search_space,
                is_fp16,
                topk,
            )
622

623
624
625
626
627
628
        need_device_guard = False
        if current_platform.is_rocm():
            visible_device = os.environ.get("ROCR_VISIBLE_DEVICES", None)
            if visible_device != f"{self.device_id}":
                need_device_guard = True

629
        with torch.cuda.device(self.device_id) if need_device_guard else nullcontext():
630
            for idx, config in enumerate(tqdm(search_space)):
631
                try:
632
633
634
635
636
637
638
639
640
641
                    kernel_time = benchmark_config(
                        config,
                        num_tokens,
                        num_experts,
                        shard_intermediate_size,
                        hidden_size,
                        topk,
                        dtype,
                        use_fp8_w8a8,
                        use_int8_w8a16,
642
                        use_int4_w4a16,
643
                        num_iters=20,
644
                        block_quant_shape=block_quant_shape,
645
646
                        use_deep_gemm=use_deep_gemm,
                    )
647
648
649
650
651
652
653
                except triton.runtime.autotuner.OutOfResources:
                    # Some configurations may be invalid and fail to compile.
                    continue

                if kernel_time < best_time:
                    best_time = kernel_time
                    best_config = config
654
655
656
657
658
659
660
661
662
663
664
665
666

                # Periodically clear Triton JIT cache to prevent OOM
                # This is especially important for large models with many experts
                if (
                    TRITON_CACHE_CLEAR_INTERVAL > 0
                    and idx > 0
                    and idx % TRITON_CACHE_CLEAR_INTERVAL == 0
                ):
                    clear_triton_cache()

        # Final cleanup after tuning completes
        clear_triton_cache()

667
668
        now = datetime.now()
        print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
669
        assert best_config is not None
670
671
672
        return best_config


673
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
674
    return {
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
        "BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
        "BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
        "BLOCK_SIZE_K": config["BLOCK_SIZE_K"],
        "GROUP_SIZE_M": config["GROUP_SIZE_M"],
        "num_warps": config["num_warps"],
        "num_stages": config["num_stages"],
        **(
            {"waves_per_eu": config["waves_per_eu"]} if "waves_per_eu" in config else {}
        ),
        **(
            {"matrix_instr_nonkdim": config["matrix_instr_nonkdim"]}
            if "matrix_instr_nonkdim" in config
            else {}
        ),
        **({"kpack": config["kpack"]} if "kpack" in config else {}),
690
        **({"SPLIT_K": config["SPLIT_K"]} if "SPLIT_K" in config else {}),
691
692
693
    }


694
695
696
697
698
699
700
701
702
def save_configs(
    configs: dict[int, BenchmarkConfig],
    num_experts: int,
    shard_intermediate_size: int,
    hidden_size: int,
    topk: int,
    dtype: torch.dtype,
    use_fp8_w8a8: bool,
    use_int8_w8a16: bool,
703
    use_int4_w4a16: bool,
704
    block_quant_shape: list[int],
705
    save_dir: str,
706
) -> None:
707
    dtype_str = _get_config_dtype_str(
708
709
710
711
        dtype,
        use_int8_w8a16=use_int8_w8a16,
        use_fp8_w8a8=use_fp8_w8a8,
        use_int4_w4a16=use_int4_w4a16,
712
    )
713

714
715
    # NOTE(woosuk): The current naming convention uses w2.shape[2], which
    # is the intermediate size after silu_and_mul.
716
717
718
    filename = get_config_file_name(
        num_experts, shard_intermediate_size // 2, dtype_str, block_quant_shape
    )
719
720
    os.makedirs(save_dir, exist_ok=True)
    filename = os.path.join(save_dir, filename)
721
722
    print(f"Writing best config to {filename}...")
    with open(filename, "w") as f:
723
        json.dump({"triton_version": triton.__version__, **configs}, f, indent=4)
724
725
726
        f.write("\n")


727
728
729
730
731
732
733
734
735
736
def get_compressed_tensors_block_structure(config, default_value=None):
    config_groups = config.get("config_groups", {})
    if len(config_groups) != 1:
        return default_value
    group = next(iter(config_groups.values()))
    weights = group.get("weights", {})
    block_structure = weights.get("block_structure", default_value)
    return block_structure


737
def get_weight_block_size_safety(config, default_value=None):
738
    quantization_config = getattr(config, "quantization_config", {})
739
    if isinstance(quantization_config, dict):
740
741
742
743
744
        if "weight_block_size" in quantization_config:
            return quantization_config["weight_block_size"]
        return get_compressed_tensors_block_structure(
            quantization_config, default_value
        )
745
746
747
    return default_value


748
def get_model_params(config):
749
750
751
752
    if config.architectures[0] == "DbrxForCausalLM":
        E = config.ffn_config.moe_num_experts
        topk = config.ffn_config.moe_top_k
        intermediate_size = config.ffn_config.ffn_hidden_size
753
        hidden_size = config.hidden_size
754
755
756
757
    elif config.architectures[0] == "JambaForCausalLM":
        E = config.num_experts
        topk = config.num_experts_per_tok
        intermediate_size = config.intermediate_size
758
        hidden_size = config.hidden_size
Yuxuan Zhang's avatar
Yuxuan Zhang committed
759
760
    elif config.architectures[0] in (
        "DeepseekV2ForCausalLM",
761
762
        "DeepseekV3ForCausalLM",
        "DeepseekV32ForCausalLM",
Jee Jee Li's avatar
Jee Jee Li committed
763
        "GlmMoeDsaForCausalLM",
Yuxuan Zhang's avatar
Yuxuan Zhang committed
764
        "Glm4MoeForCausalLM",
765
        "Glm4MoeLiteForCausalLM",
766
        "NemotronHForCausalLM",
767
        "MistralLarge3ForCausalLM",
Yuxuan Zhang's avatar
Yuxuan Zhang committed
768
    ):
769
770
771
        E = config.n_routed_experts
        topk = config.num_experts_per_tok
        intermediate_size = config.moe_intermediate_size
772
        hidden_size = config.hidden_size
773
774
775
776
777
    elif config.architectures[0] in (
        "Qwen2MoeForCausalLM",
        "Qwen3MoeForCausalLM",
        "Qwen3NextForCausalLM",
    ):
778
779
780
        E = config.num_experts
        topk = config.num_experts_per_tok
        intermediate_size = config.moe_intermediate_size
781
782
783
784
785
786
787
        hidden_size = config.hidden_size
    elif config.architectures[0] == "Qwen3VLMoeForConditionalGeneration":
        text_config = config.get_text_config()
        E = text_config.num_experts
        topk = text_config.num_experts_per_tok
        intermediate_size = text_config.moe_intermediate_size
        hidden_size = text_config.hidden_size
788
    elif config.architectures[0] == "HunYuanMoEV1ForCausalLM":
789
790
791
        E = config.num_experts
        topk = config.moe_topk[0]
        intermediate_size = config.moe_intermediate_size[0]
792
        hidden_size = config.hidden_size
793
    elif config.architectures[0] == "Qwen3OmniMoeForConditionalGeneration":
794
795
796
797
        E = config.thinker_config.text_config.num_experts
        topk = config.thinker_config.text_config.num_experts_per_tok
        intermediate_size = config.thinker_config.text_config.moe_intermediate_size
        hidden_size = config.thinker_config.text_config.hidden_size
798
799
800
801
    elif config.architectures[0] == "PixtralForConditionalGeneration":
        # Pixtral can contain different LLM architectures,
        # recurse to get their parameters
        return get_model_params(config.get_text_config())
802
    else:
803
804
        # Support for llama4
        config = config.get_text_config()
805
806
807
808
        # Default: Mixtral.
        E = config.num_local_experts
        topk = config.num_experts_per_tok
        intermediate_size = config.intermediate_size
809
        hidden_size = config.hidden_size
810
811
812
    return E, topk, intermediate_size, hidden_size


813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
def get_quantization_group_size(config) -> int | None:
    """Extract the quantization group size from the HF model config.

    This reads directly from the HuggingFace config object (as returned by
    ``get_config()``), not from vLLM's quantization config classes.

    Supports AWQ/GPTQ-style configs (direct 'group_size' key) and
    compressed-tensors configs (nested inside 'config_groups').
    """
    quantization_config = getattr(config, "quantization_config", {})
    if not isinstance(quantization_config, dict):
        return None
    # AWQ / GPTQ style: group_size is a top-level key
    gs = quantization_config.get("group_size")
    if gs is not None:
        return gs
    # compressed-tensors style: group_size is nested in config_groups
    config_groups = quantization_config.get("config_groups", {})
    if not isinstance(config_groups, dict):
        return None
    for group_cfg in config_groups.values():
        if not isinstance(group_cfg, dict):
            continue
        weights = group_cfg.get("weights", {})
        if not isinstance(weights, dict):
            continue
        gs = weights.get("group_size")
        if gs is not None:
            return gs
    return None


845
846
847
848
849
850
851
def main(args: argparse.Namespace):
    print(args)

    config = get_config(model=args.model, trust_remote_code=args.trust_remote_code)
    if args.model_prefix:
        config = getattr(config, args.model_prefix)
    E, topk, intermediate_size, hidden_size = get_model_params(config)
852
853
854
855
856
857
858
    enable_ep = bool(args.enable_expert_parallel)
    if enable_ep:
        ensure_divisibility(E, args.tp_size, "Number of experts")
        E = E // args.tp_size
        shard_intermediate_size = 2 * intermediate_size
    else:
        ensure_divisibility(intermediate_size, args.tp_size, "intermediate_size")
859
        shard_intermediate_size = 2 * intermediate_size // args.tp_size
860
    dtype = torch.float16 if current_platform.is_rocm() else config.dtype
861
862
    use_fp8_w8a8 = args.dtype == "fp8_w8a8"
    use_int8_w8a16 = args.dtype == "int8_w8a16"
863
    use_int4_w4a16 = args.dtype == "int4_w4a16"
864
    block_quant_shape = get_weight_block_size_safety(config)
865
866
867
868
869
870
871
872
873
874
875
876
    if use_int4_w4a16:
        group_size = get_quantization_group_size(config)
        if group_size is None:
            raise ValueError(
                "Could not determine group_size from model config. "
                "The model's quantization_config must contain a 'group_size' "
                "field (AWQ/GPTQ) or 'config_groups.*.weights.group_size' "
                "(compressed-tensors)."
            )
        # For int4_w4a16, block_shape = [0, group_size]
        # block_shape[0]=0 means no block quantization on N dimension
        block_quant_shape = [0, group_size]
877
878

    if args.batch_size is None:
879
        batch_sizes = [
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
            1,
            2,
            4,
            8,
            16,
            24,
            32,
            48,
            64,
            96,
            128,
            256,
            512,
            1024,
            1536,
            2048,
            3072,
            4096,
898
        ]
899
    else:
900
        batch_sizes = args.batch_size
901

902
903
    use_deep_gemm = bool(args.use_deep_gemm)

904
905
906
907
    if current_platform.is_rocm() and "HIP_VISIBLE_DEVICES" in os.environ:
        # Ray will set ROCR_VISIBLE_DEVICES for device visibility
        logger.warning(
            "Ray uses ROCR_VISIBLE_DEVICES to control device accessibility."
908
909
            "Replacing HIP_VISIBLE_DEVICES with ROCR_VISIBLE_DEVICES."
        )
910
911
912
913
        val = os.environ["HIP_VISIBLE_DEVICES"]
        os.environ["ROCR_VISIBLE_DEVICES"] = val
        del os.environ["HIP_VISIBLE_DEVICES"]

914
915
916
917
    ray.init()
    num_gpus = int(ray.available_resources()["GPU"])
    workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]

918
    def _distribute(method: str, inputs: list[Any]) -> list[Any]:
919
920
921
922
923
924
925
926
927
928
929
        outputs = []
        worker_idx = 0
        for input_args in inputs:
            worker = workers[worker_idx]
            worker_method = getattr(worker, method)
            output = worker_method.remote(*input_args)
            outputs.append(output)
            worker_idx = (worker_idx + 1) % num_gpus
        return ray.get(outputs)

    if args.tune:
930
931
932
933
934
935
936
937
938
939
940
941
942
943
        # int4_w4a16 weights are uint8-packed, not fp16; treat like fp8 for
        # search space generation (no matrix_instr_nonkdim/kpack exploration).
        is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16 or use_int4_w4a16)
        # For int4_w4a16, the group_size constraint on BLOCK_SIZE_K does not
        # apply: the gptq_awq kernel handles arbitrary BLOCK_SIZE_K regardless
        # of group_size. Skip block_quant_shape filtering to keep the full
        # search space (e.g. BLOCK_SIZE_K=64 with group_size=128).
        tune_block_quant_shape = None if use_int4_w4a16 else block_quant_shape
        search_space = get_configs_compute_bound(is_fp16, tune_block_quant_shape)
        if use_int4_w4a16:
            # SPLIT_K is a required kernel constexpr for gptq_awq kernel;
            # only SPLIT_K=1 is used at runtime, so fix it during tuning.
            for cfg in search_space:
                cfg["SPLIT_K"] = 1
944
        print(f"Start tuning over {len(search_space)} configurations...")
945
946
947
948
949
        if use_deep_gemm:
            raise ValueError(
                "Tuning with --use-deep-gemm is not supported as it only tunes Triton "
                "kernels. Please remove the flag."
            )
950
951
        start = time.time()
        configs = _distribute(
952
953
954
955
956
957
958
959
960
961
962
            "tune",
            [
                (
                    batch_size,
                    E,
                    shard_intermediate_size,
                    hidden_size,
                    topk,
                    dtype,
                    use_fp8_w8a8,
                    use_int8_w8a16,
963
                    use_int4_w4a16,
964
965
966
967
968
969
970
                    search_space,
                    block_quant_shape,
                    use_deep_gemm,
                )
                for batch_size in batch_sizes
            ],
        )
971
        best_configs = {
972
            M: sort_config(config) for M, config in zip(batch_sizes, configs)
973
        }
974
975
976
977
978
979
980
981
982
        save_configs(
            best_configs,
            E,
            shard_intermediate_size,
            hidden_size,
            topk,
            dtype,
            use_fp8_w8a8,
            use_int8_w8a16,
983
            use_int4_w4a16,
984
            block_quant_shape,
985
            args.save_dir,
986
        )
987
988
989
        end = time.time()
        print(f"Tuning took {end - start:.2f} seconds")
    else:
990
        outputs = _distribute(
991
            "benchmark",
992
993
994
995
996
997
998
999
1000
1001
            [
                (
                    batch_size,
                    E,
                    shard_intermediate_size,
                    hidden_size,
                    topk,
                    dtype,
                    use_fp8_w8a8,
                    use_int8_w8a16,
1002
                    use_int4_w4a16,
1003
1004
1005
1006
1007
1008
                    block_quant_shape,
                    use_deep_gemm,
                )
                for batch_size in batch_sizes
            ],
        )
1009
1010
1011
1012
1013
1014
1015

        for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
            print(f"Batch size: {batch_size}, config: {config}")
            print(f"Kernel time: {kernel_time:.2f} us")


if __name__ == "__main__":
1016
    parser = FlexibleArgumentParser()
1017
1018
1019
1020
1021
1022
    parser.add_argument(
        "--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
    )
    parser.add_argument(
        "--tp-size", "-tp", "--tensor-parallel-size", type=int, default=2
    )
1023
    parser.add_argument("--enable-expert-parallel", "-enable-ep", action="store_true")
1024
    parser.add_argument(
1025
1026
1027
1028
        "--dtype",
        type=str,
        choices=["auto", "fp8_w8a8", "int8_w8a16", "int4_w4a16"],
        default="auto",
1029
    )
1030
    parser.add_argument("--use-deep-gemm", action="store_true")
1031
1032
1033
    parser.add_argument(
        "--save-dir", type=str, default="./", help="Directory to save tuned results"
    )
1034
    parser.add_argument("--seed", type=int, default=0)
1035
    parser.add_argument("--batch-size", type=int, nargs="+", required=False)
1036
    parser.add_argument("--tune", action="store_true")
1037
    parser.add_argument("--trust-remote-code", action="store_true")
1038
    parser.add_argument("--model-prefix", type=str, required=False)
1039
1040
1041
    args = parser.parse_args()

    main(args)