setup.py 20.3 KB
Newer Older
1
import distutils.command.clean
2
import distutils.spawn
3
import glob
4
import os
5
import shutil
6
7
import subprocess
import sys
8
9

import torch
10
11
from pkg_resources import parse_version, get_distribution, DistributionNotFound
from setuptools import setup, find_packages
12
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
soumith's avatar
soumith committed
13
14


Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
15
def read(*names, **kwargs):
16
    with open(os.path.join(os.path.dirname(__file__), *names), encoding=kwargs.get("encoding", "utf8")) as fp:
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
17
18
        return fp.read()

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
19

20
21
22
23
24
25
26
def get_dist(pkgname):
    try:
        return get_distribution(pkgname)
    except DistributionNotFound:
        return None


27
28
cwd = os.path.dirname(os.path.abspath(__file__))

29
version_txt = os.path.join(cwd, "version.txt")
30
with open(version_txt) as f:
31
    version = f.readline().strip()
32
33
sha = "Unknown"
package_name = "torchvision"
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
34

35
try:
36
    sha = subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=cwd).decode("ascii").strip()
37
38
39
except Exception:
    pass

40
41
42
43
if os.getenv("BUILD_VERSION"):
    version = os.getenv("BUILD_VERSION")
elif sha != "Unknown":
    version += "+" + sha[:7]
44
45
46


def write_version_file():
47
48
    version_path = os.path.join(cwd, "torchvision", "version.py")
    with open(version_path, "w") as f:
49
50
        f.write(f"__version__ = '{version}'\n")
        f.write(f"git_version = {repr(sha)}\n")
51
52
53
        f.write("from torchvision.extension import _check_cuda_version\n")
        f.write("if _check_cuda_version() > 0:\n")
        f.write("    cuda = _check_cuda_version()\n")
54
55


56
57
58
pytorch_dep = "torch"
if os.getenv("PYTORCH_VERSION"):
    pytorch_dep += "==" + os.getenv("PYTORCH_VERSION")
soumith's avatar
soumith committed
59

60
requirements = [
61
    "typing_extensions",
62
    "numpy",
63
    "requests",
64
    pytorch_dep,
65
66
]

67
68
# Excluding 8.3.* because of https://github.com/pytorch/vision/issues/4934
pillow_ver = " >= 5.3.0, !=8.3.*"
69
pillow_req = "pillow-simd" if get_dist("pillow-simd") is not None else "pillow"
70
71
requirements.append(pillow_req + pillow_ver)

72

73
74
def find_library(name, vision_include):
    this_dir = os.path.dirname(os.path.abspath(__file__))
75
    build_prefix = os.environ.get("BUILD_PREFIX", None)
76
77
78
79
80
81
    is_conda_build = build_prefix is not None

    library_found = False
    conda_installed = False
    lib_folder = None
    include_folder = None
82
    library_header = f"{name}.h"
83

84
    # Lookup in TORCHVISION_INCLUDE or in the package file
85
    package_path = [os.path.join(this_dir, "torchvision")]
86
87
88
89
90
91
92
    for folder in vision_include + package_path:
        candidate_path = os.path.join(folder, library_header)
        library_found = os.path.exists(candidate_path)
        if library_found:
            break

    if not library_found:
93
        print(f"Running build on conda-build: {is_conda_build}")
94
95
        if is_conda_build:
            # Add conda headers/libraries
96
97
98
99
100
            if os.name == "nt":
                build_prefix = os.path.join(build_prefix, "Library")
            include_folder = os.path.join(build_prefix, "include")
            lib_folder = os.path.join(build_prefix, "lib")
            library_header_path = os.path.join(include_folder, library_header)
101
102
103
            library_found = os.path.isfile(library_header_path)
            conda_installed = library_found
        else:
104
            # Check if using Anaconda to produce wheels
105
            conda = shutil.which("conda")
106
            is_conda = conda is not None
107
            print(f"Running build on conda: {is_conda}")
108
109
110
            if is_conda:
                python_executable = sys.executable
                py_folder = os.path.dirname(python_executable)
111
112
                if os.name == "nt":
                    env_path = os.path.join(py_folder, "Library")
113
114
                else:
                    env_path = os.path.dirname(py_folder)
115
116
117
                lib_folder = os.path.join(env_path, "lib")
                include_folder = os.path.join(env_path, "include")
                library_header_path = os.path.join(include_folder, library_header)
118
119
120
121
                library_found = os.path.isfile(library_header_path)
                conda_installed = library_found

        if not library_found:
122
            if sys.platform == "linux":
123
124
                library_found = os.path.exists(f"/usr/include/{library_header}")
                library_found = library_found or os.path.exists(f"/usr/local/include/{library_header}")
125
126
127
128

    return library_found, conda_installed, include_folder, lib_folder


129
130
def get_extensions():
    this_dir = os.path.dirname(os.path.abspath(__file__))
131
    extensions_dir = os.path.join(this_dir, "torchvision", "csrc")
132

133
134
135
    main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) + glob.glob(
        os.path.join(extensions_dir, "ops", "*.cpp")
    )
136
    source_cpu = (
137
138
139
        glob.glob(os.path.join(extensions_dir, "ops", "autograd", "*.cpp"))
        + glob.glob(os.path.join(extensions_dir, "ops", "cpu", "*.cpp"))
        + glob.glob(os.path.join(extensions_dir, "ops", "quantized", "cpu", "*.cpp"))
140
    )
141
142

    is_rocm_pytorch = False
143
144

    if torch.__version__ >= "1.5":
145
        from torch.utils.cpp_extension import ROCM_HOME
146

147
        is_rocm_pytorch = (torch.version.hip is not None) and (ROCM_HOME is not None)
148
149

    if is_rocm_pytorch:
150
        from torch.utils.hipify import hipify_python
151

152
153
154
        hipify_python.hipify(
            project_directory=this_dir,
            output_directory=this_dir,
155
            includes="torchvision/csrc/ops/cuda/*",
156
157
            show_detailed=True,
            is_pytorch_extension=True,
158
        )
159
        source_cuda = glob.glob(os.path.join(extensions_dir, "ops", "hip", "*.hip"))
160
        # Copy over additional files
161
162
        for file in glob.glob(r"torchvision/csrc/ops/cuda/*.h"):
            shutil.copy(file, "torchvision/csrc/ops/hip")
163
    else:
164
        source_cuda = glob.glob(os.path.join(extensions_dir, "ops", "cuda", "*.cu"))
165

166
    source_cuda += glob.glob(os.path.join(extensions_dir, "ops", "autocast", "*.cpp"))
167
168
169
170

    sources = main_file + source_cpu
    extension = CppExtension

171
    compile_cpp_tests = os.getenv("WITH_CPP_MODELS_TEST", "0") == "1"
172
    if compile_cpp_tests:
173
174
175
176
        test_dir = os.path.join(this_dir, "test")
        models_dir = os.path.join(this_dir, "torchvision", "csrc", "models")
        test_file = glob.glob(os.path.join(test_dir, "*.cpp"))
        source_models = glob.glob(os.path.join(models_dir, "*.cpp"))
177
178
179
180
181

        test_file = [os.path.join(test_dir, s) for s in test_file]
        source_models = [os.path.join(models_dir, s) for s in source_models]
        tests = test_file + source_models
        tests_include_dirs = [test_dir, models_dir]
Shahriar's avatar
Shahriar committed
182

183
184
    define_macros = []

185
186
187
188
    extra_compile_args = {"cxx": []}
    if (torch.cuda.is_available() and ((CUDA_HOME is not None) or is_rocm_pytorch)) or os.getenv(
        "FORCE_CUDA", "0"
    ) == "1":
189
190
        extension = CUDAExtension
        sources += source_cuda
191
        if not is_rocm_pytorch:
192
193
194
            define_macros += [("WITH_CUDA", None)]
            nvcc_flags = os.getenv("NVCC_FLAGS", "")
            if nvcc_flags == "":
195
196
                nvcc_flags = []
            else:
197
                nvcc_flags = nvcc_flags.split(" ")
Soumith Chintala's avatar
Soumith Chintala committed
198
        else:
199
            define_macros += [("WITH_HIP", None)]
200
            nvcc_flags = []
201
        extra_compile_args["nvcc"] = nvcc_flags
202

203
204
    if sys.platform == "win32":
        define_macros += [("torchvision_EXPORTS", None)]
205
        define_macros += [("USE_PYTHON", None)]
206
        extra_compile_args["cxx"].append("/MP")
Francisco Massa's avatar
Francisco Massa committed
207

208
    debug_mode = os.getenv("DEBUG", "0") == "1"
209
210
    if debug_mode:
        print("Compile in debug mode")
211
212
        extra_compile_args["cxx"].append("-g")
        extra_compile_args["cxx"].append("-O0")
213
214
215
        if "nvcc" in extra_compile_args:
            # we have to remove "-OX" and "-g" flag if exists and append
            nvcc_flags = extra_compile_args["nvcc"]
216
            extra_compile_args["nvcc"] = [f for f in nvcc_flags if not ("-O" in f or "-g" in f)]
217
218
219
            extra_compile_args["nvcc"].append("-O0")
            extra_compile_args["nvcc"].append("-g")

220
221
    sources = [os.path.join(extensions_dir, s) for s in sources]

222
    include_dirs = [extensions_dir]
223
224
225

    ext_modules = [
        extension(
226
            "torchvision._C",
227
            sorted(sources),
228
229
            include_dirs=include_dirs,
            define_macros=define_macros,
Soumith Chintala's avatar
Soumith Chintala committed
230
            extra_compile_args=extra_compile_args,
231
        )
232
    ]
233
234
235
    if compile_cpp_tests:
        ext_modules.append(
            extension(
236
                "torchvision._C_tests",
237
238
239
240
241
242
                tests,
                include_dirs=tests_include_dirs,
                define_macros=define_macros,
                extra_compile_args=extra_compile_args,
            )
        )
243

244
    # ------------------- Torchvision extra extensions ------------------------
245
246
247
248
    vision_include = os.environ.get("TORCHVISION_INCLUDE", None)
    vision_library = os.environ.get("TORCHVISION_LIBRARY", None)
    vision_include = vision_include.split(os.pathsep) if vision_include is not None else []
    vision_library = vision_library.split(os.pathsep) if vision_library is not None else []
249
250
251
252
253
254
255
256
257
    include_dirs += vision_include
    library_dirs = vision_library

    # Image reading extension
    image_macros = []
    image_include = [extensions_dir]
    image_library = []
    image_link_flags = []

258
259
260
    if sys.platform == "win32":
        image_macros += [("USE_PYTHON", None)]

261
    # Locating libPNG
262
263
    libpng = shutil.which("libpng-config")
    pngfix = shutil.which("pngfix")
264
    png_found = libpng is not None or pngfix is not None
265
    print(f"PNG found: {png_found}")
266
267
268
    if png_found:
        if libpng is not None:
            # Linux / Mac
269
270
            png_version = subprocess.run([libpng, "--version"], stdout=subprocess.PIPE)
            png_version = png_version.stdout.strip().decode("utf-8")
271
            print(f"libpng version: {png_version}")
272
273
            png_version = parse_version(png_version)
            if png_version >= parse_version("1.6.0"):
274
275
276
277
                print("Building torchvision with PNG image support")
                png_lib = subprocess.run([libpng, "--libdir"], stdout=subprocess.PIPE)
                png_lib = png_lib.stdout.strip().decode("utf-8")
                if "disabled" not in png_lib:
278
                    image_library += [png_lib]
279
280
281
                png_include = subprocess.run([libpng, "--I_opts"], stdout=subprocess.PIPE)
                png_include = png_include.stdout.strip().decode("utf-8")
                _, png_include = png_include.split("-I")
282
                print(f"libpng include path: {png_include}")
283
                image_include += [png_include]
284
                image_link_flags.append("png")
285
            else:
286
                print("libpng installed version is less than 1.6.0, disabling PNG support")
287
288
289
                png_found = False
        else:
            # Windows
290
291
            png_lib = os.path.join(os.path.dirname(os.path.dirname(pngfix)), "lib")
            png_include = os.path.join(os.path.dirname(os.path.dirname(pngfix)), "include", "libpng16")
292
293
            image_library += [png_lib]
            image_include += [png_include]
294
            image_link_flags.append("libpng")
295

296
    # Locating libjpeg
297
    (jpeg_found, jpeg_conda, jpeg_include, jpeg_lib) = find_library("jpeglib", vision_include)
298

299
    print(f"JPEG found: {jpeg_found}")
300
301
    image_macros += [("PNG_FOUND", str(int(png_found)))]
    image_macros += [("JPEG_FOUND", str(int(jpeg_found)))]
302
    if jpeg_found:
303
304
        print("Building torchvision with JPEG image support")
        image_link_flags.append("jpeg")
305
306
307
308
        if jpeg_conda:
            image_library += [jpeg_lib]
            image_include += [jpeg_include]

309
310
311
    # Locating nvjpeg
    # Should be included in CUDA_HOME for CUDA >= 10.1, which is the minimum version we have in the CI
    nvjpeg_found = (
312
313
314
        extension is CUDAExtension
        and CUDA_HOME is not None
        and os.path.exists(os.path.join(CUDA_HOME, "include", "nvjpeg.h"))
315
316
    )

317
    print(f"NVJPEG found: {nvjpeg_found}")
318
    image_macros += [("NVJPEG_FOUND", str(int(nvjpeg_found)))]
319
    if nvjpeg_found:
320
321
322
323
324
325
326
327
328
        print("Building torchvision with NVJPEG image support")
        image_link_flags.append("nvjpeg")

    image_path = os.path.join(extensions_dir, "io", "image")
    image_src = (
        glob.glob(os.path.join(image_path, "*.cpp"))
        + glob.glob(os.path.join(image_path, "cpu", "*.cpp"))
        + glob.glob(os.path.join(image_path, "cuda", "*.cpp"))
    )
329

330
    if png_found or jpeg_found:
331
332
333
334
335
336
337
338
339
340
341
342
        ext_modules.append(
            extension(
                "torchvision.image",
                image_src,
                include_dirs=image_include + include_dirs + [image_path],
                library_dirs=image_library + library_dirs,
                define_macros=image_macros,
                libraries=image_link_flags,
                extra_compile_args=extra_compile_args,
            )
        )

343
    ffmpeg_exe = shutil.which("ffmpeg")
344
    has_ffmpeg = ffmpeg_exe is not None
345
346
347
348
    # FIXME: Building torchvision with ffmpeg on MacOS or with Python 3.9
    # FIXME: causes crash. See the following GitHub issues for more details.
    # FIXME: https://github.com/pytorch/pytorch/issues/65000
    # FIXME: https://github.com/pytorch/vision/issues/3367
349
    if sys.platform != "linux" or (sys.version_info.major == 3 and sys.version_info.minor == 9):
350
        has_ffmpeg = False
351
352
    if has_ffmpeg:
        try:
353
354
355
            # This is to check if ffmpeg is installed properly.
            subprocess.check_output(["ffmpeg", "-version"])
        except subprocess.CalledProcessError:
356
            print("Error fetching ffmpeg version, ignoring ffmpeg.")
357
358
            has_ffmpeg = False

359
360
    use_ffmpeg = os.getenv("TORCHVISION_USE_FFMPEG", "1") == "1"
    has_ffmpeg = has_ffmpeg and use_ffmpeg
361
    print(f"FFmpeg found: {has_ffmpeg}")
362

363
    if has_ffmpeg:
364
        ffmpeg_libraries = {"libavcodec", "libavformat", "libavutil", "libswresample", "libswscale"}
365

366
367
        ffmpeg_bin = os.path.dirname(ffmpeg_exe)
        ffmpeg_root = os.path.dirname(ffmpeg_bin)
368
369
        ffmpeg_include_dir = os.path.join(ffmpeg_root, "include")
        ffmpeg_library_dir = os.path.join(ffmpeg_root, "lib")
370

371
        gcc = os.environ.get("CC", shutil.which("gcc"))
372
373
        platform_tag = subprocess.run([gcc, "-print-multiarch"], stdout=subprocess.PIPE)
        platform_tag = platform_tag.stdout.strip().decode("utf-8")
374
375
376

        if platform_tag:
            # Most probably a Debian-based distribution
377
378
            ffmpeg_include_dir = [ffmpeg_include_dir, os.path.join(ffmpeg_include_dir, platform_tag)]
            ffmpeg_library_dir = [ffmpeg_library_dir, os.path.join(ffmpeg_library_dir, platform_tag)]
379
380
381
382
383
384
385
386
        else:
            ffmpeg_include_dir = [ffmpeg_include_dir]
            ffmpeg_library_dir = [ffmpeg_library_dir]

        has_ffmpeg = True
        for library in ffmpeg_libraries:
            library_found = False
            for search_path in ffmpeg_include_dir + include_dirs:
387
                full_path = os.path.join(search_path, library, "*.h")
388
389
390
                library_found |= len(glob.glob(full_path)) > 0

            if not library_found:
391
                print(f"{library} header files were not found, disabling ffmpeg support")
392
393
394
                has_ffmpeg = False

    if has_ffmpeg:
395
396
        print(f"ffmpeg include path: {ffmpeg_include_dir}")
        print(f"ffmpeg library_dir: {ffmpeg_library_dir}")
397
398

        # TorchVision base decoder + video reader
399
        video_reader_src_dir = os.path.join(this_dir, "torchvision", "csrc", "io", "video_reader")
400
        video_reader_src = glob.glob(os.path.join(video_reader_src_dir, "*.cpp"))
401
402
        base_decoder_src_dir = os.path.join(this_dir, "torchvision", "csrc", "io", "decoder")
        base_decoder_src = glob.glob(os.path.join(base_decoder_src_dir, "*.cpp"))
403
        # Torchvision video API
404
        videoapi_src_dir = os.path.join(this_dir, "torchvision", "csrc", "io", "video")
405
        videoapi_src = glob.glob(os.path.join(videoapi_src_dir, "*.cpp"))
406
        # exclude tests
407
        base_decoder_src = [x for x in base_decoder_src if "_test.cpp" not in x]
408

409
        combined_src = video_reader_src + base_decoder_src + videoapi_src
410

411
412
        ext_modules.append(
            CppExtension(
413
                "torchvision.video_reader",
414
                combined_src,
415
                include_dirs=[
416
                    base_decoder_src_dir,
417
                    video_reader_src_dir,
418
                    videoapi_src_dir,
419
                    extensions_dir,
420
                    *ffmpeg_include_dir,
421
                    *include_dirs,
422
                ],
423
                library_dirs=ffmpeg_library_dir + library_dirs,
424
                libraries=[
425
426
427
428
429
                    "avcodec",
                    "avformat",
                    "avutil",
                    "swresample",
                    "swscale",
430
                ],
431
432
                extra_compile_args=["-std=c++14"] if os.name != "nt" else ["/std:c++14", "/MP"],
                extra_link_args=["-std=c++14" if os.name != "nt" else "/std:c++14"],
433
434
            )
        )
435

Prabhat Roy's avatar
Prabhat Roy committed
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
    # Locating video codec
    # CUDA_HOME should be set to the cuda root directory.
    # TORCHVISION_INCLUDE and TORCHVISION_LIBRARY should include the location to
    # video codec header files and libraries respectively.
    video_codec_found = (
        extension is CUDAExtension
        and CUDA_HOME is not None
        and any([os.path.exists(os.path.join(folder, "cuviddec.h")) for folder in vision_include])
        and any([os.path.exists(os.path.join(folder, "nvcuvid.h")) for folder in vision_include])
        and any([os.path.exists(os.path.join(folder, "libnvcuvid.so")) for folder in library_dirs])
    )

    print(f"video codec found: {video_codec_found}")

    if (
        video_codec_found
        and has_ffmpeg
        and any([os.path.exists(os.path.join(folder, "libavcodec", "bsf.h")) for folder in ffmpeg_include_dir])
    ):
        gpu_decoder_path = os.path.join(extensions_dir, "io", "decoder", "gpu")
        gpu_decoder_src = glob.glob(os.path.join(gpu_decoder_path, "*.cpp"))
        cuda_libs = os.path.join(CUDA_HOME, "lib64")
        cuda_inc = os.path.join(CUDA_HOME, "include")

        ext_modules.append(
            extension(
                "torchvision.Decoder",
                gpu_decoder_src,
                include_dirs=include_dirs + [gpu_decoder_path] + [cuda_inc] + ffmpeg_include_dir,
                library_dirs=ffmpeg_library_dir + library_dirs + [cuda_libs],
                libraries=[
                    "avcodec",
                    "avformat",
                    "avutil",
                    "swresample",
                    "swscale",
                    "nvcuvid",
                    "cuda",
                    "cudart",
                    "z",
                    "pthread",
                    "dl",
478
                    "nppicc",
Prabhat Roy's avatar
Prabhat Roy committed
479
480
481
482
483
484
485
486
487
488
489
                ],
                extra_compile_args=extra_compile_args,
            )
        )
    else:
        print(
            "The installed version of ffmpeg is missing the header file 'bsf.h' which is "
            "required for GPU video decoding. Please install the latest ffmpeg from conda-forge channel:"
            " `conda install -c conda-forge ffmpeg`."
        )

490
491
492
493
494
    return ext_modules


class clean(distutils.command.clean.clean):
    def run(self):
495
        with open(".gitignore") as f:
496
            ignores = f.read()
497
            for wildcard in filter(None, ignores.split("\n")):
498
499
500
501
502
503
504
505
506
507
                for filename in glob.glob(wildcard):
                    try:
                        os.remove(filename)
                    except OSError:
                        shutil.rmtree(filename, ignore_errors=True)

        # It's an old-style class in Python 2.7...
        distutils.command.clean.clean.run(self)


508
if __name__ == "__main__":
509
    print(f"Building wheel {package_name}-{version}")
510
511
512

    write_version_file()

513
    with open("README.rst") as f:
514
515
516
517
518
519
        readme = f.read()

    setup(
        # Metadata
        name=package_name,
        version=version,
520
521
522
523
        author="PyTorch Core Team",
        author_email="soumith@pytorch.org",
        url="https://github.com/pytorch/vision",
        description="image and video datasets and models for torch deep learning",
524
        long_description=readme,
525
        license="BSD",
526
        # Package info
527
        packages=find_packages(exclude=("test",)),
528
        package_data={package_name: ["*.dll", "*.dylib", "*.so", "prototype/datasets/_builtin/*.categories"]},
529
530
531
532
533
534
        zip_safe=False,
        install_requires=requirements,
        extras_require={
            "scipy": ["scipy"],
        },
        ext_modules=get_extensions(),
535
        python_requires=">=3.7",
536
        cmdclass={
537
538
539
            "build_ext": BuildExtension.with_options(no_python_abi_suffix=True),
            "clean": clean,
        },
540
    )