Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
90b4adc1
Unverified
Commit
90b4adc1
authored
Nov 07, 2023
by
Joao Gante
Committed by
GitHub
Nov 07, 2023
Browse files
Generate: skip tests on unsupported models instead of passing (#27265)
parent
26d8d5f2
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
38 deletions
+24
-38
tests/generation/test_utils.py
tests/generation/test_utils.py
+24
-38
No files found.
tests/generation/test_utils.py
View file @
90b4adc1
...
@@ -749,8 +749,7 @@ class GenerationTesterMixin:
...
@@ -749,8 +749,7 @@ class GenerationTesterMixin:
config
,
input_ids
,
attention_mask
,
max_length
=
self
.
_get_input_ids_and_config
()
config
,
input_ids
,
attention_mask
,
max_length
=
self
.
_get_input_ids_and_config
()
if
not
hasattr
(
config
,
"use_cache"
):
if
not
hasattr
(
config
,
"use_cache"
):
# only relevant if model has "use_cache"
self
.
skipTest
(
"This model doesn't support caching"
)
return
config
.
use_cache
=
True
config
.
use_cache
=
True
config
.
is_decoder
=
True
config
.
is_decoder
=
True
...
@@ -983,8 +982,7 @@ class GenerationTesterMixin:
...
@@ -983,8 +982,7 @@ class GenerationTesterMixin:
config
.
forced_eos_token_id
=
None
config
.
forced_eos_token_id
=
None
if
not
hasattr
(
config
,
"use_cache"
):
if
not
hasattr
(
config
,
"use_cache"
):
# only relevant if model has "use_cache"
self
.
skipTest
(
"This model doesn't support caching"
)
return
model
=
model_class
(
config
).
to
(
torch_device
).
eval
()
model
=
model_class
(
config
).
to
(
torch_device
).
eval
()
if
model
.
config
.
is_encoder_decoder
:
if
model
.
config
.
is_encoder_decoder
:
...
@@ -1420,13 +1418,13 @@ class GenerationTesterMixin:
...
@@ -1420,13 +1418,13 @@ class GenerationTesterMixin:
for
model_class
in
self
.
all_generative_model_classes
:
for
model_class
in
self
.
all_generative_model_classes
:
# won't fix: FSMT and Reformer have a different cache variable type (and format).
# won't fix: FSMT and Reformer have a different cache variable type (and format).
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"fsmt"
,
"reformer"
]):
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"fsmt"
,
"reformer"
]):
return
self
.
skipTest
(
"Won't fix: old model with different cache format"
)
config
,
input_ids
,
attention_mask
,
max_length
=
self
.
_get_input_ids_and_config
()
config
,
input_ids
,
attention_mask
,
max_length
=
self
.
_get_input_ids_and_config
()
# NOTE: contrastive search only works with cache on at the moment.
# NOTE: contrastive search only works with cache on at the moment.
if
not
hasattr
(
config
,
"use_cache"
):
if
not
hasattr
(
config
,
"use_cache"
):
return
self
.
skipTest
(
"This model doesn't support caching"
)
config
.
use_cache
=
True
config
.
use_cache
=
True
config
.
is_decoder
=
True
config
.
is_decoder
=
True
...
@@ -1441,14 +1439,14 @@ class GenerationTesterMixin:
...
@@ -1441,14 +1439,14 @@ class GenerationTesterMixin:
for
model_class
in
self
.
all_generative_model_classes
:
for
model_class
in
self
.
all_generative_model_classes
:
# won't fix: FSMT and Reformer have a different cache variable type (and format).
# won't fix: FSMT and Reformer have a different cache variable type (and format).
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"fsmt"
,
"reformer"
]):
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"fsmt"
,
"reformer"
]):
return
self
.
skipTest
(
"Won't fix: old model with different cache format"
)
# enable cache
# enable cache
config
,
input_ids
,
attention_mask
,
max_length
=
self
.
_get_input_ids_and_config
()
config
,
input_ids
,
attention_mask
,
max_length
=
self
.
_get_input_ids_and_config
()
# NOTE: contrastive search only works with cache on at the moment.
# NOTE: contrastive search only works with cache on at the moment.
if
not
hasattr
(
config
,
"use_cache"
):
if
not
hasattr
(
config
,
"use_cache"
):
return
self
.
skipTest
(
"This model doesn't support caching"
)
config
.
use_cache
=
True
config
.
use_cache
=
True
config
.
is_decoder
=
True
config
.
is_decoder
=
True
...
@@ -1472,18 +1470,16 @@ class GenerationTesterMixin:
...
@@ -1472,18 +1470,16 @@ class GenerationTesterMixin:
def
test_contrastive_generate_low_memory
(
self
):
def
test_contrastive_generate_low_memory
(
self
):
# Check that choosing 'low_memory' does not change the model output
# Check that choosing 'low_memory' does not change the model output
for
model_class
in
self
.
all_generative_model_classes
:
for
model_class
in
self
.
all_generative_model_classes
:
# won't fix: FSMT, Reformer, gptbigcode, and speech2text have a different cache variable type (and format).
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"fsmt"
,
"reformer"
,
"speech2text"
]):
if
any
(
self
.
skipTest
(
"Won't fix: old model with different cache format"
)
model_name
in
model_class
.
__name__
.
lower
()
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"gptbigcode"
]):
for
model_name
in
[
"fsmt"
,
"reformer"
,
"gptbigcode"
,
"speech2text"
]
self
.
skipTest
(
"TODO: fix me"
)
):
return
config
,
input_ids
,
attention_mask
,
max_length
=
self
.
_get_input_ids_and_config
(
batch_size
=
1
)
config
,
input_ids
,
attention_mask
,
max_length
=
self
.
_get_input_ids_and_config
(
batch_size
=
1
)
# NOTE: contrastive search only works with cache on at the moment.
# NOTE: contrastive search only works with cache on at the moment.
if
not
hasattr
(
config
,
"use_cache"
):
if
not
hasattr
(
config
,
"use_cache"
):
return
self
.
skipTest
(
"This model doesn't support caching"
)
config
.
use_cache
=
True
config
.
use_cache
=
True
config
.
is_decoder
=
True
config
.
is_decoder
=
True
...
@@ -1510,8 +1506,6 @@ class GenerationTesterMixin:
...
@@ -1510,8 +1506,6 @@ class GenerationTesterMixin:
)
)
self
.
assertListEqual
(
low_output
.
tolist
(),
high_output
.
tolist
())
self
.
assertListEqual
(
low_output
.
tolist
(),
high_output
.
tolist
())
return
@
slow
# TODO(Joao): remove this. Some models (e.g. data2vec, xcom, roberta) have an error rate between 1 and 10%.
@
slow
# TODO(Joao): remove this. Some models (e.g. data2vec, xcom, roberta) have an error rate between 1 and 10%.
def
test_assisted_decoding_matches_greedy_search
(
self
):
def
test_assisted_decoding_matches_greedy_search
(
self
):
# This test ensures that the assisted generation does not introduce output changes over greedy search.
# This test ensures that the assisted generation does not introduce output changes over greedy search.
...
@@ -1522,15 +1516,13 @@ class GenerationTesterMixin:
...
@@ -1522,15 +1516,13 @@ class GenerationTesterMixin:
# - assisted_decoding does not support `batch_size > 1`
# - assisted_decoding does not support `batch_size > 1`
for
model_class
in
self
.
all_generative_model_classes
:
for
model_class
in
self
.
all_generative_model_classes
:
# won't fix: FSMT and Reformer have a different cache variable type (and format).
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"fsmt"
,
"reformer"
]):
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"fsmt"
,
"reformer"
]):
return
self
.
skipTest
(
"Won't fix: old model with different cache format"
)
# may fix in the future: the following models fail with assisted decoding, and need model-specific fixes
if
any
(
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"bigbirdpegasus"
,
"led"
,
"mega"
,
"speech2text"
,
"git"
,
"prophetnet"
]
for
model_name
in
[
"bigbirdpegasus"
,
"led"
,
"mega"
,
"speech2text"
,
"git"
,
"prophetnet"
]
):
):
return
self
.
skipTest
(
"May fix in the future: need model-specific fixes"
)
# This for loop is a naive and temporary effort to make the test less flaky.
# This for loop is a naive and temporary effort to make the test less flaky.
failed
=
0
failed
=
0
...
@@ -1540,7 +1532,7 @@ class GenerationTesterMixin:
...
@@ -1540,7 +1532,7 @@ class GenerationTesterMixin:
# NOTE: assisted generation only works with cache on at the moment.
# NOTE: assisted generation only works with cache on at the moment.
if
not
hasattr
(
config
,
"use_cache"
):
if
not
hasattr
(
config
,
"use_cache"
):
return
self
.
skipTest
(
"This model doesn't support caching"
)
config
.
use_cache
=
True
config
.
use_cache
=
True
config
.
is_decoder
=
True
config
.
is_decoder
=
True
...
@@ -1587,24 +1579,21 @@ class GenerationTesterMixin:
...
@@ -1587,24 +1579,21 @@ class GenerationTesterMixin:
def
test_assisted_decoding_sample
(
self
):
def
test_assisted_decoding_sample
(
self
):
# Seeded assisted decoding will not match sample for the same seed, as the forward pass does not return the
# Seeded assisted decoding will not match sample for the same seed, as the forward pass does not return the
# exact same logits (the forward pass of the main model, now with several tokens at once, has causal masking).
# exact same logits (the forward pass of the main model, now with several tokens at once, has causal masking).
for
model_class
in
self
.
all_generative_model_classes
:
for
model_class
in
self
.
all_generative_model_classes
:
# won't fix: FSMT and Reformer have a different cache variable type (and format).
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"fsmt"
,
"reformer"
]):
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"fsmt"
,
"reformer"
]):
return
self
.
skipTest
(
"Won't fix: old model with different cache format"
)
# may fix in the future: the following models fail with assisted decoding, and need model-specific fixes
if
any
(
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"bigbirdpegasus"
,
"led"
,
"mega"
,
"speech2text"
,
"git"
,
"prophetnet"
,
"seamlessm4t"
]
for
model_name
in
[
"bigbirdpegasus"
,
"led"
,
"mega"
,
"speech2text"
,
"git"
,
"prophetnet"
,
"seamlessm4t"
]
):
):
return
self
.
skipTest
(
"May fix in the future: need model-specific fixes"
)
# enable cache
# enable cache
config
,
input_ids
,
attention_mask
,
max_length
=
self
.
_get_input_ids_and_config
(
batch_size
=
1
)
config
,
input_ids
,
attention_mask
,
max_length
=
self
.
_get_input_ids_and_config
(
batch_size
=
1
)
# NOTE: assisted generation only works with cache on at the moment.
# NOTE: assisted generation only works with cache on at the moment.
if
not
hasattr
(
config
,
"use_cache"
):
if
not
hasattr
(
config
,
"use_cache"
):
return
self
.
skipTest
(
"This model doesn't support caching"
)
config
.
use_cache
=
True
config
.
use_cache
=
True
config
.
is_decoder
=
True
config
.
is_decoder
=
True
...
@@ -1716,7 +1705,7 @@ class GenerationTesterMixin:
...
@@ -1716,7 +1705,7 @@ class GenerationTesterMixin:
# If it doesn't support cache, pass the test
# If it doesn't support cache, pass the test
if
not
hasattr
(
config
,
"use_cache"
):
if
not
hasattr
(
config
,
"use_cache"
):
return
self
.
skipTest
(
"This model doesn't support caching"
)
model
=
model_class
(
config
).
to
(
torch_device
)
model
=
model_class
(
config
).
to
(
torch_device
)
if
"use_cache"
not
in
inputs
:
if
"use_cache"
not
in
inputs
:
...
@@ -1725,7 +1714,7 @@ class GenerationTesterMixin:
...
@@ -1725,7 +1714,7 @@ class GenerationTesterMixin:
# If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format)
# If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format)
if
"past_key_values"
not
in
outputs
:
if
"past_key_values"
not
in
outputs
:
return
self
.
skipTest
(
"This model doesn't return `past_key_values`"
)
num_hidden_layers
=
(
num_hidden_layers
=
(
getattr
(
config
,
"decoder_layers"
,
None
)
getattr
(
config
,
"decoder_layers"
,
None
)
...
@@ -1832,18 +1821,15 @@ class GenerationTesterMixin:
...
@@ -1832,18 +1821,15 @@ class GenerationTesterMixin:
def
test_generate_continue_from_past_key_values
(
self
):
def
test_generate_continue_from_past_key_values
(
self
):
# Tests that we can continue generating from past key values, returned from a previous `generate` call
# Tests that we can continue generating from past key values, returned from a previous `generate` call
for
model_class
in
self
.
all_generative_model_classes
:
for
model_class
in
self
.
all_generative_model_classes
:
# won't fix: old models with unique inputs/caches/others
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"imagegpt"
]):
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"imagegpt"
]):
return
self
.
skipTest
(
"Won't fix: old model with unique inputs/caches/other"
)
# may fix in the future: needs modeling or test input preparation fixes for compatibility
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"umt5"
]):
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"umt5"
]):
return
self
.
skipTest
(
"TODO: needs modeling or test input preparation fixes for compatibility"
)
config
,
inputs
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
config
,
inputs
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
# If it doesn't support cache, pass the test
if
not
hasattr
(
config
,
"use_cache"
):
if
not
hasattr
(
config
,
"use_cache"
):
return
self
.
skipTest
(
"This model doesn't support caching"
)
# Let's make it always:
# Let's make it always:
# 1. use cache (for obvious reasons)
# 1. use cache (for obvious reasons)
...
@@ -1862,10 +1848,10 @@ class GenerationTesterMixin:
...
@@ -1862,10 +1848,10 @@ class GenerationTesterMixin:
model
.
generation_config
.
pad_token_id
=
model
.
generation_config
.
eos_token_id
=
-
1
model
.
generation_config
.
pad_token_id
=
model
.
generation_config
.
eos_token_id
=
-
1
model
.
generation_config
.
forced_eos_token_id
=
None
model
.
generation_config
.
forced_eos_token_id
=
None
# If "past_key_values" is not returned,
pass
the test (e.g. RWKV uses a different cache name and format)
# If "past_key_values" is not returned,
skip
the test (e.g. RWKV uses a different cache name and format)
outputs
=
model
(
**
inputs
)
outputs
=
model
(
**
inputs
)
if
"past_key_values"
not
in
outputs
:
if
"past_key_values"
not
in
outputs
:
return
self
.
skipTest
(
"This model doesn't return `past_key_values`"
)
# Traditional way of generating text, with `return_dict_in_generate` to return the past key values
# Traditional way of generating text, with `return_dict_in_generate` to return the past key values
outputs
=
model
.
generate
(
**
inputs
,
do_sample
=
False
,
max_new_tokens
=
4
,
return_dict_in_generate
=
True
)
outputs
=
model
.
generate
(
**
inputs
,
do_sample
=
False
,
max_new_tokens
=
4
,
return_dict_in_generate
=
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment