Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
371e2065
Commit
371e2065
authored
Aug 26, 2023
by
Tri Dao
Browse files
[GPT] Test generation when passing in multiple tokens
parent
c000c3a2
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
43 additions
and
0 deletions
+43
-0
tests/models/test_gpt.py
tests/models/test_gpt.py
+43
-0
No files found.
tests/models/test_gpt.py
View file @
371e2065
...
@@ -4,6 +4,7 @@ import pytest
...
@@ -4,6 +4,7 @@ import pytest
import
torch
import
torch
from
einops
import
rearrange
from
einops
import
rearrange
from
flash_attn.models.gpt
import
GPTLMHeadModel
,
remap_state_dict_hf_gpt2
from
flash_attn.models.gpt
import
GPTLMHeadModel
,
remap_state_dict_hf_gpt2
from
flash_attn.utils.generation
import
InferenceParams
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
transformers
import
GPT2Config
,
GPT2Tokenizer
from
transformers
import
GPT2Config
,
GPT2Tokenizer
from
transformers.models.gpt2.modeling_gpt2
import
GPT2LMHeadModel
as
GPT2LMHeadModelHF
from
transformers.models.gpt2.modeling_gpt2
import
GPT2LMHeadModel
as
GPT2LMHeadModelHF
...
@@ -335,3 +336,45 @@ def test_gpt2_generation_cg(model_name, rotary, seqlen, maxlen):
...
@@ -335,3 +336,45 @@ def test_gpt2_generation_cg(model_name, rotary, seqlen, maxlen):
logits
=
get_logits
(
model
,
input_ids
,
maxlen
,
teacher_outputs
=
teacher_outputs
)
logits
=
get_logits
(
model
,
input_ids
,
maxlen
,
teacher_outputs
=
teacher_outputs
)
logits_cg
=
get_logits
(
model
,
input_ids
,
maxlen
,
teacher_outputs
=
teacher_outputs
,
cg
=
True
)
logits_cg
=
get_logits
(
model
,
input_ids
,
maxlen
,
teacher_outputs
=
teacher_outputs
,
cg
=
True
)
assert
torch
.
equal
(
logits
,
logits_cg
)
assert
torch
.
equal
(
logits
,
logits_cg
)
@
pytest
.
mark
.
parametrize
(
"optimized"
,
[
False
,
True
])
# @pytest.mark.parametrize("optimized", [False])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"gpt2"
])
def
test_gpt2_multiple_token_generation
(
model_name
,
optimized
):
"""Generation when we pass in multiple tokens at a time, not just one."""
dtype
=
torch
.
float16
device
=
"cuda"
rtol
,
atol
=
3e-3
,
3e-1
config
=
GPT2Config
.
from_pretrained
(
model_name
)
config
.
residual_in_fp32
=
True
if
optimized
:
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_mlp
=
True
config
.
fused_dropout_add_ln
=
True
# fused_ft_kernel currently doesn't work with multiple tokens at a time
# if not rotary, we load the weight from HF but ignore the position embeddings.
# The model would be nonsense but it doesn't matter for the test.
model
=
GPTLMHeadModel
.
from_pretrained
(
model_name
,
config
,
device
=
device
,
dtype
=
dtype
)
model
.
eval
()
torch
.
manual_seed
(
0
)
input_ids
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
1
,
20
),
dtype
=
torch
.
long
,
device
=
device
)
# Reference logits
logits_ref
=
model
(
input_ids
).
logits
# Run 10 tokens, then pass in another 4, then another 6, to see if we get the same logits
inference_params
=
InferenceParams
(
max_sequence_len
=
20
,
max_batch_size
=
1
)
logits_10
=
model
(
input_ids
[:,
:
10
],
inference_params
=
inference_params
).
logits
inference_params
.
sequence_len_offset
+=
10
position_ids
=
torch
.
arange
(
10
,
14
,
dtype
=
torch
.
long
,
device
=
device
)
logits_1014
=
model
(
input_ids
[:,
10
:
14
],
position_ids
=
position_ids
,
inference_params
=
inference_params
).
logits
inference_params
.
sequence_len_offset
+=
4
position_ids
=
torch
.
arange
(
14
,
20
,
dtype
=
torch
.
long
,
device
=
device
)
logits_1420
=
model
(
input_ids
[:,
14
:
20
],
position_ids
=
position_ids
,
inference_params
=
inference_params
).
logits
logits
=
torch
.
cat
([
logits_10
,
logits_1014
,
logits_1420
],
dim
=
1
)
print
(
f
"Logits max diff:
{
(
logits
-
logits_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Logits mean diff:
{
(
logits
-
logits_ref
).
abs
().
mean
().
item
()
}
"
)
assert
torch
.
allclose
(
logits
,
logits_ref
,
rtol
=
rtol
,
atol
=
atol
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment