setup.py 20.2 KB
Newer Older
1
import distutils.command.clean
2
import distutils.spawn
3
import glob
4
import os
5
import shutil
6
7
import subprocess
import sys
8
9

import torch
10
11
from pkg_resources import parse_version, get_distribution, DistributionNotFound
from setuptools import setup, find_packages
12
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
soumith's avatar
soumith committed
13
14


Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
15
def read(*names, **kwargs):
16
    with open(os.path.join(os.path.dirname(__file__), *names), encoding=kwargs.get("encoding", "utf8")) as fp:
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
17
18
        return fp.read()

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
19

20
21
22
23
24
25
26
def get_dist(pkgname):
    try:
        return get_distribution(pkgname)
    except DistributionNotFound:
        return None


27
28
cwd = os.path.dirname(os.path.abspath(__file__))

29
version_txt = os.path.join(cwd, "version.txt")
30
with open(version_txt) as f:
31
    version = f.readline().strip()
32
33
sha = "Unknown"
package_name = "torchvision"
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
34

35
try:
36
    sha = subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=cwd).decode("ascii").strip()
37
38
39
except Exception:
    pass

40
41
42
43
if os.getenv("BUILD_VERSION"):
    version = os.getenv("BUILD_VERSION")
elif sha != "Unknown":
    version += "+" + sha[:7]
44
45
46


def write_version_file():
47
48
    version_path = os.path.join(cwd, "torchvision", "version.py")
    with open(version_path, "w") as f:
49
50
        f.write(f"__version__ = '{version}'\n")
        f.write(f"git_version = {repr(sha)}\n")
51
52
53
        f.write("from torchvision.extension import _check_cuda_version\n")
        f.write("if _check_cuda_version() > 0:\n")
        f.write("    cuda = _check_cuda_version()\n")
54
55


56
57
58
pytorch_dep = "torch"
if os.getenv("PYTORCH_VERSION"):
    pytorch_dep += "==" + os.getenv("PYTORCH_VERSION")
soumith's avatar
soumith committed
59

60
requirements = [
61
    "typing_extensions",
62
    "numpy",
63
    "requests",
64
    pytorch_dep,
65
66
]

67
68
# Excluding 8.3.* because of https://github.com/pytorch/vision/issues/4934
pillow_ver = " >= 5.3.0, !=8.3.*"
69
pillow_req = "pillow-simd" if get_dist("pillow-simd") is not None else "pillow"
70
71
requirements.append(pillow_req + pillow_ver)

72

73
74
def find_library(name, vision_include):
    this_dir = os.path.dirname(os.path.abspath(__file__))
75
    build_prefix = os.environ.get("BUILD_PREFIX", None)
76
77
78
79
80
81
    is_conda_build = build_prefix is not None

    library_found = False
    conda_installed = False
    lib_folder = None
    include_folder = None
82
    library_header = f"{name}.h"
83

84
    # Lookup in TORCHVISION_INCLUDE or in the package file
85
    package_path = [os.path.join(this_dir, "torchvision")]
86
87
88
89
90
91
92
    for folder in vision_include + package_path:
        candidate_path = os.path.join(folder, library_header)
        library_found = os.path.exists(candidate_path)
        if library_found:
            break

    if not library_found:
93
        print(f"Running build on conda-build: {is_conda_build}")
94
95
        if is_conda_build:
            # Add conda headers/libraries
96
97
98
99
100
            if os.name == "nt":
                build_prefix = os.path.join(build_prefix, "Library")
            include_folder = os.path.join(build_prefix, "include")
            lib_folder = os.path.join(build_prefix, "lib")
            library_header_path = os.path.join(include_folder, library_header)
101
102
103
            library_found = os.path.isfile(library_header_path)
            conda_installed = library_found
        else:
104
            # Check if using Anaconda to produce wheels
105
            conda = shutil.which("conda")
106
            is_conda = conda is not None
107
            print(f"Running build on conda: {is_conda}")
108
109
110
            if is_conda:
                python_executable = sys.executable
                py_folder = os.path.dirname(python_executable)
111
112
                if os.name == "nt":
                    env_path = os.path.join(py_folder, "Library")
113
114
                else:
                    env_path = os.path.dirname(py_folder)
115
116
117
                lib_folder = os.path.join(env_path, "lib")
                include_folder = os.path.join(env_path, "include")
                library_header_path = os.path.join(include_folder, library_header)
118
119
120
121
                library_found = os.path.isfile(library_header_path)
                conda_installed = library_found

        if not library_found:
122
            if sys.platform == "linux":
123
124
                library_found = os.path.exists(f"/usr/include/{library_header}")
                library_found = library_found or os.path.exists(f"/usr/local/include/{library_header}")
125
126
127
128

    return library_found, conda_installed, include_folder, lib_folder


129
130
def get_extensions():
    this_dir = os.path.dirname(os.path.abspath(__file__))
131
    extensions_dir = os.path.join(this_dir, "torchvision", "csrc")
132

133
134
135
    main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) + glob.glob(
        os.path.join(extensions_dir, "ops", "*.cpp")
    )
136
    source_cpu = (
137
138
139
        glob.glob(os.path.join(extensions_dir, "ops", "autograd", "*.cpp"))
        + glob.glob(os.path.join(extensions_dir, "ops", "cpu", "*.cpp"))
        + glob.glob(os.path.join(extensions_dir, "ops", "quantized", "cpu", "*.cpp"))
140
    )
141
142

    is_rocm_pytorch = False
143
144

    if torch.__version__ >= "1.5":
145
        from torch.utils.cpp_extension import ROCM_HOME
146

147
        is_rocm_pytorch = (torch.version.hip is not None) and (ROCM_HOME is not None)
148
149

    if is_rocm_pytorch:
150
        from torch.utils.hipify import hipify_python
151

152
153
154
        hipify_python.hipify(
            project_directory=this_dir,
            output_directory=this_dir,
155
            includes="torchvision/csrc/ops/cuda/*",
156
157
            show_detailed=True,
            is_pytorch_extension=True,
158
        )
159
        source_cuda = glob.glob(os.path.join(extensions_dir, "ops", "hip", "*.hip"))
160
        # Copy over additional files
161
162
        for file in glob.glob(r"torchvision/csrc/ops/cuda/*.h"):
            shutil.copy(file, "torchvision/csrc/ops/hip")
163
    else:
164
        source_cuda = glob.glob(os.path.join(extensions_dir, "ops", "cuda", "*.cu"))
165

166
    source_cuda += glob.glob(os.path.join(extensions_dir, "ops", "autocast", "*.cpp"))
167
168
169
170

    sources = main_file + source_cpu
    extension = CppExtension

171
    compile_cpp_tests = os.getenv("WITH_CPP_MODELS_TEST", "0") == "1"
172
    if compile_cpp_tests:
173
174
175
176
        test_dir = os.path.join(this_dir, "test")
        models_dir = os.path.join(this_dir, "torchvision", "csrc", "models")
        test_file = glob.glob(os.path.join(test_dir, "*.cpp"))
        source_models = glob.glob(os.path.join(models_dir, "*.cpp"))
177
178
179
180
181

        test_file = [os.path.join(test_dir, s) for s in test_file]
        source_models = [os.path.join(models_dir, s) for s in source_models]
        tests = test_file + source_models
        tests_include_dirs = [test_dir, models_dir]
Shahriar's avatar
Shahriar committed
182

183
184
    define_macros = []

185
186
187
188
    extra_compile_args = {"cxx": []}
    if (torch.cuda.is_available() and ((CUDA_HOME is not None) or is_rocm_pytorch)) or os.getenv(
        "FORCE_CUDA", "0"
    ) == "1":
189
190
        extension = CUDAExtension
        sources += source_cuda
191
        if not is_rocm_pytorch:
192
193
194
            define_macros += [("WITH_CUDA", None)]
            nvcc_flags = os.getenv("NVCC_FLAGS", "")
            if nvcc_flags == "":
195
196
                nvcc_flags = []
            else:
197
                nvcc_flags = nvcc_flags.split(" ")
Soumith Chintala's avatar
Soumith Chintala committed
198
        else:
199
            define_macros += [("WITH_HIP", None)]
200
            nvcc_flags = []
201
        extra_compile_args["nvcc"] = nvcc_flags
202

203
204
    if sys.platform == "win32":
        define_macros += [("torchvision_EXPORTS", None)]
205
        define_macros += [("USE_PYTHON", None)]
206
        extra_compile_args["cxx"].append("/MP")
Francisco Massa's avatar
Francisco Massa committed
207

208
    debug_mode = os.getenv("DEBUG", "0") == "1"
209
210
    if debug_mode:
        print("Compile in debug mode")
211
212
        extra_compile_args["cxx"].append("-g")
        extra_compile_args["cxx"].append("-O0")
213
214
215
        if "nvcc" in extra_compile_args:
            # we have to remove "-OX" and "-g" flag if exists and append
            nvcc_flags = extra_compile_args["nvcc"]
216
            extra_compile_args["nvcc"] = [f for f in nvcc_flags if not ("-O" in f or "-g" in f)]
217
218
219
            extra_compile_args["nvcc"].append("-O0")
            extra_compile_args["nvcc"].append("-g")

220
221
    sources = [os.path.join(extensions_dir, s) for s in sources]

222
    include_dirs = [extensions_dir]
223
224
225

    ext_modules = [
        extension(
226
            "torchvision._C",
227
            sorted(sources),
228
229
            include_dirs=include_dirs,
            define_macros=define_macros,
Soumith Chintala's avatar
Soumith Chintala committed
230
            extra_compile_args=extra_compile_args,
231
        )
232
    ]
233
234
235
    if compile_cpp_tests:
        ext_modules.append(
            extension(
236
                "torchvision._C_tests",
237
238
239
240
241
242
                tests,
                include_dirs=tests_include_dirs,
                define_macros=define_macros,
                extra_compile_args=extra_compile_args,
            )
        )
243

244
    # ------------------- Torchvision extra extensions ------------------------
245
246
247
248
    vision_include = os.environ.get("TORCHVISION_INCLUDE", None)
    vision_library = os.environ.get("TORCHVISION_LIBRARY", None)
    vision_include = vision_include.split(os.pathsep) if vision_include is not None else []
    vision_library = vision_library.split(os.pathsep) if vision_library is not None else []
249
250
251
252
253
254
255
256
257
    include_dirs += vision_include
    library_dirs = vision_library

    # Image reading extension
    image_macros = []
    image_include = [extensions_dir]
    image_library = []
    image_link_flags = []

258
259
260
    if sys.platform == "win32":
        image_macros += [("USE_PYTHON", None)]

261
    # Locating libPNG
262
263
    libpng = shutil.which("libpng-config")
    pngfix = shutil.which("pngfix")
264
    png_found = libpng is not None or pngfix is not None
265
    print(f"PNG found: {png_found}")
266
267
268
    if png_found:
        if libpng is not None:
            # Linux / Mac
269
270
            png_version = subprocess.run([libpng, "--version"], stdout=subprocess.PIPE)
            png_version = png_version.stdout.strip().decode("utf-8")
271
            print(f"libpng version: {png_version}")
272
273
            png_version = parse_version(png_version)
            if png_version >= parse_version("1.6.0"):
274
275
276
277
                print("Building torchvision with PNG image support")
                png_lib = subprocess.run([libpng, "--libdir"], stdout=subprocess.PIPE)
                png_lib = png_lib.stdout.strip().decode("utf-8")
                if "disabled" not in png_lib:
278
                    image_library += [png_lib]
279
280
281
                png_include = subprocess.run([libpng, "--I_opts"], stdout=subprocess.PIPE)
                png_include = png_include.stdout.strip().decode("utf-8")
                _, png_include = png_include.split("-I")
282
                print(f"libpng include path: {png_include}")
283
                image_include += [png_include]
284
                image_link_flags.append("png")
285
            else:
286
                print("libpng installed version is less than 1.6.0, disabling PNG support")
287
288
289
                png_found = False
        else:
            # Windows
290
291
            png_lib = os.path.join(os.path.dirname(os.path.dirname(pngfix)), "lib")
            png_include = os.path.join(os.path.dirname(os.path.dirname(pngfix)), "include", "libpng16")
292
293
            image_library += [png_lib]
            image_include += [png_include]
294
            image_link_flags.append("libpng")
295

296
    # Locating libjpeg
297
    (jpeg_found, jpeg_conda, jpeg_include, jpeg_lib) = find_library("jpeglib", vision_include)
298

299
    print(f"JPEG found: {jpeg_found}")
300
301
    image_macros += [("PNG_FOUND", str(int(png_found)))]
    image_macros += [("JPEG_FOUND", str(int(jpeg_found)))]
302
    if jpeg_found:
303
304
        print("Building torchvision with JPEG image support")
        image_link_flags.append("jpeg")
305
306
307
308
        if jpeg_conda:
            image_library += [jpeg_lib]
            image_include += [jpeg_include]

309
310
311
    # Locating nvjpeg
    # Should be included in CUDA_HOME for CUDA >= 10.1, which is the minimum version we have in the CI
    nvjpeg_found = (
312
313
314
        extension is CUDAExtension
        and CUDA_HOME is not None
        and os.path.exists(os.path.join(CUDA_HOME, "include", "nvjpeg.h"))
315
316
    )

317
    print(f"NVJPEG found: {nvjpeg_found}")
318
    image_macros += [("NVJPEG_FOUND", str(int(nvjpeg_found)))]
319
    if nvjpeg_found:
320
321
322
323
324
325
326
327
328
        print("Building torchvision with NVJPEG image support")
        image_link_flags.append("nvjpeg")

    image_path = os.path.join(extensions_dir, "io", "image")
    image_src = (
        glob.glob(os.path.join(image_path, "*.cpp"))
        + glob.glob(os.path.join(image_path, "cpu", "*.cpp"))
        + glob.glob(os.path.join(image_path, "cuda", "*.cpp"))
    )
329

330
    if png_found or jpeg_found:
331
332
333
334
335
336
337
338
339
340
341
342
        ext_modules.append(
            extension(
                "torchvision.image",
                image_src,
                include_dirs=image_include + include_dirs + [image_path],
                library_dirs=image_library + library_dirs,
                define_macros=image_macros,
                libraries=image_link_flags,
                extra_compile_args=extra_compile_args,
            )
        )

343
    ffmpeg_exe = shutil.which("ffmpeg")
344
    has_ffmpeg = ffmpeg_exe is not None
345
346
347
348
    # FIXME: Building torchvision with ffmpeg on MacOS or with Python 3.9
    # FIXME: causes crash. See the following GitHub issues for more details.
    # FIXME: https://github.com/pytorch/pytorch/issues/65000
    # FIXME: https://github.com/pytorch/vision/issues/3367
349
    if sys.platform != "linux" or (sys.version_info.major == 3 and sys.version_info.minor == 9):
350
        has_ffmpeg = False
351
352
    if has_ffmpeg:
        try:
353
354
355
            # This is to check if ffmpeg is installed properly.
            subprocess.check_output(["ffmpeg", "-version"])
        except subprocess.CalledProcessError:
356
            print("Error fetching ffmpeg version, ignoring ffmpeg.")
357
358
            has_ffmpeg = False

359
    print(f"FFmpeg found: {has_ffmpeg}")
360

361
    if has_ffmpeg:
362
        ffmpeg_libraries = {"libavcodec", "libavformat", "libavutil", "libswresample", "libswscale"}
363

364
365
        ffmpeg_bin = os.path.dirname(ffmpeg_exe)
        ffmpeg_root = os.path.dirname(ffmpeg_bin)
366
367
        ffmpeg_include_dir = os.path.join(ffmpeg_root, "include")
        ffmpeg_library_dir = os.path.join(ffmpeg_root, "lib")
368

369
        gcc = os.environ.get("CC", shutil.which("gcc"))
370
371
        platform_tag = subprocess.run([gcc, "-print-multiarch"], stdout=subprocess.PIPE)
        platform_tag = platform_tag.stdout.strip().decode("utf-8")
372
373
374

        if platform_tag:
            # Most probably a Debian-based distribution
375
376
            ffmpeg_include_dir = [ffmpeg_include_dir, os.path.join(ffmpeg_include_dir, platform_tag)]
            ffmpeg_library_dir = [ffmpeg_library_dir, os.path.join(ffmpeg_library_dir, platform_tag)]
377
378
379
380
381
382
383
384
        else:
            ffmpeg_include_dir = [ffmpeg_include_dir]
            ffmpeg_library_dir = [ffmpeg_library_dir]

        has_ffmpeg = True
        for library in ffmpeg_libraries:
            library_found = False
            for search_path in ffmpeg_include_dir + include_dirs:
385
                full_path = os.path.join(search_path, library, "*.h")
386
387
388
                library_found |= len(glob.glob(full_path)) > 0

            if not library_found:
389
                print(f"{library} header files were not found, disabling ffmpeg support")
390
391
392
                has_ffmpeg = False

    if has_ffmpeg:
393
394
        print(f"ffmpeg include path: {ffmpeg_include_dir}")
        print(f"ffmpeg library_dir: {ffmpeg_library_dir}")
395
396

        # TorchVision base decoder + video reader
397
        video_reader_src_dir = os.path.join(this_dir, "torchvision", "csrc", "io", "video_reader")
398
        video_reader_src = glob.glob(os.path.join(video_reader_src_dir, "*.cpp"))
399
400
        base_decoder_src_dir = os.path.join(this_dir, "torchvision", "csrc", "io", "decoder")
        base_decoder_src = glob.glob(os.path.join(base_decoder_src_dir, "*.cpp"))
401
        # Torchvision video API
402
        videoapi_src_dir = os.path.join(this_dir, "torchvision", "csrc", "io", "video")
403
        videoapi_src = glob.glob(os.path.join(videoapi_src_dir, "*.cpp"))
404
        # exclude tests
405
        base_decoder_src = [x for x in base_decoder_src if "_test.cpp" not in x]
406

407
        combined_src = video_reader_src + base_decoder_src + videoapi_src
408

409
410
        ext_modules.append(
            CppExtension(
411
                "torchvision.video_reader",
412
                combined_src,
413
                include_dirs=[
414
                    base_decoder_src_dir,
415
                    video_reader_src_dir,
416
                    videoapi_src_dir,
417
                    extensions_dir,
418
                    *ffmpeg_include_dir,
419
                    *include_dirs,
420
                ],
421
                library_dirs=ffmpeg_library_dir + library_dirs,
422
                libraries=[
423
424
425
426
427
                    "avcodec",
                    "avformat",
                    "avutil",
                    "swresample",
                    "swscale",
428
                ],
429
430
                extra_compile_args=["-std=c++14"] if os.name != "nt" else ["/std:c++14", "/MP"],
                extra_link_args=["-std=c++14" if os.name != "nt" else "/std:c++14"],
431
432
            )
        )
433

Prabhat Roy's avatar
Prabhat Roy committed
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
    # Locating video codec
    # CUDA_HOME should be set to the cuda root directory.
    # TORCHVISION_INCLUDE and TORCHVISION_LIBRARY should include the location to
    # video codec header files and libraries respectively.
    video_codec_found = (
        extension is CUDAExtension
        and CUDA_HOME is not None
        and any([os.path.exists(os.path.join(folder, "cuviddec.h")) for folder in vision_include])
        and any([os.path.exists(os.path.join(folder, "nvcuvid.h")) for folder in vision_include])
        and any([os.path.exists(os.path.join(folder, "libnvcuvid.so")) for folder in library_dirs])
    )

    print(f"video codec found: {video_codec_found}")

    if (
        video_codec_found
        and has_ffmpeg
        and any([os.path.exists(os.path.join(folder, "libavcodec", "bsf.h")) for folder in ffmpeg_include_dir])
    ):
        gpu_decoder_path = os.path.join(extensions_dir, "io", "decoder", "gpu")
        gpu_decoder_src = glob.glob(os.path.join(gpu_decoder_path, "*.cpp"))
        cuda_libs = os.path.join(CUDA_HOME, "lib64")
        cuda_inc = os.path.join(CUDA_HOME, "include")

        ext_modules.append(
            extension(
                "torchvision.Decoder",
                gpu_decoder_src,
                include_dirs=include_dirs + [gpu_decoder_path] + [cuda_inc] + ffmpeg_include_dir,
                library_dirs=ffmpeg_library_dir + library_dirs + [cuda_libs],
                libraries=[
                    "avcodec",
                    "avformat",
                    "avutil",
                    "swresample",
                    "swscale",
                    "nvcuvid",
                    "cuda",
                    "cudart",
                    "z",
                    "pthread",
                    "dl",
476
                    "nppicc",
Prabhat Roy's avatar
Prabhat Roy committed
477
478
479
480
481
482
483
484
485
486
487
                ],
                extra_compile_args=extra_compile_args,
            )
        )
    else:
        print(
            "The installed version of ffmpeg is missing the header file 'bsf.h' which is "
            "required for GPU video decoding. Please install the latest ffmpeg from conda-forge channel:"
            " `conda install -c conda-forge ffmpeg`."
        )

488
489
490
491
492
    return ext_modules


class clean(distutils.command.clean.clean):
    def run(self):
493
        with open(".gitignore") as f:
494
            ignores = f.read()
495
            for wildcard in filter(None, ignores.split("\n")):
496
497
498
499
500
501
502
503
504
505
                for filename in glob.glob(wildcard):
                    try:
                        os.remove(filename)
                    except OSError:
                        shutil.rmtree(filename, ignore_errors=True)

        # It's an old-style class in Python 2.7...
        distutils.command.clean.clean.run(self)


506
if __name__ == "__main__":
507
    print(f"Building wheel {package_name}-{version}")
508
509
510

    write_version_file()

511
    with open("README.rst") as f:
512
513
514
515
516
517
        readme = f.read()

    setup(
        # Metadata
        name=package_name,
        version=version,
518
519
520
521
        author="PyTorch Core Team",
        author_email="soumith@pytorch.org",
        url="https://github.com/pytorch/vision",
        description="image and video datasets and models for torch deep learning",
522
        long_description=readme,
523
        license="BSD",
524
        # Package info
525
        packages=find_packages(exclude=("test",)),
526
        package_data={package_name: ["*.dll", "*.dylib", "*.so", "prototype/datasets/_builtin/*.categories"]},
527
528
529
530
531
532
        zip_safe=False,
        install_requires=requirements,
        extras_require={
            "scipy": ["scipy"],
        },
        ext_modules=get_extensions(),
533
        python_requires=">=3.7",
534
        cmdclass={
535
536
537
            "build_ext": BuildExtension.with_options(no_python_abi_suffix=True),
            "clean": clean,
        },
538
    )