setup.py 5.29 KB
Newer Older
1
import torch.cuda
2
import ctypes.util
3
4
5
import os
import re
import subprocess
6
from setuptools import setup, find_packages
Christian Sarofeen's avatar
Christian Sarofeen committed
7
from distutils.command.clean import clean
8
from torch.utils.cpp_extension import CUDAExtension, CUDA_HOME
9
10
11

# TODO:  multiple modules, so we don't have to route all interfaces through
# the same interface.cpp file?
12
13
14
15
16

if not torch.cuda.is_available():
    print("Warning: Torch did not find available GPUs on this system.\n",
          "If your intention is to cross-compile, this is not an error.")

17
18
19
20
21
22
23
24
25
26
27
28
print("torch.__version__  = ", torch.__version__)
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])

if TORCH_MAJOR == 0 and TORCH_MINOR < 4:
      raise RuntimeError("APEx requires Pytorch 0.4 or newer.\n" +
                         "The latest stable release can be obtained from https://pytorch.org/")

version_le_04 = []
if TORCH_MAJOR == 0 and TORCH_MINOR == 4:
    version_le_04 = ['-DVERSION_LE_04']

29
30
31
32
33
34
35
36
37
38
39
def find(path, regex_func, collect=False):
    collection = [] if collect else None
    for root, dirs, files in os.walk(path):
        for file in files:
            if regex_func(file):
                if collect:
                    collection.append(os.path.join(root, file))
                else:
                    return os.path.join(root, file)
    return list(set(collection))

40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# Due to https://github.com/pytorch/pytorch/issues/8223, for Pytorch <= 0.4
# torch.utils.cpp_extension's check for CUDA_HOME fails if there are no GPUs
# available on the system, which prevents cross-compiling and building via Dockerfiles.
# Workaround:  manually search for CUDA_HOME if Pytorch <= 0.4.
def find_cuda_home():
    cuda_path = None
    CUDA_HOME = None

    CUDA_HOME = os.getenv('CUDA_HOME', '/usr/local/cuda')
    if not os.path.exists(CUDA_HOME):
        # We use nvcc path on Linux and cudart path on macOS
        cudart_path = ctypes.util.find_library('cudart')
        if cudart_path is not None:
            cuda_path = os.path.dirname(cudart_path)
        if cuda_path is not None:
            CUDA_HOME = os.path.dirname(cuda_path)
            
    if not cuda_path and not CUDA_HOME:
        nvcc_path = find('/usr/local/', re.compile("nvcc").search, False)
        if nvcc_path:
            CUDA_HOME = os.path.dirname(nvcc_path)
            if CUDA_HOME:
                os.path.dirname(CUDA_HOME)

        if (not os.path.exists(CUDA_HOME+os.sep+"lib64")
            or not os.path.exists(CUDA_HOME+os.sep+"include") ):
            raise RuntimeError("Error: found NVCC at ", 
                               nvcc_path,
                               " but could not locate CUDA libraries"+
                               " or include directories.")
        
        raise RuntimeError("Error: Could not find cuda on this system. " +
72
                            "Please set your CUDA_HOME environment variable "
73
74
75
76
77
78
79
80
81
82
                            "to the CUDA base directory.")

    return CUDA_HOME

if TORCH_MAJOR == 0 and TORCH_MINOR == 4:
    if CUDA_HOME is None:
        CUDA_HOME = find_cuda_home()
        # Patch cpp_extension's view of CUDA_HOME:
        torch.utils.cpp_extension.CUDA_HOME = CUDA_HOME

83
84
def get_cuda_version():
    NVCC = find(CUDA_HOME+os.sep+"bin",
mcarilli's avatar
mcarilli committed
85
                re.compile('nvcc$|nvcc.exe').search)
86
87
88
89
90
91
92
    print("Found NVCC = ", NVCC)

    # Parse output of nvcc to get cuda major version
    nvcc_output = subprocess.check_output([NVCC, '--version']).decode("utf-8")
    CUDA_LIB = re.compile(', V[0-9]+\.[0-9]+\.[0-9]+').search(nvcc_output).group(0).split('V')[1]
    print("Found CUDA_LIB = ", CUDA_LIB)

93
94
    CUDA_MAJOR = int(CUDA_LIB.split('.')[0])
    print("Found CUDA_MAJOR = ", CUDA_MAJOR)
95

96
    if CUDA_MAJOR < 8:
97
98
        raise RuntimeError("APex requires CUDA 8.0 or newer")

99
    return CUDA_MAJOR
100

101
if CUDA_HOME is not None:
102
 
103
104
105
106
    print("Found CUDA_HOME = ", CUDA_HOME)

    CUDA_MAJOR = get_cuda_version()

mcarilli's avatar
mcarilli committed
107
108
    gencodes = ['-gencode', 'arch=compute_50,code=sm_50',
                '-gencode', 'arch=compute_52,code=sm_52',
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
                '-gencode', 'arch=compute_60,code=sm_60',
                '-gencode', 'arch=compute_61,code=sm_61',]

    if CUDA_MAJOR > 8:
        gencodes += ['-gencode', 'arch=compute_70,code=sm_70',
                     '-gencode', 'arch=compute_70,code=compute_70',]

    ext_modules = []
    extension = CUDAExtension(
        'apex_C', [
            'csrc/interface.cpp',
            'csrc/weight_norm_fwd_cuda.cu',
            'csrc/weight_norm_bwd_cuda.cu',
            'csrc/scale_cuda.cu',
        ],
        extra_compile_args={'cxx': ['-g'] + version_le_04,
                            'nvcc': ['-O3'] + version_le_04 + gencodes})
    ext_modules.append(extension)
else:
    raise RuntimeError("Could not find Cuda install directory")
Christian Sarofeen's avatar
Christian Sarofeen committed
129
130

setup(
131
132
133
134
135
136
137
138
139
140
141
142
    name='apex',
    version='0.1',
    packages=find_packages(exclude=('build', 
                                    'csrc', 
                                    'include', 
                                    'tests', 
                                    'dist',
                                    'docs',
                                    'tests',
                                    'examples',
                                    'apex.egg-info',)),
    ext_modules=ext_modules,
Christian Sarofeen's avatar
Christian Sarofeen committed
143
    description='PyTorch Extensions written by NVIDIA',
144
    cmdclass={'build_ext': torch.utils.cpp_extension.BuildExtension},
Christian Sarofeen's avatar
Christian Sarofeen committed
145
)