Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
8e9820a5
Commit
8e9820a5
authored
Jul 26, 2023
by
Tri Dao
Browse files
[Rotary] Fix tests when loading state dict with rotary inv_freqs
parent
b2520724
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
5 additions
and
9 deletions
+5
-9
tests/models/test_gptj.py
tests/models/test_gptj.py
+2
-4
tests/models/test_llama.py
tests/models/test_llama.py
+3
-5
No files found.
tests/models/test_gptj.py
View file @
8e9820a5
...
@@ -20,10 +20,8 @@ def test_gptj_state_dict(model_name):
...
@@ -20,10 +20,8 @@ def test_gptj_state_dict(model_name):
pretrained_state_dict
=
remap_state_dict_hf_gptj
(
state_dict_from_pretrained
(
model_name
),
config
)
pretrained_state_dict
=
remap_state_dict_hf_gptj
(
state_dict_from_pretrained
(
model_name
),
config
)
model
=
GPTLMHeadModel
(
config
,
device
=
'meta'
)
# Without device='meta' init is very slow
model
=
GPTLMHeadModel
(
config
,
device
=
'meta'
)
# Without device='meta' init is very slow
state_dict
=
model
.
state_dict
()
state_dict
=
model
.
state_dict
()
rotary_inv_freq_keys
=
{
f
'transformer.layers.
{
l
}
.mixer.rotary_emb.inv_freq'
assert
state_dict
.
keys
()
==
pretrained_state_dict
.
keys
()
for
l
in
range
(
config
.
n_layer
)}
for
k
in
state_dict
.
keys
():
assert
state_dict
.
keys
()
==
pretrained_state_dict
.
keys
()
|
rotary_inv_freq_keys
for
k
in
state_dict
.
keys
()
-
rotary_inv_freq_keys
:
assert
state_dict
[
k
].
shape
==
pretrained_state_dict
[
k
].
shape
assert
state_dict
[
k
].
shape
==
pretrained_state_dict
[
k
].
shape
...
...
tests/models/test_llama.py
View file @
8e9820a5
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
# To run the huggingface implementation, we first need to convert the weights:
# To run the huggingface implementation, we first need to convert the weights:
# https://github.com/huggingface/transformers/pull/21955
# https://github.com/huggingface/transformers/pull/21955
# python -m transformers.models.llama.convert_llama_weights_to_hf --input_dir $CHECKPOINT_DIR/llama --model_size 7B --output_dir $CHECKPOINT_DIR
$
/llama/7B-hf
# python -m transformers.models.llama.convert_llama_weights_to_hf --input_dir $CHECKPOINT_DIR/llama --model_size 7B --output_dir $CHECKPOINT_DIR/llama/7B-hf
# and repeat for 13B, 30B, 65B
# and repeat for 13B, 30B, 65B
import
os
import
os
...
@@ -32,10 +32,8 @@ def test_llama_state_dict(model_name):
...
@@ -32,10 +32,8 @@ def test_llama_state_dict(model_name):
pretrained_state_dict
=
remap_state_dict_meta_llama
(
ckpt_state_dicts
[
0
],
config
)
pretrained_state_dict
=
remap_state_dict_meta_llama
(
ckpt_state_dicts
[
0
],
config
)
model
=
GPTLMHeadModel
(
config
,
device
=
'meta'
)
# Without device='meta' init is very slow
model
=
GPTLMHeadModel
(
config
,
device
=
'meta'
)
# Without device='meta' init is very slow
state_dict
=
model
.
state_dict
()
state_dict
=
model
.
state_dict
()
rotary_inv_freq_keys
=
{
f
'transformer.layers.
{
l
}
.mixer.rotary_emb.inv_freq'
assert
state_dict
.
keys
()
==
pretrained_state_dict
.
keys
()
for
l
in
range
(
config
.
n_layer
)}
for
k
in
state_dict
.
keys
():
assert
state_dict
.
keys
()
==
pretrained_state_dict
.
keys
()
|
rotary_inv_freq_keys
for
k
in
state_dict
.
keys
()
-
rotary_inv_freq_keys
:
assert
state_dict
[
k
].
shape
==
pretrained_state_dict
[
k
].
shape
assert
state_dict
[
k
].
shape
==
pretrained_state_dict
[
k
].
shape
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment