Unverified Commit 083e3692 authored by Kevin Hu's avatar Kevin Hu Committed by GitHub
Browse files

Fix precision errors from casting rotary parameters to FP16 with AMP (#27700)

* Update modeling_llama.py

* Update modeling_open_llama.py

* Update modeling_gpt_neox.py

* Update modeling_mistral.py

* Update modeling_persimmon.py

* Update modeling_phi.py

* Update modeling_falcon.py

* Update modeling_gpt_neox_japanese.py
parent af8acc47
...@@ -83,7 +83,7 @@ class OpenLlamaRotaryEmbedding(nn.Module): ...@@ -83,7 +83,7 @@ class OpenLlamaRotaryEmbedding(nn.Module):
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
...@@ -113,7 +113,7 @@ class OpenLlamaLinearScalingRotaryEmbedding(OpenLlamaRotaryEmbedding): ...@@ -113,7 +113,7 @@ class OpenLlamaLinearScalingRotaryEmbedding(OpenLlamaRotaryEmbedding):
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = t / self.scaling_factor t = t / self.scaling_factor
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
...@@ -140,7 +140,7 @@ class OpenLlamaDynamicNTKScalingRotaryEmbedding(OpenLlamaRotaryEmbedding): ...@@ -140,7 +140,7 @@ class OpenLlamaDynamicNTKScalingRotaryEmbedding(OpenLlamaRotaryEmbedding):
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
......
...@@ -141,7 +141,7 @@ class FalconRotaryEmbedding(nn.Module): ...@@ -141,7 +141,7 @@ class FalconRotaryEmbedding(nn.Module):
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
...@@ -171,7 +171,7 @@ class FalconLinearScalingRotaryEmbedding(FalconRotaryEmbedding): ...@@ -171,7 +171,7 @@ class FalconLinearScalingRotaryEmbedding(FalconRotaryEmbedding):
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = t / self.scaling_factor t = t / self.scaling_factor
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
...@@ -198,7 +198,7 @@ class FalconDynamicNTKScalingRotaryEmbedding(FalconRotaryEmbedding): ...@@ -198,7 +198,7 @@ class FalconDynamicNTKScalingRotaryEmbedding(FalconRotaryEmbedding):
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
......
...@@ -302,7 +302,7 @@ class GPTNeoXRotaryEmbedding(nn.Module): ...@@ -302,7 +302,7 @@ class GPTNeoXRotaryEmbedding(nn.Module):
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
...@@ -332,7 +332,7 @@ class GPTNeoXLinearScalingRotaryEmbedding(GPTNeoXRotaryEmbedding): ...@@ -332,7 +332,7 @@ class GPTNeoXLinearScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = t / self.scaling_factor t = t / self.scaling_factor
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
...@@ -359,7 +359,7 @@ class GPTNeoXDynamicNTKScalingRotaryEmbedding(GPTNeoXRotaryEmbedding): ...@@ -359,7 +359,7 @@ class GPTNeoXDynamicNTKScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
......
...@@ -253,7 +253,7 @@ class RotaryEmbedding(nn.Module): ...@@ -253,7 +253,7 @@ class RotaryEmbedding(nn.Module):
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
......
...@@ -132,7 +132,7 @@ class LlamaRotaryEmbedding(nn.Module): ...@@ -132,7 +132,7 @@ class LlamaRotaryEmbedding(nn.Module):
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
...@@ -161,7 +161,7 @@ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): ...@@ -161,7 +161,7 @@ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = t / self.scaling_factor t = t / self.scaling_factor
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
...@@ -187,7 +187,7 @@ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): ...@@ -187,7 +187,7 @@ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
......
...@@ -106,7 +106,7 @@ class MistralRotaryEmbedding(nn.Module): ...@@ -106,7 +106,7 @@ class MistralRotaryEmbedding(nn.Module):
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
......
...@@ -59,7 +59,7 @@ class PersimmonRotaryEmbedding(nn.Module): ...@@ -59,7 +59,7 @@ class PersimmonRotaryEmbedding(nn.Module):
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
...@@ -89,7 +89,7 @@ class PersimmonLinearScalingRotaryEmbedding(PersimmonRotaryEmbedding): ...@@ -89,7 +89,7 @@ class PersimmonLinearScalingRotaryEmbedding(PersimmonRotaryEmbedding):
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = t / self.scaling_factor t = t / self.scaling_factor
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
...@@ -116,7 +116,7 @@ class PersimmonDynamicNTKScalingRotaryEmbedding(PersimmonRotaryEmbedding): ...@@ -116,7 +116,7 @@ class PersimmonDynamicNTKScalingRotaryEmbedding(PersimmonRotaryEmbedding):
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
......
...@@ -75,7 +75,7 @@ class PhiRotaryEmbedding(nn.Module): ...@@ -75,7 +75,7 @@ class PhiRotaryEmbedding(nn.Module):
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
...@@ -105,7 +105,7 @@ class PhiLinearScalingRotaryEmbedding(PhiRotaryEmbedding): ...@@ -105,7 +105,7 @@ class PhiLinearScalingRotaryEmbedding(PhiRotaryEmbedding):
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = t / self.scaling_factor t = t / self.scaling_factor
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
...@@ -132,7 +132,7 @@ class PhiDynamicNTKScalingRotaryEmbedding(PhiRotaryEmbedding): ...@@ -132,7 +132,7 @@ class PhiDynamicNTKScalingRotaryEmbedding(PhiRotaryEmbedding):
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment