Unverified Commit 388fd314 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: Mistral/Mixtral FA2 cache fix when going beyond the context window (#28037)

parent 0ede7626
...@@ -363,6 +363,12 @@ class MistralFlashAttention2(MistralAttention): ...@@ -363,6 +363,12 @@ class MistralFlashAttention2(MistralAttention):
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
if past_key_value is not None: if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
# Because the input can be padded, the absolute sequence length depends on the max position id. # Because the input can be padded, the absolute sequence length depends on the max position id.
...@@ -385,11 +391,16 @@ class MistralFlashAttention2(MistralAttention): ...@@ -385,11 +391,16 @@ class MistralFlashAttention2(MistralAttention):
if past_key_value is not None: if past_key_value is not None:
# Activate slicing cache only if the config has a value `sliding_windows` attribute # Activate slicing cache only if the config has a value `sliding_windows` attribute
if getattr(self.config, "sliding_window", None) is not None and kv_seq_len > self.config.sliding_window: cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
if (
getattr(self.config, "sliding_window", None) is not None
and kv_seq_len > self.config.sliding_window
and cache_has_contents
):
slicing_tokens = 1 - self.config.sliding_window slicing_tokens = 1 - self.config.sliding_window
past_key = past_key_value[0] past_key = past_key_value[self.layer_idx][0]
past_value = past_key_value[1] past_value = past_key_value[self.layer_idx][1]
past_key = past_key[:, :, slicing_tokens:, :].contiguous() past_key = past_key[:, :, slicing_tokens:, :].contiguous()
past_value = past_value[:, :, slicing_tokens:, :].contiguous() past_value = past_value[:, :, slicing_tokens:, :].contiguous()
...@@ -400,8 +411,6 @@ class MistralFlashAttention2(MistralAttention): ...@@ -400,8 +411,6 @@ class MistralFlashAttention2(MistralAttention):
f" {past_key.shape}" f" {past_key.shape}"
) )
past_key_value = (past_key, past_value)
if attention_mask is not None: if attention_mask is not None:
attention_mask = attention_mask[:, slicing_tokens:] attention_mask = attention_mask[:, slicing_tokens:]
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
......
...@@ -414,6 +414,12 @@ class MixtralFlashAttention2(MixtralAttention): ...@@ -414,6 +414,12 @@ class MixtralFlashAttention2(MixtralAttention):
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
if past_key_value is not None: if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
# Because the input can be padded, the absolute sequence length depends on the max position id. # Because the input can be padded, the absolute sequence length depends on the max position id.
...@@ -436,11 +442,16 @@ class MixtralFlashAttention2(MixtralAttention): ...@@ -436,11 +442,16 @@ class MixtralFlashAttention2(MixtralAttention):
if past_key_value is not None: if past_key_value is not None:
# Activate slicing cache only if the config has a value `sliding_windows` attribute # Activate slicing cache only if the config has a value `sliding_windows` attribute
if getattr(self.config, "sliding_window", None) is not None and kv_seq_len > self.config.sliding_window: cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
if (
getattr(self.config, "sliding_window", None) is not None
and kv_seq_len > self.config.sliding_window
and cache_has_contents
):
slicing_tokens = 1 - self.config.sliding_window slicing_tokens = 1 - self.config.sliding_window
past_key = past_key_value[0] past_key = past_key_value[self.layer_idx][0]
past_value = past_key_value[1] past_value = past_key_value[self.layer_idx][1]
past_key = past_key[:, :, slicing_tokens:, :].contiguous() past_key = past_key[:, :, slicing_tokens:, :].contiguous()
past_value = past_value[:, :, slicing_tokens:, :].contiguous() past_value = past_value[:, :, slicing_tokens:, :].contiguous()
...@@ -451,8 +462,6 @@ class MixtralFlashAttention2(MixtralAttention): ...@@ -451,8 +462,6 @@ class MixtralFlashAttention2(MixtralAttention):
f" {past_key.shape}" f" {past_key.shape}"
) )
past_key_value = (past_key, past_value)
if attention_mask is not None: if attention_mask is not None:
attention_mask = attention_mask[:, slicing_tokens:] attention_mask = attention_mask[:, slicing_tokens:]
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment