bench_moe_fused_gate.py 2.49 KB
Newer Older
1
2
import itertools
import math
3
import os
4
5
6
7
8
9
10
11

import torch
import triton
import triton.language as tl
from sgl_kernel import moe_fused_gate

from sglang.srt.layers.moe.topk import biased_grouped_topk

12
13
14
15
16
17
# CI environment detection
IS_CI = (
    os.getenv("CI", "false").lower() == "true"
    or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)

18
19
20
21
22
23
24
25
26
27

def biased_grouped_topk_org(scores, bias, num_expert_group, topk_group, topk):
    return biased_grouped_topk(
        scores,
        scores,
        bias,
        topk=topk,
        renormalize=True,
        num_expert_group=num_expert_group,
        topk_group=topk_group,
28
        routed_scaling_factor=2.5,  # DeepSeek-R1 : 2.5, Kimi K2: 2.872
29
30
31
    )


32
33
34
def biased_grouped_topk_org_fuse_kernel(
    scores, bias, num_expert_group, topk_group, topk
):
35
36
37
    return moe_fused_gate(scores, bias, num_expert_group, topk_group, topk)


38
39
40
41
42
43
# CI environment uses simplified parameters
if IS_CI:
    seq_length_range = [5000]  # Only test one sequence length in CI
else:
    seq_length_range = [5000, 10000, 15000, 20000, 25000, 30000, 35000, 40000]

44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
configs = [(sq,) for sq in seq_length_range]


@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=["seq_length"],
        x_vals=[list(_) for _ in configs],
        line_arg="provider",
        line_vals=["original", "kernel"],
        line_names=["Original", "SGL Kernel"],
        styles=[("blue", "-"), ("red", "-")],
        ylabel="us",
        plot_name="moe-fused-gate-performance",
        args={},
    )
)
def benchmark(seq_length, provider):
61
    dtype = torch.float32
62
63
64
65
66
67
68
69
70
    device = torch.device("cuda")
    num_experts, num_expert_group, topk_group, topk = 256, 8, 4, 8

    scores = torch.randn((seq_length, num_experts), device=device, dtype=dtype)
    bias = torch.rand(num_experts, device=device, dtype=dtype)

    quantiles = [0.5, 0.2, 0.8]

    if provider == "original":
71
        ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
72
73
74
75
76
77
            lambda: biased_grouped_topk_org(
                scores.clone(), bias.clone(), num_expert_group, topk_group, topk
            ),
            quantiles=quantiles,
        )
    elif provider == "kernel":
78
        ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
79
            lambda: biased_grouped_topk_org_fuse_kernel(
80
81
82
83
84
85
86
87
88
89
                scores.clone(), bias.clone(), num_expert_group, topk_group, topk
            ),
            quantiles=quantiles,
        )

    return 1000 * ms, 1000 * max_ms, 1000 * min_ms


if __name__ == "__main__":
    benchmark.run(print_data=True)