setup.py 22.7 KB
Newer Older
1
import distutils.command.clean
2
import distutils.spawn
3
import glob
4
import os
5
import shutil
6
7
import subprocess
import sys
8
9

import torch
10
11
12
from pkg_resources import DistributionNotFound, get_distribution, parse_version
from setuptools import find_packages, setup
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDA_HOME, CUDAExtension
soumith's avatar
soumith committed
13
14


Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
15
def read(*names, **kwargs):
16
    with open(os.path.join(os.path.dirname(__file__), *names), encoding=kwargs.get("encoding", "utf8")) as fp:
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
17
18
        return fp.read()

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
19

20
21
22
23
24
25
26
def get_dist(pkgname):
    try:
        return get_distribution(pkgname)
    except DistributionNotFound:
        return None


27
28
cwd = os.path.dirname(os.path.abspath(__file__))

29
version_txt = os.path.join(cwd, "version.txt")
30
with open(version_txt) as f:
31
    version = f.readline().strip()
32
33
sha = "Unknown"
package_name = "torchvision"
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
34

panning's avatar
panning committed
35
36
37
38
39
40
41
42
43
44
45
46
47
dcu_version = version

def get_abi():
    try:
        command = "echo '#include <string>' | gcc -x c++ -E -dM - | fgrep _GLIBCXX_USE_CXX11_ABI"
        result = subprocess.run(command, shell=True, capture_output=True, text=True)
        output = result.stdout.strip()
        abi = "abi" + output.split(" ")[-1]
        return abi
    except Exception:
        return 'abiUnknown'


48
try:
49
    sha = subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=cwd).decode("ascii").strip()
50
51
52
except Exception:
    pass

53
54
if os.getenv("BUILD_VERSION"):
    version = os.getenv("BUILD_VERSION")
panning's avatar
panning committed
55
56
57
58
59
60
61
62
63
#elif sha != "Unknown":
#    version += "+" + sha[:7]

if sha != 'Unknown':
    dcu_version += '+git' + sha[:7]

dcu_version += "." + get_abi()
if os.getenv("ROCM_PATH"):
    rocm_path = os.getenv('ROCM_PATH', "")
panning's avatar
panning committed
64
    rocm_version_path = os.path.join(rocm_path, '.info', "rocm_version")
panning's avatar
panning committed
65
66
67
68
69
70
    with open(rocm_version_path, 'r',encoding='utf-8') as file:
        lines = file.readlines()
    rocm_version=lines[0][:-2].replace(".", "")
    dcu_version += ".dtk" + rocm_version
# torch version
dcu_version += ".torch" + torch.__version__[:-2]
71
72
73


def write_version_file():
74
75
    version_path = os.path.join(cwd, "torchvision", "version.py")
    with open(version_path, "w") as f:
76
        f.write(f"__version__ = '{version}'\n")
panning's avatar
panning committed
77
        f.write(f"__dcu_version__ = '{dcu_version}'\n")
78
        f.write(f"git_version = {repr(sha)}\n")
79
80
81
        f.write("from torchvision.extension import _check_cuda_version\n")
        f.write("if _check_cuda_version() > 0:\n")
        f.write("    cuda = _check_cuda_version()\n")
82
83


84
85
86
pytorch_dep = "torch"
if os.getenv("PYTORCH_VERSION"):
    pytorch_dep += "==" + os.getenv("PYTORCH_VERSION")
soumith's avatar
soumith committed
87

88
requirements = [
89
    "numpy",
90
    "requests",
91
    pytorch_dep,
92
93
]

94
95
# Excluding 8.3.* because of https://github.com/pytorch/vision/issues/4934
pillow_ver = " >= 5.3.0, !=8.3.*"
96
pillow_req = "pillow-simd" if get_dist("pillow-simd") is not None else "pillow"
97
98
requirements.append(pillow_req + pillow_ver)

99

100
101
def find_library(name, vision_include):
    this_dir = os.path.dirname(os.path.abspath(__file__))
102
    build_prefix = os.environ.get("BUILD_PREFIX", None)
103
104
105
106
107
108
    is_conda_build = build_prefix is not None

    library_found = False
    conda_installed = False
    lib_folder = None
    include_folder = None
109
    library_header = f"{name}.h"
110

111
    # Lookup in TORCHVISION_INCLUDE or in the package file
112
    package_path = [os.path.join(this_dir, "torchvision")]
113
114
115
116
117
118
119
    for folder in vision_include + package_path:
        candidate_path = os.path.join(folder, library_header)
        library_found = os.path.exists(candidate_path)
        if library_found:
            break

    if not library_found:
120
        print(f"Running build on conda-build: {is_conda_build}")
121
122
        if is_conda_build:
            # Add conda headers/libraries
123
124
125
126
127
            if os.name == "nt":
                build_prefix = os.path.join(build_prefix, "Library")
            include_folder = os.path.join(build_prefix, "include")
            lib_folder = os.path.join(build_prefix, "lib")
            library_header_path = os.path.join(include_folder, library_header)
128
129
130
            library_found = os.path.isfile(library_header_path)
            conda_installed = library_found
        else:
131
            # Check if using Anaconda to produce wheels
132
            conda = shutil.which("conda")
133
            is_conda = conda is not None
134
            print(f"Running build on conda: {is_conda}")
135
136
137
            if is_conda:
                python_executable = sys.executable
                py_folder = os.path.dirname(python_executable)
138
139
                if os.name == "nt":
                    env_path = os.path.join(py_folder, "Library")
140
141
                else:
                    env_path = os.path.dirname(py_folder)
142
143
144
                lib_folder = os.path.join(env_path, "lib")
                include_folder = os.path.join(env_path, "include")
                library_header_path = os.path.join(include_folder, library_header)
145
146
147
148
                library_found = os.path.isfile(library_header_path)
                conda_installed = library_found

        if not library_found:
149
            if sys.platform == "linux":
150
151
                library_found = os.path.exists(f"/usr/include/{library_header}")
                library_found = library_found or os.path.exists(f"/usr/local/include/{library_header}")
152
153
154
155

    return library_found, conda_installed, include_folder, lib_folder


156
157
def get_extensions():
    this_dir = os.path.dirname(os.path.abspath(__file__))
158
    extensions_dir = os.path.join(this_dir, "torchvision", "csrc")
159

160
161
162
    main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) + glob.glob(
        os.path.join(extensions_dir, "ops", "*.cpp")
    )
163
    source_cpu = (
164
165
166
        glob.glob(os.path.join(extensions_dir, "ops", "autograd", "*.cpp"))
        + glob.glob(os.path.join(extensions_dir, "ops", "cpu", "*.cpp"))
        + glob.glob(os.path.join(extensions_dir, "ops", "quantized", "cpu", "*.cpp"))
167
    )
168
    source_mps = glob.glob(os.path.join(extensions_dir, "ops", "mps", "*.mm"))
169

170
171
172
    print("Compiling extensions with following flags:")
    force_cuda = os.getenv("FORCE_CUDA", "0") == "1"
    print(f"  FORCE_CUDA: {force_cuda}")
173
174
    force_mps = os.getenv("FORCE_MPS", "0") == "1"
    print(f"  FORCE_MPS: {force_mps}")
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
    debug_mode = os.getenv("DEBUG", "0") == "1"
    print(f"  DEBUG: {debug_mode}")
    use_png = os.getenv("TORCHVISION_USE_PNG", "1") == "1"
    print(f"  TORCHVISION_USE_PNG: {use_png}")
    use_jpeg = os.getenv("TORCHVISION_USE_JPEG", "1") == "1"
    print(f"  TORCHVISION_USE_JPEG: {use_jpeg}")
    use_nvjpeg = os.getenv("TORCHVISION_USE_NVJPEG", "1") == "1"
    print(f"  TORCHVISION_USE_NVJPEG: {use_nvjpeg}")
    use_ffmpeg = os.getenv("TORCHVISION_USE_FFMPEG", "1") == "1"
    print(f"  TORCHVISION_USE_FFMPEG: {use_ffmpeg}")
    use_video_codec = os.getenv("TORCHVISION_USE_VIDEO_CODEC", "1") == "1"
    print(f"  TORCHVISION_USE_VIDEO_CODEC: {use_video_codec}")

    nvcc_flags = os.getenv("NVCC_FLAGS", "")
    print(f"  NVCC_FLAGS: {nvcc_flags}")

191
    is_rocm_pytorch = False
192
193

    if torch.__version__ >= "1.5":
194
        from torch.utils.cpp_extension import ROCM_HOME
195

196
        is_rocm_pytorch = (torch.version.hip is not None) and (ROCM_HOME is not None)
197
198

    if is_rocm_pytorch:
199
        from torch.utils.hipify import hipify_python
200

201
202
203
        hipify_python.hipify(
            project_directory=this_dir,
            output_directory=this_dir,
204
            includes="torchvision/csrc/ops/cuda/*",
205
206
            show_detailed=True,
            is_pytorch_extension=True,
207
        )
208
        source_cuda = glob.glob(os.path.join(extensions_dir, "ops", "hip", "*.hip"))
209
        # Copy over additional files
210
211
        for file in glob.glob(r"torchvision/csrc/ops/cuda/*.h"):
            shutil.copy(file, "torchvision/csrc/ops/hip")
212
    else:
213
        source_cuda = glob.glob(os.path.join(extensions_dir, "ops", "cuda", "*.cu"))
214

215
    source_cuda += glob.glob(os.path.join(extensions_dir, "ops", "autocast", "*.cpp"))
216
217
218
219
220
221

    sources = main_file + source_cpu
    extension = CppExtension

    define_macros = []

222
    extra_compile_args = {"cxx": []}
223
    if (torch.cuda.is_available() and ((CUDA_HOME is not None) or is_rocm_pytorch)) or force_cuda:
224
225
        extension = CUDAExtension
        sources += source_cuda
226
        if not is_rocm_pytorch:
227
228
            define_macros += [("WITH_CUDA", None)]
            if nvcc_flags == "":
229
230
                nvcc_flags = []
            else:
231
                nvcc_flags = nvcc_flags.split(" ")
Soumith Chintala's avatar
Soumith Chintala committed
232
        else:
233
            define_macros += [("WITH_HIP", None)]
234
            nvcc_flags = []
235
        extra_compile_args["nvcc"] = nvcc_flags
236
237
    elif torch.backends.mps.is_available() or force_mps:
        sources += source_mps
238

239
240
    if sys.platform == "win32":
        define_macros += [("torchvision_EXPORTS", None)]
241
        define_macros += [("USE_PYTHON", None)]
242
        extra_compile_args["cxx"].append("/MP")
Francisco Massa's avatar
Francisco Massa committed
243

244
    if debug_mode:
245
        print("Compiling in debug mode")
246
247
        extra_compile_args["cxx"].append("-g")
        extra_compile_args["cxx"].append("-O0")
248
249
250
        if "nvcc" in extra_compile_args:
            # we have to remove "-OX" and "-g" flag if exists and append
            nvcc_flags = extra_compile_args["nvcc"]
251
            extra_compile_args["nvcc"] = [f for f in nvcc_flags if not ("-O" in f or "-g" in f)]
252
253
            extra_compile_args["nvcc"].append("-O0")
            extra_compile_args["nvcc"].append("-g")
254
255
256
    else:
        print("Compiling with debug mode OFF")
        extra_compile_args["cxx"].append("-g0")
257

258
259
    sources = [os.path.join(extensions_dir, s) for s in sources]

260
    include_dirs = [extensions_dir]
261
262
263

    ext_modules = [
        extension(
264
            "torchvision._C",
265
            sorted(sources),
266
267
            include_dirs=include_dirs,
            define_macros=define_macros,
Soumith Chintala's avatar
Soumith Chintala committed
268
            extra_compile_args=extra_compile_args,
269
        )
270
    ]
271

272
    # ------------------- Torchvision extra extensions ------------------------
273
274
275
276
    vision_include = os.environ.get("TORCHVISION_INCLUDE", None)
    vision_library = os.environ.get("TORCHVISION_LIBRARY", None)
    vision_include = vision_include.split(os.pathsep) if vision_include is not None else []
    vision_library = vision_library.split(os.pathsep) if vision_library is not None else []
277
278
279
280
281
282
283
284
285
    include_dirs += vision_include
    library_dirs = vision_library

    # Image reading extension
    image_macros = []
    image_include = [extensions_dir]
    image_library = []
    image_link_flags = []

286
287
288
    if sys.platform == "win32":
        image_macros += [("USE_PYTHON", None)]

289
    # Locating libPNG
290
291
    libpng = shutil.which("libpng-config")
    pngfix = shutil.which("pngfix")
292
    png_found = libpng is not None or pngfix is not None
293
294
295
296

    use_png = use_png and png_found
    if use_png:
        print("Found PNG library")
297
298
        if libpng is not None:
            # Linux / Mac
299
            min_version = "1.6.0"
300
301
            png_version = subprocess.run([libpng, "--version"], stdout=subprocess.PIPE)
            png_version = png_version.stdout.strip().decode("utf-8")
302
            png_version = parse_version(png_version)
303
            if png_version >= parse_version(min_version):
304
305
306
307
                print("Building torchvision with PNG image support")
                png_lib = subprocess.run([libpng, "--libdir"], stdout=subprocess.PIPE)
                png_lib = png_lib.stdout.strip().decode("utf-8")
                if "disabled" not in png_lib:
308
                    image_library += [png_lib]
309
310
311
                png_include = subprocess.run([libpng, "--I_opts"], stdout=subprocess.PIPE)
                png_include = png_include.stdout.strip().decode("utf-8")
                _, png_include = png_include.split("-I")
312
                image_include += [png_include]
313
                image_link_flags.append("png")
314
315
                print(f"  libpng version: {png_version}")
                print(f"  libpng include path: {png_include}")
316
            else:
317
318
319
                print("Could not add PNG image support to torchvision:")
                print(f"  libpng minimum version {min_version}, found {png_version}")
                use_png = False
320
321
        else:
            # Windows
322
323
            png_lib = os.path.join(os.path.dirname(os.path.dirname(pngfix)), "lib")
            png_include = os.path.join(os.path.dirname(os.path.dirname(pngfix)), "include", "libpng16")
324
325
            image_library += [png_lib]
            image_include += [png_include]
326
            image_link_flags.append("libpng")
327
328
329
    else:
        print("Building torchvision without PNG image support")
    image_macros += [("PNG_FOUND", str(int(use_png)))]
330

331
    # Locating libjpeg
332
    (jpeg_found, jpeg_conda, jpeg_include, jpeg_lib) = find_library("jpeglib", vision_include)
333

334
335
    use_jpeg = use_jpeg and jpeg_found
    if use_jpeg:
336
        print("Building torchvision with JPEG image support")
Nicolas Hug's avatar
Nicolas Hug committed
337
338
        print(f"  libjpeg include path: {jpeg_include}")
        print(f"  libjpeg lib path: {jpeg_lib}")
339
        image_link_flags.append("jpeg")
340
341
342
        if jpeg_conda:
            image_library += [jpeg_lib]
            image_include += [jpeg_include]
343
344
345
    else:
        print("Building torchvision without JPEG image support")
    image_macros += [("JPEG_FOUND", str(int(use_jpeg)))]
346

347
348
349
    # Locating nvjpeg
    # Should be included in CUDA_HOME for CUDA >= 10.1, which is the minimum version we have in the CI
    nvjpeg_found = (
350
351
352
        extension is CUDAExtension
        and CUDA_HOME is not None
        and os.path.exists(os.path.join(CUDA_HOME, "include", "nvjpeg.h"))
353
354
    )

355
356
    use_nvjpeg = use_nvjpeg and nvjpeg_found
    if use_nvjpeg:
357
358
        print("Building torchvision with NVJPEG image support")
        image_link_flags.append("nvjpeg")
359
360
361
    else:
        print("Building torchvision without NVJPEG image support")
    image_macros += [("NVJPEG_FOUND", str(int(use_nvjpeg)))]
362
363

    image_path = os.path.join(extensions_dir, "io", "image")
Philip Meier's avatar
Philip Meier committed
364
    image_src = glob.glob(os.path.join(image_path, "*.cpp")) + glob.glob(os.path.join(image_path, "cpu", "*.cpp"))
365

366
367
368
369
370
371
    if is_rocm_pytorch:
        image_src += glob.glob(os.path.join(image_path, "hip", "*.cpp"))
        # we need to exclude this in favor of the hipified source
        image_src.remove(os.path.join(image_path, "image.cpp"))
    else:
        image_src += glob.glob(os.path.join(image_path, "cuda", "*.cpp"))
372

373
    if use_png or use_jpeg:
374
375
376
377
378
379
380
381
382
383
384
385
        ext_modules.append(
            extension(
                "torchvision.image",
                image_src,
                include_dirs=image_include + include_dirs + [image_path],
                library_dirs=image_library + library_dirs,
                define_macros=image_macros,
                libraries=image_link_flags,
                extra_compile_args=extra_compile_args,
            )
        )

386
    # Locating ffmpeg
387
    ffmpeg_exe = shutil.which("ffmpeg")
388
    has_ffmpeg = ffmpeg_exe is not None
389
    ffmpeg_version = None
390
391
392
393
    # FIXME: Building torchvision with ffmpeg on MacOS or with Python 3.9
    # FIXME: causes crash. See the following GitHub issues for more details.
    # FIXME: https://github.com/pytorch/pytorch/issues/65000
    # FIXME: https://github.com/pytorch/vision/issues/3367
394
    if sys.platform != "linux" or (sys.version_info.major == 3 and sys.version_info.minor == 9):
395
        has_ffmpeg = False
396
397
    if has_ffmpeg:
        try:
398
            # This is to check if ffmpeg is installed properly.
399
            ffmpeg_version = subprocess.check_output(["ffmpeg", "-version"])
400
        except subprocess.CalledProcessError:
401
402
            print("Building torchvision without ffmpeg support")
            print("  Error fetching ffmpeg version, ignoring ffmpeg.")
403
404
            has_ffmpeg = False

405
    use_ffmpeg = use_ffmpeg and has_ffmpeg
406

407
    if use_ffmpeg:
408
        ffmpeg_libraries = {"libavcodec", "libavformat", "libavutil", "libswresample", "libswscale"}
409

410
411
        ffmpeg_bin = os.path.dirname(ffmpeg_exe)
        ffmpeg_root = os.path.dirname(ffmpeg_bin)
412
413
        ffmpeg_include_dir = os.path.join(ffmpeg_root, "include")
        ffmpeg_library_dir = os.path.join(ffmpeg_root, "lib")
414

415
        gcc = os.environ.get("CC", shutil.which("gcc"))
416
417
        platform_tag = subprocess.run([gcc, "-print-multiarch"], stdout=subprocess.PIPE)
        platform_tag = platform_tag.stdout.strip().decode("utf-8")
418
419
420

        if platform_tag:
            # Most probably a Debian-based distribution
421
422
            ffmpeg_include_dir = [ffmpeg_include_dir, os.path.join(ffmpeg_include_dir, platform_tag)]
            ffmpeg_library_dir = [ffmpeg_library_dir, os.path.join(ffmpeg_library_dir, platform_tag)]
423
424
425
426
427
428
429
        else:
            ffmpeg_include_dir = [ffmpeg_include_dir]
            ffmpeg_library_dir = [ffmpeg_library_dir]

        for library in ffmpeg_libraries:
            library_found = False
            for search_path in ffmpeg_include_dir + include_dirs:
430
                full_path = os.path.join(search_path, library, "*.h")
431
432
433
                library_found |= len(glob.glob(full_path)) > 0

            if not library_found:
434
435
436
437
438
                print("Building torchvision without ffmpeg support")
                print(f"  {library} header files were not found, disabling ffmpeg support")
                use_ffmpeg = False
    else:
        print("Building torchvision without ffmpeg support")
439

440
441
442
443
444
    if use_ffmpeg:
        print("Building torchvision with ffmpeg support")
        print(f"  ffmpeg version: {ffmpeg_version}")
        print(f"  ffmpeg include path: {ffmpeg_include_dir}")
        print(f"  ffmpeg library_dir: {ffmpeg_library_dir}")
445
446

        # TorchVision base decoder + video reader
447
        video_reader_src_dir = os.path.join(this_dir, "torchvision", "csrc", "io", "video_reader")
448
        video_reader_src = glob.glob(os.path.join(video_reader_src_dir, "*.cpp"))
449
450
        base_decoder_src_dir = os.path.join(this_dir, "torchvision", "csrc", "io", "decoder")
        base_decoder_src = glob.glob(os.path.join(base_decoder_src_dir, "*.cpp"))
451
        # Torchvision video API
452
        videoapi_src_dir = os.path.join(this_dir, "torchvision", "csrc", "io", "video")
453
        videoapi_src = glob.glob(os.path.join(videoapi_src_dir, "*.cpp"))
454
        # exclude tests
455
        base_decoder_src = [x for x in base_decoder_src if "_test.cpp" not in x]
456

457
        combined_src = video_reader_src + base_decoder_src + videoapi_src
458

459
460
        ext_modules.append(
            CppExtension(
461
                "torchvision.video_reader",
462
                combined_src,
463
                include_dirs=[
464
                    base_decoder_src_dir,
465
                    video_reader_src_dir,
466
                    videoapi_src_dir,
467
                    extensions_dir,
468
                    *ffmpeg_include_dir,
469
                    *include_dirs,
470
                ],
471
                library_dirs=ffmpeg_library_dir + library_dirs,
472
                libraries=[
473
474
475
476
477
                    "avcodec",
                    "avformat",
                    "avutil",
                    "swresample",
                    "swscale",
478
                ],
479
480
                extra_compile_args=["-std=c++17"] if os.name != "nt" else ["/std:c++17", "/MP"],
                extra_link_args=["-std=c++17" if os.name != "nt" else "/std:c++17"],
481
482
            )
        )
483

Prabhat Roy's avatar
Prabhat Roy committed
484
485
486
487
488
489
490
491
492
493
494
495
    # Locating video codec
    # CUDA_HOME should be set to the cuda root directory.
    # TORCHVISION_INCLUDE and TORCHVISION_LIBRARY should include the location to
    # video codec header files and libraries respectively.
    video_codec_found = (
        extension is CUDAExtension
        and CUDA_HOME is not None
        and any([os.path.exists(os.path.join(folder, "cuviddec.h")) for folder in vision_include])
        and any([os.path.exists(os.path.join(folder, "nvcuvid.h")) for folder in vision_include])
        and any([os.path.exists(os.path.join(folder, "libnvcuvid.so")) for folder in library_dirs])
    )

496
    use_video_codec = use_video_codec and video_codec_found
Prabhat Roy's avatar
Prabhat Roy committed
497
    if (
498
499
        use_video_codec
        and use_ffmpeg
Prabhat Roy's avatar
Prabhat Roy committed
500
501
        and any([os.path.exists(os.path.join(folder, "libavcodec", "bsf.h")) for folder in ffmpeg_include_dir])
    ):
502
        print("Building torchvision with video codec support")
Prabhat Roy's avatar
Prabhat Roy committed
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
        gpu_decoder_path = os.path.join(extensions_dir, "io", "decoder", "gpu")
        gpu_decoder_src = glob.glob(os.path.join(gpu_decoder_path, "*.cpp"))
        cuda_libs = os.path.join(CUDA_HOME, "lib64")
        cuda_inc = os.path.join(CUDA_HOME, "include")

        ext_modules.append(
            extension(
                "torchvision.Decoder",
                gpu_decoder_src,
                include_dirs=include_dirs + [gpu_decoder_path] + [cuda_inc] + ffmpeg_include_dir,
                library_dirs=ffmpeg_library_dir + library_dirs + [cuda_libs],
                libraries=[
                    "avcodec",
                    "avformat",
                    "avutil",
                    "swresample",
                    "swscale",
                    "nvcuvid",
                    "cuda",
                    "cudart",
                    "z",
                    "pthread",
                    "dl",
526
                    "nppicc",
Prabhat Roy's avatar
Prabhat Roy committed
527
528
529
530
531
                ],
                extra_compile_args=extra_compile_args,
            )
        )
    else:
532
533
534
535
536
537
538
539
540
541
542
        print("Building torchvision without video codec support")
        if (
            use_video_codec
            and use_ffmpeg
            and not any([os.path.exists(os.path.join(folder, "libavcodec", "bsf.h")) for folder in ffmpeg_include_dir])
        ):
            print(
                "  The installed version of ffmpeg is missing the header file 'bsf.h' which is "
                "  required for GPU video decoding. Please install the latest ffmpeg from conda-forge channel:"
                "   `conda install -c conda-forge ffmpeg`."
            )
Prabhat Roy's avatar
Prabhat Roy committed
543

544
545
546
547
548
    return ext_modules


class clean(distutils.command.clean.clean):
    def run(self):
549
        with open(".gitignore") as f:
550
            ignores = f.read()
551
            for wildcard in filter(None, ignores.split("\n")):
552
553
554
555
556
557
558
559
560
561
                for filename in glob.glob(wildcard):
                    try:
                        os.remove(filename)
                    except OSError:
                        shutil.rmtree(filename, ignore_errors=True)

        # It's an old-style class in Python 2.7...
        distutils.command.clean.clean.run(self)


562
if __name__ == "__main__":
563
    print(f"Building wheel {package_name}-{version}")
564
565
566

    write_version_file()

panning's avatar
panning committed
567
    with open("README.md") as f:
568
569
570
571
572
        readme = f.read()

    setup(
        # Metadata
        name=package_name,
panning's avatar
panning committed
573
        version=dcu_version,
574
575
576
577
        author="PyTorch Core Team",
        author_email="soumith@pytorch.org",
        url="https://github.com/pytorch/vision",
        description="image and video datasets and models for torch deep learning",
578
        long_description=readme,
579
        long_description_content_type="text/markdown",
580
        license="BSD",
581
        # Package info
582
        packages=find_packages(exclude=("test",)),
583
        package_data={package_name: ["*.dll", "*.dylib", "*.so"]},
584
585
586
587
588
589
        zip_safe=False,
        install_requires=requirements,
        extras_require={
            "scipy": ["scipy"],
        },
        ext_modules=get_extensions(),
590
        python_requires=">=3.8",
591
        cmdclass={
592
593
594
            "build_ext": BuildExtension.with_options(no_python_abi_suffix=True),
            "clean": clean,
        },
595
    )