Unverified Commit 849367cc authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: prepare assisted generation for release (#23052)

parent dfeb5aa6
......@@ -4177,7 +4177,6 @@ class GenerationMixin:
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
["It might be possible to get a better understanding of the nature of the problem, but it's not"]
```"""
# NOTE: the code here is copy/paste from greedy search/sample, except when clearly stated in the comments
# Assistant: initialize assistant-related variables
if not hasattr(assistant_model, "max_assistant_tokens"):
assistant_model.max_assistant_tokens = 5 # this value, which will be updated, persists across calls
......@@ -4248,20 +4247,20 @@ class GenerationMixin:
prev_seq_len = model_kwargs["assistant_past_key_values"][0][0].shape[2]
# `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model)
new_token_len = candidate_input_ids.shape[1] - prev_seq_len
tmp_inputs = candidate_input_ids[:, -new_token_len:]
tmp_attn = torch.ones_like(candidate_input_ids)
assist_inputs = candidate_input_ids[:, -new_token_len:]
assist_attn = torch.ones_like(candidate_input_ids)
# TODO (joao): make it compatible with models that use unconventional fwd pass logic, like blip2
if assistant_model.config.is_encoder_decoder:
assistant_model_outputs = assistant_model(
decoder_input_ids=tmp_inputs,
decoder_attention_mask=tmp_attn,
decoder_input_ids=assist_inputs,
decoder_attention_mask=assist_attn,
past_key_values=model_kwargs["assistant_past_key_values"],
encoder_outputs=model_kwargs["assistant_encoder_outputs"],
)
else:
assistant_model_outputs = assistant_model(
tmp_inputs,
attention_mask=tmp_attn,
assist_inputs,
attention_mask=assist_attn,
past_key_values=model_kwargs["assistant_past_key_values"],
)
else:
......@@ -4296,16 +4295,17 @@ class GenerationMixin:
candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
# 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain
# `candidate_length + 1` relevant logits from this process (see step 7 on why the +1)
# `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct,
# we use this forward pass to also pick the subsequent logits in the original model.
# 2.1. Run a forward pass on the candidate sequence
if "past_key_values" in model_kwargs:
og_model_attn = torch.ones_like(candidate_input_ids)
og_model_input_ids = candidate_input_ids[:, -candidate_length - 1 :]
model_attn = torch.ones_like(candidate_input_ids)
model_input_ids = candidate_input_ids[:, -candidate_length - 1 :]
if self.config.is_encoder_decoder:
outputs = self(
decoder_input_ids=og_model_input_ids,
decoder_attention_mask=og_model_attn,
decoder_input_ids=model_input_ids,
decoder_attention_mask=model_attn,
past_key_values=model_kwargs["past_key_values"],
encoder_outputs=model_kwargs["encoder_outputs"],
output_attentions=output_attentions,
......@@ -4313,8 +4313,8 @@ class GenerationMixin:
)
else:
outputs = self(
og_model_input_ids,
attention_mask=og_model_attn,
model_input_ids,
attention_mask=model_attn,
past_key_values=model_kwargs["past_key_values"],
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
......@@ -4343,59 +4343,51 @@ class GenerationMixin:
for i in range(candidate_length):
new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])
# 3. Obtain the next tokens from the original model logits. If `do_sample` is True, use multinomial
# sampling, otherwise use argmax.
# 3. Obtain the next tokens from the original model logits.
if do_sample:
probs = new_logits[:, -candidate_length - 1 :, :].softmax(dim=-1)
sampled_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :]
next_tokens = sampled_tokens[:, :-1]
selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :]
else:
next_tokens = new_logits[:, -candidate_length - 1 : -1, :].argmax(dim=-1)
selected_tokens = new_logits[:, -candidate_length - 1 :, :].argmax(dim=-1)
# 4. Compare the argmax from the original model logits with the assistant forecasted tokens. We can keep
# the assistant forecasted tokens until the first mismatch, or until the max length is reached.
candidate_new_tokens = candidate_input_ids[:, -candidate_length:]
n_matches = ((~(candidate_new_tokens == next_tokens)).cumsum(dim=-1) < 1).sum()
n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum()
# 5. Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic,
# probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the
# cost of forecasting incorrect assistant tokens.
if n_matches == int(assistant_model.max_assistant_tokens):
assistant_model.max_assistant_tokens += 2.0
else:
assistant_model.max_assistant_tokens = max(1.0, assistant_model.max_assistant_tokens - 1.0)
# 5. Update variables according to the number of matching assistant tokens. Remember: the token generated
# by the model after the last candidate match is also valid, as it is generated from a correct sequence.
# Because of this last token, assisted generation search reduces to a normal greedy search/sample if there
# is no match.
# 6. Update variables according to the number of matching assistant tokens.
# 6.1. Ensure we don't generate beyond max_len or an EOS token (remember: one token will be added below)
n_matches = min(n_matches, max_len - cur_len - 1)
# 5.1. Ensure we don't generate beyond max_len or an EOS token
if last_assistant_token_is_eos and n_matches == candidate_length:
n_matches -= 1
input_ids = candidate_input_ids[:, 0 : cur_len + n_matches]
new_cur_len = input_ids.shape[-1]
n_matches = min(n_matches, max_len - cur_len - 1)
# 5.2. Get the valid continuation, after the matching tokens
valid_tokens = selected_tokens[:, : n_matches + 1]
input_ids = torch.cat((input_ids, valid_tokens), dim=-1)
if streamer is not None:
streamer.put(candidate_input_ids[:, cur_len : cur_len + n_matches])
streamer.put(valid_tokens.cpu())
new_cur_len = input_ids.shape[-1]
# 6.2. Discard past key values relative to unused assistant tokens
outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cur_len)
# 5.3. Discard past key values relative to unused assistant tokens
new_cache_size = new_cur_len - 1
outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size)
model_kwargs["assistant_past_key_values"] = _crop_past_key_values(
assistant_model, model_kwargs["assistant_past_key_values"], new_cur_len
)
assistant_model, model_kwargs["assistant_past_key_values"], new_cache_size - 1
) # the assistant does not have the token after the last match, hence the -1
# 6.3. Extract the logits for the next token
next_token_scores = new_logits[:, n_matches, :]
# 7. Use the set of logits after the last matching assistant token to obtain the next token. Note that,
# because of this step, assisted generation search reduces to a normal greedy search/sample if there is no
# match.
if do_sample:
probs = probs[:, n_matches, :]
next_tokens = sampled_tokens[:, n_matches]
# 6. Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic,
# probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the
# cost of forecasting incorrect assistant tokens.
if n_matches == int(assistant_model.max_assistant_tokens):
assistant_model.max_assistant_tokens += 2.0
else:
next_tokens = torch.argmax(next_token_scores, dim=-1)
assistant_model.max_assistant_tokens = max(1.0, assistant_model.max_assistant_tokens - 1.0)
# Assistant: main logic end; Compared to greedy search/sample, the following (redundant) blocks were
# removed below: (1) model input preparation; (2) model forward pass; (3) score preparation; (4) model
# cache update.
# Assistant: main logic end
if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need
......@@ -4407,20 +4399,20 @@ class GenerationMixin:
scores += tuple(new_logits[:, i, :] for i in range(n_matches + 1))
if "past_key_values" not in model_kwargs:
last_matching_idx = new_cur_len - 1
added_len = new_cur_len
else:
last_matching_idx = n_matches
added_len = n_matches + 1
if output_attentions:
if self.config.is_encoder_decoder:
cross_attentions = _split_model_outputs(
cross_attentions, outputs.cross_attentions, cur_len, last_matching_idx
cross_attentions, outputs.cross_attentions, cur_len, added_len
)
decoder_attentions = _split_model_outputs(
decoder_attentions,
outputs.decoder_attentions,
cur_len,
last_matching_idx,
added_len,
is_decoder_attention=True,
)
else:
......@@ -4428,28 +4420,19 @@ class GenerationMixin:
decoder_attentions,
outputs.attentions,
cur_len,
last_matching_idx,
added_len,
is_decoder_attention=True,
)
if output_hidden_states:
if self.config.is_encoder_decoder:
decoder_hidden_states = _split_model_outputs(
decoder_hidden_states, outputs.decoder_hidden_states, cur_len, last_matching_idx
decoder_hidden_states, outputs.decoder_hidden_states, cur_len, added_len
)
else:
decoder_hidden_states = _split_model_outputs(
decoder_hidden_states, outputs.hidden_states, cur_len, last_matching_idx
decoder_hidden_states, outputs.hidden_states, cur_len, added_len
)
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if streamer is not None:
streamer.put(next_tokens.cpu())
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
......@@ -4457,7 +4440,10 @@ class GenerationMixin:
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id_tensor is not None:
unfinished_sequences = unfinished_sequences.mul(
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
input_ids[:, -1]
.tile(eos_token_id_tensor.shape[0], 1)
.ne(eos_token_id_tensor.unsqueeze(1))
.prod(dim=0)
)
# stop when each sentence is finished
......@@ -4531,7 +4517,7 @@ def _crop_past_key_values(model, past_key_values, maximum_length):
return past_key_values
def _split_model_outputs(outputs, new_outputs, previous_cur_len, last_matching_idx, is_decoder_attention=False):
def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_attention=False):
"""
Given the (decoder/cross attentions)/(decoder hidden states) for multiple generated tokens, splits it into a tuple
where each member corresponds to a single generated token.
......@@ -4541,16 +4527,17 @@ def _split_model_outputs(outputs, new_outputs, previous_cur_len, last_matching_i
if len(outputs) == 0:
new_tuple = ()
for layer in new_outputs:
last_dim_size = previous_cur_len if is_decoder_attention else layer.shape[-1]
new_tuple += (layer[..., :previous_cur_len, :last_dim_size],)
last_dim_size = cur_len if is_decoder_attention else layer.shape[-1]
new_tuple += (layer[..., :cur_len, :last_dim_size],)
outputs += (new_tuple,)
last_matching_idx -= previous_cur_len
previous_cur_len += 1
# The first iteration contains the prompt + 1 generated token, let's update the length variables accordingly
cur_len += 1
added_len -= cur_len
for i in range(last_matching_idx + 1):
for i in range(added_len):
new_tuple = ()
for layer in new_outputs:
last_dim_size = previous_cur_len + i if is_decoder_attention else layer.shape[-1]
last_dim_size = cur_len + i if is_decoder_attention else layer.shape[-1]
new_tuple += (layer[..., i : i + 1, :last_dim_size],)
outputs += (new_tuple,)
return outputs
......
......@@ -1518,8 +1518,8 @@ class GenerationTesterMixin:
self._check_outputs(output, input_ids, model.config, use_cache=True)
def test_assisted_decoding_sample(self):
# Seeded assisted decoding will not match sample for the same seed, as there are >1 sampling steps per output
# token. As such, this test only checks that the output format is correct.
# Seeded assisted decoding will not match sample for the same seed, as the forward pass does not return the
# exact same logits (the forward pass of the main model, now with several tokens at once, has causal masking).
for model_class in self.all_generative_model_classes:
# won't fix: FSMT and Reformer have a different cache variable type (and format).
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment