Unverified Commit 913d03dc authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: fix flaky tests (#27543)

parent d903abfc
...@@ -1301,8 +1301,9 @@ class InfNanRemoveLogitsProcessor(LogitsProcessor): ...@@ -1301,8 +1301,9 @@ class InfNanRemoveLogitsProcessor(LogitsProcessor):
# set all nan values to 0.0 # set all nan values to 0.0
scores[scores != scores] = 0.0 scores[scores != scores] = 0.0
# set all inf values to max possible value # set all +/-inf values to max/min possible value
scores[scores == float("inf")] = torch.finfo(scores.dtype).max scores[scores == float("inf")] = torch.finfo(scores.dtype).max
scores[scores == float("-inf")] = torch.finfo(scores.dtype).min
return scores return scores
......
...@@ -692,7 +692,7 @@ class LogitsProcessorTest(unittest.TestCase): ...@@ -692,7 +692,7 @@ class LogitsProcessorTest(unittest.TestCase):
torch.allclose( torch.allclose(
scores, scores,
torch.tensor( torch.tensor(
[[0.0, 0.7, 0.8, 0.0], [0.1, torch.finfo(scores.dtype).max, 0.3, float("-inf")]], [[0.0, 0.7, 0.8, 0.0], [0.1, torch.finfo(scores.dtype).max, 0.3, torch.finfo(scores.dtype).min]],
device=torch_device, device=torch_device,
), ),
atol=1e-6, atol=1e-6,
......
...@@ -124,9 +124,14 @@ class GenerationTesterMixin: ...@@ -124,9 +124,14 @@ class GenerationTesterMixin:
process_kwargs = { process_kwargs = {
"min_length": input_length + 1 if max_length is None else max_length - 1, "min_length": input_length + 1 if max_length is None else max_length - 1,
"bad_words_ids": [[1, 0]], "bad_words_ids": [[1, 0]],
"no_repeat_ngram_size": 2,
"repetition_penalty": 1.2, "repetition_penalty": 1.2,
"remove_invalid_values": True,
} }
# NoRepeatNGramLogitsProcessor + forced tokens may result in no valid continuations
if forced_bos_token_id is None and forced_eos_token_id is None:
process_kwargs["no_repeat_ngram_size"] = 2
# NOTE: the order of operations here should match `generate` for accurate testing
logits_processor = LogitsProcessorList( logits_processor = LogitsProcessorList(
( (
[ [
...@@ -154,12 +159,16 @@ class GenerationTesterMixin: ...@@ -154,12 +159,16 @@ class GenerationTesterMixin:
if forced_eos_token_id is not None if forced_eos_token_id is not None
else [] else []
) )
+ [ + [NoBadWordsLogitsProcessor(process_kwargs["bad_words_ids"], eos_token_id)]
NoBadWordsLogitsProcessor(process_kwargs["bad_words_ids"], eos_token_id), + (
NoRepeatNGramLogitsProcessor(process_kwargs["no_repeat_ngram_size"]), [NoRepeatNGramLogitsProcessor(process_kwargs["no_repeat_ngram_size"])]
RepetitionPenaltyLogitsProcessor(process_kwargs["repetition_penalty"]), if forced_bos_token_id is None and forced_eos_token_id is None
] else []
)
+ [RepetitionPenaltyLogitsProcessor(process_kwargs["repetition_penalty"])]
+ [InfNanRemoveLogitsProcessor()] # prevent flaky generation test failures
) )
return process_kwargs, logits_processor return process_kwargs, logits_processor
@staticmethod @staticmethod
...@@ -282,7 +291,6 @@ class GenerationTesterMixin: ...@@ -282,7 +291,6 @@ class GenerationTesterMixin:
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
output_scores=output_scores, output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate, return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True,
**logits_process_kwargs, **logits_process_kwargs,
**model_kwargs, **model_kwargs,
) )
...@@ -340,7 +348,6 @@ class GenerationTesterMixin: ...@@ -340,7 +348,6 @@ class GenerationTesterMixin:
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate, return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True,
**logits_warper_kwargs, **logits_warper_kwargs,
**process_kwargs, **process_kwargs,
**model_kwargs, **model_kwargs,
...@@ -361,9 +368,6 @@ class GenerationTesterMixin: ...@@ -361,9 +368,6 @@ class GenerationTesterMixin:
elif attention_mask is not None: elif attention_mask is not None:
attention_mask = attention_mask.repeat_interleave(num_return_sequences, dim=0) attention_mask = attention_mask.repeat_interleave(num_return_sequences, dim=0)
# prevent flaky generation test failures
logits_processor.append(InfNanRemoveLogitsProcessor())
with torch.no_grad(): with torch.no_grad():
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_sample = model.sample( output_sample = model.sample(
...@@ -405,7 +409,6 @@ class GenerationTesterMixin: ...@@ -405,7 +409,6 @@ class GenerationTesterMixin:
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate, return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True,
**beam_kwargs, **beam_kwargs,
**logits_process_kwargs, **logits_process_kwargs,
**model_kwargs, **model_kwargs,
...@@ -467,7 +470,6 @@ class GenerationTesterMixin: ...@@ -467,7 +470,6 @@ class GenerationTesterMixin:
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate, return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True,
**beam_kwargs, **beam_kwargs,
**logits_warper_kwargs, **logits_warper_kwargs,
**model_kwargs, **model_kwargs,
...@@ -534,7 +536,6 @@ class GenerationTesterMixin: ...@@ -534,7 +536,6 @@ class GenerationTesterMixin:
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate, return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True,
**beam_kwargs, **beam_kwargs,
**logits_process_kwargs, **logits_process_kwargs,
**model_kwargs, **model_kwargs,
...@@ -596,7 +597,6 @@ class GenerationTesterMixin: ...@@ -596,7 +597,6 @@ class GenerationTesterMixin:
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate, return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True,
constraints=constraints, constraints=constraints,
**beam_kwargs, **beam_kwargs,
**logits_process_kwargs, **logits_process_kwargs,
...@@ -671,7 +671,6 @@ class GenerationTesterMixin: ...@@ -671,7 +671,6 @@ class GenerationTesterMixin:
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
output_scores=output_scores, output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate, return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True,
**logits_process_kwargs, **logits_process_kwargs,
**model_kwargs, **model_kwargs,
**contrastive_search_kwargs, **contrastive_search_kwargs,
...@@ -1284,13 +1283,8 @@ class GenerationTesterMixin: ...@@ -1284,13 +1283,8 @@ class GenerationTesterMixin:
# check `generate()` and `constrained_beam_search()` are equal # check `generate()` and `constrained_beam_search()` are equal
# Sample constraints # Sample constraints
if not input_ids.dtype == torch.float32: min_id = 3
min_id = torch.min(input_ids) + 3 max_id = config.vocab_size
max_id = torch.max(input_ids)
else:
# otherwise this throws an error for Speech2TextModel since its inputs are floating points
min_id = 3
max_id = 100
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0] force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
constraints = [ constraints = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment