setup.py 7.82 KB
Newer Older
soumith's avatar
soumith committed
1
import os
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
2
3
import io
import re
soumith's avatar
soumith committed
4
5
import sys
from setuptools import setup, find_packages
6
from pkg_resources import get_distribution, DistributionNotFound
7
import subprocess
8
import distutils.command.clean
9
import distutils.spawn
10
11
12
13
import glob
import shutil

import torch
14
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
15
from torch.utils.hipify import hipify_python
soumith's avatar
soumith committed
16
17


Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
18
19
20
21
22
23
24
def read(*names, **kwargs):
    with io.open(
        os.path.join(os.path.dirname(__file__), *names),
        encoding=kwargs.get("encoding", "utf8")
    ) as fp:
        return fp.read()

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
25

26
27
28
29
30
31
32
def get_dist(pkgname):
    try:
        return get_distribution(pkgname)
    except DistributionNotFound:
        return None


33
version = '0.6.0a0'
34
sha = 'Unknown'
35
package_name = 'torchvision'
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
36

37
38
39
40
41
42
43
cwd = os.path.dirname(os.path.abspath(__file__))

try:
    sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=cwd).decode('ascii').strip()
except Exception:
    pass

44
45
if os.getenv('BUILD_VERSION'):
    version = os.getenv('BUILD_VERSION')
46
47
48
49
50
51
52
53
54
55
elif sha != 'Unknown':
    version += '+' + sha[:7]
print("Building wheel {}-{}".format(package_name, version))


def write_version_file():
    version_path = os.path.join(cwd, 'torchvision', 'version.py')
    with open(version_path, 'w') as f:
        f.write("__version__ = '{}'\n".format(version))
        f.write("git_version = {}\n".format(repr(sha)))
56
57
58
        f.write("from torchvision.extension import _check_cuda_version\n")
        f.write("if _check_cuda_version() > 0:\n")
        f.write("    cuda = _check_cuda_version()\n")
59
60
61


write_version_file()
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
62

Thomas Grainger's avatar
Thomas Grainger committed
63
64
readme = open('README.rst').read()

65
66
67
pytorch_dep = 'torch'
if os.getenv('PYTORCH_VERSION'):
    pytorch_dep += "==" + os.getenv('PYTORCH_VERSION')
soumith's avatar
soumith committed
68

69
70
requirements = [
    'numpy',
71
    pytorch_dep,
72
73
]

74
75
76
77
pillow_ver = ' >= 4.1.1'
pillow_req = 'pillow-simd' if get_dist('pillow-simd') is not None else 'pillow'
requirements.append(pillow_req + pillow_ver)

78

79
80
81
82
83
84
def get_extensions():
    this_dir = os.path.dirname(os.path.abspath(__file__))
    extensions_dir = os.path.join(this_dir, 'torchvision', 'csrc')

    main_file = glob.glob(os.path.join(extensions_dir, '*.cpp'))
    source_cpu = glob.glob(os.path.join(extensions_dir, 'cpu', '*.cpp'))
85
86
87
88
89
90
91
92
93
94
95
96
97

    is_rocm_pytorch = False
    if torch.__version__ >= '1.5':
        from torch.utils.cpp_extension import ROCM_HOME
        is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False

    if is_rocm_pytorch:
        hipify_python.hipify(
            project_directory=this_dir,
            output_directory=this_dir,
            includes="torchvision/csrc/cuda/*",
            show_detailed=True,
            is_pytorch_extension=True,
98
        )
99
        source_cuda = glob.glob(os.path.join(extensions_dir, 'hip', '*.hip'))
100
        # Copy over additional files
101
102
103
104
105
        shutil.copy("torchvision/csrc/cuda/cuda_helpers.h", "torchvision/csrc/hip/cuda_helpers.h")
        shutil.copy("torchvision/csrc/cuda/vision_cuda.h", "torchvision/csrc/hip/vision_cuda.h")

    else:
        source_cuda = glob.glob(os.path.join(extensions_dir, 'cuda', '*.cu'))
106
107
108
109

    sources = main_file + source_cpu
    extension = CppExtension

110
111
112
113
114
115
116
117
118
119
120
    compile_cpp_tests = os.getenv('WITH_CPP_MODELS_TEST', '0') == '1'
    if compile_cpp_tests:
        test_dir = os.path.join(this_dir, 'test')
        models_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'models')
        test_file = glob.glob(os.path.join(test_dir, '*.cpp'))
        source_models = glob.glob(os.path.join(models_dir, '*.cpp'))

        test_file = [os.path.join(test_dir, s) for s in test_file]
        source_models = [os.path.join(models_dir, s) for s in source_models]
        tests = test_file + source_models
        tests_include_dirs = [test_dir, models_dir]
Shahriar's avatar
Shahriar committed
121

122
123
    define_macros = []

Soumith Chintala's avatar
Soumith Chintala committed
124
    extra_compile_args = {}
125
126
    if (torch.cuda.is_available() and ((CUDA_HOME is not None) or is_rocm_pytorch)) \
            or os.getenv('FORCE_CUDA', '0') == '1':
127
128
        extension = CUDAExtension
        sources += source_cuda
129
130
131
132
133
134
135
        if not is_rocm_pytorch:
            define_macros += [('WITH_CUDA', None)]
            nvcc_flags = os.getenv('NVCC_FLAGS', '')
            if nvcc_flags == '':
                nvcc_flags = []
            else:
                nvcc_flags = nvcc_flags.split(' ')
Soumith Chintala's avatar
Soumith Chintala committed
136
        else:
137
138
            define_macros += [('WITH_HIP', None)]
            nvcc_flags = []
Soumith Chintala's avatar
Soumith Chintala committed
139
        extra_compile_args = {
140
            'cxx': [],
Soumith Chintala's avatar
Soumith Chintala committed
141
142
            'nvcc': nvcc_flags,
        }
143

144
145
146
    if sys.platform == 'win32':
        define_macros += [('torchvision_EXPORTS', None)]

Francisco Massa's avatar
Francisco Massa committed
147
148
149
        extra_compile_args.setdefault('cxx', [])
        extra_compile_args['cxx'].append('/MP')

150
151
152
153
154
155
156
157
158
159
    sources = [os.path.join(extensions_dir, s) for s in sources]

    include_dirs = [extensions_dir]

    ext_modules = [
        extension(
            'torchvision._C',
            sources,
            include_dirs=include_dirs,
            define_macros=define_macros,
Soumith Chintala's avatar
Soumith Chintala committed
160
            extra_compile_args=extra_compile_args,
161
        )
162
    ]
163
164
165
166
167
168
169
170
171
172
    if compile_cpp_tests:
        ext_modules.append(
            extension(
                'torchvision._C_tests',
                tests,
                include_dirs=tests_include_dirs,
                define_macros=define_macros,
                extra_compile_args=extra_compile_args,
            )
        )
173
174
175
176

    ffmpeg_exe = distutils.spawn.find_executable('ffmpeg')
    has_ffmpeg = ffmpeg_exe is not None

177
    if has_ffmpeg:
178
179
180
181
182
183
184
185
186
187
188
189
190
        ffmpeg_bin = os.path.dirname(ffmpeg_exe)
        ffmpeg_root = os.path.dirname(ffmpeg_bin)
        ffmpeg_include_dir = os.path.join(ffmpeg_root, 'include')

        # TorchVision base decoder + video reader
        video_reader_src_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'cpu', 'video_reader')
        video_reader_src = glob.glob(os.path.join(video_reader_src_dir, "*.cpp"))
        base_decoder_src_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'cpu', 'decoder')
        base_decoder_src = glob.glob(
            os.path.join(base_decoder_src_dir, "[!sync_decoder_test,!utils_test]*.cpp"))

        combined_src = video_reader_src + base_decoder_src

191
192
193
        ext_modules.append(
            CppExtension(
                'torchvision.video_reader',
194
                combined_src,
195
                include_dirs=[
196
                    base_decoder_src_dir,
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
                    video_reader_src_dir,
                    ffmpeg_include_dir,
                    extensions_dir,
                ],
                libraries=[
                    'avcodec',
                    'avformat',
                    'avutil',
                    'swresample',
                    'swscale',
                ],
                extra_compile_args=["-std=c++14"],
                extra_link_args=["-std=c++14"],
            )
        )
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230

    return ext_modules


class clean(distutils.command.clean.clean):
    def run(self):
        with open('.gitignore', 'r') as f:
            ignores = f.read()
            for wildcard in filter(None, ignores.split('\n')):
                for filename in glob.glob(wildcard):
                    try:
                        os.remove(filename)
                    except OSError:
                        shutil.rmtree(filename, ignore_errors=True)

        # It's an old-style class in Python 2.7...
        distutils.command.clean.clean.run(self)


Thomas Grainger's avatar
Thomas Grainger committed
231
setup(
soumith's avatar
soumith committed
232
    # Metadata
233
234
    name=package_name,
    version=version,
soumith's avatar
soumith committed
235
236
237
238
    author='PyTorch Core Team',
    author_email='soumith@pytorch.org',
    url='https://github.com/pytorch/vision',
    description='image and video datasets and models for torch deep learning',
Thomas Grainger's avatar
Thomas Grainger committed
239
    long_description=readme,
soumith's avatar
soumith committed
240
241
242
    license='BSD',

    # Package info
soumith's avatar
soumith committed
243
    packages=find_packages(exclude=('test',)),
soumith's avatar
soumith committed
244

245
    zip_safe=False,
246
    install_requires=requirements,
247
248
249
    extras_require={
        "scipy": ["scipy"],
    },
250
    ext_modules=get_extensions(),
251
252
253
254
    cmdclass={
        'build_ext': BuildExtension.with_options(no_python_abi_suffix=True),
        'clean': clean,
    }
soumith's avatar
soumith committed
255
)