Unverified Commit 95f96b45 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[`Llama`] remove persistent `inv_freq` tensor (#24998)

remove persistent tensor
parent d3ce048c
...@@ -107,7 +107,7 @@ class OpenLlamaRotaryEmbedding(torch.nn.Module): ...@@ -107,7 +107,7 @@ class OpenLlamaRotaryEmbedding(torch.nn.Module):
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.base = base self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq) self.register_buffer("inv_freq", inv_freq, persistent=False)
# Build here to make `torch.jit.trace` work. # Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache( self._set_cos_sin_cache(
...@@ -171,7 +171,7 @@ class OpenLlamaDynamicNTKScalingRotaryEmbedding(OpenLlamaRotaryEmbedding): ...@@ -171,7 +171,7 @@ class OpenLlamaDynamicNTKScalingRotaryEmbedding(OpenLlamaRotaryEmbedding):
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2)) ) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq) self.register_buffer("inv_freq", inv_freq, persistent=False)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
......
...@@ -97,7 +97,7 @@ class LlamaRotaryEmbedding(torch.nn.Module): ...@@ -97,7 +97,7 @@ class LlamaRotaryEmbedding(torch.nn.Module):
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.base = base self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq) self.register_buffer("inv_freq", inv_freq, persistent=False)
# Build here to make `torch.jit.trace` work. # Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache( self._set_cos_sin_cache(
...@@ -159,7 +159,7 @@ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): ...@@ -159,7 +159,7 @@ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2)) ) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq) self.register_buffer("inv_freq", inv_freq, persistent=False)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment