Unverified Commit 965803c9 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Update FA version to 2.5.6 (#714)


Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent a3ba77b8
...@@ -265,7 +265,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: ...@@ -265,7 +265,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
# Framework-specific requirements # Framework-specific requirements
if "pytorch" in frameworks(): if "pytorch" in frameworks():
add_unique(install_reqs, ["torch", "flash-attn>=2.0.6,<=2.4.2,!=2.0.9,!=2.1.0"]) add_unique(install_reqs, ["torch", "flash-attn>=2.0.6,<=2.5.6,!=2.0.9,!=2.1.0"])
add_unique(test_reqs, ["numpy", "onnxruntime", "torchvision"]) add_unique(test_reqs, ["numpy", "onnxruntime", "torchvision"])
if "jax" in frameworks(): if "jax" in frameworks():
if not found_pybind11(): if not found_pybind11():
......
...@@ -58,6 +58,7 @@ from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo ...@@ -58,6 +58,7 @@ from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo
_flash_attn_version = packaging.version.Version(version("flash-attn")) _flash_attn_version = packaging.version.Version(version("flash-attn"))
_flash_attn_version_required = packaging.version.Version("2.0.6") _flash_attn_version_required = packaging.version.Version("2.0.6")
_flash_attn_max_version = packaging.version.Version("2.5.6")
_flash_attn_2_1_plus = _flash_attn_version >= packaging.version.Version("2.1") _flash_attn_2_1_plus = _flash_attn_version >= packaging.version.Version("2.1")
_flash_attn_2_3_plus = _flash_attn_version >= packaging.version.Version("2.3") _flash_attn_2_3_plus = _flash_attn_version >= packaging.version.Version("2.3")
_flash_attn_2_4_plus = _flash_attn_version >= packaging.version.Version("2.4") _flash_attn_2_4_plus = _flash_attn_version >= packaging.version.Version("2.4")
...@@ -1656,6 +1657,9 @@ class FlashAttention(torch.nn.Module): ...@@ -1656,6 +1657,9 @@ class FlashAttention(torch.nn.Module):
assert ( assert (
_flash_attn_version >= _flash_attn_version_required _flash_attn_version >= _flash_attn_version_required
), f"FlashAttention minimum version {_flash_attn_version_required} is required." ), f"FlashAttention minimum version {_flash_attn_version_required} is required."
assert (
_flash_attn_version <= _flash_attn_max_version
), f"FlashAttention maximum version {_flash_attn_max_version} is supported."
self.norm_factor = norm_factor self.norm_factor = norm_factor
self.attention_dropout_ctx = attention_dropout_ctx self.attention_dropout_ctx = attention_dropout_ctx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment