setup.py 5.98 KB
Newer Older
1
from __future__ import print_function
soumith's avatar
soumith committed
2
import os
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
3
4
import io
import re
soumith's avatar
soumith committed
5
6
import sys
from setuptools import setup, find_packages
7
from pkg_resources import get_distribution, DistributionNotFound
8
import subprocess
9
10
11
12
13
14
import distutils.command.clean
import glob
import shutil

import torch
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME
soumith's avatar
soumith committed
15
16


Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
17
18
19
20
21
22
23
def read(*names, **kwargs):
    with io.open(
        os.path.join(os.path.dirname(__file__), *names),
        encoding=kwargs.get("encoding", "utf8")
    ) as fp:
        return fp.read()

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
24

25
26
27
28
29
30
31
def get_dist(pkgname):
    try:
        return get_distribution(pkgname)
    except DistributionNotFound:
        return None


32
version = '0.5.0a0'
33
sha = 'Unknown'
34
package_name = 'torchvision'
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
35

36
37
38
39
40
41
42
cwd = os.path.dirname(os.path.abspath(__file__))

try:
    sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=cwd).decode('ascii').strip()
except Exception:
    pass

43
44
if os.getenv('BUILD_VERSION'):
    version = os.getenv('BUILD_VERSION')
45
46
47
48
49
50
51
52
53
54
elif sha != 'Unknown':
    version += '+' + sha[:7]
print("Building wheel {}-{}".format(package_name, version))


def write_version_file():
    version_path = os.path.join(cwd, 'torchvision', 'version.py')
    with open(version_path, 'w') as f:
        f.write("__version__ = '{}'\n".format(version))
        f.write("git_version = {}\n".format(repr(sha)))
55
56
57
        f.write("from torchvision import _C\n")
        f.write("if hasattr(_C, 'CUDA_VERSION'):\n")
        f.write("    cuda = _C.CUDA_VERSION\n")
58
59
60


write_version_file()
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
61

Thomas Grainger's avatar
Thomas Grainger committed
62
63
readme = open('README.rst').read()

64
65
66
pytorch_dep = 'torch'
if os.getenv('PYTORCH_VERSION'):
    pytorch_dep += "==" + os.getenv('PYTORCH_VERSION')
soumith's avatar
soumith committed
67

68
69
70
requirements = [
    'numpy',
    'six',
71
    pytorch_dep,
72
73
]

74
75
76
77
pillow_ver = ' >= 4.1.1'
pillow_req = 'pillow-simd' if get_dist('pillow-simd') is not None else 'pillow'
requirements.append(pillow_req + pillow_ver)

78

79
80
81
82
83
84
85
86
87
88
89
def get_extensions():
    this_dir = os.path.dirname(os.path.abspath(__file__))
    extensions_dir = os.path.join(this_dir, 'torchvision', 'csrc')

    main_file = glob.glob(os.path.join(extensions_dir, '*.cpp'))
    source_cpu = glob.glob(os.path.join(extensions_dir, 'cpu', '*.cpp'))
    source_cuda = glob.glob(os.path.join(extensions_dir, 'cuda', '*.cu'))

    sources = main_file + source_cpu
    extension = CppExtension

Shahriar's avatar
Shahriar committed
90
91
92
93
94
95
96
97
98
    test_dir = os.path.join(this_dir, 'test')
    models_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'models')
    test_file = glob.glob(os.path.join(test_dir, '*.cpp'))
    source_models = glob.glob(os.path.join(models_dir, '*.cpp'))

    test_file = [os.path.join(test_dir, s) for s in test_file]
    source_models = [os.path.join(models_dir, s) for s in source_models]
    tests = test_file + source_models

99
100
101
102
103
104
105
106
    custom_ops_sources = [os.path.join(extensions_dir, "custom_ops", "custom_ops.cpp"),
                          os.path.join(extensions_dir, "cpu", "nms_cpu.cpp"),
                          os.path.join(extensions_dir, "cpu", "ROIAlign_cpu.cpp"),
                          os.path.join(extensions_dir, "cpu", "ROIPool_cpu.cpp")]
    custom_ops_sources_cuda = [os.path.join(extensions_dir, "cuda", "nms_cuda.cu"),
                               os.path.join(extensions_dir, "cuda", "ROIAlign_cuda.cu"),
                               os.path.join(extensions_dir, "cuda", "ROIPool_cuda.cu")]

107
108
    define_macros = []

Soumith Chintala's avatar
Soumith Chintala committed
109
    extra_compile_args = {}
110
    if (torch.cuda.is_available() and CUDA_HOME is not None) or os.getenv('FORCE_CUDA', '0') == '1':
111
112
        extension = CUDAExtension
        sources += source_cuda
113
        custom_ops_sources += custom_ops_sources_cuda
114
        define_macros += [('WITH_CUDA', None)]
Soumith Chintala's avatar
Soumith Chintala committed
115
116
117
118
119
120
121
122
123
        nvcc_flags = os.getenv('NVCC_FLAGS', '')
        if nvcc_flags == '':
            nvcc_flags = []
        else:
            nvcc_flags = nvcc_flags.split(' ')
        extra_compile_args = {
            'cxx': ['-O0'],
            'nvcc': nvcc_flags,
        }
124

125
126
127
    if sys.platform == 'win32':
        define_macros += [('torchvision_EXPORTS', None)]

Francisco Massa's avatar
Francisco Massa committed
128
129
130
        extra_compile_args.setdefault('cxx', [])
        extra_compile_args['cxx'].append('/MP')

131
132
133
    sources = [os.path.join(extensions_dir, s) for s in sources]

    include_dirs = [extensions_dir]
Shahriar's avatar
Shahriar committed
134
    tests_include_dirs = [test_dir, models_dir]
135
136
137
138
139
140
141

    ext_modules = [
        extension(
            'torchvision._C',
            sources,
            include_dirs=include_dirs,
            define_macros=define_macros,
Soumith Chintala's avatar
Soumith Chintala committed
142
            extra_compile_args=extra_compile_args,
Shahriar's avatar
Shahriar committed
143
144
145
146
147
148
149
        ),
        extension(
            'torchvision._C_tests',
            tests,
            include_dirs=tests_include_dirs,
            define_macros=define_macros,
            extra_compile_args=extra_compile_args,
150
151
152
153
154
155
156
157
        ),
        extension(
            "torchvision._custom_ops",
            sources=custom_ops_sources,
            include_dirs=include_dirs,
            define_macros=define_macros,
            extra_compile_args=extra_compile_args,
        ),
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
    ]

    return ext_modules


class clean(distutils.command.clean.clean):
    def run(self):
        with open('.gitignore', 'r') as f:
            ignores = f.read()
            for wildcard in filter(None, ignores.split('\n')):
                for filename in glob.glob(wildcard):
                    try:
                        os.remove(filename)
                    except OSError:
                        shutil.rmtree(filename, ignore_errors=True)

        # It's an old-style class in Python 2.7...
        distutils.command.clean.clean.run(self)


Thomas Grainger's avatar
Thomas Grainger committed
178
setup(
soumith's avatar
soumith committed
179
    # Metadata
180
181
    name=package_name,
    version=version,
soumith's avatar
soumith committed
182
183
184
185
    author='PyTorch Core Team',
    author_email='soumith@pytorch.org',
    url='https://github.com/pytorch/vision',
    description='image and video datasets and models for torch deep learning',
Thomas Grainger's avatar
Thomas Grainger committed
186
    long_description=readme,
soumith's avatar
soumith committed
187
188
189
    license='BSD',

    # Package info
soumith's avatar
soumith committed
190
    packages=find_packages(exclude=('test',)),
soumith's avatar
soumith committed
191
192

    zip_safe=True,
193
    install_requires=requirements,
194
195
196
    extras_require={
        "scipy": ["scipy"],
    },
197
    ext_modules=get_extensions(),
198
199
    cmdclass={'build_ext': torch.utils.cpp_extension.BuildExtension,
              'clean': clean}
soumith's avatar
soumith committed
200
)