benchmark_router_gemm.py 4.42 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import torch
import torch.nn.functional as F

from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.transformers_utils.config import get_config
from vllm.triton_utils import triton
from vllm.utils.argparse_utils import FlexibleArgumentParser

# Dimensions supported by the DSV3 specialized kernel
DSV3_SUPPORTED_NUM_EXPERTS = [256, 384]
DSV3_SUPPORTED_HIDDEN_SIZES = [7168]

# Dimensions supported by the gpt-oss specialized kernel
GPT_OSS_SUPPORTED_NUM_EXPERTS = [32, 128]
GPT_OSS_SUPPORTED_HIDDEN_SIZES = [2880]


def get_batch_size_range(max_batch_size):
    return [2**x for x in range(14) if 2**x <= max_batch_size]


def get_model_params(config):
    if config.architectures[0] in (
        "DeepseekV2ForCausalLM",
        "DeepseekV3ForCausalLM",
        "DeepseekV32ForCausalLM",
    ):
        num_experts = config.n_routed_experts
        hidden_size = config.hidden_size
    elif config.architectures[0] in ("GptOssForCausalLM",):
        num_experts = config.num_local_experts
        hidden_size = config.hidden_size
    else:
        raise ValueError(f"Unsupported architecture: {config.architectures}")
    return num_experts, hidden_size


def get_benchmark(model, max_batch_size, trust_remote_code):
    @triton.testing.perf_report(
        triton.testing.Benchmark(
            x_names=["batch_size"],
            x_vals=get_batch_size_range(max_batch_size),
            x_log=False,
            line_arg="provider",
            line_vals=[
                "torch",
                "vllm",
            ],
            line_names=["PyTorch", "vLLM"],
            styles=([("blue", "-"), ("red", "-")]),
            ylabel="TFLOPs",
            plot_name=f"{model} router gemm throughput",
            args={},
        )
    )
    def benchmark(batch_size, provider):
        config = get_config(model=model, trust_remote_code=trust_remote_code)
        num_experts, hidden_size = get_model_params(config)

        mat_a = torch.randn(
            (batch_size, hidden_size), dtype=torch.bfloat16, device="cuda"
        ).contiguous()
        mat_b = torch.randn(
            (num_experts, hidden_size), dtype=torch.bfloat16, device="cuda"
        ).contiguous()
        bias = torch.randn(
            num_experts, dtype=torch.bfloat16, device="cuda"
        ).contiguous()

        is_hopper_or_blackwell = current_platform.is_device_capability(
            90
        ) or current_platform.is_device_capability_family(100)
        allow_dsv3_router_gemm = (
            is_hopper_or_blackwell
            and num_experts in DSV3_SUPPORTED_NUM_EXPERTS
            and hidden_size in DSV3_SUPPORTED_HIDDEN_SIZES
        )
        allow_gpt_oss_router_gemm = (
            is_hopper_or_blackwell
            and num_experts in GPT_OSS_SUPPORTED_NUM_EXPERTS
            and hidden_size in GPT_OSS_SUPPORTED_HIDDEN_SIZES
        )

        has_bias = False
        if allow_gpt_oss_router_gemm:
            has_bias = True

        quantiles = [0.5, 0.2, 0.8]

        if provider == "torch":

            def runner():
                if has_bias:
                    F.linear(mat_a, mat_b, bias)
                else:
                    F.linear(mat_a, mat_b)
        elif provider == "vllm":

            def runner():
                if allow_dsv3_router_gemm:
                    ops.dsv3_router_gemm(mat_a, mat_b, torch.bfloat16)
                elif allow_gpt_oss_router_gemm:
                    ops.gpt_oss_router_gemm(mat_a, mat_b, bias)
                else:
                    raise ValueError("Unsupported router gemm")

        ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
            runner, quantiles=quantiles
        )

        def tflops(t_ms):
            flops = 2 * batch_size * hidden_size * num_experts
            return flops / (t_ms * 1e-3) / 1e12

        return tflops(ms), tflops(max_ms), tflops(min_ms)

    return benchmark


if __name__ == "__main__":
    parser = FlexibleArgumentParser()
    parser.add_argument("--model", type=str, default="openai/gpt-oss-20b")
    parser.add_argument("--max-batch-size", default=16, type=int)
    parser.add_argument("--trust-remote-code", action="store_true")
    args = parser.parse_args()

    # Get the benchmark function
    benchmark = get_benchmark(args.model, args.max_batch_size, args.trust_remote_code)
    # Run performance benchmark
    benchmark.run(print_data=True)