extension.py 3.02 KB
Newer Older
1
2
3
4
5
import ctypes
import os
import sys
from warnings import warn

6
import torch
7

8
from ._internally_replaced_utils import _get_extension_path
9

10

11
_HAS_OPS = False
12

13

14
15
def _has_ops():
    return False
16
17
18


try:
19
    lib_path = _get_extension_path("_C")
20
    torch.ops.load_library(lib_path)
21
    _HAS_OPS = True
22
23
24

    def _has_ops():  # noqa: F811
        return True
25
26


27
28
29
30
except (ImportError, OSError):
    pass


31
32
33
34
35
36
37
38
39
40
41
42
43
def _assert_has_ops():
    if not _has_ops():
        raise RuntimeError(
            "Couldn't load custom C++ ops. This can happen if your PyTorch and "
            "torchvision versions are incompatible, or if you had errors while compiling "
            "torchvision from source. For further information on the compatible versions, check "
            "https://github.com/pytorch/vision#installation for the compatibility matrix. "
            "Please check your PyTorch version with torch.__version__ and your torchvision "
            "version with torchvision.__version__ and verify if they are compatible, and if not "
            "please reinstall torchvision so that it matches your PyTorch install."
        )


44
def _check_cuda_version():
45
46
47
    """
    Make sure that CUDA versions match between the pytorch install and torchvision install
    """
48
49
    if not _HAS_OPS:
        return -1
50
    import torch
51

52
53
54
    _version = torch.ops.torchvision._cuda_version()
    if _version != -1 and torch.version.cuda is not None:
        tv_version = str(_version)
55
56
57
58
59
60
61
        if int(tv_version) < 10000:
            tv_major = int(tv_version[0])
            tv_minor = int(tv_version[2])
        else:
            tv_major = int(tv_version[0:2])
            tv_minor = int(tv_version[3])
        t_version = torch.version.cuda
62
        t_version = t_version.split(".")
63
64
65
        t_major = int(t_version[0])
        t_minor = int(t_version[1])
        if t_major != tv_major or t_minor != tv_minor:
66
67
            raise RuntimeError(
                "Detected that PyTorch and torchvision were compiled with different CUDA versions. "
68
69
70
                f"PyTorch has CUDA Version={t_major}.{t_minor} and torchvision has "
                f"CUDA Version={tv_major}.{tv_minor}. "
                "Please reinstall the torchvision that matches your PyTorch install."
71
            )
72
73
74
    return _version


75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def _load_library(lib_name):
    lib_path = _get_extension_path(lib_name)
    # On Windows Python-3.8+ has `os.add_dll_directory` call,
    # which is called from _get_extension_path to configure dll search path
    # Condition below adds a workaround for older versions by
    # explicitly calling `LoadLibraryExW` with the following flags:
    #  - LOAD_LIBRARY_SEARCH_DEFAULT_DIRS (0x1000)
    #  - LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR (0x100)
    if os.name == "nt" and sys.version_info < (3, 8):
        _kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
        if hasattr(_kernel32, "LoadLibraryExW"):
            _kernel32.LoadLibraryExW(lib_path, None, 0x00001100)
        else:
            warn("LoadLibraryExW is missing in kernel32.dll")

    torch.ops.load_library(lib_path)


93
_check_cuda_version()