__init__.py 7.25 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
9
10
import logging
import traceback
from itertools import chain
from typing import TYPE_CHECKING, Optional

from vllm.plugins import load_plugins_by_group
from vllm.utils import resolve_obj_by_qualname

11
from .interface import _Backend  # noqa: F401
12
from .interface import CpuArchEnum, Platform, PlatformEnum
zhuwenwen's avatar
zhuwenwen committed
13
import torch
14

15
logger = logging.getLogger(__name__)
16

17

18
19
20
21
22
23
24
25
26
27
28
def tpu_platform_plugin() -> Optional[str]:
    is_tpu = False
    try:
        # While it's technically possible to install libtpu on a
        # non-TPU machine, this is a very uncommon scenario. Therefore,
        # we assume that libtpu is installed if and only if the machine
        # has TPUs.
        import libtpu  # noqa: F401
        is_tpu = True
    except Exception:
        pass
29

30
    return "vllm.platforms.tpu.TpuPlatform" if is_tpu else None
31
32


33
34
def cuda_platform_plugin() -> Optional[str]:
    is_cuda = False
35
36

    try:
37
38
        from vllm.utils import import_pynvml
        pynvml = import_pynvml()
39
40
41
42
43
44
        pynvml.nvmlInit()
        try:
            if pynvml.nvmlDeviceGetCount() > 0:
                is_cuda = True
        finally:
            pynvml.nvmlShutdown()
45
46
47
48
49
    except Exception as e:
        if "nvml" not in e.__class__.__name__.lower():
            # If the error is not related to NVML, re-raise it.
            raise e

50
51
52
53
54
55
56
57
        # CUDA is supported on Jetson, but NVML may not be.
        import os

        def cuda_is_jetson() -> bool:
            return os.path.isfile("/etc/nv_tegra_release") \
                or os.path.exists("/sys/class/tegra-firmware")

        if cuda_is_jetson():
58
            is_cuda = True
59

60
61
62
63
64
65
66
    return "vllm.platforms.cuda.CudaPlatform" if is_cuda else None


def rocm_platform_plugin() -> Optional[str]:
    is_rocm = False

    try:
zhuwenwen's avatar
zhuwenwen committed
67
68
69
70
71
72
73
74
75
        if torch.version.hip is not None:
            is_rocm = True
        # import amdsmi
        # amdsmi.amdsmi_init()
        # try:
        #     if len(amdsmi.amdsmi_get_processor_handles()) > 0:
        #         is_rocm = True
        # finally:
        #     amdsmi.amdsmi_shut_down()
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
    except Exception:
        pass

    return "vllm.platforms.rocm.RocmPlatform" if is_rocm else None


def hpu_platform_plugin() -> Optional[str]:
    is_hpu = False
    try:
        from importlib import util
        is_hpu = util.find_spec('habana_frameworks') is not None
    except Exception:
        pass

    return "vllm.platforms.hpu.HpuPlatform" if is_hpu else None


def xpu_platform_plugin() -> Optional[str]:
    is_xpu = False

    try:
        # installed IPEX if the machine has XPUs.
        import intel_extension_for_pytorch  # noqa: F401
        import oneccl_bindings_for_pytorch  # noqa: F401
        import torch
        if hasattr(torch, 'xpu') and torch.xpu.is_available():
            is_xpu = True
    except Exception:
        pass

    return "vllm.platforms.xpu.XPUPlatform" if is_xpu else None


def cpu_platform_plugin() -> Optional[str]:
    is_cpu = False
    try:
        from importlib.metadata import version
        is_cpu = "cpu" in version("vllm")
114
        if not is_cpu:
115
116
117
            import platform
            is_cpu = platform.machine().lower().startswith("arm")

118
119
120
121
122
123
124
125
126
127
128
129
130
    except Exception:
        pass

    return "vllm.platforms.cpu.CpuPlatform" if is_cpu else None


def neuron_platform_plugin() -> Optional[str]:
    is_neuron = False
    try:
        import transformers_neuronx  # noqa: F401
        is_neuron = True
    except ImportError:
        pass
131

132
    return "vllm.platforms.neuron.NeuronPlatform" if is_neuron else None
133
134


135
136
def openvino_platform_plugin() -> Optional[str]:
    is_openvino = False
137
    try:
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
        from importlib.metadata import version
        is_openvino = "openvino" in version("vllm")
    except Exception:
        pass

    return "vllm.platforms.openvino.OpenVinoPlatform" if is_openvino else None


builtin_platform_plugins = {
    'tpu': tpu_platform_plugin,
    'cuda': cuda_platform_plugin,
    'rocm': rocm_platform_plugin,
    'hpu': hpu_platform_plugin,
    'xpu': xpu_platform_plugin,
    'cpu': cpu_platform_plugin,
    'neuron': neuron_platform_plugin,
    'openvino': openvino_platform_plugin,
}


def resolve_current_platform_cls_qualname() -> str:
    platform_plugins = load_plugins_by_group('vllm.platform_plugins')

    activated_plugins = []

    for name, func in chain(builtin_platform_plugins.items(),
                            platform_plugins.items()):
        try:
            assert callable(func)
            platform_cls_qualname = func()
            if platform_cls_qualname is not None:
                activated_plugins.append(name)
        except Exception:
            pass

    activated_builtin_plugins = list(
        set(activated_plugins) & set(builtin_platform_plugins.keys()))
    activated_oot_plugins = list(
        set(activated_plugins) & set(platform_plugins.keys()))

    if len(activated_oot_plugins) >= 2:
        raise RuntimeError(
            "Only one platform plugin can be activated, but got: "
            f"{activated_oot_plugins}")
    elif len(activated_oot_plugins) == 1:
        platform_cls_qualname = platform_plugins[activated_oot_plugins[0]]()
        logger.info("Platform plugin %s is activated",
                    activated_oot_plugins[0])
    elif len(activated_builtin_plugins) >= 2:
        raise RuntimeError(
            "Only one platform plugin can be activated, but got: "
            f"{activated_builtin_plugins}")
    elif len(activated_builtin_plugins) == 1:
        platform_cls_qualname = builtin_platform_plugins[
            activated_builtin_plugins[0]]()
        logger.info("Automatically detected platform %s.",
                    activated_builtin_plugins[0])
    else:
196
        platform_cls_qualname = "vllm.platforms.interface.UnspecifiedPlatform"
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
        logger.info(
            "No platform detected, vLLM is running on UnspecifiedPlatform")
    return platform_cls_qualname


_current_platform = None
_init_trace: str = ''

if TYPE_CHECKING:
    current_platform: Platform


def __getattr__(name: str):
    if name == 'current_platform':
        # lazy init current_platform.
        # 1. out-of-tree platform plugins need `from vllm.platforms import
        #    Platform` so that they can inherit `Platform` class. Therefore,
        #    we cannot resolve `current_platform` during the import of
        #    `vllm.platforms`.
        # 2. when users use out-of-tree platform plugins, they might run
        #    `import vllm`, some vllm internal code might access
        #    `current_platform` during the import, and we need to make sure
        #    `current_platform` is only resolved after the plugins are loaded
        #    (we have tests for this, if any developer violate this, they will
        #    see the test failures).
        global _current_platform
        if _current_platform is None:
            platform_cls_qualname = resolve_current_platform_cls_qualname()
            _current_platform = resolve_obj_by_qualname(
                platform_cls_qualname)()
            global _init_trace
            _init_trace = "".join(traceback.format_stack())
        return _current_platform
230
    elif name in globals():
231
        return globals()[name]
232
233
234
    else:
        raise AttributeError(
            f"No attribute named '{name}' exists in {__name__}.")
235
236
237
238
239
240


__all__ = [
    'Platform', 'PlatformEnum', 'current_platform', 'CpuArchEnum',
    "_init_trace"
]