"vscode:/vscode.git/clone" did not exist on "bfd81766369a7e8650692350e436d83b1f625dcc"
Unverified Commit 4b759da8 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: `assisted_decoding` now accepts arbitrary candidate generators (#27750)


Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent e6604247
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import warnings
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import torch
if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel
from .logits_process import LogitsProcessorList
class CandidateGenerator:
"""Abstract base class for all candidate generators that can be applied during assisted generation."""
def get_candidates(self, input_ids: torch.LongTensor) -> torch.LongTensor:
"""
Fetches the candidates to be tried for the current input.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
Return:
`torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be assessed by
the model.
"""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can call `get_candidates`."
)
def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int):
"""
Updates the candidate generation strategy based on the outcomes.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`):
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using
beam search or log softmax for each vocabulary token when using beam search
num_matches (`int`):
The number of matches between the candidate sequences and the model predictions.
"""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can call "
"`update_candidate_strategy`."
)
class AssistedCandidateGenerator(CandidateGenerator):
"""
`CandidateGenerator` class to be used for assisted generation. This class generates candidates through the use of
a smaller model. Read the following blog post for more information: https://huggingface.co/blog/assisted-generation
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
assistant_model (`PreTrainedModel`):
The model to be used for generating candidates. This model should be smaller than the main model.
logits_processor (`LogitsProcessorList`):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
model_kwargs (`Dict`):
The keyword arguments that will be passed to the main model, and are used as base inputs for the assistant
model as well.
inputs_tensor (`torch.Tensor`, *optional*):
The model input tensor. In encoder-decoder models, this is the encoder input.
eos_token_id (`Union[int, List[int]]`, *optional*):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
"""
def __init__(
self,
input_ids: torch.LongTensor,
assistant_model: "PreTrainedModel",
logits_processor: "LogitsProcessorList",
model_kwargs: Dict,
inputs_tensor: Optional[torch.Tensor] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
):
self.assistant_model = assistant_model
# Prepare the number of candidate tokens
if hasattr(assistant_model, "num_assistant_tokens"):
warnings.warn(
"Setting `num_assistant_tokens` via `assistant_model.num_assistant_tokens` is deprecated and will be "
"removed in v4.37. Make sure to set `num_assistant_tokens` via the generation_config instead.",
FutureWarning,
)
self.num_assistant_tokens = assistant_model.num_assistant_tokens
else:
self.num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens
# Prepare the kwargs for the assistant model
assistant_kwargs = {}
for key, value in model_kwargs.items(): # deepcopy crashes if we attempt to copy encoder outputs with grads
if key not in ("encoder_outputs", "assistant_encoder_outputs"):
assistant_kwargs[key] = value.detach() if isinstance(value, torch.Tensor) else copy.deepcopy(value)
if "assistant_encoder_outputs" in model_kwargs:
assistant_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"]
elif assistant_model.config.is_encoder_decoder:
inputs_tensor, model_input_name, assistant_kwargs = assistant_model._prepare_model_inputs(
inputs_tensor, assistant_model.generation_config.bos_token_id, assistant_kwargs
)
assistant_kwargs = assistant_model._prepare_encoder_decoder_kwargs_for_generation(
inputs_tensor, assistant_kwargs, model_input_name
)
elif "encoder_outputs" in model_kwargs:
assistant_kwargs["encoder_outputs"] = model_kwargs["encoder_outputs"]
self.assistant_kwargs = assistant_kwargs
# Prepare assistant model's keys of inputs
if assistant_model.config.is_encoder_decoder:
# both are encoder-decoder
self.input_ids_key = "decoder_input_ids"
self.attention_key = "decoder_attention_mask"
elif "encoder_outputs" in assistant_kwargs:
# special case for encoder-decoder with decoder-only assistant (like DistilWhisper)
self.input_ids_key = "input_ids"
self.attention_key = "attention_mask"
self.assistant_kwargs["attention_mask"] = self.assistant_kwargs.get(
"decoder_attention_mask",
torch.ones((input_ids.shape[0], 1), device=input_ids.device, dtype=torch.long),
)
else:
# both are decoder-only
self.input_ids_key = "input_ids"
self.attention_key = "attention_mask"
# Prepare other attributes
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
self.eos_token_id_tensor = (
torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
)
self.logits_processor = logits_processor
def get_candidates(self, input_ids: torch.LongTensor) -> torch.LongTensor:
"""
Fetches the candidates to be tried for the current input.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
Return:
`torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried.
"""
# 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length
# (which implicitly contains the number of accepted candidates from the previous round)
has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None
if has_past_key_values:
new_cur_len = input_ids.shape[-1]
new_cache_size = new_cur_len - 1
self.assistant_kwargs["past_key_values"] = _crop_past_key_values(
self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1
) # the assistant does not have the token after the last match, hence the -1
self.assistant_kwargs = _prepare_attention_mask(
self.assistant_kwargs, new_cur_len, self.assistant_model.config.is_encoder_decoder
)
self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, new_cur_len)
# 2. Forecast next N tokens using the assistant model. This `for` block can be replaced with a `.generate()`
# call if we decide to add `past_key_values` as a possible output of generate, as we need access to the
# assistant cache to secure strong speedups.
candidate_input_ids = input_ids
for _ in range(int(self.num_assistant_tokens)):
# 2.1 prepare assistant model inputs
assistant_inputs = self.assistant_model.prepare_inputs_for_generation(
candidate_input_ids,
**self.assistant_kwargs,
)
# 2.2. check if the input ids length is correct
has_past_key_values = assistant_inputs.get("past_key_values", None) is not None
if has_past_key_values and assistant_inputs[self.input_ids_key].shape[-1] not in (1, 2):
raise ValueError("The length of the input ids in assistant inputs should be 1 or 2")
# 2.3. use the assistant model to obtain the next candidate logits
assistant_model_outputs = self.assistant_model(**assistant_inputs)
# 2.4. greedily select the next candidate token
if len(self.logits_processor) > 0:
assistant_model_outputs.logits[:, -1, :] = self.logits_processor(
candidate_input_ids, assistant_model_outputs.logits[:, -1, :]
)
new_token = assistant_model_outputs.logits[:, -1, :].argmax(dim=-1)
candidate_input_ids = torch.cat((candidate_input_ids, new_token[:, None]), dim=-1)
# 2.5. update assistant model inputs
if self.assistant_kwargs.get(self.attention_key, None) is not None:
mask = self.assistant_kwargs[self.attention_key]
self.assistant_kwargs[self.attention_key] = torch.cat(
[mask, mask.new_ones((mask.shape[0], 1))], dim=-1
)
self.assistant_kwargs["past_key_values"] = assistant_model_outputs.past_key_values
# 2.6. stop assistant generation on EOS
if self.eos_token_id_tensor is not None:
last_assistant_token_is_eos = new_token.tile(self.eos_token_id_tensor.shape[0], 1)
last_assistant_token_is_eos = (
~last_assistant_token_is_eos.ne(self.eos_token_id_tensor.unsqueeze(1)).prod(dim=0).bool()
)
if last_assistant_token_is_eos:
break
return candidate_input_ids
def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int):
"""
Updates the candidate generation strategy based on the outcomes.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`):
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using
beam search or log softmax for each vocabulary token when using beam search
num_matches (`int`):
The number of matches between the candidate sequences and the model predictions.
"""
# Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic,
# probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the
# cost of forecasting incorrect assistant tokens.
if self.assistant_model.generation_config.num_assistant_tokens_schedule == "heuristic":
if num_matches == int(self.num_assistant_tokens):
self.num_assistant_tokens += 2.0
else:
self.num_assistant_tokens = max(1.0, self.num_assistant_tokens - 1.0)
def _crop_past_key_values(model, past_key_values, maximum_length):
"""Crops the past key values up to a certain maximum length."""
new_past = []
if model.config.is_encoder_decoder:
for idx in range(len(past_key_values)):
new_past.append(
(
past_key_values[idx][0][:, :, :maximum_length, :],
past_key_values[idx][1][:, :, :maximum_length, :],
past_key_values[idx][2],
past_key_values[idx][3],
)
)
past_key_values = tuple(new_past)
# bloom is special
elif "bloom" in model.__class__.__name__.lower() or (
model.config.architectures is not None and "bloom" in model.config.architectures[0].lower()
):
for idx in range(len(past_key_values)):
new_past.append(
(
past_key_values[idx][0][:, :, :maximum_length],
past_key_values[idx][1][:, :maximum_length, :],
)
)
past_key_values = tuple(new_past)
# gptbigcode is too
elif "gptbigcode" in model.__class__.__name__.lower() or (
model.config.architectures is not None and "gptbigcode" in model.config.architectures[0].lower()
):
if model.config.multi_query:
for idx in range(len(past_key_values)):
past_key_values[idx] = past_key_values[idx][:, :maximum_length, :]
else:
for idx in range(len(past_key_values)):
past_key_values[idx] = past_key_values[idx][:, :, :maximum_length, :]
else:
for idx in range(len(past_key_values)):
new_past.append(
(
past_key_values[idx][0][:, :, :maximum_length, :],
past_key_values[idx][1][:, :, :maximum_length, :],
)
)
past_key_values = tuple(new_past)
return past_key_values
def _prepare_attention_mask(model_kwargs: Dict[str, Any], new_length: int, is_encoder_decoder: bool) -> Dict[str, Any]:
"""Expands or crops the model's mask for decoding purposes, to the defined length"""
mask_key = "decoder_attention_mask" if is_encoder_decoder else "attention_mask"
if mask_key not in model_kwargs:
return model_kwargs
mask = model_kwargs[mask_key]
mask_length_diff = new_length - mask.shape[1]
if mask_length_diff < 0:
model_kwargs[mask_key] = mask[:, :mask_length_diff]
elif mask_length_diff > 0:
model_kwargs[mask_key] = torch.cat([mask, mask.new_ones((mask.shape[0], mask_length_diff))], dim=-1)
return model_kwargs
def _prepare_token_type_ids(model_kwargs: Dict[str, Any], new_length: int) -> Dict[str, Any]:
"""Expands or crops the model's token_type_ids for decoding purposes, to the defined length"""
if "token_type_ids" not in model_kwargs or model_kwargs["token_type_ids"] is None:
return model_kwargs
token_type_ids = model_kwargs["token_type_ids"]
final_token_type = token_type_ids[:, -1].unsqueeze(-1)
type_length_diff = new_length - token_type_ids.shape[1]
if type_length_diff < 0:
token_type_ids = token_type_ids[:, :type_length_diff]
elif type_length_diff > 0:
token_type_copies = final_token_type.repeat(1, type_length_diff)
model_kwargs["token_type_ids"] = torch.cat([model_kwargs["token_type_ids"], token_type_copies], dim=-1)
return model_kwargs
......@@ -37,6 +37,13 @@ from ..models.auto import (
from ..utils import ExplicitEnum, ModelOutput, is_accelerate_available, logging
from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint
from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .candidate_generator import (
AssistedCandidateGenerator,
CandidateGenerator,
_crop_past_key_values,
_prepare_attention_mask,
_prepare_token_type_ids,
)
from .configuration_utils import GenerationConfig
from .logits_process import (
EncoderNoRepeatNGramLogitsProcessor,
......@@ -889,6 +896,28 @@ class GenerationMixin:
f" enable beam search for {self.__class__}"
)
def _get_candidate_generator(
self,
generation_config: GenerationConfig,
input_ids: torch.LongTensor,
inputs_tensor: torch.Tensor,
assistant_model: "PreTrainedModel",
logits_processor: LogitsProcessorList,
model_kwargs: Dict,
) -> CandidateGenerator:
"""
Returns the candidate generator to be used in `assisted_generation`
"""
candidate_generator = AssistedCandidateGenerator(
input_ids=input_ids,
assistant_model=assistant_model,
logits_processor=logits_processor,
model_kwargs=model_kwargs,
inputs_tensor=inputs_tensor,
eos_token_id=generation_config.eos_token_id,
)
return candidate_generator
def _get_logits_warper(
self,
generation_config: GenerationConfig,
......@@ -1671,36 +1700,20 @@ class GenerationMixin:
if not model_kwargs["use_cache"]:
raise ValueError("assisted generate requires `use_cache=True`")
assistant_accepts_encoder_outputs = "encoder_outputs" in set(
inspect.signature(assistant_model.forward).parameters.keys()
)
# 11. If the assistant model is an encoder-decoder, prepare its encoder outputs
if assistant_model.config.is_encoder_decoder and "assistant_encoder_outputs" not in model_kwargs:
assistant_model_kwargs = copy.deepcopy(model_kwargs)
inputs_tensor, model_input_name, assistant_model_kwargs = assistant_model._prepare_model_inputs(
inputs_tensor, assistant_model.generation_config.bos_token_id, assistant_model_kwargs
)
assistant_model_kwargs = assistant_model._prepare_encoder_decoder_kwargs_for_generation(
inputs_tensor, assistant_model_kwargs, model_input_name
# 11. Get the candidate generator, given the parameterization
candidate_generator = self._get_candidate_generator(
generation_config=generation_config,
input_ids=input_ids,
inputs_tensor=inputs_tensor,
assistant_model=assistant_model,
logits_processor=logits_processor,
model_kwargs=model_kwargs,
)
model_kwargs["assistant_encoder_outputs"] = assistant_model_kwargs["encoder_outputs"]
if (
not assistant_model.config.is_encoder_decoder
and assistant_accepts_encoder_outputs
and "encoder_outputs" in model_kwargs
):
# some assistants might be assymetric (many more enc layers than dec layers)
# encoder-decoder models that share the exact same encoder as the teacher
# in this case the assistant only needs to load the light-weight decoder,
# but still requires `encoder_outputs` to be passed
model_kwargs["assistant_encoder_outputs"] = model_kwargs["encoder_outputs"]
# 12. run assisted generate
return self.assisted_decoding(
input_ids,
assistant_model=assistant_model,
candidate_generator=candidate_generator,
do_sample=generation_config.do_sample,
logits_processor=logits_processor,
logits_warper=self._get_logits_warper(generation_config) if generation_config.do_sample else None,
......@@ -4378,7 +4391,8 @@ class GenerationMixin:
def assisted_decoding(
self,
input_ids: torch.LongTensor,
assistant_model: "PreTrainedModel",
assistant_model: Optional["PreTrainedModel"] = None,
candidate_generator: Optional["CandidateGenerator"] = None,
do_sample: bool = False,
logits_processor: Optional[LogitsProcessorList] = None,
logits_warper: Optional[LogitsProcessorList] = None,
......@@ -4395,12 +4409,13 @@ class GenerationMixin:
):
r"""
Generates sequences of token ids for models with a language modeling head using **greedy decoding** or
**sample** (depending on `do_sample`), assisted by a smaller model. Can be used for text-decoder, text-to-text,
speech-to-text, and vision-to-text models.
**sample** (depending on `do_sample`), assisted by candidate sequences. Assisted generation is an example of a
candidate decoding strategy. Can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text
models.
<Tip warning={true}>
In most cases, you do not need to call [`~generation.GenerationMixin.assisted_decoding`] directly. Use
In most cases, you do not need to call [`~generation.GenerationMixin.candidate_decoding`] directly. Use
generate() instead. For an overview of generation strategies and code examples, check the [following
guide](../generation_strategies).
......@@ -4409,6 +4424,9 @@ class GenerationMixin:
Parameters:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation.
candidate_generator (`CandidateGenerator`, *optional*):
A derived instance of [`CandidateGenerator`] that defines how candidate sequences are generated. For
more information, the documentation of [`CandidateGenerator`] should be read. Only one of `assistant_model` or `candidate_generator` should be passed as input to this function.
assistant_model (`PreTrainedModel`, *optional*):
An assistant model that can be used to accelerate generation. The assistant model must have the exact
same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model
......@@ -4491,15 +4509,23 @@ class GenerationMixin:
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
["It might be possible to get a better understanding of the nature of the problem, but it's not"]
```"""
# Assistant: initialize assistant-related variables
if hasattr(assistant_model, "num_assistant_tokens"):
# handling deprecated arguments
if (assistant_model is None) == (candidate_generator is None):
raise ValueError("One (and only one) of `assistant_model` and `candidate_generator` should be defined.")
if assistant_model is not None:
candidate_generator = AssistedCandidateGenerator(
input_ids=input_ids,
assistant_model=assistant_model,
logits_processor=logits_processor,
model_kwargs=model_kwargs,
eos_token_id=eos_token_id,
)
warnings.warn(
"Setting `num_assistant_tokens` via `assistant_model.num_assistant_tokens` is deprecated and will be removed in v.37. Make sure to set `num_assistant_tokens` via the generation_config instead.",
"Passing `assistant_model` to `assisted_decoding` is deprecated and will be removed in v4.38. "
"Pass the `candidate_generator` argument instead.",
FutureWarning,
)
num_assistant_tokens = assistant_model.num_assistant_tokens
else:
num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
......@@ -4538,27 +4564,6 @@ class GenerationMixin:
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)
# prepare assistant model's keys of inputs
assistant_kwargs = copy.copy(model_kwargs)
if assistant_model.config.is_encoder_decoder:
# both are encoder-decoder
input_ids_key = "decoder_input_ids"
attention_key = "decoder_attention_mask"
assistant_kwargs["encoder_outputs"] = assistant_kwargs.pop("assistant_encoder_outputs")
elif "assistant_encoder_outputs" in assistant_kwargs:
# special case for encoder-decoder with decoder-only assistant (like DistilWhisper)
input_ids_key = "input_ids"
attention_key = "attention_mask"
assistant_kwargs["attention_mask"] = assistant_kwargs.get(
"decoder_attention_mask",
torch.ones((input_ids.shape[0], 1), device=input_ids.device, dtype=torch.long),
)
assistant_kwargs["encoder_outputs"] = assistant_kwargs.pop("assistant_encoder_outputs")
else:
# both are decoder-only
input_ids_key = "input_ids"
attention_key = "attention_mask"
# keep track of which sequences are already finished
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
......@@ -4577,54 +4582,18 @@ class GenerationMixin:
if this_peer_finished_flag.item() == 0.0:
break
# Assistant: main logic start
cur_len = input_ids.shape[-1]
# 1. Forecast next N tokens using the assistant model. This `for` block can be replaced with a
# `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we
# need access to the assistant cache to secure strong speedups.
candidate_input_ids = input_ids
for _ in range(int(num_assistant_tokens)):
# 1.1 prepare assistant model inputs
assistant_inputs = assistant_model.prepare_inputs_for_generation(
candidate_input_ids,
**assistant_kwargs,
)
# 1.2. check if the input ids length is correct
has_past_key_values = assistant_inputs.get("past_key_values", None) is not None
if has_past_key_values and assistant_inputs[input_ids_key].shape[-1] not in (1, 2):
raise ValueError("The length of the input ids in assistant inputs should be 1 or 2")
# 1.3. use the assistant model to obtain the next candidate logits
assistant_model_outputs = assistant_model(**assistant_inputs)
# 1.4. greedily select the next candidate token
if len(logits_processor) > 0:
assistant_model_outputs.logits[:, -1, :] = logits_processor(
candidate_input_ids, assistant_model_outputs.logits[:, -1, :]
)
new_token = assistant_model_outputs.logits[:, -1, :].argmax(dim=-1)
candidate_input_ids = torch.cat((candidate_input_ids, new_token[:, None]), dim=-1)
# 1.5. update assistant model inputs
if assistant_kwargs.get(attention_key, None) is not None:
mask = assistant_kwargs[attention_key]
assistant_kwargs[attention_key] = torch.cat([mask, mask.new_ones((mask.shape[0], 1))], dim=-1)
assistant_kwargs["past_key_values"] = assistant_model_outputs.past_key_values
# 1.6. stop assistant generation on EOS
if eos_token_id_tensor is not None:
last_assistant_token_is_eos = new_token.tile(eos_token_id_tensor.shape[0], 1)
# 1. Fetch candidate sequences from a `CandidateGenerator`
candidate_input_ids = candidate_generator.get_candidates(input_ids)
candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
last_assistant_token_is_eos = (
~last_assistant_token_is_eos.ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0).bool()
~candidate_input_ids[:, -1]
.tile(eos_token_id_tensor.shape[0], 1)
.ne(eos_token_id_tensor.unsqueeze(1))
.prod(dim=0)
.bool()
)
if last_assistant_token_is_eos:
break
else:
last_assistant_token_is_eos = False
candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
# 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain
# `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct,
......@@ -4687,20 +4656,10 @@ class GenerationMixin:
# 5.3. Discard past key values relative to unused assistant tokens
new_cache_size = new_cur_len - 1
outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size)
assistant_kwargs["past_key_values"] = _crop_past_key_values(
assistant_model, assistant_kwargs["past_key_values"], new_cache_size - 1
) # the assistant does not have the token after the last match, hence the -1
# 6. Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic,
# probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the
# cost of forecasting incorrect assistant tokens.
if assistant_model.generation_config.num_assistant_tokens_schedule == "heuristic":
if n_matches == int(num_assistant_tokens):
num_assistant_tokens += 2.0
else:
num_assistant_tokens = max(1.0, num_assistant_tokens - 1.0)
# Assistant: main logic end
# 6. Update the candidate generation strategy if needed
candidate_generator.update_candidate_strategy(input_ids, new_logits, n_matches)
if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need
......@@ -4749,12 +4708,6 @@ class GenerationMixin:
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
# Update assistant_kwargs for the assistant's next round of generations
assistant_kwargs = _prepare_attention_mask(
assistant_kwargs, new_cur_len, assistant_model.config.is_encoder_decoder
)
assistant_kwargs = _prepare_token_type_ids(assistant_kwargs, new_cur_len)
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id_tensor is not None:
unfinished_sequences = unfinished_sequences.mul(
......@@ -4802,54 +4755,6 @@ class GenerationMixin:
return input_ids
def _crop_past_key_values(model, past_key_values, maximum_length):
"""Crops the past key values up to a certain maximum length."""
new_past = []
if model.config.is_encoder_decoder:
for idx in range(len(past_key_values)):
new_past.append(
(
past_key_values[idx][0][:, :, :maximum_length, :],
past_key_values[idx][1][:, :, :maximum_length, :],
past_key_values[idx][2],
past_key_values[idx][3],
)
)
past_key_values = tuple(new_past)
# bloom is special
elif "bloom" in model.__class__.__name__.lower() or (
model.config.architectures is not None and "bloom" in model.config.architectures[0].lower()
):
for idx in range(len(past_key_values)):
new_past.append(
(
past_key_values[idx][0][:, :, :maximum_length],
past_key_values[idx][1][:, :maximum_length, :],
)
)
past_key_values = tuple(new_past)
# gptbigcode is too
elif "gptbigcode" in model.__class__.__name__.lower() or (
model.config.architectures is not None and "gptbigcode" in model.config.architectures[0].lower()
):
if model.config.multi_query:
for idx in range(len(past_key_values)):
past_key_values[idx] = past_key_values[idx][:, :maximum_length, :]
else:
for idx in range(len(past_key_values)):
past_key_values[idx] = past_key_values[idx][:, :, :maximum_length, :]
else:
for idx in range(len(past_key_values)):
new_past.append(
(
past_key_values[idx][0][:, :, :maximum_length, :],
past_key_values[idx][1][:, :, :maximum_length, :],
)
)
past_key_values = tuple(new_past)
return past_key_values
def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_attention=False):
"""
Given the (decoder/cross attentions)/(decoder hidden states) for multiple generated tokens, splits it into a tuple
......@@ -4932,37 +4837,3 @@ def _ranking_fast(
contrastive_score = torch.stack(torch.split(contrastive_score, beam_width)) # [B, K]
_, selected_idx = contrastive_score.max(dim=-1) # [B]
return selected_idx
def _prepare_attention_mask(model_kwargs: Dict[str, Any], new_length: int, is_encoder_decoder: bool) -> Dict[str, Any]:
"""Expands or crops the model's mask for decoding purposes, to the defined length"""
mask_key = "decoder_attention_mask" if is_encoder_decoder else "attention_mask"
if mask_key not in model_kwargs:
return model_kwargs
mask = model_kwargs[mask_key]
mask_length_diff = new_length - mask.shape[1]
if mask_length_diff < 0:
model_kwargs[mask_key] = mask[:, :mask_length_diff]
elif mask_length_diff > 0:
model_kwargs[mask_key] = torch.cat([mask, mask.new_ones((mask.shape[0], mask_length_diff))], dim=-1)
return model_kwargs
def _prepare_token_type_ids(model_kwargs: Dict[str, Any], new_length: int) -> Dict[str, Any]:
"""Expands or crops the model's token_type_ids for decoding purposes, to the defined length"""
if "token_type_ids" not in model_kwargs or model_kwargs["token_type_ids"] is None:
return model_kwargs
token_type_ids = model_kwargs["token_type_ids"]
final_token_type = token_type_ids[:, -1].unsqueeze(-1)
type_length_diff = new_length - token_type_ids.shape[1]
if type_length_diff < 0:
token_type_ids = token_type_ids[:, :type_length_diff]
elif type_length_diff > 0:
token_type_copies = final_token_type.repeat(1, type_length_diff)
model_kwargs["token_type_ids"] = torch.cat([model_kwargs["token_type_ids"], token_type_copies], dim=-1)
return model_kwargs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment