Unverified Commit cf3eacfe authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Standardise `get_rope` to use `rope_parameters["partial_rotary_factor"]`, not `rotary_dim` (#30389)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 92fea56f
...@@ -140,7 +140,6 @@ class InternLM2Attention(nn.Module): ...@@ -140,7 +140,6 @@ class InternLM2Attention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
) )
......
...@@ -143,7 +143,6 @@ class Lfm2Attention(nn.Module): ...@@ -143,7 +143,6 @@ class Lfm2Attention(nn.Module):
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings, max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=True, is_neox_style=True,
......
...@@ -236,7 +236,6 @@ class Lfm2MoeAttention(nn.Module): ...@@ -236,7 +236,6 @@ class Lfm2MoeAttention(nn.Module):
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings, max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=True, is_neox_style=True,
......
...@@ -259,7 +259,6 @@ class LlamaAttention(nn.Module): ...@@ -259,7 +259,6 @@ class LlamaAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings, max_position=self.max_position_embeddings,
rope_parameters=getattr(config, "rope_parameters", None), rope_parameters=getattr(config, "rope_parameters", None),
is_neox_style=is_neox_style, is_neox_style=is_neox_style,
......
...@@ -243,7 +243,6 @@ class Llama4Attention(nn.Module): ...@@ -243,7 +243,6 @@ class Llama4Attention(nn.Module):
self.rotary_emb = ( self.rotary_emb = (
get_rope( get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=is_neox_style, is_neox_style=is_neox_style,
......
...@@ -277,7 +277,6 @@ class MiniCPMAttention(nn.Module): ...@@ -277,7 +277,6 @@ class MiniCPMAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
) )
......
...@@ -120,7 +120,6 @@ class MiniCPM3Attention(nn.Module): ...@@ -120,7 +120,6 @@ class MiniCPM3Attention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.qk_rope_head_dim, self.qk_rope_head_dim,
rotary_dim=self.qk_rope_head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
) )
......
...@@ -199,9 +199,13 @@ class MiniMaxM2Attention(nn.Module): ...@@ -199,9 +199,13 @@ class MiniMaxM2Attention(nn.Module):
prefix=f"{prefix}.o_proj", prefix=f"{prefix}.o_proj",
) )
if (
rope_parameters is not None
and "partial_rotary_factor" not in rope_parameters
):
rope_parameters["partial_rotary_factor"] = rotary_dim / self.head_dim
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
) )
......
...@@ -187,7 +187,6 @@ class MiniMaxText01Attention(nn.Module): ...@@ -187,7 +187,6 @@ class MiniMaxText01Attention(nn.Module):
num_heads: int, num_heads: int,
head_dim: int, head_dim: int,
num_kv_heads: int, num_kv_heads: int,
rotary_dim: int,
max_position: int = 4096 * 32, max_position: int = 4096 * 32,
rope_parameters: dict | None = None, rope_parameters: dict | None = None,
sliding_window: int | None = None, sliding_window: int | None = None,
...@@ -245,7 +244,6 @@ class MiniMaxText01Attention(nn.Module): ...@@ -245,7 +244,6 @@ class MiniMaxText01Attention(nn.Module):
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
head_size=self.head_dim, head_size=self.head_dim,
rotary_dim=rotary_dim,
max_position=max_position, max_position=max_position,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
is_neox_style=True, is_neox_style=True,
...@@ -290,6 +288,8 @@ class MiniMaxText01DecoderLayer(nn.Module): ...@@ -290,6 +288,8 @@ class MiniMaxText01DecoderLayer(nn.Module):
head_dim = getattr(config, "head_dim", None) head_dim = getattr(config, "head_dim", None)
if head_dim is None: if head_dim is None:
head_dim = config.hidden_size // config.num_attention_heads head_dim = config.hidden_size // config.num_attention_heads
rotary_dim = getattr(config, "rotary_dim", head_dim)
config.rope_parameters["partial_rotary_factor"] = rotary_dim / head_dim
if hasattr(config, "max_model_len") and isinstance(config.max_model_len, int): if hasattr(config, "max_model_len") and isinstance(config.max_model_len, int):
max_position_embeddings = min( max_position_embeddings = min(
config.max_position_embeddings, config.max_model_len config.max_position_embeddings, config.max_model_len
...@@ -321,9 +321,6 @@ class MiniMaxText01DecoderLayer(nn.Module): ...@@ -321,9 +321,6 @@ class MiniMaxText01DecoderLayer(nn.Module):
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
head_dim=head_dim, head_dim=head_dim,
rotary_dim=config.rotary_dim
if hasattr(config, "rotary_dim")
else head_dim,
num_kv_heads=config.num_key_value_heads, num_kv_heads=config.num_key_value_heads,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
......
...@@ -206,7 +206,6 @@ class MixtralAttention(nn.Module): ...@@ -206,7 +206,6 @@ class MixtralAttention(nn.Module):
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position, max_position=max_position,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=True, is_neox_style=True,
......
...@@ -295,11 +295,11 @@ class Llama4VisionAttention(nn.Module): ...@@ -295,11 +295,11 @@ class Llama4VisionAttention(nn.Module):
rope_parameters = { rope_parameters = {
"rope_type": "mllama4", "rope_type": "mllama4",
"rope_theta": config.rope_parameters["rope_theta"], "rope_theta": config.rope_parameters["rope_theta"],
"partial_rotary_factor": 0.5,
} }
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
head_size=self.head_dim, head_size=self.head_dim,
rotary_dim=config.hidden_size // config.num_attention_heads // 2,
# number of image patches # number of image patches
max_position=(config.image_size // config.patch_size) ** 2, max_position=(config.image_size // config.patch_size) ** 2,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
......
...@@ -105,7 +105,6 @@ class ModernBertAttention(nn.Module): ...@@ -105,7 +105,6 @@ class ModernBertAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
head_size=self.head_dim, head_size=self.head_dim,
rotary_dim=self.head_dim,
max_position=config.max_position_embeddings, max_position=config.max_position_embeddings,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
dtype=torch.float16, dtype=torch.float16,
......
...@@ -433,7 +433,6 @@ class MolmoAttention(nn.Module): ...@@ -433,7 +433,6 @@ class MolmoAttention(nn.Module):
# Rotary embeddings. # Rotary embeddings.
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings, max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
) )
......
...@@ -199,7 +199,6 @@ class NemotronAttention(nn.Module): ...@@ -199,7 +199,6 @@ class NemotronAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
) )
......
...@@ -118,7 +118,6 @@ class DeciLMAttention(LlamaAttention): ...@@ -118,7 +118,6 @@ class DeciLMAttention(LlamaAttention):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings, max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=is_neox_style, is_neox_style=is_neox_style,
......
...@@ -102,7 +102,6 @@ class OlmoAttention(nn.Module): ...@@ -102,7 +102,6 @@ class OlmoAttention(nn.Module):
# Rotary embeddings. # Rotary embeddings.
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings, max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
) )
......
...@@ -146,7 +146,6 @@ class Olmo2Attention(nn.Module): ...@@ -146,7 +146,6 @@ class Olmo2Attention(nn.Module):
rope_parameters = {"rope_type": "default", "rope_theta": rope_theta} rope_parameters = {"rope_type": "default", "rope_theta": rope_theta}
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings, max_position=self.max_position_embeddings,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
) )
......
...@@ -171,7 +171,6 @@ class OlmoeAttention(nn.Module): ...@@ -171,7 +171,6 @@ class OlmoeAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=True, is_neox_style=True,
......
...@@ -352,7 +352,6 @@ class OpenPanguMLAAttention(nn.Module): ...@@ -352,7 +352,6 @@ class OpenPanguMLAAttention(nn.Module):
} }
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
qk_rope_head_dim, qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
is_neox_style=False, is_neox_style=False,
...@@ -525,7 +524,6 @@ class OpenPanguEmbeddedAttention(nn.Module): ...@@ -525,7 +524,6 @@ class OpenPanguEmbeddedAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings, max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=is_neox_style, is_neox_style=is_neox_style,
......
...@@ -135,7 +135,6 @@ class OrionAttention(nn.Module): ...@@ -135,7 +135,6 @@ class OrionAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment