Unverified Commit 0540fef7 authored by JieXin Liang's avatar JieXin Liang Committed by GitHub
Browse files

[Fix] fix _yarn_linear_ramp_mask with device parameter (#4337)

parent 481f608b
...@@ -403,12 +403,12 @@ def _yarn_find_correction_range( ...@@ -403,12 +403,12 @@ def _yarn_find_correction_range(
def _yarn_linear_ramp_mask( def _yarn_linear_ramp_mask(
low: float, high: float, dim: int, dtype: torch.dtype low: float, high: float, dim: int, dtype: torch.dtype, device: torch.device = None
) -> torch.Tensor: ) -> torch.Tensor:
if low == high: if low == high:
high += 0.001 # Prevent singularity high += 0.001 # Prevent singularity
linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low) linear_func = (torch.arange(dim, dtype=dtype, device=device) - low) / (high - low)
ramp_func = torch.clamp(linear_func, 0, 1) ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func return ramp_func
...@@ -688,7 +688,9 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): ...@@ -688,7 +688,9 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
# Get n-d rotational scaling corrected for extrapolation # Get n-d rotational scaling corrected for extrapolation
inv_freq_mask = ( inv_freq_mask = (
1 1
- _yarn_linear_ramp_mask(low, high, self.rotary_dim // 2, dtype=torch.float) - _yarn_linear_ramp_mask(
low, high, self.rotary_dim // 2, dtype=torch.float, device=self.device
)
) * self.extrapolation_factor ) * self.extrapolation_factor
inv_freq = ( inv_freq = (
inv_freq_interpolation * (1 - inv_freq_mask) inv_freq_interpolation * (1 - inv_freq_mask)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment