"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "f6cb0f806efecb64df40c946dacaad0adad33d53"
Unverified Commit d475f767 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

SlidingWindowCache: reduce differences to other Cache classes (#30970)

* tmp commit

* sliding window with fewer differences

* make fixup + rebase

* missing overwrite
parent 221aaec6
...@@ -829,22 +829,22 @@ class StaticCache(Cache): ...@@ -829,22 +829,22 @@ class StaticCache(Cache):
self.value_cache[layer_idx].zero_() self.value_cache[layer_idx].zero_()
class SlidingWindowCache(Cache): class SlidingWindowCache(StaticCache):
""" """
Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention. Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention.
Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.config.sliding_window_size - 1`, Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.config.sliding_window - 1`,
if true(which means the cache can not hold all the old key value states and new states together because of the sliding window constraint), if true(which means the cache can not hold all the old key value states and new states together because of the sliding window constraint),
we need to do a cycle shift based on `indices` to replace the oldest states by the new key value states passed in. we need to do a cycle shift based on `indices` to replace the oldest states by the new key value states passed in.
The `to_shift` is only true once we are above sliding_window_size. Thus with `sliding_window_size==64`: The `to_shift` is only true once we are above sliding_window. Thus with `sliding_window==64`:
indices = (slicing + to_shift[-1].int()-1) % self.config.sliding_window_size indices = (slicing + to_shift[-1].int()-1) % self.config.sliding_window
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
55, 56, 57, 58, 59, 60, 61, 62, 63, 0]) 55, 56, 57, 58, 59, 60, 61, 62, 63, 0])
We overwrite the cache using these, then we always write at cache_position (clamped to `sliding_window_size`) We overwrite the cache using these, then we always write at cache_position (clamped to `sliding_window`)
Parameters: Parameters:
config (`PretrainedConfig): config (`PretrainedConfig):
...@@ -866,38 +866,11 @@ class SlidingWindowCache(Cache): ...@@ -866,38 +866,11 @@ class SlidingWindowCache(Cache):
"sliding window attention, please check if there is a `sliding_window` field in the model " "sliding window attention, please check if there is a `sliding_window` field in the model "
"config and it's not set to None." "config and it's not set to None."
) )
max_cache_len = min(config.sliding_window, max_cache_len)
super().__init__() super().__init__(
self.max_batch_size = max_batch_size config=config, max_batch_size=max_batch_size, max_cache_len=max_cache_len, device=device, dtype=dtype
# take the minimum of max_cache_len and config.sliding_window so that we allocate less memory
# when we do short-sentence generation
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
self.model_sliding_window_size = config.sliding_window
self.sliding_window_size = min(self.max_cache_len, self.model_sliding_window_size)
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
self.head_dim = (
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
) )
self.dtype = dtype if dtype is not None else torch.float32
self.num_key_value_heads = (
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
)
cache_shape = (
config.num_hidden_layers,
max_batch_size,
self.num_key_value_heads,
self.sliding_window_size,
self.head_dim,
)
self.key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
self.value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
torch._dynamo.mark_static_address(self.key_cache)
torch._dynamo.mark_static_address(self.value_cache)
def update( def update(
self, self,
key_states: torch.Tensor, key_states: torch.Tensor,
...@@ -909,20 +882,21 @@ class SlidingWindowCache(Cache): ...@@ -909,20 +882,21 @@ class SlidingWindowCache(Cache):
k_out = self.key_cache[layer_idx] k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx] v_out = self.value_cache[layer_idx]
# assume this only happens in prefill phase when prompt length > sliding_window_size # assume this only happens in prefill phase when prompt length > sliding_window_size (= max_cache_len)
if cache_position.shape[0] > self.sliding_window_size: if cache_position.shape[0] > self.max_cache_len:
k_out = key_states[:, :, -self.sliding_window_size :, :] k_out = key_states[:, :, -self.max_cache_len :, :]
v_out = value_states[:, :, -self.sliding_window_size :, :] v_out = value_states[:, :, -self.max_cache_len :, :]
self.key_cache[layer_idx] = k_out # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
self.value_cache[layer_idx] = v_out self.key_cache[layer_idx] += k_out
self.value_cache[layer_idx] += v_out
# we should return the whole states instead of k_out, v_out to take the whole prompt # we should return the whole states instead of k_out, v_out to take the whole prompt
# into consideration when building kv cache instead of just throwing away tokens outside of the window # into consideration when building kv cache instead of just throwing away tokens outside of the window
return key_states, value_states return key_states, value_states
slicing = torch.ones(self.sliding_window_size, dtype=torch.long, device=value_states.device).cumsum(0) slicing = torch.ones(self.max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0)
cache_position = cache_position.clamp(0, self.sliding_window_size - 1) cache_position = cache_position.clamp(0, self.max_cache_len - 1)
to_shift = cache_position >= self.sliding_window_size - 1 to_shift = cache_position >= self.max_cache_len - 1
indices = (slicing + to_shift[-1].int() - 1) % self.sliding_window_size indices = (slicing + to_shift[-1].int() - 1) % self.max_cache_len
k_out = k_out[:, :, indices] k_out = k_out[:, :, indices]
v_out = v_out[:, :, indices] v_out = v_out[:, :, indices]
...@@ -930,21 +904,16 @@ class SlidingWindowCache(Cache): ...@@ -930,21 +904,16 @@ class SlidingWindowCache(Cache):
k_out[:, :, cache_position] = key_states k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states v_out[:, :, cache_position] = value_states
self.key_cache[layer_idx] = k_out # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
self.value_cache[layer_idx] = v_out self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
return k_out, v_out self.key_cache[layer_idx] += k_out
self.value_cache[layer_idx] += v_out
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: return k_out, v_out
# assume this will be called only in the first generation step
# `cache_postion` will be used in other cases
return 0
def get_max_length(self) -> Optional[int]: def get_max_length(self) -> Optional[int]:
# in theory there is no limit because the sliding window size is fixed # in theory there is no limit because the sliding window size is fixed
# no matter how long the sentence is # no matter how long the sentence is
return None return None
def reset(self):
self.key_cache.zero_()
self.value_cache.zero_()
...@@ -1368,18 +1368,15 @@ class GenerationMixin: ...@@ -1368,18 +1368,15 @@ class GenerationMixin:
Returns the resulting cache object. Returns the resulting cache object.
""" """
cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation] cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation]
if cache_implementation == "sliding_window":
max_cache_len = min(self.config.sliding_window, max_cache_len)
need_new_cache = ( need_new_cache = (
not hasattr(self, "_cache") not hasattr(self, "_cache")
or (not isinstance(self._cache, cache_cls)) or (not isinstance(self._cache, cache_cls))
or self._cache.max_batch_size < max_batch_size or self._cache.max_batch_size != max_batch_size
or self._cache.max_cache_len < max_cache_len
) )
if cache_implementation == "sliding_window":
need_new_cache = need_new_cache or (
self._cache.sliding_window_size < self._cache.model_sliding_window_size
and max_cache_len > self._cache.max_cache_len
)
elif cache_implementation == "static":
need_new_cache = need_new_cache or self._cache.max_cache_len < max_cache_len
if need_new_cache: if need_new_cache:
if hasattr(self.config, "_pre_quantization_dtype"): if hasattr(self.config, "_pre_quantization_dtype"):
......
...@@ -1290,9 +1290,9 @@ class MistralForCausalLM(MistralPreTrainedModel): ...@@ -1290,9 +1290,9 @@ class MistralForCausalLM(MistralPreTrainedModel):
past_length > 0 past_length > 0
and attention_mask is not None and attention_mask is not None
and isinstance(past_key_values, SlidingWindowCache) and isinstance(past_key_values, SlidingWindowCache)
and attention_mask.shape[1] > past_key_values.sliding_window_size and attention_mask.shape[1] > past_key_values.max_cache_len
): ):
attention_mask = attention_mask[:, -past_key_values.sliding_window_size :] attention_mask = attention_mask[:, -past_key_values.max_cache_len :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None: if inputs_embeds is not None and past_key_values is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment