"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "413a2e4f915d390354d862cb5c67460329eb78e7"
Unverified Commit 42960214 authored by Asher's avatar Asher Committed by GitHub
Browse files

[Hunyuan]: Fix Dense Model Support (#8117)


Signed-off-by: default avatarAsher Zhang <asherszhang@tencent.com>
parent 01857fab
...@@ -206,6 +206,42 @@ class HunYuanSparseMoeBlock(nn.Module): ...@@ -206,6 +206,42 @@ class HunYuanSparseMoeBlock(nn.Module):
return final_hidden_states.view(orig_shape) return final_hidden_states.view(orig_shape)
def get_head_dim(config):
if hasattr(config, "head_dim"):
return int(config.head_dim)
if hasattr(config, "attention_head_dim"):
return int(config.attention_head_dim)
# since some hunyuan model don't follow the self.hidden_size // self.total_num_heads rule
# wrong setting may cause runtime error, just throw error if this field is missing.
raise ValueError("Missing head dim config, try set head_dim in config.json")
def check_head_dim(config):
# Some models may lack `head_dim` and use `attention_head_dim` instead.
# This attribute is also used by flashinfer_backend.py, so we check for
# consistency and raise an error if it's not met to avoid silent failures.
# Although we could adapt the HunYuan model to use `attention_head_dim`,
# flashinfer expects `head_dim`, so we enforce its presence for correctness.
calc_head_dim = config.hidden_size // config.num_attention_heads
if hasattr(config, "attention_head_dim"):
if calc_head_dim != config.attention_head_dim and not hasattr(
config, "head_dim"
):
# in this case, flash infer(and other components may calculate wrong value.)
raise ValueError(
f"HunYuan model config error: calculated head_dim {calc_head_dim} != attention_head_dim {config.attention_head_dim}"
+ f"\nPlease Add head_dim:{config.attention_head_dim} in config.json to make sure correctly inference."
)
if hasattr(config, "head_dim") and config.attention_head_dim != config.head_dim:
raise ValueError(
f"HunYuan model config error: head_dim({config.head_dim}) != attention_head_dim({config.attention_head_dim})"
+ f"\nPlease change head_dim:{config.attention_head_dim} in config.json to make sure correctly inference."
)
class HunYuanAttention(nn.Module): class HunYuanAttention(nn.Module):
def __init__( def __init__(
...@@ -240,9 +276,11 @@ class HunYuanAttention(nn.Module): ...@@ -240,9 +276,11 @@ class HunYuanAttention(nn.Module):
assert tp_size % self.total_num_kv_heads == 0 assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
# MistralConfig has an optional head_dim introduced by Mistral-Nemo # MistralConfig has an optional head_dim introduced by Mistral-Nemo
self.head_dim = getattr( # Prioritize `head_dim` but fall back to `attention_head_dim` for Hunyuan models.
config, "head_dim", self.hidden_size // self.total_num_heads self.head_dim = get_head_dim(config)
)
check_head_dim(config)
self.q_size = self.num_heads * self.head_dim self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
...@@ -493,7 +531,6 @@ class HunYuanModel(nn.Module): ...@@ -493,7 +531,6 @@ class HunYuanModel(nn.Module):
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.get_input_embeddings(input_ids)
residual = None residual = None
cla_factor = _get_cla_factor(self.config)
prev_kv_states = None prev_kv_states = None
for i in range(len(self.layers)): for i in range(len(self.layers)):
layer = self.layers[i] layer = self.layers[i]
...@@ -560,6 +597,11 @@ class HunYuanMoEV1ForCausalLM(nn.Module): ...@@ -560,6 +597,11 @@ class HunYuanMoEV1ForCausalLM(nn.Module):
if config.tie_word_embeddings: if config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight self.lm_head.weight = self.model.embed_tokens.weight
self.hidden_size = config.hidden_size
self.head_dim = get_head_dim(config)
check_head_dim(config)
logit_scale = getattr(config, "logit_scale", 1.0) logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(config, logit_scale=logit_scale) self.logits_processor = LogitsProcessor(config, logit_scale=logit_scale)
self.sampler = Sampler() self.sampler = Sampler()
...@@ -582,16 +624,14 @@ class HunYuanMoEV1ForCausalLM(nn.Module): ...@@ -582,16 +624,14 @@ class HunYuanMoEV1ForCausalLM(nn.Module):
self.config, "num_key_value_heads", self.config.num_attention_heads self.config, "num_key_value_heads", self.config.num_attention_heads
) )
num_key_value_groups = num_attention_heads // num_kv_heads num_key_value_groups = num_attention_heads // num_kv_heads
hidden_size = self.config.hidden_size
attention_head_dim = self.config.hidden_size // num_attention_heads
qkv = qkv.reshape( qkv = qkv.reshape(
num_kv_heads, num_key_value_groups + 2, attention_head_dim, hidden_size num_kv_heads, num_key_value_groups + 2, self.head_dim, self.hidden_size
) )
q, k, v = torch.split(qkv, (num_key_value_groups, 1, 1), dim=1) q, k, v = torch.split(qkv, (num_key_value_groups, 1, 1), dim=1)
q = q.reshape(-1, hidden_size) q = q.reshape(-1, self.hidden_size)
k = k.reshape(-1, hidden_size) k = k.reshape(-1, self.hidden_size)
v = v.reshape(-1, hidden_size) v = v.reshape(-1, self.hidden_size)
return torch.concat((q, k, v)) return torch.concat((q, k, v))
# return qkv.reshape((num_kv_heads, num_key_value_groups+2 , attention_head_dim, hidden_size)).permute((1,0,2,3)).reshape((-1, hidden_size)), # return qkv.reshape((num_kv_heads, num_key_value_groups+2 , attention_head_dim, hidden_size)).permute((1,0,2,3)).reshape((-1, hidden_size)),
...@@ -768,4 +808,8 @@ class HunYuanMoEV1ForCausalLM(nn.Module): ...@@ -768,4 +808,8 @@ class HunYuanMoEV1ForCausalLM(nn.Module):
) )
EntryClass = HunYuanMoEV1ForCausalLM class HunYuanDenseV1ForCausalLM(HunYuanMoEV1ForCausalLM):
pass
EntryClass = [HunYuanMoEV1ForCausalLM, HunYuanDenseV1ForCausalLM]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment