Unverified Commit cf3eacfe authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Standardise `get_rope` to use `rope_parameters["partial_rotary_factor"]`, not `rotary_dim` (#30389)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 92fea56f
...@@ -99,7 +99,6 @@ def benchmark_mrope( ...@@ -99,7 +99,6 @@ def benchmark_mrope(
# the parameters to compute the q k v size based on tp_size # the parameters to compute the q k v size based on tp_size
mrope_helper_class = get_rope( mrope_helper_class = get_rope(
head_size=head_dim, head_size=head_dim,
rotary_dim=head_dim,
max_position=max_position, max_position=max_position,
is_neox_style=is_neox_style, is_neox_style=is_neox_style,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
......
...@@ -32,8 +32,8 @@ def get_benchmark(head_size, rotary_dim, is_neox_style, device): ...@@ -32,8 +32,8 @@ def get_benchmark(head_size, rotary_dim, is_neox_style, device):
def benchmark(batch_size, seq_len, num_heads, provider): def benchmark(batch_size, seq_len, num_heads, provider):
dtype = torch.bfloat16 dtype = torch.bfloat16
max_position = 8192 max_position = 8192
base = 10000 rope_parameters = {"partial_rotary_factor": rotary_dim / head_size}
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style) rope = get_rope(head_size, max_position, is_neox_style, rope_parameters)
rope = rope.to(dtype=dtype, device=device) rope = rope.to(dtype=dtype, device=device)
cos_sin_cache = rope.cos_sin_cache.to(dtype=torch.float, device=device) cos_sin_cache = rope.cos_sin_cache.to(dtype=torch.float, device=device)
......
...@@ -128,14 +128,12 @@ class TestFusedAddRMSNorm(torch.nn.Module): ...@@ -128,14 +128,12 @@ class TestFusedAddRMSNorm(torch.nn.Module):
class TestRotaryEmbedding(torch.nn.Module): class TestRotaryEmbedding(torch.nn.Module):
def __init__(self, head_dim=64, rotary_dim=None, max_position=2048, base=10000): def __init__(self, head_dim=64, max_position=2048, base=10000):
super().__init__() super().__init__()
self.head_dim = head_dim self.head_dim = head_dim
self.rotary_dim = rotary_dim or head_dim
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.rotary_dim,
max_position=max_position, max_position=max_position,
rope_parameters={"rope_type": "default", "rope_theta": base}, rope_parameters={"rope_type": "default", "rope_theta": base},
) )
...@@ -170,7 +168,6 @@ class TestRotaryEmbeddingSliceScatter(torch.nn.Module): ...@@ -170,7 +168,6 @@ class TestRotaryEmbeddingSliceScatter(torch.nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position, max_position=max_position,
rope_parameters={"rope_type": "default", "rope_theta": base}, rope_parameters={"rope_type": "default", "rope_theta": base},
) )
......
...@@ -116,7 +116,6 @@ def test_mrope( ...@@ -116,7 +116,6 @@ def test_mrope(
mrope_helper_class = get_rope( mrope_helper_class = get_rope(
head_size=head_dim, head_size=head_dim,
rotary_dim=head_dim,
max_position=max_position, max_position=max_position,
is_neox_style=is_neox_style, is_neox_style=is_neox_style,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
...@@ -185,7 +184,6 @@ def test_mrope_torch_compile_tracing( ...@@ -185,7 +184,6 @@ def test_mrope_torch_compile_tracing(
mrope_helper_class = get_rope( mrope_helper_class = get_rope(
head_size=head_dim, head_size=head_dim,
rotary_dim=head_dim,
max_position=max_position, max_position=max_position,
is_neox_style=is_neox_style, is_neox_style=is_neox_style,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
......
...@@ -83,8 +83,12 @@ def test_rotary_embedding( ...@@ -83,8 +83,12 @@ def test_rotary_embedding(
torch.set_default_device(device) torch.set_default_device(device)
if rotary_dim is None: if rotary_dim is None:
rotary_dim = head_size rotary_dim = head_size
rope_parameters = {"rope_type": "default", "rope_theta": rope_theta} rope_parameters = {
rope = get_rope(head_size, rotary_dim, max_position, is_neox_style, rope_parameters) "rope_type": "default",
"rope_theta": rope_theta,
"partial_rotary_factor": rotary_dim / head_size,
}
rope = get_rope(head_size, max_position, is_neox_style, rope_parameters)
rope = rope.to(dtype=dtype, device=torch.get_default_device()) rope = rope.to(dtype=dtype, device=torch.get_default_device())
positions = torch.randint(0, max_position, (batch_size, seq_len)) positions = torch.randint(0, max_position, (batch_size, seq_len))
...@@ -150,9 +154,9 @@ def test_rope_module_cache(): ...@@ -150,9 +154,9 @@ def test_rope_module_cache():
if rotary_dim is None: if rotary_dim is None:
rotary_dim = head_size rotary_dim = head_size
rope_parameters["rope_theta"] = rope_theta rope_parameters["rope_theta"] = rope_theta
rope_parameters["partial_rotary_factor"] = rotary_dim / head_size
rope = get_rope( rope = get_rope(
head_size, head_size,
rotary_dim,
max_position, max_position,
is_neox_style, is_neox_style,
rope_parameters, rope_parameters,
...@@ -177,9 +181,9 @@ def test_rope_module_cache(): ...@@ -177,9 +181,9 @@ def test_rope_module_cache():
if rotary_dim is None: if rotary_dim is None:
rotary_dim = head_size rotary_dim = head_size
rope_parameters["rope_theta"] = rope_theta rope_parameters["rope_theta"] = rope_theta
rope_parameters["partial_rotary_factor"] = rotary_dim / head_size
rope = get_rope( rope = get_rope(
head_size, head_size,
rotary_dim,
max_position, max_position,
is_neox_style, is_neox_style,
rope_parameters, rope_parameters,
......
...@@ -73,14 +73,28 @@ def get_field(cls: ConfigType, name: str) -> Field: ...@@ -73,14 +73,28 @@ def get_field(cls: ConfigType, name: str) -> Field:
) )
def getattr_iter(object: object, names: Iterable[str], default: Any) -> Any: def getattr_iter(
object: object, names: Iterable[str], default: Any, warn: bool = False
) -> Any:
""" """
A helper function that retrieves an attribute from an object which may A helper function that retrieves an attribute from an object which may
have multiple possible names. This is useful when fetching attributes from have multiple possible names. This is useful when fetching attributes from
arbitrary `transformers.PretrainedConfig` instances. arbitrary `transformers.PretrainedConfig` instances.
In the case where the first name in `names` is the preferred name, and
any other names are deprecated aliases, setting `warn=True` will log a
warning when a deprecated name is used.
""" """
for name in names: for i, name in enumerate(names):
if hasattr(object, name): if hasattr(object, name):
if warn and i > 0:
logger.warning_once(
"%s contains a deprecated attribute name '%s'. "
"Please use the preferred attribute name '%s' instead.",
type(object).__name__,
name,
names[0],
)
return getattr(object, name) return getattr(object, name)
return default return default
......
...@@ -25,7 +25,6 @@ _ROPE_DICT: dict[tuple, RotaryEmbedding] = {} ...@@ -25,7 +25,6 @@ _ROPE_DICT: dict[tuple, RotaryEmbedding] = {}
def get_rope( def get_rope(
head_size: int, head_size: int,
rotary_dim: int,
max_position: int, max_position: int,
is_neox_style: bool = True, is_neox_style: bool = True,
rope_parameters: dict[str, Any] | None = None, rope_parameters: dict[str, Any] | None = None,
...@@ -54,12 +53,15 @@ def get_rope( ...@@ -54,12 +53,15 @@ def get_rope(
else: else:
dual_chunk_attention_args = None dual_chunk_attention_args = None
partial_rotary_factor = 1.0 rope_parameters = rope_parameters or {}
if rope_parameters is not None: base = rope_parameters.get("rope_theta", 10000)
partial_rotary_factor = rope_parameters.get("partial_rotary_factor", 1.0) scaling_type = rope_parameters.get("rope_type", "default")
partial_rotary_factor = rope_parameters.get("partial_rotary_factor", 1.0)
if partial_rotary_factor <= 0.0 or partial_rotary_factor > 1.0:
raise ValueError(f"{partial_rotary_factor=} must be between 0.0 and 1.0")
rotary_dim = int(head_size * partial_rotary_factor)
if partial_rotary_factor < 1.0:
rotary_dim = int(rotary_dim * partial_rotary_factor)
key = ( key = (
head_size, head_size,
rotary_dim, rotary_dim,
...@@ -72,7 +74,6 @@ def get_rope( ...@@ -72,7 +74,6 @@ def get_rope(
if key in _ROPE_DICT: if key in _ROPE_DICT:
return _ROPE_DICT[key] return _ROPE_DICT[key]
base = rope_parameters["rope_theta"] if rope_parameters else 10000
if dual_chunk_attention_config is not None: if dual_chunk_attention_config is not None:
extra_kwargs = { extra_kwargs = {
k: v k: v
...@@ -88,208 +89,201 @@ def get_rope( ...@@ -88,208 +89,201 @@ def get_rope(
dtype, dtype,
**extra_kwargs, **extra_kwargs,
) )
elif not rope_parameters: elif scaling_type == "default":
rotary_emb = RotaryEmbedding( if "mrope_section" in rope_parameters:
head_size, rotary_dim, max_position, base, is_neox_style, dtype rotary_emb = MRotaryEmbedding(
)
else:
scaling_type = rope_parameters["rope_type"]
if scaling_type == "llama3":
scaling_factor = rope_parameters["factor"]
low_freq_factor = rope_parameters["low_freq_factor"]
high_freq_factor = rope_parameters["high_freq_factor"]
original_max_position = rope_parameters["original_max_position_embeddings"]
rotary_emb = Llama3RotaryEmbedding(
head_size, head_size,
rotary_dim, rotary_dim,
max_position, max_position,
base, base,
is_neox_style, is_neox_style,
dtype, dtype,
scaling_factor, mrope_section=rope_parameters["mrope_section"],
low_freq_factor, mrope_interleaved=rope_parameters.get("mrope_interleaved", False),
high_freq_factor,
original_max_position,
) )
elif scaling_type == "mllama4": else:
rotary_emb = Llama4VisionRotaryEmbedding( rotary_emb = RotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style, dtype
)
elif scaling_type == "default":
if "mrope_section" in rope_parameters:
rotary_emb = MRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
mrope_section=rope_parameters["mrope_section"],
mrope_interleaved=rope_parameters.get("mrope_interleaved", False),
)
else:
rotary_emb = RotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
)
elif scaling_type == "linear":
scaling_factor = rope_parameters["factor"]
rotary_emb = LinearScalingRotaryEmbedding(
head_size, head_size,
rotary_dim, rotary_dim,
max_position, max_position,
base, base,
is_neox_style, is_neox_style,
scaling_factor,
dtype, dtype,
) )
elif scaling_type == "ntk": elif scaling_type == "llama3":
scaling_factor = rope_parameters["factor"] scaling_factor = rope_parameters["factor"]
mixed_b = rope_parameters.get("mixed_b") low_freq_factor = rope_parameters["low_freq_factor"]
rotary_emb = NTKScalingRotaryEmbedding( high_freq_factor = rope_parameters["high_freq_factor"]
original_max_position = rope_parameters["original_max_position_embeddings"]
rotary_emb = Llama3RotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
scaling_factor,
low_freq_factor,
high_freq_factor,
original_max_position,
)
elif scaling_type == "mllama4":
rotary_emb = Llama4VisionRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style, dtype
)
elif scaling_type == "linear":
scaling_factor = rope_parameters["factor"]
rotary_emb = LinearScalingRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_factor,
dtype,
)
elif scaling_type == "ntk":
scaling_factor = rope_parameters["factor"]
mixed_b = rope_parameters.get("mixed_b")
rotary_emb = NTKScalingRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_factor,
dtype,
mixed_b,
)
elif scaling_type == "dynamic":
if "alpha" in rope_parameters:
scaling_alpha = rope_parameters["alpha"]
rotary_emb = DynamicNTKAlphaRotaryEmbedding(
head_size, head_size,
rotary_dim, rotary_dim,
max_position, max_position,
base, base,
is_neox_style, is_neox_style,
scaling_factor, scaling_alpha,
dtype, dtype,
mixed_b,
) )
elif scaling_type == "dynamic": elif "factor" in rope_parameters:
if "alpha" in rope_parameters: scaling_factor = rope_parameters["factor"]
scaling_alpha = rope_parameters["alpha"] rotary_emb = DynamicNTKScalingRotaryEmbedding(
rotary_emb = DynamicNTKAlphaRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_alpha,
dtype,
)
elif "factor" in rope_parameters:
scaling_factor = rope_parameters["factor"]
rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_factor,
dtype,
)
else:
raise ValueError(
"Dynamic rope scaling must contain either 'alpha' or 'factor' field"
)
elif scaling_type == "xdrope":
scaling_alpha = rope_parameters["alpha"]
rotary_emb = XDRotaryEmbedding(
head_size, head_size,
rotary_dim, rotary_dim,
max_position, max_position,
base, base,
is_neox_style, is_neox_style,
scaling_alpha, scaling_factor,
dtype, dtype,
xdrope_section=rope_parameters["xdrope_section"],
) )
elif scaling_type == "yarn": else:
scaling_factor = rope_parameters["factor"] raise ValueError(
original_max_position = rope_parameters["original_max_position_embeddings"] "Dynamic rope scaling must contain either 'alpha' or 'factor' field"
extra_kwargs = { )
k: v elif scaling_type == "xdrope":
for k, v in rope_parameters.items() scaling_alpha = rope_parameters["alpha"]
if k rotary_emb = XDRotaryEmbedding(
in ( head_size,
"extrapolation_factor", rotary_dim,
"attn_factor", max_position,
"beta_fast", base,
"beta_slow", is_neox_style,
"apply_yarn_scaling", scaling_alpha,
"truncate", dtype,
) xdrope_section=rope_parameters["xdrope_section"],
} )
if "mrope_section" in rope_parameters: elif scaling_type == "yarn":
extra_kwargs.pop("apply_yarn_scaling", None) scaling_factor = rope_parameters["factor"]
rotary_emb = MRotaryEmbedding( original_max_position = rope_parameters["original_max_position_embeddings"]
head_size, extra_kwargs = {
rotary_dim, k: v
original_max_position, for k, v in rope_parameters.items()
base, if k
is_neox_style, in (
dtype, "extrapolation_factor",
mrope_section=rope_parameters["mrope_section"], "attn_factor",
mrope_interleaved=rope_parameters.get("mrope_interleaved", False), "beta_fast",
scaling_factor=scaling_factor, "beta_slow",
**extra_kwargs, "apply_yarn_scaling",
) "truncate",
else: )
rotary_emb = YaRNScalingRotaryEmbedding( }
head_size, if "mrope_section" in rope_parameters:
rotary_dim, extra_kwargs.pop("apply_yarn_scaling", None)
original_max_position, rotary_emb = MRotaryEmbedding(
base,
is_neox_style,
scaling_factor,
dtype,
**extra_kwargs,
)
elif scaling_type in ["deepseek_yarn", "deepseek_llama_scaling"]:
scaling_factor = rope_parameters["factor"]
original_max_position = rope_parameters["original_max_position_embeddings"]
# assert max_position == original_max_position * scaling_factor
extra_kwargs = {
k: v
for k, v in rope_parameters.items()
if k
in (
"extrapolation_factor",
"attn_factor",
"beta_fast",
"beta_slow",
"mscale",
"mscale_all_dim",
)
}
rotary_emb = DeepseekScalingRotaryEmbedding(
head_size, head_size,
rotary_dim, rotary_dim,
original_max_position, original_max_position,
base, base,
is_neox_style, is_neox_style,
scaling_factor,
dtype, dtype,
mrope_section=rope_parameters["mrope_section"],
mrope_interleaved=rope_parameters.get("mrope_interleaved", False),
scaling_factor=scaling_factor,
**extra_kwargs, **extra_kwargs,
) )
elif scaling_type == "longrope": else:
short_factor = rope_parameters["short_factor"] rotary_emb = YaRNScalingRotaryEmbedding(
long_factor = rope_parameters["long_factor"]
original_max_position = rope_parameters["original_max_position_embeddings"]
extra_kwargs = {
k: v
for k, v in rope_parameters.items()
if k in ("short_mscale", "long_mscale")
}
rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(
head_size, head_size,
rotary_dim, rotary_dim,
max_position,
original_max_position, original_max_position,
base, base,
is_neox_style, is_neox_style,
scaling_factor,
dtype, dtype,
short_factor,
long_factor,
**extra_kwargs, **extra_kwargs,
) )
else: elif scaling_type in ["deepseek_yarn", "deepseek_llama_scaling"]:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}") scaling_factor = rope_parameters["factor"]
original_max_position = rope_parameters["original_max_position_embeddings"]
# assert max_position == original_max_position * scaling_factor
extra_kwargs = {
k: v
for k, v in rope_parameters.items()
if k
in (
"extrapolation_factor",
"attn_factor",
"beta_fast",
"beta_slow",
"mscale",
"mscale_all_dim",
)
}
rotary_emb = DeepseekScalingRotaryEmbedding(
head_size,
rotary_dim,
original_max_position,
base,
is_neox_style,
scaling_factor,
dtype,
**extra_kwargs,
)
elif scaling_type == "longrope":
short_factor = rope_parameters["short_factor"]
long_factor = rope_parameters["long_factor"]
original_max_position = rope_parameters["original_max_position_embeddings"]
extra_kwargs = {
k: v
for k, v in rope_parameters.items()
if k in ("short_mscale", "long_mscale")
}
rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(
head_size,
rotary_dim,
max_position,
original_max_position,
base,
is_neox_style,
dtype,
short_factor,
long_factor,
**extra_kwargs,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
_ROPE_DICT[key] = rotary_emb _ROPE_DICT[key] = rotary_emb
return rotary_emb return rotary_emb
...@@ -241,7 +241,6 @@ class AfmoeAttention(nn.Module): ...@@ -241,7 +241,6 @@ class AfmoeAttention(nn.Module):
if self.is_local_attention: if self.is_local_attention:
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=config["rope_parameters"], rope_parameters=config["rope_parameters"],
is_neox_style=True, is_neox_style=True,
......
...@@ -226,7 +226,6 @@ class ApertusAttention(nn.Module): ...@@ -226,7 +226,6 @@ class ApertusAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings, max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=is_neox_style, is_neox_style=is_neox_style,
......
...@@ -314,7 +314,6 @@ class ArcticAttention(nn.Module): ...@@ -314,7 +314,6 @@ class ArcticAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings, max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=True, is_neox_style=True,
......
...@@ -189,7 +189,6 @@ class BaiChuanAttention(nn.Module): ...@@ -189,7 +189,6 @@ class BaiChuanAttention(nn.Module):
else: else:
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings, max_position=self.max_position_embeddings,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
) )
......
...@@ -127,11 +127,11 @@ class BailingAttention(nn.Module): ...@@ -127,11 +127,11 @@ class BailingAttention(nn.Module):
prefix=f"{prefix}.dense", prefix=f"{prefix}.dense",
) )
self.rotary_dim = getattr(config, "rotary_dim", self.head_dim) rotary_dim = getattr(config, "rotary_dim", self.head_dim)
config.rope_parameters["partial_rotary_factor"] = rotary_dim / self.head_dim
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.rotary_dim,
max_position=config.max_position_embeddings, max_position=config.max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=True, is_neox_style=True,
......
...@@ -178,14 +178,11 @@ class BambaAttentionDecoderLayer(nn.Module): ...@@ -178,14 +178,11 @@ class BambaAttentionDecoderLayer(nn.Module):
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
if hasattr(config, "attn_rotary_emb"): rotary_dim = getattr(config, "attn_rotary_emb", self.head_dim)
rotary_dim = config.attn_rotary_emb # for backward compatibility config.rope_parameters["partial_rotary_factor"] = rotary_dim / self.head_dim
else:
rotary_dim = self.head_dim # default
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
head_size=self.head_dim, head_size=self.head_dim,
rotary_dim=rotary_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=True, is_neox_style=True,
......
...@@ -314,7 +314,6 @@ class ChameleonAttention(nn.Module): ...@@ -314,7 +314,6 @@ class ChameleonAttention(nn.Module):
self.k_norm = ChameleonLayerNorm((self.num_kv_heads, self.head_dim)) self.k_norm = ChameleonLayerNorm((self.num_kv_heads, self.head_dim))
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
) )
......
...@@ -99,13 +99,16 @@ class GLMAttention(nn.Module): ...@@ -99,13 +99,16 @@ class GLMAttention(nn.Module):
# https://huggingface.co/zai-org/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141 # https://huggingface.co/zai-org/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
rope_ratio = getattr(config, "rope_ratio", 1.0) rope_ratio = getattr(config, "rope_ratio", 1.0)
max_positions = getattr(config, "seq_length", 8192) max_positions = getattr(config, "seq_length", 8192)
rope_parameters = {"rope_type": "default", "rope_theta": 10000 * rope_ratio} rope_parameters = {
"rope_type": "default",
"rope_theta": 10000 * rope_ratio,
"partial_rotary_factor": 0.5,
}
# NOTE: zai-org/cogagent-9b-20241220 uses original_rope=False, # NOTE: zai-org/cogagent-9b-20241220 uses original_rope=False,
# which is equivalent to is_neox_style=True # which is equivalent to is_neox_style=True
is_neox_style = not config.original_rope is_neox_style = not config.original_rope
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim // 2,
max_position=max_positions, max_position=max_positions,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
is_neox_style=is_neox_style, is_neox_style=is_neox_style,
......
...@@ -175,7 +175,6 @@ class CohereAttention(nn.Module): ...@@ -175,7 +175,6 @@ class CohereAttention(nn.Module):
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings, max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=False, is_neox_style=False,
......
...@@ -42,9 +42,10 @@ class GteNewModelConfig(VerifyAndUpdateConfig): ...@@ -42,9 +42,10 @@ class GteNewModelConfig(VerifyAndUpdateConfig):
config.hidden_act = "geglu" config.hidden_act = "geglu"
head_dim = config.hidden_size // config.num_attention_heads head_dim = config.hidden_size // config.num_attention_heads
rotary_dim = getattr(config, "rotary_emb_dim", head_dim)
config.rope_parameters["partial_rotary_factor"] = rotary_dim / head_dim
config.rotary_kwargs = { config.rotary_kwargs = {
"head_size": head_dim, "head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": config.max_position_embeddings, "max_position": config.max_position_embeddings,
"rope_parameters": config.rope_parameters, "rope_parameters": config.rope_parameters,
} }
...@@ -77,9 +78,11 @@ class JinaRobertaModelConfig(VerifyAndUpdateConfig): ...@@ -77,9 +78,11 @@ class JinaRobertaModelConfig(VerifyAndUpdateConfig):
if not model_config.enforce_eager: if not model_config.enforce_eager:
max_position = round_up(max_position, 8) max_position = round_up(max_position, 8)
rotary_dim = getattr(config, "rotary_emb_dim", head_dim)
config.rope_parameters["partial_rotary_factor"] = rotary_dim / head_dim
config.rotary_kwargs = { config.rotary_kwargs = {
"head_size": head_dim, "head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": max_position, "max_position": max_position,
"rope_parameters": config.rope_parameters, "rope_parameters": config.rope_parameters,
} }
...@@ -113,12 +116,10 @@ class NomicBertModelConfig(VerifyAndUpdateConfig): ...@@ -113,12 +116,10 @@ class NomicBertModelConfig(VerifyAndUpdateConfig):
config.num_hidden_layers = config.n_layer config.num_hidden_layers = config.n_layer
head_dim = config.hidden_size // config.num_attention_heads head_dim = config.hidden_size // config.num_attention_heads
rotary_emb_dim = int(head_dim * config.rotary_emb_fraction)
max_trained_positions = getattr(config, "max_trained_positions", 2048) max_trained_positions = getattr(config, "max_trained_positions", 2048)
config.rotary_kwargs = { config.rotary_kwargs = {
"head_size": head_dim, "head_size": head_dim,
"rotary_dim": rotary_emb_dim,
"max_position": max_trained_positions, "max_position": max_trained_positions,
"rope_parameters": config.rope_parameters, "rope_parameters": config.rope_parameters,
} }
...@@ -240,9 +241,10 @@ class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig): ...@@ -240,9 +241,10 @@ class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
config.hidden_act = "geglu" config.hidden_act = "geglu"
head_dim = config.hidden_size // config.num_attention_heads head_dim = config.hidden_size // config.num_attention_heads
rotary_dim = getattr(config, "rotary_emb_dim", head_dim)
config.rope_parameters["partial_rotary_factor"] = rotary_dim / head_dim
config.rotary_kwargs = { config.rotary_kwargs = {
"head_size": head_dim, "head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": config.max_position_embeddings, "max_position": config.max_position_embeddings,
"rope_parameters": config.rope_parameters, "rope_parameters": config.rope_parameters,
} }
......
...@@ -222,7 +222,6 @@ class DbrxAttention(nn.Module): ...@@ -222,7 +222,6 @@ class DbrxAttention(nn.Module):
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position, max_position=self.max_position,
rope_parameters=rope_parameters, rope_parameters=rope_parameters,
is_neox_style=True, is_neox_style=True,
......
...@@ -156,7 +156,6 @@ class DeepseekAttention(nn.Module): ...@@ -156,7 +156,6 @@ class DeepseekAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
) )
...@@ -499,7 +498,6 @@ class DeepseekV2Attention(nn.Module): ...@@ -499,7 +498,6 @@ class DeepseekV2Attention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
qk_rope_head_dim, qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=False, is_neox_style=False,
...@@ -1018,7 +1016,6 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -1018,7 +1016,6 @@ class DeepseekV2MLAAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
qk_rope_head_dim, qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=False, is_neox_style=False,
...@@ -1038,7 +1035,6 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -1038,7 +1035,6 @@ class DeepseekV2MLAAttention(nn.Module):
if self.is_v32: if self.is_v32:
self.indexer_rope_emb = get_rope( self.indexer_rope_emb = get_rope(
qk_rope_head_dim, qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
is_neox_style=True, is_neox_style=True,
......
...@@ -250,7 +250,6 @@ class Dots1Attention(nn.Module): ...@@ -250,7 +250,6 @@ class Dots1Attention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment