tpu.py 8.62 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import contextlib
5
from typing import TYPE_CHECKING, cast
6

7
import torch
8
from tpu_info import device
9

10
from vllm.inputs import ProcessorInputs, PromptType
11
from vllm.logger import init_logger
12
from vllm.sampling_params import SamplingParams, SamplingType
13
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
14

15
from .interface import Platform, PlatformEnum
16

17
if TYPE_CHECKING:
18
    from vllm.attention.backends.registry import _Backend
19
20
    from vllm.config import ModelConfig, VllmConfig
    from vllm.config.cache import BlockSize
21
    from vllm.pooling_params import PoolingParams
22
else:
23
    BlockSize = None
24
    ModelConfig = None
25
    VllmConfig = None
26
    PoolingParams = None
27
    _Backend = None
28

29
30
logger = init_logger(__name__)

31
USE_TPU_INFERENCE = False
32

33
34
35

class TpuPlatform(Platform):
    _enum = PlatformEnum.TPU
36
    device_name: str = "tpu"
37
    device_type: str = "tpu"
38
    dispatch_key: str = "XLA"
39
    ray_device_key: str = "TPU"
40
    dist_backend: str = "gloo"
41
    device_control_env_var: str = "TPU_VISIBLE_CHIPS"
42
    simple_compile_backend: str = "openxla"
43

44
    supported_quantization: list[str] = ["fp8", "tpu_int8", "compressed-tensors"]
45

46
    additional_env_vars: list[str] = ["TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS"]
47

48
    @classmethod
49
50
51
52
    def import_kernels(cls) -> None:
        # Do not import vllm._C
        with contextlib.suppress(ImportError):
            import vllm._moe_C  # noqa: F401
53

54
    @classmethod
55
56
57
58
59
    def get_attn_backend_cls(
        cls,
        selected_backend: "_Backend",
        head_size: int,
        dtype: torch.dtype,
60
        kv_cache_dtype: str | None,
61
62
63
64
65
66
        block_size: int,
        use_v1: bool,
        use_mla: bool,
        has_sink,
        use_sparse,
    ) -> str:
67
        from vllm.attention.backends.registry import _Backend
68

69
        if use_sparse:
70
            raise NotImplementedError("Sparse Attention is not supported on TPU.")
71
        if selected_backend != _Backend.PALLAS:
72
            logger.info("Cannot use %s backend on TPU.", selected_backend)
73

74
75
76
77
        if not use_v1:
            raise ValueError("TPU backend only supports V1.")
        logger.info("Using Pallas V1 backend.")
        return "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
78

79
80
81
82
83
84
85
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
        torch.tpu.set_device(device)

86
87
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
88
89
        chip_type, _ = device.get_local_chips()
        return f"TPU {chip_type.name}"
90

91
92
93
94
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        raise NotImplementedError

95
96
97
98
99
    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU"

    @classmethod
100
    def get_infinity_values(cls, dtype: torch.dtype) -> tuple[float, float]:
101
102
103
104
105
106
107
108
109
110
        return torch.finfo(dtype).min, torch.finfo(dtype).max

    @classmethod
    def can_update_inplace(cls):
        return False

    @classmethod
    def get_lora_vocab_padding_size(cls) -> int:
        return 1

111
112
    @classmethod
    def inference_mode(cls):
113
        return torch.no_grad()
114
115
116

    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
117
        from vllm.config import CompilationLevel, CUDAGraphMode
118
119

        cache_config = vllm_config.cache_config
120
        # For v0, the default block size is 16.
121
        if cache_config and cache_config.block_size is None:
122
            cache_config.block_size = cast(BlockSize, 16)
123
        compilation_config = vllm_config.compilation_config
124
125
126

        # TPU only supports DYNAMO_ONCE compilation level
        if compilation_config.level != CompilationLevel.DYNAMO_ONCE:
127
128
129
            logger.info(
                "[TPU] Forcing DYNAMO_ONCE compilation level, and disabling cudagraph."
            )
130
            compilation_config.level = CompilationLevel.DYNAMO_ONCE
131

132
133
134
135
136
137
138
139
        if (
            compilation_config.cudagraph_mode is None
            or compilation_config.cudagraph_mode.max_cudagraph_mode()
            != CUDAGraphMode.NONE
        ):
            logger.info(
                "[TPU] CUDA graph is not supported on TPU, disabling cudagraphs."
            )
140
141
            compilation_config.cudagraph_mode = CUDAGraphMode.NONE

142
143
        if compilation_config.backend == "":
            compilation_config.backend = "openxla"
144

145
        assert vllm_config.speculative_config is None, (
146
            "TPU does not support speculative decoding"
147
        )
148

149
        model_config = vllm_config.model_config
150
151
152
153
        if model_config is not None and model_config.dtype in (
            torch.float16,
            torch.float32,
        ):
154
155
            logger.warning(
                "The TPU backend currently does not support %s. "
156
157
158
                "Using bfloat16 instead.",
                model_config.dtype,
            )
159
            model_config.dtype = torch.bfloat16
160

161
        from vllm.v1.attention.backends.pallas import PallasAttentionBackend
162
163

        cache_config.block_size = PallasAttentionBackend.get_page_size(vllm_config)  # type: ignore[assignment]
164

165
166
167
        parallel_config = vllm_config.parallel_config
        scheduler_config = vllm_config.scheduler_config
        if parallel_config.worker_cls == "auto":
168
            parallel_config.worker_cls = "vllm.v1.worker.tpu_worker.TPUWorker"
169
170

        assert not vllm_config.speculative_config, (
171
172
            "Speculative decoding is not yet supported for TPU backend"
        )
173

174
175
176
177
178
179
180
181
182
        if (
            scheduler_config.is_multimodal_model
            and not scheduler_config.disable_chunked_mm_input
        ):
            logger.warning(
                "TPU does not support running Multimodal models"
                " without setting `--disable_chunked_mm_input`. "
                "Forcing --disable_chunked_mm_input."
            )
183
184
            scheduler_config.disable_chunked_mm_input = True

185
        if model_config and model_config.use_mla:
186
187
            logger.info(
                "MLA is enabled on a non-GPU platform; forcing chunked "
188
189
                "prefill and prefix caching to be disabled."
            )
190
191
192
193
            vllm_config.scheduler_config.enable_chunked_prefill = False
            vllm_config.scheduler_config.chunked_prefill_enabled = False
            vllm_config.scheduler_config.max_num_batched_tokens = max(
                vllm_config.scheduler_config.max_model_len,
194
195
                DEFAULT_MAX_NUM_BATCHED_TOKENS,
            )
196

197
198
199
200
    @classmethod
    def is_pin_memory_available(cls):
        logger.warning("Pin memory is not supported on TPU.")
        return False
201
202
203
204

    @classmethod
    def get_device_communicator_cls(cls) -> str:
        return "vllm.distributed.device_communicators.tpu_communicator.TpuCommunicator"  # noqa
205
206
207
208

    @classmethod
    def use_all_gather(cls) -> bool:
        return True
209

210
    @classmethod
211
212
213
    def validate_request(
        cls,
        prompt: PromptType,
214
        params: SamplingParams | PoolingParams,
215
        processed_inputs: ProcessorInputs,
216
217
    ) -> None:
        """Raises if this request is unsupported on this platform"""
218
219
220
221
        if (
            isinstance(params, SamplingParams)
            and params.sampling_type == SamplingType.RANDOM_SEED
        ):
222
            raise ValueError("Torch XLA does not support per-request seed.")
223

224
    @classmethod
225
226
227
    def is_kv_cache_dtype_supported(
        cls, kv_cache_dtype: str, model_config: "ModelConfig"
    ) -> bool:
228
229
        return True

230
231
232
233
234
235
236
237
238
239
    @classmethod
    @torch.compile(backend="openxla")
    def insert_blocks_to_device(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        torch.ops.xla.dynamo_set_buffer_donor_(dst_cache, True)
240
        dst_cache[dst_block_indices] = src_cache[src_block_indices].to(dst_cache.device)
241
242
243
244
245
246
247
248
249
250

    @classmethod
    @torch.compile(backend="openxla")
    def swap_out_blocks_to_host(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
251
        """tpu blocks to cpu blocks"""
252
253
254
        torch.ops.xla.dynamo_set_buffer_donor_(src_cache, True)
        dst_cache[dst_block_indices] = src_cache[src_block_indices].cpu()

255
256
257
258
    @classmethod
    def use_sync_weight_loader(cls) -> bool:
        return True

259
260

try:
261
    from tpu_inference.platforms import TpuPlatform as TpuInferencePlatform
262

263
264
    TpuPlatform = TpuInferencePlatform  # type: ignore
    USE_TPU_INFERENCE = True
265
except ImportError:
266
    logger.info("tpu_inference not found, using vLLM's TpuPlatform")
267
    pass