"...git@developer.sourcefind.cn:kecinstone/2024-pra-vllm.git" did not exist on "ab7e6006d62d77dca72b29721b7b346eeb6563d4"
Unverified Commit 2bdeb6f5 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Update FA version (#838)



Bump FA version to 2.5.8
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent a51ff542
...@@ -265,7 +265,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: ...@@ -265,7 +265,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
# Framework-specific requirements # Framework-specific requirements
if "pytorch" in frameworks(): if "pytorch" in frameworks():
add_unique(install_reqs, ["torch", "flash-attn>=2.0.6,<=2.4.2,!=2.0.9,!=2.1.0"]) add_unique(install_reqs, ["torch", "flash-attn>=2.0.6,<=2.5.8,!=2.0.9,!=2.1.0"])
add_unique(test_reqs, ["numpy", "onnxruntime", "torchvision"]) add_unique(test_reqs, ["numpy", "onnxruntime", "torchvision"])
if "jax" in frameworks(): if "jax" in frameworks():
if not found_pybind11(): if not found_pybind11():
......
...@@ -69,6 +69,7 @@ from transformer_engine.pytorch.graph import is_graph_capturing ...@@ -69,6 +69,7 @@ from transformer_engine.pytorch.graph import is_graph_capturing
_flash_attn_version = packaging.version.Version(version("flash-attn")) _flash_attn_version = packaging.version.Version(version("flash-attn"))
_flash_attn_version_required = packaging.version.Version("2.0.6") _flash_attn_version_required = packaging.version.Version("2.0.6")
_flash_attn_max_version = packaging.version.Version("2.5.8")
_flash_attn_2_1_plus = _flash_attn_version >= packaging.version.Version("2.1") _flash_attn_2_1_plus = _flash_attn_version >= packaging.version.Version("2.1")
_flash_attn_2_3_plus = _flash_attn_version >= packaging.version.Version("2.3") _flash_attn_2_3_plus = _flash_attn_version >= packaging.version.Version("2.3")
_flash_attn_2_4_plus = _flash_attn_version >= packaging.version.Version("2.4") _flash_attn_2_4_plus = _flash_attn_version >= packaging.version.Version("2.4")
...@@ -1931,6 +1932,9 @@ class FlashAttention(torch.nn.Module): ...@@ -1931,6 +1932,9 @@ class FlashAttention(torch.nn.Module):
assert ( assert (
_flash_attn_version >= _flash_attn_version_required _flash_attn_version >= _flash_attn_version_required
), f"FlashAttention minimum version {_flash_attn_version_required} is required." ), f"FlashAttention minimum version {_flash_attn_version_required} is required."
assert (
_flash_attn_version <= _flash_attn_max_version
), f"FlashAttention maximum version {_flash_attn_max_version} is supported."
self.norm_factor = norm_factor self.norm_factor = norm_factor
self.attention_dropout_ctx = attention_dropout_ctx self.attention_dropout_ctx = attention_dropout_ctx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment