Unverified Commit ed1924e8 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: validate `model_kwargs` (and catch typos in generate arguments) (#18261)

* validate generate model_kwargs

* generate tests -- not all models have an attn mask
parent 2156619f
...@@ -841,6 +841,29 @@ class GenerationMixin: ...@@ -841,6 +841,29 @@ class GenerationMixin:
return transition_scores return transition_scores
def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
# Excludes arguments that are handled before calling any model function
if self.config.is_encoder_decoder:
for key in ["decoder_input_ids"]:
model_kwargs.pop(key, None)
unused_model_args = []
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
# `kwargs` if often used to handle optional forward pass inputs like `attention_mask`. If
# `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;)
if "kwargs" in model_args:
model_args |= set(inspect.signature(self.forward).parameters)
for key, value in model_kwargs.items():
if value is not None and key not in model_args:
unused_model_args.append(key)
if unused_model_args:
raise ValueError(
f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
" generate arguments will also show up in this list)"
)
@torch.no_grad() @torch.no_grad()
def generate( def generate(
self, self,
...@@ -1120,6 +1143,9 @@ class GenerationMixin: ...@@ -1120,6 +1143,9 @@ class GenerationMixin:
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Paris ist eines der dichtesten besiedelten Gebiete Europas.'] ['Paris ist eines der dichtesten besiedelten Gebiete Europas.']
```""" ```"""
# 0. Validate model kwargs
self._validate_model_kwargs(model_kwargs.copy())
# 1. Set generation parameters if not already defined # 1. Set generation parameters if not already defined
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
num_beams = num_beams if num_beams is not None else self.config.num_beams num_beams = num_beams if num_beams is not None else self.config.num_beams
......
...@@ -75,21 +75,25 @@ class GenerationTesterMixin: ...@@ -75,21 +75,25 @@ class GenerationTesterMixin:
def _get_input_ids_and_config(self): def _get_input_ids_and_config(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict[self.input_name] input_ids = inputs_dict[self.input_name]
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
# cut to half length & take max batch_size 3 # cut to half length & take max batch_size 3
max_batch_size = 2 max_batch_size = 2
sequence_length = input_ids.shape[-1] // 2 sequence_length = input_ids.shape[-1] // 2
input_ids = input_ids[:max_batch_size, :sequence_length] input_ids = input_ids[:max_batch_size, :sequence_length]
attention_mask = attention_mask[:max_batch_size, :sequence_length]
# generate max 3 tokens # generate max 3 tokens
max_length = input_ids.shape[-1] + 3 max_length = input_ids.shape[-1] + 3
if config.eos_token_id is not None and config.pad_token_id is None: if config.eos_token_id is not None and config.pad_token_id is None:
# hack to allow generate for models such as GPT2 as is done in `generate()` # hack to allow generate for models such as GPT2 as is done in `generate()`
config.pad_token_id = config.eos_token_id config.pad_token_id = config.eos_token_id
# TransfoXL has no attention mask
if "transfoxl" in config.__class__.__name__.lower():
attention_mask = None
else:
attention_mask = torch.ones_like(input_ids, dtype=torch.long)[:max_batch_size, :sequence_length]
return config, input_ids, attention_mask, max_length return config, input_ids, attention_mask, max_length
@staticmethod @staticmethod
...@@ -252,10 +256,9 @@ class GenerationTesterMixin: ...@@ -252,10 +256,9 @@ class GenerationTesterMixin:
) )
kwargs = {} kwargs = {}
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate( output_generate = model.generate(
input_ids, input_ids,
attention_mask=attention_mask,
do_sample=False, do_sample=False,
num_beams=1, num_beams=1,
max_length=max_length, max_length=max_length,
...@@ -265,6 +268,7 @@ class GenerationTesterMixin: ...@@ -265,6 +268,7 @@ class GenerationTesterMixin:
return_dict_in_generate=return_dict_in_generate, return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True, remove_invalid_values=True,
**logits_process_kwargs, **logits_process_kwargs,
**model_kwargs,
) )
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
...@@ -278,16 +282,17 @@ class GenerationTesterMixin: ...@@ -278,16 +282,17 @@ class GenerationTesterMixin:
kwargs["encoder_outputs"] = encoder_outputs kwargs["encoder_outputs"] = encoder_outputs
with torch.no_grad(): with torch.no_grad():
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_greedy = model.greedy_search( output_greedy = model.greedy_search(
input_ids, input_ids,
max_length=max_length, max_length=max_length,
attention_mask=attention_mask,
logits_processor=logits_processor, logits_processor=logits_processor,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
output_scores=output_scores, output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate, return_dict_in_generate=return_dict_in_generate,
**kwargs, **kwargs,
**model_kwargs,
) )
return output_greedy, output_generate return output_greedy, output_generate
...@@ -308,13 +313,13 @@ class GenerationTesterMixin: ...@@ -308,13 +313,13 @@ class GenerationTesterMixin:
return_dict_in_generate=False, return_dict_in_generate=False,
): ):
torch.manual_seed(0) torch.manual_seed(0)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate( output_generate = model.generate(
input_ids, input_ids,
do_sample=True, do_sample=True,
num_beams=1, num_beams=1,
max_length=max_length, max_length=max_length,
num_return_sequences=num_return_sequences, num_return_sequences=num_return_sequences,
attention_mask=attention_mask,
output_scores=output_scores, output_scores=output_scores,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
...@@ -327,7 +332,7 @@ class GenerationTesterMixin: ...@@ -327,7 +332,7 @@ class GenerationTesterMixin:
torch.manual_seed(0) torch.manual_seed(0)
kwargs = {} kwargs = {}
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs( encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
model, model,
input_ids, input_ids,
attention_mask, attention_mask,
...@@ -336,18 +341,16 @@ class GenerationTesterMixin: ...@@ -336,18 +341,16 @@ class GenerationTesterMixin:
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
) )
kwargs["encoder_outputs"] = encoder_outputs kwargs["encoder_outputs"] = encoder_outputs
input_ids_clone = input_ids_clone.repeat_interleave(num_return_sequences, dim=0) elif attention_mask is not None:
else: attention_mask = attention_mask.repeat_interleave(num_return_sequences, dim=0)
attention_mask_clone = attention_mask.repeat_interleave(num_return_sequences, dim=0)
input_ids_clone = input_ids.repeat_interleave(num_return_sequences, dim=0)
# prevent flaky generation test failures # prevent flaky generation test failures
logits_processor.append(InfNanRemoveLogitsProcessor()) logits_processor.append(InfNanRemoveLogitsProcessor())
with torch.no_grad(): with torch.no_grad():
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_sample = model.sample( output_sample = model.sample(
input_ids_clone, input_ids.repeat_interleave(num_return_sequences, dim=0),
attention_mask=attention_mask_clone,
max_length=max_length, max_length=max_length,
logits_processor=logits_processor, logits_processor=logits_processor,
logits_warper=logits_warper, logits_warper=logits_warper,
...@@ -356,6 +359,7 @@ class GenerationTesterMixin: ...@@ -356,6 +359,7 @@ class GenerationTesterMixin:
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate, return_dict_in_generate=return_dict_in_generate,
**kwargs, **kwargs,
**model_kwargs,
) )
return output_sample, output_generate return output_sample, output_generate
...@@ -374,9 +378,9 @@ class GenerationTesterMixin: ...@@ -374,9 +378,9 @@ class GenerationTesterMixin:
output_hidden_states=False, output_hidden_states=False,
return_dict_in_generate=False, return_dict_in_generate=False,
): ):
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate( output_generate = model.generate(
input_ids, input_ids,
attention_mask=attention_mask,
do_sample=False, do_sample=False,
max_length=max_length, max_length=max_length,
output_scores=output_scores, output_scores=output_scores,
...@@ -386,12 +390,13 @@ class GenerationTesterMixin: ...@@ -386,12 +390,13 @@ class GenerationTesterMixin:
remove_invalid_values=True, remove_invalid_values=True,
**beam_kwargs, **beam_kwargs,
**logits_process_kwargs, **logits_process_kwargs,
**model_kwargs,
) )
# beam_search does not automatically interleave `batch_size` dim for `num_beams` # beam_search does not automatically interleave `batch_size` dim for `num_beams`
kwargs = {} kwargs = {}
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs( encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
model, model,
input_ids, input_ids,
attention_mask, attention_mask,
...@@ -400,23 +405,22 @@ class GenerationTesterMixin: ...@@ -400,23 +405,22 @@ class GenerationTesterMixin:
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
) )
kwargs["encoder_outputs"] = encoder_outputs kwargs["encoder_outputs"] = encoder_outputs
input_ids_clone = input_ids_clone.repeat_interleave(beam_scorer.num_beams, dim=0) elif attention_mask is not None:
else: attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0)
attention_mask_clone = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0)
input_ids_clone = input_ids.repeat_interleave(beam_scorer.num_beams, dim=0)
with torch.no_grad(): with torch.no_grad():
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_beam_search = model.beam_search( output_beam_search = model.beam_search(
input_ids_clone, input_ids.repeat_interleave(beam_scorer.num_beams, dim=0),
beam_scorer, beam_scorer,
max_length=max_length, max_length=max_length,
attention_mask=attention_mask_clone,
logits_processor=logits_processor, logits_processor=logits_processor,
output_scores=output_scores, output_scores=output_scores,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate, return_dict_in_generate=return_dict_in_generate,
**kwargs, **kwargs,
**model_kwargs,
) )
return output_generate, output_beam_search return output_generate, output_beam_search
...@@ -437,9 +441,9 @@ class GenerationTesterMixin: ...@@ -437,9 +441,9 @@ class GenerationTesterMixin:
return_dict_in_generate=False, return_dict_in_generate=False,
): ):
torch.manual_seed(0) torch.manual_seed(0)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate( output_generate = model.generate(
input_ids, input_ids,
attention_mask=attention_mask,
do_sample=True, do_sample=True,
max_length=max_length, max_length=max_length,
output_scores=output_scores, output_scores=output_scores,
...@@ -449,6 +453,7 @@ class GenerationTesterMixin: ...@@ -449,6 +453,7 @@ class GenerationTesterMixin:
remove_invalid_values=True, remove_invalid_values=True,
**beam_kwargs, **beam_kwargs,
**logits_warper_kwargs, **logits_warper_kwargs,
**model_kwargs,
) )
# beam_search does not automatically interleave `batch_size` dim for `num_beams * num_return_sequences` # beam_search does not automatically interleave `batch_size` dim for `num_beams * num_return_sequences`
kwargs = {} kwargs = {}
...@@ -462,7 +467,7 @@ class GenerationTesterMixin: ...@@ -462,7 +467,7 @@ class GenerationTesterMixin:
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
) )
kwargs["encoder_outputs"] = encoder_outputs kwargs["encoder_outputs"] = encoder_outputs
else: elif attention_mask is not None:
attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0) attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0)
# prevent flaky generation test failures # prevent flaky generation test failures
...@@ -471,11 +476,11 @@ class GenerationTesterMixin: ...@@ -471,11 +476,11 @@ class GenerationTesterMixin:
torch.manual_seed(0) torch.manual_seed(0)
with torch.no_grad(): with torch.no_grad():
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_beam_sample = model.beam_sample( output_beam_sample = model.beam_sample(
input_ids.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0), input_ids.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0),
beam_scorer, beam_scorer,
max_length=max_length, max_length=max_length,
attention_mask=attention_mask,
logits_warper=logits_warper, logits_warper=logits_warper,
logits_processor=logits_processor, logits_processor=logits_processor,
output_scores=output_scores, output_scores=output_scores,
...@@ -483,6 +488,7 @@ class GenerationTesterMixin: ...@@ -483,6 +488,7 @@ class GenerationTesterMixin:
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate, return_dict_in_generate=return_dict_in_generate,
**kwargs, **kwargs,
**model_kwargs,
) )
return output_generate, output_beam_sample return output_generate, output_beam_sample
...@@ -502,9 +508,9 @@ class GenerationTesterMixin: ...@@ -502,9 +508,9 @@ class GenerationTesterMixin:
output_hidden_states=False, output_hidden_states=False,
return_dict_in_generate=False, return_dict_in_generate=False,
): ):
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate( output_generate = model.generate(
input_ids, input_ids,
attention_mask=attention_mask,
do_sample=False, do_sample=False,
max_length=max_length, max_length=max_length,
output_scores=output_scores, output_scores=output_scores,
...@@ -514,12 +520,13 @@ class GenerationTesterMixin: ...@@ -514,12 +520,13 @@ class GenerationTesterMixin:
remove_invalid_values=True, remove_invalid_values=True,
**beam_kwargs, **beam_kwargs,
**logits_process_kwargs, **logits_process_kwargs,
**model_kwargs,
) )
# group_beam_search does not automatically interleave `batch_size` dim for `num_beams` # group_beam_search does not automatically interleave `batch_size` dim for `num_beams`
kwargs = {} kwargs = {}
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs( encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
model, model,
input_ids, input_ids,
attention_mask, attention_mask,
...@@ -528,23 +535,22 @@ class GenerationTesterMixin: ...@@ -528,23 +535,22 @@ class GenerationTesterMixin:
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
) )
kwargs["encoder_outputs"] = encoder_outputs kwargs["encoder_outputs"] = encoder_outputs
input_ids_clone = input_ids_clone.repeat_interleave(beam_scorer.num_beams, dim=0) elif attention_mask is not None:
else: attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0)
attention_mask_clone = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0)
input_ids_clone = input_ids.repeat_interleave(beam_scorer.num_beams, dim=0)
with torch.no_grad(): with torch.no_grad():
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_group_beam_search = model.group_beam_search( output_group_beam_search = model.group_beam_search(
input_ids_clone, input_ids.repeat_interleave(beam_scorer.num_beams, dim=0),
beam_scorer, beam_scorer,
max_length=max_length, max_length=max_length,
attention_mask=attention_mask_clone,
logits_processor=logits_processor, logits_processor=logits_processor,
output_scores=output_scores, output_scores=output_scores,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate, return_dict_in_generate=return_dict_in_generate,
**kwargs, **kwargs,
**model_kwargs,
) )
return output_generate, output_group_beam_search return output_generate, output_group_beam_search
...@@ -564,9 +570,9 @@ class GenerationTesterMixin: ...@@ -564,9 +570,9 @@ class GenerationTesterMixin:
output_hidden_states=False, output_hidden_states=False,
return_dict_in_generate=False, return_dict_in_generate=False,
): ):
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate( output_generate = model.generate(
input_ids, input_ids,
attention_mask=attention_mask,
do_sample=False, do_sample=False,
max_length=max_length, max_length=max_length,
output_scores=output_scores, output_scores=output_scores,
...@@ -577,12 +583,13 @@ class GenerationTesterMixin: ...@@ -577,12 +583,13 @@ class GenerationTesterMixin:
constraints=constraints, constraints=constraints,
**beam_kwargs, **beam_kwargs,
**logits_process_kwargs, **logits_process_kwargs,
**model_kwargs,
) )
# group_beam_search does not automatically interleave `batch_size` dim for `num_beams` # group_beam_search does not automatically interleave `batch_size` dim for `num_beams`
kwargs = {} kwargs = {}
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs( encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
model, model,
input_ids, input_ids,
attention_mask, attention_mask,
...@@ -591,23 +598,22 @@ class GenerationTesterMixin: ...@@ -591,23 +598,22 @@ class GenerationTesterMixin:
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
) )
kwargs["encoder_outputs"] = encoder_outputs kwargs["encoder_outputs"] = encoder_outputs
input_ids_clone = input_ids_clone.repeat_interleave(constrained_beam_scorer.num_beams, dim=0) elif attention_mask is not None:
else: attention_mask = attention_mask.repeat_interleave(constrained_beam_scorer.num_beams, dim=0)
attention_mask_clone = attention_mask.repeat_interleave(constrained_beam_scorer.num_beams, dim=0)
input_ids_clone = input_ids.repeat_interleave(constrained_beam_scorer.num_beams, dim=0)
with torch.no_grad(): with torch.no_grad():
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_group_beam_search = model.constrained_beam_search( output_group_beam_search = model.constrained_beam_search(
input_ids_clone, input_ids.repeat_interleave(constrained_beam_scorer.num_beams, dim=0),
constrained_beam_scorer, constrained_beam_scorer,
max_length=max_length, max_length=max_length,
attention_mask=attention_mask_clone,
logits_processor=logits_processor, logits_processor=logits_processor,
output_scores=output_scores, output_scores=output_scores,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate, return_dict_in_generate=return_dict_in_generate,
**kwargs, **kwargs,
**model_kwargs,
) )
return output_generate, output_group_beam_search return output_generate, output_group_beam_search
...@@ -1044,12 +1050,7 @@ class GenerationTesterMixin: ...@@ -1044,12 +1050,7 @@ class GenerationTesterMixin:
model = model_class(config).to(torch_device) model = model_class(config).to(torch_device)
model.eval() model.eval()
output_ids_generate = model.generate( output_ids_generate = model.generate(do_sample=False, max_length=max_length, remove_invalid_values=True)
do_sample=False,
max_length=max_length,
remove_invalid_values=True,
)
self.assertIsNotNone(output_ids_generate) self.assertIsNotNone(output_ids_generate)
def test_group_beam_search_generate(self): def test_group_beam_search_generate(self):
...@@ -2052,7 +2053,7 @@ class GenerationIntegrationTests(unittest.TestCase): ...@@ -2052,7 +2053,7 @@ class GenerationIntegrationTests(unittest.TestCase):
# max_new_tokens and max_length serve the same purpose and must not be used together. # max_new_tokens and max_length serve the same purpose and must not be used together.
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
gpt2_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20) gpt2_model.generate(input_ids=input_ids, max_new_tokens=10, max_length=20)
def test_encoder_decoder_generate_with_inputs_embeds(self): def test_encoder_decoder_generate_with_inputs_embeds(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
...@@ -2699,3 +2700,19 @@ class GenerationIntegrationTests(unittest.TestCase): ...@@ -2699,3 +2700,19 @@ class GenerationIntegrationTests(unittest.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
model.generate(input_ids, force_words_ids=[[[-1]]]) model.generate(input_ids, force_words_ids=[[[-1]]])
def test_validate_generation_inputs(self):
tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/t5-tiny-random")
model = AutoModelForSeq2SeqLM.from_pretrained("patrickvonplaten/t5-tiny-random")
encoder_input_str = "Hello world"
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
# typos are quickly detected (the correct argument is `do_sample`)
with self.assertRaisesRegex(ValueError, "do_samples"):
model.generate(input_ids, do_samples=True)
# arbitrary arguments that will not be used anywhere are also not accepted
with self.assertRaisesRegex(ValueError, "foo"):
fake_model_kwargs = {"foo": "bar"}
model.generate(input_ids, **fake_model_kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment