interface.py 19.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import contextlib
4
import enum
5
import os
6
import platform
7
import random
8
import sys
9
from datetime import timedelta
10
from typing import TYPE_CHECKING, Any, NamedTuple
11

12
import numpy as np
13
14
import torch

15
16
from vllm.logger import init_logger

17
if TYPE_CHECKING:
18
19
    from torch.distributed import PrefixStore, ProcessGroup

20
    from vllm.attention.backends.registry import AttentionBackendEnum
21
    from vllm.config import VllmConfig
22
    from vllm.config.cache import CacheDType
23
    from vllm.inputs import ProcessorInputs, PromptType
24
25
    from vllm.pooling_params import PoolingParams
    from vllm.sampling_params import SamplingParams
26
    from vllm.utils.argparse_utils import FlexibleArgumentParser
27
else:
28
    FlexibleArgumentParser = object
29

30
31
logger = init_logger(__name__)

32

33
34
def in_wsl() -> bool:
    # Reference: https://github.com/microsoft/WSL/issues/4071
35
    return "microsoft" in " ".join(platform.uname()).lower()
36
37


38
39
40
class PlatformEnum(enum.Enum):
    CUDA = enum.auto()
    ROCM = enum.auto()
41
    TPU = enum.auto()
42
    XPU = enum.auto()
43
    CPU = enum.auto()
44
    OOT = enum.auto()
45
    UNSPECIFIED = enum.auto()
46
47


48
49
50
51
class CpuArchEnum(enum.Enum):
    X86 = enum.auto()
    ARM = enum.auto()
    POWERPC = enum.auto()
52
    S390X = enum.auto()
53
    RISCV = enum.auto()
54
55
56
57
    OTHER = enum.auto()
    UNKNOWN = enum.auto()


58
59
60
61
class DeviceCapability(NamedTuple):
    major: int
    minor: int

62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
    def __lt__(self, other: Any) -> bool:
        if not isinstance(other, DeviceCapability):
            return NotImplemented
        return (self.major, self.minor) < (other.major, other.minor)

    def __le__(self, other: Any) -> bool:
        if not isinstance(other, DeviceCapability):
            return NotImplemented
        return (self.major, self.minor) <= (other.major, other.minor)

    def __eq__(self, other: Any) -> bool:
        if not isinstance(other, DeviceCapability):
            return NotImplemented
        return (self.major, self.minor) == (other.major, other.minor)

    def __ge__(self, other: Any) -> bool:
        if not isinstance(other, DeviceCapability):
            return NotImplemented
        return (self.major, self.minor) >= (other.major, other.minor)

    def __gt__(self, other: Any) -> bool:
        if not isinstance(other, DeviceCapability):
            return NotImplemented
        return (self.major, self.minor) > (other.major, other.minor)

87
88
89
90
91
    def as_version_str(self) -> str:
        return f"{self.major}.{self.minor}"

    def to_int(self) -> int:
        """
92
        Express device capability as an integer `<major><minor>`.
93
94
95
96
97
98
99

        It is assumed that the minor version is always a single digit.
        """
        assert 0 <= self.minor < 10
        return self.major * 10 + self.minor


100
101
class Platform:
    _enum: PlatformEnum
102
    device_name: str
103
    device_type: str
104

105
106
107
108
    # available dispatch keys:
    # check https://github.com/pytorch/pytorch/blob/313dac6c1ca0fa0cde32477509cce32089f8532a/torchgen/model.py#L134 # noqa
    # use "CPU" as a fallback for platforms not registered in PyTorch
    dispatch_key: str = "CPU"
109

110
111
112
113
    # available ray device keys:
    # https://github.com/ray-project/ray/blob/10ba5adadcc49c60af2c358a33bb943fb491a171/python/ray/_private/ray_constants.py#L438 # noqa
    # empty string means the device does not support ray
    ray_device_key: str = ""
114
115
116
117
118
119
120

    # platform-agnostic way to specify the device control environment variable,
    # .e.g. CUDA_VISIBLE_DEVICES for CUDA.
    # hint: search for "get_visible_accelerator_ids_env_var" in
    # https://github.com/ray-project/ray/tree/master/python/ray/_private/accelerators # noqa
    device_control_env_var: str = "VLLM_DEVICE_CONTROL_ENV_VAR_PLACEHOLDER"

121
122
123
124
125
126
    # The torch.compile backend for compiling simple and
    # standalone functions. The default value is "inductor" to keep
    # the same behavior as PyTorch.
    # NOTE: for the forward part of the model, vLLM has another separate
    # compilation strategy.
    simple_compile_backend: str = "inductor"
127

128
129
130
    # The backend used for distributed communication.
    dist_backend: str = ""

131
    supported_quantization: list[str] = []
132

133
134
    additional_env_vars: list[str] = []

135
    _global_graph_pool: Any | None = None
136

137
138
139
140
141
142
143
144
    @property
    def supported_dtypes(self) -> list[torch.dtype]:
        """Returns the supported dtypes for the current platform."""
        # Be careful with the order of the dtypes. The first dtype will
        # be used as the default dtype fallback for the current platform,
        # when encountering unsupported dtypes in "auto" dtype.
        return [torch.bfloat16, torch.float16, torch.float32]

145
146
147
148
149
150
    def is_cuda(self) -> bool:
        return self._enum == PlatformEnum.CUDA

    def is_rocm(self) -> bool:
        return self._enum == PlatformEnum.ROCM

151
152
153
    def is_tpu(self) -> bool:
        return self._enum == PlatformEnum.TPU

154
155
156
    def is_xpu(self) -> bool:
        return self._enum == PlatformEnum.XPU

157
158
159
    def is_cpu(self) -> bool:
        return self._enum == PlatformEnum.CPU

160
161
162
    def is_out_of_tree(self) -> bool:
        return self._enum == PlatformEnum.OOT

163
164
165
    def is_unspecified(self) -> bool:
        return self._enum == PlatformEnum.UNSPECIFIED

166
167
168
    def get_max_output_tokens(self, prompt_len: int) -> int:
        return sys.maxsize

169
    def is_cuda_alike(self) -> bool:
170
        """Stateless version of [torch.cuda.is_available][]."""
171
172
        return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)

173
    def is_sleep_mode_available(self) -> bool:
174
175
176
177
178
        # TODO: Actually only mi3xx has the sleep mode support now
        # for ROCm, but currently we don't have a way to detect the
        # exact GPU model statelessly here. So we return True for
        # all ROCm platforms for now.
        return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
179

180
181
    @classmethod
    def device_id_to_physical_device_id(cls, device_id: int):
182
183
184
        # Treat empty device control env var as unset. This is a valid
        # configuration in Ray setups where the engine is launched in
        # a CPU-only placement group located on a GPU node.
185
186
187
188
        if (
            cls.device_control_env_var in os.environ
            and os.environ[cls.device_control_env_var] != ""
        ):
189
190
191
192
193
194
            device_ids = os.environ[cls.device_control_env_var].split(",")
            physical_device_id = device_ids[device_id]
            return int(physical_device_id)
        else:
            return device_id

195
    @classmethod
196
    def import_kernels(cls) -> None:
197
        """Import any platform-specific C kernels."""
198
199
200
201
202
203
        try:
            import vllm._C  # noqa: F401
        except ImportError as e:
            logger.warning("Failed to import from vllm._C: %r", e)
        with contextlib.suppress(ImportError):
            import vllm._moe_C  # noqa: F401
204

205
    @classmethod
206
207
208
209
210
    def get_vit_attn_backend(
        cls, head_size: int, dtype: torch.dtype
    ) -> "AttentionBackendEnum":
        # Import AttentionBackendEnum here to avoid circular import.
        from vllm.attention.backends.registry import AttentionBackendEnum
211

212
        return AttentionBackendEnum.TORCH_SDPA
213

214
    @classmethod
215
216
    def get_attn_backend_cls(
        cls,
217
        selected_backend: "AttentionBackendEnum",
218
219
        head_size: int,
        dtype: torch.dtype,
220
        kv_cache_dtype: "CacheDType | None",
221
222
223
224
225
        block_size: int,
        use_mla: bool,
        has_sink: bool,
        use_sparse: bool,
    ) -> str:
226
227
        """Get the attention backend class of a device."""
        return ""
228

229
230
231
232
    @classmethod
    def get_device_capability(
        cls,
        device_id: int = 0,
233
    ) -> DeviceCapability | None:
234
        """Stateless version of [torch.cuda.get_device_capability][]."""
235
        return None
236

237
238
239
    @classmethod
    def has_device_capability(
        cls,
240
        capability: tuple[int, int] | int,
241
242
243
244
245
        device_id: int = 0,
    ) -> bool:
        """
        Test whether this platform is compatible with a device capability.

246
        The `capability` argument can either be:
247

248
249
250
        - A tuple `(major, minor)`.
        - An integer `<major><minor>`. (See
        [`DeviceCapability.to_int`][vllm.platforms.interface.DeviceCapability.to_int])
251
252
253
254
255
256
257
258
259
260
        """
        current_capability = cls.get_device_capability(device_id=device_id)
        if current_capability is None:
            return False

        if isinstance(capability, tuple):
            return current_capability >= capability

        return current_capability.to_int() >= capability

261
262
263
    @classmethod
    def is_device_capability(
        cls,
264
        capability: tuple[int, int] | int,
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
        device_id: int = 0,
    ) -> bool:
        """
        Test whether this platform has exactly the specified device capability.

        The `capability` argument can either be:

        - A tuple `(major, minor)`.
        - An integer `<major><minor>`. (See
        [`DeviceCapability.to_int`][vllm.platforms.interface.DeviceCapability.to_int])
        """
        current_capability = cls.get_device_capability(device_id=device_id)
        if current_capability is None:
            return False

        if isinstance(capability, tuple):
            return current_capability == capability

        return current_capability.to_int() == capability

285
286
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
287
288
289
        """Get the name of a device."""
        raise NotImplementedError

290
291
292
293
294
    @classmethod
    def get_device_uuid(cls, device_id: int = 0) -> str:
        """Get the uuid of a device, e.g. the PCI bus ID."""
        raise NotImplementedError

295
296
297
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        """Get the total memory of a device in bytes."""
298
299
        raise NotImplementedError

300
301
    @classmethod
    def inference_mode(cls):
302
303
304
305
306
307
308
309
        """A device-specific wrapper of `torch.inference_mode`.

        This wrapper is recommended because some hardware backends such as TPU
        do not support `torch.inference_mode`. In such a case, they will fall
        back to `torch.no_grad` by overriding this method.
        """
        return torch.inference_mode(mode=True)

310
    @classmethod
311
    def seed_everything(cls, seed: int | None = None) -> None:
312
313
314
315
316
317
        """
        Set the seed of each random module.
        `torch.manual_seed` will set seed on all devices.

        Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
        """
318
319
320
321
        if seed is not None:
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)
322

323
324
325
326
327
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
328
        raise NotImplementedError
329

330
    @classmethod
331
    def pre_register_and_update(
332
        cls, parser: FlexibleArgumentParser | None = None
333
    ) -> None:
334
        """
335
        Do some pre-registration or update action for the current platform.
336
337
338
339
340
341
342
343
344
345

        This function is called before global VllmConfig is initialized or cli
        arguments are parsed. It's used for out-of-tree platforms to register or
        update the configuration.

        For example, the out-of-tree quantization config can be imported and
        registered here dynamically.
        """
        pass

346
    @classmethod
347
    def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
348
349
350
351
352
353
354
355
356
357
358
        """
        Check and update the configuration for the current platform.

        It can raise an exception if the configuration is not compatible with
        the current platform, or it can update the configuration to make it
        compatible with the current platform.

        The config is passed by reference, so it can be modified in place.
        """
        pass

359
360
361
362
363
364
365
366
367
368
369
370
    @classmethod
    def verify_model_arch(cls, model_arch: str) -> None:
        """
        Verify whether the current platform supports the specified model
        architecture.

        - This will raise an Error or Warning based on the model support on
        the current platform.
        - By default all models are considered supported.
        """
        pass

371
372
373
374
375
    @classmethod
    def verify_quantization(cls, quant: str) -> None:
        """
        Verify whether the quantization is supported by the current platform.
        """
376
        if cls.supported_quantization and quant not in cls.supported_quantization:
377
            raise ValueError(
378
379
                f"{quant} quantization is currently not supported in {cls.device_name}."
            )
380

381
382
383
384
385
386
387
388
389
390
391
392
393
394
    @classmethod
    def get_cpu_architecture(cls) -> CpuArchEnum:
        """
        Determine the CPU architecture of the current system.
        Returns CpuArchEnum indicating the architecture type.
        """
        machine = platform.machine().lower()

        if machine in ("x86_64", "amd64", "i386", "i686"):
            return CpuArchEnum.X86
        elif machine.startswith("arm") or machine.startswith("aarch"):
            return CpuArchEnum.ARM
        elif machine.startswith("ppc"):
            return CpuArchEnum.POWERPC
395
396
        elif machine == "s390x":
            return CpuArchEnum.S390X
397
398
        elif machine.startswith("riscv"):
            return CpuArchEnum.RISCV
399
400
401

        return CpuArchEnum.OTHER if machine else CpuArchEnum.UNKNOWN

402
403
404
405
406
407
    @classmethod
    def is_pin_memory_available(cls) -> bool:
        """Checks whether pin memory is available on the current platform."""
        if in_wsl():
            # Pinning memory in WSL is not supported.
            # https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
408
409
410
411
            logger.warning(
                "Using 'pin_memory=False' as WSL is detected. "
                "This may slow down the performance."
            )
412
413
414
            return False
        return True

415
    @classmethod
416
    def get_current_memory_usage(
417
        cls, device: torch.types.Device | None = None
418
    ) -> float:
419
420
421
422
423
        """
        Return the memory usage in bytes.
        """
        raise NotImplementedError

424
425
426
427
428
429
430
    @classmethod
    def get_punica_wrapper(cls) -> str:
        """
        Return the punica wrapper for current platform.
        """
        raise NotImplementedError

431
    @classmethod
432
    def get_infinity_values(cls, dtype: torch.dtype) -> tuple[float, float]:
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
        """
        Return the platform specific values for (-inf, inf)
        """
        return float("-inf"), float("inf")

    @classmethod
    def can_update_inplace(cls) -> bool:
        """
        Checks if the platform allows inplace memory updates
        """
        return True

    @classmethod
    def get_lora_vocab_padding_size(cls) -> int:
        """
        Returns how much padding the LoRA logits need for kernels
        """
        return 256

452
453
454
455
456
    @classmethod
    def get_device_communicator_cls(cls) -> str:
        """
        Get device specific communicator class for distributed communication.
        """
Mengqing Cao's avatar
Mengqing Cao committed
457
        return "vllm.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase"  # noqa
458

459
460
461
462
463
464
465
    @classmethod
    def supports_mx(cls) -> bool:
        """
        Returns whether the current platform supports MX types.
        """
        return False

466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
    @classmethod
    def supports_fp8(cls) -> bool:
        """
        Returns whether the current platform supports FP8 types.
        """
        return False

    @classmethod
    def is_fp8_fnuz(cls) -> bool:
        """
        Returns whether the preferred FP8 type is FNUZ on the current platform.

        There are two representations of FP8, OCP FP8 and FNUZ FP8.
        The OCP specification can be found at https://tinyurl.com/b7jvwpft.
        The FNUZ specification can be found at https://tinyurl.com/5n6hwwu5.

        AMD's MI300 and MI325 have native hardware support for FNUZ. All other
        hardware has converged on the OCP FP8 standard.
        """
        return False

    @classmethod
    def fp8_dtype(cls) -> torch.dtype:
        """
        Returns the preferred FP8 type on the current platform.

        See the documentation for is_fp8_fnuz for details.
        """
        return torch.float8_e4m3fn

496
497
498
499
500
    @classmethod
    def use_all_gather(cls) -> bool:
        """
        Whether to use allgather in LogitsProcessor to gather the logits.
        """
501
        return True
502

503
504
505
506
507
508
509
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        """
        Returns if custom allreduce is supported on the current platform
        """
        return False

510
511
512
513
514
515
516
517
    @classmethod
    def opaque_attention_op(cls) -> bool:
        """
        Returns True if we register attention as one giant opaque custom op
        on the current platform
        """
        return False

518
519
520
    @classmethod
    def validate_request(
        cls,
521
522
523
        prompt: "PromptType",
        params: "SamplingParams | PoolingParams",
        processed_inputs: "ProcessorInputs",
524
525
526
    ) -> None:
        """Raises if this request is unsupported on this platform"""

527
    def __getattr__(self, key: str):
528
        device = getattr(torch, self.device_type, None)
529
530
531
        if device is not None and hasattr(device, key):
            return getattr(device, key)
        else:
532
533
534
535
536
            logger.warning(
                "Current platform %s does not have '%s' attribute.",
                self.device_type,
                key,
            )
537
538
            return None

539
540
    def get_global_graph_pool(self) -> Any:
        """
541
        Return the global graph pool for this platform.
542
543
544
545
546
547
        """
        cls = self.__class__
        if cls._global_graph_pool is None:
            cls._global_graph_pool = self.graph_pool_handle()
        return cls._global_graph_pool

548
    @classmethod
549
    def get_static_graph_wrapper_cls(cls) -> str:
550
        """
551
        Get static graph wrapper class for static graph.
552
        """
553
        return "vllm.compilation.base_static_graph.AbstractStaticGraphWrapper"
554

555
556
557
558
    @classmethod
    def stateless_init_device_torch_dist_pg(
        cls,
        backend: str,
559
        prefix_store: "PrefixStore",
560
561
562
        group_rank: int,
        group_size: int,
        timeout: timedelta,
563
    ) -> "ProcessGroup":
564
565
566
        """
        Init platform-specific torch distributed process group.
        """
567
        raise NotImplementedError
568

569
    @classmethod
570
    def check_if_supports_dtype(cls, dtype: torch.dtype):
571
572
573
574
575
        """
        Check if the dtype is supported by the current platform.
        """
        raise NotImplementedError

576
577
578
579
580
581
582
    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        """
        Returns if the hybrid kv cache is supported by the current platform.
        """
        return False

583
584
585
586
587
588
589
    @classmethod
    def support_static_graph_mode(cls) -> bool:
        """
        Returns if the graph mode is supported by the current platform.
        """
        return False

590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
    @classmethod
    def use_sync_weight_loader(cls) -> bool:
        """
        Returns if the current platform needs to sync weight loader.
        """
        return False

    @classmethod
    def make_synced_weight_loader(cls, original_weight_loader):
        """
        Wrap the original weight loader to make it synced.
        """
        if not cls.use_sync_weight_loader():
            return original_weight_loader

        def _synced_weight_loader(param, *args, **kwargs):
            out = original_weight_loader(param, *args, **kwargs)
            if param.device != torch.device("cpu"):
                torch._sync(param)
            return out

        return _synced_weight_loader

613
614
615
    @classmethod
    def get_nixl_supported_devices(cls) -> dict[str, tuple[str, ...]]:
        """
616
        Returns a mapping from device_type to a tuple of supported
617
618
619
620
621
        kv_buffer_device for nixl.
        """
        return {}

    @classmethod
622
    def get_nixl_memory_type(cls) -> str | None:
623
624
625
626
627
        """
        Returns the nixl memory type for the current platform.
        """
        return None

628
629
630
631
632
633
634
    @classmethod
    def check_max_model_len(cls, max_model_len: int) -> int:
        """
        Check max_model_len for the current platform.
        """
        return max_model_len

635
636
637

class UnspecifiedPlatform(Platform):
    _enum = PlatformEnum.UNSPECIFIED
638
    device_type = ""