Unverified Commit cf3eacfe authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Standardise `get_rope` to use `rope_parameters["partial_rotary_factor"]`, not `rotary_dim` (#30389)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 92fea56f
......@@ -99,7 +99,6 @@ def benchmark_mrope(
# the parameters to compute the q k v size based on tp_size
mrope_helper_class = get_rope(
head_size=head_dim,
rotary_dim=head_dim,
max_position=max_position,
is_neox_style=is_neox_style,
rope_parameters=rope_parameters,
......
......@@ -32,8 +32,8 @@ def get_benchmark(head_size, rotary_dim, is_neox_style, device):
def benchmark(batch_size, seq_len, num_heads, provider):
dtype = torch.bfloat16
max_position = 8192
base = 10000
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style)
rope_parameters = {"partial_rotary_factor": rotary_dim / head_size}
rope = get_rope(head_size, max_position, is_neox_style, rope_parameters)
rope = rope.to(dtype=dtype, device=device)
cos_sin_cache = rope.cos_sin_cache.to(dtype=torch.float, device=device)
......
......@@ -128,14 +128,12 @@ class TestFusedAddRMSNorm(torch.nn.Module):
class TestRotaryEmbedding(torch.nn.Module):
def __init__(self, head_dim=64, rotary_dim=None, max_position=2048, base=10000):
def __init__(self, head_dim=64, max_position=2048, base=10000):
super().__init__()
self.head_dim = head_dim
self.rotary_dim = rotary_dim or head_dim
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.rotary_dim,
max_position=max_position,
rope_parameters={"rope_type": "default", "rope_theta": base},
)
......@@ -170,7 +168,6 @@ class TestRotaryEmbeddingSliceScatter(torch.nn.Module):
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
rope_parameters={"rope_type": "default", "rope_theta": base},
)
......
......@@ -116,7 +116,6 @@ def test_mrope(
mrope_helper_class = get_rope(
head_size=head_dim,
rotary_dim=head_dim,
max_position=max_position,
is_neox_style=is_neox_style,
rope_parameters=config.rope_parameters,
......@@ -185,7 +184,6 @@ def test_mrope_torch_compile_tracing(
mrope_helper_class = get_rope(
head_size=head_dim,
rotary_dim=head_dim,
max_position=max_position,
is_neox_style=is_neox_style,
rope_parameters=config.rope_parameters,
......
......@@ -83,8 +83,12 @@ def test_rotary_embedding(
torch.set_default_device(device)
if rotary_dim is None:
rotary_dim = head_size
rope_parameters = {"rope_type": "default", "rope_theta": rope_theta}
rope = get_rope(head_size, rotary_dim, max_position, is_neox_style, rope_parameters)
rope_parameters = {
"rope_type": "default",
"rope_theta": rope_theta,
"partial_rotary_factor": rotary_dim / head_size,
}
rope = get_rope(head_size, max_position, is_neox_style, rope_parameters)
rope = rope.to(dtype=dtype, device=torch.get_default_device())
positions = torch.randint(0, max_position, (batch_size, seq_len))
......@@ -150,9 +154,9 @@ def test_rope_module_cache():
if rotary_dim is None:
rotary_dim = head_size
rope_parameters["rope_theta"] = rope_theta
rope_parameters["partial_rotary_factor"] = rotary_dim / head_size
rope = get_rope(
head_size,
rotary_dim,
max_position,
is_neox_style,
rope_parameters,
......@@ -177,9 +181,9 @@ def test_rope_module_cache():
if rotary_dim is None:
rotary_dim = head_size
rope_parameters["rope_theta"] = rope_theta
rope_parameters["partial_rotary_factor"] = rotary_dim / head_size
rope = get_rope(
head_size,
rotary_dim,
max_position,
is_neox_style,
rope_parameters,
......
......@@ -73,14 +73,28 @@ def get_field(cls: ConfigType, name: str) -> Field:
)
def getattr_iter(object: object, names: Iterable[str], default: Any) -> Any:
def getattr_iter(
object: object, names: Iterable[str], default: Any, warn: bool = False
) -> Any:
"""
A helper function that retrieves an attribute from an object which may
have multiple possible names. This is useful when fetching attributes from
arbitrary `transformers.PretrainedConfig` instances.
In the case where the first name in `names` is the preferred name, and
any other names are deprecated aliases, setting `warn=True` will log a
warning when a deprecated name is used.
"""
for name in names:
for i, name in enumerate(names):
if hasattr(object, name):
if warn and i > 0:
logger.warning_once(
"%s contains a deprecated attribute name '%s'. "
"Please use the preferred attribute name '%s' instead.",
type(object).__name__,
name,
names[0],
)
return getattr(object, name)
return default
......
......@@ -25,7 +25,6 @@ _ROPE_DICT: dict[tuple, RotaryEmbedding] = {}
def get_rope(
head_size: int,
rotary_dim: int,
max_position: int,
is_neox_style: bool = True,
rope_parameters: dict[str, Any] | None = None,
......@@ -54,12 +53,15 @@ def get_rope(
else:
dual_chunk_attention_args = None
partial_rotary_factor = 1.0
if rope_parameters is not None:
partial_rotary_factor = rope_parameters.get("partial_rotary_factor", 1.0)
rope_parameters = rope_parameters or {}
base = rope_parameters.get("rope_theta", 10000)
scaling_type = rope_parameters.get("rope_type", "default")
partial_rotary_factor = rope_parameters.get("partial_rotary_factor", 1.0)
if partial_rotary_factor <= 0.0 or partial_rotary_factor > 1.0:
raise ValueError(f"{partial_rotary_factor=} must be between 0.0 and 1.0")
rotary_dim = int(head_size * partial_rotary_factor)
if partial_rotary_factor < 1.0:
rotary_dim = int(rotary_dim * partial_rotary_factor)
key = (
head_size,
rotary_dim,
......@@ -72,7 +74,6 @@ def get_rope(
if key in _ROPE_DICT:
return _ROPE_DICT[key]
base = rope_parameters["rope_theta"] if rope_parameters else 10000
if dual_chunk_attention_config is not None:
extra_kwargs = {
k: v
......@@ -88,208 +89,201 @@ def get_rope(
dtype,
**extra_kwargs,
)
elif not rope_parameters:
rotary_emb = RotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style, dtype
)
else:
scaling_type = rope_parameters["rope_type"]
if scaling_type == "llama3":
scaling_factor = rope_parameters["factor"]
low_freq_factor = rope_parameters["low_freq_factor"]
high_freq_factor = rope_parameters["high_freq_factor"]
original_max_position = rope_parameters["original_max_position_embeddings"]
rotary_emb = Llama3RotaryEmbedding(
elif scaling_type == "default":
if "mrope_section" in rope_parameters:
rotary_emb = MRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
scaling_factor,
low_freq_factor,
high_freq_factor,
original_max_position,
mrope_section=rope_parameters["mrope_section"],
mrope_interleaved=rope_parameters.get("mrope_interleaved", False),
)
elif scaling_type == "mllama4":
rotary_emb = Llama4VisionRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style, dtype
)
elif scaling_type == "default":
if "mrope_section" in rope_parameters:
rotary_emb = MRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
mrope_section=rope_parameters["mrope_section"],
mrope_interleaved=rope_parameters.get("mrope_interleaved", False),
)
else:
rotary_emb = RotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
)
elif scaling_type == "linear":
scaling_factor = rope_parameters["factor"]
rotary_emb = LinearScalingRotaryEmbedding(
else:
rotary_emb = RotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_factor,
dtype,
)
elif scaling_type == "ntk":
scaling_factor = rope_parameters["factor"]
mixed_b = rope_parameters.get("mixed_b")
rotary_emb = NTKScalingRotaryEmbedding(
elif scaling_type == "llama3":
scaling_factor = rope_parameters["factor"]
low_freq_factor = rope_parameters["low_freq_factor"]
high_freq_factor = rope_parameters["high_freq_factor"]
original_max_position = rope_parameters["original_max_position_embeddings"]
rotary_emb = Llama3RotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
scaling_factor,
low_freq_factor,
high_freq_factor,
original_max_position,
)
elif scaling_type == "mllama4":
rotary_emb = Llama4VisionRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style, dtype
)
elif scaling_type == "linear":
scaling_factor = rope_parameters["factor"]
rotary_emb = LinearScalingRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_factor,
dtype,
)
elif scaling_type == "ntk":
scaling_factor = rope_parameters["factor"]
mixed_b = rope_parameters.get("mixed_b")
rotary_emb = NTKScalingRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_factor,
dtype,
mixed_b,
)
elif scaling_type == "dynamic":
if "alpha" in rope_parameters:
scaling_alpha = rope_parameters["alpha"]
rotary_emb = DynamicNTKAlphaRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_factor,
scaling_alpha,
dtype,
mixed_b,
)
elif scaling_type == "dynamic":
if "alpha" in rope_parameters:
scaling_alpha = rope_parameters["alpha"]
rotary_emb = DynamicNTKAlphaRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_alpha,
dtype,
)
elif "factor" in rope_parameters:
scaling_factor = rope_parameters["factor"]
rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_factor,
dtype,
)
else:
raise ValueError(
"Dynamic rope scaling must contain either 'alpha' or 'factor' field"
)
elif scaling_type == "xdrope":
scaling_alpha = rope_parameters["alpha"]
rotary_emb = XDRotaryEmbedding(
elif "factor" in rope_parameters:
scaling_factor = rope_parameters["factor"]
rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_alpha,
scaling_factor,
dtype,
xdrope_section=rope_parameters["xdrope_section"],
)
elif scaling_type == "yarn":
scaling_factor = rope_parameters["factor"]
original_max_position = rope_parameters["original_max_position_embeddings"]
extra_kwargs = {
k: v
for k, v in rope_parameters.items()
if k
in (
"extrapolation_factor",
"attn_factor",
"beta_fast",
"beta_slow",
"apply_yarn_scaling",
"truncate",
)
}
if "mrope_section" in rope_parameters:
extra_kwargs.pop("apply_yarn_scaling", None)
rotary_emb = MRotaryEmbedding(
head_size,
rotary_dim,
original_max_position,
base,
is_neox_style,
dtype,
mrope_section=rope_parameters["mrope_section"],
mrope_interleaved=rope_parameters.get("mrope_interleaved", False),
scaling_factor=scaling_factor,
**extra_kwargs,
)
else:
rotary_emb = YaRNScalingRotaryEmbedding(
head_size,
rotary_dim,
original_max_position,
base,
is_neox_style,
scaling_factor,
dtype,
**extra_kwargs,
)
elif scaling_type in ["deepseek_yarn", "deepseek_llama_scaling"]:
scaling_factor = rope_parameters["factor"]
original_max_position = rope_parameters["original_max_position_embeddings"]
# assert max_position == original_max_position * scaling_factor
extra_kwargs = {
k: v
for k, v in rope_parameters.items()
if k
in (
"extrapolation_factor",
"attn_factor",
"beta_fast",
"beta_slow",
"mscale",
"mscale_all_dim",
)
}
rotary_emb = DeepseekScalingRotaryEmbedding(
else:
raise ValueError(
"Dynamic rope scaling must contain either 'alpha' or 'factor' field"
)
elif scaling_type == "xdrope":
scaling_alpha = rope_parameters["alpha"]
rotary_emb = XDRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_alpha,
dtype,
xdrope_section=rope_parameters["xdrope_section"],
)
elif scaling_type == "yarn":
scaling_factor = rope_parameters["factor"]
original_max_position = rope_parameters["original_max_position_embeddings"]
extra_kwargs = {
k: v
for k, v in rope_parameters.items()
if k
in (
"extrapolation_factor",
"attn_factor",
"beta_fast",
"beta_slow",
"apply_yarn_scaling",
"truncate",
)
}
if "mrope_section" in rope_parameters:
extra_kwargs.pop("apply_yarn_scaling", None)
rotary_emb = MRotaryEmbedding(
head_size,
rotary_dim,
original_max_position,
base,
is_neox_style,
scaling_factor,
dtype,
mrope_section=rope_parameters["mrope_section"],
mrope_interleaved=rope_parameters.get("mrope_interleaved", False),
scaling_factor=scaling_factor,
**extra_kwargs,
)
elif scaling_type == "longrope":
short_factor = rope_parameters["short_factor"]
long_factor = rope_parameters["long_factor"]
original_max_position = rope_parameters["original_max_position_embeddings"]
extra_kwargs = {
k: v
for k, v in rope_parameters.items()
if k in ("short_mscale", "long_mscale")
}
rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(
else:
rotary_emb = YaRNScalingRotaryEmbedding(
head_size,
rotary_dim,
max_position,
original_max_position,
base,
is_neox_style,
scaling_factor,
dtype,
short_factor,
long_factor,
**extra_kwargs,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
elif scaling_type in ["deepseek_yarn", "deepseek_llama_scaling"]:
scaling_factor = rope_parameters["factor"]
original_max_position = rope_parameters["original_max_position_embeddings"]
# assert max_position == original_max_position * scaling_factor
extra_kwargs = {
k: v
for k, v in rope_parameters.items()
if k
in (
"extrapolation_factor",
"attn_factor",
"beta_fast",
"beta_slow",
"mscale",
"mscale_all_dim",
)
}
rotary_emb = DeepseekScalingRotaryEmbedding(
head_size,
rotary_dim,
original_max_position,
base,
is_neox_style,
scaling_factor,
dtype,
**extra_kwargs,
)
elif scaling_type == "longrope":
short_factor = rope_parameters["short_factor"]
long_factor = rope_parameters["long_factor"]
original_max_position = rope_parameters["original_max_position_embeddings"]
extra_kwargs = {
k: v
for k, v in rope_parameters.items()
if k in ("short_mscale", "long_mscale")
}
rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(
head_size,
rotary_dim,
max_position,
original_max_position,
base,
is_neox_style,
dtype,
short_factor,
long_factor,
**extra_kwargs,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
_ROPE_DICT[key] = rotary_emb
return rotary_emb
......@@ -241,7 +241,6 @@ class AfmoeAttention(nn.Module):
if self.is_local_attention:
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
rope_parameters=config["rope_parameters"],
is_neox_style=True,
......
......@@ -226,7 +226,6 @@ class ApertusAttention(nn.Module):
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters,
is_neox_style=is_neox_style,
......
......@@ -314,7 +314,6 @@ class ArcticAttention(nn.Module):
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters,
is_neox_style=True,
......
......@@ -189,7 +189,6 @@ class BaiChuanAttention(nn.Module):
else:
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
rope_parameters=rope_parameters,
)
......
......@@ -127,11 +127,11 @@ class BailingAttention(nn.Module):
prefix=f"{prefix}.dense",
)
self.rotary_dim = getattr(config, "rotary_dim", self.head_dim)
rotary_dim = getattr(config, "rotary_dim", self.head_dim)
config.rope_parameters["partial_rotary_factor"] = rotary_dim / self.head_dim
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.rotary_dim,
max_position=config.max_position_embeddings,
rope_parameters=config.rope_parameters,
is_neox_style=True,
......
......@@ -178,14 +178,11 @@ class BambaAttentionDecoderLayer(nn.Module):
self.scaling = self.head_dim**-0.5
self.max_position_embeddings = max_position_embeddings
if hasattr(config, "attn_rotary_emb"):
rotary_dim = config.attn_rotary_emb # for backward compatibility
else:
rotary_dim = self.head_dim # default
rotary_dim = getattr(config, "attn_rotary_emb", self.head_dim)
config.rope_parameters["partial_rotary_factor"] = rotary_dim / self.head_dim
self.rotary_emb = get_rope(
head_size=self.head_dim,
rotary_dim=rotary_dim,
max_position=max_position_embeddings,
rope_parameters=config.rope_parameters,
is_neox_style=True,
......
......@@ -314,7 +314,6 @@ class ChameleonAttention(nn.Module):
self.k_norm = ChameleonLayerNorm((self.num_kv_heads, self.head_dim))
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
rope_parameters=rope_parameters,
)
......
......@@ -99,13 +99,16 @@ class GLMAttention(nn.Module):
# https://huggingface.co/zai-org/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
rope_ratio = getattr(config, "rope_ratio", 1.0)
max_positions = getattr(config, "seq_length", 8192)
rope_parameters = {"rope_type": "default", "rope_theta": 10000 * rope_ratio}
rope_parameters = {
"rope_type": "default",
"rope_theta": 10000 * rope_ratio,
"partial_rotary_factor": 0.5,
}
# NOTE: zai-org/cogagent-9b-20241220 uses original_rope=False,
# which is equivalent to is_neox_style=True
is_neox_style = not config.original_rope
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim // 2,
max_position=max_positions,
rope_parameters=rope_parameters,
is_neox_style=is_neox_style,
......
......@@ -175,7 +175,6 @@ class CohereAttention(nn.Module):
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters,
is_neox_style=False,
......
......@@ -42,9 +42,10 @@ class GteNewModelConfig(VerifyAndUpdateConfig):
config.hidden_act = "geglu"
head_dim = config.hidden_size // config.num_attention_heads
rotary_dim = getattr(config, "rotary_emb_dim", head_dim)
config.rope_parameters["partial_rotary_factor"] = rotary_dim / head_dim
config.rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": config.max_position_embeddings,
"rope_parameters": config.rope_parameters,
}
......@@ -77,9 +78,11 @@ class JinaRobertaModelConfig(VerifyAndUpdateConfig):
if not model_config.enforce_eager:
max_position = round_up(max_position, 8)
rotary_dim = getattr(config, "rotary_emb_dim", head_dim)
config.rope_parameters["partial_rotary_factor"] = rotary_dim / head_dim
config.rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": max_position,
"rope_parameters": config.rope_parameters,
}
......@@ -113,12 +116,10 @@ class NomicBertModelConfig(VerifyAndUpdateConfig):
config.num_hidden_layers = config.n_layer
head_dim = config.hidden_size // config.num_attention_heads
rotary_emb_dim = int(head_dim * config.rotary_emb_fraction)
max_trained_positions = getattr(config, "max_trained_positions", 2048)
config.rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": rotary_emb_dim,
"max_position": max_trained_positions,
"rope_parameters": config.rope_parameters,
}
......@@ -240,9 +241,10 @@ class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
config.hidden_act = "geglu"
head_dim = config.hidden_size // config.num_attention_heads
rotary_dim = getattr(config, "rotary_emb_dim", head_dim)
config.rope_parameters["partial_rotary_factor"] = rotary_dim / head_dim
config.rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": config.max_position_embeddings,
"rope_parameters": config.rope_parameters,
}
......
......@@ -222,7 +222,6 @@ class DbrxAttention(nn.Module):
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position,
rope_parameters=rope_parameters,
is_neox_style=True,
......
......@@ -156,7 +156,6 @@ class DeepseekAttention(nn.Module):
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
rope_parameters=config.rope_parameters,
)
......@@ -499,7 +498,6 @@ class DeepseekV2Attention(nn.Module):
self.rotary_emb = get_rope(
qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
rope_parameters=config.rope_parameters,
is_neox_style=False,
......@@ -1018,7 +1016,6 @@ class DeepseekV2MLAAttention(nn.Module):
self.rotary_emb = get_rope(
qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
rope_parameters=config.rope_parameters,
is_neox_style=False,
......@@ -1038,7 +1035,6 @@ class DeepseekV2MLAAttention(nn.Module):
if self.is_v32:
self.indexer_rope_emb = get_rope(
qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
rope_parameters=config.rope_parameters,
is_neox_style=True,
......
......@@ -250,7 +250,6 @@ class Dots1Attention(nn.Module):
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
rope_parameters=config.rope_parameters,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment