tpu.py 8.08 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import TYPE_CHECKING, Optional, Union, cast
5

6
import torch
7
from tpu_info import device
8

9
import vllm.envs as envs
10
from vllm.inputs import ProcessorInputs, PromptType
11
from vllm.logger import init_logger
12
from vllm.sampling_params import SamplingParams, SamplingType
13
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
14
15

from .interface import Platform, PlatformEnum, _Backend
16

17
if TYPE_CHECKING:
18
    from vllm.config import BlockSize, ModelConfig, VllmConfig
19
    from vllm.pooling_params import PoolingParams
20
else:
21
    BlockSize = None
22
    ModelConfig = None
23
    VllmConfig = None
24
    PoolingParams = None
25

26
27
logger = init_logger(__name__)

28
29
30

class TpuPlatform(Platform):
    _enum = PlatformEnum.TPU
31
    device_name: str = "tpu"
32
    device_type: str = "tpu"
33
    dispatch_key: str = "XLA"
34
    ray_device_key: str = "TPU"
35
    device_control_env_var: str = "TPU_VISIBLE_CHIPS"
36
    simple_compile_backend: str = "openxla"
37

38
    supported_quantization: list[str] = ["tpu_int8", "compressed-tensors"]
39

40
41
42
43
    additional_env_vars: list[str] = [
        "TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS"
    ]

44
    @classmethod
45
46
    def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
                             dtype: torch.dtype, kv_cache_dtype: Optional[str],
47
48
                             block_size: int, use_v1: bool,
                             use_mla: bool) -> str:
49
50
        if (selected_backend != _Backend.PALLAS
                and selected_backend != _Backend.PALLAS_VLLM_V1):
51
            logger.info("Cannot use %s backend on TPU.", selected_backend)
52
53
54
55
56
57
58

        if use_v1:
            logger.info("Using Pallas V1 backend.")
            return "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
        else:
            logger.info("Using Pallas backend.")
            return "vllm.attention.backends.pallas.PallasAttentionBackend"
59

60
61
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
62
63
        chip_type, _ = device.get_local_chips()
        return f"TPU {chip_type.name}"
64

65
66
67
68
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        raise NotImplementedError

69
70
    @classmethod
    def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
71
        return not envs.VLLM_USE_V1
72

73
74
75
76
77
    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU"

    @classmethod
78
    def get_infinity_values(cls, dtype: torch.dtype) -> tuple[float, float]:
79
80
81
82
83
84
85
86
87
88
        return torch.finfo(dtype).min, torch.finfo(dtype).max

    @classmethod
    def can_update_inplace(cls):
        return False

    @classmethod
    def get_lora_vocab_padding_size(cls) -> int:
        return 1

89
90
    @classmethod
    def inference_mode(cls):
91
        return torch.no_grad()
92
93
94
95

    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
        from vllm.config import CompilationLevel
96
97

        cache_config = vllm_config.cache_config
98
        # For v0, the default block size is 16.
99
        if cache_config and cache_config.block_size is None:
100
            cache_config.block_size = cast(BlockSize, 16)
101
        compilation_config = vllm_config.compilation_config
102
103
104
105

        # TPU only supports DYNAMO_ONCE compilation level
        if compilation_config.level != CompilationLevel.DYNAMO_ONCE:
            logger.info("[TPU] Forcing DYNAMO_ONCE compilation level")
106
            compilation_config.level = CompilationLevel.DYNAMO_ONCE
107
108
109

        if compilation_config.backend == "":
            compilation_config.backend = "openxla"
110
111
112
113

        assert vllm_config.speculative_config is None, \
            "TPU does not support speculative decoding"

114
115
116
117
118
119
        if vllm_config.model_config.dtype in (torch.float16, torch.float32):
            logger.warning(
                "The TPU backend currently does not support %s. "
                "Using bfloat16 instead.", vllm_config.model_config.dtype)
            vllm_config.model_config.dtype = torch.bfloat16

120
121
122
        if envs.VLLM_USE_V1:
            from vllm.v1.attention.backends.pallas import (
                PallasAttentionBackend)
123
            cache_config.block_size = PallasAttentionBackend.get_page_size(
124
                vllm_config)  # type: ignore[assignment]
125
126
            min_page_size = PallasAttentionBackend.get_min_page_size(
                vllm_config)
127
            if min_page_size > cache_config.block_size:
128
129
130
                logger.warning(
                    "Increase the page size from %s to %s to make sure there's"
                    "no SMEM OOM",
131
                    cache_config.block_size,
132
133
                    min_page_size,
                )
134
                cache_config.block_size = min_page_size  # type: ignore[assignment]
135

136
137
138
        parallel_config = vllm_config.parallel_config
        scheduler_config = vllm_config.scheduler_config
        if parallel_config.worker_cls == "auto":
139
140
141
142
143
144
145
            if scheduler_config.is_multi_step:
                if envs.VLLM_USE_V1:
                    raise NotImplementedError(
                        "Multi-step scheduling is not supported (and not "
                        "needed) on vLLM V1. Please launch without "
                        "--num-scheduler-steps.")
                else:
146
147
                    parallel_config.worker_cls = \
                        "vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker"
148
149
150
151
            else:
                if envs.VLLM_USE_V1:
                    parallel_config.worker_cls = \
                        "vllm.v1.worker.tpu_worker.TPUWorker"
152
153
154
155
156
157
158
                else:
                    parallel_config.worker_cls = \
                        "vllm.worker.tpu_worker.TPUWorker"

        assert not vllm_config.speculative_config, (
            "Speculative decoding is not yet supported for TPU backend")

159
160
161
162
163
164
165
        if scheduler_config.is_multimodal_model and not \
            scheduler_config.disable_chunked_mm_input:
            logger.warning("TPU does not support running Multimodal models"\
            " without setting `--disable_chunked_mm_input`. " \
            "Forcing --disable_chunked_mm_input.")
            scheduler_config.disable_chunked_mm_input = True

166
167
168
169
170
171
172
173
174
175
        if vllm_config.model_config and vllm_config.model_config.use_mla:
            logger.info(
                "MLA is enabled on a non-GPU platform; forcing chunked "
                "prefill and prefix caching to be disabled.")
            vllm_config.scheduler_config.enable_chunked_prefill = False
            vllm_config.scheduler_config.chunked_prefill_enabled = False
            vllm_config.scheduler_config.max_num_batched_tokens = max(
                vllm_config.scheduler_config.max_model_len,
                DEFAULT_MAX_NUM_BATCHED_TOKENS)

176
177
178
179
    @classmethod
    def is_pin_memory_available(cls):
        logger.warning("Pin memory is not supported on TPU.")
        return False
180
181
182
183

    @classmethod
    def get_device_communicator_cls(cls) -> str:
        return "vllm.distributed.device_communicators.tpu_communicator.TpuCommunicator"  # noqa
184
185
186
187

    @classmethod
    def use_all_gather(cls) -> bool:
        return True
188
189
190
191
192

    @classmethod
    def supports_v1(cls, model_config: ModelConfig) -> bool:
        # V1 support on TPU is experimental
        return True
193
194

    @classmethod
195
196
197
198
    def validate_request(
        cls,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
199
        processed_inputs: ProcessorInputs,
200
201
    ) -> None:
        """Raises if this request is unsupported on this platform"""
202
        if isinstance(params, SamplingParams):
203
            if params.guided_decoding is not None and not envs.VLLM_USE_V1:
204
                raise ValueError("Structured output is not supported on "
205
                                 f"{cls.device_name} V0.")
206
207
208
            if params.sampling_type == SamplingType.RANDOM_SEED:
                raise ValueError(
                    "Torch XLA does not support per-request seed.")
209
210
211
212
213
214
215
216


try:
    from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform
    TpuPlatform = TpuCommonsPlatform  # type: ignore
except ImportError:
    logger.info("tpu_commons not found, using vLLM's TpuPlatform")
    pass