"src/git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "8116b04338716b97f8496a8213022e26c21b8f07"
Unverified Commit a30c865f authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

Cache: new Cache format in decoder-only models (#31421)



* draft bart with new cache

* add cache for decoder-only models

* revert utils

* modify docstring

* revert bart

* minor fixes

* fix copies (not related)

* revert tests

* remove enc-dec related code

* remove bloom

* remove opt (enc-dec)

* update docstring

* git, codegen, gpt_neo, gpt_neox, gpj

* clean up

* copied from statements

* revert

* tmp

* update warning msg

* forgot git

* add more flags

* run-slow git,codegen,gpt_neo,gpt_neox,gpj

* add cache flag to VLMs

* remove files

* style

* video LLMs also need a flag

* style

* llava will go in another PR

* style

* [run-slow] codegen, falcon, git, gpt_neo, gpt_neox, gptj, idefics

* Update src/transformers/models/gpt_neo/modeling_gpt_neo.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* copy from

* deprecate until v4.45 and warn if not training

* nit

* fix test

* test static cache

* add more tests and fix models

* fix copies

* return sliding window mask

* run slow tests & fix + codestyle

* one more falcon fix for alibi

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 6af0854e
......@@ -1016,7 +1016,9 @@ class StaticCache(Cache):
self.dtype = dtype if dtype is not None else torch.float32
self.num_key_value_heads = (
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
config.num_attention_heads
if getattr(config, "num_key_value_heads", None) is None
else config.num_key_value_heads
)
self.key_cache: List[torch.Tensor] = []
......
......@@ -1473,7 +1473,7 @@ class GenerationMixin:
# NOTE: self.dtype is not compatible with torch.compile, as it calls `self.parameters()`.
# Workaround: trust the lm_head, whose attribute name is somewhat consistent across generative
# models. May cause trobles with non-text modalities.
cache_dtype = self.lm_head.weight.dtype
cache_dtype = self.get_output_embeddings().weight.dtype
cache_kwargs = {
"config": self.config,
......
......@@ -25,6 +25,7 @@ from torch import nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...file_utils import ModelOutput
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_outputs import (
......@@ -124,13 +125,20 @@ class GitEmbeddings(nn.Module):
class GitSelfAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
def __init__(self, config, position_embedding_type=None, layer_idx=None):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
f"heads ({config.num_attention_heads})"
)
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
......@@ -161,46 +169,31 @@ class GitSelfAttention(nn.Module):
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
pixel_values_present: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
mixed_query_layer = self.query(hidden_states)
cutoff = self.image_patch_tokens if pixel_values_present else 0
if past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([key_layer[:, :, :cutoff, :], past_key_value[0], key_layer[:, :, -1:, :]], dim=2)
value_layer = torch.cat(
[value_layer[:, :, :cutoff, :], past_key_value[1], value_layer[:, :, -1:, :]], dim=2
if past_key_value is not None:
# NOTE: like in other caches, we store the text component. In GIT it means we discard the image component.
key_layer_past, value_layer_past = past_key_value.update(
key_layer[:, :, cutoff:, :], value_layer[:, :, cutoff:, :], self.layer_idx
)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([key_layer[:, :, :cutoff, :], key_layer_past], dim=2)
value_layer = torch.cat([value_layer[:, :, :cutoff, :], value_layer_past], dim=2)
query_layer = self.transpose_for_scores(mixed_query_layer)
use_cache = past_key_value is not None
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
# NOTE: like in other caches, we store the text component. In GIT it means we discard the image component.
past_key_value = (
key_layer[:, :, cutoff:, :],
value_layer[:, :, cutoff:, :],
)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
if use_cache:
if past_key_value is not None:
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
-1, 1
)
......@@ -269,11 +262,10 @@ GIT_SELF_ATTENTION_CLASSES = {
class GitAttention(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertAttention.__init__ with Bert->Git,BERT->GIT
def __init__(self, config, position_embedding_type=None):
def __init__(self, config, position_embedding_type=None, layer_idx=None):
super().__init__()
self.self = GIT_SELF_ATTENTION_CLASSES[config._attn_implementation](
config, position_embedding_type=position_embedding_type
config, position_embedding_type=position_embedding_type, layer_idx=layer_idx
)
self.output = GitSelfOutput(config)
self.pruned_heads = set()
......@@ -302,7 +294,7 @@ class GitAttention(nn.Module):
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
pixel_values_present: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
......@@ -351,11 +343,11 @@ class GitOutput(nn.Module):
class GitLayer(nn.Module):
def __init__(self, config):
def __init__(self, config, layer_idx=None):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = GitAttention(config)
self.attention = GitAttention(config, layer_idx=layer_idx)
self.intermediate = GitIntermediate(config)
self.output = GitOutput(config)
......@@ -364,18 +356,17 @@ class GitLayer(nn.Module):
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
pixel_values_present: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
past_key_value=past_key_value,
pixel_values_present=pixel_values_present,
)
attention_output = self_attention_outputs[0]
......@@ -401,11 +392,10 @@ class GitLayer(nn.Module):
class GitEncoder(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertEncoder.__init__ with Bert->Git
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList([GitLayer(config) for _ in range(config.num_hidden_layers)])
self.layer = nn.ModuleList([GitLayer(config, i) for i in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
......@@ -413,7 +403,7 @@ class GitEncoder(nn.Module):
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.FloatTensor]]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
......@@ -427,16 +417,23 @@ class GitEncoder(nn.Module):
)
use_cache = False
use_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.45. "
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
)
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
next_decoder_cache = () if use_cache else None
next_decoder_cache = None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
......@@ -444,7 +441,7 @@ class GitEncoder(nn.Module):
hidden_states,
attention_mask,
layer_head_mask,
past_key_value,
past_key_values,
output_attentions,
)
else:
......@@ -452,26 +449,30 @@ class GitEncoder(nn.Module):
hidden_states,
attention_mask,
layer_head_mask,
past_key_value,
past_key_values,
output_attentions,
pixel_values_present,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
next_decoder_cache = layer_outputs[-1]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
next_cache,
all_hidden_states,
all_self_attentions,
]
......@@ -479,7 +480,7 @@ class GitEncoder(nn.Module):
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
......@@ -494,6 +495,8 @@ class GitPreTrainedModel(PreTrainedModel):
config_class = GitConfig
base_model_prefix = "git"
supports_gradient_checkpointing = True
_supports_cache_class = True
_supports_quantized_cache = True
def _init_weights(self, module):
"""Initialize the weights"""
......@@ -569,6 +572,23 @@ GIT_INPUTS_DOCSTRING = r"""
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed:
- a [`~cache_utils.Cache`] instance;
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
legacy cache format will be returned.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
of shape `(batch_size, sequence_length)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
......@@ -1136,19 +1156,13 @@ class GitModel(GitPreTrainedModel):
pixel_values: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPooling]:
r"""
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
......@@ -1195,7 +1209,13 @@ class GitModel(GitPreTrainedModel):
seq_length = input_shape[1]
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = (
past_key_values[0][0].shape[2]
if not isinstance(past_key_values, Cache)
else past_key_values.get_seq_length()
)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
......@@ -1327,7 +1347,7 @@ class GitForCausalLM(GitPreTrainedModel):
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.Tensor]] = None,
past_key_values: Optional[Union[Cache, List[torch.Tensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
......@@ -1338,12 +1358,6 @@ class GitForCausalLM(GitPreTrainedModel):
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
`[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
......@@ -1522,7 +1536,16 @@ class GitForCausalLM(GitPreTrainedModel):
):
# cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
input_ids = input_ids[:, -1:]
past_length = past_key_values.get_seq_length()
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
input_shape = input_ids.shape
......
......@@ -59,7 +59,7 @@ if is_torch_available():
ImageGPTForCausalImageModeling,
SpeechEncoderDecoderModel,
)
from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache
from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache
from transformers.generation import (
BeamSampleDecoderOnlyOutput,
BeamSampleEncoderDecoderOutput,
......@@ -1769,6 +1769,53 @@ class GenerationTesterMixin:
)
)
def test_generate_with_static_cache(self):
"""
Tests if StaticCache works if we set attn_implementation=static when generation.
This doesn't test if generation quality is good, but tests that models with
self._supports_static_cache don't throw an error when generating and return
a StaticCache object at the end.
"""
for model_class in self.all_generative_model_classes:
if not model_class._supports_static_cache:
self.skipTest(reason="This model does not support the static cache format")
config, input_ids, attention_mask = self._get_input_ids_and_config()
if config.is_encoder_decoder:
self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache")
config.use_cache = True
config.is_decoder = True
batch_size, seq_length = input_ids.shape
max_new_tokens = 20
model = model_class(config).to(torch_device).eval()
generation_kwargs = {
"max_length": None,
"max_new_tokens": max_new_tokens,
"cache_implementation": "static",
"return_dict_in_generate": True, # Required to return `past_key_values`
}
max_cache_len = seq_length + max_new_tokens
head_dim = (
model.config.head_dim
if hasattr(model.config, "head_dim")
else model.config.hidden_size // model.config.num_attention_heads
)
num_key_value_heads = (
model.config.num_attention_heads
if getattr(config, "num_key_value_heads", None) is None
else model.config.num_key_value_heads
)
num_hidden_layers = config.num_hidden_layers
results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim)
self.assertTrue(isinstance(results.past_key_values, StaticCache))
self.assertTrue(len(results.past_key_values.key_cache) == num_hidden_layers)
self.assertTrue(results.past_key_values.key_cache[0].shape == cache_shape)
@require_quanto
def test_generate_with_quant_cache(self):
for model_class in self.all_generative_model_classes:
......
......@@ -4587,6 +4587,44 @@ class ModelTesterMixin:
normalized_1 = F.softmax(out_shared_prefix_last_tokens)
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
def test_static_cache_matches_dynamic(self):
"""
Tests that generating with static cache give almost same results as with dynamic cache.
This test does not compile the model and check only logits similarity for numerical precision
errors.
"""
if len(self.all_generative_model_classes) == 0:
self.skipTest(
reason="Model architecture has no generative classes, and thus not necessarily supporting 4D masks"
)
for model_class in self.all_generative_model_classes:
if not model_class._supports_static_cache:
self.skipTest(f"{model_class.__name__} does not support static cache")
if not model_class._supports_cache_class:
self.skipTest(f"{model_class.__name__} does not support cache class")
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
if getattr(config, "sliding_window", 0) > 0:
self.skipTest(f"{model_class.__name__} with sliding window attention is not supported by this test")
model = model_class(config).to(device=torch_device, dtype=torch.float32)
model.eval()
dynamic_out = model.generate(
**inputs, do_sample=False, max_new_tokens=10, output_logits=True, return_dict_in_generate=True
)
static_out = model.generate(
**inputs,
do_sample=False,
max_new_tokens=10,
cache_implementation="static",
output_logits=True,
return_dict_in_generate=True,
)
self.assertTrue(torch.allclose(dynamic_out.logits[0], static_out.logits[0], rtol=1e-3, atol=1e-4))
# For now, Let's focus only on GPU for `torch.compile`
@slow
@require_torch_gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment