tpu.py 6.83 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import TYPE_CHECKING, Optional, Union, cast
5

6
import torch
7
from tpu_info import device
8

9
from vllm.inputs import ProcessorInputs, PromptType
10
from vllm.logger import init_logger
11
from vllm.sampling_params import SamplingParams, SamplingType
12
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
13
14

from .interface import Platform, PlatformEnum, _Backend
15

16
if TYPE_CHECKING:
17
    from vllm.config import BlockSize, ModelConfig, VllmConfig
18
    from vllm.pooling_params import PoolingParams
19
else:
20
    BlockSize = None
21
    ModelConfig = None
22
    VllmConfig = None
23
    PoolingParams = None
24

25
26
logger = init_logger(__name__)

27
28
29

class TpuPlatform(Platform):
    _enum = PlatformEnum.TPU
30
    device_name: str = "tpu"
31
    device_type: str = "tpu"
32
    dispatch_key: str = "XLA"
33
    ray_device_key: str = "TPU"
34
    dist_backend: str = "gloo"
35
    device_control_env_var: str = "TPU_VISIBLE_CHIPS"
36
    simple_compile_backend: str = "openxla"
37

38
    supported_quantization: list[str] = ["tpu_int8", "compressed-tensors"]
39

40
41
42
43
    additional_env_vars: list[str] = [
        "TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS"
    ]

44
    @classmethod
45
46
    def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
                             dtype: torch.dtype, kv_cache_dtype: Optional[str],
47
48
                             block_size: int, use_v1: bool,
                             use_mla: bool) -> str:
49
50
        if (selected_backend != _Backend.PALLAS
                and selected_backend != _Backend.PALLAS_VLLM_V1):
51
            logger.info("Cannot use %s backend on TPU.", selected_backend)
52

53
54
55
56
        if not use_v1:
            raise ValueError("TPU backend only supports V1.")
        logger.info("Using Pallas V1 backend.")
        return "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
57

58
59
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
60
61
        chip_type, _ = device.get_local_chips()
        return f"TPU {chip_type.name}"
62

63
64
65
66
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        raise NotImplementedError

67
68
    @classmethod
    def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
69
        return False
70

71
72
73
74
75
    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU"

    @classmethod
76
    def get_infinity_values(cls, dtype: torch.dtype) -> tuple[float, float]:
77
78
79
80
81
82
83
84
85
86
        return torch.finfo(dtype).min, torch.finfo(dtype).max

    @classmethod
    def can_update_inplace(cls):
        return False

    @classmethod
    def get_lora_vocab_padding_size(cls) -> int:
        return 1

87
88
    @classmethod
    def inference_mode(cls):
89
        return torch.no_grad()
90
91
92
93

    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
        from vllm.config import CompilationLevel
94
95

        cache_config = vllm_config.cache_config
96
        # For v0, the default block size is 16.
97
        if cache_config and cache_config.block_size is None:
98
            cache_config.block_size = cast(BlockSize, 16)
99
        compilation_config = vllm_config.compilation_config
100
101
102
103

        # TPU only supports DYNAMO_ONCE compilation level
        if compilation_config.level != CompilationLevel.DYNAMO_ONCE:
            logger.info("[TPU] Forcing DYNAMO_ONCE compilation level")
104
            compilation_config.level = CompilationLevel.DYNAMO_ONCE
105
106
107

        if compilation_config.backend == "":
            compilation_config.backend = "openxla"
108
109
110
111

        assert vllm_config.speculative_config is None, \
            "TPU does not support speculative decoding"

112
113
114
115
116
117
        if vllm_config.model_config.dtype in (torch.float16, torch.float32):
            logger.warning(
                "The TPU backend currently does not support %s. "
                "Using bfloat16 instead.", vllm_config.model_config.dtype)
            vllm_config.model_config.dtype = torch.bfloat16

118
119
120
        from vllm.v1.attention.backends.pallas import PallasAttentionBackend
        cache_config.block_size = PallasAttentionBackend.get_page_size(
            vllm_config)  # type: ignore[assignment]
121

122
123
124
        parallel_config = vllm_config.parallel_config
        scheduler_config = vllm_config.scheduler_config
        if parallel_config.worker_cls == "auto":
125
            if scheduler_config.is_multi_step:
126
127
128
129
130
                raise NotImplementedError(
                    "Multi-step scheduling is not supported (and not "
                    "needed) on vLLM V1. Please launch without "
                    "--num-scheduler-steps.")
            parallel_config.worker_cls = "vllm.v1.worker.tpu_worker.TPUWorker"
131
132
133
134

        assert not vllm_config.speculative_config, (
            "Speculative decoding is not yet supported for TPU backend")

135
136
137
138
139
140
141
        if scheduler_config.is_multimodal_model and not \
            scheduler_config.disable_chunked_mm_input:
            logger.warning("TPU does not support running Multimodal models"\
            " without setting `--disable_chunked_mm_input`. " \
            "Forcing --disable_chunked_mm_input.")
            scheduler_config.disable_chunked_mm_input = True

142
143
144
145
146
147
148
149
150
151
        if vllm_config.model_config and vllm_config.model_config.use_mla:
            logger.info(
                "MLA is enabled on a non-GPU platform; forcing chunked "
                "prefill and prefix caching to be disabled.")
            vllm_config.scheduler_config.enable_chunked_prefill = False
            vllm_config.scheduler_config.chunked_prefill_enabled = False
            vllm_config.scheduler_config.max_num_batched_tokens = max(
                vllm_config.scheduler_config.max_model_len,
                DEFAULT_MAX_NUM_BATCHED_TOKENS)

152
153
154
155
    @classmethod
    def is_pin_memory_available(cls):
        logger.warning("Pin memory is not supported on TPU.")
        return False
156
157
158
159

    @classmethod
    def get_device_communicator_cls(cls) -> str:
        return "vllm.distributed.device_communicators.tpu_communicator.TpuCommunicator"  # noqa
160
161
162
163

    @classmethod
    def use_all_gather(cls) -> bool:
        return True
164
165
166
167
168

    @classmethod
    def supports_v1(cls, model_config: ModelConfig) -> bool:
        # V1 support on TPU is experimental
        return True
169
170

    @classmethod
171
172
173
174
    def validate_request(
        cls,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
175
        processed_inputs: ProcessorInputs,
176
177
    ) -> None:
        """Raises if this request is unsupported on this platform"""
178
179
180
        if (isinstance(params, SamplingParams)
                and params.sampling_type == SamplingType.RANDOM_SEED):
            raise ValueError("Torch XLA does not support per-request seed.")
181
182
183
184
185
186
187
188


try:
    from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform
    TpuPlatform = TpuCommonsPlatform  # type: ignore
except ImportError:
    logger.info("tpu_commons not found, using vLLM's TpuPlatform")
    pass