__init__.py 3.17 KB
Newer Older
1
from .interface import _Backend  # noqa: F401
2
from .interface import CpuArchEnum, Platform, PlatformEnum, UnspecifiedPlatform
zhuwenwen's avatar
zhuwenwen committed
3
import torch
4

5
current_platform: Platform
6

7
8
9
10
11
12
# NOTE: we don't use `torch.version.cuda` / `torch.version.hip` because
# they only indicate the build configuration, not the runtime environment.
# For example, people can install a cuda build of pytorch but run on tpu.

is_tpu = False
try:
13
14
15
16
    # While it's technically possible to install libtpu on a non-TPU machine,
    # this is a very uncommon scenario. Therefore, we assume that libtpu is
    # installed if and only if the machine has TPUs.
    import libtpu  # noqa: F401
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
    is_tpu = True
except Exception:
    pass

is_cuda = False

try:
    import pynvml
    pynvml.nvmlInit()
    try:
        if pynvml.nvmlDeviceGetCount() > 0:
            is_cuda = True
    finally:
        pynvml.nvmlShutdown()
except Exception:
32
33
34
35
36
37
38
39
40
    # CUDA is supported on Jetson, but NVML may not be.
    import os

    def cuda_is_jetson() -> bool:
        return os.path.isfile("/etc/nv_tegra_release") \
            or os.path.exists("/sys/class/tegra-firmware")

    if cuda_is_jetson():
        is_cuda = True
41
42
43

is_rocm = False

44
try:
zhuwenwen's avatar
zhuwenwen committed
45
46
47
48
49
50
51
52
53
    if torch.version.hip is not None:
        is_rocm = True
    # import amdsmi
    # amdsmi.amdsmi_init()
    # try:
    #     if len(amdsmi.amdsmi_get_processor_handles()) > 0:
    #         is_rocm = True
    # finally:
    #     amdsmi.amdsmi_shut_down()
54
55
except Exception:
    pass
56

57
58
59
60
61
62
63
is_hpu = False
try:
    from importlib import util
    is_hpu = util.find_spec('habana_frameworks') is not None
except Exception:
    pass

64
65
66
is_xpu = False

try:
67
68
69
    # installed IPEX if the machine has XPUs.
    import intel_extension_for_pytorch  # noqa: F401
    import oneccl_bindings_for_pytorch  # noqa: F401
70
71
72
73
74
75
    import torch
    if hasattr(torch, 'xpu') and torch.xpu.is_available():
        is_xpu = True
except Exception:
    pass

76
77
78
79
80
81
82
is_cpu = False
try:
    from importlib.metadata import version
    is_cpu = "cpu" in version("vllm")
except Exception:
    pass

83
84
85
86
87
88
89
is_neuron = False
try:
    import transformers_neuronx  # noqa: F401
    is_neuron = True
except ImportError:
    pass

90
91
92
93
94
95
96
is_openvino = False
try:
    from importlib.metadata import version
    is_openvino = "openvino" in version("vllm")
except Exception:
    pass

97
if is_tpu:
98
99
100
101
    # people might install pytorch built with cuda but run on tpu
    # so we need to check tpu first
    from .tpu import TpuPlatform
    current_platform = TpuPlatform()
102
elif is_cuda:
103
104
    from .cuda import CudaPlatform
    current_platform = CudaPlatform()
105
elif is_rocm:
106
107
    from .rocm import RocmPlatform
    current_platform = RocmPlatform()
108
109
110
elif is_hpu:
    from .hpu import HpuPlatform
    current_platform = HpuPlatform()
111
112
113
elif is_xpu:
    from .xpu import XPUPlatform
    current_platform = XPUPlatform()
114
115
116
elif is_cpu:
    from .cpu import CpuPlatform
    current_platform = CpuPlatform()
117
118
119
elif is_neuron:
    from .neuron import NeuronPlatform
    current_platform = NeuronPlatform()
120
121
122
elif is_openvino:
    from .openvino import OpenVinoPlatform
    current_platform = OpenVinoPlatform()
123
else:
124
    current_platform = UnspecifiedPlatform()
125

126
__all__ = ['Platform', 'PlatformEnum', 'current_platform', 'CpuArchEnum']