Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
61e4433c
Unverified
Commit
61e4433c
authored
Mar 14, 2025
by
Qingquan Song
Committed by
GitHub
Mar 14, 2025
Browse files
Add moe topk softmax templated from vllm (#4302)
parent
660305c3
Changes
9
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
716 additions
and
6 deletions
+716
-6
sgl-kernel/benchmark/bench_moe_topk_softmax.py
sgl-kernel/benchmark/bench_moe_topk_softmax.py
+120
-0
sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu
sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu
+505
-0
sgl-kernel/csrc/torch_extension.cc
sgl-kernel/csrc/torch_extension.cc
+5
-0
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+6
-0
sgl-kernel/include/utils.h
sgl-kernel/include/utils.h
+14
-5
sgl-kernel/python/sgl_kernel/__init__.py
sgl-kernel/python/sgl_kernel/__init__.py
+1
-1
sgl-kernel/python/sgl_kernel/moe.py
sgl-kernel/python/sgl_kernel/moe.py
+11
-0
sgl-kernel/setup.py
sgl-kernel/setup.py
+1
-0
sgl-kernel/tests/test_moe_topk_softmax.py
sgl-kernel/tests/test_moe_topk_softmax.py
+53
-0
No files found.
sgl-kernel/benchmark/bench_moe_topk_softmax.py
0 → 100644
View file @
61e4433c
import
itertools
import
pytest
import
torch
import
triton
from
sgl_kernel
import
topk_softmax
from
vllm
import
_custom_ops
as
vllm_custom_ops
def
vllm_topk_softmax
(
gating_output
,
topk
):
num_tokens
,
num_experts
=
gating_output
.
shape
topk_weights
=
torch
.
empty
(
(
num_tokens
,
topk
),
device
=
gating_output
.
device
,
dtype
=
torch
.
float32
)
topk_indices
=
torch
.
empty
(
(
num_tokens
,
topk
),
dtype
=
torch
.
int32
,
device
=
gating_output
.
device
)
token_expert_indices
=
torch
.
empty
(
(
num_tokens
,
topk
),
dtype
=
torch
.
int32
,
device
=
gating_output
.
device
)
torch
.
ops
.
_moe_C
.
topk_softmax
(
topk_weights
,
topk_indices
,
token_expert_indices
,
gating_output
)
return
topk_weights
,
topk_indices
def
sglang_topk_softmax
(
gating_output
,
topk
):
num_tokens
,
num_experts
=
gating_output
.
shape
topk_weights
=
torch
.
empty
(
(
num_tokens
,
topk
),
device
=
gating_output
.
device
,
dtype
=
torch
.
float32
)
topk_indices
=
torch
.
empty
(
(
num_tokens
,
topk
),
dtype
=
torch
.
int32
,
device
=
gating_output
.
device
)
token_expert_indices
=
torch
.
empty
(
(
num_tokens
,
topk
),
dtype
=
torch
.
int32
,
device
=
gating_output
.
device
)
topk_softmax
(
topk_weights
=
topk_weights
,
topk_ids
=
topk_indices
,
token_expert_indices
=
token_expert_indices
,
gating_output
=
gating_output
,
)
return
topk_weights
,
topk_indices
def
calculate_diff
(
num_tokens
,
num_experts
,
topk
):
gating_output
=
torch
.
randn
(
(
num_tokens
,
num_experts
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
weights_vllm
,
indices_vllm
=
vllm_topk_softmax
(
gating_output
.
clone
(),
topk
)
weights_sglang
,
indices_sglang
=
sglang_topk_softmax
(
gating_output
.
clone
(),
topk
)
weights_diff
=
torch
.
abs
(
weights_vllm
-
weights_sglang
).
mean
().
item
()
indices_match
=
torch
.
equal
(
indices_vllm
,
indices_sglang
)
if
(
torch
.
allclose
(
weights_vllm
,
weights_sglang
,
atol
=
1e-3
,
rtol
=
1e-3
)
and
indices_match
):
print
(
"✅ VLLM and SGLang topk_softmax implementations match"
)
else
:
print
(
f
"❌ Implementations differ: Weights diff=
{
weights_diff
}
, Indices match=
{
indices_match
}
"
)
num_tokens_range
=
[
128
,
512
,
1024
,
2048
,
4096
,
8192
,
16384
,
32768
]
num_experts_range
=
[
32
,
64
,
128
,
256
,
12
,
512
]
topk_range
=
[
1
,
2
,
4
,
8
]
configs
=
list
(
itertools
.
product
(
num_tokens_range
,
num_experts_range
,
topk_range
))
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"num_tokens"
,
"num_experts"
,
"topk"
],
x_vals
=
configs
,
line_arg
=
"provider"
,
line_vals
=
[
"sglang"
,
"vllm"
],
line_names
=
[
"SGLang"
,
"VLLM"
],
styles
=
[(
"blue"
,
"-"
),
(
"green"
,
"-"
)],
ylabel
=
"Latency (us)"
,
plot_name
=
"topk-softmax-performance"
,
args
=
{},
)
)
def
benchmark
(
num_tokens
,
num_experts
,
topk
,
provider
):
gating_output
=
torch
.
randn
(
(
num_tokens
,
num_experts
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
if
provider
==
"vllm"
or
provider
==
"vllm1"
:
fn
=
lambda
:
vllm_topk_softmax
(
gating_output
,
topk
)
elif
provider
==
"sglang"
or
provider
==
"sglang1"
:
fn
=
lambda
:
sglang_topk_softmax
(
gating_output
,
topk
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
fn
,
quantiles
=
quantiles
)
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
if
__name__
==
"__main__"
:
configs
=
[
(
20
,
256
,
4
),
(
20
,
256
,
8
),
(
20
,
12
,
4
),
(
20
,
12
,
1
),
(
20
,
512
,
4
),
(
20
,
512
,
1
),
]
for
num_tokens
,
num_experts
,
topk
in
configs
:
calculate_diff
(
num_tokens
,
num_experts
,
topk
)
benchmark
.
run
(
print_data
=
True
)
sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu
0 → 100644
View file @
61e4433c
This diff is collapsed.
Click to expand it.
sgl-kernel/csrc/torch_extension.cc
View file @
61e4433c
...
@@ -117,6 +117,11 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
...
@@ -117,6 +117,11 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()"
);
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()"
);
m
.
impl
(
"moe_align_block_size"
,
torch
::
kCUDA
,
&
moe_align_block_size
);
m
.
impl
(
"moe_align_block_size"
,
torch
::
kCUDA
,
&
moe_align_block_size
);
m
.
def
(
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"token_expert_indices, Tensor gating_output) -> ()"
);
m
.
impl
(
"topk_softmax"
,
torch
::
kCUDA
,
&
topk_softmax
);
/*
/*
* From csrc/speculative
* From csrc/speculative
*/
*/
...
...
sgl-kernel/include/sgl_kernel_ops.h
View file @
61e4433c
...
@@ -173,6 +173,12 @@ void moe_align_block_size(
...
@@ -173,6 +173,12 @@ void moe_align_block_size(
torch
::
Tensor
token_cnts_buffer
,
torch
::
Tensor
token_cnts_buffer
,
torch
::
Tensor
cumsum_buffer
);
torch
::
Tensor
cumsum_buffer
);
void
topk_softmax
(
torch
::
Tensor
&
topk_weights
,
torch
::
Tensor
&
topk_indices
,
torch
::
Tensor
&
token_expert_indices
,
torch
::
Tensor
&
gating_output
);
/*
/*
* From csrc/speculative
* From csrc/speculative
*/
*/
...
...
sgl-kernel/include/utils.h
View file @
61e4433c
...
@@ -65,6 +65,15 @@ inline int getSMVersion() {
...
@@ -65,6 +65,15 @@ inline int getSMVersion() {
return
sm_major
*
10
+
sm_minor
;
return
sm_major
*
10
+
sm_minor
;
}
}
// SGLANG_SHFL_XOR_* adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/csrc/cuda_compat.h#L19-L28
#ifndef USE_ROCM
#define SGLANG_SHFL_XOR_SYNC(mask, var, lane_mask) __shfl_xor_sync((mask), (var), (lane_mask))
#define SGLANG_SHFL_XOR_SYNC_WIDTH(mask, var, lane_mask, width) __shfl_xor_sync((mask), (var), (lane_mask), (width))
#else
#define SGLANG_SHFL_XOR_SYNC(mask, var, lane_mask) __shfl_xor((var), (lane_mask))
#define SGLANG_SHFL_XOR_SYNC_WIDTH(mask, var, lane_mask, width) __shfl_xor((var), (lane_mask), (width))
#endif
#ifndef USE_ROCM
#ifndef USE_ROCM
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(pytorch_dtype, c_type, ...) \
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(pytorch_dtype, c_type, ...) \
[&]() -> bool { \
[&]() -> bool { \
...
@@ -117,11 +126,11 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
...
@@ -117,11 +126,11 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
}
}
__device__
__forceinline__
float
warpReduceMax
(
float
max_value
)
{
__device__
__forceinline__
float
warpReduceMax
(
float
max_value
)
{
max_value
=
fmaxf
(
max_value
,
__shfl_xor_sync
(
0xffffffff
,
max_value
,
16
));
max_value
=
fmaxf
(
max_value
,
SGLANG_SHFL_XOR_SYNC
(
0xffffffff
,
max_value
,
16
));
max_value
=
fmaxf
(
max_value
,
__shfl_xor_sync
(
0xffffffff
,
max_value
,
8
));
max_value
=
fmaxf
(
max_value
,
SGLANG_SHFL_XOR_SYNC
(
0xffffffff
,
max_value
,
8
));
max_value
=
fmaxf
(
max_value
,
__shfl_xor_sync
(
0xffffffff
,
max_value
,
4
));
max_value
=
fmaxf
(
max_value
,
SGLANG_SHFL_XOR_SYNC
(
0xffffffff
,
max_value
,
4
));
max_value
=
fmaxf
(
max_value
,
__shfl_xor_sync
(
0xffffffff
,
max_value
,
2
));
max_value
=
fmaxf
(
max_value
,
SGLANG_SHFL_XOR_SYNC
(
0xffffffff
,
max_value
,
2
));
max_value
=
fmaxf
(
max_value
,
__shfl_xor_sync
(
0xffffffff
,
max_value
,
1
));
max_value
=
fmaxf
(
max_value
,
SGLANG_SHFL_XOR_SYNC
(
0xffffffff
,
max_value
,
1
));
return
max_value
;
return
max_value
;
}
}
...
...
sgl-kernel/python/sgl_kernel/__init__.py
View file @
61e4433c
...
@@ -33,7 +33,7 @@ from sgl_kernel.gemm import (
...
@@ -33,7 +33,7 @@ from sgl_kernel.gemm import (
sgl_per_token_group_quant_fp8
,
sgl_per_token_group_quant_fp8
,
sgl_per_token_quant_fp8
,
sgl_per_token_quant_fp8
,
)
)
from
sgl_kernel.moe
import
moe_align_block_size
from
sgl_kernel.moe
import
moe_align_block_size
,
topk_softmax
from
sgl_kernel.sampling
import
(
from
sgl_kernel.sampling
import
(
min_p_sampling_from_probs
,
min_p_sampling_from_probs
,
top_k_renorm_prob
,
top_k_renorm_prob
,
...
...
sgl-kernel/python/sgl_kernel/moe.py
View file @
61e4433c
...
@@ -21,3 +21,14 @@ def moe_align_block_size(
...
@@ -21,3 +21,14 @@ def moe_align_block_size(
token_cnts_buffer
,
token_cnts_buffer
,
cumsum_buffer
,
cumsum_buffer
,
)
)
def
topk_softmax
(
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
token_expert_indices
:
torch
.
Tensor
,
gating_output
:
float
,
)
->
None
:
torch
.
ops
.
sgl_kernel
.
topk_softmax
(
topk_weights
,
topk_ids
,
token_expert_indices
,
gating_output
)
sgl-kernel/setup.py
View file @
61e4433c
...
@@ -157,6 +157,7 @@ sources = [
...
@@ -157,6 +157,7 @@ sources = [
"csrc/gemm/per_token_quant_fp8.cu"
,
"csrc/gemm/per_token_quant_fp8.cu"
,
"csrc/gemm/per_tensor_quant_fp8.cu"
,
"csrc/gemm/per_tensor_quant_fp8.cu"
,
"csrc/moe/moe_align_kernel.cu"
,
"csrc/moe/moe_align_kernel.cu"
,
"csrc/moe/moe_topk_softmax_kernels.cu"
,
"csrc/speculative/eagle_utils.cu"
,
"csrc/speculative/eagle_utils.cu"
,
"csrc/speculative/speculative_sampling.cu"
,
"csrc/speculative/speculative_sampling.cu"
,
"csrc/torch_extension.cc"
,
"csrc/torch_extension.cc"
,
...
...
sgl-kernel/tests/test_moe_topk_softmax.py
0 → 100644
View file @
61e4433c
import
itertools
import
pytest
import
torch
from
sgl_kernel
import
topk_softmax
@
pytest
.
mark
.
parametrize
(
"num_tokens, num_experts, topk"
,
list
(
itertools
.
product
(
[
1
,
16
,
128
,
512
,
1024
,
2048
],
# num_tokens
[
4
,
8
,
16
,
32
,
64
,
128
,
256
],
# num_experts
[
1
,
2
,
4
],
# topk
)
),
)
def
test_topk_softmax
(
num_tokens
,
num_experts
,
topk
):
gating_output
=
torch
.
randn
(
(
num_tokens
,
num_experts
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
topk_weights
=
torch
.
empty
((
num_tokens
,
topk
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
topk_indices
=
torch
.
empty
((
num_tokens
,
topk
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
token_expert_indices
=
torch
.
empty
(
(
num_tokens
,
topk
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
topk_softmax
(
topk_weights
,
topk_indices
,
token_expert_indices
,
gating_output
,
)
# Native torch implementation
softmax_output
=
torch
.
softmax
(
gating_output
,
dim
=-
1
)
topk_weights_ref
,
topk_indices_ref
=
torch
.
topk
(
softmax_output
,
topk
,
dim
=-
1
)
# Verify the top-k weights and indices match the torch native ones
assert
torch
.
allclose
(
topk_weights_ref
,
topk_weights
,
atol
=
1e-3
,
rtol
=
1e-3
),
f
"Weights mismatch: torch=
{
topk_indices_ref
}
vs SGLang=
{
topk_weights
}
"
assert
torch
.
equal
(
topk_indices_ref
,
topk_indices
),
f
"Indices mismatch: torch=
{
topk_indices_ref
}
, SGLang=
{
topk_indices
}
"
print
(
"✅ Native torch and custom kernel implementations match."
)
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment