Unverified Commit a9701953 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[whisper] static kv cache (#31166)



* make work with cache abstraction

* correct for static cache

* hacks for compile

* make fast

* fix

* fix pos ids

* generate

* fix sdpa

* fix sdpa cache pos

* fix fa2

* clean fa2

* integrate cache into generate

* make style

* copies

* more copies

* update eager

* update sdpa

* update fa2

* simplify

* use cache pos

* always compute cross-cache for debug

* avoid recompiles
Co-authored-by: default avatarArthur Zucker <arthur@huggingface.co>

* fix fix

* fix fix fix

* more fix

* try encoder-decoder cache (too messy)

* revert encoder-decoder cache

* check cross-attn cache

* use enc-dec dataclass

* use richer enc-dec dataclass

* clean-up

* revert static cache changes

* small fixes

* revert to cpu flag

* fix copies

* add static slow test

* past k/v docstring

* more docstrings

* cache_position docstrings

* add to docs

* add enc-dec cache to docs

* make style

* fix after rebase

* fix beam

* style

* fix generation strategies

* fix most decoder-only tests

* style

* skip test

* more clean up

* small docstrings

* Apply suggestions from code review
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* add todo

* only crop self-attn

* check cache in mixin

* style

* fix re-compile after rebase

* move `is_updated` logic to enc-dec wrapper

* revert back

* revert cache back

* finalise design

* fix

* fix fix

* style

* Update src/transformers/cache_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* deprecate

* updates

* final updates

* style

* style

---------
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 57d7594a
...@@ -391,6 +391,12 @@ A [`Constraint`] can be used to force the generation to include specific tokens ...@@ -391,6 +391,12 @@ A [`Constraint`] can be used to force the generation to include specific tokens
- get_seq_length - get_seq_length
- reset - reset
[[autodoc]] EncoderDecoderCache
- get_seq_length
- to_legacy_cache
- from_legacy_cache
- reset
- reorder_cache
## Watermark Utils ## Watermark Utils
......
...@@ -52,8 +52,6 @@ Here is a step-by-step guide to transcribing an audio sample using a pre-trained ...@@ -52,8 +52,6 @@ Here is a step-by-step guide to transcribing an audio sample using a pre-trained
>>> # Select an audio file and read it: >>> # Select an audio file and read it:
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> audio_sample = ds[0]["audio"] >>> audio_sample = ds[0]["audio"]
>>> waveform = audio_sample["array"]
>>> sampling_rate = audio_sample["sampling_rate"]
>>> # Load the Whisper model in Hugging Face format: >>> # Load the Whisper model in Hugging Face format:
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") >>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
...@@ -61,7 +59,7 @@ Here is a step-by-step guide to transcribing an audio sample using a pre-trained ...@@ -61,7 +59,7 @@ Here is a step-by-step guide to transcribing an audio sample using a pre-trained
>>> # Use the model and processor to transcribe the audio: >>> # Use the model and processor to transcribe the audio:
>>> input_features = processor( >>> input_features = processor(
... waveform, sampling_rate=sampling_rate, return_tensors="pt" ... audio_sample["array"], sampling_rate=audio_sample["sampling_rate"], return_tensors="pt"
... ).input_features ... ).input_features
>>> # Generate token ids >>> # Generate token ids
...@@ -74,6 +72,49 @@ Here is a step-by-step guide to transcribing an audio sample using a pre-trained ...@@ -74,6 +72,49 @@ Here is a step-by-step guide to transcribing an audio sample using a pre-trained
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.' ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
``` ```
Whisper is compatible with the following optimisations:
- [PyTorch Scaled Dot Product Attention (SDPA)](../perf_infer_gpu_one#pytorch-scaled-dot-product-attention): flash attention and memory-efficient attention kernels. Enabled by default for `torch>=2.1.1`.
- [Flash Attention 2](../perf_infer_gpu_one#flashattention-2): improved implementation of flash attention through better parallelism and work partitioning.
- [torch.compile](../llm_optims#static-kv-cache-and-torchcompile): JIT-compile the forward pass to dispatch to efficient fused kernels.
As an example, the following codesnippet enables SDPA and `torch.compile` for up to 5x faster inference:
```python
>>> from datasets import load_dataset
>>> from transformers import WhisperProcessor, WhisperForConditionalGeneration
>>> # Select an audio file and read it:
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> audio_sample = ds[0]["audio"]
>>> # Load the Whisper model with SDPA attention
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", attn_implementation="sdpa")
>>> # Enable static cache and compile the forward pass
>>> model.generation_config.cache_implementation = "static"
>>> model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
>>> # Use the model and processor to transcribe the audio:
>>> input_features = processor(
... audio_sample["array"], sampling_rate=audio_sample["sampling_rate"], return_tensors="pt"
... ).input_features
>>> # Compile the forward pass
>>> _ = model.generate(input_features)
>>> # Generate token ids using compiled graph (fast!)
>>> predicted_ids = model.generate(input_features)
>>> # Decode token ids to text
>>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
>>> transcription[0]
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
```
For more details on each optimisation, refer to the documentation linked above.
## Resources ## Resources
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with Whisper. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource. A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with Whisper. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
......
...@@ -1212,6 +1212,7 @@ else: ...@@ -1212,6 +1212,7 @@ else:
"Cache", "Cache",
"CacheConfig", "CacheConfig",
"DynamicCache", "DynamicCache",
"EncoderDecoderCache",
"HQQQuantizedCache", "HQQQuantizedCache",
"QuantizedCache", "QuantizedCache",
"QuantizedCacheConfig", "QuantizedCacheConfig",
...@@ -5895,6 +5896,7 @@ if TYPE_CHECKING: ...@@ -5895,6 +5896,7 @@ if TYPE_CHECKING:
Cache, Cache,
CacheConfig, CacheConfig,
DynamicCache, DynamicCache,
EncoderDecoderCache,
HQQQuantizedCache, HQQQuantizedCache,
QuantizedCache, QuantizedCache,
QuantizedCacheConfig, QuantizedCacheConfig,
......
...@@ -858,8 +858,12 @@ class StaticCache(Cache): ...@@ -858,8 +858,12 @@ class StaticCache(Cache):
k_out = self.key_cache[layer_idx] k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx] v_out = self.value_cache[layer_idx]
k_out[:, :, cache_position] = key_states if cache_position is None:
v_out[:, :, cache_position] = value_states k_out.copy_(key_states)
v_out.copy_(value_states)
else:
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
return k_out, v_out return k_out, v_out
...@@ -971,6 +975,158 @@ class SlidingWindowCache(StaticCache): ...@@ -971,6 +975,158 @@ class SlidingWindowCache(StaticCache):
# no matter how long the sentence is # no matter how long the sentence is
return None return None
def reset(self):
self.key_cache.zero_()
self.value_cache.zero_()
class EncoderDecoderCache(Cache):
"""
Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and
cross-attention caches.
"""
def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache):
self.self_attention_cache = self_attention_cache
self.cross_attention_cache = cross_attention_cache
self.is_updated = {}
for layer_idx in range(len(cross_attention_cache.key_cache)):
self.is_updated[layer_idx] = bool(cross_attention_cache.get_seq_length(layer_idx) > 0)
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
"""
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
sequence length.
"""
if layer_idx < len(self):
return (
self.self_attention_cache.key_cache[layer_idx],
self.self_attention_cache.value_cache[layer_idx],
self.cross_attention_cache.key_cache[layer_idx],
self.cross_attention_cache.key_cache[layer_idx],
)
else:
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
def __len__(self):
"""
Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
to the number of layers in the model.
"""
return len(self.self_attention_cache)
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
"""Converts the `EncoderDecoderCache` instance into its equivalent in the legacy cache format."""
legacy_cache = ()
if len(self.cross_attention_cache) > 0:
for self_attn, cross_attn in zip(
self.self_attention_cache.to_legacy_cache(), self.cross_attention_cache.to_legacy_cache()
):
legacy_cache += (self_attn + cross_attn,)
else:
legacy_cache = self.self_attention_cache.to_legacy_cache()
return legacy_cache
@classmethod
def from_legacy_cache(
cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
) -> "EncoderDecoderCache":
"""Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`."""
cache = cls(self_attention_cache=DynamicCache(), cross_attention_cache=DynamicCache())
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
key_states, value_states = past_key_values[layer_idx][:2]
cache.self_attention_cache.update(key_states, value_states, layer_idx)
if len(past_key_values[layer_idx]) > 2:
key_states, value_states = past_key_values[layer_idx][2:]
cache.cross_attention_cache.update(key_states, value_states, layer_idx)
cache.is_updated[layer_idx] = True
return cache
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
if len(self.self_attention_cache.key_cache) <= layer_idx:
return 0
return (self.self_attention_cache.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
def reset(self):
if hasattr(self.self_attention_cache, "reset"):
self.self_attention_cache.reset()
if hasattr(self.cross_attention_cache, "reset"):
self.cross_attention_cache.reset()
elif not hasattr(self.self_attention_cache, "reset") and not hasattr(self.cross_attention_cache, "reset"):
raise ValueError(
"Neither self nor cross-attention cache have valid `.reset()` methods. `.reset()` should "
"only be called on compatible cache classes, such as `StaticCache` or `SlidingWindowCache`. "
f"Got {self.self_attention_cache.__str__()} for the self attention cache and "
f"{self.cross_attention_cache.__str__()} for the cross attention cache."
)
for layer_idx in self.is_updated:
self.is_updated[layer_idx] = False
def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
self.self_attention_cache.reorder_cache(beam_idx)
self.cross_attention_cache.reorder_cache(beam_idx)
def check_dynamic_cache(self, method: str):
if not (
isinstance(self.self_attention_cache, DynamicCache)
and isinstance(self.cross_attention_cache, DynamicCache)
):
raise ValueError(
f"`{method}` is only defined for dynamic cache, got {self.self_attention_cache.__str__()} for the self "
f"attention cache and {self.cross_attention_cache.__str__()} for the cross attention cache."
)
# TODO(gante, sanchit-gandhi): move following functionality into `.generate`
def crop(self, maximum_length: int):
"""Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search."""
self.check_dynamic_cache(self.crop.__name__)
self.self_attention_cache.crop(maximum_length)
def batch_split(self, full_batch_size: int, split_size: int) -> "List[EncoderDecoderCache]":
"""Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
`_split_model_inputs()` in `generation.utils`"""
self.check_dynamic_cache(self.batch_split.__name__)
self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size)
cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size)
out = []
for self_attn, cross_attn in zip(self_attention_cache, cross_attention_cache):
out.append(EncoderDecoderCache(self_attn, cross_attn))
return out
@classmethod
def from_batch_splits(cls, splits: List["EncoderDecoderCache"]) -> "EncoderDecoderCache":
"""This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
`generation.utils`"""
self_attention_cache = DynamicCache()
cross_attention_cache = DynamicCache()
for idx in range(len(splits[0])):
layer_keys = torch.cat([current.self_attention_cache.key_cache[idx] for current in splits], dim=0)
layer_values = torch.cat([current.self_attention_cache.value_cache[idx] for current in splits], dim=0)
self_attention_cache.update(layer_keys, layer_values, idx)
layer_keys = torch.cat([current.cross_attention_cache.key_cache[idx] for current in splits], dim=0)
layer_values = torch.cat([current.cross_attention_cache.value_cache[idx] for current in splits], dim=0)
cross_attention_cache.update(layer_keys, layer_values, idx)
return cls(self_attention_cache, cross_attention_cache)
def batch_repeat_interleave(self, repeats: int):
"""Repeat the cache `repeats` times in the batch dimension. Used in contrastive search."""
self.check_dynamic_cache(self.batch_repeat_interleave.__name__)
self.self_attention_cache.batch_repeat_interleave(repeats)
self.cross_attention_cache.batch_repeat_interleave(repeats)
def batch_select_indices(self, indices: torch.Tensor):
"""Only keep the `indices` in the batch dimension of the cache. Used in contrastive search."""
self.check_dynamic_cache(self.batch_select_indices.__name__)
self.self_attention_cache.batch_select_indices(indices)
self.cross_attention_cache.batch_select_indices(indices)
class HybridCache(Cache): class HybridCache(Cache):
def __init__(self, config: PretrainedConfig, max_batch_size, max_cache_len, device="cpu", dtype=None) -> None: def __init__(self, config: PretrainedConfig, max_batch_size, max_cache_len, device="cpu", dtype=None) -> None:
......
...@@ -27,6 +27,7 @@ from torch import nn ...@@ -27,6 +27,7 @@ from torch import nn
from ..cache_utils import ( from ..cache_utils import (
Cache, Cache,
DynamicCache, DynamicCache,
EncoderDecoderCache,
HQQQuantizedCache, HQQQuantizedCache,
HybridCache, HybridCache,
QuantizedCacheConfig, QuantizedCacheConfig,
...@@ -1409,7 +1410,7 @@ class GenerationMixin: ...@@ -1409,7 +1410,7 @@ class GenerationMixin:
model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device) model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device)
return model_kwargs return model_kwargs
def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_len: int) -> Cache: def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_len: int, model_kwargs) -> Cache:
""" """
Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a
new `generate` call requires a larger cache. new `generate` call requires a larger cache.
...@@ -1417,28 +1418,46 @@ class GenerationMixin: ...@@ -1417,28 +1418,46 @@ class GenerationMixin:
Returns the resulting cache object. Returns the resulting cache object.
""" """
cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation] cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation]
requires_cross_attention_cache = (
self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
)
if hasattr(self, "_cache"):
cache_to_check = self._cache.self_attention_cache if requires_cross_attention_cache else self._cache
if cache_implementation == "sliding_window": if cache_implementation == "sliding_window":
max_cache_len = min(self.config.sliding_window, max_cache_len) max_cache_len = min(self.config.sliding_window, max_cache_len)
need_new_cache = ( need_new_cache = (
not hasattr(self, "_cache") not hasattr(self, "_cache")
or (not isinstance(self._cache, cache_cls)) or (not isinstance(cache_to_check, cache_cls))
or self._cache.max_batch_size != max_batch_size or cache_to_check.max_batch_size != max_batch_size
or self._cache.max_cache_len < max_cache_len or cache_to_check.max_cache_len < max_cache_len
) )
if requires_cross_attention_cache and hasattr(self, "_cache"):
need_new_cache = (
need_new_cache
or self._cache.cross_attention_cache.max_cache_len != model_kwargs["encoder_outputs"][0].shape[1]
)
if need_new_cache: if need_new_cache:
if hasattr(self.config, "_pre_quantization_dtype"): if hasattr(self.config, "_pre_quantization_dtype"):
cache_dtype = self.config._pre_quantization_dtype cache_dtype = self.config._pre_quantization_dtype
else: else:
cache_dtype = self.dtype cache_dtype = self.dtype
self._cache = cache_cls( cache_kwargs = {
config=self.config, "config": self.config,
max_batch_size=max_batch_size, "max_batch_size": max_batch_size,
max_cache_len=max_cache_len, "max_cache_len": max_cache_len,
device=self.device, "device": self.device,
dtype=cache_dtype, "dtype": cache_dtype,
) }
self._cache = cache_cls(**cache_kwargs)
if requires_cross_attention_cache:
encoder_kwargs = cache_kwargs.copy()
encoder_kwargs["max_cache_len"] = model_kwargs["encoder_outputs"][0].shape[1]
self._cache = EncoderDecoderCache(self._cache, cache_cls(**encoder_kwargs))
else: else:
self._cache.reset() self._cache.reset()
return self._cache return self._cache
...@@ -1745,6 +1764,7 @@ class GenerationMixin: ...@@ -1745,6 +1764,7 @@ class GenerationMixin:
generation_config.cache_implementation, generation_config.cache_implementation,
getattr(generation_config, "num_beams", 1) * batch_size, getattr(generation_config, "num_beams", 1) * batch_size,
generation_config.max_length, generation_config.max_length,
model_kwargs,
) )
elif generation_config.cache_implementation == "quantized": elif generation_config.cache_implementation == "quantized":
if not self._supports_quantized_cache: if not self._supports_quantized_cache:
...@@ -1776,11 +1796,22 @@ class GenerationMixin: ...@@ -1776,11 +1796,22 @@ class GenerationMixin:
# keeps copying the cache thus using much more memory # keeps copying the cache thus using much more memory
elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache(): elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache():
past = model_kwargs.get("past_key_values", None) past = model_kwargs.get("past_key_values", None)
requires_cross_attention_cache = (
self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
)
if past is None: if past is None:
model_kwargs["past_key_values"] = DynamicCache() model_kwargs["past_key_values"] = (
DynamicCache()
if not requires_cross_attention_cache
else EncoderDecoderCache(DynamicCache(), DynamicCache())
)
use_dynamic_cache_by_default = True use_dynamic_cache_by_default = True
elif isinstance(past, tuple): elif isinstance(past, tuple):
model_kwargs["past_key_values"] = DynamicCache.from_legacy_cache(past) model_kwargs["past_key_values"] = (
DynamicCache.from_legacy_cache(past)
if not requires_cross_attention_cache
else EncoderDecoderCache.from_legacy_cache(past)
)
use_dynamic_cache_by_default = True use_dynamic_cache_by_default = True
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
...@@ -2064,7 +2095,7 @@ class GenerationMixin: ...@@ -2064,7 +2095,7 @@ class GenerationMixin:
# Convert to legacy cache if needed # Convert to legacy cache if needed
if use_dynamic_cache_by_default and generation_config.return_legacy_cache: if use_dynamic_cache_by_default and generation_config.return_legacy_cache:
if isinstance(result, ModelOutput) and hasattr(result, "past_key_values"): if isinstance(result, ModelOutput) and hasattr(result, "past_key_values"):
if isinstance(result.past_key_values, DynamicCache): if isinstance(result.past_key_values, (DynamicCache, EncoderDecoderCache)):
result.past_key_values = result.past_key_values.to_legacy_cache() result.past_key_values = result.past_key_values.to_legacy_cache()
return result return result
...@@ -2234,7 +2265,7 @@ class GenerationMixin: ...@@ -2234,7 +2265,7 @@ class GenerationMixin:
# if the first step in the loop, encode all the prefix and obtain: (1) past_key_values; # if the first step in the loop, encode all the prefix and obtain: (1) past_key_values;
# (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step # (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step
if model_kwargs.get("past_key_values") is None or ( if model_kwargs.get("past_key_values") is None or (
isinstance(model_kwargs["past_key_values"], Cache) isinstance(model_kwargs["past_key_values"], (Cache, EncoderDecoderCache))
and model_kwargs["past_key_values"].get_seq_length() == 0 and model_kwargs["past_key_values"].get_seq_length() == 0
): ):
# prepare inputs # prepare inputs
...@@ -2323,7 +2354,9 @@ class GenerationMixin: ...@@ -2323,7 +2354,9 @@ class GenerationMixin:
# Replicates the new past_key_values to match the `top_k` candidates # Replicates the new past_key_values to match the `top_k` candidates
past = model_kwargs["past_key_values"] past = model_kwargs["past_key_values"]
# If it is a static cache, modify it in-place layer after layer to save memory # If it is a static cache, modify it in-place layer after layer to save memory
if isinstance(past, DynamicCache): if isinstance(past, DynamicCache) or (
isinstance(past, EncoderDecoderCache) and isinstance(past.self_attention_cache, DynamicCache)
):
past.batch_repeat_interleave(top_k) past.batch_repeat_interleave(top_k)
else: else:
new_key_values = [] new_key_values = []
...@@ -2350,7 +2383,10 @@ class GenerationMixin: ...@@ -2350,7 +2383,10 @@ class GenerationMixin:
output_hidden_states=True, output_hidden_states=True,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
if isinstance(outputs["past_key_values"], DynamicCache): if isinstance(outputs["past_key_values"], DynamicCache) or (
isinstance(outputs["past_key_values"], EncoderDecoderCache)
and isinstance(outputs["past_key_values"].self_attention_cache, DynamicCache)
):
# Remove past K-V from output since we don't need to stack later # Remove past K-V from output since we don't need to stack later
outputs["past_key_values"] = None outputs["past_key_values"] = None
# Remove last token from past K-V since we don't want to append it at this point # Remove last token from past K-V since we don't want to append it at this point
...@@ -2425,7 +2461,10 @@ class GenerationMixin: ...@@ -2425,7 +2461,10 @@ class GenerationMixin:
else: else:
_, next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True) _, next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True)
# Do it in-place layer per layer to save memory # Do it in-place layer per layer to save memory
if isinstance(next_past_key_values, DynamicCache): if isinstance(next_past_key_values, DynamicCache) or (
isinstance(next_past_key_values, EncoderDecoderCache)
and isinstance(next_past_key_values.self_attention_cache, DynamicCache)
):
next_past_key_values.batch_select_indices(augmented_idx) next_past_key_values.batch_select_indices(augmented_idx)
else: else:
new_key_values = [] new_key_values = []
...@@ -2498,7 +2537,10 @@ class GenerationMixin: ...@@ -2498,7 +2537,10 @@ class GenerationMixin:
# Contrastive search works by forward looking at the next token, so we need to exclude it from # Contrastive search works by forward looking at the next token, so we need to exclude it from
# `past_key_values` to be consistent with the other decoding methods # `past_key_values` to be consistent with the other decoding methods
if model_kwargs.get("past_key_values") is not None: if model_kwargs.get("past_key_values") is not None:
if isinstance(model_kwargs["past_key_values"], DynamicCache): if isinstance(model_kwargs["past_key_values"], DynamicCache) or (
isinstance(model_kwargs["past_key_values"], EncoderDecoderCache)
and isinstance(model_kwargs["past_key_values"].self_attention_cache, DynamicCache)
):
model_kwargs["past_key_values"].crop(-1) model_kwargs["past_key_values"].crop(-1)
else: else:
past_key_values = [] past_key_values = []
...@@ -2757,7 +2799,7 @@ class GenerationMixin: ...@@ -2757,7 +2799,7 @@ class GenerationMixin:
# Exception 2: models with different cache formats. These are limited to `DynamicCache` until their # Exception 2: models with different cache formats. These are limited to `DynamicCache` until their
# cache format is standardized, to avoid adding complexity to the codebase. # cache format is standardized, to avoid adding complexity to the codebase.
elif "bloom" in model_class or "gptbigcode" in model_class: elif "bloom" in model_class or "gptbigcode" in model_class:
if not isinstance(past_key_values, DynamicCache): if not isinstance(past_key_values, (DynamicCache, EncoderDecoderCache)):
raise ValueError( raise ValueError(
f"Using an unsupported cache format with {model_class}. Currently, it only supports the " f"Using an unsupported cache format with {model_class}. Currently, it only supports the "
"legacy tuple format or `DynamicCache`" "legacy tuple format or `DynamicCache`"
...@@ -3703,8 +3745,12 @@ class GenerationMixin: ...@@ -3703,8 +3745,12 @@ class GenerationMixin:
# This is needed if return_dict_in_generate is True # This is needed if return_dict_in_generate is True
start_from_empty_dynamic_cache = False start_from_empty_dynamic_cache = False
if isinstance(model_kwargs.get("past_key_values", None), DynamicCache): past_key_values = model_kwargs.get("past_key_values", None)
if len(model_kwargs["past_key_values"]) == 0: if isinstance(past_key_values, DynamicCache) or (
isinstance(past_key_values, EncoderDecoderCache)
and isinstance(past_key_values.self_attention_cache, DynamicCache)
):
if len(past_key_values) == 0:
start_from_empty_dynamic_cache = True start_from_empty_dynamic_cache = True
this_peer_finished = False this_peer_finished = False
...@@ -4022,7 +4068,9 @@ def _split(data, full_batch_size: int, split_size: int = None): ...@@ -4022,7 +4068,9 @@ def _split(data, full_batch_size: int, split_size: int = None):
if isinstance(data, torch.Tensor): if isinstance(data, torch.Tensor):
return [data[i : i + split_size] for i in range(0, full_batch_size, split_size)] return [data[i : i + split_size] for i in range(0, full_batch_size, split_size)]
# New cache format # New cache format
elif isinstance(data, DynamicCache): elif isinstance(data, DynamicCache) or (
isinstance(data, EncoderDecoderCache) and isinstance(data.self_attention_cache, DynamicCache)
):
return data.batch_split(full_batch_size, split_size) return data.batch_split(full_batch_size, split_size)
elif isinstance(data, tuple): elif isinstance(data, tuple):
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example) # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
...@@ -4130,6 +4178,8 @@ def stack_model_outputs(model_outputs: List[ModelOutput]) -> ModelOutput: ...@@ -4130,6 +4178,8 @@ def stack_model_outputs(model_outputs: List[ModelOutput]) -> ModelOutput:
# New cache format # New cache format
elif isinstance(data[0], DynamicCache): elif isinstance(data[0], DynamicCache):
return DynamicCache.from_batch_splits(data) return DynamicCache.from_batch_splits(data)
elif isinstance(data[0], EncoderDecoderCache):
return EncoderDecoderCache.from_batch_splits(data)
elif isinstance(data[0], tuple): elif isinstance(data[0], tuple):
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example) # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
if isinstance(data[0][0], tuple): if isinstance(data[0][0], tuple):
......
...@@ -189,7 +189,11 @@ class WhisperConfig(PretrainedConfig): ...@@ -189,7 +189,11 @@ class WhisperConfig(PretrainedConfig):
model_type = "whisper" model_type = "whisper"
keys_to_ignore_at_inference = ["past_key_values"] keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"} attribute_map = {
"num_key_value_heads": "encoder_attention_heads",
"num_attention_heads": "encoder_attention_heads",
"hidden_size": "d_model",
}
def __init__( def __init__(
self, self,
......
...@@ -25,7 +25,8 @@ from torch import nn ...@@ -25,7 +25,8 @@ from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
...@@ -244,6 +245,7 @@ class WhisperAttention(nn.Module): ...@@ -244,6 +245,7 @@ class WhisperAttention(nn.Module):
is_decoder: bool = False, is_decoder: bool = False,
bias: bool = True, bias: bool = True,
is_causal: bool = False, is_causal: bool = False,
layer_idx: Optional[int] = None,
config: Optional[WhisperConfig] = None, config: Optional[WhisperConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -262,6 +264,14 @@ class WhisperAttention(nn.Module): ...@@ -262,6 +264,14 @@ class WhisperAttention(nn.Module):
self.is_decoder = is_decoder self.is_decoder = is_decoder
self.is_causal = is_causal self.is_causal = is_causal
if layer_idx is None and is_decoder:
logger.warning_once(
f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
"will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.layer_idx = layer_idx
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
...@@ -271,84 +281,56 @@ class WhisperAttention(nn.Module): ...@@ -271,84 +281,56 @@ class WhisperAttention(nn.Module):
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
# Copied from transformers.models.bart.modeling_bart.BartAttention.forward with BART->whisper
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None, key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[EncoderDecoderCache] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False, output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer # if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder # for the decoder
is_cross_attention = key_value_states is not None is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size() bsz, tgt_len, _ = hidden_states.size()
# get query proj # get query proj
query_states = self.q_proj(hidden_states) * self.scaling query_states = self._shape(self.q_proj(hidden_states) * self.scaling, tgt_len, bsz)
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]` if past_key_value is not None:
# is checking that the `sequence_length` of the `past_key_value` is the same as is_updated = past_key_value.is_updated.get(self.layer_idx)
# the provided `key_value_states` to support prefix tuning if is_cross_attention:
if ( # after the first generated id, we can subsequently re-use all key/value_states from cache
is_cross_attention past_key_value.is_updated[self.layer_idx] = True
and past_key_value is not None past_key_value = past_key_value.cross_attention_cache
and past_key_value[0].shape[2] == key_value_states.shape[1] else:
): past_key_value = past_key_value.self_attention_cache
# use key_value_states if cross attention
current_states = key_value_states if key_value_states is not None else hidden_states
if is_cross_attention and past_key_value and is_updated:
# reuse k,v, cross_attentions # reuse k,v, cross_attentions
key_states = past_key_value[0] key_states = past_key_value.key_cache[self.layer_idx]
value_states = past_key_value[1] value_states = past_key_value.value_cache[self.layer_idx]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else: else:
# self_attention key_states = self._shape(self.k_proj(current_states), -1, bsz)
key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(current_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz) if past_key_value is not None:
# save all key/value_states to cache to be re-used for fast auto-regressive generation
if self.is_decoder: cache_position = cache_position if not is_cross_attention else None
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. key_states, value_states = past_key_value.update(
# Further calls to cross_attention layer can then reuse all cross-attention key_states, value_states, self.layer_idx, {"cache_position": cache_position}
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.reshape(*proj_shape)
value_states = value_states.reshape(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
) )
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1) attn_weights = nn.functional.softmax(attn_weights, dim=-1)
...@@ -358,42 +340,27 @@ class WhisperAttention(nn.Module): ...@@ -358,42 +340,27 @@ class WhisperAttention(nn.Module):
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
f" {layer_head_mask.size()}" f" {layer_head_mask.size()}"
) )
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if output_attentions:
# this operation is a bit awkward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to be reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
else:
attn_weights_reshaped = None
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_probs, value_states)
attn_output = torch.bmm(attn_probs, value_states) if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError( raise ValueError(
f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}" f" {attn_output.size()}"
) )
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2) attn_output = attn_output.transpose(1, 2)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned across GPUs when using tensor-parallelism. # partitioned across GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output) attn_output = self.out_proj(attn_output)
return attn_output, attn_weights_reshaped, past_key_value return attn_output, attn_weights, past_key_value
# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->Whisper
class WhisperFlashAttention2(WhisperAttention): class WhisperFlashAttention2(WhisperAttention):
""" """
Whisper flash attention module. This module inherits from `WhisperAttention` as the weights of the module stays Whisper flash attention module. This module inherits from `WhisperAttention` as the weights of the module stays
...@@ -410,18 +377,21 @@ class WhisperFlashAttention2(WhisperAttention): ...@@ -410,18 +377,21 @@ class WhisperFlashAttention2(WhisperAttention):
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None, key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[EncoderDecoderCache] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False, output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if isinstance(past_key_value, StaticCache):
raise ValueError(
"The `static` cache implementation is not compatible with `attn_implementation='flash_attention_2'`. "
"Use `attn_implementation='sdpa'` in the meantime, and open an issue at https://github.com/huggingface/transformers"
)
# WhisperFlashAttention2 attention does not support output_attentions # WhisperFlashAttention2 attention does not support output_attentions
if output_attentions: if output_attentions:
raise ValueError("WhisperFlashAttention2 attention does not support output_attentions") raise ValueError("WhisperFlashAttention2 attention does not support output_attentions")
...@@ -429,51 +399,45 @@ class WhisperFlashAttention2(WhisperAttention): ...@@ -429,51 +399,45 @@ class WhisperFlashAttention2(WhisperAttention):
# if key_value_states are provided this layer is used as a cross-attention layer # if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder # for the decoder
is_cross_attention = key_value_states is not None is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
bsz, q_len, _ = hidden_states.size()
# get query proj # get query proj
query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz)
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]` if past_key_value is not None:
# is checking that the `sequence_length` of the `past_key_value` is the same as is_updated = past_key_value.is_updated.get(self.layer_idx)
# the provided `key_value_states` to support prefix tuning if is_cross_attention:
if ( # after the first generated id, we can subsequently re-use all key/value_states from cache
is_cross_attention past_key_value.is_updated[self.layer_idx] = True
and past_key_value is not None past_key_value = past_key_value.cross_attention_cache
and past_key_value[0].shape[2] == key_value_states.shape[1] else:
): past_key_value = past_key_value.self_attention_cache
# use key_value_states if cross attention
current_states = key_value_states if key_value_states is not None else hidden_states
if is_cross_attention and past_key_value and is_updated:
# reuse k,v, cross_attentions # reuse k,v, cross_attentions
key_states = past_key_value[0].transpose(1, 2) key_states = past_key_value.key_cache[self.layer_idx]
value_states = past_key_value[1].transpose(1, 2) value_states = past_key_value.value_cache[self.layer_idx]
elif is_cross_attention:
# cross_attentions
key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
else: else:
# self_attention key_states = self._shape(self.k_proj(current_states), -1, bsz)
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(current_states), -1, bsz)
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) if past_key_value is not None:
# save all key/value_states to cache to be re-used for fast auto-regressive generation
if self.is_decoder: cache_position = cache_position if not is_cross_attention else None
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. key_states, value_states = past_key_value.update(
# Further calls to cross_attention layer can then reuse all cross-attention key_states, value_states, self.layer_idx, {"cache_position": cache_position}
# key/value_states (first "if" case) )
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) # We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view.
# if encoder bi-directional self-attention `past_key_value` is always `None` query_states = query_states.transpose(1, 2)
past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None: causal_mask = attention_mask
kv_seq_len += past_key_value[0].shape[-2] if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
# In PEFT, usually we cast the layer norms in float32 for training stability reasons # In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need # therefore the input hidden states gets silently casted in float32. Hence, we need
...@@ -502,10 +466,10 @@ class WhisperFlashAttention2(WhisperAttention): ...@@ -502,10 +466,10 @@ class WhisperFlashAttention2(WhisperAttention):
value_states = value_states.to(target_dtype) value_states = value_states.to(target_dtype)
attn_output = self._flash_attention_forward( attn_output = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout query_states, key_states, value_states, causal_mask, tgt_len, dropout=self.dropout
) )
attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = attn_output.reshape(bsz, tgt_len, -1)
attn_output = self.out_proj(attn_output) attn_output = self.out_proj(attn_output)
if not output_attentions: if not output_attentions:
...@@ -614,15 +578,15 @@ class WhisperFlashAttention2(WhisperAttention): ...@@ -614,15 +578,15 @@ class WhisperFlashAttention2(WhisperAttention):
class WhisperSdpaAttention(WhisperAttention): class WhisperSdpaAttention(WhisperAttention):
# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with BART->whisper, Bart->Whisper
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None, key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[EncoderDecoderCache] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False, output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
if output_attentions or layer_head_mask is not None: if output_attentions or layer_head_mask is not None:
...@@ -638,59 +602,50 @@ class WhisperSdpaAttention(WhisperAttention): ...@@ -638,59 +602,50 @@ class WhisperSdpaAttention(WhisperAttention):
attention_mask=attention_mask, attention_mask=attention_mask,
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
cache_position=cache_position,
) )
# if key_value_states are provided this layer is used as a cross-attention layer # if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder # for the decoder
is_cross_attention = key_value_states is not None is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size() bsz, tgt_len, _ = hidden_states.size()
# get query proj # get query proj
query_states = self.q_proj(hidden_states) query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz)
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]` if past_key_value is not None:
# is checking that the `sequence_length` of the `past_key_value` is the same as is_updated = past_key_value.is_updated.get(self.layer_idx)
# the provided `key_value_states` to support prefix tuning if is_cross_attention:
if ( # after the first generated id, we can subsequently re-use all key/value_states from cache
is_cross_attention past_key_value.is_updated[self.layer_idx] = True
and past_key_value is not None past_key_value = past_key_value.cross_attention_cache
and past_key_value[0].shape[2] == key_value_states.shape[1] else:
): past_key_value = past_key_value.self_attention_cache
# use key_value_states if cross attention
current_states = key_value_states if key_value_states is not None else hidden_states
if is_cross_attention and past_key_value and is_updated:
# reuse k,v, cross_attentions # reuse k,v, cross_attentions
key_states = past_key_value[0] key_states = past_key_value.key_cache[self.layer_idx]
value_states = past_key_value[1] value_states = past_key_value.value_cache[self.layer_idx]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else: else:
# self_attention key_states = self._shape(self.k_proj(current_states), -1, bsz)
key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(current_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz) if past_key_value is not None:
# save all key/value_states to cache to be re-used for fast auto-regressive generation
if self.is_decoder: cache_position = cache_position if not is_cross_attention else None
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. key_states, value_states = past_key_value.update(
# Further calls to cross_attention layer can then reuse all cross-attention key_states, value_states, self.layer_idx, {"cache_position": cache_position}
# key/value_states (first "if" case) )
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention causal_mask = attention_mask
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) if attention_mask is not None: # no matter the length, we just slice it
# if encoder bi-directional self-attention `past_key_value` is always `None` causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
past_key_value = (key_states, value_states)
query_states = self._shape(query_states, tgt_len, bsz)
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False is_causal = True if self.is_causal and causal_mask is None and tgt_len > 1 else False
# NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
# but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
...@@ -698,7 +653,7 @@ class WhisperSdpaAttention(WhisperAttention): ...@@ -698,7 +653,7 @@ class WhisperSdpaAttention(WhisperAttention):
query_states, query_states,
key_states, key_states,
value_states, value_states,
attn_mask=attention_mask, attn_mask=causal_mask,
dropout_p=self.dropout if self.training else 0.0, dropout_p=self.dropout if self.training else 0.0,
is_causal=is_causal, is_causal=is_causal,
) )
...@@ -798,9 +753,8 @@ class WhisperEncoderLayer(nn.Module): ...@@ -798,9 +753,8 @@ class WhisperEncoderLayer(nn.Module):
return outputs return outputs
# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Whisper, MBART->WHISPER
class WhisperDecoderLayer(nn.Module): class WhisperDecoderLayer(nn.Module):
def __init__(self, config: WhisperConfig): def __init__(self, config: WhisperConfig, layer_idx: int = None):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
...@@ -810,6 +764,7 @@ class WhisperDecoderLayer(nn.Module): ...@@ -810,6 +764,7 @@ class WhisperDecoderLayer(nn.Module):
dropout=config.attention_dropout, dropout=config.attention_dropout,
is_decoder=True, is_decoder=True,
is_causal=True, is_causal=True,
layer_idx=layer_idx,
config=config, config=config,
) )
self.dropout = config.dropout self.dropout = config.dropout
...@@ -822,6 +777,7 @@ class WhisperDecoderLayer(nn.Module): ...@@ -822,6 +777,7 @@ class WhisperDecoderLayer(nn.Module):
config.decoder_attention_heads, config.decoder_attention_heads,
dropout=config.attention_dropout, dropout=config.attention_dropout,
is_decoder=True, is_decoder=True,
layer_idx=layer_idx,
config=config, config=config,
) )
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
...@@ -837,9 +793,10 @@ class WhisperDecoderLayer(nn.Module): ...@@ -837,9 +793,10 @@ class WhisperDecoderLayer(nn.Module):
encoder_attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None,
cross_attn_layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[EncoderDecoderCache] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True, use_cache: Optional[bool] = True,
cache_position: Optional[torch.LongTensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
...@@ -863,41 +820,35 @@ class WhisperDecoderLayer(nn.Module): ...@@ -863,41 +820,35 @@ class WhisperDecoderLayer(nn.Module):
hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn_layer_norm(hidden_states)
# Self Attention # Self Attention
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
# add present self-attn cache to positions 1,2 of present_key_value tuple
hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
past_key_value=self_attn_past_key_value, past_key_value=past_key_value,
attention_mask=attention_mask, attention_mask=attention_mask,
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
cache_position=cache_position,
) )
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
# Cross-Attention Block # Cross-Attention Block
cross_attn_present_key_value = None
cross_attn_weights = None cross_attn_weights = None
if encoder_hidden_states is not None: if encoder_hidden_states is not None:
residual = hidden_states residual = hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states) hidden_states = self.encoder_attn_layer_norm(hidden_states)
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
layer_head_mask=cross_attn_layer_head_mask, layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
# add cross-attn to positions 3,4 of present_key_value tuple # add cross-attn to positions 1 of present_key_value tuple
present_key_value = present_key_value + cross_attn_present_key_value present_key_value = (present_key_value, cross_attn_present_key_value)
# Fully Connected # Fully Connected
residual = hidden_states residual = hidden_states
...@@ -927,6 +878,8 @@ class WhisperPreTrainedModel(PreTrainedModel): ...@@ -927,6 +878,8 @@ class WhisperPreTrainedModel(PreTrainedModel):
_no_split_modules = ["WhisperEncoderLayer", "WhisperDecoderLayer"] _no_split_modules = ["WhisperEncoderLayer", "WhisperDecoderLayer"]
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = True
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.init_std std = self.config.init_std
...@@ -1024,14 +977,18 @@ WHISPER_INPUTS_DOCSTRING = r""" ...@@ -1024,14 +977,18 @@ WHISPER_INPUTS_DOCSTRING = r"""
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): past_key_values (`EncoderDecoderCache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape Pre-computed hidden-states that can be used to speed up auto-regressive (sequential) decoding. There are
four sets of pre-computed hidden-states: key and values states in the self-attention blocks (2) and
in the cross-attention blocks (2). The `past_key_values` are returned when `use_cache=True` is passed or
when `config.use_cache=True`
Two formats are allowed:
- An [`~cache_utils.EncoderDecoderCache`] instance;
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`. `decoder_input_ids` of shape `(batch_size, sequence_length)`.
...@@ -1051,6 +1008,9 @@ WHISPER_INPUTS_DOCSTRING = r""" ...@@ -1051,6 +1008,9 @@ WHISPER_INPUTS_DOCSTRING = r"""
more detail. more detail.
return_dict (`bool`, *optional*): return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. It is used to update the cache
in the correct position and to infer the complete sequence length.
""" """
WHISPER_ENCODER_INPUTS_DOCSTRING = r""" WHISPER_ENCODER_INPUTS_DOCSTRING = r"""
...@@ -1256,7 +1216,9 @@ class WhisperDecoder(WhisperPreTrainedModel): ...@@ -1256,7 +1216,9 @@ class WhisperDecoder(WhisperPreTrainedModel):
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
self.embed_positions = WhisperPositionalEmbedding(self.max_target_positions, config.d_model) self.embed_positions = WhisperPositionalEmbedding(self.max_target_positions, config.d_model)
self.layers = nn.ModuleList([WhisperDecoderLayer(config) for _ in range(config.decoder_layers)]) self.layers = nn.ModuleList(
[WhisperDecoderLayer(config, layer_idx) for layer_idx in range(config.decoder_layers)]
)
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self._use_sdpa = config._attn_implementation == "sdpa" self._use_sdpa = config._attn_implementation == "sdpa"
...@@ -1286,6 +1248,7 @@ class WhisperDecoder(WhisperPreTrainedModel): ...@@ -1286,6 +1248,7 @@ class WhisperDecoder(WhisperPreTrainedModel):
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
cache_position=None,
): ):
r""" r"""
Args: Args:
...@@ -1320,13 +1283,17 @@ class WhisperDecoder(WhisperPreTrainedModel): ...@@ -1320,13 +1283,17 @@ class WhisperDecoder(WhisperPreTrainedModel):
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**. - 0 indicates the head is **masked**.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): past_key_values (`EncoderDecoderCache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of Pre-computed hidden-states that can be used to speed up auto-regressive (sequential) decoding. There are
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of four sets of pre-computed hidden-states: key and values states in the self-attention blocks (2) and
shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. in the cross-attention blocks (2). The `past_key_values` are returned when `use_cache=True` is passed or
when `config.use_cache=True`
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the Two formats are allowed:
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - An [`~cache_utils.EncoderDecoderCache`] instance;
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
...@@ -1344,6 +1311,9 @@ class WhisperDecoder(WhisperPreTrainedModel): ...@@ -1344,6 +1311,9 @@ class WhisperDecoder(WhisperPreTrainedModel):
for more detail. for more detail.
return_dict (`bool`, *optional*): return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
cache in the correct position and to infer the complete sequence length.
""" """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
...@@ -1363,26 +1333,38 @@ class WhisperDecoder(WhisperPreTrainedModel): ...@@ -1363,26 +1333,38 @@ class WhisperDecoder(WhisperPreTrainedModel):
else: else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
if self._use_flash_attention_2: return_legacy_cache = False
# 2d mask is passed through the layers return_self_attention_cache = False
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None if use_cache or past_key_values is not None:
elif self._use_sdpa and head_mask is None and not output_attentions: if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache):
# output_attentions=True & head_mask can not be supported when using SDPA. return_self_attention_cache = True
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( past_key_values = EncoderDecoderCache(past_key_values, DynamicCache())
attention_mask, input_shape, inputs_embeds, past_key_values_length elif not isinstance(past_key_values, EncoderDecoderCache):
) return_legacy_cache = True
else: logger.warning_once(
# 4d mask is passed through the layers "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. "
attention_mask = _prepare_4d_causal_attention_mask( "You should pass an instance of `EncoderDecoderCache` instead, e.g. "
attention_mask, input_shape, inputs_embeds, past_key_values_length "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
)
past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
past_key_values_length = 0
if cache_position is not None:
past_key_values_length = cache_position[0]
elif past_key_values is not None:
past_key_values_length = past_key_values.get_seq_length()
if cache_position is None:
cache_position = torch.arange(
past_key_values_length, past_key_values_length + input_shape[1], device=inputs_embeds.device
) )
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
# embed positions # embed positions
if input_ids is not None: if input_ids is not None:
positions = self.embed_positions( positions = self.embed_positions(
...@@ -1396,6 +1378,14 @@ class WhisperDecoder(WhisperPreTrainedModel): ...@@ -1396,6 +1378,14 @@ class WhisperDecoder(WhisperPreTrainedModel):
hidden_states = inputs_embeds + positions.to(inputs_embeds.device) hidden_states = inputs_embeds + positions.to(inputs_embeds.device)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
causal_mask = self._update_causal_mask(
attention_mask,
inputs_embeds,
cache_position,
past_key_values.self_attention_cache if past_key_values is not None else None,
output_attentions,
)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning_once( logger.warning_once(
...@@ -1406,7 +1396,6 @@ class WhisperDecoder(WhisperPreTrainedModel): ...@@ -1406,7 +1396,6 @@ class WhisperDecoder(WhisperPreTrainedModel):
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None all_self_attns = () if output_attentions else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
next_decoder_cache = () if use_cache else None
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
...@@ -1424,13 +1413,11 @@ class WhisperDecoder(WhisperPreTrainedModel): ...@@ -1424,13 +1413,11 @@ class WhisperDecoder(WhisperPreTrainedModel):
if dropout_probability < self.layerdrop: if dropout_probability < self.layerdrop:
continue continue
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, decoder_layer.__call__,
hidden_states, hidden_states,
attention_mask, causal_mask,
encoder_hidden_states, encoder_hidden_states,
None, # encoder attention mask None, # encoder attention mask
head_mask[idx] if head_mask is not None else None, head_mask[idx] if head_mask is not None else None,
...@@ -1438,25 +1425,24 @@ class WhisperDecoder(WhisperPreTrainedModel): ...@@ -1438,25 +1425,24 @@ class WhisperDecoder(WhisperPreTrainedModel):
None, # past_key_value None, # past_key_value
output_attentions, output_attentions,
use_cache, use_cache,
cache_position,
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
hidden_states, hidden_states,
attention_mask=attention_mask, attention_mask=causal_mask,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
layer_head_mask=(head_mask[idx] if head_mask is not None else None), layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=( cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
), ),
past_key_value=past_key_value, past_key_value=past_key_values if use_cache else None,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
if output_attentions: if output_attentions:
all_self_attns += (layer_outputs[1],) all_self_attns += (layer_outputs[1],)
...@@ -1468,7 +1454,11 @@ class WhisperDecoder(WhisperPreTrainedModel): ...@@ -1468,7 +1454,11 @@ class WhisperDecoder(WhisperPreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None next_cache = past_key_values if use_cache else None
if return_self_attention_cache:
next_cache = past_key_values.self_attention_cache
if return_legacy_cache:
next_cache = past_key_values.to_legacy_cache()
if not return_dict: if not return_dict:
return tuple( return tuple(
v v
...@@ -1483,6 +1473,87 @@ class WhisperDecoder(WhisperPreTrainedModel): ...@@ -1483,6 +1473,87 @@ class WhisperDecoder(WhisperPreTrainedModel):
cross_attentions=all_cross_attentions, cross_attentions=all_cross_attentions,
) )
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_length()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
if attention_mask is not None and attention_mask.dim() == 4:
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
if attention_mask.max() != 0:
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
causal_mask = attention_mask
else:
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@add_start_docstrings( @add_start_docstrings(
"The bare Whisper Model outputting raw hidden-states without any specific head on top.", "The bare Whisper Model outputting raw hidden-states without any specific head on top.",
...@@ -1571,13 +1642,14 @@ class WhisperModel(WhisperPreTrainedModel): ...@@ -1571,13 +1642,14 @@ class WhisperModel(WhisperPreTrainedModel):
decoder_head_mask: Optional[torch.Tensor] = None, decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Union[EncoderDecoderCache, Tuple[torch.FloatTensor]]] = None,
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None, decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]: ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
r""" r"""
Returns: Returns:
...@@ -1637,6 +1709,7 @@ class WhisperModel(WhisperPreTrainedModel): ...@@ -1637,6 +1709,7 @@ class WhisperModel(WhisperPreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
cache_position=cache_position,
) )
if not return_dict: if not return_dict:
...@@ -1704,7 +1777,7 @@ class WhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedM ...@@ -1704,7 +1777,7 @@ class WhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedM
decoder_head_mask: Optional[torch.Tensor] = None, decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Union[EncoderDecoderCache, Tuple[torch.FloatTensor]]] = None,
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None, decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
...@@ -1712,6 +1785,7 @@ class WhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedM ...@@ -1712,6 +1785,7 @@ class WhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedM
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
...@@ -1766,6 +1840,7 @@ class WhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedM ...@@ -1766,6 +1840,7 @@ class WhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedM
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
cache_position=cache_position,
) )
lm_logits = self.proj_out(outputs[0]) lm_logits = self.proj_out(outputs[0])
...@@ -1800,14 +1875,19 @@ class WhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedM ...@@ -1800,14 +1875,19 @@ class WhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedM
encoder_outputs=None, encoder_outputs=None,
attention_mask=None, attention_mask=None,
decoder_attention_mask=None, decoder_attention_mask=None,
cache_position=None,
**kwargs, **kwargs,
): ):
decoder_position_ids = None decoder_position_ids = None
if decoder_attention_mask is not None: if decoder_attention_mask is not None:
decoder_position_ids = (decoder_attention_mask.cumsum(-1) - 1).clamp(min=0) decoder_position_ids = (decoder_attention_mask.cumsum(-1) - 1).clamp(min=0)
past_length = 0
if past_key_values is not None: if past_key_values is not None:
past_length = past_key_values[0][0].shape[2] if isinstance(past_key_values, EncoderDecoderCache):
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
else:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID # Some generation methods already pass only the last input ID
if decoder_input_ids.shape[1] > past_length: if decoder_input_ids.shape[1] > past_length:
...@@ -1821,6 +1901,13 @@ class WhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedM ...@@ -1821,6 +1901,13 @@ class WhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedM
if decoder_position_ids is not None and decoder_position_ids.shape[1] > decoder_input_ids.shape[1]: if decoder_position_ids is not None and decoder_position_ids.shape[1] > decoder_input_ids.shape[1]:
decoder_position_ids = decoder_position_ids[:, remove_prefix_length:] decoder_position_ids = decoder_position_ids[:, remove_prefix_length:]
if cache_position is None:
cache_position = torch.arange(
past_length, past_length + decoder_input_ids.shape[1], device=decoder_input_ids.device
)
elif use_cache:
cache_position = cache_position[-decoder_input_ids.shape[1] :]
return { return {
"encoder_outputs": encoder_outputs, "encoder_outputs": encoder_outputs,
"past_key_values": past_key_values, "past_key_values": past_key_values,
...@@ -1828,6 +1915,7 @@ class WhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedM ...@@ -1828,6 +1915,7 @@ class WhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedM
"use_cache": use_cache, "use_cache": use_cache,
"decoder_attention_mask": decoder_attention_mask, "decoder_attention_mask": decoder_attention_mask,
"decoder_position_ids": decoder_position_ids, "decoder_position_ids": decoder_position_ids,
"cache_position": cache_position,
} }
@staticmethod @staticmethod
...@@ -1914,6 +2002,7 @@ class WhisperForCausalLM(WhisperPreTrainedModel): ...@@ -1914,6 +2002,7 @@ class WhisperForCausalLM(WhisperPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r""" r"""
Args: Args:
...@@ -1968,6 +2057,9 @@ class WhisperForCausalLM(WhisperPreTrainedModel): ...@@ -1968,6 +2057,9 @@ class WhisperForCausalLM(WhisperPreTrainedModel):
for more detail. for more detail.
return_dict (`bool`, *optional*): return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. It is used to update the cache
in the correct position and to infer the complete sequence length.
Returns: Returns:
...@@ -2019,6 +2111,7 @@ class WhisperForCausalLM(WhisperPreTrainedModel): ...@@ -2019,6 +2111,7 @@ class WhisperForCausalLM(WhisperPreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
cache_position=cache_position,
) )
logits = self.proj_out(outputs[0]) logits = self.proj_out(outputs[0])
...@@ -2049,10 +2142,15 @@ class WhisperForCausalLM(WhisperPreTrainedModel): ...@@ -2049,10 +2142,15 @@ class WhisperForCausalLM(WhisperPreTrainedModel):
use_cache=None, use_cache=None,
encoder_outputs=None, encoder_outputs=None,
attention_mask=None, attention_mask=None,
cache_position=None,
**kwargs, **kwargs,
): ):
past_length = 0
if past_key_values is not None: if past_key_values is not None:
past_length = past_key_values[0][0].shape[2] if isinstance(past_key_values, (Cache, EncoderDecoderCache)):
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
else:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID # Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length: if input_ids.shape[1] > past_length:
...@@ -2063,12 +2161,18 @@ class WhisperForCausalLM(WhisperPreTrainedModel): ...@@ -2063,12 +2161,18 @@ class WhisperForCausalLM(WhisperPreTrainedModel):
input_ids = input_ids[:, remove_prefix_length:] input_ids = input_ids[:, remove_prefix_length:]
if cache_position is None:
cache_position = torch.arange(past_length, past_length + input_ids.shape[1], device=input_ids.device)
elif use_cache:
cache_position = cache_position[-input_ids.shape[1] :]
return { return {
"encoder_outputs": encoder_outputs, "encoder_outputs": encoder_outputs,
"past_key_values": past_key_values, "past_key_values": past_key_values,
"input_ids": input_ids, "input_ids": input_ids,
"use_cache": use_cache, "use_cache": use_cache,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"cache_position": cache_position,
} }
@staticmethod @staticmethod
......
...@@ -37,6 +37,13 @@ class DynamicCache(metaclass=DummyObject): ...@@ -37,6 +37,13 @@ class DynamicCache(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class EncoderDecoderCache(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class HQQQuantizedCache(metaclass=DummyObject): class HQQQuantizedCache(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -57,7 +57,7 @@ if is_torch_available(): ...@@ -57,7 +57,7 @@ if is_torch_available():
ImageGPTForCausalImageModeling, ImageGPTForCausalImageModeling,
SpeechEncoderDecoderModel, SpeechEncoderDecoderModel,
) )
from transformers.cache_utils import DynamicCache, QuantoQuantizedCache from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache
from transformers.generation import ( from transformers.generation import (
BeamSampleDecoderOnlyOutput, BeamSampleDecoderOnlyOutput,
BeamSampleEncoderDecoderOutput, BeamSampleEncoderDecoderOutput,
...@@ -1636,7 +1636,6 @@ class GenerationTesterMixin: ...@@ -1636,7 +1636,6 @@ class GenerationTesterMixin:
config, input_ids, attention_mask = self._get_input_ids_and_config() config, input_ids, attention_mask = self._get_input_ids_and_config()
config.use_cache = True config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
generation_kwargs = { generation_kwargs = {
...@@ -1652,15 +1651,21 @@ class GenerationTesterMixin: ...@@ -1652,15 +1651,21 @@ class GenerationTesterMixin:
set_seed(seed) set_seed(seed)
legacy_results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) legacy_results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
set_seed(seed) set_seed(seed)
if config.is_encoder_decoder:
cache_cls = EncoderDecoderCache
past_key_values = cache_cls(DynamicCache(), DynamicCache())
else:
cache_cls = DynamicCache
past_key_values = cache_cls()
new_results = model.generate( new_results = model.generate(
input_ids, attention_mask=attention_mask, past_key_values=DynamicCache(), **generation_kwargs input_ids, attention_mask=attention_mask, past_key_values=past_key_values, **generation_kwargs
) )
# The two sets of generated sequences must match, despite the cache format between forward passes being # The two sets of generated sequences must match, despite the cache format between forward passes being
# different # different
self.assertListEqual(legacy_results.sequences.tolist(), new_results.sequences.tolist()) self.assertListEqual(legacy_results.sequences.tolist(), new_results.sequences.tolist())
self.assertTrue(isinstance(legacy_results.past_key_values, tuple)) self.assertTrue(isinstance(legacy_results.past_key_values, tuple))
self.assertTrue(isinstance(new_results.past_key_values, DynamicCache)) self.assertTrue(isinstance(new_results.past_key_values, cache_cls))
# The contents of the two caches, when converted to the same format (in both directions!), must match # The contents of the two caches, when converted to the same format (in both directions!), must match
legacy_cache = legacy_results.past_key_values legacy_cache = legacy_results.past_key_values
...@@ -1675,7 +1680,7 @@ class GenerationTesterMixin: ...@@ -1675,7 +1680,7 @@ class GenerationTesterMixin:
) )
new_cache = new_results.past_key_values new_cache = new_results.past_key_values
legacy_cache_converted = DynamicCache.from_legacy_cache(legacy_results.past_key_values) legacy_cache_converted = cache_cls.from_legacy_cache(legacy_results.past_key_values)
for layer_idx in range(len(new_cache)): for layer_idx in range(len(new_cache)):
for kv_idx in range(len(new_cache[layer_idx])): for kv_idx in range(len(new_cache[layer_idx])):
self.assertTrue( self.assertTrue(
......
...@@ -1539,6 +1539,46 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ...@@ -1539,6 +1539,46 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
def test_longform_generate_multi_batch_cond_prev(self): def test_longform_generate_multi_batch_cond_prev(self):
self._check_longform_generate_multi_batch(condition_on_prev_tokens=True) self._check_longform_generate_multi_batch(condition_on_prev_tokens=True)
def test_custom_4d_attention_mask(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = WhisperForConditionalGeneration(config).to(device=torch_device, dtype=torch.float32)
model.eval()
(
input_ids,
position_ids,
input_ids_shared_prefix,
mask_shared_prefix,
position_ids_shared_prefix,
) = self._get_custom_4d_mask_test_data()
with torch.no_grad():
logits = model.forward(
decoder_input_ids=input_ids,
input_features=input_dict["input_features"],
decoder_position_ids=position_ids,
).logits
# logits.shape == torch.Size([3, 4, ...])
logits_shared_prefix = model(
decoder_input_ids=input_ids_shared_prefix,
input_features=input_dict["input_features"],
decoder_attention_mask=mask_shared_prefix,
decoder_position_ids=position_ids_shared_prefix,
)[0]
# logits_shared_prefix.shape == torch.Size([1, 6, ...])
out_last_tokens = logits[:, -1, :] # last tokens in each batch line
out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens
# comparing greedily-chosen tokens:
assert torch.equal(out_last_tokens.max(axis=1).indices, out_shared_prefix_last_tokens.max(axis=1).indices)
# comparing softmax-normalized logits:
normalized_0 = torch.nn.functional.softmax(out_last_tokens)
normalized_1 = torch.nn.functional.softmax(out_shared_prefix_last_tokens)
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
@require_torch @require_torch
@require_torchaudio @require_torchaudio
...@@ -2961,6 +3001,34 @@ class WhisperModelIntegrationTests(unittest.TestCase): ...@@ -2961,6 +3001,34 @@ class WhisperModelIntegrationTests(unittest.TestCase):
torch.manual_seed(0) torch.manual_seed(0)
model.generate(**inputs, **gen_kwargs) model.generate(**inputs, **gen_kwargs)
@slow
def test_tiny_static_generation(self):
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
model.to(torch_device)
input_speech = self._load_datasamples(4)
input_features = processor(input_speech, return_tensors="pt", sampling_rate=16_000).input_features
input_features = input_features.to(torch_device)
eager_generated_ids = model.generate(input_features, max_new_tokens=64)
model.generation_config.cache_implementation = "static"
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
# compile the forward pass and assert equivalence
static_generated_ids = model.generate(input_features, max_new_tokens=64)
assert (eager_generated_ids == static_generated_ids).all()
# check the compiled graph can be re-used and that the cache is correctly reset
# reverse the ordering of the input features
permutation_idx = (
torch.arange(input_features.shape[0], 0, step=-1, dtype=torch.long, device=input_features.device) - 1
)
input_features = input_features[permutation_idx, ...]
static_generated_ids = model.generate(input_features, max_new_tokens=64)
# assert re-ordered generations match those from eager
assert (eager_generated_ids[permutation_idx, :] == static_generated_ids).all()
def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None): def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None):
if head_mask is None: if head_mask is None:
...@@ -3564,6 +3632,10 @@ class WhisperStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, ...@@ -3564,6 +3632,10 @@ class WhisperStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin,
config=config, input_ids=inputs_dict["input_ids"] config=config, input_ids=inputs_dict["input_ids"]
) )
@unittest.skip(reason="Tested implicitly through the encoder-decoder tests")
def test_custom_4d_attention_mask(self):
pass
@unittest.skip(reason="Generate needs input ids") @unittest.skip(reason="Generate needs input ids")
def test_generate_without_input_ids(self): def test_generate_without_input_ids(self):
# generate only works with input ids for whisper # generate only works with input ids for whisper
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment