Unverified Commit 8ef3a33d authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Fix runtime lib loading logic (#2297)



Fixes to runtime loading logic and add missing deps
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent c09411d8
......@@ -14,7 +14,7 @@ from typing import List
def install_requirements() -> List[str]:
"""Install dependencies for TE/PyTorch extensions."""
return ["torch>=2.1", "einops", "onnxscript", "onnx"]
return ["torch>=2.1", "einops", "onnxscript", "onnx", "packaging", "pydantic"]
def test_requirements() -> List[str]:
......
......@@ -241,13 +241,9 @@ def get_cuda_include_dirs() -> Tuple[str, str]:
cuda_root = Path(nvidia.__file__).parent
return [
cuda_root / "cuda_nvcc" / "include",
cuda_root / "cublas" / "include",
cuda_root / "cuda_runtime" / "include",
cuda_root / "cudnn" / "include",
cuda_root / "cuda_cccl" / "include",
cuda_root / "nvtx" / "include",
cuda_root / "cuda_nvrtc" / "include",
subdir / "include"
for subdir in cuda_root.iterdir()
if subdir.is_dir() and (subdir / "include").is_dir()
]
......
......@@ -235,31 +235,6 @@ def _get_sys_extension() -> str:
raise RuntimeError(f"Unsupported operating system ({system})")
@functools.lru_cache(maxsize=None)
def _load_nvidia_cuda_library(lib_name: str):
"""
Attempts to load shared object file installed via pip.
`lib_name`: Name of package as found in the `nvidia` dir in python environment.
"""
so_paths = glob.glob(
os.path.join(
sysconfig.get_path("purelib"),
f"nvidia/{lib_name}/lib/lib*{_get_sys_extension()}.*[0-9]",
)
)
path_found = len(so_paths) > 0
ctypes_handles = []
if path_found:
for so_path in so_paths:
ctypes_handles.append(ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL))
return path_found, ctypes_handles
@functools.lru_cache(maxsize=None)
def _nvidia_cudart_include_dir() -> str:
"""Returns the include directory for cuda_runtime.h if exists in python environment."""
......@@ -279,101 +254,102 @@ def _nvidia_cudart_include_dir() -> str:
@functools.lru_cache(maxsize=None)
def _load_cudnn():
"""Load CUDNN shared library."""
def _load_cuda_library_from_python(lib_name: str, strict: bool = False):
"""
Attempts to load shared object file installed via python packages.
# Attempt to locate cuDNN in CUDNN_HOME or CUDNN_PATH, if either is set
cudnn_home = os.environ.get("CUDNN_HOME") or os.environ.get("CUDNN_PATH")
if cudnn_home:
libs = glob.glob(f"{cudnn_home}/**/libcudnn{_get_sys_extension()}*", recursive=True)
libs.sort(reverse=True, key=os.path.basename)
if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
`lib_name` : Name of package as found in the `nvidia` dir in python environment.
`strict` : If set to `True`, throw an error if lib is not found.
"""
# Attempt to locate cuDNN in CUDA_HOME, CUDA_PATH or /usr/local/cuda
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda"
libs = glob.glob(f"{cuda_home}/**/libcudnn{_get_sys_extension()}*", recursive=True)
libs.sort(reverse=True, key=os.path.basename)
if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
ext = _get_sys_extension()
nvidia_dir = os.path.join(sysconfig.get_path("purelib"), "nvidia")
# Attempt to locate cuDNN in Python dist-packages
found, handle = _load_nvidia_cuda_library("cudnn")
if found:
return handle
# PyPI packages provided by nvidia libs exist
# in 4 possible locations inside `nvidia`.
# Check by order of priority.
path_found = False
if os.path.isdir(os.path.join(nvidia_dir, "cu13", lib_name)):
so_paths = glob.glob(os.path.join(nvidia_dir, "cu13", lib_name, f"lib/lib*{ext}.*[0-9]"))
path_found = len(so_paths) > 0
if not path_found and os.path.isdir(os.path.join(nvidia_dir, "cu13")):
so_paths = glob.glob(os.path.join(nvidia_dir, "cu13", f"lib/lib{lib_name}*{ext}.*[0-9]"))
path_found = len(so_paths) > 0
if not path_found and os.path.isdir(os.path.join(nvidia_dir, lib_name)):
so_paths = glob.glob(os.path.join(nvidia_dir, lib_name, f"lib/lib*{ext}.*[0-9]"))
path_found = len(so_paths) > 0
# Attempt to locate libcudnn via ldconfig
libs = subprocess.check_output(["ldconfig", "-p"])
libs = libs.decode("utf-8").split("\n")
sos = []
for lib in libs:
if "libcudnn" in lib and "=>" in lib:
sos.append(lib.split(">")[1].strip())
if sos:
return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL)
if not path_found:
so_paths = glob.glob(os.path.join(nvidia_dir, f"cuda_{lib_name}", f"lib/lib*{ext}.*[0-9]"))
path_found = len(so_paths) > 0
# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
return ctypes.CDLL(f"libcudnn{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
ctypes_handles = []
if path_found:
for so_path in so_paths:
ctypes_handles.append(ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL))
if strict and not path_found:
raise RuntimeError(f"{lib_name} shared object not found.")
return path_found, ctypes_handles
@functools.lru_cache(maxsize=None)
def _load_nvrtc():
"""Load NVRTC shared library."""
# Attempt to locate NVRTC in CUDA_HOME, CUDA_PATH or /usr/local/cuda
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda"
libs = glob.glob(f"{cuda_home}/**/libnvrtc{_get_sys_extension()}*", recursive=True)
libs = list(filter(lambda x: not ("stub" in x or "libnvrtc-builtins" in x), libs))
libs.sort(reverse=True, key=os.path.basename)
if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
# Attempt to locate NVRTC in Python dist-packages
found, handle = _load_nvidia_cuda_library("cuda_nvrtc")
if found:
return handle
def _load_cuda_library_from_system(lib_name: str):
"""
Attempts to load shared object file installed via system/cuda-toolkit.
`lib_name`: Name of library to load without extension or `lib` prefix.
"""
# Attempt to locate NVRTC via ldconfig
libs = subprocess.check_output(["ldconfig", "-p"])
libs = libs.decode("utf-8").split("\n")
sos = []
for lib in libs:
if "libnvrtc" in lib and "=>" in lib:
sos.append(lib.split(">")[1].strip())
if sos:
return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL)
# Where to look for the shared lib in decreasing order of preference.
paths = (
os.environ.get(f"{lib_name.upper()}_HOME"),
os.environ.get(f"{lib_name.upper()}_PATH"),
os.environ.get("CUDA_HOME"),
os.environ.get("CUDA_PATH"),
"/usr/local/cuda",
)
for path in paths:
if path is None:
continue
libs = glob.glob(f"{path}/**/lib{lib_name}{_get_sys_extension()}*", recursive=True)
libs = [lib for lib in libs if "stub" not in lib]
libs.sort(reverse=True, key=os.path.basename)
if libs:
return True, ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
return ctypes.CDLL(f"libnvrtc{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
# Search in LD_LIBRARY_PATH.
try:
_lib_handle = ctypes.CDLL(f"lib{lib_name}{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
return True, _lib_handle
except OSError:
return False, None
@functools.lru_cache(maxsize=None)
def _load_curand():
"""Load cuRAND shared library."""
# Attempt to locate cuRAND in CUDA_HOME, CUDA_PATH or /usr/local/cuda
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda"
libs = glob.glob(f"{cuda_home}/**/libcurand{_get_sys_extension()}*", recursive=True)
libs = list(filter(lambda x: not ("stub" in x), libs))
libs.sort(reverse=True, key=os.path.basename)
if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
# Attempt to locate cuRAND in Python dist-packages
found, handle = _load_nvidia_cuda_library("curand")
def _load_cuda_library(lib_name: str):
"""
Load given shared library.
Prioritize loading from system/toolkit
before checking python packages.
"""
# Attempt to locate library in system.
found, handle = _load_cuda_library_from_system(lib_name)
if found:
return handle
return True, handle
# Attempt to locate cuRAND via ldconfig
libs = subprocess.check_output(["ldconfig", "-p"])
libs = libs.decode("utf-8").split("\n")
sos = []
for lib in libs:
if "libcurand" in lib and "=>" in lib:
sos.append(lib.split(">")[1].strip())
if sos:
return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL)
# Attempt to locate library in Python dist-packages.
found, handle = _load_cuda_library_from_python(lib_name)
if found:
return False, handle
# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
return ctypes.CDLL(f"libcurand{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
raise RuntimeError(f"{lib_name} shared object not found.")
@functools.lru_cache(maxsize=None)
......@@ -384,11 +360,23 @@ def _load_core_library():
if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
sanity_checks_for_pypi_installation()
_CUDNN_LIB_CTYPES = _load_cudnn()
_NVRTC_LIB_CTYPES = _load_nvrtc()
_CURAND_LIB_CTYPES = _load_curand()
_CUBLAS_LIB_CTYPES = _load_nvidia_cuda_library("cublas")
_CUDART_LIB_CTYPES = _load_nvidia_cuda_library("cuda_runtime")
# `_load_cuda_library` is used for packages that must be loaded
# during runtime. Both system and pypi packages are searched
# and an error is thrown if not found.
_, _CUDNN_LIB_CTYPES = _load_cuda_library("cudnn")
system_nvrtc, _NVRTC_LIB_CTYPES = _load_cuda_library("nvrtc")
system_curand, _CURAND_LIB_CTYPES = _load_cuda_library("curand")
# This additional step is necessary to be able to install TE wheels
# and import TE (without any guards) in an environment where the cuda
# toolkit might be absent without being guarded
load_libs_for_no_ctk = not system_nvrtc and not system_curand
if load_libs_for_no_ctk:
_CUBLAS_LIB_CTYPES = _load_cuda_library_from_python("cublas", strict=True)
_CUDART_LIB_CTYPES = _load_cuda_library_from_python("cudart", strict=True)
_CUDNN_ALL_LIB_CTYPES = _load_cuda_library_from_python("cudnn", strict=True)
_TE_LIB_CTYPES = _load_core_library()
# Needed to find the correct headers for NVRTC kernels.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment