Unverified Commit 2b0fb534 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[distributed][misc] be consistent with pytorch for libcudart.so (#6346)

[distributed][misc] keep consistent with how pytorch finds libcudart.so (#6346)
parent d6ab5289
...@@ -4,6 +4,9 @@ convenient for use when we just need to call a few functions. ...@@ -4,6 +4,9 @@ convenient for use when we just need to call a few functions.
""" """
import ctypes import ctypes
import glob
import os
import sys
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
...@@ -33,6 +36,26 @@ class Function: ...@@ -33,6 +36,26 @@ class Function:
argtypes: List[Any] argtypes: List[Any]
def get_pytorch_default_cudart_library_path() -> str:
# code borrowed from https://github.com/pytorch/pytorch/blob/1cae60a87e5bdda8bcf55724a862eeed98a9747e/torch/__init__.py#L284 # noqa
lib_folder = "cuda_runtime"
lib_name = "libcudart.so.*[0-9]"
lib_path = None
for path in sys.path:
nvidia_path = os.path.join(path, "nvidia")
if not os.path.exists(nvidia_path):
continue
candidate_lib_paths = glob.glob(
os.path.join(nvidia_path, lib_folder, "lib", lib_name))
if candidate_lib_paths and not lib_path:
lib_path = candidate_lib_paths[0]
if lib_path:
break
if not lib_path:
raise ValueError(f"{lib_name} not found in the system path {sys.path}")
return lib_path
class CudaRTLibrary: class CudaRTLibrary:
exported_functions = [ exported_functions = [
# ​cudaError_t cudaSetDevice ( int device ) # ​cudaError_t cudaSetDevice ( int device )
...@@ -77,9 +100,7 @@ class CudaRTLibrary: ...@@ -77,9 +100,7 @@ class CudaRTLibrary:
def __init__(self, so_file: Optional[str] = None): def __init__(self, so_file: Optional[str] = None):
if so_file is None: if so_file is None:
assert torch.version.cuda is not None so_file = get_pytorch_default_cudart_library_path()
major_version = torch.version.cuda.split(".")[0]
so_file = f"libcudart.so.{major_version}"
if so_file not in CudaRTLibrary.path_to_library_cache: if so_file not in CudaRTLibrary.path_to_library_cache:
lib = ctypes.CDLL(so_file) lib = ctypes.CDLL(so_file)
CudaRTLibrary.path_to_library_cache[so_file] = lib CudaRTLibrary.path_to_library_cache[so_file] = lib
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment