"tests/models/language/generation/test_common.py" did not exist on "e1957c6ebdd4860f832c26ae4de4195d10803723"
__init__.py 9.66 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
import logging
import traceback
from itertools import chain
6
from typing import TYPE_CHECKING
7

8
from vllm import envs
9
from vllm.plugins import PLATFORM_PLUGINS_GROUP, load_plugins_by_group
10
from vllm.utils.import_utils import resolve_obj_by_qualname
11
from vllm.utils.torch_utils import supports_xccl
12
13

from .interface import CpuArchEnum, Platform, PlatformEnum
14

15
logger = logging.getLogger(__name__)
16

17

18
19
20
21
22
def vllm_version_matches_substr(substr: str) -> bool:
    """
    Check to see if the vLLM version matches a substring.
    """
    from importlib.metadata import PackageNotFoundError, version
23

24
25
26
27
28
    try:
        vllm_version = version("vllm")
    except PackageNotFoundError as e:
        logger.warning(
            "The vLLM package was not found, so its version could not be "
29
30
            "inspected. This may cause platform detection to fail."
        )
31
32
33
34
        raise e
    return substr in vllm_version


35
def tpu_platform_plugin() -> str | None:
36
    logger.debug("Checking if TPU platform is available.")
37
38
39
40

    # Check for Pathways TPU proxy
    if envs.VLLM_TPU_USING_PATHWAYS:
        logger.debug("Confirmed TPU platform is available via Pathways proxy.")
41
        return "tpu_inference.platforms.tpu_platform.TpuPlatform"
42
43

    # Check for libtpu installation
44
45
46
    try:
        # While it's technically possible to install libtpu on a
        # non-TPU machine, this is a very uncommon scenario. Therefore,
47
        # we assume that libtpu is installed only if the machine
48
        # has TPUs.
49

50
        import libtpu  # noqa: F401
51

52
        logger.debug("Confirmed TPU platform is available.")
53
        return "vllm.platforms.tpu.TpuPlatform"
54
55
    except Exception as e:
        logger.debug("TPU platform is not available because: %s", str(e))
56
        return None
57
58


59
def cuda_platform_plugin() -> str | None:
60
    is_cuda = False
61
    logger.debug("Checking if CUDA platform is available.")
62
    try:
63
        from vllm.utils.import_utils import import_pynvml
64

65
        pynvml = import_pynvml()
66
67
        pynvml.nvmlInit()
        try:
68
69
70
71
72
            # NOTE: Edge case: vllm cpu build on a GPU machine.
            # Third-party pynvml can be imported in cpu build,
            # we need to check if vllm is built with cpu too.
            # Otherwise, vllm will always activate cuda plugin
            # on a GPU machine, even if in a cpu build.
73
74
75
76
            is_cuda = (
                pynvml.nvmlDeviceGetCount() > 0
                and not vllm_version_matches_substr("cpu")
            )
77
            if pynvml.nvmlDeviceGetCount() <= 0:
78
                logger.debug("CUDA platform is not available because no GPU is found.")
79
            if vllm_version_matches_substr("cpu"):
80
81
82
                logger.debug(
                    "CUDA platform is not available because vLLM is built with CPU."
                )
83
84
            if is_cuda:
                logger.debug("Confirmed CUDA platform is available.")
85
86
        finally:
            pynvml.nvmlShutdown()
87
    except Exception as e:
88
        logger.debug("Exception happens when checking CUDA platform: %s", str(e))
89
90
91
92
        if "nvml" not in e.__class__.__name__.lower():
            # If the error is not related to NVML, re-raise it.
            raise e

93
94
95
96
        # CUDA is supported on Jetson, but NVML may not be.
        import os

        def cuda_is_jetson() -> bool:
97
98
99
            return os.path.isfile("/etc/nv_tegra_release") or os.path.exists(
                "/sys/class/tegra-firmware"
            )
100
101

        if cuda_is_jetson():
102
            logger.debug("Confirmed CUDA platform is available on Jetson.")
103
            is_cuda = True
104
105
        else:
            logger.debug("CUDA platform is not available because: %s", str(e))
106

107
108
109
    return "vllm.platforms.cuda.CudaPlatform" if is_cuda else None


110
def rocm_platform_plugin() -> str | None:
111
    is_rocm = False
112
    logger.debug("Checking if ROCm platform is available.")
113
114
    try:
        import amdsmi
115

116
117
118
119
        amdsmi.amdsmi_init()
        try:
            if len(amdsmi.amdsmi_get_processor_handles()) > 0:
                is_rocm = True
120
                logger.debug("Confirmed ROCm platform is available.")
121
            else:
122
                logger.debug("ROCm platform is not available because no GPU is found.")
123
124
        finally:
            amdsmi.amdsmi_shut_down()
125
126
    except Exception as e:
        logger.debug("ROCm platform is not available because: %s", str(e))
127
128
129
130

    return "vllm.platforms.rocm.RocmPlatform" if is_rocm else None


131
def xpu_platform_plugin() -> str | None:
132
    is_xpu = False
133
    logger.debug("Checking if XPU platform is available.")
134
135
136
137
    try:
        # installed IPEX if the machine has XPUs.
        import intel_extension_for_pytorch  # noqa: F401
        import torch
138

139
140
141
142
143
144
        if supports_xccl():
            dist_backend = "xccl"
        else:
            dist_backend = "ccl"
            import oneccl_bindings_for_pytorch  # noqa: F401

145
        if hasattr(torch, "xpu") and torch.xpu.is_available():
146
            is_xpu = True
147
            from vllm.platforms.xpu import XPUPlatform
148

149
            XPUPlatform.dist_backend = dist_backend
150
            logger.debug("Confirmed %s backend is available.", XPUPlatform.dist_backend)
151
152
153
            logger.debug("Confirmed XPU platform is available.")
    except Exception as e:
        logger.debug("XPU platform is not available because: %s", str(e))
154
155
156
157

    return "vllm.platforms.xpu.XPUPlatform" if is_xpu else None


158
def cpu_platform_plugin() -> str | None:
159
    is_cpu = False
160
    logger.debug("Checking if CPU platform is available.")
161
    try:
162
        is_cpu = vllm_version_matches_substr("cpu")
163
        if is_cpu:
164
165
166
            logger.debug(
                "Confirmed CPU platform is available because vLLM is built with CPU."
            )
167
        if not is_cpu:
168
            import sys
169

170
            is_cpu = sys.platform.startswith("darwin")
171
            if is_cpu:
172
173
174
                logger.debug(
                    "Confirmed CPU platform is available because the machine is MacOS."
                )
175

176
177
    except Exception as e:
        logger.debug("CPU platform is not available because: %s", str(e))
178
179
180
181
182

    return "vllm.platforms.cpu.CpuPlatform" if is_cpu else None


builtin_platform_plugins = {
183
184
185
186
187
    "tpu": tpu_platform_plugin,
    "cuda": cuda_platform_plugin,
    "rocm": rocm_platform_plugin,
    "xpu": xpu_platform_plugin,
    "cpu": cpu_platform_plugin,
188
189
190
191
}


def resolve_current_platform_cls_qualname() -> str:
192
    platform_plugins = load_plugins_by_group(PLATFORM_PLUGINS_GROUP)
193
194
195

    activated_plugins = []

196
    for name, func in chain(builtin_platform_plugins.items(), platform_plugins.items()):
197
198
199
200
201
202
203
204
205
        try:
            assert callable(func)
            platform_cls_qualname = func()
            if platform_cls_qualname is not None:
                activated_plugins.append(name)
        except Exception:
            pass

    activated_builtin_plugins = list(
206
207
208
        set(activated_plugins) & set(builtin_platform_plugins.keys())
    )
    activated_oot_plugins = list(set(activated_plugins) & set(platform_plugins.keys()))
209
210
211
212

    if len(activated_oot_plugins) >= 2:
        raise RuntimeError(
            "Only one platform plugin can be activated, but got: "
213
214
            f"{activated_oot_plugins}"
        )
215
216
    elif len(activated_oot_plugins) == 1:
        platform_cls_qualname = platform_plugins[activated_oot_plugins[0]]()
217
        logger.info("Platform plugin %s is activated", activated_oot_plugins[0])
218
219
220
    elif len(activated_builtin_plugins) >= 2:
        raise RuntimeError(
            "Only one platform plugin can be activated, but got: "
221
222
            f"{activated_builtin_plugins}"
        )
223
    elif len(activated_builtin_plugins) == 1:
224
        platform_cls_qualname = builtin_platform_plugins[activated_builtin_plugins[0]]()
225
226
227
        logger.debug(
            "Automatically detected platform %s.", activated_builtin_plugins[0]
        )
228
    else:
229
        platform_cls_qualname = "vllm.platforms.interface.UnspecifiedPlatform"
230
        logger.debug("No platform detected, vLLM is running on UnspecifiedPlatform")
231
232
233
234
    return platform_cls_qualname


_current_platform = None
235
_init_trace: str = ""
236
237
238
239
240
241

if TYPE_CHECKING:
    current_platform: Platform


def __getattr__(name: str):
242
    if name == "current_platform":
243
244
245
246
247
248
249
250
251
252
253
254
255
256
        # lazy init current_platform.
        # 1. out-of-tree platform plugins need `from vllm.platforms import
        #    Platform` so that they can inherit `Platform` class. Therefore,
        #    we cannot resolve `current_platform` during the import of
        #    `vllm.platforms`.
        # 2. when users use out-of-tree platform plugins, they might run
        #    `import vllm`, some vllm internal code might access
        #    `current_platform` during the import, and we need to make sure
        #    `current_platform` is only resolved after the plugins are loaded
        #    (we have tests for this, if any developer violate this, they will
        #    see the test failures).
        global _current_platform
        if _current_platform is None:
            platform_cls_qualname = resolve_current_platform_cls_qualname()
257
            _current_platform = resolve_obj_by_qualname(platform_cls_qualname)()
258
259
260
            global _init_trace
            _init_trace = "".join(traceback.format_stack())
        return _current_platform
261
    elif name in globals():
262
        return globals()[name]
263
    else:
264
        raise AttributeError(f"No attribute named '{name}' exists in {__name__}.")
265
266


267
268
269
270
271
272
273
274
def __setattr__(name: str, value):
    if name == "current_platform":
        global _current_platform
        _current_platform = value
    elif name in globals():
        globals()[name] = value
    else:
        raise AttributeError(f"No attribute named '{name}' exists in {__name__}.")
275
276


277
__all__ = ["Platform", "PlatformEnum", "current_platform", "CpuArchEnum", "_init_trace"]