tpu.py 8.44 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import TYPE_CHECKING, Optional, Union, cast
5

6
import torch
7
from tpu_info import device
8

9
from vllm.inputs import ProcessorInputs, PromptType
10
from vllm.logger import init_logger
11
from vllm.sampling_params import SamplingParams, SamplingType
12
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
13

14
from .interface import Platform, PlatformEnum
15

16
if TYPE_CHECKING:
17
    from vllm.attention.backends.registry import _Backend
18
    from vllm.config import BlockSize, ModelConfig, VllmConfig
19
    from vllm.pooling_params import PoolingParams
20
else:
21
    BlockSize = None
22
    ModelConfig = None
23
    VllmConfig = None
24
    PoolingParams = None
25
    _Backend = None
26

27
28
logger = init_logger(__name__)

29
30
USE_TPU_COMMONS = False

31
32
33

class TpuPlatform(Platform):
    _enum = PlatformEnum.TPU
34
    device_name: str = "tpu"
35
    device_type: str = "tpu"
36
    dispatch_key: str = "XLA"
37
    ray_device_key: str = "TPU"
38
    dist_backend: str = "gloo"
39
    device_control_env_var: str = "TPU_VISIBLE_CHIPS"
40
    simple_compile_backend: str = "openxla"
41

42
43
44
    supported_quantization: list[str] = [
        "fp8", "tpu_int8", "compressed-tensors"
    ]
45

46
47
48
49
    additional_env_vars: list[str] = [
        "TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS"
    ]

50
51
52
53
    @classmethod
    def import_core_kernels(cls) -> None:
        pass

54
    @classmethod
55
    def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
56
                             dtype: torch.dtype, kv_cache_dtype: Optional[str],
57
                             block_size: int, use_v1: bool, use_mla: bool,
58
                             has_sink, use_sparse) -> str:
59
        from vllm.attention.backends.registry import _Backend
60
61
62
        if use_sparse:
            raise NotImplementedError(
                "Sparse Attention is not supported on TPU.")
63
        if selected_backend != _Backend.PALLAS:
64
            logger.info("Cannot use %s backend on TPU.", selected_backend)
65

66
67
68
69
        if not use_v1:
            raise ValueError("TPU backend only supports V1.")
        logger.info("Using Pallas V1 backend.")
        return "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
70

71
72
73
74
75
76
77
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
        torch.tpu.set_device(device)

78
79
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
80
81
        chip_type, _ = device.get_local_chips()
        return f"TPU {chip_type.name}"
82

83
84
85
86
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        raise NotImplementedError

87
88
89
90
91
    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU"

    @classmethod
92
    def get_infinity_values(cls, dtype: torch.dtype) -> tuple[float, float]:
93
94
95
96
97
98
99
100
101
102
        return torch.finfo(dtype).min, torch.finfo(dtype).max

    @classmethod
    def can_update_inplace(cls):
        return False

    @classmethod
    def get_lora_vocab_padding_size(cls) -> int:
        return 1

103
104
    @classmethod
    def inference_mode(cls):
105
        return torch.no_grad()
106
107
108

    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
109
        from vllm.config import CompilationLevel, CUDAGraphMode
110
111

        cache_config = vllm_config.cache_config
112
        # For v0, the default block size is 16.
113
        if cache_config and cache_config.block_size is None:
114
            cache_config.block_size = cast(BlockSize, 16)
115
        compilation_config = vllm_config.compilation_config
116
117
118

        # TPU only supports DYNAMO_ONCE compilation level
        if compilation_config.level != CompilationLevel.DYNAMO_ONCE:
119
120
            logger.info("[TPU] Forcing DYNAMO_ONCE compilation level, and "
                        "disabling cudagraph.")
121
            compilation_config.level = CompilationLevel.DYNAMO_ONCE
122

123
124
125
126
127
128
129
        if compilation_config.cudagraph_mode is None or \
                compilation_config.cudagraph_mode.max_cudagraph_mode() \
                    != CUDAGraphMode.NONE:
            logger.info("[TPU] CUDA graph is not supported on TPU, "
                        "disabling cudagraphs.")
            compilation_config.cudagraph_mode = CUDAGraphMode.NONE

130
131
        if compilation_config.backend == "":
            compilation_config.backend = "openxla"
132
133
134
135

        assert vllm_config.speculative_config is None, \
            "TPU does not support speculative decoding"

136
137
138
        model_config = vllm_config.model_config
        if model_config is not None and model_config.dtype in (torch.float16,
                                                               torch.float32):
139
140
            logger.warning(
                "The TPU backend currently does not support %s. "
141
142
                "Using bfloat16 instead.", model_config.dtype)
            model_config.dtype = torch.bfloat16
143

144
145
146
        from vllm.v1.attention.backends.pallas import PallasAttentionBackend
        cache_config.block_size = PallasAttentionBackend.get_page_size(
            vllm_config)  # type: ignore[assignment]
147

148
149
150
        parallel_config = vllm_config.parallel_config
        scheduler_config = vllm_config.scheduler_config
        if parallel_config.worker_cls == "auto":
151
            parallel_config.worker_cls = "vllm.v1.worker.tpu_worker.TPUWorker"
152
153
154
155

        assert not vllm_config.speculative_config, (
            "Speculative decoding is not yet supported for TPU backend")

156
        if scheduler_config.is_multimodal_model and not \
157
                scheduler_config.disable_chunked_mm_input:
158
159
160
161
162
            logger.warning("TPU does not support running Multimodal models"\
            " without setting `--disable_chunked_mm_input`. " \
            "Forcing --disable_chunked_mm_input.")
            scheduler_config.disable_chunked_mm_input = True

163
        if model_config and model_config.use_mla:
164
165
166
167
168
169
170
171
172
            logger.info(
                "MLA is enabled on a non-GPU platform; forcing chunked "
                "prefill and prefix caching to be disabled.")
            vllm_config.scheduler_config.enable_chunked_prefill = False
            vllm_config.scheduler_config.chunked_prefill_enabled = False
            vllm_config.scheduler_config.max_num_batched_tokens = max(
                vllm_config.scheduler_config.max_model_len,
                DEFAULT_MAX_NUM_BATCHED_TOKENS)

173
174
175
176
    @classmethod
    def is_pin_memory_available(cls):
        logger.warning("Pin memory is not supported on TPU.")
        return False
177
178
179
180

    @classmethod
    def get_device_communicator_cls(cls) -> str:
        return "vllm.distributed.device_communicators.tpu_communicator.TpuCommunicator"  # noqa
181
182
183
184

    @classmethod
    def use_all_gather(cls) -> bool:
        return True
185

186
    @classmethod
187
188
189
190
    def validate_request(
        cls,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
191
        processed_inputs: ProcessorInputs,
192
193
    ) -> None:
        """Raises if this request is unsupported on this platform"""
194
195
196
        if (isinstance(params, SamplingParams)
                and params.sampling_type == SamplingType.RANDOM_SEED):
            raise ValueError("Torch XLA does not support per-request seed.")
197

198
    @classmethod
199
200
    def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str,
                                    model_config: "ModelConfig") -> bool:
201
202
        return True

203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
    @classmethod
    @torch.compile(backend="openxla")
    def insert_blocks_to_device(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        torch.ops.xla.dynamo_set_buffer_donor_(dst_cache, True)
        dst_cache[dst_block_indices] = src_cache[src_block_indices].to(
            dst_cache.device)

    @classmethod
    @torch.compile(backend="openxla")
    def swap_out_blocks_to_host(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """ tpu blocks to cpu blocks"""
        torch.ops.xla.dynamo_set_buffer_donor_(src_cache, True)
        dst_cache[dst_block_indices] = src_cache[src_block_indices].cpu()

229
230
231
232
    @classmethod
    def use_sync_weight_loader(cls) -> bool:
        return True

233
234
235
236

try:
    from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform
    TpuPlatform = TpuCommonsPlatform  # type: ignore
237
    USE_TPU_COMMONS = True
238
239
240
except ImportError:
    logger.info("tpu_commons not found, using vLLM's TpuPlatform")
    pass