Unverified Commit 8ef3a33d authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Fix runtime lib loading logic (#2297)



Fixes to runtime loading logic and add missing deps
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent c09411d8
...@@ -14,7 +14,7 @@ from typing import List ...@@ -14,7 +14,7 @@ from typing import List
def install_requirements() -> List[str]: def install_requirements() -> List[str]:
"""Install dependencies for TE/PyTorch extensions.""" """Install dependencies for TE/PyTorch extensions."""
return ["torch>=2.1", "einops", "onnxscript", "onnx"] return ["torch>=2.1", "einops", "onnxscript", "onnx", "packaging", "pydantic"]
def test_requirements() -> List[str]: def test_requirements() -> List[str]:
......
...@@ -241,13 +241,9 @@ def get_cuda_include_dirs() -> Tuple[str, str]: ...@@ -241,13 +241,9 @@ def get_cuda_include_dirs() -> Tuple[str, str]:
cuda_root = Path(nvidia.__file__).parent cuda_root = Path(nvidia.__file__).parent
return [ return [
cuda_root / "cuda_nvcc" / "include", subdir / "include"
cuda_root / "cublas" / "include", for subdir in cuda_root.iterdir()
cuda_root / "cuda_runtime" / "include", if subdir.is_dir() and (subdir / "include").is_dir()
cuda_root / "cudnn" / "include",
cuda_root / "cuda_cccl" / "include",
cuda_root / "nvtx" / "include",
cuda_root / "cuda_nvrtc" / "include",
] ]
......
...@@ -235,31 +235,6 @@ def _get_sys_extension() -> str: ...@@ -235,31 +235,6 @@ def _get_sys_extension() -> str:
raise RuntimeError(f"Unsupported operating system ({system})") raise RuntimeError(f"Unsupported operating system ({system})")
@functools.lru_cache(maxsize=None)
def _load_nvidia_cuda_library(lib_name: str):
"""
Attempts to load shared object file installed via pip.
`lib_name`: Name of package as found in the `nvidia` dir in python environment.
"""
so_paths = glob.glob(
os.path.join(
sysconfig.get_path("purelib"),
f"nvidia/{lib_name}/lib/lib*{_get_sys_extension()}.*[0-9]",
)
)
path_found = len(so_paths) > 0
ctypes_handles = []
if path_found:
for so_path in so_paths:
ctypes_handles.append(ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL))
return path_found, ctypes_handles
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def _nvidia_cudart_include_dir() -> str: def _nvidia_cudart_include_dir() -> str:
"""Returns the include directory for cuda_runtime.h if exists in python environment.""" """Returns the include directory for cuda_runtime.h if exists in python environment."""
...@@ -279,101 +254,102 @@ def _nvidia_cudart_include_dir() -> str: ...@@ -279,101 +254,102 @@ def _nvidia_cudart_include_dir() -> str:
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def _load_cudnn(): def _load_cuda_library_from_python(lib_name: str, strict: bool = False):
"""Load CUDNN shared library.""" """
Attempts to load shared object file installed via python packages.
# Attempt to locate cuDNN in CUDNN_HOME or CUDNN_PATH, if either is set `lib_name` : Name of package as found in the `nvidia` dir in python environment.
cudnn_home = os.environ.get("CUDNN_HOME") or os.environ.get("CUDNN_PATH") `strict` : If set to `True`, throw an error if lib is not found.
if cudnn_home: """
libs = glob.glob(f"{cudnn_home}/**/libcudnn{_get_sys_extension()}*", recursive=True)
libs.sort(reverse=True, key=os.path.basename)
if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
# Attempt to locate cuDNN in CUDA_HOME, CUDA_PATH or /usr/local/cuda ext = _get_sys_extension()
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda" nvidia_dir = os.path.join(sysconfig.get_path("purelib"), "nvidia")
libs = glob.glob(f"{cuda_home}/**/libcudnn{_get_sys_extension()}*", recursive=True)
libs.sort(reverse=True, key=os.path.basename)
if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
# Attempt to locate cuDNN in Python dist-packages # PyPI packages provided by nvidia libs exist
found, handle = _load_nvidia_cuda_library("cudnn") # in 4 possible locations inside `nvidia`.
if found: # Check by order of priority.
return handle path_found = False
if os.path.isdir(os.path.join(nvidia_dir, "cu13", lib_name)):
so_paths = glob.glob(os.path.join(nvidia_dir, "cu13", lib_name, f"lib/lib*{ext}.*[0-9]"))
path_found = len(so_paths) > 0
if not path_found and os.path.isdir(os.path.join(nvidia_dir, "cu13")):
so_paths = glob.glob(os.path.join(nvidia_dir, "cu13", f"lib/lib{lib_name}*{ext}.*[0-9]"))
path_found = len(so_paths) > 0
if not path_found and os.path.isdir(os.path.join(nvidia_dir, lib_name)):
so_paths = glob.glob(os.path.join(nvidia_dir, lib_name, f"lib/lib*{ext}.*[0-9]"))
path_found = len(so_paths) > 0
# Attempt to locate libcudnn via ldconfig if not path_found:
libs = subprocess.check_output(["ldconfig", "-p"]) so_paths = glob.glob(os.path.join(nvidia_dir, f"cuda_{lib_name}", f"lib/lib*{ext}.*[0-9]"))
libs = libs.decode("utf-8").split("\n") path_found = len(so_paths) > 0
sos = []
for lib in libs:
if "libcudnn" in lib and "=>" in lib:
sos.append(lib.split(">")[1].strip())
if sos:
return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL)
# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise ctypes_handles = []
return ctypes.CDLL(f"libcudnn{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
if path_found:
for so_path in so_paths:
ctypes_handles.append(ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL))
if strict and not path_found:
raise RuntimeError(f"{lib_name} shared object not found.")
return path_found, ctypes_handles
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def _load_nvrtc(): def _load_cuda_library_from_system(lib_name: str):
"""Load NVRTC shared library.""" """
# Attempt to locate NVRTC in CUDA_HOME, CUDA_PATH or /usr/local/cuda Attempts to load shared object file installed via system/cuda-toolkit.
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda"
libs = glob.glob(f"{cuda_home}/**/libnvrtc{_get_sys_extension()}*", recursive=True) `lib_name`: Name of library to load without extension or `lib` prefix.
libs = list(filter(lambda x: not ("stub" in x or "libnvrtc-builtins" in x), libs)) """
libs.sort(reverse=True, key=os.path.basename)
if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
# Attempt to locate NVRTC in Python dist-packages
found, handle = _load_nvidia_cuda_library("cuda_nvrtc")
if found:
return handle
# Attempt to locate NVRTC via ldconfig # Where to look for the shared lib in decreasing order of preference.
libs = subprocess.check_output(["ldconfig", "-p"]) paths = (
libs = libs.decode("utf-8").split("\n") os.environ.get(f"{lib_name.upper()}_HOME"),
sos = [] os.environ.get(f"{lib_name.upper()}_PATH"),
for lib in libs: os.environ.get("CUDA_HOME"),
if "libnvrtc" in lib and "=>" in lib: os.environ.get("CUDA_PATH"),
sos.append(lib.split(">")[1].strip()) "/usr/local/cuda",
if sos: )
return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL)
for path in paths:
if path is None:
continue
libs = glob.glob(f"{path}/**/lib{lib_name}{_get_sys_extension()}*", recursive=True)
libs = [lib for lib in libs if "stub" not in lib]
libs.sort(reverse=True, key=os.path.basename)
if libs:
return True, ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise # Search in LD_LIBRARY_PATH.
return ctypes.CDLL(f"libnvrtc{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL) try:
_lib_handle = ctypes.CDLL(f"lib{lib_name}{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
return True, _lib_handle
except OSError:
return False, None
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def _load_curand(): def _load_cuda_library(lib_name: str):
"""Load cuRAND shared library.""" """
# Attempt to locate cuRAND in CUDA_HOME, CUDA_PATH or /usr/local/cuda Load given shared library.
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda" Prioritize loading from system/toolkit
libs = glob.glob(f"{cuda_home}/**/libcurand{_get_sys_extension()}*", recursive=True) before checking python packages.
libs = list(filter(lambda x: not ("stub" in x), libs)) """
libs.sort(reverse=True, key=os.path.basename)
if libs: # Attempt to locate library in system.
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) found, handle = _load_cuda_library_from_system(lib_name)
# Attempt to locate cuRAND in Python dist-packages
found, handle = _load_nvidia_cuda_library("curand")
if found: if found:
return handle return True, handle
# Attempt to locate cuRAND via ldconfig # Attempt to locate library in Python dist-packages.
libs = subprocess.check_output(["ldconfig", "-p"]) found, handle = _load_cuda_library_from_python(lib_name)
libs = libs.decode("utf-8").split("\n") if found:
sos = [] return False, handle
for lib in libs:
if "libcurand" in lib and "=>" in lib:
sos.append(lib.split(">")[1].strip())
if sos:
return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL)
# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise raise RuntimeError(f"{lib_name} shared object not found.")
return ctypes.CDLL(f"libcurand{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
...@@ -384,11 +360,23 @@ def _load_core_library(): ...@@ -384,11 +360,23 @@ def _load_core_library():
if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
sanity_checks_for_pypi_installation() sanity_checks_for_pypi_installation()
_CUDNN_LIB_CTYPES = _load_cudnn()
_NVRTC_LIB_CTYPES = _load_nvrtc() # `_load_cuda_library` is used for packages that must be loaded
_CURAND_LIB_CTYPES = _load_curand() # during runtime. Both system and pypi packages are searched
_CUBLAS_LIB_CTYPES = _load_nvidia_cuda_library("cublas") # and an error is thrown if not found.
_CUDART_LIB_CTYPES = _load_nvidia_cuda_library("cuda_runtime") _, _CUDNN_LIB_CTYPES = _load_cuda_library("cudnn")
system_nvrtc, _NVRTC_LIB_CTYPES = _load_cuda_library("nvrtc")
system_curand, _CURAND_LIB_CTYPES = _load_cuda_library("curand")
# This additional step is necessary to be able to install TE wheels
# and import TE (without any guards) in an environment where the cuda
# toolkit might be absent without being guarded
load_libs_for_no_ctk = not system_nvrtc and not system_curand
if load_libs_for_no_ctk:
_CUBLAS_LIB_CTYPES = _load_cuda_library_from_python("cublas", strict=True)
_CUDART_LIB_CTYPES = _load_cuda_library_from_python("cudart", strict=True)
_CUDNN_ALL_LIB_CTYPES = _load_cuda_library_from_python("cudnn", strict=True)
_TE_LIB_CTYPES = _load_core_library() _TE_LIB_CTYPES = _load_core_library()
# Needed to find the correct headers for NVRTC kernels. # Needed to find the correct headers for NVRTC kernels.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment