"examples/offline_inference/spec_decode.py" did not exist on "70788bdbdc590e7fbf9bddb3fa9bc92ac3181733"
Unverified Commit cf3eacfe authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Standardise `get_rope` to use `rope_parameters["partial_rotary_factor"]`, not `rotary_dim` (#30389)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 92fea56f
......@@ -140,7 +140,6 @@ class InternLM2Attention(nn.Module):
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
rope_parameters=rope_parameters,
)
......
......@@ -143,7 +143,6 @@ class Lfm2Attention(nn.Module):
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters,
is_neox_style=True,
......
......@@ -236,7 +236,6 @@ class Lfm2MoeAttention(nn.Module):
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters,
is_neox_style=True,
......
......@@ -259,7 +259,6 @@ class LlamaAttention(nn.Module):
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
rope_parameters=getattr(config, "rope_parameters", None),
is_neox_style=is_neox_style,
......
......@@ -243,7 +243,6 @@ class Llama4Attention(nn.Module):
self.rotary_emb = (
get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
rope_parameters=config.rope_parameters,
is_neox_style=is_neox_style,
......
......@@ -277,7 +277,6 @@ class MiniCPMAttention(nn.Module):
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
rope_parameters=rope_parameters,
)
......
......@@ -120,7 +120,6 @@ class MiniCPM3Attention(nn.Module):
self.rotary_emb = get_rope(
self.qk_rope_head_dim,
rotary_dim=self.qk_rope_head_dim,
max_position=max_position_embeddings,
rope_parameters=config.rope_parameters,
)
......
......@@ -199,9 +199,13 @@ class MiniMaxM2Attention(nn.Module):
prefix=f"{prefix}.o_proj",
)
if (
rope_parameters is not None
and "partial_rotary_factor" not in rope_parameters
):
rope_parameters["partial_rotary_factor"] = rotary_dim / self.head_dim
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
rope_parameters=rope_parameters,
)
......
......@@ -187,7 +187,6 @@ class MiniMaxText01Attention(nn.Module):
num_heads: int,
head_dim: int,
num_kv_heads: int,
rotary_dim: int,
max_position: int = 4096 * 32,
rope_parameters: dict | None = None,
sliding_window: int | None = None,
......@@ -245,7 +244,6 @@ class MiniMaxText01Attention(nn.Module):
)
self.rotary_emb = get_rope(
head_size=self.head_dim,
rotary_dim=rotary_dim,
max_position=max_position,
rope_parameters=rope_parameters,
is_neox_style=True,
......@@ -290,6 +288,8 @@ class MiniMaxText01DecoderLayer(nn.Module):
head_dim = getattr(config, "head_dim", None)
if head_dim is None:
head_dim = config.hidden_size // config.num_attention_heads
rotary_dim = getattr(config, "rotary_dim", head_dim)
config.rope_parameters["partial_rotary_factor"] = rotary_dim / head_dim
if hasattr(config, "max_model_len") and isinstance(config.max_model_len, int):
max_position_embeddings = min(
config.max_position_embeddings, config.max_model_len
......@@ -321,9 +321,6 @@ class MiniMaxText01DecoderLayer(nn.Module):
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
head_dim=head_dim,
rotary_dim=config.rotary_dim
if hasattr(config, "rotary_dim")
else head_dim,
num_kv_heads=config.num_key_value_heads,
max_position=max_position_embeddings,
rope_parameters=config.rope_parameters,
......
......@@ -206,7 +206,6 @@ class MixtralAttention(nn.Module):
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
rope_parameters=config.rope_parameters,
is_neox_style=True,
......
......@@ -295,11 +295,11 @@ class Llama4VisionAttention(nn.Module):
rope_parameters = {
"rope_type": "mllama4",
"rope_theta": config.rope_parameters["rope_theta"],
"partial_rotary_factor": 0.5,
}
self.rotary_emb = get_rope(
head_size=self.head_dim,
rotary_dim=config.hidden_size // config.num_attention_heads // 2,
# number of image patches
max_position=(config.image_size // config.patch_size) ** 2,
rope_parameters=rope_parameters,
......
......@@ -105,7 +105,6 @@ class ModernBertAttention(nn.Module):
self.rotary_emb = get_rope(
head_size=self.head_dim,
rotary_dim=self.head_dim,
max_position=config.max_position_embeddings,
rope_parameters=rope_parameters,
dtype=torch.float16,
......
......@@ -433,7 +433,6 @@ class MolmoAttention(nn.Module):
# Rotary embeddings.
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters,
)
......
......@@ -199,7 +199,6 @@ class NemotronAttention(nn.Module):
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
rope_parameters=config.rope_parameters,
)
......
......@@ -118,7 +118,6 @@ class DeciLMAttention(LlamaAttention):
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters,
is_neox_style=is_neox_style,
......
......@@ -102,7 +102,6 @@ class OlmoAttention(nn.Module):
# Rotary embeddings.
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters,
)
......
......@@ -146,7 +146,6 @@ class Olmo2Attention(nn.Module):
rope_parameters = {"rope_type": "default", "rope_theta": rope_theta}
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
rope_parameters=rope_parameters,
)
......
......@@ -171,7 +171,6 @@ class OlmoeAttention(nn.Module):
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
rope_parameters=config.rope_parameters,
is_neox_style=True,
......
......@@ -352,7 +352,6 @@ class OpenPanguMLAAttention(nn.Module):
}
self.rotary_emb = get_rope(
qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
rope_parameters=rope_parameters,
is_neox_style=False,
......@@ -525,7 +524,6 @@ class OpenPanguEmbeddedAttention(nn.Module):
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters,
is_neox_style=is_neox_style,
......
......@@ -135,7 +135,6 @@ class OrionAttention(nn.Module):
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
rope_parameters=rope_parameters,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment