Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d94cc2f9
Unverified
Commit
d94cc2f9
authored
Jan 26, 2021
by
Patrick von Platen
Committed by
GitHub
Jan 26, 2021
Browse files
[Flaky Generation Tests] Make sure that no early stopping is happening for beam search (#9794)
* fix ci * fix ci * renaming * fix dup line
parent
0fdbf085
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
44 additions
and
4 deletions
+44
-4
tests/test_generation_utils.py
tests/test_generation_utils.py
+44
-4
No files found.
tests/test_generation_utils.py
View file @
d94cc2f9
...
...
@@ -625,6 +625,12 @@ class GenerationTesterMixin:
def
test_beam_search_generate
(
self
):
for
model_class
in
self
.
all_generative_model_classes
:
config
,
input_ids
,
attention_mask
,
max_length
=
self
.
_get_input_ids_and_config
()
# It is important set set the eos_token_id to None to ensure that no sequences
# shorter than `max_length` can be generated which could lead to flaky circle ci
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
config
.
eos_token_id
=
None
model
=
model_class
(
config
).
to
(
torch_device
).
eval
()
logits_process_kwargs
,
logits_processor
=
self
.
_get_logits_processor_and_kwargs
(
...
...
@@ -669,9 +675,16 @@ class GenerationTesterMixin:
def
test_beam_search_generate_dict_output
(
self
):
for
model_class
in
self
.
all_generative_model_classes
:
# disable cache
config
,
input_ids
,
attention_mask
,
max_length
=
self
.
_get_input_ids_and_config
()
# disable cache
config
.
use_cache
=
False
# It is important set set the eos_token_id to None to ensure that no sequences
# shorter than `max_length` can be generated which could lead to flaky circle ci
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
config
.
eos_token_id
=
None
model
=
model_class
(
config
).
to
(
torch_device
).
eval
()
logits_process_kwargs
,
logits_processor
=
self
.
_get_logits_processor_and_kwargs
(
input_ids
.
shape
[
-
1
],
config
.
eos_token_id
...
...
@@ -715,11 +728,15 @@ class GenerationTesterMixin:
# enable cache
config
,
input_ids
,
attention_mask
,
max_length
=
self
.
_get_input_ids_and_config
()
# It is important set set the eos_token_id to None to ensure that no sequences
# shorter than `max_length` can be generated which could lead to flaky circle ci
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
config
.
eos_token_id
=
None
if
not
hasattr
(
config
,
"use_cache"
):
# only relevant if model has "use_cache"
return
config
,
input_ids
,
attention_mask
,
max_length
=
self
.
_get_input_ids_and_config
()
model
=
model_class
(
config
).
to
(
torch_device
).
eval
()
logits_process_kwargs
,
logits_processor
=
self
.
_get_logits_processor_and_kwargs
(
...
...
@@ -758,7 +775,12 @@ class GenerationTesterMixin:
def
test_beam_sample_generate
(
self
):
for
model_class
in
self
.
all_generative_model_classes
:
config
,
input_ids
,
attention_mask
,
max_length
=
self
.
_get_input_ids_and_config
()
print
(
"Return dict"
,
config
.
return_dict
)
# It is important set set the eos_token_id to None to ensure that no sequences
# shorter than `max_length` can be generated which could lead to flaky circle ci
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
config
.
eos_token_id
=
None
logits_warper_kwargs
,
logits_warper
=
self
.
_get_warper_and_kwargs
(
num_beams
=
1
)
model
=
model_class
(
config
).
to
(
torch_device
).
eval
()
...
...
@@ -788,9 +810,16 @@ class GenerationTesterMixin:
def
test_beam_sample_generate_dict_output
(
self
):
for
model_class
in
self
.
all_generative_model_classes
:
# disable cache
config
,
input_ids
,
attention_mask
,
max_length
=
self
.
_get_input_ids_and_config
()
# disable cache
config
.
use_cache
=
False
# It is important set set the eos_token_id to None to ensure that no sequences
# shorter than `max_length` can be generated which could lead to flaky circle ci
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
config
.
eos_token_id
=
None
model
=
model_class
(
config
).
to
(
torch_device
).
eval
()
logits_warper_kwargs
,
logits_warper
=
self
.
_get_warper_and_kwargs
(
num_beams
=
1
)
...
...
@@ -859,6 +888,11 @@ class GenerationTesterMixin:
for
model_class
in
self
.
all_generative_model_classes
:
config
,
input_ids
,
attention_mask
,
max_length
=
self
.
_get_input_ids_and_config
()
# It is important set set the eos_token_id to None to ensure that no sequences
# shorter than `max_length` can be generated which could lead to flaky circle ci
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
config
.
eos_token_id
=
None
logits_process_kwargs
,
logits_processor
=
self
.
_get_logits_processor_and_kwargs
(
input_ids
.
shape
[
-
1
],
config
.
eos_token_id
,
diversity_penalty
=
2.0
)
...
...
@@ -904,6 +938,12 @@ class GenerationTesterMixin:
for
model_class
in
self
.
all_generative_model_classes
:
config
,
input_ids
,
attention_mask
,
max_length
=
self
.
_get_input_ids_and_config
()
config
.
use_cache
=
False
# It is important set set the eos_token_id to None to ensure that no sequences
# shorter than `max_length` can be generated which could lead to flaky circle ci
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
config
.
eos_token_id
=
None
model
=
model_class
(
config
).
to
(
torch_device
).
eval
()
logits_process_kwargs
,
logits_processor
=
self
.
_get_logits_processor_and_kwargs
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment