"...git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "c49f90d34eb04fedc348eb5f7a61481ad19c09c3"
Unverified Commit aedd7e10 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

pyproject.toml (#1852)



* Initial basic setup
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* rm setup reqs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* buil-isolation support
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* rm not needed funcs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix workflows
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix wheel
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix invalid wheel
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix JAX build in baremetal env
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Update install inst in readme
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Update build.yml
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* docstring fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent faee0e8b
...@@ -18,8 +18,8 @@ jobs: ...@@ -18,8 +18,8 @@ jobs:
- name: 'Dependencies' - name: 'Dependencies'
run: | run: |
apt-get update apt-get update
apt-get install -y git python3.9 pip ninja-build cudnn9-cuda-12 apt-get install -y git python3.9 pip cudnn9-cuda-12
pip install cmake==3.21.0 pip install cmake==3.21.0 pybind11[global] ninja
- name: 'Checkout' - name: 'Checkout'
uses: actions/checkout@v3 uses: actions/checkout@v3
with: with:
...@@ -42,8 +42,8 @@ jobs: ...@@ -42,8 +42,8 @@ jobs:
- name: 'Dependencies' - name: 'Dependencies'
run: | run: |
apt-get update apt-get update
apt-get install -y git python3.9 pip ninja-build cudnn9-cuda-12 apt-get install -y git python3.9 pip cudnn9-cuda-12
pip install cmake torch pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops
- name: 'Checkout' - name: 'Checkout'
uses: actions/checkout@v3 uses: actions/checkout@v3
with: with:
...@@ -62,6 +62,8 @@ jobs: ...@@ -62,6 +62,8 @@ jobs:
image: ghcr.io/nvidia/jax:jax image: ghcr.io/nvidia/jax:jax
options: --user root options: --user root
steps: steps:
- name: 'Dependencies'
run: pip install pybind11[global]
- name: 'Checkout' - name: 'Checkout'
uses: actions/checkout@v3 uses: actions/checkout@v3
with: with:
......
...@@ -216,13 +216,13 @@ Alternatively, install directly from the GitHub repository: ...@@ -216,13 +216,13 @@ Alternatively, install directly from the GitHub repository:
.. code-block:: bash .. code-block:: bash
pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable pip install --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@stable
When installing from GitHub, you can explicitly specify frameworks using the environment variable: When installing from GitHub, you can explicitly specify frameworks using the environment variable:
.. code-block:: bash .. code-block:: bash
NVTE_FRAMEWORK=pytorch,jax pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable NVTE_FRAMEWORK=pytorch,jax pip install --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@stable
conda Installation conda Installation
^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^
......
...@@ -14,7 +14,7 @@ from typing import List ...@@ -14,7 +14,7 @@ from typing import List
def install_requirements() -> List[str]: def install_requirements() -> List[str]:
"""Install dependencies for TE/JAX extensions.""" """Install dependencies for TE/JAX extensions."""
return ["jax[cuda12]", "flax>=0.7.1"] return ["jax", "flax>=0.7.1"]
def test_requirements() -> List[str]: def test_requirements() -> List[str]:
...@@ -75,20 +75,9 @@ def setup_jax_extension( ...@@ -75,20 +75,9 @@ def setup_jax_extension(
# Define TE/JAX as a Pybind11Extension # Define TE/JAX as a Pybind11Extension
from pybind11.setup_helpers import Pybind11Extension from pybind11.setup_helpers import Pybind11Extension
class Pybind11CPPExtension(Pybind11Extension): return Pybind11Extension(
"""Modified Pybind11Extension to allow custom CXX flags."""
def _add_cflags(self, flags: List[str]) -> None:
if isinstance(self.extra_compile_args, dict):
cxx_flags = self.extra_compile_args.pop("cxx", [])
cxx_flags += flags
self.extra_compile_args["cxx"] = cxx_flags
else:
self.extra_compile_args[:0] = flags
return Pybind11CPPExtension(
"transformer_engine_jax", "transformer_engine_jax",
sources=[str(path) for path in sources], sources=[str(path) for path in sources],
include_dirs=[str(path) for path in include_dirs], include_dirs=[str(path) for path in include_dirs],
extra_compile_args={"cxx": cxx_flags}, extra_compile_args=cxx_flags,
) )
...@@ -354,10 +354,3 @@ def copy_common_headers( ...@@ -354,10 +354,3 @@ def copy_common_headers(
new_path = dst_dir / path.relative_to(src_dir) new_path = dst_dir / path.relative_to(src_dir)
new_path.parent.mkdir(exist_ok=True, parents=True) new_path.parent.mkdir(exist_ok=True, parents=True)
shutil.copy(path, new_path) shutil.copy(path, new_path)
def install_and_import(package):
"""Install a package via pip (if not already installed) and import into globals."""
main_package = package.split("[")[0]
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
globals()[main_package] = importlib.import_module(main_package)
...@@ -20,6 +20,9 @@ cd /TransformerEngine ...@@ -20,6 +20,9 @@ cd /TransformerEngine
git checkout $TARGET_BRANCH git checkout $TARGET_BRANCH
git submodule update --init --recursive git submodule update --init --recursive
# Install deps
/opt/python/cp310-cp310/bin/pip install cmake pybind11[global] ninja
if $BUILD_METAPACKAGE ; then if $BUILD_METAPACKAGE ; then
cd /TransformerEngine cd /TransformerEngine
NVTE_BUILD_METAPACKAGE=1 /opt/python/cp310-cp310/bin/python setup.py bdist_wheel 2>&1 | tee /wheelhouse/logs/metapackage.txt NVTE_BUILD_METAPACKAGE=1 /opt/python/cp310-cp310/bin/python setup.py bdist_wheel 2>&1 | tee /wheelhouse/logs/metapackage.txt
...@@ -31,15 +34,15 @@ if $BUILD_COMMON ; then ...@@ -31,15 +34,15 @@ if $BUILD_COMMON ; then
WHL_BASE="transformer_engine-${VERSION}" WHL_BASE="transformer_engine-${VERSION}"
# Create the wheel. # Create the wheel.
/opt/python/cp38-cp38/bin/python setup.py bdist_wheel --verbose --python-tag=py3 --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/common.txt /opt/python/cp310-cp310/bin/python setup.py bdist_wheel --verbose --python-tag=py3 --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/common.txt
# Repack the wheel for cuda specific package, i.e. cu12. # Repack the wheel for cuda specific package, i.e. cu12.
/opt/python/cp38-cp38/bin/wheel unpack dist/* /opt/python/cp310-cp310/bin/wheel unpack dist/*
# From python 3.10 to 3.11, the package name delimiter in metadata got changed from - (hyphen) to _ (underscore). # From python 3.10 to 3.11, the package name delimiter in metadata got changed from - (hyphen) to _ (underscore).
sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info"
/opt/python/cp38-cp38/bin/wheel pack ${WHL_BASE} /opt/python/cp310-cp310/bin/wheel pack ${WHL_BASE}
# Rename the wheel to make it python version agnostic. # Rename the wheel to make it python version agnostic.
whl_name=$(basename dist/*) whl_name=$(basename dist/*)
...@@ -51,8 +54,8 @@ fi ...@@ -51,8 +54,8 @@ fi
if $BUILD_PYTORCH ; then if $BUILD_PYTORCH ; then
cd /TransformerEngine/transformer_engine/pytorch cd /TransformerEngine/transformer_engine/pytorch
/opt/python/cp38-cp38/bin/pip install torch /opt/python/cp310-cp310/bin/pip install torch
/opt/python/cp38-cp38/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/torch.txt /opt/python/cp310-cp310/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/torch.txt
cp dist/* /wheelhouse/ cp dist/* /wheelhouse/
fi fi
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
[build-system]
requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", "torch>=2.1", "jax[cuda12]", "flax>=0.7.1"]
# Use legacy backend to import local packages in setup.py
build-backend = "setuptools.build_meta:__legacy__"
...@@ -16,13 +16,8 @@ from build_tools.build_ext import CMakeExtension, get_build_ext ...@@ -16,13 +16,8 @@ from build_tools.build_ext import CMakeExtension, get_build_ext
from build_tools.te_version import te_version from build_tools.te_version import te_version
from build_tools.utils import ( from build_tools.utils import (
cuda_archs, cuda_archs,
found_cmake,
found_ninja,
found_pybind11,
get_frameworks, get_frameworks,
install_and_import,
remove_dups, remove_dups,
cuda_toolkit_include_path,
) )
frameworks = get_frameworks() frameworks = get_frameworks()
...@@ -36,7 +31,6 @@ os.environ["NVTE_PROJECT_BUILDING"] = "1" ...@@ -36,7 +31,6 @@ os.environ["NVTE_PROJECT_BUILDING"] = "1"
if "pytorch" in frameworks: if "pytorch" in frameworks:
from torch.utils.cpp_extension import BuildExtension from torch.utils.cpp_extension import BuildExtension
elif "jax" in frameworks: elif "jax" in frameworks:
install_and_import("pybind11[global]")
from pybind11.setup_helpers import build_ext as BuildExtension from pybind11.setup_helpers import build_ext as BuildExtension
...@@ -82,26 +76,13 @@ def setup_common_extension() -> CMakeExtension: ...@@ -82,26 +76,13 @@ def setup_common_extension() -> CMakeExtension:
) )
def setup_requirements() -> Tuple[List[str], List[str], List[str]]: def setup_requirements() -> Tuple[List[str], List[str]]:
"""Setup Python dependencies """Setup Python dependencies
Returns dependencies for build, runtime, and testing. Returns dependencies for runtime and testing.
""" """
# Common requirements # Common requirements
setup_reqs: List[str] = []
if cuda_toolkit_include_path() is None:
setup_reqs.extend(
[
"nvidia-cuda-runtime-cu12",
"nvidia-cublas-cu12",
"nvidia-cudnn-cu12",
"nvidia-cuda-cccl-cu12",
"nvidia-cuda-nvcc-cu12",
"nvidia-nvtx-cu12",
"nvidia-cuda-nvrtc-cu12",
]
)
install_reqs: List[str] = [ install_reqs: List[str] = [
"pydantic", "pydantic",
"importlib-metadata>=1.0", "importlib-metadata>=1.0",
...@@ -109,30 +90,20 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: ...@@ -109,30 +90,20 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
] ]
test_reqs: List[str] = ["pytest>=8.2.1"] test_reqs: List[str] = ["pytest>=8.2.1"]
# Requirements that may be installed outside of Python
if not found_cmake():
setup_reqs.append("cmake>=3.21")
if not found_ninja():
setup_reqs.append("ninja")
if not found_pybind11():
setup_reqs.append("pybind11")
# Framework-specific requirements # Framework-specific requirements
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if "pytorch" in frameworks: if "pytorch" in frameworks:
from build_tools.pytorch import install_requirements, test_requirements from build_tools.pytorch import install_requirements, test_requirements
setup_reqs.extend(["torch>=2.1"])
install_reqs.extend(install_requirements()) install_reqs.extend(install_requirements())
test_reqs.extend(test_requirements()) test_reqs.extend(test_requirements())
if "jax" in frameworks: if "jax" in frameworks:
from build_tools.jax import install_requirements, test_requirements from build_tools.jax import install_requirements, test_requirements
setup_reqs.extend(["jax[cuda12]", "flax>=0.7.1"])
install_reqs.extend(install_requirements()) install_reqs.extend(install_requirements())
test_reqs.extend(test_requirements()) test_reqs.extend(test_requirements())
return [remove_dups(reqs) for reqs in [setup_reqs, install_reqs, test_reqs]] return [remove_dups(reqs) for reqs in [install_reqs, test_reqs]]
if __name__ == "__main__": if __name__ == "__main__":
...@@ -149,14 +120,13 @@ if __name__ == "__main__": ...@@ -149,14 +120,13 @@ if __name__ == "__main__":
ext_modules = [] ext_modules = []
package_data = {} package_data = {}
include_package_data = False include_package_data = False
setup_requires = []
install_requires = ([f"transformer_engine_cu12=={__version__}"],) install_requires = ([f"transformer_engine_cu12=={__version__}"],)
extras_require = { extras_require = {
"pytorch": [f"transformer_engine_torch=={__version__}"], "pytorch": [f"transformer_engine_torch=={__version__}"],
"jax": [f"transformer_engine_jax=={__version__}"], "jax": [f"transformer_engine_jax=={__version__}"],
} }
else: else:
setup_requires, install_requires, test_requires = setup_requirements() install_requires, test_requires = setup_requirements()
ext_modules = [setup_common_extension()] ext_modules = [setup_common_extension()]
package_data = {"": ["VERSION.txt"]} package_data = {"": ["VERSION.txt"]}
include_package_data = True include_package_data = True
...@@ -203,7 +173,6 @@ if __name__ == "__main__": ...@@ -203,7 +173,6 @@ if __name__ == "__main__":
cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}, cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist},
python_requires=">=3.8", python_requires=">=3.8",
classifiers=["Programming Language :: Python :: 3"], classifiers=["Programming Language :: Python :: 3"],
setup_requires=setup_requires,
install_requires=install_requires, install_requires=install_requires,
license_files=("LICENSE",), license_files=("LICENSE",),
include_package_data=include_package_data, include_package_data=include_package_data,
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
[build-system]
requires = ["setuptools>=61.0", "pybind11[global]", "pip", "jax[cuda12]", "flax>=0.7.1"]
# Use legacy backend to import local packages in setup.py
build-backend = "setuptools.build_meta:__legacy__"
...@@ -44,11 +44,10 @@ if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or os.path.isdir(build_tools_ ...@@ -44,11 +44,10 @@ if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or os.path.isdir(build_tools_
from build_tools.build_ext import get_build_ext from build_tools.build_ext import get_build_ext
from build_tools.utils import copy_common_headers, install_and_import, cuda_toolkit_include_path from build_tools.utils import copy_common_headers
from build_tools.te_version import te_version from build_tools.te_version import te_version
from build_tools.jax import setup_jax_extension, install_requirements, test_requirements from build_tools.jax import setup_jax_extension, install_requirements, test_requirements
install_and_import("pybind11")
from pybind11.setup_helpers import build_ext as BuildExtension from pybind11.setup_helpers import build_ext as BuildExtension
os.environ["NVTE_PROJECT_BUILDING"] = "1" os.environ["NVTE_PROJECT_BUILDING"] = "1"
...@@ -94,20 +93,6 @@ if __name__ == "__main__": ...@@ -94,20 +93,6 @@ if __name__ == "__main__":
) )
] ]
setup_requires = ["jax[cuda12]", "flax>=0.7.1"]
if cuda_toolkit_include_path() is None:
setup_requires.extend(
[
"nvidia-cuda-runtime-cu12",
"nvidia-cublas-cu12",
"nvidia-cudnn-cu12",
"nvidia-cuda-cccl-cu12",
"nvidia-cuda-nvcc-cu12",
"nvidia-nvtx-cu12",
"nvidia-cuda-nvrtc-cu12",
]
)
# Configure package # Configure package
setuptools.setup( setuptools.setup(
name="transformer_engine_jax", name="transformer_engine_jax",
...@@ -115,7 +100,6 @@ if __name__ == "__main__": ...@@ -115,7 +100,6 @@ if __name__ == "__main__":
description="Transformer acceleration library - Jax Lib", description="Transformer acceleration library - Jax Lib",
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension}, cmdclass={"build_ext": CMakeBuildExtension},
setup_requires=setup_requires,
install_requires=install_requirements(), install_requires=install_requirements(),
tests_require=test_requirements(), tests_require=test_requirements(),
) )
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
[build-system]
requires = ["setuptools>=61.0", "pip", "torch>=2.1"]
# Use legacy backend to import local packages in setup.py
build-backend = "setuptools.build_meta:__legacy__"
...@@ -29,7 +29,7 @@ if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or os.path.isdir(build_tools_ ...@@ -29,7 +29,7 @@ if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or os.path.isdir(build_tools_
from build_tools.build_ext import get_build_ext from build_tools.build_ext import get_build_ext
from build_tools.utils import copy_common_headers, cuda_toolkit_include_path from build_tools.utils import copy_common_headers
from build_tools.te_version import te_version from build_tools.te_version import te_version
from build_tools.pytorch import setup_pytorch_extension, install_requirements, test_requirements from build_tools.pytorch import setup_pytorch_extension, install_requirements, test_requirements
...@@ -48,20 +48,6 @@ if __name__ == "__main__": ...@@ -48,20 +48,6 @@ if __name__ == "__main__":
) )
] ]
setup_requires = ["torch>=2.1"]
if cuda_toolkit_include_path() is None:
setup_requires.extend(
[
"nvidia-cuda-runtime-cu12",
"nvidia-cublas-cu12",
"nvidia-cudnn-cu12",
"nvidia-cuda-cccl-cu12",
"nvidia-cuda-nvcc-cu12",
"nvidia-nvtx-cu12",
"nvidia-cuda-nvrtc-cu12",
]
)
# Configure package # Configure package
setuptools.setup( setuptools.setup(
name="transformer_engine_torch", name="transformer_engine_torch",
...@@ -69,7 +55,6 @@ if __name__ == "__main__": ...@@ -69,7 +55,6 @@ if __name__ == "__main__":
description="Transformer acceleration library - Torch Lib", description="Transformer acceleration library - Torch Lib",
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension}, cmdclass={"build_ext": CMakeBuildExtension},
setup_requires=setup_requires,
install_requires=install_requirements(), install_requires=install_requirements(),
tests_require=test_requirements(), tests_require=test_requirements(),
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment