tpu.py 4.09 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from typing import TYPE_CHECKING, Optional
4

5
6
import torch

7
import vllm.envs as envs
8
9
10
from vllm.logger import init_logger

from .interface import Platform, PlatformEnum, _Backend
11

12
13
14
15
if TYPE_CHECKING:
    from vllm.config import VllmConfig
else:
    VllmConfig = None
16

17
18
logger = init_logger(__name__)

19
20
21

class TpuPlatform(Platform):
    _enum = PlatformEnum.TPU
22
    device_name: str = "tpu"
23
    device_type: str = "tpu"
24
    dispatch_key: str = "XLA"
25
    ray_device_key: str = "TPU"
26
    device_control_env_var: str = "TPU_VISIBLE_CHIPS"
27

28
29
30
    supported_quantization: list[str] = [
        "tpu_int8", "compressed-tensors", "compressed_tensors"
    ]
31

32
    @classmethod
33
34
    def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
                             dtype: torch.dtype, kv_cache_dtype: Optional[str],
35
36
                             block_size: int, use_v1: bool,
                             use_mla: bool) -> str:
37
38
        if (selected_backend != _Backend.PALLAS
                and selected_backend != _Backend.PALLAS_VLLM_V1):
39
            logger.info("Cannot use %s backend on TPU.", selected_backend)
40
41
42
43
44
45
46

        if use_v1:
            logger.info("Using Pallas V1 backend.")
            return "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
        else:
            logger.info("Using Pallas backend.")
            return "vllm.attention.backends.pallas.PallasAttentionBackend"
47

48
49
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
50
        return "tpu"
51

52
53
54
55
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        raise NotImplementedError

56
57
    @classmethod
    def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
58
        return not envs.VLLM_USE_V1
59

60
61
    @classmethod
    def inference_mode(cls):
62
        return torch.no_grad()
63
64
65
66

    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
        from vllm.config import CompilationLevel
67
68
69
70
71

        cache_config = vllm_config.cache_config
        if cache_config and cache_config.block_size is None:
            cache_config.block_size = 16

72
        compilation_config = vllm_config.compilation_config
73
74
75
76

        # TPU only supports DYNAMO_ONCE compilation level
        if compilation_config.level != CompilationLevel.DYNAMO_ONCE:
            logger.info("[TPU] Forcing DYNAMO_ONCE compilation level")
77
            compilation_config.level = CompilationLevel.DYNAMO_ONCE
78
79
80

        if compilation_config.backend == "":
            compilation_config.backend = "openxla"
81
82
83
84

        assert vllm_config.speculative_config is None, \
            "TPU does not support speculative decoding"

85
86
87
88
89
90
        if vllm_config.model_config.dtype in (torch.float16, torch.float32):
            logger.warning(
                "The TPU backend currently does not support %s. "
                "Using bfloat16 instead.", vllm_config.model_config.dtype)
            vllm_config.model_config.dtype = torch.bfloat16

91
92
93
        parallel_config = vllm_config.parallel_config
        scheduler_config = vllm_config.scheduler_config
        if parallel_config.worker_cls == "auto":
94
            if envs.VLLM_USE_V1:
95
                parallel_config.worker_cls = \
96
                    "vllm.v1.worker.tpu_worker.TPUWorker"
97
            else:
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
                if scheduler_config.is_multi_step:
                    parallel_config.worker_cls = \
                        "vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker"
                else:
                    parallel_config.worker_cls = \
                        "vllm.worker.tpu_worker.TPUWorker"

        # Adjust scheduler config for V1
        # TODO: Add support for these
        if envs.VLLM_USE_V1 and vllm_config.cache_config.enable_prefix_caching:
            logger.warning("[V1][TPU] Disable prefix caching")
            vllm_config.cache_config.enable_prefix_caching = False

        assert not vllm_config.speculative_config, (
            "Speculative decoding is not yet supported for TPU backend")

    @classmethod
    def is_pin_memory_available(cls):
        logger.warning("Pin memory is not supported on TPU.")
        return False