Unverified Commit 4b3eb19f authored by Edoardo Cetin's avatar Edoardo Cetin Committed by GitHub
Browse files

Fix llama model sdpa attention forward function masking bug when output_attentions=True (#30652)



* Fix llama model forward function with attention=True, same-length encoded sequence.

* Fix style

* propagate fix to modeling_cohere, gemma, dbrx, and olmo (which copy the same sdpa masking logic from llama)

* Fix style

* ignore unnecessary sdpa mask converter when output_attentions=True

* add tests checking sdpa and eager outputs match when output_attentions=True

* Split if statements in two lines
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Fix formatting

* Add fix to new jetmoe model

* Add missing output_attentions argument to jetmoe mask creation

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 2d83324e
...@@ -889,7 +889,9 @@ class CohereModel(CoherePreTrainedModel): ...@@ -889,7 +889,9 @@ class CohereModel(CoherePreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
# embed positions # embed positions
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -958,6 +960,7 @@ class CohereModel(CoherePreTrainedModel): ...@@ -958,6 +960,7 @@ class CohereModel(CoherePreTrainedModel):
input_tensor: torch.Tensor, input_tensor: torch.Tensor,
cache_position: torch.Tensor, cache_position: torch.Tensor,
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
...@@ -974,7 +977,9 @@ class CohereModel(CoherePreTrainedModel): ...@@ -974,7 +977,9 @@ class CohereModel(CoherePreTrainedModel):
# to infer the attention mask. # to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache) using_static_cache = isinstance(past_key_values, StaticCache)
if self.config._attn_implementation == "sdpa" and not using_static_cache:
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa( if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, attention_mask,
inputs_embeds=input_tensor, inputs_embeds=input_tensor,
...@@ -1020,6 +1025,7 @@ class CohereModel(CoherePreTrainedModel): ...@@ -1020,6 +1025,7 @@ class CohereModel(CoherePreTrainedModel):
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None
and attention_mask.device.type == "cuda" and attention_mask.device.type == "cuda"
and not output_attentions
): ):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
......
...@@ -1123,7 +1123,10 @@ class DbrxModel(DbrxPreTrainedModel): ...@@ -1123,7 +1123,10 @@ class DbrxModel(DbrxPreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
# embed positions # embed positions
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -1204,6 +1207,7 @@ class DbrxModel(DbrxPreTrainedModel): ...@@ -1204,6 +1207,7 @@ class DbrxModel(DbrxPreTrainedModel):
input_tensor: torch.Tensor, input_tensor: torch.Tensor,
cache_position: torch.Tensor, cache_position: torch.Tensor,
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
...@@ -1220,7 +1224,9 @@ class DbrxModel(DbrxPreTrainedModel): ...@@ -1220,7 +1224,9 @@ class DbrxModel(DbrxPreTrainedModel):
# to infer the attention mask. # to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache) using_static_cache = isinstance(past_key_values, StaticCache)
if self.config._attn_implementation == "sdpa" and not using_static_cache:
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa( if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, attention_mask,
inputs_embeds=input_tensor, inputs_embeds=input_tensor,
...@@ -1266,6 +1272,7 @@ class DbrxModel(DbrxPreTrainedModel): ...@@ -1266,6 +1272,7 @@ class DbrxModel(DbrxPreTrainedModel):
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None
and attention_mask.device.type == "cuda" and attention_mask.device.type == "cuda"
and not output_attentions
): ):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
......
...@@ -873,7 +873,9 @@ class GemmaModel(GemmaPreTrainedModel): ...@@ -873,7 +873,9 @@ class GemmaModel(GemmaPreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
# embed positions # embed positions
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -948,6 +950,7 @@ class GemmaModel(GemmaPreTrainedModel): ...@@ -948,6 +950,7 @@ class GemmaModel(GemmaPreTrainedModel):
input_tensor: torch.Tensor, input_tensor: torch.Tensor,
cache_position: torch.Tensor, cache_position: torch.Tensor,
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
...@@ -964,7 +967,9 @@ class GemmaModel(GemmaPreTrainedModel): ...@@ -964,7 +967,9 @@ class GemmaModel(GemmaPreTrainedModel):
# to infer the attention mask. # to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache) using_static_cache = isinstance(past_key_values, StaticCache)
if self.config._attn_implementation == "sdpa" and not using_static_cache:
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa( if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, attention_mask,
inputs_embeds=input_tensor, inputs_embeds=input_tensor,
...@@ -1008,6 +1013,7 @@ class GemmaModel(GemmaPreTrainedModel): ...@@ -1008,6 +1013,7 @@ class GemmaModel(GemmaPreTrainedModel):
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None
and attention_mask.device.type == "cuda" and attention_mask.device.type == "cuda"
and not output_attentions
): ):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
......
...@@ -1103,7 +1103,9 @@ class JetMoeModel(JetMoePreTrainedModel): ...@@ -1103,7 +1103,9 @@ class JetMoeModel(JetMoePreTrainedModel):
" this may lead to unexpected behaviour for Flash Attention version of JetMoe. Make sure to " " this may lead to unexpected behaviour for Flash Attention version of JetMoe. Make sure to "
" call `tokenizer.padding_side = 'left'` before tokenizing the input. " " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
) )
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -1178,6 +1180,7 @@ class JetMoeModel(JetMoePreTrainedModel): ...@@ -1178,6 +1180,7 @@ class JetMoeModel(JetMoePreTrainedModel):
input_tensor: torch.Tensor, input_tensor: torch.Tensor,
cache_position: torch.Tensor, cache_position: torch.Tensor,
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
...@@ -1194,7 +1197,9 @@ class JetMoeModel(JetMoePreTrainedModel): ...@@ -1194,7 +1197,9 @@ class JetMoeModel(JetMoePreTrainedModel):
# to infer the attention mask. # to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache) using_static_cache = isinstance(past_key_values, StaticCache)
if self.config._attn_implementation == "sdpa" and not using_static_cache:
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa( if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, attention_mask,
inputs_embeds=input_tensor, inputs_embeds=input_tensor,
...@@ -1240,6 +1245,7 @@ class JetMoeModel(JetMoePreTrainedModel): ...@@ -1240,6 +1245,7 @@ class JetMoeModel(JetMoePreTrainedModel):
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None
and attention_mask.device.type == "cuda" and attention_mask.device.type == "cuda"
and not output_attentions
): ):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
......
...@@ -967,7 +967,9 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -967,7 +967,9 @@ class LlamaModel(LlamaPreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
# embed positions # embed positions
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -1036,6 +1038,7 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -1036,6 +1038,7 @@ class LlamaModel(LlamaPreTrainedModel):
input_tensor: torch.Tensor, input_tensor: torch.Tensor,
cache_position: torch.Tensor, cache_position: torch.Tensor,
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
...@@ -1052,7 +1055,9 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -1052,7 +1055,9 @@ class LlamaModel(LlamaPreTrainedModel):
# to infer the attention mask. # to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache) using_static_cache = isinstance(past_key_values, StaticCache)
if self.config._attn_implementation == "sdpa" and not using_static_cache:
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa( if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, attention_mask,
inputs_embeds=input_tensor, inputs_embeds=input_tensor,
...@@ -1098,6 +1103,7 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -1098,6 +1103,7 @@ class LlamaModel(LlamaPreTrainedModel):
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None
and attention_mask.device.type == "cuda" and attention_mask.device.type == "cuda"
and not output_attentions
): ):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
......
...@@ -945,7 +945,9 @@ class OlmoModel(OlmoPreTrainedModel): ...@@ -945,7 +945,9 @@ class OlmoModel(OlmoPreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
# embed positions # embed positions
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -1015,6 +1017,7 @@ class OlmoModel(OlmoPreTrainedModel): ...@@ -1015,6 +1017,7 @@ class OlmoModel(OlmoPreTrainedModel):
input_tensor: torch.Tensor, input_tensor: torch.Tensor,
cache_position: torch.Tensor, cache_position: torch.Tensor,
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
...@@ -1031,7 +1034,9 @@ class OlmoModel(OlmoPreTrainedModel): ...@@ -1031,7 +1034,9 @@ class OlmoModel(OlmoPreTrainedModel):
# to infer the attention mask. # to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache) using_static_cache = isinstance(past_key_values, StaticCache)
if self.config._attn_implementation == "sdpa" and not using_static_cache:
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa( if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, attention_mask,
inputs_embeds=input_tensor, inputs_embeds=input_tensor,
...@@ -1077,6 +1082,7 @@ class OlmoModel(OlmoPreTrainedModel): ...@@ -1077,6 +1082,7 @@ class OlmoModel(OlmoPreTrainedModel):
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None
and attention_mask.device.type == "cuda" and attention_mask.device.type == "cuda"
and not output_attentions
): ):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment