__init__.py 2.05 KB
Newer Older
1
2
from torchvision import models
from torchvision import datasets
3
from torchvision import ops
4
5
from torchvision import transforms
from torchvision import utils
6

7
8
try:
    from .version import __version__  # noqa: F401
Soumith Chintala's avatar
Soumith Chintala committed
9
except ImportError:
10
    pass
11
12
13
14
15
16
17
18
19

_image_backend = 'PIL'


def set_image_backend(backend):
    """
    Specifies the package used to load images.

    Args:
20
21
22
        backend (string): Name of the image backend. one of {'PIL', 'accimage'}.
            The :mod:`accimage` package uses the Intel IPP library. It is
            generally faster than PIL, but does not support as many operations.
23
24
25
26
27
28
29
30
31
32
33
34
35
    """
    global _image_backend
    if backend not in ['PIL', 'accimage']:
        raise ValueError("Invalid backend '{}'. Options are 'PIL' and 'accimage'"
                         .format(backend))
    _image_backend = backend


def get_image_backend():
    """
    Gets the name of the package used to load images
    """
    return _image_backend
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63


def _check_cuda_matches():
    """
    Make sure that CUDA versions match between the pytorch install and torchvision install
    """
    import torch
    from torchvision import _C
    if hasattr(_C, "CUDA_VERSION") and torch.version.cuda is not None:
        tv_version = str(_C.CUDA_VERSION)
        if int(tv_version) < 10000:
            tv_major = int(tv_version[0])
            tv_minor = int(tv_version[2])
        else:
            tv_major = int(tv_version[0:2])
            tv_minor = int(tv_version[3])
        t_version = torch.version.cuda
        t_version = t_version.split('.')
        t_major = int(t_version[0])
        t_minor = int(t_version[1])
        if t_major != tv_major or t_minor != tv_minor:
            raise RuntimeError("Detected that PyTorch and torchvision were compiled with different CUDA versions. "
                               "PyTorch has CUDA Version={}.{} and torchvision has CUDA Version={}.{}. "
                               "Please reinstall the torchvision that matches your PyTorch install."
                               .format(t_major, t_minor, tv_major, tv_minor))


_check_cuda_matches()