Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
47e66c24
Commit
47e66c24
authored
Oct 09, 2025
by
bnellnm
Committed by
GitHub
Oct 09, 2025
Browse files
[Model] Apply shared experts overlap optimization to all models with shared experts (#26145)
Signed-off-by:
Bill Nell
<
bnell@redhat.com
>
parent
3b736e1c
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
271 additions
and
283 deletions
+271
-283
vllm/model_executor/layers/fused_moe/__init__.py
vllm/model_executor/layers/fused_moe/__init__.py
+2
-0
vllm/model_executor/layers/fused_moe/shared_fused_moe.py
vllm/model_executor/layers/fused_moe/shared_fused_moe.py
+23
-12
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+2
-0
vllm/model_executor/layers/shared_fused_moe/__init__.py
vllm/model_executor/layers/shared_fused_moe/__init__.py
+0
-5
vllm/model_executor/models/aria.py
vllm/model_executor/models/aria.py
+15
-13
vllm/model_executor/models/bailing_moe.py
vllm/model_executor/models/bailing_moe.py
+26
-21
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+24
-45
vllm/model_executor/models/dots1.py
vllm/model_executor/models/dots1.py
+22
-18
vllm/model_executor/models/ernie45_moe.py
vllm/model_executor/models/ernie45_moe.py
+20
-20
vllm/model_executor/models/ernie45_vl_moe.py
vllm/model_executor/models/ernie45_vl_moe.py
+34
-23
vllm/model_executor/models/glm4_moe.py
vllm/model_executor/models/glm4_moe.py
+25
-43
vllm/model_executor/models/hunyuan_v1.py
vllm/model_executor/models/hunyuan_v1.py
+20
-21
vllm/model_executor/models/llama4.py
vllm/model_executor/models/llama4.py
+4
-5
vllm/model_executor/models/qwen2_moe.py
vllm/model_executor/models/qwen2_moe.py
+30
-27
vllm/model_executor/models/qwen3_next.py
vllm/model_executor/models/qwen3_next.py
+24
-30
No files found.
vllm/model_executor/layers/fused_moe/__init__.py
View file @
47e66c24
...
...
@@ -15,6 +15,7 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPermuteExpertsUnpermute
,
FusedMoEPrepareAndFinalize
,
)
from
vllm.model_executor.layers.fused_moe.shared_fused_moe
import
SharedFusedMoE
from
vllm.model_executor.layers.fused_moe.utils
import
activation_without_mul
from
vllm.triton_utils
import
HAS_TRITON
...
...
@@ -42,6 +43,7 @@ __all__ = [
"FusedMoEPermuteExpertsUnpermute"
,
"FusedMoEActivationFormat"
,
"FusedMoEPrepareAndFinalize"
,
"SharedFusedMoE"
,
"activation_without_mul"
,
"override_config"
,
"get_config"
,
...
...
vllm/model_executor/layers/
shared_
fused_moe/shared_fused_moe.py
→
vllm/model_executor/layers/fused_moe/shared_fused_moe.py
View file @
47e66c24
...
...
@@ -18,13 +18,21 @@ class SharedFusedMoE(FusedMoE):
def
__init__
(
self
,
shared_experts
:
torch
.
nn
.
Module
,
shared_experts
:
Optional
[
torch
.
nn
.
Module
]
,
use_overlapped
:
bool
=
True
,
**
kwargs
,
):
super
().
__init__
(
**
kwargs
)
self
.
_shared_experts
=
shared_experts
self
.
use_overlapped
=
use_overlapped
# Disable shared expert overlap if EP is disabled or we are not using
# flashinfer + DP since there is nothing to be gained in this case.
# Disabling the overlap optimization also prevents the shared experts
# from being hidden from torch.compile.
self
.
use_overlapped
=
(
use_overlapped
and
not
(
self
.
use_ep
or
self
.
use_flashinfer_cutlass_kernels
)
and
self
.
_shared_experts
is
not
None
)
@
property
def
shared_experts
(
self
)
->
Optional
[
torch
.
nn
.
Module
]:
...
...
@@ -36,16 +44,19 @@ class SharedFusedMoE(FusedMoE):
router_logits
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
not
self
.
use_overlapped
:
shared_out
=
self
.
_shared_experts
(
hidden_states
)
# Reduce outputs if necessary, since the MLP should
# have been created with reduce_results=False.
if
(
self
.
reduce_results
and
self
.
tp_size
>
1
and
self
.
must_reduce_shared_expert_outputs
()
):
shared_out
=
tensor_model_parallel_all_reduce
(
shared_out
)
if
self
.
_shared_experts
is
not
None
:
shared_out
=
self
.
_shared_experts
(
hidden_states
)
# Reduce shared expert outputs if necessary, since the MLP
# should have been created with reduce_results=False.
if
(
self
.
reduce_results
and
self
.
tp_size
>
1
and
self
.
must_reduce_shared_expert_outputs
()
):
shared_out
=
tensor_model_parallel_all_reduce
(
shared_out
)
else
:
shared_out
=
None
fused_out
=
super
().
forward
(
hidden_states
=
hidden_states
,
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
47e66c24
...
...
@@ -741,6 +741,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer
.
w13_input_scale
=
None
layer
.
w2_input_scale
=
None
self
.
rocm_aiter_moe_enabled
=
False
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# Lazy import to avoid importing triton too early.
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
...
...
vllm/model_executor/layers/shared_fused_moe/__init__.py
deleted
100644 → 0
View file @
3b736e1c
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
vllm.model_executor.layers.shared_fused_moe.shared_fused_moe
import
SharedFusedMoE
__all__
=
[
"SharedFusedMoE"
]
vllm/model_executor/models/aria.py
View file @
47e66c24
...
...
@@ -13,7 +13,7 @@ from vllm.config import VllmConfig
from
vllm.config.multimodal
import
BaseDummyOptions
from
vllm.distributed
import
get_tensor_model_parallel_rank
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe
import
Shared
FusedMoE
from
vllm.model_executor.layers.linear
import
ColumnParallelLinear
,
RowParallelLinear
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
...
...
@@ -206,7 +206,7 @@ class AriaProjector(nn.Module):
return
out
class
AriaFusedMoE
(
FusedMoE
):
class
AriaFusedMoE
(
Shared
FusedMoE
):
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
shard_id
:
str
)
->
None
:
...
...
@@ -260,7 +260,16 @@ class AriaTextMoELayer(nn.Module):
torch
.
empty
((
self
.
config
.
moe_num_experts
,
self
.
config
.
hidden_size
))
)
self
.
shared_experts
=
LlamaMLP
(
config
.
hidden_size
,
config
.
intermediate_size
*
config
.
moe_num_shared_experts
,
"silu"
,
quant_config
=
quant_config
,
bias
=
config
.
mlp_bias
,
)
self
.
experts
=
AriaFusedMoE
(
shared_experts
=
self
.
shared_experts
,
num_experts
=
config
.
moe_num_experts
,
top_k
=
config
.
moe_topk
,
hidden_size
=
config
.
hidden_size
,
...
...
@@ -269,13 +278,6 @@ class AriaTextMoELayer(nn.Module):
reduce_results
=
True
,
prefix
=
f
"
{
prefix
}
.experts"
,
)
self
.
shared_experts
=
LlamaMLP
(
config
.
hidden_size
,
config
.
intermediate_size
*
config
.
moe_num_shared_experts
,
"silu"
,
quant_config
=
quant_config
,
bias
=
config
.
mlp_bias
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
...
...
@@ -291,12 +293,12 @@ class AriaTextMoELayer(nn.Module):
router_output
=
torch
.
nn
.
functional
.
linear
(
hidden_states
,
self
.
router_weight
)
hidden_states_copy
=
hidden_states
.
clone
()
# NOTE: hidden_states will be modified inplace by `FusedMoE`
sparse_expert_output
=
self
.
experts
(
hidden_states
,
router_output
)
shared_expert_output
=
self
.
shared_experts
(
hidden_states_copy
)
return
sparse_expert_output
+
shared_expert_output
if
self
.
shared_experts
is
not
None
:
return
sparse_expert_output
[
0
]
+
sparse_expert_output
[
1
]
else
:
return
sparse_expert_output
class
AriaTextDecoderLayer
(
LlamaDecoderLayer
):
...
...
vllm/model_executor/models/bailing_moe.py
View file @
47e66c24
...
...
@@ -43,7 +43,7 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce
,
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe
import
Shared
FusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
...
...
@@ -276,22 +276,6 @@ class BailingMoE(nn.Module):
# default value for scoring_func
self
.
score_function
=
"softmax"
self
.
experts
=
FusedMoE
(
num_experts
=
self
.
num_experts
,
top_k
=
self
.
top_k
,
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
reduce_results
=
False
,
renormalize
=
self
.
norm_expert_prob
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.experts"
,
scoring_func
=
self
.
score_function
,
e_score_correction_bias
=
self
.
gate
.
expert_bias
,
num_expert_group
=
self
.
n_group
,
topk_group
=
self
.
topk_group
,
use_grouped_topk
=
self
.
use_grouped_topk
,
)
if
self
.
num_shared_experts
>
0
:
if
hasattr
(
config
,
"moe_shared_expert_intermediate_size"
):
intermediate_size
=
config
.
moe_shared_expert_intermediate_size
...
...
@@ -308,11 +292,27 @@ class BailingMoE(nn.Module):
else
:
self
.
shared_experts
=
None
self
.
experts
=
SharedFusedMoE
(
shared_experts
=
self
.
shared_experts
,
num_experts
=
self
.
num_experts
,
top_k
=
self
.
top_k
,
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
reduce_results
=
False
,
renormalize
=
self
.
norm_expert_prob
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.experts"
,
scoring_func
=
self
.
score_function
,
e_score_correction_bias
=
self
.
gate
.
expert_bias
,
num_expert_group
=
self
.
n_group
,
topk_group
=
self
.
topk_group
,
use_grouped_topk
=
self
.
use_grouped_topk
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
,
hidden_size
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_size
)
if
self
.
shared_experts
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
.
to
(
self
.
router_dtype
))
router_logits
=
router_logits
.
to
(
hidden_states
.
dtype
)
...
...
@@ -321,9 +321,14 @@ class BailingMoE(nn.Module):
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
if
self
.
shared_experts
is
not
None
:
shared_output
,
final_hidden_states
=
final_hidden_states
else
:
shared_output
=
None
final_hidden_states
*=
self
.
routed_scaling_factor
if
self
.
shared_
experts
:
if
shared_
output
is
not
None
:
final_hidden_states
=
final_hidden_states
+
shared_output
if
self
.
tp_size
>
1
:
...
...
@@ -475,7 +480,7 @@ class BailingMoeModel(nn.Module):
return
hidden_states
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
return
FusedMoE
.
make_expert_params_mapping
(
return
Shared
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
47e66c24
...
...
@@ -49,7 +49,7 @@ from vllm.forward_context import get_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention_layer_base
import
AttentionLayerBase
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe
import
Shared
FusedMoE
from
vllm.model_executor.layers.layernorm
import
LayerNorm
,
RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
...
...
@@ -64,7 +64,6 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8
,
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.shared_fused_moe
import
SharedFusedMoE
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
...
...
@@ -205,26 +204,6 @@ class DeepseekV2MoE(nn.Module):
)
if
config
.
n_shared_experts
is
None
:
self
.
experts
=
FusedMoE
(
num_experts
=
config
.
n_routed_experts
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
reduce_results
=
False
,
renormalize
=
config
.
norm_topk_prob
,
quant_config
=
quant_config
,
use_grouped_topk
=
True
,
num_expert_group
=
config
.
n_group
,
topk_group
=
config
.
topk_group
,
prefix
=
f
"
{
prefix
}
.experts"
,
scoring_func
=
config
.
scoring_func
,
# we do scaling outside, set factor to 1.0 to avoid double mul
routed_scaling_factor
=
1.0
,
e_score_correction_bias
=
self
.
gate
.
e_score_correction_bias
,
enable_eplb
=
self
.
enable_eplb
,
num_redundant_experts
=
self
.
n_redundant_experts
,
is_sequence_parallel
=
self
.
is_sequence_parallel
,
)
self
.
shared_experts
=
None
else
:
intermediate_size
=
config
.
moe_intermediate_size
*
config
.
n_shared_experts
...
...
@@ -239,27 +218,27 @@ class DeepseekV2MoE(nn.Module):
prefix
=
f
"
{
prefix
}
.shared_experts"
,
)
self
.
experts
=
SharedFusedMoE
(
shared_experts
=
self
.
shared_experts
,
num_experts
=
config
.
n_routed_experts
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
reduce_results
=
False
,
renormalize
=
config
.
norm_topk_prob
,
quant_config
=
quant_config
,
use_grouped_topk
=
True
,
num_expert_group
=
config
.
n_group
,
topk_group
=
config
.
topk_group
,
prefix
=
f
"
{
prefix
}
.experts"
,
scoring_func
=
config
.
scoring_func
,
# we do scaling outside, set factor to 1.0 to avoid double mul
routed_scaling_factor
=
1.0
,
e_score_correction_bias
=
self
.
gate
.
e_score_correction_bias
,
enable_eplb
=
self
.
enable_eplb
,
num_redundant_experts
=
self
.
n_redundant_experts
,
is_sequence_parallel
=
self
.
is_sequence_parallel
,
)
self
.
experts
=
SharedFusedMoE
(
shared_experts
=
self
.
shared_experts
,
num_experts
=
config
.
n_routed_experts
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
reduce_results
=
False
,
renormalize
=
config
.
norm_topk_prob
,
quant_config
=
quant_config
,
use_grouped_topk
=
True
,
num_expert_group
=
config
.
n_group
,
topk_group
=
config
.
topk_group
,
prefix
=
f
"
{
prefix
}
.experts"
,
scoring_func
=
config
.
scoring_func
,
# we do scaling outside, set factor to 1.0 to avoid double mul
routed_scaling_factor
=
1.0
,
e_score_correction_bias
=
self
.
gate
.
e_score_correction_bias
,
enable_eplb
=
self
.
enable_eplb
,
num_redundant_experts
=
self
.
n_redundant_experts
,
is_sequence_parallel
=
self
.
is_sequence_parallel
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
...
...
@@ -1306,7 +1285,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoR
self
.
num_moe_layers
=
config
.
num_hidden_layers
-
config
.
first_k_dense_replace
self
.
num_expert_groups
=
config
.
n_group
self
.
moe_layers
:
list
[
FusedMoE
]
=
[]
self
.
moe_layers
:
list
[
Shared
FusedMoE
]
=
[]
example_moe
=
None
for
layer
in
self
.
model
.
layers
:
if
isinstance
(
layer
,
PPMissingLayer
):
...
...
@@ -1394,7 +1373,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoR
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
expert_params_mapping
=
Shared
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
...
...
vllm/model_executor/models/dots1.py
View file @
47e66c24
...
...
@@ -42,7 +42,7 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce
,
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe
import
Shared
FusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
...
...
@@ -145,7 +145,21 @@ class Dots1MoE(nn.Module):
else
:
self
.
gate
.
e_score_correction_bias
=
None
self
.
experts
=
FusedMoE
(
if
config
.
n_shared_experts
is
not
None
:
intermediate_size
=
config
.
moe_intermediate_size
*
config
.
n_shared_experts
self
.
shared_experts
=
Dots1MLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
reduce_results
=
False
,
prefix
=
f
"
{
prefix
}
.shared_experts"
,
)
else
:
self
.
shared_experts
=
None
self
.
experts
=
SharedFusedMoE
(
shared_experts
=
self
.
shared_experts
,
num_experts
=
config
.
n_routed_experts
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
...
...
@@ -163,29 +177,19 @@ class Dots1MoE(nn.Module):
e_score_correction_bias
=
self
.
gate
.
e_score_correction_bias
,
)
if
config
.
n_shared_experts
is
not
None
:
intermediate_size
=
config
.
moe_intermediate_size
*
config
.
n_shared_experts
self
.
shared_experts
=
Dots1MLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
reduce_results
=
False
,
prefix
=
f
"
{
prefix
}
.shared_experts"
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
if
self
.
n_shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
final_hidden_states
=
(
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
*
self
.
routed_scaling_factor
)
if
shared_output
is
not
None
:
final_hidden_states
=
final_hidden_states
+
shared_output
if
self
.
shared_experts
is
not
None
:
final_hidden_states
=
final_hidden_states
[
0
]
+
final_hidden_states
[
1
]
if
self
.
tp_size
>
1
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
)
...
...
@@ -426,7 +430,7 @@ class Dots1Model(nn.Module):
return
hidden_states
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
return
FusedMoE
.
make_expert_params_mapping
(
return
Shared
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
...
...
vllm/model_executor/models/ernie45_moe.py
View file @
47e66c24
...
...
@@ -37,7 +37,7 @@ from vllm.config import CacheConfig, VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe
import
Shared
FusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
...
...
@@ -145,18 +145,6 @@ class Ernie4_5_MoeMoE(nn.Module):
torch
.
empty
(
config
.
moe_num_experts
,
dtype
=
torch
.
float32
)
)
self
.
experts
=
FusedMoE
(
num_experts
=
config
.
moe_num_experts
,
top_k
=
config
.
moe_k
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
reduce_results
=
False
,
renormalize
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.experts"
,
e_score_correction_bias
=
self
.
gate
.
e_score_correction_bias
,
)
if
self
.
has_shared_experts
:
intermediate_size
=
(
config
.
moe_intermediate_size
*
config
.
moe_num_shared_experts
...
...
@@ -167,16 +155,28 @@ class Ernie4_5_MoeMoE(nn.Module):
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.shared_experts"
,
reduce_results
=
self
.
experts
.
must_reduce_shared_expert_outputs
()
,
reduce_results
=
False
,
)
else
:
self
.
shared_experts
=
None
self
.
experts
=
SharedFusedMoE
(
shared_experts
=
self
.
shared_experts
,
num_experts
=
config
.
moe_num_experts
,
top_k
=
config
.
moe_k
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
reduce_results
=
False
,
renormalize
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.experts"
,
e_score_correction_bias
=
self
.
gate
.
e_score_correction_bias
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
orig_shape
=
hidden_states
.
shape
hidden_dim
=
hidden_states
.
shape
[
-
1
]
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
shared_output
=
None
if
self
.
has_shared_experts
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
router_logits
,
_
=
self
.
gate
(
hidden_states
.
to
(
dtype
=
torch
.
float32
))
...
...
@@ -184,8 +184,8 @@ class Ernie4_5_MoeMoE(nn.Module):
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
if
self
.
has_shared_experts
and
shared_output
is
not
None
:
final_hidden_states
=
final_hidden_states
+
shared_output
if
self
.
has_shared_experts
:
final_hidden_states
=
final_hidden_states
[
0
]
+
final_hidden_states
[
1
]
if
self
.
tp_size
>
1
:
final_hidden_states
=
self
.
experts
.
maybe_all_reduce_tensor_model_parallel
(
...
...
@@ -460,7 +460,7 @@ class Ernie4_5_MoeModel(nn.Module):
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return
FusedMoE
.
make_expert_params_mapping
(
return
Shared
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
...
...
vllm/model_executor/models/ernie45_vl_moe.py
View file @
47e66c24
...
...
@@ -37,7 +37,7 @@ from vllm.attention import Attention
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe
import
Shared
FusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
...
...
@@ -74,7 +74,15 @@ logger = init_logger(__name__)
class
Ernie4_5_VLMoeMLP
(
Ernie4_5_MoeMLP
):
pass
def
__init__
(
self
,
shared_experts
:
Optional
[
torch
.
nn
.
Module
]
=
None
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
shared_experts
=
shared_experts
def
forward
(
self
,
x
):
if
self
.
shared_experts
is
not
None
:
return
self
.
shared_experts
(
x
)
+
super
().
forward
(
x
)
else
:
return
super
().
forward
(
x
)
class
Ernie4_5_VLMoeAttention
(
nn
.
Module
):
...
...
@@ -223,6 +231,21 @@ class Ernie4_5_VLMoeMoE(nn.Module):
assert
text_moe_layer_start_index
<=
text_moe_layer_end_index
if
self
.
has_shared_experts
:
intermediate_size
=
(
config
.
moe_intermediate_size
[
0
]
*
config
.
moe_num_shared_experts
)
self
.
shared_experts
=
Ernie4_5_VLMoeMLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.shared_experts"
,
reduce_results
=
False
,
)
else
:
self
.
shared_experts
=
None
if
(
layer_idx
>=
text_moe_layer_start_index
and
layer_idx
<=
text_moe_layer_end_index
...
...
@@ -236,7 +259,8 @@ class Ernie4_5_VLMoeMoE(nn.Module):
prefix
=
f
"
{
prefix
}
.text_experts_gate"
,
)
self
.
text_experts
=
FusedMoE
(
self
.
text_experts
=
SharedFusedMoE
(
shared_experts
=
self
.
shared_experts
,
num_experts
=
config
.
moe_num_experts
[
0
],
top_k
=
config
.
moe_k
,
hidden_size
=
config
.
hidden_size
,
...
...
@@ -249,6 +273,7 @@ class Ernie4_5_VLMoeMoE(nn.Module):
)
else
:
self
.
text_experts
=
Ernie4_5_VLMoeMLP
(
shared_experts
=
self
.
shared_experts
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
...
...
@@ -271,7 +296,8 @@ class Ernie4_5_VLMoeMoE(nn.Module):
prefix
=
f
"
{
prefix
}
.vision_experts_gate"
,
)
self
.
vision_experts
=
FusedMoE
(
self
.
vision_experts
=
SharedFusedMoE
(
shared_experts
=
self
.
shared_experts
,
num_experts
=
config
.
moe_num_experts
[
1
],
top_k
=
config
.
moe_k
,
hidden_size
=
config
.
hidden_size
,
...
...
@@ -284,6 +310,7 @@ class Ernie4_5_VLMoeMoE(nn.Module):
)
else
:
self
.
vision_experts
=
Ernie4_5_VLMoeMLP
(
shared_experts
=
self
.
shared_experts
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
...
...
@@ -292,19 +319,6 @@ class Ernie4_5_VLMoeMoE(nn.Module):
prefix
=
f
"
{
prefix
}
.mlp"
,
)
if
self
.
has_shared_experts
:
intermediate_size
=
(
config
.
moe_intermediate_size
[
0
]
*
config
.
moe_num_shared_experts
)
self
.
shared_experts
=
Ernie4_5_VLMoeMLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.shared_experts"
,
reduce_results
=
self
.
text_experts
.
must_reduce_shared_expert_outputs
(),
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
@@ -315,9 +329,6 @@ class Ernie4_5_VLMoeMoE(nn.Module):
hidden_dim
=
hidden_states
.
shape
[
-
1
]
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
if
self
.
has_shared_experts
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
if
visual_token_mask
is
not
None
and
visual_token_mask
.
all
():
# only vision modal input
router_logits
,
_
=
self
.
vision_experts_gate
(
...
...
@@ -362,8 +373,8 @@ class Ernie4_5_VLMoeMoE(nn.Module):
hidden_states
=
hidden_states
,
router_logits
=
text_router_logits
)
if
self
.
has_shared_experts
and
shared_output
is
not
None
:
final_hidden_states
=
final_hidden_states
+
shared_output
if
self
.
has_shared_experts
:
final_hidden_states
=
final_hidden_states
[
0
]
+
final_hidden_states
[
1
]
if
self
.
tp_size
>
1
:
final_hidden_states
=
(
...
...
@@ -649,7 +660,7 @@ class Ernie4_5_VLMoeForCausalLM(nn.Module, SupportsPP):
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
expert_params_mapping
=
Shared
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
...
...
vllm/model_executor/models/glm4_moe.py
View file @
47e66c24
...
...
@@ -42,7 +42,7 @@ from vllm.distributed import (
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe
import
Shared
FusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
...
...
@@ -52,7 +52,6 @@ from vllm.model_executor.layers.linear import (
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.shared_fused_moe
import
SharedFusedMoE
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
...
...
@@ -176,46 +175,29 @@ class Glm4MoE(nn.Module):
reduce_results
=
False
,
prefix
=
f
"
{
prefix
}
.shared_experts"
,
)
self
.
experts
=
SharedFusedMoE
(
shared_experts
=
self
.
shared_experts
,
num_experts
=
config
.
n_routed_experts
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
reduce_results
=
False
,
renormalize
=
config
.
norm_topk_prob
,
quant_config
=
quant_config
,
use_grouped_topk
=
True
,
num_expert_group
=
config
.
n_group
,
topk_group
=
config
.
topk_group
,
prefix
=
f
"
{
prefix
}
.experts"
,
scoring_func
=
"sigmoid"
,
# we do scaling outside, set factor to 1.0 to avoid double mul
routed_scaling_factor
=
1.0
,
e_score_correction_bias
=
self
.
gate
.
e_score_correction_bias
,
enable_eplb
=
self
.
enable_eplb
,
num_redundant_experts
=
self
.
n_redundant_experts
,
)
else
:
self
.
experts
=
FusedMoE
(
num_experts
=
config
.
n_routed_experts
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
reduce_results
=
False
,
renormalize
=
config
.
norm_topk_prob
,
quant_config
=
quant_config
,
use_grouped_topk
=
True
,
num_expert_group
=
config
.
n_group
,
topk_group
=
config
.
topk_group
,
prefix
=
f
"
{
prefix
}
.experts"
,
scoring_func
=
"sigmoid"
,
# we do scaling outside, set factor to 1.0 to avoid double mul
routed_scaling_factor
=
1.0
,
e_score_correction_bias
=
self
.
gate
.
e_score_correction_bias
,
enable_eplb
=
self
.
enable_eplb
,
num_redundant_experts
=
self
.
n_redundant_experts
,
)
self
.
shared_experts
=
None
self
.
experts
=
SharedFusedMoE
(
shared_experts
=
self
.
shared_experts
,
num_experts
=
config
.
n_routed_experts
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
reduce_results
=
False
,
renormalize
=
config
.
norm_topk_prob
,
quant_config
=
quant_config
,
use_grouped_topk
=
True
,
num_expert_group
=
config
.
n_group
,
topk_group
=
config
.
topk_group
,
prefix
=
f
"
{
prefix
}
.experts"
,
scoring_func
=
"sigmoid"
,
# we do scaling outside, set factor to 1.0 to avoid double mul
routed_scaling_factor
=
1.0
,
e_score_correction_bias
=
self
.
gate
.
e_score_correction_bias
,
enable_eplb
=
self
.
enable_eplb
,
num_redundant_experts
=
self
.
n_redundant_experts
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
...
...
@@ -522,7 +504,7 @@ class Glm4MoeModel(nn.Module):
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return
FusedMoE
.
make_expert_params_mapping
(
return
Shared
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
...
...
@@ -677,7 +659,7 @@ class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
self
.
num_moe_layers
=
config
.
num_hidden_layers
-
config
.
first_k_dense_replace
self
.
num_expert_groups
=
config
.
n_group
self
.
moe_layers
:
list
[
FusedMoE
]
=
[]
self
.
moe_layers
:
list
[
Shared
FusedMoE
]
=
[]
example_moe
=
None
for
layer
in
self
.
model
.
layers
:
if
isinstance
(
layer
,
PPMissingLayer
):
...
...
vllm/model_executor/models/hunyuan_v1.py
View file @
47e66c24
...
...
@@ -44,7 +44,7 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce
,
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe
import
Shared
FusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
...
...
@@ -415,19 +415,6 @@ class HunYuanSparseMoeBlock(nn.Module):
self
.
physical_expert_start
+
self
.
n_local_physical_experts
)
self
.
experts
=
FusedMoE
(
num_experts
=
self
.
n_routed_experts
,
top_k
=
top_k
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
intermediate_size
,
reduce_results
=
False
,
renormalize
=
top_k
>
1
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.experts"
,
enable_eplb
=
self
.
enable_eplb
,
num_redundant_experts
=
self
.
n_redundant_experts
,
)
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
config
.
num_experts
,
...
...
@@ -455,22 +442,34 @@ class HunYuanSparseMoeBlock(nn.Module):
else
:
self
.
shared_mlp
=
None
self
.
experts
=
SharedFusedMoE
(
shared_experts
=
self
.
shared_mlp
,
num_experts
=
self
.
n_routed_experts
,
top_k
=
top_k
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
intermediate_size
,
reduce_results
=
False
,
renormalize
=
top_k
>
1
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.experts"
,
enable_eplb
=
self
.
enable_eplb
,
num_redundant_experts
=
self
.
n_redundant_experts
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape
=
hidden_states
.
shape
hidden_dim
=
hidden_states
.
shape
[
-
1
]
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
shared_output
=
None
if
self
.
shared_mlp
is
not
None
:
shared_output
=
self
.
shared_mlp
(
hidden_states
)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
if
shared_output
is
not
None
:
final_hidden_states
=
final_hidden_states
+
shared_output
if
self
.
shared_mlp
is
not
None
:
final_hidden_states
=
final_hidden_states
[
0
]
+
final_hidden_states
[
1
]
if
self
.
tp_size
>
1
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
...
...
@@ -724,7 +723,7 @@ class HunYuanModel(nn.Module):
if
_is_moe
(
self
.
config
):
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return
FusedMoE
.
make_expert_params_mapping
(
return
Shared
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
...
...
@@ -1008,7 +1007,7 @@ class HunYuanMoEV1Base(HunyuanV1ModelBase, MixtureOfExperts):
# Set MoE hyperparameters
self
.
expert_weights
=
[]
self
.
num_expert_groups
=
1
self
.
moe_layers
:
list
[
FusedMoE
]
=
[]
self
.
moe_layers
:
list
[
Shared
FusedMoE
]
=
[]
example_layer
=
None
for
layer
in
self
.
model
.
layers
:
if
isinstance
(
layer
,
PPMissingLayer
):
...
...
vllm/model_executor/models/llama4.py
View file @
47e66c24
...
...
@@ -33,7 +33,7 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_gather
,
)
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe
import
Shared
FusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
...
...
@@ -42,7 +42,6 @@ from vllm.model_executor.layers.linear import (
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.shared_fused_moe
import
SharedFusedMoE
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
,
...
...
@@ -399,7 +398,7 @@ class Llama4Model(LlamaModel):
params_dict: The dictionary of module parameters.
loaded_params: The set of already loaded parameters.
expert_params_mapping: The mapping of expert parameters. Must be
generated by FusedMoE.make_expert_params_mapping().
generated by
Shared
FusedMoE.make_expert_params_mapping().
fused: Whether the expert weights are fused into a single weight
tensor or are separate weight tensors for each expert.
When fused is True, loaded_weight should have shape of:
...
...
@@ -522,7 +521,7 @@ class Llama4Model(LlamaModel):
fused_experts_params
=
False
# Expert parameter mapping for the case where the expert weights are
# not fused into a single weight tensor.
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
expert_params_mapping
=
Shared
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
...
...
@@ -530,7 +529,7 @@ class Llama4Model(LlamaModel):
)
# Expert parameter mapping for the case where the expert weights are
# fused into a single weight tensor.
expert_params_mapping_fused
=
FusedMoE
.
make_expert_params_mapping
(
expert_params_mapping_fused
=
Shared
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_up_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"gate_up_proj"
,
...
...
vllm/model_executor/models/qwen2_moe.py
View file @
47e66c24
...
...
@@ -40,7 +40,7 @@ from vllm.config import CacheConfig, VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe
import
Shared
FusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
...
...
@@ -79,6 +79,7 @@ class Qwen2MoeMLP(nn.Module):
hidden_act
:
str
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
reduce_results
:
bool
=
True
,
expert_gate
:
Optional
[
torch
.
nn
.
Linear
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
...
...
@@ -102,12 +103,17 @@ class Qwen2MoeMLP(nn.Module):
f
"Unsupported activation:
{
hidden_act
}
. Only silu is supported for now."
)
self
.
act_fn
=
SiluAndMul
()
self
.
expert_gate
=
expert_gate
def
forward
(
self
,
x
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
out
=
self
.
act_fn
(
gate_up
)
out
,
_
=
self
.
down_proj
(
out
)
if
self
.
expert_gate
is
not
None
:
out
=
F
.
sigmoid
(
self
.
expert_gate
(
x
))
*
out
return
out
class
Qwen2MoeSparseMoeBlock
(
nn
.
Module
):
...
...
@@ -126,17 +132,6 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
f
"the number of experts
{
config
.
num_experts
}
."
)
self
.
experts
=
FusedMoE
(
num_experts
=
config
.
num_experts
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
reduce_results
=
False
,
renormalize
=
config
.
norm_topk_prob
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.experts"
,
)
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
config
.
num_experts
,
...
...
@@ -144,39 +139,47 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
quant_config
=
None
,
prefix
=
f
"
{
prefix
}
.gate"
,
)
self
.
shared_expert_gate
=
torch
.
nn
.
Linear
(
config
.
hidden_size
,
1
,
bias
=
False
)
if
config
.
shared_expert_intermediate_size
>
0
:
self
.
shared_expert
=
Qwen2MoeMLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
shared_expert_intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
reduce_results
=
self
.
experts
.
must_reduce_shared_expert_outputs
(),
reduce_results
=
False
,
expert_gate
=
self
.
shared_expert_gate
,
prefix
=
f
"
{
prefix
}
.shared_expert"
,
)
else
:
self
.
shared_expert
=
None
self
.
shared_expert_gate
=
torch
.
nn
.
Linear
(
config
.
hidden_size
,
1
,
bias
=
False
)
self
.
experts
=
SharedFusedMoE
(
shared_experts
=
self
.
shared_expert
,
num_experts
=
config
.
num_experts
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
reduce_results
=
False
,
renormalize
=
config
.
norm_topk_prob
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.experts"
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape
=
hidden_states
.
shape
hidden_dim
=
hidden_states
.
shape
[
-
1
]
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
shared_output
=
None
if
self
.
shared_expert
is
not
None
:
shared_output
=
self
.
shared_expert
(
hidden_states
)
if
self
.
shared_expert_gate
is
not
None
:
shared_output
=
(
F
.
sigmoid
(
self
.
shared_expert_gate
(
hidden_states
))
*
shared_output
)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
if
shared_
outpu
t
is
not
None
:
final_hidden_states
=
final_hidden_states
+
shared_output
if
self
.
shared_
exper
t
is
not
None
:
final_hidden_states
=
final_hidden_states
[
0
]
+
final_hidden_states
[
1
]
if
self
.
tp_size
>
1
:
final_hidden_states
=
self
.
experts
.
maybe_all_reduce_tensor_model_parallel
(
# noqa E501
final_hidden_states
...
...
@@ -418,7 +421,7 @@ class Qwen2MoeModel(nn.Module):
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return
FusedMoE
.
make_expert_params_mapping
(
return
Shared
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
...
...
vllm/model_executor/models/qwen3_next.py
View file @
47e66c24
...
...
@@ -7,7 +7,6 @@ from itertools import islice
from
typing
import
Optional
import
torch
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
torch
import
nn
from
transformers.activations
import
ACT2FN
...
...
@@ -36,7 +35,7 @@ from vllm.model_executor.layers.fla.ops import (
chunk_gated_delta_rule
,
fused_recurrent_gated_delta_rule
,
)
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe
import
Shared
FusedMoE
from
vllm.model_executor.layers.layernorm
import
GemmaRMSNorm
as
Qwen3NextRMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
...
...
@@ -136,20 +135,6 @@ class Qwen3NextSparseMoeBlock(nn.Module):
self
.
physical_expert_start
+
self
.
n_local_physical_experts
)
self
.
experts
=
FusedMoE
(
num_experts
=
self
.
n_routed_experts
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
reduce_results
=
False
,
renormalize
=
config
.
norm_topk_prob
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.experts"
,
enable_eplb
=
self
.
enable_eplb
,
num_redundant_experts
=
self
.
n_redundant_experts
,
is_sequence_parallel
=
self
.
is_sequence_parallel
,
)
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
config
.
num_experts
,
...
...
@@ -158,18 +143,35 @@ class Qwen3NextSparseMoeBlock(nn.Module):
prefix
=
f
"
{
prefix
}
.gate"
,
)
self
.
shared_expert_gate
=
torch
.
nn
.
Linear
(
config
.
hidden_size
,
1
,
bias
=
False
)
if
config
.
shared_expert_intermediate_size
>
0
:
self
.
shared_expert
=
Qwen3NextMLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
shared_expert_intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
reduce_results
=
self
.
experts
.
must_reduce_shared_expert_outputs
(),
reduce_results
=
False
,
expert_gate
=
self
.
shared_expert_gate
,
prefix
=
f
"
{
prefix
}
.shared_expert"
,
)
else
:
self
.
shared_expert
=
None
self
.
shared_expert_gate
=
torch
.
nn
.
Linear
(
config
.
hidden_size
,
1
,
bias
=
False
)
self
.
experts
=
SharedFusedMoE
(
shared_experts
=
self
.
shared_expert
,
num_experts
=
self
.
n_routed_experts
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
reduce_results
=
False
,
renormalize
=
config
.
norm_topk_prob
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.experts"
,
enable_eplb
=
self
.
enable_eplb
,
num_redundant_experts
=
self
.
n_redundant_experts
,
is_sequence_parallel
=
self
.
is_sequence_parallel
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# NOTE: hidden_states can have either 1D or 2D shape.
...
...
@@ -180,22 +182,14 @@ class Qwen3NextSparseMoeBlock(nn.Module):
if
self
.
is_sequence_parallel
:
hidden_states
=
sequence_parallel_chunk
(
hidden_states
)
shared_output
=
None
if
self
.
shared_expert
is
not
None
:
shared_output
=
self
.
shared_expert
(
hidden_states
)
if
self
.
shared_expert_gate
is
not
None
:
shared_output
=
(
F
.
sigmoid
(
self
.
shared_expert_gate
(
hidden_states
))
*
shared_output
)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
if
shared_
outpu
t
is
not
None
:
final_hidden_states
=
final_hidden_states
+
shared_output
if
self
.
shared_
exper
t
is
not
None
:
final_hidden_states
=
final_hidden_states
[
0
]
+
final_hidden_states
[
1
]
if
self
.
is_sequence_parallel
:
final_hidden_states
=
tensor_model_parallel_all_gather
(
...
...
@@ -1008,7 +1002,7 @@ class Qwen3NextModel(nn.Module):
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return
FusedMoE
.
make_expert_params_mapping
(
return
Shared
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
...
...
@@ -1150,7 +1144,7 @@ class Qwen3NextForCausalLM(
# Set MoE hyperparameters
self
.
expert_weights
=
[]
self
.
moe_layers
:
list
[
FusedMoE
]
=
[]
self
.
moe_layers
:
list
[
Shared
FusedMoE
]
=
[]
example_layer
=
None
for
layer
in
self
.
model
.
layers
:
if
isinstance
(
layer
,
PPMissingLayer
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment