Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
f0ddab82
Unverified
Commit
f0ddab82
authored
Jul 19, 2023
by
Kirthi Shankar Sivamani
Committed by
GitHub
Jul 19, 2023
Browse files
Relax FA 2.0 checks for Ada (#331)
Signed-off-by:
Kirthi Shankar Sivamani
<
ksivamani@nvidia.com
>
parent
10eb13e2
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
1 deletion
+6
-1
transformer_engine/pytorch/attention.py
transformer_engine/pytorch/attention.py
+6
-1
No files found.
transformer_engine/pytorch/attention.py
View file @
f0ddab82
...
...
@@ -885,10 +885,15 @@ class DotProductAttention(torch.nn.Module):
if
(
query_layer
.
dtype
not
in
[
torch
.
bfloat16
,
torch
.
float16
]
or
key_layer
.
dtype
not
in
[
torch
.
bfloat16
,
torch
.
float16
]
or
value_layer
.
dtype
not
in
[
torch
.
bfloat16
,
torch
.
float16
]
or
(
self
.
device_compute_capability
in
(
8.6
,
8.7
,
8.9
)
and
key_layer
.
shape
[
-
1
]
>
64
)
):
use_flash_attention
=
False
if
key_layer
.
shape
[
-
1
]
>
64
:
if
self
.
device_compute_capability
in
(
8.6
,
8.7
):
use_flash_attention
=
False
elif
not
_flash_attn_2_available
and
self
.
device_compute_capability
==
8.9
:
use_flash_attention
=
False
if
self
.
attn_mask_type
==
"padding"
and
attention_mask
is
not
None
:
use_flash_attention
=
False
use_fused_attention
=
False
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment