Unverified Commit 849367cc authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: prepare assisted generation for release (#23052)

parent dfeb5aa6
...@@ -4177,7 +4177,6 @@ class GenerationMixin: ...@@ -4177,7 +4177,6 @@ class GenerationMixin:
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
["It might be possible to get a better understanding of the nature of the problem, but it's not"] ["It might be possible to get a better understanding of the nature of the problem, but it's not"]
```""" ```"""
# NOTE: the code here is copy/paste from greedy search/sample, except when clearly stated in the comments
# Assistant: initialize assistant-related variables # Assistant: initialize assistant-related variables
if not hasattr(assistant_model, "max_assistant_tokens"): if not hasattr(assistant_model, "max_assistant_tokens"):
assistant_model.max_assistant_tokens = 5 # this value, which will be updated, persists across calls assistant_model.max_assistant_tokens = 5 # this value, which will be updated, persists across calls
...@@ -4248,20 +4247,20 @@ class GenerationMixin: ...@@ -4248,20 +4247,20 @@ class GenerationMixin:
prev_seq_len = model_kwargs["assistant_past_key_values"][0][0].shape[2] prev_seq_len = model_kwargs["assistant_past_key_values"][0][0].shape[2]
# `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model) # `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model)
new_token_len = candidate_input_ids.shape[1] - prev_seq_len new_token_len = candidate_input_ids.shape[1] - prev_seq_len
tmp_inputs = candidate_input_ids[:, -new_token_len:] assist_inputs = candidate_input_ids[:, -new_token_len:]
tmp_attn = torch.ones_like(candidate_input_ids) assist_attn = torch.ones_like(candidate_input_ids)
# TODO (joao): make it compatible with models that use unconventional fwd pass logic, like blip2 # TODO (joao): make it compatible with models that use unconventional fwd pass logic, like blip2
if assistant_model.config.is_encoder_decoder: if assistant_model.config.is_encoder_decoder:
assistant_model_outputs = assistant_model( assistant_model_outputs = assistant_model(
decoder_input_ids=tmp_inputs, decoder_input_ids=assist_inputs,
decoder_attention_mask=tmp_attn, decoder_attention_mask=assist_attn,
past_key_values=model_kwargs["assistant_past_key_values"], past_key_values=model_kwargs["assistant_past_key_values"],
encoder_outputs=model_kwargs["assistant_encoder_outputs"], encoder_outputs=model_kwargs["assistant_encoder_outputs"],
) )
else: else:
assistant_model_outputs = assistant_model( assistant_model_outputs = assistant_model(
tmp_inputs, assist_inputs,
attention_mask=tmp_attn, attention_mask=assist_attn,
past_key_values=model_kwargs["assistant_past_key_values"], past_key_values=model_kwargs["assistant_past_key_values"],
) )
else: else:
...@@ -4296,16 +4295,17 @@ class GenerationMixin: ...@@ -4296,16 +4295,17 @@ class GenerationMixin:
candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1] candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
# 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain # 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain
# `candidate_length + 1` relevant logits from this process (see step 7 on why the +1) # `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct,
# we use this forward pass to also pick the subsequent logits in the original model.
# 2.1. Run a forward pass on the candidate sequence # 2.1. Run a forward pass on the candidate sequence
if "past_key_values" in model_kwargs: if "past_key_values" in model_kwargs:
og_model_attn = torch.ones_like(candidate_input_ids) model_attn = torch.ones_like(candidate_input_ids)
og_model_input_ids = candidate_input_ids[:, -candidate_length - 1 :] model_input_ids = candidate_input_ids[:, -candidate_length - 1 :]
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
outputs = self( outputs = self(
decoder_input_ids=og_model_input_ids, decoder_input_ids=model_input_ids,
decoder_attention_mask=og_model_attn, decoder_attention_mask=model_attn,
past_key_values=model_kwargs["past_key_values"], past_key_values=model_kwargs["past_key_values"],
encoder_outputs=model_kwargs["encoder_outputs"], encoder_outputs=model_kwargs["encoder_outputs"],
output_attentions=output_attentions, output_attentions=output_attentions,
...@@ -4313,8 +4313,8 @@ class GenerationMixin: ...@@ -4313,8 +4313,8 @@ class GenerationMixin:
) )
else: else:
outputs = self( outputs = self(
og_model_input_ids, model_input_ids,
attention_mask=og_model_attn, attention_mask=model_attn,
past_key_values=model_kwargs["past_key_values"], past_key_values=model_kwargs["past_key_values"],
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
...@@ -4343,59 +4343,51 @@ class GenerationMixin: ...@@ -4343,59 +4343,51 @@ class GenerationMixin:
for i in range(candidate_length): for i in range(candidate_length):
new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]) new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])
# 3. Obtain the next tokens from the original model logits. If `do_sample` is True, use multinomial # 3. Obtain the next tokens from the original model logits.
# sampling, otherwise use argmax.
if do_sample: if do_sample:
probs = new_logits[:, -candidate_length - 1 :, :].softmax(dim=-1) probs = new_logits[:, -candidate_length - 1 :, :].softmax(dim=-1)
sampled_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :] selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :]
next_tokens = sampled_tokens[:, :-1]
else: else:
next_tokens = new_logits[:, -candidate_length - 1 : -1, :].argmax(dim=-1) selected_tokens = new_logits[:, -candidate_length - 1 :, :].argmax(dim=-1)
# 4. Compare the argmax from the original model logits with the assistant forecasted tokens. We can keep # 4. Compare the argmax from the original model logits with the assistant forecasted tokens. We can keep
# the assistant forecasted tokens until the first mismatch, or until the max length is reached. # the assistant forecasted tokens until the first mismatch, or until the max length is reached.
candidate_new_tokens = candidate_input_ids[:, -candidate_length:] candidate_new_tokens = candidate_input_ids[:, -candidate_length:]
n_matches = ((~(candidate_new_tokens == next_tokens)).cumsum(dim=-1) < 1).sum() n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum()
# 5. Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic, # 5. Update variables according to the number of matching assistant tokens. Remember: the token generated
# probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the # by the model after the last candidate match is also valid, as it is generated from a correct sequence.
# cost of forecasting incorrect assistant tokens. # Because of this last token, assisted generation search reduces to a normal greedy search/sample if there
if n_matches == int(assistant_model.max_assistant_tokens): # is no match.
assistant_model.max_assistant_tokens += 2.0
else:
assistant_model.max_assistant_tokens = max(1.0, assistant_model.max_assistant_tokens - 1.0)
# 6. Update variables according to the number of matching assistant tokens. # 5.1. Ensure we don't generate beyond max_len or an EOS token
# 6.1. Ensure we don't generate beyond max_len or an EOS token (remember: one token will be added below)
n_matches = min(n_matches, max_len - cur_len - 1)
if last_assistant_token_is_eos and n_matches == candidate_length: if last_assistant_token_is_eos and n_matches == candidate_length:
n_matches -= 1 n_matches -= 1
input_ids = candidate_input_ids[:, 0 : cur_len + n_matches] n_matches = min(n_matches, max_len - cur_len - 1)
new_cur_len = input_ids.shape[-1]
# 5.2. Get the valid continuation, after the matching tokens
valid_tokens = selected_tokens[:, : n_matches + 1]
input_ids = torch.cat((input_ids, valid_tokens), dim=-1)
if streamer is not None: if streamer is not None:
streamer.put(candidate_input_ids[:, cur_len : cur_len + n_matches]) streamer.put(valid_tokens.cpu())
new_cur_len = input_ids.shape[-1]
# 6.2. Discard past key values relative to unused assistant tokens # 5.3. Discard past key values relative to unused assistant tokens
outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cur_len) new_cache_size = new_cur_len - 1
outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size)
model_kwargs["assistant_past_key_values"] = _crop_past_key_values( model_kwargs["assistant_past_key_values"] = _crop_past_key_values(
assistant_model, model_kwargs["assistant_past_key_values"], new_cur_len assistant_model, model_kwargs["assistant_past_key_values"], new_cache_size - 1
) ) # the assistant does not have the token after the last match, hence the -1
# 6.3. Extract the logits for the next token # 6. Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic,
next_token_scores = new_logits[:, n_matches, :] # probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the
# cost of forecasting incorrect assistant tokens.
# 7. Use the set of logits after the last matching assistant token to obtain the next token. Note that, if n_matches == int(assistant_model.max_assistant_tokens):
# because of this step, assisted generation search reduces to a normal greedy search/sample if there is no assistant_model.max_assistant_tokens += 2.0
# match.
if do_sample:
probs = probs[:, n_matches, :]
next_tokens = sampled_tokens[:, n_matches]
else: else:
next_tokens = torch.argmax(next_token_scores, dim=-1) assistant_model.max_assistant_tokens = max(1.0, assistant_model.max_assistant_tokens - 1.0)
# Assistant: main logic end; Compared to greedy search/sample, the following (redundant) blocks were # Assistant: main logic end
# removed below: (1) model input preparation; (2) model forward pass; (3) score preparation; (4) model
# cache update.
if synced_gpus and this_peer_finished: if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need continue # don't waste resources running the code we don't need
...@@ -4407,20 +4399,20 @@ class GenerationMixin: ...@@ -4407,20 +4399,20 @@ class GenerationMixin:
scores += tuple(new_logits[:, i, :] for i in range(n_matches + 1)) scores += tuple(new_logits[:, i, :] for i in range(n_matches + 1))
if "past_key_values" not in model_kwargs: if "past_key_values" not in model_kwargs:
last_matching_idx = new_cur_len - 1 added_len = new_cur_len
else: else:
last_matching_idx = n_matches added_len = n_matches + 1
if output_attentions: if output_attentions:
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
cross_attentions = _split_model_outputs( cross_attentions = _split_model_outputs(
cross_attentions, outputs.cross_attentions, cur_len, last_matching_idx cross_attentions, outputs.cross_attentions, cur_len, added_len
) )
decoder_attentions = _split_model_outputs( decoder_attentions = _split_model_outputs(
decoder_attentions, decoder_attentions,
outputs.decoder_attentions, outputs.decoder_attentions,
cur_len, cur_len,
last_matching_idx, added_len,
is_decoder_attention=True, is_decoder_attention=True,
) )
else: else:
...@@ -4428,28 +4420,19 @@ class GenerationMixin: ...@@ -4428,28 +4420,19 @@ class GenerationMixin:
decoder_attentions, decoder_attentions,
outputs.attentions, outputs.attentions,
cur_len, cur_len,
last_matching_idx, added_len,
is_decoder_attention=True, is_decoder_attention=True,
) )
if output_hidden_states: if output_hidden_states:
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
decoder_hidden_states = _split_model_outputs( decoder_hidden_states = _split_model_outputs(
decoder_hidden_states, outputs.decoder_hidden_states, cur_len, last_matching_idx decoder_hidden_states, outputs.decoder_hidden_states, cur_len, added_len
) )
else: else:
decoder_hidden_states = _split_model_outputs( decoder_hidden_states = _split_model_outputs(
decoder_hidden_states, outputs.hidden_states, cur_len, last_matching_idx decoder_hidden_states, outputs.hidden_states, cur_len, added_len
) )
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if streamer is not None:
streamer.put(next_tokens.cpu())
model_kwargs = self._update_model_kwargs_for_generation( model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
) )
...@@ -4457,7 +4440,10 @@ class GenerationMixin: ...@@ -4457,7 +4440,10 @@ class GenerationMixin:
# if eos_token was found in one sentence, set sentence to finished # if eos_token was found in one sentence, set sentence to finished
if eos_token_id_tensor is not None: if eos_token_id_tensor is not None:
unfinished_sequences = unfinished_sequences.mul( unfinished_sequences = unfinished_sequences.mul(
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) input_ids[:, -1]
.tile(eos_token_id_tensor.shape[0], 1)
.ne(eos_token_id_tensor.unsqueeze(1))
.prod(dim=0)
) )
# stop when each sentence is finished # stop when each sentence is finished
...@@ -4531,7 +4517,7 @@ def _crop_past_key_values(model, past_key_values, maximum_length): ...@@ -4531,7 +4517,7 @@ def _crop_past_key_values(model, past_key_values, maximum_length):
return past_key_values return past_key_values
def _split_model_outputs(outputs, new_outputs, previous_cur_len, last_matching_idx, is_decoder_attention=False): def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_attention=False):
""" """
Given the (decoder/cross attentions)/(decoder hidden states) for multiple generated tokens, splits it into a tuple Given the (decoder/cross attentions)/(decoder hidden states) for multiple generated tokens, splits it into a tuple
where each member corresponds to a single generated token. where each member corresponds to a single generated token.
...@@ -4541,16 +4527,17 @@ def _split_model_outputs(outputs, new_outputs, previous_cur_len, last_matching_i ...@@ -4541,16 +4527,17 @@ def _split_model_outputs(outputs, new_outputs, previous_cur_len, last_matching_i
if len(outputs) == 0: if len(outputs) == 0:
new_tuple = () new_tuple = ()
for layer in new_outputs: for layer in new_outputs:
last_dim_size = previous_cur_len if is_decoder_attention else layer.shape[-1] last_dim_size = cur_len if is_decoder_attention else layer.shape[-1]
new_tuple += (layer[..., :previous_cur_len, :last_dim_size],) new_tuple += (layer[..., :cur_len, :last_dim_size],)
outputs += (new_tuple,) outputs += (new_tuple,)
last_matching_idx -= previous_cur_len # The first iteration contains the prompt + 1 generated token, let's update the length variables accordingly
previous_cur_len += 1 cur_len += 1
added_len -= cur_len
for i in range(last_matching_idx + 1): for i in range(added_len):
new_tuple = () new_tuple = ()
for layer in new_outputs: for layer in new_outputs:
last_dim_size = previous_cur_len + i if is_decoder_attention else layer.shape[-1] last_dim_size = cur_len + i if is_decoder_attention else layer.shape[-1]
new_tuple += (layer[..., i : i + 1, :last_dim_size],) new_tuple += (layer[..., i : i + 1, :last_dim_size],)
outputs += (new_tuple,) outputs += (new_tuple,)
return outputs return outputs
......
...@@ -1518,8 +1518,8 @@ class GenerationTesterMixin: ...@@ -1518,8 +1518,8 @@ class GenerationTesterMixin:
self._check_outputs(output, input_ids, model.config, use_cache=True) self._check_outputs(output, input_ids, model.config, use_cache=True)
def test_assisted_decoding_sample(self): def test_assisted_decoding_sample(self):
# Seeded assisted decoding will not match sample for the same seed, as there are >1 sampling steps per output # Seeded assisted decoding will not match sample for the same seed, as the forward pass does not return the
# token. As such, this test only checks that the output format is correct. # exact same logits (the forward pass of the main model, now with several tokens at once, has causal masking).
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
# won't fix: FSMT and Reformer have a different cache variable type (and format). # won't fix: FSMT and Reformer have a different cache variable type (and format).
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment