extension.py 3.07 KB
Newer Older
1
2
3
import os
import sys

4
import torch
5

6
from ._internally_replaced_utils import _get_extension_path
7

8

9
_HAS_OPS = False
10

11

12
13
def _has_ops():
    return False
14
15
16


try:
17
18
19
20
21
    # On Windows Python-3.8.x has `os.add_dll_directory` call,
    # which is called to configure dll search path.
    # To find cuda related dlls we need to make sure the
    # conda environment/bin path is configured Please take a look:
    # https://stackoverflow.com/questions/59330863/cant-import-dll-module-in-python
22
    # Please note: if some path can't be added using add_dll_directory we simply ignore this path
23
    if os.name == "nt" and sys.version_info < (3, 9):
24
25
26
27
        env_path = os.environ["PATH"]
        path_arr = env_path.split(";")
        for path in path_arr:
            if os.path.exists(path):
28
29
30
31
                try:
                    os.add_dll_directory(path)  # type: ignore[attr-defined]
                except Exception:
                    pass
32

33
    lib_path = _get_extension_path("_C")
34
    torch.ops.load_library(lib_path)
35
    _HAS_OPS = True
36
37
38

    def _has_ops():  # noqa: F811
        return True
39

40
41
42
43
except (ImportError, OSError):
    pass


44
45
46
47
48
49
50
51
52
53
54
55
56
def _assert_has_ops():
    if not _has_ops():
        raise RuntimeError(
            "Couldn't load custom C++ ops. This can happen if your PyTorch and "
            "torchvision versions are incompatible, or if you had errors while compiling "
            "torchvision from source. For further information on the compatible versions, check "
            "https://github.com/pytorch/vision#installation for the compatibility matrix. "
            "Please check your PyTorch version with torch.__version__ and your torchvision "
            "version with torchvision.__version__ and verify if they are compatible, and if not "
            "please reinstall torchvision so that it matches your PyTorch install."
        )


57
def _check_cuda_version():
58
59
60
    """
    Make sure that CUDA versions match between the pytorch install and torchvision install
    """
61
62
    if not _HAS_OPS:
        return -1
63
    from torch.version import cuda as torch_version_cuda
64

65
    _version = torch.ops.torchvision._cuda_version()
66
    if _version != -1 and torch_version_cuda is not None:
67
        tv_version = str(_version)
68
69
70
71
72
73
        if int(tv_version) < 10000:
            tv_major = int(tv_version[0])
            tv_minor = int(tv_version[2])
        else:
            tv_major = int(tv_version[0:2])
            tv_minor = int(tv_version[3])
74
        t_version = torch_version_cuda.split(".")
75
76
        t_major = int(t_version[0])
        t_minor = int(t_version[1])
77
        if t_major != tv_major:
78
            raise RuntimeError(
79
                "Detected that PyTorch and torchvision were compiled with different CUDA major versions. "
80
81
82
                f"PyTorch has CUDA Version={t_major}.{t_minor} and torchvision has "
                f"CUDA Version={tv_major}.{tv_minor}. "
                "Please reinstall the torchvision that matches your PyTorch install."
83
            )
84
85
86
    return _version


87
88
89
90
91
def _load_library(lib_name):
    lib_path = _get_extension_path(lib_name)
    torch.ops.load_library(lib_path)


92
_check_cuda_version()