Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
f70b4bbf
Unverified
Commit
f70b4bbf
authored
Jun 06, 2023
by
Ming-Xu Huang
Committed by
GitHub
Jun 05, 2023
Browse files
[JAX] Enhance fall-back conditions for fMHA. (#260)
Signed-off-by:
Ming Huang
<
mingh@nvidia.com
>
parent
144e4888
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
1 deletion
+3
-1
transformer_engine/jax/flax/transformer.py
transformer_engine/jax/flax/transformer.py
+3
-1
No files found.
transformer_engine/jax/flax/transformer.py
View file @
f70b4bbf
...
@@ -374,7 +374,7 @@ class MultiHeadAttention(nn.Module):
...
@@ -374,7 +374,7 @@ class MultiHeadAttention(nn.Module):
use_fused_attn
=
not
decode
and
not
self
.
transpose_batch_sequence
and
self
.
fuse_qkv
and
\
use_fused_attn
=
not
decode
and
not
self
.
transpose_batch_sequence
and
self
.
fuse_qkv
and
\
self
.
dropout_rate
==
0
and
canonicalize_dtype
in
[
jnp
.
bfloat16
,
jnp
.
float16
]
and
\
self
.
dropout_rate
==
0
and
canonicalize_dtype
in
[
jnp
.
bfloat16
,
jnp
.
float16
]
and
\
q_seqlen
in
fused_attn_supported_seqlen
and
kv_seqlen
in
fused_attn_supported_seqlen
\
q_seqlen
in
fused_attn_supported_seqlen
and
kv_seqlen
in
fused_attn_supported_seqlen
\
and
is_fused_attn_kernel_available
()
and
enable_fused_attn
and
is_fused_attn_kernel_available
()
and
(
self
.
head_dim
==
64
)
and
enable_fused_attn
if
enable_fused_attn
and
not
use_fused_attn
:
if
enable_fused_attn
and
not
use_fused_attn
:
reason
=
""
reason
=
""
...
@@ -399,6 +399,8 @@ class MultiHeadAttention(nn.Module):
...
@@ -399,6 +399,8 @@ class MultiHeadAttention(nn.Module):
f
"but got
{
kv_seqlen
=
}
, "
f
"but got
{
kv_seqlen
=
}
, "
if
not
is_fused_attn_kernel_available
():
if
not
is_fused_attn_kernel_available
():
reason
+=
"GPU arch >= Ampere and cuDNN >= 8.9.1 are required, "
reason
+=
"GPU arch >= Ampere and cuDNN >= 8.9.1 are required, "
if
self
.
head_dim
!=
64
:
reason
+=
f
"head_dim should be 64 but got
{
self
.
head_dim
}
, "
warnings
.
warn
(
warnings
.
warn
(
f
"Fused attention is not enabled, "
\
f
"Fused attention is not enabled, "
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment