Unverified Commit d27f4bae authored by Roy's avatar Roy Committed by GitHub
Browse files

Fix rope cache key error (#1867)

parent 8d8c2f6f
...@@ -284,9 +284,10 @@ def get_rope( ...@@ -284,9 +284,10 @@ def get_rope(
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
) -> RotaryEmbedding: ) -> RotaryEmbedding:
key = (head_size, rotary_dim, max_position, base, is_neox_style, key = (head_size, rotary_dim, max_position, base, is_neox_style,
rope_scaling) tuple(rope_scaling.items()) if rope_scaling is not None else None)
if key in _ROPE_DICT: if key in _ROPE_DICT:
return _ROPE_DICT[key] return _ROPE_DICT[key]
if rope_scaling is None: if rope_scaling is None:
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style) is_neox_style)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment