Unverified Commit 913d03dc authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: fix flaky tests (#27543)

parent d903abfc
......@@ -1301,8 +1301,9 @@ class InfNanRemoveLogitsProcessor(LogitsProcessor):
# set all nan values to 0.0
scores[scores != scores] = 0.0
# set all inf values to max possible value
# set all +/-inf values to max/min possible value
scores[scores == float("inf")] = torch.finfo(scores.dtype).max
scores[scores == float("-inf")] = torch.finfo(scores.dtype).min
return scores
......
......@@ -692,7 +692,7 @@ class LogitsProcessorTest(unittest.TestCase):
torch.allclose(
scores,
torch.tensor(
[[0.0, 0.7, 0.8, 0.0], [0.1, torch.finfo(scores.dtype).max, 0.3, float("-inf")]],
[[0.0, 0.7, 0.8, 0.0], [0.1, torch.finfo(scores.dtype).max, 0.3, torch.finfo(scores.dtype).min]],
device=torch_device,
),
atol=1e-6,
......
......@@ -124,9 +124,14 @@ class GenerationTesterMixin:
process_kwargs = {
"min_length": input_length + 1 if max_length is None else max_length - 1,
"bad_words_ids": [[1, 0]],
"no_repeat_ngram_size": 2,
"repetition_penalty": 1.2,
"remove_invalid_values": True,
}
# NoRepeatNGramLogitsProcessor + forced tokens may result in no valid continuations
if forced_bos_token_id is None and forced_eos_token_id is None:
process_kwargs["no_repeat_ngram_size"] = 2
# NOTE: the order of operations here should match `generate` for accurate testing
logits_processor = LogitsProcessorList(
(
[
......@@ -154,12 +159,16 @@ class GenerationTesterMixin:
if forced_eos_token_id is not None
else []
)
+ [
NoBadWordsLogitsProcessor(process_kwargs["bad_words_ids"], eos_token_id),
NoRepeatNGramLogitsProcessor(process_kwargs["no_repeat_ngram_size"]),
RepetitionPenaltyLogitsProcessor(process_kwargs["repetition_penalty"]),
]
+ [NoBadWordsLogitsProcessor(process_kwargs["bad_words_ids"], eos_token_id)]
+ (
[NoRepeatNGramLogitsProcessor(process_kwargs["no_repeat_ngram_size"])]
if forced_bos_token_id is None and forced_eos_token_id is None
else []
)
+ [RepetitionPenaltyLogitsProcessor(process_kwargs["repetition_penalty"])]
+ [InfNanRemoveLogitsProcessor()] # prevent flaky generation test failures
)
return process_kwargs, logits_processor
@staticmethod
......@@ -282,7 +291,6 @@ class GenerationTesterMixin:
output_hidden_states=output_hidden_states,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True,
**logits_process_kwargs,
**model_kwargs,
)
......@@ -340,7 +348,6 @@ class GenerationTesterMixin:
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True,
**logits_warper_kwargs,
**process_kwargs,
**model_kwargs,
......@@ -361,9 +368,6 @@ class GenerationTesterMixin:
elif attention_mask is not None:
attention_mask = attention_mask.repeat_interleave(num_return_sequences, dim=0)
# prevent flaky generation test failures
logits_processor.append(InfNanRemoveLogitsProcessor())
with torch.no_grad():
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_sample = model.sample(
......@@ -405,7 +409,6 @@ class GenerationTesterMixin:
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True,
**beam_kwargs,
**logits_process_kwargs,
**model_kwargs,
......@@ -467,7 +470,6 @@ class GenerationTesterMixin:
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True,
**beam_kwargs,
**logits_warper_kwargs,
**model_kwargs,
......@@ -534,7 +536,6 @@ class GenerationTesterMixin:
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True,
**beam_kwargs,
**logits_process_kwargs,
**model_kwargs,
......@@ -596,7 +597,6 @@ class GenerationTesterMixin:
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True,
constraints=constraints,
**beam_kwargs,
**logits_process_kwargs,
......@@ -671,7 +671,6 @@ class GenerationTesterMixin:
output_hidden_states=output_hidden_states,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True,
**logits_process_kwargs,
**model_kwargs,
**contrastive_search_kwargs,
......@@ -1284,13 +1283,8 @@ class GenerationTesterMixin:
# check `generate()` and `constrained_beam_search()` are equal
# Sample constraints
if not input_ids.dtype == torch.float32:
min_id = torch.min(input_ids) + 3
max_id = torch.max(input_ids)
else:
# otherwise this throws an error for Speech2TextModel since its inputs are floating points
min_id = 3
max_id = 100
min_id = 3
max_id = config.vocab_size
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
constraints = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment