interface.py 19.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import contextlib
4
import enum
5
import os
6
import platform
7
import random
8
import sys
9
from datetime import timedelta
10
from platform import uname
11
from typing import TYPE_CHECKING, Any, NamedTuple
12

13
import numpy as np
14
import torch
15
from torch.distributed import PrefixStore, ProcessGroup
16

17
from vllm.inputs import ProcessorInputs, PromptType
18
19
from vllm.logger import init_logger

20
if TYPE_CHECKING:
21
    from vllm.attention.backends.registry import _Backend
22
    from vllm.config import ModelConfig, VllmConfig
23
24
    from vllm.pooling_params import PoolingParams
    from vllm.sampling_params import SamplingParams
25
    from vllm.utils import FlexibleArgumentParser
26
else:
27
28
29
30
31
32
    _Backend = object
    ModelConfig = object
    VllmConfig = object
    PoolingParams = object
    SamplingParams = object
    FlexibleArgumentParser = object
33

34
35
logger = init_logger(__name__)

36

37
38
39
40
41
def in_wsl() -> bool:
    # Reference: https://github.com/microsoft/WSL/issues/4071
    return "microsoft" in " ".join(uname()).lower()


42
43
44
class PlatformEnum(enum.Enum):
    CUDA = enum.auto()
    ROCM = enum.auto()
45
    TPU = enum.auto()
46
    XPU = enum.auto()
47
    CPU = enum.auto()
48
    OOT = enum.auto()
49
    UNSPECIFIED = enum.auto()
50
51


52
53
54
55
class CpuArchEnum(enum.Enum):
    X86 = enum.auto()
    ARM = enum.auto()
    POWERPC = enum.auto()
56
    S390X = enum.auto()
57
    RISCV = enum.auto()
58
59
60
61
    OTHER = enum.auto()
    UNKNOWN = enum.auto()


62
63
64
65
66
67
68
69
70
class DeviceCapability(NamedTuple):
    major: int
    minor: int

    def as_version_str(self) -> str:
        return f"{self.major}.{self.minor}"

    def to_int(self) -> int:
        """
71
        Express device capability as an integer `<major><minor>`.
72
73
74
75
76
77
78

        It is assumed that the minor version is always a single digit.
        """
        assert 0 <= self.minor < 10
        return self.major * 10 + self.minor


79
80
class Platform:
    _enum: PlatformEnum
81
    device_name: str
82
    device_type: str
83

84
85
86
87
    # available dispatch keys:
    # check https://github.com/pytorch/pytorch/blob/313dac6c1ca0fa0cde32477509cce32089f8532a/torchgen/model.py#L134 # noqa
    # use "CPU" as a fallback for platforms not registered in PyTorch
    dispatch_key: str = "CPU"
88

89
90
91
92
    # available ray device keys:
    # https://github.com/ray-project/ray/blob/10ba5adadcc49c60af2c358a33bb943fb491a171/python/ray/_private/ray_constants.py#L438 # noqa
    # empty string means the device does not support ray
    ray_device_key: str = ""
93
94
95
96
97
98
99

    # platform-agnostic way to specify the device control environment variable,
    # .e.g. CUDA_VISIBLE_DEVICES for CUDA.
    # hint: search for "get_visible_accelerator_ids_env_var" in
    # https://github.com/ray-project/ray/tree/master/python/ray/_private/accelerators # noqa
    device_control_env_var: str = "VLLM_DEVICE_CONTROL_ENV_VAR_PLACEHOLDER"

100
101
102
103
104
105
    # The torch.compile backend for compiling simple and
    # standalone functions. The default value is "inductor" to keep
    # the same behavior as PyTorch.
    # NOTE: for the forward part of the model, vLLM has another separate
    # compilation strategy.
    simple_compile_backend: str = "inductor"
106

107
108
109
    # The backend used for distributed communication.
    dist_backend: str = ""

110
    supported_quantization: list[str] = []
111

112
113
    additional_env_vars: list[str] = []

114
    _global_graph_pool: Any | None = None
115

116
117
118
119
120
121
122
123
    @property
    def supported_dtypes(self) -> list[torch.dtype]:
        """Returns the supported dtypes for the current platform."""
        # Be careful with the order of the dtypes. The first dtype will
        # be used as the default dtype fallback for the current platform,
        # when encountering unsupported dtypes in "auto" dtype.
        return [torch.bfloat16, torch.float16, torch.float32]

124
125
126
127
128
129
    def is_cuda(self) -> bool:
        return self._enum == PlatformEnum.CUDA

    def is_rocm(self) -> bool:
        return self._enum == PlatformEnum.ROCM

130
131
132
    def is_tpu(self) -> bool:
        return self._enum == PlatformEnum.TPU

133
134
135
    def is_xpu(self) -> bool:
        return self._enum == PlatformEnum.XPU

136
137
138
    def is_cpu(self) -> bool:
        return self._enum == PlatformEnum.CPU

139
140
141
    def is_out_of_tree(self) -> bool:
        return self._enum == PlatformEnum.OOT

142
143
144
    def is_unspecified(self) -> bool:
        return self._enum == PlatformEnum.UNSPECIFIED

145
146
147
    def get_max_output_tokens(self, prompt_len: int) -> int:
        return sys.maxsize

148
    def is_cuda_alike(self) -> bool:
149
        """Stateless version of [torch.cuda.is_available][]."""
150
151
        return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)

152
153
154
    def is_sleep_mode_available(self) -> bool:
        return self._enum == PlatformEnum.CUDA

155
156
    @classmethod
    def device_id_to_physical_device_id(cls, device_id: int):
157
158
159
        # Treat empty device control env var as unset. This is a valid
        # configuration in Ray setups where the engine is launched in
        # a CPU-only placement group located on a GPU node.
160
161
162
163
        if (
            cls.device_control_env_var in os.environ
            and os.environ[cls.device_control_env_var] != ""
        ):
164
165
166
167
168
169
            device_ids = os.environ[cls.device_control_env_var].split(",")
            physical_device_id = device_ids[device_id]
            return int(physical_device_id)
        else:
            return device_id

170
    @classmethod
171
    def import_kernels(cls) -> None:
172
        """Import any platform-specific C kernels."""
173
174
175
176
177
178
        try:
            import vllm._C  # noqa: F401
        except ImportError as e:
            logger.warning("Failed to import from vllm._C: %r", e)
        with contextlib.suppress(ImportError):
            import vllm._moe_C  # noqa: F401
179

180
    @classmethod
181
    def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend:
182
        from vllm.attention.backends.registry import _Backend
183

184
185
        return _Backend.TORCH_SDPA

186
    @classmethod
187
188
    def get_attn_backend_cls(
        cls,
189
        selected_backend: _Backend,
190
191
        head_size: int,
        dtype: torch.dtype,
192
        kv_cache_dtype: str | None,
193
194
195
196
197
198
        block_size: int,
        use_v1: bool,
        use_mla: bool,
        has_sink: bool,
        use_sparse: bool,
    ) -> str:
199
200
        """Get the attention backend class of a device."""
        return ""
201

202
203
204
205
    @classmethod
    def get_device_capability(
        cls,
        device_id: int = 0,
206
    ) -> DeviceCapability | None:
207
        """Stateless version of [torch.cuda.get_device_capability][]."""
208
        return None
209

210
211
212
    @classmethod
    def has_device_capability(
        cls,
213
        capability: tuple[int, int] | int,
214
215
216
217
218
        device_id: int = 0,
    ) -> bool:
        """
        Test whether this platform is compatible with a device capability.

219
        The `capability` argument can either be:
220

221
222
223
        - A tuple `(major, minor)`.
        - An integer `<major><minor>`. (See
        [`DeviceCapability.to_int`][vllm.platforms.interface.DeviceCapability.to_int])
224
225
226
227
228
229
230
231
232
233
        """
        current_capability = cls.get_device_capability(device_id=device_id)
        if current_capability is None:
            return False

        if isinstance(capability, tuple):
            return current_capability >= capability

        return current_capability.to_int() >= capability

234
235
236
    @classmethod
    def is_device_capability(
        cls,
237
        capability: tuple[int, int] | int,
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
        device_id: int = 0,
    ) -> bool:
        """
        Test whether this platform has exactly the specified device capability.

        The `capability` argument can either be:

        - A tuple `(major, minor)`.
        - An integer `<major><minor>`. (See
        [`DeviceCapability.to_int`][vllm.platforms.interface.DeviceCapability.to_int])
        """
        current_capability = cls.get_device_capability(device_id=device_id)
        if current_capability is None:
            return False

        if isinstance(capability, tuple):
            return current_capability == capability

        return current_capability.to_int() == capability

258
259
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
260
261
262
        """Get the name of a device."""
        raise NotImplementedError

263
264
265
266
267
    @classmethod
    def get_device_uuid(cls, device_id: int = 0) -> str:
        """Get the uuid of a device, e.g. the PCI bus ID."""
        raise NotImplementedError

268
269
270
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        """Get the total memory of a device in bytes."""
271
272
        raise NotImplementedError

273
274
    @classmethod
    def inference_mode(cls):
275
276
277
278
279
280
281
282
        """A device-specific wrapper of `torch.inference_mode`.

        This wrapper is recommended because some hardware backends such as TPU
        do not support `torch.inference_mode`. In such a case, they will fall
        back to `torch.no_grad` by overriding this method.
        """
        return torch.inference_mode(mode=True)

283
    @classmethod
284
    def seed_everything(cls, seed: int | None = None) -> None:
285
286
287
288
289
290
        """
        Set the seed of each random module.
        `torch.manual_seed` will set seed on all devices.

        Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
        """
291
292
293
294
        if seed is not None:
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)
295

296
297
298
299
300
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
301
        raise NotImplementedError
302

303
    @classmethod
304
    def pre_register_and_update(
305
        cls, parser: FlexibleArgumentParser | None = None
306
    ) -> None:
307
        """
308
        Do some pre-registration or update action for the current platform.
309
310
311
312
313
314
315
316
317
318

        This function is called before global VllmConfig is initialized or cli
        arguments are parsed. It's used for out-of-tree platforms to register or
        update the configuration.

        For example, the out-of-tree quantization config can be imported and
        registered here dynamically.
        """
        pass

319
320
321
322
323
324
325
326
327
328
329
330
331
    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
        """
        Check and update the configuration for the current platform.

        It can raise an exception if the configuration is not compatible with
        the current platform, or it can update the configuration to make it
        compatible with the current platform.

        The config is passed by reference, so it can be modified in place.
        """
        pass

332
333
334
335
336
337
338
339
340
341
342
343
    @classmethod
    def verify_model_arch(cls, model_arch: str) -> None:
        """
        Verify whether the current platform supports the specified model
        architecture.

        - This will raise an Error or Warning based on the model support on
        the current platform.
        - By default all models are considered supported.
        """
        pass

344
345
346
347
348
    @classmethod
    def verify_quantization(cls, quant: str) -> None:
        """
        Verify whether the quantization is supported by the current platform.
        """
349
        if cls.supported_quantization and quant not in cls.supported_quantization:
350
            raise ValueError(
351
352
                f"{quant} quantization is currently not supported in {cls.device_name}."
            )
353

354
355
356
357
358
359
360
361
362
363
364
365
366
367
    @classmethod
    def get_cpu_architecture(cls) -> CpuArchEnum:
        """
        Determine the CPU architecture of the current system.
        Returns CpuArchEnum indicating the architecture type.
        """
        machine = platform.machine().lower()

        if machine in ("x86_64", "amd64", "i386", "i686"):
            return CpuArchEnum.X86
        elif machine.startswith("arm") or machine.startswith("aarch"):
            return CpuArchEnum.ARM
        elif machine.startswith("ppc"):
            return CpuArchEnum.POWERPC
368
369
        elif machine == "s390x":
            return CpuArchEnum.S390X
370
371
        elif machine.startswith("riscv"):
            return CpuArchEnum.RISCV
372
373
374

        return CpuArchEnum.OTHER if machine else CpuArchEnum.UNKNOWN

375
376
377
378
379
380
    @classmethod
    def is_pin_memory_available(cls) -> bool:
        """Checks whether pin memory is available on the current platform."""
        if in_wsl():
            # Pinning memory in WSL is not supported.
            # https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
381
382
383
384
            logger.warning(
                "Using 'pin_memory=False' as WSL is detected. "
                "This may slow down the performance."
            )
385
386
387
            return False
        return True

388
    @classmethod
389
    def get_current_memory_usage(
390
        cls, device: torch.types.Device | None = None
391
    ) -> float:
392
393
394
395
396
        """
        Return the memory usage in bytes.
        """
        raise NotImplementedError

397
398
399
400
401
402
403
    @classmethod
    def get_punica_wrapper(cls) -> str:
        """
        Return the punica wrapper for current platform.
        """
        raise NotImplementedError

404
    @classmethod
405
    def get_infinity_values(cls, dtype: torch.dtype) -> tuple[float, float]:
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
        """
        Return the platform specific values for (-inf, inf)
        """
        return float("-inf"), float("inf")

    @classmethod
    def can_update_inplace(cls) -> bool:
        """
        Checks if the platform allows inplace memory updates
        """
        return True

    @classmethod
    def get_lora_vocab_padding_size(cls) -> int:
        """
        Returns how much padding the LoRA logits need for kernels
        """
        return 256

425
426
427
428
429
    @classmethod
    def get_device_communicator_cls(cls) -> str:
        """
        Get device specific communicator class for distributed communication.
        """
Mengqing Cao's avatar
Mengqing Cao committed
430
        return "vllm.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase"  # noqa
431

432
433
434
435
436
437
438
    @classmethod
    def supports_mx(cls) -> bool:
        """
        Returns whether the current platform supports MX types.
        """
        return False

439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
    @classmethod
    def supports_fp8(cls) -> bool:
        """
        Returns whether the current platform supports FP8 types.
        """
        return False

    @classmethod
    def is_fp8_fnuz(cls) -> bool:
        """
        Returns whether the preferred FP8 type is FNUZ on the current platform.

        There are two representations of FP8, OCP FP8 and FNUZ FP8.
        The OCP specification can be found at https://tinyurl.com/b7jvwpft.
        The FNUZ specification can be found at https://tinyurl.com/5n6hwwu5.

        AMD's MI300 and MI325 have native hardware support for FNUZ. All other
        hardware has converged on the OCP FP8 standard.
        """
        return False

    @classmethod
    def fp8_dtype(cls) -> torch.dtype:
        """
        Returns the preferred FP8 type on the current platform.

        See the documentation for is_fp8_fnuz for details.
        """
        return torch.float8_e4m3fn

469
470
471
472
473
474
475
476
477
    @classmethod
    def use_all_gather(cls) -> bool:
        """
        Whether to use allgather in LogitsProcessor to gather the logits.
        """
        import vllm.envs as envs
        from vllm.config import get_current_vllm_config

        parallel_config = get_current_vllm_config().parallel_config
478
479
480
481
        return (
            envs.VLLM_USE_V1
            or parallel_config.distributed_executor_backend == "external_launcher"
        )
482

483
484
485
486
487
488
489
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        """
        Returns if custom allreduce is supported on the current platform
        """
        return False

490
491
492
493
494
495
496
497
    @classmethod
    def opaque_attention_op(cls) -> bool:
        """
        Returns True if we register attention as one giant opaque custom op
        on the current platform
        """
        return False

498
499
500
501
    @classmethod
    def validate_request(
        cls,
        prompt: PromptType,
502
        params: SamplingParams | PoolingParams,
503
        processed_inputs: ProcessorInputs,
504
505
506
    ) -> None:
        """Raises if this request is unsupported on this platform"""

507
    def __getattr__(self, key: str):
508
        device = getattr(torch, self.device_type, None)
509
510
511
        if device is not None and hasattr(device, key):
            return getattr(device, key)
        else:
512
513
514
515
516
            logger.warning(
                "Current platform %s does not have '%s' attribute.",
                self.device_type,
                key,
            )
517
518
            return None

519
520
    def get_global_graph_pool(self) -> Any:
        """
521
        Return the global graph pool for this platform.
522
523
524
525
526
527
        """
        cls = self.__class__
        if cls._global_graph_pool is None:
            cls._global_graph_pool = self.graph_pool_handle()
        return cls._global_graph_pool

528
529
530
531
532
533
534
    @classmethod
    def get_cu_count(cls, device_id: int = 0) -> int:
        """
        Returns the total number of compute units (CU) on single GPU.
        """
        raise NotImplementedError

535
    @classmethod
536
    def get_static_graph_wrapper_cls(cls) -> str:
537
        """
538
        Get static graph wrapper class for static graph.
539
        """
540
        return "vllm.compilation.base_static_graph.AbstractStaticGraphWrapper"
541

542
543
544
545
546
547
548
549
550
551
552
553
    @classmethod
    def stateless_init_device_torch_dist_pg(
        cls,
        backend: str,
        prefix_store: PrefixStore,
        group_rank: int,
        group_size: int,
        timeout: timedelta,
    ) -> ProcessGroup:
        """
        Init platform-specific torch distributed process group.
        """
554
        raise NotImplementedError
555

556
    @classmethod
557
    def is_kv_cache_dtype_supported(
558
        cls, kv_cache_dtype: str, model_config: ModelConfig
559
    ) -> bool:
560
561
562
563
564
        """
        Returns if the kv_cache_dtype is supported by the current platform.
        """
        return False

565
566
567
568
569
570
571
    @classmethod
    def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
        """
        Check if the dtype is supported by the current platform.
        """
        raise NotImplementedError

572
573
574
575
576
577
578
    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        """
        Returns if the hybrid kv cache is supported by the current platform.
        """
        return False

579
580
581
582
583
584
585
    @classmethod
    def support_static_graph_mode(cls) -> bool:
        """
        Returns if the graph mode is supported by the current platform.
        """
        return False

586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
    @classmethod
    def use_sync_weight_loader(cls) -> bool:
        """
        Returns if the current platform needs to sync weight loader.
        """
        return False

    @classmethod
    def make_synced_weight_loader(cls, original_weight_loader):
        """
        Wrap the original weight loader to make it synced.
        """
        if not cls.use_sync_weight_loader():
            return original_weight_loader

        def _synced_weight_loader(param, *args, **kwargs):
            out = original_weight_loader(param, *args, **kwargs)
            if param.device != torch.device("cpu"):
                torch._sync(param)
            return out

        return _synced_weight_loader

609
610
611
    @classmethod
    def get_nixl_supported_devices(cls) -> dict[str, tuple[str, ...]]:
        """
612
        Returns a mapping from device_type to a tuple of supported
613
614
615
616
617
        kv_buffer_device for nixl.
        """
        return {}

    @classmethod
618
    def get_nixl_memory_type(cls) -> str | None:
619
620
621
622
623
        """
        Returns the nixl memory type for the current platform.
        """
        return None

624
625
626

class UnspecifiedPlatform(Platform):
    _enum = PlatformEnum.UNSPECIFIED
627
    device_type = ""