Unverified Commit a9701953 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[whisper] static kv cache (#31166)



* make work with cache abstraction

* correct for static cache

* hacks for compile

* make fast

* fix

* fix pos ids

* generate

* fix sdpa

* fix sdpa cache pos

* fix fa2

* clean fa2

* integrate cache into generate

* make style

* copies

* more copies

* update eager

* update sdpa

* update fa2

* simplify

* use cache pos

* always compute cross-cache for debug

* avoid recompiles
Co-authored-by: default avatarArthur Zucker <arthur@huggingface.co>

* fix fix

* fix fix fix

* more fix

* try encoder-decoder cache (too messy)

* revert encoder-decoder cache

* check cross-attn cache

* use enc-dec dataclass

* use richer enc-dec dataclass

* clean-up

* revert static cache changes

* small fixes

* revert to cpu flag

* fix copies

* add static slow test

* past k/v docstring

* more docstrings

* cache_position docstrings

* add to docs

* add enc-dec cache to docs

* make style

* fix after rebase

* fix beam

* style

* fix generation strategies

* fix most decoder-only tests

* style

* skip test

* more clean up

* small docstrings

* Apply suggestions from code review
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* add todo

* only crop self-attn

* check cache in mixin

* style

* fix re-compile after rebase

* move `is_updated` logic to enc-dec wrapper

* revert back

* revert cache back

* finalise design

* fix

* fix fix

* style

* Update src/transformers/cache_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* deprecate

* updates

* final updates

* style

* style

---------
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 57d7594a
......@@ -391,6 +391,12 @@ A [`Constraint`] can be used to force the generation to include specific tokens
- get_seq_length
- reset
[[autodoc]] EncoderDecoderCache
- get_seq_length
- to_legacy_cache
- from_legacy_cache
- reset
- reorder_cache
## Watermark Utils
......
......@@ -52,8 +52,6 @@ Here is a step-by-step guide to transcribing an audio sample using a pre-trained
>>> # Select an audio file and read it:
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> audio_sample = ds[0]["audio"]
>>> waveform = audio_sample["array"]
>>> sampling_rate = audio_sample["sampling_rate"]
>>> # Load the Whisper model in Hugging Face format:
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
......@@ -61,7 +59,7 @@ Here is a step-by-step guide to transcribing an audio sample using a pre-trained
>>> # Use the model and processor to transcribe the audio:
>>> input_features = processor(
... waveform, sampling_rate=sampling_rate, return_tensors="pt"
... audio_sample["array"], sampling_rate=audio_sample["sampling_rate"], return_tensors="pt"
... ).input_features
>>> # Generate token ids
......@@ -74,6 +72,49 @@ Here is a step-by-step guide to transcribing an audio sample using a pre-trained
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
```
Whisper is compatible with the following optimisations:
- [PyTorch Scaled Dot Product Attention (SDPA)](../perf_infer_gpu_one#pytorch-scaled-dot-product-attention): flash attention and memory-efficient attention kernels. Enabled by default for `torch>=2.1.1`.
- [Flash Attention 2](../perf_infer_gpu_one#flashattention-2): improved implementation of flash attention through better parallelism and work partitioning.
- [torch.compile](../llm_optims#static-kv-cache-and-torchcompile): JIT-compile the forward pass to dispatch to efficient fused kernels.
As an example, the following codesnippet enables SDPA and `torch.compile` for up to 5x faster inference:
```python
>>> from datasets import load_dataset
>>> from transformers import WhisperProcessor, WhisperForConditionalGeneration
>>> # Select an audio file and read it:
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> audio_sample = ds[0]["audio"]
>>> # Load the Whisper model with SDPA attention
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", attn_implementation="sdpa")
>>> # Enable static cache and compile the forward pass
>>> model.generation_config.cache_implementation = "static"
>>> model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
>>> # Use the model and processor to transcribe the audio:
>>> input_features = processor(
... audio_sample["array"], sampling_rate=audio_sample["sampling_rate"], return_tensors="pt"
... ).input_features
>>> # Compile the forward pass
>>> _ = model.generate(input_features)
>>> # Generate token ids using compiled graph (fast!)
>>> predicted_ids = model.generate(input_features)
>>> # Decode token ids to text
>>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
>>> transcription[0]
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
```
For more details on each optimisation, refer to the documentation linked above.
## Resources
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with Whisper. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
......
......@@ -1212,6 +1212,7 @@ else:
"Cache",
"CacheConfig",
"DynamicCache",
"EncoderDecoderCache",
"HQQQuantizedCache",
"QuantizedCache",
"QuantizedCacheConfig",
......@@ -5895,6 +5896,7 @@ if TYPE_CHECKING:
Cache,
CacheConfig,
DynamicCache,
EncoderDecoderCache,
HQQQuantizedCache,
QuantizedCache,
QuantizedCacheConfig,
......
......@@ -858,8 +858,12 @@ class StaticCache(Cache):
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
if cache_position is None:
k_out.copy_(key_states)
v_out.copy_(value_states)
else:
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
return k_out, v_out
......@@ -971,6 +975,158 @@ class SlidingWindowCache(StaticCache):
# no matter how long the sentence is
return None
def reset(self):
self.key_cache.zero_()
self.value_cache.zero_()
class EncoderDecoderCache(Cache):
"""
Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and
cross-attention caches.
"""
def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache):
self.self_attention_cache = self_attention_cache
self.cross_attention_cache = cross_attention_cache
self.is_updated = {}
for layer_idx in range(len(cross_attention_cache.key_cache)):
self.is_updated[layer_idx] = bool(cross_attention_cache.get_seq_length(layer_idx) > 0)
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
"""
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
sequence length.
"""
if layer_idx < len(self):
return (
self.self_attention_cache.key_cache[layer_idx],
self.self_attention_cache.value_cache[layer_idx],
self.cross_attention_cache.key_cache[layer_idx],
self.cross_attention_cache.key_cache[layer_idx],
)
else:
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
def __len__(self):
"""
Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
to the number of layers in the model.
"""
return len(self.self_attention_cache)
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
"""Converts the `EncoderDecoderCache` instance into its equivalent in the legacy cache format."""
legacy_cache = ()
if len(self.cross_attention_cache) > 0:
for self_attn, cross_attn in zip(
self.self_attention_cache.to_legacy_cache(), self.cross_attention_cache.to_legacy_cache()
):
legacy_cache += (self_attn + cross_attn,)
else:
legacy_cache = self.self_attention_cache.to_legacy_cache()
return legacy_cache
@classmethod
def from_legacy_cache(
cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
) -> "EncoderDecoderCache":
"""Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`."""
cache = cls(self_attention_cache=DynamicCache(), cross_attention_cache=DynamicCache())
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
key_states, value_states = past_key_values[layer_idx][:2]
cache.self_attention_cache.update(key_states, value_states, layer_idx)
if len(past_key_values[layer_idx]) > 2:
key_states, value_states = past_key_values[layer_idx][2:]
cache.cross_attention_cache.update(key_states, value_states, layer_idx)
cache.is_updated[layer_idx] = True
return cache
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
if len(self.self_attention_cache.key_cache) <= layer_idx:
return 0
return (self.self_attention_cache.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
def reset(self):
if hasattr(self.self_attention_cache, "reset"):
self.self_attention_cache.reset()
if hasattr(self.cross_attention_cache, "reset"):
self.cross_attention_cache.reset()
elif not hasattr(self.self_attention_cache, "reset") and not hasattr(self.cross_attention_cache, "reset"):
raise ValueError(
"Neither self nor cross-attention cache have valid `.reset()` methods. `.reset()` should "
"only be called on compatible cache classes, such as `StaticCache` or `SlidingWindowCache`. "
f"Got {self.self_attention_cache.__str__()} for the self attention cache and "
f"{self.cross_attention_cache.__str__()} for the cross attention cache."
)
for layer_idx in self.is_updated:
self.is_updated[layer_idx] = False
def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
self.self_attention_cache.reorder_cache(beam_idx)
self.cross_attention_cache.reorder_cache(beam_idx)
def check_dynamic_cache(self, method: str):
if not (
isinstance(self.self_attention_cache, DynamicCache)
and isinstance(self.cross_attention_cache, DynamicCache)
):
raise ValueError(
f"`{method}` is only defined for dynamic cache, got {self.self_attention_cache.__str__()} for the self "
f"attention cache and {self.cross_attention_cache.__str__()} for the cross attention cache."
)
# TODO(gante, sanchit-gandhi): move following functionality into `.generate`
def crop(self, maximum_length: int):
"""Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search."""
self.check_dynamic_cache(self.crop.__name__)
self.self_attention_cache.crop(maximum_length)
def batch_split(self, full_batch_size: int, split_size: int) -> "List[EncoderDecoderCache]":
"""Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
`_split_model_inputs()` in `generation.utils`"""
self.check_dynamic_cache(self.batch_split.__name__)
self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size)
cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size)
out = []
for self_attn, cross_attn in zip(self_attention_cache, cross_attention_cache):
out.append(EncoderDecoderCache(self_attn, cross_attn))
return out
@classmethod
def from_batch_splits(cls, splits: List["EncoderDecoderCache"]) -> "EncoderDecoderCache":
"""This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
`generation.utils`"""
self_attention_cache = DynamicCache()
cross_attention_cache = DynamicCache()
for idx in range(len(splits[0])):
layer_keys = torch.cat([current.self_attention_cache.key_cache[idx] for current in splits], dim=0)
layer_values = torch.cat([current.self_attention_cache.value_cache[idx] for current in splits], dim=0)
self_attention_cache.update(layer_keys, layer_values, idx)
layer_keys = torch.cat([current.cross_attention_cache.key_cache[idx] for current in splits], dim=0)
layer_values = torch.cat([current.cross_attention_cache.value_cache[idx] for current in splits], dim=0)
cross_attention_cache.update(layer_keys, layer_values, idx)
return cls(self_attention_cache, cross_attention_cache)
def batch_repeat_interleave(self, repeats: int):
"""Repeat the cache `repeats` times in the batch dimension. Used in contrastive search."""
self.check_dynamic_cache(self.batch_repeat_interleave.__name__)
self.self_attention_cache.batch_repeat_interleave(repeats)
self.cross_attention_cache.batch_repeat_interleave(repeats)
def batch_select_indices(self, indices: torch.Tensor):
"""Only keep the `indices` in the batch dimension of the cache. Used in contrastive search."""
self.check_dynamic_cache(self.batch_select_indices.__name__)
self.self_attention_cache.batch_select_indices(indices)
self.cross_attention_cache.batch_select_indices(indices)
class HybridCache(Cache):
def __init__(self, config: PretrainedConfig, max_batch_size, max_cache_len, device="cpu", dtype=None) -> None:
......
......@@ -27,6 +27,7 @@ from torch import nn
from ..cache_utils import (
Cache,
DynamicCache,
EncoderDecoderCache,
HQQQuantizedCache,
HybridCache,
QuantizedCacheConfig,
......@@ -1409,7 +1410,7 @@ class GenerationMixin:
model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device)
return model_kwargs
def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_len: int) -> Cache:
def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_len: int, model_kwargs) -> Cache:
"""
Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a
new `generate` call requires a larger cache.
......@@ -1417,28 +1418,46 @@ class GenerationMixin:
Returns the resulting cache object.
"""
cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation]
requires_cross_attention_cache = (
self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
)
if hasattr(self, "_cache"):
cache_to_check = self._cache.self_attention_cache if requires_cross_attention_cache else self._cache
if cache_implementation == "sliding_window":
max_cache_len = min(self.config.sliding_window, max_cache_len)
need_new_cache = (
not hasattr(self, "_cache")
or (not isinstance(self._cache, cache_cls))
or self._cache.max_batch_size != max_batch_size
or self._cache.max_cache_len < max_cache_len
or (not isinstance(cache_to_check, cache_cls))
or cache_to_check.max_batch_size != max_batch_size
or cache_to_check.max_cache_len < max_cache_len
)
if requires_cross_attention_cache and hasattr(self, "_cache"):
need_new_cache = (
need_new_cache
or self._cache.cross_attention_cache.max_cache_len != model_kwargs["encoder_outputs"][0].shape[1]
)
if need_new_cache:
if hasattr(self.config, "_pre_quantization_dtype"):
cache_dtype = self.config._pre_quantization_dtype
else:
cache_dtype = self.dtype
self._cache = cache_cls(
config=self.config,
max_batch_size=max_batch_size,
max_cache_len=max_cache_len,
device=self.device,
dtype=cache_dtype,
)
cache_kwargs = {
"config": self.config,
"max_batch_size": max_batch_size,
"max_cache_len": max_cache_len,
"device": self.device,
"dtype": cache_dtype,
}
self._cache = cache_cls(**cache_kwargs)
if requires_cross_attention_cache:
encoder_kwargs = cache_kwargs.copy()
encoder_kwargs["max_cache_len"] = model_kwargs["encoder_outputs"][0].shape[1]
self._cache = EncoderDecoderCache(self._cache, cache_cls(**encoder_kwargs))
else:
self._cache.reset()
return self._cache
......@@ -1745,6 +1764,7 @@ class GenerationMixin:
generation_config.cache_implementation,
getattr(generation_config, "num_beams", 1) * batch_size,
generation_config.max_length,
model_kwargs,
)
elif generation_config.cache_implementation == "quantized":
if not self._supports_quantized_cache:
......@@ -1776,11 +1796,22 @@ class GenerationMixin:
# keeps copying the cache thus using much more memory
elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache():
past = model_kwargs.get("past_key_values", None)
requires_cross_attention_cache = (
self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
)
if past is None:
model_kwargs["past_key_values"] = DynamicCache()
model_kwargs["past_key_values"] = (
DynamicCache()
if not requires_cross_attention_cache
else EncoderDecoderCache(DynamicCache(), DynamicCache())
)
use_dynamic_cache_by_default = True
elif isinstance(past, tuple):
model_kwargs["past_key_values"] = DynamicCache.from_legacy_cache(past)
model_kwargs["past_key_values"] = (
DynamicCache.from_legacy_cache(past)
if not requires_cross_attention_cache
else EncoderDecoderCache.from_legacy_cache(past)
)
use_dynamic_cache_by_default = True
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
......@@ -2064,7 +2095,7 @@ class GenerationMixin:
# Convert to legacy cache if needed
if use_dynamic_cache_by_default and generation_config.return_legacy_cache:
if isinstance(result, ModelOutput) and hasattr(result, "past_key_values"):
if isinstance(result.past_key_values, DynamicCache):
if isinstance(result.past_key_values, (DynamicCache, EncoderDecoderCache)):
result.past_key_values = result.past_key_values.to_legacy_cache()
return result
......@@ -2234,7 +2265,7 @@ class GenerationMixin:
# if the first step in the loop, encode all the prefix and obtain: (1) past_key_values;
# (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step
if model_kwargs.get("past_key_values") is None or (
isinstance(model_kwargs["past_key_values"], Cache)
isinstance(model_kwargs["past_key_values"], (Cache, EncoderDecoderCache))
and model_kwargs["past_key_values"].get_seq_length() == 0
):
# prepare inputs
......@@ -2323,7 +2354,9 @@ class GenerationMixin:
# Replicates the new past_key_values to match the `top_k` candidates
past = model_kwargs["past_key_values"]
# If it is a static cache, modify it in-place layer after layer to save memory
if isinstance(past, DynamicCache):
if isinstance(past, DynamicCache) or (
isinstance(past, EncoderDecoderCache) and isinstance(past.self_attention_cache, DynamicCache)
):
past.batch_repeat_interleave(top_k)
else:
new_key_values = []
......@@ -2350,7 +2383,10 @@ class GenerationMixin:
output_hidden_states=True,
output_attentions=output_attentions,
)
if isinstance(outputs["past_key_values"], DynamicCache):
if isinstance(outputs["past_key_values"], DynamicCache) or (
isinstance(outputs["past_key_values"], EncoderDecoderCache)
and isinstance(outputs["past_key_values"].self_attention_cache, DynamicCache)
):
# Remove past K-V from output since we don't need to stack later
outputs["past_key_values"] = None
# Remove last token from past K-V since we don't want to append it at this point
......@@ -2425,7 +2461,10 @@ class GenerationMixin:
else:
_, next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True)
# Do it in-place layer per layer to save memory
if isinstance(next_past_key_values, DynamicCache):
if isinstance(next_past_key_values, DynamicCache) or (
isinstance(next_past_key_values, EncoderDecoderCache)
and isinstance(next_past_key_values.self_attention_cache, DynamicCache)
):
next_past_key_values.batch_select_indices(augmented_idx)
else:
new_key_values = []
......@@ -2498,7 +2537,10 @@ class GenerationMixin:
# Contrastive search works by forward looking at the next token, so we need to exclude it from
# `past_key_values` to be consistent with the other decoding methods
if model_kwargs.get("past_key_values") is not None:
if isinstance(model_kwargs["past_key_values"], DynamicCache):
if isinstance(model_kwargs["past_key_values"], DynamicCache) or (
isinstance(model_kwargs["past_key_values"], EncoderDecoderCache)
and isinstance(model_kwargs["past_key_values"].self_attention_cache, DynamicCache)
):
model_kwargs["past_key_values"].crop(-1)
else:
past_key_values = []
......@@ -2757,7 +2799,7 @@ class GenerationMixin:
# Exception 2: models with different cache formats. These are limited to `DynamicCache` until their
# cache format is standardized, to avoid adding complexity to the codebase.
elif "bloom" in model_class or "gptbigcode" in model_class:
if not isinstance(past_key_values, DynamicCache):
if not isinstance(past_key_values, (DynamicCache, EncoderDecoderCache)):
raise ValueError(
f"Using an unsupported cache format with {model_class}. Currently, it only supports the "
"legacy tuple format or `DynamicCache`"
......@@ -3703,8 +3745,12 @@ class GenerationMixin:
# This is needed if return_dict_in_generate is True
start_from_empty_dynamic_cache = False
if isinstance(model_kwargs.get("past_key_values", None), DynamicCache):
if len(model_kwargs["past_key_values"]) == 0:
past_key_values = model_kwargs.get("past_key_values", None)
if isinstance(past_key_values, DynamicCache) or (
isinstance(past_key_values, EncoderDecoderCache)
and isinstance(past_key_values.self_attention_cache, DynamicCache)
):
if len(past_key_values) == 0:
start_from_empty_dynamic_cache = True
this_peer_finished = False
......@@ -4022,7 +4068,9 @@ def _split(data, full_batch_size: int, split_size: int = None):
if isinstance(data, torch.Tensor):
return [data[i : i + split_size] for i in range(0, full_batch_size, split_size)]
# New cache format
elif isinstance(data, DynamicCache):
elif isinstance(data, DynamicCache) or (
isinstance(data, EncoderDecoderCache) and isinstance(data.self_attention_cache, DynamicCache)
):
return data.batch_split(full_batch_size, split_size)
elif isinstance(data, tuple):
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
......@@ -4130,6 +4178,8 @@ def stack_model_outputs(model_outputs: List[ModelOutput]) -> ModelOutput:
# New cache format
elif isinstance(data[0], DynamicCache):
return DynamicCache.from_batch_splits(data)
elif isinstance(data[0], EncoderDecoderCache):
return EncoderDecoderCache.from_batch_splits(data)
elif isinstance(data[0], tuple):
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
if isinstance(data[0][0], tuple):
......
......@@ -189,7 +189,11 @@ class WhisperConfig(PretrainedConfig):
model_type = "whisper"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
attribute_map = {
"num_key_value_heads": "encoder_attention_heads",
"num_attention_heads": "encoder_attention_heads",
"hidden_size": "d_model",
}
def __init__(
self,
......
......@@ -37,6 +37,13 @@ class DynamicCache(metaclass=DummyObject):
requires_backends(self, ["torch"])
class EncoderDecoderCache(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class HQQQuantizedCache(metaclass=DummyObject):
_backends = ["torch"]
......
......@@ -57,7 +57,7 @@ if is_torch_available():
ImageGPTForCausalImageModeling,
SpeechEncoderDecoderModel,
)
from transformers.cache_utils import DynamicCache, QuantoQuantizedCache
from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache
from transformers.generation import (
BeamSampleDecoderOnlyOutput,
BeamSampleEncoderDecoderOutput,
......@@ -1636,7 +1636,6 @@ class GenerationTesterMixin:
config, input_ids, attention_mask = self._get_input_ids_and_config()
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
generation_kwargs = {
......@@ -1652,15 +1651,21 @@ class GenerationTesterMixin:
set_seed(seed)
legacy_results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
set_seed(seed)
if config.is_encoder_decoder:
cache_cls = EncoderDecoderCache
past_key_values = cache_cls(DynamicCache(), DynamicCache())
else:
cache_cls = DynamicCache
past_key_values = cache_cls()
new_results = model.generate(
input_ids, attention_mask=attention_mask, past_key_values=DynamicCache(), **generation_kwargs
input_ids, attention_mask=attention_mask, past_key_values=past_key_values, **generation_kwargs
)
# The two sets of generated sequences must match, despite the cache format between forward passes being
# different
self.assertListEqual(legacy_results.sequences.tolist(), new_results.sequences.tolist())
self.assertTrue(isinstance(legacy_results.past_key_values, tuple))
self.assertTrue(isinstance(new_results.past_key_values, DynamicCache))
self.assertTrue(isinstance(new_results.past_key_values, cache_cls))
# The contents of the two caches, when converted to the same format (in both directions!), must match
legacy_cache = legacy_results.past_key_values
......@@ -1675,7 +1680,7 @@ class GenerationTesterMixin:
)
new_cache = new_results.past_key_values
legacy_cache_converted = DynamicCache.from_legacy_cache(legacy_results.past_key_values)
legacy_cache_converted = cache_cls.from_legacy_cache(legacy_results.past_key_values)
for layer_idx in range(len(new_cache)):
for kv_idx in range(len(new_cache[layer_idx])):
self.assertTrue(
......
......@@ -1539,6 +1539,46 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
def test_longform_generate_multi_batch_cond_prev(self):
self._check_longform_generate_multi_batch(condition_on_prev_tokens=True)
def test_custom_4d_attention_mask(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = WhisperForConditionalGeneration(config).to(device=torch_device, dtype=torch.float32)
model.eval()
(
input_ids,
position_ids,
input_ids_shared_prefix,
mask_shared_prefix,
position_ids_shared_prefix,
) = self._get_custom_4d_mask_test_data()
with torch.no_grad():
logits = model.forward(
decoder_input_ids=input_ids,
input_features=input_dict["input_features"],
decoder_position_ids=position_ids,
).logits
# logits.shape == torch.Size([3, 4, ...])
logits_shared_prefix = model(
decoder_input_ids=input_ids_shared_prefix,
input_features=input_dict["input_features"],
decoder_attention_mask=mask_shared_prefix,
decoder_position_ids=position_ids_shared_prefix,
)[0]
# logits_shared_prefix.shape == torch.Size([1, 6, ...])
out_last_tokens = logits[:, -1, :] # last tokens in each batch line
out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens
# comparing greedily-chosen tokens:
assert torch.equal(out_last_tokens.max(axis=1).indices, out_shared_prefix_last_tokens.max(axis=1).indices)
# comparing softmax-normalized logits:
normalized_0 = torch.nn.functional.softmax(out_last_tokens)
normalized_1 = torch.nn.functional.softmax(out_shared_prefix_last_tokens)
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
@require_torch
@require_torchaudio
......@@ -2961,6 +3001,34 @@ class WhisperModelIntegrationTests(unittest.TestCase):
torch.manual_seed(0)
model.generate(**inputs, **gen_kwargs)
@slow
def test_tiny_static_generation(self):
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
model.to(torch_device)
input_speech = self._load_datasamples(4)
input_features = processor(input_speech, return_tensors="pt", sampling_rate=16_000).input_features
input_features = input_features.to(torch_device)
eager_generated_ids = model.generate(input_features, max_new_tokens=64)
model.generation_config.cache_implementation = "static"
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
# compile the forward pass and assert equivalence
static_generated_ids = model.generate(input_features, max_new_tokens=64)
assert (eager_generated_ids == static_generated_ids).all()
# check the compiled graph can be re-used and that the cache is correctly reset
# reverse the ordering of the input features
permutation_idx = (
torch.arange(input_features.shape[0], 0, step=-1, dtype=torch.long, device=input_features.device) - 1
)
input_features = input_features[permutation_idx, ...]
static_generated_ids = model.generate(input_features, max_new_tokens=64)
# assert re-ordered generations match those from eager
assert (eager_generated_ids[permutation_idx, :] == static_generated_ids).all()
def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None):
if head_mask is None:
......@@ -3564,6 +3632,10 @@ class WhisperStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin,
config=config, input_ids=inputs_dict["input_ids"]
)
@unittest.skip(reason="Tested implicitly through the encoder-decoder tests")
def test_custom_4d_attention_mask(self):
pass
@unittest.skip(reason="Generate needs input ids")
def test_generate_without_input_ids(self):
# generate only works with input ids for whisper
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment