interface.py 19.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import contextlib
4
import enum
5
import os
6
import platform
7
import random
8
import sys
9
from datetime import timedelta
10
from platform import uname
11
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union
12

13
import numpy as np
14
import torch
15
from torch.distributed import PrefixStore, ProcessGroup
16

17
from vllm.inputs import ProcessorInputs, PromptType
18
19
from vllm.logger import init_logger

20
if TYPE_CHECKING:
21
    from vllm.attention.backends.registry import _Backend
22
    from vllm.config import ModelConfig, VllmConfig
23
24
25
    from vllm.lora.request import LoRARequest
    from vllm.pooling_params import PoolingParams
    from vllm.sampling_params import SamplingParams
26
    from vllm.utils import FlexibleArgumentParser
27
else:
28
    _Backend = None
29
    ModelConfig = None
30
    VllmConfig = None
31
32
33
    LoRARequest = None
    PoolingParams = None
    SamplingParams = None
34
    FlexibleArgumentParser = None
35

36
37
logger = init_logger(__name__)

38

39
40
41
42
43
def in_wsl() -> bool:
    # Reference: https://github.com/microsoft/WSL/issues/4071
    return "microsoft" in " ".join(uname()).lower()


44
45
46
class PlatformEnum(enum.Enum):
    CUDA = enum.auto()
    ROCM = enum.auto()
47
    TPU = enum.auto()
48
    XPU = enum.auto()
49
    CPU = enum.auto()
50
    OOT = enum.auto()
51
    UNSPECIFIED = enum.auto()
52
53


54
55
56
57
class CpuArchEnum(enum.Enum):
    X86 = enum.auto()
    ARM = enum.auto()
    POWERPC = enum.auto()
58
    S390X = enum.auto()
59
    RISCV = enum.auto()
60
61
62
63
    OTHER = enum.auto()
    UNKNOWN = enum.auto()


64
65
66
67
68
69
70
71
72
class DeviceCapability(NamedTuple):
    major: int
    minor: int

    def as_version_str(self) -> str:
        return f"{self.major}.{self.minor}"

    def to_int(self) -> int:
        """
73
        Express device capability as an integer `<major><minor>`.
74
75
76
77
78
79
80

        It is assumed that the minor version is always a single digit.
        """
        assert 0 <= self.minor < 10
        return self.major * 10 + self.minor


81
82
class Platform:
    _enum: PlatformEnum
83
    device_name: str
84
    device_type: str
85

86
87
88
89
    # available dispatch keys:
    # check https://github.com/pytorch/pytorch/blob/313dac6c1ca0fa0cde32477509cce32089f8532a/torchgen/model.py#L134 # noqa
    # use "CPU" as a fallback for platforms not registered in PyTorch
    dispatch_key: str = "CPU"
90

91
92
93
94
    # available ray device keys:
    # https://github.com/ray-project/ray/blob/10ba5adadcc49c60af2c358a33bb943fb491a171/python/ray/_private/ray_constants.py#L438 # noqa
    # empty string means the device does not support ray
    ray_device_key: str = ""
95
96
97
98
99
100
101

    # platform-agnostic way to specify the device control environment variable,
    # .e.g. CUDA_VISIBLE_DEVICES for CUDA.
    # hint: search for "get_visible_accelerator_ids_env_var" in
    # https://github.com/ray-project/ray/tree/master/python/ray/_private/accelerators # noqa
    device_control_env_var: str = "VLLM_DEVICE_CONTROL_ENV_VAR_PLACEHOLDER"

102
103
104
105
106
107
    # The torch.compile backend for compiling simple and
    # standalone functions. The default value is "inductor" to keep
    # the same behavior as PyTorch.
    # NOTE: for the forward part of the model, vLLM has another separate
    # compilation strategy.
    simple_compile_backend: str = "inductor"
108

109
110
111
    # The backend used for distributed communication.
    dist_backend: str = ""

112
    supported_quantization: list[str] = []
113

114
115
    additional_env_vars: list[str] = []

116
117
    _global_graph_pool: Optional[Any] = None

118
119
120
121
122
123
124
125
    @property
    def supported_dtypes(self) -> list[torch.dtype]:
        """Returns the supported dtypes for the current platform."""
        # Be careful with the order of the dtypes. The first dtype will
        # be used as the default dtype fallback for the current platform,
        # when encountering unsupported dtypes in "auto" dtype.
        return [torch.bfloat16, torch.float16, torch.float32]

126
127
128
129
130
131
    def is_cuda(self) -> bool:
        return self._enum == PlatformEnum.CUDA

    def is_rocm(self) -> bool:
        return self._enum == PlatformEnum.ROCM

132
133
134
    def is_tpu(self) -> bool:
        return self._enum == PlatformEnum.TPU

135
136
137
    def is_xpu(self) -> bool:
        return self._enum == PlatformEnum.XPU

138
139
140
    def is_cpu(self) -> bool:
        return self._enum == PlatformEnum.CPU

141
142
143
    def is_out_of_tree(self) -> bool:
        return self._enum == PlatformEnum.OOT

144
145
146
    def is_unspecified(self) -> bool:
        return self._enum == PlatformEnum.UNSPECIFIED

147
148
149
    def get_max_output_tokens(self, prompt_len: int) -> int:
        return sys.maxsize

150
    def is_cuda_alike(self) -> bool:
151
        """Stateless version of [torch.cuda.is_available][]."""
152
153
        return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)

154
155
156
    def is_sleep_mode_available(self) -> bool:
        return self._enum == PlatformEnum.CUDA

157
158
    @classmethod
    def device_id_to_physical_device_id(cls, device_id: int):
159
160
161
        # Treat empty device control env var as unset. This is a valid
        # configuration in Ray setups where the engine is launched in
        # a CPU-only placement group located on a GPU node.
162
163
164
165
        if (
            cls.device_control_env_var in os.environ
            and os.environ[cls.device_control_env_var] != ""
        ):
166
167
168
169
170
171
            device_ids = os.environ[cls.device_control_env_var].split(",")
            physical_device_id = device_ids[device_id]
            return int(physical_device_id)
        else:
            return device_id

172
    @classmethod
173
    def import_kernels(cls) -> None:
174
        """Import any platform-specific C kernels."""
175
176
177
178
179
180
        try:
            import vllm._C  # noqa: F401
        except ImportError as e:
            logger.warning("Failed to import from vllm._C: %r", e)
        with contextlib.suppress(ImportError):
            import vllm._moe_C  # noqa: F401
181

182
    @classmethod
183
    def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend":
184
        from vllm.attention.backends.registry import _Backend
185

186
187
        return _Backend.TORCH_SDPA

188
    @classmethod
189
190
191
192
193
194
195
196
197
198
199
200
    def get_attn_backend_cls(
        cls,
        selected_backend: "_Backend",
        head_size: int,
        dtype: torch.dtype,
        kv_cache_dtype: Optional[str],
        block_size: int,
        use_v1: bool,
        use_mla: bool,
        has_sink: bool,
        use_sparse: bool,
    ) -> str:
201
202
        """Get the attention backend class of a device."""
        return ""
203

204
205
206
207
208
    @classmethod
    def get_device_capability(
        cls,
        device_id: int = 0,
    ) -> Optional[DeviceCapability]:
209
        """Stateless version of [torch.cuda.get_device_capability][]."""
210
        return None
211

212
213
214
    @classmethod
    def has_device_capability(
        cls,
215
        capability: Union[tuple[int, int], int],
216
217
218
219
220
        device_id: int = 0,
    ) -> bool:
        """
        Test whether this platform is compatible with a device capability.

221
        The `capability` argument can either be:
222

223
224
225
        - A tuple `(major, minor)`.
        - An integer `<major><minor>`. (See
        [`DeviceCapability.to_int`][vllm.platforms.interface.DeviceCapability.to_int])
226
227
228
229
230
231
232
233
234
235
        """
        current_capability = cls.get_device_capability(device_id=device_id)
        if current_capability is None:
            return False

        if isinstance(capability, tuple):
            return current_capability >= capability

        return current_capability.to_int() >= capability

236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
    @classmethod
    def is_device_capability(
        cls,
        capability: Union[tuple[int, int], int],
        device_id: int = 0,
    ) -> bool:
        """
        Test whether this platform has exactly the specified device capability.

        The `capability` argument can either be:

        - A tuple `(major, minor)`.
        - An integer `<major><minor>`. (See
        [`DeviceCapability.to_int`][vllm.platforms.interface.DeviceCapability.to_int])
        """
        current_capability = cls.get_device_capability(device_id=device_id)
        if current_capability is None:
            return False

        if isinstance(capability, tuple):
            return current_capability == capability

        return current_capability.to_int() == capability

260
261
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
262
263
264
        """Get the name of a device."""
        raise NotImplementedError

265
266
267
268
269
    @classmethod
    def get_device_uuid(cls, device_id: int = 0) -> str:
        """Get the uuid of a device, e.g. the PCI bus ID."""
        raise NotImplementedError

270
271
272
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        """Get the total memory of a device in bytes."""
273
274
        raise NotImplementedError

275
276
    @classmethod
    def inference_mode(cls):
277
278
279
280
281
282
283
284
        """A device-specific wrapper of `torch.inference_mode`.

        This wrapper is recommended because some hardware backends such as TPU
        do not support `torch.inference_mode`. In such a case, they will fall
        back to `torch.no_grad` by overriding this method.
        """
        return torch.inference_mode(mode=True)

285
    @classmethod
286
    def seed_everything(cls, seed: Optional[int] = None) -> None:
287
288
289
290
291
292
        """
        Set the seed of each random module.
        `torch.manual_seed` will set seed on all devices.

        Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
        """
293
294
295
296
        if seed is not None:
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)
297

298
299
300
301
302
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
303
        raise NotImplementedError
304

305
    @classmethod
306
307
308
    def pre_register_and_update(
        cls, parser: Optional[FlexibleArgumentParser] = None
    ) -> None:
309
        """
310
        Do some pre-registration or update action for the current platform.
311
312
313
314
315
316
317
318
319
320

        This function is called before global VllmConfig is initialized or cli
        arguments are parsed. It's used for out-of-tree platforms to register or
        update the configuration.

        For example, the out-of-tree quantization config can be imported and
        registered here dynamically.
        """
        pass

321
322
323
324
325
326
327
328
329
330
331
332
333
    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
        """
        Check and update the configuration for the current platform.

        It can raise an exception if the configuration is not compatible with
        the current platform, or it can update the configuration to make it
        compatible with the current platform.

        The config is passed by reference, so it can be modified in place.
        """
        pass

334
335
336
337
338
339
340
341
342
343
344
345
    @classmethod
    def verify_model_arch(cls, model_arch: str) -> None:
        """
        Verify whether the current platform supports the specified model
        architecture.

        - This will raise an Error or Warning based on the model support on
        the current platform.
        - By default all models are considered supported.
        """
        pass

346
347
348
349
350
    @classmethod
    def verify_quantization(cls, quant: str) -> None:
        """
        Verify whether the quantization is supported by the current platform.
        """
351
        if cls.supported_quantization and quant not in cls.supported_quantization:
352
            raise ValueError(
353
354
                f"{quant} quantization is currently not supported in {cls.device_name}."
            )
355

356
357
358
359
360
361
362
363
364
365
366
367
368
369
    @classmethod
    def get_cpu_architecture(cls) -> CpuArchEnum:
        """
        Determine the CPU architecture of the current system.
        Returns CpuArchEnum indicating the architecture type.
        """
        machine = platform.machine().lower()

        if machine in ("x86_64", "amd64", "i386", "i686"):
            return CpuArchEnum.X86
        elif machine.startswith("arm") or machine.startswith("aarch"):
            return CpuArchEnum.ARM
        elif machine.startswith("ppc"):
            return CpuArchEnum.POWERPC
370
371
        elif machine == "s390x":
            return CpuArchEnum.S390X
372
373
        elif machine.startswith("riscv"):
            return CpuArchEnum.RISCV
374
375
376

        return CpuArchEnum.OTHER if machine else CpuArchEnum.UNKNOWN

377
378
379
380
381
382
    @classmethod
    def is_pin_memory_available(cls) -> bool:
        """Checks whether pin memory is available on the current platform."""
        if in_wsl():
            # Pinning memory in WSL is not supported.
            # https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
383
384
385
386
            logger.warning(
                "Using 'pin_memory=False' as WSL is detected. "
                "This may slow down the performance."
            )
387
388
389
            return False
        return True

390
    @classmethod
391
392
393
    def get_current_memory_usage(
        cls, device: Optional[torch.types.Device] = None
    ) -> float:
394
395
396
397
398
        """
        Return the memory usage in bytes.
        """
        raise NotImplementedError

399
400
401
402
403
404
405
    @classmethod
    def get_punica_wrapper(cls) -> str:
        """
        Return the punica wrapper for current platform.
        """
        raise NotImplementedError

406
    @classmethod
407
    def get_infinity_values(cls, dtype: torch.dtype) -> tuple[float, float]:
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
        """
        Return the platform specific values for (-inf, inf)
        """
        return float("-inf"), float("inf")

    @classmethod
    def can_update_inplace(cls) -> bool:
        """
        Checks if the platform allows inplace memory updates
        """
        return True

    @classmethod
    def get_lora_vocab_padding_size(cls) -> int:
        """
        Returns how much padding the LoRA logits need for kernels
        """
        return 256

427
428
429
430
431
    @classmethod
    def get_device_communicator_cls(cls) -> str:
        """
        Get device specific communicator class for distributed communication.
        """
Mengqing Cao's avatar
Mengqing Cao committed
432
        return "vllm.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase"  # noqa
433

434
435
436
437
438
439
440
    @classmethod
    def supports_mx(cls) -> bool:
        """
        Returns whether the current platform supports MX types.
        """
        return False

441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
    @classmethod
    def supports_fp8(cls) -> bool:
        """
        Returns whether the current platform supports FP8 types.
        """
        return False

    @classmethod
    def is_fp8_fnuz(cls) -> bool:
        """
        Returns whether the preferred FP8 type is FNUZ on the current platform.

        There are two representations of FP8, OCP FP8 and FNUZ FP8.
        The OCP specification can be found at https://tinyurl.com/b7jvwpft.
        The FNUZ specification can be found at https://tinyurl.com/5n6hwwu5.

        AMD's MI300 and MI325 have native hardware support for FNUZ. All other
        hardware has converged on the OCP FP8 standard.
        """
        return False

    @classmethod
    def fp8_dtype(cls) -> torch.dtype:
        """
        Returns the preferred FP8 type on the current platform.

        See the documentation for is_fp8_fnuz for details.
        """
        return torch.float8_e4m3fn

471
472
473
474
475
476
477
478
479
    @classmethod
    def use_all_gather(cls) -> bool:
        """
        Whether to use allgather in LogitsProcessor to gather the logits.
        """
        import vllm.envs as envs
        from vllm.config import get_current_vllm_config

        parallel_config = get_current_vllm_config().parallel_config
480
481
482
483
        return (
            envs.VLLM_USE_V1
            or parallel_config.distributed_executor_backend == "external_launcher"
        )
484

485
486
487
488
489
490
491
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        """
        Returns if custom allreduce is supported on the current platform
        """
        return False

492
493
494
495
496
497
498
499
    @classmethod
    def opaque_attention_op(cls) -> bool:
        """
        Returns True if we register attention as one giant opaque custom op
        on the current platform
        """
        return False

500
501
502
503
504
    @classmethod
    def validate_request(
        cls,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
505
        processed_inputs: ProcessorInputs,
506
507
508
    ) -> None:
        """Raises if this request is unsupported on this platform"""

509
    def __getattr__(self, key: str):
510
        device = getattr(torch, self.device_type, None)
511
512
513
        if device is not None and hasattr(device, key):
            return getattr(device, key)
        else:
514
515
516
517
518
            logger.warning(
                "Current platform %s does not have '%s' attribute.",
                self.device_type,
                key,
            )
519
520
            return None

521
522
    def get_global_graph_pool(self) -> Any:
        """
523
        Return the global graph pool for this platform.
524
525
526
527
528
529
        """
        cls = self.__class__
        if cls._global_graph_pool is None:
            cls._global_graph_pool = self.graph_pool_handle()
        return cls._global_graph_pool

530
531
532
533
534
535
536
    @classmethod
    def get_cu_count(cls, device_id: int = 0) -> int:
        """
        Returns the total number of compute units (CU) on single GPU.
        """
        raise NotImplementedError

537
    @classmethod
538
    def get_static_graph_wrapper_cls(cls) -> str:
539
        """
540
        Get static graph wrapper class for static graph.
541
        """
542
        return "vllm.compilation.base_static_graph.AbstractStaticGraphWrapper"
543

544
545
546
547
548
549
550
551
552
553
554
555
556
557
    @classmethod
    def stateless_init_device_torch_dist_pg(
        cls,
        backend: str,
        prefix_store: PrefixStore,
        group_rank: int,
        group_size: int,
        timeout: timedelta,
    ) -> ProcessGroup:
        """
        Init platform-specific torch distributed process group.
        """
        raise RuntimeError(f"Unsupported torch distributed backend: {backend}")

558
    @classmethod
559
560
561
    def is_kv_cache_dtype_supported(
        cls, kv_cache_dtype: str, model_config: "ModelConfig"
    ) -> bool:
562
563
564
565
566
        """
        Returns if the kv_cache_dtype is supported by the current platform.
        """
        return False

567
568
569
570
571
572
573
    @classmethod
    def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
        """
        Check if the dtype is supported by the current platform.
        """
        raise NotImplementedError

574
575
576
577
578
579
580
    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        """
        Returns if the hybrid kv cache is supported by the current platform.
        """
        return False

581
582
583
584
585
586
587
    @classmethod
    def support_static_graph_mode(cls) -> bool:
        """
        Returns if the graph mode is supported by the current platform.
        """
        return False

588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
    @classmethod
    def use_sync_weight_loader(cls) -> bool:
        """
        Returns if the current platform needs to sync weight loader.
        """
        return False

    @classmethod
    def make_synced_weight_loader(cls, original_weight_loader):
        """
        Wrap the original weight loader to make it synced.
        """
        if not cls.use_sync_weight_loader():
            return original_weight_loader

        def _synced_weight_loader(param, *args, **kwargs):
            out = original_weight_loader(param, *args, **kwargs)
            if param.device != torch.device("cpu"):
                torch._sync(param)
            return out

        return _synced_weight_loader

611
612
613
    @classmethod
    def get_nixl_supported_devices(cls) -> dict[str, tuple[str, ...]]:
        """
614
        Returns a mapping from device_type to a tuple of supported
615
616
617
618
619
620
621
622
623
624
625
        kv_buffer_device for nixl.
        """
        return {}

    @classmethod
    def get_nixl_memory_type(cls) -> Optional[str]:
        """
        Returns the nixl memory type for the current platform.
        """
        return None

626
627
628

class UnspecifiedPlatform(Platform):
    _enum = PlatformEnum.UNSPECIFIED
629
    device_type = ""