benchmark_moe.py 25 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import argparse
4
import json
5
import time
6
from contextlib import nullcontext
7
from datetime import datetime
8
from itertools import product
9
from typing import Any, TypedDict
10
11
12
13
14
15
16
17

import ray
import torch
import triton
from ray.experimental.tqdm_ray import tqdm
from transformers import AutoConfig

from vllm.model_executor.layers.fused_moe.fused_moe import *
18
19
from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser
20

21
FP8_DTYPE = current_platform.fp8_dtype()
22
23


24
25
26
27
28
29
30
class BenchmarkConfig(TypedDict):
    BLOCK_SIZE_M: int
    BLOCK_SIZE_N: int
    BLOCK_SIZE_K: int
    GROUP_SIZE_M: int
    num_warps: int
    num_stages: int
zhuwenwen's avatar
zhuwenwen committed
31
    num_ldmatrixes: Optional[int]
32
33


34
def benchmark_config(
35
    config: BenchmarkConfig,
36
37
38
39
40
41
    num_tokens: int,
    num_experts: int,
    shard_intermediate_size: int,
    hidden_size: int,
    topk: int,
    dtype: torch.dtype,
42
43
    use_fp8_w8a8: bool,
    use_int8_w8a16: bool,
44
    num_iters: int = 100,
45
    block_quant_shape: List[int] = None,
王敏's avatar
王敏 committed
46
    nn_moe: Optional[bool] = False,
47
) -> float:
48
    init_dtype = torch.float16 if use_fp8_w8a8 else dtype
49
    x = torch.randn(num_tokens, hidden_size, dtype=dtype)
50
    if use_int8_w8a16:
zhuwenwen's avatar
zhuwenwen committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
        if not nn_moe:
            w1 = torch.randint(-127,
                            127, (
                                num_experts,
                                shard_intermediate_size,
                                hidden_size,
                            ),
                            dtype=torch.int8)
            w2 = torch.randint(-127,
                            127, (
                                num_experts,
                                hidden_size,
                                shard_intermediate_size // 2,
                            ),
                            dtype=torch.int8)
        else:
            w1 = torch.randint(-127,
                            127, (
                                num_experts,
                                hidden_size,
                                shard_intermediate_size
                            ),
                            dtype=torch.int8)
            w2 = torch.randint(-127,
                            127, (
                                num_experts,
                                shard_intermediate_size // 2,
                                hidden_size
                            ),
                            dtype=torch.int8)
81
    else:
zhuwenwen's avatar
zhuwenwen committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
        if not nn_moe:
            w1 = torch.randn(num_experts,
                            shard_intermediate_size,
                            hidden_size,
                            dtype=init_dtype)
            w2 = torch.randn(num_experts,
                            hidden_size,
                            shard_intermediate_size // 2,
                            dtype=init_dtype)
        else:
            w1 = torch.randn(num_experts,
                             hidden_size,
                            shard_intermediate_size,
                            dtype=init_dtype)
            w2 = torch.randn(num_experts,
                             shard_intermediate_size // 2,
                            hidden_size,
                            dtype=init_dtype)
100
101
102
103
104
105
106
107
108
    gating_output = torch.randn(num_iters,
                                num_tokens,
                                num_experts,
                                dtype=torch.float32)

    w1_scale = None
    w2_scale = None
    a1_scale = None
    a2_scale = None
109
110
111
112
113
    if use_int8_w8a16:
        w1_scale = torch.randn((num_experts, 2 * shard_intermediate_size),
                               dtype=torch.float32)
        w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
    if use_fp8_w8a8:
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
        if block_quant_shape:
            block_n, block_k = block_quant_shape[0], block_quant_shape[1]
            E = num_experts
            N = shard_intermediate_size // 2
            K = hidden_size
            factor_for_scale = 1e-2
            n_tiles_w1 = (2 * N + block_n - 1) // block_n
            n_tiles_w2 = (K + block_n - 1) // block_n
            k_tiles_w1 = (K + block_k - 1) // block_k
            k_tiles_w2 = (N + block_k - 1) // block_k
            w1_scale = torch.rand((E, n_tiles_w1, k_tiles_w1),
                                  dtype=torch.float32) * factor_for_scale
            w2_scale = torch.rand((E, n_tiles_w2, k_tiles_w2),
                                  dtype=torch.float32) * factor_for_scale
        else:
            w1_scale = torch.randn(num_experts, dtype=torch.float32)
            w2_scale = torch.randn(num_experts, dtype=torch.float32)

132
133
134
        a1_scale = torch.randn(1, dtype=torch.float32)
        a2_scale = torch.randn(1, dtype=torch.float32)

135
136
        w1 = w1.to(FP8_DTYPE)
        w2 = w2.to(FP8_DTYPE)
137
138
139
140
141
142
143

    input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32)

    def prepare(i: int):
        input_gating.copy_(gating_output[i])

    def run():
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
        from vllm.model_executor.layers.fused_moe import override_config
        with override_config(config):
            fused_moe(
                x,
                w1,
                w2,
                input_gating,
                topk,
                renormalize=True,
                inplace=True,
                use_fp8_w8a8=use_fp8_w8a8,
                use_int8_w8a16=use_int8_w8a16,
                w1_scale=w1_scale,
                w2_scale=w2_scale,
                a1_scale=a1_scale,
                a2_scale=a2_scale,
160
                block_shape=block_quant_shape,
zhuwenwen's avatar
zhuwenwen committed
161
                use_nn_moe=nn_moe,
162
            )
163
164
165
166
167
168

    # JIT compilation & warmup
    run()
    torch.cuda.synchronize()

    # Capture 10 invocations with CUDA graph
169
170
171
172
173
    graph = torch.cuda.CUDAGraph()
    with torch.cuda.graph(graph):
        for _ in range(10):
            run()
    torch.cuda.synchronize()
174
175
176

    # Warmup
    for _ in range(5):
177
178
        graph.replay()
        # run()
179
180
181
182
183
    torch.cuda.synchronize()

    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)

184
    latencies: list[float] = []
185
186
187
188
189
    for i in range(num_iters):
        prepare(i)
        torch.cuda.synchronize()

        start_event.record()
190
191
        graph.replay()
        # run()
192
193
194
195
        end_event.record()
        end_event.synchronize()
        latencies.append(start_event.elapsed_time(end_event))
    avg = sum(latencies) / (num_iters * 10) * 1000  # us
196
    graph.reset()
197
198
199
    return avg


zhuwenwen's avatar
zhuwenwen committed
200
def get_rocm_tuning_space(use_fp16, nn_moe: Optional[bool] = False):
201
202
    block_m_range = [16, 32, 64, 128, 256]
    block_n_range = [32, 64, 128, 256]
203
    block_k_range = [32, 64, 128, 256]
204
205
    if not use_fp16:
        block_k_range.remove(16)  # BLOCK_K=16 not supported for fp8
206
207
208
209
210
211
    num_warps_range = [2, 4, 8]
    group_m_range = [1, 16, 32, 64]
    num_stage_range = [2, 3, 4, 5]
    # waves_per_eu_range = [0]
    # matrix_instr_nonkdim_range = [16, 32] if use_fp16 else []
    # kpack_range = [1, 2] if use_fp16 else []
212
213

    param_ranges = {
214
215
        "BLOCK_SIZE_M": block_m_range,
        "BLOCK_SIZE_N": block_n_range,
216
217
218
219
        "BLOCK_SIZE_K": block_k_range,
        "GROUP_SIZE_M": group_m_range,
        "num_warps": num_warps_range,
        "num_stages": num_stage_range,
220
        # "waves_per_eu": waves_per_eu_range,
221
    }
zhuwenwen's avatar
zhuwenwen committed
222
    if nn_moe:
223
224
225
226
227
228
        param_ranges["num_ldmatrixes"] = [1]
    
    # DCU currently does not support the following parameters
    # if use_fp16:
    #     param_ranges["matrix_instr_nonkdim"] = matrix_instr_nonkdim_range
    #     param_ranges["kpack"] = kpack_range
229
230
231
232

    return param_ranges


233
def get_configs_compute_bound(use_fp16,
zhuwenwen's avatar
zhuwenwen committed
234
                              block_quant_shape, nn_moe: Optional[bool] = False) -> list[dict[str, int]]:
235
    configs: list[BenchmarkConfig] = []
236
237

    if current_platform.is_rocm():
zhuwenwen's avatar
zhuwenwen committed
238
        param_ranges = get_rocm_tuning_space(use_fp16, nn_moe)
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
    else:
        # Reduced search space for faster tuning.
        # TODO(woosuk): Increase the search space and use a performance model to
        # prune the search space.
        block_m_range = [16, 32, 64, 128, 256]
        block_n_range = [32, 64, 128, 256]
        block_k_range = [64, 128, 256]
        num_warps_range = [4, 8]
        group_m_range = [1, 16, 32, 64]
        num_stage_range = [2, 3, 4, 5]

        param_ranges = {
            "BLOCK_SIZE_M": block_m_range,
            "BLOCK_SIZE_N": block_n_range,
            "BLOCK_SIZE_K": block_k_range,
            "GROUP_SIZE_M": group_m_range,
            "num_warps": num_warps_range,
            "num_stages": num_stage_range,
        }

    keys, values = zip(*param_ranges.items())
    for config_values in product(*values):
        config = dict(zip(keys, config_values))
        configs.append(config)
263
264
265
266
267
268
269
270
271
272

    # Remove configs that are not compatible with fp8 block quantization
    # BLOCK_SIZE_K must be a multiple of block_k
    # BLOCK_SIZE_N must be a multiple of block_n
    if block_quant_shape is not None and not use_fp16:
        block_n, block_k = block_quant_shape[0], block_quant_shape[1]
        for config in configs[:]:
            if config["BLOCK_SIZE_K"] % block_k != 0 or config[
                    "BLOCK_SIZE_N"] % block_n != 0:
                configs.remove(config)
273
274
275
    return configs


276
def prune_rocm_search_space(num_tokens, shard_intermediate_size, hidden_size,
277
                            search_space, is_fp16, topk):
278
279
    N1, K1 = shard_intermediate_size, hidden_size
    N2, K2 = hidden_size, shard_intermediate_size // 2
280
281
282
283
    pruned_space_1 = prune_rocm_configs(num_tokens * topk, N1, K1,
                                        search_space, is_fp16)
    pruned_space_2 = prune_rocm_configs(num_tokens * topk, N2, K2,
                                        search_space, is_fp16)
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
    search_space = merge_unique_dicts(pruned_space_1, pruned_space_2)
    return search_space


# The following code is inspired by ROCm/Triton GEMM tuning script:
# https://github.com/ROCm/triton/blob/triton-mlir/scripts/amd/gemm/tune_gemm.py#L89
def prune_rocm_configs(M, N, K, configs, is_fp16=True):
    pruned_configs = []
    elemBytes_a = 2 if is_fp16 else 1
    elemBytes_b = 2 if is_fp16 else 1

    mfma = 16 if M < 32 or N < 32 else 32

    # TODO (zhanglx): figure out the boundary between large and small gemms
    large_gemm = False
    if M >= 2048 and N >= 2048:
        large_gemm = True

    for config in configs:
        BLOCK_SIZE_M = config.get("BLOCK_SIZE_M")
        BLOCK_SIZE_N = config.get("BLOCK_SIZE_N")
        BLOCK_SIZE_K = config.get("BLOCK_SIZE_K")
        num_warps = config.get("num_warps")

308
309
310
311
312
        # DCU currently does not support matrix_instr_nonkdim param
        # if is_fp16:
        #     matrix_instr_nonkdim = config.get("matrix_instr_nonkdim")
        #     if matrix_instr_nonkdim > mfma:
        #         continue
313
314
315
316
317
318
319
320
        if mfma == 4 and BLOCK_SIZE_K < 64:
            continue
        # some layouts could not work properly in case
        # number elements per thread is less 1
        if BLOCK_SIZE_M * BLOCK_SIZE_N < 64:
            continue
        SPLIT_K = config.get("SPLIT_K", 1)
        GROUP_M = config.get("GROUP_SIZE_M")
321
322
323
324
325
326
327
328
329
330
331
332

        # DCU currently does not support matrix_instr_nonkdim param
        # if is_fp16:
        #     if (matrix_instr_nonkdim > BLOCK_SIZE_M
        #             or matrix_instr_nonkdim > BLOCK_SIZE_N):
        #         continue
        #     if (matrix_instr_nonkdim >= M
        #             and matrix_instr_nonkdim != BLOCK_SIZE_M):
        #         continue
        #     if (matrix_instr_nonkdim >= N
        #             and matrix_instr_nonkdim != BLOCK_SIZE_N):
        #         continue
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
        # Skip BLOCK_SIZE that is too large compare to M/N
        # unless BLOCK_SIZE is already small enough
        if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16:
            continue
        if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16:
            continue
        # skip large split_k when not necessary
        if SPLIT_K != 1 and not need_split_k(M, N, K):
            continue
        # skip split_k that leads to EVEN_K = false
        leap = SPLIT_K * BLOCK_SIZE_K
        modv = K % leap
        if modv != 0:
            continue
        # skip large GROUP_M
        if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1:
            continue
        # out of shared memory resource
        # TODO (zhanglx): This does not consider the LDS usage in the epilogue
        LDS = (BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a +
               BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b)
        if LDS > 65536:
            continue
        # Skip small block sizes and num_warps for large gemm
        # For fp16 and f8, we want to only use BLOCK_SIZE >= 64
        if large_gemm:
            if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64:
                continue
            if BLOCK_SIZE_K < 64:
                continue
            if num_warps < 4:
                continue

        pruned_configs.append(config)

    return pruned_configs


def need_split_k(SIZE_M, SIZE_N, SIZE_K):
    return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024


def merge_unique_dicts(list1, list2):
    result = []
    combined_list = list1.copy()
    combined_list.extend(list2)
    for dictionary in combined_list:
        if dictionary not in result:
            result.append(dictionary)
    return result


385
386
387
@ray.remote(num_gpus=1)
class BenchmarkWorker:

王敏's avatar
王敏 committed
388
389
    def __init__(self, seed: int, device_id: int) -> None:
        torch.set_default_device("cuda:"+ str(device_id))
390
        current_platform.seed_everything(seed)
391
        self.seed = seed
392
393
394
        # Get the device ID to allocate tensors and kernels
        # on the respective GPU. This is required for Ray to work
        # correctly with multi-GPU tuning on the ROCm platform.
王敏's avatar
王敏 committed
395
        self.device_id = device_id
396
397
398
399
400
401
402
403
404

    def benchmark(
        self,
        num_tokens: int,
        num_experts: int,
        shard_intermediate_size: int,
        hidden_size: int,
        topk: int,
        dtype: torch.dtype,
405
406
        use_fp8_w8a8: bool,
        use_int8_w8a16: bool,
407
        block_quant_shape: List[int] = None,
408
        nn_moe: Optional[bool] = False,
409
    ) -> tuple[dict[str, int], float]:
410
        current_platform.seed_everything(self.seed)
411
412
413
        dtype_str = get_config_dtype_str(dtype,
                                         use_int8_w8a16=use_int8_w8a16,
                                         use_fp8_w8a8=use_fp8_w8a8)
414
415
416
        # NOTE(woosuk): The current naming convention uses w2.shape[2], which
        # is the intermediate size after silu_and_mul.
        op_config = get_moe_configs(num_experts, shard_intermediate_size // 2,
417
                                    dtype_str, use_nn_moe=nn_moe)
418
        if op_config is None:
419
420
421
422
423
424
            config = get_default_config(num_tokens,
                                        num_experts,
                                        shard_intermediate_size,
                                        hidden_size,
                                        topk,
                                        dtype_str,
425
426
                                        is_marlin=False,
                                        use_nn_moe=nn_moe)
427
428
429
        else:
            config = op_config[min(op_config.keys(),
                                   key=lambda x: abs(x - num_tokens))]
430
431
432
433
434
435
436
437
438
439
        kernel_time = benchmark_config(config,
                                       num_tokens,
                                       num_experts,
                                       shard_intermediate_size,
                                       hidden_size,
                                       topk,
                                       dtype,
                                       use_fp8_w8a8,
                                       use_int8_w8a16,
                                       num_iters=100,
zhuwenwen's avatar
zhuwenwen committed
440
                                       block_quant_shape=block_quant_shape,
王敏's avatar
王敏 committed
441
                                       nn_moe=nn_moe,)
442
443
444
445
446
447
448
449
450
451
        return config, kernel_time

    def tune(
        self,
        num_tokens: int,
        num_experts: int,
        shard_intermediate_size: int,
        hidden_size: int,
        topk: int,
        dtype: torch.dtype,
452
453
        use_fp8_w8a8: bool,
        use_int8_w8a16: bool,
454
        search_space: list[dict[str, int]],
455
        block_quant_shape: list[int],
王敏's avatar
王敏 committed
456
        nn_moe: Optional[bool] = False,
457
    ) -> dict[str, int]:
458
459
        best_config = None
        best_time = float("inf")
460
461
462
463
464
        if current_platform.is_rocm():
            is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
            search_space = prune_rocm_search_space(num_tokens,
                                                   shard_intermediate_size,
                                                   hidden_size, search_space,
465
                                                   is_fp16, topk)
466

467
468
        with torch.cuda.device(self.device_id) if current_platform.is_rocm(
        ) else nullcontext():
469
470
            for config in tqdm(search_space):
                try:
471
472
473
474
475
476
477
478
479
480
481
                    kernel_time = benchmark_config(
                        config,
                        num_tokens,
                        num_experts,
                        shard_intermediate_size,
                        hidden_size,
                        topk,
                        dtype,
                        use_fp8_w8a8,
                        use_int8_w8a16,
                        num_iters=20,
zhuwenwen's avatar
zhuwenwen committed
482
                        block_quant_shape=block_quant_shape,
王敏's avatar
王敏 committed
483
                        nn_moe=nn_moe,)
484
485
486
487
488
489
490
                except triton.runtime.autotuner.OutOfResources:
                    # Some configurations may be invalid and fail to compile.
                    continue

                if kernel_time < best_time:
                    best_time = kernel_time
                    best_config = config
491
492
        now = datetime.now()
        print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
493
        assert best_config is not None
494
495
496
        return best_config


497
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
498
499

    return {
zhuwenwen's avatar
zhuwenwen committed
500
501
502
503
504
505
506
507
508
509
510
511
512
            "BLOCK_SIZE_M": 
            config["BLOCK_SIZE_M"],
            "BLOCK_SIZE_N": 
            config["BLOCK_SIZE_N"],
            "BLOCK_SIZE_K": 
            config["BLOCK_SIZE_K"],
            "GROUP_SIZE_M": 
            config["GROUP_SIZE_M"],
            "num_warps": 
            config["num_warps"],
            "num_stages": 
            config["num_stages"],
            **({
513
514
            "num_ldmatrixes": config["num_ldmatrixes"]
            } if "num_ldmatrixes" in config else {}),
zhuwenwen's avatar
zhuwenwen committed
515
516
517
518
519
520
521
522
523
524
            **({
            "waves_per_eu": config["waves_per_eu"]
            } if "waves_per_eu" in config else {}),
            **({
                "matrix_instr_nonkdim": config["matrix_instr_nonkdim"]
            } if "matrix_instr_nonkdim" in config else {}),
            **({
                "kpack": config["kpack"]
            } if "kpack" in config else {}),
        }
525
526


527
def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int,
528
                 shard_intermediate_size: int, hidden_size: int, topk: int,
529
                 dtype: torch.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool,
zhuwenwen's avatar
zhuwenwen committed
530
                 block_quant_shape: List[int], use_nn_moe: Optional[bool] = False) -> None:
531
532
533
534
    dtype_str = get_config_dtype_str(dtype,
                                     use_int8_w8a16=use_int8_w8a16,
                                     use_fp8_w8a8=use_fp8_w8a8)

535
536
537
    # NOTE(woosuk): The current naming convention uses w2.shape[2], which
    # is the intermediate size after silu_and_mul.
    filename = get_config_file_name(num_experts, shard_intermediate_size // 2,
zhuwenwen's avatar
zhuwenwen committed
538
                                    dtype_str, block_quant_shape, use_nn_moe=use_nn_moe)
539

540
541
542
543
544
545
    print(f"Writing best config to {filename}...")
    with open(filename, "w") as f:
        json.dump(configs, f, indent=4)
        f.write("\n")


546
547
548
549
550
551
552
553
def get_weight_block_size_safety(config, default_value=None):

    quantization_config = getattr(config, 'quantization_config', {})
    if isinstance(quantization_config, dict):
        return quantization_config.get('weight_block_size', default_value)
    return default_value


554
555
def main(args: argparse.Namespace):
    print(args)
zhuwenwen's avatar
zhuwenwen committed
556
    
557
    block_quant_shape = None
558

王敏's avatar
王敏 committed
559
560
    tp_size = args.tp_size

561
562
    config = AutoConfig.from_pretrained(
        args.model, trust_remote_code=args.trust_remote_code)
563
564
565
566
    if config.architectures[0] == "DbrxForCausalLM":
        E = config.ffn_config.moe_num_experts
        topk = config.ffn_config.moe_top_k
        intermediate_size = config.ffn_config.ffn_hidden_size
王敏's avatar
王敏 committed
567
        shard_intermediate_size = 2 * intermediate_size // tp_size
568
569
570
571
    elif config.architectures[0] == "JambaForCausalLM":
        E = config.num_experts
        topk = config.num_experts_per_tok
        intermediate_size = config.intermediate_size
王敏's avatar
王敏 committed
572
        shard_intermediate_size = 2 * intermediate_size // tp_size
573
574
    elif (config.architectures[0] == "DeepseekV3ForCausalLM"
          or config.architectures[0] == "DeepseekV2ForCausalLM"):
575
        E = config.n_routed_experts
王敏's avatar
王敏 committed
576
577
578
579
580
        topk = config.num_experts_per_tok
        intermediate_size = config.moe_intermediate_size
        shard_intermediate_size = 2 * intermediate_size // tp_size
    elif config.architectures[0] == "Qwen2MoeForCausalLM":
        E = config.num_experts
581
582
        topk = config.num_experts_per_tok
        intermediate_size = config.moe_intermediate_size
王敏's avatar
王敏 committed
583
        shard_intermediate_size = 2 * intermediate_size // tp_size
584
        block_quant_shape = get_weight_block_size_safety(config)
585
586
587
588
589
    elif config.architectures[0] == "Qwen2MoeForCausalLM":
        E = config.num_experts
        topk = config.num_experts_per_tok
        intermediate_size = config.moe_intermediate_size
        shard_intermediate_size = 2 * intermediate_size // args.tp_size
590
591
592
593
594
    else:
        # Default: Mixtral.
        E = config.num_local_experts
        topk = config.num_experts_per_tok
        intermediate_size = config.intermediate_size
王敏's avatar
王敏 committed
595
        shard_intermediate_size = 2 * intermediate_size // tp_size
596
597

    hidden_size = config.hidden_size
598
    dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
599
600
    use_fp8_w8a8 = args.dtype == "fp8_w8a8"
    use_int8_w8a16 = args.dtype == "int8_w8a16"
601
602

    if args.batch_size is None:
603
        batch_sizes = [
604
605
            1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
            2048, 3072, 4096
606
        ]
607
608
609
    else:
        batch_sizes = [args.batch_size]

610
611
    ray.init(address=None,
                 ignore_reinit_error=True,
王敏's avatar
王敏 committed
612
                 num_gpus=args.num_gpus)
613
    num_gpus = int(ray.available_resources()["GPU"])
王敏's avatar
王敏 committed
614
    workers = [BenchmarkWorker.remote(args.seed, i) for i in range(num_gpus)]
615

616
    def _distribute(method: str, inputs: list[Any]) -> list[Any]:
617
618
619
620
621
622
623
624
625
626
627
        outputs = []
        worker_idx = 0
        for input_args in inputs:
            worker = workers[worker_idx]
            worker_method = getattr(worker, method)
            output = worker_method.remote(*input_args)
            outputs.append(output)
            worker_idx = (worker_idx + 1) % num_gpus
        return ray.get(outputs)

    if args.tune:
628
        is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
zhuwenwen's avatar
zhuwenwen committed
629
        search_space = get_configs_compute_bound(is_fp16, block_quant_shape, args.nn_moe)
630
631
632
633
        print(f"Start tuning over {len(search_space)} configurations...")

        start = time.time()
        configs = _distribute(
634
635
            "tune",
            [(batch_size, E, shard_intermediate_size, hidden_size, topk, dtype,
王敏's avatar
王敏 committed
636
              use_fp8_w8a8, use_int8_w8a16, search_space, block_quant_shape, args.nn_moe)
637
             for batch_size in batch_sizes])
638
639
640
641
642
        best_configs = {
            M: sort_config(config)
            for M, config in zip(batch_sizes, configs)
        }
        save_configs(best_configs, E, shard_intermediate_size, hidden_size,
643
                     topk, dtype, use_fp8_w8a8, use_int8_w8a16,
zhuwenwen's avatar
zhuwenwen committed
644
                     block_quant_shape, use_nn_moe=args.nn_moe)
645
646
647
        end = time.time()
        print(f"Tuning took {end - start:.2f} seconds")
    else:
648
        outputs = _distribute(
649
650
            "benchmark",
            [(batch_size, E, shard_intermediate_size, hidden_size, topk, dtype,
王敏's avatar
王敏 committed
651
              use_fp8_w8a8, use_int8_w8a16, block_quant_shape, args.nn_moe)
652
             for batch_size in batch_sizes])
653
654
655
656
657
658
659

        for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
            print(f"Batch size: {batch_size}, config: {config}")
            print(f"Kernel time: {kernel_time:.2f} us")


if __name__ == "__main__":
660
    parser = FlexibleArgumentParser()
661
662
663
    parser.add_argument("--model",
                        type=str,
                        default="mistralai/Mixtral-8x7B-Instruct-v0.1")
664
665
666
667
668
    parser.add_argument("--tp-size",
                        "-tp",
                        "--tensor-parallel-size",
                        type=int,
                        default=2)
669
670
    parser.add_argument("--dtype",
                        type=str,
671
                        choices=["auto", "fp8_w8a8", "int8_w8a16"],
672
673
674
675
                        default="auto")
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--batch-size", type=int, required=False)
    parser.add_argument("--tune", action="store_true")
王敏's avatar
王敏 committed
676
    parser.add_argument("--nn-moe", action='store_true', default=False)
677
    parser.add_argument("--trust-remote-code", action="store_true")
王敏's avatar
王敏 committed
678
    parser.add_argument("--num-gpus", type=int, default=1)
679
680
681
    args = parser.parse_args()

    main(args)