benchmark_moe.py 12.9 KB
Newer Older
zhuwenwen's avatar
zhuwenwen committed
1
2
3
import argparse
import time
from datetime import datetime
laibao's avatar
laibao committed
4
from typing import Any, Dict, List, Tuple, TypedDict
zhuwenwen's avatar
zhuwenwen committed
5
6
7
8
9
10
11
12

import ray
import torch
import triton
from ray.experimental.tqdm_ray import tqdm
from transformers import AutoConfig

from vllm.model_executor.layers.fused_moe.fused_moe import *
laibao's avatar
laibao committed
13
14
15
16
17
18
19
20
21
22
from vllm.utils import FlexibleArgumentParser, seed_everything


class BenchmarkConfig(TypedDict):
    BLOCK_SIZE_M: int
    BLOCK_SIZE_N: int
    BLOCK_SIZE_K: int
    GROUP_SIZE_M: int
    num_warps: int
    num_stages: int
zhuwenwen's avatar
zhuwenwen committed
23
24
25


def benchmark_config(
laibao's avatar
laibao committed
26
    config: BenchmarkConfig,
zhuwenwen's avatar
zhuwenwen committed
27
28
29
30
31
32
    num_tokens: int,
    num_experts: int,
    shard_intermediate_size: int,
    hidden_size: int,
    topk: int,
    dtype: torch.dtype,
laibao's avatar
laibao committed
33
34
    use_fp8_w8a8: bool,
    use_int8_w8a16: bool,
zhuwenwen's avatar
zhuwenwen committed
35
36
    num_iters: int = 100,
) -> float:
laibao's avatar
laibao committed
37
    init_dtype = torch.float16 if use_fp8_w8a8 else dtype
zhuwenwen's avatar
zhuwenwen committed
38
    x = torch.randn(num_tokens, hidden_size, dtype=dtype)
laibao's avatar
laibao committed
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
    if use_int8_w8a16:
        w1 = torch.randint(-127,
                           127, (
                               num_experts,
                               shard_intermediate_size,
                               hidden_size,
                           ),
                           dtype=torch.int8)
        w2 = torch.randint(-127,
                           127, (
                               num_experts,
                               hidden_size,
                               shard_intermediate_size // 2,
                           ),
                           dtype=torch.int8)
    else:
        w1 = torch.randn(num_experts,
                         shard_intermediate_size,
                         hidden_size,
                         dtype=init_dtype)
        w2 = torch.randn(num_experts,
                         hidden_size,
                         shard_intermediate_size // 2,
                         dtype=init_dtype)
zhuwenwen's avatar
zhuwenwen committed
63
64
65
66
67
68
69
70
71
    gating_output = torch.randn(num_iters,
                                num_tokens,
                                num_experts,
                                dtype=torch.float32)

    w1_scale = None
    w2_scale = None
    a1_scale = None
    a2_scale = None
laibao's avatar
laibao committed
72
73
74
75
76
    if use_int8_w8a16:
        w1_scale = torch.randn((num_experts, 2 * shard_intermediate_size),
                               dtype=torch.float32)
        w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
    if use_fp8_w8a8:
zhuwenwen's avatar
zhuwenwen committed
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
        w1_scale = torch.randn(num_experts, dtype=torch.float32)
        w2_scale = torch.randn(num_experts, dtype=torch.float32)
        a1_scale = torch.randn(1, dtype=torch.float32)
        a2_scale = torch.randn(1, dtype=torch.float32)

        w1 = w1.to(torch.float8_e4m3fn)
        w2 = w2.to(torch.float8_e4m3fn)

    input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32)

    def prepare(i: int):
        input_gating.copy_(gating_output[i])

    def run():
        fused_moe(
            x,
            w1,
            w2,
            input_gating,
            topk,
            renormalize=True,
            inplace=True,
            override_config=config,
laibao's avatar
laibao committed
100
101
            use_fp8_w8a8=use_fp8_w8a8,
            use_int8_w8a16=use_int8_w8a16,
zhuwenwen's avatar
zhuwenwen committed
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
        )

    # JIT compilation & warmup
    run()
    torch.cuda.synchronize()

    # Capture 10 invocations with CUDA graph
    graph = torch.cuda.CUDAGraph()
    with torch.cuda.graph(graph):
        for _ in range(10):
            run()
    torch.cuda.synchronize()

    # Warmup
    for _ in range(5):
        graph.replay()
    torch.cuda.synchronize()

    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)

laibao's avatar
laibao committed
127
    latencies: List[float] = []
zhuwenwen's avatar
zhuwenwen committed
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
    for i in range(num_iters):
        prepare(i)
        torch.cuda.synchronize()

        start_event.record()
        graph.replay()
        end_event.record()
        end_event.synchronize()
        latencies.append(start_event.elapsed_time(end_event))
    avg = sum(latencies) / (num_iters * 10) * 1000  # us
    graph.reset()
    return avg


def get_configs_compute_bound() -> List[Dict[str, int]]:
    # Reduced search space for faster tuning.
    # TODO(woosuk): Increase the search space and use a performance model to
    # prune the search space.
laibao's avatar
laibao committed
146
    configs: List[BenchmarkConfig] = []
zhuwenwen's avatar
zhuwenwen committed
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
    for num_stages in [2, 3, 4, 5]:
        for block_m in [16, 32, 64, 128, 256]:
            for block_k in [64, 128, 256]:
                for block_n in [32, 64, 128, 256]:
                    for num_warps in [4, 8]:
                        for group_size in [1, 16, 32, 64]:
                            configs.append({
                                "BLOCK_SIZE_M": block_m,
                                "BLOCK_SIZE_N": block_n,
                                "BLOCK_SIZE_K": block_k,
                                "GROUP_SIZE_M": group_size,
                                "num_warps": num_warps,
                                "num_stages": num_stages,
                            })
    return configs


@ray.remote(num_gpus=1)
class BenchmarkWorker:

    def __init__(self, seed: int) -> None:
        torch.set_default_device("cuda")
laibao's avatar
laibao committed
169
        seed_everything(seed)
zhuwenwen's avatar
zhuwenwen committed
170
171
172
173
174
175
176
177
178
179
        self.seed = seed

    def benchmark(
        self,
        num_tokens: int,
        num_experts: int,
        shard_intermediate_size: int,
        hidden_size: int,
        topk: int,
        dtype: torch.dtype,
laibao's avatar
laibao committed
180
181
        use_fp8_w8a8: bool,
        use_int8_w8a16: bool,
zhuwenwen's avatar
zhuwenwen committed
182
    ) -> Tuple[Dict[str, int], float]:
laibao's avatar
laibao committed
183
184
185
186
        seed_everything(self.seed)
        dtype_str = get_config_dtype_str(dtype,
                                         use_int8_w8a16=use_int8_w8a16,
                                         use_fp8_w8a8=use_fp8_w8a8)
zhuwenwen's avatar
zhuwenwen committed
187
188
189
190
191
192
193
194
195
196
197
198
199
        # NOTE(woosuk): The current naming convention uses w2.shape[2], which
        # is the intermediate size after silu_and_mul.
        op_config = get_moe_configs(num_experts, shard_intermediate_size // 2,
                                    dtype_str)
        if op_config is None:
            config = get_default_config(num_tokens, num_experts,
                                        shard_intermediate_size, hidden_size,
                                        topk, dtype_str)
        else:
            config = op_config[min(op_config.keys(),
                                   key=lambda x: abs(x - num_tokens))]
        kernel_time = benchmark_config(config, num_tokens, num_experts,
                                       shard_intermediate_size, hidden_size,
laibao's avatar
laibao committed
200
201
                                       topk, dtype, use_fp8_w8a8,
                                       use_int8_w8a16)
zhuwenwen's avatar
zhuwenwen committed
202
203
204
205
206
207
208
209
210
211
        return config, kernel_time

    def tune(
        self,
        num_tokens: int,
        num_experts: int,
        shard_intermediate_size: int,
        hidden_size: int,
        topk: int,
        dtype: torch.dtype,
laibao's avatar
laibao committed
212
213
        use_fp8_w8a8: bool,
        use_int8_w8a16: bool,
zhuwenwen's avatar
zhuwenwen committed
214
215
216
217
218
219
220
221
222
223
224
225
226
        search_space: List[Dict[str, int]],
    ) -> Dict[str, int]:
        best_config = None
        best_time = float("inf")
        for config in tqdm(search_space):
            try:
                kernel_time = benchmark_config(config,
                                               num_tokens,
                                               num_experts,
                                               shard_intermediate_size,
                                               hidden_size,
                                               topk,
                                               dtype,
laibao's avatar
laibao committed
227
228
                                               use_fp8_w8a8,
                                               use_int8_w8a16,
zhuwenwen's avatar
zhuwenwen committed
229
230
231
232
233
234
235
236
237
238
                                               num_iters=10)
            except triton.runtime.autotuner.OutOfResources:
                # Some configurations may be invalid and fail to compile.
                continue

            if kernel_time < best_time:
                best_time = kernel_time
                best_config = config
        now = datetime.now()
        print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
laibao's avatar
laibao committed
239
        assert best_config is not None
zhuwenwen's avatar
zhuwenwen committed
240
241
242
        return best_config


laibao's avatar
laibao committed
243
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
zhuwenwen's avatar
zhuwenwen committed
244
245
246
247
248
249
250
251
252
253
    return {
        "BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
        "BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
        "BLOCK_SIZE_K": config["BLOCK_SIZE_K"],
        "GROUP_SIZE_M": config["GROUP_SIZE_M"],
        "num_warps": config["num_warps"],
        "num_stages": config["num_stages"],
    }


laibao's avatar
laibao committed
254
255
256
257
258
259
260
261
def save_configs(configs: Dict[int, BenchmarkConfig], num_experts: int,
                 shard_intermediate_size: int, hidden_size: int, topk: int,
                 dtype: torch.dtype, use_fp8_w8a8: bool,
                 use_int8_w8a16: bool) -> None:
    dtype_str = get_config_dtype_str(dtype,
                                     use_int8_w8a16=use_int8_w8a16,
                                     use_fp8_w8a8=use_fp8_w8a8)

zhuwenwen's avatar
zhuwenwen committed
262
263
264
265
    # NOTE(woosuk): The current naming convention uses w2.shape[2], which
    # is the intermediate size after silu_and_mul.
    filename = get_config_file_name(num_experts, shard_intermediate_size // 2,
                                    dtype_str)
laibao's avatar
laibao committed
266

zhuwenwen's avatar
zhuwenwen committed
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
    print(f"Writing best config to {filename}...")
    with open(filename, "w") as f:
        json.dump(configs, f, indent=4)
        f.write("\n")


def main(args: argparse.Namespace):
    print(args)

    config = AutoConfig.from_pretrained(args.model)
    if config.architectures[0] == "DbrxForCausalLM":
        E = config.ffn_config.moe_num_experts
        topk = config.ffn_config.moe_top_k
        intermediate_size = config.ffn_config.ffn_hidden_size
        shard_intermediate_size = 2 * intermediate_size // args.tp_size
laibao's avatar
laibao committed
282
283
284
285
286
    elif config.architectures[0] == "JambaForCausalLM":
        E = config.num_experts
        topk = config.num_experts_per_tok
        intermediate_size = config.intermediate_size
        shard_intermediate_size = 2 * intermediate_size // args.tp_size
zhuwenwen's avatar
zhuwenwen committed
287
288
289
290
291
292
293
294
295
    else:
        # Default: Mixtral.
        E = config.num_local_experts
        topk = config.num_experts_per_tok
        intermediate_size = config.intermediate_size
        shard_intermediate_size = 2 * intermediate_size // args.tp_size

    hidden_size = config.hidden_size
    dtype = config.torch_dtype
laibao's avatar
laibao committed
296
297
    use_fp8_w8a8 = args.dtype == "fp8_w8a8"
    use_int8_w8a16 = args.dtype == "int8_w8a16"
zhuwenwen's avatar
zhuwenwen committed
298
299
300
301
302
303
304
305
306

    if args.batch_size is None:
        batch_sizes = [
            1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
            2048, 3072, 4096
        ]
    else:
        batch_sizes = [args.batch_size]

laibao's avatar
laibao committed
307
308
309
    ray.init(address=None,
                 ignore_reinit_error=True,
                 num_gpus=args.tp_size)
zhuwenwen's avatar
zhuwenwen committed
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
    num_gpus = int(ray.available_resources()["GPU"])
    workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]

    def _distribute(method: str, inputs: List[Any]) -> List[Any]:
        outputs = []
        worker_idx = 0
        for input_args in inputs:
            worker = workers[worker_idx]
            worker_method = getattr(worker, method)
            output = worker_method.remote(*input_args)
            outputs.append(output)
            worker_idx = (worker_idx + 1) % num_gpus
        return ray.get(outputs)

    if args.tune:
        search_space = get_configs_compute_bound()
        print(f"Start tuning over {len(search_space)} configurations...")

        start = time.time()
        configs = _distribute(
            "tune", [(batch_size, E, shard_intermediate_size, hidden_size,
laibao's avatar
laibao committed
331
                      topk, dtype, use_fp8_w8a8, use_int8_w8a16, search_space)
zhuwenwen's avatar
zhuwenwen committed
332
333
334
335
336
337
                     for batch_size in batch_sizes])
        best_configs = {
            M: sort_config(config)
            for M, config in zip(batch_sizes, configs)
        }
        save_configs(best_configs, E, shard_intermediate_size, hidden_size,
laibao's avatar
laibao committed
338
                     topk, dtype, use_fp8_w8a8, use_int8_w8a16)
zhuwenwen's avatar
zhuwenwen committed
339
340
341
        end = time.time()
        print(f"Tuning took {end - start:.2f} seconds")
    else:
laibao's avatar
laibao committed
342
343
344
345
        outputs = _distribute(
            "benchmark", [(batch_size, E, shard_intermediate_size, hidden_size,
                           topk, dtype, use_fp8_w8a8, use_int8_w8a16)
                          for batch_size in batch_sizes])
zhuwenwen's avatar
zhuwenwen committed
346
347
348
349
350
351
352

        for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
            print(f"Batch size: {batch_size}, config: {config}")
            print(f"Kernel time: {kernel_time:.2f} us")


if __name__ == "__main__":
laibao's avatar
laibao committed
353
    parser = FlexibleArgumentParser()
zhuwenwen's avatar
zhuwenwen committed
354
355
356
357
358
359
    parser.add_argument("--model",
                        type=str,
                        default="mistralai/Mixtral-8x7B-Instruct-v0.1")
    parser.add_argument("--tp-size", "-tp", type=int, default=2)
    parser.add_argument("--dtype",
                        type=str,
laibao's avatar
laibao committed
360
                        choices=["auto", "fp8_w8a8", "int8_w8a16"],
zhuwenwen's avatar
zhuwenwen committed
361
362
363
364
365
366
367
                        default="auto")
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--batch-size", type=int, required=False)
    parser.add_argument("--tune", action="store_true")
    args = parser.parse_args()

    main(args)