setup.py 6.54 KB
Newer Older
1
from __future__ import print_function
soumith's avatar
soumith committed
2
import os
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
3
4
import io
import re
soumith's avatar
soumith committed
5
6
import sys
from setuptools import setup, find_packages
7
from pkg_resources import get_distribution, DistributionNotFound
8
import subprocess
9
import distutils.command.clean
10
import distutils.spawn
11
12
13
14
import glob
import shutil

import torch
15
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
soumith's avatar
soumith committed
16
17


Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
18
19
20
21
22
23
24
def read(*names, **kwargs):
    with io.open(
        os.path.join(os.path.dirname(__file__), *names),
        encoding=kwargs.get("encoding", "utf8")
    ) as fp:
        return fp.read()

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
25

26
27
28
29
30
31
32
def get_dist(pkgname):
    try:
        return get_distribution(pkgname)
    except DistributionNotFound:
        return None


33
version = '0.6.0a0'
34
sha = 'Unknown'
35
package_name = 'torchvision'
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
36

37
38
39
40
41
42
43
cwd = os.path.dirname(os.path.abspath(__file__))

try:
    sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=cwd).decode('ascii').strip()
except Exception:
    pass

44
45
if os.getenv('BUILD_VERSION'):
    version = os.getenv('BUILD_VERSION')
46
47
48
49
50
51
52
53
54
55
elif sha != 'Unknown':
    version += '+' + sha[:7]
print("Building wheel {}-{}".format(package_name, version))


def write_version_file():
    version_path = os.path.join(cwd, 'torchvision', 'version.py')
    with open(version_path, 'w') as f:
        f.write("__version__ = '{}'\n".format(version))
        f.write("git_version = {}\n".format(repr(sha)))
56
57
58
        f.write("from torchvision.extension import _check_cuda_version\n")
        f.write("if _check_cuda_version() > 0:\n")
        f.write("    cuda = _check_cuda_version()\n")
59
60
61


write_version_file()
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
62

Thomas Grainger's avatar
Thomas Grainger committed
63
64
readme = open('README.rst').read()

65
66
67
pytorch_dep = 'torch'
if os.getenv('PYTORCH_VERSION'):
    pytorch_dep += "==" + os.getenv('PYTORCH_VERSION')
soumith's avatar
soumith committed
68

69
70
71
requirements = [
    'numpy',
    'six',
72
    pytorch_dep,
73
74
]

75
76
77
78
pillow_ver = ' >= 4.1.1'
pillow_req = 'pillow-simd' if get_dist('pillow-simd') is not None else 'pillow'
requirements.append(pillow_req + pillow_ver)

79

80
81
82
83
84
85
86
87
88
89
90
def get_extensions():
    this_dir = os.path.dirname(os.path.abspath(__file__))
    extensions_dir = os.path.join(this_dir, 'torchvision', 'csrc')

    main_file = glob.glob(os.path.join(extensions_dir, '*.cpp'))
    source_cpu = glob.glob(os.path.join(extensions_dir, 'cpu', '*.cpp'))
    source_cuda = glob.glob(os.path.join(extensions_dir, 'cuda', '*.cu'))

    sources = main_file + source_cpu
    extension = CppExtension

91
92
93
94
95
96
97
98
99
100
101
    compile_cpp_tests = os.getenv('WITH_CPP_MODELS_TEST', '0') == '1'
    if compile_cpp_tests:
        test_dir = os.path.join(this_dir, 'test')
        models_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'models')
        test_file = glob.glob(os.path.join(test_dir, '*.cpp'))
        source_models = glob.glob(os.path.join(models_dir, '*.cpp'))

        test_file = [os.path.join(test_dir, s) for s in test_file]
        source_models = [os.path.join(models_dir, s) for s in source_models]
        tests = test_file + source_models
        tests_include_dirs = [test_dir, models_dir]
Shahriar's avatar
Shahriar committed
102

103
104
    define_macros = []

Soumith Chintala's avatar
Soumith Chintala committed
105
    extra_compile_args = {}
106
    if (torch.cuda.is_available() and CUDA_HOME is not None) or os.getenv('FORCE_CUDA', '0') == '1':
107
108
109
        extension = CUDAExtension
        sources += source_cuda
        define_macros += [('WITH_CUDA', None)]
Soumith Chintala's avatar
Soumith Chintala committed
110
111
112
113
114
115
116
117
118
        nvcc_flags = os.getenv('NVCC_FLAGS', '')
        if nvcc_flags == '':
            nvcc_flags = []
        else:
            nvcc_flags = nvcc_flags.split(' ')
        extra_compile_args = {
            'cxx': ['-O0'],
            'nvcc': nvcc_flags,
        }
119

120
121
122
    if sys.platform == 'win32':
        define_macros += [('torchvision_EXPORTS', None)]

Francisco Massa's avatar
Francisco Massa committed
123
124
125
        extra_compile_args.setdefault('cxx', [])
        extra_compile_args['cxx'].append('/MP')

126
127
128
129
    sources = [os.path.join(extensions_dir, s) for s in sources]

    include_dirs = [extensions_dir]

130
131
132
133
134
135
136
137
138
139
140
    ffmpeg_exe = distutils.spawn.find_executable('ffmpeg')
    has_ffmpeg = ffmpeg_exe is not None
    if has_ffmpeg:
        ffmpeg_bin = os.path.dirname(ffmpeg_exe)
        ffmpeg_root = os.path.dirname(ffmpeg_bin)
        ffmpeg_include_dir = os.path.join(ffmpeg_root, 'include')

        # TorchVision video reader
        video_reader_src_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'cpu', 'video_reader')
        video_reader_src = glob.glob(os.path.join(video_reader_src_dir, "*.cpp"))

141
142
143
144
145
146
    ext_modules = [
        extension(
            'torchvision._C',
            sources,
            include_dirs=include_dirs,
            define_macros=define_macros,
Soumith Chintala's avatar
Soumith Chintala committed
147
            extra_compile_args=extra_compile_args,
148
        )
149
    ]
150
151
152
153
154
155
156
157
158
159
    if compile_cpp_tests:
        ext_modules.append(
            extension(
                'torchvision._C_tests',
                tests,
                include_dirs=tests_include_dirs,
                define_macros=define_macros,
                extra_compile_args=extra_compile_args,
            )
        )
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
    if has_ffmpeg:
        ext_modules.append(
            CppExtension(
                'torchvision.video_reader',
                video_reader_src,
                include_dirs=[
                    video_reader_src_dir,
                    ffmpeg_include_dir,
                    extensions_dir,
                ],
                libraries=[
                    'avcodec',
                    'avformat',
                    'avutil',
                    'swresample',
                    'swscale',
                ],
                extra_compile_args=["-std=c++14"],
                extra_link_args=["-std=c++14"],
            )
        )
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199

    return ext_modules


class clean(distutils.command.clean.clean):
    def run(self):
        with open('.gitignore', 'r') as f:
            ignores = f.read()
            for wildcard in filter(None, ignores.split('\n')):
                for filename in glob.glob(wildcard):
                    try:
                        os.remove(filename)
                    except OSError:
                        shutil.rmtree(filename, ignore_errors=True)

        # It's an old-style class in Python 2.7...
        distutils.command.clean.clean.run(self)


Thomas Grainger's avatar
Thomas Grainger committed
200
setup(
soumith's avatar
soumith committed
201
    # Metadata
202
203
    name=package_name,
    version=version,
soumith's avatar
soumith committed
204
205
206
207
    author='PyTorch Core Team',
    author_email='soumith@pytorch.org',
    url='https://github.com/pytorch/vision',
    description='image and video datasets and models for torch deep learning',
Thomas Grainger's avatar
Thomas Grainger committed
208
    long_description=readme,
soumith's avatar
soumith committed
209
210
211
    license='BSD',

    # Package info
soumith's avatar
soumith committed
212
    packages=find_packages(exclude=('test',)),
soumith's avatar
soumith committed
213

214
    zip_safe=False,
215
    install_requires=requirements,
216
217
218
    extras_require={
        "scipy": ["scipy"],
    },
219
    ext_modules=get_extensions(),
220
221
222
223
    cmdclass={
        'build_ext': BuildExtension.with_options(no_python_abi_suffix=True),
        'clean': clean,
    }
soumith's avatar
soumith committed
224
)