Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2390a2bc
Unverified
Commit
2390a2bc
authored
Jun 26, 2025
by
Meng, Peng
Committed by
GitHub
Jun 25, 2025
Browse files
Add Tencent HunYuanMoEV1 model support (#7549)
parent
16d76b9f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
828 additions
and
9 deletions
+828
-9
python/sglang/srt/layers/rotary_embedding.py
python/sglang/srt/layers/rotary_embedding.py
+57
-9
python/sglang/srt/models/hunyuan.py
python/sglang/srt/models/hunyuan.py
+771
-0
No files found.
python/sglang/srt/layers/rotary_embedding.py
View file @
2390a2bc
...
@@ -890,6 +890,43 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding):
...
@@ -890,6 +890,43 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding):
return
query_out
.
type_as
(
query
),
key_out
.
type_as
(
key
)
return
query_out
.
type_as
(
query
),
key_out
.
type_as
(
key
)
class
DynamicNTKAlphaRotaryEmbedding
(
RotaryEmbedding
):
"""RotaryEmbedding extended with Dynamic NTK scaling.
Credits to the Reddit users /u/bloc97 and /u/emozilla
"""
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
int
,
is_neox_style
:
bool
,
scaling_alpha
:
float
,
dtype
:
torch
.
dtype
,
)
->
None
:
self
.
scaling_alpha
=
scaling_alpha
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
)
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
max_len
=
self
.
max_position_embeddings
base
=
self
.
base
*
self
.
scaling_alpha
**
(
self
.
rotary_dim
/
(
self
.
rotary_dim
-
2
)
)
inv_freq
=
self
.
_compute_inv_freq
(
base
)
t
=
torch
.
arange
(
max_len
,
dtype
=
torch
.
float
)
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
return
cache
class
MRotaryEmbedding
(
RotaryEmbedding
):
class
MRotaryEmbedding
(
RotaryEmbedding
):
"""Rotary Embedding with Multimodal Sections."""
"""Rotary Embedding with Multimodal Sections."""
...
@@ -1234,15 +1271,26 @@ def get_rope(
...
@@ -1234,15 +1271,26 @@ def get_rope(
)
)
elif
scaling_type
==
"dynamic"
:
elif
scaling_type
==
"dynamic"
:
scaling_factor
=
rope_scaling
[
"factor"
]
scaling_factor
=
rope_scaling
[
"factor"
]
rotary_emb
=
DynamicNTKScalingRotaryEmbedding
(
if
"alpha"
in
rope_scaling
:
head_size
,
rotary_emb
=
DynamicNTKAlphaRotaryEmbedding
(
rotary_dim
,
head_size
,
max_position
,
rotary_dim
,
base
,
max_position
,
is_neox_style
,
base
,
scaling_factor
,
is_neox_style
,
dtype
,
rope_scaling
[
"alpha"
],
)
dtype
,
)
else
:
rotary_emb
=
DynamicNTKScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
scaling_factor
,
dtype
,
)
elif
scaling_type
==
"yarn"
:
elif
scaling_type
==
"yarn"
:
scaling_factor
=
rope_scaling
[
"factor"
]
scaling_factor
=
rope_scaling
[
"factor"
]
original_max_position
=
rope_scaling
[
"original_max_position_embeddings"
]
original_max_position
=
rope_scaling
[
"original_max_position_embeddings"
]
...
...
python/sglang/srt/models/hunyuan.py
0 → 100644
View file @
2390a2bc
# coding=utf-8
# Copyright 2024 The HunYuan team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only HunYuan model compatible with HuggingFace weights."""
import
logging
import
re
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
sglang.srt.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
,
)
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
sglang.srt.managers.expert_distribution
import
ExpertDistributionRecorder
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
(
default_weight_loader
,
kv_cache_scales_loader
,
maybe_remap_kv_scale_name
,
)
from
sglang.srt.utils
import
add_prefix
,
is_hip
expert_distribution_recorder
=
ExpertDistributionRecorder
()
def
_is_moe
(
config
:
PretrainedConfig
)
->
bool
:
if
getattr
(
config
,
"num_experts"
,
None
)
and
(
(
isinstance
(
config
.
num_experts
,
int
)
and
config
.
num_experts
>
1
)
or
(
isinstance
(
config
.
num_experts
,
list
)
and
max
(
config
.
num_experts
)
>
1
)
):
return
True
else
:
return
False
def
_get_cla_factor
(
config
:
PretrainedConfig
)
->
int
:
if
not
getattr
(
config
,
"use_cla"
,
False
):
return
1
return
getattr
(
config
,
"cla_share_factor"
,
1
)
class
HunYuanMLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
bias
:
bool
=
False
,
prefix
:
str
=
""
,
reduce_results
:
bool
=
True
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
input_size
=
hidden_size
,
output_sizes
=
[
intermediate_size
]
*
2
,
bias
=
bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
,
)
self
.
down_proj
=
RowParallelLinear
(
input_size
=
intermediate_size
,
output_size
=
hidden_size
,
bias
=
bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.down_proj"
,
reduce_results
=
reduce_results
,
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
class
HunYuanSparseMoeBlock
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
layer_id
:
int
=
-
1
,
):
super
().
__init__
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
if
self
.
tp_size
>
config
.
num_experts
:
raise
ValueError
(
f
"Tensor parallel size
{
self
.
tp_size
}
is greater than "
f
"the number of experts
{
config
.
num_experts
}
."
)
# Get layer_id topk if config.moe_topk is a list
if
isinstance
(
config
.
moe_topk
,
list
):
assert
layer_id
>=
0
assert
len
(
config
.
moe_topk
)
>
layer_id
top_k
=
config
.
moe_topk
[
layer_id
]
else
:
top_k
=
config
.
moe_topk
# If it is moe, moe_intermediate_size is preferred
intermediate_size
=
config
.
intermediate_size
if
config
.
moe_intermediate_size
is
not
None
:
intermediate_size
=
(
config
.
moe_intermediate_size
if
isinstance
(
config
.
moe_intermediate_size
,
int
)
else
config
.
moe_intermediate_size
[
layer_id
]
)
self
.
experts
=
FusedMoE
(
num_experts
=
config
.
num_experts
,
top_k
=
top_k
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
intermediate_size
,
reduce_results
=
False
,
renormalize
=
True
if
top_k
>
1
else
False
,
quant_config
=
quant_config
,
)
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
config
.
num_experts
,
bias
=
False
,
quant_config
=
None
)
if
config
.
use_mixed_mlp_moe
>
0
:
# Get layer_id num_shared_expert if config.num_shared_expert is a list
if
isinstance
(
config
.
num_shared_expert
,
list
):
assert
layer_id
>=
0
assert
len
(
config
.
num_shared_expert
)
>
layer_id
num_shared_expert
=
config
.
num_shared_expert
[
layer_id
]
else
:
num_shared_expert
=
config
.
num_shared_expert
self
.
shared_mlp
=
HunYuanMLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
*
num_shared_expert
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
reduce_results
=
False
,
)
else
:
self
.
shared_mlp
=
None
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape
=
hidden_states
.
shape
hidden_dim
=
hidden_states
.
shape
[
-
1
]
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
shared_output
=
None
if
self
.
shared_mlp
is
not
None
:
shared_output
=
self
.
shared_mlp
(
hidden_states
)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
if
shared_output
is
not
None
:
final_hidden_states
=
final_hidden_states
+
shared_output
if
self
.
tp_size
>
1
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
.
view
(
orig_shape
)
class
HunYuanAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
8192
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
bias
:
bool
=
False
,
prefix
:
str
=
""
,
attention_type
:
str
=
"self"
,
layer_id
:
int
=
-
1
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
num_kv_heads
if
self
.
total_num_kv_heads
>=
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
self
.
head_dim
=
getattr
(
config
,
"head_dim"
,
self
.
hidden_size
//
self
.
total_num_heads
)
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
rope_theta
=
rope_theta
self
.
max_position_embeddings
=
max_position_embeddings
self
.
use_qk_norm
=
getattr
(
config
,
"use_qk_norm"
,
False
)
self
.
attention_type
=
attention_type
self
.
layer_id
=
layer_id
if
attention_type
==
"self"
:
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
=
hidden_size
,
head_size
=
self
.
head_dim
,
total_num_heads
=
self
.
total_num_heads
,
total_num_kv_heads
=
self
.
total_num_kv_heads
,
bias
=
bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
elif
attention_type
==
"cross"
:
self
.
q_proj
=
ColumnParallelLinear
(
hidden_size
,
hidden_size
,
bias
=
bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.q_proj"
,
)
else
:
raise
RuntimeError
(
"Not support attnention type"
)
self
.
o_proj
=
RowParallelLinear
(
input_size
=
self
.
total_num_heads
*
self
.
head_dim
,
output_size
=
hidden_size
,
bias
=
bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
)
is_neox_style
=
True
if
quant_config
is
not
None
and
quant_config
.
get_name
()
==
"gguf"
:
is_neox_style
=
False
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
is_neox_style
=
is_neox_style
,
)
self
.
attn
=
RadixAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
layer_id
=
layer_id
,
prefix
=
f
"
{
prefix
}
.attn"
,
)
if
self
.
use_qk_norm
:
self
.
query_layernorm
=
RMSNorm
(
self
.
head_dim
,
eps
=
config
.
rms_norm_eps
)
self
.
key_layernorm
=
RMSNorm
(
self
.
head_dim
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
kv_states
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
)
->
torch
.
Tensor
:
if
self
.
attention_type
==
"self"
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
ori_k
=
k
if
self
.
use_qk_norm
:
# q = self.query_layernorm(q.view(-1, self.num_heads, self.head_dim).contiguous())
# k = self.key_layernorm(k.view(-1, self.num_kv_heads, self.head_dim).contiguous())
q
=
self
.
query_layernorm
(
q
.
reshape
(
-
1
,
self
.
head_dim
).
contiguous
())
k
=
self
.
key_layernorm
(
k
.
reshape
(
-
1
,
self
.
head_dim
).
contiguous
())
elif
self
.
attention_type
==
"cross"
:
assert
kv_states
is
not
None
ori_k
,
v
=
kv_states
# use last layer kv,
k
=
ori_k
q
,
_
=
self
.
q_proj
(
hidden_states
)
k_tmp
=
torch
.
empty_like
(
k
)
# Todo: reduant rotary embedding
q
,
_
=
self
.
rotary_emb
(
positions
,
q
,
k_tmp
)
if
self
.
use_qk_norm
:
q
=
self
.
query_layernorm
(
q
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_dim
).
contiguous
()
)
k
=
self
.
key_layernorm
(
k
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_dim
).
contiguous
()
)
else
:
raise
RuntimeError
(
"Not support attnention type"
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
,
(
ori_k
,
v
)
class
HunYuanDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
layer_id
:
int
=
-
1
,
)
->
None
:
super
().
__init__
()
assert
layer_id
>=
0
self
.
layer_id
=
layer_id
self
.
hidden_size
=
config
.
hidden_size
self
.
intermediate_size
=
(
config
.
intermediate_size
if
isinstance
(
config
.
intermediate_size
,
int
)
else
config
.
intermediate_size
[
layer_id
]
)
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
if
rope_scaling
is
not
None
and
getattr
(
config
,
"original_max_position_embeddings"
,
None
):
rope_scaling
[
"original_max_position_embeddings"
]
=
(
config
.
original_max_position_embeddings
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
# Support abacusai/Smaug-72B-v0.1 with attention_bias
# Support internlm/internlm-7b with bias
attention_bias
=
getattr
(
config
,
"attention_bias"
,
False
)
or
getattr
(
config
,
"bias"
,
False
)
cla_factor
=
_get_cla_factor
(
config
)
attention_type
=
(
"cross"
if
layer_id
>=
0
and
layer_id
%
cla_factor
!=
0
else
"self"
)
self
.
self_attn
=
HunYuanAttention
(
config
=
config
,
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_kv_heads
=
getattr
(
config
,
"num_key_value_heads"
,
config
.
num_attention_heads
),
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
quant_config
=
quant_config
,
bias
=
attention_bias
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
attention_type
=
attention_type
,
layer_id
=
layer_id
,
)
if
_is_moe
(
config
):
self
.
mlp
=
HunYuanSparseMoeBlock
(
config
=
config
,
quant_config
=
quant_config
,
layer_id
=
layer_id
,
)
else
:
self
.
mlp
=
HunYuanMLP
(
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
self
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
bias
=
getattr
(
config
,
"mlp_bias"
,
False
),
prefix
=
f
"
{
prefix
}
.mlp"
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
kv_states
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
,
ori_kv_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
forward_batch
=
forward_batch
,
kv_states
=
kv_states
,
)
# Fully Connected
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
return
hidden_states
,
residual
,
ori_kv_states
class
HunYuanModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
org_vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
vocab_size
,
config
.
hidden_size
,
)
self
.
layers
=
nn
.
ModuleList
(
[
HunYuanDecoderLayer
(
config
=
config
,
layer_id
=
layer_id
,
quant_config
=
quant_config
,
# prefix=prefix
)
for
layer_id
in
range
(
config
.
num_hidden_layers
)
]
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
input_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
input_embeds
is
not
None
:
hidden_states
=
input_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
residual
=
None
cla_factor
=
_get_cla_factor
(
self
.
config
)
prev_kv_states
=
None
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
,
kv_states
=
layer
(
positions
,
hidden_states
,
forward_batch
,
residual
,
prev_kv_states
,
)
if
False
:
# (i - self.start_layer) % cla_factor == 0:
prev_kv_states
=
kv_states
else
:
prev_kv_states
=
None
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
class
HunYuanMoEV1ForCausalLM
(
nn
.
Module
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
}
embedding_modules
=
{
"embed_tokens"
:
"input_embeddings"
,
"lm_head"
:
"output_embeddings"
,
}
embedding_padding_modules
=
[
"lm_head"
]
bitsandbytes_stacked_params_mapping
=
{
# shard_name, weight_name, index
"q_proj"
:
(
"qkv_proj"
,
0
),
"k_proj"
:
(
"qkv_proj"
,
1
),
"v_proj"
:
(
"qkv_proj"
,
2
),
"gate_proj"
:
(
"gate_up_proj"
,
0
),
"up_proj"
:
(
"gate_up_proj"
,
1
),
}
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
model
=
HunYuanModel
(
config
,
quant_config
,
prefix
=
"model"
)
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
)
if
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
logits_processor
=
LogitsProcessor
(
config
,
logit_scale
=
logit_scale
)
self
.
sampler
=
Sampler
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
def
_split_qkv_weight
(
self
,
qkv
:
torch
.
Tensor
):
num_attention_heads
=
self
.
config
.
num_attention_heads
num_kv_heads
=
getattr
(
self
.
config
,
"num_key_value_heads"
,
self
.
config
.
num_attention_heads
)
num_key_value_groups
=
num_attention_heads
//
num_kv_heads
hidden_size
=
self
.
config
.
hidden_size
attention_head_dim
=
self
.
config
.
hidden_size
//
num_attention_heads
qkv
=
qkv
.
reshape
(
num_kv_heads
,
num_key_value_groups
+
2
,
attention_head_dim
,
hidden_size
)
q
,
k
,
v
=
torch
.
split
(
qkv
,
(
num_key_value_groups
,
1
,
1
),
dim
=
1
)
q
=
q
.
reshape
(
-
1
,
hidden_size
)
k
=
k
.
reshape
(
-
1
,
hidden_size
)
v
=
v
.
reshape
(
-
1
,
hidden_size
)
return
torch
.
concat
((
q
,
k
,
v
))
# return qkv.reshape((num_kv_heads, num_key_value_groups+2 , attention_head_dim, hidden_size)).permute((1,0,2,3)).reshape((-1, hidden_size)),
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
cla_factor
=
_get_cla_factor
(
self
.
config
)
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".qkv_proj"
,
".q_proj"
,
"q"
),
(
".qkv_proj"
,
".k_proj"
,
"k"
),
(
".qkv_proj"
,
".v_proj"
,
"v"
),
(
".gate_up_proj"
,
".gate_proj"
,
0
),
(
".gate_up_proj"
,
".up_proj"
,
1
),
]
num_attention_heads
=
self
.
config
.
num_attention_heads
num_kv_heads
=
getattr
(
self
.
config
,
"num_key_value_heads"
,
self
.
config
.
num_attention_heads
)
split_params_mapping
=
[
(
".gate_up_proj"
,
".gate_and_up_proj"
,
2
,
[(
1
,
1
),
(
0
,
1
)],
None
),
(
".qkv_proj"
,
".qkv_proj"
,
num_attention_heads
+
num_kv_heads
*
2
,
[(
"q"
,
num_attention_heads
),
(
"k"
,
num_kv_heads
),
(
"v"
,
num_kv_heads
)],
self
.
_split_qkv_weight
,
),
]
if
_is_moe
(
self
.
config
):
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
num_experts
,
)
else
:
expert_params_mapping
=
{}
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
if
"gate_proj_bias"
in
name
:
name
=
name
.
replace
(
"gate_proj_bias"
,
"gate_proj.bias"
)
if
"up_proj_bias"
in
name
:
name
=
name
.
replace
(
"up_proj_bias"
,
"up_proj.bias"
)
if
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
# With tie_word_embeddings, we can skip lm_head.weight
# The weight might appear unnecessarily in the files if the model is
# processed with quantization, LoRA, fine-tuning, etc.
if
self
.
config
.
tie_word_embeddings
and
"lm_head.weight"
in
name
:
continue
is_found
=
False
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
if
"mlp.experts"
in
name
:
continue
# cross layer only have q_proj, skip qkv pack
if
weight_name
==
".q_proj"
:
match
=
re
.
search
(
r
"layers\.\d+"
,
name
)
if
match
:
layer_id
=
int
(
match
.
group
(
0
).
split
(
"."
)[
-
1
])
if
cla_factor
>
1
and
layer_id
%
cla_factor
!=
0
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
is_found
=
True
break
if
is_found
:
continue
for
param_name
,
weight_name
,
den
,
split_param
,
func
in
split_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
assert
loaded_weight
.
shape
[
0
]
%
den
==
0
units
=
loaded_weight
.
shape
[
0
]
//
den
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
offset
=
0
for
shard_id
,
num
in
split_param
:
new_offset
=
offset
+
num
*
units
if
func
:
weight_loader
(
param
,
func
(
loaded_weight
)[
offset
:
new_offset
],
shard_id
)
else
:
weight_loader
(
param
,
loaded_weight
[
offset
:
new_offset
],
shard_id
)
offset
=
new_offset
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
for
mapping
in
expert_params_mapping
:
param_name
,
weight_name
,
expert_id
,
shard_id
=
mapping
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip layers on other devices.
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
,
)
break
else
:
# Remapping the name of FP8 kv-scale.
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
if
"mlp.gate.wg."
in
name
:
name
=
name
.
replace
(
"wg."
,
""
)
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
def
load_kv_cache_scales
(
self
,
quantization_param_path
:
str
)
->
None
:
tp_size
=
get_tensor_model_parallel_world_size
()
tp_rank
=
get_tensor_model_parallel_rank
()
for
layer_idx
,
scaling_factor
in
kv_cache_scales_loader
(
quantization_param_path
,
tp_rank
,
tp_size
,
self
.
config
.
num_hidden_layers
,
self
.
config
.
__class__
.
model_type
,
):
if
not
isinstance
(
self
.
model
.
layers
[
layer_idx
],
nn
.
Identity
):
layer_self_attn
=
self
.
model
.
layers
[
layer_idx
].
self_attn
if
is_hip
():
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
scaling_factor
*=
2
if
hasattr
(
layer_self_attn
,
"kv_scale"
):
layer_self_attn
.
attn
.
_kv_scale
=
scaling_factor
else
:
raise
RuntimeError
(
"Self attention has no KV cache scaling "
"factor attribute!"
)
EntryClass
=
HunYuanMoEV1ForCausalLM
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment