Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
502fec77
Unverified
Commit
502fec77
authored
Mar 23, 2023
by
Joao Gante
Committed by
GitHub
Mar 23, 2023
Browse files
Generate: add test for left-padding support (#22322)
parent
ec9b18f6
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
50 additions
and
1 deletion
+50
-1
tests/generation/test_utils.py
tests/generation/test_utils.py
+48
-0
tests/models/gpt_neox/test_modeling_gpt_neox.py
tests/models/gpt_neox/test_modeling_gpt_neox.py
+2
-1
No files found.
tests/generation/test_utils.py
View file @
502fec77
...
...
@@ -1497,6 +1497,54 @@ class GenerationTesterMixin:
attn_weights
=
out
[
attn_name
]
if
attn_name
==
attention_names
[
0
]
else
out
[
attn_name
][
-
1
]
self
.
assertEqual
(
sum
([
w
.
sum
().
item
()
for
w
in
attn_weights
]),
0.0
)
# TODO (joao): this test is actually not slow :) However, it is not passing in some models (e.g. GPTNeoX) and the
# fix for some models is quite lengthy. Being slow means it doesn't block our push CI while we fix it.
@
slow
def
test_left_padding_compatibility
(
self
):
# The check done in this test is fairly difficult -- depending on the model architecture, passing the right
# position index for the position embeddings can still result in a different output, due to numerical masking.
# On the other hand, for some types of position embeddings, an incorrect position index can have a minimal
# impact on the output.
# There are two tricks employed to check whether left-padding compatibility is in place:
# 1 - To reduce the negative impact of the numerical attention mask on a correct position index, we set the
# padding size to 1.
# 2 - To reduce the chance of false positives (i.e. passing when it should be failing), we run the check
# multiple times with random inputs, and it has to pass with all of them.
# NOTE: because of 2), there is some chance of false positives in this test.
for
model_class
in
self
.
all_generative_model_classes
:
config
,
_
,
_
,
_
=
self
.
_get_input_ids_and_config
()
if
config
.
is_encoder_decoder
:
continue
# skip for encoder-decoder models -- they don't need left-padding compatibility
model
=
model_class
(
config
).
to
(
torch_device
).
eval
()
signature
=
inspect
.
signature
(
model
.
forward
).
parameters
.
keys
()
no_failures
=
True
for
_
in
range
(
10
):
# there may be false positives with 10 runs, we rely on the CI to catch the flakiness
_
,
input_ids
,
attention_mask
,
_
=
self
.
_get_input_ids_and_config
()
model_kwargs
=
{
"input_ids"
:
input_ids
,
"attention_mask"
:
attention_mask
}
if
"position_ids"
in
signature
:
position_ids
=
torch
.
cumsum
(
attention_mask
,
dim
=-
1
)
-
1
position_ids
.
masked_fill_
(
attention_mask
==
0
,
1
)
model_kwargs
[
"position_ids"
]
=
position_ids
next_logits_wo_padding
=
model
(
**
model_kwargs
).
logits
[:,
-
1
,
:]
pad_size
=
(
input_ids
.
shape
[
0
],
1
)
padding
=
torch
.
ones
(
pad_size
,
dtype
=
input_ids
.
dtype
,
device
=
torch_device
)
*
config
.
pad_token_id
padded_input_ids
=
torch
.
cat
((
padding
,
input_ids
),
dim
=
1
)
padded_attention_mask
=
torch
.
cat
((
torch
.
zeros_like
(
padding
),
attention_mask
),
dim
=
1
)
model_kwargs
=
{
"input_ids"
:
padded_input_ids
,
"attention_mask"
:
padded_attention_mask
}
if
"position_ids"
in
signature
:
position_ids
=
torch
.
cumsum
(
padded_attention_mask
,
dim
=-
1
)
-
1
position_ids
.
masked_fill_
(
padded_attention_mask
==
0
,
1
)
model_kwargs
[
"position_ids"
]
=
position_ids
next_logits_with_padding
=
model
(
**
model_kwargs
).
logits
[:,
-
1
,
:]
if
not
torch
.
allclose
(
next_logits_wo_padding
,
next_logits_with_padding
):
no_failures
=
False
break
self
.
assertTrue
(
no_failures
)
def
_check_outputs
(
self
,
output
,
input_ids
,
config
,
use_cache
=
False
,
num_return_sequences
=
1
):
batch_size
,
seq_length
=
input_ids
.
shape
num_sequences_in_output
=
batch_size
*
num_return_sequences
...
...
tests/models/gpt_neox/test_modeling_gpt_neox.py
View file @
502fec77
...
...
@@ -20,6 +20,7 @@ import unittest
from
transformers
import
GPTNeoXConfig
,
is_torch_available
from
transformers.testing_utils
import
require_torch
,
torch_device
from
...generation.test_utils
import
GenerationTesterMixin
from
...test_configuration_common
import
ConfigTester
from
...test_modeling_common
import
ModelTesterMixin
,
ids_tensor
,
random_attention_mask
from
...test_pipeline_mixin
import
PipelineTesterMixin
...
...
@@ -186,7 +187,7 @@ class GPTNeoXModelTester:
@
require_torch
class
GPTNeoXModelTest
(
ModelTesterMixin
,
PipelineTesterMixin
,
unittest
.
TestCase
):
class
GPTNeoXModelTest
(
ModelTesterMixin
,
GenerationTesterMixin
,
PipelineTesterMixin
,
unittest
.
TestCase
):
all_model_classes
=
(
GPTNeoXModel
,
GPTNeoXForCausalLM
)
if
is_torch_available
()
else
()
all_generative_model_classes
=
(
GPTNeoXForCausalLM
,)
if
is_torch_available
()
else
()
pipeline_model_mapping
=
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment