Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
8374a96e
"server/text_generation_server/adapters/lora.py" did not exist on "a2a97b05d6c59bb95e54124a68c69b17851d3480"
Unverified
Commit
8374a96e
authored
Oct 21, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Oct 21, 2025
Browse files
piecewise cuda graph support qwen3-moe (#11845)
parent
74de76c6
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
71 additions
and
6 deletions
+71
-6
python/sglang/srt/layers/communicator.py
python/sglang/srt/layers/communicator.py
+6
-5
python/sglang/srt/layers/quantization/fp8_kernel.py
python/sglang/srt/layers/quantization/fp8_kernel.py
+18
-0
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+7
-1
test/srt/test_piecewise_cuda_graph.py
test/srt/test_piecewise_cuda_graph.py
+40
-0
No files found.
python/sglang/srt/layers/communicator.py
View file @
8374a96e
...
...
@@ -212,6 +212,10 @@ class LayerCommunicator:
)
)
self
.
_speculative_algo
=
SpeculativeAlgorithm
.
from_string
(
get_global_server_args
().
speculative_algorithm
)
def
prepare_attn
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
@@ -315,13 +319,10 @@ class LayerCommunicator:
def
should_fuse_mlp_allreduce_with_next_layer
(
self
,
forward_batch
:
ForwardBatch
)
->
bool
:
speculative_algo
=
SpeculativeAlgorithm
.
from_string
(
get_global_server_args
().
speculative_algorithm
)
if
(
is_dp_attention_enabled
()
and
speculative_algo
is
not
None
and
speculative_algo
.
is_eagle
()
and
self
.
_
speculative_algo
is
not
None
and
self
.
_
speculative_algo
.
is_eagle
()
):
return
False
...
...
python/sglang/srt/layers/quantization/fp8_kernel.py
View file @
8374a96e
...
...
@@ -1831,3 +1831,21 @@ def triton_scaled_mm(
)
return
result
.
to
(
out_dtype
)
if
_is_cuda
:
if
enable_sgl_per_token_group_quant_8bit
:
@
torch
.
library
.
register_fake
(
"sgl_kernel::sgl_per_token_group_quant_8bit"
)
def
_
(
input
,
output_q
,
output_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
,
scale_ue8m0
):
return
else
:
@
torch
.
library
.
register_fake
(
"sgl_kernel::sgl_per_token_group_quant_fp8"
)
def
_
(
input
,
output_q
,
output_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
,
scale_ue8m0
):
return
python/sglang/srt/models/qwen2_moe.py
View file @
8374a96e
...
...
@@ -17,6 +17,7 @@
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
import
logging
from
contextlib
import
nullcontext
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
torch
...
...
@@ -590,7 +591,12 @@ class Qwen2MoeModel(nn.Module):
if
residual
is
not
None
else
hidden_states
)
with
get_global_expert_distribution_recorder
().
with_current_layer
(
i
):
ctx
=
(
nullcontext
()
if
get_global_server_args
().
enable_piecewise_cuda_graph
else
get_global_expert_distribution_recorder
().
with_current_layer
(
i
)
)
with
ctx
:
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
forward_batch
,
residual
...
...
test/srt/test_piecewise_cuda_graph.py
View file @
8374a96e
...
...
@@ -55,5 +55,45 @@ class TestPiecewiseCudaGraphBenchmark(CustomTestCase):
self
.
assertLess
(
prefill_latency
,
0.015
)
class
TestPiecewiseCudaGraphQwen3MoE
(
CustomTestCase
):
"""Test piecewise CUDA graph with Qwen3-Coder-30B-A3B-Instruct MoE model"""
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
"Qwen/Qwen3-Coder-30B-A3B-Instruct"
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--enable-piecewise-cuda-graph"
,
"--piecewise-cuda-graph-compiler"
,
"eager"
,
],
)
@
classmethod
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
def
test_gsm8k_accuracy
(
self
):
"""Test GSM8K accuracy with 8-shot setting"""
num_examples
=
2000
args
=
SimpleNamespace
(
base_url
=
self
.
base_url
,
model
=
self
.
model
,
eval_name
=
"mgsm_en"
,
num_examples
=
num_examples
,
num_threads
=
min
(
num_examples
,
1024
),
)
metrics
=
run_eval
(
args
)
print
(
f
"GSM8K Accuracy:
{
metrics
[
'score'
]:.
3
f
}
"
)
self
.
assertGreaterEqual
(
metrics
[
"score"
],
0.90
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment