setup.py 18.8 KB
Newer Older
soumith's avatar
soumith committed
1
import os
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
2
import io
soumith's avatar
soumith committed
3
4
import sys
from setuptools import setup, find_packages
5
from pkg_resources import parse_version, get_distribution, DistributionNotFound
6
import subprocess
7
import distutils.command.clean
8
import distutils.spawn
9
10
11
12
import glob
import shutil

import torch
13
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
14
from torch.utils.hipify import hipify_python
soumith's avatar
soumith committed
15
16


Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
17
18
19
20
21
22
23
def read(*names, **kwargs):
    with io.open(
        os.path.join(os.path.dirname(__file__), *names),
        encoding=kwargs.get("encoding", "utf8")
    ) as fp:
        return fp.read()

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
24

25
26
27
28
29
30
31
def get_dist(pkgname):
    try:
        return get_distribution(pkgname)
    except DistributionNotFound:
        return None


32
33
34
35
36
cwd = os.path.dirname(os.path.abspath(__file__))

version_txt = os.path.join(cwd, 'version.txt')
with open(version_txt, 'r') as f:
    version = f.readline().strip()
37
sha = 'Unknown'
38
package_name = 'torchvision'
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
39

panning's avatar
panning committed
40
41
42
43
44
45
46
47
48
49
50
51
dcu_version = version

def get_abi():
    try:
        command = "echo '#include <string>' | gcc -x c++ -E -dM - | fgrep _GLIBCXX_USE_CXX11_ABI"
        result = subprocess.run(command, shell=True, capture_output=True, text=True)
        output = result.stdout.strip()
        abi = "abi" + output.split(" ")[-1]
        return abi
    except Exception:
        return 'abiUnknown'

52
53
54
55
56
try:
    sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=cwd).decode('ascii').strip()
except Exception:
    pass

57
58
if os.getenv('BUILD_VERSION'):
    version = os.getenv('BUILD_VERSION')
panning's avatar
panning committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#elif sha != 'Unknown':
#    version += '+' + sha[:7]
if sha != 'Unknown':
    dcu_version += '+git' + sha[:7]

dcu_version += "." + get_abi()
if os.getenv("ROCM_PATH"):
    rocm_path = os.getenv('ROCM_PATH', "")
    rocm_version_path = os.path.join(rocm_path, '.info', "version-dev")
    with open(rocm_version_path, 'r',encoding='utf-8') as file:
        lines = file.readlines()
    rocm_version=lines[0][:-2].replace(".", "")
    dcu_version += ".dtk" + rocm_version
# torch version
dcu_version += ".torch" + torch.__version__[:-2]
74
75
76
77
78

def write_version_file():
    version_path = os.path.join(cwd, 'torchvision', 'version.py')
    with open(version_path, 'w') as f:
        f.write("__version__ = '{}'\n".format(version))
panning's avatar
panning committed
79
        f.write("__dcu_version__ = '{}'\n".format(dcu_version))
80
        f.write("git_version = {}\n".format(repr(sha)))
81
82
83
        f.write("from torchvision.extension import _check_cuda_version\n")
        f.write("if _check_cuda_version() > 0:\n")
        f.write("    cuda = _check_cuda_version()\n")
84
85


86
87
88
pytorch_dep = 'torch'
if os.getenv('PYTORCH_VERSION'):
    pytorch_dep += "==" + os.getenv('PYTORCH_VERSION')
soumith's avatar
soumith committed
89

90
91
requirements = [
    'numpy',
92
    pytorch_dep,
93
94
]

95
pillow_ver = ' >= 5.3.0'
96
97
98
pillow_req = 'pillow-simd' if get_dist('pillow-simd') is not None else 'pillow'
requirements.append(pillow_req + pillow_ver)

99

100
101
102
103
104
105
106
107
108
109
110
def find_library(name, vision_include):
    this_dir = os.path.dirname(os.path.abspath(__file__))
    build_prefix = os.environ.get('BUILD_PREFIX', None)
    is_conda_build = build_prefix is not None

    library_found = False
    conda_installed = False
    lib_folder = None
    include_folder = None
    library_header = '{0}.h'.format(name)

111
112
113
114
115
116
117
118
119
120
121
122
    # Lookup in TORCHVISION_INCLUDE or in the package file
    package_path = [os.path.join(this_dir, 'torchvision')]
    for folder in vision_include + package_path:
        candidate_path = os.path.join(folder, library_header)
        library_found = os.path.exists(candidate_path)
        if library_found:
            break

    if not library_found:
        print('Running build on conda-build: {0}'.format(is_conda_build))
        if is_conda_build:
            # Add conda headers/libraries
123
            if os.name == 'nt':
124
125
126
                build_prefix = os.path.join(build_prefix, 'Library')
            include_folder = os.path.join(build_prefix, 'include')
            lib_folder = os.path.join(build_prefix, 'lib')
127
128
129
130
131
            library_header_path = os.path.join(
                include_folder, library_header)
            library_found = os.path.isfile(library_header_path)
            conda_installed = library_found
        else:
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
            # Check if using Anaconda to produce wheels
            conda = distutils.spawn.find_executable('conda')
            is_conda = conda is not None
            print('Running build on conda: {0}'.format(is_conda))
            if is_conda:
                python_executable = sys.executable
                py_folder = os.path.dirname(python_executable)
                if os.name == 'nt':
                    env_path = os.path.join(py_folder, 'Library')
                else:
                    env_path = os.path.dirname(py_folder)
                lib_folder = os.path.join(env_path, 'lib')
                include_folder = os.path.join(env_path, 'include')
                library_header_path = os.path.join(
                    include_folder, library_header)
                library_found = os.path.isfile(library_header_path)
                conda_installed = library_found

        if not library_found:
            if sys.platform == 'linux':
                library_found = os.path.exists('/usr/include/{0}'.format(
                    library_header))
                library_found = library_found or os.path.exists(
                    '/usr/local/include/{0}'.format(library_header))
156
157
158
159

    return library_found, conda_installed, include_folder, lib_folder


160
161
162
163
def get_extensions():
    this_dir = os.path.dirname(os.path.abspath(__file__))
    extensions_dir = os.path.join(this_dir, 'torchvision', 'csrc')

164
165
    main_file = glob.glob(os.path.join(extensions_dir, '*.cpp')) + glob.glob(os.path.join(extensions_dir, 'ops',
                                                                                          '*.cpp'))
166
167
168
169
170
    source_cpu = (
        glob.glob(os.path.join(extensions_dir, 'ops', 'autograd', '*.cpp')) +
        glob.glob(os.path.join(extensions_dir, 'ops', 'cpu', '*.cpp')) +
        glob.glob(os.path.join(extensions_dir, 'ops', 'quantized', 'cpu', '*.cpp'))
    )
171
172
173
174
175
176
177
178
179
180

    is_rocm_pytorch = False
    if torch.__version__ >= '1.5':
        from torch.utils.cpp_extension import ROCM_HOME
        is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False

    if is_rocm_pytorch:
        hipify_python.hipify(
            project_directory=this_dir,
            output_directory=this_dir,
181
            includes="torchvision/csrc/ops/cuda/*",
182
183
            show_detailed=True,
            is_pytorch_extension=True,
184
        )
185
        source_cuda = glob.glob(os.path.join(extensions_dir, 'ops', 'hip', '*.hip'))
186
        # Copy over additional files
187
188
        for file in glob.glob(r"torchvision/csrc/ops/cuda/*.h"):
            shutil.copy(file, "torchvision/csrc/ops/hip")
189
190

    else:
191
192
193
        source_cuda = glob.glob(os.path.join(extensions_dir, 'ops', 'cuda', '*.cu'))

    source_cuda += glob.glob(os.path.join(extensions_dir, 'ops', 'autocast', '*.cpp'))
194
195
196
197

    sources = main_file + source_cpu
    extension = CppExtension

198
199
200
201
202
203
204
205
206
207
208
    compile_cpp_tests = os.getenv('WITH_CPP_MODELS_TEST', '0') == '1'
    if compile_cpp_tests:
        test_dir = os.path.join(this_dir, 'test')
        models_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'models')
        test_file = glob.glob(os.path.join(test_dir, '*.cpp'))
        source_models = glob.glob(os.path.join(models_dir, '*.cpp'))

        test_file = [os.path.join(test_dir, s) for s in test_file]
        source_models = [os.path.join(models_dir, s) for s in source_models]
        tests = test_file + source_models
        tests_include_dirs = [test_dir, models_dir]
Shahriar's avatar
Shahriar committed
209

210
211
    define_macros = []

212
    extra_compile_args = {'cxx': []}
213
214
    if (torch.cuda.is_available() and ((CUDA_HOME is not None) or is_rocm_pytorch)) \
            or os.getenv('FORCE_CUDA', '0') == '1':
215
216
        extension = CUDAExtension
        sources += source_cuda
217
218
219
220
221
222
223
        if not is_rocm_pytorch:
            define_macros += [('WITH_CUDA', None)]
            nvcc_flags = os.getenv('NVCC_FLAGS', '')
            if nvcc_flags == '':
                nvcc_flags = []
            else:
                nvcc_flags = nvcc_flags.split(' ')
Soumith Chintala's avatar
Soumith Chintala committed
224
        else:
225
226
            define_macros += [('WITH_HIP', None)]
            nvcc_flags = []
227
        extra_compile_args["nvcc"] = nvcc_flags
228

229
230
    if sys.platform == 'win32':
        define_macros += [('torchvision_EXPORTS', None)]
231

Francisco Massa's avatar
Francisco Massa committed
232
233
        extra_compile_args['cxx'].append('/MP')

234
235
236
237
238
239
240
241
242
243
244
245
246
247
    debug_mode = os.getenv('DEBUG', '0') == '1'
    if debug_mode:
        print("Compile in debug mode")
        extra_compile_args['cxx'].append("-g")
        extra_compile_args['cxx'].append("-O0")
        if "nvcc" in extra_compile_args:
            # we have to remove "-OX" and "-g" flag if exists and append
            nvcc_flags = extra_compile_args["nvcc"]
            extra_compile_args["nvcc"] = [
                f for f in nvcc_flags if not ("-O" in f or "-g" in f)
            ]
            extra_compile_args["nvcc"].append("-O0")
            extra_compile_args["nvcc"].append("-g")

248
249
    sources = [os.path.join(extensions_dir, s) for s in sources]

250
    include_dirs = [extensions_dir]
251
252
253
254

    ext_modules = [
        extension(
            'torchvision._C',
255
            sorted(sources),
256
257
            include_dirs=include_dirs,
            define_macros=define_macros,
Soumith Chintala's avatar
Soumith Chintala committed
258
            extra_compile_args=extra_compile_args,
259
        )
260
    ]
261
262
263
264
265
266
267
268
269
270
    if compile_cpp_tests:
        ext_modules.append(
            extension(
                'torchvision._C_tests',
                tests,
                include_dirs=tests_include_dirs,
                define_macros=define_macros,
                extra_compile_args=extra_compile_args,
            )
        )
271

272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
    # ------------------- Torchvision extra extensions ------------------------
    vision_include = os.environ.get('TORCHVISION_INCLUDE', None)
    vision_library = os.environ.get('TORCHVISION_LIBRARY', None)
    vision_include = (vision_include.split(os.pathsep)
                      if vision_include is not None else [])
    vision_library = (vision_library.split(os.pathsep)
                      if vision_library is not None else [])
    include_dirs += vision_include
    library_dirs = vision_library

    # Image reading extension
    image_macros = []
    image_include = [extensions_dir]
    image_library = []
    image_link_flags = []

    # Locating libPNG
    libpng = distutils.spawn.find_executable('libpng-config')
    pngfix = distutils.spawn.find_executable('pngfix')
    png_found = libpng is not None or pngfix is not None
    print('PNG found: {0}'.format(png_found))
    if png_found:
        if libpng is not None:
            # Linux / Mac
            png_version = subprocess.run([libpng, '--version'],
                                         stdout=subprocess.PIPE)
            png_version = png_version.stdout.strip().decode('utf-8')
            print('libpng version: {0}'.format(png_version))
            png_version = parse_version(png_version)
            if png_version >= parse_version("1.6.0"):
                print('Building torchvision with PNG image support')
303
304
305
306
                png_lib = subprocess.run([libpng, '--libdir'],
                                         stdout=subprocess.PIPE)
                png_lib = png_lib.stdout.strip().decode('utf-8')
                if 'disabled' not in png_lib:
307
                    image_library += [png_lib]
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
                png_include = subprocess.run([libpng, '--I_opts'],
                                             stdout=subprocess.PIPE)
                png_include = png_include.stdout.strip().decode('utf-8')
                _, png_include = png_include.split('-I')
                print('libpng include path: {0}'.format(png_include))
                image_include += [png_include]
                image_link_flags.append('png')
            else:
                print('libpng installed version is less than 1.6.0, '
                      'disabling PNG support')
                png_found = False
        else:
            # Windows
            png_lib = os.path.join(
                os.path.dirname(os.path.dirname(pngfix)), 'lib')
            png_include = os.path.join(os.path.dirname(
                os.path.dirname(pngfix)), 'include', 'libpng16')
            image_library += [png_lib]
            image_include += [png_include]
            image_link_flags.append('libpng')

329
330
331
332
333
    # Locating libjpeg
    (jpeg_found, jpeg_conda,
     jpeg_include, jpeg_lib) = find_library('jpeglib', vision_include)

    print('JPEG found: {0}'.format(jpeg_found))
334
    image_macros += [('PNG_FOUND', str(int(png_found)))]
335
336
337
338
339
340
341
342
    image_macros += [('JPEG_FOUND', str(int(jpeg_found)))]
    if jpeg_found:
        print('Building torchvision with JPEG image support')
        image_link_flags.append('jpeg')
        if jpeg_conda:
            image_library += [jpeg_lib]
            image_include += [jpeg_include]

343
344
345
346
347
348
349
350
351
352
353
354
355
356
    # Locating nvjpeg
    # Should be included in CUDA_HOME for CUDA >= 10.1, which is the minimum version we have in the CI
    nvjpeg_found = (
        extension is CUDAExtension and
        CUDA_HOME is not None and
        os.path.exists(os.path.join(CUDA_HOME, 'include', 'nvjpeg.h'))
    )

    print('NVJPEG found: {0}'.format(nvjpeg_found))
    image_macros += [('NVJPEG_FOUND', str(int(nvjpeg_found)))]
    if nvjpeg_found:
        print('Building torchvision with NVJPEG image support')
        image_link_flags.append('nvjpeg')

357
    image_path = os.path.join(extensions_dir, 'io', 'image')
358
359
    image_src = (glob.glob(os.path.join(image_path, '*.cpp')) + glob.glob(os.path.join(image_path, 'cpu', '*.cpp'))
                 + glob.glob(os.path.join(image_path, 'cuda', '*.cpp')))
360

361
    if png_found or jpeg_found:
362
363
364
365
366
367
368
369
370
371
        ext_modules.append(extension(
            'torchvision.image',
            image_src,
            include_dirs=image_include + include_dirs + [image_path],
            library_dirs=image_library + library_dirs,
            define_macros=image_macros,
            libraries=image_link_flags,
            extra_compile_args=extra_compile_args
        ))

372
373
    ffmpeg_exe = distutils.spawn.find_executable('ffmpeg')
    has_ffmpeg = ffmpeg_exe is not None
374
    print("FFmpeg found: {}".format(has_ffmpeg))
375

376
    if has_ffmpeg:
377
378
379
380
381
382
383
384
        ffmpeg_libraries = {
            'libavcodec',
            'libavformat',
            'libavutil',
            'libswresample',
            'libswscale'
        }

385
386
387
        ffmpeg_bin = os.path.dirname(ffmpeg_exe)
        ffmpeg_root = os.path.dirname(ffmpeg_bin)
        ffmpeg_include_dir = os.path.join(ffmpeg_root, 'include')
388
        ffmpeg_library_dir = os.path.join(ffmpeg_root, 'lib')
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416

        gcc = distutils.spawn.find_executable('gcc')
        platform_tag = subprocess.run(
            [gcc, '-print-multiarch'], stdout=subprocess.PIPE)
        platform_tag = platform_tag.stdout.strip().decode('utf-8')

        if platform_tag:
            # Most probably a Debian-based distribution
            ffmpeg_include_dir = [
                ffmpeg_include_dir,
                os.path.join(ffmpeg_include_dir, platform_tag)
            ]
            ffmpeg_library_dir = [
                ffmpeg_library_dir,
                os.path.join(ffmpeg_library_dir, platform_tag)
            ]
        else:
            ffmpeg_include_dir = [ffmpeg_include_dir]
            ffmpeg_library_dir = [ffmpeg_library_dir]

        has_ffmpeg = True
        for library in ffmpeg_libraries:
            library_found = False
            for search_path in ffmpeg_include_dir + include_dirs:
                full_path = os.path.join(search_path, library, '*.h')
                library_found |= len(glob.glob(full_path)) > 0

            if not library_found:
417
                print(f'{library} header files were not found, disabling ffmpeg support')
418
419
420
                has_ffmpeg = False

    if has_ffmpeg:
421
        print("ffmpeg include path: {}".format(ffmpeg_include_dir))
422
        print("ffmpeg library_dir: {}".format(ffmpeg_library_dir))
423
424

        # TorchVision base decoder + video reader
425
        video_reader_src_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'io', 'video_reader')
426
        video_reader_src = glob.glob(os.path.join(video_reader_src_dir, "*.cpp"))
427
        base_decoder_src_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'io', 'decoder')
428
        base_decoder_src = glob.glob(
429
            os.path.join(base_decoder_src_dir, "*.cpp"))
430
        # Torchvision video API
431
        videoapi_src_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'io', 'video')
432
        videoapi_src = glob.glob(os.path.join(videoapi_src_dir, "*.cpp"))
433
434
        # exclude tests
        base_decoder_src = [x for x in base_decoder_src if '_test.cpp' not in x]
435

436
        combined_src = video_reader_src + base_decoder_src + videoapi_src
437

438
439
440
        ext_modules.append(
            CppExtension(
                'torchvision.video_reader',
441
                combined_src,
442
                include_dirs=[
443
                    base_decoder_src_dir,
444
                    video_reader_src_dir,
445
                    videoapi_src_dir,
446
                    extensions_dir,
447
448
                    *ffmpeg_include_dir,
                    *include_dirs
449
                ],
450
                library_dirs=ffmpeg_library_dir + library_dirs,
451
452
453
454
455
456
457
                libraries=[
                    'avcodec',
                    'avformat',
                    'avutil',
                    'swresample',
                    'swscale',
                ],
458
459
                extra_compile_args=["-std=c++14"] if os.name != 'nt' else ['/std:c++14', '/MP'],
                extra_link_args=["-std=c++14" if os.name != 'nt' else '/std:c++14'],
460
461
            )
        )
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480

    return ext_modules


class clean(distutils.command.clean.clean):
    def run(self):
        with open('.gitignore', 'r') as f:
            ignores = f.read()
            for wildcard in filter(None, ignores.split('\n')):
                for filename in glob.glob(wildcard):
                    try:
                        os.remove(filename)
                    except OSError:
                        shutil.rmtree(filename, ignore_errors=True)

        # It's an old-style class in Python 2.7...
        distutils.command.clean.clean.run(self)


481
482
483
484
485
486
487
488
489
490
491
if __name__ == "__main__":
    print("Building wheel {}-{}".format(package_name, version))

    write_version_file()

    with open('README.rst') as f:
        readme = f.read()

    setup(
        # Metadata
        name=package_name,
panning's avatar
panning committed
492
        version=dcu_version,
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
        author='PyTorch Core Team',
        author_email='soumith@pytorch.org',
        url='https://github.com/pytorch/vision',
        description='image and video datasets and models for torch deep learning',
        long_description=readme,
        license='BSD',

        # Package info
        packages=find_packages(exclude=('test',)),
        package_data={
            package_name: ['*.dll', '*.dylib', '*.so']
        },
        zip_safe=False,
        install_requires=requirements,
        extras_require={
            "scipy": ["scipy"],
        },
        ext_modules=get_extensions(),
        cmdclass={
            'build_ext': BuildExtension.with_options(no_python_abi_suffix=True),
            'clean': clean,
        }
    )