Unverified Commit 47caafb2 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Re-add framework specific required dependencies for source build (#1124)



* Re-add framework specific required dependencies for source build
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix build
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 467b39a3
...@@ -89,6 +89,18 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: ...@@ -89,6 +89,18 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
if not found_pybind11(): if not found_pybind11():
setup_reqs.append("pybind11") setup_reqs.append("pybind11")
# Framework-specific requirements
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if "pytorch" in frameworks:
install_reqs.extend(["torch", "flash-attn>=2.0.6,<=2.5.8,!=2.0.9,!=2.1.0"])
test_reqs.extend(["numpy", "onnxruntime", "torchvision", "prettytable"])
if "jax" in frameworks:
install_reqs.extend(["jax", "flax>=0.7.1"])
test_reqs.extend(["numpy", "praxis"])
if "paddle" in frameworks:
install_reqs.append("paddlepaddle-gpu")
test_reqs.append("numpy")
return [remove_dups(reqs) for reqs in [setup_reqs, install_reqs, test_reqs]] return [remove_dups(reqs) for reqs in [setup_reqs, install_reqs, test_reqs]]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment