Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
fc0e3b91
"tutorials/vscode:/vscode.git/clone" did not exist on "7fe6d0c85732d57a95cd2260fce1a2e1fd93489c"
Unverified
Commit
fc0e3b91
authored
May 23, 2025
by
lukec
Committed by
GitHub
May 22, 2025
Browse files
Support qwen3 deepep (#6120)
parent
d71f3f0a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
125 additions
and
8 deletions
+125
-8
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+4
-1
python/sglang/srt/models/qwen3_moe.py
python/sglang/srt/models/qwen3_moe.py
+121
-7
No files found.
python/sglang/srt/models/qwen2_moe.py
View file @
fc0e3b91
...
@@ -607,7 +607,10 @@ class Qwen2MoeModel(nn.Module):
...
@@ -607,7 +607,10 @@ class Qwen2MoeModel(nn.Module):
)
)
else
:
else
:
if
hidden_states
.
shape
[
0
]
!=
0
:
if
hidden_states
.
shape
[
0
]
!=
0
:
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
if
residual
is
None
:
hidden_states
=
self
.
norm
(
hidden_states
)
else
:
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
return
hidden_states
...
...
python/sglang/srt/models/qwen3_moe.py
View file @
fc0e3b91
...
@@ -32,6 +32,7 @@ from sglang.srt.distributed import (
...
@@ -32,6 +32,7 @@ from sglang.srt.distributed import (
get_pp_group
,
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
parallel_state
,
split_tensor_along_last_dim
,
split_tensor_along_last_dim
,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_reduce
,
...
@@ -54,8 +55,10 @@ from sglang.srt.layers.linear import (
...
@@ -54,8 +55,10 @@ from sglang.srt.layers.linear import (
RowParallelLinear
,
RowParallelLinear
,
)
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
,
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
,
LogitsProcessorOutput
from
sglang.srt.layers.moe.ep_moe.layer
import
EPMoE
from
sglang.srt.layers.moe.ep_moe.layer
import
DeepEPMoE
,
EPMoE
from
sglang.srt.layers.moe.ep_moe.token_dispatcher
import
DeepEPDispatcher
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.moe.topk
import
select_experts
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.rotary_embedding
import
get_rope
...
@@ -65,11 +68,15 @@ from sglang.srt.layers.vocab_parallel_embedding import (
...
@@ -65,11 +68,15 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_executor.forward_batch_info
import
(
ForwardBatch
,
ForwardMode
,
PPProxyTensors
,
)
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeMLP
as
Qwen3MoeMLP
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeMLP
as
Qwen3MoeMLP
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeModel
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeModel
from
sglang.srt.utils
import
add_prefix
from
sglang.srt.utils
import
DeepEPMode
,
add_prefix
Qwen3MoeConfig
=
None
Qwen3MoeConfig
=
None
...
@@ -92,7 +99,11 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
...
@@ -92,7 +99,11 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
f
"the number of experts
{
config
.
num_experts
}
."
f
"the number of experts
{
config
.
num_experts
}
."
)
)
MoEImpl
=
EPMoE
if
global_server_args_dict
[
"enable_ep_moe"
]
else
FusedMoE
MoEImpl
=
(
DeepEPMoE
if
global_server_args_dict
[
"enable_deepep_moe"
]
else
(
EPMoE
if
global_server_args_dict
[
"enable_ep_moe"
]
else
FusedMoE
)
)
self
.
experts
=
MoEImpl
(
self
.
experts
=
MoEImpl
(
num_experts
=
config
.
num_experts
,
num_experts
=
config
.
num_experts
,
...
@@ -102,6 +113,11 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
...
@@ -102,6 +113,11 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
renormalize
=
config
.
norm_topk_prob
,
renormalize
=
config
.
norm_topk_prob
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"experts"
,
prefix
),
prefix
=
add_prefix
(
"experts"
,
prefix
),
**
(
dict
(
deepep_mode
=
DeepEPMode
[
global_server_args_dict
[
"deepep_mode"
]])
if
global_server_args_dict
[
"enable_deepep_moe"
]
else
{}
),
)
)
self
.
gate
=
ReplicatedLinear
(
self
.
gate
=
ReplicatedLinear
(
...
@@ -112,7 +128,37 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
...
@@ -112,7 +128,37 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
prefix
=
add_prefix
(
"gate"
,
prefix
),
prefix
=
add_prefix
(
"gate"
,
prefix
),
)
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
global_server_args_dict
[
"enable_deepep_moe"
]:
# TODO: we will support tp < ep in the future
self
.
ep_size
=
get_tensor_model_parallel_world_size
()
self
.
num_experts
=
config
.
num_experts
self
.
top_k
=
config
.
num_experts_per_tok
self
.
renormalize
=
config
.
norm_topk_prob
self
.
deepep_dispatcher
=
DeepEPDispatcher
(
group
=
parallel_state
.
get_tp_group
().
device_group
,
router_topk
=
self
.
top_k
,
permute_fusion
=
True
,
num_experts
=
config
.
num_experts
,
num_local_experts
=
config
.
num_experts
//
self
.
tp_size
,
hidden_size
=
config
.
hidden_size
,
params_dtype
=
config
.
torch_dtype
,
deepep_mode
=
DeepEPMode
[
global_server_args_dict
[
"deepep_mode"
]],
async_finish
=
True
,
# TODO
return_recv_hook
=
True
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
forward_mode
:
Optional
[
ForwardMode
]
=
None
)
->
torch
.
Tensor
:
if
not
global_server_args_dict
[
"enable_deepep_moe"
]:
return
self
.
forward_normal
(
hidden_states
)
else
:
return
self
.
forward_deepep
(
hidden_states
,
forward_mode
)
def
forward_normal
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
...
@@ -126,6 +172,68 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
...
@@ -126,6 +172,68 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
)
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
)
def
forward_deepep
(
self
,
hidden_states
:
torch
.
Tensor
,
forward_mode
:
ForwardMode
)
->
torch
.
Tensor
:
if
(
forward_mode
is
not
None
and
not
forward_mode
.
is_idle
()
and
hidden_states
.
shape
[
0
]
>
0
):
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
topk_weights
,
topk_idx
=
select_experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
top_k
=
self
.
top_k
,
use_grouped_topk
=
False
,
renormalize
=
self
.
renormalize
,
)
else
:
topk_idx
=
torch
.
full
(
(
0
,
self
.
top_k
),
-
1
,
dtype
=
torch
.
int
,
device
=
hidden_states
.
device
)
topk_weights
=
torch
.
empty
(
(
0
,
self
.
top_k
),
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
if
self
.
ep_size
>
1
:
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
(
hidden_states
,
topk_idx
,
topk_weights
,
reorder_topk_ids
,
num_recv_tokens_per_expert
,
seg_indptr
,
masked_m
,
expected_m
,
)
=
self
.
deepep_dispatcher
.
dispatch
(
hidden_states
,
topk_idx
,
topk_weights
,
forward_mode
=
forward_mode
,
)
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
topk_idx
=
topk_idx
,
topk_weights
=
topk_weights
,
reorder_topk_ids
=
reorder_topk_ids
,
seg_indptr
=
seg_indptr
,
masked_m
=
masked_m
,
expected_m
=
expected_m
,
num_recv_tokens_per_expert
=
num_recv_tokens_per_expert
,
forward_mode
=
forward_mode
,
)
if
self
.
ep_size
>
1
:
final_hidden_states
=
self
.
deepep_dispatcher
.
combine
(
final_hidden_states
,
topk_idx
,
topk_weights
,
forward_mode
,
)
return
final_hidden_states
class
Qwen3MoeAttention
(
nn
.
Module
):
class
Qwen3MoeAttention
(
nn
.
Module
):
def
__init__
(
def
__init__
(
...
@@ -403,7 +511,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
...
@@ -403,7 +511,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
)
)
# Fully Connected
# Fully Connected
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
,
forward_batch
.
forward_mode
)
# TODO: use reduce-scatter in MLP to avoid this scatter
# TODO: use reduce-scatter in MLP to avoid this scatter
# Scatter
# Scatter
...
@@ -577,7 +685,13 @@ class Qwen3MoeForCausalLM(nn.Module):
...
@@ -577,7 +685,13 @@ class Qwen3MoeForCausalLM(nn.Module):
(
"gate_up_proj"
,
"up_proj"
,
1
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
]
MoEImpl
=
EPMoE
if
global_server_args_dict
[
"enable_ep_moe"
]
else
FusedMoE
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
MoEImpl
=
(
DeepEPMoE
if
global_server_args_dict
[
"enable_deepep_moe"
]
else
(
EPMoE
if
global_server_args_dict
[
"enable_ep_moe"
]
else
FusedMoE
)
)
expert_params_mapping
=
MoEImpl
.
make_expert_params_mapping
(
expert_params_mapping
=
MoEImpl
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_gate_proj_name
=
"gate_proj"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment