benchmark_mixtral_moe.py 7.52 KB
Newer Older
1
import argparse
2
3
4
5
6
7
8
import json
import os
import sys

import torch
import torch.nn.functional as F
import triton
9
from tqdm import tqdm
10

11
12
13
from vllm.model_executor.layers.fused_moe import (fused_moe,
                                                  get_config_file_name)

14

15
16
def main(model, tp_size, gpu, dtype: str):
    os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
17
18
19
20
21
    method = fused_moe
    for bs in [
            1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
            2048, 3072, 4096
    ]:
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
        run_grid(bs,
                 model=model,
                 method=method,
                 gpu=gpu,
                 tp_size=tp_size,
                 dtype=dtype)


def run_grid(bs, model, method, gpu, tp_size, dtype: str):
    if model == '8x7B':
        d_model = 4096
        model_intermediate_size = 14336
        num_layers = 32
    elif model == '8x22B':
        d_model = 6144
        model_intermediate_size = 16384
        num_layers = 56
    else:
        raise ValueError(f'Unsupported Mixtral model {model}')
41
42
    num_total_experts = 8
    top_k = 2
43
    # tp_size = 2
44
45
46
47
48
49
50
51
    num_calls = 100

    num_warmup_trials = 1
    num_trials = 1

    configs = []

    for block_size_n in [32, 64, 128, 256]:
52
        for block_size_m in [16, 32, 64, 128, 256]:
53
54
55
            for block_size_k in [64, 128, 256]:
                for group_size_m in [1, 16, 32, 64]:
                    for num_warps in [4, 8]:
56
57
58
59
60
61
62
63
64
                        for num_stages in [2, 3, 4, 5]:
                            configs.append({
                                "BLOCK_SIZE_M": block_size_m,
                                "BLOCK_SIZE_N": block_size_n,
                                "BLOCK_SIZE_K": block_size_k,
                                "GROUP_SIZE_M": group_size_m,
                                "num_warps": num_warps,
                                "num_stages": num_stages,
                            })
65
66
67
68

    best_config = None
    best_time_us = 1e20

69
70
71
    print(f'{tp_size=} {bs=}')

    for config in tqdm(configs):
72
73
74
75
76
77
78
79
80
81
82
83
84
        # warmup
        try:
            for _ in range(num_warmup_trials):
                run_timing(
                    num_calls=num_calls,
                    bs=bs,
                    d_model=d_model,
                    num_total_experts=num_total_experts,
                    top_k=top_k,
                    tp_size=tp_size,
                    model_intermediate_size=model_intermediate_size,
                    method=method,
                    config=config,
85
                    dtype=dtype,
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
                )
        except triton.runtime.autotuner.OutOfResources:
            continue

        # trial
        for _ in range(num_trials):
            kernel_dur_ms = run_timing(
                num_calls=num_calls,
                bs=bs,
                d_model=d_model,
                num_total_experts=num_total_experts,
                top_k=top_k,
                tp_size=tp_size,
                model_intermediate_size=model_intermediate_size,
                method=method,
                config=config,
102
                dtype=dtype,
103
104
105
106
107
108
109
110
111
            )

            kernel_dur_us = 1000 * kernel_dur_ms
            model_dur_ms = kernel_dur_ms * num_layers

            if kernel_dur_us < best_time_us:
                best_config = config
                best_time_us = kernel_dur_us

112
113
114
115
                tqdm.write(
                    f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f}'
                    f' {bs=} {tp_size=} {top_k=} {num_total_experts=} '
                    f'{d_model=} {model_intermediate_size=} {num_layers=}')
116
117
118
119

    print("best_time_us", best_time_us)
    print("best_config", best_config)

120
121
    # holds Dict[str, Dict[str, int]]
    filename = get_config_file_name(num_total_experts,
122
123
                                    model_intermediate_size // tp_size,
                                    "float8" if dtype == "float8" else None)
124
    print(f"writing config to file {filename}")
125
126
127
128
129
130
131
132
    existing_content = {}
    if os.path.exists(filename):
        with open(filename, "r") as f:
            existing_content = json.load(f)
    existing_content[str(bs)] = best_config
    with open(filename, "w") as f:
        json.dump(existing_content, f, indent=4)
        f.write("\n")
133
134
135
136


def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int,
               top_k: int, tp_size: int, model_intermediate_size: int, method,
137
               config, dtype: str) -> float:
138
139
140
141
142
    shard_intermediate_size = model_intermediate_size // tp_size

    hidden_states = torch.rand(
        (bs, d_model),
        device="cuda:0",
143
        dtype=torch.float16,
144
145
    )

146
    w1 = torch.rand(
147
148
149
150
151
        (num_total_experts, 2 * shard_intermediate_size, d_model),
        device=hidden_states.device,
        dtype=hidden_states.dtype,
    )

152
    w2 = torch.rand(
153
154
155
156
157
        (num_total_experts, d_model, shard_intermediate_size),
        device=hidden_states.device,
        dtype=hidden_states.dtype,
    )

158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
    w1_scale = None
    w2_scale = None
    a1_scale = None
    a2_scale = None

    if dtype == "float8":
        w1 = w1.to(torch.float8_e4m3fn)
        w2 = w2.to(torch.float8_e4m3fn)
        w1_scale = torch.ones(num_total_experts,
                              device=hidden_states.device,
                              dtype=torch.float32)
        w2_scale = torch.ones(num_total_experts,
                              device=hidden_states.device,
                              dtype=torch.float32)
        a1_scale = torch.ones(1,
                              device=hidden_states.device,
                              dtype=torch.float32)
        a2_scale = torch.ones(1,
                              device=hidden_states.device,
                              dtype=torch.float32)

179
180
181
182
183
184
185
186
187
188
189
190
191
192
    gating_output = F.softmax(torch.rand(
        (num_calls, bs, num_total_experts),
        device=hidden_states.device,
        dtype=torch.float32,
    ),
                              dim=-1)

    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)

    start_event.record()
    for i in range(num_calls):
        hidden_states = method(
            hidden_states=hidden_states,
193
194
195
196
197
198
            w1=w1,
            w2=w2,
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
199
200
201
202
203
            gating_output=gating_output[i],
            topk=2,
            renormalize=True,
            inplace=True,
            override_config=config,
204
            use_fp8=dtype == "float8",
205
206
207
208
209
210
211
212
213
        )
    end_event.record()
    end_event.synchronize()

    dur_ms = start_event.elapsed_time(end_event) / num_calls
    return dur_ms


if __name__ == "__main__":
214
215
216
217
218
219
220
221
222
223
224
    parser = argparse.ArgumentParser(
        prog='benchmark_mixtral_moe',
        description='Benchmark and tune the fused_moe kernel',
    )
    parser.add_argument(
        '--dtype',
        type=str,
        default='auto',
        choices=['float8', 'float16'],
        help='Data type used for fused_moe kernel computations',
    )
225
226
227
228
229
230
231
232
233
234
235
236
237
    parser.add_argument('--model',
                        type=str,
                        default='8x7B',
                        choices=['8x7B', '8x22B'],
                        help='The Mixtral model to benchmark')
    parser.add_argument('--tp-size',
                        type=int,
                        default=2,
                        help='Tensor paralleli size')
    parser.add_argument('--gpu',
                        type=int,
                        default=0,
                        help="GPU ID for benchmarking")
238
    args = parser.parse_args()
239
    sys.exit(main(args.model, args.tp_size, args.gpu, args.dtype))