Unverified Commit 4b3eb19f authored by Edoardo Cetin's avatar Edoardo Cetin Committed by GitHub
Browse files

Fix llama model sdpa attention forward function masking bug when output_attentions=True (#30652)



* Fix llama model forward function with attention=True, same-length encoded sequence.

* Fix style

* propagate fix to modeling_cohere, gemma, dbrx, and olmo (which copy the same sdpa masking logic from llama)

* Fix style

* ignore unnecessary sdpa mask converter when output_attentions=True

* add tests checking sdpa and eager outputs match when output_attentions=True

* Split if statements in two lines
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Fix formatting

* Add fix to new jetmoe model

* Add missing output_attentions argument to jetmoe mask creation

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 2d83324e
...@@ -889,7 +889,9 @@ class CohereModel(CoherePreTrainedModel): ...@@ -889,7 +889,9 @@ class CohereModel(CoherePreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
# embed positions # embed positions
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -958,6 +960,7 @@ class CohereModel(CoherePreTrainedModel): ...@@ -958,6 +960,7 @@ class CohereModel(CoherePreTrainedModel):
input_tensor: torch.Tensor, input_tensor: torch.Tensor,
cache_position: torch.Tensor, cache_position: torch.Tensor,
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
...@@ -974,7 +977,9 @@ class CohereModel(CoherePreTrainedModel): ...@@ -974,7 +977,9 @@ class CohereModel(CoherePreTrainedModel):
# to infer the attention mask. # to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache) using_static_cache = isinstance(past_key_values, StaticCache)
if self.config._attn_implementation == "sdpa" and not using_static_cache:
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa( if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, attention_mask,
inputs_embeds=input_tensor, inputs_embeds=input_tensor,
...@@ -1020,6 +1025,7 @@ class CohereModel(CoherePreTrainedModel): ...@@ -1020,6 +1025,7 @@ class CohereModel(CoherePreTrainedModel):
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None
and attention_mask.device.type == "cuda" and attention_mask.device.type == "cuda"
and not output_attentions
): ):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
......
...@@ -1123,7 +1123,10 @@ class DbrxModel(DbrxPreTrainedModel): ...@@ -1123,7 +1123,10 @@ class DbrxModel(DbrxPreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
# embed positions # embed positions
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -1204,6 +1207,7 @@ class DbrxModel(DbrxPreTrainedModel): ...@@ -1204,6 +1207,7 @@ class DbrxModel(DbrxPreTrainedModel):
input_tensor: torch.Tensor, input_tensor: torch.Tensor,
cache_position: torch.Tensor, cache_position: torch.Tensor,
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
...@@ -1220,7 +1224,9 @@ class DbrxModel(DbrxPreTrainedModel): ...@@ -1220,7 +1224,9 @@ class DbrxModel(DbrxPreTrainedModel):
# to infer the attention mask. # to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache) using_static_cache = isinstance(past_key_values, StaticCache)
if self.config._attn_implementation == "sdpa" and not using_static_cache:
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa( if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, attention_mask,
inputs_embeds=input_tensor, inputs_embeds=input_tensor,
...@@ -1266,6 +1272,7 @@ class DbrxModel(DbrxPreTrainedModel): ...@@ -1266,6 +1272,7 @@ class DbrxModel(DbrxPreTrainedModel):
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None
and attention_mask.device.type == "cuda" and attention_mask.device.type == "cuda"
and not output_attentions
): ):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
......
...@@ -873,7 +873,9 @@ class GemmaModel(GemmaPreTrainedModel): ...@@ -873,7 +873,9 @@ class GemmaModel(GemmaPreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
# embed positions # embed positions
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -948,6 +950,7 @@ class GemmaModel(GemmaPreTrainedModel): ...@@ -948,6 +950,7 @@ class GemmaModel(GemmaPreTrainedModel):
input_tensor: torch.Tensor, input_tensor: torch.Tensor,
cache_position: torch.Tensor, cache_position: torch.Tensor,
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
...@@ -964,7 +967,9 @@ class GemmaModel(GemmaPreTrainedModel): ...@@ -964,7 +967,9 @@ class GemmaModel(GemmaPreTrainedModel):
# to infer the attention mask. # to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache) using_static_cache = isinstance(past_key_values, StaticCache)
if self.config._attn_implementation == "sdpa" and not using_static_cache:
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa( if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, attention_mask,
inputs_embeds=input_tensor, inputs_embeds=input_tensor,
...@@ -1008,6 +1013,7 @@ class GemmaModel(GemmaPreTrainedModel): ...@@ -1008,6 +1013,7 @@ class GemmaModel(GemmaPreTrainedModel):
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None
and attention_mask.device.type == "cuda" and attention_mask.device.type == "cuda"
and not output_attentions
): ):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
......
...@@ -1103,7 +1103,9 @@ class JetMoeModel(JetMoePreTrainedModel): ...@@ -1103,7 +1103,9 @@ class JetMoeModel(JetMoePreTrainedModel):
" this may lead to unexpected behaviour for Flash Attention version of JetMoe. Make sure to " " this may lead to unexpected behaviour for Flash Attention version of JetMoe. Make sure to "
" call `tokenizer.padding_side = 'left'` before tokenizing the input. " " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
) )
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -1178,6 +1180,7 @@ class JetMoeModel(JetMoePreTrainedModel): ...@@ -1178,6 +1180,7 @@ class JetMoeModel(JetMoePreTrainedModel):
input_tensor: torch.Tensor, input_tensor: torch.Tensor,
cache_position: torch.Tensor, cache_position: torch.Tensor,
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
...@@ -1194,7 +1197,9 @@ class JetMoeModel(JetMoePreTrainedModel): ...@@ -1194,7 +1197,9 @@ class JetMoeModel(JetMoePreTrainedModel):
# to infer the attention mask. # to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache) using_static_cache = isinstance(past_key_values, StaticCache)
if self.config._attn_implementation == "sdpa" and not using_static_cache:
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa( if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, attention_mask,
inputs_embeds=input_tensor, inputs_embeds=input_tensor,
...@@ -1240,6 +1245,7 @@ class JetMoeModel(JetMoePreTrainedModel): ...@@ -1240,6 +1245,7 @@ class JetMoeModel(JetMoePreTrainedModel):
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None
and attention_mask.device.type == "cuda" and attention_mask.device.type == "cuda"
and not output_attentions
): ):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
......
...@@ -967,7 +967,9 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -967,7 +967,9 @@ class LlamaModel(LlamaPreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
# embed positions # embed positions
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -1036,6 +1038,7 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -1036,6 +1038,7 @@ class LlamaModel(LlamaPreTrainedModel):
input_tensor: torch.Tensor, input_tensor: torch.Tensor,
cache_position: torch.Tensor, cache_position: torch.Tensor,
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
...@@ -1052,7 +1055,9 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -1052,7 +1055,9 @@ class LlamaModel(LlamaPreTrainedModel):
# to infer the attention mask. # to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache) using_static_cache = isinstance(past_key_values, StaticCache)
if self.config._attn_implementation == "sdpa" and not using_static_cache:
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa( if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, attention_mask,
inputs_embeds=input_tensor, inputs_embeds=input_tensor,
...@@ -1098,6 +1103,7 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -1098,6 +1103,7 @@ class LlamaModel(LlamaPreTrainedModel):
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None
and attention_mask.device.type == "cuda" and attention_mask.device.type == "cuda"
and not output_attentions
): ):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
......
...@@ -945,7 +945,9 @@ class OlmoModel(OlmoPreTrainedModel): ...@@ -945,7 +945,9 @@ class OlmoModel(OlmoPreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
# embed positions # embed positions
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -1015,6 +1017,7 @@ class OlmoModel(OlmoPreTrainedModel): ...@@ -1015,6 +1017,7 @@ class OlmoModel(OlmoPreTrainedModel):
input_tensor: torch.Tensor, input_tensor: torch.Tensor,
cache_position: torch.Tensor, cache_position: torch.Tensor,
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool,
): ):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
...@@ -1031,7 +1034,9 @@ class OlmoModel(OlmoPreTrainedModel): ...@@ -1031,7 +1034,9 @@ class OlmoModel(OlmoPreTrainedModel):
# to infer the attention mask. # to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache) using_static_cache = isinstance(past_key_values, StaticCache)
if self.config._attn_implementation == "sdpa" and not using_static_cache:
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa( if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, attention_mask,
inputs_embeds=input_tensor, inputs_embeds=input_tensor,
...@@ -1077,6 +1082,7 @@ class OlmoModel(OlmoPreTrainedModel): ...@@ -1077,6 +1082,7 @@ class OlmoModel(OlmoPreTrainedModel):
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None
and attention_mask.device.type == "cuda" and attention_mask.device.type == "cuda"
and not output_attentions
): ):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
......
...@@ -3757,177 +3757,189 @@ class ModelTesterMixin: ...@@ -3757,177 +3757,189 @@ class ModelTesterMixin:
if not has_sdpa and model_sdpa.config.model_type != "falcon": if not has_sdpa and model_sdpa.config.model_type != "falcon":
raise ValueError("The SDPA model should have SDPA attention layers") raise ValueError("The SDPA model should have SDPA attention layers")
# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 8 times the model, # We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 16 times the model,
# but it would be nicer to have an efficient way to use parameterized.expand # but it would be nicer to have an efficient way to use parameterized.expand
fail_cases = [] fail_cases = []
for padding_side in ["left", "right"]: for padding_side in ["left", "right"]:
for use_mask in [False, True]: for use_mask in [False, True]:
for batch_size in [1, 5]: for output_attentions in [True, False]:
dummy_input = inputs_dict[model.main_input_name] can_output_attn = "output_attentions" in inspect.signature(model_sdpa.forward).parameters
if not (self.has_attentions and can_output_attn) and output_attentions:
continue
for batch_size in [1, 5]:
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
dummy_input = dummy_input.to(torch_dtype)
dummy_input = dummy_input[:batch_size]
if dummy_input.shape[0] != batch_size:
if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]: if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
extension = torch.rand( dummy_input = dummy_input.to(torch_dtype)
batch_size - dummy_input.shape[0],
*dummy_input.shape[1:], dummy_input = dummy_input[:batch_size]
dtype=torch_dtype, if dummy_input.shape[0] != batch_size:
device=torch_device, if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
) extension = torch.rand(
dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device) batch_size - dummy_input.shape[0],
else: *dummy_input.shape[1:],
extension = torch.randint( dtype=torch_dtype,
high=5, device=torch_device,
size=(batch_size - dummy_input.shape[0], *dummy_input.shape[1:]), )
dtype=dummy_input.dtype, dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device)
device=torch_device,
)
dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device)
if not use_mask:
dummy_attention_mask = None
else:
dummy_attention_mask = inputs_dict.get("attention_mask", None)
if dummy_attention_mask is None:
if is_encoder_decoder:
seqlen = inputs_dict.get("decoder_input_ids", dummy_input).shape[-1]
else: else:
seqlen = dummy_input.shape[-1] extension = torch.randint(
dummy_attention_mask = ( high=5,
torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device) size=(batch_size - dummy_input.shape[0], *dummy_input.shape[1:]),
) dtype=dummy_input.dtype,
device=torch_device,
)
dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device)
dummy_attention_mask = dummy_attention_mask[:batch_size] if not use_mask:
if dummy_attention_mask.shape[0] != batch_size: dummy_attention_mask = None
extension = torch.ones( else:
batch_size - dummy_attention_mask.shape[0], dummy_attention_mask = inputs_dict.get("attention_mask", None)
*dummy_attention_mask.shape[1:], if dummy_attention_mask is None:
dtype=dummy_attention_mask.dtype, if is_encoder_decoder:
device=torch_device, seqlen = inputs_dict.get("decoder_input_ids", dummy_input).shape[-1]
) else:
dummy_attention_mask = torch.cat((dummy_attention_mask, extension), dim=0) seqlen = dummy_input.shape[-1]
dummy_attention_mask = dummy_attention_mask.to(torch_device) dummy_attention_mask = (
torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device)
dummy_attention_mask[:] = 1 )
if padding_side == "left":
dummy_attention_mask[-1, :-1] = 1 dummy_attention_mask = dummy_attention_mask[:batch_size]
dummy_attention_mask[-1, -4:] = 0 if dummy_attention_mask.shape[0] != batch_size:
elif padding_side == "right":
dummy_attention_mask[-1, 1:] = 1
dummy_attention_mask[-1, :3] = 0
for enable_kernels in [False, True]:
failcase = f"padding_side={padding_side}, use_mask={use_mask}, batch_size={batch_size}, enable_kernels={enable_kernels}"
if is_encoder_decoder:
decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:batch_size]
if decoder_input_ids.shape[0] != batch_size:
extension = torch.ones( extension = torch.ones(
batch_size - decoder_input_ids.shape[0], batch_size - dummy_attention_mask.shape[0],
*decoder_input_ids.shape[1:], *dummy_attention_mask.shape[1:],
dtype=decoder_input_ids.dtype, dtype=dummy_attention_mask.dtype,
device=torch_device, device=torch_device,
) )
decoder_input_ids = torch.cat((decoder_input_ids, extension), dim=0) dummy_attention_mask = torch.cat((dummy_attention_mask, extension), dim=0)
decoder_input_ids = decoder_input_ids.to(torch_device) dummy_attention_mask = dummy_attention_mask.to(torch_device)
# TODO: never an `attention_mask` arg here?
processed_inputs = {
model.main_input_name: dummy_input,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": dummy_attention_mask,
"output_hidden_states": True,
}
else:
processed_inputs = {
model.main_input_name: dummy_input,
"output_hidden_states": True,
}
# Otherwise fails for e.g. WhisperEncoderModel
if "attention_mask" in inspect.signature(model_eager.forward).parameters:
processed_inputs["attention_mask"] = dummy_attention_mask
# TODO: test gradients as well (& for FA2 as well!)
with torch.no_grad():
with torch.backends.cuda.sdp_kernel(
enable_flash=enable_kernels,
enable_math=True,
enable_mem_efficient=enable_kernels,
):
prepared_inputs = self._prepare_for_class(processed_inputs, model_class)
outputs_eager = model_eager(**prepared_inputs)
outputs_sdpa = model_sdpa(**prepared_inputs)
logits_eager = (
outputs_eager.hidden_states[-1]
if not is_encoder_decoder
else outputs_eager.decoder_hidden_states[-1]
)
logits_sdpa = (
outputs_sdpa.hidden_states[-1]
if not is_encoder_decoder
else outputs_sdpa.decoder_hidden_states[-1]
)
if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype]
else:
atol = 1e-7
rtol = 1e-4
# Masked tokens output slightly deviates - we don't mind that. dummy_attention_mask[:] = 1
if use_mask:
if padding_side == "left": if padding_side == "left":
sub_sdpa = logits_sdpa[:-1] dummy_attention_mask[-1, :-1] = 1
sub_eager = logits_eager[:-1] dummy_attention_mask[-1, -4:] = 0
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): elif padding_side == "right":
fail_cases.append( dummy_attention_mask[-1, 1:] = 1
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol) dummy_attention_mask[-1, :3] = 0
)
sub_sdpa = logits_sdpa[-1, :-4] for enable_kernels in [False, True]:
sub_eager = logits_eager[-1, :-4] failcase = f"padding_side={padding_side}, use_mask={use_mask}, batch_size={batch_size}, enable_kernels={enable_kernels}"
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): if is_encoder_decoder:
fail_cases.append( decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol) :batch_size
]
if decoder_input_ids.shape[0] != batch_size:
extension = torch.ones(
batch_size - decoder_input_ids.shape[0],
*decoder_input_ids.shape[1:],
dtype=decoder_input_ids.dtype,
device=torch_device,
) )
decoder_input_ids = torch.cat((decoder_input_ids, extension), dim=0)
decoder_input_ids = decoder_input_ids.to(torch_device)
# TODO: never an `attention_mask` arg here?
processed_inputs = {
model.main_input_name: dummy_input,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": dummy_attention_mask,
"output_hidden_states": True,
}
else:
processed_inputs = {
model.main_input_name: dummy_input,
"output_hidden_states": True,
}
# Otherwise fails for e.g. WhisperEncoderModel
if "attention_mask" in inspect.signature(model_eager.forward).parameters:
processed_inputs["attention_mask"] = dummy_attention_mask
if (
self.has_attentions
and "output_attentions" in inspect.signature(model_sdpa.forward).parameters
):
processed_inputs["output_attentions"] = output_attentions
# TODO: test gradients as well (& for FA2 as well!)
with torch.no_grad():
with torch.backends.cuda.sdp_kernel(
enable_flash=enable_kernels,
enable_math=True,
enable_mem_efficient=enable_kernels,
):
prepared_inputs = self._prepare_for_class(processed_inputs, model_class)
outputs_eager = model_eager(**prepared_inputs)
outputs_sdpa = model_sdpa(**prepared_inputs)
logits_eager = (
outputs_eager.hidden_states[-1]
if not is_encoder_decoder
else outputs_eager.decoder_hidden_states[-1]
)
logits_sdpa = (
outputs_sdpa.hidden_states[-1]
if not is_encoder_decoder
else outputs_sdpa.decoder_hidden_states[-1]
)
# Testing the padding tokens is not really meaningful but anyway if torch_device in ["cpu", "cuda"]:
# sub_sdpa = logits_sdpa[-1, -4:] atol = atols[torch_device, enable_kernels, torch_dtype]
# sub_eager = logits_eager[-1, -4:] rtol = rtols[torch_device, enable_kernels, torch_dtype]
# if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): else:
# fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2)) atol = 1e-7
elif padding_side == "right": rtol = 1e-4
sub_sdpa = logits_sdpa[:-1]
sub_eager = logits_eager[:-1] # Masked tokens output slightly deviates - we don't mind that.
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): if use_mask:
fail_cases.append( if padding_side == "left":
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol) sub_sdpa = logits_sdpa[:-1]
) sub_eager = logits_eager[:-1]
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
)
sub_sdpa = logits_sdpa[-1, :-4]
sub_eager = logits_eager[-1, :-4]
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
)
# Testing the padding tokens is not really meaningful but anyway
# sub_sdpa = logits_sdpa[-1, -4:]
# sub_eager = logits_eager[-1, -4:]
# if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
# fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2))
elif padding_side == "right":
sub_sdpa = logits_sdpa[:-1]
sub_eager = logits_eager[:-1]
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
)
sub_sdpa = logits_sdpa[-1, 3:]
sub_eager = logits_eager[-1, 3:]
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
)
# Testing the padding tokens is not really meaningful but anyway
# sub_sdpa = logits_sdpa[-1, :3]
# sub_eager = logits_eager[-1, :3]
# if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
# fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2))
sub_sdpa = logits_sdpa[-1, 3:] else:
sub_eager = logits_eager[-1, 3:] if not torch.allclose(logits_sdpa, logits_eager, atol=atol, rtol=rtol):
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
fail_cases.append( fail_cases.append(
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol) get_mean_reldiff(failcase, logits_sdpa, logits_eager, atol, rtol)
) )
# Testing the padding tokens is not really meaningful but anyway
# sub_sdpa = logits_sdpa[-1, :3]
# sub_eager = logits_eager[-1, :3]
# if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
# fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2))
else:
if not torch.allclose(logits_sdpa, logits_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, logits_sdpa, logits_eager, atol, rtol)
)
self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases)) self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases))
@require_torch_sdpa @require_torch_sdpa
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment