Commit e35337f0 authored by Tim Dettmers's avatar Tim Dettmers
Browse files

Now determining cuda version via libcudart.so call.

parent 8f84674d
...@@ -28,10 +28,40 @@ def check_cuda_result(cuda, result_val): ...@@ -28,10 +28,40 @@ def check_cuda_result(cuda, result_val):
if result_val != 0: if result_val != 0:
error_str = ctypes.c_char_p() error_str = ctypes.c_char_p()
cuda.cuGetErrorString(result_val, ctypes.byref(error_str)) cuda.cuGetErrorString(result_val, ctypes.byref(error_str))
raise Exception(f"CUDA exception! ERROR: {error_str}") raise Exception(f"CUDA exception! Error code: {error_str.value.decode()}")
def get_cuda_version(cuda, cudart_path):
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION
try:
cudart = ctypes.CDLL(cudart_path)
except OSError:
# TODO: shouldn't we error or at least warn here?
print(f'ERROR: libcudart.so could not be read from path: {cudart_path}!')
return None
version = ctypes.c_int()
check_cuda_result(cuda, cudart.cudaRuntimeGetVersion(ctypes.byref(version)))
version = int(version.value)
major = version//1000
minor = (version-(major*1000))//10
return f'{major}{minor}'
def get_cuda_lib_handle():
# 1. find libcuda.so library (GPU driver) (/usr/lib)
try:
cuda = ctypes.CDLL("libcuda.so")
except OSError:
# TODO: shouldn't we error or at least warn here?
print('ERROR: libcuda.so not found!')
return None
check_cuda_result(cuda, cuda.cuInit(0))
def get_compute_capabilities(): return cuda
def get_compute_capabilities(cuda):
""" """
1. find libcuda.so library (GPU driver) (/usr/lib) 1. find libcuda.so library (GPU driver) (/usr/lib)
init_device -> init variables -> call function by reference init_device -> init variables -> call function by reference
...@@ -42,13 +72,6 @@ def get_compute_capabilities(): ...@@ -42,13 +72,6 @@ def get_compute_capabilities():
# bits taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549 # bits taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549
""" """
# 1. find libcuda.so library (GPU driver) (/usr/lib)
try:
cuda = ctypes.CDLL("libcuda.so")
except OSError:
# TODO: shouldn't we error or at least warn here?
print('ERROR: libcuda.so not found!')
return None
nGpus = ctypes.c_int() nGpus = ctypes.c_int()
cc_major = ctypes.c_int() cc_major = ctypes.c_int()
...@@ -57,8 +80,6 @@ def get_compute_capabilities(): ...@@ -57,8 +80,6 @@ def get_compute_capabilities():
result = ctypes.c_int() result = ctypes.c_int()
device = ctypes.c_int() device = ctypes.c_int()
check_cuda_result(cuda, cuda.cuInit(0))
check_cuda_result(cuda, cuda.cuDeviceGetCount(ctypes.byref(nGpus))) check_cuda_result(cuda, cuda.cuDeviceGetCount(ctypes.byref(nGpus)))
ccs = [] ccs = []
for i in range(nGpus.value): for i in range(nGpus.value):
...@@ -75,13 +96,13 @@ def get_compute_capabilities(): ...@@ -75,13 +96,13 @@ def get_compute_capabilities():
# def get_compute_capability()-> Union[List[str, ...], None]: # FIXME: error # def get_compute_capability()-> Union[List[str, ...], None]: # FIXME: error
def get_compute_capability(): def get_compute_capability(cuda):
""" """
Extracts the highest compute capbility from all available GPUs, as compute Extracts the highest compute capbility from all available GPUs, as compute
capabilities are downwards compatible. If no GPUs are detected, it returns capabilities are downwards compatible. If no GPUs are detected, it returns
None. None.
""" """
ccs = get_compute_capabilities() ccs = get_compute_capabilities(cuda)
if ccs is not None: if ccs is not None:
# TODO: handle different compute capabilities; for now, take the max # TODO: handle different compute capabilities; for now, take the max
return ccs[-1] return ccs[-1]
...@@ -89,10 +110,19 @@ def get_compute_capability(): ...@@ -89,10 +110,19 @@ def get_compute_capability():
def evaluate_cuda_setup(): def evaluate_cuda_setup():
cuda_path = determine_cuda_runtime_lib_path()
print(f"CUDA SETUP: CUDA path found: {cuda_path}")
cc = get_compute_capability()
binary_name = "libbitsandbytes_cpu.so" binary_name = "libbitsandbytes_cpu.so"
cudart_path = determine_cuda_runtime_lib_path()
if cudart_path is None:
print(
"WARNING: No libcudart.so found! Install CUDA or the cudatoolkit package (anaconda)!"
)
return binary_name
print(f"CUDA SETUP: CUDA path found: {cudart_path}")
cuda = get_cuda_lib_handle()
cc = get_compute_capability(cuda)
cuda_version_string = get_cuda_version(cuda, cudart_path)
if cc == '': if cc == '':
print( print(
...@@ -107,15 +137,8 @@ def evaluate_cuda_setup(): ...@@ -107,15 +137,8 @@ def evaluate_cuda_setup():
# (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible) # (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible)
# (2) Multiple CUDA versions installed # (2) Multiple CUDA versions installed
# FIXME: cuda_home is still unused
cuda_home = str(Path(cuda_path).parent.parent)
# we use ls -l instead of nvcc to determine the cuda version # we use ls -l instead of nvcc to determine the cuda version
# since most installations will have the libcudart.so installed, but not the compiler # since most installations will have the libcudart.so installed, but not the compiler
ls_output, err = execute_and_return(f"ls -l {cuda_path}")
major, minor, revision = (
ls_output.split(" ")[-1].replace("libcudart.so.", "").split(".")
)
cuda_version_string = f"{major}{minor}"
print(f'CUDA_SETUP: Detected CUDA version {cuda_version_string}') print(f'CUDA_SETUP: Detected CUDA version {cuda_version_string}')
def get_binary_name(): def get_binary_name():
......
...@@ -123,4 +123,4 @@ def determine_cuda_runtime_lib_path() -> Union[Path, None]: ...@@ -123,4 +123,4 @@ def determine_cuda_runtime_lib_path() -> Union[Path, None]:
warn_in_case_of_duplicates(cuda_runtime_libs) warn_in_case_of_duplicates(cuda_runtime_libs)
return next(iter(cuda_runtime_libs)) if cuda_runtime_libs else set() return next(iter(cuda_runtime_libs)) if cuda_runtime_libs else None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment