Unverified Commit 1a03dd49 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Bugfix] Fix dynamic rotary embedding (#20343)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 27b80176
......@@ -1963,16 +1963,19 @@ def get_rope(
scaling_factor, dtype,
mixed_b)
elif scaling_type == "dynamic":
scaling_factor = rope_scaling["factor"]
scaling_alpha = rope_scaling["alpha"]
if scaling_alpha:
if "alpha" in rope_scaling:
scaling_alpha = rope_scaling["alpha"]
rotary_emb = DynamicNTKAlphaRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style,
scaling_alpha, dtype)
else:
elif "factor" in rope_scaling:
scaling_factor = rope_scaling["factor"]
rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style,
scaling_factor, dtype)
else:
raise ValueError("Dynamic rope scaling must contain either "
"'alpha' or 'factor' field")
elif scaling_type == "yarn":
scaling_factor = rope_scaling["factor"]
original_max_position = rope_scaling[
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment