Unverified Commit a9701953 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[whisper] static kv cache (#31166)



* make work with cache abstraction

* correct for static cache

* hacks for compile

* make fast

* fix

* fix pos ids

* generate

* fix sdpa

* fix sdpa cache pos

* fix fa2

* clean fa2

* integrate cache into generate

* make style

* copies

* more copies

* update eager

* update sdpa

* update fa2

* simplify

* use cache pos

* always compute cross-cache for debug

* avoid recompiles
Co-authored-by: default avatarArthur Zucker <arthur@huggingface.co>

* fix fix

* fix fix fix

* more fix

* try encoder-decoder cache (too messy)

* revert encoder-decoder cache

* check cross-attn cache

* use enc-dec dataclass

* use richer enc-dec dataclass

* clean-up

* revert static cache changes

* small fixes

* revert to cpu flag

* fix copies

* add static slow test

* past k/v docstring

* more docstrings

* cache_position docstrings

* add to docs

* add enc-dec cache to docs

* make style

* fix after rebase

* fix beam

* style

* fix generation strategies

* fix most decoder-only tests

* style

* skip test

* more clean up

* small docstrings

* Apply suggestions from code review
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* add todo

* only crop self-attn

* check cache in mixin

* style

* fix re-compile after rebase

* move `is_updated` logic to enc-dec wrapper

* revert back

* revert cache back

* finalise design

* fix

* fix fix

* style

* Update src/transformers/cache_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* deprecate

* updates

* final updates

* style

* style

---------
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 57d7594a
...@@ -391,6 +391,12 @@ A [`Constraint`] can be used to force the generation to include specific tokens ...@@ -391,6 +391,12 @@ A [`Constraint`] can be used to force the generation to include specific tokens
- get_seq_length - get_seq_length
- reset - reset
[[autodoc]] EncoderDecoderCache
- get_seq_length
- to_legacy_cache
- from_legacy_cache
- reset
- reorder_cache
## Watermark Utils ## Watermark Utils
......
...@@ -52,8 +52,6 @@ Here is a step-by-step guide to transcribing an audio sample using a pre-trained ...@@ -52,8 +52,6 @@ Here is a step-by-step guide to transcribing an audio sample using a pre-trained
>>> # Select an audio file and read it: >>> # Select an audio file and read it:
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> audio_sample = ds[0]["audio"] >>> audio_sample = ds[0]["audio"]
>>> waveform = audio_sample["array"]
>>> sampling_rate = audio_sample["sampling_rate"]
>>> # Load the Whisper model in Hugging Face format: >>> # Load the Whisper model in Hugging Face format:
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") >>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
...@@ -61,7 +59,7 @@ Here is a step-by-step guide to transcribing an audio sample using a pre-trained ...@@ -61,7 +59,7 @@ Here is a step-by-step guide to transcribing an audio sample using a pre-trained
>>> # Use the model and processor to transcribe the audio: >>> # Use the model and processor to transcribe the audio:
>>> input_features = processor( >>> input_features = processor(
... waveform, sampling_rate=sampling_rate, return_tensors="pt" ... audio_sample["array"], sampling_rate=audio_sample["sampling_rate"], return_tensors="pt"
... ).input_features ... ).input_features
>>> # Generate token ids >>> # Generate token ids
...@@ -74,6 +72,49 @@ Here is a step-by-step guide to transcribing an audio sample using a pre-trained ...@@ -74,6 +72,49 @@ Here is a step-by-step guide to transcribing an audio sample using a pre-trained
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.' ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
``` ```
Whisper is compatible with the following optimisations:
- [PyTorch Scaled Dot Product Attention (SDPA)](../perf_infer_gpu_one#pytorch-scaled-dot-product-attention): flash attention and memory-efficient attention kernels. Enabled by default for `torch>=2.1.1`.
- [Flash Attention 2](../perf_infer_gpu_one#flashattention-2): improved implementation of flash attention through better parallelism and work partitioning.
- [torch.compile](../llm_optims#static-kv-cache-and-torchcompile): JIT-compile the forward pass to dispatch to efficient fused kernels.
As an example, the following codesnippet enables SDPA and `torch.compile` for up to 5x faster inference:
```python
>>> from datasets import load_dataset
>>> from transformers import WhisperProcessor, WhisperForConditionalGeneration
>>> # Select an audio file and read it:
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> audio_sample = ds[0]["audio"]
>>> # Load the Whisper model with SDPA attention
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", attn_implementation="sdpa")
>>> # Enable static cache and compile the forward pass
>>> model.generation_config.cache_implementation = "static"
>>> model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
>>> # Use the model and processor to transcribe the audio:
>>> input_features = processor(
... audio_sample["array"], sampling_rate=audio_sample["sampling_rate"], return_tensors="pt"
... ).input_features
>>> # Compile the forward pass
>>> _ = model.generate(input_features)
>>> # Generate token ids using compiled graph (fast!)
>>> predicted_ids = model.generate(input_features)
>>> # Decode token ids to text
>>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
>>> transcription[0]
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
```
For more details on each optimisation, refer to the documentation linked above.
## Resources ## Resources
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with Whisper. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource. A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with Whisper. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
......
...@@ -1212,6 +1212,7 @@ else: ...@@ -1212,6 +1212,7 @@ else:
"Cache", "Cache",
"CacheConfig", "CacheConfig",
"DynamicCache", "DynamicCache",
"EncoderDecoderCache",
"HQQQuantizedCache", "HQQQuantizedCache",
"QuantizedCache", "QuantizedCache",
"QuantizedCacheConfig", "QuantizedCacheConfig",
...@@ -5895,6 +5896,7 @@ if TYPE_CHECKING: ...@@ -5895,6 +5896,7 @@ if TYPE_CHECKING:
Cache, Cache,
CacheConfig, CacheConfig,
DynamicCache, DynamicCache,
EncoderDecoderCache,
HQQQuantizedCache, HQQQuantizedCache,
QuantizedCache, QuantizedCache,
QuantizedCacheConfig, QuantizedCacheConfig,
......
...@@ -858,6 +858,10 @@ class StaticCache(Cache): ...@@ -858,6 +858,10 @@ class StaticCache(Cache):
k_out = self.key_cache[layer_idx] k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx] v_out = self.value_cache[layer_idx]
if cache_position is None:
k_out.copy_(key_states)
v_out.copy_(value_states)
else:
k_out[:, :, cache_position] = key_states k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states v_out[:, :, cache_position] = value_states
...@@ -971,6 +975,158 @@ class SlidingWindowCache(StaticCache): ...@@ -971,6 +975,158 @@ class SlidingWindowCache(StaticCache):
# no matter how long the sentence is # no matter how long the sentence is
return None return None
def reset(self):
self.key_cache.zero_()
self.value_cache.zero_()
class EncoderDecoderCache(Cache):
"""
Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and
cross-attention caches.
"""
def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache):
self.self_attention_cache = self_attention_cache
self.cross_attention_cache = cross_attention_cache
self.is_updated = {}
for layer_idx in range(len(cross_attention_cache.key_cache)):
self.is_updated[layer_idx] = bool(cross_attention_cache.get_seq_length(layer_idx) > 0)
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
"""
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
sequence length.
"""
if layer_idx < len(self):
return (
self.self_attention_cache.key_cache[layer_idx],
self.self_attention_cache.value_cache[layer_idx],
self.cross_attention_cache.key_cache[layer_idx],
self.cross_attention_cache.key_cache[layer_idx],
)
else:
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
def __len__(self):
"""
Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
to the number of layers in the model.
"""
return len(self.self_attention_cache)
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
"""Converts the `EncoderDecoderCache` instance into its equivalent in the legacy cache format."""
legacy_cache = ()
if len(self.cross_attention_cache) > 0:
for self_attn, cross_attn in zip(
self.self_attention_cache.to_legacy_cache(), self.cross_attention_cache.to_legacy_cache()
):
legacy_cache += (self_attn + cross_attn,)
else:
legacy_cache = self.self_attention_cache.to_legacy_cache()
return legacy_cache
@classmethod
def from_legacy_cache(
cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
) -> "EncoderDecoderCache":
"""Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`."""
cache = cls(self_attention_cache=DynamicCache(), cross_attention_cache=DynamicCache())
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
key_states, value_states = past_key_values[layer_idx][:2]
cache.self_attention_cache.update(key_states, value_states, layer_idx)
if len(past_key_values[layer_idx]) > 2:
key_states, value_states = past_key_values[layer_idx][2:]
cache.cross_attention_cache.update(key_states, value_states, layer_idx)
cache.is_updated[layer_idx] = True
return cache
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
if len(self.self_attention_cache.key_cache) <= layer_idx:
return 0
return (self.self_attention_cache.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
def reset(self):
if hasattr(self.self_attention_cache, "reset"):
self.self_attention_cache.reset()
if hasattr(self.cross_attention_cache, "reset"):
self.cross_attention_cache.reset()
elif not hasattr(self.self_attention_cache, "reset") and not hasattr(self.cross_attention_cache, "reset"):
raise ValueError(
"Neither self nor cross-attention cache have valid `.reset()` methods. `.reset()` should "
"only be called on compatible cache classes, such as `StaticCache` or `SlidingWindowCache`. "
f"Got {self.self_attention_cache.__str__()} for the self attention cache and "
f"{self.cross_attention_cache.__str__()} for the cross attention cache."
)
for layer_idx in self.is_updated:
self.is_updated[layer_idx] = False
def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
self.self_attention_cache.reorder_cache(beam_idx)
self.cross_attention_cache.reorder_cache(beam_idx)
def check_dynamic_cache(self, method: str):
if not (
isinstance(self.self_attention_cache, DynamicCache)
and isinstance(self.cross_attention_cache, DynamicCache)
):
raise ValueError(
f"`{method}` is only defined for dynamic cache, got {self.self_attention_cache.__str__()} for the self "
f"attention cache and {self.cross_attention_cache.__str__()} for the cross attention cache."
)
# TODO(gante, sanchit-gandhi): move following functionality into `.generate`
def crop(self, maximum_length: int):
"""Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search."""
self.check_dynamic_cache(self.crop.__name__)
self.self_attention_cache.crop(maximum_length)
def batch_split(self, full_batch_size: int, split_size: int) -> "List[EncoderDecoderCache]":
"""Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
`_split_model_inputs()` in `generation.utils`"""
self.check_dynamic_cache(self.batch_split.__name__)
self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size)
cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size)
out = []
for self_attn, cross_attn in zip(self_attention_cache, cross_attention_cache):
out.append(EncoderDecoderCache(self_attn, cross_attn))
return out
@classmethod
def from_batch_splits(cls, splits: List["EncoderDecoderCache"]) -> "EncoderDecoderCache":
"""This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
`generation.utils`"""
self_attention_cache = DynamicCache()
cross_attention_cache = DynamicCache()
for idx in range(len(splits[0])):
layer_keys = torch.cat([current.self_attention_cache.key_cache[idx] for current in splits], dim=0)
layer_values = torch.cat([current.self_attention_cache.value_cache[idx] for current in splits], dim=0)
self_attention_cache.update(layer_keys, layer_values, idx)
layer_keys = torch.cat([current.cross_attention_cache.key_cache[idx] for current in splits], dim=0)
layer_values = torch.cat([current.cross_attention_cache.value_cache[idx] for current in splits], dim=0)
cross_attention_cache.update(layer_keys, layer_values, idx)
return cls(self_attention_cache, cross_attention_cache)
def batch_repeat_interleave(self, repeats: int):
"""Repeat the cache `repeats` times in the batch dimension. Used in contrastive search."""
self.check_dynamic_cache(self.batch_repeat_interleave.__name__)
self.self_attention_cache.batch_repeat_interleave(repeats)
self.cross_attention_cache.batch_repeat_interleave(repeats)
def batch_select_indices(self, indices: torch.Tensor):
"""Only keep the `indices` in the batch dimension of the cache. Used in contrastive search."""
self.check_dynamic_cache(self.batch_select_indices.__name__)
self.self_attention_cache.batch_select_indices(indices)
self.cross_attention_cache.batch_select_indices(indices)
class HybridCache(Cache): class HybridCache(Cache):
def __init__(self, config: PretrainedConfig, max_batch_size, max_cache_len, device="cpu", dtype=None) -> None: def __init__(self, config: PretrainedConfig, max_batch_size, max_cache_len, device="cpu", dtype=None) -> None:
......
...@@ -27,6 +27,7 @@ from torch import nn ...@@ -27,6 +27,7 @@ from torch import nn
from ..cache_utils import ( from ..cache_utils import (
Cache, Cache,
DynamicCache, DynamicCache,
EncoderDecoderCache,
HQQQuantizedCache, HQQQuantizedCache,
HybridCache, HybridCache,
QuantizedCacheConfig, QuantizedCacheConfig,
...@@ -1409,7 +1410,7 @@ class GenerationMixin: ...@@ -1409,7 +1410,7 @@ class GenerationMixin:
model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device) model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device)
return model_kwargs return model_kwargs
def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_len: int) -> Cache: def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_len: int, model_kwargs) -> Cache:
""" """
Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a
new `generate` call requires a larger cache. new `generate` call requires a larger cache.
...@@ -1417,14 +1418,27 @@ class GenerationMixin: ...@@ -1417,14 +1418,27 @@ class GenerationMixin:
Returns the resulting cache object. Returns the resulting cache object.
""" """
cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation] cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation]
requires_cross_attention_cache = (
self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
)
if hasattr(self, "_cache"):
cache_to_check = self._cache.self_attention_cache if requires_cross_attention_cache else self._cache
if cache_implementation == "sliding_window": if cache_implementation == "sliding_window":
max_cache_len = min(self.config.sliding_window, max_cache_len) max_cache_len = min(self.config.sliding_window, max_cache_len)
need_new_cache = ( need_new_cache = (
not hasattr(self, "_cache") not hasattr(self, "_cache")
or (not isinstance(self._cache, cache_cls)) or (not isinstance(cache_to_check, cache_cls))
or self._cache.max_batch_size != max_batch_size or cache_to_check.max_batch_size != max_batch_size
or self._cache.max_cache_len < max_cache_len or cache_to_check.max_cache_len < max_cache_len
)
if requires_cross_attention_cache and hasattr(self, "_cache"):
need_new_cache = (
need_new_cache
or self._cache.cross_attention_cache.max_cache_len != model_kwargs["encoder_outputs"][0].shape[1]
) )
if need_new_cache: if need_new_cache:
...@@ -1432,13 +1446,18 @@ class GenerationMixin: ...@@ -1432,13 +1446,18 @@ class GenerationMixin:
cache_dtype = self.config._pre_quantization_dtype cache_dtype = self.config._pre_quantization_dtype
else: else:
cache_dtype = self.dtype cache_dtype = self.dtype
self._cache = cache_cls( cache_kwargs = {
config=self.config, "config": self.config,
max_batch_size=max_batch_size, "max_batch_size": max_batch_size,
max_cache_len=max_cache_len, "max_cache_len": max_cache_len,
device=self.device, "device": self.device,
dtype=cache_dtype, "dtype": cache_dtype,
) }
self._cache = cache_cls(**cache_kwargs)
if requires_cross_attention_cache:
encoder_kwargs = cache_kwargs.copy()
encoder_kwargs["max_cache_len"] = model_kwargs["encoder_outputs"][0].shape[1]
self._cache = EncoderDecoderCache(self._cache, cache_cls(**encoder_kwargs))
else: else:
self._cache.reset() self._cache.reset()
return self._cache return self._cache
...@@ -1745,6 +1764,7 @@ class GenerationMixin: ...@@ -1745,6 +1764,7 @@ class GenerationMixin:
generation_config.cache_implementation, generation_config.cache_implementation,
getattr(generation_config, "num_beams", 1) * batch_size, getattr(generation_config, "num_beams", 1) * batch_size,
generation_config.max_length, generation_config.max_length,
model_kwargs,
) )
elif generation_config.cache_implementation == "quantized": elif generation_config.cache_implementation == "quantized":
if not self._supports_quantized_cache: if not self._supports_quantized_cache:
...@@ -1776,11 +1796,22 @@ class GenerationMixin: ...@@ -1776,11 +1796,22 @@ class GenerationMixin:
# keeps copying the cache thus using much more memory # keeps copying the cache thus using much more memory
elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache(): elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache():
past = model_kwargs.get("past_key_values", None) past = model_kwargs.get("past_key_values", None)
requires_cross_attention_cache = (
self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
)
if past is None: if past is None:
model_kwargs["past_key_values"] = DynamicCache() model_kwargs["past_key_values"] = (
DynamicCache()
if not requires_cross_attention_cache
else EncoderDecoderCache(DynamicCache(), DynamicCache())
)
use_dynamic_cache_by_default = True use_dynamic_cache_by_default = True
elif isinstance(past, tuple): elif isinstance(past, tuple):
model_kwargs["past_key_values"] = DynamicCache.from_legacy_cache(past) model_kwargs["past_key_values"] = (
DynamicCache.from_legacy_cache(past)
if not requires_cross_attention_cache
else EncoderDecoderCache.from_legacy_cache(past)
)
use_dynamic_cache_by_default = True use_dynamic_cache_by_default = True
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
...@@ -2064,7 +2095,7 @@ class GenerationMixin: ...@@ -2064,7 +2095,7 @@ class GenerationMixin:
# Convert to legacy cache if needed # Convert to legacy cache if needed
if use_dynamic_cache_by_default and generation_config.return_legacy_cache: if use_dynamic_cache_by_default and generation_config.return_legacy_cache:
if isinstance(result, ModelOutput) and hasattr(result, "past_key_values"): if isinstance(result, ModelOutput) and hasattr(result, "past_key_values"):
if isinstance(result.past_key_values, DynamicCache): if isinstance(result.past_key_values, (DynamicCache, EncoderDecoderCache)):
result.past_key_values = result.past_key_values.to_legacy_cache() result.past_key_values = result.past_key_values.to_legacy_cache()
return result return result
...@@ -2234,7 +2265,7 @@ class GenerationMixin: ...@@ -2234,7 +2265,7 @@ class GenerationMixin:
# if the first step in the loop, encode all the prefix and obtain: (1) past_key_values; # if the first step in the loop, encode all the prefix and obtain: (1) past_key_values;
# (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step # (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step
if model_kwargs.get("past_key_values") is None or ( if model_kwargs.get("past_key_values") is None or (
isinstance(model_kwargs["past_key_values"], Cache) isinstance(model_kwargs["past_key_values"], (Cache, EncoderDecoderCache))
and model_kwargs["past_key_values"].get_seq_length() == 0 and model_kwargs["past_key_values"].get_seq_length() == 0
): ):
# prepare inputs # prepare inputs
...@@ -2323,7 +2354,9 @@ class GenerationMixin: ...@@ -2323,7 +2354,9 @@ class GenerationMixin:
# Replicates the new past_key_values to match the `top_k` candidates # Replicates the new past_key_values to match the `top_k` candidates
past = model_kwargs["past_key_values"] past = model_kwargs["past_key_values"]
# If it is a static cache, modify it in-place layer after layer to save memory # If it is a static cache, modify it in-place layer after layer to save memory
if isinstance(past, DynamicCache): if isinstance(past, DynamicCache) or (
isinstance(past, EncoderDecoderCache) and isinstance(past.self_attention_cache, DynamicCache)
):
past.batch_repeat_interleave(top_k) past.batch_repeat_interleave(top_k)
else: else:
new_key_values = [] new_key_values = []
...@@ -2350,7 +2383,10 @@ class GenerationMixin: ...@@ -2350,7 +2383,10 @@ class GenerationMixin:
output_hidden_states=True, output_hidden_states=True,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
if isinstance(outputs["past_key_values"], DynamicCache): if isinstance(outputs["past_key_values"], DynamicCache) or (
isinstance(outputs["past_key_values"], EncoderDecoderCache)
and isinstance(outputs["past_key_values"].self_attention_cache, DynamicCache)
):
# Remove past K-V from output since we don't need to stack later # Remove past K-V from output since we don't need to stack later
outputs["past_key_values"] = None outputs["past_key_values"] = None
# Remove last token from past K-V since we don't want to append it at this point # Remove last token from past K-V since we don't want to append it at this point
...@@ -2425,7 +2461,10 @@ class GenerationMixin: ...@@ -2425,7 +2461,10 @@ class GenerationMixin:
else: else:
_, next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True) _, next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True)
# Do it in-place layer per layer to save memory # Do it in-place layer per layer to save memory
if isinstance(next_past_key_values, DynamicCache): if isinstance(next_past_key_values, DynamicCache) or (
isinstance(next_past_key_values, EncoderDecoderCache)
and isinstance(next_past_key_values.self_attention_cache, DynamicCache)
):
next_past_key_values.batch_select_indices(augmented_idx) next_past_key_values.batch_select_indices(augmented_idx)
else: else:
new_key_values = [] new_key_values = []
...@@ -2498,7 +2537,10 @@ class GenerationMixin: ...@@ -2498,7 +2537,10 @@ class GenerationMixin:
# Contrastive search works by forward looking at the next token, so we need to exclude it from # Contrastive search works by forward looking at the next token, so we need to exclude it from
# `past_key_values` to be consistent with the other decoding methods # `past_key_values` to be consistent with the other decoding methods
if model_kwargs.get("past_key_values") is not None: if model_kwargs.get("past_key_values") is not None:
if isinstance(model_kwargs["past_key_values"], DynamicCache): if isinstance(model_kwargs["past_key_values"], DynamicCache) or (
isinstance(model_kwargs["past_key_values"], EncoderDecoderCache)
and isinstance(model_kwargs["past_key_values"].self_attention_cache, DynamicCache)
):
model_kwargs["past_key_values"].crop(-1) model_kwargs["past_key_values"].crop(-1)
else: else:
past_key_values = [] past_key_values = []
...@@ -2757,7 +2799,7 @@ class GenerationMixin: ...@@ -2757,7 +2799,7 @@ class GenerationMixin:
# Exception 2: models with different cache formats. These are limited to `DynamicCache` until their # Exception 2: models with different cache formats. These are limited to `DynamicCache` until their
# cache format is standardized, to avoid adding complexity to the codebase. # cache format is standardized, to avoid adding complexity to the codebase.
elif "bloom" in model_class or "gptbigcode" in model_class: elif "bloom" in model_class or "gptbigcode" in model_class:
if not isinstance(past_key_values, DynamicCache): if not isinstance(past_key_values, (DynamicCache, EncoderDecoderCache)):
raise ValueError( raise ValueError(
f"Using an unsupported cache format with {model_class}. Currently, it only supports the " f"Using an unsupported cache format with {model_class}. Currently, it only supports the "
"legacy tuple format or `DynamicCache`" "legacy tuple format or `DynamicCache`"
...@@ -3703,8 +3745,12 @@ class GenerationMixin: ...@@ -3703,8 +3745,12 @@ class GenerationMixin:
# This is needed if return_dict_in_generate is True # This is needed if return_dict_in_generate is True
start_from_empty_dynamic_cache = False start_from_empty_dynamic_cache = False
if isinstance(model_kwargs.get("past_key_values", None), DynamicCache): past_key_values = model_kwargs.get("past_key_values", None)
if len(model_kwargs["past_key_values"]) == 0: if isinstance(past_key_values, DynamicCache) or (
isinstance(past_key_values, EncoderDecoderCache)
and isinstance(past_key_values.self_attention_cache, DynamicCache)
):
if len(past_key_values) == 0:
start_from_empty_dynamic_cache = True start_from_empty_dynamic_cache = True
this_peer_finished = False this_peer_finished = False
...@@ -4022,7 +4068,9 @@ def _split(data, full_batch_size: int, split_size: int = None): ...@@ -4022,7 +4068,9 @@ def _split(data, full_batch_size: int, split_size: int = None):
if isinstance(data, torch.Tensor): if isinstance(data, torch.Tensor):
return [data[i : i + split_size] for i in range(0, full_batch_size, split_size)] return [data[i : i + split_size] for i in range(0, full_batch_size, split_size)]
# New cache format # New cache format
elif isinstance(data, DynamicCache): elif isinstance(data, DynamicCache) or (
isinstance(data, EncoderDecoderCache) and isinstance(data.self_attention_cache, DynamicCache)
):
return data.batch_split(full_batch_size, split_size) return data.batch_split(full_batch_size, split_size)
elif isinstance(data, tuple): elif isinstance(data, tuple):
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example) # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
...@@ -4130,6 +4178,8 @@ def stack_model_outputs(model_outputs: List[ModelOutput]) -> ModelOutput: ...@@ -4130,6 +4178,8 @@ def stack_model_outputs(model_outputs: List[ModelOutput]) -> ModelOutput:
# New cache format # New cache format
elif isinstance(data[0], DynamicCache): elif isinstance(data[0], DynamicCache):
return DynamicCache.from_batch_splits(data) return DynamicCache.from_batch_splits(data)
elif isinstance(data[0], EncoderDecoderCache):
return EncoderDecoderCache.from_batch_splits(data)
elif isinstance(data[0], tuple): elif isinstance(data[0], tuple):
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example) # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
if isinstance(data[0][0], tuple): if isinstance(data[0][0], tuple):
......
...@@ -189,7 +189,11 @@ class WhisperConfig(PretrainedConfig): ...@@ -189,7 +189,11 @@ class WhisperConfig(PretrainedConfig):
model_type = "whisper" model_type = "whisper"
keys_to_ignore_at_inference = ["past_key_values"] keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"} attribute_map = {
"num_key_value_heads": "encoder_attention_heads",
"num_attention_heads": "encoder_attention_heads",
"hidden_size": "d_model",
}
def __init__( def __init__(
self, self,
......
...@@ -37,6 +37,13 @@ class DynamicCache(metaclass=DummyObject): ...@@ -37,6 +37,13 @@ class DynamicCache(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class EncoderDecoderCache(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class HQQQuantizedCache(metaclass=DummyObject): class HQQQuantizedCache(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -57,7 +57,7 @@ if is_torch_available(): ...@@ -57,7 +57,7 @@ if is_torch_available():
ImageGPTForCausalImageModeling, ImageGPTForCausalImageModeling,
SpeechEncoderDecoderModel, SpeechEncoderDecoderModel,
) )
from transformers.cache_utils import DynamicCache, QuantoQuantizedCache from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache
from transformers.generation import ( from transformers.generation import (
BeamSampleDecoderOnlyOutput, BeamSampleDecoderOnlyOutput,
BeamSampleEncoderDecoderOutput, BeamSampleEncoderDecoderOutput,
...@@ -1636,7 +1636,6 @@ class GenerationTesterMixin: ...@@ -1636,7 +1636,6 @@ class GenerationTesterMixin:
config, input_ids, attention_mask = self._get_input_ids_and_config() config, input_ids, attention_mask = self._get_input_ids_and_config()
config.use_cache = True config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
generation_kwargs = { generation_kwargs = {
...@@ -1652,15 +1651,21 @@ class GenerationTesterMixin: ...@@ -1652,15 +1651,21 @@ class GenerationTesterMixin:
set_seed(seed) set_seed(seed)
legacy_results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) legacy_results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
set_seed(seed) set_seed(seed)
if config.is_encoder_decoder:
cache_cls = EncoderDecoderCache
past_key_values = cache_cls(DynamicCache(), DynamicCache())
else:
cache_cls = DynamicCache
past_key_values = cache_cls()
new_results = model.generate( new_results = model.generate(
input_ids, attention_mask=attention_mask, past_key_values=DynamicCache(), **generation_kwargs input_ids, attention_mask=attention_mask, past_key_values=past_key_values, **generation_kwargs
) )
# The two sets of generated sequences must match, despite the cache format between forward passes being # The two sets of generated sequences must match, despite the cache format between forward passes being
# different # different
self.assertListEqual(legacy_results.sequences.tolist(), new_results.sequences.tolist()) self.assertListEqual(legacy_results.sequences.tolist(), new_results.sequences.tolist())
self.assertTrue(isinstance(legacy_results.past_key_values, tuple)) self.assertTrue(isinstance(legacy_results.past_key_values, tuple))
self.assertTrue(isinstance(new_results.past_key_values, DynamicCache)) self.assertTrue(isinstance(new_results.past_key_values, cache_cls))
# The contents of the two caches, when converted to the same format (in both directions!), must match # The contents of the two caches, when converted to the same format (in both directions!), must match
legacy_cache = legacy_results.past_key_values legacy_cache = legacy_results.past_key_values
...@@ -1675,7 +1680,7 @@ class GenerationTesterMixin: ...@@ -1675,7 +1680,7 @@ class GenerationTesterMixin:
) )
new_cache = new_results.past_key_values new_cache = new_results.past_key_values
legacy_cache_converted = DynamicCache.from_legacy_cache(legacy_results.past_key_values) legacy_cache_converted = cache_cls.from_legacy_cache(legacy_results.past_key_values)
for layer_idx in range(len(new_cache)): for layer_idx in range(len(new_cache)):
for kv_idx in range(len(new_cache[layer_idx])): for kv_idx in range(len(new_cache[layer_idx])):
self.assertTrue( self.assertTrue(
......
...@@ -1539,6 +1539,46 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ...@@ -1539,6 +1539,46 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
def test_longform_generate_multi_batch_cond_prev(self): def test_longform_generate_multi_batch_cond_prev(self):
self._check_longform_generate_multi_batch(condition_on_prev_tokens=True) self._check_longform_generate_multi_batch(condition_on_prev_tokens=True)
def test_custom_4d_attention_mask(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = WhisperForConditionalGeneration(config).to(device=torch_device, dtype=torch.float32)
model.eval()
(
input_ids,
position_ids,
input_ids_shared_prefix,
mask_shared_prefix,
position_ids_shared_prefix,
) = self._get_custom_4d_mask_test_data()
with torch.no_grad():
logits = model.forward(
decoder_input_ids=input_ids,
input_features=input_dict["input_features"],
decoder_position_ids=position_ids,
).logits
# logits.shape == torch.Size([3, 4, ...])
logits_shared_prefix = model(
decoder_input_ids=input_ids_shared_prefix,
input_features=input_dict["input_features"],
decoder_attention_mask=mask_shared_prefix,
decoder_position_ids=position_ids_shared_prefix,
)[0]
# logits_shared_prefix.shape == torch.Size([1, 6, ...])
out_last_tokens = logits[:, -1, :] # last tokens in each batch line
out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens
# comparing greedily-chosen tokens:
assert torch.equal(out_last_tokens.max(axis=1).indices, out_shared_prefix_last_tokens.max(axis=1).indices)
# comparing softmax-normalized logits:
normalized_0 = torch.nn.functional.softmax(out_last_tokens)
normalized_1 = torch.nn.functional.softmax(out_shared_prefix_last_tokens)
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
@require_torch @require_torch
@require_torchaudio @require_torchaudio
...@@ -2961,6 +3001,34 @@ class WhisperModelIntegrationTests(unittest.TestCase): ...@@ -2961,6 +3001,34 @@ class WhisperModelIntegrationTests(unittest.TestCase):
torch.manual_seed(0) torch.manual_seed(0)
model.generate(**inputs, **gen_kwargs) model.generate(**inputs, **gen_kwargs)
@slow
def test_tiny_static_generation(self):
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
model.to(torch_device)
input_speech = self._load_datasamples(4)
input_features = processor(input_speech, return_tensors="pt", sampling_rate=16_000).input_features
input_features = input_features.to(torch_device)
eager_generated_ids = model.generate(input_features, max_new_tokens=64)
model.generation_config.cache_implementation = "static"
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
# compile the forward pass and assert equivalence
static_generated_ids = model.generate(input_features, max_new_tokens=64)
assert (eager_generated_ids == static_generated_ids).all()
# check the compiled graph can be re-used and that the cache is correctly reset
# reverse the ordering of the input features
permutation_idx = (
torch.arange(input_features.shape[0], 0, step=-1, dtype=torch.long, device=input_features.device) - 1
)
input_features = input_features[permutation_idx, ...]
static_generated_ids = model.generate(input_features, max_new_tokens=64)
# assert re-ordered generations match those from eager
assert (eager_generated_ids[permutation_idx, :] == static_generated_ids).all()
def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None): def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None):
if head_mask is None: if head_mask is None:
...@@ -3564,6 +3632,10 @@ class WhisperStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, ...@@ -3564,6 +3632,10 @@ class WhisperStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin,
config=config, input_ids=inputs_dict["input_ids"] config=config, input_ids=inputs_dict["input_ids"]
) )
@unittest.skip(reason="Tested implicitly through the encoder-decoder tests")
def test_custom_4d_attention_mask(self):
pass
@unittest.skip(reason="Generate needs input ids") @unittest.skip(reason="Generate needs input ids")
def test_generate_without_input_ids(self): def test_generate_without_input_ids(self):
# generate only works with input ids for whisper # generate only works with input ids for whisper
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment