Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e330f2b8
Unverified
Commit
e330f2b8
authored
May 01, 2025
by
laixin
Committed by
GitHub
Apr 30, 2025
Browse files
[qwen3] support qwen3 ep moe (#5917)
Co-authored-by:
sleepcoo
<
sleepcoo@gmail.com
>
parent
3ddf5b9d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
16 additions
and
6 deletions
+16
-6
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+8
-3
python/sglang/srt/models/qwen3_moe.py
python/sglang/srt/models/qwen3_moe.py
+8
-3
No files found.
python/sglang/srt/models/qwen2_moe.py
View file @
e330f2b8
...
@@ -36,6 +36,7 @@ from sglang.srt.layers.linear import (
...
@@ -36,6 +36,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear
,
RowParallelLinear
,
)
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe.ep_moe.layer
import
EPMoE
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
...
@@ -45,6 +46,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
...
@@ -45,6 +46,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
sglang.srt.managers.expert_distribution
import
ExpertDistributionRecorder
from
sglang.srt.managers.expert_distribution
import
ExpertDistributionRecorder
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.utils
import
add_prefix
,
make_layers
from
sglang.srt.utils
import
add_prefix
,
make_layers
...
@@ -108,12 +110,13 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
...
@@ -108,12 +110,13 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
f
"the number of experts
{
config
.
num_experts
}
."
f
"the number of experts
{
config
.
num_experts
}
."
)
)
self
.
experts
=
FusedMoE
(
MoEImpl
=
EPMoE
if
global_server_args_dict
[
"enable_ep_moe"
]
else
FusedMoE
self
.
experts
=
MoEImpl
(
num_experts
=
config
.
num_experts
,
num_experts
=
config
.
num_experts
,
top_k
=
config
.
num_experts_per_tok
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
reduce_results
=
False
,
renormalize
=
config
.
norm_topk_prob
,
renormalize
=
config
.
norm_topk_prob
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"experts"
,
prefix
),
prefix
=
add_prefix
(
"experts"
,
prefix
),
...
@@ -427,7 +430,9 @@ class Qwen2MoeForCausalLM(nn.Module):
...
@@ -427,7 +430,9 @@ class Qwen2MoeForCausalLM(nn.Module):
(
"gate_up_proj"
,
"up_proj"
,
1
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
]
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
MoEImpl
=
EPMoE
if
global_server_args_dict
[
"enable_ep_moe"
]
else
FusedMoE
expert_params_mapping
=
MoEImpl
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
...
...
python/sglang/srt/models/qwen3_moe.py
View file @
e330f2b8
...
@@ -40,6 +40,7 @@ from sglang.srt.layers.linear import (
...
@@ -40,6 +40,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear
,
RowParallelLinear
,
)
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe.ep_moe.layer
import
EPMoE
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
...
@@ -48,6 +49,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
...
@@ -48,6 +49,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeMLP
as
Qwen3MoeMLP
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeMLP
as
Qwen3MoeMLP
...
@@ -73,12 +75,13 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
...
@@ -73,12 +75,13 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
f
"the number of experts
{
config
.
num_experts
}
."
f
"the number of experts
{
config
.
num_experts
}
."
)
)
self
.
experts
=
FusedMoE
(
MoEImpl
=
EPMoE
if
global_server_args_dict
[
"enable_ep_moe"
]
else
FusedMoE
self
.
experts
=
MoEImpl
(
num_experts
=
config
.
num_experts
,
num_experts
=
config
.
num_experts
,
top_k
=
config
.
num_experts_per_tok
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
reduce_results
=
False
,
renormalize
=
config
.
norm_topk_prob
,
renormalize
=
config
.
norm_topk_prob
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"experts"
,
prefix
),
prefix
=
add_prefix
(
"experts"
,
prefix
),
...
@@ -356,7 +359,9 @@ class Qwen3MoeForCausalLM(nn.Module):
...
@@ -356,7 +359,9 @@ class Qwen3MoeForCausalLM(nn.Module):
(
"gate_up_proj"
,
"up_proj"
,
1
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
]
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
MoEImpl
=
EPMoE
if
global_server_args_dict
[
"enable_ep_moe"
]
else
FusedMoE
expert_params_mapping
=
MoEImpl
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment