"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "fbdb978eb5b42686f8ad858c13c8f6f7209b5c84"
Unverified Commit 616bb11d authored by Longjie Zheng's avatar Longjie Zheng Committed by GitHub
Browse files

Add torch.compile for Mistral (#30642)

* first version

* fix sliding window

* fix style

* add sliding window cache

* fix style

* address comments

* fix test

* fix style

* move sliding window check inside cache init

* revert changes on irrelevant files & add comment on SlidingWindowCache

* address comments & fix style

fix style

* update causal mask

* [run-slow] mistral

* [run-slow] mistral

* [run-slow] mistral

* [run-slow] mistral

* [run-slow] mistral

* [run-slow] llama

* [run-slow] mistral

* [run-slow] mistral

* [run-slow] mistral

* revert CI from a10 to t4

* wrap up
parent 92d1d97c
...@@ -29,7 +29,7 @@ To optimize this, you can use a kv-cache to store the past keys and values inste ...@@ -29,7 +29,7 @@ To optimize this, you can use a kv-cache to store the past keys and values inste
The *static kv-cache* solves this issue by pre-allocating the kv-cache size to a maximum value which allows you to combine it with torch.compile for up to a 4x speed up. The *static kv-cache* solves this issue by pre-allocating the kv-cache size to a maximum value which allows you to combine it with torch.compile for up to a 4x speed up.
> [!WARNING] > [!WARNING]
> Currently, only [Command R](./model_doc/cohere), [Gemma](./model_doc/gemma) and [Llama](./model_doc/llama2) models support static kv-cache and torch.compile. > Currently, only [Llama](./model_doc/llama2) and a few other models support static kv-cache and torch.compile. Check [this issue](https://github.com/huggingface/transformers/issues/28981) for a live model compatibility list.
For this example, let's load the [Gemma](https://hf.co/google/gemma-2b) model. For this example, let's load the [Gemma](https://hf.co/google/gemma-2b) model.
......
...@@ -448,3 +448,124 @@ class StaticCache(Cache): ...@@ -448,3 +448,124 @@ class StaticCache(Cache):
# In-place ops prevent breaking the static address # In-place ops prevent breaking the static address
self.key_cache[layer_idx].zero_() self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_() self.value_cache[layer_idx].zero_()
class SlidingWindowCache(Cache):
"""
Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention.
Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.config.sliding_window_size - 1`,
if true(which means the cache can not hold all the old key value states and new states together because of the sliding window constraint),
we need to do a cycle shift based on `indices` to replace the oldest states by the new key value states passed in.
The `to_shift` is only true once we are above sliding_window_size. Thus with `sliding_window_size==64`:
indices = (slicing + to_shift[-1].int()-1) % self.config.sliding_window_size
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
55, 56, 57, 58, 59, 60, 61, 62, 63, 0])
We overwrite the cache using these, then we always write at cache_position (clamped to `sliding_window_size`)
Parameters:
config (`PretrainedConfig):
The configuration file defining the shape-related attributes required to initialize the static cache.
max_batch_size (`int`):
The maximum batch size with which the model will be used.
max_cache_len (`int`):
The maximum sequence length with which the model will be used.
device (`torch.device`):
The device on which the cache should be initialized. Should be the same as the layer.
dtype (*optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer.
"""
def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
if not hasattr(config, "sliding_window") or config.sliding_window is None:
raise ValueError(
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
"sliding window attention, please check if there is a `sliding_window` field in the model "
"config and it's not set to None."
)
super().__init__()
self.max_batch_size = max_batch_size
# take the minimum of max_cache_len and config.sliding_window so that we allocate less memory
# when we do short-sentence generation
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
self.model_sliding_window_size = config.sliding_window
self.sliding_window_size = min(self.max_cache_len, self.model_sliding_window_size)
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
self.head_dim = (
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
)
self.dtype = dtype if dtype is not None else torch.float32
self.num_key_value_heads = (
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
)
cache_shape = (
config.num_hidden_layers,
max_batch_size,
self.num_key_value_heads,
self.sliding_window_size,
self.head_dim,
)
self.key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
self.value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
torch._dynamo.mark_static_address(self.key_cache)
torch._dynamo.mark_static_address(self.value_cache)
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor]:
cache_position = cache_kwargs.get("cache_position")
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
# assume this only happens in prefill phase when prompt length > sliding_window_size
if cache_position.shape[0] > self.sliding_window_size:
k_out = key_states[:, :, -self.sliding_window_size :, :]
v_out = value_states[:, :, -self.sliding_window_size :, :]
self.key_cache[layer_idx] = k_out
self.value_cache[layer_idx] = v_out
# we should return the whole states instead of k_out, v_out to take the whole prompt
# into consideration when building kv cache instead of just throwing away tokens outside of the window
return key_states, value_states
slicing = torch.ones(self.sliding_window_size, dtype=torch.long, device=value_states.device).cumsum(0)
cache_position = cache_position.clamp(0, self.sliding_window_size - 1)
to_shift = cache_position >= self.sliding_window_size - 1
indices = (slicing + to_shift[-1].int() - 1) % self.sliding_window_size
k_out = k_out[:, :, indices]
v_out = v_out[:, :, indices]
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
self.key_cache[layer_idx] = k_out
self.value_cache[layer_idx] = v_out
return k_out, v_out
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
# assume this will be called only in the first generation step
# `cache_postion` will be used in other cases
return 0
def get_max_length(self) -> Optional[int]:
# in theory there is no limit because the sliding window size is fixed
# no matter how long the sentence is
return None
def reset(self):
self.key_cache.zero_()
self.value_cache.zero_()
...@@ -24,7 +24,7 @@ import torch ...@@ -24,7 +24,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from torch import nn from torch import nn
from ..cache_utils import Cache, DynamicCache, StaticCache from ..cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
from ..integrations.deepspeed import is_deepspeed_zero3_enabled from ..integrations.deepspeed import is_deepspeed_zero3_enabled
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
from ..models.auto import ( from ..models.auto import (
...@@ -96,9 +96,7 @@ logger = logging.get_logger(__name__) ...@@ -96,9 +96,7 @@ logger = logging.get_logger(__name__)
if is_accelerate_available(): if is_accelerate_available():
from accelerate.hooks import AlignDevicesHook, add_hook_to_module from accelerate.hooks import AlignDevicesHook, add_hook_to_module
NEED_SETUP_CACHE_CLASSES_MAPPING = { NEED_SETUP_CACHE_CLASSES_MAPPING = {"static": StaticCache, "sliding_window": SlidingWindowCache}
"static": StaticCache,
}
@dataclass @dataclass
...@@ -1326,24 +1324,33 @@ class GenerationMixin: ...@@ -1326,24 +1324,33 @@ class GenerationMixin:
model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device) model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device)
return model_kwargs return model_kwargs
def _get_static_cache(self, max_batch_size: int, max_cache_len: int) -> StaticCache: def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_len: int) -> Cache:
""" """
Sets a static cache for `generate`, that will persist across calls. A new cache will only be initialized a Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a
new `generate` call requires a larger cache. new `generate` call requires a larger cache.
Returns the resulting static cache object. Returns the resulting cache object.
""" """
needs_new_cache = ( cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation]
not hasattr(self, "_static_cache") need_new_cache = (
or self._static_cache.max_batch_size < max_batch_size not hasattr(self, "_cache")
or self._static_cache.max_cache_len < max_cache_len or (not isinstance(self._cache, cache_cls))
or self._cache.max_batch_size < max_batch_size
) )
if needs_new_cache: if cache_implementation == "sliding_window":
need_new_cache = need_new_cache or (
self._cache.sliding_window_size < self._cache.model_sliding_window_size
and max_cache_len > self._cache.max_cache_len
)
elif cache_implementation == "static":
need_new_cache = need_new_cache or self._cache.max_cache_len < max_cache_len
if need_new_cache:
if hasattr(self.config, "_pre_quantization_dtype"): if hasattr(self.config, "_pre_quantization_dtype"):
cache_dtype = self.config._pre_quantization_dtype cache_dtype = self.config._pre_quantization_dtype
else: else:
cache_dtype = self.dtype cache_dtype = self.dtype
self._static_cache = StaticCache( self._cache = cache_cls(
config=self.config, config=self.config,
max_batch_size=max_batch_size, max_batch_size=max_batch_size,
max_cache_len=max_cache_len, max_cache_len=max_cache_len,
...@@ -1351,8 +1358,8 @@ class GenerationMixin: ...@@ -1351,8 +1358,8 @@ class GenerationMixin:
dtype=cache_dtype, dtype=cache_dtype,
) )
else: else:
self._static_cache.reset() # reset the cache for a new generation self._cache.reset()
return self._static_cache return self._cache
def _prepare_special_tokens( def _prepare_special_tokens(
self, self,
...@@ -1615,14 +1622,14 @@ class GenerationMixin: ...@@ -1615,14 +1622,14 @@ class GenerationMixin:
"This model does not support the `cache_implementation` argument. Please check the following " "This model does not support the `cache_implementation` argument. Please check the following "
"issue: https://github.com/huggingface/transformers/issues/28981." "issue: https://github.com/huggingface/transformers/issues/28981."
) )
if generation_config.cache_implementation == "static": if generation_config.cache_implementation == "static" and not self._supports_static_cache:
if not self._supports_static_cache: raise ValueError(
raise ValueError( "This model does not support `cache_implementation='static'`. Please check the following "
"This model does not support `cache_implementation='static'`. Please check the following " "issue: https://github.com/huggingface/transformers/issues/28981"
"issue: https://github.com/huggingface/transformers/issues/28981" )
) model_kwargs["past_key_values"] = self._get_cache(
model_kwargs["past_key_values"] = self._get_static_cache(batch_size, generation_config.max_length) generation_config.cache_implementation, batch_size, generation_config.max_length
)
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
# 7. determine generation mode # 7. determine generation mode
......
...@@ -63,7 +63,7 @@ class OpenLlamaRMSNorm(nn.Module): ...@@ -63,7 +63,7 @@ class OpenLlamaRMSNorm(nn.Module):
return self.weight * hidden_states.to(input_dtype) return self.weight * hidden_states.to(input_dtype)
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->OpenLlama # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->OpenLlama
class OpenLlamaRotaryEmbedding(nn.Module): class OpenLlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__() super().__init__()
...@@ -154,7 +154,7 @@ def rotate_half(x): ...@@ -154,7 +154,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb # Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors. """Applies Rotary Position Embedding to the query and key tensors.
......
...@@ -81,7 +81,7 @@ def rotate_half(x): ...@@ -81,7 +81,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb # Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors. """Applies Rotary Position Embedding to the query and key tensors.
...@@ -123,7 +123,7 @@ def _get_unpad_data(attention_mask): ...@@ -123,7 +123,7 @@ def _get_unpad_data(attention_mask):
) )
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Falcon # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Falcon
class FalconRotaryEmbedding(nn.Module): class FalconRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__() super().__init__()
......
...@@ -253,7 +253,6 @@ class GemmaAttention(nn.Module): ...@@ -253,7 +253,6 @@ class GemmaAttention(nn.Module):
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
...@@ -265,8 +264,8 @@ class GemmaAttention(nn.Module): ...@@ -265,8 +264,8 @@ class GemmaAttention(nn.Module):
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None: if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache # sin and cos are specific to RoPE models; cache_position needed for the static cache
......
...@@ -522,7 +522,7 @@ def attention_mask_func(attention_scores, ltor_mask): ...@@ -522,7 +522,7 @@ def attention_mask_func(attention_scores, ltor_mask):
class GPTNeoXRotaryEmbedding(nn.Module): class GPTNeoXRotaryEmbedding(nn.Module):
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding.__init__ # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding.__init__
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__() super().__init__()
...@@ -614,7 +614,7 @@ def rotate_half(x): ...@@ -614,7 +614,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb # Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors. """Applies Rotary Position Embedding to the query and key tensors.
......
...@@ -230,7 +230,7 @@ class GPTNeoXJapaneseAttention(nn.Module): ...@@ -230,7 +230,7 @@ class GPTNeoXJapaneseAttention(nn.Module):
# Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXRotaryEmbedding with GPTNeoXRotaryEmbedding->RotaryEmbedding # Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXRotaryEmbedding with GPTNeoXRotaryEmbedding->RotaryEmbedding
class RotaryEmbedding(nn.Module): class RotaryEmbedding(nn.Module):
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding.__init__ # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding.__init__
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__() super().__init__()
......
...@@ -478,7 +478,7 @@ def rotate_half(x): ...@@ -478,7 +478,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb # Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors. """Applies Rotary Position Embedding to the query and key tensors.
......
...@@ -303,7 +303,6 @@ class LlamaAttention(nn.Module): ...@@ -303,7 +303,6 @@ class LlamaAttention(nn.Module):
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
...@@ -402,7 +401,6 @@ class LlamaFlashAttention2(LlamaAttention): ...@@ -402,7 +401,6 @@ class LlamaFlashAttention2(LlamaAttention):
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if isinstance(past_key_value, StaticCache): if isinstance(past_key_value, StaticCache):
raise ValueError( raise ValueError(
......
...@@ -181,7 +181,8 @@ class MixtralRMSNorm(nn.Module): ...@@ -181,7 +181,8 @@ class MixtralRMSNorm(nn.Module):
return self.weight * hidden_states.to(input_dtype) return self.weight * hidden_states.to(input_dtype)
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mixtral # copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mixtral
# TODO @longjie no longer copied from Mistral after static cache
class MixtralRotaryEmbedding(nn.Module): class MixtralRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__() super().__init__()
...@@ -226,7 +227,8 @@ def rotate_half(x): ...@@ -226,7 +227,8 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb # copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
# TODO @longjie no longer copied from Mistral after static cache
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors. """Applies Rotary Position Embedding to the query and key tensors.
...@@ -268,7 +270,8 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: ...@@ -268,7 +270,8 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
# Copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Mixtral # copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Mixtral
# TODO @longjie no longer copied from Mistral after static cache
class MixtralAttention(nn.Module): class MixtralAttention(nn.Module):
""" """
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
...@@ -392,7 +395,8 @@ class MixtralAttention(nn.Module): ...@@ -392,7 +395,8 @@ class MixtralAttention(nn.Module):
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
# Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Mixtral # copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Mixtral
# TODO @longjie no longer copied from Mistral after static cache
class MixtralFlashAttention2(MixtralAttention): class MixtralFlashAttention2(MixtralAttention):
""" """
Mixtral flash attention module. This module inherits from `MixtralAttention` as the weights of the module stays Mixtral flash attention module. This module inherits from `MixtralAttention` as the weights of the module stays
...@@ -679,7 +683,8 @@ class MixtralFlashAttention2(MixtralAttention): ...@@ -679,7 +683,8 @@ class MixtralFlashAttention2(MixtralAttention):
) )
# Copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Mixtral # copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Mixtral
# TODO @longjie no longer copied from Mistral after static cache
class MixtralSdpaAttention(MixtralAttention): class MixtralSdpaAttention(MixtralAttention):
""" """
Mixtral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from Mixtral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
...@@ -958,7 +963,7 @@ MIXTRAL_START_DOCSTRING = r""" ...@@ -958,7 +963,7 @@ MIXTRAL_START_DOCSTRING = r"""
"The bare Mixtral Model outputting raw hidden-states without any specific head on top.", "The bare Mixtral Model outputting raw hidden-states without any specific head on top.",
MIXTRAL_START_DOCSTRING, MIXTRAL_START_DOCSTRING,
) )
# Copied from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel with Mistral->Mixtral # Copied from transformers.models.qwen2.modeling_qwen2.Qwen2PreTrainedModel with Qwen2->Mixtral
class MixtralPreTrainedModel(PreTrainedModel): class MixtralPreTrainedModel(PreTrainedModel):
config_class = MixtralConfig config_class = MixtralConfig
base_model_prefix = "model" base_model_prefix = "model"
...@@ -1052,7 +1057,8 @@ MIXTRAL_INPUTS_DOCSTRING = r""" ...@@ -1052,7 +1057,8 @@ MIXTRAL_INPUTS_DOCSTRING = r"""
"The bare Mixtral Model outputting raw hidden-states without any specific head on top.", "The bare Mixtral Model outputting raw hidden-states without any specific head on top.",
MIXTRAL_START_DOCSTRING, MIXTRAL_START_DOCSTRING,
) )
# Copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral # copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral
# TODO @longjie no longer copied from Mistral after static cache
class MixtralModel(MixtralPreTrainedModel): class MixtralModel(MixtralPreTrainedModel):
""" """
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`] Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`]
......
...@@ -45,7 +45,7 @@ logger = logging.get_logger(__name__) ...@@ -45,7 +45,7 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "PersimmonConfig" _CONFIG_FOR_DOC = "PersimmonConfig"
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Persimmon # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Persimmon
class PersimmonRotaryEmbedding(nn.Module): class PersimmonRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__() super().__init__()
...@@ -137,7 +137,7 @@ def rotate_half(x): ...@@ -137,7 +137,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb # Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors. """Applies Rotary Position Embedding to the query and key tensors.
......
...@@ -76,7 +76,7 @@ def _get_unpad_data(attention_mask): ...@@ -76,7 +76,7 @@ def _get_unpad_data(attention_mask):
) )
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Phi # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Phi
class PhiRotaryEmbedding(nn.Module): class PhiRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__() super().__init__()
...@@ -168,7 +168,7 @@ def rotate_half(x): ...@@ -168,7 +168,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb # Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors. """Applies Rotary Position Embedding to the query and key tensors.
......
...@@ -94,7 +94,7 @@ class Qwen2RMSNorm(nn.Module): ...@@ -94,7 +94,7 @@ class Qwen2RMSNorm(nn.Module):
return self.weight * hidden_states.to(input_dtype) return self.weight * hidden_states.to(input_dtype)
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Qwen2 # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2
class Qwen2RotaryEmbedding(nn.Module): class Qwen2RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__() super().__init__()
...@@ -139,7 +139,7 @@ def rotate_half(x): ...@@ -139,7 +139,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb # Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors. """Applies Rotary Position Embedding to the query and key tensors.
...@@ -611,7 +611,7 @@ class Qwen2FlashAttention2(Qwen2Attention): ...@@ -611,7 +611,7 @@ class Qwen2FlashAttention2(Qwen2Attention):
) )
# Copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Qwen2 # Copied from transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention with Mixtral->Qwen2
class Qwen2SdpaAttention(Qwen2Attention): class Qwen2SdpaAttention(Qwen2Attention):
""" """
Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
......
...@@ -170,7 +170,7 @@ class Qwen2MoeRMSNorm(nn.Module): ...@@ -170,7 +170,7 @@ class Qwen2MoeRMSNorm(nn.Module):
return self.weight * hidden_states.to(input_dtype) return self.weight * hidden_states.to(input_dtype)
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Qwen2Moe # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2Moe
class Qwen2MoeRotaryEmbedding(nn.Module): class Qwen2MoeRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__() super().__init__()
...@@ -215,7 +215,7 @@ def rotate_half(x): ...@@ -215,7 +215,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb # Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors. """Applies Rotary Position Embedding to the query and key tensors.
...@@ -689,7 +689,7 @@ class Qwen2MoeFlashAttention2(Qwen2MoeAttention): ...@@ -689,7 +689,7 @@ class Qwen2MoeFlashAttention2(Qwen2MoeAttention):
) )
# Copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Qwen2Moe # Copied from transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention with Mixtral->Qwen2Moe
class Qwen2MoeSdpaAttention(Qwen2MoeAttention): class Qwen2MoeSdpaAttention(Qwen2MoeAttention):
""" """
Qwen2Moe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from Qwen2Moe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
......
...@@ -71,7 +71,7 @@ def _get_unpad_data(attention_mask): ...@@ -71,7 +71,7 @@ def _get_unpad_data(attention_mask):
) )
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->StableLm # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->StableLm
class StableLmRotaryEmbedding(nn.Module): class StableLmRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__() super().__init__()
...@@ -163,7 +163,7 @@ def rotate_half(x): ...@@ -163,7 +163,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb # Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors. """Applies Rotary Position Embedding to the query and key tensors.
......
...@@ -74,7 +74,7 @@ def _get_unpad_data(attention_mask): ...@@ -74,7 +74,7 @@ def _get_unpad_data(attention_mask):
) )
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Starcoder2 # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Starcoder2
class Starcoder2RotaryEmbedding(nn.Module): class Starcoder2RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__() super().__init__()
...@@ -119,7 +119,7 @@ def rotate_half(x): ...@@ -119,7 +119,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb # Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors. """Applies Rotary Position Embedding to the query and key tensors.
...@@ -590,7 +590,7 @@ class Starcoder2FlashAttention2(Starcoder2Attention): ...@@ -590,7 +590,7 @@ class Starcoder2FlashAttention2(Starcoder2Attention):
) )
# Copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Starcoder2 # Copied from transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention with Mixtral->Starcoder2
class Starcoder2SdpaAttention(Starcoder2Attention): class Starcoder2SdpaAttention(Starcoder2Attention):
""" """
Starcoder2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from Starcoder2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
...@@ -703,7 +703,7 @@ class Starcoder2DecoderLayer(nn.Module): ...@@ -703,7 +703,7 @@ class Starcoder2DecoderLayer(nn.Module):
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
# Copied from transformers.models.mistral.modeling_mistral.MistralDecoderLayer.forward # Copied from transformers.models.qwen2.modeling_qwen2.Qwen2DecoderLayer.forward
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -780,7 +780,7 @@ STARCODER2_START_DOCSTRING = r""" ...@@ -780,7 +780,7 @@ STARCODER2_START_DOCSTRING = r"""
"The bare Starcoder2 Model outputting raw hidden-states without any specific head on top.", "The bare Starcoder2 Model outputting raw hidden-states without any specific head on top.",
STARCODER2_START_DOCSTRING, STARCODER2_START_DOCSTRING,
) )
# Copied from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel with Mistral->Starcoder2 # Copied from transformers.models.qwen2.modeling_qwen2.Qwen2PreTrainedModel with Qwen2->Starcoder2
class Starcoder2PreTrainedModel(PreTrainedModel): class Starcoder2PreTrainedModel(PreTrainedModel):
config_class = Starcoder2Config config_class = Starcoder2Config
base_model_prefix = "model" base_model_prefix = "model"
...@@ -1057,7 +1057,7 @@ class Starcoder2Model(Starcoder2PreTrainedModel): ...@@ -1057,7 +1057,7 @@ class Starcoder2Model(Starcoder2PreTrainedModel):
) )
# Copied from transformers.models.mistral.modeling_mistral.MistralForCausalLM with MISTRAL->STARCODER2,Mistral-7B-v0.1->starcoder2-7b_16k,Mistral->Starcoder2,mistralai->bigcode # Copied from transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM with QWEN2->STARCODER2,Qwen2->Starcoder2
class Starcoder2ForCausalLM(Starcoder2PreTrainedModel): class Starcoder2ForCausalLM(Starcoder2PreTrainedModel):
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
...@@ -1090,6 +1090,7 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel): ...@@ -1090,6 +1090,7 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel):
@add_start_docstrings_to_model_forward(STARCODER2_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(STARCODER2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
# Ignore copy
def forward( def forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
......
...@@ -12,14 +12,14 @@ ...@@ -12,14 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Testing suite for the PyTorch Mistral model. """ """Testing suite for the PyTorch Mistral model."""
import gc import gc
import tempfile import tempfile
import unittest import unittest
import pytest import pytest
from packaging import version
from transformers import AutoTokenizer, MistralConfig, is_torch_available, set_seed from transformers import AutoTokenizer, MistralConfig, is_torch_available, set_seed
from transformers.testing_utils import ( from transformers.testing_utils import (
...@@ -648,6 +648,74 @@ class MistralIntegrationTest(unittest.TestCase): ...@@ -648,6 +648,74 @@ class MistralIntegrationTest(unittest.TestCase):
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
gc.collect() gc.collect()
@slow
def test_compile_static_cache(self):
# `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2
# work as intended. See https://github.com/pytorch/pytorch/issues/121943
if version.parse(torch.__version__) < version.parse("2.3.0"):
self.skipTest("This test requires torch >= 2.3 to run.")
NUM_TOKENS_TO_GENERATE = 40
EXPECTED_TEXT_COMPLETION = {
8: [
"My favourite condiment is 100% ketchup. I love it on everything. "
"I’m not a big fan of mustard, mayo, or relish. I’m not a fan of pickles"
],
7: [
"My favourite condiment is 100% ketchup. I love it on everything. "
"I’m not a big fan of mustard, mayo, or relish. I’m not a fan of pickles"
],
}
prompts = ["My favourite condiment is "]
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False)
tokenizer.pad_token = tokenizer.eos_token
model = MistralForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1", device_map="sequential", torch_dtype=torch.float16
)
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
# Dynamic Cache
generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False)
dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], dynamic_text)
# Static Cache
generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
)
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_text)
# Sliding Window Cache
generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="sliding_window"
)
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_text)
# Static Cache + compile
forward_function = model.forward
model.forward = torch.compile(forward_function, mode="reduce-overhead", fullgraph=True)
generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
)
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text)
# Sliding Window Cache + compile
torch._dynamo.reset()
model.forward = torch.compile(forward_function, mode="reduce-overhead", fullgraph=True)
generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="sliding_window"
)
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text)
del model
backend_empty_cache(torch_device)
gc.collect()
@slow @slow
@require_torch_gpu @require_torch_gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment