Unverified Commit 63b2206a authored by Jee Li's avatar Jee Li Committed by GitHub
Browse files

Avoid multiple instantiations of the RoPE class (#1828)

parent 27feead2
...@@ -272,6 +272,9 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding): ...@@ -272,6 +272,9 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
return cache return cache
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
def get_rope( def get_rope(
head_size: int, head_size: int,
rotary_dim: int, rotary_dim: int,
...@@ -280,6 +283,10 @@ def get_rope( ...@@ -280,6 +283,10 @@ def get_rope(
is_neox_style: bool = True, is_neox_style: bool = True,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
) -> RotaryEmbedding: ) -> RotaryEmbedding:
key = (head_size, rotary_dim, max_position, base, is_neox_style,
rope_scaling)
if key in _ROPE_DICT:
return _ROPE_DICT[key]
if rope_scaling is None: if rope_scaling is None:
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style) is_neox_style)
...@@ -312,4 +319,5 @@ def get_rope( ...@@ -312,4 +319,5 @@ def get_rope(
**extra_kwargs) **extra_kwargs)
else: else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}") raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
_ROPE_DICT[key] = rotary_emb
return rotary_emb return rotary_emb
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment