tpu.py 8.52 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import TYPE_CHECKING, Optional, Union, cast
5

6
import torch
7
from tpu_info import device
8

9
from vllm.inputs import ProcessorInputs, PromptType
10
from vllm.logger import init_logger
11
from vllm.sampling_params import SamplingParams, SamplingType
12
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
13

14
from .interface import Platform, PlatformEnum
15

16
if TYPE_CHECKING:
17
    from vllm.attention.backends.registry import _Backend
18
19
    from vllm.config import ModelConfig, VllmConfig
    from vllm.config.cache import BlockSize
20
    from vllm.pooling_params import PoolingParams
21
else:
22
    BlockSize = None
23
    ModelConfig = None
24
    VllmConfig = None
25
    PoolingParams = None
26
    _Backend = None
27

28
29
logger = init_logger(__name__)

30
USE_TPU_INFERENCE = False
31

32
33
34

class TpuPlatform(Platform):
    _enum = PlatformEnum.TPU
35
    device_name: str = "tpu"
36
    device_type: str = "tpu"
37
    dispatch_key: str = "XLA"
38
    ray_device_key: str = "TPU"
39
    dist_backend: str = "gloo"
40
    device_control_env_var: str = "TPU_VISIBLE_CHIPS"
41
    simple_compile_backend: str = "openxla"
42

43
    supported_quantization: list[str] = ["fp8", "tpu_int8", "compressed-tensors"]
44

45
    additional_env_vars: list[str] = ["TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS"]
46

47
48
49
50
    @classmethod
    def import_core_kernels(cls) -> None:
        pass

51
    @classmethod
52
53
54
55
56
57
58
59
60
61
62
63
    def get_attn_backend_cls(
        cls,
        selected_backend: "_Backend",
        head_size: int,
        dtype: torch.dtype,
        kv_cache_dtype: Optional[str],
        block_size: int,
        use_v1: bool,
        use_mla: bool,
        has_sink,
        use_sparse,
    ) -> str:
64
        from vllm.attention.backends.registry import _Backend
65

66
        if use_sparse:
67
            raise NotImplementedError("Sparse Attention is not supported on TPU.")
68
        if selected_backend != _Backend.PALLAS:
69
            logger.info("Cannot use %s backend on TPU.", selected_backend)
70

71
72
73
74
        if not use_v1:
            raise ValueError("TPU backend only supports V1.")
        logger.info("Using Pallas V1 backend.")
        return "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
75

76
77
78
79
80
81
82
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
        torch.tpu.set_device(device)

83
84
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
85
86
        chip_type, _ = device.get_local_chips()
        return f"TPU {chip_type.name}"
87

88
89
90
91
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        raise NotImplementedError

92
93
94
95
96
    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU"

    @classmethod
97
    def get_infinity_values(cls, dtype: torch.dtype) -> tuple[float, float]:
98
99
100
101
102
103
104
105
106
107
        return torch.finfo(dtype).min, torch.finfo(dtype).max

    @classmethod
    def can_update_inplace(cls):
        return False

    @classmethod
    def get_lora_vocab_padding_size(cls) -> int:
        return 1

108
109
    @classmethod
    def inference_mode(cls):
110
        return torch.no_grad()
111
112
113

    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
114
        from vllm.config import CompilationLevel, CUDAGraphMode
115
116

        cache_config = vllm_config.cache_config
117
        # For v0, the default block size is 16.
118
        if cache_config and cache_config.block_size is None:
119
            cache_config.block_size = cast(BlockSize, 16)
120
        compilation_config = vllm_config.compilation_config
121
122
123

        # TPU only supports DYNAMO_ONCE compilation level
        if compilation_config.level != CompilationLevel.DYNAMO_ONCE:
124
125
126
            logger.info(
                "[TPU] Forcing DYNAMO_ONCE compilation level, and disabling cudagraph."
            )
127
            compilation_config.level = CompilationLevel.DYNAMO_ONCE
128

129
130
131
132
133
134
135
136
        if (
            compilation_config.cudagraph_mode is None
            or compilation_config.cudagraph_mode.max_cudagraph_mode()
            != CUDAGraphMode.NONE
        ):
            logger.info(
                "[TPU] CUDA graph is not supported on TPU, disabling cudagraphs."
            )
137
138
            compilation_config.cudagraph_mode = CUDAGraphMode.NONE

139
140
        if compilation_config.backend == "":
            compilation_config.backend = "openxla"
141

142
        assert vllm_config.speculative_config is None, (
143
            "TPU does not support speculative decoding"
144
        )
145

146
        model_config = vllm_config.model_config
147
148
149
150
        if model_config is not None and model_config.dtype in (
            torch.float16,
            torch.float32,
        ):
151
152
            logger.warning(
                "The TPU backend currently does not support %s. "
153
154
155
                "Using bfloat16 instead.",
                model_config.dtype,
            )
156
            model_config.dtype = torch.bfloat16
157

158
        from vllm.v1.attention.backends.pallas import PallasAttentionBackend
159
160

        cache_config.block_size = PallasAttentionBackend.get_page_size(vllm_config)  # type: ignore[assignment]
161

162
163
164
        parallel_config = vllm_config.parallel_config
        scheduler_config = vllm_config.scheduler_config
        if parallel_config.worker_cls == "auto":
165
            parallel_config.worker_cls = "vllm.v1.worker.tpu_worker.TPUWorker"
166
167

        assert not vllm_config.speculative_config, (
168
169
            "Speculative decoding is not yet supported for TPU backend"
        )
170

171
172
173
174
175
176
177
178
179
        if (
            scheduler_config.is_multimodal_model
            and not scheduler_config.disable_chunked_mm_input
        ):
            logger.warning(
                "TPU does not support running Multimodal models"
                " without setting `--disable_chunked_mm_input`. "
                "Forcing --disable_chunked_mm_input."
            )
180
181
            scheduler_config.disable_chunked_mm_input = True

182
        if model_config and model_config.use_mla:
183
184
            logger.info(
                "MLA is enabled on a non-GPU platform; forcing chunked "
185
186
                "prefill and prefix caching to be disabled."
            )
187
188
189
190
            vllm_config.scheduler_config.enable_chunked_prefill = False
            vllm_config.scheduler_config.chunked_prefill_enabled = False
            vllm_config.scheduler_config.max_num_batched_tokens = max(
                vllm_config.scheduler_config.max_model_len,
191
192
                DEFAULT_MAX_NUM_BATCHED_TOKENS,
            )
193

194
195
196
197
    @classmethod
    def is_pin_memory_available(cls):
        logger.warning("Pin memory is not supported on TPU.")
        return False
198
199
200
201

    @classmethod
    def get_device_communicator_cls(cls) -> str:
        return "vllm.distributed.device_communicators.tpu_communicator.TpuCommunicator"  # noqa
202
203
204
205

    @classmethod
    def use_all_gather(cls) -> bool:
        return True
206

207
    @classmethod
208
209
210
211
    def validate_request(
        cls,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
212
        processed_inputs: ProcessorInputs,
213
214
    ) -> None:
        """Raises if this request is unsupported on this platform"""
215
216
217
218
        if (
            isinstance(params, SamplingParams)
            and params.sampling_type == SamplingType.RANDOM_SEED
        ):
219
            raise ValueError("Torch XLA does not support per-request seed.")
220

221
    @classmethod
222
223
224
    def is_kv_cache_dtype_supported(
        cls, kv_cache_dtype: str, model_config: "ModelConfig"
    ) -> bool:
225
226
        return True

227
228
229
230
231
232
233
234
235
236
    @classmethod
    @torch.compile(backend="openxla")
    def insert_blocks_to_device(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        torch.ops.xla.dynamo_set_buffer_donor_(dst_cache, True)
237
        dst_cache[dst_block_indices] = src_cache[src_block_indices].to(dst_cache.device)
238
239
240
241
242
243
244
245
246
247

    @classmethod
    @torch.compile(backend="openxla")
    def swap_out_blocks_to_host(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
248
        """tpu blocks to cpu blocks"""
249
250
251
        torch.ops.xla.dynamo_set_buffer_donor_(src_cache, True)
        dst_cache[dst_block_indices] = src_cache[src_block_indices].cpu()

252
253
254
255
    @classmethod
    def use_sync_weight_loader(cls) -> bool:
        return True

256
257

try:
258
    from tpu_inference.platforms import TpuPlatform as TpuInferencePlatform
259

260
261
    TpuPlatform = TpuInferencePlatform  # type: ignore
    USE_TPU_INFERENCE = True
262
except ImportError:
263
    logger.info("tpu_inference not found, using vLLM's TpuPlatform")
264
    pass