interface.py 25 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import contextlib
4
import enum
5
import os
6
import platform
7
import sys
8
from datetime import timedelta
9
from typing import TYPE_CHECKING, Any, NamedTuple
10

11
12
import torch

13
from vllm.logger import init_logger
14
from vllm.v1.attention.backends.registry import AttentionBackendEnum
15

16
if TYPE_CHECKING:
17
18
19
    from torch.distributed import PrefixStore, ProcessGroup

    from vllm.config import VllmConfig
20
    from vllm.inputs import ProcessorInputs
21
22
    from vllm.pooling_params import PoolingParams
    from vllm.sampling_params import SamplingParams
23
    from vllm.utils.argparse_utils import FlexibleArgumentParser
24
    from vllm.v1.attention.selector import AttentionSelectorConfig
25
else:
26
    FlexibleArgumentParser = object
27

28
29
logger = init_logger(__name__)

30

31
32
def in_wsl() -> bool:
    # Reference: https://github.com/microsoft/WSL/issues/4071
33
    return "microsoft" in " ".join(platform.uname()).lower()
34
35


36
class PlatformEnum(enum.Enum):
37
38
    """Enumeration of supported hardware platforms."""

39
40
    CUDA = enum.auto()
    ROCM = enum.auto()
41
    TPU = enum.auto()
42
    XPU = enum.auto()
43
    CPU = enum.auto()
44
    OOT = enum.auto()
45
    UNSPECIFIED = enum.auto()
46
47


48
49
50
51
class CpuArchEnum(enum.Enum):
    X86 = enum.auto()
    ARM = enum.auto()
    POWERPC = enum.auto()
52
    S390X = enum.auto()
53
    RISCV = enum.auto()
54
55
56
57
    OTHER = enum.auto()
    UNKNOWN = enum.auto()


58
59
60
61
class DeviceCapability(NamedTuple):
    major: int
    minor: int

62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
    def __lt__(self, other: Any) -> bool:
        if not isinstance(other, DeviceCapability):
            return NotImplemented
        return (self.major, self.minor) < (other.major, other.minor)

    def __le__(self, other: Any) -> bool:
        if not isinstance(other, DeviceCapability):
            return NotImplemented
        return (self.major, self.minor) <= (other.major, other.minor)

    def __eq__(self, other: Any) -> bool:
        if not isinstance(other, DeviceCapability):
            return NotImplemented
        return (self.major, self.minor) == (other.major, other.minor)

    def __ge__(self, other: Any) -> bool:
        if not isinstance(other, DeviceCapability):
            return NotImplemented
        return (self.major, self.minor) >= (other.major, other.minor)

    def __gt__(self, other: Any) -> bool:
        if not isinstance(other, DeviceCapability):
            return NotImplemented
        return (self.major, self.minor) > (other.major, other.minor)

87
88
89
90
91
    def as_version_str(self) -> str:
        return f"{self.major}.{self.minor}"

    def to_int(self) -> int:
        """
92
        Express device capability as an integer `<major><minor>`.
93
94
95
96
97
98
99

        It is assumed that the minor version is always a single digit.
        """
        assert 0 <= self.minor < 10
        return self.major * 10 + self.minor


100
101
class Platform:
    _enum: PlatformEnum
102
    device_name: str
103
    device_type: str
104

105
106
107
108
    # available dispatch keys:
    # check https://github.com/pytorch/pytorch/blob/313dac6c1ca0fa0cde32477509cce32089f8532a/torchgen/model.py#L134 # noqa
    # use "CPU" as a fallback for platforms not registered in PyTorch
    dispatch_key: str = "CPU"
109

110
111
112
113
    # available ray device keys:
    # https://github.com/ray-project/ray/blob/10ba5adadcc49c60af2c358a33bb943fb491a171/python/ray/_private/ray_constants.py#L438 # noqa
    # empty string means the device does not support ray
    ray_device_key: str = ""
114
115
116
117
118
119
120

    # platform-agnostic way to specify the device control environment variable,
    # .e.g. CUDA_VISIBLE_DEVICES for CUDA.
    # hint: search for "get_visible_accelerator_ids_env_var" in
    # https://github.com/ray-project/ray/tree/master/python/ray/_private/accelerators # noqa
    device_control_env_var: str = "VLLM_DEVICE_CONTROL_ENV_VAR_PLACEHOLDER"

121
122
123
124
125
    # environment variables that need to be set to 1 to prevent ray from
    # setting the visible devices e.g.
    # RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES
    ray_noset_device_env_vars: list[str] = []

126
127
128
129
130
131
    # The torch.compile backend for compiling simple and
    # standalone functions. The default value is "inductor" to keep
    # the same behavior as PyTorch.
    # NOTE: for the forward part of the model, vLLM has another separate
    # compilation strategy.
    simple_compile_backend: str = "inductor"
132

133
134
135
    # The backend used for distributed communication.
    dist_backend: str = ""

136
    supported_quantization: list[str] = []
137

138
139
    additional_env_vars: list[str] = []

140
    _global_graph_pool: Any | None = None
141

142
143
144
145
146
    @property
    def pass_key(self) -> str:
        """Inductor config key for the PassManager custom pass"""
        return "post_grad_custom_post_pass"

147
148
149
150
151
152
153
154
    @property
    def supported_dtypes(self) -> list[torch.dtype]:
        """Returns the supported dtypes for the current platform."""
        # Be careful with the order of the dtypes. The first dtype will
        # be used as the default dtype fallback for the current platform,
        # when encountering unsupported dtypes in "auto" dtype.
        return [torch.bfloat16, torch.float16, torch.float32]

155
156
157
158
159
160
    def is_cuda(self) -> bool:
        return self._enum == PlatformEnum.CUDA

    def is_rocm(self) -> bool:
        return self._enum == PlatformEnum.ROCM

161
162
163
    def is_tpu(self) -> bool:
        return self._enum == PlatformEnum.TPU

164
165
166
    def is_xpu(self) -> bool:
        return self._enum == PlatformEnum.XPU

167
168
169
    def is_cpu(self) -> bool:
        return self._enum == PlatformEnum.CPU

170
171
172
    def is_out_of_tree(self) -> bool:
        return self._enum == PlatformEnum.OOT

173
174
175
    def is_unspecified(self) -> bool:
        return self._enum == PlatformEnum.UNSPECIFIED

176
177
178
    def get_max_output_tokens(self, prompt_len: int) -> int:
        return sys.maxsize

179
    def is_cuda_alike(self) -> bool:
180
        """Stateless version of [torch.cuda.is_available][]."""
181
182
        return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)

183
    def is_sleep_mode_available(self) -> bool:
184
185
186
187
188
        # TODO: Actually only mi3xx has the sleep mode support now
        # for ROCm, but currently we don't have a way to detect the
        # exact GPU model statelessly here. So we return True for
        # all ROCm platforms for now.
        return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
189

190
191
192
193
194
195
    @classmethod
    def get_pass_manager_cls(cls) -> str:
        """
        Get the pass manager class for this platform.
        It will be registered as a custom pass under the current_platform.pass_key.
        """
196
        return "vllm.compilation.passes.pass_manager.PostGradPassManager"
197
198
199
200
201
202
203
204

    @classmethod
    def get_compile_backend(cls) -> str:
        """
        Get the custom compile backend for current platform.
        """
        return cls.simple_compile_backend

205
206
    @classmethod
    def device_id_to_physical_device_id(cls, device_id: int):
207
208
209
        # Treat empty device control env var as unset. This is a valid
        # configuration in Ray setups where the engine is launched in
        # a CPU-only placement group located on a GPU node.
210
211
212
213
        if (
            cls.device_control_env_var in os.environ
            and os.environ[cls.device_control_env_var] != ""
        ):
214
215
216
217
218
219
            device_ids = os.environ[cls.device_control_env_var].split(",")
            physical_device_id = device_ids[device_id]
            return int(physical_device_id)
        else:
            return device_id

220
    @classmethod
221
    def import_kernels(cls) -> None:
222
        """Import any platform-specific C kernels."""
223
224
225
226
227
228
        try:
            import vllm._C  # noqa: F401
        except ImportError as e:
            logger.warning("Failed to import from vllm._C: %r", e)
        with contextlib.suppress(ImportError):
            import vllm._moe_C  # noqa: F401
229

230
    @classmethod
231
232
    def get_attn_backend_cls(
        cls,
233
        selected_backend: "AttentionBackendEnum",
234
        attn_selector_config: "AttentionSelectorConfig",
235
        num_heads: int | None = None,
236
    ) -> str:
237
238
        """Get the attention backend class of a device."""
        return ""
239

240
241
242
243
244
245
246
247
248
249
250
    @classmethod
    def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
        return [
            AttentionBackendEnum.TORCH_SDPA,
        ]

    @classmethod
    def get_vit_attn_backend(
        cls,
        head_size: int,
        dtype: torch.dtype,
251
        backend: "AttentionBackendEnum | None" = None,
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
    ) -> "AttentionBackendEnum":
        """
        Get the vision attention backend class of a device.

        NOTE: ViT Attention should be checked and override in the platform-specific
        implementation. we should not override this in any other places, like
        the model_executor/models/<model_name>.py.

        We check if the backend is None or not:
            1. If not, check if the backend is supported by the platform.
            2. If None, continue to the default selection logic.
        """
        if backend is not None:
            assert backend in cls.get_supported_vit_attn_backends(), (
                f"Backend {backend} is not supported for vit attention"
                f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
            )
            logger.info_once(f"Using backend {backend} for vit attention")
            return backend

        logger.info_once(
            f"Using default backend {AttentionBackendEnum.TORCH_SDPA} for vit attention"
        )
        return AttentionBackendEnum.TORCH_SDPA

277
278
279
280
    @classmethod
    def get_device_capability(
        cls,
        device_id: int = 0,
281
    ) -> DeviceCapability | None:
282
        """Stateless version of [torch.cuda.get_device_capability][]."""
283
        return None
284

285
286
287
    @classmethod
    def has_device_capability(
        cls,
288
        capability: tuple[int, int] | int,
289
290
291
292
293
        device_id: int = 0,
    ) -> bool:
        """
        Test whether this platform is compatible with a device capability.

294
        The `capability` argument can either be:
295

296
297
298
        - A tuple `(major, minor)`.
        - An integer `<major><minor>`. (See
        [`DeviceCapability.to_int`][vllm.platforms.interface.DeviceCapability.to_int])
299
300
301
302
303
304
305
306
307
308
        """
        current_capability = cls.get_device_capability(device_id=device_id)
        if current_capability is None:
            return False

        if isinstance(capability, tuple):
            return current_capability >= capability

        return current_capability.to_int() >= capability

309
310
311
    @classmethod
    def is_device_capability(
        cls,
312
        capability: tuple[int, int] | int,
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
        device_id: int = 0,
    ) -> bool:
        """
        Test whether this platform has exactly the specified device capability.

        The `capability` argument can either be:

        - A tuple `(major, minor)`.
        - An integer `<major><minor>`. (See
        [`DeviceCapability.to_int`][vllm.platforms.interface.DeviceCapability.to_int])
        """
        current_capability = cls.get_device_capability(device_id=device_id)
        if current_capability is None:
            return False

        if isinstance(capability, tuple):
            return current_capability == capability

        return current_capability.to_int() == capability

333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
    @classmethod
    def is_device_capability_family(
        cls,
        capability: int,
        device_id: int = 0,
    ) -> bool:
        """
        Returns True if the device capability is any <major>.x.
        Mirrors CUDA 13 'family' architecture semantics (e.g. 10.x, 11.x, 12.x).
        """
        current_capability = cls.get_device_capability(device_id=device_id)
        if current_capability is None:
            return False
        return (current_capability.to_int() // 10) == (capability // 10)

348
349
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
350
351
352
        """Get the name of a device."""
        raise NotImplementedError

353
354
355
356
357
    @classmethod
    def get_device_uuid(cls, device_id: int = 0) -> str:
        """Get the uuid of a device, e.g. the PCI bus ID."""
        raise NotImplementedError

358
359
360
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        """Get the total memory of a device in bytes."""
361
362
        raise NotImplementedError

363
364
    @classmethod
    def inference_mode(cls):
365
366
367
368
369
370
371
372
        """A device-specific wrapper of `torch.inference_mode`.

        This wrapper is recommended because some hardware backends such as TPU
        do not support `torch.inference_mode`. In such a case, they will fall
        back to `torch.no_grad` by overriding this method.
        """
        return torch.inference_mode(mode=True)

373
374
375
376
377
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
378
        raise NotImplementedError
379

380
    @classmethod
381
    def pre_register_and_update(
382
        cls, parser: FlexibleArgumentParser | None = None
383
    ) -> None:
384
        """
385
        Do some pre-registration or update action for the current platform.
386
387
388
389
390
391
392
393
394
395

        This function is called before global VllmConfig is initialized or cli
        arguments are parsed. It's used for out-of-tree platforms to register or
        update the configuration.

        For example, the out-of-tree quantization config can be imported and
        registered here dynamically.
        """
        pass

396
397
398
399
400
401
402
403
404
405
406
407
408
409
    @classmethod
    def apply_config_platform_defaults(cls, vllm_config: "VllmConfig") -> None:
        """
        Apply the platform-specific default values to the config.

        This function is called during the initialization of global VllmConfig, after
        parsing cli arguments.
        It can modify the defaults of the config according to the platform. For example,
        it can enable custom_ops based on the enabled features.

        The config is passed by reference, so it can be modified in place.
        """
        pass

410
    @classmethod
411
    def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
412
413
414
415
416
417
418
419
420
421
422
        """
        Check and update the configuration for the current platform.

        It can raise an exception if the configuration is not compatible with
        the current platform, or it can update the configuration to make it
        compatible with the current platform.

        The config is passed by reference, so it can be modified in place.
        """
        pass

423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
    @classmethod
    def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None:
        """
        Ensure block_size is compatible with the attention backend.
        """
        from vllm.config.cache import CacheConfig

        cache_config = vllm_config.cache_config
        if cache_config.user_specified_block_size:
            # User specified --block-size; keep it.
            return

        model_config = vllm_config.model_config
        # model_config may be None during testing.
        # Skip hybrid models — their block_size is managed by
        # HybridAttentionMambaModelConfig.
        if model_config is None or model_config.is_hybrid:
            cache_config.block_size = CacheConfig.DEFAULT_BLOCK_SIZE
            return

        from vllm.config.vllm import (
            get_layers_from_vllm_config,
            set_current_vllm_config,
        )
        from vllm.model_executor.layers.attention_layer_base import (
            AttentionLayerBase,
        )

        attn_layers = get_layers_from_vllm_config(
            vllm_config,
            AttentionLayerBase,  # type: ignore[type-abstract]
        )
        if not attn_layers:
            cache_config.block_size = CacheConfig.DEFAULT_BLOCK_SIZE
            return

        first_layer = next(iter(attn_layers.values()))
        backend_cls = first_layer.get_attn_backend()
        with set_current_vllm_config(vllm_config):
            preferred = backend_cls.get_preferred_block_size(
                CacheConfig.DEFAULT_BLOCK_SIZE
            )
        if preferred != CacheConfig.DEFAULT_BLOCK_SIZE:
            logger.info(
                "Setting kv cache block size to %d for %s backend.",
                preferred,
                backend_cls.get_name(),
            )
        cache_config.block_size = preferred

473
474
475
476
477
478
479
480
481
482
483
484
    @classmethod
    def verify_model_arch(cls, model_arch: str) -> None:
        """
        Verify whether the current platform supports the specified model
        architecture.

        - This will raise an Error or Warning based on the model support on
        the current platform.
        - By default all models are considered supported.
        """
        pass

485
486
487
488
489
    @classmethod
    def verify_quantization(cls, quant: str) -> None:
        """
        Verify whether the quantization is supported by the current platform.
        """
490
        if cls.supported_quantization and quant not in cls.supported_quantization:
491
            raise ValueError(
492
493
                f"{quant} quantization is currently not supported in {cls.device_name}."
            )
494

495
496
497
498
499
500
501
502
503
504
505
506
507
508
    @classmethod
    def get_cpu_architecture(cls) -> CpuArchEnum:
        """
        Determine the CPU architecture of the current system.
        Returns CpuArchEnum indicating the architecture type.
        """
        machine = platform.machine().lower()

        if machine in ("x86_64", "amd64", "i386", "i686"):
            return CpuArchEnum.X86
        elif machine.startswith("arm") or machine.startswith("aarch"):
            return CpuArchEnum.ARM
        elif machine.startswith("ppc"):
            return CpuArchEnum.POWERPC
509
510
        elif machine == "s390x":
            return CpuArchEnum.S390X
511
512
        elif machine.startswith("riscv"):
            return CpuArchEnum.RISCV
513
514
515

        return CpuArchEnum.OTHER if machine else CpuArchEnum.UNKNOWN

516
517
518
519
520
521
    @classmethod
    def is_pin_memory_available(cls) -> bool:
        """Checks whether pin memory is available on the current platform."""
        if in_wsl():
            # Pinning memory in WSL is not supported.
            # https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
522
523
524
525
            logger.warning(
                "Using 'pin_memory=False' as WSL is detected. "
                "This may slow down the performance."
            )
526
527
528
            return False
        return True

529
    @classmethod
530
    def get_current_memory_usage(
531
        cls, device: torch.types.Device | None = None
532
    ) -> float:
533
534
535
536
537
        """
        Return the memory usage in bytes.
        """
        raise NotImplementedError

538
539
540
541
542
543
544
    @classmethod
    def get_punica_wrapper(cls) -> str:
        """
        Return the punica wrapper for current platform.
        """
        raise NotImplementedError

545
    @classmethod
546
    def get_infinity_values(cls, dtype: torch.dtype) -> tuple[float, float]:
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
        """
        Return the platform specific values for (-inf, inf)
        """
        return float("-inf"), float("inf")

    @classmethod
    def can_update_inplace(cls) -> bool:
        """
        Checks if the platform allows inplace memory updates
        """
        return True

    @classmethod
    def get_lora_vocab_padding_size(cls) -> int:
        """
        Returns how much padding the LoRA logits need for kernels
        """
        return 256

566
567
568
569
570
    @classmethod
    def get_device_communicator_cls(cls) -> str:
        """
        Get device specific communicator class for distributed communication.
        """
Mengqing Cao's avatar
Mengqing Cao committed
571
        return "vllm.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase"  # noqa
572

573
574
575
576
577
578
579
    @classmethod
    def supports_mx(cls) -> bool:
        """
        Returns whether the current platform supports MX types.
        """
        return False

580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
    @classmethod
    def supports_fp8(cls) -> bool:
        """
        Returns whether the current platform supports FP8 types.
        """
        return False

    @classmethod
    def is_fp8_fnuz(cls) -> bool:
        """
        Returns whether the preferred FP8 type is FNUZ on the current platform.

        There are two representations of FP8, OCP FP8 and FNUZ FP8.
        The OCP specification can be found at https://tinyurl.com/b7jvwpft.
        The FNUZ specification can be found at https://tinyurl.com/5n6hwwu5.

        AMD's MI300 and MI325 have native hardware support for FNUZ. All other
        hardware has converged on the OCP FP8 standard.
        """
        return False

    @classmethod
    def fp8_dtype(cls) -> torch.dtype:
        """
        Returns the preferred FP8 type on the current platform.

        See the documentation for is_fp8_fnuz for details.
        """
        return torch.float8_e4m3fn

610
611
612
613
614
    @classmethod
    def use_all_gather(cls) -> bool:
        """
        Whether to use allgather in LogitsProcessor to gather the logits.
        """
615
        return True
616

617
618
619
620
621
622
623
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        """
        Returns if custom allreduce is supported on the current platform
        """
        return False

624
625
626
627
628
629
630
631
    @classmethod
    def opaque_attention_op(cls) -> bool:
        """
        Returns True if we register attention as one giant opaque custom op
        on the current platform
        """
        return False

632
633
634
    @classmethod
    def validate_request(
        cls,
635
        processed_inputs: "ProcessorInputs",
636
        params: "SamplingParams | PoolingParams",
637
638
639
    ) -> None:
        """Raises if this request is unsupported on this platform"""

640
    def __getattr__(self, key: str):
641
        device = getattr(torch, self.device_type, None)
642
        if device is not None and hasattr(device, key):
643
644
645
646
647
648
649
650
651
652
653
654
            attr = getattr(device, key)
            # NOTE: `hasattr(device, key)=True` can only avoid AttributeError,
            # but the value of this attr could be `None`.
            if attr is not None:
                return attr

        logger.warning(
            "Current platform %s does not have '%s' attribute.",
            self.device_type,
            key,
        )
        return None
655

656
657
    def get_global_graph_pool(self) -> Any:
        """
658
        Return the global graph pool for this platform.
659
660
661
662
663
664
        """
        cls = self.__class__
        if cls._global_graph_pool is None:
            cls._global_graph_pool = self.graph_pool_handle()
        return cls._global_graph_pool

665
    @classmethod
666
    def get_static_graph_wrapper_cls(cls) -> str:
667
        """
668
        Get static graph wrapper class for static graph.
669
        """
670
        return "vllm.compilation.base_static_graph.AbstractStaticGraphWrapper"
671

672
673
674
675
    @classmethod
    def stateless_init_device_torch_dist_pg(
        cls,
        backend: str,
676
        prefix_store: "PrefixStore",
677
678
679
        group_rank: int,
        group_size: int,
        timeout: timedelta,
680
    ) -> "ProcessGroup":
681
682
683
        """
        Init platform-specific torch distributed process group.
        """
684
        raise NotImplementedError
685

686
    @classmethod
687
    def check_if_supports_dtype(cls, dtype: torch.dtype):
688
689
690
691
692
        """
        Check if the dtype is supported by the current platform.
        """
        raise NotImplementedError

693
694
695
696
697
698
699
    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        """
        Returns if the hybrid kv cache is supported by the current platform.
        """
        return False

700
701
702
703
704
705
706
    @classmethod
    def support_static_graph_mode(cls) -> bool:
        """
        Returns if the graph mode is supported by the current platform.
        """
        return False

707
708
709
710
711
712
713
714
715
    @classmethod
    def use_custom_op_collectives(cls) -> bool:
        """
        Whether this platform should use torch.ops.vllm.* custom ops for collectives.

        Returns False by default - platforms must explicitly opt-in.
        """
        return False

716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
    @classmethod
    def use_sync_weight_loader(cls) -> bool:
        """
        Returns if the current platform needs to sync weight loader.
        """
        return False

    @classmethod
    def make_synced_weight_loader(cls, original_weight_loader):
        """
        Wrap the original weight loader to make it synced.
        """
        if not cls.use_sync_weight_loader():
            return original_weight_loader

        def _synced_weight_loader(param, *args, **kwargs):
            out = original_weight_loader(param, *args, **kwargs)
            if param.device != torch.device("cpu"):
                torch._sync(param)
            return out

        return _synced_weight_loader

739
740
741
    @classmethod
    def get_nixl_supported_devices(cls) -> dict[str, tuple[str, ...]]:
        """
742
        Returns a mapping from device_type to a tuple of supported
743
744
745
746
747
        kv_buffer_device for nixl.
        """
        return {}

    @classmethod
748
    def get_nixl_memory_type(cls) -> str | None:
749
750
751
752
753
        """
        Returns the nixl memory type for the current platform.
        """
        return None

754
755
756
757
758
759
760
    @classmethod
    def check_max_model_len(cls, max_model_len: int) -> int:
        """
        Check max_model_len for the current platform.
        """
        return max_model_len

761
762
763
764
765
766
767
    @classmethod
    def set_additional_forward_context(cls, *args, **kwargs) -> dict[str, Any]:
        """
        Set some additional forward context for the current platform if needs.
        """
        return {}

768
769
770
771
772
773
774
775
776
777
    @classmethod
    def num_compute_units(cls, device_id: int = 0) -> int:
        """
        Get the number of compute units for the current platform.
        (NVIDIA SM / AMD CU / Intel EU)
        """
        raise NotImplementedError(
            "num_compute_units is not implemented for the current platform."
        )

778
779
780

class UnspecifiedPlatform(Platform):
    _enum = PlatformEnum.UNSPECIFIED
781
    device_type = ""