Unverified Commit ddcda1ff authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Manage dependencies and add missing `einops` req (#1859)



* Manage deps and add einops
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Update build.yml
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent fc185200
...@@ -87,7 +87,7 @@ jobs: ...@@ -87,7 +87,7 @@ jobs:
with: with:
submodules: recursive submodules: recursive
- name: 'Build' - name: 'Build'
run: pip install --no-build-isolation . -v run: pip install --no-build-isolation . -v --no-deps
env: env:
NVTE_FRAMEWORK: all NVTE_FRAMEWORK: all
MAX_JOBS: 1 MAX_JOBS: 1
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
"""JAX related extensions.""" """JAX related extensions."""
import os import os
import shutil
from pathlib import Path from pathlib import Path
import setuptools import setuptools
...@@ -13,6 +12,16 @@ from .utils import get_cuda_include_dirs, all_files_in_dir, debug_build_enabled ...@@ -13,6 +12,16 @@ from .utils import get_cuda_include_dirs, all_files_in_dir, debug_build_enabled
from typing import List from typing import List
def install_requirements() -> List[str]:
"""Install dependencies for TE/JAX extensions."""
return ["jax[cuda12]", "flax>=0.7.1"]
def test_requirements() -> List[str]:
"""Test dependencies for TE/JAX extensions."""
return ["numpy"]
def xla_path() -> str: def xla_path() -> str:
"""XLA root path lookup. """XLA root path lookup.
Throws FileNotFoundError if XLA source is not found.""" Throws FileNotFoundError if XLA source is not found."""
......
...@@ -9,6 +9,22 @@ from pathlib import Path ...@@ -9,6 +9,22 @@ from pathlib import Path
import setuptools import setuptools
from .utils import all_files_in_dir, cuda_version, get_cuda_include_dirs, debug_build_enabled from .utils import all_files_in_dir, cuda_version, get_cuda_include_dirs, debug_build_enabled
from typing import List
def install_requirements() -> List[str]:
"""Install dependencies for TE/JAX extensions."""
reqs = ["torch>=2.1", "einops"]
reqs.append(
"nvdlfw-inspect @"
" git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git@v0.1#egg=nvdlfw-inspect"
)
return reqs
def test_requirements() -> List[str]:
"""Test dependencies for TE/JAX extensions."""
return ["numpy", "torchvision", "transformers"]
def setup_pytorch_extension( def setup_pytorch_extension(
......
...@@ -120,19 +120,17 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: ...@@ -120,19 +120,17 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
# Framework-specific requirements # Framework-specific requirements
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if "pytorch" in frameworks: if "pytorch" in frameworks:
from build_tools.pytorch import install_requirements, test_requirements
setup_reqs.extend(["torch>=2.1"]) setup_reqs.extend(["torch>=2.1"])
install_reqs.extend(["torch>=2.1"]) install_reqs.extend(install_requirements())
install_reqs.append( test_reqs.extend(test_requirements())
"nvdlfw-inspect @"
" git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git@v0.1#egg=nvdlfw-inspect"
)
# Blackwell is not supported as of Triton 3.2.0, need custom internal build
# install_reqs.append("triton")
test_reqs.extend(["numpy", "torchvision", "transformers"])
if "jax" in frameworks: if "jax" in frameworks:
from build_tools.jax import install_requirements, test_requirements
setup_reqs.extend(["jax[cuda12]", "flax>=0.7.1"]) setup_reqs.extend(["jax[cuda12]", "flax>=0.7.1"])
install_reqs.extend(["jax", "flax>=0.7.1"]) install_reqs.extend(install_requirements())
test_reqs.extend(["numpy"]) test_reqs.extend(test_requirements())
return [remove_dups(reqs) for reqs in [setup_reqs, install_reqs, test_reqs]] return [remove_dups(reqs) for reqs in [setup_reqs, install_reqs, test_reqs]]
......
...@@ -46,7 +46,7 @@ if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or os.path.isdir(build_tools_ ...@@ -46,7 +46,7 @@ if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or os.path.isdir(build_tools_
from build_tools.build_ext import get_build_ext from build_tools.build_ext import get_build_ext
from build_tools.utils import copy_common_headers, install_and_import, cuda_toolkit_include_path from build_tools.utils import copy_common_headers, install_and_import, cuda_toolkit_include_path
from build_tools.te_version import te_version from build_tools.te_version import te_version
from build_tools.jax import setup_jax_extension from build_tools.jax import setup_jax_extension, install_requirements, test_requirements
install_and_import("pybind11") install_and_import("pybind11")
from pybind11.setup_helpers import build_ext as BuildExtension from pybind11.setup_helpers import build_ext as BuildExtension
...@@ -116,8 +116,8 @@ if __name__ == "__main__": ...@@ -116,8 +116,8 @@ if __name__ == "__main__":
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension}, cmdclass={"build_ext": CMakeBuildExtension},
setup_requires=setup_requires, setup_requires=setup_requires,
install_requires=["jax", "flax>=0.7.1"], install_requires=install_requirements(),
tests_require=["numpy"], tests_require=test_requirements(),
) )
if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")): if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")):
shutil.rmtree(common_headers_dir) shutil.rmtree(common_headers_dir)
......
...@@ -31,7 +31,7 @@ if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or os.path.isdir(build_tools_ ...@@ -31,7 +31,7 @@ if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or os.path.isdir(build_tools_
from build_tools.build_ext import get_build_ext from build_tools.build_ext import get_build_ext
from build_tools.utils import copy_common_headers, cuda_toolkit_include_path from build_tools.utils import copy_common_headers, cuda_toolkit_include_path
from build_tools.te_version import te_version from build_tools.te_version import te_version
from build_tools.pytorch import setup_pytorch_extension from build_tools.pytorch import setup_pytorch_extension, install_requirements, test_requirements
os.environ["NVTE_PROJECT_BUILDING"] = "1" os.environ["NVTE_PROJECT_BUILDING"] = "1"
...@@ -70,8 +70,8 @@ if __name__ == "__main__": ...@@ -70,8 +70,8 @@ if __name__ == "__main__":
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension}, cmdclass={"build_ext": CMakeBuildExtension},
setup_requires=setup_requires, setup_requires=setup_requires,
install_requires=["torch>=2.1"], install_requires=install_requirements(),
tests_require=["numpy", "torchvision"], tests_require=test_requirements(),
) )
if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")): if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")):
shutil.rmtree(common_headers_dir) shutil.rmtree(common_headers_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment