Unverified Commit 388fd314 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: Mistral/Mixtral FA2 cache fix when going beyond the context window (#28037)

parent 0ede7626
......@@ -363,6 +363,12 @@ class MistralFlashAttention2(MistralAttention):
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
# Because the input can be padded, the absolute sequence length depends on the max position id.
......@@ -385,11 +391,16 @@ class MistralFlashAttention2(MistralAttention):
if past_key_value is not None:
# Activate slicing cache only if the config has a value `sliding_windows` attribute
if getattr(self.config, "sliding_window", None) is not None and kv_seq_len > self.config.sliding_window:
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
if (
getattr(self.config, "sliding_window", None) is not None
and kv_seq_len > self.config.sliding_window
and cache_has_contents
):
slicing_tokens = 1 - self.config.sliding_window
past_key = past_key_value[0]
past_value = past_key_value[1]
past_key = past_key_value[self.layer_idx][0]
past_value = past_key_value[self.layer_idx][1]
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
......@@ -400,8 +411,6 @@ class MistralFlashAttention2(MistralAttention):
f" {past_key.shape}"
)
past_key_value = (past_key, past_value)
if attention_mask is not None:
attention_mask = attention_mask[:, slicing_tokens:]
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
......
......@@ -414,6 +414,12 @@ class MixtralFlashAttention2(MixtralAttention):
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
# Because the input can be padded, the absolute sequence length depends on the max position id.
......@@ -436,11 +442,16 @@ class MixtralFlashAttention2(MixtralAttention):
if past_key_value is not None:
# Activate slicing cache only if the config has a value `sliding_windows` attribute
if getattr(self.config, "sliding_window", None) is not None and kv_seq_len > self.config.sliding_window:
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
if (
getattr(self.config, "sliding_window", None) is not None
and kv_seq_len > self.config.sliding_window
and cache_has_contents
):
slicing_tokens = 1 - self.config.sliding_window
past_key = past_key_value[0]
past_value = past_key_value[1]
past_key = past_key_value[self.layer_idx][0]
past_value = past_key_value[self.layer_idx][1]
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
......@@ -451,8 +462,6 @@ class MixtralFlashAttention2(MixtralAttention):
f" {past_key.shape}"
)
past_key_value = (past_key, past_value)
if attention_mask is not None:
attention_mask = attention_mask[:, slicing_tokens:]
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment