Unverified Commit 44fd316f authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[Common] Default CUDA_HOME to /usr/local/cuda when dynamically loading cuDNN and NVRTC (#1183)



Defaulted CUDA_HOME/CUDA_PATH to /usr/local/cuda when attempting to dynamically load cuDNN and NVRTC
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
parent 9101a78f
...@@ -47,20 +47,20 @@ def _get_sys_extension(): ...@@ -47,20 +47,20 @@ def _get_sys_extension():
def _load_cudnn(): def _load_cudnn():
"""Load CUDNN shared library.""" """Load CUDNN shared library."""
# Attempt to locate cuDNN in Python dist-packages
lib_path = glob.glob( lib_path = glob.glob(
os.path.join( os.path.join(
sysconfig.get_path("purelib"), sysconfig.get_path("purelib"),
f"nvidia/cudnn/lib/libcudnn.{_get_sys_extension()}.*[0-9]", f"nvidia/cudnn/lib/libcudnn.{_get_sys_extension()}.*[0-9]",
) )
) )
if lib_path: if lib_path:
assert ( assert (
len(lib_path) == 1 len(lib_path) == 1
), f"Found {len(lib_path)} libcudnn.{_get_sys_extension()}.x in nvidia-cudnn-cuXX." ), f"Found {len(lib_path)} libcudnn.{_get_sys_extension()}.x in nvidia-cudnn-cuXX."
return ctypes.CDLL(lib_path[0], mode=ctypes.RTLD_GLOBAL) return ctypes.CDLL(lib_path[0], mode=ctypes.RTLD_GLOBAL)
# Attempt to locate cuDNN in CUDNN_HOME or CUDNN_PATH, if either is set
cudnn_home = os.environ.get("CUDNN_HOME") or os.environ.get("CUDNN_PATH") cudnn_home = os.environ.get("CUDNN_HOME") or os.environ.get("CUDNN_PATH")
if cudnn_home: if cudnn_home:
libs = glob.glob(f"{cudnn_home}/**/libcudnn.{_get_sys_extension()}*", recursive=True) libs = glob.glob(f"{cudnn_home}/**/libcudnn.{_get_sys_extension()}*", recursive=True)
...@@ -68,13 +68,14 @@ def _load_cudnn(): ...@@ -68,13 +68,14 @@ def _load_cudnn():
if libs: if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") # Attempt to locate cuDNN in CUDA_HOME, CUDA_PATH or /usr/local/cuda
if cuda_home: cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda"
libs = glob.glob(f"{cuda_home}/**/libcudnn.{_get_sys_extension()}*", recursive=True) libs = glob.glob(f"{cuda_home}/**/libcudnn.{_get_sys_extension()}*", recursive=True)
libs.sort(reverse=True, key=os.path.basename) libs.sort(reverse=True, key=os.path.basename)
if libs: if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
return ctypes.CDLL(f"libcudnn.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL) return ctypes.CDLL(f"libcudnn.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
...@@ -91,14 +92,15 @@ def _load_library(): ...@@ -91,14 +92,15 @@ def _load_library():
def _load_nvrtc(): def _load_nvrtc():
"""Load NVRTC shared library.""" """Load NVRTC shared library."""
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") # Attempt to locate NVRTC in CUDA_HOME, CUDA_PATH or /usr/local/cuda
if cuda_home: cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda"
libs = glob.glob(f"{cuda_home}/**/libnvrtc.{_get_sys_extension()}*", recursive=True) libs = glob.glob(f"{cuda_home}/**/libnvrtc.{_get_sys_extension()}*", recursive=True)
libs = list(filter(lambda x: not ("stub" in x or "libnvrtc-builtins" in x), libs)) libs = list(filter(lambda x: not ("stub" in x or "libnvrtc-builtins" in x), libs))
libs.sort(reverse=True, key=os.path.basename) libs.sort(reverse=True, key=os.path.basename)
if libs: if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
# Attempt to locate NVRTC via ldconfig
libs = subprocess.check_output("ldconfig -p | grep 'libnvrtc'", shell=True) libs = subprocess.check_output("ldconfig -p | grep 'libnvrtc'", shell=True)
libs = libs.decode("utf-8").split("\n") libs = libs.decode("utf-8").split("\n")
sos = [] sos = []
...@@ -109,6 +111,8 @@ def _load_nvrtc(): ...@@ -109,6 +111,8 @@ def _load_nvrtc():
sos.append(lib.split(">")[1].strip()) sos.append(lib.split(">")[1].strip())
if sos: if sos:
return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL) return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL)
# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
return ctypes.CDLL(f"libnvrtc.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL) return ctypes.CDLL(f"libnvrtc.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment