"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "c4bc66886d2e910cba2230e931c6b3ab1094dc47"
Unverified Commit d475f767 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

SlidingWindowCache: reduce differences to other Cache classes (#30970)

* tmp commit

* sliding window with fewer differences

* make fixup + rebase

* missing overwrite
parent 221aaec6
......@@ -829,22 +829,22 @@ class StaticCache(Cache):
self.value_cache[layer_idx].zero_()
class SlidingWindowCache(Cache):
class SlidingWindowCache(StaticCache):
"""
Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention.
Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.config.sliding_window_size - 1`,
Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.config.sliding_window - 1`,
if true(which means the cache can not hold all the old key value states and new states together because of the sliding window constraint),
we need to do a cycle shift based on `indices` to replace the oldest states by the new key value states passed in.
The `to_shift` is only true once we are above sliding_window_size. Thus with `sliding_window_size==64`:
The `to_shift` is only true once we are above sliding_window. Thus with `sliding_window==64`:
indices = (slicing + to_shift[-1].int()-1) % self.config.sliding_window_size
indices = (slicing + to_shift[-1].int()-1) % self.config.sliding_window
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
55, 56, 57, 58, 59, 60, 61, 62, 63, 0])
We overwrite the cache using these, then we always write at cache_position (clamped to `sliding_window_size`)
We overwrite the cache using these, then we always write at cache_position (clamped to `sliding_window`)
Parameters:
config (`PretrainedConfig):
......@@ -866,38 +866,11 @@ class SlidingWindowCache(Cache):
"sliding window attention, please check if there is a `sliding_window` field in the model "
"config and it's not set to None."
)
super().__init__()
self.max_batch_size = max_batch_size
# take the minimum of max_cache_len and config.sliding_window so that we allocate less memory
# when we do short-sentence generation
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
self.model_sliding_window_size = config.sliding_window
self.sliding_window_size = min(self.max_cache_len, self.model_sliding_window_size)
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
self.head_dim = (
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
max_cache_len = min(config.sliding_window, max_cache_len)
super().__init__(
config=config, max_batch_size=max_batch_size, max_cache_len=max_cache_len, device=device, dtype=dtype
)
self.dtype = dtype if dtype is not None else torch.float32
self.num_key_value_heads = (
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
)
cache_shape = (
config.num_hidden_layers,
max_batch_size,
self.num_key_value_heads,
self.sliding_window_size,
self.head_dim,
)
self.key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
self.value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
torch._dynamo.mark_static_address(self.key_cache)
torch._dynamo.mark_static_address(self.value_cache)
def update(
self,
key_states: torch.Tensor,
......@@ -909,20 +882,21 @@ class SlidingWindowCache(Cache):
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
# assume this only happens in prefill phase when prompt length > sliding_window_size
if cache_position.shape[0] > self.sliding_window_size:
k_out = key_states[:, :, -self.sliding_window_size :, :]
v_out = value_states[:, :, -self.sliding_window_size :, :]
self.key_cache[layer_idx] = k_out
self.value_cache[layer_idx] = v_out
# assume this only happens in prefill phase when prompt length > sliding_window_size (= max_cache_len)
if cache_position.shape[0] > self.max_cache_len:
k_out = key_states[:, :, -self.max_cache_len :, :]
v_out = value_states[:, :, -self.max_cache_len :, :]
# Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
self.key_cache[layer_idx] += k_out
self.value_cache[layer_idx] += v_out
# we should return the whole states instead of k_out, v_out to take the whole prompt
# into consideration when building kv cache instead of just throwing away tokens outside of the window
return key_states, value_states
slicing = torch.ones(self.sliding_window_size, dtype=torch.long, device=value_states.device).cumsum(0)
cache_position = cache_position.clamp(0, self.sliding_window_size - 1)
to_shift = cache_position >= self.sliding_window_size - 1
indices = (slicing + to_shift[-1].int() - 1) % self.sliding_window_size
slicing = torch.ones(self.max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0)
cache_position = cache_position.clamp(0, self.max_cache_len - 1)
to_shift = cache_position >= self.max_cache_len - 1
indices = (slicing + to_shift[-1].int() - 1) % self.max_cache_len
k_out = k_out[:, :, indices]
v_out = v_out[:, :, indices]
......@@ -930,21 +904,16 @@ class SlidingWindowCache(Cache):
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
self.key_cache[layer_idx] = k_out
self.value_cache[layer_idx] = v_out
# `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
return k_out, v_out
self.key_cache[layer_idx] += k_out
self.value_cache[layer_idx] += v_out
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
# assume this will be called only in the first generation step
# `cache_postion` will be used in other cases
return 0
return k_out, v_out
def get_max_length(self) -> Optional[int]:
# in theory there is no limit because the sliding window size is fixed
# no matter how long the sentence is
return None
def reset(self):
self.key_cache.zero_()
self.value_cache.zero_()
......@@ -1368,18 +1368,15 @@ class GenerationMixin:
Returns the resulting cache object.
"""
cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation]
if cache_implementation == "sliding_window":
max_cache_len = min(self.config.sliding_window, max_cache_len)
need_new_cache = (
not hasattr(self, "_cache")
or (not isinstance(self._cache, cache_cls))
or self._cache.max_batch_size < max_batch_size
or self._cache.max_batch_size != max_batch_size
or self._cache.max_cache_len < max_cache_len
)
if cache_implementation == "sliding_window":
need_new_cache = need_new_cache or (
self._cache.sliding_window_size < self._cache.model_sliding_window_size
and max_cache_len > self._cache.max_cache_len
)
elif cache_implementation == "static":
need_new_cache = need_new_cache or self._cache.max_cache_len < max_cache_len
if need_new_cache:
if hasattr(self.config, "_pre_quantization_dtype"):
......
......@@ -1290,9 +1290,9 @@ class MistralForCausalLM(MistralPreTrainedModel):
past_length > 0
and attention_mask is not None
and isinstance(past_key_values, SlidingWindowCache)
and attention_mask.shape[1] > past_key_values.sliding_window_size
and attention_mask.shape[1] > past_key_values.max_cache_len
):
attention_mask = attention_mask[:, -past_key_values.sliding_window_size :]
attention_mask = attention_mask[:, -past_key_values.max_cache_len :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment