Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2671334d
Unverified
Commit
2671334d
authored
Jul 23, 2025
by
Asher
Committed by
GitHub
Jul 23, 2025
Browse files
[Model] add Hunyuan V1 Dense Model support. (#21368)
Signed-off-by:
Asher Zhang
<
asherszhang@tencent.com
>
parent
2cc5016a
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
57 additions
and
19 deletions
+57
-19
docs/models/supported_models.md
docs/models/supported_models.md
+1
-0
tests/models/registry.py
tests/models/registry.py
+2
-0
vllm/model_executor/models/hunyuan_v1.py
vllm/model_executor/models/hunyuan_v1.py
+52
-18
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+2
-1
No files found.
docs/models/supported_models.md
View file @
2671334d
...
...
@@ -363,6 +363,7 @@ th {
|
`GraniteMoeSharedForCausalLM`
| Granite MoE Shared |
`ibm-research/moe-7b-1b-active-shared-experts`
(test model) | ✅︎ | ✅︎ | ✅︎ |
|
`GritLM`
| GritLM |
`parasail-ai/GritLM-7B-vllm`
. | ✅︎ | ✅︎ | |
|
`Grok1ModelForCausalLM`
| Grok1 |
`hpcai-tech/grok-1`
. | ✅︎ | ✅︎ | ✅︎ |
|
`HunYuanDenseV1ForCausalLM`
| Hunyuan-7B-Instruct-0124 |
`tencent/Hunyuan-7B-Instruct-0124`
| ✅︎ | | ✅︎ |
|
`HunYuanMoEV1ForCausalLM`
| Hunyuan-80B-A13B |
`tencent/Hunyuan-A13B-Instruct`
,
`tencent/Hunyuan-A13B-Pretrain`
,
`tencent/Hunyuan-A13B-Instruct-FP8`
, etc. | ✅︎ | | ✅︎ |
|
`InternLMForCausalLM`
| InternLM |
`internlm/internlm-7b`
,
`internlm/internlm-chat-7b`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`InternLM2ForCausalLM`
| InternLM2 |
`internlm/internlm2-7b`
,
`internlm/internlm2-chat-7b`
, etc. | ✅︎ | ✅︎ | ✅︎ |
...
...
tests/models/registry.py
View file @
2671334d
...
...
@@ -199,6 +199,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code
=
True
),
"HunYuanMoEV1ForCausalLM"
:
_HfExamplesInfo
(
"tencent/Hunyuan-A13B-Instruct"
,
trust_remote_code
=
True
),
"HunYuanDenseV1ForCausalLM"
:
_HfExamplesInfo
(
"tencent/Hunyuan-7B-Instruct-0124"
,
trust_remote_code
=
True
),
"InternLMForCausalLM"
:
_HfExamplesInfo
(
"internlm/internlm-chat-7b"
,
trust_remote_code
=
True
),
"InternLM2ForCausalLM"
:
_HfExamplesInfo
(
"internlm/internlm2-chat-7b"
,
...
...
vllm/model_executor/models/hunyuan_v1
_moe
.py
→
vllm/model_executor/models/hunyuan_v1.py
View file @
2671334d
...
...
@@ -61,6 +61,19 @@ from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
make_layers
)
def
_is_moe
(
config
:
PretrainedConfig
)
->
bool
:
num_experts
=
getattr
(
config
,
"num_experts"
,
None
)
if
isinstance
(
num_experts
,
int
):
return
num_experts
>
1
if
isinstance
(
num_experts
,
list
)
and
num_experts
:
# Ensure all elements are integers before calling max.
if
all
(
isinstance
(
e
,
int
)
for
e
in
num_experts
):
return
max
(
num_experts
)
>
1
else
:
return
False
return
False
def
_get_cla_factor
(
config
:
PretrainedConfig
)
->
int
:
if
not
getattr
(
config
,
"use_cla"
,
False
):
return
1
...
...
@@ -140,8 +153,8 @@ class HunYuanAttention(nn.Module):
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
if
hasattr
(
config
,
"head_dim"
):
if
hasattr
(
config
,
"head_dim"
)
and
config
.
head_dim
:
self
.
head_dim
=
config
.
head_dim
elif
hasattr
(
config
,
"attention_head_dim"
):
self
.
head_dim
=
config
.
attention_head_dim
...
...
@@ -490,12 +503,23 @@ class HunYuanDecoderLayer(nn.Module):
else
:
raise
RuntimeError
(
f
"Unsupported attention type:
{
attention_type
}
"
)
if
_is_moe
(
config
):
self
.
mlp
=
HunYuanSparseMoeBlock
(
config
=
config
,
quant_config
=
quant_config
,
layer_id
=
layer_id
,
prefix
=
f
"
{
prefix
}
.mlp"
,
)
else
:
self
.
mlp
=
HunYuanMLP
(
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
self
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
bias
=
getattr
(
config
,
"mlp_bias"
,
False
),
prefix
=
f
"
{
prefix
}
.mlp"
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
...
...
@@ -642,7 +666,7 @@ class HunYuanModel(nn.Module):
return
torch
.
concat
((
q
,
k
,
v
))
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
if
_is_moe
(
self
.
config
):
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return
FusedMoE
.
make_expert_params_mapping
(
...
...
@@ -651,6 +675,8 @@ class HunYuanModel(nn.Module):
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
num_experts
,
)
else
:
return
[]
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
cla_factor
=
_get_cla_factor
(
self
.
config
)
...
...
@@ -815,7 +841,7 @@ class HunYuanModel(nn.Module):
return
loaded_params
class
HunYuan
MoEV1ForCausalLM
(
nn
.
Module
,
SupportsLoRA
):
class
HunYuan
V1Base
(
nn
.
Module
,
SupportsLoRA
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
...
...
@@ -901,3 +927,11 @@ class HunYuanMoEV1ForCausalLM(nn.Module, SupportsLoRA):
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
return
self
.
model
.
get_expert_mapping
()
class
HunYuanDenseV1ForCausalLM
(
HunYuanV1Base
):
pass
class
HunYuanMoEV1ForCausalLM
(
HunYuanV1Base
):
pass
vllm/model_executor/models/registry.py
View file @
2671334d
...
...
@@ -79,7 +79,8 @@ _TEXT_GENERATION_MODELS = {
"GraniteMoeSharedForCausalLM"
:
(
"granitemoeshared"
,
"GraniteMoeSharedForCausalLM"
),
# noqa: E501
"GritLM"
:
(
"gritlm"
,
"GritLM"
),
"Grok1ModelForCausalLM"
:
(
"grok1"
,
"Grok1ForCausalLM"
),
"HunYuanMoEV1ForCausalLM"
:
(
"hunyuan_v1_moe"
,
"HunYuanMoEV1ForCausalLM"
),
"HunYuanMoEV1ForCausalLM"
:
(
"hunyuan_v1"
,
"HunYuanMoEV1ForCausalLM"
),
"HunYuanDenseV1ForCausalLM"
:
(
"hunyuan_v1"
,
"HunYuanDenseV1ForCausalLM"
),
"InternLMForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"InternLM2ForCausalLM"
:
(
"internlm2"
,
"InternLM2ForCausalLM"
),
"InternLM2VEForCausalLM"
:
(
"internlm2_ve"
,
"InternLM2VEForCausalLM"
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment