Commit 47276e1b authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani
Browse files

Revert "Update FA version to 2.5.6 (#714)"

This reverts commit 965803c9.
parent 2dd6b146
......@@ -265,7 +265,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
# Framework-specific requirements
if "pytorch" in frameworks():
add_unique(install_reqs, ["torch", "flash-attn>=2.0.6,<=2.5.6,!=2.0.9,!=2.1.0"])
add_unique(install_reqs, ["torch", "flash-attn>=2.0.6,<=2.4.2,!=2.0.9,!=2.1.0"])
add_unique(test_reqs, ["numpy", "onnxruntime", "torchvision"])
if "jax" in frameworks():
if not found_pybind11():
......
......@@ -58,7 +58,6 @@ from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo
_flash_attn_version = packaging.version.Version(version("flash-attn"))
_flash_attn_version_required = packaging.version.Version("2.0.6")
_flash_attn_max_version = packaging.version.Version("2.5.6")
_flash_attn_2_1_plus = _flash_attn_version >= packaging.version.Version("2.1")
_flash_attn_2_3_plus = _flash_attn_version >= packaging.version.Version("2.3")
_flash_attn_2_4_plus = _flash_attn_version >= packaging.version.Version("2.4")
......@@ -1658,9 +1657,6 @@ class FlashAttention(torch.nn.Module):
assert (
_flash_attn_version >= _flash_attn_version_required
), f"FlashAttention minimum version {_flash_attn_version_required} is required."
assert (
_flash_attn_version <= _flash_attn_max_version
), f"FlashAttention maximum version {_flash_attn_max_version} is supported."
self.norm_factor = norm_factor
self.attention_dropout_ctx = attention_dropout_ctx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment