Unverified Commit 23db187d authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: handle `cache_position` update in `generate` (#29467)

parent 7b87ecb0
......@@ -4,6 +4,10 @@ from typing import Any, Dict, List, Optional, Tuple
import torch
from .configuration_utils import PretrainedConfig
from .utils import logging
logger = logging.get_logger(__name__)
@dataclass
......@@ -57,6 +61,17 @@ class Cache:
return max_length - new_seq_length
return previous_seq_length
@property
def seen_tokens(self):
logger.warning_once(
"The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` "
"model input instead."
)
if hasattr(self, "_seen_tokens"):
return self._seen_tokens
else:
return None
class DynamicCache(Cache):
"""
......@@ -69,7 +84,7 @@ class DynamicCache(Cache):
def __init__(self) -> None:
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
"""
......@@ -121,7 +136,7 @@ class DynamicCache(Cache):
"""
# Update the number of seen tokens
if layer_idx == 0:
self.seen_tokens += key_states.shape[-2]
self._seen_tokens += key_states.shape[-2]
# Update the cache
if len(self.key_cache) <= layer_idx:
......@@ -191,7 +206,7 @@ class SinkCache(Cache):
self.window_length = window_length
self.num_sink_tokens = num_sink_tokens
self.cos_sin_cache = {}
self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
@staticmethod
def _rotate_half(x):
......@@ -272,7 +287,7 @@ class SinkCache(Cache):
# Update the number of seen tokens
if layer_idx == 0:
self.seen_tokens += key_states.shape[-2]
self._seen_tokens += key_states.shape[-2]
# [bsz, num_heads, seq_len, head_dim]
if len(self.key_cache) <= layer_idx:
......@@ -398,16 +413,11 @@ class StaticCache(Cache):
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC"""
# TODO: Fix once the stateful `int` bug in PyTorch is fixed.
raise ValueError(
"get_seq_length is not implemented for StaticCache. Please refer to https://github.com/huggingface/transformers/pull/29114."
)
def get_usable_length(self, new_sequence_length=None, layer_idx: Optional[int] = 0) -> int:
# TODO: Fix once the stateful `int` bug in PyTorch is fixed.
raise ValueError(
"get_seq_length is not implemented for StaticCache. Please refer to https://github.com/huggingface/transformers/pull/29114."
)
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
# limit the check to the first batch member and head dimension.
# TODO: This is error prone, a filled cache may be `0.0`. Let's use a stateless integer instead, after
# https://github.com/pytorch/pytorch/issues/120248 is fixed
return (self.key_cache[0, 0].any(dim=-1)).sum()
def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
......
......@@ -633,7 +633,6 @@ class GenerationMixin:
model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False,
standardize_cache_format: bool = False,
model_inputs: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
# update past_key_values
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
......@@ -663,7 +662,8 @@ class GenerationMixin:
dim=-1,
)
model_kwargs["cache_position"] = model_inputs.get("cache_position", None)
if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None:
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
return model_kwargs
......@@ -1931,10 +1931,15 @@ class GenerationMixin:
)
# keep track of which sequences are already finished
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
batch_size, cur_len = (
model_kwargs["attention_mask"].shape
if model_kwargs.get("attention_mask", None) is not None
else input_ids.shape
)
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
this_peer_finished = False # used by synced_gpus only
batch_size = input_ids.shape[0]
while True:
if synced_gpus:
......@@ -1975,7 +1980,6 @@ class GenerationMixin:
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
standardize_cache_format=True,
model_inputs=model_inputs,
)
if not sequential:
# Expands model inputs top_k times, for batched forward passes (akin to beam search).
......@@ -2170,7 +2174,9 @@ class GenerationMixin:
if streamer is not None:
streamer.put(next_tokens.cpu())
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
# if eos_token was found in one sentence, set sentence to finished
......@@ -2389,7 +2395,13 @@ class GenerationMixin:
)
# keep track of which sequences are already finished
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
batch_size, cur_len = (
model_kwargs["attention_mask"].shape
if model_kwargs.get("attention_mask", None) is not None
else input_ids.shape
)
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
this_peer_finished = False # used by synced_gpus only
while True:
......@@ -2459,7 +2471,6 @@ class GenerationMixin:
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
model_inputs=model_inputs,
)
# if eos_token was found in one sentence, set sentence to finished
......@@ -2688,7 +2699,13 @@ class GenerationMixin:
)
# keep track of which sequences are already finished
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
batch_size, cur_len = (
model_kwargs["attention_mask"].shape
if model_kwargs.get("attention_mask", None) is not None
else input_ids.shape
)
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
this_peer_finished = False # used by synced_gpus only
# auto-regressive generation
......@@ -2758,7 +2775,9 @@ class GenerationMixin:
if streamer is not None:
streamer.put(next_tokens.cpu())
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
# if eos_token was found in one sentence, set sentence to finished
......@@ -3003,6 +3022,7 @@ class GenerationMixin:
num_beams = beam_scorer.num_beams
batch_beam_size, cur_len = input_ids.shape
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
if num_beams * batch_size != batch_beam_size:
raise ValueError(
......@@ -3156,7 +3176,9 @@ class GenerationMixin:
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
if model_kwargs.get("past_key_values", None) is not None:
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
......@@ -3397,6 +3419,7 @@ class GenerationMixin:
num_beams = beam_scorer.num_beams
batch_beam_size, cur_len = input_ids.shape
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
......@@ -3510,7 +3533,9 @@ class GenerationMixin:
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
if model_kwargs.get("past_key_values", None) is not None:
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
......@@ -3747,6 +3772,7 @@ class GenerationMixin:
device = input_ids.device
batch_beam_size, cur_len = input_ids.shape
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
if return_dict_in_generate and output_scores:
beam_indices = [tuple(() for _ in range(num_sub_beams * batch_size)) for _ in range(num_beam_groups)]
......@@ -3916,7 +3942,9 @@ class GenerationMixin:
input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
if model_kwargs.get("past_key_values", None) is not None:
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
......@@ -4155,6 +4183,7 @@ class GenerationMixin:
num_beams = constrained_beam_scorer.num_beams
batch_beam_size, cur_len = input_ids.shape
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
if num_beams * batch_size != batch_beam_size:
raise ValueError(
......@@ -4275,7 +4304,9 @@ class GenerationMixin:
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
if model_kwargs.get("past_key_values", None) is not None:
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
......@@ -4511,7 +4542,13 @@ class GenerationMixin:
)
# keep track of which sequences are already finished
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
batch_size, cur_len = batch_size, cur_len = (
model_kwargs["attention_mask"].shape
if model_kwargs.get("attention_mask", None) is not None
else input_ids.shape
)
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
# other auxiliary variables
max_len = stopping_criteria[0].max_length
......@@ -4555,6 +4592,14 @@ class GenerationMixin:
candidate_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder
)
candidate_kwargs = _prepare_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1])
if "cache_position" in candidate_kwargs:
candidate_kwargs["cache_position"] = torch.cat(
(
candidate_kwargs["cache_position"],
torch.arange(cur_len, cur_len + candidate_length, device=input_ids.device, dtype=torch.long),
),
dim=0,
)
model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs)
......@@ -4673,7 +4718,9 @@ class GenerationMixin:
)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
# if eos_token was found in one sentence, set sentence to finished
......
......@@ -256,7 +256,7 @@ class GemmaAttention(nn.Module):
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
if past_key_value is not None:
# sin and cos are specific to RoPE models; position_ids needed for the static cache
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
......@@ -343,7 +343,7 @@ class GemmaFlashAttention2(GemmaAttention):
past_key_value = getattr(self, "past_key_value", past_key_value)
if past_key_value is not None:
# sin and cos are specific to RoPE models; position_ids needed for the static cache
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
......@@ -542,7 +542,7 @@ class GemmaSdpaAttention(GemmaAttention):
past_key_value = getattr(self, "past_key_value", past_key_value)
if past_key_value is not None:
# sin and cos are specific to RoPE models; position_ids needed for the static cache
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
......@@ -791,6 +791,10 @@ GEMMA_INPUTS_DOCSTRING = r"""
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""
......@@ -1128,14 +1132,26 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs
):
# With static cache, the `past_key_values` is None
# TODO joao: standardize interface for the different Cache classes and remove of this if
has_static_cache = False
if past_key_values is None:
past_key_values = getattr(self.model.layers[0].self_attn, "past_key_value", None)
has_static_cache = past_key_values is not None
past_length = 0
if past_key_values is not None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
max_cache_length = (
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
if past_key_values.get_max_length() is not None
else None
)
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
......@@ -1168,22 +1184,6 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
if self.generation_config.cache_implementation == "static":
# generation with static cache
cache_position = kwargs.get("cache_position", None)
if cache_position is None:
past_length = 0
else:
past_length = cache_position[-1] + 1
input_ids = input_ids[:, past_length:]
position_ids = position_ids[:, past_length:]
# TODO @gante we should only keep a `cache_position` in generate, and do +=1.
# same goes for position ids. Could also help with continued generation.
input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
position_ids = position_ids.contiguous() if position_ids is not None else None
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
......@@ -1193,6 +1193,15 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
# TODO: use `next_tokens` directly instead.
model_inputs = {"input_ids": input_ids.contiguous()}
input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
if cache_position is None:
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
else:
cache_position = cache_position[-input_length:]
if has_static_cache:
past_key_values = None
model_inputs.update(
{
"position_ids": position_ids,
......
......@@ -1557,10 +1557,12 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False,
standardize_cache_format: bool = False,
model_inputs: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
model_kwargs = super()._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder, standardize_cache_format, model_inputs
outputs,
model_kwargs,
is_encoder_decoder,
standardize_cache_format,
)
if "image_attention_mask" in model_kwargs:
......
......@@ -361,7 +361,7 @@ class LlamaAttention(nn.Module):
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; position_ids needed for the static cache
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
......@@ -451,7 +451,7 @@ class LlamaFlashAttention2(LlamaAttention):
past_key_value = getattr(self, "past_key_value", past_key_value)
if past_key_value is not None:
# sin and cos are specific to RoPE models; position_ids needed for the static cache
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
......@@ -650,7 +650,7 @@ class LlamaSdpaAttention(LlamaAttention):
past_key_value = getattr(self, "past_key_value", past_key_value)
if past_key_value is not None:
# sin and cos are specific to RoPE models; position_ids needed for the static cache
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
......@@ -903,6 +903,10 @@ LLAMA_INPUTS_DOCSTRING = r"""
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""
......@@ -1240,14 +1244,26 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs
):
# With static cache, the `past_key_values` is None
# TODO joao: standardize interface for the different Cache classes and remove of this if
has_static_cache = False
if past_key_values is None:
past_key_values = getattr(self.model.layers[0].self_attn, "past_key_value", None)
has_static_cache = past_key_values is not None
past_length = 0
if past_key_values is not None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
max_cache_length = (
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
if past_key_values.get_max_length() is not None
else None
)
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
......@@ -1280,22 +1296,6 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
if self.generation_config.cache_implementation == "static":
# generation with static cache
cache_position = kwargs.get("cache_position", None)
if cache_position is None:
past_length = 0
else:
past_length = cache_position[-1] + 1
input_ids = input_ids[:, past_length:]
position_ids = position_ids[:, past_length:]
# TODO @gante we should only keep a `cache_position` in generate, and do +=1.
# same goes for position ids. Could also help with continued generation.
input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
position_ids = position_ids.contiguous() if position_ids is not None else None
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
......@@ -1305,6 +1305,15 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
# TODO: use `next_tokens` directly instead.
model_inputs = {"input_ids": input_ids.contiguous()}
input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
if cache_position is None:
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
else:
cache_position = cache_position[-input_length:]
if has_static_cache:
past_key_values = None
model_inputs.update(
{
"position_ids": position_ids,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment