Unverified Commit c530e2cf authored by 44670's avatar 44670 Committed by GitHub
Browse files

[FIX] Fix a bug in initializing Yarn RoPE (#2983)

parent fd5dcc5c
...@@ -245,13 +245,11 @@ def _yarn_find_correction_range(low_rot: int, ...@@ -245,13 +245,11 @@ def _yarn_find_correction_range(low_rot: int,
def _yarn_linear_ramp_mask(low: float, high: float, dim: int, def _yarn_linear_ramp_mask(low: float, high: float, dim: int,
dtype: torch.dtype, dtype: torch.dtype) -> torch.Tensor:
device: torch.device) -> torch.Tensor:
if low == high: if low == high:
high += 0.001 # Prevent singularity high += 0.001 # Prevent singularity
linear_func = (torch.arange(dim, dtype=dtype, device=device) - linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low)
low) / (high - low)
ramp_func = torch.clamp(linear_func, 0, 1) ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func return ramp_func
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment