executor_base.py 14.2 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import asyncio
4
from abc import ABC, abstractmethod
5
6
from typing import (Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple,
                    Union)
7

8
9
10
import torch.nn as nn
from typing_extensions import TypeVar

11
import vllm.platforms
12
from vllm.config import VllmConfig
13
from vllm.logger import init_logger
14
from vllm.lora.request import LoRARequest
15
from vllm.model_executor.layers.sampler import SamplerOutput
16
from vllm.prompt_adapter.request import PromptAdapterRequest
17
18
from vllm.sequence import ExecuteModelRequest, PoolerOutput
from vllm.utils import make_async
19
from vllm.worker.worker_base import WorkerBase
20
21

logger = init_logger(__name__)
22

23
24
_R = TypeVar("_R", default=Any)

25
26
27
28

class ExecutorBase(ABC):
    """Base class for all executors.

29
30
    An executor is responsible for executing the model on one device,
    or it can be a distributed executor 
31
32
33
    that can execute the model on multiple devices.
    """

34
35
    uses_ray: bool  # whether the executor uses Ray for orchestration.

36
37
    def __init__(
        self,
38
        vllm_config: VllmConfig,
39
    ) -> None:
40
41
42
43
44
45
46
47
48
49
50
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.load_config = vllm_config.load_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.device_config = vllm_config.device_config
        self.speculative_config = vllm_config.speculative_config
        self.prompt_adapter_config = vllm_config.prompt_adapter_config
        self.observability_config = vllm_config.observability_config
51
        self._init_executor()
52
        self.is_sleeping = False
53
54
55

    @abstractmethod
    def _init_executor(self) -> None:
56
        raise NotImplementedError
57

58
    @abstractmethod
59
    def collective_rpc(self,
60
                       method: Union[str, Callable[..., _R]],
61
62
                       timeout: Optional[float] = None,
                       args: Tuple = (),
63
                       kwargs: Optional[Dict[str, Any]] = None) -> List[_R]:
64
        """
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
                :exc:`TimeoutError` on timeout. `None` means wait indefinitely.
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
        
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
85
        """
86
        raise NotImplementedError
87

88
    def determine_num_available_blocks(self) -> Tuple[int, int]:
89
90
91
92
93
94
95
        """Determine the number of available blocks for the GPU KV cache and
        swappable CPU KV cache.

        Normally, this should simply delegate to the underlying Worker. Some
        ExecutorBase may require modification of the result, e.g. to ensure the
        selected cache sizes are compatible with all workers.

96
        Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
97
98
99
100
        are blocks that are "active" on the device and can be appended to.
        num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
        appended to.
        """
101
102
103
104
        results = self.collective_rpc("determine_num_available_blocks")
        a = min([r[0] for r in results])
        b = min([r[1] for r in results])
        return a, b
105

106
107
108
109
110
    def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
        """Initialize the KV cache by invoking the underlying worker.
        """
        # NOTE: This is logged in the executor because there can be >1 workers.
        logger.info("# %s blocks: %d, # CPU blocks: %d",
111
112
                    vllm.platforms.current_platform.dispatch_key,
                    num_gpu_blocks, num_cpu_blocks)
113
114
115
116
117
118
119
120
121
122
        max_concurrency = (num_gpu_blocks * self.cache_config.block_size /
                           self.model_config.max_model_len)
        logger.info("Maximum concurrency for %s tokens per request: %.2fx",
                    self.model_config.max_model_len, max_concurrency)

        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

        self.collective_rpc("initialize_cache",
                            args=(num_gpu_blocks, num_cpu_blocks))
123

124
125
126
127
128
129
130
131
132
133
134
    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
        """
        Run a function directly on the model inside each worker,
        returning the result for each of them.
        """

        def rpc_func(worker: WorkerBase) -> _R:
            return func(worker.get_model())

        return self.collective_rpc(rpc_func)

135
    def execute_model(
136
        self, execute_model_req: ExecuteModelRequest
137
138
139
140
    ) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:
        output = self.collective_rpc("execute_model",
                                     args=(execute_model_req, ))
        return output[0]
141

142
143
144
145
    def stop_remote_worker_execution_loop(self) -> None:
        """Releases parallel workers from model loop."""
        return

146
    def add_lora(self, lora_request: LoRARequest) -> bool:
147
148
        assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
        return all(self.collective_rpc("add_lora", args=(lora_request, )))
149
150

    def remove_lora(self, lora_id: int) -> bool:
151
152
        assert lora_id > 0, "lora_id must be greater than 0."
        return all(self.collective_rpc("remove_lora", args=(lora_id, )))
153

154
    def pin_lora(self, lora_id: int) -> bool:
155
156
        assert lora_id > 0, "lora_id must be greater than 0."
        return all(self.collective_rpc("pin_lora", args=(lora_id, )))
157

158
    def list_loras(self) -> Set[int]:
159
160
161
162
        sets = self.collective_rpc("list_loras")
        for s in sets:
            assert s == sets[0], "All workers should have the same LORAs."
        return sets[0]
163

164
165
    def add_prompt_adapter(
            self, prompt_adapter_request: PromptAdapterRequest) -> bool:
166
167
168
169
170
        assert prompt_adapter_request.prompt_adapter_id > 0, \
            "prompt_adapter_id must be greater than 0."
        return all(
            self.collective_rpc("add_prompt_adapter",
                                args=(prompt_adapter_request, )))
171
172

    def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
173
174
175
176
177
        assert prompt_adapter_id > 0, \
            "prompt_adapter_id must be greater than 0."
        return all(
            self.collective_rpc("remove_prompt_adapter",
                                args=(prompt_adapter_id, )))
178
179

    def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
180
181
182
183
184
        assert prompt_adapter_id > 0, \
            "prompt_adapter_id must be greater than 0."
        return all(
            self.collective_rpc("pin_prompt_adapter",
                                args=(prompt_adapter_id, )))
185
186

    def list_prompt_adapters(self) -> Set[int]:
187
188
189
190
191
192
193
194
195
196
197
198
        sets = self.collective_rpc("list_prompt_adapters")
        for s in sets:
            assert (s == sets[0]
                    ), "All workers should have the same prompt adapters."
        return sets[0]

    def start_profile(self) -> None:
        self.collective_rpc("start_profile")

    def stop_profile(self) -> None:
        self.collective_rpc("stop_profile")

199
    def sleep(self, level: int = 1):
200
201
202
        if self.is_sleeping:
            logger.warning("Executor is already sleeping.")
            return
203
        self.collective_rpc("sleep", kwargs=dict(level=level))
204
        self.is_sleeping = True
205
206

    def wake_up(self):
207
208
209
        if not self.is_sleeping:
            logger.warning("Executor is not sleeping.")
            return
210
        self.collective_rpc("wake_up")
211
        self.is_sleeping = False
212

213
214
215
216
217
218
219
220
221
222
    def save_sharded_state(
        self,
        path: str,
        pattern: Optional[str] = None,
        max_size: Optional[int] = None,
    ) -> None:
        self.collective_rpc("save_sharded_state",
                            kwargs=dict(path=path,
                                        pattern=pattern,
                                        max_size=max_size))
223

224
225
226
227
228
229
    @abstractmethod
    def check_health(self) -> None:
        """Checks if the executor is healthy. If not, it should raise an
        exception."""
        raise NotImplementedError

230
231
232
233
234
235
236
    def shutdown(self) -> None:
        """Shutdown the executor."""
        return

    def __del__(self):
        self.shutdown()

237
    async def execute_model_async(
238
239
            self,
            execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
240
        """Executes one model step on the given sequences."""
241
242
        output = await make_async(self.execute_model)(execute_model_req)
        return output
243

244
245
246
247
    async def stop_remote_worker_execution_loop_async(self) -> None:
        """Releases parallel workers from model loop."""
        return

248
249
250
    async def check_health_async(self) -> None:
        """Checks if the executor is healthy. If not, it should raise an
        exception."""
251
        self.check_health()
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302


class DistributedExecutorBase(ExecutorBase):
    """Abstract superclass of distributed executor implementations."""

    def __init__(self, *args, **kwargs):
        # This is non-None when the execute model loop is running
        # in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
        self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None

        super().__init__(*args, **kwargs)

    def execute_model(
        self,
        execute_model_req: ExecuteModelRequest,
    ) -> List[SamplerOutput]:
        # TODO: unify into collective_rpc
        if self.parallel_worker_tasks is None:
            self.parallel_worker_tasks = self._run_workers(
                "start_worker_execution_loop",
                async_run_tensor_parallel_workers_only=True)

        # Only the driver worker returns the sampling results.
        driver_outputs = self._driver_execute_model(execute_model_req)
        assert driver_outputs is not None
        return driver_outputs

    def stop_remote_worker_execution_loop(self) -> None:
        if self.parallel_worker_tasks is None:
            return

        self._driver_execute_model(execute_model_req=None)
        parallel_worker_tasks = self.parallel_worker_tasks
        self.parallel_worker_tasks = None
        # Ensure that workers exit model loop cleanly
        # (this will raise otherwise)
        self._wait_for_tasks_completion(parallel_worker_tasks)

    @abstractmethod
    def _driver_execute_model(
        self, execute_model_req: Optional[ExecuteModelRequest]
    ) -> Optional[List[SamplerOutput]]:
        """Run execute_model in the driver worker.

        Passing None will cause the driver to stop the model execution loop
        running in each of the remote workers. In this case, this method
        returns None. Otherwise, this method returns the model output.
        """
        raise NotImplementedError

    def collective_rpc(self,
303
                       method: Union[str, Callable],
304
305
306
307
308
309
310
311
                       timeout: Optional[float] = None,
                       args: Tuple = (),
                       kwargs: Optional[Dict] = None) -> List[Any]:
        return self._run_workers(method, *args, **(kwargs or {}))

    @abstractmethod
    def _run_workers(
        self,
312
        method: Union[str, Callable],
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
        *args,
        async_run_tensor_parallel_workers_only: bool = False,
        max_concurrent_workers: Optional[int] = None,
        **kwargs,
    ) -> Any:
        """Runs the given method on all workers.

        Args:
            async_run_tensor_parallel_workers_only: If True the method will be
                run only in the remote TP workers, not the driver worker.
                It will also be run asynchronously and return a list of futures
                rather than blocking on the results.
        
        # TODO: simplify and merge with collective_rpc
        """
        raise NotImplementedError

    @abstractmethod
    def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
        """Wait for futures returned from _run_workers() with
        async_run_remote_workers_only to complete."""
        raise NotImplementedError

    async def execute_model_async(
            self,
            execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
        if self.parallel_worker_tasks is None:
            # Start model execution loop running in the parallel workers
            self.parallel_worker_tasks = asyncio.create_task(
                self._start_worker_execution_loop())

        # Only the driver worker returns the sampling results.
        return await self._driver_execute_model_async(execute_model_req)

    async def stop_remote_worker_execution_loop_async(self) -> None:
        if self.parallel_worker_tasks is None:
            return

        await self._driver_execute_model_async()
        parallel_worker_tasks = self.parallel_worker_tasks
        self.parallel_worker_tasks = None
        # Ensure that workers exit model loop cleanly
        # (this will raise otherwise)
        await parallel_worker_tasks

    @abstractmethod
    async def _driver_execute_model_async(
        self,
        execute_model_req: Optional[ExecuteModelRequest] = None,
    ) -> List[SamplerOutput]:
        """Execute the model asynchronously in the driver worker.

        Passing None will cause the driver to stop the model execution
        loop running in each of the remote workers.
        """
        raise NotImplementedError

    @abstractmethod
    async def _start_worker_execution_loop(self):
        """Run execution loop on all workers. It guarantees all workers run
        the loop or None of them is running the loop. Loop can be stopped by
        `stop_remote_worker_execution_loop`.
        The API is idempotent (guarantee only 1 loop run at any moment)."""
        raise NotImplementedError