Unverified Commit e4a97f82 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: assisted generation with sample (take 2) (#22949)

* temperature controls speed
parent 7701716e
......@@ -333,15 +333,16 @@ This guide illustrates the main parameters that enable various decoding strategi
[`generate`] method, which gives you even further control over the [`generate`] method's behavior.
For the complete list of the available parameters, refer to the [API documentation](./main_classes/text_generation.mdx).
### Assisted Generation
### Assisted Decoding
Assisted generation is a modification of the decoding strategies above that uses an assistant model with the same
tokenizer (ideally a much smaller model) to speed up the decoding process. Currently only assisted greedy search is
supported, and doesn't support batched inputs.
Assisted decoding is a modification of the decoding strategies above that uses an assistant model with the same
tokenizer (ideally a much smaller model) to greedily generate a few candidate tokens. The main model then validates
the candidate tokens in a single forward pass, which speeds up the decoding process. Currently, only greedy search
and sampling are supported with assisted decoding, and doesn't support batched inputs.
<!-- TODO: add link to the blog post about assisted generation when it exists -->
<!-- TODO: add link to the blog post about assisted decoding when it exists -->
To enable assisted generation, set the `assistant_model` argument with a model.
To enable assisted decoding, set the `assistant_model` argument with a model.
```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
......@@ -359,3 +360,25 @@ To enable assisted generation, set the `assistant_model` argument with a model.
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
```
When using assisted decoding with sampling methods, you can use the `temperarure` argument to control the randomness
just like in multinomial sampling. However, in assisted decoding, reducing the temperature will help improving latency.
<!-- TODO: link the blog post again to explain why the tradeoff exists -->
```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> prompt = "Alice and Bob"
>>> checkpoint = "EleutherAI/pythia-1.4b-deduped"
>>> assistant_checkpoint = "EleutherAI/pythia-160m-deduped"
>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint)
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> model = AutoModelForCausalLM.from_pretrained(checkpoint)
>>> assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint)
>>> outputs = model.generate(**inputs, assistant_model=assistant_model, do_sample=True, temperature=0.5)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
["Alice and Bob are sitting on the sofa. Alice says, 'I'm going to my room"]
```
......@@ -54,8 +54,10 @@ class GenerationConfig(PushToHubMixin):
`num_beams>1` and `num_beam_groups>1`
- *constrained beam-search decoding* by calling [`~generation.GenerationMixin.constrained_beam_search`], if
`constraints!=None` or `force_words_ids!=None`
- *assisted decoding* by calling [`~generation.GenerationMixin.assisted_decoding`], if
`assistant_model` is passed to `.generate()`
You do not need to call any of the above methods directly. Pass custom parameter values to 'generate'. To learn
You do not need to call any of the above methods directly. Pass custom parameter values to '.generate()'. To learn
more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
Arg:
......
......@@ -492,7 +492,7 @@ class GenerationMixin:
def prepare_inputs_for_generation(self, *args, **kwargs):
raise NotImplementedError(
"A model class needs to define a `prepare_inputs_for_generation` method in order to use `generate`."
"A model class needs to define a `prepare_inputs_for_generation` method in order to use `.generate()`."
)
def _prepare_model_inputs(
......@@ -962,10 +962,10 @@ class GenerationMixin:
object_type = "stopping criteria" if isinstance(custom, StoppingCriteria) else "logits processor"
raise ValueError(
f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to"
f" `generate`, but it has already been created with the values {default}. {default} has been"
f" `.generate()`, but it has already been created with the values {default}. {default} has been"
" created by passing the corresponding arguments to generate or by the model's config default"
f" values. If you just want to change the default values of {object_type} consider passing"
f" them as arguments to `generate` instead of using a custom {object_type}."
f" them as arguments to `.generate()` instead of using a custom {object_type}."
)
default_list.extend(custom_list)
return default_list
......@@ -1418,14 +1418,14 @@ class GenerationMixin:
and not is_constraint_gen_mode
and not is_contrastive_search_gen_mode
)
is_assisted_greedy_gen_mode = False
is_assisted_gen_mode = False
if assistant_model is not None:
if not is_greedy_gen_mode:
if not (is_greedy_gen_mode or is_sample_gen_mode):
raise ValueError(
"You've set `assistant_model`, which triggers assisted generation. Currently, assisted generation "
"is only supported with Greedy Search."
"You've set `assistant_model`, which triggers assisted generate. Currently, assisted generate "
"is only supported with Greedy Search and Sample."
)
is_assisted_greedy_gen_mode = True
is_assisted_gen_mode = True
if generation_config.num_beam_groups > generation_config.num_beams:
raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
......@@ -1464,16 +1464,16 @@ class GenerationMixin:
generation_config=generation_config, stopping_criteria=stopping_criteria
)
# 10. go into different generation modes
if is_assisted_greedy_gen_mode:
if is_assisted_gen_mode:
if generation_config.num_return_sequences > 1:
raise ValueError(
"num_return_sequences has to be 1 when doing assisted greedy search, "
"num_return_sequences has to be 1 when doing assisted generate, "
f"but is {generation_config.num_return_sequences}."
)
if batch_size > 1:
raise ValueError("Assisted generation is only supported for batch_size = 1")
raise ValueError("assisted generate is only supported for batch_size = 1")
if not model_kwargs["use_cache"]:
raise ValueError("Assisted generation requires `use_cache=True`")
raise ValueError("assisted generate requires `use_cache=True`")
# 11. If the assistant model is an encoder-decoder, prepare its encoder outputs
if assistant_model.config.is_encoder_decoder:
......@@ -1486,11 +1486,13 @@ class GenerationMixin:
)
model_kwargs["assistant_encoder_outputs"] = assistant_model_kwargs["encoder_outputs"]
# 12. run assisted greedy search
return self.assisted_greedy_search(
# 12. run assisted generate
return self.assisted_decoding(
input_ids,
assistant_model=assistant_model,
do_sample=generation_config.do_sample,
logits_processor=logits_processor,
logits_warper=self._get_logits_warper(generation_config) if generation_config.do_sample else None,
stopping_criteria=stopping_criteria,
pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
......@@ -4059,11 +4061,13 @@ class GenerationMixin:
else:
return sequence_outputs["sequences"]
def assisted_greedy_search(
def assisted_decoding(
self,
input_ids: torch.LongTensor,
assistant_model: "PreTrainedModel",
do_sample: bool = False,
logits_processor: Optional[LogitsProcessorList] = None,
logits_warper: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
......@@ -4076,12 +4080,13 @@ class GenerationMixin:
**model_kwargs,
):
r"""
Generates sequences of token ids for models with a language modeling head using **greedy decoding**, assisted
by a smaller model. Can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
Generates sequences of token ids for models with a language modeling head using **greedy decoding** or
**sample** (depending on `do_sample`), assisted by a smaller model. Can be used for text-decoder, text-to-text,
speech-to-text, and vision-to-text models.
<Tip warning={true}>
In most cases, you do not need to call [`~generation.GenerationMixin.assisted_greedy_search`] directly. Use
In most cases, you do not need to call [`~generation.GenerationMixin.assisted_decoding`] directly. Use
generate() instead. For an overview of generation strategies and code examples, check the [following
guide](../generation_strategies).
......@@ -4095,9 +4100,15 @@ class GenerationMixin:
same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model
is much faster than running generation with the model you're calling generate from. As such, the
assistant model should be much smaller.
do_sample (`bool`, *optional*, defaults to `False`):
Whether or not to use sampling ; use greedy decoding otherwise.
logits_processor (`LogitsProcessorList`, *optional*):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
logits_warper (`LogitsProcessorList`, *optional*):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
to warp the prediction score distribution of the language modeling head applied before multinomial
sampling at each generation step.
stopping_criteria (`StoppingCriteriaList`, *optional*):
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
used to tell if the generation loop should stop.
......@@ -4157,7 +4168,7 @@ class GenerationMixin:
... ]
... )
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
>>> outputs = model.assisted_greedy_search(
>>> outputs = model.assisted_decoding(
... input_ids,
... assistant_model=assistant_model,
... logits_processor=logits_processor,
......@@ -4166,13 +4177,14 @@ class GenerationMixin:
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
["It might be possible to get a better understanding of the nature of the problem, but it's not"]
```"""
# NOTE: the code here is copy/paste from greedy search, except when clearly stated in the comments
# NOTE: the code here is copy/paste from greedy search/sample, except when clearly stated in the comments
# Assistant: initialize assistant-related variables
if not hasattr(assistant_model, "max_assistant_tokens"):
assistant_model.max_assistant_tokens = 5 # this value, which will be updated, persists across calls
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
......@@ -4285,6 +4297,8 @@ class GenerationMixin:
# 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain
# `candidate_length + 1` relevant logits from this process (see step 7 on why the +1)
# 2.1. Run a forward pass on the candidate sequence
if "past_key_values" in model_kwargs:
og_model_attn = torch.ones_like(candidate_input_ids)
og_model_input_ids = candidate_input_ids[:, -candidate_length - 1 :]
......@@ -4320,17 +4334,28 @@ class GenerationMixin:
output_hidden_states=output_hidden_states,
)
# 3. Obtain the argmax from the original model logits.
# 2.2. Process the new logits
new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present
if len(logits_processor) > 0:
for i in range(candidate_length):
new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])
max_logits = new_logits.argmax(dim=-1)[:, -candidate_length - 1 : -1]
if len(logits_warper) > 0:
for i in range(candidate_length):
new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])
# 3. Obtain the next tokens from the original model logits. If `do_sample` is True, use multinomial
# sampling, otherwise use argmax.
if do_sample:
probs = new_logits[:, -candidate_length - 1 :, :].softmax(dim=-1)
sampled_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :]
next_tokens = sampled_tokens[:, :-1]
else:
next_tokens = new_logits[:, -candidate_length - 1 : -1, :].argmax(dim=-1)
# 4. Compare the argmax from the original model logits with the assistant forecasted tokens. We can keep
# the assistant forecasted tokens until the first mismatch, or until the max length is reached.
candidate_new_tokens = candidate_input_ids[:, -candidate_length:]
n_matches = ((~(candidate_new_tokens == max_logits)).cumsum(dim=-1) < 1).sum()
n_matches = ((~(candidate_new_tokens == next_tokens)).cumsum(dim=-1) < 1).sum()
# 5. Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic,
# probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the
......@@ -4360,12 +4385,17 @@ class GenerationMixin:
next_token_scores = new_logits[:, n_matches, :]
# 7. Use the set of logits after the last matching assistant token to obtain the next token. Note that,
# because of this step, assisted greedy search reduces to a normal greedy search if there is no match.
# because of this step, assisted generation search reduces to a normal greedy search/sample if there is no
# match.
if do_sample:
probs = probs[:, n_matches, :]
next_tokens = sampled_tokens[:, n_matches]
else:
next_tokens = torch.argmax(next_token_scores, dim=-1)
# Assistant: main logic end; Compared to greedy search, the following (redundant) blocks were removed
# below: (1) model input preparation; (2) model forward pass; (3) score preparation; (4) model cache
# update.
# Assistant: main logic end; Compared to greedy search/sample, the following (redundant) blocks were
# removed below: (1) model input preparation; (2) model forward pass; (3) score preparation; (4) model
# cache update.
if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need
......@@ -4378,20 +4408,18 @@ class GenerationMixin:
if "past_key_values" not in model_kwargs:
last_matching_idx = new_cur_len - 1
prompt_length = cur_len
else:
last_matching_idx = n_matches
prompt_length = 0
if output_attentions:
if self.config.is_encoder_decoder:
cross_attentions = _split_model_outputs(
cross_attentions, outputs.cross_attentions, prompt_length, last_matching_idx
cross_attentions, outputs.cross_attentions, cur_len, last_matching_idx
)
decoder_attentions = _split_model_outputs(
decoder_attentions,
outputs.decoder_attentions,
prompt_length,
cur_len,
last_matching_idx,
is_decoder_attention=True,
)
......@@ -4399,18 +4427,18 @@ class GenerationMixin:
decoder_attentions = _split_model_outputs(
decoder_attentions,
outputs.attentions,
prompt_length,
cur_len,
last_matching_idx,
is_decoder_attention=True,
)
if output_hidden_states:
if self.config.is_encoder_decoder:
decoder_hidden_states = _split_model_outputs(
decoder_hidden_states, outputs.decoder_hidden_states, prompt_length, last_matching_idx
decoder_hidden_states, outputs.decoder_hidden_states, cur_len, last_matching_idx
)
else:
decoder_hidden_states = _split_model_outputs(
decoder_hidden_states, outputs.hidden_states, prompt_length, last_matching_idx
decoder_hidden_states, outputs.hidden_states, cur_len, last_matching_idx
)
# finished sentences should have their next token be a padding token
......@@ -4503,24 +4531,26 @@ def _crop_past_key_values(model, past_key_values, maximum_length):
return past_key_values
def _split_model_outputs(outputs, new_outputs, prompt_length, last_matching_idx, is_decoder_attention=False):
def _split_model_outputs(outputs, new_outputs, previous_cur_len, last_matching_idx, is_decoder_attention=False):
"""
Given the (decoder/cross attentions)/(decoder hidden states) for multiple generated tokens, splits it into a tuple
where each member corresponds to a single generated token.
"""
# Retrocompatibility: in our generation functions, the first iteration includes the attention/hidden states for the
# prompt.
if prompt_length > 0:
if len(outputs) == 0:
new_tuple = ()
for layer in new_outputs:
last_dim_size = prompt_length if is_decoder_attention else layer.shape[-1]
new_tuple += (layer[..., :prompt_length, :last_dim_size],)
last_dim_size = previous_cur_len if is_decoder_attention else layer.shape[-1]
new_tuple += (layer[..., :previous_cur_len, :last_dim_size],)
outputs += (new_tuple,)
last_matching_idx -= previous_cur_len
previous_cur_len += 1
for i in range(prompt_length, last_matching_idx + 1):
for i in range(last_matching_idx + 1):
new_tuple = ()
for layer in new_outputs:
last_dim_size = i + 1 if is_decoder_attention else layer.shape[-1]
last_dim_size = previous_cur_len + i if is_decoder_attention else layer.shape[-1]
new_tuple += (layer[..., i : i + 1, :last_dim_size],)
outputs += (new_tuple,)
return outputs
......
......@@ -1457,22 +1457,22 @@ class GenerationTesterMixin:
for output in (output_contrastive, output_generate):
self._check_outputs(output, input_ids, model.config, use_cache=True)
def test_assisted_greedy_search_matches_greedy_search(self):
def test_assisted_decoding_matches_greedy_search(self):
# This test ensures that the assisted generation does not introduce output changes over greedy search.
# It breaks the pattern in the tests above, for multiple reasons:
# - assisted_greedy_search, contrarily to the other methods, can't be called on its own (e.g. needs to
# - assisted_decoding, contrarily to the other methods, can't be called on its own (e.g. needs to
# prepare the assistant encoder outputs in the main generate body);
# - assisted_greedy_search does not support `use_cache = False`
# - assisted_greedy_search does not support `batch_size > 1`
# - assisted_decoding does not support `use_cache = False`
# - assisted_decoding does not support `batch_size > 1`
for model_class in self.all_generative_model_classes:
# won't fix: FSMT and Reformer have a different cache variable type (and format).
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
return
# may fix in the future: the following models fail to pass this test, and need model-specific fixes
# may fix in the future: the following models fail with assisted decoding, and need model-specific fixes
if any(
model_name in model_class.__name__.lower()
for model_name in ["bigbirdpegasus", "gptbigcode", "led", "mega", "speech2text"]
for model_name in ["bigbirdpegasus", "gptbigcode", "led", "mega", "speech2text", "git", "prophetnet"]
):
return
......@@ -1517,6 +1517,46 @@ class GenerationTesterMixin:
for output in (output_greedy, output_assisted):
self._check_outputs(output, input_ids, model.config, use_cache=True)
def test_assisted_decoding_sample(self):
# Seeded assisted decoding will not match sample for the same seed, as there are >1 sampling steps per output
# token. As such, this test only checks that the output format is correct.
for model_class in self.all_generative_model_classes:
# won't fix: FSMT and Reformer have a different cache variable type (and format).
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
return
# may fix in the future: the following models fail with assisted decoding, and need model-specific fixes
if any(
model_name in model_class.__name__.lower()
for model_name in ["bigbirdpegasus", "gptbigcode", "led", "mega", "speech2text", "git", "prophetnet"]
):
return
# enable cache
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1)
# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
return
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
output_assisted = model.generate(
input_ids,
attention_mask=attention_mask,
max_length=max_length,
num_beams=1,
do_sample=True,
assistant_model=model, # triggers assisted decoding
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
self._check_outputs(output_assisted, input_ids, model.config, use_cache=True)
def test_generate_with_head_masking(self):
"""Test designed for encoder-decoder models to ensure the attention head masking is used."""
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment