tpu.py 8.48 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import TYPE_CHECKING, Optional, Union, cast
5

6
import torch
7
from tpu_info import device
8

9
from vllm.inputs import ProcessorInputs, PromptType
10
from vllm.logger import init_logger
11
from vllm.sampling_params import SamplingParams, SamplingType
12
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
13

14
from .interface import Platform, PlatformEnum
15

16
if TYPE_CHECKING:
17
    from vllm.attention.backends.registry import _Backend
18
    from vllm.config import BlockSize, ModelConfig, VllmConfig
19
    from vllm.pooling_params import PoolingParams
20
else:
21
    BlockSize = None
22
    ModelConfig = None
23
    VllmConfig = None
24
    PoolingParams = None
25
    _Backend = None
26

27
28
logger = init_logger(__name__)

29
30
USE_TPU_COMMONS = False

31
32
33

class TpuPlatform(Platform):
    _enum = PlatformEnum.TPU
34
    device_name: str = "tpu"
35
    device_type: str = "tpu"
36
    dispatch_key: str = "XLA"
37
    ray_device_key: str = "TPU"
38
    dist_backend: str = "gloo"
39
    device_control_env_var: str = "TPU_VISIBLE_CHIPS"
40
    simple_compile_backend: str = "openxla"
41

42
    supported_quantization: list[str] = ["fp8", "tpu_int8", "compressed-tensors"]
43

44
    additional_env_vars: list[str] = ["TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS"]
45

46
47
48
49
    @classmethod
    def import_core_kernels(cls) -> None:
        pass

50
    @classmethod
51
52
53
54
55
56
57
58
59
60
61
62
    def get_attn_backend_cls(
        cls,
        selected_backend: "_Backend",
        head_size: int,
        dtype: torch.dtype,
        kv_cache_dtype: Optional[str],
        block_size: int,
        use_v1: bool,
        use_mla: bool,
        has_sink,
        use_sparse,
    ) -> str:
63
        from vllm.attention.backends.registry import _Backend
64

65
        if use_sparse:
66
            raise NotImplementedError("Sparse Attention is not supported on TPU.")
67
        if selected_backend != _Backend.PALLAS:
68
            logger.info("Cannot use %s backend on TPU.", selected_backend)
69

70
71
72
73
        if not use_v1:
            raise ValueError("TPU backend only supports V1.")
        logger.info("Using Pallas V1 backend.")
        return "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
74

75
76
77
78
79
80
81
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
        torch.tpu.set_device(device)

82
83
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
84
85
        chip_type, _ = device.get_local_chips()
        return f"TPU {chip_type.name}"
86

87
88
89
90
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        raise NotImplementedError

91
92
93
94
95
    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU"

    @classmethod
96
    def get_infinity_values(cls, dtype: torch.dtype) -> tuple[float, float]:
97
98
99
100
101
102
103
104
105
106
        return torch.finfo(dtype).min, torch.finfo(dtype).max

    @classmethod
    def can_update_inplace(cls):
        return False

    @classmethod
    def get_lora_vocab_padding_size(cls) -> int:
        return 1

107
108
    @classmethod
    def inference_mode(cls):
109
        return torch.no_grad()
110
111
112

    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
113
        from vllm.config import CompilationLevel, CUDAGraphMode
114
115

        cache_config = vllm_config.cache_config
116
        # For v0, the default block size is 16.
117
        if cache_config and cache_config.block_size is None:
118
            cache_config.block_size = cast(BlockSize, 16)
119
        compilation_config = vllm_config.compilation_config
120
121
122

        # TPU only supports DYNAMO_ONCE compilation level
        if compilation_config.level != CompilationLevel.DYNAMO_ONCE:
123
124
125
            logger.info(
                "[TPU] Forcing DYNAMO_ONCE compilation level, and disabling cudagraph."
            )
126
            compilation_config.level = CompilationLevel.DYNAMO_ONCE
127

128
129
130
131
132
133
134
135
        if (
            compilation_config.cudagraph_mode is None
            or compilation_config.cudagraph_mode.max_cudagraph_mode()
            != CUDAGraphMode.NONE
        ):
            logger.info(
                "[TPU] CUDA graph is not supported on TPU, disabling cudagraphs."
            )
136
137
            compilation_config.cudagraph_mode = CUDAGraphMode.NONE

138
139
        if compilation_config.backend == "":
            compilation_config.backend = "openxla"
140

141
        assert vllm_config.speculative_config is None, (
142
            "TPU does not support speculative decoding"
143
        )
144

145
        model_config = vllm_config.model_config
146
147
148
149
        if model_config is not None and model_config.dtype in (
            torch.float16,
            torch.float32,
        ):
150
151
            logger.warning(
                "The TPU backend currently does not support %s. "
152
153
154
                "Using bfloat16 instead.",
                model_config.dtype,
            )
155
            model_config.dtype = torch.bfloat16
156

157
        from vllm.v1.attention.backends.pallas import PallasAttentionBackend
158
159

        cache_config.block_size = PallasAttentionBackend.get_page_size(vllm_config)  # type: ignore[assignment]
160

161
162
163
        parallel_config = vllm_config.parallel_config
        scheduler_config = vllm_config.scheduler_config
        if parallel_config.worker_cls == "auto":
164
            parallel_config.worker_cls = "vllm.v1.worker.tpu_worker.TPUWorker"
165
166

        assert not vllm_config.speculative_config, (
167
168
            "Speculative decoding is not yet supported for TPU backend"
        )
169

170
171
172
173
174
175
176
177
178
        if (
            scheduler_config.is_multimodal_model
            and not scheduler_config.disable_chunked_mm_input
        ):
            logger.warning(
                "TPU does not support running Multimodal models"
                " without setting `--disable_chunked_mm_input`. "
                "Forcing --disable_chunked_mm_input."
            )
179
180
            scheduler_config.disable_chunked_mm_input = True

181
        if model_config and model_config.use_mla:
182
183
            logger.info(
                "MLA is enabled on a non-GPU platform; forcing chunked "
184
185
                "prefill and prefix caching to be disabled."
            )
186
187
188
189
            vllm_config.scheduler_config.enable_chunked_prefill = False
            vllm_config.scheduler_config.chunked_prefill_enabled = False
            vllm_config.scheduler_config.max_num_batched_tokens = max(
                vllm_config.scheduler_config.max_model_len,
190
191
                DEFAULT_MAX_NUM_BATCHED_TOKENS,
            )
192

193
194
195
196
    @classmethod
    def is_pin_memory_available(cls):
        logger.warning("Pin memory is not supported on TPU.")
        return False
197
198
199
200

    @classmethod
    def get_device_communicator_cls(cls) -> str:
        return "vllm.distributed.device_communicators.tpu_communicator.TpuCommunicator"  # noqa
201
202
203
204

    @classmethod
    def use_all_gather(cls) -> bool:
        return True
205

206
    @classmethod
207
208
209
210
    def validate_request(
        cls,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
211
        processed_inputs: ProcessorInputs,
212
213
    ) -> None:
        """Raises if this request is unsupported on this platform"""
214
215
216
217
        if (
            isinstance(params, SamplingParams)
            and params.sampling_type == SamplingType.RANDOM_SEED
        ):
218
            raise ValueError("Torch XLA does not support per-request seed.")
219

220
    @classmethod
221
222
223
    def is_kv_cache_dtype_supported(
        cls, kv_cache_dtype: str, model_config: "ModelConfig"
    ) -> bool:
224
225
        return True

226
227
228
229
230
231
232
233
234
235
    @classmethod
    @torch.compile(backend="openxla")
    def insert_blocks_to_device(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        torch.ops.xla.dynamo_set_buffer_donor_(dst_cache, True)
236
        dst_cache[dst_block_indices] = src_cache[src_block_indices].to(dst_cache.device)
237
238
239
240
241
242
243
244
245
246

    @classmethod
    @torch.compile(backend="openxla")
    def swap_out_blocks_to_host(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
247
        """tpu blocks to cpu blocks"""
248
249
250
        torch.ops.xla.dynamo_set_buffer_donor_(src_cache, True)
        dst_cache[dst_block_indices] = src_cache[src_block_indices].cpu()

251
252
253
254
    @classmethod
    def use_sync_weight_loader(cls) -> bool:
        return True

255
256
257

try:
    from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform
258

259
    TpuPlatform = TpuCommonsPlatform  # type: ignore
260
    USE_TPU_COMMONS = True
261
262
263
except ImportError:
    logger.info("tpu_commons not found, using vLLM's TpuPlatform")
    pass