"vllm/vscode:/vscode.git/clone" did not exist on "cc90419e89c358f906e17a5ec484fbe04092c277"
interface.py 12.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
import enum
3
import platform
4
import random
5
from platform import uname
6
from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union
7

8
import numpy as np
9
10
import torch

11
from vllm.inputs import PromptType
12
13
from vllm.logger import init_logger

14
if TYPE_CHECKING:
15
    from vllm.config import ModelConfig, VllmConfig
16
17
18
    from vllm.lora.request import LoRARequest
    from vllm.pooling_params import PoolingParams
    from vllm.sampling_params import SamplingParams
19
    from vllm.utils import FlexibleArgumentParser
20
else:
21
    ModelConfig = None
22
    VllmConfig = None
23
24
25
    LoRARequest = None
    PoolingParams = None
    SamplingParams = None
26
    FlexibleArgumentParser = None
27

28
29
logger = init_logger(__name__)

30

31
32
33
34
35
def in_wsl() -> bool:
    # Reference: https://github.com/microsoft/WSL/issues/4071
    return "microsoft" in " ".join(uname()).lower()


36
37
38
class _Backend(enum.Enum):
    FLASH_ATTN = enum.auto()
    FLASH_ATTN_VLLM_V1 = enum.auto()
39
    TRITON_ATTN_VLLM_V1 = enum.auto()
40
41
42
43
    XFORMERS = enum.auto()
    ROCM_FLASH = enum.auto()
    TORCH_SDPA = enum.auto()
    FLASHINFER = enum.auto()
44
45
    TRITON_MLA = enum.auto()  # Supported by V1
    FLASHMLA = enum.auto()  # Supported by V1
46
47
    HPU_ATTN = enum.auto()
    PALLAS = enum.auto()
48
    PALLAS_VLLM_V1 = enum.auto()
49
    IPEX = enum.auto()
50
    BLOCK_SPARSE_FLASH_ATTN = enum.auto()
51
52
53
    NO_ATTENTION = enum.auto()


54
55
56
class PlatformEnum(enum.Enum):
    CUDA = enum.auto()
    ROCM = enum.auto()
57
    TPU = enum.auto()
58
    HPU = enum.auto()
59
    XPU = enum.auto()
60
    CPU = enum.auto()
61
    NEURON = enum.auto()
62
    OOT = enum.auto()
63
    UNSPECIFIED = enum.auto()
64
65


66
67
68
69
70
71
72
73
class CpuArchEnum(enum.Enum):
    X86 = enum.auto()
    ARM = enum.auto()
    POWERPC = enum.auto()
    OTHER = enum.auto()
    UNKNOWN = enum.auto()


74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
class DeviceCapability(NamedTuple):
    major: int
    minor: int

    def as_version_str(self) -> str:
        return f"{self.major}.{self.minor}"

    def to_int(self) -> int:
        """
        Express device capability as an integer ``<major><minor>``.

        It is assumed that the minor version is always a single digit.
        """
        assert 0 <= self.minor < 10
        return self.major * 10 + self.minor


91
92
class Platform:
    _enum: PlatformEnum
93
    device_name: str
94
    device_type: str
95

96
97
98
99
    # available dispatch keys:
    # check https://github.com/pytorch/pytorch/blob/313dac6c1ca0fa0cde32477509cce32089f8532a/torchgen/model.py#L134 # noqa
    # use "CPU" as a fallback for platforms not registered in PyTorch
    dispatch_key: str = "CPU"
100

101
102
103
104
    # available ray device keys:
    # https://github.com/ray-project/ray/blob/10ba5adadcc49c60af2c358a33bb943fb491a171/python/ray/_private/ray_constants.py#L438 # noqa
    # empty string means the device does not support ray
    ray_device_key: str = ""
105
106
107
108
109
110
111

    # platform-agnostic way to specify the device control environment variable,
    # .e.g. CUDA_VISIBLE_DEVICES for CUDA.
    # hint: search for "get_visible_accelerator_ids_env_var" in
    # https://github.com/ray-project/ray/tree/master/python/ray/_private/accelerators # noqa
    device_control_env_var: str = "VLLM_DEVICE_CONTROL_ENV_VAR_PLACEHOLDER"

112
113
114
115
116
117
    # The torch.compile backend for compiling simple and
    # standalone functions. The default value is "inductor" to keep
    # the same behavior as PyTorch.
    # NOTE: for the forward part of the model, vLLM has another separate
    # compilation strategy.
    simple_compile_backend: str = "inductor"
118

119
    supported_quantization: list[str] = []
120

121
122
    additional_env_vars: list[str] = []

123
124
125
126
127
128
    def is_cuda(self) -> bool:
        return self._enum == PlatformEnum.CUDA

    def is_rocm(self) -> bool:
        return self._enum == PlatformEnum.ROCM

129
130
131
    def is_tpu(self) -> bool:
        return self._enum == PlatformEnum.TPU

132
133
134
    def is_hpu(self) -> bool:
        return self._enum == PlatformEnum.HPU

135
136
137
    def is_xpu(self) -> bool:
        return self._enum == PlatformEnum.XPU

138
139
140
    def is_cpu(self) -> bool:
        return self._enum == PlatformEnum.CPU

141
142
143
    def is_neuron(self) -> bool:
        return self._enum == PlatformEnum.NEURON

144
145
146
    def is_out_of_tree(self) -> bool:
        return self._enum == PlatformEnum.OOT

147
148
149
150
    def is_cuda_alike(self) -> bool:
        """Stateless version of :func:`torch.cuda.is_available`."""
        return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)

151
    @classmethod
152
153
    def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
                             dtype: torch.dtype, kv_cache_dtype: Optional[str],
154
155
                             block_size: int, use_v1: bool,
                             use_mla: bool) -> str:
156
157
        """Get the attention backend class of a device."""
        return ""
158

159
160
161
162
163
164
    @classmethod
    def get_device_capability(
        cls,
        device_id: int = 0,
    ) -> Optional[DeviceCapability]:
        """Stateless version of :func:`torch.cuda.get_device_capability`."""
165
        return None
166

167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
    @classmethod
    def has_device_capability(
        cls,
        capability: Union[Tuple[int, int], int],
        device_id: int = 0,
    ) -> bool:
        """
        Test whether this platform is compatible with a device capability.

        The ``capability`` argument can either be:

        - A tuple ``(major, minor)``.
        - An integer ``<major><minor>``. (See :meth:`DeviceCapability.to_int`)
        """
        current_capability = cls.get_device_capability(device_id=device_id)
        if current_capability is None:
            return False

        if isinstance(capability, tuple):
            return current_capability >= capability

        return current_capability.to_int() >= capability

    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
192
193
194
        """Get the name of a device."""
        raise NotImplementedError

195
196
197
198
199
    @classmethod
    def get_device_uuid(cls, device_id: int = 0) -> str:
        """Get the uuid of a device, e.g. the PCI bus ID."""
        raise NotImplementedError

200
201
202
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        """Get the total memory of a device in bytes."""
203
204
        raise NotImplementedError

205
206
207
208
209
210
211
    @classmethod
    def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
        """
        Check if the current platform supports async output.
        """
        raise NotImplementedError

212
213
    @classmethod
    def inference_mode(cls):
214
215
216
217
218
219
220
221
        """A device-specific wrapper of `torch.inference_mode`.

        This wrapper is recommended because some hardware backends such as TPU
        do not support `torch.inference_mode`. In such a case, they will fall
        back to `torch.no_grad` by overriding this method.
        """
        return torch.inference_mode(mode=True)

222
    @classmethod
223
    def seed_everything(cls, seed: Optional[int] = None) -> None:
224
225
226
227
228
229
        """
        Set the seed of each random module.
        `torch.manual_seed` will set seed on all devices.

        Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
        """
230
231
232
233
        if seed is not None:
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)
234

235
236
237
238
239
    @classmethod
    def pre_register_and_update(cls,
                                parser: Optional[FlexibleArgumentParser] = None
                                ) -> None:
        """
240
        Do some pre-registration or update action for the current platform.
241
242
243
244
245
246
247
248
249
250

        This function is called before global VllmConfig is initialized or cli
        arguments are parsed. It's used for out-of-tree platforms to register or
        update the configuration.

        For example, the out-of-tree quantization config can be imported and
        registered here dynamically.
        """
        pass

251
252
253
254
255
256
257
258
259
260
261
262
263
    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
        """
        Check and update the configuration for the current platform.

        It can raise an exception if the configuration is not compatible with
        the current platform, or it can update the configuration to make it
        compatible with the current platform.

        The config is passed by reference, so it can be modified in place.
        """
        pass

264
265
266
267
268
269
270
271
272
273
274
275
    @classmethod
    def verify_model_arch(cls, model_arch: str) -> None:
        """
        Verify whether the current platform supports the specified model
        architecture.

        - This will raise an Error or Warning based on the model support on
        the current platform.
        - By default all models are considered supported.
        """
        pass

276
277
278
279
280
281
282
283
284
285
286
    @classmethod
    def verify_quantization(cls, quant: str) -> None:
        """
        Verify whether the quantization is supported by the current platform.
        """
        if cls.supported_quantization and \
            quant not in cls.supported_quantization:
            raise ValueError(
                f"{quant} quantization is currently not supported in "
                f"{cls.device_name}.")

287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
    @classmethod
    def get_cpu_architecture(cls) -> CpuArchEnum:
        """
        Determine the CPU architecture of the current system.
        Returns CpuArchEnum indicating the architecture type.
        """
        machine = platform.machine().lower()

        if machine in ("x86_64", "amd64", "i386", "i686"):
            return CpuArchEnum.X86
        elif machine.startswith("arm") or machine.startswith("aarch"):
            return CpuArchEnum.ARM
        elif machine.startswith("ppc"):
            return CpuArchEnum.POWERPC

        return CpuArchEnum.OTHER if machine else CpuArchEnum.UNKNOWN

304
305
306
307
308
309
310
311
312
313
314
    @classmethod
    def is_pin_memory_available(cls) -> bool:
        """Checks whether pin memory is available on the current platform."""
        if in_wsl():
            # Pinning memory in WSL is not supported.
            # https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
            logger.warning("Using 'pin_memory=False' as WSL is detected. "
                           "This may slow down the performance.")
            return False
        return True

315
316
317
318
319
320
321
322
323
    @classmethod
    def get_current_memory_usage(cls,
                                 device: Optional[torch.types.Device] = None
                                 ) -> float:
        """
        Return the memory usage in bytes.
        """
        raise NotImplementedError

324
325
326
327
328
329
330
    @classmethod
    def get_punica_wrapper(cls) -> str:
        """
        Return the punica wrapper for current platform.
        """
        raise NotImplementedError

331
332
333
334
335
    @classmethod
    def get_device_communicator_cls(cls) -> str:
        """
        Get device specific communicator class for distributed communication.
        """
Mengqing Cao's avatar
Mengqing Cao committed
336
        return "vllm.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase"  # noqa
337

338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
    @classmethod
    def supports_fp8(cls) -> bool:
        """
        Returns whether the current platform supports FP8 types.
        """
        return False

    @classmethod
    def is_fp8_fnuz(cls) -> bool:
        """
        Returns whether the preferred FP8 type is FNUZ on the current platform.

        There are two representations of FP8, OCP FP8 and FNUZ FP8.
        The OCP specification can be found at https://tinyurl.com/b7jvwpft.
        The FNUZ specification can be found at https://tinyurl.com/5n6hwwu5.

        AMD's MI300 and MI325 have native hardware support for FNUZ. All other
        hardware has converged on the OCP FP8 standard.
        """
        return False

    @classmethod
    def fp8_dtype(cls) -> torch.dtype:
        """
        Returns the preferred FP8 type on the current platform.

        See the documentation for is_fp8_fnuz for details.
        """
        return torch.float8_e4m3fn

368
369
370
371
372
373
374
375
376
377
378
379
380
    @classmethod
    def use_all_gather(cls) -> bool:
        """
        Whether to use allgather in LogitsProcessor to gather the logits.
        """
        import vllm.envs as envs
        from vllm.config import get_current_vllm_config

        parallel_config = get_current_vllm_config().parallel_config
        return (envs.VLLM_USE_V1
                or parallel_config.distributed_executor_backend
                == "external_launcher")

381
382
383
384
385
386
387
    @classmethod
    def supports_v1(cls, model_config: ModelConfig) -> bool:
        """Returns whether the current platform can support v1 for the supplied
        model configuration.
        """
        return False

388
389
390
391
392
393
394
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        """
        Returns if custom allreduce is supported on the current platform
        """
        return False

395
396
397
398
399
400
401
402
    @classmethod
    def validate_request(
        cls,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
    ) -> None:
        """Raises if this request is unsupported on this platform"""

403
404
405

class UnspecifiedPlatform(Platform):
    _enum = PlatformEnum.UNSPECIFIED
406
    device_type = ""