"fair_dev/testing/testing.py" did not exist on "290afecd0c04cb28e885537afd5e7990da139844"
setup.py 7.55 KB
Newer Older
1
from __future__ import print_function
soumith's avatar
soumith committed
2
import os
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
3
4
import io
import re
soumith's avatar
soumith committed
5
6
import sys
from setuptools import setup, find_packages
7
from pkg_resources import get_distribution, DistributionNotFound
8
import subprocess
9
import distutils.command.clean
10
import distutils.spawn
11
12
13
14
import glob
import shutil

import torch
15
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
16
from torch.utils.hipify import hipify_python
soumith's avatar
soumith committed
17
18


Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
19
20
21
22
23
24
25
def read(*names, **kwargs):
    with io.open(
        os.path.join(os.path.dirname(__file__), *names),
        encoding=kwargs.get("encoding", "utf8")
    ) as fp:
        return fp.read()

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
26

27
28
29
30
31
32
33
def get_dist(pkgname):
    try:
        return get_distribution(pkgname)
    except DistributionNotFound:
        return None


34
version = '0.6.0a0'
35
sha = 'Unknown'
36
package_name = 'torchvision'
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
37

38
39
40
41
42
43
44
cwd = os.path.dirname(os.path.abspath(__file__))

try:
    sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=cwd).decode('ascii').strip()
except Exception:
    pass

45
46
if os.getenv('BUILD_VERSION'):
    version = os.getenv('BUILD_VERSION')
47
48
49
50
51
52
53
54
55
56
elif sha != 'Unknown':
    version += '+' + sha[:7]
print("Building wheel {}-{}".format(package_name, version))


def write_version_file():
    version_path = os.path.join(cwd, 'torchvision', 'version.py')
    with open(version_path, 'w') as f:
        f.write("__version__ = '{}'\n".format(version))
        f.write("git_version = {}\n".format(repr(sha)))
57
58
59
        f.write("from torchvision.extension import _check_cuda_version\n")
        f.write("if _check_cuda_version() > 0:\n")
        f.write("    cuda = _check_cuda_version()\n")
60
61
62


write_version_file()
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
63

Thomas Grainger's avatar
Thomas Grainger committed
64
65
readme = open('README.rst').read()

66
67
68
pytorch_dep = 'torch'
if os.getenv('PYTORCH_VERSION'):
    pytorch_dep += "==" + os.getenv('PYTORCH_VERSION')
soumith's avatar
soumith committed
69

70
71
72
requirements = [
    'numpy',
    'six',
73
    pytorch_dep,
74
75
]

76
77
78
79
pillow_ver = ' >= 4.1.1'
pillow_req = 'pillow-simd' if get_dist('pillow-simd') is not None else 'pillow'
requirements.append(pillow_req + pillow_ver)

80

81
82
83
84
85
86
def get_extensions():
    this_dir = os.path.dirname(os.path.abspath(__file__))
    extensions_dir = os.path.join(this_dir, 'torchvision', 'csrc')

    main_file = glob.glob(os.path.join(extensions_dir, '*.cpp'))
    source_cpu = glob.glob(os.path.join(extensions_dir, 'cpu', '*.cpp'))
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107

    is_rocm_pytorch = False
    if torch.__version__ >= '1.5':
        from torch.utils.cpp_extension import ROCM_HOME
        is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False

    if is_rocm_pytorch:
        hipify_python.hipify(
            project_directory=this_dir,
            output_directory=this_dir,
            includes="torchvision/csrc/cuda/*",
            show_detailed=True,
            is_pytorch_extension=True,
            )
        source_cuda = glob.glob(os.path.join(extensions_dir, 'hip', '*.hip'))
        ## Copy over additional files
        shutil.copy("torchvision/csrc/cuda/cuda_helpers.h", "torchvision/csrc/hip/cuda_helpers.h")
        shutil.copy("torchvision/csrc/cuda/vision_cuda.h", "torchvision/csrc/hip/vision_cuda.h")

    else:
        source_cuda = glob.glob(os.path.join(extensions_dir, 'cuda', '*.cu'))
108
109
110
111

    sources = main_file + source_cpu
    extension = CppExtension

112
113
114
115
116
117
118
119
120
121
122
    compile_cpp_tests = os.getenv('WITH_CPP_MODELS_TEST', '0') == '1'
    if compile_cpp_tests:
        test_dir = os.path.join(this_dir, 'test')
        models_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'models')
        test_file = glob.glob(os.path.join(test_dir, '*.cpp'))
        source_models = glob.glob(os.path.join(models_dir, '*.cpp'))

        test_file = [os.path.join(test_dir, s) for s in test_file]
        source_models = [os.path.join(models_dir, s) for s in source_models]
        tests = test_file + source_models
        tests_include_dirs = [test_dir, models_dir]
Shahriar's avatar
Shahriar committed
123

124
125
    define_macros = []

Soumith Chintala's avatar
Soumith Chintala committed
126
    extra_compile_args = {}
127
    if (torch.cuda.is_available() and ((CUDA_HOME is not None)  or is_rocm_pytorch)) or os.getenv('FORCE_CUDA', '0') == '1':
128
129
        extension = CUDAExtension
        sources += source_cuda
130
131
132
133
134
135
136
        if not is_rocm_pytorch:
            define_macros += [('WITH_CUDA', None)]
            nvcc_flags = os.getenv('NVCC_FLAGS', '')
            if nvcc_flags == '':
                nvcc_flags = []
            else:
                nvcc_flags = nvcc_flags.split(' ')
Soumith Chintala's avatar
Soumith Chintala committed
137
        else:
138
139
            define_macros += [('WITH_HIP', None)]
            nvcc_flags = []
Soumith Chintala's avatar
Soumith Chintala committed
140
        extra_compile_args = {
141
            'cxx': [],
Soumith Chintala's avatar
Soumith Chintala committed
142
143
            'nvcc': nvcc_flags,
        }
144

145
146
147
    if sys.platform == 'win32':
        define_macros += [('torchvision_EXPORTS', None)]

Francisco Massa's avatar
Francisco Massa committed
148
149
150
        extra_compile_args.setdefault('cxx', [])
        extra_compile_args['cxx'].append('/MP')

151
152
153
154
    sources = [os.path.join(extensions_dir, s) for s in sources]

    include_dirs = [extensions_dir]

155
156
157
158
159
160
161
162
163
164
165
    ffmpeg_exe = distutils.spawn.find_executable('ffmpeg')
    has_ffmpeg = ffmpeg_exe is not None
    if has_ffmpeg:
        ffmpeg_bin = os.path.dirname(ffmpeg_exe)
        ffmpeg_root = os.path.dirname(ffmpeg_bin)
        ffmpeg_include_dir = os.path.join(ffmpeg_root, 'include')

        # TorchVision video reader
        video_reader_src_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'cpu', 'video_reader')
        video_reader_src = glob.glob(os.path.join(video_reader_src_dir, "*.cpp"))

166
167
168
169
170
171
    ext_modules = [
        extension(
            'torchvision._C',
            sources,
            include_dirs=include_dirs,
            define_macros=define_macros,
Soumith Chintala's avatar
Soumith Chintala committed
172
            extra_compile_args=extra_compile_args,
173
        )
174
    ]
175
176
177
178
179
180
181
182
183
184
    if compile_cpp_tests:
        ext_modules.append(
            extension(
                'torchvision._C_tests',
                tests,
                include_dirs=tests_include_dirs,
                define_macros=define_macros,
                extra_compile_args=extra_compile_args,
            )
        )
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
    if has_ffmpeg:
        ext_modules.append(
            CppExtension(
                'torchvision.video_reader',
                video_reader_src,
                include_dirs=[
                    video_reader_src_dir,
                    ffmpeg_include_dir,
                    extensions_dir,
                ],
                libraries=[
                    'avcodec',
                    'avformat',
                    'avutil',
                    'swresample',
                    'swscale',
                ],
                extra_compile_args=["-std=c++14"],
                extra_link_args=["-std=c++14"],
            )
        )
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224

    return ext_modules


class clean(distutils.command.clean.clean):
    def run(self):
        with open('.gitignore', 'r') as f:
            ignores = f.read()
            for wildcard in filter(None, ignores.split('\n')):
                for filename in glob.glob(wildcard):
                    try:
                        os.remove(filename)
                    except OSError:
                        shutil.rmtree(filename, ignore_errors=True)

        # It's an old-style class in Python 2.7...
        distutils.command.clean.clean.run(self)


Thomas Grainger's avatar
Thomas Grainger committed
225
setup(
soumith's avatar
soumith committed
226
    # Metadata
227
228
    name=package_name,
    version=version,
soumith's avatar
soumith committed
229
230
231
232
    author='PyTorch Core Team',
    author_email='soumith@pytorch.org',
    url='https://github.com/pytorch/vision',
    description='image and video datasets and models for torch deep learning',
Thomas Grainger's avatar
Thomas Grainger committed
233
    long_description=readme,
soumith's avatar
soumith committed
234
235
236
    license='BSD',

    # Package info
soumith's avatar
soumith committed
237
    packages=find_packages(exclude=('test',)),
soumith's avatar
soumith committed
238

239
    zip_safe=False,
240
    install_requires=requirements,
241
242
243
    extras_require={
        "scipy": ["scipy"],
    },
244
    ext_modules=get_extensions(),
245
246
247
248
    cmdclass={
        'build_ext': BuildExtension.with_options(no_python_abi_suffix=True),
        'clean': clean,
    }
soumith's avatar
soumith committed
249
)