Unverified Commit 78cda46f authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: Add assisted generation (#22211)

* working mvp

* remove breakpoint

* fix commit

* standardize outputs

* tmp commit

* tests almost ready

* tmp commit

* skip a few models

* Add streaming; Docs and examples

* document limitations

* PR commits

* Amy PR comments
parent 90247d3e
...@@ -332,3 +332,30 @@ The groups are selected to ensure they are distinct enough compared to the other ...@@ -332,3 +332,30 @@ The groups are selected to ensure they are distinct enough compared to the other
This guide illustrates the main parameters that enable various decoding strategies. More advanced parameters exist for the This guide illustrates the main parameters that enable various decoding strategies. More advanced parameters exist for the
[`generate`] method, which gives you even further control over the [`generate`] method's behavior. [`generate`] method, which gives you even further control over the [`generate`] method's behavior.
For the complete list of the available parameters, refer to the [API documentation](./main_classes/text_generation.mdx). For the complete list of the available parameters, refer to the [API documentation](./main_classes/text_generation.mdx).
### Assisted Generation
Assisted generation is a modification of the decoding strategies above that uses an assistant model with the same
tokenizer (ideally a much smaller model) to speed up the decoding process. Currently only assisted greedy search is
supported, and doesn't support batched inputs.
<!-- TODO: add link to the blog post about assisted generation when it exists -->
To enable assisted generation, set the `assistant_model` argument with a model.
```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> prompt = "Alice and Bob"
>>> checkpoint = "EleutherAI/pythia-1.4b-deduped"
>>> assistant_checkpoint = "EleutherAI/pythia-160m-deduped"
>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint)
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> model = AutoModelForCausalLM.from_pretrained(checkpoint)
>>> assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint)
>>> outputs = model.generate(**inputs, assistant_model=assistant_model)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
```
This diff is collapsed.
...@@ -79,14 +79,13 @@ class GenerationTesterMixin: ...@@ -79,14 +79,13 @@ class GenerationTesterMixin:
all_generative_model_classes = () all_generative_model_classes = ()
input_name = "input_ids" input_name = "input_ids"
def _get_input_ids_and_config(self): def _get_input_ids_and_config(self, batch_size=2):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict[self.input_name] input_ids = inputs_dict[self.input_name]
# cut to half length & take max batch_size 3 # cut to half length & take max batch_size 3
max_batch_size = 2
sequence_length = input_ids.shape[-1] // 2 sequence_length = input_ids.shape[-1] // 2
input_ids = input_ids[:max_batch_size, :sequence_length] input_ids = input_ids[:batch_size, :sequence_length]
# generate max 3 tokens # generate max 3 tokens
max_length = input_ids.shape[-1] + 3 max_length = input_ids.shape[-1] + 3
...@@ -99,7 +98,7 @@ class GenerationTesterMixin: ...@@ -99,7 +98,7 @@ class GenerationTesterMixin:
if "transfoxl" in config.__class__.__name__.lower(): if "transfoxl" in config.__class__.__name__.lower():
attention_mask = None attention_mask = None
else: else:
attention_mask = torch.ones_like(input_ids, dtype=torch.long)[:max_batch_size, :sequence_length] attention_mask = torch.ones_like(input_ids, dtype=torch.long)[:batch_size, :sequence_length]
return config, input_ids, attention_mask, max_length return config, input_ids, attention_mask, max_length
...@@ -1458,6 +1457,66 @@ class GenerationTesterMixin: ...@@ -1458,6 +1457,66 @@ class GenerationTesterMixin:
for output in (output_contrastive, output_generate): for output in (output_contrastive, output_generate):
self._check_outputs(output, input_ids, model.config, use_cache=True) self._check_outputs(output, input_ids, model.config, use_cache=True)
def test_assisted_greedy_search_matches_greedy_search(self):
# This test ensures that the assisted generation does not introduce output changes over greedy search.
# It breaks the pattern in the tests above, for multiple reasons:
# - assisted_greedy_search, contrarily to the other methods, can't be called on its own (e.g. needs to
# prepare the assistant encoder outputs in the main generate body);
# - assisted_greedy_search does not support `use_cache = False`
# - assisted_greedy_search does not support `batch_size > 1`
for model_class in self.all_generative_model_classes:
# won't fix: FSMT and Reformer have a different cache variable type (and format).
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
return
# may fix in the future: the following models fail to pass this test, and need model-specific fixes
if any(
model_name in model_class.__name__.lower()
for model_name in ["bigbirdpegasus", "gptbigcode", "led", "mega", "speech2text"]
):
return
# enable cache
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1)
# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
return
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
output_greedy = model.generate(
input_ids,
attention_mask=attention_mask,
max_length=max_length,
num_beams=1,
do_sample=False,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
# Note: with assisted generate, if the same model is used as assistant, then all assistant tokens will
# be correct
output_assisted = model.generate(
input_ids,
attention_mask=attention_mask,
max_length=max_length,
num_beams=1,
do_sample=False,
assistant_model=model,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist())
for output in (output_greedy, output_assisted):
self._check_outputs(output, input_ids, model.config, use_cache=True)
def test_generate_with_head_masking(self): def test_generate_with_head_masking(self):
"""Test designed for encoder-decoder models to ensure the attention head masking is used.""" """Test designed for encoder-decoder models to ensure the attention head masking is used."""
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"] attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
......
...@@ -280,7 +280,7 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT ...@@ -280,7 +280,7 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
# overwrite from GenerationTesterMixin to solve problem # overwrite from GenerationTesterMixin to solve problem
# with conflicting random seeds # with conflicting random seeds
def _get_input_ids_and_config(self): def _get_input_ids_and_config(self, batch_size=2):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.attention_type = "original_full" config.attention_type = "original_full"
...@@ -288,10 +288,9 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT ...@@ -288,10 +288,9 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
attention_mask = torch.ones_like(input_ids, dtype=torch.long) attention_mask = torch.ones_like(input_ids, dtype=torch.long)
# cut to half length & take max batch_size 3 # cut to half length & take max batch_size 3
max_batch_size = 2
sequence_length = input_ids.shape[-1] // 2 sequence_length = input_ids.shape[-1] // 2
input_ids = input_ids[:max_batch_size, :sequence_length] input_ids = input_ids[:batch_size, :sequence_length]
attention_mask = attention_mask[:max_batch_size, :sequence_length] attention_mask = attention_mask[:batch_size, :sequence_length]
# generate max 3 tokens # generate max 3 tokens
max_length = input_ids.shape[-1] + 3 max_length = input_ids.shape[-1] + 3
......
...@@ -303,7 +303,7 @@ class TFWhisperModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC ...@@ -303,7 +303,7 @@ class TFWhisperModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC
input_ids = input_ids[:max_batch_size, :, :] input_ids = input_ids[:max_batch_size, :, :]
# generate max 3 tokens # generate max 3 tokens
max_length = input_ids.shape[-1] + 3 max_length = 4
if config.eos_token_id is not None and config.pad_token_id is None: if config.eos_token_id is not None and config.pad_token_id is None:
# hack to allow generate for models such as GPT2 as is done in `generate()` # hack to allow generate for models such as GPT2 as is done in `generate()`
config.pad_token_id = config.eos_token_id config.pad_token_id = config.eos_token_id
......
...@@ -359,16 +359,15 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ...@@ -359,16 +359,15 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common() config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs) self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs)
def _get_input_ids_and_config(self): def _get_input_ids_and_config(self, batch_size=3):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict[self.input_name] input_ids = inputs_dict[self.input_name]
# cut to half length & take max batch_size 3 # cut to half length & take max batch_size=batch_size
max_batch_size = 3 input_ids = input_ids[:batch_size, :, :]
input_ids = input_ids[:max_batch_size, :, :]
# generate max 3 tokens # generate max 3 tokens
max_length = input_ids.shape[-1] + 3 max_length = 4
if config.eos_token_id is not None and config.pad_token_id is None: if config.eos_token_id is not None and config.pad_token_id is None:
# hack to allow generate for models such as GPT2 as is done in `generate()` # hack to allow generate for models such as GPT2 as is done in `generate()`
config.pad_token_id = config.eos_token_id config.pad_token_id = config.eos_token_id
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment