"lib/llm/src/kv_router/indexer/mod.rs" did not exist on "1d34af75ed36aacacf99ea99f83b16f5db0a32ed"
tpu.py 8.74 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import TYPE_CHECKING, Optional, Union, cast
5

6
import torch
7
from tpu_info import device
8

9
import vllm.envs as envs
10
from vllm.inputs import ProcessorInputs, PromptType
11
from vllm.logger import init_logger
12
from vllm.sampling_params import SamplingParams, SamplingType
13
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
14
15

from .interface import Platform, PlatformEnum, _Backend
16

17
if TYPE_CHECKING:
18
    from vllm.config import BlockSize, ModelConfig, VllmConfig
19
    from vllm.pooling_params import PoolingParams
20
else:
21
    BlockSize = None
22
    ModelConfig = None
23
    VllmConfig = None
24
    PoolingParams = None
25

26
27
logger = init_logger(__name__)

28
29
USE_TPU_COMMONS = False

30
31
32

class TpuPlatform(Platform):
    _enum = PlatformEnum.TPU
33
    device_name: str = "tpu"
34
    device_type: str = "tpu"
35
    dispatch_key: str = "XLA"
36
    ray_device_key: str = "TPU"
37
    dist_backend: str = "gloo"
38
    device_control_env_var: str = "TPU_VISIBLE_CHIPS"
39
    simple_compile_backend: str = "openxla"
40

41
42
43
    supported_quantization: list[str] = [
        "fp8", "tpu_int8", "compressed-tensors"
    ]
44

45
46
47
48
    additional_env_vars: list[str] = [
        "TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS"
    ]

49
    @classmethod
50
51
    def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
                             dtype: torch.dtype, kv_cache_dtype: Optional[str],
52
53
                             block_size: int, use_v1: bool, use_mla: bool,
                             has_sink) -> str:
54
55
        if (selected_backend != _Backend.PALLAS
                and selected_backend != _Backend.PALLAS_VLLM_V1):
56
            logger.info("Cannot use %s backend on TPU.", selected_backend)
57

58
59
60
61
62
63
        if use_v1:
            logger.info("Using Pallas V1 backend.")
            return "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
        else:
            logger.info("Using Pallas backend.")
            return "vllm.attention.backends.pallas.PallasAttentionBackend"
64

65
66
67
68
69
70
71
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
        torch.tpu.set_device(device)

72
73
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
74
75
        chip_type, _ = device.get_local_chips()
        return f"TPU {chip_type.name}"
76

77
78
79
80
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        raise NotImplementedError

81
82
    @classmethod
    def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
83
        return not envs.VLLM_USE_V1
84

85
86
87
88
89
    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU"

    @classmethod
90
    def get_infinity_values(cls, dtype: torch.dtype) -> tuple[float, float]:
91
92
93
94
95
96
97
98
99
100
        return torch.finfo(dtype).min, torch.finfo(dtype).max

    @classmethod
    def can_update_inplace(cls):
        return False

    @classmethod
    def get_lora_vocab_padding_size(cls) -> int:
        return 1

101
102
    @classmethod
    def inference_mode(cls):
103
        return torch.no_grad()
104
105
106

    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
107
        from vllm.config import CompilationLevel, CUDAGraphMode
108
109

        cache_config = vllm_config.cache_config
110
        # For v0, the default block size is 16.
111
        if cache_config and cache_config.block_size is None:
112
            cache_config.block_size = cast(BlockSize, 16)
113
        compilation_config = vllm_config.compilation_config
114
115
116

        # TPU only supports DYNAMO_ONCE compilation level
        if compilation_config.level != CompilationLevel.DYNAMO_ONCE:
117
118
            logger.info("[TPU] Forcing DYNAMO_ONCE compilation level, and "
                        "disabling cudagraph.")
119
            compilation_config.level = CompilationLevel.DYNAMO_ONCE
120

121
122
123
124
125
126
127
        if compilation_config.cudagraph_mode is None or \
                compilation_config.cudagraph_mode.max_cudagraph_mode() \
                    != CUDAGraphMode.NONE:
            logger.info("[TPU] CUDA graph is not supported on TPU, "
                        "disabling cudagraphs.")
            compilation_config.cudagraph_mode = CUDAGraphMode.NONE

128
129
        if compilation_config.backend == "":
            compilation_config.backend = "openxla"
130
131
132
133

        assert vllm_config.speculative_config is None, \
            "TPU does not support speculative decoding"

134
135
136
        model_config = vllm_config.model_config
        if model_config is not None and model_config.dtype in (torch.float16,
                                                               torch.float32):
137
138
            logger.warning(
                "The TPU backend currently does not support %s. "
139
140
                "Using bfloat16 instead.", model_config.dtype)
            model_config.dtype = torch.bfloat16
141

142
143
144
145
146
        if envs.VLLM_USE_V1:
            from vllm.v1.attention.backends.pallas import (
                PallasAttentionBackend)
            cache_config.block_size = PallasAttentionBackend.get_page_size(
                vllm_config)  # type: ignore[assignment]
147

148
149
150
        parallel_config = vllm_config.parallel_config
        scheduler_config = vllm_config.scheduler_config
        if parallel_config.worker_cls == "auto":
151
            parallel_config.worker_cls = "vllm.v1.worker.tpu_worker.TPUWorker"
152
153
154
155

        assert not vllm_config.speculative_config, (
            "Speculative decoding is not yet supported for TPU backend")

156
        if scheduler_config.is_multimodal_model and not \
157
                scheduler_config.disable_chunked_mm_input:
158
159
160
161
162
            logger.warning("TPU does not support running Multimodal models"\
            " without setting `--disable_chunked_mm_input`. " \
            "Forcing --disable_chunked_mm_input.")
            scheduler_config.disable_chunked_mm_input = True

163
        if model_config and model_config.use_mla:
164
165
166
167
168
169
170
171
172
            logger.info(
                "MLA is enabled on a non-GPU platform; forcing chunked "
                "prefill and prefix caching to be disabled.")
            vllm_config.scheduler_config.enable_chunked_prefill = False
            vllm_config.scheduler_config.chunked_prefill_enabled = False
            vllm_config.scheduler_config.max_num_batched_tokens = max(
                vllm_config.scheduler_config.max_model_len,
                DEFAULT_MAX_NUM_BATCHED_TOKENS)

173
174
175
176
    @classmethod
    def is_pin_memory_available(cls):
        logger.warning("Pin memory is not supported on TPU.")
        return False
177
178
179
180

    @classmethod
    def get_device_communicator_cls(cls) -> str:
        return "vllm.distributed.device_communicators.tpu_communicator.TpuCommunicator"  # noqa
181
182
183
184

    @classmethod
    def use_all_gather(cls) -> bool:
        return True
185
186
187
188
189

    @classmethod
    def supports_v1(cls, model_config: ModelConfig) -> bool:
        # V1 support on TPU is experimental
        return True
190
191

    @classmethod
192
193
194
195
    def validate_request(
        cls,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
196
        processed_inputs: ProcessorInputs,
197
198
    ) -> None:
        """Raises if this request is unsupported on this platform"""
199
200
201
202
203
204
205
        if isinstance(params, SamplingParams):
            if params.guided_decoding is not None and not envs.VLLM_USE_V1:
                raise ValueError("Structured output is not supported on "
                                 f"{cls.device_name} V0.")
            if params.sampling_type == SamplingType.RANDOM_SEED:
                raise ValueError(
                    "Torch XLA does not support per-request seed.")
206

207
    @classmethod
208
209
    def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str,
                                    model_config: "ModelConfig") -> bool:
210
211
        return True

212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
    @classmethod
    @torch.compile(backend="openxla")
    def insert_blocks_to_device(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        torch.ops.xla.dynamo_set_buffer_donor_(dst_cache, True)
        dst_cache[dst_block_indices] = src_cache[src_block_indices].to(
            dst_cache.device)

    @classmethod
    @torch.compile(backend="openxla")
    def swap_out_blocks_to_host(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """ tpu blocks to cpu blocks"""
        torch.ops.xla.dynamo_set_buffer_donor_(src_cache, True)
        dst_cache[dst_block_indices] = src_cache[src_block_indices].cpu()

238
239
240
241

try:
    from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform
    TpuPlatform = TpuCommonsPlatform  # type: ignore
242
    USE_TPU_COMMONS = True
243
244
245
except ImportError:
    logger.info("tpu_commons not found, using vLLM's TpuPlatform")
    pass