Unverified Commit ed1924e8 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: validate `model_kwargs` (and catch typos in generate arguments) (#18261)

* validate generate model_kwargs

* generate tests -- not all models have an attn mask
parent 2156619f
......@@ -841,6 +841,29 @@ class GenerationMixin:
return transition_scores
def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
# Excludes arguments that are handled before calling any model function
if self.config.is_encoder_decoder:
for key in ["decoder_input_ids"]:
model_kwargs.pop(key, None)
unused_model_args = []
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
# `kwargs` if often used to handle optional forward pass inputs like `attention_mask`. If
# `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;)
if "kwargs" in model_args:
model_args |= set(inspect.signature(self.forward).parameters)
for key, value in model_kwargs.items():
if value is not None and key not in model_args:
unused_model_args.append(key)
if unused_model_args:
raise ValueError(
f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
" generate arguments will also show up in this list)"
)
@torch.no_grad()
def generate(
self,
......@@ -1120,6 +1143,9 @@ class GenerationMixin:
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Paris ist eines der dichtesten besiedelten Gebiete Europas.']
```"""
# 0. Validate model kwargs
self._validate_model_kwargs(model_kwargs.copy())
# 1. Set generation parameters if not already defined
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
num_beams = num_beams if num_beams is not None else self.config.num_beams
......
......@@ -75,21 +75,25 @@ class GenerationTesterMixin:
def _get_input_ids_and_config(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict[self.input_name]
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
# cut to half length & take max batch_size 3
max_batch_size = 2
sequence_length = input_ids.shape[-1] // 2
input_ids = input_ids[:max_batch_size, :sequence_length]
attention_mask = attention_mask[:max_batch_size, :sequence_length]
# generate max 3 tokens
max_length = input_ids.shape[-1] + 3
if config.eos_token_id is not None and config.pad_token_id is None:
# hack to allow generate for models such as GPT2 as is done in `generate()`
config.pad_token_id = config.eos_token_id
# TransfoXL has no attention mask
if "transfoxl" in config.__class__.__name__.lower():
attention_mask = None
else:
attention_mask = torch.ones_like(input_ids, dtype=torch.long)[:max_batch_size, :sequence_length]
return config, input_ids, attention_mask, max_length
@staticmethod
......@@ -252,10 +256,9 @@ class GenerationTesterMixin:
)
kwargs = {}
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
attention_mask=attention_mask,
do_sample=False,
num_beams=1,
max_length=max_length,
......@@ -265,6 +268,7 @@ class GenerationTesterMixin:
return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True,
**logits_process_kwargs,
**model_kwargs,
)
if model.config.is_encoder_decoder:
......@@ -278,16 +282,17 @@ class GenerationTesterMixin:
kwargs["encoder_outputs"] = encoder_outputs
with torch.no_grad():
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_greedy = model.greedy_search(
input_ids,
max_length=max_length,
attention_mask=attention_mask,
logits_processor=logits_processor,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
**kwargs,
**model_kwargs,
)
return output_greedy, output_generate
......@@ -308,13 +313,13 @@ class GenerationTesterMixin:
return_dict_in_generate=False,
):
torch.manual_seed(0)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
do_sample=True,
num_beams=1,
max_length=max_length,
num_return_sequences=num_return_sequences,
attention_mask=attention_mask,
output_scores=output_scores,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
......@@ -327,7 +332,7 @@ class GenerationTesterMixin:
torch.manual_seed(0)
kwargs = {}
if model.config.is_encoder_decoder:
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs(
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
model,
input_ids,
attention_mask,
......@@ -336,18 +341,16 @@ class GenerationTesterMixin:
output_hidden_states=output_hidden_states,
)
kwargs["encoder_outputs"] = encoder_outputs
input_ids_clone = input_ids_clone.repeat_interleave(num_return_sequences, dim=0)
else:
attention_mask_clone = attention_mask.repeat_interleave(num_return_sequences, dim=0)
input_ids_clone = input_ids.repeat_interleave(num_return_sequences, dim=0)
elif attention_mask is not None:
attention_mask = attention_mask.repeat_interleave(num_return_sequences, dim=0)
# prevent flaky generation test failures
logits_processor.append(InfNanRemoveLogitsProcessor())
with torch.no_grad():
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_sample = model.sample(
input_ids_clone,
attention_mask=attention_mask_clone,
input_ids.repeat_interleave(num_return_sequences, dim=0),
max_length=max_length,
logits_processor=logits_processor,
logits_warper=logits_warper,
......@@ -356,6 +359,7 @@ class GenerationTesterMixin:
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
**kwargs,
**model_kwargs,
)
return output_sample, output_generate
......@@ -374,9 +378,9 @@ class GenerationTesterMixin:
output_hidden_states=False,
return_dict_in_generate=False,
):
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
attention_mask=attention_mask,
do_sample=False,
max_length=max_length,
output_scores=output_scores,
......@@ -386,12 +390,13 @@ class GenerationTesterMixin:
remove_invalid_values=True,
**beam_kwargs,
**logits_process_kwargs,
**model_kwargs,
)
# beam_search does not automatically interleave `batch_size` dim for `num_beams`
kwargs = {}
if model.config.is_encoder_decoder:
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs(
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
model,
input_ids,
attention_mask,
......@@ -400,23 +405,22 @@ class GenerationTesterMixin:
output_hidden_states=output_hidden_states,
)
kwargs["encoder_outputs"] = encoder_outputs
input_ids_clone = input_ids_clone.repeat_interleave(beam_scorer.num_beams, dim=0)
else:
attention_mask_clone = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0)
input_ids_clone = input_ids.repeat_interleave(beam_scorer.num_beams, dim=0)
elif attention_mask is not None:
attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0)
with torch.no_grad():
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_beam_search = model.beam_search(
input_ids_clone,
input_ids.repeat_interleave(beam_scorer.num_beams, dim=0),
beam_scorer,
max_length=max_length,
attention_mask=attention_mask_clone,
logits_processor=logits_processor,
output_scores=output_scores,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
**kwargs,
**model_kwargs,
)
return output_generate, output_beam_search
......@@ -437,9 +441,9 @@ class GenerationTesterMixin:
return_dict_in_generate=False,
):
torch.manual_seed(0)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
attention_mask=attention_mask,
do_sample=True,
max_length=max_length,
output_scores=output_scores,
......@@ -449,6 +453,7 @@ class GenerationTesterMixin:
remove_invalid_values=True,
**beam_kwargs,
**logits_warper_kwargs,
**model_kwargs,
)
# beam_search does not automatically interleave `batch_size` dim for `num_beams * num_return_sequences`
kwargs = {}
......@@ -462,7 +467,7 @@ class GenerationTesterMixin:
output_hidden_states=output_hidden_states,
)
kwargs["encoder_outputs"] = encoder_outputs
else:
elif attention_mask is not None:
attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0)
# prevent flaky generation test failures
......@@ -471,11 +476,11 @@ class GenerationTesterMixin:
torch.manual_seed(0)
with torch.no_grad():
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_beam_sample = model.beam_sample(
input_ids.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0),
beam_scorer,
max_length=max_length,
attention_mask=attention_mask,
logits_warper=logits_warper,
logits_processor=logits_processor,
output_scores=output_scores,
......@@ -483,6 +488,7 @@ class GenerationTesterMixin:
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
**kwargs,
**model_kwargs,
)
return output_generate, output_beam_sample
......@@ -502,9 +508,9 @@ class GenerationTesterMixin:
output_hidden_states=False,
return_dict_in_generate=False,
):
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
attention_mask=attention_mask,
do_sample=False,
max_length=max_length,
output_scores=output_scores,
......@@ -514,12 +520,13 @@ class GenerationTesterMixin:
remove_invalid_values=True,
**beam_kwargs,
**logits_process_kwargs,
**model_kwargs,
)
# group_beam_search does not automatically interleave `batch_size` dim for `num_beams`
kwargs = {}
if model.config.is_encoder_decoder:
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs(
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
model,
input_ids,
attention_mask,
......@@ -528,23 +535,22 @@ class GenerationTesterMixin:
output_hidden_states=output_hidden_states,
)
kwargs["encoder_outputs"] = encoder_outputs
input_ids_clone = input_ids_clone.repeat_interleave(beam_scorer.num_beams, dim=0)
else:
attention_mask_clone = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0)
input_ids_clone = input_ids.repeat_interleave(beam_scorer.num_beams, dim=0)
elif attention_mask is not None:
attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0)
with torch.no_grad():
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_group_beam_search = model.group_beam_search(
input_ids_clone,
input_ids.repeat_interleave(beam_scorer.num_beams, dim=0),
beam_scorer,
max_length=max_length,
attention_mask=attention_mask_clone,
logits_processor=logits_processor,
output_scores=output_scores,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
**kwargs,
**model_kwargs,
)
return output_generate, output_group_beam_search
......@@ -564,9 +570,9 @@ class GenerationTesterMixin:
output_hidden_states=False,
return_dict_in_generate=False,
):
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
attention_mask=attention_mask,
do_sample=False,
max_length=max_length,
output_scores=output_scores,
......@@ -577,12 +583,13 @@ class GenerationTesterMixin:
constraints=constraints,
**beam_kwargs,
**logits_process_kwargs,
**model_kwargs,
)
# group_beam_search does not automatically interleave `batch_size` dim for `num_beams`
kwargs = {}
if model.config.is_encoder_decoder:
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs(
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
model,
input_ids,
attention_mask,
......@@ -591,23 +598,22 @@ class GenerationTesterMixin:
output_hidden_states=output_hidden_states,
)
kwargs["encoder_outputs"] = encoder_outputs
input_ids_clone = input_ids_clone.repeat_interleave(constrained_beam_scorer.num_beams, dim=0)
else:
attention_mask_clone = attention_mask.repeat_interleave(constrained_beam_scorer.num_beams, dim=0)
input_ids_clone = input_ids.repeat_interleave(constrained_beam_scorer.num_beams, dim=0)
elif attention_mask is not None:
attention_mask = attention_mask.repeat_interleave(constrained_beam_scorer.num_beams, dim=0)
with torch.no_grad():
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_group_beam_search = model.constrained_beam_search(
input_ids_clone,
input_ids.repeat_interleave(constrained_beam_scorer.num_beams, dim=0),
constrained_beam_scorer,
max_length=max_length,
attention_mask=attention_mask_clone,
logits_processor=logits_processor,
output_scores=output_scores,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
**kwargs,
**model_kwargs,
)
return output_generate, output_group_beam_search
......@@ -1044,12 +1050,7 @@ class GenerationTesterMixin:
model = model_class(config).to(torch_device)
model.eval()
output_ids_generate = model.generate(
do_sample=False,
max_length=max_length,
remove_invalid_values=True,
)
output_ids_generate = model.generate(do_sample=False, max_length=max_length, remove_invalid_values=True)
self.assertIsNotNone(output_ids_generate)
def test_group_beam_search_generate(self):
......@@ -2052,7 +2053,7 @@ class GenerationIntegrationTests(unittest.TestCase):
# max_new_tokens and max_length serve the same purpose and must not be used together.
with self.assertRaises(ValueError):
gpt2_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20)
gpt2_model.generate(input_ids=input_ids, max_new_tokens=10, max_length=20)
def test_encoder_decoder_generate_with_inputs_embeds(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
......@@ -2699,3 +2700,19 @@ class GenerationIntegrationTests(unittest.TestCase):
with self.assertRaises(ValueError):
model.generate(input_ids, force_words_ids=[[[-1]]])
def test_validate_generation_inputs(self):
tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/t5-tiny-random")
model = AutoModelForSeq2SeqLM.from_pretrained("patrickvonplaten/t5-tiny-random")
encoder_input_str = "Hello world"
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
# typos are quickly detected (the correct argument is `do_sample`)
with self.assertRaisesRegex(ValueError, "do_samples"):
model.generate(input_ids, do_samples=True)
# arbitrary arguments that will not be used anywhere are also not accepted
with self.assertRaisesRegex(ValueError, "foo"):
fake_model_kwargs = {"foo": "bar"}
model.generate(input_ids, **fake_model_kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment