setup.py 4.79 KB
Newer Older
1
import torch.cuda
2
3
4
import os
import re
import subprocess
5
from setuptools import setup, find_packages
Christian Sarofeen's avatar
Christian Sarofeen committed
6
from distutils.command.clean import clean
7
from torch.utils.cpp_extension import CUDAExtension
8
9
10

# TODO:  multiple modules, so we don't have to route all interfaces through
# the same interface.cpp file?
11
12
13
14
15

if not torch.cuda.is_available():
    print("Warning: Torch did not find available GPUs on this system.\n",
          "If your intention is to cross-compile, this is not an error.")

16
17
18
19
20
21
22
23
24
25
26
27
print("torch.__version__  = ", torch.__version__)
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])

if TORCH_MAJOR == 0 and TORCH_MINOR < 4:
      raise RuntimeError("APEx requires Pytorch 0.4 or newer.\n" +
                         "The latest stable release can be obtained from https://pytorch.org/")

version_le_04 = []
if TORCH_MAJOR == 0 and TORCH_MINOR == 4:
    version_le_04 = ['-DVERSION_LE_04']

28
29
30
31
32
33
34
35
36
37
38
def find(path, regex_func, collect=False):
    collection = [] if collect else None
    for root, dirs, files in os.walk(path):
        for file in files:
            if regex_func(file):
                if collect:
                    collection.append(os.path.join(root, file))
                else:
                    return os.path.join(root, file)
    return list(set(collection))

39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# The means of finding CUDA_HOME in cpp_extension does not allow cross-compilation 
# if torch.cuda.is_available() is False.
def find_cuda_home():
    cuda_path = None
    CUDA_HOME = None

    CUDA_HOME = os.getenv('CUDA_HOME', '/usr/local/cuda')
    if not os.path.exists(CUDA_HOME):
        # We use nvcc path on Linux and cudart path on macOS
        cudart_path = ctypes.util.find_library('cudart')
        if cudart_path is not None:
            cuda_path = os.path.dirname(cudart_path)
        if cuda_path is not None:
            CUDA_HOME = os.path.dirname(cuda_path)
            
    if not cuda_path and not CUDA_HOME:
        nvcc_path = find('/usr/local/', re.compile("nvcc").search, False)
        if nvcc_path:
            CUDA_HOME = os.path.dirname(nvcc_path)
            if CUDA_HOME:
                os.path.dirname(CUDA_HOME)

        if (not os.path.exists(CUDA_HOME+os.sep+"lib64")
            or not os.path.exists(CUDA_HOME+os.sep+"include") ):
            raise RuntimeError("Error: found NVCC at ", 
                               nvcc_path,
                               " but could not locate CUDA libraries"+
                               " or include directories.")
        
        raise RuntimeError("Error: Could not find cuda on this system. " +
                            "Please set your CUDA_HOME enviornment variable "
                            "to the CUDA base directory.")

    print("Found CUDA_HOME = ", CUDA_HOME)

    return CUDA_HOME

CUDA_HOME = find_cuda_home()

# Patch the extension's view of CUDA_HOME to allow cross-compilation
torch.utils.cpp_extension.CUDA_HOME = CUDA_HOME

81
82
83
84
85
86
87
88
89
90
def get_cuda_version():
    NVCC = find(CUDA_HOME+os.sep+"bin",
                re.compile('nvcc$').search)
    print("Found NVCC = ", NVCC)

    # Parse output of nvcc to get cuda major version
    nvcc_output = subprocess.check_output([NVCC, '--version']).decode("utf-8")
    CUDA_LIB = re.compile(', V[0-9]+\.[0-9]+\.[0-9]+').search(nvcc_output).group(0).split('V')[1]
    print("Found CUDA_LIB = ", CUDA_LIB)

91
92
    CUDA_MAJOR = int(CUDA_LIB.split('.')[0])
    print("Found CUDA_MAJOR = ", CUDA_MAJOR)
93

94
    if CUDA_MAJOR < 8:
95
96
        raise RuntimeError("APex requires CUDA 8.0 or newer")

97
    return CUDA_MAJOR
98

99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
CUDA_MAJOR = get_cuda_version()

gencodes = ['-gencode', 'arch=compute_52,code=sm_52',
            '-gencode', 'arch=compute_60,code=sm_60',
            '-gencode', 'arch=compute_61,code=sm_61',]

if CUDA_MAJOR > 8:
    gencodes += ['-gencode', 'arch=compute_70,code=sm_70',
                 '-gencode', 'arch=compute_70,code=compute_70',]

ext_modules = []
extension = CUDAExtension(
    'apex_C', [
        'csrc/interface.cpp',
        'csrc/weight_norm_fwd_cuda.cu',
        'csrc/weight_norm_bwd_cuda.cu',
        'csrc/scale_cuda.cu',
    ],
    extra_compile_args={'cxx': ['-g'] + version_le_04,
                        'nvcc': ['-O3'] + version_le_04 + gencodes})
ext_modules.append(extension)
Christian Sarofeen's avatar
Christian Sarofeen committed
120
121

setup(
122
123
124
125
126
127
128
129
130
131
132
133
    name='apex',
    version='0.1',
    packages=find_packages(exclude=('build', 
                                    'csrc', 
                                    'include', 
                                    'tests', 
                                    'dist',
                                    'docs',
                                    'tests',
                                    'examples',
                                    'apex.egg-info',)),
    ext_modules=ext_modules,
Christian Sarofeen's avatar
Christian Sarofeen committed
134
    description='PyTorch Extensions written by NVIDIA',
135
    cmdclass={'build_ext': torch.utils.cpp_extension.BuildExtension},
Christian Sarofeen's avatar
Christian Sarofeen committed
136
)