bench_dsv3_router_gemm.py 4.32 KB
Newer Older
1
import argparse
2
3
4
5
6
7
8
import os

# CI environment detection
IS_CI = (
    os.getenv("CI", "false").lower() == "true"
    or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
9
10
11
12
13
14
15

import torch
import torch.nn.functional as F
import triton
import triton.testing
from sgl_kernel import dsv3_router_gemm

16
17
18
19
20
21
22
23
# CI environment uses simplified parameters
if IS_CI:
    num_tokens_vals = [1]  # Only test 1 value in CI
    line_vals = ["sgl-kernel-256"]  # Only test one implementation in CI
else:
    num_tokens_vals = [i + 1 for i in range(16)]  # Test 1-16 in full mode
    line_vals = ["torch-256", "sgl-kernel-256", "torch-384", "sgl-kernel-384"]

24

25
26
27
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=["num_tokens"],
28
        x_vals=num_tokens_vals,
29
30
        x_log=False,
        line_arg="impl",
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
        line_vals=line_vals,
        line_names=(
            [
                "torch-256",
                "dsv3_router_gemm-256",
                "torch-384",
                "dsv3_router_gemm-384",
            ]
            if not IS_CI
            else ["dsv3_router_gemm-256"]
        ),
        styles=(
            [("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")]
            if not IS_CI
            else [("orange", "-")]
        ),
47
48
49
50
51
52
53
        ylabel="TFLOPs",
        plot_name="input-bf16-output-bf16 dsv3 router gemm throughput",
        args={},
    )
)
def benchmark_bf16_output(num_tokens, impl):
    # M: num_tokens, K: hidden_dim, N: num_experts
54
55
56
57
58
59
60
61
    M, K = num_tokens, 7168

    if impl == "torch-256" or impl == "sgl-kernel-256":
        N = 256
    elif impl == "torch-384" or impl == "sgl-kernel-384":
        N = 384
    else:
        raise ValueError(f"Unknown impl: {impl}")
62
63
64
65
66
67

    mat_a = torch.randn((M, K), dtype=torch.bfloat16, device="cuda").contiguous()
    mat_b = torch.randn((N, K), dtype=torch.bfloat16, device="cuda").contiguous()

    quantiles = [0.5, 0.2, 0.8]

68
    if impl == "torch-256" or impl == "torch-384":
69
70
71
72

        def runner():
            F.linear(mat_a, mat_b)

73
    elif impl == "sgl-kernel-256" or impl == "sgl-kernel-384":
74
75
76
77

        def runner():
            dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.bfloat16)

78
    ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(runner, quantiles=quantiles)
79
80
81
82
83
84
85
86

    def tflops(t_ms):
        flops = 2 * M * K * N
        return flops / (t_ms * 1e-3) / 1e12

    return tflops(ms), tflops(max_ms), tflops(min_ms)


87
88
89
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=["num_tokens"],
90
        x_vals=num_tokens_vals,
91
92
        x_log=False,
        line_arg="impl",
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
        line_vals=line_vals,
        line_names=(
            [
                "torch-256",
                "dsv3_router_gemm-256",
                "torch-384",
                "dsv3_router_gemm-384",
            ]
            if not IS_CI
            else ["dsv3_router_gemm-256"]
        ),
        styles=(
            [("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")]
            if not IS_CI
            else [("orange", "-")]
        ),
109
110
111
112
113
        ylabel="TFLOPs",
        plot_name="input-bf16-output-fp32 dsv3 router gemm throughput",
        args={},
    )
)
114
def benchmark_float_output(num_tokens, impl):
115
    # M: num_tokens, K: hidden_dim, N: num_experts
116
117
118
119
120
121
122
123
    M, K = num_tokens, 7168

    if impl == "torch-256" or impl == "sgl-kernel-256":
        N = 256
    elif impl == "torch-384" or impl == "sgl-kernel-384":
        N = 384
    else:
        raise ValueError(f"Unknown impl: {impl}")
124
125
126
127
128
129

    mat_a = torch.randn((M, K), dtype=torch.bfloat16, device="cuda").contiguous()
    mat_b = torch.randn((N, K), dtype=torch.bfloat16, device="cuda").contiguous()

    quantiles = [0.5, 0.2, 0.8]

130
    if impl == "torch-256" or impl == "torch-384":
131
132
133
134

        def runner():
            F.linear(mat_a, mat_b).to(torch.float32)

135
    elif impl == "sgl-kernel-256" or impl == "sgl-kernel-384":
136
137

        def runner():
138
            dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.float32)
139

140
    ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(runner, quantiles=quantiles)
141
142
143
144
145
146
147
148
149
150
151
152

    def tflops(t_ms):
        flops = 2 * M * K * N
        return flops / (t_ms * 1e-3) / 1e12

    return tflops(ms), tflops(max_ms), tflops(min_ms)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    args = parser.parse_args()

153
154
    benchmark_bf16_output.run(print_data=True)
    benchmark_float_output.run(print_data=True)