Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9e2f7252
Unverified
Commit
9e2f7252
authored
Sep 11, 2025
by
Yi Zhang
Committed by
GitHub
Sep 10, 2025
Browse files
add dual stream for qwen2_moe (#10252)
parent
21176b00
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
53 additions
and
11 deletions
+53
-11
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+53
-11
No files found.
python/sglang/srt/models/qwen2_moe.py
View file @
9e2f7252
...
...
@@ -65,10 +65,12 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.two_batch_overlap
import
model_forward_maybe_tbo
from
sglang.srt.utils
import
add_prefix
,
make_layers
from
sglang.srt.utils
import
add_prefix
,
is_cuda
,
make_layers
logger
=
logging
.
getLogger
(
__name__
)
_is_cuda
=
is_cuda
()
class
Qwen2MoeMLP
(
nn
.
Module
):
def
__init__
(
...
...
@@ -122,11 +124,13 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
layer_id
:
int
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
alt_stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
layer_id
=
layer_id
self
.
alt_stream
=
alt_stream
if
self
.
tp_size
>
config
.
num_experts
:
raise
ValueError
(
f
"Tensor parallel size
{
self
.
tp_size
}
is greater than "
...
...
@@ -168,14 +172,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
self
.
shared_expert
=
None
self
.
shared_expert_gate
=
torch
.
nn
.
Linear
(
config
.
hidden_size
,
1
,
bias
=
False
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
Optional
[
ForwardBatch
]
=
None
,
use_reduce_scatter
:
bool
=
False
,
)
->
torch
.
Tensor
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
def
_forward_shared_experts
(
self
,
hidden_states
:
torch
.
Tensor
):
shared_output
=
None
if
self
.
shared_expert
is
not
None
:
shared_output
=
self
.
shared_expert
(
hidden_states
)
...
...
@@ -183,11 +180,51 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
shared_output
=
(
F
.
sigmoid
(
self
.
shared_expert_gate
(
hidden_states
))
*
shared_output
)
return
shared_output
def
_forward_router_experts
(
self
,
hidden_states
:
torch
.
Tensor
):
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
final_hidden_states
=
self
.
experts
(
hidden_states
,
topk_output
)
return
self
.
experts
(
hidden_states
,
topk_output
)
def
forward_normal_dual_stream
(
self
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
current_stream
=
torch
.
cuda
.
current_stream
()
self
.
alt_stream
.
wait_stream
(
current_stream
)
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
)
with
torch
.
cuda
.
stream
(
self
.
alt_stream
):
router_output
=
self
.
_forward_router_experts
(
hidden_states
)
current_stream
.
wait_stream
(
self
.
alt_stream
)
return
router_output
,
shared_output
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
Optional
[
ForwardBatch
]
=
None
,
use_reduce_scatter
:
bool
=
False
,
)
->
torch
.
Tensor
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
DUAL_STREAM_TOKEN_THRESHOLD
=
1024
if
(
self
.
alt_stream
is
not
None
and
hidden_states
.
shape
[
0
]
>
0
and
hidden_states
.
shape
[
0
]
<=
DUAL_STREAM_TOKEN_THRESHOLD
):
final_hidden_states
,
shared_output
=
self
.
forward_normal_dual_stream
(
hidden_states
)
else
:
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
)
final_hidden_states
=
self
.
_forward_router_experts
(
hidden_states
)
if
shared_output
is
not
None
:
final_hidden_states
=
final_hidden_states
+
shared_output
if
self
.
tp_size
>
1
and
not
use_reduce_scatter
:
...
...
@@ -346,6 +383,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
layer_id
=
layer_id
,
config
=
config
,
quant_config
=
quant_config
,
alt_stream
=
alt_stream
,
prefix
=
add_prefix
(
"mlp"
,
prefix
),
)
else
:
...
...
@@ -528,8 +566,12 @@ class Qwen2MoeForCausalLM(nn.Module):
self
.
pp_group
=
get_pp_group
()
self
.
config
=
config
self
.
quant_config
=
quant_config
alt_stream
=
torch
.
cuda
.
Stream
()
if
_is_cuda
else
None
self
.
model
=
Qwen2MoeModel
(
config
,
quant_config
,
prefix
=
add_prefix
(
"model"
,
prefix
)
config
,
quant_config
,
prefix
=
add_prefix
(
"model"
,
prefix
),
alt_stream
=
alt_stream
,
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment