Unverified Commit 87499420 authored by rui-ren's avatar rui-ren Committed by GitHub
Browse files

fix RoPE t range issue for fp16 (#26602)

parent ea52ed9d
...@@ -108,7 +108,7 @@ class FalconRotaryEmbedding(nn.Module): ...@@ -108,7 +108,7 @@ class FalconRotaryEmbedding(nn.Module):
def _set_cos_sin_cache(self, seq_len, device, dtype): def _set_cos_sin_cache(self, seq_len, device, dtype):
self.seq_len_cached = seq_len self.seq_len_cached = seq_len
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) t = torch.arange(seq_len, device=device).to(dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(device) emb = torch.cat((freqs, freqs), dim=-1).to(device)
...@@ -171,7 +171,7 @@ class FalconLinearScalingRotaryEmbedding(FalconRotaryEmbedding): ...@@ -171,7 +171,7 @@ class FalconLinearScalingRotaryEmbedding(FalconRotaryEmbedding):
def _set_cos_sin_cache(self, seq_len, device, dtype): def _set_cos_sin_cache(self, seq_len, device, dtype):
self.seq_len_cached = seq_len self.seq_len_cached = seq_len
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) t = torch.arange(seq_len, device=device).to(dtype)
# This line is the only difference from FalconRotaryEmbedding._set_cos_sin_cache # This line is the only difference from FalconRotaryEmbedding._set_cos_sin_cache
t = t / self.scaling_factor t = t / self.scaling_factor
...@@ -208,7 +208,7 @@ class FalconDynamicNTKScalingRotaryEmbedding(FalconRotaryEmbedding): ...@@ -208,7 +208,7 @@ class FalconDynamicNTKScalingRotaryEmbedding(FalconRotaryEmbedding):
inv_freq = 1.0 / (base ** (torch.arange(0, self.head_dim, 2).float().to(device) / self.head_dim)) inv_freq = 1.0 / (base ** (torch.arange(0, self.head_dim, 2).float().to(device) / self.head_dim))
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) t = torch.arange(seq_len, device=device).to(dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(device) emb = torch.cat((freqs, freqs), dim=-1).to(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment