interface.py 31.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import contextlib
4
import enum
5
import os
6
import platform
7
import sys
8
from datetime import timedelta
9
from typing import TYPE_CHECKING, Any, NamedTuple
10

11
12
import torch

13
from vllm.logger import init_logger
14
from vllm.v1.attention.backends.registry import AttentionBackendEnum
15

16
if TYPE_CHECKING:
17
18
19
    from torch.distributed import PrefixStore, ProcessGroup

    from vllm.config import VllmConfig
20
    from vllm.config.kernel import IrOpPriorityConfig
21
    from vllm.inputs import EngineInput
22
23
    from vllm.pooling_params import PoolingParams
    from vllm.sampling_params import SamplingParams
24
    from vllm.utils.argparse_utils import FlexibleArgumentParser
25
    from vllm.v1.attention.backend import AttentionBackend
26
    from vllm.v1.attention.selector import AttentionSelectorConfig
27
else:
28
    FlexibleArgumentParser = object
29

30
31
logger = init_logger(__name__)

32

33
34
def in_wsl() -> bool:
    # Reference: https://github.com/microsoft/WSL/issues/4071
35
    return "microsoft" in " ".join(platform.uname()).lower()
36
37


38
class PlatformEnum(enum.Enum):
39
40
    """Enumeration of supported hardware platforms."""

41
42
    CUDA = enum.auto()
    ROCM = enum.auto()
43
    TPU = enum.auto()
44
    XPU = enum.auto()
45
    CPU = enum.auto()
46
    OOT = enum.auto()
47
    UNSPECIFIED = enum.auto()
48
49


50
51
52
53
class CpuArchEnum(enum.Enum):
    X86 = enum.auto()
    ARM = enum.auto()
    POWERPC = enum.auto()
54
    S390X = enum.auto()
55
    RISCV = enum.auto()
56
57
58
59
    OTHER = enum.auto()
    UNKNOWN = enum.auto()


60
61
62
63
class DeviceCapability(NamedTuple):
    major: int
    minor: int

64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    def __lt__(self, other: Any) -> bool:
        if not isinstance(other, DeviceCapability):
            return NotImplemented
        return (self.major, self.minor) < (other.major, other.minor)

    def __le__(self, other: Any) -> bool:
        if not isinstance(other, DeviceCapability):
            return NotImplemented
        return (self.major, self.minor) <= (other.major, other.minor)

    def __eq__(self, other: Any) -> bool:
        if not isinstance(other, DeviceCapability):
            return NotImplemented
        return (self.major, self.minor) == (other.major, other.minor)

    def __ge__(self, other: Any) -> bool:
        if not isinstance(other, DeviceCapability):
            return NotImplemented
        return (self.major, self.minor) >= (other.major, other.minor)

    def __gt__(self, other: Any) -> bool:
        if not isinstance(other, DeviceCapability):
            return NotImplemented
        return (self.major, self.minor) > (other.major, other.minor)

89
90
91
92
93
    def as_version_str(self) -> str:
        return f"{self.major}.{self.minor}"

    def to_int(self) -> int:
        """
94
        Express device capability as an integer `<major><minor>`.
95
96
97
98
99
100
101

        It is assumed that the minor version is always a single digit.
        """
        assert 0 <= self.minor < 10
        return self.major * 10 + self.minor


102
103
class Platform:
    _enum: PlatformEnum
104
    device_name: str
105
    device_type: str
106

107
108
109
110
    # available dispatch keys:
    # check https://github.com/pytorch/pytorch/blob/313dac6c1ca0fa0cde32477509cce32089f8532a/torchgen/model.py#L134 # noqa
    # use "CPU" as a fallback for platforms not registered in PyTorch
    dispatch_key: str = "CPU"
111

112
113
114
115
    # available ray device keys:
    # https://github.com/ray-project/ray/blob/10ba5adadcc49c60af2c358a33bb943fb491a171/python/ray/_private/ray_constants.py#L438 # noqa
    # empty string means the device does not support ray
    ray_device_key: str = ""
116
117
118
119
120
121
122

    # platform-agnostic way to specify the device control environment variable,
    # .e.g. CUDA_VISIBLE_DEVICES for CUDA.
    # hint: search for "get_visible_accelerator_ids_env_var" in
    # https://github.com/ray-project/ray/tree/master/python/ray/_private/accelerators # noqa
    device_control_env_var: str = "VLLM_DEVICE_CONTROL_ENV_VAR_PLACEHOLDER"

123
124
125
126
127
    # environment variables that need to be set to 1 to prevent ray from
    # setting the visible devices e.g.
    # RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES
    ray_noset_device_env_vars: list[str] = []

128
129
130
131
132
133
    # The torch.compile backend for compiling simple and
    # standalone functions. The default value is "inductor" to keep
    # the same behavior as PyTorch.
    # NOTE: for the forward part of the model, vLLM has another separate
    # compilation strategy.
    simple_compile_backend: str = "inductor"
134

135
136
137
    # The backend used for distributed communication.
    dist_backend: str = ""

138
    supported_quantization: list[str] = []
139

140
141
    additional_env_vars: list[str] = []

142
    _global_graph_pool: Any | None = None
143

144
145
146
147
148
    @property
    def pass_key(self) -> str:
        """Inductor config key for the PassManager custom pass"""
        return "post_grad_custom_post_pass"

149
150
151
152
153
154
155
156
    @property
    def supported_dtypes(self) -> list[torch.dtype]:
        """Returns the supported dtypes for the current platform."""
        # Be careful with the order of the dtypes. The first dtype will
        # be used as the default dtype fallback for the current platform,
        # when encountering unsupported dtypes in "auto" dtype.
        return [torch.bfloat16, torch.float16, torch.float32]

157
158
159
160
161
162
    def is_cuda(self) -> bool:
        return self._enum == PlatformEnum.CUDA

    def is_rocm(self) -> bool:
        return self._enum == PlatformEnum.ROCM

163
164
165
    def is_tpu(self) -> bool:
        return self._enum == PlatformEnum.TPU

166
167
168
    def is_xpu(self) -> bool:
        return self._enum == PlatformEnum.XPU

169
170
171
    def is_cpu(self) -> bool:
        return self._enum == PlatformEnum.CPU

172
173
174
    def is_zen_cpu(self) -> bool:
        return False

175
176
177
    def is_out_of_tree(self) -> bool:
        return self._enum == PlatformEnum.OOT

178
179
180
    def is_unspecified(self) -> bool:
        return self._enum == PlatformEnum.UNSPECIFIED

181
182
183
    def get_max_output_tokens(self, prompt_len: int) -> int:
        return sys.maxsize

184
    def is_cuda_alike(self) -> bool:
185
        """Stateless version of [torch.cuda.is_available][]."""
186
187
        return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)

188
    def is_sleep_mode_available(self) -> bool:
189
190
191
192
193
        # TODO: Actually only mi3xx has the sleep mode support now
        # for ROCm, but currently we don't have a way to detect the
        # exact GPU model statelessly here. So we return True for
        # all ROCm platforms for now.
        return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
194

195
196
197
198
199
200
    @classmethod
    def get_pass_manager_cls(cls) -> str:
        """
        Get the pass manager class for this platform.
        It will be registered as a custom pass under the current_platform.pass_key.
        """
201
        return "vllm.compilation.passes.pass_manager.PostGradPassManager"
202
203
204
205
206
207
208
209

    @classmethod
    def get_compile_backend(cls) -> str:
        """
        Get the custom compile backend for current platform.
        """
        return cls.simple_compile_backend

210
211
    @classmethod
    def device_id_to_physical_device_id(cls, device_id: int):
212
213
214
        # Treat empty device control env var as unset. This is a valid
        # configuration in Ray setups where the engine is launched in
        # a CPU-only placement group located on a GPU node.
215
216
217
218
        if (
            cls.device_control_env_var in os.environ
            and os.environ[cls.device_control_env_var] != ""
        ):
219
220
221
222
223
224
            device_ids = os.environ[cls.device_control_env_var].split(",")
            physical_device_id = device_ids[device_id]
            return int(physical_device_id)
        else:
            return device_id

225
    @classmethod
226
    def import_kernels(cls) -> None:
227
        """Import any platform-specific C kernels."""
228
229
230
231
232
233
        try:
            import vllm._C  # noqa: F401
        except ImportError as e:
            logger.warning("Failed to import from vllm._C: %r", e)
        with contextlib.suppress(ImportError):
            import vllm._moe_C  # noqa: F401
234

235
    @classmethod
236
237
    def get_attn_backend_cls(
        cls,
238
        selected_backend: "AttentionBackendEnum",
239
        attn_selector_config: "AttentionSelectorConfig",
240
        num_heads: int | None = None,
241
    ) -> str:
242
243
        """Get the attention backend class of a device."""
        return ""
244

245
246
247
248
249
250
251
252
253
254
255
    @classmethod
    def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
        return [
            AttentionBackendEnum.TORCH_SDPA,
        ]

    @classmethod
    def get_vit_attn_backend(
        cls,
        head_size: int,
        dtype: torch.dtype,
256
        backend: "AttentionBackendEnum | None" = None,
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
    ) -> "AttentionBackendEnum":
        """
        Get the vision attention backend class of a device.

        NOTE: ViT Attention should be checked and override in the platform-specific
        implementation. we should not override this in any other places, like
        the model_executor/models/<model_name>.py.

        We check if the backend is None or not:
            1. If not, check if the backend is supported by the platform.
            2. If None, continue to the default selection logic.
        """
        if backend is not None:
            assert backend in cls.get_supported_vit_attn_backends(), (
                f"Backend {backend} is not supported for vit attention"
                f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
            )
            logger.info_once(f"Using backend {backend} for vit attention")
            return backend

        logger.info_once(
            f"Using default backend {AttentionBackendEnum.TORCH_SDPA} for vit attention"
        )
        return AttentionBackendEnum.TORCH_SDPA

282
283
284
285
    @classmethod
    def get_device_capability(
        cls,
        device_id: int = 0,
286
    ) -> DeviceCapability | None:
287
        """Stateless version of [torch.cuda.get_device_capability][]."""
288
        return None
289

290
291
292
    @classmethod
    def has_device_capability(
        cls,
293
        capability: tuple[int, int] | int,
294
295
296
297
298
        device_id: int = 0,
    ) -> bool:
        """
        Test whether this platform is compatible with a device capability.

299
        The `capability` argument can either be:
300

301
302
303
        - A tuple `(major, minor)`.
        - An integer `<major><minor>`. (See
        [`DeviceCapability.to_int`][vllm.platforms.interface.DeviceCapability.to_int])
304
305
306
307
308
309
310
311
312
313
        """
        current_capability = cls.get_device_capability(device_id=device_id)
        if current_capability is None:
            return False

        if isinstance(capability, tuple):
            return current_capability >= capability

        return current_capability.to_int() >= capability

314
315
316
    @classmethod
    def is_device_capability(
        cls,
317
        capability: tuple[int, int] | int,
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
        device_id: int = 0,
    ) -> bool:
        """
        Test whether this platform has exactly the specified device capability.

        The `capability` argument can either be:

        - A tuple `(major, minor)`.
        - An integer `<major><minor>`. (See
        [`DeviceCapability.to_int`][vllm.platforms.interface.DeviceCapability.to_int])
        """
        current_capability = cls.get_device_capability(device_id=device_id)
        if current_capability is None:
            return False

        if isinstance(capability, tuple):
            return current_capability == capability

        return current_capability.to_int() == capability

338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
    @classmethod
    def is_device_capability_family(
        cls,
        capability: int,
        device_id: int = 0,
    ) -> bool:
        """
        Returns True if the device capability is any <major>.x.
        Mirrors CUDA 13 'family' architecture semantics (e.g. 10.x, 11.x, 12.x).
        """
        current_capability = cls.get_device_capability(device_id=device_id)
        if current_capability is None:
            return False
        return (current_capability.to_int() // 10) == (capability // 10)

353
354
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
355
356
357
        """Get the name of a device."""
        raise NotImplementedError

358
359
360
361
362
    @classmethod
    def get_device_uuid(cls, device_id: int = 0) -> str:
        """Get the uuid of a device, e.g. the PCI bus ID."""
        raise NotImplementedError

363
364
365
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        """Get the total memory of a device in bytes."""
366
367
        raise NotImplementedError

368
369
    @classmethod
    def inference_mode(cls):
370
371
372
373
374
375
376
377
        """A device-specific wrapper of `torch.inference_mode`.

        This wrapper is recommended because some hardware backends such as TPU
        do not support `torch.inference_mode`. In such a case, they will fall
        back to `torch.no_grad` by overriding this method.
        """
        return torch.inference_mode(mode=True)

378
379
380
381
382
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
383
        raise NotImplementedError
384

385
    @classmethod
386
    def pre_register_and_update(
387
        cls, parser: FlexibleArgumentParser | None = None
388
    ) -> None:
389
        """
390
        Do some pre-registration or update action for the current platform.
391
392
393
394
395
396
397
398
399
400

        This function is called before global VllmConfig is initialized or cli
        arguments are parsed. It's used for out-of-tree platforms to register or
        update the configuration.

        For example, the out-of-tree quantization config can be imported and
        registered here dynamically.
        """
        pass

401
402
403
404
405
406
407
408
409
410
411
412
413
414
    @classmethod
    def apply_config_platform_defaults(cls, vllm_config: "VllmConfig") -> None:
        """
        Apply the platform-specific default values to the config.

        This function is called during the initialization of global VllmConfig, after
        parsing cli arguments.
        It can modify the defaults of the config according to the platform. For example,
        it can enable custom_ops based on the enabled features.

        The config is passed by reference, so it can be modified in place.
        """
        pass

415
    @classmethod
416
    def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
417
418
419
420
421
422
423
424
425
426
427
        """
        Check and update the configuration for the current platform.

        It can raise an exception if the configuration is not compatible with
        the current platform, or it can update the configuration to make it
        compatible with the current platform.

        The config is passed by reference, so it can be modified in place.
        """
        pass

428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
    @classmethod
    def _find_non_ssm_backend(
        cls, vllm_config: "VllmConfig"
    ) -> "type[AttentionBackend] | None":
        """Find the first non-SSM attention backend from model layers."""
        from vllm.config.vllm import get_layers_from_vllm_config
        from vllm.model_executor.layers.attention_layer_base import (
            AttentionLayerBase,
        )

        attn_layers = get_layers_from_vllm_config(
            vllm_config,
            AttentionLayerBase,  # type: ignore[type-abstract]
        )
        for layer in attn_layers.values():
            b = layer.get_attn_backend()
            if not b.is_ssm():
                return b
        return None

448
449
450
451
    @classmethod
    def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None:
        """
        Ensure block_size is compatible with the attention backend.
452
        For hybrid models, also aligns block_size with mamba page sizes.
453
454
        """
        from vllm.config.cache import CacheConfig
455
        from vllm.config.vllm import set_current_vllm_config
456
457
458

        cache_config = vllm_config.cache_config
        model_config = vllm_config.model_config
459

460
        # model_config may be None during testing.
461
        if not model_config:
462
463
            return

464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
        backend_cls = cls._find_non_ssm_backend(vllm_config)
        if backend_cls is None:
            return

        # Phase 1: Pick block size from backend (skip if user set --block-size)
        if not cache_config.user_specified_block_size:
            with set_current_vllm_config(vllm_config):
                preferred = backend_cls.get_preferred_block_size(
                    CacheConfig.DEFAULT_BLOCK_SIZE
                )
            if preferred != CacheConfig.DEFAULT_BLOCK_SIZE:
                logger.info(
                    "Setting kv cache block size to %d for %s backend.",
                    preferred,
                    backend_cls.get_name(),
                )
            cache_config.block_size = preferred

        # Phase 2: Align block/mamba sizes for hybrid models
        # (may override user settings).
        if model_config.is_hybrid:
            cls._align_hybrid_block_size(vllm_config, backend_cls)

    @classmethod
    def _align_hybrid_block_size(
        cls,
        vllm_config: "VllmConfig",
        backend_cls: "type[AttentionBackend]",
    ) -> None:
        """
        For hybrid attention/mamba models, ensure that the attention page
        size is >= the mamba page size, and pad the mamba page size to match.
        """
        from math import lcm

        from vllm.config.vllm import set_current_vllm_config
        from vllm.model_executor.models import ModelRegistry
        from vllm.utils.math_utils import cdiv
        from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
        from vllm.v1.attention.backend import MultipleOf
        from vllm.v1.kv_cache_interface import (
            FullAttentionSpec,
            MambaSpec,
            MLAAttentionSpec,
508
            get_kv_quant_mode,
509
510
        )

511
512
513
514
515
516
517
518
519
        cache_config = vllm_config.cache_config
        model_config = vllm_config.model_config
        parallel_config = vllm_config.parallel_config

        if cache_config.cache_dtype == "auto":
            kv_cache_dtype = model_config.dtype
        else:
            kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]

520
521
        kv_quant_mode = get_kv_quant_mode(cache_config.cache_dtype)

522
523
524
525
526
527
528
        # Compute attention page size for 1 token
        if model_config.use_mla:
            attn_page_size_1_token = MLAAttentionSpec(
                block_size=1,
                num_kv_heads=model_config.get_num_kv_heads(parallel_config),
                head_size=model_config.get_head_size(),
                dtype=kv_cache_dtype,
529
                kv_quant_mode=kv_quant_mode,
530
531
532
533
534
535
536
            ).page_size_bytes
        else:
            attn_page_size_1_token = FullAttentionSpec(
                block_size=1,
                num_kv_heads=model_config.get_num_kv_heads(parallel_config),
                head_size=model_config.get_head_size(),
                dtype=kv_cache_dtype,
537
                kv_quant_mode=kv_quant_mode,
538
539
540
541
542
543
            ).page_size_bytes

        # Compute mamba page size
        model_cls, _ = ModelRegistry.resolve_model_cls(
            model_config.architecture,
            model_config=model_config,
544
        )
545
546
547
548
549
550
551
        mamba_page_size = MambaSpec(
            shapes=model_cls.get_mamba_state_shape_from_config(vllm_config),
            dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config),
            block_size=-1,
        ).page_size_bytes

        if mamba_page_size == 0:
552
553
            return

554
555
556
557
558
559
560
561
        # mamba_block_size here should either be user specified value or None
        mamba_block_size = (
            cache_config.mamba_block_size
            if cache_config.user_specified_mamba_block_size
            else None
        )

        # Get kernel block alignment from the backend's supported sizes
562
        with set_current_vllm_config(vllm_config):
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
            kernel_block_alignment_size = max(
                min(
                    s.base if isinstance(s, MultipleOf) else s
                    for s in backend_cls.get_supported_kernel_block_sizes()
                ),
                cache_config.block_size,
            )

        if cache_config.mamba_cache_mode == "all":
            # With prefix caching, align to mamba chunk size for kernel perf
            # TODO(tdoublep): this constraint can be relaxed fairly
            # easily by changing the way we layout chunks in the
            # mamba2 kernels.
            base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size()
            assert base_chunk_size is not None
            attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token)
            chunk_size = lcm(base_chunk_size, kernel_block_alignment_size)
            attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size)
            cache_config.mamba_block_size = attn_block_size
        else:
            # Without prefix caching, use minimum block size that satisfies
            # both backend alignment and mamba page size compatibility
            attn_block_size = kernel_block_alignment_size * cdiv(
                mamba_page_size,
                kernel_block_alignment_size * attn_page_size_1_token,
            )

        if cache_config.block_size < attn_block_size:
            cache_config.block_size = attn_block_size
            logger.info(
                "Setting attention block size to %d tokens "
                "to ensure that attention page size is >= mamba page size.",
                attn_block_size,
            )

        if cache_config.mamba_cache_mode == "align":
            cache_config.mamba_block_size = cache_config.block_size

        # Pad mamba page size to exactly match attention page size
        attn_page_size = cache_config.block_size * attn_page_size_1_token
        assert attn_page_size >= mamba_page_size

        if attn_page_size == mamba_page_size:
            return

        if (
            cache_config.mamba_page_size_padded is None
            or cache_config.mamba_page_size_padded != attn_page_size
        ):
            cache_config.mamba_page_size_padded = attn_page_size
            mamba_padding_pct = (
                100 * (attn_page_size - mamba_page_size) / mamba_page_size
615
616
            )
            logger.info(
617
618
619
620
                "Padding mamba page size by %.2f%% to ensure "
                "that mamba page size and attention page size are "
                "exactly equal.",
                mamba_padding_pct,
621
622
            )

623
624
625
626
627
628
629
630
631
632
633
634
    @classmethod
    def verify_model_arch(cls, model_arch: str) -> None:
        """
        Verify whether the current platform supports the specified model
        architecture.

        - This will raise an Error or Warning based on the model support on
        the current platform.
        - By default all models are considered supported.
        """
        pass

635
636
637
638
639
    @classmethod
    def verify_quantization(cls, quant: str) -> None:
        """
        Verify whether the quantization is supported by the current platform.
        """
640
        if cls.supported_quantization and quant not in cls.supported_quantization:
641
            raise ValueError(
642
643
                f"{quant} quantization is currently not supported in {cls.device_name}."
            )
644

645
646
647
648
649
650
651
652
653
654
655
656
657
658
    @classmethod
    def get_cpu_architecture(cls) -> CpuArchEnum:
        """
        Determine the CPU architecture of the current system.
        Returns CpuArchEnum indicating the architecture type.
        """
        machine = platform.machine().lower()

        if machine in ("x86_64", "amd64", "i386", "i686"):
            return CpuArchEnum.X86
        elif machine.startswith("arm") or machine.startswith("aarch"):
            return CpuArchEnum.ARM
        elif machine.startswith("ppc"):
            return CpuArchEnum.POWERPC
659
660
        elif machine == "s390x":
            return CpuArchEnum.S390X
661
662
        elif machine.startswith("riscv"):
            return CpuArchEnum.RISCV
663
664
665

        return CpuArchEnum.OTHER if machine else CpuArchEnum.UNKNOWN

666
667
668
669
670
671
    @classmethod
    def is_pin_memory_available(cls) -> bool:
        """Checks whether pin memory is available on the current platform."""
        if in_wsl():
            # Pinning memory in WSL is not supported.
            # https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
672
673
674
675
            logger.warning(
                "Using 'pin_memory=False' as WSL is detected. "
                "This may slow down the performance."
            )
676
677
678
            return False
        return True

679
    @classmethod
680
    def get_current_memory_usage(
681
        cls, device: torch.types.Device | None = None
682
    ) -> float:
683
684
685
686
687
        """
        Return the memory usage in bytes.
        """
        raise NotImplementedError

688
689
690
691
692
693
694
    @classmethod
    def get_punica_wrapper(cls) -> str:
        """
        Return the punica wrapper for current platform.
        """
        raise NotImplementedError

695
    @classmethod
696
    def get_infinity_values(cls, dtype: torch.dtype) -> tuple[float, float]:
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
        """
        Return the platform specific values for (-inf, inf)
        """
        return float("-inf"), float("inf")

    @classmethod
    def can_update_inplace(cls) -> bool:
        """
        Checks if the platform allows inplace memory updates
        """
        return True

    @classmethod
    def get_lora_vocab_padding_size(cls) -> int:
        """
        Returns how much padding the LoRA logits need for kernels
        """
        return 256

716
717
718
719
720
    @classmethod
    def get_device_communicator_cls(cls) -> str:
        """
        Get device specific communicator class for distributed communication.
        """
Mengqing Cao's avatar
Mengqing Cao committed
721
        return "vllm.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase"  # noqa
722

723
724
725
726
727
728
729
    @classmethod
    def supports_mx(cls) -> bool:
        """
        Returns whether the current platform supports MX types.
        """
        return False

730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
    @classmethod
    def supports_fp8(cls) -> bool:
        """
        Returns whether the current platform supports FP8 types.
        """
        return False

    @classmethod
    def is_fp8_fnuz(cls) -> bool:
        """
        Returns whether the preferred FP8 type is FNUZ on the current platform.

        There are two representations of FP8, OCP FP8 and FNUZ FP8.
        The OCP specification can be found at https://tinyurl.com/b7jvwpft.
        The FNUZ specification can be found at https://tinyurl.com/5n6hwwu5.

        AMD's MI300 and MI325 have native hardware support for FNUZ. All other
        hardware has converged on the OCP FP8 standard.
        """
        return False

    @classmethod
    def fp8_dtype(cls) -> torch.dtype:
        """
        Returns the preferred FP8 type on the current platform.

        See the documentation for is_fp8_fnuz for details.
        """
        return torch.float8_e4m3fn

760
761
762
763
764
    @classmethod
    def use_all_gather(cls) -> bool:
        """
        Whether to use allgather in LogitsProcessor to gather the logits.
        """
765
        return True
766

767
768
769
770
771
772
773
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        """
        Returns if custom allreduce is supported on the current platform
        """
        return False

774
775
776
777
778
779
780
781
    @classmethod
    def opaque_attention_op(cls) -> bool:
        """
        Returns True if we register attention as one giant opaque custom op
        on the current platform
        """
        return False

782
783
784
    @classmethod
    def validate_request(
        cls,
785
        processed_inputs: "EngineInput",
786
        params: "SamplingParams | PoolingParams",
787
788
789
    ) -> None:
        """Raises if this request is unsupported on this platform"""

790
    def __getattr__(self, key: str):
791
792
793
794
795
        # Pickle checks dunder methods like __getstate__. If we return None
        # for them, pickle treats it like a real value and tries to call it.
        if key.startswith("__") and key.endswith("__"):
            raise AttributeError(key)

796
        device = getattr(torch, self.device_type, None)
797
        if device is not None and hasattr(device, key):
798
799
800
801
802
803
804
805
806
807
808
809
            attr = getattr(device, key)
            # NOTE: `hasattr(device, key)=True` can only avoid AttributeError,
            # but the value of this attr could be `None`.
            if attr is not None:
                return attr

        logger.warning(
            "Current platform %s does not have '%s' attribute.",
            self.device_type,
            key,
        )
        return None
810

811
812
    def get_global_graph_pool(self) -> Any:
        """
813
        Return the global graph pool for this platform.
814
815
816
817
818
819
        """
        cls = self.__class__
        if cls._global_graph_pool is None:
            cls._global_graph_pool = self.graph_pool_handle()
        return cls._global_graph_pool

820
    @classmethod
821
    def get_static_graph_wrapper_cls(cls) -> str:
822
        """
823
        Get static graph wrapper class for static graph.
824
        """
825
        return "vllm.compilation.base_static_graph.AbstractStaticGraphWrapper"
826

827
828
829
830
    @classmethod
    def stateless_init_device_torch_dist_pg(
        cls,
        backend: str,
831
        prefix_store: "PrefixStore",
832
833
834
        group_rank: int,
        group_size: int,
        timeout: timedelta,
835
    ) -> "ProcessGroup":
836
837
838
        """
        Init platform-specific torch distributed process group.
        """
839
        raise NotImplementedError
840

841
    @classmethod
842
    def check_if_supports_dtype(cls, dtype: torch.dtype):
843
844
845
846
847
        """
        Check if the dtype is supported by the current platform.
        """
        raise NotImplementedError

848
849
850
851
852
853
854
    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        """
        Returns if the hybrid kv cache is supported by the current platform.
        """
        return False

855
856
857
858
859
860
861
    @classmethod
    def support_static_graph_mode(cls) -> bool:
        """
        Returns if the graph mode is supported by the current platform.
        """
        return False

862
863
864
865
866
867
868
    @classmethod
    def support_deep_gemm(cls) -> bool:
        """
        Returns if DeepGEMM is supported by the current platform.
        """
        return False

869
870
871
872
873
874
875
876
877
    @classmethod
    def use_custom_op_collectives(cls) -> bool:
        """
        Whether this platform should use torch.ops.vllm.* custom ops for collectives.

        Returns False by default - platforms must explicitly opt-in.
        """
        return False

878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
    @classmethod
    def use_sync_weight_loader(cls) -> bool:
        """
        Returns if the current platform needs to sync weight loader.
        """
        return False

    @classmethod
    def make_synced_weight_loader(cls, original_weight_loader):
        """
        Wrap the original weight loader to make it synced.
        """
        if not cls.use_sync_weight_loader():
            return original_weight_loader

        def _synced_weight_loader(param, *args, **kwargs):
            out = original_weight_loader(param, *args, **kwargs)
            if param.device != torch.device("cpu"):
                torch._sync(param)
            return out

        return _synced_weight_loader

901
902
903
    @classmethod
    def get_nixl_supported_devices(cls) -> dict[str, tuple[str, ...]]:
        """
904
        Returns a mapping from device_type to a tuple of supported
905
906
907
908
909
        kv_buffer_device for nixl.
        """
        return {}

    @classmethod
910
    def get_nixl_memory_type(cls) -> str | None:
911
912
913
914
915
        """
        Returns the nixl memory type for the current platform.
        """
        return None

916
917
918
919
920
921
922
    @classmethod
    def check_max_model_len(cls, max_model_len: int) -> int:
        """
        Check max_model_len for the current platform.
        """
        return max_model_len

923
924
925
926
927
928
929
    @classmethod
    def set_additional_forward_context(cls, *args, **kwargs) -> dict[str, Any]:
        """
        Set some additional forward context for the current platform if needs.
        """
        return {}

930
931
932
933
934
935
936
937
938
939
    @classmethod
    def num_compute_units(cls, device_id: int = 0) -> int:
        """
        Get the number of compute units for the current platform.
        (NVIDIA SM / AMD CU / Intel EU)
        """
        raise NotImplementedError(
            "num_compute_units is not implemented for the current platform."
        )

940
941
942
943
944
945
946
947
948
949
    @classmethod
    def get_default_ir_op_priority(
        cls, vllm_config: "VllmConfig"
    ) -> "IrOpPriorityConfig":
        """Get the default IR op priority for the current platform."""
        from vllm.config.kernel import IrOpPriorityConfig

        # Native always used by default. Platforms can override this behavior.
        return IrOpPriorityConfig.with_default(["native"])

950
951
952

class UnspecifiedPlatform(Platform):
    _enum = PlatformEnum.UNSPECIFIED
953
    device_type = ""