Unverified Commit fd234d80 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Wheels for cuda 13 (#2278)



* Support wheel build for cuda 13
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fixes for cu13 runtime, format
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add documentation
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Better error handling
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix jax sdist
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Modify function names
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent ee384ab5
...@@ -205,7 +205,7 @@ pip Installation ...@@ -205,7 +205,7 @@ pip Installation
**Prerequisites for pip installation:** **Prerequisites for pip installation:**
* A compatible C++ compiler * A compatible C++ compiler
* CUDA Toolkit with cuDNN and NVCC (NVIDIA CUDA Compiler) installed * CUDA Toolkit with cuDNN and NVCC (NVIDIA CUDA Compiler) if installing from source.
To install the latest stable version with pip: To install the latest stable version with pip:
......
...@@ -7,23 +7,34 @@ FROM quay.io/pypa/manylinux_2_28_aarch64 ...@@ -7,23 +7,34 @@ FROM quay.io/pypa/manylinux_2_28_aarch64
WORKDIR /TransformerEngine/ WORKDIR /TransformerEngine/
COPY ../.. /TransformerEngine/ COPY ../.. /TransformerEngine/
ARG VER="12-3" ARG CUDA_MAJOR="12"
ARG ARCH="aarch64" ARG CUDA_MINOR="3"
RUN dnf -y install vim
# Args for build_wheels.sh
ARG BUILD_METAPACKAGE=true
ARG BUILD_COMMON=true
ARG BUILD_PYTORCH=true
ARG BUILD_JAX=true
ENV BUILD_METAPACKAGE=${BUILD_METAPACKAGE}
ENV BUILD_COMMON=${BUILD_COMMON}
ENV BUILD_PYTORCH=${BUILD_PYTORCH}
ENV BUILD_JAX=${BUILD_JAX}
ENV CUDA_MAJOR=${CUDA_MAJOR}
# Cuda toolkit, cudnn, driver. # Cuda toolkit, cudnn, driver.
RUN dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo RUN dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo
RUN dnf -y install epel-release RUN dnf -y install epel-release
RUN dnf -y install cuda-compiler-${VER}.${ARCH} \ RUN dnf -y install cuda-compiler-${CUDA_MAJOR}-${CUDA_MINOR}.aarch64 \
cuda-libraries-${VER}.${ARCH} \ cuda-libraries-${CUDA_MAJOR}-${CUDA_MINOR}.aarch64 \
cuda-libraries-devel-${VER}.${ARCH} cuda-libraries-devel-${CUDA_MAJOR}-${CUDA_MINOR}.aarch64
RUN dnf -y install --allowerasing cudnn9-cuda-12 RUN dnf -y install --allowerasing cudnn9-cuda-${CUDA_MAJOR}
RUN dnf clean all RUN dnf clean all
RUN rm -rf /var/cache/dnf/* RUN rm -rf /var/cache/dnf/*
RUN echo "/usr/local/cuda/lib64" >> /etc/ld.so.conf.d/999_nvidia_cuda.conf RUN echo "/usr/local/cuda/lib64" >> /etc/ld.so.conf.d/999_nvidia_cuda.conf
RUN dnf -y install cuda-toolkit RUN dnf -y install cuda-toolkit-${CUDA_MAJOR}
RUN dnf clean all RUN dnf clean all
RUN dnf -y install glog.aarch64 glog-devel.aarch64 RUN dnf -y install glog.aarch64 glog-devel.aarch64
RUN dnf -y install libnccl libnccl-devel libnccl-static
ENV PATH="/usr/local/cuda/bin:${PATH}" ENV PATH="/usr/local/cuda/bin:${PATH}"
ENV LD_LIBRARY_PATH="/usr/local/cuda/lib64:${LD_LIBRARY_PATH}" ENV LD_LIBRARY_PATH="/usr/local/cuda/lib64:${LD_LIBRARY_PATH}"
...@@ -33,4 +44,4 @@ ENV CUDA_PATH=/usr/local/cuda ...@@ -33,4 +44,4 @@ ENV CUDA_PATH=/usr/local/cuda
ENV CUDADIR=/usr/local/cuda ENV CUDADIR=/usr/local/cuda
ENV NVTE_RELEASE_BUILD=1 ENV NVTE_RELEASE_BUILD=1
CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_aarch64", "true", "true", "false", "false", "false"] CMD ["/bin/bash", "-c", "bash /TransformerEngine/build_tools/wheel_utils/build_wheels.sh manylinux_2_28_aarch64 $BUILD_METAPACKAGE $BUILD_COMMON $BUILD_PYTORCH $BUILD_JAX $CUDA_MAJOR"]
...@@ -7,23 +7,34 @@ FROM quay.io/pypa/manylinux_2_28_x86_64 ...@@ -7,23 +7,34 @@ FROM quay.io/pypa/manylinux_2_28_x86_64
WORKDIR /TransformerEngine/ WORKDIR /TransformerEngine/
COPY ../.. /TransformerEngine/ COPY ../.. /TransformerEngine/
ARG VER="12-3" ARG CUDA_MAJOR="12"
ARG ARCH="x86_64" ARG CUDA_MINOR="3"
RUN dnf -y install vim
# Args for build_wheels.sh
ARG BUILD_METAPACKAGE=true
ARG BUILD_COMMON=true
ARG BUILD_PYTORCH=true
ARG BUILD_JAX=true
ENV BUILD_METAPACKAGE=${BUILD_METAPACKAGE}
ENV BUILD_COMMON=${BUILD_COMMON}
ENV BUILD_PYTORCH=${BUILD_PYTORCH}
ENV BUILD_JAX=${BUILD_JAX}
ENV CUDA_MAJOR=${CUDA_MAJOR}
# Cuda toolkit, cudnn, driver. # Cuda toolkit, cudnn, driver.
RUN dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo RUN dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
RUN dnf -y install epel-release RUN dnf -y install epel-release
RUN dnf -y install cuda-compiler-${VER}.${ARCH} \ RUN dnf -y install cuda-compiler-${CUDA_MAJOR}-${CUDA_MINOR}.x86_64 \
cuda-libraries-${VER}.${ARCH} \ cuda-libraries-${CUDA_MAJOR}-${CUDA_MINOR}.x86_64 \
cuda-libraries-devel-${VER}.${ARCH} cuda-libraries-devel-${CUDA_MAJOR}-${CUDA_MINOR}.x86_64
RUN dnf -y install --allowerasing cudnn9-cuda-12 RUN dnf -y install --allowerasing cudnn9-cuda-${CUDA_MAJOR}
RUN dnf clean all RUN dnf clean all
RUN rm -rf /var/cache/dnf/* RUN rm -rf /var/cache/dnf/*
RUN echo "/usr/local/cuda/lib64" >> /etc/ld.so.conf.d/999_nvidia_cuda.conf RUN echo "/usr/local/cuda/lib64" >> /etc/ld.so.conf.d/999_nvidia_cuda.conf
RUN dnf -y install cuda-toolkit RUN dnf -y install cuda-toolkit-${CUDA_MAJOR}
RUN dnf clean all RUN dnf clean all
RUN dnf -y install glog.x86_64 glog-devel.x86_64 RUN dnf -y install glog.x86_64 glog-devel.x86_64
RUN dnf -y install libnccl libnccl-devel libnccl-static
ENV PATH="/usr/local/cuda/bin:${PATH}" ENV PATH="/usr/local/cuda/bin:${PATH}"
ENV LD_LIBRARY_PATH="/usr/local/cuda/lib64:${LD_LIBRARY_PATH}" ENV LD_LIBRARY_PATH="/usr/local/cuda/lib64:${LD_LIBRARY_PATH}"
...@@ -33,4 +44,4 @@ ENV CUDA_PATH=/usr/local/cuda ...@@ -33,4 +44,4 @@ ENV CUDA_PATH=/usr/local/cuda
ENV CUDADIR=/usr/local/cuda ENV CUDADIR=/usr/local/cuda
ENV NVTE_RELEASE_BUILD=1 ENV NVTE_RELEASE_BUILD=1
CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_x86_64", "true", "true", "true", "true", "true"] CMD ["/bin/bash", "-c", "bash /TransformerEngine/build_tools/wheel_utils/build_wheels.sh manylinux_2_28_x86_64 $BUILD_METAPACKAGE $BUILD_COMMON $BUILD_PYTORCH $BUILD_JAX $CUDA_MAJOR"]
\ No newline at end of file
...@@ -9,8 +9,10 @@ BUILD_METAPACKAGE=${2:-true} ...@@ -9,8 +9,10 @@ BUILD_METAPACKAGE=${2:-true}
BUILD_COMMON=${3:-true} BUILD_COMMON=${3:-true}
BUILD_PYTORCH=${4:-true} BUILD_PYTORCH=${4:-true}
BUILD_JAX=${5:-true} BUILD_JAX=${5:-true}
CUDA_MAJOR=${6:-12}
export NVTE_RELEASE_BUILD=1 export NVTE_RELEASE_BUILD=1
export PIP_CONSTRAINT=""
export TARGET_BRANCH=${TARGET_BRANCH:-} export TARGET_BRANCH=${TARGET_BRANCH:-}
mkdir -p /wheelhouse/logs mkdir -p /wheelhouse/logs
...@@ -21,7 +23,7 @@ git checkout $TARGET_BRANCH ...@@ -21,7 +23,7 @@ git checkout $TARGET_BRANCH
git submodule update --init --recursive git submodule update --init --recursive
# Install deps # Install deps
/opt/python/cp310-cp310/bin/pip install cmake pybind11[global] ninja /opt/python/cp310-cp310/bin/pip install cmake pybind11[global] ninja setuptools wheel nvidia-mathdx==25.1.1
if $BUILD_METAPACKAGE ; then if $BUILD_METAPACKAGE ; then
cd /TransformerEngine cd /TransformerEngine
...@@ -36,18 +38,18 @@ if $BUILD_COMMON ; then ...@@ -36,18 +38,18 @@ if $BUILD_COMMON ; then
# Create the wheel. # Create the wheel.
/opt/python/cp310-cp310/bin/python setup.py bdist_wheel --verbose --python-tag=py3 --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/common.txt /opt/python/cp310-cp310/bin/python setup.py bdist_wheel --verbose --python-tag=py3 --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/common.txt
# Repack the wheel for cuda specific package, i.e. cu12. # Repack the wheel for specific cuda version.
/opt/python/cp310-cp310/bin/wheel unpack dist/* /opt/python/cp310-cp310/bin/wheel unpack dist/*
# From python 3.10 to 3.11, the package name delimiter in metadata got changed from - (hyphen) to _ (underscore). # From python 3.10 to 3.11, the package name delimiter in metadata got changed from - (hyphen) to _ (underscore).
sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" sed -i "s/Name: transformer-engine/Name: transformer-engine-cu${CUDA_MAJOR}/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" sed -i "s/Name: transformer_engine/Name: transformer_engine_cu${CUDA_MAJOR}/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu${CUDA_MAJOR}-${VERSION}.dist-info"
/opt/python/cp310-cp310/bin/wheel pack ${WHL_BASE} /opt/python/cp310-cp310/bin/wheel pack ${WHL_BASE}
# Rename the wheel to make it python version agnostic. # Rename the wheel to make it python version agnostic.
whl_name=$(basename dist/*) whl_name=$(basename dist/*)
IFS='-' read -ra whl_parts <<< "$whl_name" IFS='-' read -ra whl_parts <<< "$whl_name"
whl_name_target="${whl_parts[0]}_cu12-${whl_parts[1]}-py3-none-${whl_parts[4]}" whl_name_target="${whl_parts[0]}_cu${CUDA_MAJOR}-${whl_parts[1]}-py3-none-${whl_parts[4]}"
rm -rf $WHL_BASE dist rm -rf $WHL_BASE dist
mv *.whl /wheelhouse/"$whl_name_target" mv *.whl /wheelhouse/"$whl_name_target"
fi fi
...@@ -61,7 +63,7 @@ fi ...@@ -61,7 +63,7 @@ fi
if $BUILD_JAX ; then if $BUILD_JAX ; then
cd /TransformerEngine/transformer_engine/jax cd /TransformerEngine/transformer_engine/jax
/opt/python/cp310-cp310/bin/pip install "jax[cuda12_local]" jaxlib /opt/python/cp310-cp310/bin/pip install "jax[cuda${CUDA_MAJOR}_local]" jaxlib
/opt/python/cp310-cp310/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/jax.txt /opt/python/cp310-cp310/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/jax.txt
cp dist/* /wheelhouse/ cp dist/* /wheelhouse/
fi fi
...@@ -2,7 +2,29 @@ ...@@ -2,7 +2,29 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
docker build --no-cache -t "aarch_wheel" -f build_tools/wheel_utils/Dockerfile.aarch . # Remove leftovers.
rm -rf aarch_wheelhouse_cu12 aarch_wheelhouse_cu13
# CUDA 12.
docker build --no-cache \
--build-arg CUDA_MAJOR=12 \
--build-arg CUDA_MINOR=3 \
--build-arg BUILD_METAPACKAGE=false \
--build-arg BUILD_COMMON=true \
--build-arg BUILD_PYTORCH=false \
--build-arg BUILD_JAX=false \
-t "aarch_wheel" -f build_tools/wheel_utils/Dockerfile.aarch .
docker run --runtime=nvidia --gpus=all --ipc=host "aarch_wheel"
docker cp $(docker ps -aq | head -1):/wheelhouse aarch_wheelhouse_cu12
# CUDA 13.
docker build --no-cache \
--build-arg CUDA_MAJOR=13 \
--build-arg CUDA_MINOR=0 \
--build-arg BUILD_METAPACKAGE=false \
--build-arg BUILD_COMMON=true \
--build-arg BUILD_PYTORCH=false \
--build-arg BUILD_JAX=false \
-t "aarch_wheel" -f build_tools/wheel_utils/Dockerfile.aarch .
docker run --runtime=nvidia --gpus=all --ipc=host "aarch_wheel" docker run --runtime=nvidia --gpus=all --ipc=host "aarch_wheel"
rm -rf aarch_wheelhouse docker cp $(docker ps -aq | head -1):/wheelhouse aarch_wheelhouse_cu13
docker cp $(docker ps -aq | head -1):/wheelhouse/ aarch_wheelhouse
...@@ -2,7 +2,29 @@ ...@@ -2,7 +2,29 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
docker build --no-cache -t "x86_wheel" -f build_tools/wheel_utils/Dockerfile.x86 . # Remove leftovers.
rm -rf x86_wheelhouse_cu12 x86_wheelhouse_cu13
# CUDA 12.
docker build --no-cache \
--build-arg CUDA_MAJOR=12 \
--build-arg CUDA_MINOR=3 \
--build-arg BUILD_METAPACKAGE=true \
--build-arg BUILD_COMMON=true \
--build-arg BUILD_PYTORCH=true \
--build-arg BUILD_JAX=true \
-t "x86_wheel" -f build_tools/wheel_utils/Dockerfile.x86 .
docker run --runtime=nvidia --gpus=all --ipc=host "x86_wheel"
docker cp $(docker ps -aq | head -1):/wheelhouse x86_wheelhouse_cu12
# CUDA 13.
docker build --no-cache \
--build-arg CUDA_MAJOR=13 \
--build-arg CUDA_MINOR=0 \
--build-arg BUILD_METAPACKAGE=false \
--build-arg BUILD_COMMON=true \
--build-arg BUILD_PYTORCH=false \
--build-arg BUILD_JAX=false \
-t "x86_wheel" -f build_tools/wheel_utils/Dockerfile.x86 .
docker run --runtime=nvidia --gpus=all --ipc=host "x86_wheel" docker run --runtime=nvidia --gpus=all --ipc=host "x86_wheel"
rm -rf x86_wheelhouse docker cp $(docker ps -aq | head -1):/wheelhouse x86_wheelhouse_cu13
docker cp $(docker ps -aq | head -1):/wheelhouse x86_wheelhouse
...@@ -38,6 +38,14 @@ Transformer Engine can be directly installed from `our PyPI <https://pypi.org/pr ...@@ -38,6 +38,14 @@ Transformer Engine can be directly installed from `our PyPI <https://pypi.org/pr
To obtain the necessary Python bindings for Transformer Engine, the frameworks needed must be explicitly specified as extra dependencies in a comma-separated list (e.g. [jax,pytorch]). Transformer Engine ships wheels for the core library. Source distributions are shipped for the JAX and PyTorch extensions. To obtain the necessary Python bindings for Transformer Engine, the frameworks needed must be explicitly specified as extra dependencies in a comma-separated list (e.g. [jax,pytorch]). Transformer Engine ships wheels for the core library. Source distributions are shipped for the JAX and PyTorch extensions.
The core package from Transformer Engine (without any framework extensions) can be installed via:
.. code-block:: bash
pip3 install transformer_engine[core]
By default, this will install the core library compiled for CUDA 12. The cuda major version can be specified by modified the extra dependency to `core_cu12` or `core_cu13`.
pip - from GitHub pip - from GitHub
----------------------- -----------------------
......
...@@ -140,8 +140,11 @@ if __name__ == "__main__": ...@@ -140,8 +140,11 @@ if __name__ == "__main__":
ext_modules = [] ext_modules = []
package_data = {} package_data = {}
include_package_data = False include_package_data = False
install_requires = ([f"transformer_engine_cu12=={__version__}"],) install_requires = []
extras_require = { extras_require = {
"core": [f"transformer_engine_cu12=={__version__}"],
"core_cu12": [f"transformer_engine_cu12=={__version__}"],
"core_cu13": [f"transformer_engine_cu13=={__version__}"],
"pytorch": [f"transformer_engine_torch=={__version__}"], "pytorch": [f"transformer_engine_torch=={__version__}"],
"jax": [f"transformer_engine_jax=={__version__}"], "jax": [f"transformer_engine_jax=={__version__}"],
} }
......
...@@ -8,22 +8,18 @@ import ctypes ...@@ -8,22 +8,18 @@ import ctypes
import functools import functools
import glob import glob
import importlib import importlib
from importlib.metadata import version, metadata, PackageNotFoundError from importlib.metadata import version, distribution, PackageNotFoundError
import logging
import os import os
from pathlib import Path from pathlib import Path
import platform import platform
import subprocess import subprocess
import sys import sys
import sysconfig import sysconfig
from typing import Optional from typing import Optional, Tuple
_logger = logging.getLogger(__name__)
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def _is_pip_package_installed(package) -> bool: def _is_package_installed(package) -> bool:
"""Check if the given package is installed via pip.""" """Check if the given package is installed via pip."""
# This is needed because we only want to return true # This is needed because we only want to return true
...@@ -31,12 +27,34 @@ def _is_pip_package_installed(package) -> bool: ...@@ -31,12 +27,34 @@ def _is_pip_package_installed(package) -> bool:
# if it's importable in the current directory due to # if it's importable in the current directory due to
# the presence of the shared library module. # the presence of the shared library module.
try: try:
metadata(package) distribution(package)
except PackageNotFoundError: except PackageNotFoundError:
return False return False
return True return True
@functools.lru_cache(maxsize=None)
def _is_package_installed_from_wheel(package) -> bool:
"""Check if the given package is installed via PyPI."""
if not _is_package_installed(package):
return False
te_dist = distribution(package)
te_wheel_file = ""
for file_path in te_dist.files:
if file_path.name == "WHEEL":
te_wheel_file = te_dist.locate_file("") / file_path
if not te_wheel_file:
return False
with te_wheel_file.open("r") as f:
for line in f:
if line.startswith("Root-Is-Purelib:"):
return line.strip().split(":")[1].strip().lower() == "true"
return False
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def _find_shared_object_in_te_dir(te_path: Path, prefix: str) -> Optional[Path]: def _find_shared_object_in_te_dir(te_path: Path, prefix: str) -> Optional[Path]:
""" """
...@@ -112,6 +130,19 @@ def _get_shared_object_file(library: str) -> Path: ...@@ -112,6 +130,19 @@ def _get_shared_object_file(library: str) -> Path:
) )
def get_te_core_package_info() -> Tuple[bool, str, str]:
"""
Check if Tranformer Engine core package is installed.
Returns the module name and version if found.
"""
te_core_packages = ("transformer-engine-cu12", "transformer-engine-cu13")
for package in te_core_packages:
if _is_package_installed(package):
return True, package, version(package)
return False, "", ""
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def load_framework_extension(framework: str) -> None: def load_framework_extension(framework: str) -> None:
""" """
...@@ -130,37 +161,28 @@ def load_framework_extension(framework: str) -> None: ...@@ -130,37 +161,28 @@ def load_framework_extension(framework: str) -> None:
if framework == "torch": if framework == "torch":
extra_dep_name = "pytorch" extra_dep_name = "pytorch"
# Find the TE packages. The core and framework packages can only be installed via PyPI.
# For the `transformer-engine` package, we need to check explicity.
te_core_installed, te_core_package_name, te_core_version = get_te_core_package_info()
te_framework_installed = _is_package_installed(module_name)
te_installed = _is_package_installed("transformer_engine")
te_installed_via_pypi = _is_package_installed_from_wheel("transformer_engine")
assert te_installed, "Could not find `transformer_engine`."
# If the framework extension pip package is installed, it means that TE is installed via # If the framework extension pip package is installed, it means that TE is installed via
# PyPI. For this case we need to make sure that the metapackage, the core lib, and framework # PyPI. For this case we need to make sure that the metapackage, the core lib, and framework
# extension are all installed via PyPI and have matching version. # extension are all installed via PyPI and have matching versions.
if _is_pip_package_installed(module_name): if te_framework_installed:
assert _is_pip_package_installed( assert te_installed_via_pypi, "Could not find `transformer-engine` PyPI package."
"transformer_engine" assert te_core_installed, "Could not find TE core package `transformer-engine-cu*`."
), "Could not find `transformer-engine`."
assert _is_pip_package_installed(
"transformer_engine_cu12"
), "Could not find `transformer-engine-cu12`."
assert (
version(module_name)
== version("transformer-engine")
== version("transformer-engine-cu12")
), (
"TransformerEngine package version mismatch. Found"
f" {module_name} v{version(module_name)}, transformer-engine"
f" v{version('transformer-engine')}, and transformer-engine-cu12"
f" v{version('transformer-engine-cu12')}. Install transformer-engine using "
f"'pip3 install transformer-engine[{extra_dep_name}]==VERSION'"
)
# If the core package is installed via PyPI, log if assert version(module_name) == version("transformer-engine") == te_core_version, (
# the framework extension is not found from PyPI. "Transformer Engine package version mismatch. Found"
# Note: Should we error? This is a rare use case. f" {module_name} v{version(module_name)}, transformer-engine"
if _is_pip_package_installed("transformer-engine-cu12"): f" v{version('transformer-engine')}, and {te_core_package_name}"
if not _is_pip_package_installed(module_name): f" v{te_core_version}. Install transformer-engine using "
_logger.info( f"'pip3 install --no-build-isolation transformer-engine[{extra_dep_name}]==VERSION'"
"Could not find package %s. Install transformer-engine using "
f"'pip3 install transformer-engine[{extra_dep_name}]==VERSION'",
module_name,
) )
# After all checks are completed, load the shared object file. # After all checks are completed, load the shared object file.
...@@ -170,6 +192,35 @@ def load_framework_extension(framework: str) -> None: ...@@ -170,6 +192,35 @@ def load_framework_extension(framework: str) -> None:
spec.loader.exec_module(solib) spec.loader.exec_module(solib)
def sanity_checks_for_pypi_installation() -> None:
"""Ensure that package is installed correctly if using PyPI."""
te_core_installed, te_core_package_name, te_core_version = get_te_core_package_info()
te_installed = _is_package_installed("transformer_engine")
te_installed_via_pypi = _is_package_installed_from_wheel("transformer_engine")
assert te_installed, "Could not find `transformer-engine`."
# If the core package is installed via PyPI.
if te_core_installed:
assert te_installed_via_pypi, "Could not find `transformer-engine` PyPI package."
assert version("transformer-engine") == te_core_version, (
"Transformer Engine package version mismatch. Found "
f"transformer-engine v{version('transformer-engine')} "
f"and {te_core_package_name} v{te_core_version}."
)
# Only the metapackage is found, invalid usecase.
elif te_installed_via_pypi:
raise RuntimeError(
"Found empty `transformer-engine` meta package installed. "
"Install `transformer-engine` with framework extensions via"
"'pip3 install --no-build-isolation transformer-engine[pytorch,jax]==VERSION'"
" or 'pip3 install transformer-engine[core]` for the TE core lib only. The `core_cu12`"
" or `core_cu13` extra deps can be used to specify CUDA version for the TE core lib."
)
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def _get_sys_extension() -> str: def _get_sys_extension() -> str:
"""File extension for shared objects.""" """File extension for shared objects."""
...@@ -332,6 +383,7 @@ def _load_core_library(): ...@@ -332,6 +383,7 @@ def _load_core_library():
if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
sanity_checks_for_pypi_installation()
_CUDNN_LIB_CTYPES = _load_cudnn() _CUDNN_LIB_CTYPES = _load_cudnn()
_NVRTC_LIB_CTYPES = _load_nvrtc() _NVRTC_LIB_CTYPES = _load_nvrtc()
_CURAND_LIB_CTYPES = _load_curand() _CURAND_LIB_CTYPES = _load_curand()
......
...@@ -54,6 +54,26 @@ os.environ["NVTE_PROJECT_BUILDING"] = "1" ...@@ -54,6 +54,26 @@ os.environ["NVTE_PROJECT_BUILDING"] = "1"
CMakeBuildExtension = get_build_ext(BuildExtension, True) CMakeBuildExtension = get_build_ext(BuildExtension, True)
def get_cuda_major_version() -> int:
"""Get CUDA major version using Jax backend."""
assert (
jax._src.lib.cuda_versions is not None
), "GPU backend is required to build TE jax extensions."
# Jax currently does not have any stable/public method to get cuda version.
# Try using internal function and default to cuda12 if not found.
try:
cuda_version = jax._src.lib.cuda_versions.cuda_runtime_get_version()
cuda_major_version = cuda_version // 1000
except AttributeError:
cuda_version = os.getenv("CUDA_VERSION", "12")
cuda_major_version = int(cuda_version.split(".")[0])
assert cuda_major_version in (12, 13), f"Unsupported cuda version {cuda_version}."
return cuda_major_version
if __name__ == "__main__": if __name__ == "__main__":
"""Main entry point for JAX extension installation. """Main entry point for JAX extension installation.
...@@ -93,15 +113,23 @@ if __name__ == "__main__": ...@@ -93,15 +113,23 @@ if __name__ == "__main__":
) )
] ]
# Setup version and requirements.
# Having the framework extension depend on the core lib allows
# us to detect CUDA version dynamically during compilation and
# choose the correct wheel for te core lib.
__version__ = te_version()
te_core = f"transformer_engine_cu{get_cuda_major_version()}=={__version__}"
install_requires = install_requirements() + [te_core]
# Configure package # Configure package
setuptools.setup( setuptools.setup(
name="transformer_engine_jax", name="transformer_engine_jax",
version=te_version(), version=__version__,
description="Transformer acceleration library - Jax Lib", description="Transformer acceleration library - Jax Lib",
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension}, cmdclass={"build_ext": CMakeBuildExtension},
python_requires=f">={min_python_version_str()}", python_requires=f">={min_python_version_str()}",
install_requires=install_requirements(), install_requires=install_requires,
tests_require=test_requirements(), tests_require=test_requirements(),
) )
if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")): if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")):
......
...@@ -145,15 +145,25 @@ if __name__ == "__main__": ...@@ -145,15 +145,25 @@ if __name__ == "__main__":
) )
] ]
# Setup version and requirements.
# Having the framework extension depend on the core lib allows
# us to detect CUDA version dynamically during compilation and
# choose the correct wheel for te core lib.
__version__ = te_version()
cuda_major_version = parse(torch.version.cuda).major
assert cuda_major_version in (12, 13), f"Unsupported cuda version {torch.version.cuda}."
te_core = f"transformer_engine_cu{cuda_major_version}=={__version__}"
install_requires = install_requirements() + [te_core]
# Configure package # Configure package
setuptools.setup( setuptools.setup(
name=PACKAGE_NAME, name=PACKAGE_NAME,
version=te_version(), version=__version__,
description="Transformer acceleration library - Torch Lib", description="Transformer acceleration library - Torch Lib",
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": CachedWheelsCommand}, cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": CachedWheelsCommand},
python_requires=f">={min_python_version_str()}", python_requires=f">={min_python_version_str()}",
install_requires=install_requirements(), install_requires=install_requires,
tests_require=test_requirements(), tests_require=test_requirements(),
) )
if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")): if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment