Unverified Commit 2390a2bc authored by Meng, Peng's avatar Meng, Peng Committed by GitHub
Browse files

Add Tencent HunYuanMoEV1 model support (#7549)

parent 16d76b9f
...@@ -890,6 +890,43 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding): ...@@ -890,6 +890,43 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding):
return query_out.type_as(query), key_out.type_as(key) return query_out.type_as(query), key_out.type_as(key)
class DynamicNTKAlphaRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with Dynamic NTK scaling.
Credits to the Reddit users /u/bloc97 and /u/emozilla
"""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
scaling_alpha: float,
dtype: torch.dtype,
) -> None:
self.scaling_alpha = scaling_alpha
super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)
def _compute_cos_sin_cache(self) -> torch.Tensor:
max_len = self.max_position_embeddings
base = self.base * self.scaling_alpha ** (
self.rotary_dim / (self.rotary_dim - 2)
)
inv_freq = self._compute_inv_freq(base)
t = torch.arange(max_len, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache
class MRotaryEmbedding(RotaryEmbedding): class MRotaryEmbedding(RotaryEmbedding):
"""Rotary Embedding with Multimodal Sections.""" """Rotary Embedding with Multimodal Sections."""
...@@ -1234,15 +1271,26 @@ def get_rope( ...@@ -1234,15 +1271,26 @@ def get_rope(
) )
elif scaling_type == "dynamic": elif scaling_type == "dynamic":
scaling_factor = rope_scaling["factor"] scaling_factor = rope_scaling["factor"]
rotary_emb = DynamicNTKScalingRotaryEmbedding( if "alpha" in rope_scaling:
head_size, rotary_emb = DynamicNTKAlphaRotaryEmbedding(
rotary_dim, head_size,
max_position, rotary_dim,
base, max_position,
is_neox_style, base,
scaling_factor, is_neox_style,
dtype, rope_scaling["alpha"],
) dtype,
)
else:
rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_factor,
dtype,
)
elif scaling_type == "yarn": elif scaling_type == "yarn":
scaling_factor = rope_scaling["factor"] scaling_factor = rope_scaling["factor"]
original_max_position = rope_scaling["original_max_position_embeddings"] original_max_position = rope_scaling["original_max_position_embeddings"]
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment