Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
913d03dc
Unverified
Commit
913d03dc
authored
Nov 17, 2023
by
Joao Gante
Committed by
GitHub
Nov 17, 2023
Browse files
Generate: fix flaky tests (#27543)
parent
d903abfc
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
20 additions
and
25 deletions
+20
-25
src/transformers/generation/logits_process.py
src/transformers/generation/logits_process.py
+2
-1
tests/generation/test_logits_process.py
tests/generation/test_logits_process.py
+1
-1
tests/generation/test_utils.py
tests/generation/test_utils.py
+17
-23
No files found.
src/transformers/generation/logits_process.py
View file @
913d03dc
...
...
@@ -1301,8 +1301,9 @@ class InfNanRemoveLogitsProcessor(LogitsProcessor):
# set all nan values to 0.0
scores
[
scores
!=
scores
]
=
0.0
# set all inf values to max possible value
# set all
+/-
inf values to max
/min
possible value
scores
[
scores
==
float
(
"inf"
)]
=
torch
.
finfo
(
scores
.
dtype
).
max
scores
[
scores
==
float
(
"-inf"
)]
=
torch
.
finfo
(
scores
.
dtype
).
min
return
scores
...
...
tests/generation/test_logits_process.py
View file @
913d03dc
...
...
@@ -692,7 +692,7 @@ class LogitsProcessorTest(unittest.TestCase):
torch
.
allclose
(
scores
,
torch
.
tensor
(
[[
0.0
,
0.7
,
0.8
,
0.0
],
[
0.1
,
torch
.
finfo
(
scores
.
dtype
).
max
,
0.3
,
float
(
"-inf"
)
]],
[[
0.0
,
0.7
,
0.8
,
0.0
],
[
0.1
,
torch
.
finfo
(
scores
.
dtype
).
max
,
0.3
,
torch
.
finfo
(
scores
.
dtype
).
min
]],
device
=
torch_device
,
),
atol
=
1e-6
,
...
...
tests/generation/test_utils.py
View file @
913d03dc
...
...
@@ -124,9 +124,14 @@ class GenerationTesterMixin:
process_kwargs
=
{
"min_length"
:
input_length
+
1
if
max_length
is
None
else
max_length
-
1
,
"bad_words_ids"
:
[[
1
,
0
]],
"no_repeat_ngram_size"
:
2
,
"repetition_penalty"
:
1.2
,
"remove_invalid_values"
:
True
,
}
# NoRepeatNGramLogitsProcessor + forced tokens may result in no valid continuations
if
forced_bos_token_id
is
None
and
forced_eos_token_id
is
None
:
process_kwargs
[
"no_repeat_ngram_size"
]
=
2
# NOTE: the order of operations here should match `generate` for accurate testing
logits_processor
=
LogitsProcessorList
(
(
[
...
...
@@ -154,12 +159,16 @@ class GenerationTesterMixin:
if
forced_eos_token_id
is
not
None
else
[]
)
+
[
NoBadWordsLogitsProcessor
(
process_kwargs
[
"bad_words_ids"
],
eos_token_id
),
NoRepeatNGramLogitsProcessor
(
process_kwargs
[
"no_repeat_ngram_size"
]),
RepetitionPenaltyLogitsProcessor
(
process_kwargs
[
"repetition_penalty"
]),
]
+
[
NoBadWordsLogitsProcessor
(
process_kwargs
[
"bad_words_ids"
],
eos_token_id
)]
+
(
[
NoRepeatNGramLogitsProcessor
(
process_kwargs
[
"no_repeat_ngram_size"
])]
if
forced_bos_token_id
is
None
and
forced_eos_token_id
is
None
else
[]
)
+
[
RepetitionPenaltyLogitsProcessor
(
process_kwargs
[
"repetition_penalty"
])]
+
[
InfNanRemoveLogitsProcessor
()]
# prevent flaky generation test failures
)
return
process_kwargs
,
logits_processor
@
staticmethod
...
...
@@ -282,7 +291,6 @@ class GenerationTesterMixin:
output_hidden_states
=
output_hidden_states
,
output_scores
=
output_scores
,
return_dict_in_generate
=
return_dict_in_generate
,
remove_invalid_values
=
True
,
**
logits_process_kwargs
,
**
model_kwargs
,
)
...
...
@@ -340,7 +348,6 @@ class GenerationTesterMixin:
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict_in_generate
=
return_dict_in_generate
,
remove_invalid_values
=
True
,
**
logits_warper_kwargs
,
**
process_kwargs
,
**
model_kwargs
,
...
...
@@ -361,9 +368,6 @@ class GenerationTesterMixin:
elif
attention_mask
is
not
None
:
attention_mask
=
attention_mask
.
repeat_interleave
(
num_return_sequences
,
dim
=
0
)
# prevent flaky generation test failures
logits_processor
.
append
(
InfNanRemoveLogitsProcessor
())
with
torch
.
no_grad
():
model_kwargs
=
{
"attention_mask"
:
attention_mask
}
if
attention_mask
is
not
None
else
{}
output_sample
=
model
.
sample
(
...
...
@@ -405,7 +409,6 @@ class GenerationTesterMixin:
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict_in_generate
=
return_dict_in_generate
,
remove_invalid_values
=
True
,
**
beam_kwargs
,
**
logits_process_kwargs
,
**
model_kwargs
,
...
...
@@ -467,7 +470,6 @@ class GenerationTesterMixin:
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict_in_generate
=
return_dict_in_generate
,
remove_invalid_values
=
True
,
**
beam_kwargs
,
**
logits_warper_kwargs
,
**
model_kwargs
,
...
...
@@ -534,7 +536,6 @@ class GenerationTesterMixin:
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict_in_generate
=
return_dict_in_generate
,
remove_invalid_values
=
True
,
**
beam_kwargs
,
**
logits_process_kwargs
,
**
model_kwargs
,
...
...
@@ -596,7 +597,6 @@ class GenerationTesterMixin:
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict_in_generate
=
return_dict_in_generate
,
remove_invalid_values
=
True
,
constraints
=
constraints
,
**
beam_kwargs
,
**
logits_process_kwargs
,
...
...
@@ -671,7 +671,6 @@ class GenerationTesterMixin:
output_hidden_states
=
output_hidden_states
,
output_scores
=
output_scores
,
return_dict_in_generate
=
return_dict_in_generate
,
remove_invalid_values
=
True
,
**
logits_process_kwargs
,
**
model_kwargs
,
**
contrastive_search_kwargs
,
...
...
@@ -1284,13 +1283,8 @@ class GenerationTesterMixin:
# check `generate()` and `constrained_beam_search()` are equal
# Sample constraints
if
not
input_ids
.
dtype
==
torch
.
float32
:
min_id
=
torch
.
min
(
input_ids
)
+
3
max_id
=
torch
.
max
(
input_ids
)
else
:
# otherwise this throws an error for Speech2TextModel since its inputs are floating points
min_id
=
3
max_id
=
100
min_id
=
3
max_id
=
config
.
vocab_size
force_tokens
=
torch
.
randint
(
min_id
,
max_id
,
(
1
,
2
)).
tolist
()[
0
]
constraints
=
[
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment