"vscode:/vscode.git/clone" did not exist on "5b032352cc7285ed0b0d5c2fcbb9b7deab85d6c6"
tpu_executor.py 5.16 KB
Newer Older
1
from typing import Any, Dict, List, Optional, Set, Tuple
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16

import torch

from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
                        make_async)

logger = init_logger(__name__)


class TPUExecutor(ExecutorBase):

17
18
    uses_ray: bool = False

19
20
21
22
23
24
25
26
27
28
29
30
    def _init_executor(self) -> None:
        assert not self.scheduler_config.chunked_prefill_enabled, (
            "Chunked prefill is not yet supported for TPU backend")
        assert not self.speculative_config, (
            "Speculative decoding is not yet supported for TPU backend")
        if self.model_config.dtype in (torch.float16, torch.float32):
            logger.warning(
                "The TPU backend currently does not support %s. "
                "Using bfloat16 instead.", self.model_config.dtype)
            self.model_config.dtype = torch.bfloat16

        # Instantiate the worker and load the model to the device.
31
32
33
        self.driver_worker = self._create_worker()
        self.driver_worker.init_device()
        self.driver_worker.load_model()
34

35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
    def _get_worker_kwargs(
        self,
        local_rank: int = 0,
        rank: int = 0,
        distributed_init_method: Optional[str] = None,
    ) -> Dict[str, Any]:
        """Return worker init args for a given rank."""
        if distributed_init_method is None:
            distributed_init_method = get_distributed_init_method(
                get_ip(), get_open_port())
        return dict(
            model_config=self.model_config,
            parallel_config=self.parallel_config,
            scheduler_config=self.scheduler_config,
            device_config=self.device_config,
            cache_config=self.cache_config,
            load_config=self.load_config,
            local_rank=local_rank,
            rank=rank,
54
            distributed_init_method=distributed_init_method,
55
            is_driver_worker=rank == 0,
56
        )
57
58
59
60
61
62
63
64
65
66
67
68

    def _create_worker(
        self,
        local_rank: int = 0,
        rank: int = 0,
        distributed_init_method: Optional[str] = None,
    ):
        from vllm.worker.tpu_worker import TPUWorker

        worker = TPUWorker(**self._get_worker_kwargs(local_rank, rank,
                                                     distributed_init_method))
        return worker
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84

    def initialize_cache(
        self,
        num_gpu_blocks: int,
        num_cpu_blocks: int,
    ) -> None:
        """Initialize the KV cache by invoking the underlying worker."""
        # NOTE: This is logged in the executor because there can be >1 worker
        # with other executors. We could log in the engine level, but work
        # remains to abstract away the device for non-GPU configurations.
        logger.info("# TPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
                    num_cpu_blocks)
        self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)

    def determine_num_available_blocks(self) -> Tuple[int, int]:
        """Determine the number of available KV blocks by invoking the
85
        underlying worker."""
86
87
88
89
90
91
92
93
94
95
        return self.driver_worker.determine_num_available_blocks()

    def execute_model(
        self,
        execute_model_req: ExecuteModelRequest,
    ) -> List[SamplerOutput]:
        output = self.driver_worker.execute_model(execute_model_req)
        return output

    def add_lora(self, lora_request: LoRARequest) -> bool:
96
97
        raise NotImplementedError(
            "LoRA is currently not supported by the TPU backend.")
98
99

    def remove_lora(self, lora_id: int) -> bool:
100
101
        raise NotImplementedError(
            "LoRA is currently not supported by the TPU backend.")
102

103
    def pin_lora(self, lora_id: int) -> bool:
104
105
        raise NotImplementedError(
            "LoRA is currently not supported by the TPU backend.")
106

107
    def list_loras(self) -> Set[int]:
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
        raise NotImplementedError(
            "LoRA is currently not supported by the TPU backend.")

    def add_prompt_adapter(self, prompt_adapter_request) -> bool:
        raise NotImplementedError(
            "Soft prompt is currently not supported by the TPU backend.")

    def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
        raise NotImplementedError(
            "Soft prompt is currently not supported by the TPU backend.")

    def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
        raise NotImplementedError(
            "Soft prompt is currently not supported by the TPU backend.")

    def list_prompt_adapters(self) -> Set[int]:
        raise NotImplementedError(
            "Soft prompt is currently not supported by the TPU backend.")
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140

    def check_health(self) -> None:
        # TPUExecutor will always be healthy as long as it's running.
        return


class TPUExecutorAsync(TPUExecutor, ExecutorAsyncBase):

    async def execute_model_async(
        self,
        sexecute_model_req: ExecuteModelRequest,
    ) -> SamplerOutput:
        output = await make_async(self.driver_worker.execute_model
                                  )(sexecute_model_req)
        return output