"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "badb9d2aaa58df2fddc09a868d8e3e5655b101a3"
Unverified Commit 046c2ad7 authored by Benjamin Warner's avatar Benjamin Warner Committed by GitHub
Browse files

Finish adding support for torch.compile dynamic shapes (#30919)

add torch.compile dynamic support
parent 6739e1d2
...@@ -615,13 +615,17 @@ class DbrxSdpaAttention(DbrxAttention): ...@@ -615,13 +615,17 @@ class DbrxSdpaAttention(DbrxAttention):
key_states = key_states.contiguous() key_states = key_states.contiguous()
value_states = value_states.contiguous() value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention( attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, query_states,
key_states, key_states,
value_states, value_states,
attn_mask=causal_mask, attn_mask=causal_mask,
dropout_p=self.attn_pdrop if self.training else 0.0, dropout_p=self.attn_pdrop if self.training else 0.0,
is_causal=causal_mask is None and q_len > 1, is_causal=is_causal,
) )
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
......
...@@ -725,14 +725,18 @@ class JambaSdpaAttention(JambaAttention): ...@@ -725,14 +725,18 @@ class JambaSdpaAttention(JambaAttention):
key_states = key_states.contiguous() key_states = key_states.contiguous()
value_states = value_states.contiguous() value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal = True if self.is_causal and causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention( attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, query_states,
key_states, key_states,
value_states, value_states,
attn_mask=causal_mask, attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0, dropout_p=self.attention_dropout if self.training else 0.0,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. is_causal=is_causal,
is_causal=self.is_causal and attention_mask is None and q_len > 1,
) )
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
......
...@@ -624,15 +624,17 @@ class JetMoeSdpaAttention(JetMoeAttention): ...@@ -624,15 +624,17 @@ class JetMoeSdpaAttention(JetMoeAttention):
key_states = key_states.contiguous() key_states = key_states.contiguous()
value_states = value_states.contiguous() value_states = value_states.contiguous()
# In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# relying on the `is_causal` argument. # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention( attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, query_states,
key_states, key_states,
value_states, value_states,
attn_mask=causal_mask, attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0, dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=causal_mask is None and q_len > 1, is_causal=is_causal,
) )
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment