Unverified Commit 75bbfd5b authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Cache: Static cache as a standalone object (#30476)

parent 0ae789e0
...@@ -362,3 +362,4 @@ A [`Constraint`] can be used to force the generation to include specific tokens ...@@ -362,3 +362,4 @@ A [`Constraint`] can be used to force the generation to include specific tokens
[[autodoc]] StaticCache [[autodoc]] StaticCache
- update - update
- get_seq_length - get_seq_length
- reorder_cache
...@@ -65,13 +65,12 @@ tokenizer.batch_decode(outputs, skip_special_tokens=True) ...@@ -65,13 +65,12 @@ tokenizer.batch_decode(outputs, skip_special_tokens=True)
['The theory of special relativity states 1. The speed of light is constant in all inertial reference'] ['The theory of special relativity states 1. The speed of light is constant in all inertial reference']
``` ```
</hfoption> Under the hood, `generate` will attempt to reuse the same cache object, removing the need for re-compilation at each call. However, if the batch size or the maximum output length increase between calls, the cache will have to be reinitialized, triggering a new compilation.
<hfoption id="setup_cache">
> [!WARNING] </hfoption>
> The `_setup_cache` method is an internal and private method that is still under development. This means it may not be backward compatible and the API design may change in the future. <hfoption id="Static Cache">
The `_setup_cache` method doesn't support [`~GenerationMixin.generate`] yet, so this method is a bit more involved. You'll need to write your own function to decode the next token given the current token and position and cache position of previously generated tokens. A [`StaticCache`] object can be passed to the model's forward pass under the `past_key_values` argument, enabling the use of this object as a static kv-cache. Using this strategy, you can write your own function to decode the next token given the current token and position and cache position of previously generated tokens. You can also pass the [`StaticCache`] object to [`~GenerationMixin.generate`] and use it across calls, like you would do with a dynamic cache.
```py ```py
from transformers import LlamaTokenizer, LlamaForCausalLM, StaticCache, logging from transformers import LlamaTokenizer, LlamaForCausalLM, StaticCache, logging
...@@ -90,17 +89,22 @@ tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token ...@@ -90,17 +89,22 @@ tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token
model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="sequential") model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="sequential")
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
def decode_one_tokens(model, cur_token, input_pos, cache_position): def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_values):
logits = model( logits = model(
cur_token, position_ids=input_pos, cache_position=cache_position, return_dict=False, use_cache=True cur_token,
position_ids=input_pos,
cache_position=cache_position,
past_key_values=past_key_values,
return_dict=False,
use_cache=True
)[0] )[0]
new_token = torch.argmax(logits[:, -1], dim=-1)[:, None] new_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
return new_token return new_token
``` ```
There are a few important things you must do to enable static kv-cache and torch.compile with the `_setup_cache` method: There are a few important things you must do to enable static kv-cache and torch.compile with the `StaticCache` method:
1. Access the model's `_setup_cache` method and pass it the [`StaticCache`] class. This is a more flexible method because it allows you to configure parameters like the maximum batch size and sequence length. 1. Initialize the [`StaticCache`] instance before using the model for inference. There you can configure parameters like the maximum batch size and sequence length.
2. Call torch.compile on the model to compile the forward pass with the static kv-cache. 2. Call torch.compile on the model to compile the forward pass with the static kv-cache.
...@@ -109,24 +113,28 @@ There are a few important things you must do to enable static kv-cache and torch ...@@ -109,24 +113,28 @@ There are a few important things you must do to enable static kv-cache and torch
```py ```py
batch_size, seq_length = inputs["input_ids"].shape batch_size, seq_length = inputs["input_ids"].shape
with torch.no_grad(): with torch.no_grad():
model._setup_cache(StaticCache, 2, max_cache_len=4096) past_key_values = StaticCache(
cache_position = torch.arange(seq_length, device=torch_device) config=model.config, max_batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype
generated_ids = torch.zeros( )
batch_size, seq_length + NUM_TOKENS_TO_GENERATE + 1, dtype=torch.int, device=torch_device cache_position = torch.arange(seq_length, device=torch_device)
) generated_ids = torch.zeros(
generated_ids[:, cache_position] = inputs["input_ids"].to(torch_device).to(torch.int) batch_size, seq_length + NUM_TOKENS_TO_GENERATE + 1, dtype=torch.int, device=torch_device
)
logits = model(**inputs, cache_position=cache_position, return_dict=False, use_cache=True)[0] generated_ids[:, cache_position] = inputs["input_ids"].to(torch_device).to(torch.int)
next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
generated_ids[:, seq_length] = next_token[:, 0] logits = model(
**inputs, cache_position=cache_position, past_key_values=past_key_values,return_dict=False, use_cache=True
decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True) )[0]
cache_position = torch.tensor([seq_length + 1], device=torch_device) next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
for _ in range(1, NUM_TOKENS_TO_GENERATE): generated_ids[:, seq_length] = next_token[:, 0]
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
next_token = decode_one_tokens(model, next_token.clone(), None, cache_position) decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True)
generated_ids[:, cache_position] = next_token.int() cache_position = torch.tensor([seq_length + 1], device=torch_device)
cache_position += 1 for _ in range(1, NUM_TOKENS_TO_GENERATE):
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
next_token = decode_one_tokens(model, next_token.clone(), None, cache_position, past_key_values)
generated_ids[:, cache_position] = next_token.int()
cache_position += 1
text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
text text
...@@ -134,6 +142,9 @@ text ...@@ -134,6 +142,9 @@ text
'My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p'] 'My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p']
``` ```
> [!TIP]
> If you want to reuse the [`StaticCache`] object on a new prompt, be sure to reset its contents with the `.reset()` method
</hfoption> </hfoption>
</hfoptions> </hfoptions>
......
...@@ -44,6 +44,7 @@ class Cache: ...@@ -44,6 +44,7 @@ class Cache:
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed.""" """Returns the sequence length of the cached states. A layer index can be optionally passed."""
# TODO: deprecate this function in favor of `cache_position`
raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")
def get_max_length(self) -> Optional[int]: def get_max_length(self) -> Optional[int]:
...@@ -61,6 +62,14 @@ class Cache: ...@@ -61,6 +62,14 @@ class Cache:
return max_length - new_seq_length return max_length - new_seq_length
return previous_seq_length return previous_seq_length
def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
for layer_idx in range(len(self.key_cache)):
device = self.key_cache[layer_idx].device
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
device = self.value_cache[layer_idx].device
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
@property @property
def seen_tokens(self): def seen_tokens(self):
logger.warning_once( logger.warning_once(
...@@ -150,6 +159,7 @@ class DynamicCache(Cache): ...@@ -150,6 +159,7 @@ class DynamicCache(Cache):
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed.""" """Returns the sequence length of the cached states. A layer index can be optionally passed."""
# TODO: deprecate this function in favor of `cache_position`
if len(self.key_cache) <= layer_idx: if len(self.key_cache) <= layer_idx:
return 0 return 0
return self.key_cache[layer_idx].shape[-2] return self.key_cache[layer_idx].shape[-2]
...@@ -158,14 +168,6 @@ class DynamicCache(Cache): ...@@ -158,14 +168,6 @@ class DynamicCache(Cache):
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
return None return None
def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
for layer_idx in range(len(self.key_cache)):
device = self.key_cache[layer_idx].device
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
device = self.value_cache[layer_idx].device
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format.""" """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
legacy_cache = () legacy_cache = ()
...@@ -244,6 +246,7 @@ class SinkCache(Cache): ...@@ -244,6 +246,7 @@ class SinkCache(Cache):
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed.""" """Returns the sequence length of the cached states. A layer index can be optionally passed."""
# TODO: deprecate this function in favor of `cache_position`
# Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
if len(self.key_cache) <= layer_idx: if len(self.key_cache) <= layer_idx:
return 0 return 0
...@@ -332,14 +335,6 @@ class SinkCache(Cache): ...@@ -332,14 +335,6 @@ class SinkCache(Cache):
return self.key_cache[layer_idx], self.value_cache[layer_idx] return self.key_cache[layer_idx], self.value_cache[layer_idx]
def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
for layer_idx in range(len(self.key_cache)):
device = self.key_cache[layer_idx].device
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
device = self.value_cache[layer_idx].device
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
class StaticCache(Cache): class StaticCache(Cache):
""" """
...@@ -347,8 +342,7 @@ class StaticCache(Cache): ...@@ -347,8 +342,7 @@ class StaticCache(Cache):
Parameters: Parameters:
config (`PretrainedConfig): config (`PretrainedConfig):
The configuration file defining the `max_position_embeddings`, `hidden_size` and `num_attention_heads` The configuration file defining the shape-related attributes required to initialize the static cache.
required to initialize the static cache.
max_batch_size (`int`): max_batch_size (`int`):
The maximum batch size with which the model will be used. The maximum batch size with which the model will be used.
max_cache_len (`int`): max_cache_len (`int`):
...@@ -373,9 +367,18 @@ class StaticCache(Cache): ...@@ -373,9 +367,18 @@ class StaticCache(Cache):
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
) )
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device) for _ in range(config.num_hidden_layers):
self.value_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device) # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache.
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache)
def update( def update(
self, self,
...@@ -394,42 +397,37 @@ class StaticCache(Cache): ...@@ -394,42 +397,37 @@ class StaticCache(Cache):
value_states (`torch.Tensor`): value_states (`torch.Tensor`):
The new value states to cache. The new value states to cache.
layer_idx (`int`): layer_idx (`int`):
The index of the layer to cache the states for. Kept for backward compatibility The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`): cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. The `StaticCache` just needs the `q_len` Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input
to know how much of the cache it should overwrite. to know how where to write in the cache.
Return: Return:
A tuple containing the updated key and value states. A tuple containing the updated key and value states.
""" """
new_cache_positions = cache_kwargs.get("cache_position") cache_position = cache_kwargs.get("cache_position")
k_out = self.key_cache k_out = self.key_cache[layer_idx]
v_out = self.value_cache v_out = self.value_cache[layer_idx]
k_out[:, :, new_cache_positions] = key_states k_out[:, :, cache_position] = key_states
v_out[:, :, new_cache_positions] = value_states v_out[:, :, cache_position] = value_states
return k_out, v_out return k_out, v_out
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC""" """Returns the sequence length of the cached states that were seen by the model."""
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
# limit the check to the first batch member and head dimension. # limit the check to the first batch member and head dimension.
# TODO: This is error prone, a filled cache may be `0.0`. Let's use a stateless integer instead, after # TODO: deprecate this function in favor of `cache_position`
# https://github.com/pytorch/pytorch/issues/120248 is fixed return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
return (self.key_cache[0, 0].any(dim=-1)).sum()
def get_max_length(self) -> Optional[int]: def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" """Returns the maximum sequence length of the cached states."""
return self.max_cache_len return self.max_cache_len
def reorder_cache(self, beam_idx: torch.LongTensor): def reset(self):
"""Reorders the cache for beam search, given the selected beam indices.""" """Resets the cache values while preserving the objects"""
device = self.key_cache.device for layer_idx in range(len(self.key_cache)):
self.key_cache = self.key_cache.index_select(0, beam_idx.to(device)) # In-place ops prevent breaking the static address
device = self.value_cache.device self.key_cache[layer_idx].zero_()
self.value_cache = self.value_cache.index_select(0, beam_idx.to(device)) self.value_cache[layer_idx].zero_()
def to_legacy_cache(self):
"""Dummy function for BC. We have to keep it because otherwise the call in the forward of models will break it"""
return None
...@@ -1310,6 +1310,34 @@ class GenerationMixin: ...@@ -1310,6 +1310,34 @@ class GenerationMixin:
model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device) model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device)
return model_kwargs return model_kwargs
def _get_static_cache(self, max_batch_size: int, max_cache_len: int) -> StaticCache:
"""
Sets a static cache for `generate`, that will persist across calls. A new cache will only be initialized a
new `generate` call requires a larger cache.
Returns the resulting static cache object.
"""
needs_new_cache = (
not hasattr(self, "_static_cache")
or self._static_cache.max_batch_size < max_batch_size
or self._static_cache.max_cache_len < max_cache_len
)
if needs_new_cache:
if hasattr(self.config, "_pre_quantization_dtype"):
cache_dtype = self.config._pre_quantization_dtype
else:
cache_dtype = self.dtype
self._static_cache = StaticCache(
config=self.config,
max_batch_size=max_batch_size,
max_cache_len=max_cache_len,
device=self.device,
dtype=cache_dtype,
)
else:
self._static_cache.reset() # reset the cache for a new generation
return self._static_cache
@torch.no_grad() @torch.no_grad()
def generate( def generate(
self, self,
...@@ -1514,19 +1542,19 @@ class GenerationMixin: ...@@ -1514,19 +1542,19 @@ class GenerationMixin:
input_ids_length=input_ids_length, input_ids_length=input_ids_length,
) )
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: if generation_config.cache_implementation is not None and model_kwargs.get("past_key_values") is not None:
raise ValueError(
"Passing both `cache_implementation` (used to initialize certain caches) and `past_key_values` (a "
"Cache object) is unsupported. Please use only one of the two."
)
elif generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
if not self._supports_cache_class:
raise ValueError(
"This model does not support the `cache_implementation` argument. Please check the following "
"issue: https://github.com/huggingface/transformers/issues/28981."
)
if generation_config.cache_implementation == "static": if generation_config.cache_implementation == "static":
if model_kwargs.get("past_key_values", False) is not False: model_kwargs["past_key_values"] = self._get_static_cache(batch_size, generation_config.max_length)
raise ValueError(
"Using `past_key_values` argument with `generate()` when using a static KV cache is not supported. Please open an issue in Transformers GitHub repository."
)
cache_cls = NEED_SETUP_CACHE_CLASSES_MAPPING["static"]
if not callable(getattr(self, "_setup_cache", None)):
raise ValueError(
"The `generation_config` defines a `cache_implementation` that is not compatible with this model."
" Make sure it has a `_setup_cache` function."
)
self._setup_cache(cache_cls, max_batch_size=batch_size, max_cache_len=generation_config.max_length)
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
...@@ -1844,14 +1872,6 @@ class GenerationMixin: ...@@ -1844,14 +1872,6 @@ class GenerationMixin:
**model_kwargs, **model_kwargs,
) )
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
if not callable(getattr(self, "_reset_cache", None)):
raise ValueError(
"A `static_cache` was used to generate but there was a failure when trying to release the cache. "
" Make sure this model implements a `_reset_cache` function."
)
self._reset_cache()
return result return result
def _has_unfinished_sequences(self, this_peer_finished: bool, synced_gpus: bool, device: torch.device) -> bool: def _has_unfinished_sequences(self, this_peer_finished: bool, synced_gpus: bool, device: torch.device) -> bool:
......
...@@ -340,6 +340,11 @@ class CohereFlashAttention2(CohereAttention): ...@@ -340,6 +340,11 @@ class CohereFlashAttention2(CohereAttention):
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if isinstance(past_key_value, StaticCache):
raise ValueError(
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
)
output_attentions = False output_attentions = False
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
...@@ -734,27 +739,6 @@ class CoherePreTrainedModel(PreTrainedModel): ...@@ -734,27 +739,6 @@ class CoherePreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None):
if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache:
raise ValueError(
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
)
for layer in self.model.layers:
device = layer.input_layernorm.weight.device
if hasattr(self.config, "_pre_quantization_dtype"):
dtype = self.config._pre_quantization_dtype
else:
dtype = layer.self_attn.o_proj.weight.dtype
layer.self_attn.past_key_value = cache_cls(
self.config, max_batch_size, max_cache_len, device=device, dtype=dtype
)
def _reset_cache(self):
for layer in self.model.layers:
layer.self_attn.past_key_value = None
COHERE_INPUTS_DOCSTRING = r""" COHERE_INPUTS_DOCSTRING = r"""
Args: Args:
...@@ -898,14 +882,11 @@ class CohereModel(CoherePreTrainedModel): ...@@ -898,14 +882,11 @@ class CohereModel(CoherePreTrainedModel):
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
past_seen_tokens = 0 past_seen_tokens = 0
if use_cache: # kept for BC (cache positions) if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
if not isinstance(past_key_values, StaticCache): past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length()
if cache_position is None: if cache_position is None:
if isinstance(past_key_values, StaticCache): past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
raise ValueError("cache_position is a required argument when using StaticCache.")
cache_position = torch.arange( cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
) )
...@@ -913,7 +894,7 @@ class CohereModel(CoherePreTrainedModel): ...@@ -913,7 +894,7 @@ class CohereModel(CoherePreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens) causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values)
# embed positions # embed positions
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -982,7 +963,7 @@ class CohereModel(CoherePreTrainedModel): ...@@ -982,7 +963,7 @@ class CohereModel(CoherePreTrainedModel):
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
input_tensor: torch.Tensor, input_tensor: torch.Tensor,
cache_position: torch.Tensor, cache_position: torch.Tensor,
past_seen_tokens: int, past_key_values: Cache,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
...@@ -994,9 +975,12 @@ class CohereModel(CoherePreTrainedModel): ...@@ -994,9 +975,12 @@ class CohereModel(CoherePreTrainedModel):
return attention_mask return attention_mask
return None return None
if self.config._attn_implementation == "sdpa": # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# in order to dispatch on Flash Attention 2. # to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
if self.config._attn_implementation == "sdpa" and not using_static_cache:
if AttentionMaskConverter._ignore_causal_mask_sdpa( if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, attention_mask,
inputs_embeds=input_tensor, inputs_embeds=input_tensor,
...@@ -1008,9 +992,9 @@ class CohereModel(CoherePreTrainedModel): ...@@ -1008,9 +992,9 @@ class CohereModel(CoherePreTrainedModel):
dtype, device = input_tensor.dtype, input_tensor.device dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1] sequence_length = input_tensor.shape[1]
if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache if using_static_cache:
target_length = self.config.max_position_embeddings target_length = past_key_values.get_max_length()
else: # dynamic cache else:
target_length = ( target_length = (
attention_mask.shape[-1] attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor) if isinstance(attention_mask, torch.Tensor)
...@@ -1032,6 +1016,10 @@ class CohereModel(CoherePreTrainedModel): ...@@ -1032,6 +1016,10 @@ class CohereModel(CoherePreTrainedModel):
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only. # cache. In that case, the 4D attention mask attends to the newest tokens only.
if attention_mask.shape[-2] < cache_position[0] + sequence_length: if attention_mask.shape[-2] < cache_position[0] + sequence_length:
logger.warning_once(
"Passing a 4d mask shorter than the input length is deprecated and will be removed in "
"transformers v4.42.0"
)
offset = cache_position[0] offset = cache_position[0]
else: else:
offset = 0 offset = 0
...@@ -1189,13 +1177,6 @@ class CohereForCausalLM(CoherePreTrainedModel): ...@@ -1189,13 +1177,6 @@ class CohereForCausalLM(CoherePreTrainedModel):
use_cache=True, use_cache=True,
**kwargs, **kwargs,
): ):
# With static cache, the `past_key_values` is None
# TODO joao: standardize interface for the different Cache classes and remove of this if
has_static_cache = False
if past_key_values is None:
past_key_values = getattr(getattr(self.model.layers[0], "self_attn", {}), "past_key_value", None)
has_static_cache = past_key_values is not None
past_length = 0 past_length = 0
if past_key_values is not None: if past_key_values is not None:
if isinstance(past_key_values, Cache): if isinstance(past_key_values, Cache):
...@@ -1213,8 +1194,7 @@ class CohereForCausalLM(CoherePreTrainedModel): ...@@ -1213,8 +1194,7 @@ class CohereForCausalLM(CoherePreTrainedModel):
# Keep only the unprocessed tokens: # Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
# input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
...@@ -1254,9 +1234,6 @@ class CohereForCausalLM(CoherePreTrainedModel): ...@@ -1254,9 +1234,6 @@ class CohereForCausalLM(CoherePreTrainedModel):
elif use_cache: elif use_cache:
cache_position = cache_position[-input_length:] cache_position = cache_position[-input_length:]
if has_static_cache:
past_key_values = None
model_inputs.update( model_inputs.update(
{ {
"position_ids": position_ids, "position_ids": position_ids,
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
""" PyTorch DBRX model. """ """ PyTorch DBRX model. """
import math import math
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -354,6 +354,11 @@ class DbrxFlashAttention2(DbrxAttention): ...@@ -354,6 +354,11 @@ class DbrxFlashAttention2(DbrxAttention):
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs: Any, **kwargs: Any,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if isinstance(past_key_value, StaticCache):
raise ValueError(
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
)
logger.info("Implicitly setting `output_attentions` to False as it is not supported in Flash Attention.") logger.info("Implicitly setting `output_attentions` to False as it is not supported in Flash Attention.")
output_attentions = False output_attentions = False
...@@ -622,6 +627,7 @@ class DbrxSdpaAttention(DbrxAttention): ...@@ -622,6 +627,7 @@ class DbrxSdpaAttention(DbrxAttention):
value_states, value_states,
attn_mask=causal_mask, attn_mask=causal_mask,
dropout_p=self.attn_pdrop if self.training else 0.0, dropout_p=self.attn_pdrop if self.training else 0.0,
is_causal=causal_mask is None and q_len > 1,
) )
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
...@@ -957,28 +963,6 @@ class DbrxPreTrainedModel(PreTrainedModel): ...@@ -957,28 +963,6 @@ class DbrxPreTrainedModel(PreTrainedModel):
module.v1.data.normal_(mean=0.0, std=std) module.v1.data.normal_(mean=0.0, std=std)
module.w2.data.normal_(mean=0.0, std=std) module.w2.data.normal_(mean=0.0, std=std)
def _setup_cache(self, cache_cls: Any, max_batch_size: int, max_cache_len: int):
if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache:
raise ValueError(
"`static` cache implementation is not compatible with "
+ "`attn_implementation==flash_attention_2`. Make sure to use "
+ "`spda` in the mean time and open an issue at https://github.com/huggingface/transformers."
)
for block in self.transformer.blocks:
device = block.norm_attn_norm.norm_1.weight.device
if hasattr(self.config, "_pre_quantization_dtype"):
dtype = self.config._pre_quantization_dtype
else:
dtype = block.norm_attn_norm.attn.out_proj.weight.dtype
block.norm_attn_norm.attn.past_key_value = cache_cls(
self.config, max_batch_size, max_cache_len, device=device, dtype=dtype
)
def _reset_cache(self):
for block in self.transformer.blocks:
block.norm_attn_norm.attn.past_key_value = None
DBRX_INPUTS_DOCSTRING = r""" DBRX_INPUTS_DOCSTRING = r"""
Args: Args:
...@@ -1131,22 +1115,18 @@ class DbrxModel(DbrxPreTrainedModel): ...@@ -1131,22 +1115,18 @@ class DbrxModel(DbrxPreTrainedModel):
inputs_embeds = nn.functional.dropout(inputs_embeds, p=self.emb_pdrop, training=self.training) inputs_embeds = nn.functional.dropout(inputs_embeds, p=self.emb_pdrop, training=self.training)
past_seen_tokens = 0 if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
if use_cache: # kept for BC (cache positions) past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if not isinstance(past_key_values, StaticCache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length()
if cache_position is None: if cache_position is None:
if isinstance(past_key_values, StaticCache): past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
raise ValueError("cache_position is a required argument when using StaticCache.")
cache_position = torch.arange( cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
) )
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values)
# embed positions # embed positions
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -1205,7 +1185,9 @@ class DbrxModel(DbrxPreTrainedModel): ...@@ -1205,7 +1185,9 @@ class DbrxModel(DbrxPreTrainedModel):
next_cache = None next_cache = None
if use_cache: if use_cache:
next_cache = ( next_cache = (
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache next_decoder_cache.to_legacy_cache()
if isinstance(next_decoder_cache, DynamicCache)
else next_decoder_cache
) )
if not return_dict: if not return_dict:
return tuple( return tuple(
...@@ -1221,28 +1203,45 @@ class DbrxModel(DbrxPreTrainedModel): ...@@ -1221,28 +1203,45 @@ class DbrxModel(DbrxPreTrainedModel):
router_logits=all_router_logits, router_logits=all_router_logits,
) )
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
def _update_causal_mask( def _update_causal_mask(
self, attention_mask: Optional[torch.Tensor], input_tensor: torch.Tensor, cache_position: torch.Tensor self,
) -> Optional[torch.Tensor]: attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self.config._attn_implementation == "flash_attention_2": if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask: if attention_mask is not None and 0.0 in attention_mask:
return attention_mask return attention_mask
return None return None
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
if self.config._attn_implementation == "sdpa" and not using_static_cache:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens
):
return None
dtype, device = input_tensor.dtype, input_tensor.device dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1] sequence_length = input_tensor.shape[1]
if hasattr(self.blocks[0].norm_attn_norm.attn, "past_key_value"): # static cache if using_static_cache:
target_length = self.config.max_position_embeddings target_length = past_key_values.get_max_length()
else: # dynamic cache else:
target_length = ( target_length = (
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1 attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
) )
target_length = int(target_length)
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1: if sequence_length != 1:
...@@ -1259,6 +1258,10 @@ class DbrxModel(DbrxPreTrainedModel): ...@@ -1259,6 +1258,10 @@ class DbrxModel(DbrxPreTrainedModel):
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only. # cache. In that case, the 4D attention mask attends to the newest tokens only.
if attention_mask.shape[-2] < cache_position[0] + sequence_length: if attention_mask.shape[-2] < cache_position[0] + sequence_length:
logger.warning_once(
"Passing a 4d mask shorter than the input length is deprecated and will be removed in "
"transformers v4.42.0"
)
offset = cache_position[0] offset = cache_position[0]
else: else:
offset = 0 offset = 0
...@@ -1273,17 +1276,10 @@ class DbrxModel(DbrxPreTrainedModel): ...@@ -1273,17 +1276,10 @@ class DbrxModel(DbrxPreTrainedModel):
and attention_mask is not None and attention_mask is not None
and attention_mask.device.type == "cuda" and attention_mask.device.type == "cuda"
): ):
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400). # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
is_tracing = ( # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
torch.jit.is_tracing() # Details: https://github.com/pytorch/pytorch/issues/110213
or isinstance(input_tensor, torch.fx.Proxy) causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
)
if not is_tracing and torch.any(attention_mask != 1):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask return causal_mask
...@@ -1431,28 +1427,35 @@ class DbrxForCausalLM(DbrxPreTrainedModel): ...@@ -1431,28 +1427,35 @@ class DbrxForCausalLM(DbrxPreTrainedModel):
router_logits=outputs.router_logits, router_logits=outputs.router_logits,
) )
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, self,
input_ids: torch.Tensor, input_ids,
past_key_values: Optional[Cache] = None, past_key_values=None,
attention_mask: Optional[torch.Tensor] = None, attention_mask=None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds=None,
**kwargs: Any, cache_position=None,
) -> Dict[str, Any]: use_cache=True,
**kwargs,
):
past_length = 0 past_length = 0
if past_key_values is not None: if past_key_values is not None:
if isinstance(past_key_values, Cache): if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length() past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens max_cache_length = (
max_cache_length = past_key_values.get_max_length() torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
if past_key_values.get_max_length() is not None
else None
)
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
else: else:
cache_length = past_length = past_key_values[0][0].shape[2] cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None max_cache_length = None
# Keep only the unprocessed tokens: # Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
# input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
...@@ -1477,22 +1480,6 @@ class DbrxForCausalLM(DbrxPreTrainedModel): ...@@ -1477,22 +1480,6 @@ class DbrxForCausalLM(DbrxPreTrainedModel):
if past_key_values: if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :] position_ids = position_ids[:, -input_ids.shape[1] :]
if self.generation_config.cache_implementation == "static":
# generation with static cache
cache_position = kwargs.get("cache_position", None)
if cache_position is None:
past_length = 0
else:
past_length = cache_position[-1] + 1
input_ids = input_ids[:, past_length:]
position_ids = position_ids[:, past_length:] if position_ids is not None else None
# TODO @gante we should only keep a `cache_position` in generate, and do +=1.
# same goes for position ids. Could also help with continued generation.
input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
position_ids = position_ids.contiguous() if position_ids is not None else None
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None: if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds}
...@@ -1502,12 +1489,18 @@ class DbrxForCausalLM(DbrxPreTrainedModel): ...@@ -1502,12 +1489,18 @@ class DbrxForCausalLM(DbrxPreTrainedModel):
# TODO: use `next_tokens` directly instead. # TODO: use `next_tokens` directly instead.
model_inputs = {"input_ids": input_ids.contiguous()} model_inputs = {"input_ids": input_ids.contiguous()}
input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
if cache_position is None:
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
elif use_cache:
cache_position = cache_position[-input_length:]
model_inputs.update( model_inputs.update(
{ {
"position_ids": position_ids, "position_ids": position_ids,
"cache_position": cache_position, "cache_position": cache_position,
"past_key_values": past_key_values, "past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"), "use_cache": use_cache,
"attention_mask": attention_mask, "attention_mask": attention_mask,
} }
) )
......
...@@ -332,6 +332,11 @@ class GemmaFlashAttention2(GemmaAttention): ...@@ -332,6 +332,11 @@ class GemmaFlashAttention2(GemmaAttention):
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if isinstance(past_key_value, StaticCache):
raise ValueError(
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
)
output_attentions = False output_attentions = False
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
...@@ -615,7 +620,7 @@ class GemmaDecoderLayer(nn.Module): ...@@ -615,7 +620,7 @@ class GemmaDecoderLayer(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
...@@ -717,23 +722,6 @@ class GemmaPreTrainedModel(PreTrainedModel): ...@@ -717,23 +722,6 @@ class GemmaPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None):
if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache:
raise ValueError(
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
)
for layer in self.model.layers:
weights = layer.self_attn.o_proj.weight
layer.self_attn.past_key_value = cache_cls(
self.config, max_batch_size, max_cache_len, device=weights.device, dtype=weights.dtype
)
def _reset_cache(self):
for layer in self.model.layers:
layer.self_attn.past_key_value = None
GEMMA_INPUTS_DOCSTRING = r""" GEMMA_INPUTS_DOCSTRING = r"""
Args: Args:
...@@ -850,7 +838,7 @@ class GemmaModel(GemmaPreTrainedModel): ...@@ -850,7 +838,7 @@ class GemmaModel(GemmaPreTrainedModel):
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
...@@ -879,13 +867,11 @@ class GemmaModel(GemmaPreTrainedModel): ...@@ -879,13 +867,11 @@ class GemmaModel(GemmaPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
past_seen_tokens = 0 if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
if use_cache: # kept for BC (cache positions) past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if not isinstance(past_key_values, StaticCache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length()
if cache_position is None: if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange( cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
) )
...@@ -893,7 +879,7 @@ class GemmaModel(GemmaPreTrainedModel): ...@@ -893,7 +879,7 @@ class GemmaModel(GemmaPreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens) causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values)
# embed positions # embed positions
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -952,7 +938,9 @@ class GemmaModel(GemmaPreTrainedModel): ...@@ -952,7 +938,9 @@ class GemmaModel(GemmaPreTrainedModel):
next_cache = None next_cache = None
if use_cache: if use_cache:
next_cache = ( next_cache = (
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache next_decoder_cache.to_legacy_cache()
if isinstance(next_decoder_cache, DynamicCache)
else next_decoder_cache
) )
if not return_dict: if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
...@@ -968,7 +956,7 @@ class GemmaModel(GemmaPreTrainedModel): ...@@ -968,7 +956,7 @@ class GemmaModel(GemmaPreTrainedModel):
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
input_tensor: torch.Tensor, input_tensor: torch.Tensor,
cache_position: torch.Tensor, cache_position: torch.Tensor,
past_seen_tokens: int, past_key_values: Cache,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
...@@ -980,9 +968,12 @@ class GemmaModel(GemmaPreTrainedModel): ...@@ -980,9 +968,12 @@ class GemmaModel(GemmaPreTrainedModel):
return attention_mask return attention_mask
return None return None
if self.config._attn_implementation == "sdpa": # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# in order to dispatch on Flash Attention 2. # to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
if self.config._attn_implementation == "sdpa" and not using_static_cache:
if AttentionMaskConverter._ignore_causal_mask_sdpa( if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, attention_mask,
inputs_embeds=input_tensor, inputs_embeds=input_tensor,
...@@ -994,9 +985,9 @@ class GemmaModel(GemmaPreTrainedModel): ...@@ -994,9 +985,9 @@ class GemmaModel(GemmaPreTrainedModel):
dtype, device = input_tensor.dtype, input_tensor.device dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1] sequence_length = input_tensor.shape[1]
if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache if using_static_cache:
target_length = self.config.max_position_embeddings target_length = past_key_values.get_max_length()
else: # dynamic cache else:
target_length = ( target_length = (
attention_mask.shape[-1] attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor) if isinstance(attention_mask, torch.Tensor)
...@@ -1018,6 +1009,10 @@ class GemmaModel(GemmaPreTrainedModel): ...@@ -1018,6 +1009,10 @@ class GemmaModel(GemmaPreTrainedModel):
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only. # cache. In that case, the 4D attention mask attends to the newest tokens only.
if attention_mask.shape[-2] < cache_position[0] + sequence_length: if attention_mask.shape[-2] < cache_position[0] + sequence_length:
logger.warning_once(
"Passing a 4d mask shorter than the input length is deprecated and will be removed in "
"transformers v4.42.0"
)
offset = cache_position[0] offset = cache_position[0]
else: else:
offset = 0 offset = 0
...@@ -1079,7 +1074,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel): ...@@ -1079,7 +1074,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
...@@ -1171,13 +1166,6 @@ class GemmaForCausalLM(GemmaPreTrainedModel): ...@@ -1171,13 +1166,6 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
use_cache=True, use_cache=True,
**kwargs, **kwargs,
): ):
# With static cache, the `past_key_values` is None
# TODO joao: standardize interface for the different Cache classes and remove of this if
has_static_cache = False
if past_key_values is None:
past_key_values = getattr(getattr(self.model.layers[0], "self_attn", {}), "past_key_value", None)
has_static_cache = past_key_values is not None
past_length = 0 past_length = 0
if past_key_values is not None: if past_key_values is not None:
if isinstance(past_key_values, Cache): if isinstance(past_key_values, Cache):
...@@ -1195,8 +1183,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel): ...@@ -1195,8 +1183,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
# Keep only the unprocessed tokens: # Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
# input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
...@@ -1236,9 +1223,6 @@ class GemmaForCausalLM(GemmaPreTrainedModel): ...@@ -1236,9 +1223,6 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
elif use_cache: elif use_cache:
cache_position = cache_position[-input_length:] cache_position = cache_position[-input_length:]
if has_static_cache:
past_key_values = None
model_inputs.update( model_inputs.update(
{ {
"position_ids": position_ids, "position_ids": position_ids,
...@@ -1298,7 +1282,7 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel): ...@@ -1298,7 +1282,7 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel):
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
......
...@@ -29,7 +29,7 @@ from torch import nn ...@@ -29,7 +29,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import DynamicCache # we need __iter__ and __len__ of pkv from ...cache_utils import Cache, DynamicCache # we need __iter__ and __len__ of pkv
from ...modeling_attn_mask_utils import ( from ...modeling_attn_mask_utils import (
AttentionMaskConverter, AttentionMaskConverter,
) )
...@@ -1807,7 +1807,7 @@ class JambaForSequenceClassification(JambaPreTrainedModel): ...@@ -1807,7 +1807,7 @@ class JambaForSequenceClassification(JambaPreTrainedModel):
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
......
...@@ -428,6 +428,12 @@ class LlamaFlashAttention2(LlamaAttention): ...@@ -428,6 +428,12 @@ class LlamaFlashAttention2(LlamaAttention):
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if isinstance(past_key_value, StaticCache):
raise ValueError(
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
)
output_attentions = False output_attentions = False
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
...@@ -710,7 +716,7 @@ class LlamaDecoderLayer(nn.Module): ...@@ -710,7 +716,7 @@ class LlamaDecoderLayer(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
...@@ -811,27 +817,6 @@ class LlamaPreTrainedModel(PreTrainedModel): ...@@ -811,27 +817,6 @@ class LlamaPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None):
if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache:
raise ValueError(
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
)
for layer in self.model.layers:
device = layer.input_layernorm.weight.device
if hasattr(self.config, "_pre_quantization_dtype"):
dtype = self.config._pre_quantization_dtype
else:
dtype = layer.self_attn.o_proj.weight.dtype
layer.self_attn.past_key_value = cache_cls(
self.config, max_batch_size, max_cache_len, device=device, dtype=dtype
)
def _reset_cache(self):
for layer in self.model.layers:
layer.self_attn.past_key_value = None
LLAMA_INPUTS_DOCSTRING = r""" LLAMA_INPUTS_DOCSTRING = r"""
Args: Args:
...@@ -946,7 +931,7 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -946,7 +931,7 @@ class LlamaModel(LlamaPreTrainedModel):
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
...@@ -975,23 +960,18 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -975,23 +960,18 @@ class LlamaModel(LlamaPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
past_seen_tokens = 0 if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
if use_cache: # kept for BC (cache positions) past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if not isinstance(past_key_values, StaticCache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length()
if cache_position is None: if cache_position is None:
if isinstance(past_key_values, StaticCache): past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
raise ValueError("cache_position is a required argument when using StaticCache.")
cache_position = torch.arange( cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
) )
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens) causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values)
# embed positions # embed positions
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -1044,7 +1024,9 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -1044,7 +1024,9 @@ class LlamaModel(LlamaPreTrainedModel):
next_cache = None next_cache = None
if use_cache: if use_cache:
next_cache = ( next_cache = (
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache next_decoder_cache.to_legacy_cache()
if isinstance(next_decoder_cache, DynamicCache)
else next_decoder_cache
) )
if not return_dict: if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
...@@ -1060,7 +1042,7 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -1060,7 +1042,7 @@ class LlamaModel(LlamaPreTrainedModel):
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
input_tensor: torch.Tensor, input_tensor: torch.Tensor,
cache_position: torch.Tensor, cache_position: torch.Tensor,
past_seen_tokens: int, past_key_values: Cache,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
...@@ -1072,9 +1054,12 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -1072,9 +1054,12 @@ class LlamaModel(LlamaPreTrainedModel):
return attention_mask return attention_mask
return None return None
if self.config._attn_implementation == "sdpa": # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# in order to dispatch on Flash Attention 2. # to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
if self.config._attn_implementation == "sdpa" and not using_static_cache:
if AttentionMaskConverter._ignore_causal_mask_sdpa( if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, attention_mask,
inputs_embeds=input_tensor, inputs_embeds=input_tensor,
...@@ -1086,9 +1071,9 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -1086,9 +1071,9 @@ class LlamaModel(LlamaPreTrainedModel):
dtype, device = input_tensor.dtype, input_tensor.device dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1] sequence_length = input_tensor.shape[1]
if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache if using_static_cache:
target_length = self.config.max_position_embeddings target_length = past_key_values.get_max_length()
else: # dynamic cache else:
target_length = ( target_length = (
attention_mask.shape[-1] attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor) if isinstance(attention_mask, torch.Tensor)
...@@ -1110,6 +1095,10 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -1110,6 +1095,10 @@ class LlamaModel(LlamaPreTrainedModel):
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only. # cache. In that case, the 4D attention mask attends to the newest tokens only.
if attention_mask.shape[-2] < cache_position[0] + sequence_length: if attention_mask.shape[-2] < cache_position[0] + sequence_length:
logger.warning_once(
"Passing a 4d mask shorter than the input length is deprecated and will be removed in "
"transformers v4.42.0"
)
offset = cache_position[0] offset = cache_position[0]
else: else:
offset = 0 offset = 0
...@@ -1169,7 +1158,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel): ...@@ -1169,7 +1158,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
...@@ -1267,13 +1256,6 @@ class LlamaForCausalLM(LlamaPreTrainedModel): ...@@ -1267,13 +1256,6 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
use_cache=True, use_cache=True,
**kwargs, **kwargs,
): ):
# With static cache, the `past_key_values` is None
# TODO joao: standardize interface for the different Cache classes and remove of this if
has_static_cache = False
if past_key_values is None:
past_key_values = getattr(getattr(self.model.layers[0], "self_attn", {}), "past_key_value", None)
has_static_cache = past_key_values is not None
past_length = 0 past_length = 0
if past_key_values is not None: if past_key_values is not None:
if isinstance(past_key_values, Cache): if isinstance(past_key_values, Cache):
...@@ -1291,8 +1273,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel): ...@@ -1291,8 +1273,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
# Keep only the unprocessed tokens: # Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
# input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
...@@ -1332,9 +1313,6 @@ class LlamaForCausalLM(LlamaPreTrainedModel): ...@@ -1332,9 +1313,6 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
elif use_cache: elif use_cache:
cache_position = cache_position[-input_length:] cache_position = cache_position[-input_length:]
if has_static_cache:
past_key_values = None
model_inputs.update( model_inputs.update(
{ {
"position_ids": position_ids, "position_ids": position_ids,
...@@ -1393,7 +1371,7 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel): ...@@ -1393,7 +1371,7 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
...@@ -1510,7 +1488,7 @@ class LlamaForQuestionAnswering(LlamaPreTrainedModel): ...@@ -1510,7 +1488,7 @@ class LlamaForQuestionAnswering(LlamaPreTrainedModel):
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None, start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None, end_positions: Optional[torch.LongTensor] = None,
......
...@@ -1301,7 +1301,7 @@ class MistralForSequenceClassification(MistralPreTrainedModel): ...@@ -1301,7 +1301,7 @@ class MistralForSequenceClassification(MistralPreTrainedModel):
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
......
...@@ -1525,7 +1525,7 @@ class MixtralForSequenceClassification(MixtralPreTrainedModel): ...@@ -1525,7 +1525,7 @@ class MixtralForSequenceClassification(MixtralPreTrainedModel):
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
......
...@@ -692,7 +692,7 @@ class OlmoDecoderLayer(nn.Module): ...@@ -692,7 +692,7 @@ class OlmoDecoderLayer(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
...@@ -794,27 +794,6 @@ class OlmoPreTrainedModel(PreTrainedModel): ...@@ -794,27 +794,6 @@ class OlmoPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None):
if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache:
raise ValueError(
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
)
for layer in self.model.layers:
device = layer.input_layernorm.weight.device
if hasattr(self.config, "_pre_quantization_dtype"):
dtype = self.config._pre_quantization_dtype
else:
dtype = layer.self_attn.o_proj.weight.dtype
layer.self_attn.past_key_value = cache_cls(
self.config, max_batch_size, max_cache_len, device=device, dtype=dtype
)
def _reset_cache(self):
for layer in self.model.layers:
layer.self_attn.past_key_value = None
OLMO_INPUTS_DOCSTRING = r""" OLMO_INPUTS_DOCSTRING = r"""
Args: Args:
...@@ -930,7 +909,7 @@ class OlmoModel(OlmoPreTrainedModel): ...@@ -930,7 +909,7 @@ class OlmoModel(OlmoPreTrainedModel):
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
...@@ -959,23 +938,18 @@ class OlmoModel(OlmoPreTrainedModel): ...@@ -959,23 +938,18 @@ class OlmoModel(OlmoPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
past_seen_tokens = 0 if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
if use_cache: # kept for BC (cache positions) past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if not isinstance(past_key_values, StaticCache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length()
if cache_position is None: if cache_position is None:
if isinstance(past_key_values, StaticCache): past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
raise ValueError("cache_position is a required argument when using StaticCache.")
cache_position = torch.arange( cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
) )
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens) causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values)
# embed positions # embed positions
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -1028,7 +1002,9 @@ class OlmoModel(OlmoPreTrainedModel): ...@@ -1028,7 +1002,9 @@ class OlmoModel(OlmoPreTrainedModel):
next_cache = None next_cache = None
if use_cache: if use_cache:
next_cache = ( next_cache = (
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache next_decoder_cache.to_legacy_cache()
if isinstance(next_decoder_cache, DynamicCache)
else next_decoder_cache
) )
if not return_dict: if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
...@@ -1045,7 +1021,7 @@ class OlmoModel(OlmoPreTrainedModel): ...@@ -1045,7 +1021,7 @@ class OlmoModel(OlmoPreTrainedModel):
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
input_tensor: torch.Tensor, input_tensor: torch.Tensor,
cache_position: torch.Tensor, cache_position: torch.Tensor,
past_seen_tokens: int, past_key_values: Cache,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
...@@ -1057,9 +1033,12 @@ class OlmoModel(OlmoPreTrainedModel): ...@@ -1057,9 +1033,12 @@ class OlmoModel(OlmoPreTrainedModel):
return attention_mask return attention_mask
return None return None
if self.config._attn_implementation == "sdpa": # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# in order to dispatch on Flash Attention 2. # to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
if self.config._attn_implementation == "sdpa" and not using_static_cache:
if AttentionMaskConverter._ignore_causal_mask_sdpa( if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, attention_mask,
inputs_embeds=input_tensor, inputs_embeds=input_tensor,
...@@ -1071,9 +1050,9 @@ class OlmoModel(OlmoPreTrainedModel): ...@@ -1071,9 +1050,9 @@ class OlmoModel(OlmoPreTrainedModel):
dtype, device = input_tensor.dtype, input_tensor.device dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1] sequence_length = input_tensor.shape[1]
if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache if using_static_cache:
target_length = self.config.max_position_embeddings target_length = past_key_values.get_max_length()
else: # dynamic cache else:
target_length = ( target_length = (
attention_mask.shape[-1] attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor) if isinstance(attention_mask, torch.Tensor)
...@@ -1095,6 +1074,10 @@ class OlmoModel(OlmoPreTrainedModel): ...@@ -1095,6 +1074,10 @@ class OlmoModel(OlmoPreTrainedModel):
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only. # cache. In that case, the 4D attention mask attends to the newest tokens only.
if attention_mask.shape[-2] < cache_position[0] + sequence_length: if attention_mask.shape[-2] < cache_position[0] + sequence_length:
logger.warning_once(
"Passing a 4d mask shorter than the input length is deprecated and will be removed in "
"transformers v4.42.0"
)
offset = cache_position[0] offset = cache_position[0]
else: else:
offset = 0 offset = 0
...@@ -1250,13 +1233,6 @@ class OlmoForCausalLM(OlmoPreTrainedModel): ...@@ -1250,13 +1233,6 @@ class OlmoForCausalLM(OlmoPreTrainedModel):
use_cache=True, use_cache=True,
**kwargs, **kwargs,
): ):
# With static cache, the `past_key_values` is None
# TODO joao: standardize interface for the different Cache classes and remove of this if
has_static_cache = False
if past_key_values is None:
past_key_values = getattr(getattr(self.model.layers[0], "self_attn", {}), "past_key_value", None)
has_static_cache = past_key_values is not None
past_length = 0 past_length = 0
if past_key_values is not None: if past_key_values is not None:
if isinstance(past_key_values, Cache): if isinstance(past_key_values, Cache):
...@@ -1274,8 +1250,7 @@ class OlmoForCausalLM(OlmoPreTrainedModel): ...@@ -1274,8 +1250,7 @@ class OlmoForCausalLM(OlmoPreTrainedModel):
# Keep only the unprocessed tokens: # Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
# input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
...@@ -1315,9 +1290,6 @@ class OlmoForCausalLM(OlmoPreTrainedModel): ...@@ -1315,9 +1290,6 @@ class OlmoForCausalLM(OlmoPreTrainedModel):
elif use_cache: elif use_cache:
cache_position = cache_position[-input_length:] cache_position = cache_position[-input_length:]
if has_static_cache:
past_key_values = None
model_inputs.update( model_inputs.update(
{ {
"position_ids": position_ids, "position_ids": position_ids,
......
...@@ -927,7 +927,7 @@ class PersimmonForSequenceClassification(PersimmonPreTrainedModel): ...@@ -927,7 +927,7 @@ class PersimmonForSequenceClassification(PersimmonPreTrainedModel):
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
......
...@@ -1313,7 +1313,7 @@ class PhiForSequenceClassification(PhiPreTrainedModel): ...@@ -1313,7 +1313,7 @@ class PhiForSequenceClassification(PhiPreTrainedModel):
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
......
...@@ -1419,7 +1419,7 @@ class Phi3ForSequenceClassification(Phi3PreTrainedModel): ...@@ -1419,7 +1419,7 @@ class Phi3ForSequenceClassification(Phi3PreTrainedModel):
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
......
...@@ -1509,7 +1509,7 @@ class Qwen2MoeForSequenceClassification(Qwen2MoePreTrainedModel): ...@@ -1509,7 +1509,7 @@ class Qwen2MoeForSequenceClassification(Qwen2MoePreTrainedModel):
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
......
...@@ -1299,7 +1299,7 @@ class StableLmForSequenceClassification(StableLmPreTrainedModel): ...@@ -1299,7 +1299,7 @@ class StableLmForSequenceClassification(StableLmPreTrainedModel):
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
......
...@@ -1292,7 +1292,7 @@ class Starcoder2ForSequenceClassification(Starcoder2PreTrainedModel): ...@@ -1292,7 +1292,7 @@ class Starcoder2ForSequenceClassification(Starcoder2PreTrainedModel):
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
......
...@@ -18,11 +18,11 @@ import tempfile ...@@ -18,11 +18,11 @@ import tempfile
import unittest import unittest
import pytest import pytest
from packaging import version
from parameterized import parameterized from parameterized import parameterized
from transformers import LlamaConfig, StaticCache, is_torch_available, logging, set_seed from transformers import LlamaConfig, is_torch_available, set_seed
from transformers.testing_utils import ( from transformers.testing_utils import (
CaptureLogger,
require_bitsandbytes, require_bitsandbytes,
require_flash_attn, require_flash_attn,
require_read_token, require_read_token,
...@@ -684,15 +684,28 @@ class LlamaIntegrationTest(unittest.TestCase): ...@@ -684,15 +684,28 @@ class LlamaIntegrationTest(unittest.TestCase):
@require_torch_gpu @require_torch_gpu
@require_read_token @require_read_token
def test_compile_static_cache(self): def test_compile_static_cache(self):
# `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2
# work as intended. See https://github.com/pytorch/pytorch/issues/121943
if version.parse(torch.__version__) < version.parse("2.3.0"):
self.skipTest("This test requires torch >= 2.3 to run.")
NUM_TOKENS_TO_GENERATE = 40 NUM_TOKENS_TO_GENERATE = 40
# Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test
# was changed to have a cache of 53 tokens (as opposed to 4096), on Ampere GPUs.
EXPECTED_TEXT_COMPLETION = { EXPECTED_TEXT_COMPLETION = {
7: [
"Simply put, the theory of relativity states that 1) the speed of light is constant, 2) the speed of light is the same for all observers, and 3) the laws of physics are the same for all observers.",
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p",
],
8: [ 8: [
"Simply put, the theory of relativity states that 1) the speed of light is the same for all observers, and 2) the laws of physics are the same for all observers.\nThe first part of the theory of relativity", "Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial "
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p", "reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe "
"theory of relativ",
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, "
"my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p",
],
7: [
"Simply put, the theory of relativity states that 1. surely nothing is faster than light.\nThe theory "
"goes that nothing travels faster than light, but the faster you go, the slower everything else will "
"be.\nThe theory of relativity",
"My favorite all time favorite condiment is ketchup. I love it on hamburgers, hot dogs, fries, eggs, "
"and even on a good old fashioned cheeseburger. I love it on everything. I love it so",
], ],
} }
...@@ -706,38 +719,25 @@ class LlamaIntegrationTest(unittest.TestCase): ...@@ -706,38 +719,25 @@ class LlamaIntegrationTest(unittest.TestCase):
) )
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
def decode_one_tokens(model, cur_token, input_pos, cache_position): # Dynamic Cache
logits = model( generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False)
cur_token, position_ids=input_pos, cache_position=cache_position, return_dict=False, use_cache=True dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
)[0] self.assertEqual(EXPECTED_TEXT_COMPLETION[8], dynamic_text) # Both GPU architectures have the same output
new_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
return new_token # Static Cache
generated_ids = model.generate(
batch_size, seq_length = inputs["input_ids"].shape **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
with torch.no_grad(): )
model._setup_cache(StaticCache, 2, max_cache_len=4096) static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
cache_position = torch.arange(seq_length, device=torch_device) self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_text)
generated_ids = torch.zeros(
batch_size, seq_length + NUM_TOKENS_TO_GENERATE + 1, dtype=torch.int, device=torch_device # Static Cache + compile
) model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
generated_ids[:, cache_position] = inputs["input_ids"].to(torch_device).to(torch.int) generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
logits = model(**inputs, cache_position=cache_position, return_dict=False, use_cache=True)[0] )
next_token = torch.argmax(logits[:, -1], dim=-1)[:, None] static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
generated_ids[:, seq_length] = next_token[:, 0] self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text)
decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True)
cache_position = torch.tensor([seq_length + 1], device=torch_device)
for _ in range(1, NUM_TOKENS_TO_GENERATE):
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
with CaptureLogger(logging.get_logger(__name__)) as cl:
next_token = decode_one_tokens(model, next_token.clone(), None, cache_position)
self.assertNotIn("skipping cudagraphs due to", cl.out)
generated_ids[:, cache_position] = next_token.int()
cache_position += 1
text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], text)
@require_torch @require_torch
......
...@@ -196,9 +196,14 @@ class AqlmTest(unittest.TestCase): ...@@ -196,9 +196,14 @@ class AqlmTest(unittest.TestCase):
""" """
# Sample tokens greedily # Sample tokens greedily
def decode_one_tokens(model, cur_token, input_pos, cache_position): def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_values):
logits = model( logits = model(
cur_token, position_ids=input_pos, cache_position=cache_position, return_dict=False, use_cache=True cur_token,
position_ids=input_pos,
cache_position=cache_position,
past_key_values=past_key_values,
return_dict=False,
use_cache=True,
)[0] )[0]
new_token = torch.argmax(logits[:, [-1]], dim=-1).to(torch.int) new_token = torch.argmax(logits[:, [-1]], dim=-1).to(torch.int)
...@@ -209,7 +214,13 @@ class AqlmTest(unittest.TestCase): ...@@ -209,7 +214,13 @@ class AqlmTest(unittest.TestCase):
seq_length = input_ids.shape[1] seq_length = input_ids.shape[1]
# Setup static KV cache for generation # Setup static KV cache for generation
self.quantized_model._setup_cache(StaticCache, 1, max_cache_len=seq_length + self.max_new_tokens + 1) past_key_values = StaticCache(
config=self.quantized_model.config,
max_batch_size=1,
max_cache_len=seq_length + self.max_new_tokens + 1,
device=torch_device,
dtype=self.quantized_model.config._pre_quantization_dtype,
)
# Allocate token ids to be generated and copy prefix ids # Allocate token ids to be generated and copy prefix ids
cache_position = torch.arange(seq_length, device=torch_device) cache_position = torch.arange(seq_length, device=torch_device)
...@@ -217,7 +228,13 @@ class AqlmTest(unittest.TestCase): ...@@ -217,7 +228,13 @@ class AqlmTest(unittest.TestCase):
generated_ids[:, cache_position] = input_ids.to(torch_device).to(torch.int) generated_ids[:, cache_position] = input_ids.to(torch_device).to(torch.int)
# Do a forward pass to fill the prefix cache and compile the kernels if necessary # Do a forward pass to fill the prefix cache and compile the kernels if necessary
logits = self.quantized_model(input_ids, cache_position=cache_position, return_dict=False, use_cache=True)[0] logits = self.quantized_model(
input_ids,
cache_position=cache_position,
past_key_values=past_key_values,
return_dict=False,
use_cache=True,
)[0]
next_token = torch.argmax(logits[:, [-1]], dim=-1).to(torch.int) next_token = torch.argmax(logits[:, [-1]], dim=-1).to(torch.int)
generated_ids[:, [seq_length]] = next_token generated_ids[:, [seq_length]] = next_token
...@@ -229,7 +246,9 @@ class AqlmTest(unittest.TestCase): ...@@ -229,7 +246,9 @@ class AqlmTest(unittest.TestCase):
cache_position = torch.tensor([seq_length + 1], device=torch_device) cache_position = torch.tensor([seq_length + 1], device=torch_device)
for _ in range(1, self.max_new_tokens): for _ in range(1, self.max_new_tokens):
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
next_token = decode_one_tokens(self.quantized_model, next_token.clone(), None, cache_position) next_token = decode_one_tokens(
self.quantized_model, next_token.clone(), None, cache_position, past_key_values
)
generated_ids.index_copy_(1, cache_position, next_token) generated_ids.index_copy_(1, cache_position, next_token)
cache_position += 1 cache_position += 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment