interface.py 10.6 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import enum
4
import platform
5
import random
6
from platform import uname
7
from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union
8

9
import numpy as np
10
11
import torch

12
13
from vllm.logger import init_logger

14
15
if TYPE_CHECKING:
    from vllm.config import VllmConfig
16
    from vllm.utils import FlexibleArgumentParser
17
18
else:
    VllmConfig = None
19
    FlexibleArgumentParser = None
20

21
22
logger = init_logger(__name__)

23

24
25
26
27
28
def in_wsl() -> bool:
    # Reference: https://github.com/microsoft/WSL/issues/4071
    return "microsoft" in " ".join(uname()).lower()


29
30
31
32
33
34
35
36
class _Backend(enum.Enum):
    FLASH_ATTN = enum.auto()
    FLASH_ATTN_VLLM_V1 = enum.auto()
    XFORMERS = enum.auto()
    ROCM_FLASH = enum.auto()
    TORCH_SDPA = enum.auto()
    OPENVINO = enum.auto()
    FLASHINFER = enum.auto()
37
    TRITON_MLA = enum.auto()
38
    TRITON_MLA_VLLM_V1 = enum.auto()
39
    FLASHMLA = enum.auto()
40
41
    HPU_ATTN = enum.auto()
    PALLAS = enum.auto()
42
    PALLAS_VLLM_V1 = enum.auto()
43
    IPEX = enum.auto()
44
    BLOCK_SPARSE_FLASH_ATTN = enum.auto()
45
46
47
    NO_ATTENTION = enum.auto()


48
49
50
class PlatformEnum(enum.Enum):
    CUDA = enum.auto()
    ROCM = enum.auto()
51
    TPU = enum.auto()
52
    HPU = enum.auto()
53
    XPU = enum.auto()
54
    CPU = enum.auto()
55
    NEURON = enum.auto()
56
    OPENVINO = enum.auto()
57
    OOT = enum.auto()
58
    UNSPECIFIED = enum.auto()
59
60


61
62
63
64
65
66
67
68
class CpuArchEnum(enum.Enum):
    X86 = enum.auto()
    ARM = enum.auto()
    POWERPC = enum.auto()
    OTHER = enum.auto()
    UNKNOWN = enum.auto()


69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
class DeviceCapability(NamedTuple):
    major: int
    minor: int

    def as_version_str(self) -> str:
        return f"{self.major}.{self.minor}"

    def to_int(self) -> int:
        """
        Express device capability as an integer ``<major><minor>``.

        It is assumed that the minor version is always a single digit.
        """
        assert 0 <= self.minor < 10
        return self.major * 10 + self.minor


86
87
class Platform:
    _enum: PlatformEnum
88
    device_name: str
89
    device_type: str
90

91
92
93
94
    # available dispatch keys:
    # check https://github.com/pytorch/pytorch/blob/313dac6c1ca0fa0cde32477509cce32089f8532a/torchgen/model.py#L134 # noqa
    # use "CPU" as a fallback for platforms not registered in PyTorch
    dispatch_key: str = "CPU"
95

96
97
98
99
    # available ray device keys:
    # https://github.com/ray-project/ray/blob/10ba5adadcc49c60af2c358a33bb943fb491a171/python/ray/_private/ray_constants.py#L438 # noqa
    # empty string means the device does not support ray
    ray_device_key: str = ""
100
101
102
103
104
105
106

    # platform-agnostic way to specify the device control environment variable,
    # .e.g. CUDA_VISIBLE_DEVICES for CUDA.
    # hint: search for "get_visible_accelerator_ids_env_var" in
    # https://github.com/ray-project/ray/tree/master/python/ray/_private/accelerators # noqa
    device_control_env_var: str = "VLLM_DEVICE_CONTROL_ENV_VAR_PLACEHOLDER"

107
108
109
110
111
112
    # The torch.compile backend for compiling simple and
    # standalone functions. The default value is "inductor" to keep
    # the same behavior as PyTorch.
    # NOTE: for the forward part of the model, vLLM has another separate
    # compilation strategy.
    simple_compile_backend: str = "inductor"
113

114
    supported_quantization: list[str] = []
115
116
117
118
119
120
121

    def is_cuda(self) -> bool:
        return self._enum == PlatformEnum.CUDA

    def is_rocm(self) -> bool:
        return self._enum == PlatformEnum.ROCM

122
123
124
    def is_tpu(self) -> bool:
        return self._enum == PlatformEnum.TPU

125
126
127
    def is_hpu(self) -> bool:
        return self._enum == PlatformEnum.HPU

128
129
130
    def is_xpu(self) -> bool:
        return self._enum == PlatformEnum.XPU

131
132
133
    def is_cpu(self) -> bool:
        return self._enum == PlatformEnum.CPU

134
135
136
    def is_neuron(self) -> bool:
        return self._enum == PlatformEnum.NEURON

137
138
139
    def is_openvino(self) -> bool:
        return self._enum == PlatformEnum.OPENVINO

140
141
142
    def is_out_of_tree(self) -> bool:
        return self._enum == PlatformEnum.OOT

143
144
145
146
    def is_cuda_alike(self) -> bool:
        """Stateless version of :func:`torch.cuda.is_available`."""
        return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)

147
    @classmethod
148
149
    def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
                             dtype: torch.dtype, kv_cache_dtype: Optional[str],
150
151
                             block_size: int, use_v1: bool,
                             use_mla: bool) -> str:
152
153
        """Get the attention backend class of a device."""
        return ""
154

155
156
157
158
159
160
    @classmethod
    def get_device_capability(
        cls,
        device_id: int = 0,
    ) -> Optional[DeviceCapability]:
        """Stateless version of :func:`torch.cuda.get_device_capability`."""
161
        return None
162

163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
    @classmethod
    def has_device_capability(
        cls,
        capability: Union[Tuple[int, int], int],
        device_id: int = 0,
    ) -> bool:
        """
        Test whether this platform is compatible with a device capability.

        The ``capability`` argument can either be:

        - A tuple ``(major, minor)``.
        - An integer ``<major><minor>``. (See :meth:`DeviceCapability.to_int`)
        """
        current_capability = cls.get_device_capability(device_id=device_id)
        if current_capability is None:
            return False

        if isinstance(capability, tuple):
            return current_capability >= capability

        return current_capability.to_int() >= capability

    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
188
189
190
        """Get the name of a device."""
        raise NotImplementedError

191
192
193
194
195
    @classmethod
    def get_device_uuid(cls, device_id: int = 0) -> str:
        """Get the uuid of a device, e.g. the PCI bus ID."""
        raise NotImplementedError

196
197
198
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        """Get the total memory of a device in bytes."""
199
200
        raise NotImplementedError

201
202
203
204
205
206
207
    @classmethod
    def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
        """
        Check if the current platform supports async output.
        """
        raise NotImplementedError

208
209
    @classmethod
    def inference_mode(cls):
210
211
212
213
214
215
216
217
        """A device-specific wrapper of `torch.inference_mode`.

        This wrapper is recommended because some hardware backends such as TPU
        do not support `torch.inference_mode`. In such a case, they will fall
        back to `torch.no_grad` by overriding this method.
        """
        return torch.inference_mode(mode=True)

218
    @classmethod
219
    def seed_everything(cls, seed: Optional[int] = None) -> None:
220
221
222
223
224
225
        """
        Set the seed of each random module.
        `torch.manual_seed` will set seed on all devices.

        Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
        """
226
227
228
229
        if seed is not None:
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)
230

231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
    @classmethod
    def pre_register_and_update(cls,
                                parser: Optional[FlexibleArgumentParser] = None
                                ) -> None:
        """
        Do some pre-registeration or update action for the current platform.

        This function is called before global VllmConfig is initialized or cli
        arguments are parsed. It's used for out-of-tree platforms to register or
        update the configuration.

        For example, the out-of-tree quantization config can be imported and
        registered here dynamically.
        """
        pass

247
248
249
250
251
252
253
254
255
256
257
258
259
    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
        """
        Check and update the configuration for the current platform.

        It can raise an exception if the configuration is not compatible with
        the current platform, or it can update the configuration to make it
        compatible with the current platform.

        The config is passed by reference, so it can be modified in place.
        """
        pass

260
261
262
263
264
265
266
267
268
269
270
271
    @classmethod
    def verify_model_arch(cls, model_arch: str) -> None:
        """
        Verify whether the current platform supports the specified model
        architecture.

        - This will raise an Error or Warning based on the model support on
        the current platform.
        - By default all models are considered supported.
        """
        pass

272
273
274
275
276
277
278
279
280
281
282
    @classmethod
    def verify_quantization(cls, quant: str) -> None:
        """
        Verify whether the quantization is supported by the current platform.
        """
        if cls.supported_quantization and \
            quant not in cls.supported_quantization:
            raise ValueError(
                f"{quant} quantization is currently not supported in "
                f"{cls.device_name}.")

283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
    @classmethod
    def get_cpu_architecture(cls) -> CpuArchEnum:
        """
        Determine the CPU architecture of the current system.
        Returns CpuArchEnum indicating the architecture type.
        """
        machine = platform.machine().lower()

        if machine in ("x86_64", "amd64", "i386", "i686"):
            return CpuArchEnum.X86
        elif machine.startswith("arm") or machine.startswith("aarch"):
            return CpuArchEnum.ARM
        elif machine.startswith("ppc"):
            return CpuArchEnum.POWERPC

        return CpuArchEnum.OTHER if machine else CpuArchEnum.UNKNOWN

300
301
302
303
304
305
306
307
308
309
310
    @classmethod
    def is_pin_memory_available(cls) -> bool:
        """Checks whether pin memory is available on the current platform."""
        if in_wsl():
            # Pinning memory in WSL is not supported.
            # https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
            logger.warning("Using 'pin_memory=False' as WSL is detected. "
                           "This may slow down the performance.")
            return False
        return True

311
312
313
314
315
316
317
318
319
    @classmethod
    def get_current_memory_usage(cls,
                                 device: Optional[torch.types.Device] = None
                                 ) -> float:
        """
        Return the memory usage in bytes.
        """
        raise NotImplementedError

320
321
322
323
324
325
326
    @classmethod
    def get_punica_wrapper(cls) -> str:
        """
        Return the punica wrapper for current platform.
        """
        raise NotImplementedError

327
328
329
330
331
    @classmethod
    def get_device_communicator_cls(cls) -> str:
        """
        Get device specific communicator class for distributed communication.
        """
Mengqing Cao's avatar
Mengqing Cao committed
332
        return "vllm.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase"  # noqa
333

334
335
336

class UnspecifiedPlatform(Platform):
    _enum = PlatformEnum.UNSPECIFIED
337
    device_type = ""