Unverified Commit cf3eacfe authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Standardise `get_rope` to use `rope_parameters["partial_rotary_factor"]`, not `rotary_dim` (#30389)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 92fea56f
...@@ -166,7 +166,6 @@ class OuroAttention(nn.Module): ...@@ -166,7 +166,6 @@ class OuroAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position, max_position=max_position,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
dual_chunk_attention_config=dual_chunk_attention_config, dual_chunk_attention_config=dual_chunk_attention_config,
......
...@@ -134,7 +134,6 @@ class PersimmonAttention(nn.Module): ...@@ -134,7 +134,6 @@ class PersimmonAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings, max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
) )
......
...@@ -84,19 +84,18 @@ class PhiAttention(nn.Module): ...@@ -84,19 +84,18 @@ class PhiAttention(nn.Module):
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
self.total_num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.total_num_heads self.head_size = self.hidden_size // config.num_attention_heads
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
assert self.total_num_heads % tensor_model_parallel_world_size == 0 assert config.num_attention_heads % tensor_model_parallel_world_size == 0
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size self.num_heads = config.num_attention_heads // tensor_model_parallel_world_size
# pylint: disable=C0103 # pylint: disable=C0103
self.qkv_proj = QKVParallelLinear( self.qkv_proj = QKVParallelLinear(
self.hidden_size, self.hidden_size,
self.head_size, self.head_size,
self.total_num_heads, config.num_attention_heads,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv_proj", prefix=f"{prefix}.qkv_proj",
...@@ -109,13 +108,10 @@ class PhiAttention(nn.Module): ...@@ -109,13 +108,10 @@ class PhiAttention(nn.Module):
) )
scaling = self.head_size**-0.5 scaling = self.head_size**-0.5
rotary_dim = config.hidden_size // config.num_attention_heads
assert rotary_dim % 2 == 0
max_position_embeddings = getattr(config, "max_position_embeddings", 2048) max_position_embeddings = getattr(config, "max_position_embeddings", 2048)
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_size, self.head_size,
rotary_dim=rotary_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
) )
......
...@@ -352,7 +352,6 @@ class PhiMoEAttention(nn.Module): ...@@ -352,7 +352,6 @@ class PhiMoEAttention(nn.Module):
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position, max_position=max_position,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
is_neox_style=True, is_neox_style=True,
......
...@@ -574,7 +574,6 @@ class Plamo2AttentionMixer(nn.Module): ...@@ -574,7 +574,6 @@ class Plamo2AttentionMixer(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position, max_position=max_position,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
) )
......
...@@ -179,7 +179,6 @@ class Plamo3AttentionMixer(nn.Module): ...@@ -179,7 +179,6 @@ class Plamo3AttentionMixer(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position, max_position=max_position,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
) )
......
...@@ -114,7 +114,6 @@ class QWenAttention(nn.Module): ...@@ -114,7 +114,6 @@ class QWenAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
) )
......
...@@ -164,7 +164,6 @@ class Qwen2Attention(nn.Module): ...@@ -164,7 +164,6 @@ class Qwen2Attention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position, max_position=max_position,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
dual_chunk_attention_config=dual_chunk_attention_config, dual_chunk_attention_config=dual_chunk_attention_config,
......
...@@ -624,9 +624,9 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -624,9 +624,9 @@ class Qwen2_5_VisionTransformer(nn.Module):
head_dim = self.hidden_size // self.num_heads head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = get_rope( self.rotary_pos_emb = get_rope(
head_size=head_dim, head_size=head_dim,
rotary_dim=head_dim // 2,
max_position=8192, max_position=8192,
is_neox_style=True, is_neox_style=True,
rope_parameters={"partial_rotary_factor": 0.5},
) )
self.attn_backend = get_vit_attn_backend( self.attn_backend = get_vit_attn_backend(
......
...@@ -244,7 +244,6 @@ class Qwen2MoeAttention(nn.Module): ...@@ -244,7 +244,6 @@ class Qwen2MoeAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
dual_chunk_attention_config=dual_chunk_attention_config, dual_chunk_attention_config=dual_chunk_attention_config,
......
...@@ -621,9 +621,9 @@ class Qwen2VisionTransformer(nn.Module): ...@@ -621,9 +621,9 @@ class Qwen2VisionTransformer(nn.Module):
head_dim = embed_dim // num_heads head_dim = embed_dim // num_heads
self.rotary_pos_emb = get_rope( self.rotary_pos_emb = get_rope(
head_size=head_dim, head_size=head_dim,
rotary_dim=head_dim // 2,
max_position=8192, max_position=8192,
is_neox_style=True, is_neox_style=True,
rope_parameters={"partial_rotary_factor": 0.5},
) )
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList(
......
...@@ -111,7 +111,6 @@ class Qwen3Attention(nn.Module): ...@@ -111,7 +111,6 @@ class Qwen3Attention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position, max_position=max_position,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
dual_chunk_attention_config=dual_chunk_attention_config, dual_chunk_attention_config=dual_chunk_attention_config,
......
...@@ -269,7 +269,6 @@ class Qwen3MoeAttention(nn.Module): ...@@ -269,7 +269,6 @@ class Qwen3MoeAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
dual_chunk_attention_config=dual_chunk_attention_config, dual_chunk_attention_config=dual_chunk_attention_config,
......
...@@ -747,7 +747,6 @@ class Qwen3NextAttention(nn.Module): ...@@ -747,7 +747,6 @@ class Qwen3NextAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
head_size=self.head_dim, head_size=self.head_dim,
rotary_dim=self.head_dim,
max_position=config.max_position_embeddings, max_position=config.max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
dual_chunk_attention_config=self.dual_chunk_attention_config, dual_chunk_attention_config=self.dual_chunk_attention_config,
......
...@@ -333,9 +333,9 @@ class Qwen3Omni_VisionTransformer(nn.Module): ...@@ -333,9 +333,9 @@ class Qwen3Omni_VisionTransformer(nn.Module):
head_dim = self.hidden_size // self.num_heads head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = get_rope( self.rotary_pos_emb = get_rope(
head_size=head_dim, head_size=head_dim,
rotary_dim=head_dim // 2,
max_position=8192, max_position=8192,
is_neox_style=True, is_neox_style=True,
rope_parameters={"partial_rotary_factor": 0.5},
) )
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList(
......
...@@ -340,9 +340,9 @@ class Qwen3_VisionTransformer(nn.Module): ...@@ -340,9 +340,9 @@ class Qwen3_VisionTransformer(nn.Module):
head_dim = self.hidden_size // self.num_heads head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = get_rope( self.rotary_pos_emb = get_rope(
head_size=head_dim, head_size=head_dim,
rotary_dim=head_dim // 2,
max_position=8192, max_position=8192,
is_neox_style=True, is_neox_style=True,
rope_parameters={"partial_rotary_factor": 0.5},
) )
self.merger = Qwen3_VisionPatchMerger( self.merger = Qwen3_VisionPatchMerger(
......
...@@ -161,7 +161,6 @@ class SeedOssAttention(nn.Module): ...@@ -161,7 +161,6 @@ class SeedOssAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position, max_position=max_position,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
) )
......
...@@ -160,7 +160,6 @@ class SolarAttention(nn.Module): ...@@ -160,7 +160,6 @@ class SolarAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
) )
......
...@@ -148,7 +148,6 @@ class StablelmAttention(nn.Module): ...@@ -148,7 +148,6 @@ class StablelmAttention(nn.Module):
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=self.config.max_position_embeddings, max_position=self.config.max_position_embeddings,
rope_parameters=self.config.rope_parameters, rope_parameters=self.config.rope_parameters,
) )
......
...@@ -112,7 +112,6 @@ class Starcoder2Attention(nn.Module): ...@@ -112,7 +112,6 @@ class Starcoder2Attention(nn.Module):
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings, max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=True, is_neox_style=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment