"tests/models/mpnet/test_tokenization_mpnet.py" did not exist on "df2af6d8b8765b1ac2cda12d2ece09bf7240fba8"
Unverified Commit 4b3eb19f authored by Edoardo Cetin's avatar Edoardo Cetin Committed by GitHub
Browse files

Fix llama model sdpa attention forward function masking bug when output_attentions=True (#30652)



* Fix llama model forward function with attention=True, same-length encoded sequence.

* Fix style

* propagate fix to modeling_cohere, gemma, dbrx, and olmo (which copy the same sdpa masking logic from llama)

* Fix style

* ignore unnecessary sdpa mask converter when output_attentions=True

* add tests checking sdpa and eager outputs match when output_attentions=True

* Split if statements in two lines
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Fix formatting

* Add fix to new jetmoe model

* Add missing output_attentions argument to jetmoe mask creation

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 2d83324e
...@@ -889,7 +889,9 @@ class CohereModel(CoherePreTrainedModel): ...@@ -889,7 +889,9 @@ class CohereModel(CoherePreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
# embed positions # embed positions
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -958,6 +960,7 @@ class CohereModel(CoherePreTrainedModel): ...@@ -958,6 +960,7 @@ class CohereModel(CoherePreTrainedModel):
input_tensor: torch.Tensor, input_tensor: torch.Tensor,
cache_position: torch.Tensor, cache_position: torch.Tensor,
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
...@@ -974,7 +977,9 @@ class CohereModel(CoherePreTrainedModel): ...@@ -974,7 +977,9 @@ class CohereModel(CoherePreTrainedModel):
# to infer the attention mask. # to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache) using_static_cache = isinstance(past_key_values, StaticCache)
if self.config._attn_implementation == "sdpa" and not using_static_cache:
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa( if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, attention_mask,
inputs_embeds=input_tensor, inputs_embeds=input_tensor,
...@@ -1020,6 +1025,7 @@ class CohereModel(CoherePreTrainedModel): ...@@ -1020,6 +1025,7 @@ class CohereModel(CoherePreTrainedModel):
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None
and attention_mask.device.type == "cuda" and attention_mask.device.type == "cuda"
and not output_attentions
): ):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
......
...@@ -1123,7 +1123,10 @@ class DbrxModel(DbrxPreTrainedModel): ...@@ -1123,7 +1123,10 @@ class DbrxModel(DbrxPreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
# embed positions # embed positions
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -1204,6 +1207,7 @@ class DbrxModel(DbrxPreTrainedModel): ...@@ -1204,6 +1207,7 @@ class DbrxModel(DbrxPreTrainedModel):
input_tensor: torch.Tensor, input_tensor: torch.Tensor,
cache_position: torch.Tensor, cache_position: torch.Tensor,
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
...@@ -1220,7 +1224,9 @@ class DbrxModel(DbrxPreTrainedModel): ...@@ -1220,7 +1224,9 @@ class DbrxModel(DbrxPreTrainedModel):
# to infer the attention mask. # to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache) using_static_cache = isinstance(past_key_values, StaticCache)
if self.config._attn_implementation == "sdpa" and not using_static_cache:
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa( if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, attention_mask,
inputs_embeds=input_tensor, inputs_embeds=input_tensor,
...@@ -1266,6 +1272,7 @@ class DbrxModel(DbrxPreTrainedModel): ...@@ -1266,6 +1272,7 @@ class DbrxModel(DbrxPreTrainedModel):
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None
and attention_mask.device.type == "cuda" and attention_mask.device.type == "cuda"
and not output_attentions
): ):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
......
...@@ -873,7 +873,9 @@ class GemmaModel(GemmaPreTrainedModel): ...@@ -873,7 +873,9 @@ class GemmaModel(GemmaPreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
# embed positions # embed positions
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -948,6 +950,7 @@ class GemmaModel(GemmaPreTrainedModel): ...@@ -948,6 +950,7 @@ class GemmaModel(GemmaPreTrainedModel):
input_tensor: torch.Tensor, input_tensor: torch.Tensor,
cache_position: torch.Tensor, cache_position: torch.Tensor,
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
...@@ -964,7 +967,9 @@ class GemmaModel(GemmaPreTrainedModel): ...@@ -964,7 +967,9 @@ class GemmaModel(GemmaPreTrainedModel):
# to infer the attention mask. # to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache) using_static_cache = isinstance(past_key_values, StaticCache)
if self.config._attn_implementation == "sdpa" and not using_static_cache:
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa( if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, attention_mask,
inputs_embeds=input_tensor, inputs_embeds=input_tensor,
...@@ -1008,6 +1013,7 @@ class GemmaModel(GemmaPreTrainedModel): ...@@ -1008,6 +1013,7 @@ class GemmaModel(GemmaPreTrainedModel):
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None
and attention_mask.device.type == "cuda" and attention_mask.device.type == "cuda"
and not output_attentions
): ):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
......
...@@ -1103,7 +1103,9 @@ class JetMoeModel(JetMoePreTrainedModel): ...@@ -1103,7 +1103,9 @@ class JetMoeModel(JetMoePreTrainedModel):
" this may lead to unexpected behaviour for Flash Attention version of JetMoe. Make sure to " " this may lead to unexpected behaviour for Flash Attention version of JetMoe. Make sure to "
" call `tokenizer.padding_side = 'left'` before tokenizing the input. " " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
) )
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -1178,6 +1180,7 @@ class JetMoeModel(JetMoePreTrainedModel): ...@@ -1178,6 +1180,7 @@ class JetMoeModel(JetMoePreTrainedModel):
input_tensor: torch.Tensor, input_tensor: torch.Tensor,
cache_position: torch.Tensor, cache_position: torch.Tensor,
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
...@@ -1194,7 +1197,9 @@ class JetMoeModel(JetMoePreTrainedModel): ...@@ -1194,7 +1197,9 @@ class JetMoeModel(JetMoePreTrainedModel):
# to infer the attention mask. # to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache) using_static_cache = isinstance(past_key_values, StaticCache)
if self.config._attn_implementation == "sdpa" and not using_static_cache:
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa( if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, attention_mask,
inputs_embeds=input_tensor, inputs_embeds=input_tensor,
...@@ -1240,6 +1245,7 @@ class JetMoeModel(JetMoePreTrainedModel): ...@@ -1240,6 +1245,7 @@ class JetMoeModel(JetMoePreTrainedModel):
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None
and attention_mask.device.type == "cuda" and attention_mask.device.type == "cuda"
and not output_attentions
): ):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
......
...@@ -967,7 +967,9 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -967,7 +967,9 @@ class LlamaModel(LlamaPreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
# embed positions # embed positions
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -1036,6 +1038,7 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -1036,6 +1038,7 @@ class LlamaModel(LlamaPreTrainedModel):
input_tensor: torch.Tensor, input_tensor: torch.Tensor,
cache_position: torch.Tensor, cache_position: torch.Tensor,
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
...@@ -1052,7 +1055,9 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -1052,7 +1055,9 @@ class LlamaModel(LlamaPreTrainedModel):
# to infer the attention mask. # to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache) using_static_cache = isinstance(past_key_values, StaticCache)
if self.config._attn_implementation == "sdpa" and not using_static_cache:
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa( if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, attention_mask,
inputs_embeds=input_tensor, inputs_embeds=input_tensor,
...@@ -1098,6 +1103,7 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -1098,6 +1103,7 @@ class LlamaModel(LlamaPreTrainedModel):
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None
and attention_mask.device.type == "cuda" and attention_mask.device.type == "cuda"
and not output_attentions
): ):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
......
...@@ -945,7 +945,9 @@ class OlmoModel(OlmoPreTrainedModel): ...@@ -945,7 +945,9 @@ class OlmoModel(OlmoPreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
# embed positions # embed positions
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -1015,6 +1017,7 @@ class OlmoModel(OlmoPreTrainedModel): ...@@ -1015,6 +1017,7 @@ class OlmoModel(OlmoPreTrainedModel):
input_tensor: torch.Tensor, input_tensor: torch.Tensor,
cache_position: torch.Tensor, cache_position: torch.Tensor,
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
...@@ -1031,7 +1034,9 @@ class OlmoModel(OlmoPreTrainedModel): ...@@ -1031,7 +1034,9 @@ class OlmoModel(OlmoPreTrainedModel):
# to infer the attention mask. # to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache) using_static_cache = isinstance(past_key_values, StaticCache)
if self.config._attn_implementation == "sdpa" and not using_static_cache:
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa( if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, attention_mask,
inputs_embeds=input_tensor, inputs_embeds=input_tensor,
...@@ -1077,6 +1082,7 @@ class OlmoModel(OlmoPreTrainedModel): ...@@ -1077,6 +1082,7 @@ class OlmoModel(OlmoPreTrainedModel):
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None
and attention_mask.device.type == "cuda" and attention_mask.device.type == "cuda"
and not output_attentions
): ):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
......
...@@ -3757,11 +3757,15 @@ class ModelTesterMixin: ...@@ -3757,11 +3757,15 @@ class ModelTesterMixin:
if not has_sdpa and model_sdpa.config.model_type != "falcon": if not has_sdpa and model_sdpa.config.model_type != "falcon":
raise ValueError("The SDPA model should have SDPA attention layers") raise ValueError("The SDPA model should have SDPA attention layers")
# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 8 times the model, # We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 16 times the model,
# but it would be nicer to have an efficient way to use parameterized.expand # but it would be nicer to have an efficient way to use parameterized.expand
fail_cases = [] fail_cases = []
for padding_side in ["left", "right"]: for padding_side in ["left", "right"]:
for use_mask in [False, True]: for use_mask in [False, True]:
for output_attentions in [True, False]:
can_output_attn = "output_attentions" in inspect.signature(model_sdpa.forward).parameters
if not (self.has_attentions and can_output_attn) and output_attentions:
continue
for batch_size in [1, 5]: for batch_size in [1, 5]:
dummy_input = inputs_dict[model.main_input_name] dummy_input = inputs_dict[model.main_input_name]
...@@ -3822,7 +3826,9 @@ class ModelTesterMixin: ...@@ -3822,7 +3826,9 @@ class ModelTesterMixin:
for enable_kernels in [False, True]: for enable_kernels in [False, True]:
failcase = f"padding_side={padding_side}, use_mask={use_mask}, batch_size={batch_size}, enable_kernels={enable_kernels}" failcase = f"padding_side={padding_side}, use_mask={use_mask}, batch_size={batch_size}, enable_kernels={enable_kernels}"
if is_encoder_decoder: if is_encoder_decoder:
decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:batch_size] decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[
:batch_size
]
if decoder_input_ids.shape[0] != batch_size: if decoder_input_ids.shape[0] != batch_size:
extension = torch.ones( extension = torch.ones(
batch_size - decoder_input_ids.shape[0], batch_size - decoder_input_ids.shape[0],
...@@ -3850,6 +3856,12 @@ class ModelTesterMixin: ...@@ -3850,6 +3856,12 @@ class ModelTesterMixin:
if "attention_mask" in inspect.signature(model_eager.forward).parameters: if "attention_mask" in inspect.signature(model_eager.forward).parameters:
processed_inputs["attention_mask"] = dummy_attention_mask processed_inputs["attention_mask"] = dummy_attention_mask
if (
self.has_attentions
and "output_attentions" in inspect.signature(model_sdpa.forward).parameters
):
processed_inputs["output_attentions"] = output_attentions
# TODO: test gradients as well (& for FA2 as well!) # TODO: test gradients as well (& for FA2 as well!)
with torch.no_grad(): with torch.no_grad():
with torch.backends.cuda.sdp_kernel( with torch.backends.cuda.sdp_kernel(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment