Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
046c2ad7
Unverified
Commit
046c2ad7
authored
May 23, 2024
by
Benjamin Warner
Committed by
GitHub
May 23, 2024
Browse files
Finish adding support for torch.compile dynamic shapes (#30919)
add torch.compile dynamic support
parent
6739e1d2
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
16 additions
and
6 deletions
+16
-6
src/transformers/models/dbrx/modeling_dbrx.py
src/transformers/models/dbrx/modeling_dbrx.py
+5
-1
src/transformers/models/jamba/modeling_jamba.py
src/transformers/models/jamba/modeling_jamba.py
+6
-2
src/transformers/models/jetmoe/modeling_jetmoe.py
src/transformers/models/jetmoe/modeling_jetmoe.py
+5
-3
No files found.
src/transformers/models/dbrx/modeling_dbrx.py
View file @
046c2ad7
...
...
@@ -615,13 +615,17 @@ class DbrxSdpaAttention(DbrxAttention):
key_states
=
key_states
.
contiguous
()
value_states
=
value_states
.
contiguous
()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal
=
True
if
causal_mask
is
None
and
q_len
>
1
else
False
attn_output
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
query_states
,
key_states
,
value_states
,
attn_mask
=
causal_mask
,
dropout_p
=
self
.
attn_pdrop
if
self
.
training
else
0.0
,
is_causal
=
causal
_mask
is
None
and
q_len
>
1
,
is_causal
=
is_
causal
,
)
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
...
...
src/transformers/models/jamba/modeling_jamba.py
View file @
046c2ad7
...
...
@@ -725,14 +725,18 @@ class JambaSdpaAttention(JambaAttention):
key_states
=
key_states
.
contiguous
()
value_states
=
value_states
.
contiguous
()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal
=
True
if
self
.
is_causal
and
causal_mask
is
None
and
q_len
>
1
else
False
attn_output
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
query_states
,
key_states
,
value_states
,
attn_mask
=
causal_mask
,
dropout_p
=
self
.
attention_dropout
if
self
.
training
else
0.0
,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal
=
self
.
is_causal
and
attention_mask
is
None
and
q_len
>
1
,
is_causal
=
is_causal
,
)
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
...
...
src/transformers/models/jetmoe/modeling_jetmoe.py
View file @
046c2ad7
...
...
@@ -624,15 +624,17 @@ class JetMoeSdpaAttention(JetMoeAttention):
key_states
=
key_states
.
contiguous
()
value_states
=
value_states
.
contiguous
()
# In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather
# relying on the `is_causal` argument.
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal
=
True
if
causal_mask
is
None
and
q_len
>
1
else
False
attn_output
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
query_states
,
key_states
,
value_states
,
attn_mask
=
causal_mask
,
dropout_p
=
self
.
attention_dropout
if
self
.
training
else
0.0
,
is_causal
=
causal
_mask
is
None
and
q_len
>
1
,
is_causal
=
is_
causal
,
)
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment