interface.py 18.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import enum
4
import os
5
import platform
6
import random
7
import sys
8
from datetime import timedelta
9
from platform import uname
10
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union
11

12
import numpy as np
13
import torch
14
from torch.distributed import PrefixStore, ProcessGroup
15

16
from vllm.inputs import ProcessorInputs, PromptType
17
18
from vllm.logger import init_logger

19
if TYPE_CHECKING:
20
    from vllm.config import ModelConfig, VllmConfig
21
22
23
    from vllm.lora.request import LoRARequest
    from vllm.pooling_params import PoolingParams
    from vllm.sampling_params import SamplingParams
24
    from vllm.utils import FlexibleArgumentParser
25
else:
26
    ModelConfig = None
27
    VllmConfig = None
28
29
30
    LoRARequest = None
    PoolingParams = None
    SamplingParams = None
31
    FlexibleArgumentParser = None
32

33
34
logger = init_logger(__name__)

35

36
37
38
39
40
def in_wsl() -> bool:
    # Reference: https://github.com/microsoft/WSL/issues/4071
    return "microsoft" in " ".join(uname()).lower()


41
42
43
class _Backend(enum.Enum):
    FLASH_ATTN = enum.auto()
    FLASH_ATTN_VLLM_V1 = enum.auto()
44
    TRITON_ATTN_VLLM_V1 = enum.auto()
45
46
    XFORMERS = enum.auto()
    ROCM_FLASH = enum.auto()
47
48
    ROCM_AITER_MLA = enum.auto()  # Supported by V1
    ROCM_AITER_MLA_VLLM_V1 = enum.auto()
49
    ROCM_AITER_FA = enum.auto()  # used for ViT attn backend
50
51
    TORCH_SDPA = enum.auto()
    FLASHINFER = enum.auto()
52
    FLASHINFER_VLLM_V1 = enum.auto()
53
    TRITON_MLA = enum.auto()  # Supported by V1
54
55
    TRITON_MLA_VLLM_V1 = enum.auto()
    FLASHMLA_VLLM_V1 = enum.auto()
56
    FLASHMLA = enum.auto()  # Supported by V1
57
    CUTLASS_MLA = enum.auto()
58
    PALLAS = enum.auto()
59
    PALLAS_VLLM_V1 = enum.auto()
60
    IPEX = enum.auto()
61
    DUAL_CHUNK_FLASH_ATTN = enum.auto()
62
    DIFFERENTIAL_FLASH_ATTN = enum.auto()
63
    NO_ATTENTION = enum.auto()
64
    FLEX_ATTENTION = enum.auto()
65
    TREE_ATTN = enum.auto()
66
    XFORMERS_VLLM_V1 = enum.auto()
67
68


69
70
71
class PlatformEnum(enum.Enum):
    CUDA = enum.auto()
    ROCM = enum.auto()
72
    TPU = enum.auto()
73
    XPU = enum.auto()
74
    CPU = enum.auto()
75
    NEURON = enum.auto()
76
    OOT = enum.auto()
77
    UNSPECIFIED = enum.auto()
78
79


80
81
82
83
84
85
86
87
class CpuArchEnum(enum.Enum):
    X86 = enum.auto()
    ARM = enum.auto()
    POWERPC = enum.auto()
    OTHER = enum.auto()
    UNKNOWN = enum.auto()


88
89
90
91
92
93
94
95
96
class DeviceCapability(NamedTuple):
    major: int
    minor: int

    def as_version_str(self) -> str:
        return f"{self.major}.{self.minor}"

    def to_int(self) -> int:
        """
97
        Express device capability as an integer `<major><minor>`.
98
99
100
101
102
103
104

        It is assumed that the minor version is always a single digit.
        """
        assert 0 <= self.minor < 10
        return self.major * 10 + self.minor


105
106
class Platform:
    _enum: PlatformEnum
107
    device_name: str
108
    device_type: str
109

110
111
112
113
    # available dispatch keys:
    # check https://github.com/pytorch/pytorch/blob/313dac6c1ca0fa0cde32477509cce32089f8532a/torchgen/model.py#L134 # noqa
    # use "CPU" as a fallback for platforms not registered in PyTorch
    dispatch_key: str = "CPU"
114

115
116
117
118
    # available ray device keys:
    # https://github.com/ray-project/ray/blob/10ba5adadcc49c60af2c358a33bb943fb491a171/python/ray/_private/ray_constants.py#L438 # noqa
    # empty string means the device does not support ray
    ray_device_key: str = ""
119
120
121
122
123
124
125

    # platform-agnostic way to specify the device control environment variable,
    # .e.g. CUDA_VISIBLE_DEVICES for CUDA.
    # hint: search for "get_visible_accelerator_ids_env_var" in
    # https://github.com/ray-project/ray/tree/master/python/ray/_private/accelerators # noqa
    device_control_env_var: str = "VLLM_DEVICE_CONTROL_ENV_VAR_PLACEHOLDER"

126
127
128
129
130
131
    # The torch.compile backend for compiling simple and
    # standalone functions. The default value is "inductor" to keep
    # the same behavior as PyTorch.
    # NOTE: for the forward part of the model, vLLM has another separate
    # compilation strategy.
    simple_compile_backend: str = "inductor"
132

133
134
135
    # The backend used for distributed communication.
    dist_backend: str = ""

136
    supported_quantization: list[str] = []
137

138
139
    additional_env_vars: list[str] = []

140
141
    _global_graph_pool: Optional[Any] = None

142
143
144
145
146
147
148
149
    @property
    def supported_dtypes(self) -> list[torch.dtype]:
        """Returns the supported dtypes for the current platform."""
        # Be careful with the order of the dtypes. The first dtype will
        # be used as the default dtype fallback for the current platform,
        # when encountering unsupported dtypes in "auto" dtype.
        return [torch.bfloat16, torch.float16, torch.float32]

150
151
152
153
154
155
    def is_cuda(self) -> bool:
        return self._enum == PlatformEnum.CUDA

    def is_rocm(self) -> bool:
        return self._enum == PlatformEnum.ROCM

156
157
158
    def is_tpu(self) -> bool:
        return self._enum == PlatformEnum.TPU

159
160
161
    def is_xpu(self) -> bool:
        return self._enum == PlatformEnum.XPU

162
163
164
    def is_cpu(self) -> bool:
        return self._enum == PlatformEnum.CPU

165
166
167
    def is_neuron(self) -> bool:
        return self._enum == PlatformEnum.NEURON

168
169
170
    def is_out_of_tree(self) -> bool:
        return self._enum == PlatformEnum.OOT

171
172
173
    def get_max_output_tokens(self, prompt_len: int) -> int:
        return sys.maxsize

174
    def is_cuda_alike(self) -> bool:
175
        """Stateless version of [torch.cuda.is_available][]."""
176
177
        return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)

178
179
180
    def is_sleep_mode_available(self) -> bool:
        return self._enum == PlatformEnum.CUDA

181
182
    @classmethod
    def device_id_to_physical_device_id(cls, device_id: int):
183
184
185
186
187
        # Treat empty device control env var as unset. This is a valid
        # configuration in Ray setups where the engine is launched in
        # a CPU-only placement group located on a GPU node.
        if cls.device_control_env_var in os.environ and os.environ[
                cls.device_control_env_var] != "":
188
189
190
191
192
193
            device_ids = os.environ[cls.device_control_env_var].split(",")
            physical_device_id = device_ids[device_id]
            return int(physical_device_id)
        else:
            return device_id

194
195
196
197
    @classmethod
    def get_vit_attn_backend(cls, support_fa: bool = False) -> _Backend:
        return _Backend.TORCH_SDPA

198
    @classmethod
199
200
    def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
                             dtype: torch.dtype, kv_cache_dtype: Optional[str],
201
202
                             block_size: int, use_v1: bool, use_mla: bool,
                             has_sink: bool) -> str:
203
204
        """Get the attention backend class of a device."""
        return ""
205

206
207
208
209
210
    @classmethod
    def get_device_capability(
        cls,
        device_id: int = 0,
    ) -> Optional[DeviceCapability]:
211
        """Stateless version of [torch.cuda.get_device_capability][]."""
212
        return None
213

214
215
216
    @classmethod
    def has_device_capability(
        cls,
217
        capability: Union[tuple[int, int], int],
218
219
220
221
222
        device_id: int = 0,
    ) -> bool:
        """
        Test whether this platform is compatible with a device capability.

223
        The `capability` argument can either be:
224

225
226
227
        - A tuple `(major, minor)`.
        - An integer `<major><minor>`. (See
        [`DeviceCapability.to_int`][vllm.platforms.interface.DeviceCapability.to_int])
228
229
230
231
232
233
234
235
236
237
        """
        current_capability = cls.get_device_capability(device_id=device_id)
        if current_capability is None:
            return False

        if isinstance(capability, tuple):
            return current_capability >= capability

        return current_capability.to_int() >= capability

238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
    @classmethod
    def is_device_capability(
        cls,
        capability: Union[tuple[int, int], int],
        device_id: int = 0,
    ) -> bool:
        """
        Test whether this platform has exactly the specified device capability.

        The `capability` argument can either be:

        - A tuple `(major, minor)`.
        - An integer `<major><minor>`. (See
        [`DeviceCapability.to_int`][vllm.platforms.interface.DeviceCapability.to_int])
        """
        current_capability = cls.get_device_capability(device_id=device_id)
        if current_capability is None:
            return False

        if isinstance(capability, tuple):
            return current_capability == capability

        return current_capability.to_int() == capability

262
263
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
264
265
266
        """Get the name of a device."""
        raise NotImplementedError

267
268
269
270
271
    @classmethod
    def get_device_uuid(cls, device_id: int = 0) -> str:
        """Get the uuid of a device, e.g. the PCI bus ID."""
        raise NotImplementedError

272
273
274
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        """Get the total memory of a device in bytes."""
275
276
        raise NotImplementedError

277
278
279
280
281
282
283
    @classmethod
    def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
        """
        Check if the current platform supports async output.
        """
        raise NotImplementedError

284
285
    @classmethod
    def inference_mode(cls):
286
287
288
289
290
291
292
293
        """A device-specific wrapper of `torch.inference_mode`.

        This wrapper is recommended because some hardware backends such as TPU
        do not support `torch.inference_mode`. In such a case, they will fall
        back to `torch.no_grad` by overriding this method.
        """
        return torch.inference_mode(mode=True)

294
    @classmethod
295
    def seed_everything(cls, seed: Optional[int] = None) -> None:
296
297
298
299
300
301
        """
        Set the seed of each random module.
        `torch.manual_seed` will set seed on all devices.

        Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
        """
302
303
304
305
        if seed is not None:
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)
306

307
308
309
310
311
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
312
        raise NotImplementedError
313

314
315
316
317
318
    @classmethod
    def pre_register_and_update(cls,
                                parser: Optional[FlexibleArgumentParser] = None
                                ) -> None:
        """
319
        Do some pre-registration or update action for the current platform.
320
321
322
323
324
325
326
327
328
329

        This function is called before global VllmConfig is initialized or cli
        arguments are parsed. It's used for out-of-tree platforms to register or
        update the configuration.

        For example, the out-of-tree quantization config can be imported and
        registered here dynamically.
        """
        pass

330
331
332
333
334
335
336
337
338
339
340
341
342
    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
        """
        Check and update the configuration for the current platform.

        It can raise an exception if the configuration is not compatible with
        the current platform, or it can update the configuration to make it
        compatible with the current platform.

        The config is passed by reference, so it can be modified in place.
        """
        pass

343
344
345
346
347
348
349
350
351
352
353
354
    @classmethod
    def verify_model_arch(cls, model_arch: str) -> None:
        """
        Verify whether the current platform supports the specified model
        architecture.

        - This will raise an Error or Warning based on the model support on
        the current platform.
        - By default all models are considered supported.
        """
        pass

355
356
357
358
359
360
361
362
363
364
365
    @classmethod
    def verify_quantization(cls, quant: str) -> None:
        """
        Verify whether the quantization is supported by the current platform.
        """
        if cls.supported_quantization and \
            quant not in cls.supported_quantization:
            raise ValueError(
                f"{quant} quantization is currently not supported in "
                f"{cls.device_name}.")

366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
    @classmethod
    def get_cpu_architecture(cls) -> CpuArchEnum:
        """
        Determine the CPU architecture of the current system.
        Returns CpuArchEnum indicating the architecture type.
        """
        machine = platform.machine().lower()

        if machine in ("x86_64", "amd64", "i386", "i686"):
            return CpuArchEnum.X86
        elif machine.startswith("arm") or machine.startswith("aarch"):
            return CpuArchEnum.ARM
        elif machine.startswith("ppc"):
            return CpuArchEnum.POWERPC

        return CpuArchEnum.OTHER if machine else CpuArchEnum.UNKNOWN

383
384
385
386
387
388
389
390
391
392
393
    @classmethod
    def is_pin_memory_available(cls) -> bool:
        """Checks whether pin memory is available on the current platform."""
        if in_wsl():
            # Pinning memory in WSL is not supported.
            # https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
            logger.warning("Using 'pin_memory=False' as WSL is detected. "
                           "This may slow down the performance.")
            return False
        return True

394
395
396
397
398
399
400
401
402
    @classmethod
    def get_current_memory_usage(cls,
                                 device: Optional[torch.types.Device] = None
                                 ) -> float:
        """
        Return the memory usage in bytes.
        """
        raise NotImplementedError

403
404
405
406
407
408
409
    @classmethod
    def get_punica_wrapper(cls) -> str:
        """
        Return the punica wrapper for current platform.
        """
        raise NotImplementedError

410
    @classmethod
411
    def get_infinity_values(cls, dtype: torch.dtype) -> tuple[float, float]:
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
        """
        Return the platform specific values for (-inf, inf)
        """
        return float("-inf"), float("inf")

    @classmethod
    def can_update_inplace(cls) -> bool:
        """
        Checks if the platform allows inplace memory updates
        """
        return True

    @classmethod
    def get_lora_vocab_padding_size(cls) -> int:
        """
        Returns how much padding the LoRA logits need for kernels
        """
        return 256

431
432
433
434
435
    @classmethod
    def get_device_communicator_cls(cls) -> str:
        """
        Get device specific communicator class for distributed communication.
        """
Mengqing Cao's avatar
Mengqing Cao committed
436
        return "vllm.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase"  # noqa
437

438
439
440
441
442
443
444
    @classmethod
    def supports_mx(cls) -> bool:
        """
        Returns whether the current platform supports MX types.
        """
        return False

445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
    @classmethod
    def supports_fp8(cls) -> bool:
        """
        Returns whether the current platform supports FP8 types.
        """
        return False

    @classmethod
    def is_fp8_fnuz(cls) -> bool:
        """
        Returns whether the preferred FP8 type is FNUZ on the current platform.

        There are two representations of FP8, OCP FP8 and FNUZ FP8.
        The OCP specification can be found at https://tinyurl.com/b7jvwpft.
        The FNUZ specification can be found at https://tinyurl.com/5n6hwwu5.

        AMD's MI300 and MI325 have native hardware support for FNUZ. All other
        hardware has converged on the OCP FP8 standard.
        """
        return False

    @classmethod
    def fp8_dtype(cls) -> torch.dtype:
        """
        Returns the preferred FP8 type on the current platform.

        See the documentation for is_fp8_fnuz for details.
        """
        return torch.float8_e4m3fn

475
476
477
478
479
480
481
482
483
484
485
486
487
    @classmethod
    def use_all_gather(cls) -> bool:
        """
        Whether to use allgather in LogitsProcessor to gather the logits.
        """
        import vllm.envs as envs
        from vllm.config import get_current_vllm_config

        parallel_config = get_current_vllm_config().parallel_config
        return (envs.VLLM_USE_V1
                or parallel_config.distributed_executor_backend
                == "external_launcher")

488
489
490
491
492
493
494
    @classmethod
    def supports_v1(cls, model_config: ModelConfig) -> bool:
        """Returns whether the current platform can support v1 for the supplied
        model configuration.
        """
        return False

495
496
497
498
499
500
501
    @classmethod
    def default_v1(cls, model_config: ModelConfig) -> bool:
        """
        Returns whether the current platform supports v1 by default.
        """
        return cls.supports_v1(model_config)

502
503
504
505
506
507
508
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        """
        Returns if custom allreduce is supported on the current platform
        """
        return False

509
510
511
512
513
    @classmethod
    def validate_request(
        cls,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
514
        processed_inputs: ProcessorInputs,
515
516
517
    ) -> None:
        """Raises if this request is unsupported on this platform"""

518
    def __getattr__(self, key: str):
519
        device = getattr(torch, self.device_type, None)
520
521
522
        if device is not None and hasattr(device, key):
            return getattr(device, key)
        else:
523
            logger.warning("Current platform %s does not have '%s'" \
524
            " attribute.", self.device_type, key)
525
526
            return None

527
528
529
530
531
532
533
534
535
    def get_global_graph_pool(self) -> Any:
        """
        Return the global graph pool for the this platform.
        """
        cls = self.__class__
        if cls._global_graph_pool is None:
            cls._global_graph_pool = self.graph_pool_handle()
        return cls._global_graph_pool

536
537
538
539
540
541
542
    @classmethod
    def get_cu_count(cls, device_id: int = 0) -> int:
        """
        Returns the total number of compute units (CU) on single GPU.
        """
        raise NotImplementedError

543
    @classmethod
544
    def get_static_graph_wrapper_cls(cls) -> str:
545
        """
546
        Get static graph wrapper class for static graph.
547
        """
548
        return "vllm.compilation.base_static_graph.AbstractStaticGraphWrapper"
549

550
551
552
553
554
555
556
557
558
559
560
561
562
563
    @classmethod
    def stateless_init_device_torch_dist_pg(
        cls,
        backend: str,
        prefix_store: PrefixStore,
        group_rank: int,
        group_size: int,
        timeout: timedelta,
    ) -> ProcessGroup:
        """
        Init platform-specific torch distributed process group.
        """
        raise RuntimeError(f"Unsupported torch distributed backend: {backend}")

564
565
566
567
568
569
570
    @classmethod
    def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool:
        """
        Returns if the kv_cache_dtype is supported by the current platform.
        """
        return False

571
572
573

class UnspecifiedPlatform(Platform):
    _enum = PlatformEnum.UNSPECIFIED
574
    device_type = ""