Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
466e878f
Unverified
Commit
466e878f
authored
Jul 19, 2025
by
Jee Jee Li
Committed by
GitHub
Jul 18, 2025
Browse files
[Quantization] Enable BNB support for more MoE models (#21100)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
21793722
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
223 additions
and
181 deletions
+223
-181
docs/models/supported_models.md
docs/models/supported_models.md
+4
-4
vllm/model_executor/models/bailing_moe.py
vllm/model_executor/models/bailing_moe.py
+14
-7
vllm/model_executor/models/ernie45_moe.py
vllm/model_executor/models/ernie45_moe.py
+84
-69
vllm/model_executor/models/grok1.py
vllm/model_executor/models/grok1.py
+14
-10
vllm/model_executor/models/hunyuan_v1_moe.py
vllm/model_executor/models/hunyuan_v1_moe.py
+107
-91
No files found.
docs/models/supported_models.md
View file @
466e878f
...
...
@@ -316,7 +316,7 @@ Specified using `--task generate`.
|
`AquilaForCausalLM`
| Aquila, Aquila2 |
`BAAI/Aquila-7B`
,
`BAAI/AquilaChat-7B`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`ArcticForCausalLM`
| Arctic |
`Snowflake/snowflake-arctic-base`
,
`Snowflake/snowflake-arctic-instruct`
, etc. | | ✅︎ | ✅︎ |
|
`BaiChuanForCausalLM`
| Baichuan2, Baichuan |
`baichuan-inc/Baichuan2-13B-Chat`
,
`baichuan-inc/Baichuan-7B`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`BailingMoeForCausalLM`
| Ling |
`inclusionAI/Ling-lite-1.5`
,
`inclusionAI/Ling-plus`
, etc. | | ✅︎ | ✅︎ |
|
`BailingMoeForCausalLM`
| Ling |
`inclusionAI/Ling-lite-1.5`
,
`inclusionAI/Ling-plus`
, etc. |
✅︎
| ✅︎ | ✅︎ |
|
`BambaForCausalLM`
| Bamba |
`ibm-ai-platform/Bamba-9B-fp8`
,
`ibm-ai-platform/Bamba-9B`
| ✅︎ | ✅︎ | ✅︎ |
|
`BloomForCausalLM`
| BLOOM, BLOOMZ, BLOOMChat |
`bigscience/bloom`
,
`bigscience/bloomz`
, etc. | | ✅︎ | |
|
`BartForConditionalGeneration`
| BART |
`facebook/bart-base`
,
`facebook/bart-large-cnn`
, etc. | | | |
...
...
@@ -328,8 +328,8 @@ Specified using `--task generate`.
|
`DeepseekV2ForCausalLM`
| DeepSeek-V2 |
`deepseek-ai/DeepSeek-V2`
,
`deepseek-ai/DeepSeek-V2-Chat`
, etc. | | ✅︎ | ✅︎ |
|
`DeepseekV3ForCausalLM`
| DeepSeek-V3 |
`deepseek-ai/DeepSeek-V3-Base`
,
`deepseek-ai/DeepSeek-V3`
, etc. | | ✅︎ | ✅︎ |
|
`Dots1ForCausalLM`
| dots.llm1 |
`rednote-hilab/dots.llm1.base`
,
`rednote-hilab/dots.llm1.inst`
, etc. | | ✅︎ | ✅︎ |
|
`Ernie4_5_ForCausalLM`
| Ernie4.5 |
`baidu/ERNIE-4.5-0.3B-PT`
, etc. | | ✅︎ | ✅︎ |
|
`Ernie4_5_MoeForCausalLM`
| Ernie4.5MoE |
`baidu/ERNIE-4.5-21B-A3B-PT`
,
`baidu/ERNIE-4.5-300B-A47B-PT`
, etc. |
| ✅︎ | ✅︎ |
|
`Ernie4_5_ForCausalLM`
| Ernie4.5 |
`baidu/ERNIE-4.5-0.3B-PT`
, etc. |
✅︎
| ✅︎ | ✅︎ |
|
`Ernie4_5_MoeForCausalLM`
| Ernie4.5MoE |
`baidu/ERNIE-4.5-21B-A3B-PT`
,
`baidu/ERNIE-4.5-300B-A47B-PT`
, etc. |
✅︎
| ✅︎ | ✅︎ |
|
`ExaoneForCausalLM`
| EXAONE-3 |
`LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`Fairseq2LlamaForCausalLM`
| Llama (fairseq2 format) |
`mgleize/fairseq2-dummy-Llama-3.2-1B`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`FalconForCausalLM`
| Falcon |
`tiiuae/falcon-7b`
,
`tiiuae/falcon-40b`
,
`tiiuae/falcon-rw-7b`
, etc. | | ✅︎ | ✅︎ |
...
...
@@ -351,7 +351,7 @@ Specified using `--task generate`.
|
`GraniteMoeSharedForCausalLM`
| Granite MoE Shared |
`ibm-research/moe-7b-1b-active-shared-experts`
(test model) | ✅︎ | ✅︎ | ✅︎ |
|
`GritLM`
| GritLM |
`parasail-ai/GritLM-7B-vllm`
. | ✅︎ | ✅︎ | |
|
`Grok1ModelForCausalLM`
| Grok1 |
`hpcai-tech/grok-1`
. | ✅︎ | ✅︎ | ✅︎ |
|
`HunYuanMoEV1ForCausalLM`
| Hunyuan-80B-A13B |
`tencent/Hunyuan-A13B-Instruct`
,
`tencent/Hunyuan-A13B-Pretrain`
,
`tencent/Hunyuan-A13B-Instruct-FP8`
, etc. | | | ✅︎ |
|
`HunYuanMoEV1ForCausalLM`
| Hunyuan-80B-A13B |
`tencent/Hunyuan-A13B-Instruct`
,
`tencent/Hunyuan-A13B-Pretrain`
,
`tencent/Hunyuan-A13B-Instruct-FP8`
, etc. |
✅︎
| | ✅︎ |
|
`InternLMForCausalLM`
| InternLM |
`internlm/internlm-7b`
,
`internlm/internlm-chat-7b`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`InternLM2ForCausalLM`
| InternLM2 |
`internlm/internlm2-7b`
,
`internlm/internlm2-chat-7b`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`InternLM3ForCausalLM`
| InternLM3 |
`internlm/internlm3-8b-instruct`
, etc. | ✅︎ | ✅︎ | ✅︎ |
...
...
vllm/model_executor/models/bailing_moe.py
View file @
466e878f
...
...
@@ -53,7 +53,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsPP
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
PPMissingLayer
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
...
...
@@ -374,6 +374,14 @@ class BailingMoeModel(nn.Module):
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
return
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
num_experts
,
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
...
...
@@ -381,14 +389,10 @@ class BailingMoeModel(nn.Module):
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
num_experts
)
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
loaded_params
:
set
[
str
]
=
set
()
expert_params_mapping
=
self
.
get_expert_mapping
()
for
name
,
loaded_weight
in
weights
:
if
self
.
config
.
norm_head
and
"lm_head.weight"
in
name
:
loaded_weight
=
F
.
normalize
(
loaded_weight
,
...
...
@@ -449,7 +453,7 @@ class BailingMoeModel(nn.Module):
return
loaded_params
class
BailingMoeForCausalLM
(
nn
.
Module
,
SupportsPP
):
class
BailingMoeForCausalLM
(
nn
.
Module
,
SupportsPP
,
SupportsLoRA
):
packed_modules_mapping
=
{
"query_key_value"
:
[
"query_key_value"
],
...
...
@@ -518,3 +522,6 @@ class BailingMoeForCausalLM(nn.Module, SupportsPP):
if
self
.
config
.
tie_word_embeddings
else
None
),
)
return
loader
.
load_weights
(
weights
)
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
return
self
.
model
.
get_expert_mapping
()
vllm/model_executor/models/ernie45_moe.py
View file @
466e878f
...
...
@@ -51,8 +51,8 @@ from vllm.model_executor.model_loader.weight_utils import (
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsPP
from
.utils
import
(
PPMissingLayer
,
extract_layer_index
,
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
PPMissingLayer
,
extract_layer_index
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
...
...
@@ -427,66 +427,15 @@ class Ernie4_5_MoeModel(nn.Module):
return
hidden_states
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
class
Ernie4_5_MoeForCausalLM
(
nn
.
Module
,
SupportsPP
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
}
fall_back_to_pt_during_load
=
False
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
Ernie4_5_MoeModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
if
get_pp_group
().
is_last_rank
:
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
else
:
self
.
lm_head
=
PPMissingLayer
()
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
moe_num_experts
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
...
...
@@ -499,16 +448,9 @@ class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP):
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
moe_num_experts
)
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
set
[
str
]
=
set
()
expert_params_mapping
=
self
.
get_expert_mapping
()
for
name
,
loaded_weight
in
weights
:
if
self
.
config
.
tie_word_embeddings
and
name
.
endswith
(
"lm_head.weight"
):
...
...
@@ -581,3 +523,76 @@ class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP):
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
class
Ernie4_5_MoeForCausalLM
(
nn
.
Module
,
SupportsPP
,
SupportsLoRA
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
}
fall_back_to_pt_during_load
=
False
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
Ernie4_5_MoeModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
if
get_pp_group
().
is_last_rank
:
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
else
:
self
.
lm_head
=
PPMissingLayer
()
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
loader
=
AutoWeightsLoader
(
self
,
skip_prefixes
=
([
"lm_head."
]
if
self
.
config
.
tie_word_embeddings
else
None
),
)
return
loader
.
load_weights
(
weights
)
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
return
self
.
model
.
get_expert_mapping
()
vllm/model_executor/models/grok1.py
View file @
466e878f
...
...
@@ -360,6 +360,16 @@ class Grok1Model(nn.Module):
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
# Map Grok1's unique expert parameter names to standard names
# Grok1 uses "num_experts" in its config
num_experts
=
getattr
(
self
.
config
,
"num_experts"
,
8
)
return
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"linear"
,
# Grok1 specific
ckpt_down_proj_name
=
"linear_1"
,
# Grok1 specific
ckpt_up_proj_name
=
"linear_v"
,
# Grok1 specific
num_experts
=
num_experts
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
...
...
@@ -369,18 +379,9 @@ class Grok1Model(nn.Module):
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
# Map Grok1's unique expert parameter names to standard names
# Grok1 uses "num_experts" in its config
num_experts
=
getattr
(
self
.
config
,
"num_experts"
,
8
)
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"linear"
,
# Grok1 specific
ckpt_down_proj_name
=
"linear_1"
,
# Grok1 specific
ckpt_up_proj_name
=
"linear_v"
,
# Grok1 specific
num_experts
=
num_experts
)
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
set
[
str
]
=
set
()
expert_params_mapping
=
self
.
get_expert_mapping
()
for
name
,
loaded_weight
in
weights
:
if
(
self
.
quant_config
is
not
None
and
(
scale_name
:
=
self
.
quant_config
.
get_cache_scale
(
name
))):
...
...
@@ -544,3 +545,6 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
skip_prefixes
=
skip_prefixes
,
)
return
loader
.
load_weights
(
weights
)
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
return
self
.
model
.
get_expert_mapping
()
vllm/model_executor/models/hunyuan_v1_moe.py
View file @
466e878f
...
...
@@ -56,7 +56,9 @@ from vllm.model_executor.model_loader.weight_utils import (
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.utils
import
PPMissingLayer
,
is_pp_missing_parameter
,
make_layers
from
.interfaces
import
SupportsLoRA
from
.utils
import
(
AutoWeightsLoader
,
PPMissingLayer
,
is_pp_missing_parameter
,
make_layers
)
def
_get_cla_factor
(
config
:
PretrainedConfig
)
->
int
:
...
...
@@ -617,86 +619,6 @@ class HunYuanModel(nn.Module):
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
class
HunYuanMoEV1ForCausalLM
(
nn
.
Module
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
}
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
lora_config
=
lora_config
self
.
model
=
HunYuanModel
(
vllm_config
=
vllm_config
,
prefix
=
"model"
)
if
get_pp_group
().
is_last_rank
:
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
unpadded_vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
,
quant_config
=
quant_config
,
)
if
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
,
logit_scale
)
else
:
self
.
lm_head
=
PPMissingLayer
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
model_output
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
model_output
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
make_empty_intermediate_tensors
(
self
,
batch_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
)
->
IntermediateTensors
:
return
IntermediateTensors
({
"hidden_states"
:
torch
.
zeros
((
batch_size
,
self
.
config
.
hidden_size
),
dtype
=
dtype
,
device
=
device
),
"residual"
:
torch
.
zeros
((
batch_size
,
self
.
config
.
hidden_size
),
dtype
=
dtype
,
device
=
device
),
})
def
_split_qkv_weight
(
self
,
qkv
:
torch
.
Tensor
):
num_attention_heads
=
self
.
config
.
num_attention_heads
num_kv_heads
=
getattr
(
self
.
config
,
"num_key_value_heads"
,
...
...
@@ -719,6 +641,17 @@ class HunYuanMoEV1ForCausalLM(nn.Module):
v
=
v
.
reshape
(
-
1
,
hidden_size
)
return
torch
.
concat
((
q
,
k
,
v
))
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
num_experts
,
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
cla_factor
=
_get_cla_factor
(
self
.
config
)
stacked_params_mapping
=
[
...
...
@@ -745,16 +678,9 @@ class HunYuanMoEV1ForCausalLM(nn.Module):
),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
num_experts
,
)
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
set
[
str
]
=
set
()
expert_params_mapping
=
self
.
get_expert_mapping
()
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
...
...
@@ -806,7 +732,7 @@ class HunYuanMoEV1ForCausalLM(nn.Module):
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
loaded_params
.
add
(
name
)
is_found
=
True
break
if
is_found
:
...
...
@@ -885,3 +811,93 @@ class HunYuanMoEV1ForCausalLM(nn.Module):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
class
HunYuanMoEV1ForCausalLM
(
nn
.
Module
,
SupportsLoRA
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
}
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
HunYuanModel
(
vllm_config
=
vllm_config
,
prefix
=
"model"
)
if
get_pp_group
().
is_last_rank
:
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
unpadded_vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
,
quant_config
=
quant_config
,
)
if
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
,
logit_scale
)
else
:
self
.
lm_head
=
PPMissingLayer
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
model_output
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
model_output
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
make_empty_intermediate_tensors
(
self
,
batch_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
)
->
IntermediateTensors
:
return
IntermediateTensors
({
"hidden_states"
:
torch
.
zeros
((
batch_size
,
self
.
config
.
hidden_size
),
dtype
=
dtype
,
device
=
device
),
"residual"
:
torch
.
zeros
((
batch_size
,
self
.
config
.
hidden_size
),
dtype
=
dtype
,
device
=
device
),
})
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
loader
=
AutoWeightsLoader
(
self
,
skip_prefixes
=
([
"lm_head."
]
if
self
.
config
.
tie_word_embeddings
else
None
),
)
return
loader
.
load_weights
(
weights
)
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
return
self
.
model
.
get_expert_mapping
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment