Commit 5709cfb5 authored by Christian Sarofeen's avatar Christian Sarofeen
Browse files

Try to improve robustness of finding cuda in build. Try to support building with CUDA 8.

parent 1cea1005
......@@ -17,10 +17,16 @@ import ctypes.util
import torch
#Takes a path to walk
#A function to decide if to keep
#collection if we want a list of all occurances
def find(path, regex_func, collect=False):
"""
Recursively searches through a directory with regex_func and
either collects all instances or returns the first instance.
Args:
path: Directory to search through
regex_function: A function to run on each file to decide if it should be returned/collected
collect (False) : If True will collect all instances of matching, else will return first instance only
"""
collection = [] if collect else None
for root, dirs, files in os.walk(path):
for file in files:
......@@ -31,25 +37,55 @@ def find(path, regex_func, collect=False):
return os.path.join(root, file)
return list(set(collection))
def findcuda():
"""
Based on PyTorch build process. Will look for nvcc for compilation.
Either will set cuda home by enviornment variable CUDA_HOME or will search
for nvcc. Returns NVCC executable, cuda major version and cuda home directory.
"""
cuda_path = None
CUDA_HOME = None
CUDA_HOME = os.getenv('CUDA_HOME', '/usr/local/cuda')
if not os.path.exists(CUDA_HOME):
# We use nvcc path on Linux and cudart path on macOS
osname = platform.system()
if osname == 'Linux':
cuda_path = find_nvcc()
else:
cudart_path = ctypes.util.find_library('cudart')
if cudart_path is not None:
cuda_path = os.path.dirname(cudart_path)
else:
cuda_path = None
cudart_path = ctypes.util.find_library('cudart')
if cudart_path is not None:
cuda_path = os.path.dirname(cudart_path)
if cuda_path is not None:
CUDA_HOME = os.path.dirname(cuda_path)
else:
CUDA_HOME = None
WITH_CUDA = CUDA_HOME is not None
return CUDA_HOME
if not cuda_path and not CUDA_HOME:
nvcc_path = find('/usr/local/', re.compile("nvcc").search, False)
if nvcc_path:
CUDA_HOME = os.path.dirname(nvcc_path)
if CUDA_HOME:
os.path.dirname(CUDA_HOME)
if (not os.path.exists(CUDA_HOME+os.sep+"lib64")
or not os.path.exists(CUDA_HOME+os.sep+"include") ):
raise RuntimeError("Error: found NVCC at ", nvcc_path ," but could not locate CUDA libraries"+
" or include directories.")
raise RuntimeError("Error: Could not find cuda on this system."+
" Please set your CUDA_HOME enviornment variable to the CUDA base directory.")
NVCC = find(CUDA_HOME, re.compile('nvcc').search)
CUDA_LIB = find(CUDA_HOME, re.compile('libcudart.so.*.*.*').search)
if CUDA_LIB:
try:
CUDA_VERSION = int(CUDA_LIB.split('.')[2])
except (ValueError, TypeError):
CUDA_VERSION = 9
else:
CUDA_VERSION = 9
if CUDA_VERSION < 8:
raise RuntimeError("Error: APEx requires CUDA 8 or newer")
return NVCC, CUDA_VERSION, CUDA_HOME
#Get some important paths
curdir = os.path.dirname(os.path.abspath(inspect.stack()[0][1]))
......@@ -87,7 +123,7 @@ extra_compile_args = ["--std=c++11",]
#findcuda returns root dir of CUDA
#include cuda/include and cuda/lib64 for python module build.
CUDA_HOME=findcuda()
NVCC, CUDA_VERSION, CUDA_HOME=findcuda()
library_dirs.append(os.path.join(CUDA_HOME, "lib64"))
include_dirs.append(os.path.join(CUDA_HOME, 'include'))
......@@ -107,22 +143,25 @@ class RMBuild(clean):
shutil.rmtree(eggdir)
clean.run(self)
def CompileCudaFiles():
def CompileCudaFiles(NVCC, CUDA_VERSION):
print()
print("Compiling cuda modules with nvcc:")
#Need arches to compile for. Compiles for 70 which requires CUDA9
nvcc_cmd = ['nvcc',
'-Xcompiler',
'-fPIC',
'-gencode', 'arch=compute_52,code=sm_52',
gencodes = ['-gencode', 'arch=compute_52,code=sm_52',
'-gencode', 'arch=compute_60,code=sm_60',
'-gencode', 'arch=compute_61,code=sm_61',
'-gencode', 'arch=compute_70,code=sm_70',
'-gencode', 'arch=compute_70,code=compute_70',
'--std=c++11',
'-O3',
]
'-gencode', 'arch=compute_61,code=sm_61',]
if CUDA_VERSION > 8:
gencodes += ['-gencode', 'arch=compute_70,code=sm_70',
'-gencode', 'arch=compute_70,code=compute_70',]
#Need arches to compile for. Compiles for 70 which requires CUDA9
nvcc_cmd = [NVCC,
'-Xcompiler',
'-fPIC'
] + gencodes + [
'--std=c++11',
'-O3',
]
for dir in include_dirs:
nvcc_cmd.append("-I"+dir)
......@@ -152,7 +191,7 @@ if 'clean' not in sys.argv:
print("library_dirs: ", library_dirs)
print("libraries: ", main_libraries)
print()
CompileCudaFiles()
CompileCudaFiles(NVCC, CUDA_VERSION)
print("Building CUDA extension.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment