Unverified Commit a7f341c3 authored by haosdent's avatar haosdent Committed by GitHub
Browse files

[Bugfix] Fix MRotaryEmbedding missing `truncate` attr with YaRN scaling (#35080)


Signed-off-by: default avatarhaosdent <haosdent@gmail.com>
parent d13ece38
...@@ -218,12 +218,14 @@ class MRotaryEmbedding(RotaryEmbeddingBase): ...@@ -218,12 +218,14 @@ class MRotaryEmbedding(RotaryEmbeddingBase):
attn_factor: float = 1, attn_factor: float = 1,
beta_fast: int = 32, beta_fast: int = 32,
beta_slow: int = 1, beta_slow: int = 1,
truncate: bool = True,
) -> None: ) -> None:
self.scaling_factor = scaling_factor self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor self.attn_factor = attn_factor
self.beta_fast = beta_fast self.beta_fast = beta_fast
self.beta_slow = beta_slow self.beta_slow = beta_slow
self.truncate = truncate
if self.scaling_factor is not None: if self.scaling_factor is not None:
# Get n-d magnitude scaling corrected for interpolation # Get n-d magnitude scaling corrected for interpolation
self.mscale = float(yarn_get_mscale(self.scaling_factor) * attn_factor) self.mscale = float(yarn_get_mscale(self.scaling_factor) * attn_factor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment