Unverified Commit 4b3eb19f authored by Edoardo Cetin's avatar Edoardo Cetin Committed by GitHub
Browse files

Fix llama model sdpa attention forward function masking bug when output_attentions=True (#30652)



* Fix llama model forward function with attention=True, same-length encoded sequence.

* Fix style

* propagate fix to modeling_cohere, gemma, dbrx, and olmo (which copy the same sdpa masking logic from llama)

* Fix style

* ignore unnecessary sdpa mask converter when output_attentions=True

* add tests checking sdpa and eager outputs match when output_attentions=True

* Split if statements in two lines
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Fix formatting

* Add fix to new jetmoe model

* Add missing output_attentions argument to jetmoe mask creation

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 2d83324e
...@@ -889,7 +889,9 @@ class CohereModel(CoherePreTrainedModel): ...@@ -889,7 +889,9 @@ class CohereModel(CoherePreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
# embed positions # embed positions
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -958,6 +960,7 @@ class CohereModel(CoherePreTrainedModel): ...@@ -958,6 +960,7 @@ class CohereModel(CoherePreTrainedModel):
input_tensor: torch.Tensor, input_tensor: torch.Tensor,
cache_position: torch.Tensor, cache_position: torch.Tensor,
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
...@@ -974,7 +977,9 @@ class CohereModel(CoherePreTrainedModel): ...@@ -974,7 +977,9 @@ class CohereModel(CoherePreTrainedModel):
# to infer the attention mask. # to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache) using_static_cache = isinstance(past_key_values, StaticCache)
if self.config._attn_implementation == "sdpa" and not using_static_cache:
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa( if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, attention_mask,
inputs_embeds=input_tensor, inputs_embeds=input_tensor,
...@@ -1020,6 +1025,7 @@ class CohereModel(CoherePreTrainedModel): ...@@ -1020,6 +1025,7 @@ class CohereModel(CoherePreTrainedModel):
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None
and attention_mask.device.type == "cuda" and attention_mask.device.type == "cuda"
and not output_attentions
): ):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
......
...@@ -1123,7 +1123,10 @@ class DbrxModel(DbrxPreTrainedModel): ...@@ -1123,7 +1123,10 @@ class DbrxModel(DbrxPreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
# embed positions # embed positions
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -1204,6 +1207,7 @@ class DbrxModel(DbrxPreTrainedModel): ...@@ -1204,6 +1207,7 @@ class DbrxModel(DbrxPreTrainedModel):
input_tensor: torch.Tensor, input_tensor: torch.Tensor,
cache_position: torch.Tensor, cache_position: torch.Tensor,
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
...@@ -1220,7 +1224,9 @@ class DbrxModel(DbrxPreTrainedModel): ...@@ -1220,7 +1224,9 @@ class DbrxModel(DbrxPreTrainedModel):
# to infer the attention mask. # to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache) using_static_cache = isinstance(past_key_values, StaticCache)
if self.config._attn_implementation == "sdpa" and not using_static_cache:
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa( if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, attention_mask,
inputs_embeds=input_tensor, inputs_embeds=input_tensor,
...@@ -1266,6 +1272,7 @@ class DbrxModel(DbrxPreTrainedModel): ...@@ -1266,6 +1272,7 @@ class DbrxModel(DbrxPreTrainedModel):
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None
and attention_mask.device.type == "cuda" and attention_mask.device.type == "cuda"
and not output_attentions
): ):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
......
...@@ -873,7 +873,9 @@ class GemmaModel(GemmaPreTrainedModel): ...@@ -873,7 +873,9 @@ class GemmaModel(GemmaPreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
# embed positions # embed positions
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -948,6 +950,7 @@ class GemmaModel(GemmaPreTrainedModel): ...@@ -948,6 +950,7 @@ class GemmaModel(GemmaPreTrainedModel):
input_tensor: torch.Tensor, input_tensor: torch.Tensor,
cache_position: torch.Tensor, cache_position: torch.Tensor,
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
...@@ -964,7 +967,9 @@ class GemmaModel(GemmaPreTrainedModel): ...@@ -964,7 +967,9 @@ class GemmaModel(GemmaPreTrainedModel):
# to infer the attention mask. # to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache) using_static_cache = isinstance(past_key_values, StaticCache)
if self.config._attn_implementation == "sdpa" and not using_static_cache:
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa( if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, attention_mask,
inputs_embeds=input_tensor, inputs_embeds=input_tensor,
...@@ -1008,6 +1013,7 @@ class GemmaModel(GemmaPreTrainedModel): ...@@ -1008,6 +1013,7 @@ class GemmaModel(GemmaPreTrainedModel):
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None
and attention_mask.device.type == "cuda" and attention_mask.device.type == "cuda"
and not output_attentions
): ):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
......
...@@ -1103,7 +1103,9 @@ class JetMoeModel(JetMoePreTrainedModel): ...@@ -1103,7 +1103,9 @@ class JetMoeModel(JetMoePreTrainedModel):
" this may lead to unexpected behaviour for Flash Attention version of JetMoe. Make sure to " " this may lead to unexpected behaviour for Flash Attention version of JetMoe. Make sure to "
" call `tokenizer.padding_side = 'left'` before tokenizing the input. " " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
) )
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -1178,6 +1180,7 @@ class JetMoeModel(JetMoePreTrainedModel): ...@@ -1178,6 +1180,7 @@ class JetMoeModel(JetMoePreTrainedModel):
input_tensor: torch.Tensor, input_tensor: torch.Tensor,
cache_position: torch.Tensor, cache_position: torch.Tensor,
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
...@@ -1194,7 +1197,9 @@ class JetMoeModel(JetMoePreTrainedModel): ...@@ -1194,7 +1197,9 @@ class JetMoeModel(JetMoePreTrainedModel):
# to infer the attention mask. # to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache) using_static_cache = isinstance(past_key_values, StaticCache)
if self.config._attn_implementation == "sdpa" and not using_static_cache:
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa( if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, attention_mask,
inputs_embeds=input_tensor, inputs_embeds=input_tensor,
...@@ -1240,6 +1245,7 @@ class JetMoeModel(JetMoePreTrainedModel): ...@@ -1240,6 +1245,7 @@ class JetMoeModel(JetMoePreTrainedModel):
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None
and attention_mask.device.type == "cuda" and attention_mask.device.type == "cuda"
and not output_attentions
): ):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
......
...@@ -967,7 +967,9 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -967,7 +967,9 @@ class LlamaModel(LlamaPreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
# embed positions # embed positions
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -1036,6 +1038,7 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -1036,6 +1038,7 @@ class LlamaModel(LlamaPreTrainedModel):
input_tensor: torch.Tensor, input_tensor: torch.Tensor,
cache_position: torch.Tensor, cache_position: torch.Tensor,
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
...@@ -1052,7 +1055,9 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -1052,7 +1055,9 @@ class LlamaModel(LlamaPreTrainedModel):
# to infer the attention mask. # to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache) using_static_cache = isinstance(past_key_values, StaticCache)
if self.config._attn_implementation == "sdpa" and not using_static_cache:
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa( if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, attention_mask,
inputs_embeds=input_tensor, inputs_embeds=input_tensor,
...@@ -1098,6 +1103,7 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -1098,6 +1103,7 @@ class LlamaModel(LlamaPreTrainedModel):
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None
and attention_mask.device.type == "cuda" and attention_mask.device.type == "cuda"
and not output_attentions
): ):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
......
...@@ -945,7 +945,9 @@ class OlmoModel(OlmoPreTrainedModel): ...@@ -945,7 +945,9 @@ class OlmoModel(OlmoPreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
# embed positions # embed positions
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -1015,6 +1017,7 @@ class OlmoModel(OlmoPreTrainedModel): ...@@ -1015,6 +1017,7 @@ class OlmoModel(OlmoPreTrainedModel):
input_tensor: torch.Tensor, input_tensor: torch.Tensor,
cache_position: torch.Tensor, cache_position: torch.Tensor,
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
...@@ -1031,7 +1034,9 @@ class OlmoModel(OlmoPreTrainedModel): ...@@ -1031,7 +1034,9 @@ class OlmoModel(OlmoPreTrainedModel):
# to infer the attention mask. # to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache) using_static_cache = isinstance(past_key_values, StaticCache)
if self.config._attn_implementation == "sdpa" and not using_static_cache:
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa( if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, attention_mask,
inputs_embeds=input_tensor, inputs_embeds=input_tensor,
...@@ -1077,6 +1082,7 @@ class OlmoModel(OlmoPreTrainedModel): ...@@ -1077,6 +1082,7 @@ class OlmoModel(OlmoPreTrainedModel):
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None
and attention_mask.device.type == "cuda" and attention_mask.device.type == "cuda"
and not output_attentions
): ):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
......
...@@ -3757,11 +3757,15 @@ class ModelTesterMixin: ...@@ -3757,11 +3757,15 @@ class ModelTesterMixin:
if not has_sdpa and model_sdpa.config.model_type != "falcon": if not has_sdpa and model_sdpa.config.model_type != "falcon":
raise ValueError("The SDPA model should have SDPA attention layers") raise ValueError("The SDPA model should have SDPA attention layers")
# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 8 times the model, # We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 16 times the model,
# but it would be nicer to have an efficient way to use parameterized.expand # but it would be nicer to have an efficient way to use parameterized.expand
fail_cases = [] fail_cases = []
for padding_side in ["left", "right"]: for padding_side in ["left", "right"]:
for use_mask in [False, True]: for use_mask in [False, True]:
for output_attentions in [True, False]:
can_output_attn = "output_attentions" in inspect.signature(model_sdpa.forward).parameters
if not (self.has_attentions and can_output_attn) and output_attentions:
continue
for batch_size in [1, 5]: for batch_size in [1, 5]:
dummy_input = inputs_dict[model.main_input_name] dummy_input = inputs_dict[model.main_input_name]
...@@ -3822,7 +3826,9 @@ class ModelTesterMixin: ...@@ -3822,7 +3826,9 @@ class ModelTesterMixin:
for enable_kernels in [False, True]: for enable_kernels in [False, True]:
failcase = f"padding_side={padding_side}, use_mask={use_mask}, batch_size={batch_size}, enable_kernels={enable_kernels}" failcase = f"padding_side={padding_side}, use_mask={use_mask}, batch_size={batch_size}, enable_kernels={enable_kernels}"
if is_encoder_decoder: if is_encoder_decoder:
decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:batch_size] decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[
:batch_size
]
if decoder_input_ids.shape[0] != batch_size: if decoder_input_ids.shape[0] != batch_size:
extension = torch.ones( extension = torch.ones(
batch_size - decoder_input_ids.shape[0], batch_size - decoder_input_ids.shape[0],
...@@ -3850,6 +3856,12 @@ class ModelTesterMixin: ...@@ -3850,6 +3856,12 @@ class ModelTesterMixin:
if "attention_mask" in inspect.signature(model_eager.forward).parameters: if "attention_mask" in inspect.signature(model_eager.forward).parameters:
processed_inputs["attention_mask"] = dummy_attention_mask processed_inputs["attention_mask"] = dummy_attention_mask
if (
self.has_attentions
and "output_attentions" in inspect.signature(model_sdpa.forward).parameters
):
processed_inputs["output_attentions"] = output_attentions
# TODO: test gradients as well (& for FA2 as well!) # TODO: test gradients as well (& for FA2 as well!)
with torch.no_grad(): with torch.no_grad():
with torch.backends.cuda.sdp_kernel( with torch.backends.cuda.sdp_kernel(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment