bench_dsv3_router_gemm.py 3.58 KB
Newer Older
1
2
3
4
5
6
7
8
9
import argparse

import torch
import torch.nn.functional as F
import triton
import triton.testing
from sgl_kernel import dsv3_router_gemm


10
11
12
13
14
15
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=["num_tokens"],
        x_vals=[i + 1 for i in range(16)],
        x_log=False,
        line_arg="impl",
16
17
18
19
20
21
22
23
        line_vals=["torch-256", "sgl-kernel-256", "torch-384", "sgl-kernel-384"],
        line_names=[
            "torch-256",
            "dsv3_router_gemm-256",
            "torch-384",
            "dsv3_router_gemm-384",
        ],
        styles=[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")],
24
25
26
27
28
29
30
        ylabel="TFLOPs",
        plot_name="input-bf16-output-bf16 dsv3 router gemm throughput",
        args={},
    )
)
def benchmark_bf16_output(num_tokens, impl):
    # M: num_tokens, K: hidden_dim, N: num_experts
31
32
33
34
35
36
37
38
    M, K = num_tokens, 7168

    if impl == "torch-256" or impl == "sgl-kernel-256":
        N = 256
    elif impl == "torch-384" or impl == "sgl-kernel-384":
        N = 384
    else:
        raise ValueError(f"Unknown impl: {impl}")
39
40
41
42
43
44

    mat_a = torch.randn((M, K), dtype=torch.bfloat16, device="cuda").contiguous()
    mat_b = torch.randn((N, K), dtype=torch.bfloat16, device="cuda").contiguous()

    quantiles = [0.5, 0.2, 0.8]

45
    if impl == "torch-256" or impl == "torch-384":
46
47
48
49

        def runner():
            F.linear(mat_a, mat_b)

50
    elif impl == "sgl-kernel-256" or impl == "sgl-kernel-384":
51
52
53
54

        def runner():
            dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.bfloat16)

55
    ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(runner, quantiles=quantiles)
56
57
58
59
60
61
62
63

    def tflops(t_ms):
        flops = 2 * M * K * N
        return flops / (t_ms * 1e-3) / 1e12

    return tflops(ms), tflops(max_ms), tflops(min_ms)


64
65
66
67
68
69
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=["num_tokens"],
        x_vals=[i + 1 for i in range(16)],
        x_log=False,
        line_arg="impl",
70
71
72
73
74
75
76
77
        line_vals=["torch-256", "sgl-kernel-256", "torch-384", "sgl-kernel-384"],
        line_names=[
            "torch-256",
            "dsv3_router_gemm-256",
            "torch-384",
            "dsv3_router_gemm-384",
        ],
        styles=[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")],
78
79
80
81
82
        ylabel="TFLOPs",
        plot_name="input-bf16-output-fp32 dsv3 router gemm throughput",
        args={},
    )
)
83
def benchmark_float_output(num_tokens, impl):
84
    # M: num_tokens, K: hidden_dim, N: num_experts
85
86
87
88
89
90
91
92
    M, K = num_tokens, 7168

    if impl == "torch-256" or impl == "sgl-kernel-256":
        N = 256
    elif impl == "torch-384" or impl == "sgl-kernel-384":
        N = 384
    else:
        raise ValueError(f"Unknown impl: {impl}")
93
94
95
96
97
98

    mat_a = torch.randn((M, K), dtype=torch.bfloat16, device="cuda").contiguous()
    mat_b = torch.randn((N, K), dtype=torch.bfloat16, device="cuda").contiguous()

    quantiles = [0.5, 0.2, 0.8]

99
    if impl == "torch-256" or impl == "torch-384":
100
101
102
103

        def runner():
            F.linear(mat_a, mat_b).to(torch.float32)

104
    elif impl == "sgl-kernel-256" or impl == "sgl-kernel-384":
105
106

        def runner():
107
            dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.float32)
108

109
    ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(runner, quantiles=quantiles)
110
111
112
113
114
115
116
117
118
119
120
121

    def tflops(t_ms):
        flops = 2 * M * K * N
        return flops / (t_ms * 1e-3) / 1e12

    return tflops(ms), tflops(max_ms), tflops(min_ms)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    args = parser.parse_args()

122
123
    benchmark_bf16_output.run(print_data=True)
    benchmark_float_output.run(print_data=True)