Unverified Commit 643fb0a0 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Support `nvidia-cu*` wheels for core lib compilation; miscellaneous build improvements (#1717)



* Add support for nvidia cu* lib wheels
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Small cleanup
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* rm unused improt
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* rm req
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Specify exact package versions
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* rm debug ms
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix cuda_path
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add frameworks and nvidia-libs to setup requirements. Add alternates to version finding
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Loose
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix jax wheel install in no toolkit env [wip]
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add missing headers via pip
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Review
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Load SOs, revert CMake
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* rm unused function
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Proper fix got get_te_path
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix JAX exec without cudatk
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix lint and typo
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent edcfc284
......@@ -4,7 +4,6 @@
"""Installation script."""
import ctypes
import os
import subprocess
import sys
......@@ -23,7 +22,7 @@ from .utils import (
debug_build_enabled,
found_ninja,
get_frameworks,
cuda_path,
nvcc_path,
get_max_jobs_for_parallel_build,
)
......@@ -94,7 +93,9 @@ class CMakeExtension(setuptools.Extension):
print(f"Time for build_ext: {total_time:.2f} seconds")
def get_build_ext(extension_cls: Type[setuptools.Extension], install_so_in_wheel_lib: bool = False):
def get_build_ext(
extension_cls: Type[setuptools.Extension], framework_extension_only: bool = False
):
class _CMakeBuildExtension(extension_cls):
"""Setuptools command with support for CMake extension modules"""
......@@ -132,7 +133,7 @@ def get_build_ext(extension_cls: Type[setuptools.Extension], install_so_in_wheel
# Ensure that binaries are not in global package space.
lib_dir = (
"wheel_lib"
if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or install_so_in_wheel_lib
if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or framework_extension_only
else ""
)
target_dir = install_dir / "transformer_engine" / lib_dir
......@@ -143,8 +144,8 @@ def get_build_ext(extension_cls: Type[setuptools.Extension], install_so_in_wheel
os.remove(ext)
def build_extensions(self):
# BuildExtensions from PyTorch already handle CUDA files correctly
# so we don't need to modify their compiler. Only the pybind11 build_ext needs to be fixed.
# For core lib + JAX install, fix build_ext from pybind11.setup_helpers
# to handle CUDA files correctly.
if "pytorch" not in get_frameworks():
# Ensure at least an empty list of flags for 'cxx' and 'nvcc' when
# extra_compile_args is a dict.
......@@ -156,6 +157,7 @@ def get_build_ext(extension_cls: Type[setuptools.Extension], install_so_in_wheel
# Define new _compile method that redirects to NVCC for .cu and .cuh files.
original_compile_fn = self.compiler._compile
if not framework_extension_only:
self.compiler.src_extensions += [".cu", ".cuh"]
def _compile_fn(obj, src, ext, cc_args, extra_postargs, pp_opts) -> None:
......@@ -163,10 +165,13 @@ def get_build_ext(extension_cls: Type[setuptools.Extension], install_so_in_wheel
cflags = copy.deepcopy(extra_postargs)
original_compiler = self.compiler.compiler_so
try:
_, nvcc_bin = cuda_path()
original_compiler = self.compiler.compiler_so
if os.path.splitext(src)[1] in [".cu", ".cuh"]:
if (
os.path.splitext(src)[1] in [".cu", ".cuh"]
and not framework_extension_only
):
nvcc_bin = nvcc_path()
self.compiler.set_executable("compiler_so", str(nvcc_bin))
if isinstance(cflags, dict):
cflags = cflags["nvcc"]
......@@ -178,7 +183,6 @@ def get_build_ext(extension_cls: Type[setuptools.Extension], install_so_in_wheel
# Forward unknown options
if not any("--forward-unknown-opts" in flag for flag in cflags):
cflags.append("--forward-unknown-opts")
elif isinstance(cflags, dict):
cflags = cflags["cxx"]
......
......@@ -4,11 +4,12 @@
"""JAX related extensions."""
import os
import shutil
from pathlib import Path
import setuptools
from .utils import cuda_path, all_files_in_dir
from .utils import get_cuda_include_dirs, all_files_in_dir
from typing import List
......@@ -43,16 +44,16 @@ def setup_jax_extension(
sources = all_files_in_dir(extensions_dir, ".cpp")
# Header files
cuda_home, _ = cuda_path()
xla_home = xla_path()
include_dirs = [
cuda_home / "include",
include_dirs = get_cuda_include_dirs()
include_dirs.extend(
[
common_header_files,
common_header_files / "common",
common_header_files / "common" / "include",
csrc_header_files,
xla_home,
xla_path(),
]
)
# Compile flags
cxx_flags = ["-O3"]
......
......@@ -13,6 +13,7 @@ import shutil
import subprocess
import sys
from pathlib import Path
from importlib.metadata import version
from subprocess import CalledProcessError
from typing import List, Optional, Tuple, Union
......@@ -162,8 +163,30 @@ def found_pybind11() -> bool:
@functools.lru_cache(maxsize=None)
def cuda_path() -> Tuple[str, str]:
"""CUDA root path and NVCC binary path as a tuple.
def cuda_toolkit_include_path() -> Tuple[str, str]:
"""Returns root path for cuda toolkit includes.
return `None` if CUDA is not found."""
# Try finding CUDA
cuda_home: Optional[Path] = None
if cuda_home is None and os.getenv("CUDA_HOME"):
# Check in CUDA_HOME
cuda_home = Path(os.getenv("CUDA_HOME")) / "include"
if cuda_home is None:
# Check in NVCC
nvcc_bin = shutil.which("nvcc")
if nvcc_bin is not None:
cuda_home = Path(nvcc_bin.rstrip("/bin/nvcc")) / "include"
if cuda_home is None:
# Last-ditch guess in /usr/local/cuda
if Path("/usr/local/cuda").is_dir():
cuda_home = Path("/usr/local/cuda") / "include"
return cuda_home
@functools.lru_cache(maxsize=None)
def nvcc_path() -> Tuple[str, str]:
"""Returns the NVCC binary path.
Throws FileNotFoundError if NVCC is not found."""
# Try finding NVCC
......@@ -185,7 +208,34 @@ def cuda_path() -> Tuple[str, str]:
if not nvcc_bin.is_file():
raise FileNotFoundError(f"Could not find NVCC at {nvcc_bin}")
return cuda_home, nvcc_bin
return nvcc_bin
@functools.lru_cache(maxsize=None)
def get_cuda_include_dirs() -> Tuple[str, str]:
"""Returns the CUDA header directory."""
# If cuda is installed via toolkit, all necessary headers
# are bundled inside the top level cuda directory.
if cuda_toolkit_include_path() is not None:
return [cuda_toolkit_include_path()]
# Use pip wheels to include all headers.
try:
import nvidia
except ModuleNotFoundError as e:
raise RuntimeError("CUDA not found.")
cuda_root = Path(nvidia.__file__).parent
return [
cuda_root / "cuda_nvcc" / "include",
cuda_root / "cublas" / "include",
cuda_root / "cuda_runtime" / "include",
cuda_root / "cudnn" / "include",
cuda_root / "cuda_cccl" / "include",
cuda_root / "nvtx" / "include",
cuda_root / "cuda_nvrtc" / "include",
]
@functools.lru_cache(maxsize=None)
......@@ -199,9 +249,18 @@ def cuda_archs() -> str:
def cuda_version() -> Tuple[int, ...]:
"""CUDA Toolkit version as a (major, minor) tuple."""
# Query NVCC for version info
_, nvcc_bin = cuda_path()
"""CUDA Toolkit version as a (major, minor) tuple.
Try to get cuda version by locating the nvcc executable and running nvcc --version. If
nvcc is not found, look for the cuda runtime package pip `nvidia-cuda-runtime-cu12`
and check pip version.
"""
try:
nvcc_bin = nvcc_path()
except FileNotFoundError as e:
pass
else:
output = subprocess.run(
[nvcc_bin, "-V"],
capture_output=True,
......@@ -212,6 +271,13 @@ def cuda_version() -> Tuple[int, ...]:
version = match.group(1).split(".")
return tuple(int(v) for v in version)
try:
version_str = version("nvidia-cuda-runtime-cu12")
version_tuple = tuple(int(part) for part in version_str.split(".") if part.isdigit())
return version_tuple
except importlib.metadata.PackageNotFoundError:
raise RuntimeError("Could neither find NVCC executable nor CUDA runtime Python package.")
def get_frameworks() -> List[str]:
"""DL frameworks to build support for"""
......@@ -298,18 +364,3 @@ def install_and_import(package):
main_package = package.split("[")[0]
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
globals()[main_package] = importlib.import_module(main_package)
def uninstall_te_wheel_packages():
subprocess.check_call(
[
sys.executable,
"-m",
"pip",
"uninstall",
"-y",
"transformer_engine_cu12",
"transformer_engine_torch",
"transformer_engine_jax",
]
)
......@@ -5,7 +5,6 @@
"""Installation script."""
import os
import sys
import time
from pathlib import Path
from typing import List, Tuple
......@@ -23,7 +22,6 @@ from build_tools.utils import (
get_frameworks,
install_and_import,
remove_dups,
uninstall_te_wheel_packages,
)
frameworks = get_frameworks()
......@@ -90,7 +88,15 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
"""
# Common requirements
setup_reqs: List[str] = []
setup_reqs: List[str] = [
"nvidia-cuda-runtime-cu12",
"nvidia-cublas-cu12",
"nvidia-cudnn-cu12",
"nvidia-cuda-cccl-cu12",
"nvidia-cuda-nvcc-cu12",
"nvidia-nvtx-cu12",
"nvidia-cuda-nvrtc-cu12",
]
install_reqs: List[str] = [
"pydantic",
"importlib-metadata>=1.0",
......@@ -109,6 +115,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
# Framework-specific requirements
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if "pytorch" in frameworks:
setup_reqs.extend(["torch>=2.1"])
install_reqs.extend(["torch>=2.1"])
install_reqs.append(
"nvdlfw-inspect @"
......@@ -118,6 +125,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
# install_reqs.append("triton")
test_reqs.extend(["numpy", "torchvision", "prettytable", "PyYAML"])
if "jax" in frameworks:
setup_reqs.extend(["jax[cuda12]", "flax>=0.7.1"])
install_reqs.extend(["jax", "flax>=0.7.1"])
test_reqs.extend(["numpy"])
......@@ -154,9 +162,6 @@ if __name__ == "__main__":
extras_require = {"test": test_requires}
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
# Remove residual FW packages since compiling from source
# results in a single binary with FW extensions included.
uninstall_te_wheel_packages()
if "pytorch" in frameworks:
from build_tools.pytorch import setup_pytorch_extension
......
......@@ -11,10 +11,10 @@ import subprocess
import ctypes
import os
import platform
import importlib
import functools
from pathlib import Path
import transformer_engine
def is_package_installed(package):
"""Checks if a pip package is installed."""
......@@ -26,9 +26,9 @@ def is_package_installed(package):
)
def get_te_path():
def get_te_path() -> Path:
"""Find Transformer Engine install path using pip"""
return Path(transformer_engine.__path__[0]).parent
return Path(importlib.metadata.distribution("transformer_engine").locate_file("").resolve())
def _get_sys_extension():
......@@ -45,20 +45,45 @@ def _get_sys_extension():
return extension
def _load_cudnn():
"""Load CUDNN shared library."""
# Attempt to locate cuDNN in Python dist-packages
lib_path = glob.glob(
def _load_nvidia_cuda_library(lib_name: str):
"""
Attempts to load shared object file installed via pip.
`lib_name`: Name of package as found in the `nvidia` dir in python environment.
"""
so_paths = glob.glob(
os.path.join(
sysconfig.get_path("purelib"),
f"nvidia/cudnn/lib/libcudnn.{_get_sys_extension()}.*[0-9]",
f"nvidia/{lib_name}/lib/lib*.{_get_sys_extension()}.*[0-9]",
)
)
if lib_path:
assert (
len(lib_path) == 1
), f"Found {len(lib_path)} libcudnn.{_get_sys_extension()}.x in nvidia-cudnn-cuXX."
return ctypes.CDLL(lib_path[0], mode=ctypes.RTLD_GLOBAL)
path_found = len(so_paths) > 0
ctypes_handles = []
if path_found:
for so_path in so_paths:
ctypes_handles.append(ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL))
return path_found, ctypes_handles
@functools.lru_cache(maxsize=None)
def _nvidia_cudart_include_dir():
"""Returns the include directory for cuda_runtime.h if exists in python environment."""
try:
import nvidia
except ModuleNotFoundError:
return ""
include_dir = Path(nvidia.__file__).parent / "cuda_runtime"
return str(include_dir) if include_dir.exists() else ""
def _load_cudnn():
"""Load CUDNN shared library."""
# Attempt to locate cuDNN in CUDNN_HOME or CUDNN_PATH, if either is set
cudnn_home = os.environ.get("CUDNN_HOME") or os.environ.get("CUDNN_PATH")
......@@ -75,6 +100,11 @@ def _load_cudnn():
if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
# Attempt to locate cuDNN in Python dist-packages
found, handle = _load_nvidia_cuda_library("cudnn")
if found:
return handle
# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
return ctypes.CDLL(f"libcudnn.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
......@@ -107,6 +137,11 @@ def _load_nvrtc():
if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
# Attempt to locate NVRTC in Python dist-packages
found, handle = _load_nvidia_cuda_library("cuda_nvrtc")
if found:
return handle
# Attempt to locate NVRTC via ldconfig
libs = subprocess.check_output("ldconfig -p | grep 'libnvrtc'", shell=True)
libs = libs.decode("utf-8").split("\n")
......@@ -126,4 +161,10 @@ def _load_nvrtc():
if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
_CUDNN_LIB_CTYPES = _load_cudnn()
_NVRTC_LIB_CTYPES = _load_nvrtc()
_CUBLAS_LIB_CTYPES = _load_nvidia_cuda_library("cublas")
_CUDART_LIB_CTYPES = _load_nvidia_cuda_library("cuda_runtime")
_TE_LIB_CTYPES = _load_library()
# Needed to find the correct headers for NVRTC kernels.
if not os.getenv("NVTE_CUDA_INCLUDE_DIR") and _nvidia_cudart_include_dir():
os.environ["NVTE_CUDA_INCLUDE_DIR"] = _nvidia_cudart_include_dir()
......@@ -84,6 +84,7 @@ if __name__ == "__main__":
The script requires JAX to be installed for building.
It will raise a RuntimeError if JAX is not available.
"""
# Extensions
common_headers_dir = "common_headers"
copy_common_headers(current_file_path.parent, str(current_file_path / common_headers_dir))
......@@ -100,6 +101,17 @@ if __name__ == "__main__":
description="Transformer acceleration library - Jax Lib",
ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension},
setup_requires=[
"jax[cuda12]",
"flax>=0.7.1",
"nvidia-cuda-runtime-cu12",
"nvidia-cublas-cu12",
"nvidia-cudnn-cu12",
"nvidia-cuda-cccl-cu12",
"nvidia-cuda-nvcc-cu12",
"nvidia-nvtx-cu12",
"nvidia-cuda-nvrtc-cu12",
],
install_requires=["jax", "flax>=0.7.1"],
tests_require=["numpy"],
)
......
......@@ -55,7 +55,17 @@ if __name__ == "__main__":
description="Transformer acceleration library - Torch Lib",
ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension},
install_requires=["torch"],
setup_requires=[
"torch>=2.1",
"nvidia-cuda-runtime-cu12",
"nvidia-cublas-cu12",
"nvidia-cudnn-cu12",
"nvidia-cuda-cccl-cu12",
"nvidia-cuda-nvcc-cu12",
"nvidia-nvtx-cu12",
"nvidia-cuda-nvrtc-cu12",
],
install_requires=["torch>=2.1"],
tests_require=["numpy", "torchvision"],
)
if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment