tpu_executor.py 5.24 KB
Newer Older
1
from typing import Any, Dict, List, Optional, Set, Tuple
2
3
4
5
6
7

import torch

from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
8
9
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
10
11
12
13
14
15
16
17
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
                        make_async)

logger = init_logger(__name__)


class TPUExecutor(ExecutorBase):

18
19
    uses_ray: bool = False

20
21
22
23
24
25
26
27
28
29
30
31
    def _init_executor(self) -> None:
        assert not self.scheduler_config.chunked_prefill_enabled, (
            "Chunked prefill is not yet supported for TPU backend")
        assert not self.speculative_config, (
            "Speculative decoding is not yet supported for TPU backend")
        if self.model_config.dtype in (torch.float16, torch.float32):
            logger.warning(
                "The TPU backend currently does not support %s. "
                "Using bfloat16 instead.", self.model_config.dtype)
            self.model_config.dtype = torch.bfloat16

        # Instantiate the worker and load the model to the device.
32
33
34
        self.driver_worker = self._create_worker()
        self.driver_worker.init_device()
        self.driver_worker.load_model()
35

36
37
38
39
40
41
42
43
44
45
46
    def _get_worker_kwargs(
        self,
        local_rank: int = 0,
        rank: int = 0,
        distributed_init_method: Optional[str] = None,
    ) -> Dict[str, Any]:
        """Return worker init args for a given rank."""
        if distributed_init_method is None:
            distributed_init_method = get_distributed_init_method(
                get_ip(), get_open_port())
        return dict(
47
            vllm_config=self.vllm_config,
48
49
            local_rank=local_rank,
            rank=rank,
50
            distributed_init_method=distributed_init_method,
51
            is_driver_worker=rank == 0,
52
        )
53
54
55
56
57
58
59

    def _create_worker(
        self,
        local_rank: int = 0,
        rank: int = 0,
        distributed_init_method: Optional[str] = None,
    ):
60
61
62
63
64
65
66
        if self.scheduler_config.is_multi_step:
            from vllm.worker.multi_step_tpu_worker import MultiStepTPUWorker
            worker = MultiStepTPUWorker(**self._get_worker_kwargs(
                local_rank, rank, distributed_init_method))
            return worker
        else:
            from vllm.worker.tpu_worker import TPUWorker
67

68
69
70
            worker = TPUWorker(**self._get_worker_kwargs(
                local_rank, rank, distributed_init_method))
            return worker
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86

    def initialize_cache(
        self,
        num_gpu_blocks: int,
        num_cpu_blocks: int,
    ) -> None:
        """Initialize the KV cache by invoking the underlying worker."""
        # NOTE: This is logged in the executor because there can be >1 worker
        # with other executors. We could log in the engine level, but work
        # remains to abstract away the device for non-GPU configurations.
        logger.info("# TPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
                    num_cpu_blocks)
        self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)

    def determine_num_available_blocks(self) -> Tuple[int, int]:
        """Determine the number of available KV blocks by invoking the
87
        underlying worker."""
88
89
90
91
92
93
94
95
96
97
        return self.driver_worker.determine_num_available_blocks()

    def execute_model(
        self,
        execute_model_req: ExecuteModelRequest,
    ) -> List[SamplerOutput]:
        output = self.driver_worker.execute_model(execute_model_req)
        return output

    def add_lora(self, lora_request: LoRARequest) -> bool:
98
99
        raise NotImplementedError(
            "LoRA is currently not supported by the TPU backend.")
100
101

    def remove_lora(self, lora_id: int) -> bool:
102
103
        raise NotImplementedError(
            "LoRA is currently not supported by the TPU backend.")
104

105
    def pin_lora(self, lora_id: int) -> bool:
106
107
        raise NotImplementedError(
            "LoRA is currently not supported by the TPU backend.")
108

109
    def list_loras(self) -> Set[int]:
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
        raise NotImplementedError(
            "LoRA is currently not supported by the TPU backend.")

    def add_prompt_adapter(self, prompt_adapter_request) -> bool:
        raise NotImplementedError(
            "Soft prompt is currently not supported by the TPU backend.")

    def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
        raise NotImplementedError(
            "Soft prompt is currently not supported by the TPU backend.")

    def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
        raise NotImplementedError(
            "Soft prompt is currently not supported by the TPU backend.")

    def list_prompt_adapters(self) -> Set[int]:
        raise NotImplementedError(
            "Soft prompt is currently not supported by the TPU backend.")
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142

    def check_health(self) -> None:
        # TPUExecutor will always be healthy as long as it's running.
        return


class TPUExecutorAsync(TPUExecutor, ExecutorAsyncBase):

    async def execute_model_async(
        self,
        sexecute_model_req: ExecuteModelRequest,
    ) -> SamplerOutput:
        output = await make_async(self.driver_worker.execute_model
                                  )(sexecute_model_req)
        return output