tpu_executor.py 5.47 KB
Newer Older
1
from typing import Any, Dict, List, Optional, Set, Tuple
2
3
4
5
6
7

import torch

from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
8
9
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
10
11
12
13
14
15
16
17
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
                        make_async)

logger = init_logger(__name__)


class TPUExecutor(ExecutorBase):

18
19
    uses_ray: bool = False

20
21
22
23
24
25
26
27
28
29
30
31
    def _init_executor(self) -> None:
        assert not self.scheduler_config.chunked_prefill_enabled, (
            "Chunked prefill is not yet supported for TPU backend")
        assert not self.speculative_config, (
            "Speculative decoding is not yet supported for TPU backend")
        if self.model_config.dtype in (torch.float16, torch.float32):
            logger.warning(
                "The TPU backend currently does not support %s. "
                "Using bfloat16 instead.", self.model_config.dtype)
            self.model_config.dtype = torch.bfloat16

        # Instantiate the worker and load the model to the device.
32
33
34
        self.driver_worker = self._create_worker()
        self.driver_worker.init_device()
        self.driver_worker.load_model()
35

36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
    def _get_worker_kwargs(
        self,
        local_rank: int = 0,
        rank: int = 0,
        distributed_init_method: Optional[str] = None,
    ) -> Dict[str, Any]:
        """Return worker init args for a given rank."""
        if distributed_init_method is None:
            distributed_init_method = get_distributed_init_method(
                get_ip(), get_open_port())
        return dict(
            model_config=self.model_config,
            parallel_config=self.parallel_config,
            scheduler_config=self.scheduler_config,
            device_config=self.device_config,
            cache_config=self.cache_config,
            load_config=self.load_config,
            local_rank=local_rank,
            rank=rank,
55
            distributed_init_method=distributed_init_method,
56
            is_driver_worker=rank == 0,
57
        )
58
59
60
61
62
63
64

    def _create_worker(
        self,
        local_rank: int = 0,
        rank: int = 0,
        distributed_init_method: Optional[str] = None,
    ):
65
66
67
68
69
70
71
        if self.scheduler_config.is_multi_step:
            from vllm.worker.multi_step_tpu_worker import MultiStepTPUWorker
            worker = MultiStepTPUWorker(**self._get_worker_kwargs(
                local_rank, rank, distributed_init_method))
            return worker
        else:
            from vllm.worker.tpu_worker import TPUWorker
72

73
74
75
            worker = TPUWorker(**self._get_worker_kwargs(
                local_rank, rank, distributed_init_method))
            return worker
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91

    def initialize_cache(
        self,
        num_gpu_blocks: int,
        num_cpu_blocks: int,
    ) -> None:
        """Initialize the KV cache by invoking the underlying worker."""
        # NOTE: This is logged in the executor because there can be >1 worker
        # with other executors. We could log in the engine level, but work
        # remains to abstract away the device for non-GPU configurations.
        logger.info("# TPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
                    num_cpu_blocks)
        self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)

    def determine_num_available_blocks(self) -> Tuple[int, int]:
        """Determine the number of available KV blocks by invoking the
92
        underlying worker."""
93
94
95
96
97
98
99
100
101
102
        return self.driver_worker.determine_num_available_blocks()

    def execute_model(
        self,
        execute_model_req: ExecuteModelRequest,
    ) -> List[SamplerOutput]:
        output = self.driver_worker.execute_model(execute_model_req)
        return output

    def add_lora(self, lora_request: LoRARequest) -> bool:
103
104
        raise NotImplementedError(
            "LoRA is currently not supported by the TPU backend.")
105
106

    def remove_lora(self, lora_id: int) -> bool:
107
108
        raise NotImplementedError(
            "LoRA is currently not supported by the TPU backend.")
109

110
    def pin_lora(self, lora_id: int) -> bool:
111
112
        raise NotImplementedError(
            "LoRA is currently not supported by the TPU backend.")
113

114
    def list_loras(self) -> Set[int]:
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
        raise NotImplementedError(
            "LoRA is currently not supported by the TPU backend.")

    def add_prompt_adapter(self, prompt_adapter_request) -> bool:
        raise NotImplementedError(
            "Soft prompt is currently not supported by the TPU backend.")

    def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
        raise NotImplementedError(
            "Soft prompt is currently not supported by the TPU backend.")

    def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
        raise NotImplementedError(
            "Soft prompt is currently not supported by the TPU backend.")

    def list_prompt_adapters(self) -> Set[int]:
        raise NotImplementedError(
            "Soft prompt is currently not supported by the TPU backend.")
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147

    def check_health(self) -> None:
        # TPUExecutor will always be healthy as long as it's running.
        return


class TPUExecutorAsync(TPUExecutor, ExecutorAsyncBase):

    async def execute_model_async(
        self,
        sexecute_model_req: ExecuteModelRequest,
    ) -> SamplerOutput:
        output = await make_async(self.driver_worker.execute_model
                                  )(sexecute_model_req)
        return output