"git@developer.sourcefind.cn:gaoqiong/pybind11.git" did not exist on "18319d5598e874168b00048bf093721a8755c611"
Unverified Commit f2bd53c4 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Bump FlashAttn version and add deterministic option for FAv2 (#585)



* Deterministic FA, bump minimum supported version
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix MQA/GQA
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Address review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent e2a75314
...@@ -284,7 +284,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: ...@@ -284,7 +284,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
# Framework-specific requirements # Framework-specific requirements
if "pytorch" in frameworks(): if "pytorch" in frameworks():
add_unique(install_reqs, ["torch", "flash-attn>=1.0.6,<=2.3.3,!=2.0.9,!=2.1.0"]) add_unique(install_reqs, ["torch", "flash-attn>=2.0.6,<=2.4.2,!=2.0.9,!=2.1.0"])
add_unique(test_reqs, ["numpy", "onnxruntime", "torchvision"]) add_unique(test_reqs, ["numpy", "onnxruntime", "torchvision"])
if "jax" in frameworks(): if "jax" in frameworks():
if not found_pybind11(): if not found_pybind11():
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment