setup.py 18.3 KB
Newer Older
1
import distutils.command.clean
2
import distutils.spawn
3
import glob
4
5
6
import io
import os
import re
7
import shutil
8
9
10
import subprocess
import sys
from distutils.version import StrictVersion
11
12

import torch
13
14
from pkg_resources import parse_version, get_distribution, DistributionNotFound
from setuptools import setup, find_packages
15
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
soumith's avatar
soumith committed
16
17


Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
18
def read(*names, **kwargs):
19
    with io.open(os.path.join(os.path.dirname(__file__), *names), encoding=kwargs.get("encoding", "utf8")) as fp:
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
20
21
        return fp.read()

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
22

23
24
25
26
27
28
29
def get_dist(pkgname):
    try:
        return get_distribution(pkgname)
    except DistributionNotFound:
        return None


30
31
cwd = os.path.dirname(os.path.abspath(__file__))

32
33
version_txt = os.path.join(cwd, "version.txt")
with open(version_txt, "r") as f:
34
    version = f.readline().strip()
35
36
sha = "Unknown"
package_name = "torchvision"
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
37

38
try:
39
    sha = subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=cwd).decode("ascii").strip()
40
41
42
except Exception:
    pass

43
44
45
46
if os.getenv("BUILD_VERSION"):
    version = os.getenv("BUILD_VERSION")
elif sha != "Unknown":
    version += "+" + sha[:7]
47
48
49


def write_version_file():
50
51
    version_path = os.path.join(cwd, "torchvision", "version.py")
    with open(version_path, "w") as f:
52
53
        f.write("__version__ = '{}'\n".format(version))
        f.write("git_version = {}\n".format(repr(sha)))
54
55
56
        f.write("from torchvision.extension import _check_cuda_version\n")
        f.write("if _check_cuda_version() > 0:\n")
        f.write("    cuda = _check_cuda_version()\n")
57
58


59
60
61
pytorch_dep = "torch"
if os.getenv("PYTORCH_VERSION"):
    pytorch_dep += "==" + os.getenv("PYTORCH_VERSION")
soumith's avatar
soumith committed
62

63
requirements = [
64
    "numpy",
65
    pytorch_dep,
66
67
]

68
# Excluding 8.3.0 because of https://github.com/pytorch/vision/issues/4146
69
70
pillow_ver = " >= 5.3.0, !=8.3.0"
pillow_req = "pillow-simd" if get_dist("pillow-simd") is not None else "pillow"
71
72
requirements.append(pillow_req + pillow_ver)

73

74
75
def find_library(name, vision_include):
    this_dir = os.path.dirname(os.path.abspath(__file__))
76
    build_prefix = os.environ.get("BUILD_PREFIX", None)
77
78
79
80
81
82
    is_conda_build = build_prefix is not None

    library_found = False
    conda_installed = False
    lib_folder = None
    include_folder = None
83
    library_header = "{0}.h".format(name)
84

85
    # Lookup in TORCHVISION_INCLUDE or in the package file
86
    package_path = [os.path.join(this_dir, "torchvision")]
87
88
89
90
91
92
93
    for folder in vision_include + package_path:
        candidate_path = os.path.join(folder, library_header)
        library_found = os.path.exists(candidate_path)
        if library_found:
            break

    if not library_found:
94
        print("Running build on conda-build: {0}".format(is_conda_build))
95
96
        if is_conda_build:
            # Add conda headers/libraries
97
98
99
100
101
            if os.name == "nt":
                build_prefix = os.path.join(build_prefix, "Library")
            include_folder = os.path.join(build_prefix, "include")
            lib_folder = os.path.join(build_prefix, "lib")
            library_header_path = os.path.join(include_folder, library_header)
102
103
104
            library_found = os.path.isfile(library_header_path)
            conda_installed = library_found
        else:
105
            # Check if using Anaconda to produce wheels
106
            conda = distutils.spawn.find_executable("conda")
107
            is_conda = conda is not None
108
            print("Running build on conda: {0}".format(is_conda))
109
110
111
            if is_conda:
                python_executable = sys.executable
                py_folder = os.path.dirname(python_executable)
112
113
                if os.name == "nt":
                    env_path = os.path.join(py_folder, "Library")
114
115
                else:
                    env_path = os.path.dirname(py_folder)
116
117
118
                lib_folder = os.path.join(env_path, "lib")
                include_folder = os.path.join(env_path, "include")
                library_header_path = os.path.join(include_folder, library_header)
119
120
121
122
                library_found = os.path.isfile(library_header_path)
                conda_installed = library_found

        if not library_found:
123
124
125
            if sys.platform == "linux":
                library_found = os.path.exists("/usr/include/{0}".format(library_header))
                library_found = library_found or os.path.exists("/usr/local/include/{0}".format(library_header))
126
127
128
129

    return library_found, conda_installed, include_folder, lib_folder


130
131
def get_extensions():
    this_dir = os.path.dirname(os.path.abspath(__file__))
132
    extensions_dir = os.path.join(this_dir, "torchvision", "csrc")
133

134
135
136
    main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) + glob.glob(
        os.path.join(extensions_dir, "ops", "*.cpp")
    )
137
    source_cpu = (
138
139
140
        glob.glob(os.path.join(extensions_dir, "ops", "autograd", "*.cpp"))
        + glob.glob(os.path.join(extensions_dir, "ops", "cpu", "*.cpp"))
        + glob.glob(os.path.join(extensions_dir, "ops", "quantized", "cpu", "*.cpp"))
141
    )
142
143

    is_rocm_pytorch = False
144
145
    TORCH_MAJOR = int(torch.__version__.split(".")[0])
    TORCH_MINOR = int(torch.__version__.split(".")[1])
146
    if TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 5):
147
        from torch.utils.cpp_extension import ROCM_HOME
148

149
150
151
        is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False

    if is_rocm_pytorch:
152
        from torch.utils.hipify import hipify_python
153

154
155
156
        hipify_python.hipify(
            project_directory=this_dir,
            output_directory=this_dir,
157
            includes="torchvision/csrc/ops/cuda/*",
158
159
            show_detailed=True,
            is_pytorch_extension=True,
160
        )
161
        source_cuda = glob.glob(os.path.join(extensions_dir, "ops", "hip", "*.hip"))
162
        # Copy over additional files
163
164
        for file in glob.glob(r"torchvision/csrc/ops/cuda/*.h"):
            shutil.copy(file, "torchvision/csrc/ops/hip")
165
166

    else:
167
        source_cuda = glob.glob(os.path.join(extensions_dir, "ops", "cuda", "*.cu"))
168

169
    source_cuda += glob.glob(os.path.join(extensions_dir, "ops", "autocast", "*.cpp"))
170
171
172
173

    sources = main_file + source_cpu
    extension = CppExtension

174
    compile_cpp_tests = os.getenv("WITH_CPP_MODELS_TEST", "0") == "1"
175
    if compile_cpp_tests:
176
177
178
179
        test_dir = os.path.join(this_dir, "test")
        models_dir = os.path.join(this_dir, "torchvision", "csrc", "models")
        test_file = glob.glob(os.path.join(test_dir, "*.cpp"))
        source_models = glob.glob(os.path.join(models_dir, "*.cpp"))
180
181
182
183
184

        test_file = [os.path.join(test_dir, s) for s in test_file]
        source_models = [os.path.join(models_dir, s) for s in source_models]
        tests = test_file + source_models
        tests_include_dirs = [test_dir, models_dir]
Shahriar's avatar
Shahriar committed
185

186
187
    define_macros = []

188
189
190
191
    extra_compile_args = {"cxx": []}
    if (torch.cuda.is_available() and ((CUDA_HOME is not None) or is_rocm_pytorch)) or os.getenv(
        "FORCE_CUDA", "0"
    ) == "1":
192
193
        extension = CUDAExtension
        sources += source_cuda
194
        if not is_rocm_pytorch:
195
196
197
            define_macros += [("WITH_CUDA", None)]
            nvcc_flags = os.getenv("NVCC_FLAGS", "")
            if nvcc_flags == "":
198
199
                nvcc_flags = []
            else:
200
                nvcc_flags = nvcc_flags.split(" ")
Soumith Chintala's avatar
Soumith Chintala committed
201
        else:
202
            define_macros += [("WITH_HIP", None)]
203
            nvcc_flags = []
204
        extra_compile_args["nvcc"] = nvcc_flags
205

206
207
    if sys.platform == "win32":
        define_macros += [("torchvision_EXPORTS", None)]
208

209
        extra_compile_args["cxx"].append("/MP")
Francisco Massa's avatar
Francisco Massa committed
210

211
    debug_mode = os.getenv("DEBUG", "0") == "1"
212
213
    if debug_mode:
        print("Compile in debug mode")
214
215
        extra_compile_args["cxx"].append("-g")
        extra_compile_args["cxx"].append("-O0")
216
217
218
        if "nvcc" in extra_compile_args:
            # we have to remove "-OX" and "-g" flag if exists and append
            nvcc_flags = extra_compile_args["nvcc"]
219
            extra_compile_args["nvcc"] = [f for f in nvcc_flags if not ("-O" in f or "-g" in f)]
220
221
222
            extra_compile_args["nvcc"].append("-O0")
            extra_compile_args["nvcc"].append("-g")

223
224
    sources = [os.path.join(extensions_dir, s) for s in sources]

225
    include_dirs = [extensions_dir]
226
227
228

    ext_modules = [
        extension(
229
            "torchvision._C",
230
            sorted(sources),
231
232
            include_dirs=include_dirs,
            define_macros=define_macros,
Soumith Chintala's avatar
Soumith Chintala committed
233
            extra_compile_args=extra_compile_args,
234
        )
235
    ]
236
237
238
    if compile_cpp_tests:
        ext_modules.append(
            extension(
239
                "torchvision._C_tests",
240
241
242
243
244
245
                tests,
                include_dirs=tests_include_dirs,
                define_macros=define_macros,
                extra_compile_args=extra_compile_args,
            )
        )
246

247
    # ------------------- Torchvision extra extensions ------------------------
248
249
250
251
    vision_include = os.environ.get("TORCHVISION_INCLUDE", None)
    vision_library = os.environ.get("TORCHVISION_LIBRARY", None)
    vision_include = vision_include.split(os.pathsep) if vision_include is not None else []
    vision_library = vision_library.split(os.pathsep) if vision_library is not None else []
252
253
254
255
256
257
258
259
260
261
    include_dirs += vision_include
    library_dirs = vision_library

    # Image reading extension
    image_macros = []
    image_include = [extensions_dir]
    image_library = []
    image_link_flags = []

    # Locating libPNG
262
263
    libpng = distutils.spawn.find_executable("libpng-config")
    pngfix = distutils.spawn.find_executable("pngfix")
264
    png_found = libpng is not None or pngfix is not None
265
    print("PNG found: {0}".format(png_found))
266
267
268
    if png_found:
        if libpng is not None:
            # Linux / Mac
269
270
271
            png_version = subprocess.run([libpng, "--version"], stdout=subprocess.PIPE)
            png_version = png_version.stdout.strip().decode("utf-8")
            print("libpng version: {0}".format(png_version))
272
273
            png_version = parse_version(png_version)
            if png_version >= parse_version("1.6.0"):
274
275
276
277
                print("Building torchvision with PNG image support")
                png_lib = subprocess.run([libpng, "--libdir"], stdout=subprocess.PIPE)
                png_lib = png_lib.stdout.strip().decode("utf-8")
                if "disabled" not in png_lib:
278
                    image_library += [png_lib]
279
280
281
282
                png_include = subprocess.run([libpng, "--I_opts"], stdout=subprocess.PIPE)
                png_include = png_include.stdout.strip().decode("utf-8")
                _, png_include = png_include.split("-I")
                print("libpng include path: {0}".format(png_include))
283
                image_include += [png_include]
284
                image_link_flags.append("png")
285
            else:
286
                print("libpng installed version is less than 1.6.0, " "disabling PNG support")
287
288
289
                png_found = False
        else:
            # Windows
290
291
            png_lib = os.path.join(os.path.dirname(os.path.dirname(pngfix)), "lib")
            png_include = os.path.join(os.path.dirname(os.path.dirname(pngfix)), "include", "libpng16")
292
293
            image_library += [png_lib]
            image_include += [png_include]
294
            image_link_flags.append("libpng")
295

296
    # Locating libjpeg
297
    (jpeg_found, jpeg_conda, jpeg_include, jpeg_lib) = find_library("jpeglib", vision_include)
298

299
300
301
    print("JPEG found: {0}".format(jpeg_found))
    image_macros += [("PNG_FOUND", str(int(png_found)))]
    image_macros += [("JPEG_FOUND", str(int(jpeg_found)))]
302
    if jpeg_found:
303
304
        print("Building torchvision with JPEG image support")
        image_link_flags.append("jpeg")
305
306
307
308
        if jpeg_conda:
            image_library += [jpeg_lib]
            image_include += [jpeg_include]

309
310
311
    # Locating nvjpeg
    # Should be included in CUDA_HOME for CUDA >= 10.1, which is the minimum version we have in the CI
    nvjpeg_found = (
312
313
314
        extension is CUDAExtension
        and CUDA_HOME is not None
        and os.path.exists(os.path.join(CUDA_HOME, "include", "nvjpeg.h"))
315
316
    )

317
318
    print("NVJPEG found: {0}".format(nvjpeg_found))
    image_macros += [("NVJPEG_FOUND", str(int(nvjpeg_found)))]
319
    if nvjpeg_found:
320
321
322
323
324
325
326
327
328
        print("Building torchvision with NVJPEG image support")
        image_link_flags.append("nvjpeg")

    image_path = os.path.join(extensions_dir, "io", "image")
    image_src = (
        glob.glob(os.path.join(image_path, "*.cpp"))
        + glob.glob(os.path.join(image_path, "cpu", "*.cpp"))
        + glob.glob(os.path.join(image_path, "cuda", "*.cpp"))
    )
329

330
    if png_found or jpeg_found:
331
332
333
334
335
336
337
338
339
340
341
342
343
        ext_modules.append(
            extension(
                "torchvision.image",
                image_src,
                include_dirs=image_include + include_dirs + [image_path],
                library_dirs=image_library + library_dirs,
                define_macros=image_macros,
                libraries=image_link_flags,
                extra_compile_args=extra_compile_args,
            )
        )

    ffmpeg_exe = distutils.spawn.find_executable("ffmpeg")
344
    has_ffmpeg = ffmpeg_exe is not None
345
346
347
348
    # FIXME: Building torchvision with ffmpeg on MacOS or with Python 3.9
    # FIXME: causes crash. See the following GitHub issues for more details.
    # FIXME: https://github.com/pytorch/pytorch/issues/65000
    # FIXME: https://github.com/pytorch/vision/issues/3367
349
    if sys.platform != "linux" or (sys.version_info.major == 3 and sys.version_info.minor == 9):
350
        has_ffmpeg = False
351
352
    if has_ffmpeg:
        try:
353
354
355
            # This is to check if ffmpeg is installed properly.
            subprocess.check_output(["ffmpeg", "-version"])
        except subprocess.CalledProcessError:
356
            print("Error fetching ffmpeg version, ignoring ffmpeg.")
357
358
            has_ffmpeg = False

359
    print("FFmpeg found: {}".format(has_ffmpeg))
360

361
    if has_ffmpeg:
362
        ffmpeg_libraries = {"libavcodec", "libavformat", "libavutil", "libswresample", "libswscale"}
363

364
365
        ffmpeg_bin = os.path.dirname(ffmpeg_exe)
        ffmpeg_root = os.path.dirname(ffmpeg_bin)
366
367
        ffmpeg_include_dir = os.path.join(ffmpeg_root, "include")
        ffmpeg_library_dir = os.path.join(ffmpeg_root, "lib")
368

369
370
371
        gcc = distutils.spawn.find_executable("gcc")
        platform_tag = subprocess.run([gcc, "-print-multiarch"], stdout=subprocess.PIPE)
        platform_tag = platform_tag.stdout.strip().decode("utf-8")
372
373
374

        if platform_tag:
            # Most probably a Debian-based distribution
375
376
            ffmpeg_include_dir = [ffmpeg_include_dir, os.path.join(ffmpeg_include_dir, platform_tag)]
            ffmpeg_library_dir = [ffmpeg_library_dir, os.path.join(ffmpeg_library_dir, platform_tag)]
377
378
379
380
381
382
383
384
        else:
            ffmpeg_include_dir = [ffmpeg_include_dir]
            ffmpeg_library_dir = [ffmpeg_library_dir]

        has_ffmpeg = True
        for library in ffmpeg_libraries:
            library_found = False
            for search_path in ffmpeg_include_dir + include_dirs:
385
                full_path = os.path.join(search_path, library, "*.h")
386
387
388
                library_found |= len(glob.glob(full_path)) > 0

            if not library_found:
389
                print(f"{library} header files were not found, disabling ffmpeg support")
390
391
392
                has_ffmpeg = False

    if has_ffmpeg:
393
        print("ffmpeg include path: {}".format(ffmpeg_include_dir))
394
        print("ffmpeg library_dir: {}".format(ffmpeg_library_dir))
395
396

        # TorchVision base decoder + video reader
397
        video_reader_src_dir = os.path.join(this_dir, "torchvision", "csrc", "io", "video_reader")
398
        video_reader_src = glob.glob(os.path.join(video_reader_src_dir, "*.cpp"))
399
400
        base_decoder_src_dir = os.path.join(this_dir, "torchvision", "csrc", "io", "decoder")
        base_decoder_src = glob.glob(os.path.join(base_decoder_src_dir, "*.cpp"))
401
        # Torchvision video API
402
        videoapi_src_dir = os.path.join(this_dir, "torchvision", "csrc", "io", "video")
403
        videoapi_src = glob.glob(os.path.join(videoapi_src_dir, "*.cpp"))
404
        # exclude tests
405
        base_decoder_src = [x for x in base_decoder_src if "_test.cpp" not in x]
406

407
        combined_src = video_reader_src + base_decoder_src + videoapi_src
408

409
410
        ext_modules.append(
            CppExtension(
411
                "torchvision.video_reader",
412
                combined_src,
413
                include_dirs=[
414
                    base_decoder_src_dir,
415
                    video_reader_src_dir,
416
                    videoapi_src_dir,
417
                    extensions_dir,
418
                    *ffmpeg_include_dir,
419
                    *include_dirs,
420
                ],
421
                library_dirs=ffmpeg_library_dir + library_dirs,
422
                libraries=[
423
424
425
426
427
                    "avcodec",
                    "avformat",
                    "avutil",
                    "swresample",
                    "swscale",
428
                ],
429
430
                extra_compile_args=["-std=c++14"] if os.name != "nt" else ["/std:c++14", "/MP"],
                extra_link_args=["-std=c++14" if os.name != "nt" else "/std:c++14"],
431
432
            )
        )
433
434
435
436
437
438

    return ext_modules


class clean(distutils.command.clean.clean):
    def run(self):
439
        with open(".gitignore", "r") as f:
440
            ignores = f.read()
441
            for wildcard in filter(None, ignores.split("\n")):
442
443
444
445
446
447
448
449
450
451
                for filename in glob.glob(wildcard):
                    try:
                        os.remove(filename)
                    except OSError:
                        shutil.rmtree(filename, ignore_errors=True)

        # It's an old-style class in Python 2.7...
        distutils.command.clean.clean.run(self)


452
453
454
455
456
if __name__ == "__main__":
    print("Building wheel {}-{}".format(package_name, version))

    write_version_file()

457
    with open("README.rst") as f:
458
459
460
461
462
463
        readme = f.read()

    setup(
        # Metadata
        name=package_name,
        version=version,
464
465
466
467
        author="PyTorch Core Team",
        author_email="soumith@pytorch.org",
        url="https://github.com/pytorch/vision",
        description="image and video datasets and models for torch deep learning",
468
        long_description=readme,
469
        license="BSD",
470
        # Package info
471
472
        packages=find_packages(exclude=("test",)),
        package_data={package_name: ["*.dll", "*.dylib", "*.so", "*.categories"]},
473
474
475
476
477
478
479
        zip_safe=False,
        install_requires=requirements,
        extras_require={
            "scipy": ["scipy"],
        },
        ext_modules=get_extensions(),
        cmdclass={
480
481
482
            "build_ext": BuildExtension.with_options(no_python_abi_suffix=True),
            "clean": clean,
        },
483
    )