__init__.py 6.71 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import logging
import traceback
from itertools import chain
from typing import TYPE_CHECKING

from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.torch_utils import supports_xccl

from vllm_omni.platforms.interface import OmniPlatform, OmniPlatformEnum
from vllm_omni.plugins import (
    OMNI_PLATFORM_PLUGINS_GROUP,
    load_omni_plugins_by_group,
)

logger = logging.getLogger(__name__)


def cuda_omni_platform_plugin() -> str | None:
    """Check if CUDA OmniPlatform should be activated."""
    is_cuda = False
    logger.debug("Checking if CUDA OmniPlatform is available.")
    try:
        from vllm.utils.import_utils import import_pynvml

        pynvml = import_pynvml()
        pynvml.nvmlInit()
        try:
            if pynvml.nvmlDeviceGetCount() > 0:
                is_cuda = True
                logger.debug("Confirmed CUDA OmniPlatform is available.")
            else:
                logger.debug("CUDA OmniPlatform is not available because no GPU is found.")
        finally:
            pynvml.nvmlShutdown()
    except Exception as e:
        logger.debug("CUDA OmniPlatform is not available because: %s", str(e))

    return "vllm_omni.platforms.cuda.platform.CudaOmniPlatform" if is_cuda else None


def rocm_omni_platform_plugin() -> str | None:
    """Check if ROCm OmniPlatform should be activated."""
    is_rocm = False
    logger.debug("Checking if ROCm OmniPlatform is available.")
    try:
        import amdsmi

        amdsmi.amdsmi_init()
        try:
            if len(amdsmi.amdsmi_get_processor_handles()) > 0:
                is_rocm = True
                logger.debug("Confirmed ROCm OmniPlatform is available.")
            else:
                logger.debug("ROCm OmniPlatform is not available because no GPU is found.")
        finally:
            amdsmi.amdsmi_shut_down()
    except Exception as e:
        logger.debug("ROCm OmniPlatform is not available because: %s", str(e))

    return "vllm_omni.platforms.rocm.platform.RocmOmniPlatform" if is_rocm else None


def npu_omni_platform_plugin() -> str | None:
    """Check if NPU OmniPlatform should be activated."""
    is_npu = False
    logger.debug("Checking if NPU OmniPlatform is available.")
    try:
        import torch

        if hasattr(torch, "npu") and torch.npu.is_available():
            is_npu = True
            logger.debug("Confirmed NPU OmniPlatform is available.")
    except Exception as e:
        logger.debug("NPU OmniPlatform is not available because: %s", str(e))

    return "vllm_omni.platforms.npu.platform.NPUOmniPlatform" if is_npu else None


def xpu_omni_platform_plugin() -> str | None:
    """Check if XPU OmniPlatform should be activated."""
    is_xpu = False
    logger.debug("Checking if XPU OmniPlatform is available.")
    try:
        # installed IPEX if the machine has XPUs.
        import intel_extension_for_pytorch  # noqa: F401
        import torch

        if supports_xccl():
            dist_backend = "xccl"
        else:
            dist_backend = "ccl"
            import oneccl_bindings_for_pytorch  # noqa: F401

        if hasattr(torch, "xpu") and torch.xpu.is_available():
            is_xpu = True
            from vllm_omni.platforms.xpu import XPUOmniPlatform

            XPUOmniPlatform.dist_backend = dist_backend
            logger.debug("Confirmed %s backend is available.", XPUOmniPlatform.dist_backend)
            logger.debug("Confirmed XPU platform is available.")
    except Exception as e:
        logger.debug("XPU omni platform is not available because: %s", str(e))

    return "vllm_omni.platforms.xpu.platform.XPUOmniPlatform" if is_xpu else None


builtin_omni_platform_plugins = {
    "cuda": cuda_omni_platform_plugin,
    "rocm": rocm_omni_platform_plugin,
    "npu": npu_omni_platform_plugin,
    "xpu": xpu_omni_platform_plugin,
}


def resolve_current_omni_platform_cls_qualname() -> str:
    """Resolve the current OmniPlatform class qualified name."""
    platform_plugins = load_omni_plugins_by_group(OMNI_PLATFORM_PLUGINS_GROUP)

    activated_plugins = []

    for name, func in chain(builtin_omni_platform_plugins.items(), platform_plugins.items()):
        try:
            assert callable(func)
            platform_cls_qualname = func()
            if platform_cls_qualname is not None:
                activated_plugins.append(name)
        except Exception:
            pass

    activated_builtin_plugins = list(set(activated_plugins) & set(builtin_omni_platform_plugins.keys()))
    activated_oot_plugins = list(set(activated_plugins) & set(platform_plugins.keys()))

    if len(activated_oot_plugins) >= 2:
        raise RuntimeError(f"Only one OmniPlatform plugin can be activated, but got: {activated_oot_plugins}")
    elif len(activated_oot_plugins) == 1:
        platform_cls_qualname = platform_plugins[activated_oot_plugins[0]]()
        logger.info("OmniPlatform plugin %s is activated", activated_oot_plugins[0])
    elif len(activated_builtin_plugins) >= 2:
        raise RuntimeError(f"Only one OmniPlatform plugin can be activated, but got: {activated_builtin_plugins}")
    elif len(activated_builtin_plugins) == 1:
        platform_cls_qualname = builtin_omni_platform_plugins[activated_builtin_plugins[0]]()
        logger.debug("Automatically detected OmniPlatform %s.", activated_builtin_plugins[0])
    else:
        platform_cls_qualname = "vllm_omni.platforms.interface.UnspecifiedOmniPlatform"
        logger.debug("No platform detected, vLLM-Omni is running on UnspecifiedOmniPlatform")

    return platform_cls_qualname


_current_omni_platform = None
_init_trace: str = ""

if TYPE_CHECKING:
    current_omni_platform: OmniPlatform


def __getattr__(name: str):
    if name == "current_omni_platform":
        # Lazy init current_omni_platform
        global _current_omni_platform
        if _current_omni_platform is None:
            platform_cls_qualname = resolve_current_omni_platform_cls_qualname()
            _current_omni_platform = resolve_obj_by_qualname(platform_cls_qualname)()
            global _init_trace
            _init_trace = "".join(traceback.format_stack())
        return _current_omni_platform
    elif name in globals():
        return globals()[name]
    else:
        raise AttributeError(f"No attribute named '{name}' exists in {__name__}.")


def __setattr__(name: str, value):  # noqa: N807
    if name == "current_omni_platform":
        global _current_omni_platform
        _current_omni_platform = value
    elif name in globals():
        globals()[name] = value
    else:
        raise AttributeError(f"No attribute named '{name}' exists in {__name__}.")


__all__ = [
    "OmniPlatform",
    "OmniPlatformEnum",
    "current_omni_platform",
    "_init_trace",
]