Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
d0032700
Commit
d0032700
authored
Sep 13, 2023
by
Tri Dao
Browse files
Add tests for Pythia, GPT-JT, and RedPajama models
parent
bb9beb36
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
26 additions
and
6 deletions
+26
-6
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+9
-2
tests/models/test_gpt_neox.py
tests/models/test_gpt_neox.py
+16
-3
tests/models/test_gptj.py
tests/models/test_gptj.py
+1
-1
No files found.
flash_attn/models/gpt.py
View file @
d0032700
...
...
@@ -352,9 +352,16 @@ class GPTPreTrainedModel(nn.Module):
state_dict
=
remap_state_dict_hf_gpt2
(
state_dict
,
config
)
elif
model_name
.
startswith
(
"facebook/opt"
):
state_dict
=
remap_state_dict_hf_opt
(
state_dict
,
config
)
elif
model_name
.
startswith
(
"EleutherAI/gpt-j-"
):
elif
(
model_name
.
startswith
(
"EleutherAI/gpt-j-"
)
or
model_name
.
startswith
(
"togethercomputer/GPT-JT-"
)
):
state_dict
=
remap_state_dict_hf_gptj
(
state_dict
,
config
)
elif
model_name
.
startswith
(
"EleutherAI/gpt-neox-"
):
elif
(
model_name
.
startswith
(
"EleutherAI/gpt-neox-"
)
or
model_name
.
startswith
(
"EleutherAI/pythia-"
)
or
model_name
.
startswith
(
"togethercomputer/RedPajama-INCITE-"
)
):
state_dict
=
remap_state_dict_hf_gpt_neox
(
state_dict
,
config
)
elif
model_name
.
startswith
(
"tiiuae/falcon-"
):
state_dict
=
remap_state_dict_hf_falcon
(
state_dict
,
config
)
...
...
tests/models/test_gpt_neox.py
View file @
d0032700
...
...
@@ -24,7 +24,15 @@ def test_gptj_state_dict(model_name):
assert
state_dict
[
k
].
shape
==
pretrained_state_dict
[
k
].
shape
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"EleutherAI/gpt-neox-20b"
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"EleutherAI/pythia-1b"
,
"EleutherAI/pythia-2.8b"
,
"EleutherAI/gpt-neox-20b"
,
"togethercomputer/RedPajama-INCITE-7B-Base"
,
],
)
def
test_gpt_neox_optimized
(
model_name
):
"""Check that our implementation of GPT-NeoX (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
...
...
@@ -35,7 +43,12 @@ def test_gpt_neox_optimized(model_name):
config
=
gpt_neox_config_to_gpt2_config
(
GPTNeoXConfig
.
from_pretrained
(
model_name
))
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_mlp
=
True
# GPT-NeoX-20B uses "gelu_fast"
config
.
fused_mlp
=
config
.
activation_function
in
[
"gelu_fast"
,
"gelu_new"
,
"gelu_approx"
,
"gelu_pytorch_tanh"
,
]
config
.
fused_dropout_add_ln
=
True
config
.
residual_in_fp32
=
True
...
...
@@ -54,7 +67,7 @@ def test_gpt_neox_optimized(model_name):
logits
=
model
(
input_ids
).
logits
del
model
# Need at least 2 GPUs, otherwise we'll OOM
# Need at least 2 GPUs, otherwise we'll OOM
for the 20B model
# Without device_map, the model is loaded on the CPU, which is very slow
model_ref
=
GPTNeoXForCausalLM
.
from_pretrained
(
model_name
,
device_map
=
"auto"
)
model_ref
.
eval
()
...
...
tests/models/test_gptj.py
View file @
d0032700
...
...
@@ -23,7 +23,7 @@ def test_gptj_state_dict(model_name):
assert
state_dict
[
k
].
shape
==
pretrained_state_dict
[
k
].
shape
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"EleutherAI/gpt-j-6B"
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"EleutherAI/gpt-j-6B"
,
"togethercomputer/GPT-JT-6B-v1"
])
def
test_gptj_optimized
(
model_name
):
"""Check that our implementation of GPT-J (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment