setup.py 1.45 KB
Newer Older
1
2
import torch.cuda
from setuptools import setup, find_packages
Christian Sarofeen's avatar
Christian Sarofeen committed
3
from distutils.command.clean import clean
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from torch.utils.cpp_extension import CppExtension, CUDAExtension
from torch.utils.cpp_extension import CUDA_HOME

# TODO:  multiple modules, so we don't have to route all interfaces through
# the same interface.cpp file?
if torch.cuda.is_available() and CUDA_HOME is not None:
    ext_modules = []
    extension = CUDAExtension(
        'apex._C', [
            'csrc/interface.cpp',
            'csrc/weight_norm_fwd_cuda.cu',
            'csrc/weight_norm_bwd_cuda.cu',
            'csrc/scale_cuda.cu',
        ],
        extra_compile_args={'cxx': ['-g'],
                            'nvcc': ['-O2', '-arch=sm_70']}) # TODO:  compile for all arches.
    ext_modules.append(extension)
else:
    raise RuntimeError("Apex requires Cuda 9.0 or higher")
Christian Sarofeen's avatar
Christian Sarofeen committed
23
24

setup(
25
26
27
28
29
30
31
32
33
34
35
36
    name='apex',
    version='0.1',
    packages=find_packages(exclude=('build', 
                                    'csrc', 
                                    'include', 
                                    'tests', 
                                    'dist',
                                    'docs',
                                    'tests',
                                    'examples',
                                    'apex.egg-info',)),
    ext_modules=ext_modules,
Christian Sarofeen's avatar
Christian Sarofeen committed
37
    description='PyTorch Extensions written by NVIDIA',
38
    cmdclass={'build_ext': torch.utils.cpp_extension.BuildExtension},
Christian Sarofeen's avatar
Christian Sarofeen committed
39
)