Commit ec9f74ab authored by Tri Dao's avatar Tri Dao
Browse files

[Rotary] Don't store inv_freq in state_dict

parent a157cc8c
...@@ -191,12 +191,12 @@ class RotaryEmbedding(torch.nn.Module): ...@@ -191,12 +191,12 @@ class RotaryEmbedding(torch.nn.Module):
self.pos_idx_in_fp32 = pos_idx_in_fp32 self.pos_idx_in_fp32 = pos_idx_in_fp32
# Generate and save the inverse frequency buffer (non trainable) # Generate and save the inverse frequency buffer (non trainable)
inv_freq = self._compute_inv_freq(device) inv_freq = self._compute_inv_freq(device)
self.register_buffer("inv_freq", inv_freq) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.interleaved = interleaved self.interleaved = interleaved
self.scale_base = scale_base self.scale_base = scale_base
scale = ((torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) scale = ((torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
/ (1.4 * dim) if scale_base is not None else None) / (1.4 * dim) if scale_base is not None else None)
self.register_buffer("scale", scale) self.register_buffer("scale", scale, persistent=False)
self._seq_len_cached = 0 self._seq_len_cached = 0
self._cos_cached = None self._cos_cached = None
......
...@@ -237,7 +237,6 @@ class GPTPreTrainedModel(nn.Module): ...@@ -237,7 +237,6 @@ class GPTPreTrainedModel(nn.Module):
state_dict = remap_state_dict_hf_opt(state_dict, config) state_dict = remap_state_dict_hf_opt(state_dict, config)
elif model_name.startswith('EleutherAI/gpt-j-'): elif model_name.startswith('EleutherAI/gpt-j-'):
state_dict = remap_state_dict_hf_gptj(state_dict, config) state_dict = remap_state_dict_hf_gptj(state_dict, config)
strict = False # We have rotary_emb.inf_freq buffers not in the GPT-J checkpoint
elif model_name.startswith('EleutherAI/gpt-neox-'): elif model_name.startswith('EleutherAI/gpt-neox-'):
state_dict = remap_state_dict_hf_gpt_neox(state_dict, config) state_dict = remap_state_dict_hf_gpt_neox(state_dict, config)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment