benchmark_moe.py 12.8 KB
Newer Older
1
2
3
import argparse
import time
from datetime import datetime
4
from typing import Any, Dict, List, Tuple, TypedDict
5
6
7
8
9
10
11
12

import ray
import torch
import triton
from ray.experimental.tqdm_ray import tqdm
from transformers import AutoConfig

from vllm.model_executor.layers.fused_moe.fused_moe import *
13
from vllm.utils import FlexibleArgumentParser
14
15


16
17
18
19
20
21
22
23
24
class BenchmarkConfig(TypedDict):
    BLOCK_SIZE_M: int
    BLOCK_SIZE_N: int
    BLOCK_SIZE_K: int
    GROUP_SIZE_M: int
    num_warps: int
    num_stages: int


25
def benchmark_config(
26
    config: BenchmarkConfig,
27
28
29
30
31
32
    num_tokens: int,
    num_experts: int,
    shard_intermediate_size: int,
    hidden_size: int,
    topk: int,
    dtype: torch.dtype,
33
34
    use_fp8_w8a8: bool,
    use_int8_w8a16: bool,
35
36
    num_iters: int = 100,
) -> float:
37
    init_dtype = torch.float16 if use_fp8_w8a8 else dtype
38
    x = torch.randn(num_tokens, hidden_size, dtype=dtype)
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
    if use_int8_w8a16:
        w1 = torch.randint(-127,
                           127, (
                               num_experts,
                               shard_intermediate_size,
                               hidden_size,
                           ),
                           dtype=torch.int8)
        w2 = torch.randint(-127,
                           127, (
                               num_experts,
                               hidden_size,
                               shard_intermediate_size // 2,
                           ),
                           dtype=torch.int8)
    else:
        w1 = torch.randn(num_experts,
                         shard_intermediate_size,
                         hidden_size,
                         dtype=init_dtype)
        w2 = torch.randn(num_experts,
                         hidden_size,
                         shard_intermediate_size // 2,
                         dtype=init_dtype)
63
64
65
66
67
68
69
70
71
    gating_output = torch.randn(num_iters,
                                num_tokens,
                                num_experts,
                                dtype=torch.float32)

    w1_scale = None
    w2_scale = None
    a1_scale = None
    a2_scale = None
72
73
74
75
76
    if use_int8_w8a16:
        w1_scale = torch.randn((num_experts, 2 * shard_intermediate_size),
                               dtype=torch.float32)
        w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
    if use_fp8_w8a8:
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
        w1_scale = torch.randn(num_experts, dtype=torch.float32)
        w2_scale = torch.randn(num_experts, dtype=torch.float32)
        a1_scale = torch.randn(1, dtype=torch.float32)
        a2_scale = torch.randn(1, dtype=torch.float32)

        w1 = w1.to(torch.float8_e4m3fn)
        w2 = w2.to(torch.float8_e4m3fn)

    input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32)

    def prepare(i: int):
        input_gating.copy_(gating_output[i])

    def run():
        fused_moe(
            x,
            w1,
            w2,
            input_gating,
            topk,
            renormalize=True,
            inplace=True,
            override_config=config,
100
101
            use_fp8_w8a8=use_fp8_w8a8,
            use_int8_w8a16=use_int8_w8a16,
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
        )

    # JIT compilation & warmup
    run()
    torch.cuda.synchronize()

    # Capture 10 invocations with CUDA graph
    graph = torch.cuda.CUDAGraph()
    with torch.cuda.graph(graph):
        for _ in range(10):
            run()
    torch.cuda.synchronize()

    # Warmup
    for _ in range(5):
        graph.replay()
    torch.cuda.synchronize()

    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)

127
    latencies: List[float] = []
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
    for i in range(num_iters):
        prepare(i)
        torch.cuda.synchronize()

        start_event.record()
        graph.replay()
        end_event.record()
        end_event.synchronize()
        latencies.append(start_event.elapsed_time(end_event))
    avg = sum(latencies) / (num_iters * 10) * 1000  # us
    graph.reset()
    return avg


def get_configs_compute_bound() -> List[Dict[str, int]]:
    # Reduced search space for faster tuning.
    # TODO(woosuk): Increase the search space and use a performance model to
    # prune the search space.
146
    configs: List[BenchmarkConfig] = []
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
    for num_stages in [2, 3, 4, 5]:
        for block_m in [16, 32, 64, 128, 256]:
            for block_k in [64, 128, 256]:
                for block_n in [32, 64, 128, 256]:
                    for num_warps in [4, 8]:
                        for group_size in [1, 16, 32, 64]:
                            configs.append({
                                "BLOCK_SIZE_M": block_m,
                                "BLOCK_SIZE_N": block_n,
                                "BLOCK_SIZE_K": block_k,
                                "GROUP_SIZE_M": group_size,
                                "num_warps": num_warps,
                                "num_stages": num_stages,
                            })
    return configs


@ray.remote(num_gpus=1)
class BenchmarkWorker:

    def __init__(self, seed: int) -> None:
        torch.set_default_device("cuda")
        torch.cuda.manual_seed_all(seed)
        self.seed = seed

    def benchmark(
        self,
        num_tokens: int,
        num_experts: int,
        shard_intermediate_size: int,
        hidden_size: int,
        topk: int,
        dtype: torch.dtype,
180
181
        use_fp8_w8a8: bool,
        use_int8_w8a16: bool,
182
183
    ) -> Tuple[Dict[str, int], float]:
        torch.cuda.manual_seed_all(self.seed)
184
185
186
        dtype_str = get_config_dtype_str(dtype,
                                         use_int8_w8a16=use_int8_w8a16,
                                         use_fp8_w8a8=use_fp8_w8a8)
187
188
189
190
191
192
193
194
195
196
197
198
199
        # NOTE(woosuk): The current naming convention uses w2.shape[2], which
        # is the intermediate size after silu_and_mul.
        op_config = get_moe_configs(num_experts, shard_intermediate_size // 2,
                                    dtype_str)
        if op_config is None:
            config = get_default_config(num_tokens, num_experts,
                                        shard_intermediate_size, hidden_size,
                                        topk, dtype_str)
        else:
            config = op_config[min(op_config.keys(),
                                   key=lambda x: abs(x - num_tokens))]
        kernel_time = benchmark_config(config, num_tokens, num_experts,
                                       shard_intermediate_size, hidden_size,
200
201
                                       topk, dtype, use_fp8_w8a8,
                                       use_int8_w8a16)
202
203
204
205
206
207
208
209
210
211
        return config, kernel_time

    def tune(
        self,
        num_tokens: int,
        num_experts: int,
        shard_intermediate_size: int,
        hidden_size: int,
        topk: int,
        dtype: torch.dtype,
212
213
214
215
        use_fp8_w8a8: bool,
        use_int8_w8a16: bool,
        search_space: List[Dict[str, int]],
    ) -> Dict[str, int]:
216
217
218
219
220
221
222
223
224
225
226
        best_config = None
        best_time = float("inf")
        for config in tqdm(search_space):
            try:
                kernel_time = benchmark_config(config,
                                               num_tokens,
                                               num_experts,
                                               shard_intermediate_size,
                                               hidden_size,
                                               topk,
                                               dtype,
227
228
                                               use_fp8_w8a8,
                                               use_int8_w8a16,
229
230
231
232
233
234
235
236
237
238
                                               num_iters=10)
            except triton.runtime.autotuner.OutOfResources:
                # Some configurations may be invalid and fail to compile.
                continue

            if kernel_time < best_time:
                best_time = kernel_time
                best_config = config
        now = datetime.now()
        print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
239
        assert best_config is not None
240
241
242
        return best_config


243
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
244
245
246
247
248
249
250
251
252
253
    return {
        "BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
        "BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
        "BLOCK_SIZE_K": config["BLOCK_SIZE_K"],
        "GROUP_SIZE_M": config["GROUP_SIZE_M"],
        "num_warps": config["num_warps"],
        "num_stages": config["num_stages"],
    }


254
255
256
257
258
259
260
261
def save_configs(configs: Dict[int, BenchmarkConfig], num_experts: int,
                 shard_intermediate_size: int, hidden_size: int, topk: int,
                 dtype: torch.dtype, use_fp8_w8a8: bool,
                 use_int8_w8a16: bool) -> None:
    dtype_str = get_config_dtype_str(dtype,
                                     use_int8_w8a16=use_int8_w8a16,
                                     use_fp8_w8a8=use_fp8_w8a8)

262
263
264
265
    # NOTE(woosuk): The current naming convention uses w2.shape[2], which
    # is the intermediate size after silu_and_mul.
    filename = get_config_file_name(num_experts, shard_intermediate_size // 2,
                                    dtype_str)
266

267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
    print(f"Writing best config to {filename}...")
    with open(filename, "w") as f:
        json.dump(configs, f, indent=4)
        f.write("\n")


def main(args: argparse.Namespace):
    print(args)

    config = AutoConfig.from_pretrained(args.model)
    if config.architectures[0] == "DbrxForCausalLM":
        E = config.ffn_config.moe_num_experts
        topk = config.ffn_config.moe_top_k
        intermediate_size = config.ffn_config.ffn_hidden_size
        shard_intermediate_size = 2 * intermediate_size // args.tp_size
282
283
284
285
286
    elif config.architectures[0] == "JambaForCausalLM":
        E = config.num_experts
        topk = config.num_experts_per_tok
        intermediate_size = config.intermediate_size
        shard_intermediate_size = 2 * intermediate_size // args.tp_size
287
288
289
290
291
292
293
294
295
    else:
        # Default: Mixtral.
        E = config.num_local_experts
        topk = config.num_experts_per_tok
        intermediate_size = config.intermediate_size
        shard_intermediate_size = 2 * intermediate_size // args.tp_size

    hidden_size = config.hidden_size
    dtype = config.torch_dtype
296
297
    use_fp8_w8a8 = args.dtype == "fp8_w8a8"
    use_int8_w8a16 = args.dtype == "int8_w8a16"
298
299

    if args.batch_size is None:
300
        batch_sizes = [
301
302
            1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
            2048, 3072, 4096
303
        ]
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
    else:
        batch_sizes = [args.batch_size]

    ray.init()
    num_gpus = int(ray.available_resources()["GPU"])
    workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]

    def _distribute(method: str, inputs: List[Any]) -> List[Any]:
        outputs = []
        worker_idx = 0
        for input_args in inputs:
            worker = workers[worker_idx]
            worker_method = getattr(worker, method)
            output = worker_method.remote(*input_args)
            outputs.append(output)
            worker_idx = (worker_idx + 1) % num_gpus
        return ray.get(outputs)

    if args.tune:
        search_space = get_configs_compute_bound()
        print(f"Start tuning over {len(search_space)} configurations...")

        start = time.time()
        configs = _distribute(
            "tune", [(batch_size, E, shard_intermediate_size, hidden_size,
329
                      topk, dtype, use_fp8_w8a8, use_int8_w8a16, search_space)
330
331
332
333
334
335
                     for batch_size in batch_sizes])
        best_configs = {
            M: sort_config(config)
            for M, config in zip(batch_sizes, configs)
        }
        save_configs(best_configs, E, shard_intermediate_size, hidden_size,
336
                     topk, dtype, use_fp8_w8a8, use_int8_w8a16)
337
338
339
        end = time.time()
        print(f"Tuning took {end - start:.2f} seconds")
    else:
340
341
342
343
        outputs = _distribute(
            "benchmark", [(batch_size, E, shard_intermediate_size, hidden_size,
                           topk, dtype, use_fp8_w8a8, use_int8_w8a16)
                          for batch_size in batch_sizes])
344
345
346
347
348
349
350

        for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
            print(f"Batch size: {batch_size}, config: {config}")
            print(f"Kernel time: {kernel_time:.2f} us")


if __name__ == "__main__":
351
    parser = FlexibleArgumentParser()
352
353
354
355
356
357
    parser.add_argument("--model",
                        type=str,
                        default="mistralai/Mixtral-8x7B-Instruct-v0.1")
    parser.add_argument("--tp-size", "-tp", type=int, default=2)
    parser.add_argument("--dtype",
                        type=str,
358
                        choices=["auto", "fp8_w8a8", "int8_w8a16"],
359
360
361
362
363
364
365
                        default="auto")
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--batch-size", type=int, required=False)
    parser.add_argument("--tune", action="store_true")
    args = parser.parse_args()

    main(args)