Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
65f09131
Unverified
Commit
65f09131
authored
May 26, 2025
by
Yi Zhang
Committed by
GitHub
May 25, 2025
Browse files
refactor qwen moe code, use communicator to support tp+dp (#6581)
parent
fc419b62
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
78 additions
and
379 deletions
+78
-379
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+1
-8
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+35
-185
python/sglang/srt/models/qwen3_moe.py
python/sglang/srt/models/qwen3_moe.py
+33
-185
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+8
-0
test/srt/test_disaggregation.py
test/srt/test_disaggregation.py
+1
-1
No files found.
python/sglang/srt/models/deepseek_v2.py
View file @
65f09131
...
@@ -95,6 +95,7 @@ from sglang.srt.utils import (
...
@@ -95,6 +95,7 @@ from sglang.srt.utils import (
get_int_env_var
,
get_int_env_var
,
is_cuda
,
is_cuda
,
is_hip
,
is_hip
,
is_non_idle_and_non_empty
,
log_info_on_rank0
,
log_info_on_rank0
,
)
)
...
@@ -206,14 +207,6 @@ class MoEGate(nn.Module):
...
@@ -206,14 +207,6 @@ class MoEGate(nn.Module):
return
logits
return
logits
def
is_non_idle_and_non_empty
(
forward_mode
,
hidden_states
):
return
(
(
forward_mode
is
not
None
)
and
not
forward_mode
.
is_idle
()
and
hidden_states
.
shape
[
0
]
>
0
)
class
DeepseekV2MoE
(
nn
.
Module
):
class
DeepseekV2MoE
(
nn
.
Module
):
def
__init__
(
def
__init__
(
...
...
python/sglang/srt/models/qwen2_moe.py
View file @
65f09131
...
@@ -32,6 +32,7 @@ from sglang.srt.distributed import (
...
@@ -32,6 +32,7 @@ from sglang.srt.distributed import (
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_reduce
,
)
)
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.communicator
import
LayerCommunicator
,
LayerScatterModes
from
sglang.srt.layers.dp_attention
import
(
from
sglang.srt.layers.dp_attention
import
(
attn_tp_all_gather
,
attn_tp_all_gather
,
attn_tp_reduce_scatter
,
attn_tp_reduce_scatter
,
...
@@ -49,7 +50,7 @@ from sglang.srt.layers.linear import (
...
@@ -49,7 +50,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear
,
RowParallelLinear
,
)
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
,
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
,
LogitsProcessorOutput
from
sglang.srt.layers.moe.ep_moe.layer
import
EPMoE
from
sglang.srt.layers.moe.ep_moe.layer
import
EPMoE
,
get_moe_impl_class
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
...
@@ -114,22 +115,22 @@ class Qwen2MoeMLP(nn.Module):
...
@@ -114,22 +115,22 @@ class Qwen2MoeMLP(nn.Module):
class
Qwen2MoeSparseMoeBlock
(
nn
.
Module
):
class
Qwen2MoeSparseMoeBlock
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
layer_id
:
int
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
layer_id
=
layer_id
if
self
.
tp_size
>
config
.
num_experts
:
if
self
.
tp_size
>
config
.
num_experts
:
raise
ValueError
(
raise
ValueError
(
f
"Tensor parallel size
{
self
.
tp_size
}
is greater than "
f
"Tensor parallel size
{
self
.
tp_size
}
is greater than "
f
"the number of experts
{
config
.
num_experts
}
."
f
"the number of experts
{
config
.
num_experts
}
."
)
)
MoEImpl
=
EPMoE
if
global_server_args_dict
[
"enable_ep_moe"
]
else
FusedMoE
self
.
experts
=
get_moe_impl_class
()(
layer_id
=
self
.
layer_id
,
self
.
experts
=
MoEImpl
(
num_experts
=
config
.
num_experts
,
num_experts
=
config
.
num_experts
,
top_k
=
config
.
num_experts_per_tok
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
...
@@ -159,7 +160,9 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
...
@@ -159,7 +160,9 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
self
.
shared_expert
=
None
self
.
shared_expert
=
None
self
.
shared_expert_gate
=
torch
.
nn
.
Linear
(
config
.
hidden_size
,
1
,
bias
=
False
)
self
.
shared_expert_gate
=
torch
.
nn
.
Linear
(
config
.
hidden_size
,
1
,
bias
=
False
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
Optional
[
ForwardBatch
]
=
None
)
->
torch
.
Tensor
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
shared_output
=
None
shared_output
=
None
...
@@ -276,19 +279,6 @@ class Qwen2MoeAttention(nn.Module):
...
@@ -276,19 +279,6 @@ class Qwen2MoeAttention(nn.Module):
return
output
return
output
class
_FFNInputMode
(
Enum
):
# The MLP sublayer requires 1/tp_size tokens as input
SCATTERED
=
auto
()
# The MLP sublayer requires all tokens as input
FULL
=
auto
()
@
dataclass
class
_DecoderLayerInfo
:
is_sparse
:
bool
ffn_input_mode
:
_FFNInputMode
class
Qwen2MoeDecoderLayer
(
nn
.
Module
):
class
Qwen2MoeDecoderLayer
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -298,6 +288,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
...
@@ -298,6 +288,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
prefix
:
str
=
""
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
...
@@ -322,16 +313,20 @@ class Qwen2MoeDecoderLayer(nn.Module):
...
@@ -322,16 +313,20 @@ class Qwen2MoeDecoderLayer(nn.Module):
self
.
attn_tp_rank
=
get_attention_tp_rank
()
self
.
attn_tp_rank
=
get_attention_tp_rank
()
self
.
local_dp_size
=
get_local_attention_dp_size
()
self
.
local_dp_size
=
get_local_attention_dp_size
()
self
.
info
=
self
.
_compute_info
(
config
,
layer_id
=
layer_id
)
# Qwen2MoE all layers are sparse and have no nextn now
previous_layer_info
=
self
.
_compute_info
(
config
,
layer_id
=
layer_id
-
1
)
self
.
is_layer_sparse
=
True
self
.
input_is_scattered
=
(
is_previous_layer_sparse
=
True
layer_id
>
0
and
previous_layer_info
.
ffn_input_mode
==
_FFNInputMode
.
SCATTERED
self
.
layer_scatter_modes
=
LayerScatterModes
.
init_new
(
layer_id
=
layer_id
,
num_layers
=
config
.
num_hidden_layers
,
is_layer_sparse
=
self
.
is_layer_sparse
,
is_previous_layer_sparse
=
is_previous_layer_sparse
,
)
)
self
.
is_last_layer
=
self
.
layer_id
==
config
.
num_hidden_layers
-
1
if
self
.
i
nfo
.
is
_sparse
:
if
self
.
i
s_layer
_sparse
:
self
.
mlp
=
Qwen2MoeSparseMoeBlock
(
self
.
mlp
=
Qwen2MoeSparseMoeBlock
(
layer_id
=
layer_id
,
config
=
config
,
config
=
config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"mlp"
,
prefix
),
prefix
=
add_prefix
(
"mlp"
,
prefix
),
...
@@ -348,27 +343,11 @@ class Qwen2MoeDecoderLayer(nn.Module):
...
@@ -348,27 +343,11 @@ class Qwen2MoeDecoderLayer(nn.Module):
self
.
post_attention_layernorm
=
RMSNorm
(
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
)
self
.
layer_communicator
=
LayerCommunicator
(
@
staticmethod
layer_scatter_modes
=
self
.
layer_scatter_modes
,
def
_enable_moe_dense_fully_dp
():
input_layernorm
=
self
.
input_layernorm
,
return
global_server_args_dict
[
"moe_dense_tp_size"
]
==
1
post_attention_layernorm
=
self
.
post_attention_layernorm
,
@
staticmethod
def
_compute_info
(
config
:
PretrainedConfig
,
layer_id
:
int
):
# WARN: Qwen2MOE has no dense_layer, it is only for compatibility.
mlp_only_layers
=
(
[]
if
not
hasattr
(
config
,
"mlp_only_layers"
)
else
config
.
mlp_only_layers
)
is_sparse
=
(
layer_id
not
in
mlp_only_layers
)
and
(
config
.
num_experts
>
0
and
(
layer_id
+
1
)
%
config
.
decoder_sparse_step
==
0
)
ffn_input_mode
=
(
_FFNInputMode
.
SCATTERED
if
(
global_server_args_dict
[
"enable_deepep_moe"
]
and
is_sparse
)
or
(
Qwen2MoeDecoderLayer
.
_enable_moe_dense_fully_dp
()
and
not
is_sparse
)
else
_FFNInputMode
.
FULL
)
)
return
_DecoderLayerInfo
(
is_sparse
=
is_sparse
,
ffn_input_mode
=
ffn_input_mode
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -377,108 +356,11 @@ class Qwen2MoeDecoderLayer(nn.Module):
...
@@ -377,108 +356,11 @@ class Qwen2MoeDecoderLayer(nn.Module):
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
self
.
info
.
ffn_input_mode
==
_FFNInputMode
.
SCATTERED
:
return
self
.
forward_ffn_with_scattered_input
(
positions
,
hidden_states
,
forward_batch
,
residual
)
elif
self
.
info
.
ffn_input_mode
==
_FFNInputMode
.
FULL
:
return
self
.
forward_ffn_with_full_input
(
positions
,
hidden_states
,
forward_batch
,
residual
)
else
:
raise
NotImplementedError
def
forward_ffn_with_full_input
(
hidden_states
,
residual
=
self
.
layer_communicator
.
prepare_attn
(
self
,
hidden_states
,
residual
,
forward_batch
positions
:
torch
.
Tensor
,
)
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
hidden_states
.
shape
[
0
]
==
0
:
residual
=
hidden_states
else
:
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
# Self Attention
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
forward_batch
=
forward_batch
,
)
# Gather
if
get_tensor_model_parallel_world_size
()
>
1
:
# all gather and all reduce
if
self
.
local_dp_size
!=
1
:
if
self
.
attn_tp_rank
==
0
:
hidden_states
+=
residual
hidden_states
,
local_hidden_states
=
(
forward_batch
.
gathered_buffer
,
hidden_states
,
)
dp_gather_partial
(
hidden_states
,
local_hidden_states
,
forward_batch
)
dp_scatter
(
residual
,
hidden_states
,
forward_batch
)
# TODO extract this bugfix
if
hidden_states
.
shape
[
0
]
!=
0
:
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
else
:
hidden_states
=
tensor_model_parallel_all_reduce
(
hidden_states
)
# TODO extract this bugfix
if
hidden_states
.
shape
[
0
]
!=
0
:
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
elif
hidden_states
.
shape
[
0
]
!=
0
:
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
# Fully Connected
hidden_states
=
self
.
mlp
(
hidden_states
)
# TODO: use reduce-scatter in MLP to avoid this scatter
# Scatter
if
self
.
local_dp_size
!=
1
:
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this!
hidden_states
,
global_hidden_states
=
(
forward_batch
.
gathered_buffer
[:
forward_batch
.
input_ids
.
shape
[
0
]],
hidden_states
,
)
dp_scatter
(
hidden_states
,
global_hidden_states
,
forward_batch
)
return
hidden_states
,
residual
def
forward_ffn_with_scattered_input
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
hidden_states
.
shape
[
0
]
==
0
:
residual
=
hidden_states
else
:
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
if
self
.
attn_tp_size
!=
1
and
self
.
input_is_scattered
:
hidden_states
,
local_hidden_states
=
(
forward_batch
.
gathered_buffer
[:
forward_batch
.
input_ids
.
shape
[
0
]],
hidden_states
,
)
attn_tp_all_gather
(
list
(
hidden_states
.
tensor_split
(
self
.
attn_tp_size
)),
local_hidden_states
)
# Self Attention
if
hidden_states
.
shape
[
0
]
!=
0
:
if
hidden_states
.
shape
[
0
]
!=
0
:
hidden_states
=
self
.
self_attn
(
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
positions
=
positions
,
...
@@ -486,47 +368,15 @@ class Qwen2MoeDecoderLayer(nn.Module):
...
@@ -486,47 +368,15 @@ class Qwen2MoeDecoderLayer(nn.Module):
forward_batch
=
forward_batch
,
forward_batch
=
forward_batch
,
)
)
if
self
.
attn_tp_size
!=
1
:
hidden_states
,
residual
=
self
.
layer_communicator
.
prepare_mlp
(
if
self
.
input_is_scattered
:
hidden_states
,
residual
,
forward_batch
tensor_list
=
list
(
hidden_states
.
tensor_split
(
self
.
attn_tp_size
))
)
hidden_states
=
tensor_list
[
self
.
attn_tp_rank
]
attn_tp_reduce_scatter
(
hidden_states
,
tensor_list
)
if
hidden_states
.
shape
[
0
]
!=
0
:
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
else
:
if
self
.
attn_tp_rank
==
0
:
hidden_states
+=
residual
tensor_list
=
list
(
hidden_states
.
tensor_split
(
self
.
attn_tp_size
))
hidden_states
=
tensor_list
[
self
.
attn_tp_rank
]
attn_tp_reduce_scatter
(
hidden_states
,
tensor_list
)
residual
=
hidden_states
if
hidden_states
.
shape
[
0
]
!=
0
:
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
else
:
if
hidden_states
.
shape
[
0
]
!=
0
:
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
if
not
(
hidden_states
=
self
.
mlp
(
hidden_states
,
forward_batch
)
self
.
_enable_moe_dense_fully_dp
()
and
(
not
self
.
info
.
is_sparse
)
and
hidden_states
.
shape
[
0
]
==
0
):
hidden_states
=
self
.
mlp
(
hidden_states
,
forward_batch
.
forward_mode
)
if
self
.
is_last_layer
and
self
.
attn_tp_size
!=
1
:
hidden_states
,
residual
=
self
.
layer_communicator
.
postprocess_layer
(
hidden_states
+=
residual
hidden_states
,
residual
,
forward_batch
residual
=
None
)
hidden_states
,
local_hidden_states
=
(
forward_batch
.
gathered_buffer
[:
forward_batch
.
input_ids
.
shape
[
0
]],
hidden_states
,
)
attn_tp_all_gather
(
list
(
hidden_states
.
tensor_split
(
self
.
attn_tp_size
)),
local_hidden_states
)
return
hidden_states
,
residual
return
hidden_states
,
residual
...
...
python/sglang/srt/models/qwen3_moe.py
View file @
65f09131
...
@@ -38,6 +38,7 @@ from sglang.srt.distributed import (
...
@@ -38,6 +38,7 @@ from sglang.srt.distributed import (
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_reduce
,
)
)
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.communicator
import
LayerCommunicator
,
LayerScatterModes
from
sglang.srt.layers.dp_attention
import
(
from
sglang.srt.layers.dp_attention
import
(
attn_tp_all_gather
,
attn_tp_all_gather
,
attn_tp_reduce_scatter
,
attn_tp_reduce_scatter
,
...
@@ -78,7 +79,7 @@ from sglang.srt.model_executor.forward_batch_info import (
...
@@ -78,7 +79,7 @@ from sglang.srt.model_executor.forward_batch_info import (
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeMLP
as
Qwen3MoeMLP
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeMLP
as
Qwen3MoeMLP
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeModel
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeModel
from
sglang.srt.utils
import
DeepEPMode
,
add_prefix
from
sglang.srt.utils
import
DeepEPMode
,
add_prefix
,
is_non_idle_and_non_empty
Qwen3MoeConfig
=
None
Qwen3MoeConfig
=
None
...
@@ -150,13 +151,13 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
...
@@ -150,13 +151,13 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
)
)
def
forward
(
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
forward_
mode
:
Optional
[
Forward
Mode
]
=
None
self
,
hidden_states
:
torch
.
Tensor
,
forward_
batch
:
Optional
[
Forward
Batch
]
=
None
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
not
global_server_args_dict
[
"enable_deepep_moe"
]:
if
not
global_server_args_dict
[
"enable_deepep_moe"
]:
return
self
.
forward_normal
(
hidden_states
)
return
self
.
forward_normal
(
hidden_states
)
else
:
else
:
return
self
.
forward_deepep
(
hidden_states
,
forward_
mode
)
return
self
.
forward_deepep
(
hidden_states
,
forward_
batch
)
def
get_moe_weights
(
self
):
def
get_moe_weights
(
self
):
return
[
return
[
...
@@ -180,13 +181,10 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
...
@@ -180,13 +181,10 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
)
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
)
def
forward_deepep
(
def
forward_deepep
(
self
,
hidden_states
:
torch
.
Tensor
,
forward_
mode
:
Forward
Mode
self
,
hidden_states
:
torch
.
Tensor
,
forward_
batch
:
Forward
Batch
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
(
forward_mode
=
forward_batch
.
forward_mode
forward_mode
is
not
None
if
is_non_idle_and_non_empty
(
forward_mode
,
hidden_states
):
and
not
forward_mode
.
is_idle
()
and
hidden_states
.
shape
[
0
]
>
0
):
# router_logits: (num_tokens, n_experts)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
...
@@ -356,19 +354,6 @@ class Qwen3MoeAttention(nn.Module):
...
@@ -356,19 +354,6 @@ class Qwen3MoeAttention(nn.Module):
return
output
return
output
class
_FFNInputMode
(
Enum
):
# The MLP sublayer requires 1/tp_size tokens as input
SCATTERED
=
auto
()
# The MLP sublayer requires all tokens as input
FULL
=
auto
()
@
dataclass
class
_DecoderLayerInfo
:
is_sparse
:
bool
ffn_input_mode
:
_FFNInputMode
class
Qwen3MoeDecoderLayer
(
nn
.
Module
):
class
Qwen3MoeDecoderLayer
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -378,6 +363,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
...
@@ -378,6 +363,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
prefix
:
str
=
""
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
...
@@ -408,15 +394,18 @@ class Qwen3MoeDecoderLayer(nn.Module):
...
@@ -408,15 +394,18 @@ class Qwen3MoeDecoderLayer(nn.Module):
self
.
attn_tp_rank
=
get_attention_tp_rank
()
self
.
attn_tp_rank
=
get_attention_tp_rank
()
self
.
local_dp_size
=
get_local_attention_dp_size
()
self
.
local_dp_size
=
get_local_attention_dp_size
()
self
.
info
=
self
.
_compute_info
(
config
,
layer_id
=
layer_id
)
# Qwen3MoE all layers are sparse and have no nextn now
previous_layer_info
=
self
.
_compute_info
(
config
,
layer_id
=
layer_id
-
1
)
self
.
is_layer_sparse
=
True
self
.
input_is_scattered
=
(
is_previous_layer_sparse
=
True
layer_id
>
0
and
previous_layer_info
.
ffn_input_mode
==
_FFNInputMode
.
SCATTERED
self
.
layer_scatter_modes
=
LayerScatterModes
.
init_new
(
layer_id
=
layer_id
,
num_layers
=
config
.
num_hidden_layers
,
is_layer_sparse
=
self
.
is_layer_sparse
,
is_previous_layer_sparse
=
is_previous_layer_sparse
,
)
)
self
.
is_last_layer
=
self
.
layer_id
==
config
.
num_hidden_layers
-
1
if
self
.
i
nfo
.
is
_sparse
:
if
self
.
i
s_layer
_sparse
:
self
.
mlp
=
Qwen3MoeSparseMoeBlock
(
self
.
mlp
=
Qwen3MoeSparseMoeBlock
(
layer_id
=
self
.
layer_id
,
layer_id
=
self
.
layer_id
,
config
=
config
,
config
=
config
,
...
@@ -436,26 +425,11 @@ class Qwen3MoeDecoderLayer(nn.Module):
...
@@ -436,26 +425,11 @@ class Qwen3MoeDecoderLayer(nn.Module):
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
)
@
staticmethod
self
.
layer_communicator
=
LayerCommunicator
(
def
_enable_moe_dense_fully_dp
():
layer_scatter_modes
=
self
.
layer_scatter_modes
,
return
global_server_args_dict
[
"moe_dense_tp_size"
]
==
1
input_layernorm
=
self
.
input_layernorm
,
post_attention_layernorm
=
self
.
post_attention_layernorm
,
@
staticmethod
def
_compute_info
(
config
:
PretrainedConfig
,
layer_id
:
int
):
# WARN: Qwen3MOE has no dense_layer, it is only for compatibility.
mlp_only_layers
=
(
[]
if
not
hasattr
(
config
,
"mlp_only_layers"
)
else
config
.
mlp_only_layers
)
is_sparse
=
(
layer_id
not
in
mlp_only_layers
)
and
(
config
.
num_experts
>
0
and
(
layer_id
+
1
)
%
config
.
decoder_sparse_step
==
0
)
)
ffn_input_mode
=
(
_FFNInputMode
.
SCATTERED
if
(
global_server_args_dict
[
"enable_deepep_moe"
]
and
is_sparse
)
or
(
Qwen3MoeDecoderLayer
.
_enable_moe_dense_fully_dp
()
and
not
is_sparse
)
else
_FFNInputMode
.
FULL
)
return
_DecoderLayerInfo
(
is_sparse
=
is_sparse
,
ffn_input_mode
=
ffn_input_mode
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -464,105 +438,11 @@ class Qwen3MoeDecoderLayer(nn.Module):
...
@@ -464,105 +438,11 @@ class Qwen3MoeDecoderLayer(nn.Module):
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
self
.
info
.
ffn_input_mode
==
_FFNInputMode
.
SCATTERED
:
return
self
.
forward_ffn_with_scattered_input
(
positions
,
hidden_states
,
forward_batch
,
residual
)
elif
self
.
info
.
ffn_input_mode
==
_FFNInputMode
.
FULL
:
return
self
.
forward_ffn_with_full_input
(
positions
,
hidden_states
,
forward_batch
,
residual
)
else
:
raise
NotImplementedError
def
forward_ffn_with_full_input
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
hidden_states
.
shape
[
0
]
==
0
:
residual
=
hidden_states
else
:
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
# Self Attention
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
forward_batch
=
forward_batch
,
)
# Gather
if
get_tensor_model_parallel_world_size
()
>
1
:
if
self
.
local_dp_size
!=
1
:
if
self
.
attn_tp_rank
==
0
:
hidden_states
+=
residual
hidden_states
,
local_hidden_states
=
(
forward_batch
.
gathered_buffer
,
hidden_states
,
)
dp_gather_partial
(
hidden_states
,
local_hidden_states
,
forward_batch
)
dp_scatter
(
residual
,
hidden_states
,
forward_batch
)
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
else
:
hidden_states
=
tensor_model_parallel_all_reduce
(
hidden_states
)
# TODO extract this bugfix
if
hidden_states
.
shape
[
0
]
!=
0
:
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
elif
hidden_states
.
shape
[
0
]
!=
0
:
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
# Fully Connected
hidden_states
=
self
.
mlp
(
hidden_states
,
forward_batch
.
forward_mode
)
# TODO: use reduce-scatter in MLP to avoid this scatter
# Scatter
if
self
.
local_dp_size
!=
1
:
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this!
hidden_states
,
global_hidden_states
=
(
forward_batch
.
gathered_buffer
[:
forward_batch
.
input_ids
.
shape
[
0
]],
hidden_states
,
)
dp_scatter
(
hidden_states
,
global_hidden_states
,
forward_batch
)
return
hidden_states
,
residual
def
forward_ffn_with_scattered_input
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
hidden_states
.
shape
[
0
]
==
0
:
residual
=
hidden_states
else
:
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
if
self
.
attn_tp_size
!=
1
and
self
.
input_is_scattered
:
hidden_states
,
residual
=
self
.
layer_communicator
.
prepare_attn
(
hidden_states
,
local_hidden_states
=
(
hidden_states
,
residual
,
forward_batch
forward_batch
.
gathered_buffer
[:
forward_batch
.
input_ids
.
shape
[
0
]],
)
hidden_states
,
)
attn_tp_all_gather
(
list
(
hidden_states
.
tensor_split
(
self
.
attn_tp_size
)),
local_hidden_states
)
# Self Attention
if
hidden_states
.
shape
[
0
]
!=
0
:
if
hidden_states
.
shape
[
0
]
!=
0
:
hidden_states
=
self
.
self_attn
(
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
positions
=
positions
,
...
@@ -570,47 +450,15 @@ class Qwen3MoeDecoderLayer(nn.Module):
...
@@ -570,47 +450,15 @@ class Qwen3MoeDecoderLayer(nn.Module):
forward_batch
=
forward_batch
,
forward_batch
=
forward_batch
,
)
)
if
self
.
attn_tp_size
!=
1
:
hidden_states
,
residual
=
self
.
layer_communicator
.
prepare_mlp
(
if
self
.
input_is_scattered
:
hidden_states
,
residual
,
forward_batch
tensor_list
=
list
(
hidden_states
.
tensor_split
(
self
.
attn_tp_size
))
)
hidden_states
=
tensor_list
[
self
.
attn_tp_rank
]
attn_tp_reduce_scatter
(
hidden_states
,
tensor_list
)
if
hidden_states
.
shape
[
0
]
!=
0
:
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
else
:
if
self
.
attn_tp_rank
==
0
:
hidden_states
+=
residual
tensor_list
=
list
(
hidden_states
.
tensor_split
(
self
.
attn_tp_size
))
hidden_states
=
tensor_list
[
self
.
attn_tp_rank
]
attn_tp_reduce_scatter
(
hidden_states
,
tensor_list
)
residual
=
hidden_states
if
hidden_states
.
shape
[
0
]
!=
0
:
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
else
:
if
hidden_states
.
shape
[
0
]
!=
0
:
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
if
not
(
hidden_states
=
self
.
mlp
(
hidden_states
,
forward_batch
)
self
.
_enable_moe_dense_fully_dp
()
and
(
not
self
.
info
.
is_sparse
)
hidden_states
,
residual
=
self
.
layer_communicator
.
postprocess_layer
(
and
hidden_states
.
shape
[
0
]
==
0
hidden_states
,
residual
,
forward_batch
):
)
hidden_states
=
self
.
mlp
(
hidden_states
,
forward_batch
.
forward_mode
)
if
self
.
is_last_layer
and
self
.
attn_tp_size
!=
1
:
hidden_states
+=
residual
residual
=
None
hidden_states
,
local_hidden_states
=
(
forward_batch
.
gathered_buffer
[:
forward_batch
.
input_ids
.
shape
[
0
]],
hidden_states
,
)
attn_tp_all_gather
(
list
(
hidden_states
.
tensor_split
(
self
.
attn_tp_size
)),
local_hidden_states
)
return
hidden_states
,
residual
return
hidden_states
,
residual
...
...
python/sglang/srt/utils.py
View file @
65f09131
...
@@ -2026,6 +2026,14 @@ class DeepEPMode(Enum):
...
@@ -2026,6 +2026,14 @@ class DeepEPMode(Enum):
return
DeepEPMode
.
normal
return
DeepEPMode
.
normal
def
is_non_idle_and_non_empty
(
forward_mode
,
hidden_states
):
return
(
(
forward_mode
is
not
None
)
and
not
forward_mode
.
is_idle
()
and
hidden_states
.
shape
[
0
]
>
0
)
def
fast_topk
(
values
,
topk
,
dim
):
def
fast_topk
(
values
,
topk
,
dim
):
if
topk
==
1
:
if
topk
==
1
:
# Use max along the specified dimension to get both value and index
# Use max along the specified dimension to get both value and index
...
...
test/srt/test_disaggregation.py
View file @
65f09131
...
@@ -146,7 +146,7 @@ class TestDisaggregationAccuracy(CustomTestCase):
...
@@ -146,7 +146,7 @@ class TestDisaggregationAccuracy(CustomTestCase):
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.62
)
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.62
)
def
test_logprob
(
self
):
def
test_logprob
(
self
):
prompt
=
"The capital of
taiwan
is "
prompt
=
"The capital of
france
is "
response
=
requests
.
post
(
response
=
requests
.
post
(
self
.
lb_url
+
"/generate"
,
self
.
lb_url
+
"/generate"
,
json
=
{
json
=
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment