Unverified Commit 61e4433c authored by Qingquan Song's avatar Qingquan Song Committed by GitHub
Browse files

Add moe topk softmax templated from vllm (#4302)

parent 660305c3
import itertools
import pytest
import torch
import triton
from sgl_kernel import topk_softmax
from vllm import _custom_ops as vllm_custom_ops
def vllm_topk_softmax(gating_output, topk):
num_tokens, num_experts = gating_output.shape
topk_weights = torch.empty(
(num_tokens, topk), device=gating_output.device, dtype=torch.float32
)
topk_indices = torch.empty(
(num_tokens, topk), dtype=torch.int32, device=gating_output.device
)
token_expert_indices = torch.empty(
(num_tokens, topk), dtype=torch.int32, device=gating_output.device
)
torch.ops._moe_C.topk_softmax(
topk_weights, topk_indices, token_expert_indices, gating_output
)
return topk_weights, topk_indices
def sglang_topk_softmax(gating_output, topk):
num_tokens, num_experts = gating_output.shape
topk_weights = torch.empty(
(num_tokens, topk), device=gating_output.device, dtype=torch.float32
)
topk_indices = torch.empty(
(num_tokens, topk), dtype=torch.int32, device=gating_output.device
)
token_expert_indices = torch.empty(
(num_tokens, topk), dtype=torch.int32, device=gating_output.device
)
topk_softmax(
topk_weights=topk_weights,
topk_ids=topk_indices,
token_expert_indices=token_expert_indices,
gating_output=gating_output,
)
return topk_weights, topk_indices
def calculate_diff(num_tokens, num_experts, topk):
gating_output = torch.randn(
(num_tokens, num_experts), device="cuda", dtype=torch.float32
)
weights_vllm, indices_vllm = vllm_topk_softmax(gating_output.clone(), topk)
weights_sglang, indices_sglang = sglang_topk_softmax(gating_output.clone(), topk)
weights_diff = torch.abs(weights_vllm - weights_sglang).mean().item()
indices_match = torch.equal(indices_vllm, indices_sglang)
if (
torch.allclose(weights_vllm, weights_sglang, atol=1e-3, rtol=1e-3)
and indices_match
):
print("✅ VLLM and SGLang topk_softmax implementations match")
else:
print(
f"❌ Implementations differ: Weights diff={weights_diff}, Indices match={indices_match}"
)
num_tokens_range = [128, 512, 1024, 2048, 4096, 8192, 16384, 32768]
num_experts_range = [32, 64, 128, 256, 12, 512]
topk_range = [1, 2, 4, 8]
configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["num_tokens", "num_experts", "topk"],
x_vals=configs,
line_arg="provider",
line_vals=["sglang", "vllm"],
line_names=["SGLang", "VLLM"],
styles=[("blue", "-"), ("green", "-")],
ylabel="Latency (us)",
plot_name="topk-softmax-performance",
args={},
)
)
def benchmark(num_tokens, num_experts, topk, provider):
gating_output = torch.randn(
(num_tokens, num_experts), device="cuda", dtype=torch.float32
)
if provider == "vllm" or provider == "vllm1":
fn = lambda: vllm_topk_softmax(gating_output, topk)
elif provider == "sglang" or provider == "sglang1":
fn = lambda: sglang_topk_softmax(gating_output, topk)
quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
if __name__ == "__main__":
configs = [
(20, 256, 4),
(20, 256, 8),
(20, 12, 4),
(20, 12, 1),
(20, 512, 4),
(20, 512, 1),
]
for num_tokens, num_experts, topk in configs:
calculate_diff(num_tokens, num_experts, topk)
benchmark.run(print_data=True)
This diff is collapsed.
......@@ -117,6 +117,11 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()");
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
m.def(
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"token_expert_indices, Tensor gating_output) -> ()");
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
/*
* From csrc/speculative
*/
......
......@@ -173,6 +173,12 @@ void moe_align_block_size(
torch::Tensor token_cnts_buffer,
torch::Tensor cumsum_buffer);
void topk_softmax(
torch::Tensor& topk_weights,
torch::Tensor& topk_indices,
torch::Tensor& token_expert_indices,
torch::Tensor& gating_output);
/*
* From csrc/speculative
*/
......
......@@ -65,6 +65,15 @@ inline int getSMVersion() {
return sm_major * 10 + sm_minor;
}
// SGLANG_SHFL_XOR_* adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/csrc/cuda_compat.h#L19-L28
#ifndef USE_ROCM
#define SGLANG_SHFL_XOR_SYNC(mask, var, lane_mask) __shfl_xor_sync((mask), (var), (lane_mask))
#define SGLANG_SHFL_XOR_SYNC_WIDTH(mask, var, lane_mask, width) __shfl_xor_sync((mask), (var), (lane_mask), (width))
#else
#define SGLANG_SHFL_XOR_SYNC(mask, var, lane_mask) __shfl_xor((var), (lane_mask))
#define SGLANG_SHFL_XOR_SYNC_WIDTH(mask, var, lane_mask, width) __shfl_xor((var), (lane_mask), (width))
#endif
#ifndef USE_ROCM
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(pytorch_dtype, c_type, ...) \
[&]() -> bool { \
......@@ -117,11 +126,11 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
}
__device__ __forceinline__ float warpReduceMax(float max_value) {
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 16));
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 8));
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 4));
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 2));
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 1));
max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 16));
max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 8));
max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 4));
max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 2));
max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 1));
return max_value;
}
......
......@@ -33,7 +33,7 @@ from sgl_kernel.gemm import (
sgl_per_token_group_quant_fp8,
sgl_per_token_quant_fp8,
)
from sgl_kernel.moe import moe_align_block_size
from sgl_kernel.moe import moe_align_block_size, topk_softmax
from sgl_kernel.sampling import (
min_p_sampling_from_probs,
top_k_renorm_prob,
......
......@@ -21,3 +21,14 @@ def moe_align_block_size(
token_cnts_buffer,
cumsum_buffer,
)
def topk_softmax(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: float,
) -> None:
torch.ops.sgl_kernel.topk_softmax(
topk_weights, topk_ids, token_expert_indices, gating_output
)
......@@ -157,6 +157,7 @@ sources = [
"csrc/gemm/per_token_quant_fp8.cu",
"csrc/gemm/per_tensor_quant_fp8.cu",
"csrc/moe/moe_align_kernel.cu",
"csrc/moe/moe_topk_softmax_kernels.cu",
"csrc/speculative/eagle_utils.cu",
"csrc/speculative/speculative_sampling.cu",
"csrc/torch_extension.cc",
......
import itertools
import pytest
import torch
from sgl_kernel import topk_softmax
@pytest.mark.parametrize(
"num_tokens, num_experts, topk",
list(
itertools.product(
[1, 16, 128, 512, 1024, 2048], # num_tokens
[4, 8, 16, 32, 64, 128, 256], # num_experts
[1, 2, 4], # topk
)
),
)
def test_topk_softmax(num_tokens, num_experts, topk):
gating_output = torch.randn(
(num_tokens, num_experts), dtype=torch.float32, device="cuda"
)
topk_weights = torch.empty((num_tokens, topk), dtype=torch.float32, device="cuda")
topk_indices = torch.empty((num_tokens, topk), dtype=torch.int32, device="cuda")
token_expert_indices = torch.empty(
(num_tokens, topk), dtype=torch.int32, device="cuda"
)
topk_softmax(
topk_weights,
topk_indices,
token_expert_indices,
gating_output,
)
# Native torch implementation
softmax_output = torch.softmax(gating_output, dim=-1)
topk_weights_ref, topk_indices_ref = torch.topk(softmax_output, topk, dim=-1)
# Verify the top-k weights and indices match the torch native ones
assert torch.allclose(
topk_weights_ref, topk_weights, atol=1e-3, rtol=1e-3
), f"Weights mismatch: torch={topk_indices_ref} vs SGLang={topk_weights}"
assert torch.equal(
topk_indices_ref, topk_indices
), f"Indices mismatch: torch={topk_indices_ref}, SGLang={topk_indices}"
print("✅ Native torch and custom kernel implementations match.")
if __name__ == "__main__":
pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment