tpu.py 7.2 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from typing import TYPE_CHECKING, Optional, Tuple, Union, cast
4

5
import torch
6
from tpu_info import device
7

8
import vllm.envs as envs
9
from vllm.inputs import ProcessorInputs, PromptType
10
from vllm.logger import init_logger
11
from vllm.sampling_params import SamplingParams, SamplingType
12
13

from .interface import Platform, PlatformEnum, _Backend
14

15
if TYPE_CHECKING:
16
    from vllm.config import BlockSize, ModelConfig, VllmConfig
17
    from vllm.pooling_params import PoolingParams
18
else:
19
    BlockSize = None
20
    ModelConfig = None
21
    VllmConfig = None
22
    PoolingParams = None
23

24
25
logger = init_logger(__name__)

26
27
28

class TpuPlatform(Platform):
    _enum = PlatformEnum.TPU
29
    device_name: str = "tpu"
30
    device_type: str = "tpu"
31
    dispatch_key: str = "XLA"
32
    ray_device_key: str = "TPU"
33
    device_control_env_var: str = "TPU_VISIBLE_CHIPS"
34
    simple_compile_backend: str = "openxla"
35

36
    supported_quantization: list[str] = ["tpu_int8", "compressed-tensors"]
37

38
39
40
41
    additional_env_vars: list[str] = [
        "TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS"
    ]

42
    @classmethod
43
44
    def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
                             dtype: torch.dtype, kv_cache_dtype: Optional[str],
45
46
                             block_size: int, use_v1: bool,
                             use_mla: bool) -> str:
47
48
        if (selected_backend != _Backend.PALLAS
                and selected_backend != _Backend.PALLAS_VLLM_V1):
49
            logger.info("Cannot use %s backend on TPU.", selected_backend)
50
51
52
53
54
55
56

        if use_v1:
            logger.info("Using Pallas V1 backend.")
            return "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
        else:
            logger.info("Using Pallas backend.")
            return "vllm.attention.backends.pallas.PallasAttentionBackend"
57

58
59
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
60
61
        chip_type, _ = device.get_local_chips()
        return f"TPU {chip_type.name}"
62

63
64
65
66
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        raise NotImplementedError

67
68
    @classmethod
    def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
69
        return not envs.VLLM_USE_V1
70

71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU"

    @classmethod
    def get_infinity_values(cls, dtype: torch.dtype) -> Tuple[float, float]:
        return torch.finfo(dtype).min, torch.finfo(dtype).max

    @classmethod
    def can_update_inplace(cls):
        return False

    @classmethod
    def get_lora_vocab_padding_size(cls) -> int:
        return 1

87
88
    @classmethod
    def inference_mode(cls):
89
        return torch.no_grad()
90
91
92
93

    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
        from vllm.config import CompilationLevel
94
95

        cache_config = vllm_config.cache_config
96
        # For v0, the default block size is 16.
97
        if cache_config and cache_config.block_size is None:
98
            cache_config.block_size = cast(BlockSize, 16)
99
        compilation_config = vllm_config.compilation_config
100
101
102
103

        # TPU only supports DYNAMO_ONCE compilation level
        if compilation_config.level != CompilationLevel.DYNAMO_ONCE:
            logger.info("[TPU] Forcing DYNAMO_ONCE compilation level")
104
            compilation_config.level = CompilationLevel.DYNAMO_ONCE
105
106
107

        if compilation_config.backend == "":
            compilation_config.backend = "openxla"
108
109
110
111

        assert vllm_config.speculative_config is None, \
            "TPU does not support speculative decoding"

112
113
114
115
116
117
        if vllm_config.model_config.dtype in (torch.float16, torch.float32):
            logger.warning(
                "The TPU backend currently does not support %s. "
                "Using bfloat16 instead.", vllm_config.model_config.dtype)
            vllm_config.model_config.dtype = torch.bfloat16

118
119
120
        if envs.VLLM_USE_V1:
            from vllm.v1.attention.backends.pallas import (
                PallasAttentionBackend)
121
            cache_config.block_size = PallasAttentionBackend.get_page_size(
122
                vllm_config)  # type: ignore[assignment]
123
124
            min_page_size = PallasAttentionBackend.get_min_page_size(
                vllm_config)
125
            if min_page_size > cache_config.block_size:
126
127
128
                logger.warning(
                    "Increase the page size from %s to %s to make sure there's"
                    "no SMEM OOM",
129
                    cache_config.block_size,
130
131
                    min_page_size,
                )
132
                cache_config.block_size = min_page_size  # type: ignore[assignment]
133

134
135
136
        parallel_config = vllm_config.parallel_config
        scheduler_config = vllm_config.scheduler_config
        if parallel_config.worker_cls == "auto":
137
138
139
140
141
142
143
            if scheduler_config.is_multi_step:
                if envs.VLLM_USE_V1:
                    raise NotImplementedError(
                        "Multi-step scheduling is not supported (and not "
                        "needed) on vLLM V1. Please launch without "
                        "--num-scheduler-steps.")
                else:
144
145
                    parallel_config.worker_cls = \
                        "vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker"
146
147
148
149
            else:
                if envs.VLLM_USE_V1:
                    parallel_config.worker_cls = \
                        "vllm.v1.worker.tpu_worker.TPUWorker"
150
151
152
153
154
155
156
                else:
                    parallel_config.worker_cls = \
                        "vllm.worker.tpu_worker.TPUWorker"

        assert not vllm_config.speculative_config, (
            "Speculative decoding is not yet supported for TPU backend")

157
158
159
160
161
162
163
        if scheduler_config.is_multimodal_model and not \
            scheduler_config.disable_chunked_mm_input:
            logger.warning("TPU does not support running Multimodal models"\
            " without setting `--disable_chunked_mm_input`. " \
            "Forcing --disable_chunked_mm_input.")
            scheduler_config.disable_chunked_mm_input = True

164
165
166
167
    @classmethod
    def is_pin_memory_available(cls):
        logger.warning("Pin memory is not supported on TPU.")
        return False
168
169
170
171

    @classmethod
    def get_device_communicator_cls(cls) -> str:
        return "vllm.distributed.device_communicators.tpu_communicator.TpuCommunicator"  # noqa
172
173
174
175

    @classmethod
    def use_all_gather(cls) -> bool:
        return True
176
177
178
179
180

    @classmethod
    def supports_v1(cls, model_config: ModelConfig) -> bool:
        # V1 support on TPU is experimental
        return True
181
182

    @classmethod
183
184
185
186
    def validate_request(
        cls,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
187
        processed_inputs: ProcessorInputs,
188
189
    ) -> None:
        """Raises if this request is unsupported on this platform"""
190
        if isinstance(params, SamplingParams):
191
            if params.guided_decoding is not None and not envs.VLLM_USE_V1:
192
                raise ValueError("Structured output is not supported on "
193
                                 f"{cls.device_name} V0.")
194
195
196
            if params.sampling_type == SamplingType.RANDOM_SEED:
                raise ValueError(
                    "Torch XLA does not support per-request seed.")