Unverified Commit 643fb0a0 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Support `nvidia-cu*` wheels for core lib compilation; miscellaneous build improvements (#1717)



* Add support for nvidia cu* lib wheels
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Small cleanup
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* rm unused improt
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* rm req
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Specify exact package versions
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* rm debug ms
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix cuda_path
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add frameworks and nvidia-libs to setup requirements. Add alternates to version finding
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Loose
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix jax wheel install in no toolkit env [wip]
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add missing headers via pip
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Review
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Load SOs, revert CMake
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* rm unused function
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Proper fix got get_te_path
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix JAX exec without cudatk
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix lint and typo
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent edcfc284
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
"""Installation script.""" """Installation script."""
import ctypes
import os import os
import subprocess import subprocess
import sys import sys
...@@ -23,7 +22,7 @@ from .utils import ( ...@@ -23,7 +22,7 @@ from .utils import (
debug_build_enabled, debug_build_enabled,
found_ninja, found_ninja,
get_frameworks, get_frameworks,
cuda_path, nvcc_path,
get_max_jobs_for_parallel_build, get_max_jobs_for_parallel_build,
) )
...@@ -94,7 +93,9 @@ class CMakeExtension(setuptools.Extension): ...@@ -94,7 +93,9 @@ class CMakeExtension(setuptools.Extension):
print(f"Time for build_ext: {total_time:.2f} seconds") print(f"Time for build_ext: {total_time:.2f} seconds")
def get_build_ext(extension_cls: Type[setuptools.Extension], install_so_in_wheel_lib: bool = False): def get_build_ext(
extension_cls: Type[setuptools.Extension], framework_extension_only: bool = False
):
class _CMakeBuildExtension(extension_cls): class _CMakeBuildExtension(extension_cls):
"""Setuptools command with support for CMake extension modules""" """Setuptools command with support for CMake extension modules"""
...@@ -132,7 +133,7 @@ def get_build_ext(extension_cls: Type[setuptools.Extension], install_so_in_wheel ...@@ -132,7 +133,7 @@ def get_build_ext(extension_cls: Type[setuptools.Extension], install_so_in_wheel
# Ensure that binaries are not in global package space. # Ensure that binaries are not in global package space.
lib_dir = ( lib_dir = (
"wheel_lib" "wheel_lib"
if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or install_so_in_wheel_lib if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or framework_extension_only
else "" else ""
) )
target_dir = install_dir / "transformer_engine" / lib_dir target_dir = install_dir / "transformer_engine" / lib_dir
...@@ -143,8 +144,8 @@ def get_build_ext(extension_cls: Type[setuptools.Extension], install_so_in_wheel ...@@ -143,8 +144,8 @@ def get_build_ext(extension_cls: Type[setuptools.Extension], install_so_in_wheel
os.remove(ext) os.remove(ext)
def build_extensions(self): def build_extensions(self):
# BuildExtensions from PyTorch already handle CUDA files correctly # For core lib + JAX install, fix build_ext from pybind11.setup_helpers
# so we don't need to modify their compiler. Only the pybind11 build_ext needs to be fixed. # to handle CUDA files correctly.
if "pytorch" not in get_frameworks(): if "pytorch" not in get_frameworks():
# Ensure at least an empty list of flags for 'cxx' and 'nvcc' when # Ensure at least an empty list of flags for 'cxx' and 'nvcc' when
# extra_compile_args is a dict. # extra_compile_args is a dict.
...@@ -156,6 +157,7 @@ def get_build_ext(extension_cls: Type[setuptools.Extension], install_so_in_wheel ...@@ -156,6 +157,7 @@ def get_build_ext(extension_cls: Type[setuptools.Extension], install_so_in_wheel
# Define new _compile method that redirects to NVCC for .cu and .cuh files. # Define new _compile method that redirects to NVCC for .cu and .cuh files.
original_compile_fn = self.compiler._compile original_compile_fn = self.compiler._compile
if not framework_extension_only:
self.compiler.src_extensions += [".cu", ".cuh"] self.compiler.src_extensions += [".cu", ".cuh"]
def _compile_fn(obj, src, ext, cc_args, extra_postargs, pp_opts) -> None: def _compile_fn(obj, src, ext, cc_args, extra_postargs, pp_opts) -> None:
...@@ -163,10 +165,13 @@ def get_build_ext(extension_cls: Type[setuptools.Extension], install_so_in_wheel ...@@ -163,10 +165,13 @@ def get_build_ext(extension_cls: Type[setuptools.Extension], install_so_in_wheel
cflags = copy.deepcopy(extra_postargs) cflags = copy.deepcopy(extra_postargs)
original_compiler = self.compiler.compiler_so original_compiler = self.compiler.compiler_so
try: try:
_, nvcc_bin = cuda_path()
original_compiler = self.compiler.compiler_so original_compiler = self.compiler.compiler_so
if os.path.splitext(src)[1] in [".cu", ".cuh"]: if (
os.path.splitext(src)[1] in [".cu", ".cuh"]
and not framework_extension_only
):
nvcc_bin = nvcc_path()
self.compiler.set_executable("compiler_so", str(nvcc_bin)) self.compiler.set_executable("compiler_so", str(nvcc_bin))
if isinstance(cflags, dict): if isinstance(cflags, dict):
cflags = cflags["nvcc"] cflags = cflags["nvcc"]
...@@ -178,7 +183,6 @@ def get_build_ext(extension_cls: Type[setuptools.Extension], install_so_in_wheel ...@@ -178,7 +183,6 @@ def get_build_ext(extension_cls: Type[setuptools.Extension], install_so_in_wheel
# Forward unknown options # Forward unknown options
if not any("--forward-unknown-opts" in flag for flag in cflags): if not any("--forward-unknown-opts" in flag for flag in cflags):
cflags.append("--forward-unknown-opts") cflags.append("--forward-unknown-opts")
elif isinstance(cflags, dict): elif isinstance(cflags, dict):
cflags = cflags["cxx"] cflags = cflags["cxx"]
......
...@@ -4,11 +4,12 @@ ...@@ -4,11 +4,12 @@
"""JAX related extensions.""" """JAX related extensions."""
import os import os
import shutil
from pathlib import Path from pathlib import Path
import setuptools import setuptools
from .utils import cuda_path, all_files_in_dir from .utils import get_cuda_include_dirs, all_files_in_dir
from typing import List from typing import List
...@@ -43,16 +44,16 @@ def setup_jax_extension( ...@@ -43,16 +44,16 @@ def setup_jax_extension(
sources = all_files_in_dir(extensions_dir, ".cpp") sources = all_files_in_dir(extensions_dir, ".cpp")
# Header files # Header files
cuda_home, _ = cuda_path() include_dirs = get_cuda_include_dirs()
xla_home = xla_path() include_dirs.extend(
include_dirs = [ [
cuda_home / "include",
common_header_files, common_header_files,
common_header_files / "common", common_header_files / "common",
common_header_files / "common" / "include", common_header_files / "common" / "include",
csrc_header_files, csrc_header_files,
xla_home, xla_path(),
] ]
)
# Compile flags # Compile flags
cxx_flags = ["-O3"] cxx_flags = ["-O3"]
......
...@@ -13,6 +13,7 @@ import shutil ...@@ -13,6 +13,7 @@ import shutil
import subprocess import subprocess
import sys import sys
from pathlib import Path from pathlib import Path
from importlib.metadata import version
from subprocess import CalledProcessError from subprocess import CalledProcessError
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
...@@ -162,8 +163,30 @@ def found_pybind11() -> bool: ...@@ -162,8 +163,30 @@ def found_pybind11() -> bool:
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def cuda_path() -> Tuple[str, str]: def cuda_toolkit_include_path() -> Tuple[str, str]:
"""CUDA root path and NVCC binary path as a tuple. """Returns root path for cuda toolkit includes.
return `None` if CUDA is not found."""
# Try finding CUDA
cuda_home: Optional[Path] = None
if cuda_home is None and os.getenv("CUDA_HOME"):
# Check in CUDA_HOME
cuda_home = Path(os.getenv("CUDA_HOME")) / "include"
if cuda_home is None:
# Check in NVCC
nvcc_bin = shutil.which("nvcc")
if nvcc_bin is not None:
cuda_home = Path(nvcc_bin.rstrip("/bin/nvcc")) / "include"
if cuda_home is None:
# Last-ditch guess in /usr/local/cuda
if Path("/usr/local/cuda").is_dir():
cuda_home = Path("/usr/local/cuda") / "include"
return cuda_home
@functools.lru_cache(maxsize=None)
def nvcc_path() -> Tuple[str, str]:
"""Returns the NVCC binary path.
Throws FileNotFoundError if NVCC is not found.""" Throws FileNotFoundError if NVCC is not found."""
# Try finding NVCC # Try finding NVCC
...@@ -185,7 +208,34 @@ def cuda_path() -> Tuple[str, str]: ...@@ -185,7 +208,34 @@ def cuda_path() -> Tuple[str, str]:
if not nvcc_bin.is_file(): if not nvcc_bin.is_file():
raise FileNotFoundError(f"Could not find NVCC at {nvcc_bin}") raise FileNotFoundError(f"Could not find NVCC at {nvcc_bin}")
return cuda_home, nvcc_bin return nvcc_bin
@functools.lru_cache(maxsize=None)
def get_cuda_include_dirs() -> Tuple[str, str]:
"""Returns the CUDA header directory."""
# If cuda is installed via toolkit, all necessary headers
# are bundled inside the top level cuda directory.
if cuda_toolkit_include_path() is not None:
return [cuda_toolkit_include_path()]
# Use pip wheels to include all headers.
try:
import nvidia
except ModuleNotFoundError as e:
raise RuntimeError("CUDA not found.")
cuda_root = Path(nvidia.__file__).parent
return [
cuda_root / "cuda_nvcc" / "include",
cuda_root / "cublas" / "include",
cuda_root / "cuda_runtime" / "include",
cuda_root / "cudnn" / "include",
cuda_root / "cuda_cccl" / "include",
cuda_root / "nvtx" / "include",
cuda_root / "cuda_nvrtc" / "include",
]
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
...@@ -199,9 +249,18 @@ def cuda_archs() -> str: ...@@ -199,9 +249,18 @@ def cuda_archs() -> str:
def cuda_version() -> Tuple[int, ...]: def cuda_version() -> Tuple[int, ...]:
"""CUDA Toolkit version as a (major, minor) tuple.""" """CUDA Toolkit version as a (major, minor) tuple.
# Query NVCC for version info
_, nvcc_bin = cuda_path() Try to get cuda version by locating the nvcc executable and running nvcc --version. If
nvcc is not found, look for the cuda runtime package pip `nvidia-cuda-runtime-cu12`
and check pip version.
"""
try:
nvcc_bin = nvcc_path()
except FileNotFoundError as e:
pass
else:
output = subprocess.run( output = subprocess.run(
[nvcc_bin, "-V"], [nvcc_bin, "-V"],
capture_output=True, capture_output=True,
...@@ -212,6 +271,13 @@ def cuda_version() -> Tuple[int, ...]: ...@@ -212,6 +271,13 @@ def cuda_version() -> Tuple[int, ...]:
version = match.group(1).split(".") version = match.group(1).split(".")
return tuple(int(v) for v in version) return tuple(int(v) for v in version)
try:
version_str = version("nvidia-cuda-runtime-cu12")
version_tuple = tuple(int(part) for part in version_str.split(".") if part.isdigit())
return version_tuple
except importlib.metadata.PackageNotFoundError:
raise RuntimeError("Could neither find NVCC executable nor CUDA runtime Python package.")
def get_frameworks() -> List[str]: def get_frameworks() -> List[str]:
"""DL frameworks to build support for""" """DL frameworks to build support for"""
...@@ -298,18 +364,3 @@ def install_and_import(package): ...@@ -298,18 +364,3 @@ def install_and_import(package):
main_package = package.split("[")[0] main_package = package.split("[")[0]
subprocess.check_call([sys.executable, "-m", "pip", "install", package]) subprocess.check_call([sys.executable, "-m", "pip", "install", package])
globals()[main_package] = importlib.import_module(main_package) globals()[main_package] = importlib.import_module(main_package)
def uninstall_te_wheel_packages():
subprocess.check_call(
[
sys.executable,
"-m",
"pip",
"uninstall",
"-y",
"transformer_engine_cu12",
"transformer_engine_torch",
"transformer_engine_jax",
]
)
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
"""Installation script.""" """Installation script."""
import os import os
import sys
import time import time
from pathlib import Path from pathlib import Path
from typing import List, Tuple from typing import List, Tuple
...@@ -23,7 +22,6 @@ from build_tools.utils import ( ...@@ -23,7 +22,6 @@ from build_tools.utils import (
get_frameworks, get_frameworks,
install_and_import, install_and_import,
remove_dups, remove_dups,
uninstall_te_wheel_packages,
) )
frameworks = get_frameworks() frameworks = get_frameworks()
...@@ -90,7 +88,15 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: ...@@ -90,7 +88,15 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
""" """
# Common requirements # Common requirements
setup_reqs: List[str] = [] setup_reqs: List[str] = [
"nvidia-cuda-runtime-cu12",
"nvidia-cublas-cu12",
"nvidia-cudnn-cu12",
"nvidia-cuda-cccl-cu12",
"nvidia-cuda-nvcc-cu12",
"nvidia-nvtx-cu12",
"nvidia-cuda-nvrtc-cu12",
]
install_reqs: List[str] = [ install_reqs: List[str] = [
"pydantic", "pydantic",
"importlib-metadata>=1.0", "importlib-metadata>=1.0",
...@@ -109,6 +115,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: ...@@ -109,6 +115,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
# Framework-specific requirements # Framework-specific requirements
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if "pytorch" in frameworks: if "pytorch" in frameworks:
setup_reqs.extend(["torch>=2.1"])
install_reqs.extend(["torch>=2.1"]) install_reqs.extend(["torch>=2.1"])
install_reqs.append( install_reqs.append(
"nvdlfw-inspect @" "nvdlfw-inspect @"
...@@ -118,6 +125,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: ...@@ -118,6 +125,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
# install_reqs.append("triton") # install_reqs.append("triton")
test_reqs.extend(["numpy", "torchvision", "prettytable", "PyYAML"]) test_reqs.extend(["numpy", "torchvision", "prettytable", "PyYAML"])
if "jax" in frameworks: if "jax" in frameworks:
setup_reqs.extend(["jax[cuda12]", "flax>=0.7.1"])
install_reqs.extend(["jax", "flax>=0.7.1"]) install_reqs.extend(["jax", "flax>=0.7.1"])
test_reqs.extend(["numpy"]) test_reqs.extend(["numpy"])
...@@ -154,9 +162,6 @@ if __name__ == "__main__": ...@@ -154,9 +162,6 @@ if __name__ == "__main__":
extras_require = {"test": test_requires} extras_require = {"test": test_requires}
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
# Remove residual FW packages since compiling from source
# results in a single binary with FW extensions included.
uninstall_te_wheel_packages()
if "pytorch" in frameworks: if "pytorch" in frameworks:
from build_tools.pytorch import setup_pytorch_extension from build_tools.pytorch import setup_pytorch_extension
......
...@@ -11,10 +11,10 @@ import subprocess ...@@ -11,10 +11,10 @@ import subprocess
import ctypes import ctypes
import os import os
import platform import platform
import importlib
import functools
from pathlib import Path from pathlib import Path
import transformer_engine
def is_package_installed(package): def is_package_installed(package):
"""Checks if a pip package is installed.""" """Checks if a pip package is installed."""
...@@ -26,9 +26,9 @@ def is_package_installed(package): ...@@ -26,9 +26,9 @@ def is_package_installed(package):
) )
def get_te_path(): def get_te_path() -> Path:
"""Find Transformer Engine install path using pip""" """Find Transformer Engine install path using pip"""
return Path(transformer_engine.__path__[0]).parent return Path(importlib.metadata.distribution("transformer_engine").locate_file("").resolve())
def _get_sys_extension(): def _get_sys_extension():
...@@ -45,20 +45,45 @@ def _get_sys_extension(): ...@@ -45,20 +45,45 @@ def _get_sys_extension():
return extension return extension
def _load_cudnn(): def _load_nvidia_cuda_library(lib_name: str):
"""Load CUDNN shared library.""" """
# Attempt to locate cuDNN in Python dist-packages Attempts to load shared object file installed via pip.
lib_path = glob.glob(
`lib_name`: Name of package as found in the `nvidia` dir in python environment.
"""
so_paths = glob.glob(
os.path.join( os.path.join(
sysconfig.get_path("purelib"), sysconfig.get_path("purelib"),
f"nvidia/cudnn/lib/libcudnn.{_get_sys_extension()}.*[0-9]", f"nvidia/{lib_name}/lib/lib*.{_get_sys_extension()}.*[0-9]",
) )
) )
if lib_path:
assert ( path_found = len(so_paths) > 0
len(lib_path) == 1 ctypes_handles = []
), f"Found {len(lib_path)} libcudnn.{_get_sys_extension()}.x in nvidia-cudnn-cuXX."
return ctypes.CDLL(lib_path[0], mode=ctypes.RTLD_GLOBAL) if path_found:
for so_path in so_paths:
ctypes_handles.append(ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL))
return path_found, ctypes_handles
@functools.lru_cache(maxsize=None)
def _nvidia_cudart_include_dir():
"""Returns the include directory for cuda_runtime.h if exists in python environment."""
try:
import nvidia
except ModuleNotFoundError:
return ""
include_dir = Path(nvidia.__file__).parent / "cuda_runtime"
return str(include_dir) if include_dir.exists() else ""
def _load_cudnn():
"""Load CUDNN shared library."""
# Attempt to locate cuDNN in CUDNN_HOME or CUDNN_PATH, if either is set # Attempt to locate cuDNN in CUDNN_HOME or CUDNN_PATH, if either is set
cudnn_home = os.environ.get("CUDNN_HOME") or os.environ.get("CUDNN_PATH") cudnn_home = os.environ.get("CUDNN_HOME") or os.environ.get("CUDNN_PATH")
...@@ -75,6 +100,11 @@ def _load_cudnn(): ...@@ -75,6 +100,11 @@ def _load_cudnn():
if libs: if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
# Attempt to locate cuDNN in Python dist-packages
found, handle = _load_nvidia_cuda_library("cudnn")
if found:
return handle
# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise # If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
return ctypes.CDLL(f"libcudnn.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL) return ctypes.CDLL(f"libcudnn.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
...@@ -107,6 +137,11 @@ def _load_nvrtc(): ...@@ -107,6 +137,11 @@ def _load_nvrtc():
if libs: if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
# Attempt to locate NVRTC in Python dist-packages
found, handle = _load_nvidia_cuda_library("cuda_nvrtc")
if found:
return handle
# Attempt to locate NVRTC via ldconfig # Attempt to locate NVRTC via ldconfig
libs = subprocess.check_output("ldconfig -p | grep 'libnvrtc'", shell=True) libs = subprocess.check_output("ldconfig -p | grep 'libnvrtc'", shell=True)
libs = libs.decode("utf-8").split("\n") libs = libs.decode("utf-8").split("\n")
...@@ -126,4 +161,10 @@ def _load_nvrtc(): ...@@ -126,4 +161,10 @@ def _load_nvrtc():
if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
_CUDNN_LIB_CTYPES = _load_cudnn() _CUDNN_LIB_CTYPES = _load_cudnn()
_NVRTC_LIB_CTYPES = _load_nvrtc() _NVRTC_LIB_CTYPES = _load_nvrtc()
_CUBLAS_LIB_CTYPES = _load_nvidia_cuda_library("cublas")
_CUDART_LIB_CTYPES = _load_nvidia_cuda_library("cuda_runtime")
_TE_LIB_CTYPES = _load_library() _TE_LIB_CTYPES = _load_library()
# Needed to find the correct headers for NVRTC kernels.
if not os.getenv("NVTE_CUDA_INCLUDE_DIR") and _nvidia_cudart_include_dir():
os.environ["NVTE_CUDA_INCLUDE_DIR"] = _nvidia_cudart_include_dir()
...@@ -84,6 +84,7 @@ if __name__ == "__main__": ...@@ -84,6 +84,7 @@ if __name__ == "__main__":
The script requires JAX to be installed for building. The script requires JAX to be installed for building.
It will raise a RuntimeError if JAX is not available. It will raise a RuntimeError if JAX is not available.
""" """
# Extensions # Extensions
common_headers_dir = "common_headers" common_headers_dir = "common_headers"
copy_common_headers(current_file_path.parent, str(current_file_path / common_headers_dir)) copy_common_headers(current_file_path.parent, str(current_file_path / common_headers_dir))
...@@ -100,6 +101,17 @@ if __name__ == "__main__": ...@@ -100,6 +101,17 @@ if __name__ == "__main__":
description="Transformer acceleration library - Jax Lib", description="Transformer acceleration library - Jax Lib",
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension}, cmdclass={"build_ext": CMakeBuildExtension},
setup_requires=[
"jax[cuda12]",
"flax>=0.7.1",
"nvidia-cuda-runtime-cu12",
"nvidia-cublas-cu12",
"nvidia-cudnn-cu12",
"nvidia-cuda-cccl-cu12",
"nvidia-cuda-nvcc-cu12",
"nvidia-nvtx-cu12",
"nvidia-cuda-nvrtc-cu12",
],
install_requires=["jax", "flax>=0.7.1"], install_requires=["jax", "flax>=0.7.1"],
tests_require=["numpy"], tests_require=["numpy"],
) )
......
...@@ -55,7 +55,17 @@ if __name__ == "__main__": ...@@ -55,7 +55,17 @@ if __name__ == "__main__":
description="Transformer acceleration library - Torch Lib", description="Transformer acceleration library - Torch Lib",
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension}, cmdclass={"build_ext": CMakeBuildExtension},
install_requires=["torch"], setup_requires=[
"torch>=2.1",
"nvidia-cuda-runtime-cu12",
"nvidia-cublas-cu12",
"nvidia-cudnn-cu12",
"nvidia-cuda-cccl-cu12",
"nvidia-cuda-nvcc-cu12",
"nvidia-nvtx-cu12",
"nvidia-cuda-nvrtc-cu12",
],
install_requires=["torch>=2.1"],
tests_require=["numpy", "torchvision"], tests_require=["numpy", "torchvision"],
) )
if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")): if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment