tpu.py 8.34 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import TYPE_CHECKING, Optional, Union, cast
5

6
import torch
7
from tpu_info import device
8

9
from vllm.inputs import ProcessorInputs, PromptType
10
from vllm.logger import init_logger
11
from vllm.sampling_params import SamplingParams, SamplingType
12
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
13
14

from .interface import Platform, PlatformEnum, _Backend
15

16
if TYPE_CHECKING:
17
    from vllm.config import BlockSize, ModelConfig, VllmConfig
18
    from vllm.pooling_params import PoolingParams
19
else:
20
    BlockSize = None
21
    ModelConfig = None
22
    VllmConfig = None
23
    PoolingParams = None
24

25
26
logger = init_logger(__name__)

27
28
USE_TPU_COMMONS = False

29
30
31

class TpuPlatform(Platform):
    _enum = PlatformEnum.TPU
32
    device_name: str = "tpu"
33
    device_type: str = "tpu"
34
    dispatch_key: str = "XLA"
35
    ray_device_key: str = "TPU"
36
    dist_backend: str = "gloo"
37
    device_control_env_var: str = "TPU_VISIBLE_CHIPS"
38
    simple_compile_backend: str = "openxla"
39

40
41
42
    supported_quantization: list[str] = [
        "fp8", "tpu_int8", "compressed-tensors"
    ]
43

44
45
46
47
    additional_env_vars: list[str] = [
        "TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS"
    ]

48
    @classmethod
49
50
    def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
                             dtype: torch.dtype, kv_cache_dtype: Optional[str],
51
52
                             block_size: int, use_v1: bool, use_mla: bool,
                             has_sink) -> str:
53
54
        if (selected_backend != _Backend.PALLAS
                and selected_backend != _Backend.PALLAS_VLLM_V1):
55
            logger.info("Cannot use %s backend on TPU.", selected_backend)
56

57
58
59
60
        if not use_v1:
            raise ValueError("TPU backend only supports V1.")
        logger.info("Using Pallas V1 backend.")
        return "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
61

62
63
64
65
66
67
68
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
        torch.tpu.set_device(device)

69
70
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
71
72
        chip_type, _ = device.get_local_chips()
        return f"TPU {chip_type.name}"
73

74
75
76
77
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        raise NotImplementedError

78
79
    @classmethod
    def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
80
        return False
81

82
83
84
85
86
    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU"

    @classmethod
87
    def get_infinity_values(cls, dtype: torch.dtype) -> tuple[float, float]:
88
89
90
91
92
93
94
95
96
97
        return torch.finfo(dtype).min, torch.finfo(dtype).max

    @classmethod
    def can_update_inplace(cls):
        return False

    @classmethod
    def get_lora_vocab_padding_size(cls) -> int:
        return 1

98
99
    @classmethod
    def inference_mode(cls):
100
        return torch.no_grad()
101
102
103

    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
104
        from vllm.config import CompilationLevel, CUDAGraphMode
105
106

        cache_config = vllm_config.cache_config
107
        # For v0, the default block size is 16.
108
        if cache_config and cache_config.block_size is None:
109
            cache_config.block_size = cast(BlockSize, 16)
110
        compilation_config = vllm_config.compilation_config
111
112
113

        # TPU only supports DYNAMO_ONCE compilation level
        if compilation_config.level != CompilationLevel.DYNAMO_ONCE:
114
115
            logger.info("[TPU] Forcing DYNAMO_ONCE compilation level, and "
                        "disabling cudagraph.")
116
            compilation_config.level = CompilationLevel.DYNAMO_ONCE
117

118
119
120
121
122
123
124
        if compilation_config.cudagraph_mode is None or \
                compilation_config.cudagraph_mode.max_cudagraph_mode() \
                    != CUDAGraphMode.NONE:
            logger.info("[TPU] CUDA graph is not supported on TPU, "
                        "disabling cudagraphs.")
            compilation_config.cudagraph_mode = CUDAGraphMode.NONE

125
126
        if compilation_config.backend == "":
            compilation_config.backend = "openxla"
127
128
129
130

        assert vllm_config.speculative_config is None, \
            "TPU does not support speculative decoding"

131
132
133
        model_config = vllm_config.model_config
        if model_config is not None and model_config.dtype in (torch.float16,
                                                               torch.float32):
134
135
            logger.warning(
                "The TPU backend currently does not support %s. "
136
137
                "Using bfloat16 instead.", model_config.dtype)
            model_config.dtype = torch.bfloat16
138

139
140
141
        from vllm.v1.attention.backends.pallas import PallasAttentionBackend
        cache_config.block_size = PallasAttentionBackend.get_page_size(
            vllm_config)  # type: ignore[assignment]
142

143
144
145
        parallel_config = vllm_config.parallel_config
        scheduler_config = vllm_config.scheduler_config
        if parallel_config.worker_cls == "auto":
146
            parallel_config.worker_cls = "vllm.v1.worker.tpu_worker.TPUWorker"
147
148
149
150

        assert not vllm_config.speculative_config, (
            "Speculative decoding is not yet supported for TPU backend")

151
        if scheduler_config.is_multimodal_model and not \
152
                scheduler_config.disable_chunked_mm_input:
153
154
155
156
157
            logger.warning("TPU does not support running Multimodal models"\
            " without setting `--disable_chunked_mm_input`. " \
            "Forcing --disable_chunked_mm_input.")
            scheduler_config.disable_chunked_mm_input = True

158
        if model_config and model_config.use_mla:
159
160
161
162
163
164
165
166
167
            logger.info(
                "MLA is enabled on a non-GPU platform; forcing chunked "
                "prefill and prefix caching to be disabled.")
            vllm_config.scheduler_config.enable_chunked_prefill = False
            vllm_config.scheduler_config.chunked_prefill_enabled = False
            vllm_config.scheduler_config.max_num_batched_tokens = max(
                vllm_config.scheduler_config.max_model_len,
                DEFAULT_MAX_NUM_BATCHED_TOKENS)

168
169
170
171
    @classmethod
    def is_pin_memory_available(cls):
        logger.warning("Pin memory is not supported on TPU.")
        return False
172
173
174
175

    @classmethod
    def get_device_communicator_cls(cls) -> str:
        return "vllm.distributed.device_communicators.tpu_communicator.TpuCommunicator"  # noqa
176
177
178
179

    @classmethod
    def use_all_gather(cls) -> bool:
        return True
180
181
182
183
184

    @classmethod
    def supports_v1(cls, model_config: ModelConfig) -> bool:
        # V1 support on TPU is experimental
        return True
185
186

    @classmethod
187
188
189
190
    def validate_request(
        cls,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
191
        processed_inputs: ProcessorInputs,
192
193
    ) -> None:
        """Raises if this request is unsupported on this platform"""
194
195
196
        if (isinstance(params, SamplingParams)
                and params.sampling_type == SamplingType.RANDOM_SEED):
            raise ValueError("Torch XLA does not support per-request seed.")
197

198
    @classmethod
199
200
    def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str,
                                    model_config: "ModelConfig") -> bool:
201
202
        return True

203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
    @classmethod
    @torch.compile(backend="openxla")
    def insert_blocks_to_device(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        torch.ops.xla.dynamo_set_buffer_donor_(dst_cache, True)
        dst_cache[dst_block_indices] = src_cache[src_block_indices].to(
            dst_cache.device)

    @classmethod
    @torch.compile(backend="openxla")
    def swap_out_blocks_to_host(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """ tpu blocks to cpu blocks"""
        torch.ops.xla.dynamo_set_buffer_donor_(src_cache, True)
        dst_cache[dst_block_indices] = src_cache[src_block_indices].cpu()

229
230
231
232

try:
    from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform
    TpuPlatform = TpuCommonsPlatform  # type: ignore
233
    USE_TPU_COMMONS = True
234
235
236
except ImportError:
    logger.info("tpu_commons not found, using vLLM's TpuPlatform")
    pass