bench_moe_topk_softmax.py 4.66 KB
Newer Older
1
import itertools
2
import os
3
4
5
6
7

import pytest
import torch
import triton
from sgl_kernel import topk_softmax
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22

# Optional vLLM import
try:
    from vllm import _custom_ops as vllm_custom_ops

    VLLM_AVAILABLE = True
except ImportError:
    vllm_custom_ops = None
    VLLM_AVAILABLE = False

# CI environment detection
IS_CI = (
    os.getenv("CI", "false").lower() == "true"
    or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
23
24
25


def vllm_topk_softmax(gating_output, topk):
26
27
28
29
    if not VLLM_AVAILABLE:
        # Fallback to SGLang implementation if vLLM is not available
        return sglang_topk_softmax(gating_output, topk)

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
    num_tokens, num_experts = gating_output.shape

    topk_weights = torch.empty(
        (num_tokens, topk), device=gating_output.device, dtype=torch.float32
    )
    topk_indices = torch.empty(
        (num_tokens, topk), dtype=torch.int32, device=gating_output.device
    )
    token_expert_indices = torch.empty(
        (num_tokens, topk), dtype=torch.int32, device=gating_output.device
    )
    torch.ops._moe_C.topk_softmax(
        topk_weights, topk_indices, token_expert_indices, gating_output
    )
    return topk_weights, topk_indices


def sglang_topk_softmax(gating_output, topk):
    num_tokens, num_experts = gating_output.shape

    topk_weights = torch.empty(
        (num_tokens, topk), device=gating_output.device, dtype=torch.float32
    )
    topk_indices = torch.empty(
        (num_tokens, topk), dtype=torch.int32, device=gating_output.device
    )

    topk_softmax(
        topk_weights=topk_weights,
        topk_ids=topk_indices,
        gating_output=gating_output,
    )

    return topk_weights, topk_indices


def calculate_diff(num_tokens, num_experts, topk):
    gating_output = torch.randn(
        (num_tokens, num_experts), device="cuda", dtype=torch.float32
    )
    weights_vllm, indices_vllm = vllm_topk_softmax(gating_output.clone(), topk)
    weights_sglang, indices_sglang = sglang_topk_softmax(gating_output.clone(), topk)

    weights_diff = torch.abs(weights_vllm - weights_sglang).mean().item()
    indices_match = torch.equal(indices_vllm, indices_sglang)

76
77
78
79
    if not VLLM_AVAILABLE:
        print("⚠️ vLLM not available, skipping comparison")
        return

80
81
82
83
84
85
86
87
88
89
90
    if (
        torch.allclose(weights_vllm, weights_sglang, atol=1e-3, rtol=1e-3)
        and indices_match
    ):
        print("✅ VLLM and SGLang topk_softmax implementations match")
    else:
        print(
            f"❌ Implementations differ: Weights diff={weights_diff}, Indices match={indices_match}"
        )


91
92
93
94
95
96
97
98
99
# CI environment uses simplified parameters
if IS_CI:
    num_tokens_range = [128]  # Single value for CI
    num_experts_range = [32]  # Single value for CI
    topk_range = [2]  # Single value for CI
else:
    num_tokens_range = [128, 512, 1024, 2048, 4096, 8192, 16384, 32768]
    num_experts_range = [32, 64, 128, 256, 12, 512]
    topk_range = [1, 2, 4, 8]
100
101
102
103

configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range))


104
105
106
107
108
109
110
111
112
113
114
# Filter providers based on vLLM availability
if VLLM_AVAILABLE:
    line_vals = ["sglang", "vllm"]
    line_names = ["SGLang", "VLLM"]
    styles = [("blue", "-"), ("green", "-")]
else:
    line_vals = ["sglang"]
    line_names = ["SGLang"]
    styles = [("blue", "-")]


115
116
117
118
119
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=["num_tokens", "num_experts", "topk"],
        x_vals=configs,
        line_arg="provider",
120
121
122
        line_vals=line_vals,
        line_names=line_names,
        styles=styles,
123
124
125
126
127
128
129
130
131
132
133
134
        ylabel="Latency (us)",
        plot_name="topk-softmax-performance",
        args={},
    )
)
def benchmark(num_tokens, num_experts, topk, provider):

    gating_output = torch.randn(
        (num_tokens, num_experts), device="cuda", dtype=torch.float32
    )

    if provider == "vllm" or provider == "vllm1":
135
136
        if not VLLM_AVAILABLE:
            return (0, 0, 0)
137
138
139
140
141
        fn = lambda: vllm_topk_softmax(gating_output, topk)
    elif provider == "sglang" or provider == "sglang1":
        fn = lambda: sglang_topk_softmax(gating_output, topk)

    quantiles = [0.5, 0.2, 0.8]
142
    ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
143
144
145
146
147

    return 1000 * ms, 1000 * max_ms, 1000 * min_ms


if __name__ == "__main__":
148
149
150
151
152
153
154
155
156
157
158
159
160
161
    # Simplify configs for CI environment
    if IS_CI:
        test_configs = [(20, 32, 2)]  # Single config for CI
    else:
        test_configs = [
            (20, 256, 4),
            (20, 256, 8),
            (20, 12, 4),
            (20, 12, 1),
            (20, 512, 4),
            (20, 512, 1),
        ]

    for num_tokens, num_experts, topk in test_configs:
162
163
        calculate_diff(num_tokens, num_experts, topk)
    benchmark.run(print_data=True)