"vscode:/vscode.git/clone" did not exist on "8a5e8a9c2a4dc54b4485994311de6b839976828c"
Unverified Commit 633215ba authored by Tom Aarsen's avatar Tom Aarsen Committed by GitHub
Browse files

Generate: New `Cache` abstraction and Attention Sinks support (#26681)

* Draft version of new KV Caching

This should allow Attention Sinks (https://github.com/tomaarsen/attention_sinks)
/ StreamingLLM (https://arxiv.org/abs/2309.17453) to be easily implemented
in a third-party or in transformers directly

* Address numerous PR suggestions

1. Move layer_idx from cache to ...Attention. Removes confusing set_layer_idx magic.
2. Always convert past_key_values to Cache instance at the start of ...Attention, removes all other isinstance calls.
3. Remove __bool__ and __getitem__ magic as they're confusing.
4. past_key_values.update(key, value, idx) now returns key, value.
5. Add use_legacy_cache flag, defaults to None, i.e. Falsey. This breaks generate for now, until 1) the cache is used is generate() or 2) use_legacy_cache is defaulted to True in generate() until we change it in another PR.
6. Separate key_cache and value_cache.

Some work is still needed to see if the SinkCache can conveniently be implemented with just one update method.

* Implement the SinkCache through backward+forward rotations

* Integrate (Sink)Cache with Llama FA2

* Set use_legacy_cache=True as default, allows for test passes

* Move from/to_legacy_cache to ...Model class

* Undo unnecessary newline change

* Remove copy utility from deprecated OpenLlama

* Match import style

* manual rebase with main

* Cache class working with generate (#1)

* Draft version of new KV Caching

This should allow Attention Sinks (https://github.com/tomaarsen/attention_sinks)
/ StreamingLLM (https://arxiv.org/abs/2309.17453

) to be easily implemented
in a third-party or in transformers directly

* Address numerous PR suggestions

1. Move layer_idx from cache to ...Attention. Removes confusing set_layer_idx magic.
2. Always convert past_key_values to Cache instance at the start of ...Attention, removes all other isinstance calls.
3. Remove __bool__ and __getitem__ magic as they're confusing.
4. past_key_values.update(key, value, idx) now returns key, value.
5. Add use_legacy_cache flag, defaults to None, i.e. Falsey. This breaks generate for now, until 1) the cache is used is generate() or 2) use_legacy_cache is defaulted to True in generate() until we change it in another PR.
6. Separate key_cache and value_cache.

Some work is still needed to see if the SinkCache can conveniently be implemented with just one update method.

* Integrate (Sink)Cache with Llama FA2

* Move from/to_legacy_cache to ...Model class

* Undo unnecessary newline change

* Match import style

* working generate

* Add tests; Simplify code; Apply changes to Mistral and Persimmon

* fix rebase mess

* a few more manual fixes

* last manual fix

* propagate changes to phi

* upgrade test

* add use_legacy_cache docstring; beef up tests

* reintroduce unwanted deletes

---------
Co-authored-by: default avatarTom Aarsen <Cubiegamedev@gmail.com>

* move import

* add default to model_kwargs.get('use_legacy_cache')

* correct failing test

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* apply PR suggestions

* fix failing test

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarTom Aarsen <37621491+tomaarsen@users.noreply.github.com>

* PR comments

* tmp commit

* add docstrings

* more tests, more docstrings, add to docs

* derp

* tmp commit

* tmp dbg

* more dbg

* fix beam search bug

* cache can be a list of tuples in some models

* fix group beam search

* all but sinkcache integration tests

* fix sink cache and add hard integration test

* now also compatible with input_embeds input

* PR comments

* add Cache support to Phi+FA2

* make fixup

---------
Co-authored-by: default avatarJoao Gante <joao@huggingface.co>
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 0ea42ef0
......@@ -368,3 +368,20 @@ A [`Constraint`] can be used to force the generation to include specific tokens
[[autodoc]] TextStreamer
[[autodoc]] TextIteratorStreamer
## Caches
[[autodoc]] Cache
- update
[[autodoc]] DynamicCache
- update
- get_seq_length
- reorder_cache
- to_legacy_cache
- from_legacy_cache
[[autodoc]] SinkCache
- update
- get_seq_length
- reorder_cache
......@@ -1303,6 +1303,7 @@ else:
_import_structure["activations"] = []
_import_structure["benchmark.benchmark"] = ["PyTorchBenchmark"]
_import_structure["benchmark.benchmark_args"] = ["PyTorchBenchmarkArguments"]
_import_structure["cache_utils"] = ["Cache", "DynamicCache", "SinkCache"]
_import_structure["data.datasets"] = [
"GlueDataset",
"GlueDataTrainingArguments",
......@@ -5945,6 +5946,7 @@ if TYPE_CHECKING:
# Benchmarks
from .benchmark.benchmark import PyTorchBenchmark
from .benchmark.benchmark_args import PyTorchBenchmarkArguments
from .cache_utils import Cache, DynamicCache, SinkCache
from .data.datasets import (
GlueDataset,
GlueDataTrainingArguments,
......
from typing import Any, Dict, List, Optional, Tuple
import torch
class Cache:
"""
Base, abstract class for all caches. The actual data structure is specific to each subclass.
"""
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
cache to be created.
Return:
A tuple containing the updated key and value states.
"""
raise NotImplementedError("Make sure to implement `update` in a subclass.")
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")
class DynamicCache(Cache):
"""
A cache that grows dynamically as more tokens are generated. This is the default for generative models.
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
`[batch_size, num_heads, seq_len, head_dim]`.
"""
def __init__(self) -> None:
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
"""
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
sequence length.
"""
if layer_idx < len(self):
return (self.key_cache[layer_idx], self.value_cache[layer_idx])
else:
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
def __iter__(self):
"""
Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
keys and values
"""
for layer_idx in range(len(self)):
yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
def __len__(self):
"""
Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
to the number of layers in the model.
"""
return len(self.key_cache)
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
Return:
A tuple containing the updated key and value states.
"""
# Update the number of seen tokens
if layer_idx == 0:
self.seen_tokens += key_states.shape[-2]
# Update the cache
if len(self.key_cache) <= layer_idx:
self.key_cache.append(key_states)
self.value_cache.append(value_states)
else:
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
return self.key_cache[layer_idx], self.value_cache[layer_idx]
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
if len(self.key_cache) <= layer_idx:
return 0
return self.key_cache[layer_idx].shape[-2]
def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
for layer_idx in range(len(self.key_cache)):
device = self.key_cache[layer_idx].device
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
device = self.value_cache[layer_idx].device
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
legacy_cache = ()
for layer_idx in range(len(self)):
legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
return legacy_cache
@classmethod
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
cache = cls()
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
key_states, value_states = past_key_values[layer_idx]
cache.update(key_states, value_states, layer_idx)
return cache
class SinkCache(Cache):
"""
A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
generate beyond the length of its context window, without losing fluency in the conversation. As it discards past
tokens, the model will lose the ability to generate tokens that depend on the context that was discarded.
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
`[batch_size, num_heads, seq_len, head_dim]`.
Parameters:
window_length (`int`):
The length of the context window.
num_sink_tokens (`int`):
The number of sink tokens. See the original paper for more information.
"""
def __init__(self, window_length: int, num_sink_tokens: int) -> None:
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
self.window_length = window_length
self.num_sink_tokens = num_sink_tokens
self.cos_sin_cache = {}
self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
@staticmethod
def _rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def _apply_key_rotary_pos_emb(
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> torch.Tensor:
rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin)
return rotated_key_states
def _get_rerotation_cos_sin(
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
if key_states.shape[-2] not in self.cos_sin_cache:
# Upcast to float32 temporarily for better accuracy
cos = cos.to(torch.float32)
sin = sin.to(torch.float32)
# Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence
original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :]
shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]]
original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :]
shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]]
rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin
self.cos_sin_cache[key_states.shape[-2]] = (
rerotation_cos.to(key_states.dtype).unsqueeze(0),
rerotation_sin.to(key_states.dtype).unsqueeze(0),
)
return self.cos_sin_cache[key_states.shape[-2]]
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
# Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
if len(self.key_cache) <= layer_idx:
return 0
cache_length = self.key_cache[layer_idx].shape[-2]
return min(cache_length, self.window_length - 1)
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`,
`cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the
rotation as the tokens are shifted.
Return:
A tuple containing the updated key and value states.
"""
# Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
# with partially rotated position embeddings, like Phi or Persimmon.
sin = cache_kwargs.get("sin")
cos = cache_kwargs.get("cos")
partial_rotation_size = cache_kwargs.get("partial_rotation_size")
using_rope = cos is not None and sin is not None
# Update the number of seen tokens
if layer_idx == 0:
self.seen_tokens += key_states.shape[-2]
# [bsz, num_heads, seq_len, head_dim]
if len(self.key_cache) <= layer_idx:
# Empty cache
self.key_cache.append(key_states)
self.value_cache.append(value_states)
elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
# Growing cache
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
else:
# Shifting cache
keys_to_keep = self.key_cache[layer_idx][
:, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] :
]
# On RoPE models, we need to recompute the Key rotation as the tokens are shifted
if using_rope:
rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(key_states, cos, sin)
if partial_rotation_size is not None:
keys_to_keep, keys_pass = (
keys_to_keep[..., :partial_rotation_size],
keys_to_keep[..., partial_rotation_size:],
)
keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin)
if partial_rotation_size is not None:
keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1)
# Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens
sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens]
self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2)
sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens]
values_to_keep = self.value_cache[layer_idx][
:, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] :
]
self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2)
return self.key_cache[layer_idx], self.value_cache[layer_idx]
def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
for layer_idx in range(len(self.key_cache)):
device = self.key_cache[layer_idx].device
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
device = self.value_cache[layer_idx].device
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
......@@ -24,6 +24,7 @@ import torch
import torch.distributed as dist
from torch import nn
from ..cache_utils import Cache, DynamicCache
from ..integrations.deepspeed import is_deepspeed_zero3_enabled
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
from ..models.auto import (
......@@ -1287,6 +1288,13 @@ class GenerationMixin:
def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
# If a `Cache` instance is passed, checks whether the model is compatible with it
if isinstance(model_kwargs.get("past_key_values", None), Cache) and not self._supports_cache_class:
raise ValueError(
f"{self.__class__.__name__} does not support an instance of `Cache` as `past_key_values`. Please "
"check the model documentation for supported cache formats."
)
# Excludes arguments that are handled before calling any model function
if self.config.is_encoder_decoder:
for key in ["decoder_input_ids"]:
......@@ -2945,6 +2953,32 @@ class GenerationMixin:
else:
return input_ids
def _temporary_reorder_cache(self, past_key_values, beam_idx):
"""
Temporary function to handle the different types of cache reordering processes while we roll out `Cache`.
TODO: standardize cache formats and make all models compatible with `Cache`. It would remove the need
for this function, with `Cache.reorder_cache` being the sole remaining code path
"""
model_class = self.__class__.__name__.lower()
# Exception 1: code path for models using the legacy cache format
if isinstance(past_key_values, (tuple, list)):
past_key_values = self._reorder_cache(past_key_values, beam_idx)
# Exception 2: models with different cache formats. These are limited to `DynamicCache` until their
# cache format is standardized, to avoid adding complexity to the codebase.
elif "bloom" in model_class or "gptbigcode" in model_class:
if not isinstance(past_key_values, DynamicCache):
raise ValueError(
f"Using an unsupported cache format with {model_class}. Currently, it only supports the "
"legacy tuple format or `DynamicCache`"
)
past_key_values = self._reorder_cache(past_key_values, beam_idx)
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
# Standard code path: use the `Cache.reorder_cache`
else:
past_key_values.reorder_cache(beam_idx)
return past_key_values
def beam_search(
self,
input_ids: torch.LongTensor,
......@@ -3218,7 +3252,9 @@ class GenerationMixin:
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
if model_kwargs["past_key_values"] is not None:
model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
model_kwargs["past_key_values"], beam_idx
)
if return_dict_in_generate and output_scores:
beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))
......@@ -3553,7 +3589,9 @@ class GenerationMixin:
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
if model_kwargs["past_key_values"] is not None:
model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
model_kwargs["past_key_values"], beam_idx
)
if return_dict_in_generate and output_scores:
beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))
......@@ -3938,7 +3976,7 @@ class GenerationMixin:
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
if model_kwargs["past_key_values"] is not None:
model_kwargs["past_key_values"] = self._reorder_cache(
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
model_kwargs["past_key_values"], reordering_indices
)
......@@ -4280,7 +4318,9 @@ class GenerationMixin:
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
if model_kwargs["past_key_values"] is not None:
model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
model_kwargs["past_key_values"], beam_idx
)
if return_dict_in_generate and output_scores:
beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))
......
......@@ -1128,6 +1128,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Flash Attention 2 support
_supports_flash_attn_2 = False
# Has support for a `Cache` instance as `past_key_values`
_supports_cache_class = False
@property
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
"""
......
......@@ -860,7 +860,6 @@ class OpenLlamaForCausalLM(OpenLlamaPreTrainedModel):
""",
OPEN_LLAMA_START_DOCSTRING,
)
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->OPEN_LLAMA,Llama->OpenLlama
class OpenLlamaForSequenceClassification(OpenLlamaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
......
......@@ -29,6 +29,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...modeling_attn_mask_utils import (
AttentionMaskConverter,
_prepare_4d_attention_mask,
......@@ -283,9 +284,17 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
class LlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: LlamaConfig):
def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
......@@ -343,7 +352,7 @@ class LlamaAttention(nn.Module):
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
......@@ -383,16 +392,19 @@ class LlamaAttention(nn.Module):
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
......@@ -460,7 +472,7 @@ class LlamaFlashAttention2(LlamaAttention):
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
......@@ -491,18 +503,14 @@ class LlamaFlashAttention2(LlamaAttention):
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
......@@ -647,13 +655,13 @@ class LlamaFlashAttention2(LlamaAttention):
class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig):
def __init__(self, config: LlamaConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = (
LlamaAttention(config=config)
LlamaAttention(config=config, layer_idx=layer_idx)
if not getattr(config, "_flash_attn_2_enabled", False)
else LlamaFlashAttention2(config=config)
else LlamaFlashAttention2(config=config, layer_idx=layer_idx)
)
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......@@ -749,6 +757,7 @@ class LlamaPreTrainedModel(PreTrainedModel):
_no_split_modules = ["LlamaDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_cache_class = True
def _init_weights(self, module):
std = self.config.initializer_range
......@@ -797,13 +806,19 @@ LLAMA_INPUTS_DOCSTRING = r"""
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, decoder_sequence_length, embed_size_per_head)`.
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed:
- a [`~cache_utils.Cache`] instance;
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
legacy cache format will be returned.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
......@@ -844,7 +859,9 @@ class LlamaModel(LlamaPreTrainedModel):
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.layers = nn.ModuleList(
[LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
......@@ -889,8 +906,11 @@ class LlamaModel(LlamaPreTrainedModel):
raise ValueError("You have to specify either input_ids or inputs_embeds")
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_seq_length()
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
......@@ -924,21 +944,19 @@ class LlamaModel(LlamaPreTrainedModel):
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
next_decoder_cache = None
for idx, decoder_layer in enumerate(self.layers):
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_value,
past_key_values,
output_attentions,
use_cache,
)
......@@ -947,7 +965,7 @@ class LlamaModel(LlamaPreTrainedModel):
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
)
......@@ -955,7 +973,7 @@ class LlamaModel(LlamaPreTrainedModel):
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
......@@ -966,7 +984,9 @@ class LlamaModel(LlamaPreTrainedModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
......@@ -1047,7 +1067,6 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
......@@ -1105,16 +1124,28 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
cache_length = past_length = past_key_values[0][0].shape[2]
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
# input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
# older attention values, as their corresponding values are not part of the input.
if cache_length < past_length and attention_mask is not None:
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
......
......@@ -30,6 +30,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from ...modeling_utils import PreTrainedModel
......@@ -195,9 +196,17 @@ class MistralAttention(nn.Module):
and "Generating Long Sequences with Sparse Transformers".
"""
def __init__(self, config: MistralConfig):
def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
......@@ -232,7 +241,7 @@ class MistralAttention(nn.Module):
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
......@@ -253,16 +262,19 @@ class MistralAttention(nn.Module):
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
......@@ -327,7 +339,7 @@ class MistralFlashAttention2(MistralAttention):
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
......@@ -351,7 +363,7 @@ class MistralFlashAttention2(MistralAttention):
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
# Because the input can be padded, the absolute sequence length depends on the max position id.
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
......@@ -394,10 +406,8 @@ class MistralFlashAttention2(MistralAttention):
attention_mask = attention_mask[:, slicing_tokens:]
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
......@@ -592,13 +602,13 @@ class MistralFlashAttention2(MistralAttention):
class MistralDecoderLayer(nn.Module):
def __init__(self, config: MistralConfig):
def __init__(self, config: MistralConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = (
MistralAttention(config=config)
MistralAttention(config=config, layer_idx=layer_idx)
if not getattr(config, "_flash_attn_2_enabled", False)
else MistralFlashAttention2(config)
else MistralFlashAttention2(config, layer_idx=layer_idx)
)
self.mlp = MistralMLP(config)
self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......@@ -692,6 +702,7 @@ class MistralPreTrainedModel(PreTrainedModel):
_no_split_modules = ["MistralDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_cache_class = True
def _init_weights(self, module):
std = self.config.initializer_range
......@@ -740,17 +751,23 @@ MISTRAL_INPUTS_DOCSTRING = r"""
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed:
- a [`~cache_utils.Cache`] instance;
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
legacy cache format will be returned.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
......@@ -787,7 +804,9 @@ class MistralModel(MistralPreTrainedModel):
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList([MistralDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.layers = nn.ModuleList(
[MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
......@@ -834,8 +853,11 @@ class MistralModel(MistralPreTrainedModel):
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_seq_length()
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
......@@ -889,21 +911,19 @@ class MistralModel(MistralPreTrainedModel):
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
next_decoder_cache = None
for idx, decoder_layer in enumerate(self.layers):
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_value,
past_key_values,
output_attentions,
use_cache,
)
......@@ -912,7 +932,7 @@ class MistralModel(MistralPreTrainedModel):
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
)
......@@ -920,7 +940,7 @@ class MistralModel(MistralPreTrainedModel):
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
......@@ -931,7 +951,10 @@ class MistralModel(MistralPreTrainedModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
......@@ -1065,17 +1088,29 @@ class MistralForCausalLM(MistralPreTrainedModel):
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
# Omit tokens covered by past_key_values
if past_key_values:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
if past_key_values is not None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
cache_length = past_length = past_key_values[0][0].shape[2]
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
# input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
# older attention values, as their corresponding values are not part of the input.
if cache_length < past_length and attention_mask is not None:
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
......
......@@ -27,6 +27,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from ...modeling_utils import PreTrainedModel
......@@ -178,9 +179,17 @@ class PersimmonMLP(nn.Module):
class PersimmonAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: PersimmonConfig):
def __init__(self, config: PersimmonConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
......@@ -257,7 +266,7 @@ class PersimmonAttention(nn.Module):
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
......@@ -280,7 +289,13 @@ class PersimmonAttention(nn.Module):
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
# Partial rotary embedding
......@@ -300,11 +315,9 @@ class PersimmonAttention(nn.Module):
key_states = torch.cat((key_rot, key_pass), dim=-1)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# Specific to RoPE models with partial rotation
cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
......@@ -345,10 +358,10 @@ class PersimmonAttention(nn.Module):
class PersimmonDecoderLayer(nn.Module):
def __init__(self, config: PersimmonConfig):
def __init__(self, config: PersimmonConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = PersimmonAttention(config=config)
self.self_attn = PersimmonAttention(config=config, layer_idx=layer_idx)
self.mlp = PersimmonMLP(config)
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
......@@ -444,6 +457,7 @@ class PersimmonPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["PersimmonDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
def _init_weights(self, module):
std = self.config.initializer_range
......@@ -492,17 +506,23 @@ PERSIMMON_INPUTS_DOCSTRING = r"""
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed:
- a [`~cache_utils.Cache`] instance;
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
legacy cache format will be returned.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
......@@ -539,7 +559,9 @@ class PersimmonModel(PersimmonPreTrainedModel):
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList([PersimmonDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.layers = nn.ModuleList(
[PersimmonDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.gradient_checkpointing = False
......@@ -586,8 +608,11 @@ class PersimmonModel(PersimmonPreTrainedModel):
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_seq_length()
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
......@@ -620,21 +645,19 @@ class PersimmonModel(PersimmonPreTrainedModel):
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
next_decoder_cache = None
for idx, decoder_layer in enumerate(self.layers):
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_value,
past_key_values,
output_attentions,
)
else:
......@@ -642,7 +665,7 @@ class PersimmonModel(PersimmonPreTrainedModel):
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
)
......@@ -650,7 +673,7 @@ class PersimmonModel(PersimmonPreTrainedModel):
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
......@@ -661,7 +684,10 @@ class PersimmonModel(PersimmonPreTrainedModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
......@@ -802,16 +828,28 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel):
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
cache_length = past_length = past_key_values[0][0].shape[2]
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
# input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
# older attention values, as their corresponding values are not part of the input.
if cache_length < past_length and attention_mask is not None:
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
......
......@@ -26,6 +26,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from ...modeling_outputs import (
BaseModelOutputWithPast,
......@@ -217,9 +218,17 @@ class PhiMLP(nn.Module):
class PhiAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: PhiConfig):
def __init__(self, config: PhiConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
......@@ -296,7 +305,7 @@ class PhiAttention(nn.Module):
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
......@@ -319,7 +328,13 @@ class PhiAttention(nn.Module):
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
# Partial rotary embedding
......@@ -339,11 +354,9 @@ class PhiAttention(nn.Module):
key_states = torch.cat((key_rot, key_pass), dim=-1)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# Specific to RoPE models with partial rotation
cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
......@@ -404,7 +417,7 @@ class PhiFlashAttention2(PhiAttention):
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
......@@ -431,7 +444,7 @@ class PhiFlashAttention2(PhiAttention):
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
# Partial rotary embedding
......@@ -451,11 +464,8 @@ class PhiFlashAttention2(PhiAttention):
key_states = torch.cat((key_rot, key_pass), dim=-1)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
tgt_len = key_states.shape[2]
......@@ -603,12 +613,12 @@ class PhiFlashAttention2(PhiAttention):
class PhiDecoderLayer(nn.Module):
def __init__(self, config: PhiConfig):
def __init__(self, config: PhiConfig, layer_idx: int):
super().__init__()
self.self_attn = (
PhiAttention(config=config)
PhiAttention(config=config, layer_idx=layer_idx)
if not getattr(config, "_flash_attn_2_enabled", False)
else PhiFlashAttention2(config=config)
else PhiFlashAttention2(config=config, layer_idx=layer_idx)
)
self.mlp = PhiMLP(config)
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
......@@ -696,6 +706,7 @@ class PhiPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_cache_class = True
def _init_weights(self, module):
std = self.config.initializer_range
......@@ -744,17 +755,23 @@ PHI_INPUTS_DOCSTRING = r"""
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed:
- a [`~cache_utils.Cache`] instance;
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
legacy cache format will be returned.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
......@@ -792,7 +809,9 @@ class PhiModel(PhiPreTrainedModel):
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.embed_dropout = nn.Dropout(config.embd_pdrop)
self.layers = nn.ModuleList([PhiDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.layers = nn.ModuleList(
[PhiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.gradient_checkpointing = False
......@@ -839,8 +858,11 @@ class PhiModel(PhiPreTrainedModel):
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_seq_length()
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
......@@ -877,21 +899,19 @@ class PhiModel(PhiPreTrainedModel):
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
next_decoder_cache = None
for idx, decoder_layer in enumerate(self.layers):
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_value,
past_key_values,
output_attentions,
)
else:
......@@ -899,7 +919,7 @@ class PhiModel(PhiPreTrainedModel):
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
)
......@@ -907,7 +927,7 @@ class PhiModel(PhiPreTrainedModel):
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
......@@ -918,7 +938,9 @@ class PhiModel(PhiPreTrainedModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
......@@ -1060,16 +1082,28 @@ class PhiForCausalLM(PhiPreTrainedModel):
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
cache_length = past_length = past_key_values[0][0].shape[2]
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
# input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
# older attention values, as their corresponding values are not part of the input.
if cache_length < past_length and attention_mask is not None:
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
......
......@@ -16,6 +16,27 @@ class PyTorchBenchmarkArguments(metaclass=DummyObject):
requires_backends(self, ["torch"])
class Cache(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class DynamicCache(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class SinkCache(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class GlueDataset(metaclass=DummyObject):
_backends = ["torch"]
......
......@@ -20,8 +20,9 @@ import unittest
import warnings
import numpy as np
from parameterized import parameterized
from transformers import is_torch_available, pipeline
from transformers import is_torch_available, pipeline, set_seed
from transformers.testing_utils import (
is_flaky,
require_accelerate,
......@@ -53,6 +54,7 @@ if is_torch_available():
SpeechEncoderDecoderModel,
top_k_top_p_filtering,
)
from transformers.cache_utils import DynamicCache
from transformers.generation import (
BeamSampleDecoderOnlyOutput,
BeamSampleEncoderDecoderOutput,
......@@ -1904,6 +1906,66 @@ class GenerationTesterMixin:
)
)
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
# Tests that generating with the new format is exactly the same as the legacy one (for models that support it).
# 👉 tests with and without beam search so that we can test with and without cache reordering.
# 👉 tests with and without sampling so we can cover the most common use cases.
for model_class in self.all_generative_model_classes:
if not model_class._supports_cache_class:
self.skipTest("This model does not support the new cache format")
config, input_ids, attention_mask, _ = self._get_input_ids_and_config()
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
generation_kwargs = {
"max_new_tokens": 5,
"do_sample": do_sample,
"num_beams": num_beams,
"num_return_sequences": num_beams,
"return_dict_in_generate": True, # Required to return `past_key_values`
}
# Sets seed before calling `generate` for the case with do_sample=True
seed = torch.randint(0, 1000000, (1,)).item()
set_seed(seed)
legacy_results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
set_seed(seed)
new_results = model.generate(
input_ids, attention_mask=attention_mask, past_key_values=DynamicCache(), **generation_kwargs
)
# The two sets of generated sequences must match, despite the cache format between forward passes being
# different
self.assertListEqual(legacy_results.sequences.tolist(), new_results.sequences.tolist())
self.assertTrue(isinstance(legacy_results.past_key_values, tuple))
self.assertTrue(isinstance(new_results.past_key_values, DynamicCache))
# The contents of the two caches, when converted to the same format (in both directions!), must match
legacy_cache = legacy_results.past_key_values
new_cache_converted = new_results.past_key_values.to_legacy_cache()
for layer_idx in range(len(legacy_cache)):
for kv_idx in range(len(legacy_cache[layer_idx])):
self.assertTrue(
torch.allclose(
legacy_cache[layer_idx][kv_idx],
new_cache_converted[layer_idx][kv_idx],
)
)
new_cache = new_results.past_key_values
legacy_cache_converted = DynamicCache.from_legacy_cache(legacy_results.past_key_values)
for layer_idx in range(len(new_cache)):
for kv_idx in range(len(new_cache[layer_idx])):
self.assertTrue(
torch.allclose(
new_cache[layer_idx][kv_idx],
legacy_cache_converted[layer_idx][kv_idx],
)
)
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
batch_size, seq_length = input_ids.shape
num_sequences_in_output = batch_size * num_return_sequences
......
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import set_seed
from transformers.testing_utils import is_torch_available, require_auto_gptq, require_torch, require_torch_gpu, slow
if is_torch_available():
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache, LlamaForCausalLM, SinkCache
@require_torch
class CacheTest(unittest.TestCase):
def test_cache_equivalence(self):
"""Tests that we can convert back and forth between the legacy cache format and DynamicCache"""
legacy_cache = ()
new_cache = DynamicCache()
# Creates a new cache with 10 layers in both formats
for layer_idx in range(10):
new_key = torch.rand((2, 4, 8, 16))
new_value = torch.rand((2, 4, 8, 16))
new_cache.update(new_key, new_value, layer_idx)
legacy_cache += ((new_key, new_value),)
# Sanity check 1: they must have the same shapes
self.assertTrue(len(legacy_cache), len(new_cache))
for layer_idx in range(10):
self.assertTrue(len(legacy_cache[layer_idx]), len(legacy_cache[layer_idx]))
for key_value_idx in range(2):
self.assertTrue(
legacy_cache[layer_idx][key_value_idx].shape == new_cache[layer_idx][key_value_idx].shape
)
# Sanity check 2: we can get the sequence length in multiple ways with DynamicCache, and they return the
# expected value
self.assertTrue(legacy_cache[0][0].shape[-2] == new_cache[0][0].shape[-2] == new_cache.get_seq_length() == 8)
# Sanity check 3: they must be equal, and both support indexing
for layer_idx in range(10):
for key_value_idx in range(2):
self.assertTrue(
torch.allclose(new_cache[layer_idx][key_value_idx], legacy_cache[layer_idx][key_value_idx])
)
# Test 1: We can convert from legacy to new with no changes
from_legacy = DynamicCache.from_legacy_cache(legacy_cache)
for layer_idx in range(10):
for key_value_idx in range(2):
self.assertTrue(
torch.allclose(from_legacy[layer_idx][key_value_idx], legacy_cache[layer_idx][key_value_idx])
)
# Test 2: We can convert from new to legacy with no changes
to_legacy = new_cache.to_legacy_cache()
for layer_idx in range(10):
for key_value_idx in range(2):
self.assertTrue(
torch.allclose(to_legacy[layer_idx][key_value_idx], new_cache[layer_idx][key_value_idx])
)
def test_reorder_cache_retrocompatibility(self):
"""Tests that Cache.reorder_cache is retrocompatible with the legacy code path"""
legacy_reorder_fn = LlamaForCausalLM._reorder_cache # An example of a legacy `_reorder_cache` function
legacy_cache = ()
new_cache = DynamicCache()
# Creates a new cache with 10 layers in both formats
for layer_idx in range(10):
new_key = torch.rand((4, 4, 8, 16))
new_value = torch.rand((4, 4, 8, 16))
new_cache.update(new_key, new_value, layer_idx)
legacy_cache += ((new_key, new_value),)
# Let's create some dummy beam indices. From the shape above, it is equivalent to the case where num_beams=4
# and batch_size=1
beam_idx = torch.randint(low=0, high=4, size=(4,))
legacy_cache_reordered = legacy_reorder_fn(legacy_cache, beam_idx)
new_cache.reorder_cache(beam_idx)
# Let's check that the results are the same
for layer_idx in range(10):
for key_value_idx in range(2):
self.assertTrue(
torch.allclose(
new_cache[layer_idx][key_value_idx], legacy_cache_reordered[layer_idx][key_value_idx]
)
)
@require_torch_gpu
@slow
class CacheIntegrationTest(unittest.TestCase):
def test_dynamic_cache_hard(self):
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", padding_side="left")
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.float16
)
inputs = tokenizer(["Here's everything I know about cats. Cats"], return_tensors="pt").to(model.device)
# DynamicCache and the legacy cache format should be equivalent
set_seed(0)
gen_out_legacy = model.generate(**inputs, do_sample=True, max_new_tokens=256)
set_seed(0)
gen_out = model.generate(**inputs, do_sample=True, max_new_tokens=256, past_key_values=DynamicCache())
self.assertListEqual(gen_out_legacy.tolist(), gen_out.tolist())
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
expected_text = (
"Here's everything I know about cats. Cats are mysterious creatures. They can't talk, and they don't like "
"to be held. They don't play fetch, and they don't like to be hugged. But they do like to be petted.\n"
"Cats are also very independent. They don't like to be told what to do, and they don't like to be told "
"what to eat. They are also very territorial. They don't like to share their food or their toys.\nCats "
"are also very curious. They like to explore, and they like to play. They are also very fast. They can "
"run very fast, and they can jump very high.\nCats are also very smart. They can learn tricks, and they "
"can solve problems. They are also very playful. They like to play with toys, and they like to play with "
"other cats.\nCats are also very affectionate. They like to be petted, and they like to be held. They "
"also like to be scratched.\nCats are also very clean. They like to groom themselves, and they like to "
"clean their litter box.\nCats are also very independent. They don't"
)
self.assertEqual(decoded[0], expected_text)
def test_dynamic_cache_batched(self):
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.float16
)
inputs = tokenizer(["A sequence: 1, 2, 3, 4, 5", "A sequence: A, B, C"], padding=True, return_tensors="pt").to(
model.device
)
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10, past_key_values=DynamicCache())
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
expected_text = ["A sequence: 1, 2, 3, 4, 5, 6, 7, 8,", "A sequence: A, B, C, D, E, F, G, H"]
self.assertListEqual(decoded, expected_text)
def test_dynamic_cache_beam_search(self):
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", padding_side="left")
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.float16
)
inputs = tokenizer(["The best color is"], return_tensors="pt").to(model.device)
gen_out = model.generate(
**inputs,
do_sample=False,
max_new_tokens=20,
num_beams=2,
num_return_sequences=2,
)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
expected_text = [
"The best color is the one that makes you feel good.\nThe best color is the one that makes you feel good",
"The best color is the one that suits you.\nThe best color is the one that suits you. The",
]
self.assertListEqual(decoded, expected_text)
@require_auto_gptq
def test_sink_cache_hard(self):
tokenizer = AutoTokenizer.from_pretrained("TheBloke/LLaMa-7B-GPTQ")
model = AutoModelForCausalLM.from_pretrained("TheBloke/LLaMa-7B-GPTQ", device_map="auto")
inputs = tokenizer(["Vaswani et al. (2017) introduced the Transformers"], return_tensors="pt").to(model.device)
# Set up the SinkCache. Using a small window length to contain computational complexity. If this example is run
# without a SinkCache, the last few tokens are gibberish (ends in "of the of the of a of a of")
cache = SinkCache(window_length=508, num_sink_tokens=4)
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=3000, past_key_values=cache)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
self.assertTrue(decoded[0].endswith("to perform a variety of tasks. The Transformer is a neural network"))
......@@ -557,10 +557,6 @@ class ModelTesterMixin:
return
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.use_cache = False
config.return_dict = True
if (
model_class.__name__
in [*get_values(MODEL_MAPPING_NAMES), *get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES)]
......@@ -569,6 +565,8 @@ class ModelTesterMixin:
continue
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.use_cache = False
config.return_dict = True
model = model_class(config)
model.to(torch_device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment