Commit 8e9820a5 authored by Tri Dao's avatar Tri Dao
Browse files

[Rotary] Fix tests when loading state dict with rotary inv_freqs

parent b2520724
......@@ -20,10 +20,8 @@ def test_gptj_state_dict(model_name):
pretrained_state_dict = remap_state_dict_hf_gptj(state_dict_from_pretrained(model_name), config)
model = GPTLMHeadModel(config, device='meta') # Without device='meta' init is very slow
state_dict = model.state_dict()
rotary_inv_freq_keys = {f'transformer.layers.{l}.mixer.rotary_emb.inv_freq'
for l in range(config.n_layer)}
assert state_dict.keys() == pretrained_state_dict.keys() | rotary_inv_freq_keys
for k in state_dict.keys() - rotary_inv_freq_keys:
assert state_dict.keys() == pretrained_state_dict.keys()
for k in state_dict.keys():
assert state_dict[k].shape == pretrained_state_dict[k].shape
......
......@@ -2,7 +2,7 @@
# To run the huggingface implementation, we first need to convert the weights:
# https://github.com/huggingface/transformers/pull/21955
# python -m transformers.models.llama.convert_llama_weights_to_hf --input_dir $CHECKPOINT_DIR/llama --model_size 7B --output_dir $CHECKPOINT_DIR$/llama/7B-hf
# python -m transformers.models.llama.convert_llama_weights_to_hf --input_dir $CHECKPOINT_DIR/llama --model_size 7B --output_dir $CHECKPOINT_DIR/llama/7B-hf
# and repeat for 13B, 30B, 65B
import os
......@@ -32,10 +32,8 @@ def test_llama_state_dict(model_name):
pretrained_state_dict = remap_state_dict_meta_llama(ckpt_state_dicts[0], config)
model = GPTLMHeadModel(config, device='meta') # Without device='meta' init is very slow
state_dict = model.state_dict()
rotary_inv_freq_keys = {f'transformer.layers.{l}.mixer.rotary_emb.inv_freq'
for l in range(config.n_layer)}
assert state_dict.keys() == pretrained_state_dict.keys() | rotary_inv_freq_keys
for k in state_dict.keys() - rotary_inv_freq_keys:
assert state_dict.keys() == pretrained_state_dict.keys()
for k in state_dict.keys():
assert state_dict[k].shape == pretrained_state_dict[k].shape
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment