Commit 5dfa4c37 authored by Michael Carilli's avatar Michael Carilli
Browse files

Editing setup.py to locate nvcc and detect cuda major version more strictly

parent b0d7d60d
......@@ -71,11 +71,18 @@ def findcuda():
raise RuntimeError("Error: Could not find cuda on this system."+
" Please set your CUDA_HOME enviornment variable to the CUDA base directory.")
NVCC = find(CUDA_HOME, re.compile('nvcc').search)
CUDA_LIB = find(CUDA_HOME, re.compile('libcudart.so.*.*.*').search)
NVCC = find(CUDA_HOME+os.sep+"bin",
re.compile('nvcc$').search)
print("Found NVCC = ", NVCC)
# Parse output of nvcc to get cuda major version
nvcc_output = subprocess.check_output([NVCC, '--version']).decode("utf-8")
CUDA_LIB = re.compile(', V[0-9]+\.[0-9]+\.[0-9]+').search(nvcc_output).group(0).split('V')[1]
print("Found CUDA_LIB = ", CUDA_LIB)
if CUDA_LIB:
try:
CUDA_VERSION = int(CUDA_LIB.split('.')[2])
CUDA_VERSION = int(CUDA_LIB.split('.')[0])
except (ValueError, TypeError):
CUDA_VERSION = 9
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment