__init__.py 8.03 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
import logging
import traceback
5
from contextlib import suppress
6
7
8
9
10
11
from itertools import chain
from typing import TYPE_CHECKING, Optional

from vllm.plugins import load_plugins_by_group
from vllm.utils import resolve_obj_by_qualname

12
from .interface import _Backend  # noqa: F401
13
from .interface import CpuArchEnum, Platform, PlatformEnum
zhuwenwen's avatar
zhuwenwen committed
14
import torch
15

16
logger = logging.getLogger(__name__)
17

18

19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def vllm_version_matches_substr(substr: str) -> bool:
    """
    Check to see if the vLLM version matches a substring.
    """
    from importlib.metadata import PackageNotFoundError, version
    try:
        vllm_version = version("vllm")
    except PackageNotFoundError as e:
        logger.warning(
            "The vLLM package was not found, so its version could not be "
            "inspected. This may cause platform detection to fail.")
        raise e
    return substr in vllm_version


34
35
36
37
38
39
40
41
42
43
44
def tpu_platform_plugin() -> Optional[str]:
    is_tpu = False
    try:
        # While it's technically possible to install libtpu on a
        # non-TPU machine, this is a very uncommon scenario. Therefore,
        # we assume that libtpu is installed if and only if the machine
        # has TPUs.
        import libtpu  # noqa: F401
        is_tpu = True
    except Exception:
        pass
45

46
    return "vllm.platforms.tpu.TpuPlatform" if is_tpu else None
47
48


49
50
def cuda_platform_plugin() -> Optional[str]:
    is_cuda = False
51
52

    try:
53
54
        from vllm.utils import import_pynvml
        pynvml = import_pynvml()
55
56
        pynvml.nvmlInit()
        try:
57
58
59
60
61
62
            # NOTE: Edge case: vllm cpu build on a GPU machine.
            # Third-party pynvml can be imported in cpu build,
            # we need to check if vllm is built with cpu too.
            # Otherwise, vllm will always activate cuda plugin
            # on a GPU machine, even if in a cpu build.
            is_cuda = (pynvml.nvmlDeviceGetCount() > 0
63
                       and not vllm_version_matches_substr("cpu"))
64
65
        finally:
            pynvml.nvmlShutdown()
66
67
68
69
70
    except Exception as e:
        if "nvml" not in e.__class__.__name__.lower():
            # If the error is not related to NVML, re-raise it.
            raise e

71
72
73
74
75
76
77
78
        # CUDA is supported on Jetson, but NVML may not be.
        import os

        def cuda_is_jetson() -> bool:
            return os.path.isfile("/etc/nv_tegra_release") \
                or os.path.exists("/sys/class/tegra-firmware")

        if cuda_is_jetson():
79
            is_cuda = True
80

81
82
83
84
85
86
87
    return "vllm.platforms.cuda.CudaPlatform" if is_cuda else None


def rocm_platform_plugin() -> Optional[str]:
    is_rocm = False

    try:
zhuwenwen's avatar
zhuwenwen committed
88
89
90
91
92
93
94
95
96
        if torch.version.hip is not None:
            is_rocm = True
        # import amdsmi
        # amdsmi.amdsmi_init()
        # try:
        #     if len(amdsmi.amdsmi_get_processor_handles()) > 0:
        #         is_rocm = True
        # finally:
        #     amdsmi.amdsmi_shut_down()
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    except Exception:
        pass

    return "vllm.platforms.rocm.RocmPlatform" if is_rocm else None


def hpu_platform_plugin() -> Optional[str]:
    is_hpu = False
    try:
        from importlib import util
        is_hpu = util.find_spec('habana_frameworks') is not None
    except Exception:
        pass

    return "vllm.platforms.hpu.HpuPlatform" if is_hpu else None


def xpu_platform_plugin() -> Optional[str]:
    is_xpu = False

    try:
        # installed IPEX if the machine has XPUs.
        import intel_extension_for_pytorch  # noqa: F401
        import oneccl_bindings_for_pytorch  # noqa: F401
        import torch
        if hasattr(torch, 'xpu') and torch.xpu.is_available():
            is_xpu = True
    except Exception:
        pass

    return "vllm.platforms.xpu.XPUPlatform" if is_xpu else None


def cpu_platform_plugin() -> Optional[str]:
    is_cpu = False
    try:
133
        is_cpu = vllm_version_matches_substr("cpu")
134
        if not is_cpu:
135
136
137
            import platform
            is_cpu = platform.machine().lower().startswith("arm")

138
139
140
141
142
143
144
145
146
147
148
149
150
    except Exception:
        pass

    return "vllm.platforms.cpu.CpuPlatform" if is_cpu else None


def neuron_platform_plugin() -> Optional[str]:
    is_neuron = False
    try:
        import transformers_neuronx  # noqa: F401
        is_neuron = True
    except ImportError:
        pass
151

152
    return "vllm.platforms.neuron.NeuronPlatform" if is_neuron else None
153
154


155
156
def openvino_platform_plugin() -> Optional[str]:
    is_openvino = False
157
158
    with suppress(Exception):
        is_openvino = vllm_version_matches_substr("openvino")
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212

    return "vllm.platforms.openvino.OpenVinoPlatform" if is_openvino else None


builtin_platform_plugins = {
    'tpu': tpu_platform_plugin,
    'cuda': cuda_platform_plugin,
    'rocm': rocm_platform_plugin,
    'hpu': hpu_platform_plugin,
    'xpu': xpu_platform_plugin,
    'cpu': cpu_platform_plugin,
    'neuron': neuron_platform_plugin,
    'openvino': openvino_platform_plugin,
}


def resolve_current_platform_cls_qualname() -> str:
    platform_plugins = load_plugins_by_group('vllm.platform_plugins')

    activated_plugins = []

    for name, func in chain(builtin_platform_plugins.items(),
                            platform_plugins.items()):
        try:
            assert callable(func)
            platform_cls_qualname = func()
            if platform_cls_qualname is not None:
                activated_plugins.append(name)
        except Exception:
            pass

    activated_builtin_plugins = list(
        set(activated_plugins) & set(builtin_platform_plugins.keys()))
    activated_oot_plugins = list(
        set(activated_plugins) & set(platform_plugins.keys()))

    if len(activated_oot_plugins) >= 2:
        raise RuntimeError(
            "Only one platform plugin can be activated, but got: "
            f"{activated_oot_plugins}")
    elif len(activated_oot_plugins) == 1:
        platform_cls_qualname = platform_plugins[activated_oot_plugins[0]]()
        logger.info("Platform plugin %s is activated",
                    activated_oot_plugins[0])
    elif len(activated_builtin_plugins) >= 2:
        raise RuntimeError(
            "Only one platform plugin can be activated, but got: "
            f"{activated_builtin_plugins}")
    elif len(activated_builtin_plugins) == 1:
        platform_cls_qualname = builtin_platform_plugins[
            activated_builtin_plugins[0]]()
        logger.info("Automatically detected platform %s.",
                    activated_builtin_plugins[0])
    else:
213
        platform_cls_qualname = "vllm.platforms.interface.UnspecifiedPlatform"
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
        logger.info(
            "No platform detected, vLLM is running on UnspecifiedPlatform")
    return platform_cls_qualname


_current_platform = None
_init_trace: str = ''

if TYPE_CHECKING:
    current_platform: Platform


def __getattr__(name: str):
    if name == 'current_platform':
        # lazy init current_platform.
        # 1. out-of-tree platform plugins need `from vllm.platforms import
        #    Platform` so that they can inherit `Platform` class. Therefore,
        #    we cannot resolve `current_platform` during the import of
        #    `vllm.platforms`.
        # 2. when users use out-of-tree platform plugins, they might run
        #    `import vllm`, some vllm internal code might access
        #    `current_platform` during the import, and we need to make sure
        #    `current_platform` is only resolved after the plugins are loaded
        #    (we have tests for this, if any developer violate this, they will
        #    see the test failures).
        global _current_platform
        if _current_platform is None:
            platform_cls_qualname = resolve_current_platform_cls_qualname()
            _current_platform = resolve_obj_by_qualname(
                platform_cls_qualname)()
            global _init_trace
            _init_trace = "".join(traceback.format_stack())
        return _current_platform
247
    elif name in globals():
248
        return globals()[name]
249
250
251
    else:
        raise AttributeError(
            f"No attribute named '{name}' exists in {__name__}.")
252
253
254
255
256
257


__all__ = [
    'Platform', 'PlatformEnum', 'current_platform', 'CpuArchEnum',
    "_init_trace"
]