tpu.py 4.61 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from typing import TYPE_CHECKING, Optional
4

5
6
import torch

7
import vllm.envs as envs
8
9
10
from vllm.logger import init_logger

from .interface import Platform, PlatformEnum, _Backend
11

12
if TYPE_CHECKING:
13
    from vllm.config import ModelConfig, VllmConfig
14
else:
15
    ModelConfig = None
16
    VllmConfig = None
17

18
19
logger = init_logger(__name__)

20
21
22

class TpuPlatform(Platform):
    _enum = PlatformEnum.TPU
23
    device_name: str = "tpu"
24
    device_type: str = "tpu"
25
    dispatch_key: str = "XLA"
26
    ray_device_key: str = "TPU"
27
    device_control_env_var: str = "TPU_VISIBLE_CHIPS"
28

29
30
31
    supported_quantization: list[str] = [
        "tpu_int8", "compressed-tensors", "compressed_tensors"
    ]
32

33
34
35
36
    additional_env_vars: list[str] = [
        "TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS"
    ]

37
    @classmethod
38
39
    def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
                             dtype: torch.dtype, kv_cache_dtype: Optional[str],
40
41
                             block_size: int, use_v1: bool,
                             use_mla: bool) -> str:
42
43
        if (selected_backend != _Backend.PALLAS
                and selected_backend != _Backend.PALLAS_VLLM_V1):
44
            logger.info("Cannot use %s backend on TPU.", selected_backend)
45
46
47
48
49
50
51

        if use_v1:
            logger.info("Using Pallas V1 backend.")
            return "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
        else:
            logger.info("Using Pallas backend.")
            return "vllm.attention.backends.pallas.PallasAttentionBackend"
52

53
54
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
55
        return "tpu"
56

57
58
59
60
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        raise NotImplementedError

61
62
    @classmethod
    def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
63
        return not envs.VLLM_USE_V1
64

65
66
    @classmethod
    def inference_mode(cls):
67
        return torch.no_grad()
68
69
70
71

    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
        from vllm.config import CompilationLevel
72
73
74
75
76

        cache_config = vllm_config.cache_config
        if cache_config and cache_config.block_size is None:
            cache_config.block_size = 16

77
        compilation_config = vllm_config.compilation_config
78
79
80
81

        # TPU only supports DYNAMO_ONCE compilation level
        if compilation_config.level != CompilationLevel.DYNAMO_ONCE:
            logger.info("[TPU] Forcing DYNAMO_ONCE compilation level")
82
            compilation_config.level = CompilationLevel.DYNAMO_ONCE
83
84
85

        if compilation_config.backend == "":
            compilation_config.backend = "openxla"
86
87
88
89

        assert vllm_config.speculative_config is None, \
            "TPU does not support speculative decoding"

90
91
92
93
94
95
        if vllm_config.model_config.dtype in (torch.float16, torch.float32):
            logger.warning(
                "The TPU backend currently does not support %s. "
                "Using bfloat16 instead.", vllm_config.model_config.dtype)
            vllm_config.model_config.dtype = torch.bfloat16

96
97
98
        parallel_config = vllm_config.parallel_config
        scheduler_config = vllm_config.scheduler_config
        if parallel_config.worker_cls == "auto":
99
100
101
102
103
104
105
            if scheduler_config.is_multi_step:
                if envs.VLLM_USE_V1:
                    raise NotImplementedError(
                        "Multi-step scheduling is not supported (and not "
                        "needed) on vLLM V1. Please launch without "
                        "--num-scheduler-steps.")
                else:
106
107
                    parallel_config.worker_cls = \
                        "vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker"
108
109
110
111
            else:
                if envs.VLLM_USE_V1:
                    parallel_config.worker_cls = \
                        "vllm.v1.worker.tpu_worker.TPUWorker"
112
113
114
115
116
117
118
119
120
121
122
                else:
                    parallel_config.worker_cls = \
                        "vllm.worker.tpu_worker.TPUWorker"

        assert not vllm_config.speculative_config, (
            "Speculative decoding is not yet supported for TPU backend")

    @classmethod
    def is_pin_memory_available(cls):
        logger.warning("Pin memory is not supported on TPU.")
        return False
123
124
125
126

    @classmethod
    def get_device_communicator_cls(cls) -> str:
        return "vllm.distributed.device_communicators.tpu_communicator.TpuCommunicator"  # noqa
127
128
129
130

    @classmethod
    def use_all_gather(cls) -> bool:
        return True
131
132
133
134
135

    @classmethod
    def supports_v1(cls, model_config: ModelConfig) -> bool:
        # V1 support on TPU is experimental
        return True