Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f386ba88
Commit
f386ba88
authored
Oct 25, 2025
by
zhuwenwen
Browse files
[Models] support HunYuanForCausalLM
parent
a9c37628
Changes
5
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
1130 additions
and
31 deletions
+1130
-31
vllm/config/model.py
vllm/config/model.py
+5
-0
vllm/model_executor/layers/rotary_embedding/__init__.py
vllm/model_executor/layers/rotary_embedding/__init__.py
+8
-3
vllm/model_executor/models/hunyuan.py
vllm/model_executor/models/hunyuan.py
+1005
-0
vllm/model_executor/models/hunyuan_v1.py
vllm/model_executor/models/hunyuan_v1.py
+111
-28
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+1
-0
No files found.
vllm/config/model.py
View file @
f386ba88
...
...
@@ -1869,6 +1869,11 @@ def _get_and_verify_max_len(
if
rope_type
==
"yarn"
:
derived_max_model_len
=
rope_scaling
[
"original_max_position_embeddings"
]
# see DynamicNTKAlphaRotaryEmbedding
if
rope_scaling
[
"type"
]
==
"dynamic"
and
"alpha"
in
rope_scaling
:
scaling_factor
=
1
derived_max_model_len
*=
scaling_factor
if
encoder_config
and
"max_seq_length"
in
encoder_config
:
...
...
vllm/model_executor/layers/rotary_embedding/__init__.py
View file @
f386ba88
...
...
@@ -137,6 +137,11 @@ def get_rope(
scaling_alpha
,
dtype
)
elif
"factor"
in
rope_scaling
:
scaling_factor
=
rope_scaling
[
"factor"
]
if
"alpha"
in
rope_scaling
:
rotary_emb
=
DynamicNTKAlphaRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
rope_scaling
[
"alpha"
],
dtype
)
else
:
rotary_emb
=
DynamicNTKScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
scaling_factor
,
dtype
)
...
...
vllm/model_executor/models/hunyuan.py
0 → 100644
View file @
f386ba88
This diff is collapsed.
Click to expand it.
vllm/model_executor/models/hunyuan_v1.py
View file @
f386ba88
...
...
@@ -889,7 +889,7 @@ class HunYuanModel(nn.Module):
return
loaded_params
class
HunYuanV1Base
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
,
MixtureOfExperts
):
class
HunYuanV1Base
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
...
...
@@ -931,30 +931,6 @@ class HunYuanV1Base(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
else
:
self
.
lm_head
=
PPMissingLayer
()
# Set MoE hyperparameters
self
.
expert_weights
=
[]
self
.
num_expert_groups
=
1
self
.
moe_layers
:
list
[
FusedMoE
]
=
[]
example_layer
=
None
for
layer
in
self
.
model
.
layers
:
if
isinstance
(
layer
,
PPMissingLayer
):
continue
assert
isinstance
(
layer
,
HunYuanDecoderLayer
)
if
isinstance
(
layer
.
mlp
,
HunYuanSparseMoeBlock
):
example_layer
=
layer
.
mlp
self
.
moe_layers
.
append
(
layer
.
mlp
.
experts
)
if
example_layer
is
None
:
raise
RuntimeError
(
"No HunYuanMoE layer found in model.layers."
)
self
.
num_moe_layers
=
len
(
self
.
moe_layers
)
self
.
num_logical_experts
=
example_layer
.
n_logical_experts
self
.
num_physical_experts
=
example_layer
.
n_physical_experts
self
.
num_local_physical_experts
=
example_layer
.
n_local_physical_experts
self
.
num_routed_experts
=
example_layer
.
n_routed_experts
self
.
num_redundant_experts
=
example_layer
.
n_redundant_experts
def
set_eplb_state
(
self
,
expert_load_view
:
torch
.
Tensor
,
...
...
@@ -1030,13 +1006,120 @@ class HunYuanV1Base(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
)
return
loader
.
load_weights
(
weights
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
class
HunYuanMoEV1Base
(
HunYuanV1Base
,
MixtureOfExperts
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
}
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
HunYuanModel
(
vllm_config
=
vllm_config
,
prefix
=
"model"
)
if
get_pp_group
().
is_last_rank
:
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
unpadded_vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
,
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"lm_head"
),
)
if
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
,
logit_scale
)
else
:
self
.
lm_head
=
PPMissingLayer
()
# Set MoE hyperparameters
self
.
expert_weights
=
[]
self
.
num_expert_groups
=
1
self
.
moe_layers
:
list
[
FusedMoE
]
=
[]
# list[SharedFusedMoE] = []
example_layer
=
None
for
layer
in
self
.
model
.
layers
:
if
isinstance
(
layer
,
PPMissingLayer
):
continue
assert
isinstance
(
layer
,
HunYuanDecoderLayer
)
if
isinstance
(
layer
.
mlp
,
HunYuanSparseMoeBlock
):
example_layer
=
layer
.
mlp
self
.
moe_layers
.
append
(
layer
.
mlp
.
experts
)
if
example_layer
is
None
:
raise
RuntimeError
(
"No HunYuanMoE layer found in model.layers."
)
self
.
num_moe_layers
=
len
(
self
.
moe_layers
)
self
.
num_logical_experts
=
example_layer
.
n_logical_experts
self
.
num_physical_experts
=
example_layer
.
n_physical_experts
self
.
num_local_physical_experts
=
example_layer
.
n_local_physical_experts
self
.
num_routed_experts
=
example_layer
.
n_routed_experts
self
.
num_redundant_experts
=
example_layer
.
n_redundant_experts
def
set_eplb_state
(
self
,
expert_load_view
:
torch
.
Tensor
,
logical_to_physical_map
:
torch
.
Tensor
,
logical_replica_count
:
torch
.
Tensor
,
)
->
None
:
for
layer_idx
,
layer
in
enumerate
(
self
.
moe_layers
):
self
.
expert_weights
.
append
(
layer
.
get_expert_weights
())
# Register the expert weights.
layer
.
set_eplb_state
(
moe_layer_idx
=
layer_idx
,
expert_load_view
=
expert_load_view
,
logical_to_physical_map
=
logical_to_physical_map
,
logical_replica_count
=
logical_replica_count
,
)
def
update_physical_experts_metadata
(
self
,
num_physical_experts
:
int
,
num_local_physical_experts
:
int
,
)
->
None
:
assert
self
.
num_local_physical_experts
==
num_local_physical_experts
self
.
num_physical_experts
=
num_physical_experts
self
.
num_local_physical_experts
=
num_local_physical_experts
self
.
num_redundant_experts
=
(
num_physical_experts
-
self
.
num_logical_experts
)
for
layer
in
self
.
model
.
layers
:
if
isinstance
(
layer
.
mlp
,
HunYuanSparseMoeBlock
):
moe
=
layer
.
mlp
moe
.
n_local_physical_experts
=
num_local_physical_experts
moe
.
n_physical_experts
=
num_physical_experts
moe
.
n_redundant_experts
=
self
.
num_redundant_experts
moe
.
experts
.
update_expert_map
()
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
return
self
.
model
.
get_expert_mapping
()
class
HunYuanDenseV1ForCausalLM
(
HunYuanV1Base
):
class
HunYuanDenseV1Base
(
HunYuanV1Base
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
class
HunYuanDenseV1ForCausalLM
(
HunYuanDenseV1Base
):
pass
class
HunYuanMoEV1ForCausalLM
(
HunYuanV1Base
):
class
HunYuanMoEV1ForCausalLM
(
HunYuan
MoE
V1Base
):
pass
\ No newline at end of file
vllm/model_executor/models/registry.py
View file @
f386ba88
...
...
@@ -100,6 +100,7 @@ _TEXT_GENERATION_MODELS = {
"Grok1ModelForCausalLM"
:
(
"grok1"
,
"Grok1ForCausalLM"
),
"HunYuanMoEV1ForCausalLM"
:
(
"hunyuan_v1"
,
"HunYuanMoEV1ForCausalLM"
),
"HunYuanDenseV1ForCausalLM"
:
(
"hunyuan_v1"
,
"HunYuanDenseV1ForCausalLM"
),
"HunYuanForCausalLM"
:
(
"hunyuan"
,
"HunYuanForCausalLM"
),
"HCXVisionForCausalLM"
:
(
"hyperclovax_vision"
,
"HCXVisionForCausalLM"
),
"InternLMForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"InternLM2ForCausalLM"
:
(
"internlm2"
,
"InternLM2ForCausalLM"
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment