benchmark_moe_rocm.py 12 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
import argparse
import json
import os
import sys

import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from tqdm import tqdm
from transformers import AutoConfig

Lianmin Zheng's avatar
Lianmin Zheng committed
13
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe, get_config_file_name
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377

padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0


def main(model, tp_size, dtype: str, batches):
    method = fused_moe

    for bs in batches:
        run_grid(int(bs), model=model, method=method, tp_size=tp_size, dtype=dtype)


def prune_configs(M, N, K, configs):
    pruned_configs = []
    elemBytes_a = 1  # [DV Note] Hard-coded for float16 (2 bytes)
    elemBytes_b = 1  # [DV Note] Hard-coded for float16 (2 bytes)

    mfma = 16 if M < 32 or N < 32 else 32

    # TODO (zhanglx): figure out the boundary between large and small gemms
    large_gemm = False
    if M >= 2048 and N >= 2048:
        large_gemm = True

    for config in configs:
        BLOCK_SIZE_M = config.get("BLOCK_SIZE_M")
        BLOCK_SIZE_N = config.get("BLOCK_SIZE_N")
        BLOCK_SIZE_K = config.get("BLOCK_SIZE_K")
        num_warps = config.get("num_warps")
        matrix_instr_nonkdim = config.get("matrix_instr_nonkdim")
        # kpack = config.get("kpack")
        if matrix_instr_nonkdim > mfma:
            continue
        if mfma == 4 and BLOCK_SIZE_K < 64:
            continue
        # some layouts could not work properly in case
        # number elements per thread is less 1
        if BLOCK_SIZE_M * BLOCK_SIZE_N < 64:
            continue
        SPLIT_K = 1  # config.get("SPLIT_K")
        GROUP_M = config.get("GROUP_SIZE_M")
        if matrix_instr_nonkdim > BLOCK_SIZE_M or matrix_instr_nonkdim > BLOCK_SIZE_N:
            continue
        if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M:
            continue
        if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N:
            continue
        # Skip BLOCK_SIZE that is too large compare to M/N
        # unless BLOCK_SIZE is already small enough
        if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16:
            continue
        if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16:
            continue
        # skip large split_k when not necessary
        if SPLIT_K != 1 and not need_split_k(M, N, K):
            continue
        # skip split_k that leads to EVEN_K = false
        leap = SPLIT_K * BLOCK_SIZE_K
        modv = K % leap
        if modv != 0:
            continue
        # skip large GROUP_M
        if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1:
            continue
        # out of shared memory resource
        # TODO (zhanglx): This does not consider the LDS usage in the epilogue
        LDS = (
            BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a
            + BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b
        )
        if LDS > 65536:
            continue
        # Skip small block sizes and num_warps for large gemm
        # For fp16 and f8, we want to only use BLOCK_SIZE >= 64
        if large_gemm:
            if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64:
                continue
            if BLOCK_SIZE_K < 64:
                continue
            if num_warps < 4:
                continue

        pruned_configs.append(config)

    return pruned_configs


def union_of_list_of_dicts(l1, l2):
    result = []
    temp_list = l1.copy()
    temp_list.extend(l2)
    for myDict in temp_list:
        if myDict not in result:
            result.append(myDict)

    return result


def run_grid(bs, model, method, tp_size, dtype: str):

    config = AutoConfig.from_pretrained(model)

    top_k = config.num_experts_per_tok
    d_model = config.hidden_size
    model_intermediate_size = config.intermediate_size
    num_layers = config.num_hidden_layers
    hidden_states_dtype = config.torch_dtype

    if config.num_experts_per_tok:
        if config.architectures[0] == "Grok1ModelForCausalLM":
            num_total_experts = config.num_experts
        else:
            num_total_experts = config.num_local_experts
    else:
        raise ValueError(f"Unsupported Mixtral model {model}")

    # tp_size = 2
    num_warmup_calls = 10
    num_calls = 30

    num_warmup_trials = 1
    num_trials = 1

    full_configs = []

    block_m_range = [16, 32, 64, 128, 256]
    block_n_range = [16, 32, 64, 128, 256]
    block_k_range = [32, 64, 128, 256]  # MUST >= 32
    num_warps_range = [1, 2, 4, 8]
    group_m_range = [1, 4, 8, 16, 32]
    # For now we see better perf with num_stages=0 for all gemm configs we care
    # But keep this explicit so that we do not forget we may need to set it to
    # other values in the future
    num_stage_range = [2]
    waves_per_eu_range = [0, 1, 2, 4, 8]
    # Remove 32 because of triton compiling error
    matrix_instr_nonkdim_range = [16]
    kpack_range = [1, 2]

    for block_size_m in block_m_range:
        for block_size_n in block_n_range:
            for block_size_k in block_k_range:
                for group_size_m in group_m_range:
                    for num_warps in num_warps_range:
                        for num_stages in num_stage_range:
                            for waves_per_eu in waves_per_eu_range:
                                for matrix_instr_nonkdim in matrix_instr_nonkdim_range:
                                    for kpack in kpack_range:
                                        full_configs.append(
                                            {
                                                "BLOCK_SIZE_M": block_size_m,
                                                "BLOCK_SIZE_N": block_size_n,
                                                "BLOCK_SIZE_K": block_size_k,
                                                "GROUP_SIZE_M": group_size_m,
                                                "num_warps": num_warps,
                                                "num_stages": num_stages,
                                                "waves_per_eu": waves_per_eu,
                                                "matrix_instr_nonkdim": matrix_instr_nonkdim,
                                                "kpack": kpack,
                                            }
                                        )

    M1 = bs * 2
    N1 = model_intermediate_size * 2 // tp_size
    K1 = d_model
    prune_configs_1 = prune_configs(M1, N1, K1, full_configs)

    M2 = bs * 2
    N2 = d_model
    K2 = model_intermediate_size // tp_size
    prune_configs_2 = prune_configs(M2, N2, K2, full_configs)

    configs = union_of_list_of_dicts(prune_configs_1, prune_configs_2)

    print(
        f"{bs=} || {len(full_configs)=} | {len(prune_configs_1)=} | \
            {len(prune_configs_2)=} | {len(configs)=}"
    )

    best_config = None
    best_time_us = 1e20

    print(f"{tp_size=} {bs=}")

    for config in tqdm(configs):
        # warmup
        try:
            print(config)
            for _ in range(num_warmup_trials):
                run_timing(
                    num_calls=num_warmup_calls,
                    bs=bs,
                    d_model=d_model,
                    num_total_experts=num_total_experts,
                    top_k=top_k,
                    tp_size=tp_size,
                    model_intermediate_size=model_intermediate_size,
                    method=method,
                    config=config,
                    dtype=dtype,
                    hidden_states_dtype=hidden_states_dtype,
                )
        except triton.runtime.autotuner.OutOfResources:
            continue

        # trial
        for _ in range(num_trials):
            kernel_dur_ms = run_timing(
                num_calls=num_calls,
                bs=bs,
                d_model=d_model,
                num_total_experts=num_total_experts,
                top_k=top_k,
                tp_size=tp_size,
                model_intermediate_size=model_intermediate_size,
                method=method,
                config=config,
                dtype=dtype,
                hidden_states_dtype=hidden_states_dtype,
            )

            kernel_dur_us = 1000 * kernel_dur_ms
            model_dur_ms = kernel_dur_ms * num_layers

            if kernel_dur_us < best_time_us:
                best_config = config
                best_time_us = kernel_dur_us

                tqdm.write(
                    f"{kernel_dur_us=:.1f} {model_dur_ms=:.1f}"
                    f" {bs=} {tp_size=} {top_k=} {num_total_experts=} "
                    f"{d_model=} {model_intermediate_size=} {num_layers=}"
                )

    print("best_time_us", best_time_us)
    print("best_config", best_config)

    # holds Dict[str, Dict[str, int]]
    filename = get_config_file_name(
        num_total_experts,
        model_intermediate_size // tp_size,
        "float8" if dtype == "float8" else None,
    )
    print(f"writing config to file {filename}")
    existing_content = {}
    if os.path.exists(filename):
        with open(filename, "r") as f:
            existing_content = json.load(f)
    existing_content[str(bs)] = best_config
    with open(filename, "w") as f:
        json.dump(existing_content, f, indent=4)
        f.write("\n")


def run_timing(
    num_calls: int,
    bs: int,
    d_model: int,
    num_total_experts: int,
    top_k: int,
    tp_size: int,
    model_intermediate_size: int,
    method,
    config,
    dtype: str,
    hidden_states_dtype,
) -> float:
    shard_intermediate_size = model_intermediate_size // tp_size

    hidden_states = torch.rand(
        (bs, d_model),
        device="cuda:0",
        dtype=hidden_states_dtype,
    )

    w1 = torch.rand(
        (num_total_experts, 2 * shard_intermediate_size, d_model + padding_size),
        device=hidden_states.device,
        dtype=hidden_states.dtype,
    )

    w2 = torch.rand(
        (num_total_experts, d_model, shard_intermediate_size + padding_size),
        device=hidden_states.device,
        dtype=hidden_states.dtype,
    )

    w1_scale = None
    w2_scale = None
    a1_scale = None
    a2_scale = None

    if dtype == "float8":
        w1 = w1.to(torch.float8_e4m3fnuz)
        w2 = w2.to(torch.float8_e4m3fnuz)
        w1_scale = torch.ones(
            num_total_experts, device=hidden_states.device, dtype=torch.float32
        )
        w2_scale = torch.ones(
            num_total_experts, device=hidden_states.device, dtype=torch.float32
        )
        a1_scale = torch.ones(1, device=hidden_states.device, dtype=torch.float32)
        a2_scale = torch.ones(1, device=hidden_states.device, dtype=torch.float32)

    gating_output = F.softmax(
        torch.rand(
            (num_calls, bs, num_total_experts),
            device=hidden_states.device,
            dtype=torch.float32,
        ),
        dim=-1,
    )

    ##################################

    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)

    start_event.record()
    for i in range(num_calls):
        hidden_states = method(
            hidden_states=hidden_states,
            w1=w1,
            w2=w2,
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
            gating_output=gating_output[0],
            topk=top_k,
            renormalize=True,
            inplace=True,
            override_config=config,
            use_fp8=dtype == "float8",
        )

    end_event.record()
    end_event.synchronize()

    dur_ms = start_event.elapsed_time(end_event) / num_calls
    return dur_ms


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        prog="benchmark_mixtral_moe",
        description="Benchmark and tune the fused_moe kernel",
    )
    parser.add_argument(
        "--dtype",
        type=str,
        default="auto",
        choices=["float8", "float16", "bfloat16"],
        help="Data type used for fused_moe kernel computations",
    )
    parser.add_argument("--model", type=str, default="hpcai-tech/grok-1")

    parser.add_argument("--tp-size", type=int, default=2, help="Tensor paralleli size")
    parser.add_argument("-b", "--batches", type=str)

    args = parser.parse_args()

    batches = args.batches.split(",")

    sys.exit(main(args.model, args.tp_size, args.dtype, batches))