Unverified Commit cf3eacfe authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Standardise `get_rope` to use `rope_parameters["partial_rotary_factor"]`, not `rotary_dim` (#30389)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 92fea56f
......@@ -288,7 +288,6 @@ class Ernie4_5_MoeAttention(nn.Module):
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
rope_parameters=rope_parameters,
is_neox_style=False,
......
......@@ -167,7 +167,6 @@ class ExaoneAttention(nn.Module):
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
rope_parameters=config.rope_parameters,
is_neox_style=is_neox_style,
......
......@@ -176,7 +176,6 @@ class Exaone4Attention(nn.Module):
set_default_rope_theta(config, default_theta=1000000)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
rope_parameters=config.rope_parameters,
is_neox_style=is_neox_style,
......
......@@ -167,7 +167,6 @@ class FalconAttention(nn.Module):
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
rope_parameters=config.rope_parameters,
)
......
......@@ -242,14 +242,11 @@ class FalconH1AttentionDecoderLayer(nn.Module):
self.scaling = self.head_dim**-0.5
self.max_position_embeddings = max_position_embeddings
if hasattr(config, "attn_rotary_emb"):
rotary_dim = config.attn_rotary_emb # for backward compatibility
else:
rotary_dim = self.head_dim # default
rotary_dim = getattr(config, "attn_rotary_emb", self.head_dim)
config.rope_parameters["partial_rotary_factor"] = rotary_dim / self.head_dim
self.rotary_emb = get_rope(
head_size=self.head_dim,
rotary_dim=rotary_dim,
max_position=max_position_embeddings,
rope_parameters=config.rope_parameters,
is_neox_style=True,
......
......@@ -174,7 +174,6 @@ class GemmaAttention(nn.Module):
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
rope_parameters=rope_parameters,
is_neox_style=True,
......
......@@ -152,7 +152,6 @@ class Gemma2Attention(nn.Module):
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
rope_parameters=config.rope_parameters,
is_neox_style=True,
......
......@@ -176,7 +176,6 @@ class Gemma3Attention(nn.Module):
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
rope_parameters=rope_parameters,
is_neox_style=True,
......
......@@ -384,7 +384,6 @@ class Gemma3nAttention(nn.Module):
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
rope_parameters=rope_parameters,
is_neox_style=True,
......
......@@ -81,7 +81,6 @@ class Glm4Attention(nn.Module):
config.rope_parameters.setdefault("partial_rotary_factor", 0.5)
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = head_dim or hidden_size // self.total_num_heads
self.rotary_dim = self.head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
......@@ -103,7 +102,6 @@ class Glm4Attention(nn.Module):
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.rotary_dim,
max_position=max_position,
rope_parameters=config.rope_parameters,
is_neox_style=False,
......
......@@ -678,9 +678,9 @@ class Glm4vVisionTransformer(nn.Module):
head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = get_rope(
head_size=head_dim,
rotary_dim=head_dim // 2,
max_position=8192,
is_neox_style=True,
rope_parameters={"partial_rotary_factor": 0.5},
)
self.blocks = nn.ModuleList(
[
......
......@@ -285,7 +285,6 @@ class Glm4MoeAttention(nn.Module):
config.rope_parameters.setdefault("partial_rotary_factor", 0.5)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
rope_parameters=config.rope_parameters,
)
......
......@@ -95,12 +95,13 @@ class GPTJAttention(nn.Module):
scaling = self.head_size**-0.5
assert getattr(config, "rotary", True)
assert config.rotary_dim % 2 == 0
rope_parameters = getattr(config, "rope_parameters", {})
rope_parameters["partial_rotary_factor"] = config.rotary_dim / self.head_size
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.rotary_emb = get_rope(
self.head_size,
rotary_dim=config.rotary_dim,
max_position=max_position_embeddings,
rope_parameters=getattr(config, "rope_parameters", None),
rope_parameters=rope_parameters,
is_neox_style=False,
)
self.attn = Attention(
......
......@@ -92,7 +92,6 @@ class GPTNeoXAttention(nn.Module):
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.rotary_emb = get_rope(
self.head_size,
rotary_dim=self.head_size,
max_position=max_position_embeddings,
rope_parameters=config.rope_parameters,
)
......
......@@ -67,7 +67,6 @@ class OAIAttention(nn.Module):
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=config.max_position_embeddings,
dtype=torch.float32,
rope_parameters={
......
......@@ -160,7 +160,6 @@ class GraniteAttention(nn.Module):
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
rope_parameters=config.rope_parameters,
)
......
......@@ -190,7 +190,6 @@ class GraniteMoeAttention(nn.Module):
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
rope_parameters=rope_parameters,
is_neox_style=True,
......
......@@ -271,7 +271,6 @@ class GraniteMoeHybridAttention(nn.Module):
if config.position_embedding_type == "rope":
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=config.max_position_embeddings,
rope_parameters=config.rope_parameters,
is_neox_style=True,
......
......@@ -181,7 +181,6 @@ class Grok1Attention(nn.Module):
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
rope_parameters=rope_parameters,
is_neox_style=True,
......
......@@ -199,7 +199,6 @@ class HunYuanAttention(nn.Module):
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
rope_parameters=config.rope_parameters,
is_neox_style=True,
......@@ -305,7 +304,6 @@ class HunYuanCrossAttention(nn.Module):
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
rope_parameters=config.rope_parameters,
is_neox_style=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment