Unverified Commit 09813578 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Build scripts for pip wheels (#1036)



* Specify python version
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add classifiers for python
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add utils to build wheels
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* make wheel scripts
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add aarch
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix paddle wheel
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* PaddlePaddle only builds for x86
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add optional fwk deps
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Python3.8; catch install error
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* [wip] cudnn9 compile with paddle support
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* [wip] dont link cudnn
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* dlopen cudnn
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* dynamically load nvrtc
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* remove residual packages; exclude stub from nvrtc .so search
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Exclude builtins from nvrtc .so search
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* properly include files for sdist
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* paddle wheel tie to python version
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix paddle build from src [wip]
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix workflow paddle build
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix paddle
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix paddle
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix lint from pr986
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add sanity wheel test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add sanity import to wheel test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* remove upper limit on paddlepaddle version
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Remove unused imports
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Remove pybind11 dependency
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix cpp tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Search .sos in cuda home
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* CLeanup, remove residual code
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 6ae584dd
......@@ -78,7 +78,10 @@ jobs:
with:
submodules: recursive
- name: 'Build'
run: pip install . -v
run: |
apt-get update
apt-get install -y libgoogle-glog-dev
pip install . -v
env:
NVTE_FRAMEWORK: paddle
- name: 'Sanity check'
......
......@@ -135,8 +135,14 @@ def get_build_ext(extension_cls: Type[setuptools.Extension]):
search_paths = list(Path(__file__).resolve().parent.parent.iterdir())
# Source compilation from top-level
search_paths.extend(list(Path(self.build_lib).iterdir()))
# Dynamically load required_libs.
from transformer_engine.common import _load_cudnn, _load_nvrtc
_load_cudnn()
_load_nvrtc()
else:
# Only during release sdist build.
# Only during release bdist build for paddlepaddle.
import transformer_engine
search_paths = list(Path(transformer_engine.__path__[0]).iterdir())
......
......@@ -11,6 +11,7 @@ import re
import shutil
import subprocess
import sys
import importlib
from pathlib import Path
from subprocess import CalledProcessError
from typing import List, Optional, Tuple
......@@ -253,15 +254,6 @@ def get_frameworks() -> List[str]:
return _frameworks
def package_files(directory):
paths = []
for path, _, filenames in os.walk(directory):
path = Path(path)
for filename in filenames:
paths.append(str(path / filename).replace(f"{directory}/", ""))
return paths
def copy_common_headers(te_src, dst):
headers = te_src / "common"
for file_path in glob.glob(os.path.join(str(headers), "**", "*.h"), recursive=True):
......@@ -272,11 +264,21 @@ def copy_common_headers(te_src, dst):
def install_and_import(package):
"""Install a package via pip (if not already installed) and import into globals."""
import importlib
try:
importlib.import_module(package)
except ImportError:
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
finally:
globals()[package] = importlib.import_module(package)
main_package = package.split("[")[0]
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
globals()[main_package] = importlib.import_module(main_package)
def uninstall_te_fw_packages():
subprocess.check_call(
[
sys.executable,
"-m",
"pip",
"uninstall",
"-y",
"transformer_engine_torch",
"transformer_engine_paddle",
"transformer_engine_jax",
]
)
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
FROM quay.io/pypa/manylinux_2_28_aarch64
WORKDIR /TransformerEngine/
COPY ../.. /TransformerEngine/
ARG VER="12-3"
ARG ARCH="aarch64"
RUN dnf -y install vim
# Cuda toolkit, cudnn, driver.
RUN dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo
RUN dnf -y install epel-release
RUN dnf -y install cuda-compiler-${VER}.${ARCH} \
cuda-libraries-${VER}.${ARCH} \
cuda-libraries-devel-${VER}.${ARCH}
RUN dnf -y install --allowerasing cudnn9-cuda-12
RUN dnf clean all
RUN rm -rf /var/cache/dnf/*
RUN echo "/usr/local/cuda/lib64" >> /etc/ld.so.conf.d/999_nvidia_cuda.conf
RUN dnf -y install cuda-toolkit
RUN dnf clean all
RUN dnf -y install glog.aarch64 glog-devel.aarch64
ENV PATH="/usr/local/cuda/bin:${PATH}"
ENV LD_LIBRARY_PATH="/usr/local/cuda/lib64:${LD_LIBRARY_PATH}"
ENV CUDA_HOME=/usr/local/cuda
ENV CUDA_ROOT=/usr/local/cuda
ENV CUDA_PATH=/usr/local/cuda
ENV CUDADIR=/usr/local/cuda
ENV NVTE_RELEASE_BUILD=1
CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_aarch64", "true", "false", "false", "true"]
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
FROM quay.io/pypa/manylinux_2_28_x86_64
WORKDIR /TransformerEngine/
COPY ../.. /TransformerEngine/
ARG VER="12-3"
ARG ARCH="x86_64"
RUN dnf -y install vim
# Cuda toolkit, cudnn, driver.
RUN dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
RUN dnf -y install epel-release
RUN dnf -y install cuda-compiler-${VER}.${ARCH} \
cuda-libraries-${VER}.${ARCH} \
cuda-libraries-devel-${VER}.${ARCH}
RUN dnf -y install --allowerasing cudnn9-cuda-12
RUN dnf clean all
RUN rm -rf /var/cache/dnf/*
RUN echo "/usr/local/cuda/lib64" >> /etc/ld.so.conf.d/999_nvidia_cuda.conf
RUN dnf -y install cuda-toolkit
RUN dnf clean all
RUN dnf -y install glog.x86_64 glog-devel.x86_64
ENV PATH="/usr/local/cuda/bin:${PATH}"
ENV LD_LIBRARY_PATH="/usr/local/cuda/lib64:${LD_LIBRARY_PATH}"
ENV CUDA_HOME=/usr/local/cuda
ENV CUDA_ROOT=/usr/local/cuda
ENV CUDA_PATH=/usr/local/cuda
ENV CUDADIR=/usr/local/cuda
ENV NVTE_RELEASE_BUILD=1
CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_x86_64", "true", "true", "true", "true"]
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -e
PLATFORM=${1:-manylinux_2_28_x86_64}
BUILD_COMMON=${2:-true}
BUILD_JAX=${3:-true}
BUILD_PYTORCH=${4:-true}
BUILD_PADDLE=${5:-true}
export NVTE_RELEASE_BUILD=1
export TARGET_BRANCH=${TARGET_BRANCH:-wheels}
mkdir /wheelhouse
mkdir /wheelhouse/logs
# Generate wheels for common library.
git config --global --add safe.directory /TransformerEngine
cd /TransformerEngine
git checkout $TARGET_BRANCH
git submodule update --init --recursive
if $BUILD_COMMON ; then
/opt/python/cp38-cp38/bin/python setup.py bdist_wheel --verbose --python-tag=py3 --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/common.txt
whl_name=$(basename dist/*)
IFS='-' read -ra whl_parts <<< "$whl_name"
whl_name_target="${whl_parts[0]}-${whl_parts[1]}-py3-none-${whl_parts[4]}"
mv dist/"$whl_name" /wheelhouse/"$whl_name_target"
fi
if $BUILD_PYTORCH ; then
cd /TransformerEngine/transformer_engine/pytorch
/opt/python/cp38-cp38/bin/pip install torch
/opt/python/cp38-cp38/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/torch.txt
cp dist/* /wheelhouse/
fi
if $BUILD_JAX ; then
cd /TransformerEngine/transformer_engine/jax
/opt/python/cp38-cp38/bin/pip install jax jaxlib
/opt/python/cp38-cp38/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/jax.txt
cp dist/* /wheelhouse/
fi
if $BUILD_PADDLE ; then
if [ "$PLATFORM" == "manylinux_2_28_x86_64" ] ; then
dnf -y remove --allowerasing cudnn9-cuda-12
dnf -y install libcudnn8-devel.x86_64 libcudnn8.x86_64
cd /TransformerEngine/transformer_engine/paddle
/opt/python/cp38-cp38/bin/pip install /wheelhouse/*.whl
/opt/python/cp38-cp38/bin/pip install paddlepaddle-gpu==2.6.1
/opt/python/cp38-cp38/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp38.txt
/opt/python/cp38-cp38/bin/pip uninstall -y transformer-engine paddlepaddle-gpu
/opt/python/cp39-cp39/bin/pip install /wheelhouse/*.whl
/opt/python/cp39-cp39/bin/pip install paddlepaddle-gpu==2.6.1
/opt/python/cp39-cp39/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp39.txt
/opt/python/cp39-cp39/bin/pip uninstall -y transformer-engine paddlepaddle-gpu
/opt/python/cp310-cp310/bin/pip install /wheelhouse/*.whl
/opt/python/cp310-cp310/bin/pip install paddlepaddle-gpu==2.6.1
/opt/python/cp310-cp310/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp310.txt
/opt/python/cp310-cp310/bin/pip uninstall -y transformer-engine paddlepaddle-gpu
/opt/python/cp311-cp311/bin/pip install /wheelhouse/*.whl
/opt/python/cp311-cp311/bin/pip install paddlepaddle-gpu==2.6.1
/opt/python/cp311-cp311/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp311.txt
/opt/python/cp311-cp311/bin/pip uninstall -y transformer-engine paddlepaddle-gpu
/opt/python/cp312-cp312/bin/pip install /wheelhouse/*.whl
/opt/python/cp312-cp312/bin/pip install paddlepaddle-gpu==2.6.1
/opt/python/cp312-cp312/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp312.txt
/opt/python/cp312-cp312/bin/pip uninstall -y transformer-engine paddlepaddle-gpu
mv dist/* /wheelhouse/
fi
fi
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
docker build --no-cache -t "aarch_wheel" -f build_tools/wheel_utils/Dockerfile.aarch .
docker run --runtime=nvidia --gpus=all --ipc=host "aarch_wheel"
rm -rf aarch_wheelhouse
docker cp $(docker ps -aq | head -1):/wheelhouse/ aarch_wheelhouse
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
docker build --no-cache -t "x86_wheel" -f build_tools/wheel_utils/Dockerfile.x86 .
docker run --runtime=nvidia --gpus=all --ipc=host "x86_wheel"
rm -rf x86_wheelhouse
docker cp $(docker ps -aq | head -1):/wheelhouse x86_wheelhouse
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -e
: "${TE_PATH:=/opt/transformerengine}"
cd $TE_PATH
pip uninstall -y transformer-engine
export NVTE_RELEASE_BUILD=1
python setup.py bdist_wheel
cd transformer_engine/jax
python setup.py sdist
export NVTE_RELEASE_BUILD=0
pip install dist/*
cd $TE_PATH
pip install dist/*
python $TE_PATH/tests/jax/test_sanity_import.py
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -e
: "${TE_PATH:=/opt/transformerengine}"
cd $TE_PATH
pip uninstall -y transformer-engine
export NVTE_RELEASE_BUILD=1
python setup.py bdist_wheel
pip install dist/*
cd transformer_engine/paddle
python setup.py bdist_wheel
export NVTE_RELEASE_BUILD=0
cd $TE_PATH
pip install dist/*
python $TE_PATH/tests/paddle/test_sanity_import.py
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -e
: "${TE_PATH:=/opt/transformerengine}"
cd $TE_PATH
pip uninstall -y transformer-engine
export NVTE_RELEASE_BUILD=1
python setup.py bdist_wheel
cd transformer_engine/pytorch
python setup.py sdist
export NVTE_RELEASE_BUILD=0
pip install dist/*
cd $TE_PATH
pip install dist/*
python $TE_PATH/tests/pytorch/test_sanity_import.py
......@@ -18,6 +18,7 @@ from build_tools.utils import (
remove_dups,
get_frameworks,
install_and_import,
uninstall_te_fw_packages,
)
from build_tools.te_version import te_version
......@@ -28,12 +29,14 @@ current_file_path = Path(__file__).parent.resolve()
from setuptools.command.build_ext import build_ext as BuildExtension
os.environ["NVTE_PROJECT_BUILDING"] = "1"
if "pytorch" in frameworks:
from torch.utils.cpp_extension import BuildExtension
elif "paddle" in frameworks:
from paddle.utils.cpp_extension import BuildExtension
elif "jax" in frameworks:
install_and_import("pybind11")
install_and_import("pybind11[global]")
from pybind11.setup_helpers import build_ext as BuildExtension
......@@ -61,7 +64,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
setup_reqs: List[str] = []
install_reqs: List[str] = [
"pydantic",
"importlib-metadata>=1.0; python_version<'3.8'",
"importlib-metadata>=1.0",
"packaging",
]
test_reqs: List[str] = ["pytest>=8.2.1"]
......@@ -85,6 +88,9 @@ if __name__ == "__main__":
ext_modules = [setup_common_extension()]
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
# Remove residual FW packages since compiling from source
# results in a single binary with FW extensions included.
uninstall_te_fw_packages()
if "pytorch" in frameworks:
from build_tools.pytorch import setup_pytorch_extension
......@@ -129,10 +135,21 @@ if __name__ == "__main__":
),
extras_require={
"test": test_requires,
"pytorch": [f"transformer_engine_torch=={__version__}"],
"jax": [f"transformer_engine_jax=={__version__}"],
"paddle": [f"transformer_engine_paddle=={__version__}"],
},
description="Transformer acceleration library",
ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension},
python_requires=">=3.8, <3.13",
classifiers=[
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
],
setup_requires=setup_requires,
install_requires=install_requires,
license_files=("LICENSE",),
......
......@@ -34,6 +34,7 @@ include_directories(../../transformer_engine/common)
include_directories(${CMAKE_SOURCE_DIR})
find_package(CUDAToolkit REQUIRED)
include(${CMAKE_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake)
add_subdirectory(operator)
add_subdirectory(util)
......@@ -18,7 +18,7 @@ add_executable(test_operator
test_causal_softmax.cu
../test_common.cu)
list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB})
list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn)
target_link_libraries(test_operator PUBLIC ${test_operator_LINKER_LIBS})
target_compile_options(test_operator PRIVATE -O2)
......
......@@ -7,7 +7,8 @@ add_executable(test_util
test_string.cpp
../test_common.cu)
target_link_libraries(test_util PUBLIC CUDA::cudart GTest::gtest_main ${TE_LIB})
target_link_libraries(test_util PUBLIC CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn)
target_compile_options(test_util PRIVATE -O2)
include(GoogleTest)
......
......@@ -32,7 +32,6 @@ endif()
include(${CMAKE_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake)
find_package(Python COMPONENTS Interpreter Development.Module REQUIRED)
include_directories(${PROJECT_SOURCE_DIR}/..)
# Configure Transformer Engine library
......@@ -77,9 +76,7 @@ target_include_directories(transformer_engine PUBLIC
target_link_libraries(transformer_engine PUBLIC
CUDA::cublas
CUDA::cuda_driver
CUDA::cudart
CUDA::nvrtc
CUDNN::cudnn)
CUDA::cudart)
target_include_directories(transformer_engine PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}")
......@@ -125,3 +122,4 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3")
# Install library
install(TARGETS transformer_engine DESTINATION .)
......@@ -4,6 +4,9 @@
"""FW agnostic user-end APIs"""
import glob
import sysconfig
import subprocess
import ctypes
import os
import platform
......@@ -31,6 +34,39 @@ def _get_sys_extension():
return extension
def _load_cudnn():
"""Load CUDNN shared library."""
lib_path = glob.glob(
os.path.join(
sysconfig.get_path("purelib"),
f"nvidia/cudnn/lib/libcudnn.{_get_sys_extension()}.*[0-9]",
)
)
if lib_path:
assert (
len(lib_path) == 1
), f"Found {len(lib_path)} libcudnn.{_get_sys_extension()}.x in nvidia-cudnn-cuXX."
return ctypes.CDLL(lib_path[0], mode=ctypes.RTLD_GLOBAL)
cudnn_home = os.environ.get("CUDNN_HOME") or os.environ.get("CUDNN_PATH")
if cudnn_home:
libs = glob.glob(f"{cudnn_home}/**/libcudnn.{_get_sys_extension()}*", recursive=True)
libs.sort(reverse=True, key=os.path.basename)
if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")
if cuda_home:
libs = glob.glob(f"{cuda_home}/**/libcudnn.{_get_sys_extension()}*", recursive=True)
libs.sort(reverse=True, key=os.path.basename)
if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
return ctypes.CDLL(f"libcudnn.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
def _load_library():
"""Load shared library with Transformer Engine C extensions"""
......@@ -42,5 +78,30 @@ def _load_library():
return ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL)
if "NVTE_PROJECT_BUILDING" not in os.environ:
def _load_nvrtc():
"""Load NVRTC shared library."""
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")
if cuda_home:
libs = glob.glob(f"{cuda_home}/**/libnvrtc.{_get_sys_extension()}*", recursive=True)
libs = list(filter(lambda x: not ("stub" in x or "libnvrtc-builtins" in x), libs))
libs.sort(reverse=True, key=os.path.basename)
if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
libs = subprocess.check_output("ldconfig -p | grep 'libnvrtc'", shell=True)
libs = libs.decode("utf-8").split("\n")
sos = []
for lib in libs:
if "stub" in lib or "libnvrtc-builtins" in lib:
continue
if "libnvrtc" in lib and "=>" in lib:
sos.append(lib.split(">")[1].strip())
if sos:
return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL)
return ctypes.CDLL(f"libnvrtc.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
_CUDNN_LIB_CTYPES = _load_cudnn()
_NVRTC_LIB_CTYPES = _load_nvrtc()
_TE_LIB_CTYPES = _load_library()
recursive-include build_tools *.*
recursive-include common_headers *.*
recursive-include csrc *.*
......@@ -6,7 +6,7 @@
#include "transformer_engine/activation.h"
#include "jax/csrc/extensions.h"
#include "extensions.h"
#include "transformer_engine/transpose.h"
namespace transformer_engine {
......
......@@ -4,7 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include "jax/csrc/extensions.h"
#include "extensions.h"
#include "transformer_engine/fused_attn.h"
namespace transformer_engine {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment