Unverified Commit 9df8b301 authored by Benjamin Warner's avatar Benjamin Warner Committed by GitHub
Browse files

Reenable SDPA's FA2 During Training with torch.compile (#30442)

* Reenable SDPA's FA2 during training with torch.compile

* fix Olmo's SDPA FA2 dispatching too

* update formatting

* improved SDPA comment

* formatting and explanatory comment

* is_causal if statement to one-liner
parent 87be06ca
...@@ -240,6 +240,7 @@ class AttentionMaskConverter: ...@@ -240,6 +240,7 @@ class AttentionMaskConverter:
inputs_embeds: torch.Tensor, inputs_embeds: torch.Tensor,
past_key_values_length: int, past_key_values_length: int,
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
is_training: bool = False,
) -> bool: ) -> bool:
""" """
Detects whether the optional user-specified attention_mask & the automatically created causal mask can be ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument. Detects whether the optional user-specified attention_mask & the automatically created causal mask can be ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument.
...@@ -263,11 +264,11 @@ class AttentionMaskConverter: ...@@ -263,11 +264,11 @@ class AttentionMaskConverter:
if attention_mask is None: if attention_mask is None:
# TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input shape, thus SDPA's `is_causal` argument is rightfully updated (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using `torch.export` or # TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input shape, thus SDPA's `is_causal` argument is rightfully updated (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using `torch.export` or
# or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True` which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108). # or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True` which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108).
# Thus, we currently can NOT set `ignore_causal_mask = True` here. We would need a `torch._dynamo.is_exporting()` flag. # Thus, we only set `ignore_causal_mask = True` if the model is set to training.
# #
# Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal` (`TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor`). # Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal` (`TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor`).
if ( if (
not is_tracing (is_training or not is_tracing)
and (query_length == 1 or key_value_length == query_length) and (query_length == 1 or key_value_length == query_length)
and (sliding_window is None or key_value_length < sliding_window) and (sliding_window is None or key_value_length < sliding_window)
): ):
...@@ -279,7 +280,7 @@ class AttentionMaskConverter: ...@@ -279,7 +280,7 @@ class AttentionMaskConverter:
raise ValueError( raise ValueError(
f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}." f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
) )
elif not is_tracing and torch.all(attention_mask == 1): elif (is_training or not is_tracing) and torch.all(attention_mask == 1):
if query_length == 1 or key_value_length == query_length: if query_length == 1 or key_value_length == query_length:
# For query_length == 1, causal attention and bi-directional attention are the same. # For query_length == 1, causal attention and bi-directional attention are the same.
ignore_causal_mask = True ignore_causal_mask = True
......
...@@ -590,15 +590,17 @@ class CohereSdpaAttention(CohereAttention): ...@@ -590,15 +590,17 @@ class CohereSdpaAttention(CohereAttention):
key_states = key_states.contiguous() key_states = key_states.contiguous()
value_states = value_states.contiguous() value_states = value_states.contiguous()
# In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather # We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an
# relying on the `is_causal` argument. # inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True`
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention( attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, query_states,
key_states, key_states,
value_states, value_states,
attn_mask=causal_mask, attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0, dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=causal_mask is None and q_len > 1, is_causal=is_causal,
) )
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
...@@ -996,7 +998,10 @@ class CohereModel(CoherePreTrainedModel): ...@@ -996,7 +998,10 @@ class CohereModel(CoherePreTrainedModel):
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument,
# in order to dispatch on Flash Attention 2. # in order to dispatch on Flash Attention 2.
if AttentionMaskConverter._ignore_causal_mask_sdpa( if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
): ):
return None return None
......
...@@ -570,15 +570,17 @@ class GemmaSdpaAttention(GemmaAttention): ...@@ -570,15 +570,17 @@ class GemmaSdpaAttention(GemmaAttention):
key_states = key_states.contiguous() key_states = key_states.contiguous()
value_states = value_states.contiguous() value_states = value_states.contiguous()
# In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather # We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an
# relying on the `is_causal` argument. # inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True`
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention( attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, query_states,
key_states, key_states,
value_states, value_states,
attn_mask=causal_mask, attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0, dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=causal_mask is None and q_len > 1, is_causal=is_causal,
) )
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
...@@ -982,7 +984,10 @@ class GemmaModel(GemmaPreTrainedModel): ...@@ -982,7 +984,10 @@ class GemmaModel(GemmaPreTrainedModel):
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument,
# in order to dispatch on Flash Attention 2. # in order to dispatch on Flash Attention 2.
if AttentionMaskConverter._ignore_causal_mask_sdpa( if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
): ):
return None return None
......
...@@ -666,15 +666,17 @@ class LlamaSdpaAttention(LlamaAttention): ...@@ -666,15 +666,17 @@ class LlamaSdpaAttention(LlamaAttention):
key_states = key_states.contiguous() key_states = key_states.contiguous()
value_states = value_states.contiguous() value_states = value_states.contiguous()
# In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather # We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an
# relying on the `is_causal` argument. # inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True`
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention( attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, query_states,
key_states, key_states,
value_states, value_states,
attn_mask=causal_mask, attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0, dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=causal_mask is None and q_len > 1, is_causal=is_causal,
) )
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
...@@ -1074,7 +1076,10 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -1074,7 +1076,10 @@ class LlamaModel(LlamaPreTrainedModel):
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument,
# in order to dispatch on Flash Attention 2. # in order to dispatch on Flash Attention 2.
if AttentionMaskConverter._ignore_causal_mask_sdpa( if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
): ):
return None return None
......
...@@ -647,13 +647,17 @@ class OlmoSdpaAttention(OlmoAttention): ...@@ -647,13 +647,17 @@ class OlmoSdpaAttention(OlmoAttention):
key_states = key_states.contiguous() key_states = key_states.contiguous()
value_states = value_states.contiguous() value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an
# inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True`
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention( attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, query_states,
key_states, key_states,
value_states, value_states,
attn_mask=causal_mask, attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0, dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=causal_mask is None and q_len > 1, is_causal=is_causal,
) )
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
...@@ -1057,7 +1061,10 @@ class OlmoModel(OlmoPreTrainedModel): ...@@ -1057,7 +1061,10 @@ class OlmoModel(OlmoPreTrainedModel):
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument,
# in order to dispatch on Flash Attention 2. # in order to dispatch on Flash Attention 2.
if AttentionMaskConverter._ignore_causal_mask_sdpa( if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
): ):
return None return None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment