Commit 70b33770 authored by Michael Carilli's avatar Michael Carilli
Browse files

Updating setup.py to enable cross-compilation

parent af431d28
......@@ -4,8 +4,7 @@ import re
import subprocess
from setuptools import setup, find_packages
from distutils.command.clean import clean
from torch.utils.cpp_extension import CppExtension, CUDAExtension
from torch.utils.cpp_extension import CUDA_HOME
from torch.utils.cpp_extension import CUDAExtension
# TODO: multiple modules, so we don't have to route all interfaces through
# the same interface.cpp file?
......@@ -37,6 +36,48 @@ def find(path, regex_func, collect=False):
return os.path.join(root, file)
return list(set(collection))
# The means of finding CUDA_HOME in cpp_extension does not allow cross-compilation
# if torch.cuda.is_available() is False.
def find_cuda_home():
cuda_path = None
CUDA_HOME = None
CUDA_HOME = os.getenv('CUDA_HOME', '/usr/local/cuda')
if not os.path.exists(CUDA_HOME):
# We use nvcc path on Linux and cudart path on macOS
cudart_path = ctypes.util.find_library('cudart')
if cudart_path is not None:
cuda_path = os.path.dirname(cudart_path)
if cuda_path is not None:
CUDA_HOME = os.path.dirname(cuda_path)
if not cuda_path and not CUDA_HOME:
nvcc_path = find('/usr/local/', re.compile("nvcc").search, False)
if nvcc_path:
CUDA_HOME = os.path.dirname(nvcc_path)
if CUDA_HOME:
os.path.dirname(CUDA_HOME)
if (not os.path.exists(CUDA_HOME+os.sep+"lib64")
or not os.path.exists(CUDA_HOME+os.sep+"include") ):
raise RuntimeError("Error: found NVCC at ",
nvcc_path,
" but could not locate CUDA libraries"+
" or include directories.")
raise RuntimeError("Error: Could not find cuda on this system. " +
"Please set your CUDA_HOME enviornment variable "
"to the CUDA base directory.")
print("Found CUDA_HOME = ", CUDA_HOME)
return CUDA_HOME
CUDA_HOME = find_cuda_home()
# Patch the extension's view of CUDA_HOME to allow cross-compilation
torch.utils.cpp_extension.CUDA_HOME = CUDA_HOME
def get_cuda_version():
NVCC = find(CUDA_HOME+os.sep+"bin",
re.compile('nvcc$').search)
......@@ -55,32 +96,27 @@ def get_cuda_version():
return CUDA_MAJOR
if CUDA_HOME is not None:
print("Found CUDA_HOME = ", CUDA_HOME)
CUDA_MAJOR = get_cuda_version()
gencodes = ['-gencode', 'arch=compute_52,code=sm_52',
'-gencode', 'arch=compute_60,code=sm_60',
'-gencode', 'arch=compute_61,code=sm_61',]
if CUDA_MAJOR > 8:
gencodes += ['-gencode', 'arch=compute_70,code=sm_70',
'-gencode', 'arch=compute_70,code=compute_70',]
ext_modules = []
extension = CUDAExtension(
'apex_C', [
'csrc/interface.cpp',
'csrc/weight_norm_fwd_cuda.cu',
'csrc/weight_norm_bwd_cuda.cu',
'csrc/scale_cuda.cu',
],
extra_compile_args={'cxx': ['-g'] + version_le_04,
'nvcc': ['-O3'] + version_le_04 + gencodes})
ext_modules.append(extension)
else:
raise RuntimeError("Could not find Cuda install directory")
CUDA_MAJOR = get_cuda_version()
gencodes = ['-gencode', 'arch=compute_52,code=sm_52',
'-gencode', 'arch=compute_60,code=sm_60',
'-gencode', 'arch=compute_61,code=sm_61',]
if CUDA_MAJOR > 8:
gencodes += ['-gencode', 'arch=compute_70,code=sm_70',
'-gencode', 'arch=compute_70,code=compute_70',]
ext_modules = []
extension = CUDAExtension(
'apex_C', [
'csrc/interface.cpp',
'csrc/weight_norm_fwd_cuda.cu',
'csrc/weight_norm_bwd_cuda.cu',
'csrc/scale_cuda.cu',
],
extra_compile_args={'cxx': ['-g'] + version_le_04,
'nvcc': ['-O3'] + version_le_04 + gencodes})
ext_modules.append(extension)
setup(
name='apex',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment