Unverified Commit 616bb11d authored by Longjie Zheng's avatar Longjie Zheng Committed by GitHub
Browse files

Add torch.compile for Mistral (#30642)

* first version

* fix sliding window

* fix style

* add sliding window cache

* fix style

* address comments

* fix test

* fix style

* move sliding window check inside cache init

* revert changes on irrelevant files & add comment on SlidingWindowCache

* address comments & fix style

fix style

* update causal mask

* [run-slow] mistral

* [run-slow] mistral

* [run-slow] mistral

* [run-slow] mistral

* [run-slow] mistral

* [run-slow] llama

* [run-slow] mistral

* [run-slow] mistral

* [run-slow] mistral

* revert CI from a10 to t4

* wrap up
parent 92d1d97c
...@@ -29,7 +29,7 @@ To optimize this, you can use a kv-cache to store the past keys and values inste ...@@ -29,7 +29,7 @@ To optimize this, you can use a kv-cache to store the past keys and values inste
The *static kv-cache* solves this issue by pre-allocating the kv-cache size to a maximum value which allows you to combine it with torch.compile for up to a 4x speed up. The *static kv-cache* solves this issue by pre-allocating the kv-cache size to a maximum value which allows you to combine it with torch.compile for up to a 4x speed up.
> [!WARNING] > [!WARNING]
> Currently, only [Command R](./model_doc/cohere), [Gemma](./model_doc/gemma) and [Llama](./model_doc/llama2) models support static kv-cache and torch.compile. > Currently, only [Llama](./model_doc/llama2) and a few other models support static kv-cache and torch.compile. Check [this issue](https://github.com/huggingface/transformers/issues/28981) for a live model compatibility list.
For this example, let's load the [Gemma](https://hf.co/google/gemma-2b) model. For this example, let's load the [Gemma](https://hf.co/google/gemma-2b) model.
......
...@@ -448,3 +448,124 @@ class StaticCache(Cache): ...@@ -448,3 +448,124 @@ class StaticCache(Cache):
# In-place ops prevent breaking the static address # In-place ops prevent breaking the static address
self.key_cache[layer_idx].zero_() self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_() self.value_cache[layer_idx].zero_()
class SlidingWindowCache(Cache):
"""
Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention.
Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.config.sliding_window_size - 1`,
if true(which means the cache can not hold all the old key value states and new states together because of the sliding window constraint),
we need to do a cycle shift based on `indices` to replace the oldest states by the new key value states passed in.
The `to_shift` is only true once we are above sliding_window_size. Thus with `sliding_window_size==64`:
indices = (slicing + to_shift[-1].int()-1) % self.config.sliding_window_size
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
55, 56, 57, 58, 59, 60, 61, 62, 63, 0])
We overwrite the cache using these, then we always write at cache_position (clamped to `sliding_window_size`)
Parameters:
config (`PretrainedConfig):
The configuration file defining the shape-related attributes required to initialize the static cache.
max_batch_size (`int`):
The maximum batch size with which the model will be used.
max_cache_len (`int`):
The maximum sequence length with which the model will be used.
device (`torch.device`):
The device on which the cache should be initialized. Should be the same as the layer.
dtype (*optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer.
"""
def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
if not hasattr(config, "sliding_window") or config.sliding_window is None:
raise ValueError(
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
"sliding window attention, please check if there is a `sliding_window` field in the model "
"config and it's not set to None."
)
super().__init__()
self.max_batch_size = max_batch_size
# take the minimum of max_cache_len and config.sliding_window so that we allocate less memory
# when we do short-sentence generation
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
self.model_sliding_window_size = config.sliding_window
self.sliding_window_size = min(self.max_cache_len, self.model_sliding_window_size)
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
self.head_dim = (
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
)
self.dtype = dtype if dtype is not None else torch.float32
self.num_key_value_heads = (
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
)
cache_shape = (
config.num_hidden_layers,
max_batch_size,
self.num_key_value_heads,
self.sliding_window_size,
self.head_dim,
)
self.key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
self.value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
torch._dynamo.mark_static_address(self.key_cache)
torch._dynamo.mark_static_address(self.value_cache)
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor]:
cache_position = cache_kwargs.get("cache_position")
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
# assume this only happens in prefill phase when prompt length > sliding_window_size
if cache_position.shape[0] > self.sliding_window_size:
k_out = key_states[:, :, -self.sliding_window_size :, :]
v_out = value_states[:, :, -self.sliding_window_size :, :]
self.key_cache[layer_idx] = k_out
self.value_cache[layer_idx] = v_out
# we should return the whole states instead of k_out, v_out to take the whole prompt
# into consideration when building kv cache instead of just throwing away tokens outside of the window
return key_states, value_states
slicing = torch.ones(self.sliding_window_size, dtype=torch.long, device=value_states.device).cumsum(0)
cache_position = cache_position.clamp(0, self.sliding_window_size - 1)
to_shift = cache_position >= self.sliding_window_size - 1
indices = (slicing + to_shift[-1].int() - 1) % self.sliding_window_size
k_out = k_out[:, :, indices]
v_out = v_out[:, :, indices]
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
self.key_cache[layer_idx] = k_out
self.value_cache[layer_idx] = v_out
return k_out, v_out
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
# assume this will be called only in the first generation step
# `cache_postion` will be used in other cases
return 0
def get_max_length(self) -> Optional[int]:
# in theory there is no limit because the sliding window size is fixed
# no matter how long the sentence is
return None
def reset(self):
self.key_cache.zero_()
self.value_cache.zero_()
...@@ -24,7 +24,7 @@ import torch ...@@ -24,7 +24,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from torch import nn from torch import nn
from ..cache_utils import Cache, DynamicCache, StaticCache from ..cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
from ..integrations.deepspeed import is_deepspeed_zero3_enabled from ..integrations.deepspeed import is_deepspeed_zero3_enabled
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
from ..models.auto import ( from ..models.auto import (
...@@ -96,9 +96,7 @@ logger = logging.get_logger(__name__) ...@@ -96,9 +96,7 @@ logger = logging.get_logger(__name__)
if is_accelerate_available(): if is_accelerate_available():
from accelerate.hooks import AlignDevicesHook, add_hook_to_module from accelerate.hooks import AlignDevicesHook, add_hook_to_module
NEED_SETUP_CACHE_CLASSES_MAPPING = { NEED_SETUP_CACHE_CLASSES_MAPPING = {"static": StaticCache, "sliding_window": SlidingWindowCache}
"static": StaticCache,
}
@dataclass @dataclass
...@@ -1326,24 +1324,33 @@ class GenerationMixin: ...@@ -1326,24 +1324,33 @@ class GenerationMixin:
model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device) model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device)
return model_kwargs return model_kwargs
def _get_static_cache(self, max_batch_size: int, max_cache_len: int) -> StaticCache: def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_len: int) -> Cache:
""" """
Sets a static cache for `generate`, that will persist across calls. A new cache will only be initialized a Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a
new `generate` call requires a larger cache. new `generate` call requires a larger cache.
Returns the resulting static cache object. Returns the resulting cache object.
""" """
needs_new_cache = ( cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation]
not hasattr(self, "_static_cache") need_new_cache = (
or self._static_cache.max_batch_size < max_batch_size not hasattr(self, "_cache")
or self._static_cache.max_cache_len < max_cache_len or (not isinstance(self._cache, cache_cls))
or self._cache.max_batch_size < max_batch_size
) )
if needs_new_cache: if cache_implementation == "sliding_window":
need_new_cache = need_new_cache or (
self._cache.sliding_window_size < self._cache.model_sliding_window_size
and max_cache_len > self._cache.max_cache_len
)
elif cache_implementation == "static":
need_new_cache = need_new_cache or self._cache.max_cache_len < max_cache_len
if need_new_cache:
if hasattr(self.config, "_pre_quantization_dtype"): if hasattr(self.config, "_pre_quantization_dtype"):
cache_dtype = self.config._pre_quantization_dtype cache_dtype = self.config._pre_quantization_dtype
else: else:
cache_dtype = self.dtype cache_dtype = self.dtype
self._static_cache = StaticCache( self._cache = cache_cls(
config=self.config, config=self.config,
max_batch_size=max_batch_size, max_batch_size=max_batch_size,
max_cache_len=max_cache_len, max_cache_len=max_cache_len,
...@@ -1351,8 +1358,8 @@ class GenerationMixin: ...@@ -1351,8 +1358,8 @@ class GenerationMixin:
dtype=cache_dtype, dtype=cache_dtype,
) )
else: else:
self._static_cache.reset() # reset the cache for a new generation self._cache.reset()
return self._static_cache return self._cache
def _prepare_special_tokens( def _prepare_special_tokens(
self, self,
...@@ -1615,14 +1622,14 @@ class GenerationMixin: ...@@ -1615,14 +1622,14 @@ class GenerationMixin:
"This model does not support the `cache_implementation` argument. Please check the following " "This model does not support the `cache_implementation` argument. Please check the following "
"issue: https://github.com/huggingface/transformers/issues/28981." "issue: https://github.com/huggingface/transformers/issues/28981."
) )
if generation_config.cache_implementation == "static": if generation_config.cache_implementation == "static" and not self._supports_static_cache:
if not self._supports_static_cache:
raise ValueError( raise ValueError(
"This model does not support `cache_implementation='static'`. Please check the following " "This model does not support `cache_implementation='static'`. Please check the following "
"issue: https://github.com/huggingface/transformers/issues/28981" "issue: https://github.com/huggingface/transformers/issues/28981"
) )
model_kwargs["past_key_values"] = self._get_static_cache(batch_size, generation_config.max_length) model_kwargs["past_key_values"] = self._get_cache(
generation_config.cache_implementation, batch_size, generation_config.max_length
)
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
# 7. determine generation mode # 7. determine generation mode
......
...@@ -63,7 +63,7 @@ class OpenLlamaRMSNorm(nn.Module): ...@@ -63,7 +63,7 @@ class OpenLlamaRMSNorm(nn.Module):
return self.weight * hidden_states.to(input_dtype) return self.weight * hidden_states.to(input_dtype)
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->OpenLlama # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->OpenLlama
class OpenLlamaRotaryEmbedding(nn.Module): class OpenLlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__() super().__init__()
...@@ -154,7 +154,7 @@ def rotate_half(x): ...@@ -154,7 +154,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb # Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors. """Applies Rotary Position Embedding to the query and key tensors.
......
...@@ -81,7 +81,7 @@ def rotate_half(x): ...@@ -81,7 +81,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb # Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors. """Applies Rotary Position Embedding to the query and key tensors.
...@@ -123,7 +123,7 @@ def _get_unpad_data(attention_mask): ...@@ -123,7 +123,7 @@ def _get_unpad_data(attention_mask):
) )
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Falcon # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Falcon
class FalconRotaryEmbedding(nn.Module): class FalconRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__() super().__init__()
......
...@@ -253,7 +253,6 @@ class GemmaAttention(nn.Module): ...@@ -253,7 +253,6 @@ class GemmaAttention(nn.Module):
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
...@@ -265,8 +264,8 @@ class GemmaAttention(nn.Module): ...@@ -265,8 +264,8 @@ class GemmaAttention(nn.Module):
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None: if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache # sin and cos are specific to RoPE models; cache_position needed for the static cache
......
...@@ -522,7 +522,7 @@ def attention_mask_func(attention_scores, ltor_mask): ...@@ -522,7 +522,7 @@ def attention_mask_func(attention_scores, ltor_mask):
class GPTNeoXRotaryEmbedding(nn.Module): class GPTNeoXRotaryEmbedding(nn.Module):
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding.__init__ # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding.__init__
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__() super().__init__()
...@@ -614,7 +614,7 @@ def rotate_half(x): ...@@ -614,7 +614,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb # Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors. """Applies Rotary Position Embedding to the query and key tensors.
......
...@@ -230,7 +230,7 @@ class GPTNeoXJapaneseAttention(nn.Module): ...@@ -230,7 +230,7 @@ class GPTNeoXJapaneseAttention(nn.Module):
# Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXRotaryEmbedding with GPTNeoXRotaryEmbedding->RotaryEmbedding # Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXRotaryEmbedding with GPTNeoXRotaryEmbedding->RotaryEmbedding
class RotaryEmbedding(nn.Module): class RotaryEmbedding(nn.Module):
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding.__init__ # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding.__init__
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__() super().__init__()
......
...@@ -478,7 +478,7 @@ def rotate_half(x): ...@@ -478,7 +478,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb # Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors. """Applies Rotary Position Embedding to the query and key tensors.
......
...@@ -303,7 +303,6 @@ class LlamaAttention(nn.Module): ...@@ -303,7 +303,6 @@ class LlamaAttention(nn.Module):
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
...@@ -402,7 +401,6 @@ class LlamaFlashAttention2(LlamaAttention): ...@@ -402,7 +401,6 @@ class LlamaFlashAttention2(LlamaAttention):
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if isinstance(past_key_value, StaticCache): if isinstance(past_key_value, StaticCache):
raise ValueError( raise ValueError(
......
...@@ -17,7 +17,8 @@ ...@@ -17,7 +17,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" PyTorch Mistral model.""" """PyTorch Mistral model."""
import inspect import inspect
import math import math
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
...@@ -29,8 +30,8 @@ from torch import nn ...@@ -29,8 +30,8 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPast, BaseModelOutputWithPast,
CausalLMOutputWithPast, CausalLMOutputWithPast,
...@@ -92,8 +93,6 @@ class MistralRMSNorm(nn.Module): ...@@ -92,8 +93,6 @@ class MistralRMSNorm(nn.Module):
return self.weight * hidden_states.to(input_dtype) return self.weight * hidden_states.to(input_dtype)
# copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral
# TODO @Arthur no longer copied from LLama after static cache
class MistralRotaryEmbedding(nn.Module): class MistralRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__() super().__init__()
...@@ -104,30 +103,22 @@ class MistralRotaryEmbedding(nn.Module): ...@@ -104,30 +103,22 @@ class MistralRotaryEmbedding(nn.Module):
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
# Build here to make `torch.jit.trace` work. @torch.no_grad()
self._set_cos_sin_cache( # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() def forward(self, x, position_ids):
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size] # x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached: inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
return ( # See https://github.com/huggingface/transformers/pull/29285
self.cos_cached[:seq_len].to(dtype=x.dtype), device_type = x.device.type
self.sin_cached[:seq_len].to(dtype=x.dtype), device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
) with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
# Copied from transformers.models.llama.modeling_llama.rotate_half # Copied from transformers.models.llama.modeling_llama.rotate_half
...@@ -138,9 +129,8 @@ def rotate_half(x): ...@@ -138,9 +129,8 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
# copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
# TODO @Arthur no longer copied from LLama after static cache def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors. """Applies Rotary Position Embedding to the query and key tensors.
Args: Args:
...@@ -148,9 +138,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): ...@@ -148,9 +138,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
k (`torch.Tensor`): The key tensor. k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding. cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`): position_ids (`torch.Tensor`, *optional*):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be Deprecated and unused.
used to pass offsetted position ids when working with a KV-cache.
unsqueeze_dim (`int`, *optional*, defaults to 1): unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
...@@ -161,8 +150,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): ...@@ -161,8 +150,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
Returns: Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
""" """
cos = cos[position_ids].unsqueeze(unsqueeze_dim) cos = cos.unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin) q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin) k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed return q_embed, k_embed
...@@ -213,6 +202,7 @@ class MistralAttention(nn.Module): ...@@ -213,6 +202,7 @@ class MistralAttention(nn.Module):
"when creating this class." "when creating this class."
) )
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads self.head_dim = self.hidden_size // self.num_heads
...@@ -221,7 +211,6 @@ class MistralAttention(nn.Module): ...@@ -221,7 +211,6 @@ class MistralAttention(nn.Module):
self.max_position_embeddings = config.max_position_embeddings self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta self.rope_theta = config.rope_theta
self.is_causal = True self.is_causal = True
self.attention_dropout = config.attention_dropout
if (self.head_dim * self.num_heads) != self.hidden_size: if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError( raise ValueError(
...@@ -231,7 +220,7 @@ class MistralAttention(nn.Module): ...@@ -231,7 +220,7 @@ class MistralAttention(nn.Module):
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
self.rotary_emb = MistralRotaryEmbedding( self.rotary_emb = MistralRotaryEmbedding(
self.head_dim, self.head_dim,
...@@ -239,9 +228,7 @@ class MistralAttention(nn.Module): ...@@ -239,9 +228,7 @@ class MistralAttention(nn.Module):
base=self.rope_theta, base=self.rope_theta,
) )
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): # Copied from transformers.models.gemma.modeling_gemma.GemmaAttention.forward with Gemma->Mistral
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -250,6 +237,7 @@ class MistralAttention(nn.Module): ...@@ -250,6 +237,7 @@ class MistralAttention(nn.Module):
past_key_value: Optional[Cache] = None, past_key_value: Optional[Cache] = None,
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
...@@ -261,41 +249,22 @@ class MistralAttention(nn.Module): ...@@ -261,41 +249,22 @@ class MistralAttention(nn.Module):
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2] cos, sin = self.rotary_emb(value_states, position_ids)
if past_key_value is not None: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None: if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models # sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups) key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): if attention_mask is not None: # no matter the length, we just slice it
raise ValueError( causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" attn_weights = attn_weights + causal_mask
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32 # upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
...@@ -309,8 +278,8 @@ class MistralAttention(nn.Module): ...@@ -309,8 +278,8 @@ class MistralAttention(nn.Module):
) )
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
if not output_attentions: if not output_attentions:
...@@ -343,7 +312,16 @@ class MistralFlashAttention2(MistralAttention): ...@@ -343,7 +312,16 @@ class MistralFlashAttention2(MistralAttention):
past_key_value: Optional[Cache] = None, past_key_value: Optional[Cache] = None,
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
): ):
if isinstance(past_key_value, StaticCache):
raise ValueError(
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
)
output_attentions = False
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states) query_states = self.q_proj(hidden_states)
...@@ -356,19 +334,10 @@ class MistralFlashAttention2(MistralAttention): ...@@ -356,19 +334,10 @@ class MistralFlashAttention2(MistralAttention):
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
if past_key_value is not None: if past_key_value is not None:
if self.layer_idx is None: kv_seq_len += cache_position[0]
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
# Because the input can be padded, the absolute sequence length depends on the max position id.
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
use_sliding_windows = ( use_sliding_windows = (
_flash_supports_window_size _flash_supports_window_size
...@@ -605,8 +574,7 @@ class MistralFlashAttention2(MistralAttention): ...@@ -605,8 +574,7 @@ class MistralFlashAttention2(MistralAttention):
) )
# copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral # Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral
# TODO @Arthur no longer copied from LLama after static cache
class MistralSdpaAttention(MistralAttention): class MistralSdpaAttention(MistralAttention):
""" """
Mistral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from Mistral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
...@@ -623,6 +591,7 @@ class MistralSdpaAttention(MistralAttention): ...@@ -623,6 +591,7 @@ class MistralSdpaAttention(MistralAttention):
past_key_value: Optional[Cache] = None, past_key_value: Optional[Cache] = None,
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions: if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
...@@ -637,6 +606,7 @@ class MistralSdpaAttention(MistralAttention): ...@@ -637,6 +606,7 @@ class MistralSdpaAttention(MistralAttention):
past_key_value=past_key_value, past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position,
) )
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
...@@ -649,43 +619,38 @@ class MistralSdpaAttention(MistralAttention): ...@@ -649,43 +619,38 @@ class MistralSdpaAttention(MistralAttention):
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2] cos, sin = self.rotary_emb(value_states, position_ids)
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None: if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models # sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups) key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None: if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577. # Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and attention_mask is not None: if query_states.device.type == "cuda" and causal_mask is not None:
query_states = query_states.contiguous() query_states = query_states.contiguous()
key_states = key_states.contiguous() key_states = key_states.contiguous()
value_states = value_states.contiguous() value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. is_causal = True if causal_mask is None and q_len > 1 else False
is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention( attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, query_states,
key_states, key_states,
value_states, value_states,
attn_mask=attention_mask, attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0, dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal, is_causal=is_causal,
) )
...@@ -705,12 +670,13 @@ MISTRAL_ATTENTION_CLASSES = { ...@@ -705,12 +670,13 @@ MISTRAL_ATTENTION_CLASSES = {
} }
# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Mistral, LLAMA->MISTRAL
class MistralDecoderLayer(nn.Module): class MistralDecoderLayer(nn.Module):
def __init__(self, config: MistralConfig, layer_idx: int): def __init__(self, config: MistralConfig, layer_idx: int):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
self.mlp = MistralMLP(config) self.mlp = MistralMLP(config)
self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -721,15 +687,17 @@ class MistralDecoderLayer(nn.Module): ...@@ -721,15 +687,17 @@ class MistralDecoderLayer(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
Args: Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size attention_mask (`torch.FloatTensor`, *optional*):
`(batch, sequence_length)` where padding elements are indicated by 0. attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
query_sequence_length, key_sequence_length)` if default attention is used.
output_attentions (`bool`, *optional*): output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail. returned tensors for more detail.
...@@ -751,6 +719,7 @@ class MistralDecoderLayer(nn.Module): ...@@ -751,6 +719,7 @@ class MistralDecoderLayer(nn.Module):
past_key_value=past_key_value, past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position,
) )
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
...@@ -801,6 +770,7 @@ class MistralPreTrainedModel(PreTrainedModel): ...@@ -801,6 +770,7 @@ class MistralPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_cache_class = True _supports_cache_class = True
_supports_static_cache = True
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.initializer_range std = self.config.initializer_range
...@@ -924,12 +894,13 @@ class MistralModel(MistralPreTrainedModel): ...@@ -924,12 +894,13 @@ class MistralModel(MistralPreTrainedModel):
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]: ) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
...@@ -940,72 +911,36 @@ class MistralModel(MistralPreTrainedModel): ...@@ -940,72 +911,36 @@ class MistralModel(MistralPreTrainedModel):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds # retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None: if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") raise ValueError(
elif input_ids is not None: "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
batch_size, seq_length = input_ids.shape )
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training and use_cache:
if use_cache:
logger.warning_once( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
) )
use_cache = False use_cache = False
past_key_values_length = 0 if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if use_cache: return_legacy_cache = False
use_legacy_cache = not isinstance(past_key_values, Cache) if use_cache and not isinstance(past_key_values, Cache):
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_usable_length(seq_length) return_legacy_cache = True
if position_ids is None: if cache_position is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
position_ids = torch.arange( cache_position = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
) )
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: if position_ids is None:
is_padding_right = attention_mask[:, -1].sum().item() != batch_size position_ids = cache_position.unsqueeze(0)
if is_padding_right:
raise ValueError(
"You are attempting to perform batched generation with padding_side='right'"
" this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
)
if self._attn_implementation == "flash_attention_2": causal_mask = self._update_causal_mask(
# 2d mask is passed through the layers attention_mask, inputs_embeds, cache_position, past_key_values, use_cache, output_attentions
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._attn_implementation == "sdpa" and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window=self.config.sliding_window,
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window=self.config.sliding_window,
) )
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -1023,20 +958,22 @@ class MistralModel(MistralPreTrainedModel): ...@@ -1023,20 +958,22 @@ class MistralModel(MistralPreTrainedModel):
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, decoder_layer.__call__,
hidden_states, hidden_states,
attention_mask, causal_mask,
position_ids, position_ids,
past_key_values, past_key_values,
output_attentions, output_attentions,
use_cache, use_cache,
cache_position,
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
hidden_states, hidden_states,
attention_mask=attention_mask, attention_mask=causal_mask,
position_ids=position_ids, position_ids=position_ids,
past_key_value=past_key_values, past_key_value=past_key_values,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
...@@ -1053,9 +990,9 @@ class MistralModel(MistralPreTrainedModel): ...@@ -1053,9 +990,9 @@ class MistralModel(MistralPreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
next_cache = None next_cache = next_decoder_cache if use_cache else None
if use_cache: if return_legacy_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache next_cache = next_cache.to_legacy_cache()
if not return_dict: if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
...@@ -1066,6 +1003,113 @@ class MistralModel(MistralPreTrainedModel): ...@@ -1066,6 +1003,113 @@ class MistralModel(MistralPreTrainedModel):
attentions=all_self_attns, attentions=all_self_attns,
) )
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
use_cache: bool,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self._attn_implementation == "flash_attention_2":
if attention_mask is not None and use_cache:
is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
if is_padding_right:
raise ValueError(
"You are attempting to perform batched generation with padding_side='right'"
" this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
)
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
# cache_position must be valid here no matter which cache we use
past_seen_tokens = cache_position[0] if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
if (
self.config._attn_implementation == "sdpa"
and not (using_static_cache or using_sliding_window_cache)
and not output_attentions
):
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
sliding_window=self.config.sliding_window,
is_training=self.training,
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
# SlidingWindowCache
if using_sliding_window_cache:
target_length = max(sequence_length, self.config.sliding_window)
# StaticCache
elif using_static_cache:
target_length = past_key_values.get_max_length()
# DynamicCache or no cache
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
if attention_mask is not None and attention_mask.dim() == 4:
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
if attention_mask.max() != 0:
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
causal_mask = attention_mask
else:
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
exclude_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
if self.config.sliding_window is not None:
if not using_sliding_window_cache or sequence_length > self.config.sliding_window:
exclude_mask |= torch.arange(target_length, device=device) <= (
cache_position.reshape(-1, 1) - self.config.sliding_window
)
causal_mask *= exclude_mask
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.dim() == 2:
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
class MistralForCausalLM(MistralPreTrainedModel): class MistralForCausalLM(MistralPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
...@@ -1104,13 +1148,14 @@ class MistralForCausalLM(MistralPreTrainedModel): ...@@ -1104,13 +1148,14 @@ class MistralForCausalLM(MistralPreTrainedModel):
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]: ) -> Union[Tuple, CausalLMOutputWithPast]:
r""" r"""
Args: Args:
...@@ -1155,6 +1200,7 @@ class MistralForCausalLM(MistralPreTrainedModel): ...@@ -1155,6 +1200,7 @@ class MistralForCausalLM(MistralPreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
cache_position=cache_position,
) )
hidden_states = outputs[0] hidden_states = outputs[0]
...@@ -1187,14 +1233,27 @@ class MistralForCausalLM(MistralPreTrainedModel): ...@@ -1187,14 +1233,27 @@ class MistralForCausalLM(MistralPreTrainedModel):
) )
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
use_cache=True,
**kwargs,
): ):
# Omit tokens covered by past_key_values # Omit tokens covered by past_key_values
past_length = 0
if past_key_values is not None: if past_key_values is not None:
if isinstance(past_key_values, Cache): if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length() past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens max_cache_length = (
max_cache_length = past_key_values.get_max_length() torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
if past_key_values.get_max_length() is not None
else None
)
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
else: else:
cache_length = past_length = past_key_values[0][0].shape[2] cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None max_cache_length = None
...@@ -1227,17 +1286,33 @@ class MistralForCausalLM(MistralPreTrainedModel): ...@@ -1227,17 +1286,33 @@ class MistralForCausalLM(MistralPreTrainedModel):
if past_key_values: if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :] position_ids = position_ids[:, -input_ids.shape[1] :]
# crop the attention_mask to sliding window size during decode phase if using SlidingWindowCache
if (
past_length > 0
and attention_mask is not None
and isinstance(past_key_values, SlidingWindowCache)
and attention_mask.shape[1] > past_key_values.sliding_window_size
):
attention_mask = attention_mask[:, -past_key_values.sliding_window_size :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None: if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds}
else: else:
model_inputs = {"input_ids": input_ids} model_inputs = {"input_ids": input_ids.contiguous()}
input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
if cache_position is None:
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
elif use_cache:
cache_position = cache_position[-input_length:]
model_inputs.update( model_inputs.update(
{ {
"position_ids": position_ids, "position_ids": position_ids,
"cache_position": cache_position,
"past_key_values": past_key_values, "past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"), "use_cache": use_cache,
"attention_mask": attention_mask, "attention_mask": attention_mask,
} }
) )
......
...@@ -181,7 +181,8 @@ class MixtralRMSNorm(nn.Module): ...@@ -181,7 +181,8 @@ class MixtralRMSNorm(nn.Module):
return self.weight * hidden_states.to(input_dtype) return self.weight * hidden_states.to(input_dtype)
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mixtral # copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mixtral
# TODO @longjie no longer copied from Mistral after static cache
class MixtralRotaryEmbedding(nn.Module): class MixtralRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__() super().__init__()
...@@ -226,7 +227,8 @@ def rotate_half(x): ...@@ -226,7 +227,8 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb # copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
# TODO @longjie no longer copied from Mistral after static cache
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors. """Applies Rotary Position Embedding to the query and key tensors.
...@@ -268,7 +270,8 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: ...@@ -268,7 +270,8 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
# Copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Mixtral # copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Mixtral
# TODO @longjie no longer copied from Mistral after static cache
class MixtralAttention(nn.Module): class MixtralAttention(nn.Module):
""" """
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
...@@ -392,7 +395,8 @@ class MixtralAttention(nn.Module): ...@@ -392,7 +395,8 @@ class MixtralAttention(nn.Module):
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
# Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Mixtral # copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Mixtral
# TODO @longjie no longer copied from Mistral after static cache
class MixtralFlashAttention2(MixtralAttention): class MixtralFlashAttention2(MixtralAttention):
""" """
Mixtral flash attention module. This module inherits from `MixtralAttention` as the weights of the module stays Mixtral flash attention module. This module inherits from `MixtralAttention` as the weights of the module stays
...@@ -679,7 +683,8 @@ class MixtralFlashAttention2(MixtralAttention): ...@@ -679,7 +683,8 @@ class MixtralFlashAttention2(MixtralAttention):
) )
# Copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Mixtral # copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Mixtral
# TODO @longjie no longer copied from Mistral after static cache
class MixtralSdpaAttention(MixtralAttention): class MixtralSdpaAttention(MixtralAttention):
""" """
Mixtral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from Mixtral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
...@@ -958,7 +963,7 @@ MIXTRAL_START_DOCSTRING = r""" ...@@ -958,7 +963,7 @@ MIXTRAL_START_DOCSTRING = r"""
"The bare Mixtral Model outputting raw hidden-states without any specific head on top.", "The bare Mixtral Model outputting raw hidden-states without any specific head on top.",
MIXTRAL_START_DOCSTRING, MIXTRAL_START_DOCSTRING,
) )
# Copied from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel with Mistral->Mixtral # Copied from transformers.models.qwen2.modeling_qwen2.Qwen2PreTrainedModel with Qwen2->Mixtral
class MixtralPreTrainedModel(PreTrainedModel): class MixtralPreTrainedModel(PreTrainedModel):
config_class = MixtralConfig config_class = MixtralConfig
base_model_prefix = "model" base_model_prefix = "model"
...@@ -1052,7 +1057,8 @@ MIXTRAL_INPUTS_DOCSTRING = r""" ...@@ -1052,7 +1057,8 @@ MIXTRAL_INPUTS_DOCSTRING = r"""
"The bare Mixtral Model outputting raw hidden-states without any specific head on top.", "The bare Mixtral Model outputting raw hidden-states without any specific head on top.",
MIXTRAL_START_DOCSTRING, MIXTRAL_START_DOCSTRING,
) )
# Copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral # copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral
# TODO @longjie no longer copied from Mistral after static cache
class MixtralModel(MixtralPreTrainedModel): class MixtralModel(MixtralPreTrainedModel):
""" """
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`] Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`]
......
...@@ -45,7 +45,7 @@ logger = logging.get_logger(__name__) ...@@ -45,7 +45,7 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "PersimmonConfig" _CONFIG_FOR_DOC = "PersimmonConfig"
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Persimmon # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Persimmon
class PersimmonRotaryEmbedding(nn.Module): class PersimmonRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__() super().__init__()
...@@ -137,7 +137,7 @@ def rotate_half(x): ...@@ -137,7 +137,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb # Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors. """Applies Rotary Position Embedding to the query and key tensors.
......
...@@ -76,7 +76,7 @@ def _get_unpad_data(attention_mask): ...@@ -76,7 +76,7 @@ def _get_unpad_data(attention_mask):
) )
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Phi # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Phi
class PhiRotaryEmbedding(nn.Module): class PhiRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__() super().__init__()
...@@ -168,7 +168,7 @@ def rotate_half(x): ...@@ -168,7 +168,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb # Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors. """Applies Rotary Position Embedding to the query and key tensors.
......
...@@ -94,7 +94,7 @@ class Qwen2RMSNorm(nn.Module): ...@@ -94,7 +94,7 @@ class Qwen2RMSNorm(nn.Module):
return self.weight * hidden_states.to(input_dtype) return self.weight * hidden_states.to(input_dtype)
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Qwen2 # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2
class Qwen2RotaryEmbedding(nn.Module): class Qwen2RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__() super().__init__()
...@@ -139,7 +139,7 @@ def rotate_half(x): ...@@ -139,7 +139,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb # Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors. """Applies Rotary Position Embedding to the query and key tensors.
...@@ -611,7 +611,7 @@ class Qwen2FlashAttention2(Qwen2Attention): ...@@ -611,7 +611,7 @@ class Qwen2FlashAttention2(Qwen2Attention):
) )
# Copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Qwen2 # Copied from transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention with Mixtral->Qwen2
class Qwen2SdpaAttention(Qwen2Attention): class Qwen2SdpaAttention(Qwen2Attention):
""" """
Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
......
...@@ -170,7 +170,7 @@ class Qwen2MoeRMSNorm(nn.Module): ...@@ -170,7 +170,7 @@ class Qwen2MoeRMSNorm(nn.Module):
return self.weight * hidden_states.to(input_dtype) return self.weight * hidden_states.to(input_dtype)
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Qwen2Moe # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2Moe
class Qwen2MoeRotaryEmbedding(nn.Module): class Qwen2MoeRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__() super().__init__()
...@@ -215,7 +215,7 @@ def rotate_half(x): ...@@ -215,7 +215,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb # Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors. """Applies Rotary Position Embedding to the query and key tensors.
...@@ -689,7 +689,7 @@ class Qwen2MoeFlashAttention2(Qwen2MoeAttention): ...@@ -689,7 +689,7 @@ class Qwen2MoeFlashAttention2(Qwen2MoeAttention):
) )
# Copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Qwen2Moe # Copied from transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention with Mixtral->Qwen2Moe
class Qwen2MoeSdpaAttention(Qwen2MoeAttention): class Qwen2MoeSdpaAttention(Qwen2MoeAttention):
""" """
Qwen2Moe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from Qwen2Moe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
......
...@@ -71,7 +71,7 @@ def _get_unpad_data(attention_mask): ...@@ -71,7 +71,7 @@ def _get_unpad_data(attention_mask):
) )
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->StableLm # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->StableLm
class StableLmRotaryEmbedding(nn.Module): class StableLmRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__() super().__init__()
...@@ -163,7 +163,7 @@ def rotate_half(x): ...@@ -163,7 +163,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb # Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors. """Applies Rotary Position Embedding to the query and key tensors.
......
...@@ -74,7 +74,7 @@ def _get_unpad_data(attention_mask): ...@@ -74,7 +74,7 @@ def _get_unpad_data(attention_mask):
) )
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Starcoder2 # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Starcoder2
class Starcoder2RotaryEmbedding(nn.Module): class Starcoder2RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__() super().__init__()
...@@ -119,7 +119,7 @@ def rotate_half(x): ...@@ -119,7 +119,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb # Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors. """Applies Rotary Position Embedding to the query and key tensors.
...@@ -590,7 +590,7 @@ class Starcoder2FlashAttention2(Starcoder2Attention): ...@@ -590,7 +590,7 @@ class Starcoder2FlashAttention2(Starcoder2Attention):
) )
# Copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Starcoder2 # Copied from transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention with Mixtral->Starcoder2
class Starcoder2SdpaAttention(Starcoder2Attention): class Starcoder2SdpaAttention(Starcoder2Attention):
""" """
Starcoder2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from Starcoder2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
...@@ -703,7 +703,7 @@ class Starcoder2DecoderLayer(nn.Module): ...@@ -703,7 +703,7 @@ class Starcoder2DecoderLayer(nn.Module):
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
# Copied from transformers.models.mistral.modeling_mistral.MistralDecoderLayer.forward # Copied from transformers.models.qwen2.modeling_qwen2.Qwen2DecoderLayer.forward
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -780,7 +780,7 @@ STARCODER2_START_DOCSTRING = r""" ...@@ -780,7 +780,7 @@ STARCODER2_START_DOCSTRING = r"""
"The bare Starcoder2 Model outputting raw hidden-states without any specific head on top.", "The bare Starcoder2 Model outputting raw hidden-states without any specific head on top.",
STARCODER2_START_DOCSTRING, STARCODER2_START_DOCSTRING,
) )
# Copied from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel with Mistral->Starcoder2 # Copied from transformers.models.qwen2.modeling_qwen2.Qwen2PreTrainedModel with Qwen2->Starcoder2
class Starcoder2PreTrainedModel(PreTrainedModel): class Starcoder2PreTrainedModel(PreTrainedModel):
config_class = Starcoder2Config config_class = Starcoder2Config
base_model_prefix = "model" base_model_prefix = "model"
...@@ -1057,7 +1057,7 @@ class Starcoder2Model(Starcoder2PreTrainedModel): ...@@ -1057,7 +1057,7 @@ class Starcoder2Model(Starcoder2PreTrainedModel):
) )
# Copied from transformers.models.mistral.modeling_mistral.MistralForCausalLM with MISTRAL->STARCODER2,Mistral-7B-v0.1->starcoder2-7b_16k,Mistral->Starcoder2,mistralai->bigcode # Copied from transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM with QWEN2->STARCODER2,Qwen2->Starcoder2
class Starcoder2ForCausalLM(Starcoder2PreTrainedModel): class Starcoder2ForCausalLM(Starcoder2PreTrainedModel):
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
...@@ -1090,6 +1090,7 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel): ...@@ -1090,6 +1090,7 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel):
@add_start_docstrings_to_model_forward(STARCODER2_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(STARCODER2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
# Ignore copy
def forward( def forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
......
...@@ -12,14 +12,14 @@ ...@@ -12,14 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Testing suite for the PyTorch Mistral model. """ """Testing suite for the PyTorch Mistral model."""
import gc import gc
import tempfile import tempfile
import unittest import unittest
import pytest import pytest
from packaging import version
from transformers import AutoTokenizer, MistralConfig, is_torch_available, set_seed from transformers import AutoTokenizer, MistralConfig, is_torch_available, set_seed
from transformers.testing_utils import ( from transformers.testing_utils import (
...@@ -648,6 +648,74 @@ class MistralIntegrationTest(unittest.TestCase): ...@@ -648,6 +648,74 @@ class MistralIntegrationTest(unittest.TestCase):
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
gc.collect() gc.collect()
@slow
def test_compile_static_cache(self):
# `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2
# work as intended. See https://github.com/pytorch/pytorch/issues/121943
if version.parse(torch.__version__) < version.parse("2.3.0"):
self.skipTest("This test requires torch >= 2.3 to run.")
NUM_TOKENS_TO_GENERATE = 40
EXPECTED_TEXT_COMPLETION = {
8: [
"My favourite condiment is 100% ketchup. I love it on everything. "
"I’m not a big fan of mustard, mayo, or relish. I’m not a fan of pickles"
],
7: [
"My favourite condiment is 100% ketchup. I love it on everything. "
"I’m not a big fan of mustard, mayo, or relish. I’m not a fan of pickles"
],
}
prompts = ["My favourite condiment is "]
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False)
tokenizer.pad_token = tokenizer.eos_token
model = MistralForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1", device_map="sequential", torch_dtype=torch.float16
)
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
# Dynamic Cache
generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False)
dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], dynamic_text)
# Static Cache
generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
)
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_text)
# Sliding Window Cache
generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="sliding_window"
)
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_text)
# Static Cache + compile
forward_function = model.forward
model.forward = torch.compile(forward_function, mode="reduce-overhead", fullgraph=True)
generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
)
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text)
# Sliding Window Cache + compile
torch._dynamo.reset()
model.forward = torch.compile(forward_function, mode="reduce-overhead", fullgraph=True)
generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="sliding_window"
)
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text)
del model
backend_empty_cache(torch_device)
gc.collect()
@slow @slow
@require_torch_gpu @require_torch_gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment