Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4f9947e6
"tests/vscode:/vscode.git/clone" did not exist on "fa6ecb9aa7a55a99f87fdec7a75011f87af2176c"
Commit
4f9947e6
authored
Dec 20, 2025
by
zhuwenwen
Browse files
update qwen3_moe.py
parent
1e622f10
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
173 additions
and
240 deletions
+173
-240
vllm/model_executor/models/qwen3_moe.py
vllm/model_executor/models/qwen3_moe.py
+173
-240
No files found.
vllm/model_executor/models/qwen3_moe.py
View file @
4f9947e6
...
@@ -22,22 +22,19 @@
...
@@ -22,22 +22,19 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
import
typing
from
collections.abc
import
Iterable
from
collections.abc
import
Callable
,
Iterable
from
itertools
import
islice
from
typing
import
Any
,
Optional
,
Union
from
typing
import
Any
,
Optional
,
Union
import
os
import
os
import
re
import
re
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.attention
import
Attention
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
,
get_current_vllm_config
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
(
get_ep_group
,
get_pp_group
,
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_gather
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
...
@@ -51,17 +48,17 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
...
@@ -51,17 +48,17 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
(
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
default_weight_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.models.utils
import
sequence_parallel_chunk
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
MixtureOfExperts
,
SupportsLoRA
,
SupportsPP
from
.interfaces
import
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
PPMissingLayer
,
extract_layer_index
,
from
.utils
import
(
AutoWeightsLoader
,
extract_layer_index
,
is_pp_missing_parameter
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
maybe_prefix
)
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.utils
import
direct_register_custom_op
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
from
vllm.utils
import
W8a8GetCacheJSON
from
vllm.utils
import
W8a8GetCacheJSON
...
@@ -108,86 +105,49 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
...
@@ -108,86 +105,49 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
vllm_config
:
VllmConfig
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
):
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_text_config
parallel_config
=
vllm_config
.
parallel_config
quant_config
=
vllm_config
.
quant_config
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
ep_group
=
get_ep_group
().
device_group
self
.
ep_rank
=
self
.
ep_group
.
rank
()
self
.
ep_size
=
self
.
ep_group
.
size
()
self
.
n_routed_experts
=
config
.
num_experts
self
.
is_sequence_parallel
=
parallel_config
.
use_sequence_parallel_moe
if
self
.
tp_size
>
config
.
num_experts
:
if
self
.
tp_size
>
config
.
num_experts
:
raise
ValueError
(
raise
ValueError
(
f
"Tensor parallel size
{
self
.
tp_size
}
is greater than "
f
"Tensor parallel size
{
self
.
tp_size
}
is greater than "
f
"the number of experts
{
config
.
num_experts
}
."
)
f
"the number of experts
{
config
.
num_experts
}
."
)
# Load balancing settings.
self
.
experts
=
FusedMoE
(
num_experts
=
config
.
num_experts
,
vllm_config
=
get_current_vllm_config
()
eplb_config
=
vllm_config
.
parallel_config
.
eplb_config
self
.
enable_eplb
=
parallel_config
.
enable_eplb
self
.
n_logical_experts
=
self
.
n_routed_experts
self
.
n_redundant_experts
=
eplb_config
.
num_redundant_experts
self
.
n_physical_experts
=
(
self
.
n_logical_experts
+
self
.
n_redundant_experts
)
self
.
n_local_physical_experts
=
self
.
n_physical_experts
//
self
.
ep_size
self
.
physical_expert_start
=
(
self
.
ep_rank
*
self
.
n_local_physical_experts
)
self
.
physical_expert_end
=
(
self
.
physical_expert_start
+
self
.
n_local_physical_experts
)
self
.
experts
=
FusedMoE
(
num_experts
=
self
.
n_routed_experts
,
top_k
=
config
.
num_experts_per_tok
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
reduce_results
=
Tru
e
,
reduce_results
=
Fals
e
,
renormalize
=
config
.
norm_topk_prob
,
renormalize
=
config
.
norm_topk_prob
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.experts"
,
prefix
=
f
"
{
prefix
}
.experts"
)
enable_eplb
=
self
.
enable_eplb
,
num_redundant_experts
=
self
.
n_redundant_experts
,
is_sequence_parallel
=
self
.
is_sequence_parallel
)
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
config
.
num_experts
,
config
.
num_experts
,
bias
=
False
,
bias
=
False
,
quant_config
=
quant_config
,
quant_config
=
None
,
prefix
=
f
"
{
prefix
}
.gate"
)
prefix
=
f
"
{
prefix
}
.gate"
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
assert
hidden_states
.
dim
(
# NOTE: hidden_states can have either 1D or 2D shape.
)
<=
2
,
"Qwen3MoeSparseMoeBlock only supports 1D or 2D inputs"
orig_shape
=
hidden_states
.
shape
is_input_1d
=
hidden_states
.
dim
()
==
1
hidden_dim
=
hidden_states
.
shape
[
-
1
]
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
if
self
.
is_sequence_parallel
:
hidden_states
=
sequence_parallel_chunk
(
hidden_states
)
# router_logits: (num_tokens, n_experts)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
router_logits
=
router_logits
)
if
self
.
is_sequence_parallel
:
if
self
.
tp_size
>
1
:
final_hidden_states
=
tensor_model_parallel_all_gather
(
final_hidden_states
=
self
.
experts
.
maybe_all_reduce_tensor_model_parallel
(
# noqa E501
final_hidden_states
,
0
)
final_hidden_states
)
final_hidden_states
=
final_hidden_states
[:
num_tokens
]
# return to 1d if input is 1d
return
final_hidden_states
.
view
(
orig_shape
)
return
final_hidden_states
.
squeeze
(
0
)
if
is_input_1d
else
\
final_hidden_states
class
Qwen3MoeAttention
(
nn
.
Module
):
class
Qwen3MoeAttention
(
nn
.
Module
):
...
@@ -206,7 +166,6 @@ class Qwen3MoeAttention(nn.Module):
...
@@ -206,7 +166,6 @@ class Qwen3MoeAttention(nn.Module):
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
dual_chunk_attention_config
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
...
@@ -230,7 +189,6 @@ class Qwen3MoeAttention(nn.Module):
...
@@ -230,7 +189,6 @@ class Qwen3MoeAttention(nn.Module):
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
rope_theta
=
rope_theta
self
.
rope_theta
=
rope_theta
self
.
max_position_embeddings
=
max_position_embeddings
self
.
max_position_embeddings
=
max_position_embeddings
self
.
dual_chunk_attention_config
=
dual_chunk_attention_config
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
self
.
head_dim
,
self
.
head_dim
,
...
@@ -252,25 +210,72 @@ class Qwen3MoeAttention(nn.Module):
...
@@ -252,25 +210,72 @@ class Qwen3MoeAttention(nn.Module):
max_position
=
max_position_embeddings
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
rope_scaling
=
rope_scaling
,
dual_chunk_attention_config
=
dual_chunk_attention_config
,
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
,
**
{
"layer_idx"
:
extract_layer_index
(
prefix
),
"dual_chunk_attention_config"
:
dual_chunk_attention_config
,
}
if
dual_chunk_attention_config
else
{},
)
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
)
self
.
q_norm
=
RMSNorm
(
self
.
head_dim
,
eps
=
rms_norm_eps
)
self
.
q_norm
=
RMSNorm
(
self
.
head_dim
,
eps
=
rms_norm_eps
)
self
.
k_norm
=
RMSNorm
(
self
.
head_dim
,
eps
=
rms_norm_eps
)
self
.
k_norm
=
RMSNorm
(
self
.
head_dim
,
eps
=
rms_norm_eps
)
def
rms_rotary_embedding_fuse
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
],
head_size
:
int
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox_style
:
bool
,
q_weight
:
torch
.
Tensor
,
k_weight
:
torch
.
Tensor
,
q_bias
:
Optional
[
torch
.
Tensor
],
k_bias
:
Optional
[
torch
.
Tensor
],
epsilon
:
float
,
)
->
None
:
from
lightop
import
rms_rotary_embedding_fuse
as
fused_kernel
fused_kernel
(
positions
,
query
,
key
,
head_size
,
cos_sin_cache
,
is_neox_style
,
q_weight
,
k_weight
,
q_bias
,
k_bias
,
epsilon
,
)
def
rms_rotary_embedding_fuse_fake
(
# q_out:torch.Tensor,
# k_out:torch.Tensor,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
],
head_size
:
int
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox_style
:
bool
,
q_weight
:
torch
.
Tensor
,
k_weight
:
torch
.
Tensor
,
q_bias
:
Optional
[
torch
.
Tensor
],
k_bias
:
Optional
[
torch
.
Tensor
],
epsilon
:
float
,
)
->
None
:
# Fake impl intentionally left as no-op for graph tracing modes.
pass
direct_register_custom_op
(
op_name
=
"rms_rotary_embedding_fuse"
,
op_func
=
rms_rotary_embedding_fuse
,
mutates_args
=
[
"query"
,
"key"
],
fake_impl
=
rms_rotary_embedding_fuse_fake
,
)
def
forward
(
def
forward
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
...
@@ -278,23 +283,52 @@ class Qwen3MoeAttention(nn.Module):
...
@@ -278,23 +283,52 @@ class Qwen3MoeAttention(nn.Module):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
# Add qk-norm
if
envs
.
VLLM_USE_FUSED_RMS_ROPE
:
q_by_head
=
q
.
view
(
*
q
.
shape
[:
-
1
],
q
.
shape
[
-
1
]
//
self
.
head_dim
,
# Fused RMSNorm + RoPE path through custom op.
self
.
head_dim
)
cos_sin_cache
=
self
.
rotary_emb
.
cos_sin_cache
if
envs
.
VLLM_USE_APEX_RN
:
if
(
cos_sin_cache
.
device
!=
q
.
device
q_by_head
=
self
.
q_norm
.
forward_apex
(
q_by_head
)
or
cos_sin_cache
.
dtype
!=
q
.
dtype
):
else
:
cos_sin_cache
=
cos_sin_cache
.
to
(
q
.
device
,
q_by_head
=
self
.
q_norm
.
forward_cuda
(
q_by_head
)
dtype
=
q
.
dtype
,
q
=
q_by_head
.
view
(
q
.
shape
)
non_blocking
=
True
)
# Persist the converted cache so we don't re-copy/re-allocate
# on every forward when the original buffer starts on CPU.
self
.
rotary_emb
.
cos_sin_cache
=
cos_sin_cache
# # q, k 使用 continuous
q
=
q
.
contiguous
()
k
=
k
.
contiguous
()
torch
.
ops
.
vllm
.
rms_rotary_embedding_fuse
(
positions
,
q
,
k
,
self
.
head_dim
,
cos_sin_cache
,
self
.
rotary_emb
.
is_neox_style
,
self
.
q_norm
.
weight
,
self
.
k_norm
.
weight
,
None
,
None
,
self
.
q_norm
.
variance_epsilon
,
)
k_by_head
=
k
.
view
(
*
k
.
shape
[:
-
1
],
k
.
shape
[
-
1
]
//
self
.
head_dim
,
self
.
head_dim
)
if
envs
.
VLLM_USE_APEX_RN
:
k_by_head
=
self
.
k_norm
.
forward_apex
(
k_by_head
)
else
:
else
:
k_by_head
=
self
.
k_norm
.
forward_cuda
(
k_by_head
)
# Add qk-norm then RoPE (original path).
k
=
k_by_head
.
view
(
k
.
shape
)
q_by_head
=
q
.
view
(
*
q
.
shape
[:
-
1
],
q
.
shape
[
-
1
]
//
self
.
head_dim
,
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
self
.
head_dim
)
if
envs
.
VLLM_USE_APEX_RN
:
q_by_head
=
self
.
q_norm
.
forward_apex
(
q_by_head
)
else
:
q_by_head
=
self
.
q_norm
.
forward_cuda
(
q_by_head
)
q
=
q_by_head
.
view
(
q
.
shape
)
k_by_head
=
k
.
view
(
*
k
.
shape
[:
-
1
],
k
.
shape
[
-
1
]
//
self
.
head_dim
,
self
.
head_dim
)
if
envs
.
VLLM_USE_APEX_RN
:
k_by_head
=
self
.
k_norm
.
forward_apex
(
k_by_head
)
else
:
k_by_head
=
self
.
k_norm
.
forward_cuda
(
k_by_head
)
k
=
k_by_head
.
view
(
k
.
shape
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
@@ -302,21 +336,19 @@ class Qwen3MoeAttention(nn.Module):
...
@@ -302,21 +336,19 @@ class Qwen3MoeAttention(nn.Module):
class
Qwen3MoeDecoderLayer
(
nn
.
Module
):
class
Qwen3MoeDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
)
->
None
:
def
__init__
(
self
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_text_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
8192
)
dual_chunk_attention_config
=
getattr
(
config
,
"dual_chunk_attention_config"
,
None
)
self
.
self_attn
=
Qwen3MoeAttention
(
self
.
self_attn
=
Qwen3MoeAttention
(
hidden_size
=
self
.
hidden_size
,
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_heads
=
config
.
num_attention_heads
,
...
@@ -330,7 +362,6 @@ class Qwen3MoeDecoderLayer(nn.Module):
...
@@ -330,7 +362,6 @@ class Qwen3MoeDecoderLayer(nn.Module):
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
dual_chunk_attention_config
=
dual_chunk_attention_config
,
)
)
# `mlp_only_layers` in the config.
# `mlp_only_layers` in the config.
...
@@ -340,7 +371,8 @@ class Qwen3MoeDecoderLayer(nn.Module):
...
@@ -340,7 +371,8 @@ class Qwen3MoeDecoderLayer(nn.Module):
if
(
layer_idx
not
in
mlp_only_layers
)
and
(
if
(
layer_idx
not
in
mlp_only_layers
)
and
(
config
.
num_experts
>
0
and
config
.
num_experts
>
0
and
(
layer_idx
+
1
)
%
config
.
decoder_sparse_step
==
0
):
(
layer_idx
+
1
)
%
config
.
decoder_sparse_step
==
0
):
self
.
mlp
=
Qwen3MoeSparseMoeBlock
(
vllm_config
=
vllm_config
,
self
.
mlp
=
Qwen3MoeSparseMoeBlock
(
config
=
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
)
prefix
=
f
"
{
prefix
}
.mlp"
)
else
:
else
:
self
.
mlp
=
Qwen3MoeMLP
(
hidden_size
=
config
.
hidden_size
,
self
.
mlp
=
Qwen3MoeMLP
(
hidden_size
=
config
.
hidden_size
,
...
@@ -384,11 +416,9 @@ class Qwen3MoeModel(nn.Module):
...
@@ -384,11 +416,9 @@ class Qwen3MoeModel(nn.Module):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_text_config
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
parallel_config
=
vllm_config
.
parallel_config
eplb_config
=
parallel_config
.
eplb_config
self
.
num_redundant_experts
=
eplb_config
.
num_redundant_experts
self
.
padding_idx
=
config
.
pad_token_id
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
vocab_size
=
config
.
vocab_size
...
@@ -403,11 +433,12 @@ class Qwen3MoeModel(nn.Module):
...
@@ -403,11 +433,12 @@ class Qwen3MoeModel(nn.Module):
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.embed_tokens"
)
prefix
=
f
"
{
prefix
}
.embed_tokens"
)
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
config
.
num_hidden_layers
,
lambda
prefix
:
Qwen3MoeDecoderLayer
(
vllm_config
=
vllm_config
,
lambda
prefix
:
Qwen3MoeDecoderLayer
(
config
=
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
prefix
),
prefix
=
prefix
),
prefix
=
f
"
{
prefix
}
.layers"
,
prefix
=
f
"
{
prefix
}
.layers"
,
)
)
...
@@ -444,7 +475,8 @@ class Qwen3MoeModel(nn.Module):
...
@@ -444,7 +475,8 @@ class Qwen3MoeModel(nn.Module):
assert
intermediate_tensors
is
not
None
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
residual
=
intermediate_tensors
[
"residual"
]
for
layer
in
islice
(
self
.
layers
,
self
.
start_layer
,
self
.
end_layer
):
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
residual
)
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
residual
)
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
return
IntermediateTensors
({
...
@@ -454,16 +486,6 @@ class Qwen3MoeModel(nn.Module):
...
@@ -454,16 +486,6 @@ class Qwen3MoeModel(nn.Module):
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
return
hidden_states
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
num_experts
,
num_redundant_experts
=
self
.
num_redundant_experts
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
@@ -480,9 +502,16 @@ class Qwen3MoeModel(nn.Module):
...
@@ -480,9 +502,16 @@ class Qwen3MoeModel(nn.Module):
".v_scale"
,
"_v_scale"
,
".weight_scale"
,
".v_scale"
,
"_v_scale"
,
".weight_scale"
,
"_weight_scale"
,
".input_scale"
,
"_input_scale"
)
"_weight_scale"
,
".input_scale"
,
"_input_scale"
)
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
num_experts
)
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
set
[
str
]
=
set
()
loaded_params
:
set
[
str
]
=
set
()
expert_params_mapping
=
self
.
get_expert_mapping
()
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
if
self
.
use_llama_nn
:
if
self
.
use_llama_nn
:
current_count
=
loaded_weight
.
current_count
current_count
=
loaded_weight
.
current_count
...
@@ -508,68 +537,35 @@ class Qwen3MoeModel(nn.Module):
...
@@ -508,68 +537,35 @@ class Qwen3MoeModel(nn.Module):
# Skip layers on other devices.
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
if
is_pp_missing_parameter
(
name
,
self
):
continue
continue
if
name
.
endswith
(
"scale"
):
# Remapping the name of FP8 kv-scale.
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
if
name
not
in
params_dict
:
if
name
not
in
params_dict
:
continue
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
param
.
weight_loader
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
,
shard_id
)
if
weight_loader
==
default_weight_loader
:
weight_loader
(
param
,
loaded_weight
)
else
:
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
break
else
:
else
:
is_expert_weight
=
False
for
mapping
in
expert_params_mapping
:
for
mapping
in
expert_params_mapping
:
param_name
,
weight_name
,
expert_id
,
shard_id
=
mapping
param_name
,
weight_name
,
expert_id
,
shard_id
=
mapping
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
continue
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Anyway, this is an expert weight and should not be
# Skip layers on other devices.
# attempted to load as other weights later
if
is_pp_missing_parameter
(
name
,
self
):
is_expert_weight
=
True
# Do not modify `name` since the loop may continue here
# Instead, create a new variable
name_mapped
=
name
.
replace
(
weight_name
,
param_name
)
if
is_pp_missing_parameter
(
name_mapped
,
self
):
continue
continue
# Skip loading extra parameters for GPTQ/modelopt models.
# Skip loading extra parameters for GPTQ/modelopt models.
if
name_mapped
.
endswith
(
if
name
.
endswith
(
ignore_suffixes
ignore_suffixes
)
and
name
not
in
params_dict
:
)
and
name_mapped
not
in
params_dict
:
continue
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name_mapped
]
weight_loader
=
param
.
weight_loader
# We should ask the weight loader to return success or not
weight_loader
(
param
,
# here since otherwise we may skip experts with other
loaded_weight
,
# available replicas.
name
,
weight_loader
=
typing
.
cast
(
Callable
[...,
bool
],
shard_id
=
shard_id
,
param
.
weight_loader
)
expert_id
=
expert_id
)
success
=
weight_loader
(
param
,
break
loaded_weight
,
name_mapped
,
shard_id
=
shard_id
,
expert_id
=
expert_id
,
return_success
=
True
)
if
success
:
name
=
name_mapped
break
else
:
else
:
if
is_expert_weight
:
# We've checked that this is an expert weight
# However it's not mapped locally to this rank
# So we simply skip it
continue
# Skip loading extra parameters for GPTQ/modelopt models.
# Skip loading extra parameters for GPTQ/modelopt models.
if
name
.
endswith
(
if
name
.
endswith
(
ignore_suffixes
)
and
name
not
in
params_dict
:
ignore_suffixes
)
and
name
not
in
params_dict
:
...
@@ -639,8 +635,7 @@ class Qwen3MoeModel(nn.Module):
...
@@ -639,8 +635,7 @@ class Qwen3MoeModel(nn.Module):
return
loaded_params
return
loaded_params
class
Qwen3MoeForCausalLM
(
nn
.
Module
,
SupportsPP
,
SupportsLoRA
,
class
Qwen3MoeForCausalLM
(
nn
.
Module
,
SupportsPP
):
MixtureOfExperts
):
packed_modules_mapping
=
{
packed_modules_mapping
=
{
"qkv_proj"
:
[
"qkv_proj"
:
[
"q_proj"
,
"q_proj"
,
...
@@ -657,7 +652,7 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA,
...
@@ -657,7 +652,7 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA,
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_
text_
config
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
...
@@ -665,74 +660,13 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA,
...
@@ -665,74 +660,13 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
quant_config
=
quant_config
)
prefix
=
maybe_prefix
(
prefix
,
"lm_head"
))
if
self
.
config
.
tie_word_embeddings
:
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
self
.
model
.
make_empty_intermediate_tensors
)
# Set MoE hyperparameters
self
.
expert_weights
=
[]
self
.
moe_layers
:
list
[
FusedMoE
]
=
[]
example_layer
=
None
for
layer
in
self
.
model
.
layers
:
if
isinstance
(
layer
,
PPMissingLayer
):
continue
assert
isinstance
(
layer
,
Qwen3MoeDecoderLayer
)
if
isinstance
(
layer
.
mlp
,
Qwen3MoeSparseMoeBlock
):
example_layer
=
layer
.
mlp
self
.
moe_layers
.
append
(
layer
.
mlp
.
experts
)
if
example_layer
is
None
:
raise
RuntimeError
(
"No Qwen3MoE layer found in the model.layers."
)
self
.
num_moe_layers
=
len
(
self
.
moe_layers
)
self
.
num_expert_groups
=
1
self
.
num_shared_experts
=
0
self
.
num_logical_experts
=
example_layer
.
n_logical_experts
self
.
num_physical_experts
=
example_layer
.
n_physical_experts
self
.
num_local_physical_experts
=
example_layer
.
n_local_physical_experts
self
.
num_routed_experts
=
example_layer
.
n_routed_experts
self
.
num_redundant_experts
=
example_layer
.
n_redundant_experts
def
set_eplb_state
(
self
,
expert_load_view
:
torch
.
Tensor
,
logical_to_physical_map
:
torch
.
Tensor
,
logical_replica_count
:
torch
.
Tensor
,
)
->
None
:
for
layer_idx
,
layer
in
enumerate
(
self
.
moe_layers
):
# Register the expert weights.
self
.
expert_weights
.
append
(
layer
.
get_expert_weights
())
layer
.
set_eplb_state
(
moe_layer_idx
=
layer_idx
,
expert_load_view
=
expert_load_view
,
logical_to_physical_map
=
logical_to_physical_map
,
logical_replica_count
=
logical_replica_count
,
)
def
update_physical_experts_metadata
(
self
,
num_physical_experts
:
int
,
num_local_physical_experts
:
int
,
)
->
None
:
assert
self
.
num_local_physical_experts
==
num_local_physical_experts
self
.
num_physical_experts
=
num_physical_experts
self
.
num_local_physical_experts
=
num_local_physical_experts
self
.
num_redundant_experts
=
(
num_physical_experts
-
self
.
num_logical_experts
)
for
layer
in
self
.
model
.
layers
:
if
isinstance
(
layer
.
mlp
,
Qwen3MoeSparseMoeBlock
):
moe
=
layer
.
mlp
moe
.
n_local_physical_experts
=
num_local_physical_experts
moe
.
n_physical_experts
=
num_physical_experts
moe
.
n_redundant_experts
=
self
.
num_redundant_experts
moe
.
experts
.
update_expert_map
()
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
return
self
.
model
.
get_input_embeddings
(
input_ids
)
...
@@ -750,14 +684,13 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA,
...
@@ -750,14 +684,13 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA,
def
compute_logits
(
def
compute_logits
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
)
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
return
logits
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
torch
.
Tensor
]])
->
set
[
str
]:
loader
=
AutoWeightsLoader
(
self
)
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
)
return
loader
.
load_weights
(
weights
)
\ No newline at end of file
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
return
self
.
model
.
get_expert_mapping
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment