tpu.py 9.82 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import contextlib
5
from typing import TYPE_CHECKING, Optional, cast
6

7
import torch
8
from tpu_info import device
9

10
from vllm.attention.backends.registry import AttentionBackendEnum
11
from vllm.inputs import ProcessorInputs, PromptType
12
13
from vllm.logger import init_logger

14
from .interface import Platform, PlatformEnum
15

16
if TYPE_CHECKING:
17
18
    from typing import TypeAlias

19
    from vllm.config import VllmConfig
20
    from vllm.config.cache import BlockSize
21
    from vllm.pooling_params import PoolingParams
22
23
24
    from vllm.sampling_params import SamplingParams

    ParamsType: TypeAlias = SamplingParams | PoolingParams
25
else:
26
    BlockSize = None
27
    VllmConfig = None
28
    PoolingParams = None
29
    ParamsType = None
30

31
32
logger = init_logger(__name__)

33
USE_TPU_INFERENCE = False
34

35
36
37

class TpuPlatform(Platform):
    _enum = PlatformEnum.TPU
38
    device_name: str = "tpu"
39
    device_type: str = "tpu"
40
    dispatch_key: str = "XLA"
41
    ray_device_key: str = "TPU"
42
    dist_backend: str = "gloo"
43
    device_control_env_var: str = "TPU_VISIBLE_CHIPS"
44
    simple_compile_backend: str = "openxla"
45

46
    supported_quantization: list[str] = ["fp8", "tpu_int8", "compressed-tensors"]
47

48
    additional_env_vars: list[str] = ["TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS"]
49

50
    @classmethod
51
52
53
54
    def import_kernels(cls) -> None:
        # Do not import vllm._C
        with contextlib.suppress(ImportError):
            import vllm._moe_C  # noqa: F401
55

56
    @classmethod
57
58
    def get_attn_backend_cls(
        cls,
59
        selected_backend: "AttentionBackendEnum",
60
61
        head_size: int,
        dtype: torch.dtype,
62
        kv_cache_dtype: str | None,
63
64
        block_size: int,
        use_mla: bool,
65
66
67
        has_sink: bool,
        use_sparse: bool,
        use_mm_prefix: bool,
68
        attn_type: str | None = None,
69
    ) -> str:
70
        if use_sparse:
71
            raise NotImplementedError("Sparse Attention is not supported on TPU.")
72
        if selected_backend != AttentionBackendEnum.PALLAS:
73
            logger.info("Cannot use %s backend on TPU.", selected_backend)
74

75
        logger.info("Using Pallas V1 backend.")
76
        return AttentionBackendEnum.PALLAS.get_path()
77

78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
    @classmethod
    def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
        return [
            AttentionBackendEnum.PALLAS,
        ]

    @classmethod
    def get_vit_attn_backend(
        cls,
        head_size: int,
        dtype: torch.dtype,
        backend: Optional["AttentionBackendEnum"] = None,
    ) -> "AttentionBackendEnum":
        if backend is not None:
            assert backend in cls.get_supported_vit_attn_backends(), (
                f"Backend {backend} is not supported for vit attention"
                f"Supported backends are: {cls.get_supported_vit_attn_backends()}."
            )
            logger.info_once(f"Using backend {backend} for vit attention.")
            return backend

        logger.info_once(
            f"Using default backend {AttentionBackendEnum.PALLAS} for vit attention."
        )
        return AttentionBackendEnum.PALLAS

104
105
106
107
108
109
110
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
        torch.tpu.set_device(device)

111
112
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
113
114
        chip_type, _ = device.get_local_chips()
        return f"TPU {chip_type.name}"
115

116
117
118
119
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        raise NotImplementedError

120
121
122
123
124
    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU"

    @classmethod
125
    def get_infinity_values(cls, dtype: torch.dtype) -> tuple[float, float]:
126
127
128
129
130
131
132
133
134
135
        return torch.finfo(dtype).min, torch.finfo(dtype).max

    @classmethod
    def can_update_inplace(cls):
        return False

    @classmethod
    def get_lora_vocab_padding_size(cls) -> int:
        return 1

136
137
    @classmethod
    def inference_mode(cls):
138
        return torch.no_grad()
139
140
141

    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
142
        from vllm.config import CompilationMode, CUDAGraphMode
143
144

        cache_config = vllm_config.cache_config
145
        # For v0, the default block size is 16.
146
        if cache_config and cache_config.block_size is None:
147
            cache_config.block_size = cast(BlockSize, 16)
148
        compilation_config = vllm_config.compilation_config
149

150
151
        # TPU only supports DYNAMO_TRACE_ONCE compilation mode
        if compilation_config.mode != CompilationMode.DYNAMO_TRACE_ONCE:
152
            logger.info(
153
154
                "[TPU] Forcing DYNAMO_TRACE_ONCE compilation mode, and\
                disabling cudagraph."
155
            )
156
            compilation_config.mode = CompilationMode.DYNAMO_TRACE_ONCE
157

158
159
160
161
162
163
164
165
        if (
            compilation_config.cudagraph_mode is None
            or compilation_config.cudagraph_mode.max_cudagraph_mode()
            != CUDAGraphMode.NONE
        ):
            logger.info(
                "[TPU] CUDA graph is not supported on TPU, disabling cudagraphs."
            )
166
167
            compilation_config.cudagraph_mode = CUDAGraphMode.NONE

168
169
        if compilation_config.backend == "":
            compilation_config.backend = "openxla"
170

171
        assert vllm_config.speculative_config is None, (
172
            "TPU does not support speculative decoding"
173
        )
174

175
        model_config = vllm_config.model_config
176
177
178
179
        if model_config is not None and model_config.dtype in (
            torch.float16,
            torch.float32,
        ):
180
181
            logger.warning(
                "The TPU backend currently does not support %s. "
182
183
184
                "Using bfloat16 instead.",
                model_config.dtype,
            )
185
            model_config.dtype = torch.bfloat16
186

187
        from vllm.v1.attention.backends.pallas import PallasAttentionBackend
188
189

        cache_config.block_size = PallasAttentionBackend.get_page_size(vllm_config)  # type: ignore[assignment]
190

191
192
193
        parallel_config = vllm_config.parallel_config
        scheduler_config = vllm_config.scheduler_config
        if parallel_config.worker_cls == "auto":
194
            parallel_config.worker_cls = "vllm.v1.worker.tpu_worker.TPUWorker"
195
196

        assert not vllm_config.speculative_config, (
197
198
            "Speculative decoding is not yet supported for TPU backend"
        )
199

200
201
202
203
204
205
206
207
208
        if (
            scheduler_config.is_multimodal_model
            and not scheduler_config.disable_chunked_mm_input
        ):
            logger.warning(
                "TPU does not support running Multimodal models"
                " without setting `--disable_chunked_mm_input`. "
                "Forcing --disable_chunked_mm_input."
            )
209
210
            scheduler_config.disable_chunked_mm_input = True

211
        if model_config and model_config.use_mla:
212
213
            logger.info(
                "MLA is enabled on a non-GPU platform; forcing chunked "
214
215
                "prefill and prefix caching to be disabled."
            )
216
217
            vllm_config.scheduler_config.enable_chunked_prefill = False
            vllm_config.scheduler_config.max_num_batched_tokens = max(
218
                vllm_config.model_config.max_model_len,
219
                vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS,
220
            )
221

222
223
224
225
    @classmethod
    def is_pin_memory_available(cls):
        logger.warning("Pin memory is not supported on TPU.")
        return False
226
227
228
229

    @classmethod
    def get_device_communicator_cls(cls) -> str:
        return "vllm.distributed.device_communicators.tpu_communicator.TpuCommunicator"  # noqa
230

231
    @classmethod
232
233
234
    def validate_request(
        cls,
        prompt: PromptType,
235
        params: ParamsType,
236
        processed_inputs: ProcessorInputs,
237
238
    ) -> None:
        """Raises if this request is unsupported on this platform"""
239
240
        from vllm.sampling_params import SamplingParams, SamplingType

241
242
243
244
        if (
            isinstance(params, SamplingParams)
            and params.sampling_type == SamplingType.RANDOM_SEED
        ):
245
            raise ValueError("Torch XLA does not support per-request seed.")
246

247
248
249
250
251
252
253
254
255
256
    @classmethod
    @torch.compile(backend="openxla")
    def insert_blocks_to_device(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        torch.ops.xla.dynamo_set_buffer_donor_(dst_cache, True)
257
        dst_cache[dst_block_indices] = src_cache[src_block_indices].to(dst_cache.device)
258
259
260
261
262
263
264
265
266
267

    @classmethod
    @torch.compile(backend="openxla")
    def swap_out_blocks_to_host(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
268
        """tpu blocks to cpu blocks"""
269
270
271
        torch.ops.xla.dynamo_set_buffer_donor_(src_cache, True)
        dst_cache[dst_block_indices] = src_cache[src_block_indices].cpu()

272
273
274
275
    @classmethod
    def use_sync_weight_loader(cls) -> bool:
        return True

276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
    @classmethod
    def check_max_model_len(cls, max_model_len: int) -> int:
        """
        Check max_model_len for the current platform.
        """
        logger.warning(
            "--max-model-len is not specified, "
            "it's currently using model's default length %d, "
            "which might be too large."
            "Please input with --max-model-len based on your "
            "request input length and output length, to avoid "
            "unnecessary degradation.",
            max_model_len,
        )
        return max_model_len

292
293

try:
294
    from tpu_inference.platforms import (
Johnny Yang's avatar
Johnny Yang committed
295
296
        TpuPlatform as TpuInferencePlatform,
    )
297

298
299
    TpuPlatform = TpuInferencePlatform  # type: ignore
    USE_TPU_INFERENCE = True
300
except ImportError:
301
    logger.info("tpu_inference not found, using vLLM's TpuPlatform")
302
    pass