Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8374a96e
Unverified
Commit
8374a96e
authored
Oct 21, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Oct 21, 2025
Browse files
piecewise cuda graph support qwen3-moe (#11845)
parent
74de76c6
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
71 additions
and
6 deletions
+71
-6
python/sglang/srt/layers/communicator.py
python/sglang/srt/layers/communicator.py
+6
-5
python/sglang/srt/layers/quantization/fp8_kernel.py
python/sglang/srt/layers/quantization/fp8_kernel.py
+18
-0
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+7
-1
test/srt/test_piecewise_cuda_graph.py
test/srt/test_piecewise_cuda_graph.py
+40
-0
No files found.
python/sglang/srt/layers/communicator.py
View file @
8374a96e
...
@@ -212,6 +212,10 @@ class LayerCommunicator:
...
@@ -212,6 +212,10 @@ class LayerCommunicator:
)
)
)
)
self
.
_speculative_algo
=
SpeculativeAlgorithm
.
from_string
(
get_global_server_args
().
speculative_algorithm
)
def
prepare_attn
(
def
prepare_attn
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
...
@@ -315,13 +319,10 @@ class LayerCommunicator:
...
@@ -315,13 +319,10 @@ class LayerCommunicator:
def
should_fuse_mlp_allreduce_with_next_layer
(
def
should_fuse_mlp_allreduce_with_next_layer
(
self
,
forward_batch
:
ForwardBatch
self
,
forward_batch
:
ForwardBatch
)
->
bool
:
)
->
bool
:
speculative_algo
=
SpeculativeAlgorithm
.
from_string
(
get_global_server_args
().
speculative_algorithm
)
if
(
if
(
is_dp_attention_enabled
()
is_dp_attention_enabled
()
and
speculative_algo
is
not
None
and
self
.
_
speculative_algo
is
not
None
and
speculative_algo
.
is_eagle
()
and
self
.
_
speculative_algo
.
is_eagle
()
):
):
return
False
return
False
...
...
python/sglang/srt/layers/quantization/fp8_kernel.py
View file @
8374a96e
...
@@ -1831,3 +1831,21 @@ def triton_scaled_mm(
...
@@ -1831,3 +1831,21 @@ def triton_scaled_mm(
)
)
return
result
.
to
(
out_dtype
)
return
result
.
to
(
out_dtype
)
if
_is_cuda
:
if
enable_sgl_per_token_group_quant_8bit
:
@
torch
.
library
.
register_fake
(
"sgl_kernel::sgl_per_token_group_quant_8bit"
)
def
_
(
input
,
output_q
,
output_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
,
scale_ue8m0
):
return
else
:
@
torch
.
library
.
register_fake
(
"sgl_kernel::sgl_per_token_group_quant_fp8"
)
def
_
(
input
,
output_q
,
output_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
,
scale_ue8m0
):
return
python/sglang/srt/models/qwen2_moe.py
View file @
8374a96e
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
import
logging
import
logging
from
contextlib
import
nullcontext
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
...
@@ -590,7 +591,12 @@ class Qwen2MoeModel(nn.Module):
...
@@ -590,7 +591,12 @@ class Qwen2MoeModel(nn.Module):
if
residual
is
not
None
if
residual
is
not
None
else
hidden_states
else
hidden_states
)
)
with
get_global_expert_distribution_recorder
().
with_current_layer
(
i
):
ctx
=
(
nullcontext
()
if
get_global_server_args
().
enable_piecewise_cuda_graph
else
get_global_expert_distribution_recorder
().
with_current_layer
(
i
)
)
with
ctx
:
layer
=
self
.
layers
[
i
]
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
forward_batch
,
residual
positions
,
hidden_states
,
forward_batch
,
residual
...
...
test/srt/test_piecewise_cuda_graph.py
View file @
8374a96e
...
@@ -55,5 +55,45 @@ class TestPiecewiseCudaGraphBenchmark(CustomTestCase):
...
@@ -55,5 +55,45 @@ class TestPiecewiseCudaGraphBenchmark(CustomTestCase):
self
.
assertLess
(
prefill_latency
,
0.015
)
self
.
assertLess
(
prefill_latency
,
0.015
)
class
TestPiecewiseCudaGraphQwen3MoE
(
CustomTestCase
):
"""Test piecewise CUDA graph with Qwen3-Coder-30B-A3B-Instruct MoE model"""
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
"Qwen/Qwen3-Coder-30B-A3B-Instruct"
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--enable-piecewise-cuda-graph"
,
"--piecewise-cuda-graph-compiler"
,
"eager"
,
],
)
@
classmethod
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
def
test_gsm8k_accuracy
(
self
):
"""Test GSM8K accuracy with 8-shot setting"""
num_examples
=
2000
args
=
SimpleNamespace
(
base_url
=
self
.
base_url
,
model
=
self
.
model
,
eval_name
=
"mgsm_en"
,
num_examples
=
num_examples
,
num_threads
=
min
(
num_examples
,
1024
),
)
metrics
=
run_eval
(
args
)
print
(
f
"GSM8K Accuracy:
{
metrics
[
'score'
]:.
3
f
}
"
)
self
.
assertGreaterEqual
(
metrics
[
"score"
],
0.90
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment