"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "83e59d8e0bd0a523d2459d932d3c9adb68c49cc1"
Unverified Commit cd6bd0af authored by Benjamin Warner's avatar Benjamin Warner Committed by GitHub
Browse files

Add support for torch.compile dynamic shapes (#30560)

* add torch.compile dynamic support

* Add SDPA dynamic shapes compile test & improve SDPA comment

* comment consistency
parent fce78fd0
...@@ -592,6 +592,11 @@ class BartSdpaAttention(BartAttention): ...@@ -592,6 +592,11 @@ class BartSdpaAttention(BartAttention):
query_states = self._shape(query_states, tgt_len, bsz) query_states = self._shape(query_states, tgt_len, bsz)
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False
# NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
# but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
attn_output = torch.nn.functional.scaled_dot_product_attention( attn_output = torch.nn.functional.scaled_dot_product_attention(
...@@ -600,8 +605,7 @@ class BartSdpaAttention(BartAttention): ...@@ -600,8 +605,7 @@ class BartSdpaAttention(BartAttention):
value_states, value_states,
attn_mask=attention_mask, attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0, dropout_p=self.dropout if self.training else 0.0,
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. is_causal=is_causal,
is_causal=self.is_causal and attention_mask is None and tgt_len > 1,
) )
if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
......
...@@ -428,9 +428,11 @@ class BertSdpaSelfAttention(BertSelfAttention): ...@@ -428,9 +428,11 @@ class BertSdpaSelfAttention(BertSelfAttention):
key_layer = key_layer.contiguous() key_layer = key_layer.contiguous()
value_layer = value_layer.contiguous() value_layer = value_layer.contiguous()
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# mask in case tgt_len == 1. # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = self.is_decoder and attention_mask is None and tgt_len > 1 # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create
# a causal mask in case tgt_len == 1.
is_causal = True if self.is_decoder and attention_mask is None and tgt_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention( attn_output = torch.nn.functional.scaled_dot_product_attention(
query_layer, query_layer,
......
...@@ -587,8 +587,8 @@ class CohereSdpaAttention(CohereAttention): ...@@ -587,8 +587,8 @@ class CohereSdpaAttention(CohereAttention):
key_states = key_states.contiguous() key_states = key_states.contiguous()
value_states = value_states.contiguous() value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True` # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention( attn_output = torch.nn.functional.scaled_dot_product_attention(
......
...@@ -788,6 +788,11 @@ class Data2VecAudioSdpaAttention(Data2VecAudioAttention): ...@@ -788,6 +788,11 @@ class Data2VecAudioSdpaAttention(Data2VecAudioAttention):
query_states = self._shape(query_states, tgt_len, bsz) query_states = self._shape(query_states, tgt_len, bsz)
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False
# NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
# but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
attn_output = torch.nn.functional.scaled_dot_product_attention( attn_output = torch.nn.functional.scaled_dot_product_attention(
...@@ -796,8 +801,7 @@ class Data2VecAudioSdpaAttention(Data2VecAudioAttention): ...@@ -796,8 +801,7 @@ class Data2VecAudioSdpaAttention(Data2VecAudioAttention):
value_states, value_states,
attn_mask=attention_mask, attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0, dropout_p=self.dropout if self.training else 0.0,
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. is_causal=is_causal,
is_causal=self.is_causal and attention_mask is None and tgt_len > 1,
) )
if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
......
...@@ -434,16 +434,19 @@ class FalconAttention(nn.Module): ...@@ -434,16 +434,19 @@ class FalconAttention(nn.Module):
if alibi is None: if alibi is None:
if self._use_sdpa and not output_attentions: if self._use_sdpa and not output_attentions:
attn_output = F.scaled_dot_product_attention( # We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an
# inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True`
# The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not
# create a causal mask in case query_length == 1.
is_causal = True if self.is_causal and attention_mask is None and query_length > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_layer, query_layer,
key_layer, key_layer,
value_layer, value_layer,
attention_mask, attn_mask=attention_mask,
0.0, dropout_p=0.0,
# The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1. is_causal=is_causal,
is_causal=self.is_causal and attention_mask is None and query_length > 1,
) )
attention_scores = None attention_scores = None
else: else:
attention_scores = query_layer @ key_layer.transpose(-1, -2) attention_scores = query_layer @ key_layer.transpose(-1, -2)
...@@ -466,13 +469,16 @@ class FalconAttention(nn.Module): ...@@ -466,13 +469,16 @@ class FalconAttention(nn.Module):
else: else:
if self._use_sdpa and not output_attentions and head_mask is None: if self._use_sdpa and not output_attentions and head_mask is None:
attn_output = F.scaled_dot_product_attention( # We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an
# inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True`
is_causal = True if self.is_causal and attention_mask is None and query_length > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_layer, query_layer,
key_layer, key_layer,
value_layer, value_layer,
attn_mask=attention_mask, attn_mask=attention_mask,
dropout_p=self.attention_dropout.p if self.training else 0.0, dropout_p=self.attention_dropout.p if self.training else 0.0,
is_causal=self.is_causal and attention_mask is None and query_length > 1, is_causal=is_causal,
) )
attn_output = attn_output.transpose(1, 2) attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
......
...@@ -571,8 +571,8 @@ class GemmaSdpaAttention(GemmaAttention): ...@@ -571,8 +571,8 @@ class GemmaSdpaAttention(GemmaAttention):
key_states = key_states.contiguous() key_states = key_states.contiguous()
value_states = value_states.contiguous() value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True` # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention( attn_output = torch.nn.functional.scaled_dot_product_attention(
......
...@@ -549,14 +549,19 @@ class GPTBigCodeSdpaAttention(GPTBigCodeAttention): ...@@ -549,14 +549,19 @@ class GPTBigCodeSdpaAttention(GPTBigCodeAttention):
key = key.contiguous() key = key.contiguous()
value = value.contiguous() value = value.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not
# create a causal mask in case query_length == 1.
is_causal = True if self.is_causal and attention_mask is None and query_length > 1 else False
sdpa_result = torch.nn.functional.scaled_dot_product_attention( sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, query,
key, key,
value, value,
attn_mask=attention_mask, attn_mask=attention_mask,
dropout_p=self.attn_pdrop if self.training else 0.0, dropout_p=self.attn_pdrop if self.training else 0.0,
# The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1. is_causal=is_causal,
is_causal=self.is_causal and attention_mask is None and query_length > 1,
scale=scale, scale=scale,
) )
......
...@@ -852,6 +852,11 @@ class HubertSdpaAttention(HubertAttention): ...@@ -852,6 +852,11 @@ class HubertSdpaAttention(HubertAttention):
query_states = self._shape(query_states, tgt_len, bsz) query_states = self._shape(query_states, tgt_len, bsz)
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False
# NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
# but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
attn_output = torch.nn.functional.scaled_dot_product_attention( attn_output = torch.nn.functional.scaled_dot_product_attention(
...@@ -860,8 +865,7 @@ class HubertSdpaAttention(HubertAttention): ...@@ -860,8 +865,7 @@ class HubertSdpaAttention(HubertAttention):
value_states, value_states,
attn_mask=attention_mask, attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0, dropout_p=self.dropout if self.training else 0.0,
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. is_causal=is_causal,
is_causal=self.is_causal and attention_mask is None and tgt_len > 1,
) )
if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
......
...@@ -660,14 +660,18 @@ class IdeficsAttention(nn.Module): ...@@ -660,14 +660,18 @@ class IdeficsAttention(nn.Module):
key_states = key_states.contiguous() key_states = key_states.contiguous()
value_states = value_states.contiguous() value_states = value_states.contiguous()
attn_output = nn.functional.scaled_dot_product_attention( # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, query_states,
key_states, key_states,
value_states, value_states,
attn_mask=attention_mask, attn_mask=attention_mask,
dropout_p=self.dropout, dropout_p=self.dropout if self.training else 0.0,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. is_causal=is_causal,
is_causal=self.is_causal and attention_mask is None and q_len > 1,
) )
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
......
...@@ -643,8 +643,8 @@ class LlamaSdpaAttention(LlamaAttention): ...@@ -643,8 +643,8 @@ class LlamaSdpaAttention(LlamaAttention):
key_states = key_states.contiguous() key_states = key_states.contiguous()
value_states = value_states.contiguous() value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True` # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention( attn_output = torch.nn.functional.scaled_dot_product_attention(
......
...@@ -676,14 +676,18 @@ class MistralSdpaAttention(MistralAttention): ...@@ -676,14 +676,18 @@ class MistralSdpaAttention(MistralAttention):
key_states = key_states.contiguous() key_states = key_states.contiguous()
value_states = value_states.contiguous() value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention( attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, query_states,
key_states, key_states,
value_states, value_states,
attn_mask=attention_mask, attn_mask=attention_mask,
dropout_p=self.attention_dropout if self.training else 0.0, dropout_p=self.attention_dropout if self.training else 0.0,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. is_causal=is_causal,
is_causal=self.is_causal and attention_mask is None and q_len > 1,
) )
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
......
...@@ -749,14 +749,18 @@ class MixtralSdpaAttention(MixtralAttention): ...@@ -749,14 +749,18 @@ class MixtralSdpaAttention(MixtralAttention):
key_states = key_states.contiguous() key_states = key_states.contiguous()
value_states = value_states.contiguous() value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention( attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, query_states,
key_states, key_states,
value_states, value_states,
attn_mask=attention_mask, attn_mask=attention_mask,
dropout_p=self.attention_dropout if self.training else 0.0, dropout_p=self.attention_dropout if self.training else 0.0,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. is_causal=is_causal,
is_causal=self.is_causal and attention_mask is None and q_len > 1,
) )
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
......
...@@ -618,6 +618,11 @@ class MusicgenSdpaAttention(MusicgenAttention): ...@@ -618,6 +618,11 @@ class MusicgenSdpaAttention(MusicgenAttention):
query_states = self._shape(query_states, tgt_len, bsz) query_states = self._shape(query_states, tgt_len, bsz)
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False
# NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
# but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
attn_output = torch.nn.functional.scaled_dot_product_attention( attn_output = torch.nn.functional.scaled_dot_product_attention(
...@@ -626,8 +631,7 @@ class MusicgenSdpaAttention(MusicgenAttention): ...@@ -626,8 +631,7 @@ class MusicgenSdpaAttention(MusicgenAttention):
value_states, value_states,
attn_mask=attention_mask, attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0, dropout_p=self.dropout if self.training else 0.0,
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. is_causal=is_causal,
is_causal=self.is_causal and attention_mask is None and tgt_len > 1,
) )
if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
......
...@@ -634,6 +634,11 @@ class MusicgenMelodySdpaAttention(MusicgenMelodyAttention): ...@@ -634,6 +634,11 @@ class MusicgenMelodySdpaAttention(MusicgenMelodyAttention):
query_states = self._shape(query_states, tgt_len, bsz) query_states = self._shape(query_states, tgt_len, bsz)
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False
# NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
# but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
attn_output = torch.nn.functional.scaled_dot_product_attention( attn_output = torch.nn.functional.scaled_dot_product_attention(
...@@ -642,8 +647,7 @@ class MusicgenMelodySdpaAttention(MusicgenMelodyAttention): ...@@ -642,8 +647,7 @@ class MusicgenMelodySdpaAttention(MusicgenMelodyAttention):
value_states, value_states,
attn_mask=attention_mask, attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0, dropout_p=self.dropout if self.training else 0.0,
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. is_causal=is_causal,
is_causal=self.is_causal and attention_mask is None and tgt_len > 1,
) )
if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
......
...@@ -617,8 +617,8 @@ class OlmoSdpaAttention(OlmoAttention): ...@@ -617,8 +617,8 @@ class OlmoSdpaAttention(OlmoAttention):
key_states = key_states.contiguous() key_states = key_states.contiguous()
value_states = value_states.contiguous() value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True` # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention( attn_output = torch.nn.functional.scaled_dot_product_attention(
......
...@@ -709,13 +709,17 @@ class PhiSdpaAttention(PhiAttention): ...@@ -709,13 +709,17 @@ class PhiSdpaAttention(PhiAttention):
key_states = key_states.contiguous() key_states = key_states.contiguous()
value_states = value_states.contiguous() value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention( attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, query_states,
key_states, key_states,
value_states, value_states,
attn_mask=attention_mask, attn_mask=attention_mask,
dropout_p=self.attention_dropout if self.training else 0.0, dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=self.is_causal and attention_mask is None and q_len > 1, is_causal=is_causal,
) )
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
......
...@@ -778,14 +778,18 @@ class Phi3SdpaAttention(Phi3Attention): ...@@ -778,14 +778,18 @@ class Phi3SdpaAttention(Phi3Attention):
key_states = key_states.contiguous() key_states = key_states.contiguous()
value_states = value_states.contiguous() value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention( attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, query_states,
key_states, key_states,
value_states, value_states,
attn_mask=attention_mask, attn_mask=attention_mask,
dropout_p=self.attention_dropout if self.training else 0.0, dropout_p=self.attention_dropout if self.training else 0.0,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. is_causal=is_causal,
is_causal=self.is_causal and attention_mask is None and q_len > 1,
) )
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
......
...@@ -681,14 +681,18 @@ class Qwen2SdpaAttention(Qwen2Attention): ...@@ -681,14 +681,18 @@ class Qwen2SdpaAttention(Qwen2Attention):
key_states = key_states.contiguous() key_states = key_states.contiguous()
value_states = value_states.contiguous() value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention( attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, query_states,
key_states, key_states,
value_states, value_states,
attn_mask=attention_mask, attn_mask=attention_mask,
dropout_p=self.attention_dropout if self.training else 0.0, dropout_p=self.attention_dropout if self.training else 0.0,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. is_causal=is_causal,
is_causal=self.is_causal and attention_mask is None and q_len > 1,
) )
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
......
...@@ -759,14 +759,18 @@ class Qwen2MoeSdpaAttention(Qwen2MoeAttention): ...@@ -759,14 +759,18 @@ class Qwen2MoeSdpaAttention(Qwen2MoeAttention):
key_states = key_states.contiguous() key_states = key_states.contiguous()
value_states = value_states.contiguous() value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention( attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, query_states,
key_states, key_states,
value_states, value_states,
attn_mask=attention_mask, attn_mask=attention_mask,
dropout_p=self.attention_dropout if self.training else 0.0, dropout_p=self.attention_dropout if self.training else 0.0,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. is_causal=is_causal,
is_causal=self.is_causal and attention_mask is None and q_len > 1,
) )
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
......
...@@ -852,6 +852,11 @@ class SEWSdpaAttention(SEWAttention): ...@@ -852,6 +852,11 @@ class SEWSdpaAttention(SEWAttention):
query_states = self._shape(query_states, tgt_len, bsz) query_states = self._shape(query_states, tgt_len, bsz)
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False
# NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
# but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
attn_output = torch.nn.functional.scaled_dot_product_attention( attn_output = torch.nn.functional.scaled_dot_product_attention(
...@@ -860,8 +865,7 @@ class SEWSdpaAttention(SEWAttention): ...@@ -860,8 +865,7 @@ class SEWSdpaAttention(SEWAttention):
value_states, value_states,
attn_mask=attention_mask, attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0, dropout_p=self.dropout if self.training else 0.0,
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. is_causal=is_causal,
is_causal=self.is_causal and attention_mask is None and tgt_len > 1,
) )
if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment