"vscode:/vscode.git/clone" did not exist on "1cb8c6c5fe96ad12fcb79691fdc099f7ffc03c9f"
distributed_gpu_executor.py 8.08 KB
Newer Older
1
import asyncio
2
from abc import abstractmethod
3
from typing import Any, Awaitable, Dict, List, Optional, Set, Tuple, Union
4
5
6
7
8

from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
9
10
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
11
12
13
14
15
16
17

logger = init_logger(__name__)


class DistributedGPUExecutor(GPUExecutor):
    """Abstract superclass of multi-GPU executor implementations."""

18
19
20
21
22
23
24
25
26
27
    def __init__(self, *args, **kwargs):
        # This is non-None when the execute model loop is running
        # in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
        self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None
        # Updated by implementations that require additional args to be passed
        # to the _run_workers execute_model call
        self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {}

        super().__init__(*args, **kwargs)

28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
    def determine_num_available_blocks(self) -> Tuple[int, int]:
        """Determine the number of available KV blocks.

        This invokes `determine_num_available_blocks` on each worker and takes
        the min of the results, guaranteeing that the selected cache sizes are
        compatible with all workers.

        Returns:
            - tuple[num_gpu_blocks, num_cpu_blocks]
        """
        # Get the maximum number of blocks that can be allocated on GPU and CPU.
        num_blocks = self._run_workers("determine_num_available_blocks", )

        # Since we use a shared centralized controller, we take the minimum
        # number of blocks across all workers to make sure all the memory
        # operators can be applied to all workers.
        num_gpu_blocks = min(b[0] for b in num_blocks)
        num_cpu_blocks = min(b[1] for b in num_blocks)

        return num_gpu_blocks, num_cpu_blocks

    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
        """Initialize the KV cache in all workers.
        """

        # NOTE: We log here to avoid multiple logs when number of workers is
        # greater than one. We could log in the engine, but not all executors
        # have GPUs.
        logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
                    num_cpu_blocks)
59
60
61
62
        max_concurrency = (num_gpu_blocks * self.cache_config.block_size /
                           self.model_config.max_model_len)
        logger.info("Maximum concurrency for %s tokens per request: %.2fx",
                    self.model_config.max_model_len, max_concurrency)
63
64
65
66
67
68
69
70

        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

        self._run_workers("initialize_cache",
                          num_gpu_blocks=num_gpu_blocks,
                          num_cpu_blocks=num_cpu_blocks)

71
    def execute_model(
72
73
74
        self,
        execute_model_req: ExecuteModelRequest,
    ) -> List[SamplerOutput]:
75
76
77
        if self.parallel_worker_tasks is None:
            self.parallel_worker_tasks = self._run_workers(
                "start_worker_execution_loop",
78
                async_run_tensor_parallel_workers_only=True,
79
                **self.extra_execute_model_run_workers_kwargs)
80
81

        # Only the driver worker returns the sampling results.
82
83
84
        driver_outputs = self._driver_execute_model(execute_model_req)
        assert driver_outputs is not None
        return driver_outputs
85
86
87
88
89

    def stop_remote_worker_execution_loop(self) -> None:
        if self.parallel_worker_tasks is None:
            return

90
        self._driver_execute_model(execute_model_req=None)
91
92
93
94
95
        parallel_worker_tasks = self.parallel_worker_tasks
        self.parallel_worker_tasks = None
        # Ensure that workers exit model loop cleanly
        # (this will raise otherwise)
        self._wait_for_tasks_completion(parallel_worker_tasks)
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110

    def add_lora(self, lora_request: LoRARequest) -> bool:
        assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
        return self._run_workers(
            "add_lora",
            lora_request=lora_request,
        )

    def remove_lora(self, lora_id: int) -> bool:
        assert lora_id > 0, "lora_id must be greater than 0."
        return self._run_workers(
            "remove_lora",
            lora_id=lora_id,
        )

111
112
113
114
115
116
117
    def pin_lora(self, lora_id: int) -> bool:
        assert lora_id > 0, "lora_id must be greater than 0."
        return self._run_workers(
            "pin_lora",
            lora_id=lora_id,
        )

118
119
120
    def list_loras(self) -> Set[int]:
        return self._run_workers("list_loras")

121
122
123
124
125
126
127
128
129
130
131
    def save_sharded_state(
        self,
        path: str,
        pattern: Optional[str] = None,
        max_size: Optional[int] = None,
    ) -> None:
        self._run_workers("save_sharded_state",
                          path=path,
                          pattern=pattern,
                          max_size=max_size)

132
133
    @abstractmethod
    def _driver_execute_model(
134
135
        self, execute_model_req: Optional[ExecuteModelRequest]
    ) -> Optional[List[SamplerOutput]]:
136
137
        """Run execute_model in the driver worker.

138
139
140
        Passing None will cause the driver to stop the model execution loop
        running in each of the remote workers. In this case, this method
        returns None. Otherwise, this method returns the model output.
141
142
143
        """
        raise NotImplementedError

144
145
146
147
148
    @abstractmethod
    def _run_workers(
        self,
        method: str,
        *args,
149
        async_run_tensor_parallel_workers_only: bool = False,
150
151
152
        max_concurrent_workers: Optional[int] = None,
        **kwargs,
    ) -> Any:
153
154
155
        """Runs the given method on all workers.

        Args:
156
157
158
159
            async_run_tensor_parallel_workers_only: If True the method will be
                run only in the remote TP workers, not the driver worker.
                It will also be run asynchronously and return a list of futures
                rather than blocking on the results.
160
161
162
163
164
165
166
        """
        raise NotImplementedError

    @abstractmethod
    def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
        """Wait for futures returned from _run_workers() with
        async_run_remote_workers_only to complete."""
167
168
169
170
171
        raise NotImplementedError


class DistributedGPUExecutorAsync(DistributedGPUExecutor, ExecutorAsyncBase):

172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
    async def execute_model_async(
            self,
            execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
        if self.parallel_worker_tasks is None:
            # Start model execution loop running in the parallel workers
            self.parallel_worker_tasks = asyncio.create_task(
                self._start_worker_execution_loop())

        # Only the driver worker returns the sampling results.
        return await self._driver_execute_model_async(execute_model_req)

    async def stop_remote_worker_execution_loop_async(self) -> None:
        if self.parallel_worker_tasks is None:
            return

        await self._driver_execute_model_async()
        parallel_worker_tasks = self.parallel_worker_tasks
        self.parallel_worker_tasks = None
        # Ensure that workers exit model loop cleanly
        # (this will raise otherwise)
        await parallel_worker_tasks

194
    @abstractmethod
195
    async def _driver_execute_model_async(
196
        self,
197
        execute_model_req: Optional[ExecuteModelRequest] = None,
198
199
    ) -> List[SamplerOutput]:
        """Execute the model asynchronously in the driver worker.
200

201
202
203
204
        Passing None will cause the driver to stop the model execution
        loop running in each of the remote workers.
        """
        raise NotImplementedError
205

206
207
208
209
210
211
212
    @abstractmethod
    async def _start_worker_execution_loop(self):
        """Run execution loop on all workers. It guarantees all workers run
        the loop or None of them is running the loop. Loop can be stopped by
        `stop_remote_worker_execution_loop`.
        The API is idempotent (guarantee only 1 loop run at any moment)."""
        raise NotImplementedError