Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
12b50c61
Unverified
Commit
12b50c61
authored
Nov 16, 2023
by
Joao Gante
Committed by
GitHub
Nov 16, 2023
Browse files
Generate: improve assisted generation tests (#27540)
parent
651408a0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
86 additions
and
67 deletions
+86
-67
tests/generation/test_utils.py
tests/generation/test_utils.py
+86
-67
No files found.
tests/generation/test_utils.py
View file @
12b50c61
...
...
@@ -23,6 +23,7 @@ import numpy as np
from
transformers
import
is_torch_available
,
pipeline
from
transformers.testing_utils
import
(
is_flaky
,
require_accelerate
,
require_torch
,
require_torch_multi_accelerator
,
...
...
@@ -1506,10 +1507,14 @@ class GenerationTesterMixin:
)
self
.
assertListEqual
(
low_output
.
tolist
(),
high_output
.
tolist
())
@
slow
# TODO(Joao): remove this. Some models (e.g. data2vec, xcom, roberta) have an error rate between 1 and 10%
.
@
is_flaky
()
# Read NOTE (1) below. If there are API issues, all attempts will fail
.
def
test_assisted_decoding_matches_greedy_search
(
self
):
# This test ensures that the assisted generation does not introduce output changes over greedy search.
# It breaks the pattern in the tests above, for multiple reasons:
# NOTE (1): The sentence above is true most of the time, there is a tiny difference in the logits due to matmul
# shape differences -- and it may result in a different output. The input shape difference happens in the
# main model, that runs the forward pass with several candidates at once (as opposed to generating one token at
# a time). See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535 for more info.
# NOTE (2): It breaks the pattern in the tests above, for multiple reasons:
# - assisted_decoding, contrarily to the other methods, can't be called on its own (e.g. needs to
# prepare the assistant encoder outputs in the main generate body);
# - assisted_decoding does not support `use_cache = False`
...
...
@@ -1520,77 +1525,82 @@ class GenerationTesterMixin:
self
.
skipTest
(
"Won't fix: old model with different cache format"
)
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"bigbirdpegasus"
,
"led"
,
"mega"
,
"speech2text"
,
"git"
,
"prophetnet"
]
for
model_name
in
[
"bigbirdpegasus"
,
"led"
,
"mega"
,
"speech2text"
,
"git"
,
"prophetnet"
,
"seamlessm4t"
,
"clvp"
,
]
):
self
.
skipTest
(
"May fix in the future: need model-specific fixes"
)
# This for loop is a naive and temporary effort to make the test less flaky.
failed
=
0
for
i
in
range
(
10
):
# enable cache
config
,
input_ids
,
attention_mask
,
max_length
=
self
.
_get_input_ids_and_config
(
batch_size
=
1
)
# NOTE: assisted generation only works with cache on at the moment.
if
not
hasattr
(
config
,
"use_cache"
):
self
.
skipTest
(
"This model doesn't support caching"
)
# enable cache
config
,
input_ids
,
attention_mask
,
_
=
self
.
_get_input_ids_and_config
(
batch_size
=
1
)
config
.
use_cache
=
True
config
.
is_decoder
=
True
model
=
model_class
(
config
).
to
(
torch_device
).
eval
()
output_greedy
=
model
.
generate
(
input_ids
,
attention_mask
=
attention_mask
,
max_length
=
max_length
,
num_beams
=
1
,
do_sample
=
False
,
output_scores
=
True
,
output_hidden_states
=
True
,
output_attentions
=
True
,
return_dict_in_generate
=
True
,
)
# Note: with assisted generate, if the same model is used as assistant, then all assistant tokens will
# be correct
output_assisted
=
model
.
generate
(
input_ids
,
attention_mask
=
attention_mask
,
max_length
=
max_length
,
num_beams
=
1
,
do_sample
=
False
,
assistant_model
=
model
,
output_scores
=
True
,
output_hidden_states
=
True
,
output_attentions
=
True
,
return_dict_in_generate
=
True
,
)
# NOTE: assisted generation only works with cache on at the moment.
if
not
hasattr
(
config
,
"use_cache"
):
self
.
skipTest
(
"This model doesn't support caching"
)
try
:
self
.
assertListEqual
(
output_greedy
.
sequences
.
tolist
(),
output_assisted
.
sequences
.
tolist
())
config
.
use_cache
=
True
config
.
is_decoder
=
True
model
=
model_class
(
config
).
to
(
torch_device
).
eval
()
# Sets assisted generation arguments such that:
# a) no EOS is generated, to ensure generation doesn't break early
# b) the assistant model always generates two tokens when it is called, to ensure the input preparation of
# the assistant model is correct
# c) there are at least two forward passes in the main model, to ensure the input preparation of
# the main model is correct
generation_kwargs
=
{
"eos_token_id"
:
-
1
,
# see a)
"max_new_tokens"
:
4
,
# see c)
"num_beams"
:
1
,
"do_sample"
:
False
,
"output_scores"
:
True
,
"output_hidden_states"
:
True
,
"output_attentions"
:
True
,
"return_dict_in_generate"
:
True
,
}
output_greedy
=
model
.
generate
(
input_ids
,
attention_mask
=
attention_mask
,
**
generation_kwargs
)
for
output
in
(
output_greedy
,
output_assisted
):
self
.
_check_outputs
(
output
,
input_ids
,
model
.
config
,
use_cache
=
True
)
except
AssertionError
:
failed
+=
1
if
failed
>
1
:
self
.
assertListEqual
(
output_greedy
.
sequences
.
tolist
(),
output_assisted
.
sequences
.
tolist
())
assistant_model
=
model
assistant_model
.
generation_config
.
num_assistant_tokens
=
2
# see b)
assistant_model
.
generation_config
.
num_assistant_tokens_schedule
=
"constant"
# see b)
generation_kwargs
.
update
({
"assistant_model"
:
assistant_model
})
output_assisted
=
model
.
generate
(
input_ids
,
attention_mask
=
attention_mask
,
**
generation_kwargs
)
for
output
in
(
output_greedy
,
output_assisted
):
self
.
_check_outputs
(
output
,
input_ids
,
model
.
config
,
use_cache
=
True
)
# The two outputs must match and their shape must be as expected
self
.
assertListEqual
(
output_greedy
.
sequences
.
tolist
(),
output_assisted
.
sequences
.
tolist
())
for
output
in
(
output_greedy
,
output_assisted
):
self
.
_check_outputs
(
output
,
input_ids
,
model
.
config
,
use_cache
=
True
)
@
unittest
.
skip
(
"Failing for a lot of models du to attention mask size missmatch. Works well when standalone."
)
def
test_assisted_decoding_sample
(
self
):
# Seeded assisted decoding will not match sample for the same seed, as the forward pass does not return the
# exact same logits (the forward pass of the main model, now with several tokens at once, has causal masking).
# In this test we don't check assisted vs non-assisted output -- seeded assisted decoding with sample will not
# match sample for the same seed, as the forward pass does not return the exact same logits (due to matmul with
# different shapes, see https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535).
for
model_class
in
self
.
all_generative_model_classes
:
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"fsmt"
,
"reformer"
]):
self
.
skipTest
(
"Won't fix: old model with different cache format"
)
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"bigbirdpegasus"
,
"led"
,
"mega"
,
"speech2text"
,
"git"
,
"prophetnet"
,
"seamlessm4t"
]
for
model_name
in
[
"bigbirdpegasus"
,
"led"
,
"mega"
,
"speech2text"
,
"git"
,
"prophetnet"
,
"seamlessm4t"
,
"clvp"
,
]
):
self
.
skipTest
(
"May fix in the future: need model-specific fixes"
)
# enable cache
config
,
input_ids
,
attention_mask
,
max_length
=
self
.
_get_input_ids_and_config
(
batch_size
=
1
)
config
,
input_ids
,
attention_mask
,
_
=
self
.
_get_input_ids_and_config
(
batch_size
=
1
)
# NOTE: assisted generation only works with cache on at the moment.
if
not
hasattr
(
config
,
"use_cache"
):
...
...
@@ -1599,18 +1609,27 @@ class GenerationTesterMixin:
config
.
use_cache
=
True
config
.
is_decoder
=
True
model
=
model_class
(
config
).
to
(
torch_device
).
eval
()
output_assisted
=
model
.
generate
(
input_ids
,
attention_mask
=
attention_mask
,
max_length
=
max_length
,
num_beams
=
1
,
do_sample
=
True
,
assistant_model
=
model
,
# triggers assisted decoding
output_scores
=
True
,
output_hidden_states
=
True
,
output_attentions
=
True
,
return_dict_in_generate
=
True
,
)
# Sets assisted generation arguments such that:
# a) no EOS is generated, to ensure generation doesn't break early
# b) the assistant model always generates two tokens when it is called, to ensure the input preparation of
# the assistant model is correct
# c) there are at least two forward passes in the main model, to ensure the input preparation of
# the main model is correct
assistant_model
=
model
assistant_model
.
generation_config
.
num_assistant_tokens
=
2
# see b)
assistant_model
.
generation_config
.
num_assistant_tokens_schedule
=
"constant"
# see b)
generation_kwargs
=
{
"eos_token_id"
:
-
1
,
# see a)
"max_new_tokens"
:
4
,
# see c)
"num_beams"
:
1
,
"do_sample"
:
True
,
"assistant_model"
:
assistant_model
,
"output_scores"
:
True
,
"output_hidden_states"
:
True
,
"output_attentions"
:
True
,
"return_dict_in_generate"
:
True
,
}
output_assisted
=
model
.
generate
(
input_ids
,
attention_mask
=
attention_mask
,
**
generation_kwargs
)
self
.
_check_outputs
(
output_assisted
,
input_ids
,
model
.
config
,
use_cache
=
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment