Unverified Commit 09813578 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Build scripts for pip wheels (#1036)



* Specify python version
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add classifiers for python
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add utils to build wheels
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* make wheel scripts
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add aarch
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix paddle wheel
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* PaddlePaddle only builds for x86
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add optional fwk deps
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Python3.8; catch install error
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* [wip] cudnn9 compile with paddle support
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* [wip] dont link cudnn
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* dlopen cudnn
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* dynamically load nvrtc
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* remove residual packages; exclude stub from nvrtc .so search
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Exclude builtins from nvrtc .so search
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* properly include files for sdist
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* paddle wheel tie to python version
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix paddle build from src [wip]
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix workflow paddle build
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix paddle
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix paddle
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix lint from pr986
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add sanity wheel test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add sanity import to wheel test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* remove upper limit on paddlepaddle version
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Remove unused imports
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Remove pybind11 dependency
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix cpp tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Search .sos in cuda home
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* CLeanup, remove residual code
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 6ae584dd
...@@ -78,7 +78,10 @@ jobs: ...@@ -78,7 +78,10 @@ jobs:
with: with:
submodules: recursive submodules: recursive
- name: 'Build' - name: 'Build'
run: pip install . -v run: |
apt-get update
apt-get install -y libgoogle-glog-dev
pip install . -v
env: env:
NVTE_FRAMEWORK: paddle NVTE_FRAMEWORK: paddle
- name: 'Sanity check' - name: 'Sanity check'
......
...@@ -135,8 +135,14 @@ def get_build_ext(extension_cls: Type[setuptools.Extension]): ...@@ -135,8 +135,14 @@ def get_build_ext(extension_cls: Type[setuptools.Extension]):
search_paths = list(Path(__file__).resolve().parent.parent.iterdir()) search_paths = list(Path(__file__).resolve().parent.parent.iterdir())
# Source compilation from top-level # Source compilation from top-level
search_paths.extend(list(Path(self.build_lib).iterdir())) search_paths.extend(list(Path(self.build_lib).iterdir()))
# Dynamically load required_libs.
from transformer_engine.common import _load_cudnn, _load_nvrtc
_load_cudnn()
_load_nvrtc()
else: else:
# Only during release sdist build. # Only during release bdist build for paddlepaddle.
import transformer_engine import transformer_engine
search_paths = list(Path(transformer_engine.__path__[0]).iterdir()) search_paths = list(Path(transformer_engine.__path__[0]).iterdir())
......
...@@ -11,6 +11,7 @@ import re ...@@ -11,6 +11,7 @@ import re
import shutil import shutil
import subprocess import subprocess
import sys import sys
import importlib
from pathlib import Path from pathlib import Path
from subprocess import CalledProcessError from subprocess import CalledProcessError
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
...@@ -253,15 +254,6 @@ def get_frameworks() -> List[str]: ...@@ -253,15 +254,6 @@ def get_frameworks() -> List[str]:
return _frameworks return _frameworks
def package_files(directory):
paths = []
for path, _, filenames in os.walk(directory):
path = Path(path)
for filename in filenames:
paths.append(str(path / filename).replace(f"{directory}/", ""))
return paths
def copy_common_headers(te_src, dst): def copy_common_headers(te_src, dst):
headers = te_src / "common" headers = te_src / "common"
for file_path in glob.glob(os.path.join(str(headers), "**", "*.h"), recursive=True): for file_path in glob.glob(os.path.join(str(headers), "**", "*.h"), recursive=True):
...@@ -272,11 +264,21 @@ def copy_common_headers(te_src, dst): ...@@ -272,11 +264,21 @@ def copy_common_headers(te_src, dst):
def install_and_import(package): def install_and_import(package):
"""Install a package via pip (if not already installed) and import into globals.""" """Install a package via pip (if not already installed) and import into globals."""
import importlib main_package = package.split("[")[0]
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
try: globals()[main_package] = importlib.import_module(main_package)
importlib.import_module(package)
except ImportError:
subprocess.check_call([sys.executable, "-m", "pip", "install", package]) def uninstall_te_fw_packages():
finally: subprocess.check_call(
globals()[package] = importlib.import_module(package) [
sys.executable,
"-m",
"pip",
"uninstall",
"-y",
"transformer_engine_torch",
"transformer_engine_paddle",
"transformer_engine_jax",
]
)
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
FROM quay.io/pypa/manylinux_2_28_aarch64
WORKDIR /TransformerEngine/
COPY ../.. /TransformerEngine/
ARG VER="12-3"
ARG ARCH="aarch64"
RUN dnf -y install vim
# Cuda toolkit, cudnn, driver.
RUN dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo
RUN dnf -y install epel-release
RUN dnf -y install cuda-compiler-${VER}.${ARCH} \
cuda-libraries-${VER}.${ARCH} \
cuda-libraries-devel-${VER}.${ARCH}
RUN dnf -y install --allowerasing cudnn9-cuda-12
RUN dnf clean all
RUN rm -rf /var/cache/dnf/*
RUN echo "/usr/local/cuda/lib64" >> /etc/ld.so.conf.d/999_nvidia_cuda.conf
RUN dnf -y install cuda-toolkit
RUN dnf clean all
RUN dnf -y install glog.aarch64 glog-devel.aarch64
ENV PATH="/usr/local/cuda/bin:${PATH}"
ENV LD_LIBRARY_PATH="/usr/local/cuda/lib64:${LD_LIBRARY_PATH}"
ENV CUDA_HOME=/usr/local/cuda
ENV CUDA_ROOT=/usr/local/cuda
ENV CUDA_PATH=/usr/local/cuda
ENV CUDADIR=/usr/local/cuda
ENV NVTE_RELEASE_BUILD=1
CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_aarch64", "true", "false", "false", "true"]
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
FROM quay.io/pypa/manylinux_2_28_x86_64
WORKDIR /TransformerEngine/
COPY ../.. /TransformerEngine/
ARG VER="12-3"
ARG ARCH="x86_64"
RUN dnf -y install vim
# Cuda toolkit, cudnn, driver.
RUN dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
RUN dnf -y install epel-release
RUN dnf -y install cuda-compiler-${VER}.${ARCH} \
cuda-libraries-${VER}.${ARCH} \
cuda-libraries-devel-${VER}.${ARCH}
RUN dnf -y install --allowerasing cudnn9-cuda-12
RUN dnf clean all
RUN rm -rf /var/cache/dnf/*
RUN echo "/usr/local/cuda/lib64" >> /etc/ld.so.conf.d/999_nvidia_cuda.conf
RUN dnf -y install cuda-toolkit
RUN dnf clean all
RUN dnf -y install glog.x86_64 glog-devel.x86_64
ENV PATH="/usr/local/cuda/bin:${PATH}"
ENV LD_LIBRARY_PATH="/usr/local/cuda/lib64:${LD_LIBRARY_PATH}"
ENV CUDA_HOME=/usr/local/cuda
ENV CUDA_ROOT=/usr/local/cuda
ENV CUDA_PATH=/usr/local/cuda
ENV CUDADIR=/usr/local/cuda
ENV NVTE_RELEASE_BUILD=1
CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_x86_64", "true", "true", "true", "true"]
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -e
PLATFORM=${1:-manylinux_2_28_x86_64}
BUILD_COMMON=${2:-true}
BUILD_JAX=${3:-true}
BUILD_PYTORCH=${4:-true}
BUILD_PADDLE=${5:-true}
export NVTE_RELEASE_BUILD=1
export TARGET_BRANCH=${TARGET_BRANCH:-wheels}
mkdir /wheelhouse
mkdir /wheelhouse/logs
# Generate wheels for common library.
git config --global --add safe.directory /TransformerEngine
cd /TransformerEngine
git checkout $TARGET_BRANCH
git submodule update --init --recursive
if $BUILD_COMMON ; then
/opt/python/cp38-cp38/bin/python setup.py bdist_wheel --verbose --python-tag=py3 --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/common.txt
whl_name=$(basename dist/*)
IFS='-' read -ra whl_parts <<< "$whl_name"
whl_name_target="${whl_parts[0]}-${whl_parts[1]}-py3-none-${whl_parts[4]}"
mv dist/"$whl_name" /wheelhouse/"$whl_name_target"
fi
if $BUILD_PYTORCH ; then
cd /TransformerEngine/transformer_engine/pytorch
/opt/python/cp38-cp38/bin/pip install torch
/opt/python/cp38-cp38/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/torch.txt
cp dist/* /wheelhouse/
fi
if $BUILD_JAX ; then
cd /TransformerEngine/transformer_engine/jax
/opt/python/cp38-cp38/bin/pip install jax jaxlib
/opt/python/cp38-cp38/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/jax.txt
cp dist/* /wheelhouse/
fi
if $BUILD_PADDLE ; then
if [ "$PLATFORM" == "manylinux_2_28_x86_64" ] ; then
dnf -y remove --allowerasing cudnn9-cuda-12
dnf -y install libcudnn8-devel.x86_64 libcudnn8.x86_64
cd /TransformerEngine/transformer_engine/paddle
/opt/python/cp38-cp38/bin/pip install /wheelhouse/*.whl
/opt/python/cp38-cp38/bin/pip install paddlepaddle-gpu==2.6.1
/opt/python/cp38-cp38/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp38.txt
/opt/python/cp38-cp38/bin/pip uninstall -y transformer-engine paddlepaddle-gpu
/opt/python/cp39-cp39/bin/pip install /wheelhouse/*.whl
/opt/python/cp39-cp39/bin/pip install paddlepaddle-gpu==2.6.1
/opt/python/cp39-cp39/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp39.txt
/opt/python/cp39-cp39/bin/pip uninstall -y transformer-engine paddlepaddle-gpu
/opt/python/cp310-cp310/bin/pip install /wheelhouse/*.whl
/opt/python/cp310-cp310/bin/pip install paddlepaddle-gpu==2.6.1
/opt/python/cp310-cp310/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp310.txt
/opt/python/cp310-cp310/bin/pip uninstall -y transformer-engine paddlepaddle-gpu
/opt/python/cp311-cp311/bin/pip install /wheelhouse/*.whl
/opt/python/cp311-cp311/bin/pip install paddlepaddle-gpu==2.6.1
/opt/python/cp311-cp311/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp311.txt
/opt/python/cp311-cp311/bin/pip uninstall -y transformer-engine paddlepaddle-gpu
/opt/python/cp312-cp312/bin/pip install /wheelhouse/*.whl
/opt/python/cp312-cp312/bin/pip install paddlepaddle-gpu==2.6.1
/opt/python/cp312-cp312/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp312.txt
/opt/python/cp312-cp312/bin/pip uninstall -y transformer-engine paddlepaddle-gpu
mv dist/* /wheelhouse/
fi
fi
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
docker build --no-cache -t "aarch_wheel" -f build_tools/wheel_utils/Dockerfile.aarch .
docker run --runtime=nvidia --gpus=all --ipc=host "aarch_wheel"
rm -rf aarch_wheelhouse
docker cp $(docker ps -aq | head -1):/wheelhouse/ aarch_wheelhouse
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
docker build --no-cache -t "x86_wheel" -f build_tools/wheel_utils/Dockerfile.x86 .
docker run --runtime=nvidia --gpus=all --ipc=host "x86_wheel"
rm -rf x86_wheelhouse
docker cp $(docker ps -aq | head -1):/wheelhouse x86_wheelhouse
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -e
: "${TE_PATH:=/opt/transformerengine}"
cd $TE_PATH
pip uninstall -y transformer-engine
export NVTE_RELEASE_BUILD=1
python setup.py bdist_wheel
cd transformer_engine/jax
python setup.py sdist
export NVTE_RELEASE_BUILD=0
pip install dist/*
cd $TE_PATH
pip install dist/*
python $TE_PATH/tests/jax/test_sanity_import.py
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -e
: "${TE_PATH:=/opt/transformerengine}"
cd $TE_PATH
pip uninstall -y transformer-engine
export NVTE_RELEASE_BUILD=1
python setup.py bdist_wheel
pip install dist/*
cd transformer_engine/paddle
python setup.py bdist_wheel
export NVTE_RELEASE_BUILD=0
cd $TE_PATH
pip install dist/*
python $TE_PATH/tests/paddle/test_sanity_import.py
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -e
: "${TE_PATH:=/opt/transformerengine}"
cd $TE_PATH
pip uninstall -y transformer-engine
export NVTE_RELEASE_BUILD=1
python setup.py bdist_wheel
cd transformer_engine/pytorch
python setup.py sdist
export NVTE_RELEASE_BUILD=0
pip install dist/*
cd $TE_PATH
pip install dist/*
python $TE_PATH/tests/pytorch/test_sanity_import.py
...@@ -18,6 +18,7 @@ from build_tools.utils import ( ...@@ -18,6 +18,7 @@ from build_tools.utils import (
remove_dups, remove_dups,
get_frameworks, get_frameworks,
install_and_import, install_and_import,
uninstall_te_fw_packages,
) )
from build_tools.te_version import te_version from build_tools.te_version import te_version
...@@ -28,12 +29,14 @@ current_file_path = Path(__file__).parent.resolve() ...@@ -28,12 +29,14 @@ current_file_path = Path(__file__).parent.resolve()
from setuptools.command.build_ext import build_ext as BuildExtension from setuptools.command.build_ext import build_ext as BuildExtension
os.environ["NVTE_PROJECT_BUILDING"] = "1"
if "pytorch" in frameworks: if "pytorch" in frameworks:
from torch.utils.cpp_extension import BuildExtension from torch.utils.cpp_extension import BuildExtension
elif "paddle" in frameworks: elif "paddle" in frameworks:
from paddle.utils.cpp_extension import BuildExtension from paddle.utils.cpp_extension import BuildExtension
elif "jax" in frameworks: elif "jax" in frameworks:
install_and_import("pybind11") install_and_import("pybind11[global]")
from pybind11.setup_helpers import build_ext as BuildExtension from pybind11.setup_helpers import build_ext as BuildExtension
...@@ -61,7 +64,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: ...@@ -61,7 +64,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
setup_reqs: List[str] = [] setup_reqs: List[str] = []
install_reqs: List[str] = [ install_reqs: List[str] = [
"pydantic", "pydantic",
"importlib-metadata>=1.0; python_version<'3.8'", "importlib-metadata>=1.0",
"packaging", "packaging",
] ]
test_reqs: List[str] = ["pytest>=8.2.1"] test_reqs: List[str] = ["pytest>=8.2.1"]
...@@ -85,6 +88,9 @@ if __name__ == "__main__": ...@@ -85,6 +88,9 @@ if __name__ == "__main__":
ext_modules = [setup_common_extension()] ext_modules = [setup_common_extension()]
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
# Remove residual FW packages since compiling from source
# results in a single binary with FW extensions included.
uninstall_te_fw_packages()
if "pytorch" in frameworks: if "pytorch" in frameworks:
from build_tools.pytorch import setup_pytorch_extension from build_tools.pytorch import setup_pytorch_extension
...@@ -129,10 +135,21 @@ if __name__ == "__main__": ...@@ -129,10 +135,21 @@ if __name__ == "__main__":
), ),
extras_require={ extras_require={
"test": test_requires, "test": test_requires,
"pytorch": [f"transformer_engine_torch=={__version__}"],
"jax": [f"transformer_engine_jax=={__version__}"],
"paddle": [f"transformer_engine_paddle=={__version__}"],
}, },
description="Transformer acceleration library", description="Transformer acceleration library",
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension}, cmdclass={"build_ext": CMakeBuildExtension},
python_requires=">=3.8, <3.13",
classifiers=[
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
],
setup_requires=setup_requires, setup_requires=setup_requires,
install_requires=install_requires, install_requires=install_requires,
license_files=("LICENSE",), license_files=("LICENSE",),
......
...@@ -34,6 +34,7 @@ include_directories(../../transformer_engine/common) ...@@ -34,6 +34,7 @@ include_directories(../../transformer_engine/common)
include_directories(${CMAKE_SOURCE_DIR}) include_directories(${CMAKE_SOURCE_DIR})
find_package(CUDAToolkit REQUIRED) find_package(CUDAToolkit REQUIRED)
include(${CMAKE_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake)
add_subdirectory(operator) add_subdirectory(operator)
add_subdirectory(util) add_subdirectory(util)
...@@ -18,7 +18,7 @@ add_executable(test_operator ...@@ -18,7 +18,7 @@ add_executable(test_operator
test_causal_softmax.cu test_causal_softmax.cu
../test_common.cu) ../test_common.cu)
list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB}) list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn)
target_link_libraries(test_operator PUBLIC ${test_operator_LINKER_LIBS}) target_link_libraries(test_operator PUBLIC ${test_operator_LINKER_LIBS})
target_compile_options(test_operator PRIVATE -O2) target_compile_options(test_operator PRIVATE -O2)
......
...@@ -7,7 +7,8 @@ add_executable(test_util ...@@ -7,7 +7,8 @@ add_executable(test_util
test_string.cpp test_string.cpp
../test_common.cu) ../test_common.cu)
target_link_libraries(test_util PUBLIC CUDA::cudart GTest::gtest_main ${TE_LIB})
target_link_libraries(test_util PUBLIC CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn)
target_compile_options(test_util PRIVATE -O2) target_compile_options(test_util PRIVATE -O2)
include(GoogleTest) include(GoogleTest)
......
...@@ -32,7 +32,6 @@ endif() ...@@ -32,7 +32,6 @@ endif()
include(${CMAKE_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake) include(${CMAKE_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake)
find_package(Python COMPONENTS Interpreter Development.Module REQUIRED) find_package(Python COMPONENTS Interpreter Development.Module REQUIRED)
include_directories(${PROJECT_SOURCE_DIR}/..) include_directories(${PROJECT_SOURCE_DIR}/..)
# Configure Transformer Engine library # Configure Transformer Engine library
...@@ -77,9 +76,7 @@ target_include_directories(transformer_engine PUBLIC ...@@ -77,9 +76,7 @@ target_include_directories(transformer_engine PUBLIC
target_link_libraries(transformer_engine PUBLIC target_link_libraries(transformer_engine PUBLIC
CUDA::cublas CUDA::cublas
CUDA::cuda_driver CUDA::cuda_driver
CUDA::cudart CUDA::cudart)
CUDA::nvrtc
CUDNN::cudnn)
target_include_directories(transformer_engine PRIVATE target_include_directories(transformer_engine PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}") target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}")
...@@ -125,3 +122,4 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3") ...@@ -125,3 +122,4 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3")
# Install library # Install library
install(TARGETS transformer_engine DESTINATION .) install(TARGETS transformer_engine DESTINATION .)
...@@ -4,6 +4,9 @@ ...@@ -4,6 +4,9 @@
"""FW agnostic user-end APIs""" """FW agnostic user-end APIs"""
import glob
import sysconfig
import subprocess
import ctypes import ctypes
import os import os
import platform import platform
...@@ -31,6 +34,39 @@ def _get_sys_extension(): ...@@ -31,6 +34,39 @@ def _get_sys_extension():
return extension return extension
def _load_cudnn():
"""Load CUDNN shared library."""
lib_path = glob.glob(
os.path.join(
sysconfig.get_path("purelib"),
f"nvidia/cudnn/lib/libcudnn.{_get_sys_extension()}.*[0-9]",
)
)
if lib_path:
assert (
len(lib_path) == 1
), f"Found {len(lib_path)} libcudnn.{_get_sys_extension()}.x in nvidia-cudnn-cuXX."
return ctypes.CDLL(lib_path[0], mode=ctypes.RTLD_GLOBAL)
cudnn_home = os.environ.get("CUDNN_HOME") or os.environ.get("CUDNN_PATH")
if cudnn_home:
libs = glob.glob(f"{cudnn_home}/**/libcudnn.{_get_sys_extension()}*", recursive=True)
libs.sort(reverse=True, key=os.path.basename)
if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")
if cuda_home:
libs = glob.glob(f"{cuda_home}/**/libcudnn.{_get_sys_extension()}*", recursive=True)
libs.sort(reverse=True, key=os.path.basename)
if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
return ctypes.CDLL(f"libcudnn.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
def _load_library(): def _load_library():
"""Load shared library with Transformer Engine C extensions""" """Load shared library with Transformer Engine C extensions"""
...@@ -42,5 +78,30 @@ def _load_library(): ...@@ -42,5 +78,30 @@ def _load_library():
return ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL) return ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL)
if "NVTE_PROJECT_BUILDING" not in os.environ: def _load_nvrtc():
"""Load NVRTC shared library."""
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")
if cuda_home:
libs = glob.glob(f"{cuda_home}/**/libnvrtc.{_get_sys_extension()}*", recursive=True)
libs = list(filter(lambda x: not ("stub" in x or "libnvrtc-builtins" in x), libs))
libs.sort(reverse=True, key=os.path.basename)
if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
libs = subprocess.check_output("ldconfig -p | grep 'libnvrtc'", shell=True)
libs = libs.decode("utf-8").split("\n")
sos = []
for lib in libs:
if "stub" in lib or "libnvrtc-builtins" in lib:
continue
if "libnvrtc" in lib and "=>" in lib:
sos.append(lib.split(">")[1].strip())
if sos:
return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL)
return ctypes.CDLL(f"libnvrtc.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
_CUDNN_LIB_CTYPES = _load_cudnn()
_NVRTC_LIB_CTYPES = _load_nvrtc()
_TE_LIB_CTYPES = _load_library() _TE_LIB_CTYPES = _load_library()
recursive-include build_tools *.*
recursive-include common_headers *.*
recursive-include csrc *.*
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include "transformer_engine/activation.h" #include "transformer_engine/activation.h"
#include "jax/csrc/extensions.h" #include "extensions.h"
#include "transformer_engine/transpose.h" #include "transformer_engine/transpose.h"
namespace transformer_engine { namespace transformer_engine {
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "jax/csrc/extensions.h" #include "extensions.h"
#include "transformer_engine/fused_attn.h" #include "transformer_engine/fused_attn.h"
namespace transformer_engine { namespace transformer_engine {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment