"vllm/vscode:/vscode.git/clone" did not exist on "efa9084628b32787ae1901a2d1e9b80f7d08809b"
interface.py 8.12 KB
Newer Older
1
import enum
2
import platform
3
import random
4
from platform import uname
5
from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union
6

7
import numpy as np
8
9
import torch

10
11
from vllm.logger import init_logger

12
13
14
15
16
if TYPE_CHECKING:
    from vllm.config import VllmConfig
else:
    VllmConfig = None

17
18
logger = init_logger(__name__)

19

20
21
22
23
24
def in_wsl() -> bool:
    # Reference: https://github.com/microsoft/WSL/issues/4071
    return "microsoft" in " ".join(uname()).lower()


25
26
27
28
29
30
31
32
33
34
35
36
37
38
class _Backend(enum.Enum):
    FLASH_ATTN = enum.auto()
    FLASH_ATTN_VLLM_V1 = enum.auto()
    XFORMERS = enum.auto()
    ROCM_FLASH = enum.auto()
    TORCH_SDPA = enum.auto()
    OPENVINO = enum.auto()
    FLASHINFER = enum.auto()
    HPU_ATTN = enum.auto()
    PALLAS = enum.auto()
    IPEX = enum.auto()
    NO_ATTENTION = enum.auto()


39
40
41
class PlatformEnum(enum.Enum):
    CUDA = enum.auto()
    ROCM = enum.auto()
42
    TPU = enum.auto()
43
    HPU = enum.auto()
44
    XPU = enum.auto()
45
    CPU = enum.auto()
46
    NEURON = enum.auto()
47
    OPENVINO = enum.auto()
48
    OOT = enum.auto()
49
    UNSPECIFIED = enum.auto()
50
51


52
53
54
55
56
57
58
59
class CpuArchEnum(enum.Enum):
    X86 = enum.auto()
    ARM = enum.auto()
    POWERPC = enum.auto()
    OTHER = enum.auto()
    UNKNOWN = enum.auto()


60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
class DeviceCapability(NamedTuple):
    major: int
    minor: int

    def as_version_str(self) -> str:
        return f"{self.major}.{self.minor}"

    def to_int(self) -> int:
        """
        Express device capability as an integer ``<major><minor>``.

        It is assumed that the minor version is always a single digit.
        """
        assert 0 <= self.minor < 10
        return self.major * 10 + self.minor


77
78
class Platform:
    _enum: PlatformEnum
79
    device_name: str
80
    device_type: str
81
82
83
84
    # available dispatch keys:
    # check https://github.com/pytorch/pytorch/blob/313dac6c1ca0fa0cde32477509cce32089f8532a/torchgen/model.py#L134 # noqa
    # use "CPU" as a fallback for platforms not registered in PyTorch
    dispatch_key: str = "CPU"
85
86
87
88
89
90
    # The torch.compile backend for compiling simple and
    # standalone functions. The default value is "inductor" to keep
    # the same behavior as PyTorch.
    # NOTE: for the forward part of the model, vLLM has another separate
    # compilation strategy.
    simple_compile_backend: str = "inductor"
91
    supported_quantization: list[str] = []
92
93
94
95
96
97
98

    def is_cuda(self) -> bool:
        return self._enum == PlatformEnum.CUDA

    def is_rocm(self) -> bool:
        return self._enum == PlatformEnum.ROCM

99
100
101
    def is_tpu(self) -> bool:
        return self._enum == PlatformEnum.TPU

102
103
104
    def is_hpu(self) -> bool:
        return self._enum == PlatformEnum.HPU

105
106
107
    def is_xpu(self) -> bool:
        return self._enum == PlatformEnum.XPU

108
109
110
    def is_cpu(self) -> bool:
        return self._enum == PlatformEnum.CPU

111
112
113
    def is_neuron(self) -> bool:
        return self._enum == PlatformEnum.NEURON

114
115
116
    def is_openvino(self) -> bool:
        return self._enum == PlatformEnum.OPENVINO

117
118
119
    def is_out_of_tree(self) -> bool:
        return self._enum == PlatformEnum.OOT

120
121
122
123
    def is_cuda_alike(self) -> bool:
        """Stateless version of :func:`torch.cuda.is_available`."""
        return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)

124
    @classmethod
125
126
127
128
129
    def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
                             dtype: torch.dtype, kv_cache_dtype: Optional[str],
                             block_size: int, use_v1: bool) -> str:
        """Get the attention backend class of a device."""
        return ""
130

131
132
133
134
135
136
    @classmethod
    def get_device_capability(
        cls,
        device_id: int = 0,
    ) -> Optional[DeviceCapability]:
        """Stateless version of :func:`torch.cuda.get_device_capability`."""
137
        return None
138

139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
    @classmethod
    def has_device_capability(
        cls,
        capability: Union[Tuple[int, int], int],
        device_id: int = 0,
    ) -> bool:
        """
        Test whether this platform is compatible with a device capability.

        The ``capability`` argument can either be:

        - A tuple ``(major, minor)``.
        - An integer ``<major><minor>``. (See :meth:`DeviceCapability.to_int`)
        """
        current_capability = cls.get_device_capability(device_id=device_id)
        if current_capability is None:
            return False

        if isinstance(capability, tuple):
            return current_capability >= capability

        return current_capability.to_int() >= capability

    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
164
165
166
167
168
169
        """Get the name of a device."""
        raise NotImplementedError

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        """Get the total memory of a device in bytes."""
170
171
        raise NotImplementedError

172
173
174
175
176
177
178
    @classmethod
    def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
        """
        Check if the current platform supports async output.
        """
        raise NotImplementedError

179
180
    @classmethod
    def inference_mode(cls):
181
182
183
184
185
186
187
188
        """A device-specific wrapper of `torch.inference_mode`.

        This wrapper is recommended because some hardware backends such as TPU
        do not support `torch.inference_mode`. In such a case, they will fall
        back to `torch.no_grad` by overriding this method.
        """
        return torch.inference_mode(mode=True)

189
190
191
192
193
194
195
196
197
198
199
200
    @classmethod
    def seed_everything(cls, seed: int) -> None:
        """
        Set the seed of each random module.
        `torch.manual_seed` will set seed on all devices.

        Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
        """
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)

201
202
203
204
205
206
207
208
209
210
211
212
213
    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
        """
        Check and update the configuration for the current platform.

        It can raise an exception if the configuration is not compatible with
        the current platform, or it can update the configuration to make it
        compatible with the current platform.

        The config is passed by reference, so it can be modified in place.
        """
        pass

214
215
216
217
218
219
220
221
222
223
224
225
    @classmethod
    def verify_model_arch(cls, model_arch: str) -> None:
        """
        Verify whether the current platform supports the specified model
        architecture.

        - This will raise an Error or Warning based on the model support on
        the current platform.
        - By default all models are considered supported.
        """
        pass

226
227
228
229
230
231
232
233
234
235
236
    @classmethod
    def verify_quantization(cls, quant: str) -> None:
        """
        Verify whether the quantization is supported by the current platform.
        """
        if cls.supported_quantization and \
            quant not in cls.supported_quantization:
            raise ValueError(
                f"{quant} quantization is currently not supported in "
                f"{cls.device_name}.")

237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
    @classmethod
    def get_cpu_architecture(cls) -> CpuArchEnum:
        """
        Determine the CPU architecture of the current system.
        Returns CpuArchEnum indicating the architecture type.
        """
        machine = platform.machine().lower()

        if machine in ("x86_64", "amd64", "i386", "i686"):
            return CpuArchEnum.X86
        elif machine.startswith("arm") or machine.startswith("aarch"):
            return CpuArchEnum.ARM
        elif machine.startswith("ppc"):
            return CpuArchEnum.POWERPC

        return CpuArchEnum.OTHER if machine else CpuArchEnum.UNKNOWN

254
255
256
257
258
259
260
261
262
263
264
    @classmethod
    def is_pin_memory_available(cls) -> bool:
        """Checks whether pin memory is available on the current platform."""
        if in_wsl():
            # Pinning memory in WSL is not supported.
            # https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
            logger.warning("Using 'pin_memory=False' as WSL is detected. "
                           "This may slow down the performance.")
            return False
        return True

265
266
267

class UnspecifiedPlatform(Platform):
    _enum = PlatformEnum.UNSPECIFIED
268
    device_type = ""