interface.py 17.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import enum
4
import os
5
import platform
6
import random
7
from datetime import timedelta
8
from platform import uname
9
from typing import TYPE_CHECKING, NamedTuple, Optional, Union
10

11
import numpy as np
12
import torch
13
from torch.distributed import PrefixStore, ProcessGroup
14

15
from vllm.inputs import ProcessorInputs, PromptType
16
17
from vllm.logger import init_logger

18
if TYPE_CHECKING:
19
    from vllm.config import ModelConfig, VllmConfig
20
21
22
    from vllm.lora.request import LoRARequest
    from vllm.pooling_params import PoolingParams
    from vllm.sampling_params import SamplingParams
23
    from vllm.utils import FlexibleArgumentParser
24
else:
25
    ModelConfig = None
26
    VllmConfig = None
27
28
29
    LoRARequest = None
    PoolingParams = None
    SamplingParams = None
30
    FlexibleArgumentParser = None
31

32
33
logger = init_logger(__name__)

34

35
36
37
38
39
def in_wsl() -> bool:
    # Reference: https://github.com/microsoft/WSL/issues/4071
    return "microsoft" in " ".join(uname()).lower()


40
41
42
class _Backend(enum.Enum):
    FLASH_ATTN = enum.auto()
    FLASH_ATTN_VLLM_V1 = enum.auto()
43
    TRITON_ATTN_VLLM_V1 = enum.auto()
44
45
    XFORMERS = enum.auto()
    ROCM_FLASH = enum.auto()
46
47
    ROCM_AITER_MLA = enum.auto()  # Supported by V1
    ROCM_AITER_MLA_VLLM_V1 = enum.auto()
48
49
    TORCH_SDPA = enum.auto()
    FLASHINFER = enum.auto()
50
    FLASHINFER_VLLM_V1 = enum.auto()
51
    TRITON_MLA = enum.auto()  # Supported by V1
52
53
    TRITON_MLA_VLLM_V1 = enum.auto()
    FLASHMLA_VLLM_V1 = enum.auto()
54
    FLASHMLA = enum.auto()  # Supported by V1
55
    CUTLASS_MLA_VLLM_V1 = enum.auto()
56
57
    HPU_ATTN = enum.auto()
    PALLAS = enum.auto()
58
    PALLAS_VLLM_V1 = enum.auto()
59
    IPEX = enum.auto()
60
    BLOCK_SPARSE_FLASH_ATTN = enum.auto()
61
    DUAL_CHUNK_FLASH_ATTN = enum.auto()
62
    NO_ATTENTION = enum.auto()
63
    FLEX_ATTENTION = enum.auto()
64
65


66
67
68
class PlatformEnum(enum.Enum):
    CUDA = enum.auto()
    ROCM = enum.auto()
69
    TPU = enum.auto()
70
    HPU = enum.auto()
71
    XPU = enum.auto()
72
    CPU = enum.auto()
73
    NEURON = enum.auto()
74
    OOT = enum.auto()
75
    UNSPECIFIED = enum.auto()
76
77


78
79
80
81
82
83
84
85
class CpuArchEnum(enum.Enum):
    X86 = enum.auto()
    ARM = enum.auto()
    POWERPC = enum.auto()
    OTHER = enum.auto()
    UNKNOWN = enum.auto()


86
87
88
89
90
91
92
93
94
class DeviceCapability(NamedTuple):
    major: int
    minor: int

    def as_version_str(self) -> str:
        return f"{self.major}.{self.minor}"

    def to_int(self) -> int:
        """
95
        Express device capability as an integer `<major><minor>`.
96
97
98
99
100
101
102

        It is assumed that the minor version is always a single digit.
        """
        assert 0 <= self.minor < 10
        return self.major * 10 + self.minor


103
104
class Platform:
    _enum: PlatformEnum
105
    device_name: str
106
    device_type: str
107

108
109
110
111
    # available dispatch keys:
    # check https://github.com/pytorch/pytorch/blob/313dac6c1ca0fa0cde32477509cce32089f8532a/torchgen/model.py#L134 # noqa
    # use "CPU" as a fallback for platforms not registered in PyTorch
    dispatch_key: str = "CPU"
112

113
114
115
116
    # available ray device keys:
    # https://github.com/ray-project/ray/blob/10ba5adadcc49c60af2c358a33bb943fb491a171/python/ray/_private/ray_constants.py#L438 # noqa
    # empty string means the device does not support ray
    ray_device_key: str = ""
117
118
119
120
121
122
123

    # platform-agnostic way to specify the device control environment variable,
    # .e.g. CUDA_VISIBLE_DEVICES for CUDA.
    # hint: search for "get_visible_accelerator_ids_env_var" in
    # https://github.com/ray-project/ray/tree/master/python/ray/_private/accelerators # noqa
    device_control_env_var: str = "VLLM_DEVICE_CONTROL_ENV_VAR_PLACEHOLDER"

124
125
126
127
128
129
    # The torch.compile backend for compiling simple and
    # standalone functions. The default value is "inductor" to keep
    # the same behavior as PyTorch.
    # NOTE: for the forward part of the model, vLLM has another separate
    # compilation strategy.
    simple_compile_backend: str = "inductor"
130

131
    supported_quantization: list[str] = []
132

133
134
    additional_env_vars: list[str] = []

135
136
137
138
139
140
141
142
    @property
    def supported_dtypes(self) -> list[torch.dtype]:
        """Returns the supported dtypes for the current platform."""
        # Be careful with the order of the dtypes. The first dtype will
        # be used as the default dtype fallback for the current platform,
        # when encountering unsupported dtypes in "auto" dtype.
        return [torch.bfloat16, torch.float16, torch.float32]

143
144
145
146
147
148
    def is_cuda(self) -> bool:
        return self._enum == PlatformEnum.CUDA

    def is_rocm(self) -> bool:
        return self._enum == PlatformEnum.ROCM

149
150
151
    def is_tpu(self) -> bool:
        return self._enum == PlatformEnum.TPU

152
153
154
    def is_hpu(self) -> bool:
        return self._enum == PlatformEnum.HPU

155
156
157
    def is_xpu(self) -> bool:
        return self._enum == PlatformEnum.XPU

158
159
160
    def is_cpu(self) -> bool:
        return self._enum == PlatformEnum.CPU

161
162
163
    def is_neuron(self) -> bool:
        return self._enum == PlatformEnum.NEURON

164
165
166
    def is_out_of_tree(self) -> bool:
        return self._enum == PlatformEnum.OOT

167
    def is_cuda_alike(self) -> bool:
168
        """Stateless version of [torch.cuda.is_available][]."""
169
170
        return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)

171
172
173
    def is_sleep_mode_available(self) -> bool:
        return self._enum == PlatformEnum.CUDA

174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
    @classmethod
    def device_id_to_physical_device_id(cls, device_id: int):
        if cls.device_control_env_var in os.environ:
            device_ids = os.environ[cls.device_control_env_var].split(",")
            if device_ids == [""]:
                msg = (f"{cls.device_control_env_var} is set to empty string, "
                       "which means current platform support is disabled. If "
                       "you are using ray, please unset the environment "
                       f"variable `{cls.device_control_env_var}` inside the "
                       "worker/actor. Check "
                       "https://github.com/vllm-project/vllm/issues/8402 for "
                       "more information.")
                raise RuntimeError(msg)
            physical_device_id = device_ids[device_id]
            return int(physical_device_id)
        else:
            return device_id

192
    @classmethod
193
194
    def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
                             dtype: torch.dtype, kv_cache_dtype: Optional[str],
195
196
                             block_size: int, use_v1: bool,
                             use_mla: bool) -> str:
197
198
        """Get the attention backend class of a device."""
        return ""
199

200
201
202
203
204
    @classmethod
    def get_device_capability(
        cls,
        device_id: int = 0,
    ) -> Optional[DeviceCapability]:
205
        """Stateless version of [torch.cuda.get_device_capability][]."""
206
        return None
207

208
209
210
    @classmethod
    def has_device_capability(
        cls,
211
        capability: Union[tuple[int, int], int],
212
213
214
215
216
        device_id: int = 0,
    ) -> bool:
        """
        Test whether this platform is compatible with a device capability.

217
        The `capability` argument can either be:
218

219
220
221
        - A tuple `(major, minor)`.
        - An integer `<major><minor>`. (See
        [`DeviceCapability.to_int`][vllm.platforms.interface.DeviceCapability.to_int])
222
223
224
225
226
227
228
229
230
231
        """
        current_capability = cls.get_device_capability(device_id=device_id)
        if current_capability is None:
            return False

        if isinstance(capability, tuple):
            return current_capability >= capability

        return current_capability.to_int() >= capability

232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
    @classmethod
    def is_device_capability(
        cls,
        capability: Union[tuple[int, int], int],
        device_id: int = 0,
    ) -> bool:
        """
        Test whether this platform has exactly the specified device capability.

        The `capability` argument can either be:

        - A tuple `(major, minor)`.
        - An integer `<major><minor>`. (See
        [`DeviceCapability.to_int`][vllm.platforms.interface.DeviceCapability.to_int])
        """
        current_capability = cls.get_device_capability(device_id=device_id)
        if current_capability is None:
            return False

        if isinstance(capability, tuple):
            return current_capability == capability

        return current_capability.to_int() == capability

256
257
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
258
259
260
        """Get the name of a device."""
        raise NotImplementedError

261
262
263
264
265
    @classmethod
    def get_device_uuid(cls, device_id: int = 0) -> str:
        """Get the uuid of a device, e.g. the PCI bus ID."""
        raise NotImplementedError

266
267
268
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        """Get the total memory of a device in bytes."""
269
270
        raise NotImplementedError

271
272
273
274
275
276
277
    @classmethod
    def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
        """
        Check if the current platform supports async output.
        """
        raise NotImplementedError

278
279
    @classmethod
    def inference_mode(cls):
280
281
282
283
284
285
286
287
        """A device-specific wrapper of `torch.inference_mode`.

        This wrapper is recommended because some hardware backends such as TPU
        do not support `torch.inference_mode`. In such a case, they will fall
        back to `torch.no_grad` by overriding this method.
        """
        return torch.inference_mode(mode=True)

288
    @classmethod
289
    def seed_everything(cls, seed: Optional[int] = None) -> None:
290
291
292
293
294
295
        """
        Set the seed of each random module.
        `torch.manual_seed` will set seed on all devices.

        Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
        """
296
297
298
299
        if seed is not None:
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)
300

301
302
303
304
305
306
307
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
        torch.cuda.set_device(device)

308
309
310
311
312
    @classmethod
    def pre_register_and_update(cls,
                                parser: Optional[FlexibleArgumentParser] = None
                                ) -> None:
        """
313
        Do some pre-registration or update action for the current platform.
314
315
316
317
318
319
320
321
322
323

        This function is called before global VllmConfig is initialized or cli
        arguments are parsed. It's used for out-of-tree platforms to register or
        update the configuration.

        For example, the out-of-tree quantization config can be imported and
        registered here dynamically.
        """
        pass

324
325
326
327
328
329
330
331
332
333
334
335
336
    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
        """
        Check and update the configuration for the current platform.

        It can raise an exception if the configuration is not compatible with
        the current platform, or it can update the configuration to make it
        compatible with the current platform.

        The config is passed by reference, so it can be modified in place.
        """
        pass

337
338
339
340
341
342
343
344
345
346
347
348
    @classmethod
    def verify_model_arch(cls, model_arch: str) -> None:
        """
        Verify whether the current platform supports the specified model
        architecture.

        - This will raise an Error or Warning based on the model support on
        the current platform.
        - By default all models are considered supported.
        """
        pass

349
350
351
352
353
354
355
356
357
358
359
    @classmethod
    def verify_quantization(cls, quant: str) -> None:
        """
        Verify whether the quantization is supported by the current platform.
        """
        if cls.supported_quantization and \
            quant not in cls.supported_quantization:
            raise ValueError(
                f"{quant} quantization is currently not supported in "
                f"{cls.device_name}.")

360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
    @classmethod
    def get_cpu_architecture(cls) -> CpuArchEnum:
        """
        Determine the CPU architecture of the current system.
        Returns CpuArchEnum indicating the architecture type.
        """
        machine = platform.machine().lower()

        if machine in ("x86_64", "amd64", "i386", "i686"):
            return CpuArchEnum.X86
        elif machine.startswith("arm") or machine.startswith("aarch"):
            return CpuArchEnum.ARM
        elif machine.startswith("ppc"):
            return CpuArchEnum.POWERPC

        return CpuArchEnum.OTHER if machine else CpuArchEnum.UNKNOWN

377
378
379
380
381
382
383
384
385
386
387
    @classmethod
    def is_pin_memory_available(cls) -> bool:
        """Checks whether pin memory is available on the current platform."""
        if in_wsl():
            # Pinning memory in WSL is not supported.
            # https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
            logger.warning("Using 'pin_memory=False' as WSL is detected. "
                           "This may slow down the performance.")
            return False
        return True

388
389
390
391
392
393
394
395
396
    @classmethod
    def get_current_memory_usage(cls,
                                 device: Optional[torch.types.Device] = None
                                 ) -> float:
        """
        Return the memory usage in bytes.
        """
        raise NotImplementedError

397
398
399
400
401
402
403
    @classmethod
    def get_punica_wrapper(cls) -> str:
        """
        Return the punica wrapper for current platform.
        """
        raise NotImplementedError

404
    @classmethod
405
    def get_infinity_values(cls, dtype: torch.dtype) -> tuple[float, float]:
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
        """
        Return the platform specific values for (-inf, inf)
        """
        return float("-inf"), float("inf")

    @classmethod
    def can_update_inplace(cls) -> bool:
        """
        Checks if the platform allows inplace memory updates
        """
        return True

    @classmethod
    def get_lora_vocab_padding_size(cls) -> int:
        """
        Returns how much padding the LoRA logits need for kernels
        """
        return 256

425
426
427
428
429
    @classmethod
    def get_device_communicator_cls(cls) -> str:
        """
        Get device specific communicator class for distributed communication.
        """
Mengqing Cao's avatar
Mengqing Cao committed
430
        return "vllm.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase"  # noqa
431

432
433
434
435
436
437
438
    @classmethod
    def supports_mx(cls) -> bool:
        """
        Returns whether the current platform supports MX types.
        """
        return False

439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
    @classmethod
    def supports_fp8(cls) -> bool:
        """
        Returns whether the current platform supports FP8 types.
        """
        return False

    @classmethod
    def is_fp8_fnuz(cls) -> bool:
        """
        Returns whether the preferred FP8 type is FNUZ on the current platform.

        There are two representations of FP8, OCP FP8 and FNUZ FP8.
        The OCP specification can be found at https://tinyurl.com/b7jvwpft.
        The FNUZ specification can be found at https://tinyurl.com/5n6hwwu5.

        AMD's MI300 and MI325 have native hardware support for FNUZ. All other
        hardware has converged on the OCP FP8 standard.
        """
        return False

    @classmethod
    def fp8_dtype(cls) -> torch.dtype:
        """
        Returns the preferred FP8 type on the current platform.

        See the documentation for is_fp8_fnuz for details.
        """
        return torch.float8_e4m3fn

469
470
471
472
473
474
475
476
477
478
479
480
481
    @classmethod
    def use_all_gather(cls) -> bool:
        """
        Whether to use allgather in LogitsProcessor to gather the logits.
        """
        import vllm.envs as envs
        from vllm.config import get_current_vllm_config

        parallel_config = get_current_vllm_config().parallel_config
        return (envs.VLLM_USE_V1
                or parallel_config.distributed_executor_backend
                == "external_launcher")

482
483
484
485
486
487
488
    @classmethod
    def supports_v1(cls, model_config: ModelConfig) -> bool:
        """Returns whether the current platform can support v1 for the supplied
        model configuration.
        """
        return False

489
490
491
492
493
494
495
    @classmethod
    def default_v1(cls, model_config: ModelConfig) -> bool:
        """
        Returns whether the current platform supports v1 by default.
        """
        return cls.supports_v1(model_config)

496
497
498
499
500
501
502
    @classmethod
    def use_custom_allreduce(cls) -> bool:
        """
        Returns if custom allreduce is supported on the current platform
        """
        return False

503
504
505
506
507
    @classmethod
    def validate_request(
        cls,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
508
        processed_inputs: ProcessorInputs,
509
510
511
    ) -> None:
        """Raises if this request is unsupported on this platform"""

512
    def __getattr__(self, key: str):
513
        device = getattr(torch, self.device_type, None)
514
515
516
        if device is not None and hasattr(device, key):
            return getattr(device, key)
        else:
517
            logger.warning("Current platform %s does not have '%s'" \
518
            " attribute.", self.device_type, key)
519
520
            return None

521
522
523
524
525
526
527
    @classmethod
    def get_cu_count(cls, device_id: int = 0) -> int:
        """
        Returns the total number of compute units (CU) on single GPU.
        """
        raise NotImplementedError

528
529
530
531
532
533
534
    @classmethod
    def get_piecewise_backend_cls(cls) -> str:
        """
        Get piecewise backend class for piecewise graph.
        """
        return "vllm.compilation.base_piecewise_backend.AbstractPiecewiseBackend"  # noqa

535
536
537
538
539
540
541
542
543
544
545
546
547
548
    @classmethod
    def stateless_init_device_torch_dist_pg(
        cls,
        backend: str,
        prefix_store: PrefixStore,
        group_rank: int,
        group_size: int,
        timeout: timedelta,
    ) -> ProcessGroup:
        """
        Init platform-specific torch distributed process group.
        """
        raise RuntimeError(f"Unsupported torch distributed backend: {backend}")

549
550
551

class UnspecifiedPlatform(Platform):
    _enum = PlatformEnum.UNSPECIFIED
552
    device_type = ""