Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
965803c9
Unverified
Commit
965803c9
authored
Mar 20, 2024
by
Kirthi Shankar Sivamani
Committed by
GitHub
Mar 20, 2024
Browse files
Update FA version to 2.5.6 (#714)
Signed-off-by:
Kirthi Shankar Sivamani
<
ksivamani@nvidia.com
>
parent
a3ba77b8
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
5 additions
and
1 deletion
+5
-1
setup.py
setup.py
+1
-1
transformer_engine/pytorch/attention.py
transformer_engine/pytorch/attention.py
+4
-0
No files found.
setup.py
View file @
965803c9
...
@@ -265,7 +265,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
...
@@ -265,7 +265,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
# Framework-specific requirements
# Framework-specific requirements
if
"pytorch"
in
frameworks
():
if
"pytorch"
in
frameworks
():
add_unique
(
install_reqs
,
[
"torch"
,
"flash-attn>=2.0.6,<=2.
4.2
,!=2.0.9,!=2.1.0"
])
add_unique
(
install_reqs
,
[
"torch"
,
"flash-attn>=2.0.6,<=2.
5.6
,!=2.0.9,!=2.1.0"
])
add_unique
(
test_reqs
,
[
"numpy"
,
"onnxruntime"
,
"torchvision"
])
add_unique
(
test_reqs
,
[
"numpy"
,
"onnxruntime"
,
"torchvision"
])
if
"jax"
in
frameworks
():
if
"jax"
in
frameworks
():
if
not
found_pybind11
():
if
not
found_pybind11
():
...
...
transformer_engine/pytorch/attention.py
View file @
965803c9
...
@@ -58,6 +58,7 @@ from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo
...
@@ -58,6 +58,7 @@ from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo
_flash_attn_version
=
packaging
.
version
.
Version
(
version
(
"flash-attn"
))
_flash_attn_version
=
packaging
.
version
.
Version
(
version
(
"flash-attn"
))
_flash_attn_version_required
=
packaging
.
version
.
Version
(
"2.0.6"
)
_flash_attn_version_required
=
packaging
.
version
.
Version
(
"2.0.6"
)
_flash_attn_max_version
=
packaging
.
version
.
Version
(
"2.5.6"
)
_flash_attn_2_1_plus
=
_flash_attn_version
>=
packaging
.
version
.
Version
(
"2.1"
)
_flash_attn_2_1_plus
=
_flash_attn_version
>=
packaging
.
version
.
Version
(
"2.1"
)
_flash_attn_2_3_plus
=
_flash_attn_version
>=
packaging
.
version
.
Version
(
"2.3"
)
_flash_attn_2_3_plus
=
_flash_attn_version
>=
packaging
.
version
.
Version
(
"2.3"
)
_flash_attn_2_4_plus
=
_flash_attn_version
>=
packaging
.
version
.
Version
(
"2.4"
)
_flash_attn_2_4_plus
=
_flash_attn_version
>=
packaging
.
version
.
Version
(
"2.4"
)
...
@@ -1656,6 +1657,9 @@ class FlashAttention(torch.nn.Module):
...
@@ -1656,6 +1657,9 @@ class FlashAttention(torch.nn.Module):
assert
(
assert
(
_flash_attn_version
>=
_flash_attn_version_required
_flash_attn_version
>=
_flash_attn_version_required
),
f
"FlashAttention minimum version
{
_flash_attn_version_required
}
is required."
),
f
"FlashAttention minimum version
{
_flash_attn_version_required
}
is required."
assert
(
_flash_attn_version
<=
_flash_attn_max_version
),
f
"FlashAttention maximum version
{
_flash_attn_max_version
}
is supported."
self
.
norm_factor
=
norm_factor
self
.
norm_factor
=
norm_factor
self
.
attention_dropout_ctx
=
attention_dropout_ctx
self
.
attention_dropout_ctx
=
attention_dropout_ctx
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment