Commit bf855389 authored by Michael Carilli's avatar Michael Carilli
Browse files

Reverting setup.py to rely on cpp_extension's native CUDA_HOME, because...

Reverting setup.py to rely on cpp_extension's native CUDA_HOME, because upstream fixed cross-compilation
parent 70b33770
...@@ -4,7 +4,7 @@ import re ...@@ -4,7 +4,7 @@ import re
import subprocess import subprocess
from setuptools import setup, find_packages from setuptools import setup, find_packages
from distutils.command.clean import clean from distutils.command.clean import clean
from torch.utils.cpp_extension import CUDAExtension from torch.utils.cpp_extension import CUDAExtension, CUDA_HOME
# TODO: multiple modules, so we don't have to route all interfaces through # TODO: multiple modules, so we don't have to route all interfaces through
# the same interface.cpp file? # the same interface.cpp file?
...@@ -36,48 +36,6 @@ def find(path, regex_func, collect=False): ...@@ -36,48 +36,6 @@ def find(path, regex_func, collect=False):
return os.path.join(root, file) return os.path.join(root, file)
return list(set(collection)) return list(set(collection))
# The means of finding CUDA_HOME in cpp_extension does not allow cross-compilation
# if torch.cuda.is_available() is False.
def find_cuda_home():
cuda_path = None
CUDA_HOME = None
CUDA_HOME = os.getenv('CUDA_HOME', '/usr/local/cuda')
if not os.path.exists(CUDA_HOME):
# We use nvcc path on Linux and cudart path on macOS
cudart_path = ctypes.util.find_library('cudart')
if cudart_path is not None:
cuda_path = os.path.dirname(cudart_path)
if cuda_path is not None:
CUDA_HOME = os.path.dirname(cuda_path)
if not cuda_path and not CUDA_HOME:
nvcc_path = find('/usr/local/', re.compile("nvcc").search, False)
if nvcc_path:
CUDA_HOME = os.path.dirname(nvcc_path)
if CUDA_HOME:
os.path.dirname(CUDA_HOME)
if (not os.path.exists(CUDA_HOME+os.sep+"lib64")
or not os.path.exists(CUDA_HOME+os.sep+"include") ):
raise RuntimeError("Error: found NVCC at ",
nvcc_path,
" but could not locate CUDA libraries"+
" or include directories.")
raise RuntimeError("Error: Could not find cuda on this system. " +
"Please set your CUDA_HOME enviornment variable "
"to the CUDA base directory.")
print("Found CUDA_HOME = ", CUDA_HOME)
return CUDA_HOME
CUDA_HOME = find_cuda_home()
# Patch the extension's view of CUDA_HOME to allow cross-compilation
torch.utils.cpp_extension.CUDA_HOME = CUDA_HOME
def get_cuda_version(): def get_cuda_version():
NVCC = find(CUDA_HOME+os.sep+"bin", NVCC = find(CUDA_HOME+os.sep+"bin",
re.compile('nvcc$').search) re.compile('nvcc$').search)
...@@ -96,27 +54,32 @@ def get_cuda_version(): ...@@ -96,27 +54,32 @@ def get_cuda_version():
return CUDA_MAJOR return CUDA_MAJOR
CUDA_MAJOR = get_cuda_version() if CUDA_HOME is not None:
print("Found CUDA_HOME = ", CUDA_HOME)
gencodes = ['-gencode', 'arch=compute_52,code=sm_52',
'-gencode', 'arch=compute_60,code=sm_60', CUDA_MAJOR = get_cuda_version()
'-gencode', 'arch=compute_61,code=sm_61',]
gencodes = ['-gencode', 'arch=compute_52,code=sm_52',
if CUDA_MAJOR > 8: '-gencode', 'arch=compute_60,code=sm_60',
gencodes += ['-gencode', 'arch=compute_70,code=sm_70', '-gencode', 'arch=compute_61,code=sm_61',]
'-gencode', 'arch=compute_70,code=compute_70',]
if CUDA_MAJOR > 8:
ext_modules = [] gencodes += ['-gencode', 'arch=compute_70,code=sm_70',
extension = CUDAExtension( '-gencode', 'arch=compute_70,code=compute_70',]
'apex_C', [
'csrc/interface.cpp', ext_modules = []
'csrc/weight_norm_fwd_cuda.cu', extension = CUDAExtension(
'csrc/weight_norm_bwd_cuda.cu', 'apex_C', [
'csrc/scale_cuda.cu', 'csrc/interface.cpp',
], 'csrc/weight_norm_fwd_cuda.cu',
extra_compile_args={'cxx': ['-g'] + version_le_04, 'csrc/weight_norm_bwd_cuda.cu',
'nvcc': ['-O3'] + version_le_04 + gencodes}) 'csrc/scale_cuda.cu',
ext_modules.append(extension) ],
extra_compile_args={'cxx': ['-g'] + version_le_04,
'nvcc': ['-O3'] + version_le_04 + gencodes})
ext_modules.append(extension)
else:
raise RuntimeError("Could not find Cuda install directory")
setup( setup(
name='apex', name='apex',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment