tpu.py 8.92 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import contextlib
5
from typing import TYPE_CHECKING, cast
6

7
import torch
8
from tpu_info import device
9

10
from vllm.inputs import ProcessorInputs, PromptType
11
from vllm.logger import init_logger
12
from vllm.sampling_params import SamplingParams, SamplingType
13
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
14

15
from .interface import Platform, PlatformEnum
16

17
if TYPE_CHECKING:
18
19
    from vllm.attention.backends.registry import AttentionBackendEnum
    from vllm.config import VllmConfig
20
    from vllm.config.cache import BlockSize
21
    from vllm.pooling_params import PoolingParams
22
else:
23
    BlockSize = None
24
    VllmConfig = None
25
    PoolingParams = None
26
    AttentionBackendEnum = None
27

28
29
logger = init_logger(__name__)

30
USE_TPU_INFERENCE = False
31

32
33
34

class TpuPlatform(Platform):
    _enum = PlatformEnum.TPU
35
    device_name: str = "tpu"
36
    device_type: str = "tpu"
37
    dispatch_key: str = "XLA"
38
    ray_device_key: str = "TPU"
39
    dist_backend: str = "gloo"
40
    device_control_env_var: str = "TPU_VISIBLE_CHIPS"
41
    simple_compile_backend: str = "openxla"
42

43
    supported_quantization: list[str] = ["fp8", "tpu_int8", "compressed-tensors"]
44

45
    additional_env_vars: list[str] = ["TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS"]
46

47
    @classmethod
48
49
50
51
    def import_kernels(cls) -> None:
        # Do not import vllm._C
        with contextlib.suppress(ImportError):
            import vllm._moe_C  # noqa: F401
52

53
    @classmethod
54
55
    def get_attn_backend_cls(
        cls,
56
        selected_backend: "AttentionBackendEnum",
57
58
        head_size: int,
        dtype: torch.dtype,
59
        kv_cache_dtype: str | None,
60
61
62
63
        block_size: int,
        use_mla: bool,
        has_sink,
        use_sparse,
64
        attn_type: str | None = None,
65
    ) -> str:
66
        from vllm.attention.backends.registry import AttentionBackendEnum
67

68
        if use_sparse:
69
            raise NotImplementedError("Sparse Attention is not supported on TPU.")
70
        if selected_backend != AttentionBackendEnum.PALLAS:
71
            logger.info("Cannot use %s backend on TPU.", selected_backend)
72

73
        logger.info("Using Pallas V1 backend.")
74
        return AttentionBackendEnum.PALLAS.get_path()
75

76
77
78
79
80
81
82
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
        torch.tpu.set_device(device)

83
84
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
85
86
        chip_type, _ = device.get_local_chips()
        return f"TPU {chip_type.name}"
87

88
89
90
91
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        raise NotImplementedError

92
93
94
95
96
    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU"

    @classmethod
97
    def get_infinity_values(cls, dtype: torch.dtype) -> tuple[float, float]:
98
99
100
101
102
103
104
105
106
107
        return torch.finfo(dtype).min, torch.finfo(dtype).max

    @classmethod
    def can_update_inplace(cls):
        return False

    @classmethod
    def get_lora_vocab_padding_size(cls) -> int:
        return 1

108
109
    @classmethod
    def inference_mode(cls):
110
        return torch.no_grad()
111
112
113

    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
114
        from vllm.config import CompilationMode, CUDAGraphMode
115
116

        cache_config = vllm_config.cache_config
117
        # For v0, the default block size is 16.
118
        if cache_config and cache_config.block_size is None:
119
            cache_config.block_size = cast(BlockSize, 16)
120
        compilation_config = vllm_config.compilation_config
121

122
123
        # TPU only supports DYNAMO_TRACE_ONCE compilation mode
        if compilation_config.mode != CompilationMode.DYNAMO_TRACE_ONCE:
124
            logger.info(
125
126
                "[TPU] Forcing DYNAMO_TRACE_ONCE compilation mode, and\
                disabling cudagraph."
127
            )
128
            compilation_config.mode = CompilationMode.DYNAMO_TRACE_ONCE
129

130
131
132
133
134
135
136
137
        if (
            compilation_config.cudagraph_mode is None
            or compilation_config.cudagraph_mode.max_cudagraph_mode()
            != CUDAGraphMode.NONE
        ):
            logger.info(
                "[TPU] CUDA graph is not supported on TPU, disabling cudagraphs."
            )
138
139
            compilation_config.cudagraph_mode = CUDAGraphMode.NONE

140
141
        if compilation_config.backend == "":
            compilation_config.backend = "openxla"
142

143
        assert vllm_config.speculative_config is None, (
144
            "TPU does not support speculative decoding"
145
        )
146

147
        model_config = vllm_config.model_config
148
149
150
151
        if model_config is not None and model_config.dtype in (
            torch.float16,
            torch.float32,
        ):
152
153
            logger.warning(
                "The TPU backend currently does not support %s. "
154
155
156
                "Using bfloat16 instead.",
                model_config.dtype,
            )
157
            model_config.dtype = torch.bfloat16
158

159
        from vllm.v1.attention.backends.pallas import PallasAttentionBackend
160
161

        cache_config.block_size = PallasAttentionBackend.get_page_size(vllm_config)  # type: ignore[assignment]
162

163
164
165
        parallel_config = vllm_config.parallel_config
        scheduler_config = vllm_config.scheduler_config
        if parallel_config.worker_cls == "auto":
166
            parallel_config.worker_cls = "vllm.v1.worker.tpu_worker.TPUWorker"
167
168

        assert not vllm_config.speculative_config, (
169
170
            "Speculative decoding is not yet supported for TPU backend"
        )
171

172
173
174
175
176
177
178
179
180
        if (
            scheduler_config.is_multimodal_model
            and not scheduler_config.disable_chunked_mm_input
        ):
            logger.warning(
                "TPU does not support running Multimodal models"
                " without setting `--disable_chunked_mm_input`. "
                "Forcing --disable_chunked_mm_input."
            )
181
182
            scheduler_config.disable_chunked_mm_input = True

183
        if model_config and model_config.use_mla:
184
185
            logger.info(
                "MLA is enabled on a non-GPU platform; forcing chunked "
186
187
                "prefill and prefix caching to be disabled."
            )
188
189
190
191
            vllm_config.scheduler_config.enable_chunked_prefill = False
            vllm_config.scheduler_config.chunked_prefill_enabled = False
            vllm_config.scheduler_config.max_num_batched_tokens = max(
                vllm_config.scheduler_config.max_model_len,
192
193
                DEFAULT_MAX_NUM_BATCHED_TOKENS,
            )
194

195
196
197
198
    @classmethod
    def is_pin_memory_available(cls):
        logger.warning("Pin memory is not supported on TPU.")
        return False
199
200
201
202

    @classmethod
    def get_device_communicator_cls(cls) -> str:
        return "vllm.distributed.device_communicators.tpu_communicator.TpuCommunicator"  # noqa
203

204
    @classmethod
205
206
207
    def validate_request(
        cls,
        prompt: PromptType,
208
        params: SamplingParams | PoolingParams,
209
        processed_inputs: ProcessorInputs,
210
211
    ) -> None:
        """Raises if this request is unsupported on this platform"""
212
213
214
215
        if (
            isinstance(params, SamplingParams)
            and params.sampling_type == SamplingType.RANDOM_SEED
        ):
216
            raise ValueError("Torch XLA does not support per-request seed.")
217

218
219
220
221
222
223
224
225
226
227
    @classmethod
    @torch.compile(backend="openxla")
    def insert_blocks_to_device(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        torch.ops.xla.dynamo_set_buffer_donor_(dst_cache, True)
228
        dst_cache[dst_block_indices] = src_cache[src_block_indices].to(dst_cache.device)
229
230
231
232
233
234
235
236
237
238

    @classmethod
    @torch.compile(backend="openxla")
    def swap_out_blocks_to_host(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
239
        """tpu blocks to cpu blocks"""
240
241
242
        torch.ops.xla.dynamo_set_buffer_donor_(src_cache, True)
        dst_cache[dst_block_indices] = src_cache[src_block_indices].cpu()

243
244
245
246
    @classmethod
    def use_sync_weight_loader(cls) -> bool:
        return True

247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
    @classmethod
    def check_max_model_len(cls, max_model_len: int) -> int:
        """
        Check max_model_len for the current platform.
        """
        logger.warning(
            "--max-model-len is not specified, "
            "it's currently using model's default length %d, "
            "which might be too large."
            "Please input with --max-model-len based on your "
            "request input length and output length, to avoid "
            "unnecessary degradation.",
            max_model_len,
        )
        return max_model_len

263
264

try:
265
    from tpu_inference.platforms import TpuPlatform as TpuInferencePlatform
266

267
268
    TpuPlatform = TpuInferencePlatform  # type: ignore
    USE_TPU_INFERENCE = True
269
except ImportError:
270
    logger.info("tpu_inference not found, using vLLM's TpuPlatform")
271
    pass