Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
61e4433c
"vscode:/vscode.git/clone" did not exist on "c4e9ebe3a480128818eeda4a3ce59ee7a8da53bf"
Unverified
Commit
61e4433c
authored
Mar 14, 2025
by
Qingquan Song
Committed by
GitHub
Mar 14, 2025
Browse files
Add moe topk softmax templated from vllm (#4302)
parent
660305c3
Changes
9
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
716 additions
and
6 deletions
+716
-6
sgl-kernel/benchmark/bench_moe_topk_softmax.py
sgl-kernel/benchmark/bench_moe_topk_softmax.py
+120
-0
sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu
sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu
+505
-0
sgl-kernel/csrc/torch_extension.cc
sgl-kernel/csrc/torch_extension.cc
+5
-0
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+6
-0
sgl-kernel/include/utils.h
sgl-kernel/include/utils.h
+14
-5
sgl-kernel/python/sgl_kernel/__init__.py
sgl-kernel/python/sgl_kernel/__init__.py
+1
-1
sgl-kernel/python/sgl_kernel/moe.py
sgl-kernel/python/sgl_kernel/moe.py
+11
-0
sgl-kernel/setup.py
sgl-kernel/setup.py
+1
-0
sgl-kernel/tests/test_moe_topk_softmax.py
sgl-kernel/tests/test_moe_topk_softmax.py
+53
-0
No files found.
sgl-kernel/benchmark/bench_moe_topk_softmax.py
0 → 100644
View file @
61e4433c
import
itertools
import
pytest
import
torch
import
triton
from
sgl_kernel
import
topk_softmax
from
vllm
import
_custom_ops
as
vllm_custom_ops
def
vllm_topk_softmax
(
gating_output
,
topk
):
num_tokens
,
num_experts
=
gating_output
.
shape
topk_weights
=
torch
.
empty
(
(
num_tokens
,
topk
),
device
=
gating_output
.
device
,
dtype
=
torch
.
float32
)
topk_indices
=
torch
.
empty
(
(
num_tokens
,
topk
),
dtype
=
torch
.
int32
,
device
=
gating_output
.
device
)
token_expert_indices
=
torch
.
empty
(
(
num_tokens
,
topk
),
dtype
=
torch
.
int32
,
device
=
gating_output
.
device
)
torch
.
ops
.
_moe_C
.
topk_softmax
(
topk_weights
,
topk_indices
,
token_expert_indices
,
gating_output
)
return
topk_weights
,
topk_indices
def
sglang_topk_softmax
(
gating_output
,
topk
):
num_tokens
,
num_experts
=
gating_output
.
shape
topk_weights
=
torch
.
empty
(
(
num_tokens
,
topk
),
device
=
gating_output
.
device
,
dtype
=
torch
.
float32
)
topk_indices
=
torch
.
empty
(
(
num_tokens
,
topk
),
dtype
=
torch
.
int32
,
device
=
gating_output
.
device
)
token_expert_indices
=
torch
.
empty
(
(
num_tokens
,
topk
),
dtype
=
torch
.
int32
,
device
=
gating_output
.
device
)
topk_softmax
(
topk_weights
=
topk_weights
,
topk_ids
=
topk_indices
,
token_expert_indices
=
token_expert_indices
,
gating_output
=
gating_output
,
)
return
topk_weights
,
topk_indices
def
calculate_diff
(
num_tokens
,
num_experts
,
topk
):
gating_output
=
torch
.
randn
(
(
num_tokens
,
num_experts
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
weights_vllm
,
indices_vllm
=
vllm_topk_softmax
(
gating_output
.
clone
(),
topk
)
weights_sglang
,
indices_sglang
=
sglang_topk_softmax
(
gating_output
.
clone
(),
topk
)
weights_diff
=
torch
.
abs
(
weights_vllm
-
weights_sglang
).
mean
().
item
()
indices_match
=
torch
.
equal
(
indices_vllm
,
indices_sglang
)
if
(
torch
.
allclose
(
weights_vllm
,
weights_sglang
,
atol
=
1e-3
,
rtol
=
1e-3
)
and
indices_match
):
print
(
"✅ VLLM and SGLang topk_softmax implementations match"
)
else
:
print
(
f
"❌ Implementations differ: Weights diff=
{
weights_diff
}
, Indices match=
{
indices_match
}
"
)
num_tokens_range
=
[
128
,
512
,
1024
,
2048
,
4096
,
8192
,
16384
,
32768
]
num_experts_range
=
[
32
,
64
,
128
,
256
,
12
,
512
]
topk_range
=
[
1
,
2
,
4
,
8
]
configs
=
list
(
itertools
.
product
(
num_tokens_range
,
num_experts_range
,
topk_range
))
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"num_tokens"
,
"num_experts"
,
"topk"
],
x_vals
=
configs
,
line_arg
=
"provider"
,
line_vals
=
[
"sglang"
,
"vllm"
],
line_names
=
[
"SGLang"
,
"VLLM"
],
styles
=
[(
"blue"
,
"-"
),
(
"green"
,
"-"
)],
ylabel
=
"Latency (us)"
,
plot_name
=
"topk-softmax-performance"
,
args
=
{},
)
)
def
benchmark
(
num_tokens
,
num_experts
,
topk
,
provider
):
gating_output
=
torch
.
randn
(
(
num_tokens
,
num_experts
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
if
provider
==
"vllm"
or
provider
==
"vllm1"
:
fn
=
lambda
:
vllm_topk_softmax
(
gating_output
,
topk
)
elif
provider
==
"sglang"
or
provider
==
"sglang1"
:
fn
=
lambda
:
sglang_topk_softmax
(
gating_output
,
topk
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
fn
,
quantiles
=
quantiles
)
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
if
__name__
==
"__main__"
:
configs
=
[
(
20
,
256
,
4
),
(
20
,
256
,
8
),
(
20
,
12
,
4
),
(
20
,
12
,
1
),
(
20
,
512
,
4
),
(
20
,
512
,
1
),
]
for
num_tokens
,
num_experts
,
topk
in
configs
:
calculate_diff
(
num_tokens
,
num_experts
,
topk
)
benchmark
.
run
(
print_data
=
True
)
sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu
0 → 100644
View file @
61e4433c
This diff is collapsed.
Click to expand it.
sgl-kernel/csrc/torch_extension.cc
View file @
61e4433c
...
...
@@ -117,6 +117,11 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()"
);
m
.
impl
(
"moe_align_block_size"
,
torch
::
kCUDA
,
&
moe_align_block_size
);
m
.
def
(
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"token_expert_indices, Tensor gating_output) -> ()"
);
m
.
impl
(
"topk_softmax"
,
torch
::
kCUDA
,
&
topk_softmax
);
/*
* From csrc/speculative
*/
...
...
sgl-kernel/include/sgl_kernel_ops.h
View file @
61e4433c
...
...
@@ -173,6 +173,12 @@ void moe_align_block_size(
torch
::
Tensor
token_cnts_buffer
,
torch
::
Tensor
cumsum_buffer
);
void
topk_softmax
(
torch
::
Tensor
&
topk_weights
,
torch
::
Tensor
&
topk_indices
,
torch
::
Tensor
&
token_expert_indices
,
torch
::
Tensor
&
gating_output
);
/*
* From csrc/speculative
*/
...
...
sgl-kernel/include/utils.h
View file @
61e4433c
...
...
@@ -65,6 +65,15 @@ inline int getSMVersion() {
return
sm_major
*
10
+
sm_minor
;
}
// SGLANG_SHFL_XOR_* adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/csrc/cuda_compat.h#L19-L28
#ifndef USE_ROCM
#define SGLANG_SHFL_XOR_SYNC(mask, var, lane_mask) __shfl_xor_sync((mask), (var), (lane_mask))
#define SGLANG_SHFL_XOR_SYNC_WIDTH(mask, var, lane_mask, width) __shfl_xor_sync((mask), (var), (lane_mask), (width))
#else
#define SGLANG_SHFL_XOR_SYNC(mask, var, lane_mask) __shfl_xor((var), (lane_mask))
#define SGLANG_SHFL_XOR_SYNC_WIDTH(mask, var, lane_mask, width) __shfl_xor((var), (lane_mask), (width))
#endif
#ifndef USE_ROCM
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(pytorch_dtype, c_type, ...) \
[&]() -> bool { \
...
...
@@ -117,11 +126,11 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
}
__device__
__forceinline__
float
warpReduceMax
(
float
max_value
)
{
max_value
=
fmaxf
(
max_value
,
__shfl_xor_sync
(
0xffffffff
,
max_value
,
16
));
max_value
=
fmaxf
(
max_value
,
__shfl_xor_sync
(
0xffffffff
,
max_value
,
8
));
max_value
=
fmaxf
(
max_value
,
__shfl_xor_sync
(
0xffffffff
,
max_value
,
4
));
max_value
=
fmaxf
(
max_value
,
__shfl_xor_sync
(
0xffffffff
,
max_value
,
2
));
max_value
=
fmaxf
(
max_value
,
__shfl_xor_sync
(
0xffffffff
,
max_value
,
1
));
max_value
=
fmaxf
(
max_value
,
SGLANG_SHFL_XOR_SYNC
(
0xffffffff
,
max_value
,
16
));
max_value
=
fmaxf
(
max_value
,
SGLANG_SHFL_XOR_SYNC
(
0xffffffff
,
max_value
,
8
));
max_value
=
fmaxf
(
max_value
,
SGLANG_SHFL_XOR_SYNC
(
0xffffffff
,
max_value
,
4
));
max_value
=
fmaxf
(
max_value
,
SGLANG_SHFL_XOR_SYNC
(
0xffffffff
,
max_value
,
2
));
max_value
=
fmaxf
(
max_value
,
SGLANG_SHFL_XOR_SYNC
(
0xffffffff
,
max_value
,
1
));
return
max_value
;
}
...
...
sgl-kernel/python/sgl_kernel/__init__.py
View file @
61e4433c
...
...
@@ -33,7 +33,7 @@ from sgl_kernel.gemm import (
sgl_per_token_group_quant_fp8
,
sgl_per_token_quant_fp8
,
)
from
sgl_kernel.moe
import
moe_align_block_size
from
sgl_kernel.moe
import
moe_align_block_size
,
topk_softmax
from
sgl_kernel.sampling
import
(
min_p_sampling_from_probs
,
top_k_renorm_prob
,
...
...
sgl-kernel/python/sgl_kernel/moe.py
View file @
61e4433c
...
...
@@ -21,3 +21,14 @@ def moe_align_block_size(
token_cnts_buffer
,
cumsum_buffer
,
)
def
topk_softmax
(
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
token_expert_indices
:
torch
.
Tensor
,
gating_output
:
float
,
)
->
None
:
torch
.
ops
.
sgl_kernel
.
topk_softmax
(
topk_weights
,
topk_ids
,
token_expert_indices
,
gating_output
)
sgl-kernel/setup.py
View file @
61e4433c
...
...
@@ -157,6 +157,7 @@ sources = [
"csrc/gemm/per_token_quant_fp8.cu"
,
"csrc/gemm/per_tensor_quant_fp8.cu"
,
"csrc/moe/moe_align_kernel.cu"
,
"csrc/moe/moe_topk_softmax_kernels.cu"
,
"csrc/speculative/eagle_utils.cu"
,
"csrc/speculative/speculative_sampling.cu"
,
"csrc/torch_extension.cc"
,
...
...
sgl-kernel/tests/test_moe_topk_softmax.py
0 → 100644
View file @
61e4433c
import
itertools
import
pytest
import
torch
from
sgl_kernel
import
topk_softmax
@
pytest
.
mark
.
parametrize
(
"num_tokens, num_experts, topk"
,
list
(
itertools
.
product
(
[
1
,
16
,
128
,
512
,
1024
,
2048
],
# num_tokens
[
4
,
8
,
16
,
32
,
64
,
128
,
256
],
# num_experts
[
1
,
2
,
4
],
# topk
)
),
)
def
test_topk_softmax
(
num_tokens
,
num_experts
,
topk
):
gating_output
=
torch
.
randn
(
(
num_tokens
,
num_experts
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
topk_weights
=
torch
.
empty
((
num_tokens
,
topk
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
topk_indices
=
torch
.
empty
((
num_tokens
,
topk
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
token_expert_indices
=
torch
.
empty
(
(
num_tokens
,
topk
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
topk_softmax
(
topk_weights
,
topk_indices
,
token_expert_indices
,
gating_output
,
)
# Native torch implementation
softmax_output
=
torch
.
softmax
(
gating_output
,
dim
=-
1
)
topk_weights_ref
,
topk_indices_ref
=
torch
.
topk
(
softmax_output
,
topk
,
dim
=-
1
)
# Verify the top-k weights and indices match the torch native ones
assert
torch
.
allclose
(
topk_weights_ref
,
topk_weights
,
atol
=
1e-3
,
rtol
=
1e-3
),
f
"Weights mismatch: torch=
{
topk_indices_ref
}
vs SGLang=
{
topk_weights
}
"
assert
torch
.
equal
(
topk_indices_ref
,
topk_indices
),
f
"Indices mismatch: torch=
{
topk_indices_ref
}
, SGLang=
{
topk_indices
}
"
print
(
"✅ Native torch and custom kernel implementations match."
)
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment