"vscode:/vscode.git/clone" did not exist on "1f214290d65a6a69b898e058f1408f2d929f8fa7"
benchmark_moe.py 19.9 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
import argparse
import time
from datetime import datetime
6
from itertools import product
7
from typing import Any, TypedDict
8
9
10
11
12
13
14
15

import ray
import torch
import triton
from ray.experimental.tqdm_ray import tqdm
from transformers import AutoConfig

from vllm.model_executor.layers.fused_moe.fused_moe import *
16
from vllm.platforms import current_platform
17
from vllm.utils import FlexibleArgumentParser
18
19

FP8_DTYPE = torch.float8_e4m3fnuz if current_platform.is_rocm(
20
) else torch.float8_e4m3fn
21
22


23
24
25
26
27
28
29
30
31
class BenchmarkConfig(TypedDict):
    BLOCK_SIZE_M: int
    BLOCK_SIZE_N: int
    BLOCK_SIZE_K: int
    GROUP_SIZE_M: int
    num_warps: int
    num_stages: int


32
def benchmark_config(
33
    config: BenchmarkConfig,
34
35
36
37
38
39
    num_tokens: int,
    num_experts: int,
    shard_intermediate_size: int,
    hidden_size: int,
    topk: int,
    dtype: torch.dtype,
40
41
    use_fp8_w8a8: bool,
    use_int8_w8a16: bool,
42
43
    num_iters: int = 100,
) -> float:
44
    init_dtype = torch.float16 if use_fp8_w8a8 else dtype
45
    x = torch.randn(num_tokens, hidden_size, dtype=dtype)
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
    if use_int8_w8a16:
        w1 = torch.randint(-127,
                           127, (
                               num_experts,
                               shard_intermediate_size,
                               hidden_size,
                           ),
                           dtype=torch.int8)
        w2 = torch.randint(-127,
                           127, (
                               num_experts,
                               hidden_size,
                               shard_intermediate_size // 2,
                           ),
                           dtype=torch.int8)
    else:
        w1 = torch.randn(num_experts,
                         shard_intermediate_size,
                         hidden_size,
                         dtype=init_dtype)
        w2 = torch.randn(num_experts,
                         hidden_size,
                         shard_intermediate_size // 2,
                         dtype=init_dtype)
70
71
72
73
74
75
76
77
78
    gating_output = torch.randn(num_iters,
                                num_tokens,
                                num_experts,
                                dtype=torch.float32)

    w1_scale = None
    w2_scale = None
    a1_scale = None
    a2_scale = None
79
80
81
82
83
    if use_int8_w8a16:
        w1_scale = torch.randn((num_experts, 2 * shard_intermediate_size),
                               dtype=torch.float32)
        w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
    if use_fp8_w8a8:
84
85
86
87
88
        w1_scale = torch.randn(num_experts, dtype=torch.float32)
        w2_scale = torch.randn(num_experts, dtype=torch.float32)
        a1_scale = torch.randn(1, dtype=torch.float32)
        a2_scale = torch.randn(1, dtype=torch.float32)

89
90
        w1 = w1.to(FP8_DTYPE)
        w2 = w2.to(FP8_DTYPE)
91
92
93
94
95
96
97

    input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32)

    def prepare(i: int):
        input_gating.copy_(gating_output[i])

    def run():
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
        from vllm.model_executor.layers.fused_moe import override_config
        with override_config(config):
            fused_moe(
                x,
                w1,
                w2,
                input_gating,
                topk,
                renormalize=True,
                inplace=True,
                use_fp8_w8a8=use_fp8_w8a8,
                use_int8_w8a16=use_int8_w8a16,
                w1_scale=w1_scale,
                w2_scale=w2_scale,
                a1_scale=a1_scale,
                a2_scale=a2_scale,
            )
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134

    # JIT compilation & warmup
    run()
    torch.cuda.synchronize()

    # Capture 10 invocations with CUDA graph
    graph = torch.cuda.CUDAGraph()
    with torch.cuda.graph(graph):
        for _ in range(10):
            run()
    torch.cuda.synchronize()

    # Warmup
    for _ in range(5):
        graph.replay()
    torch.cuda.synchronize()

    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)

135
    latencies: list[float] = []
136
137
138
139
140
141
142
143
144
145
146
147
148
149
    for i in range(num_iters):
        prepare(i)
        torch.cuda.synchronize()

        start_event.record()
        graph.replay()
        end_event.record()
        end_event.synchronize()
        latencies.append(start_event.elapsed_time(end_event))
    avg = sum(latencies) / (num_iters * 10) * 1000  # us
    graph.reset()
    return avg


150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
def get_rocm_tuning_space(use_fp16):
    block_mn_range = [16, 32, 64, 128, 256]
    block_k_range = [16, 32, 64, 128, 256]
    if not use_fp16:
        block_k_range.remove(16)  # BLOCK_K=16 not supported for fp8
    num_warps_range = [1, 2, 4, 8]
    group_m_range = [1, 4, 8, 16, 32]
    num_stage_range = [2]
    waves_per_eu_range = [0]
    matrix_instr_nonkdim_range = [16, 32] if use_fp16 else []
    kpack_range = [1, 2] if use_fp16 else []

    param_ranges = {
        "BLOCK_SIZE_M": block_mn_range,
        "BLOCK_SIZE_N": block_mn_range,
        "BLOCK_SIZE_K": block_k_range,
        "GROUP_SIZE_M": group_m_range,
        "num_warps": num_warps_range,
        "num_stages": num_stage_range,
        "waves_per_eu": waves_per_eu_range,
    }
    if use_fp16:
        param_ranges["matrix_instr_nonkdim"] = matrix_instr_nonkdim_range
        param_ranges["kpack"] = kpack_range

    return param_ranges


178
179
def get_configs_compute_bound(use_fp16) -> list[dict[str, int]]:
    configs: list[BenchmarkConfig] = []
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206

    if current_platform.is_rocm():
        param_ranges = get_rocm_tuning_space(use_fp16)
    else:
        # Reduced search space for faster tuning.
        # TODO(woosuk): Increase the search space and use a performance model to
        # prune the search space.
        block_m_range = [16, 32, 64, 128, 256]
        block_n_range = [32, 64, 128, 256]
        block_k_range = [64, 128, 256]
        num_warps_range = [4, 8]
        group_m_range = [1, 16, 32, 64]
        num_stage_range = [2, 3, 4, 5]

        param_ranges = {
            "BLOCK_SIZE_M": block_m_range,
            "BLOCK_SIZE_N": block_n_range,
            "BLOCK_SIZE_K": block_k_range,
            "GROUP_SIZE_M": group_m_range,
            "num_warps": num_warps_range,
            "num_stages": num_stage_range,
        }

    keys, values = zip(*param_ranges.items())
    for config_values in product(*values):
        config = dict(zip(keys, config_values))
        configs.append(config)
207
208
209
    return configs


210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
def prune_rocm_search_space(num_tokens, shard_intermediate_size, hidden_size,
                            search_space, is_fp16):
    N1, K1 = shard_intermediate_size, hidden_size
    N2, K2 = hidden_size, shard_intermediate_size // 2
    pruned_space_1 = prune_rocm_configs(num_tokens * 2, N1, K1, search_space,
                                        is_fp16)
    pruned_space_2 = prune_rocm_configs(num_tokens * 2, N2, K2, search_space,
                                        is_fp16)
    search_space = merge_unique_dicts(pruned_space_1, pruned_space_2)
    return search_space


# The following code is inspired by ROCm/Triton GEMM tuning script:
# https://github.com/ROCm/triton/blob/triton-mlir/scripts/amd/gemm/tune_gemm.py#L89
def prune_rocm_configs(M, N, K, configs, is_fp16=True):
    pruned_configs = []
    elemBytes_a = 2 if is_fp16 else 1
    elemBytes_b = 2 if is_fp16 else 1

    mfma = 16 if M < 32 or N < 32 else 32

    # TODO (zhanglx): figure out the boundary between large and small gemms
    large_gemm = False
    if M >= 2048 and N >= 2048:
        large_gemm = True

    for config in configs:
        BLOCK_SIZE_M = config.get("BLOCK_SIZE_M")
        BLOCK_SIZE_N = config.get("BLOCK_SIZE_N")
        BLOCK_SIZE_K = config.get("BLOCK_SIZE_K")
        num_warps = config.get("num_warps")

        if is_fp16:
            matrix_instr_nonkdim = config.get("matrix_instr_nonkdim")
            if matrix_instr_nonkdim > mfma:
                continue
        if mfma == 4 and BLOCK_SIZE_K < 64:
            continue
        # some layouts could not work properly in case
        # number elements per thread is less 1
        if BLOCK_SIZE_M * BLOCK_SIZE_N < 64:
            continue
        SPLIT_K = config.get("SPLIT_K", 1)
        GROUP_M = config.get("GROUP_SIZE_M")
        if is_fp16:
            if (matrix_instr_nonkdim > BLOCK_SIZE_M
                    or matrix_instr_nonkdim > BLOCK_SIZE_N):
                continue
            if (matrix_instr_nonkdim >= M
                    and matrix_instr_nonkdim != BLOCK_SIZE_M):
                continue
            if (matrix_instr_nonkdim >= N
                    and matrix_instr_nonkdim != BLOCK_SIZE_N):
                continue
        # Skip BLOCK_SIZE that is too large compare to M/N
        # unless BLOCK_SIZE is already small enough
        if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16:
            continue
        if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16:
            continue
        # skip large split_k when not necessary
        if SPLIT_K != 1 and not need_split_k(M, N, K):
            continue
        # skip split_k that leads to EVEN_K = false
        leap = SPLIT_K * BLOCK_SIZE_K
        modv = K % leap
        if modv != 0:
            continue
        # skip large GROUP_M
        if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1:
            continue
        # out of shared memory resource
        # TODO (zhanglx): This does not consider the LDS usage in the epilogue
        LDS = (BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a +
               BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b)
        if LDS > 65536:
            continue
        # Skip small block sizes and num_warps for large gemm
        # For fp16 and f8, we want to only use BLOCK_SIZE >= 64
        if large_gemm:
            if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64:
                continue
            if BLOCK_SIZE_K < 64:
                continue
            if num_warps < 4:
                continue

        pruned_configs.append(config)

    return pruned_configs


def need_split_k(SIZE_M, SIZE_N, SIZE_K):
    return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024


def merge_unique_dicts(list1, list2):
    result = []
    combined_list = list1.copy()
    combined_list.extend(list2)
    for dictionary in combined_list:
        if dictionary not in result:
            result.append(dictionary)
    return result


316
317
318
319
320
@ray.remote(num_gpus=1)
class BenchmarkWorker:

    def __init__(self, seed: int) -> None:
        torch.set_default_device("cuda")
321
        current_platform.seed_everything(seed)
322
        self.seed = seed
323
324
325
326
        # Get the device ID to allocate tensors and kernels
        # on the respective GPU. This is required for Ray to work
        # correctly with multi-GPU tuning on the ROCm platform.
        self.device_id = int(ray.get_gpu_ids()[0])
327
328
329
330
331
332
333
334
335

    def benchmark(
        self,
        num_tokens: int,
        num_experts: int,
        shard_intermediate_size: int,
        hidden_size: int,
        topk: int,
        dtype: torch.dtype,
336
337
        use_fp8_w8a8: bool,
        use_int8_w8a16: bool,
338
    ) -> tuple[dict[str, int], float]:
339
        current_platform.seed_everything(self.seed)
340
341
342
        dtype_str = get_config_dtype_str(dtype,
                                         use_int8_w8a16=use_int8_w8a16,
                                         use_fp8_w8a8=use_fp8_w8a8)
343
344
345
346
347
        # NOTE(woosuk): The current naming convention uses w2.shape[2], which
        # is the intermediate size after silu_and_mul.
        op_config = get_moe_configs(num_experts, shard_intermediate_size // 2,
                                    dtype_str)
        if op_config is None:
348
349
350
351
352
353
354
            config = get_default_config(num_tokens,
                                        num_experts,
                                        shard_intermediate_size,
                                        hidden_size,
                                        topk,
                                        dtype_str,
                                        is_marlin=False)
355
356
357
358
359
        else:
            config = op_config[min(op_config.keys(),
                                   key=lambda x: abs(x - num_tokens))]
        kernel_time = benchmark_config(config, num_tokens, num_experts,
                                       shard_intermediate_size, hidden_size,
360
361
                                       topk, dtype, use_fp8_w8a8,
                                       use_int8_w8a16)
362
363
364
365
366
367
368
369
370
371
        return config, kernel_time

    def tune(
        self,
        num_tokens: int,
        num_experts: int,
        shard_intermediate_size: int,
        hidden_size: int,
        topk: int,
        dtype: torch.dtype,
372
373
        use_fp8_w8a8: bool,
        use_int8_w8a16: bool,
374
375
        search_space: list[dict[str, int]],
    ) -> dict[str, int]:
376
377
        best_config = None
        best_time = float("inf")
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
        if current_platform.is_rocm():
            is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
            search_space = prune_rocm_search_space(num_tokens,
                                                   shard_intermediate_size,
                                                   hidden_size, search_space,
                                                   is_fp16)

        with torch.cuda.device(self.device_id):
            for config in tqdm(search_space):
                try:
                    kernel_time = benchmark_config(config,
                                                   num_tokens,
                                                   num_experts,
                                                   shard_intermediate_size,
                                                   hidden_size,
                                                   topk,
                                                   dtype,
                                                   use_fp8_w8a8,
                                                   use_int8_w8a16,
                                                   num_iters=20)
                except triton.runtime.autotuner.OutOfResources:
                    # Some configurations may be invalid and fail to compile.
                    continue

                if kernel_time < best_time:
                    best_time = kernel_time
                    best_config = config
405
406
        now = datetime.now()
        print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
407
        assert best_config is not None
408
409
410
        return best_config


411
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
412
    return {
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
        "BLOCK_SIZE_M":
        config["BLOCK_SIZE_M"],
        "BLOCK_SIZE_N":
        config["BLOCK_SIZE_N"],
        "BLOCK_SIZE_K":
        config["BLOCK_SIZE_K"],
        "GROUP_SIZE_M":
        config["GROUP_SIZE_M"],
        "num_warps":
        config["num_warps"],
        "num_stages":
        config["num_stages"],
        **({
            "waves_per_eu": config["waves_per_eu"]
        } if "waves_per_eu" in config else {}),
        **({
            "matrix_instr_nonkdim": config["matrix_instr_nonkdim"]
        } if "matrix_instr_nonkdim" in config else {}),
        **({
            "kpack": config["kpack"]
        } if "kpack" in config else {}),
434
435
436
    }


437
def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int,
438
439
440
441
442
443
444
                 shard_intermediate_size: int, hidden_size: int, topk: int,
                 dtype: torch.dtype, use_fp8_w8a8: bool,
                 use_int8_w8a16: bool) -> None:
    dtype_str = get_config_dtype_str(dtype,
                                     use_int8_w8a16=use_int8_w8a16,
                                     use_fp8_w8a8=use_fp8_w8a8)

445
446
447
448
    # NOTE(woosuk): The current naming convention uses w2.shape[2], which
    # is the intermediate size after silu_and_mul.
    filename = get_config_file_name(num_experts, shard_intermediate_size // 2,
                                    dtype_str)
449

450
451
452
453
454
455
456
457
458
    print(f"Writing best config to {filename}...")
    with open(filename, "w") as f:
        json.dump(configs, f, indent=4)
        f.write("\n")


def main(args: argparse.Namespace):
    print(args)

459
460
    config = AutoConfig.from_pretrained(
        args.model, trust_remote_code=args.trust_remote_code)
461
462
463
464
465
    if config.architectures[0] == "DbrxForCausalLM":
        E = config.ffn_config.moe_num_experts
        topk = config.ffn_config.moe_top_k
        intermediate_size = config.ffn_config.ffn_hidden_size
        shard_intermediate_size = 2 * intermediate_size // args.tp_size
466
467
468
469
470
    elif config.architectures[0] == "JambaForCausalLM":
        E = config.num_experts
        topk = config.num_experts_per_tok
        intermediate_size = config.intermediate_size
        shard_intermediate_size = 2 * intermediate_size // args.tp_size
471
472
    elif (config.architectures[0] == "DeepseekV3ForCausalLM"
          or config.architectures[0] == "DeepseekV2ForCausalLM"):
473
474
475
476
        E = config.n_routed_experts
        topk = config.num_experts_per_tok
        intermediate_size = config.moe_intermediate_size
        shard_intermediate_size = 2 * intermediate_size // args.tp_size
477
478
479
480
481
482
483
484
    else:
        # Default: Mixtral.
        E = config.num_local_experts
        topk = config.num_experts_per_tok
        intermediate_size = config.intermediate_size
        shard_intermediate_size = 2 * intermediate_size // args.tp_size

    hidden_size = config.hidden_size
485
    dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
486
487
    use_fp8_w8a8 = args.dtype == "fp8_w8a8"
    use_int8_w8a16 = args.dtype == "int8_w8a16"
488
489

    if args.batch_size is None:
490
        batch_sizes = [
491
492
            1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
            2048, 3072, 4096
493
        ]
494
495
496
497
498
499
500
    else:
        batch_sizes = [args.batch_size]

    ray.init()
    num_gpus = int(ray.available_resources()["GPU"])
    workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]

501
    def _distribute(method: str, inputs: list[Any]) -> list[Any]:
502
503
504
505
506
507
508
509
510
511
512
        outputs = []
        worker_idx = 0
        for input_args in inputs:
            worker = workers[worker_idx]
            worker_method = getattr(worker, method)
            output = worker_method.remote(*input_args)
            outputs.append(output)
            worker_idx = (worker_idx + 1) % num_gpus
        return ray.get(outputs)

    if args.tune:
513
514
        is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
        search_space = get_configs_compute_bound(is_fp16)
515
516
517
518
519
        print(f"Start tuning over {len(search_space)} configurations...")

        start = time.time()
        configs = _distribute(
            "tune", [(batch_size, E, shard_intermediate_size, hidden_size,
520
                      topk, dtype, use_fp8_w8a8, use_int8_w8a16, search_space)
521
522
523
524
525
526
                     for batch_size in batch_sizes])
        best_configs = {
            M: sort_config(config)
            for M, config in zip(batch_sizes, configs)
        }
        save_configs(best_configs, E, shard_intermediate_size, hidden_size,
527
                     topk, dtype, use_fp8_w8a8, use_int8_w8a16)
528
529
530
        end = time.time()
        print(f"Tuning took {end - start:.2f} seconds")
    else:
531
532
533
534
        outputs = _distribute(
            "benchmark", [(batch_size, E, shard_intermediate_size, hidden_size,
                           topk, dtype, use_fp8_w8a8, use_int8_w8a16)
                          for batch_size in batch_sizes])
535
536
537
538
539
540
541

        for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
            print(f"Batch size: {batch_size}, config: {config}")
            print(f"Kernel time: {kernel_time:.2f} us")


if __name__ == "__main__":
542
    parser = FlexibleArgumentParser()
543
544
545
    parser.add_argument("--model",
                        type=str,
                        default="mistralai/Mixtral-8x7B-Instruct-v0.1")
546
547
548
549
550
    parser.add_argument("--tp-size",
                        "-tp",
                        "--tensor-parallel-size",
                        type=int,
                        default=2)
551
552
    parser.add_argument("--dtype",
                        type=str,
553
                        choices=["auto", "fp8_w8a8", "int8_w8a16"],
554
555
556
557
                        default="auto")
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--batch-size", type=int, required=False)
    parser.add_argument("--tune", action="store_true")
558
    parser.add_argument("--trust-remote-code", action="store_true")
559
560
561
    args = parser.parse_args()

    main(args)