Unverified Commit e17c31c3 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Update FA version (#279)


Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent c67bb2fc
......@@ -290,7 +290,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
# Framework-specific requirements
if "pytorch" in frameworks():
add_unique(install_reqs, ["torch", "flash-attn==1.0.6"])
add_unique(install_reqs, ["torch", "flash-attn>=1.0.6, <=1.0.7"])
add_unique(test_reqs, ["numpy", "onnxruntime", "torchvision"])
if "jax" in frameworks():
if not found_pybind11():
......
......@@ -35,7 +35,7 @@ from transformer_engine.pytorch.distributed import (
from transformer_engine.pytorch.export import is_in_onnx_export_mode
_flash_attn_version = packaging.version.Version(version("flash-attn"))
_flash_attn_version_required = packaging.version.Version("1.0.2")
_flash_attn_version_required = packaging.version.Version("1.0.6")
__all__ = ["DotProductAttention"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment