interface.py 16.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
import enum
3
import os
4
import platform
5
import random
6
from datetime import timedelta
7
from platform import uname
8
from typing import TYPE_CHECKING, NamedTuple, Optional, Union
9

10
import numpy as np
11
import torch
12
from torch.distributed import PrefixStore, ProcessGroup
13

14
from vllm.inputs import ProcessorInputs, PromptType
15
16
from vllm.logger import init_logger

17
if TYPE_CHECKING:
18
    from vllm.config import ModelConfig, VllmConfig
19
20
21
    from vllm.lora.request import LoRARequest
    from vllm.pooling_params import PoolingParams
    from vllm.sampling_params import SamplingParams
22
    from vllm.utils import FlexibleArgumentParser
23
else:
24
    ModelConfig = None
25
    VllmConfig = None
26
27
28
    LoRARequest = None
    PoolingParams = None
    SamplingParams = None
29
    FlexibleArgumentParser = None
30

31
32
logger = init_logger(__name__)

33

34
35
36
37
38
def in_wsl() -> bool:
    # Reference: https://github.com/microsoft/WSL/issues/4071
    return "microsoft" in " ".join(uname()).lower()


39
40
41
class _Backend(enum.Enum):
    FLASH_ATTN = enum.auto()
    FLASH_ATTN_VLLM_V1 = enum.auto()
42
    TRITON_ATTN_VLLM_V1 = enum.auto()
43
44
    XFORMERS = enum.auto()
    ROCM_FLASH = enum.auto()
45
46
    ROCM_AITER_MLA = enum.auto()  # Supported by V1
    ROCM_AITER_MLA_VLLM_V1 = enum.auto()
47
48
    TORCH_SDPA = enum.auto()
    FLASHINFER = enum.auto()
49
    TRITON_MLA = enum.auto()  # Supported by V1
50
51
    TRITON_MLA_VLLM_V1 = enum.auto()
    FLASHMLA_VLLM_V1 = enum.auto()
52
    FLASHMLA = enum.auto()  # Supported by V1
53
54
    HPU_ATTN = enum.auto()
    PALLAS = enum.auto()
55
    PALLAS_VLLM_V1 = enum.auto()
56
    IPEX = enum.auto()
57
    BLOCK_SPARSE_FLASH_ATTN = enum.auto()
58
    DUAL_CHUNK_FLASH_ATTN = enum.auto()
59
60
61
    NO_ATTENTION = enum.auto()


62
63
64
class PlatformEnum(enum.Enum):
    CUDA = enum.auto()
    ROCM = enum.auto()
65
    TPU = enum.auto()
66
    HPU = enum.auto()
67
    XPU = enum.auto()
68
    CPU = enum.auto()
69
    NEURON = enum.auto()
70
    OOT = enum.auto()
71
    UNSPECIFIED = enum.auto()
72
73


74
75
76
77
78
79
80
81
class CpuArchEnum(enum.Enum):
    X86 = enum.auto()
    ARM = enum.auto()
    POWERPC = enum.auto()
    OTHER = enum.auto()
    UNKNOWN = enum.auto()


82
83
84
85
86
87
88
89
90
class DeviceCapability(NamedTuple):
    major: int
    minor: int

    def as_version_str(self) -> str:
        return f"{self.major}.{self.minor}"

    def to_int(self) -> int:
        """
91
        Express device capability as an integer `<major><minor>`.
92
93
94
95
96
97
98

        It is assumed that the minor version is always a single digit.
        """
        assert 0 <= self.minor < 10
        return self.major * 10 + self.minor


99
100
class Platform:
    _enum: PlatformEnum
101
    device_name: str
102
    device_type: str
103

104
105
106
107
    # available dispatch keys:
    # check https://github.com/pytorch/pytorch/blob/313dac6c1ca0fa0cde32477509cce32089f8532a/torchgen/model.py#L134 # noqa
    # use "CPU" as a fallback for platforms not registered in PyTorch
    dispatch_key: str = "CPU"
108

109
110
111
112
    # available ray device keys:
    # https://github.com/ray-project/ray/blob/10ba5adadcc49c60af2c358a33bb943fb491a171/python/ray/_private/ray_constants.py#L438 # noqa
    # empty string means the device does not support ray
    ray_device_key: str = ""
113
114
115
116
117
118
119

    # platform-agnostic way to specify the device control environment variable,
    # .e.g. CUDA_VISIBLE_DEVICES for CUDA.
    # hint: search for "get_visible_accelerator_ids_env_var" in
    # https://github.com/ray-project/ray/tree/master/python/ray/_private/accelerators # noqa
    device_control_env_var: str = "VLLM_DEVICE_CONTROL_ENV_VAR_PLACEHOLDER"

120
121
122
123
124
125
    # The torch.compile backend for compiling simple and
    # standalone functions. The default value is "inductor" to keep
    # the same behavior as PyTorch.
    # NOTE: for the forward part of the model, vLLM has another separate
    # compilation strategy.
    simple_compile_backend: str = "inductor"
126

127
    supported_quantization: list[str] = []
128

129
130
    additional_env_vars: list[str] = []

131
132
133
134
135
136
137
138
    @property
    def supported_dtypes(self) -> list[torch.dtype]:
        """Returns the supported dtypes for the current platform."""
        # Be careful with the order of the dtypes. The first dtype will
        # be used as the default dtype fallback for the current platform,
        # when encountering unsupported dtypes in "auto" dtype.
        return [torch.bfloat16, torch.float16, torch.float32]

139
140
141
142
143
144
    def is_cuda(self) -> bool:
        return self._enum == PlatformEnum.CUDA

    def is_rocm(self) -> bool:
        return self._enum == PlatformEnum.ROCM

145
146
147
    def is_tpu(self) -> bool:
        return self._enum == PlatformEnum.TPU

148
149
150
    def is_hpu(self) -> bool:
        return self._enum == PlatformEnum.HPU

151
152
153
    def is_xpu(self) -> bool:
        return self._enum == PlatformEnum.XPU

154
155
156
    def is_cpu(self) -> bool:
        return self._enum == PlatformEnum.CPU

157
158
159
    def is_neuron(self) -> bool:
        return self._enum == PlatformEnum.NEURON

160
161
162
    def is_out_of_tree(self) -> bool:
        return self._enum == PlatformEnum.OOT

163
    def is_cuda_alike(self) -> bool:
164
        """Stateless version of [torch.cuda.is_available][]."""
165
166
        return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)

167
168
169
    def is_sleep_mode_available(self) -> bool:
        return self._enum == PlatformEnum.CUDA

170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
    @classmethod
    def device_id_to_physical_device_id(cls, device_id: int):
        if cls.device_control_env_var in os.environ:
            device_ids = os.environ[cls.device_control_env_var].split(",")
            if device_ids == [""]:
                msg = (f"{cls.device_control_env_var} is set to empty string, "
                       "which means current platform support is disabled. If "
                       "you are using ray, please unset the environment "
                       f"variable `{cls.device_control_env_var}` inside the "
                       "worker/actor. Check "
                       "https://github.com/vllm-project/vllm/issues/8402 for "
                       "more information.")
                raise RuntimeError(msg)
            physical_device_id = device_ids[device_id]
            return int(physical_device_id)
        else:
            return device_id

188
    @classmethod
189
190
    def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
                             dtype: torch.dtype, kv_cache_dtype: Optional[str],
191
192
                             block_size: int, use_v1: bool,
                             use_mla: bool) -> str:
193
194
        """Get the attention backend class of a device."""
        return ""
195

196
197
198
199
200
    @classmethod
    def get_device_capability(
        cls,
        device_id: int = 0,
    ) -> Optional[DeviceCapability]:
201
        """Stateless version of [torch.cuda.get_device_capability][]."""
202
        return None
203

204
205
206
    @classmethod
    def has_device_capability(
        cls,
207
        capability: Union[tuple[int, int], int],
208
209
210
211
212
        device_id: int = 0,
    ) -> bool:
        """
        Test whether this platform is compatible with a device capability.

213
        The `capability` argument can either be:
214

215
216
217
        - A tuple `(major, minor)`.
        - An integer `<major><minor>`. (See
        [`DeviceCapability.to_int`][vllm.platforms.interface.DeviceCapability.to_int])
218
219
220
221
222
223
224
225
226
227
228
229
        """
        current_capability = cls.get_device_capability(device_id=device_id)
        if current_capability is None:
            return False

        if isinstance(capability, tuple):
            return current_capability >= capability

        return current_capability.to_int() >= capability

    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
230
231
232
        """Get the name of a device."""
        raise NotImplementedError

233
234
235
236
237
    @classmethod
    def get_device_uuid(cls, device_id: int = 0) -> str:
        """Get the uuid of a device, e.g. the PCI bus ID."""
        raise NotImplementedError

238
239
240
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        """Get the total memory of a device in bytes."""
241
242
        raise NotImplementedError

243
244
245
246
247
248
249
    @classmethod
    def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
        """
        Check if the current platform supports async output.
        """
        raise NotImplementedError

250
251
    @classmethod
    def inference_mode(cls):
252
253
254
255
256
257
258
259
        """A device-specific wrapper of `torch.inference_mode`.

        This wrapper is recommended because some hardware backends such as TPU
        do not support `torch.inference_mode`. In such a case, they will fall
        back to `torch.no_grad` by overriding this method.
        """
        return torch.inference_mode(mode=True)

260
    @classmethod
261
    def seed_everything(cls, seed: Optional[int] = None) -> None:
262
263
264
265
266
267
        """
        Set the seed of each random module.
        `torch.manual_seed` will set seed on all devices.

        Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
        """
268
269
270
271
        if seed is not None:
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)
272

273
274
275
276
277
    @classmethod
    def pre_register_and_update(cls,
                                parser: Optional[FlexibleArgumentParser] = None
                                ) -> None:
        """
278
        Do some pre-registration or update action for the current platform.
279
280
281
282
283
284
285
286
287
288

        This function is called before global VllmConfig is initialized or cli
        arguments are parsed. It's used for out-of-tree platforms to register or
        update the configuration.

        For example, the out-of-tree quantization config can be imported and
        registered here dynamically.
        """
        pass

289
290
291
292
293
294
295
296
297
298
299
300
301
    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
        """
        Check and update the configuration for the current platform.

        It can raise an exception if the configuration is not compatible with
        the current platform, or it can update the configuration to make it
        compatible with the current platform.

        The config is passed by reference, so it can be modified in place.
        """
        pass

302
303
304
305
306
307
308
309
310
311
312
313
    @classmethod
    def verify_model_arch(cls, model_arch: str) -> None:
        """
        Verify whether the current platform supports the specified model
        architecture.

        - This will raise an Error or Warning based on the model support on
        the current platform.
        - By default all models are considered supported.
        """
        pass

314
315
316
317
318
319
320
321
322
323
324
    @classmethod
    def verify_quantization(cls, quant: str) -> None:
        """
        Verify whether the quantization is supported by the current platform.
        """
        if cls.supported_quantization and \
            quant not in cls.supported_quantization:
            raise ValueError(
                f"{quant} quantization is currently not supported in "
                f"{cls.device_name}.")

325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
    @classmethod
    def get_cpu_architecture(cls) -> CpuArchEnum:
        """
        Determine the CPU architecture of the current system.
        Returns CpuArchEnum indicating the architecture type.
        """
        machine = platform.machine().lower()

        if machine in ("x86_64", "amd64", "i386", "i686"):
            return CpuArchEnum.X86
        elif machine.startswith("arm") or machine.startswith("aarch"):
            return CpuArchEnum.ARM
        elif machine.startswith("ppc"):
            return CpuArchEnum.POWERPC

        return CpuArchEnum.OTHER if machine else CpuArchEnum.UNKNOWN

342
343
344
345
346
347
348
349
350
351
352
    @classmethod
    def is_pin_memory_available(cls) -> bool:
        """Checks whether pin memory is available on the current platform."""
        if in_wsl():
            # Pinning memory in WSL is not supported.
            # https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
            logger.warning("Using 'pin_memory=False' as WSL is detected. "
                           "This may slow down the performance.")
            return False
        return True

353
354
355
356
357
358
359
360
361
    @classmethod
    def get_current_memory_usage(cls,
                                 device: Optional[torch.types.Device] = None
                                 ) -> float:
        """
        Return the memory usage in bytes.
        """
        raise NotImplementedError

362
363
364
365
366
367
368
    @classmethod
    def get_punica_wrapper(cls) -> str:
        """
        Return the punica wrapper for current platform.
        """
        raise NotImplementedError

369
    @classmethod
370
    def get_infinity_values(cls, dtype: torch.dtype) -> tuple[float, float]:
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
        """
        Return the platform specific values for (-inf, inf)
        """
        return float("-inf"), float("inf")

    @classmethod
    def can_update_inplace(cls) -> bool:
        """
        Checks if the platform allows inplace memory updates
        """
        return True

    @classmethod
    def get_lora_vocab_padding_size(cls) -> int:
        """
        Returns how much padding the LoRA logits need for kernels
        """
        return 256

390
391
392
393
394
    @classmethod
    def get_device_communicator_cls(cls) -> str:
        """
        Get device specific communicator class for distributed communication.
        """
Mengqing Cao's avatar
Mengqing Cao committed
395
        return "vllm.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase"  # noqa
396

397
398
399
400
401
402
403
    @classmethod
    def supports_mx(cls) -> bool:
        """
        Returns whether the current platform supports MX types.
        """
        return False

404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
    @classmethod
    def supports_fp8(cls) -> bool:
        """
        Returns whether the current platform supports FP8 types.
        """
        return False

    @classmethod
    def is_fp8_fnuz(cls) -> bool:
        """
        Returns whether the preferred FP8 type is FNUZ on the current platform.

        There are two representations of FP8, OCP FP8 and FNUZ FP8.
        The OCP specification can be found at https://tinyurl.com/b7jvwpft.
        The FNUZ specification can be found at https://tinyurl.com/5n6hwwu5.

        AMD's MI300 and MI325 have native hardware support for FNUZ. All other
        hardware has converged on the OCP FP8 standard.
        """
        return False

    @classmethod
    def fp8_dtype(cls) -> torch.dtype:
        """
        Returns the preferred FP8 type on the current platform.

        See the documentation for is_fp8_fnuz for details.
        """
        return torch.float8_e4m3fn

434
435
436
437
438
439
440
441
442
443
444
445
446
    @classmethod
    def use_all_gather(cls) -> bool:
        """
        Whether to use allgather in LogitsProcessor to gather the logits.
        """
        import vllm.envs as envs
        from vllm.config import get_current_vllm_config

        parallel_config = get_current_vllm_config().parallel_config
        return (envs.VLLM_USE_V1
                or parallel_config.distributed_executor_backend
                == "external_launcher")

447
448
449
450
451
452
453
    @classmethod
    def supports_v1(cls, model_config: ModelConfig) -> bool:
        """Returns whether the current platform can support v1 for the supplied
        model configuration.
        """
        return False

454
455
456
457
458
459
460
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        """
        Returns if custom allreduce is supported on the current platform
        """
        return False

461
462
463
464
465
    @classmethod
    def validate_request(
        cls,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
466
        processed_inputs: ProcessorInputs,
467
468
469
    ) -> None:
        """Raises if this request is unsupported on this platform"""

470
    def __getattr__(self, key: str):
471
        device = getattr(torch, self.device_type, None)
472
473
474
        if device is not None and hasattr(device, key):
            return getattr(device, key)
        else:
475
            logger.warning("Current platform %s does not have '%s'" \
476
            " attribute.", self.device_type, key)
477
478
            return None

479
480
481
482
483
484
485
    @classmethod
    def get_cu_count(cls, device_id: int = 0) -> int:
        """
        Returns the total number of compute units (CU) on single GPU.
        """
        raise NotImplementedError

486
487
488
489
490
491
492
    @classmethod
    def get_piecewise_backend_cls(cls) -> str:
        """
        Get piecewise backend class for piecewise graph.
        """
        return "vllm.compilation.base_piecewise_backend.AbstractPiecewiseBackend"  # noqa

493
494
495
496
497
498
499
500
501
502
503
504
505
506
    @classmethod
    def stateless_init_device_torch_dist_pg(
        cls,
        backend: str,
        prefix_store: PrefixStore,
        group_rank: int,
        group_size: int,
        timeout: timedelta,
    ) -> ProcessGroup:
        """
        Init platform-specific torch distributed process group.
        """
        raise RuntimeError(f"Unsupported torch distributed backend: {backend}")

507
508
509

class UnspecifiedPlatform(Platform):
    _enum = PlatformEnum.UNSPECIFIED
510
    device_type = ""