Unverified Commit ac974199 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: speculative decoding (#27979)



* speculative decoding

* fix test

* space

* better comments

* remove redundant test

* test nit

* Apply suggestions from code review
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* PR comments

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent bd7a3561
...@@ -4624,40 +4624,57 @@ class GenerationMixin: ...@@ -4624,40 +4624,57 @@ class GenerationMixin:
for i in range(candidate_length + 1): for i in range(candidate_length + 1):
new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]) new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])
# 3. Obtain the next tokens from the original model logits. # 3. Select the accepted tokens. There are two possible cases:
if do_sample: # Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding)
probs = new_logits.softmax(dim=-1) # 👉 Apply algorithm 1 from the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf).
selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :] max_matches = max_len - cur_len - 1
if do_sample and candidate_logits is not None:
next_sampled_tokens, n_matches = _speculative_sampling(
candidate_input_ids,
candidate_logits,
candidate_length,
new_logits,
last_assistant_token_is_eos,
max_matches,
)
# The selected tokens include the matches plus the next sampled tokens
selected_tokens = torch.cat((candidate_input_ids[:, :n_matches], next_sampled_tokens), dim=-1)
# Case 2: all other cases (originally from assisted generation) 👉 Compare the tokens selected from the
# original model logits with the candidate tokens. We can keep the candidate tokens until the first
# mismatch, or until the max length is reached.
else: else:
selected_tokens = new_logits.argmax(dim=-1) if do_sample:
probs = new_logits.softmax(dim=-1)
selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :]
else:
selected_tokens = new_logits.argmax(dim=-1)
# 4. Compare the argmax from the original model logits with the assistant forecasted tokens. We can keep candidate_new_tokens = candidate_input_ids[:, -candidate_length:]
# the assistant forecasted tokens until the first mismatch, or until the max length is reached. n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum()
candidate_new_tokens = candidate_input_ids[:, -candidate_length:]
n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum()
# 5. Update variables according to the number of matching assistant tokens. Remember: the token generated # Ensure we don't generate beyond max_len or an EOS token
if last_assistant_token_is_eos and n_matches == candidate_length:
n_matches -= 1
n_matches = min(n_matches, max_matches)
# 4. Update variables according to the number of matching assistant tokens. Remember: the token generated
# by the model after the last candidate match is also valid, as it is generated from a correct sequence. # by the model after the last candidate match is also valid, as it is generated from a correct sequence.
# Because of this last token, assisted generation search reduces to a normal greedy search/sample if there # Because of this last token, assisted generation search reduces to a normal greedy search/sample if there
# is no match. # is no match.
# 5.1. Ensure we don't generate beyond max_len or an EOS token # 4.1. Get the valid continuation, after the matching tokens
if last_assistant_token_is_eos and n_matches == candidate_length:
n_matches -= 1
n_matches = min(n_matches, max_len - cur_len - 1)
# 5.2. Get the valid continuation, after the matching tokens
valid_tokens = selected_tokens[:, : n_matches + 1] valid_tokens = selected_tokens[:, : n_matches + 1]
input_ids = torch.cat((input_ids, valid_tokens), dim=-1) input_ids = torch.cat((input_ids, valid_tokens), dim=-1)
if streamer is not None: if streamer is not None:
streamer.put(valid_tokens.cpu()) streamer.put(valid_tokens.cpu())
new_cur_len = input_ids.shape[-1] new_cur_len = input_ids.shape[-1]
# 5.3. Discard past key values relative to unused assistant tokens # 4.2. Discard past key values relative to unused assistant tokens
new_cache_size = new_cur_len - 1 new_cache_size = new_cur_len - 1
outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size) outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size)
# 6. Update the candidate generation strategy if needed # 5. Update the candidate generation strategy if needed
candidate_generator.update_candidate_strategy(input_ids, new_logits, n_matches) candidate_generator.update_candidate_strategy(input_ids, new_logits, n_matches)
if synced_gpus and this_peer_finished: if synced_gpus and this_peer_finished:
...@@ -4755,6 +4772,61 @@ class GenerationMixin: ...@@ -4755,6 +4772,61 @@ class GenerationMixin:
return input_ids return input_ids
def _speculative_sampling(
candidate_input_ids,
candidate_logits,
candidate_length,
new_logits,
last_assistant_token_is_eos,
max_matches,
):
"""
Applies sampling as in the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf, algorithm 1). Returns
the next selected token, as well as the number of candidate matches.
NOTE: Unless otherwise stated, the variable names match those in the paper.
"""
# Gets the probabilities from the logits. q_i and p_i denote the assistant and model probabilities of the tokens
# selected by the assistant, respectively.
q = candidate_logits.softmax(dim=-1)
q_i = q[
:,
torch.range(0, candidate_length - 1, dtype=torch.int),
candidate_input_ids[:, -candidate_length:],
].squeeze(0, 1)
p = new_logits.softmax(dim=-1)
p_i = p[
:,
torch.range(0, candidate_length - 1, dtype=torch.int),
candidate_input_ids[:, -candidate_length:],
].squeeze(0, 1)
probability_ratio = p_i / q_i
# When probability_ratio > 1 (i.e. q_i(x) < p_i(x), or "assistant probability of the candidate token is smaller
# than the model probability for the same token"), keep the token. Otherwise reject with p = 1 - probability_ratio
# (= keep with p = probability_ratio). Keep all the tokens until the first rejection
r_i = torch.rand_like(probability_ratio)
is_accepted = r_i <= probability_ratio
n_matches = (~is_accepted.cumsum(dim=-1) < 1).sum() # this is `n` in algorithm 1
# Ensure we don't generate beyond max_len or an EOS token (not in algorithm 1, but needed for correct behavior)
if last_assistant_token_is_eos and n_matches == candidate_length:
n_matches -= 1
n_matches = min(n_matches, max_matches)
# Next token selection: if there is a rejection, adjust the distribution from the main model before sampling.
gamma = candidate_logits.shape[1]
p_n_plus_1 = p[:, n_matches, :]
if n_matches < gamma:
q_n_plus_1 = q[:, n_matches, :]
p_prime = torch.clamp((p_n_plus_1 - q_n_plus_1), min=0).softmax(dim=-1)
else:
p_prime = p_n_plus_1
t = torch.multinomial(p_prime, num_samples=1).squeeze(1)[None, :]
return t, n_matches
def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_attention=False): def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_attention=False):
""" """
Given the (decoder/cross attentions)/(decoder hidden states) for multiple generated tokens, splits it into a tuple Given the (decoder/cross attentions)/(decoder hidden states) for multiple generated tokens, splits it into a tuple
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment