__init__.py 3.44 KB
Newer Older
1
import os
2
import warnings
Bruno Korbar's avatar
Bruno Korbar committed
3
from modulefinder import Module
4

5
import torch
6

7
8
9
10
# Don't re-order these, we need to load the _C extension (done when importing
# .extensions) before entering _meta_registrations.
from .extension import _HAS_OPS  # usort:skip
from torchvision import _meta_registrations, datasets, io, models, ops, transforms, utils  # usort:skip
11

12
13
try:
    from .version import __version__  # noqa: F401
Soumith Chintala's avatar
Soumith Chintala committed
14
except ImportError:
15
    pass
16

Bruno Korbar's avatar
Bruno Korbar committed
17

18
# Check if torchvision is being imported within the root folder
19
20
21
22
23
24
25
26
if not _HAS_OPS and os.path.dirname(os.path.realpath(__file__)) == os.path.join(
    os.path.realpath(os.getcwd()), "torchvision"
):
    message = (
        "You are importing torchvision within its own root folder ({}). "
        "This is not expected to work and may give errors. Please exit the "
        "torchvision project source and relaunch your python interpreter."
    )
27
28
    warnings.warn(message.format(os.getcwd()))

29
_image_backend = "PIL"
30

31
32
_video_backend = "pyav"

33
34
35
36
37
38

def set_image_backend(backend):
    """
    Specifies the package used to load images.

    Args:
39
40
41
        backend (string): Name of the image backend. one of {'PIL', 'accimage'}.
            The :mod:`accimage` package uses the Intel IPP library. It is
            generally faster than PIL, but does not support as many operations.
42
43
    """
    global _image_backend
44
    if backend not in ["PIL", "accimage"]:
45
        raise ValueError(f"Invalid backend '{backend}'. Options are 'PIL' and 'accimage'")
46
47
48
49
50
51
52
53
    _image_backend = backend


def get_image_backend():
    """
    Gets the name of the package used to load images
    """
    return _image_backend
54
55


56
57
58
59
60
61
62
def set_video_backend(backend):
    """
    Specifies the package used to decode videos.

    Args:
        backend (string): Name of the video backend. one of {'pyav', 'video_reader'}.
            The :mod:`pyav` package uses the 3rd party PyAv library. It is a Pythonic
63
64
65
            binding for the FFmpeg libraries.
            The :mod:`video_reader` package includes a native C++ implementation on
            top of FFMPEG libraries, and a python API of TorchScript custom operator.
66
            It generally decodes faster than :mod:`pyav`, but is perhaps less robust.
67
68

    .. note::
69
        Building with FFMPEG is disabled by default in the latest `main`. If you want to use the 'video_reader'
70
        backend, please compile torchvision from source.
71
72
    """
    global _video_backend
Bruno Korbar's avatar
Bruno Korbar committed
73
74
    if backend not in ["pyav", "video_reader", "cuda"]:
        raise ValueError("Invalid video backend '%s'. Options are 'pyav', 'video_reader' and 'cuda'" % backend)
75
    if backend == "video_reader" and not io._HAS_VIDEO_OPT:
Bruno Korbar's avatar
Bruno Korbar committed
76
        # TODO: better messages
77
        message = "video_reader video backend is not available. Please compile torchvision from source and try again"
Bruno Korbar's avatar
Bruno Korbar committed
78
79
80
81
82
        raise RuntimeError(message)
    elif backend == "cuda" and not io._HAS_GPU_VIDEO_DECODER:
        # TODO: better messages
        message = "cuda video backend is not available."
        raise RuntimeError(message)
83
84
    else:
        _video_backend = backend
85
86
87


def get_video_backend():
88
89
90
91
92
93
94
    """
    Returns the currently active video backend used to decode videos.

    Returns:
        str: Name of the video backend. one of {'pyav', 'video_reader'}.
    """

95
96
97
    return _video_backend


98
99
def _is_tracing():
    return torch._C._get_tracing_state()
100
101
102
103
104
105


def disable_beta_transforms_warning():
    # Noop, only exists to avoid breaking existing code.
    # See https://github.com/pytorch/vision/issues/7896
    pass