executor_base.py 14.6 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import asyncio
4
import time
5
from abc import ABC, abstractmethod
6
7
from typing import (Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple,
                    Union)
8

9
10
11
import torch.nn as nn
from typing_extensions import TypeVar

12
import vllm.platforms
13
from vllm.config import VllmConfig
14
from vllm.logger import init_logger
15
from vllm.lora.request import LoRARequest
16
from vllm.model_executor.layers.sampler import SamplerOutput
17
from vllm.prompt_adapter.request import PromptAdapterRequest
18
19
from vllm.sequence import ExecuteModelRequest, PoolerOutput
from vllm.utils import make_async
20
from vllm.worker.worker_base import WorkerBase
21
22

logger = init_logger(__name__)
23

24
25
_R = TypeVar("_R", default=Any)

26
27
28
29

class ExecutorBase(ABC):
    """Base class for all executors.

30
31
    An executor is responsible for executing the model on one device,
    or it can be a distributed executor 
32
33
34
    that can execute the model on multiple devices.
    """

35
36
    uses_ray: bool  # whether the executor uses Ray for orchestration.

37
38
    def __init__(
        self,
39
        vllm_config: VllmConfig,
40
    ) -> None:
41
42
43
44
45
46
47
48
49
50
51
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.load_config = vllm_config.load_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.device_config = vllm_config.device_config
        self.speculative_config = vllm_config.speculative_config
        self.prompt_adapter_config = vllm_config.prompt_adapter_config
        self.observability_config = vllm_config.observability_config
52
        self._init_executor()
53
        self.is_sleeping = False
54
55
56

    @abstractmethod
    def _init_executor(self) -> None:
57
        raise NotImplementedError
58

59
    @abstractmethod
60
    def collective_rpc(self,
61
                       method: Union[str, Callable[..., _R]],
62
63
                       timeout: Optional[float] = None,
                       args: Tuple = (),
64
                       kwargs: Optional[Dict[str, Any]] = None) -> List[_R]:
65
        """
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
                :exc:`TimeoutError` on timeout. `None` means wait indefinitely.
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
        
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
86
        """
87
        raise NotImplementedError
88

89
    def determine_num_available_blocks(self) -> Tuple[int, int]:
90
91
92
93
94
95
96
        """Determine the number of available blocks for the GPU KV cache and
        swappable CPU KV cache.

        Normally, this should simply delegate to the underlying Worker. Some
        ExecutorBase may require modification of the result, e.g. to ensure the
        selected cache sizes are compatible with all workers.

97
        Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
98
99
100
101
        are blocks that are "active" on the device and can be appended to.
        num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
        appended to.
        """
102
103
104
105
        results = self.collective_rpc("determine_num_available_blocks")
        a = min([r[0] for r in results])
        b = min([r[1] for r in results])
        return a, b
106

107
108
109
110
111
    def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
        """Initialize the KV cache by invoking the underlying worker.
        """
        # NOTE: This is logged in the executor because there can be >1 workers.
        logger.info("# %s blocks: %d, # CPU blocks: %d",
Mengqing Cao's avatar
Mengqing Cao committed
112
                    vllm.platforms.current_platform.device_name,
113
                    num_gpu_blocks, num_cpu_blocks)
114
115
116
117
118
119
120
121
122
123
        max_concurrency = (num_gpu_blocks * self.cache_config.block_size /
                           self.model_config.max_model_len)
        logger.info("Maximum concurrency for %s tokens per request: %.2fx",
                    self.model_config.max_model_len, max_concurrency)

        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

        self.collective_rpc("initialize_cache",
                            args=(num_gpu_blocks, num_cpu_blocks))
124

125
126
127
128
129
130
131
132
133
134
135
    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
        """
        Run a function directly on the model inside each worker,
        returning the result for each of them.
        """

        def rpc_func(worker: WorkerBase) -> _R:
            return func(worker.get_model())

        return self.collective_rpc(rpc_func)

136
    def execute_model(
137
        self, execute_model_req: ExecuteModelRequest
138
139
140
141
    ) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:
        output = self.collective_rpc("execute_model",
                                     args=(execute_model_req, ))
        return output[0]
142

143
144
145
146
    def stop_remote_worker_execution_loop(self) -> None:
        """Releases parallel workers from model loop."""
        return

147
    def add_lora(self, lora_request: LoRARequest) -> bool:
148
149
        assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
        return all(self.collective_rpc("add_lora", args=(lora_request, )))
150
151

    def remove_lora(self, lora_id: int) -> bool:
152
153
        assert lora_id > 0, "lora_id must be greater than 0."
        return all(self.collective_rpc("remove_lora", args=(lora_id, )))
154

155
    def pin_lora(self, lora_id: int) -> bool:
156
157
        assert lora_id > 0, "lora_id must be greater than 0."
        return all(self.collective_rpc("pin_lora", args=(lora_id, )))
158

159
    def list_loras(self) -> Set[int]:
160
161
162
163
        sets = self.collective_rpc("list_loras")
        for s in sets:
            assert s == sets[0], "All workers should have the same LORAs."
        return sets[0]
164

165
166
    def add_prompt_adapter(
            self, prompt_adapter_request: PromptAdapterRequest) -> bool:
167
168
169
170
171
        assert prompt_adapter_request.prompt_adapter_id > 0, \
            "prompt_adapter_id must be greater than 0."
        return all(
            self.collective_rpc("add_prompt_adapter",
                                args=(prompt_adapter_request, )))
172
173

    def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
174
175
176
177
178
        assert prompt_adapter_id > 0, \
            "prompt_adapter_id must be greater than 0."
        return all(
            self.collective_rpc("remove_prompt_adapter",
                                args=(prompt_adapter_id, )))
179
180

    def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
181
182
183
184
185
        assert prompt_adapter_id > 0, \
            "prompt_adapter_id must be greater than 0."
        return all(
            self.collective_rpc("pin_prompt_adapter",
                                args=(prompt_adapter_id, )))
186
187

    def list_prompt_adapters(self) -> Set[int]:
188
189
190
191
192
193
194
195
196
197
198
199
        sets = self.collective_rpc("list_prompt_adapters")
        for s in sets:
            assert (s == sets[0]
                    ), "All workers should have the same prompt adapters."
        return sets[0]

    def start_profile(self) -> None:
        self.collective_rpc("start_profile")

    def stop_profile(self) -> None:
        self.collective_rpc("stop_profile")

200
    def sleep(self, level: int = 1):
201
202
203
        if self.is_sleeping:
            logger.warning("Executor is already sleeping.")
            return
204
        time_before_sleep = time.perf_counter()
205
        self.collective_rpc("sleep", kwargs=dict(level=level))
206
        time_after_sleep = time.perf_counter()
207
        self.is_sleeping = True
208
209
        logger.info("It took %.6f seconds to fall asleep.",
                    time_after_sleep - time_before_sleep)
210
211

    def wake_up(self):
212
213
214
        if not self.is_sleeping:
            logger.warning("Executor is not sleeping.")
            return
215
        time_before_wakeup = time.perf_counter()
216
        self.collective_rpc("wake_up")
217
        time_after_wakeup = time.perf_counter()
218
        self.is_sleeping = False
219
220
        logger.info("It took %.6f seconds to wake up.",
                    time_after_wakeup - time_before_wakeup)
221

222
223
224
225
226
227
228
229
230
231
    def save_sharded_state(
        self,
        path: str,
        pattern: Optional[str] = None,
        max_size: Optional[int] = None,
    ) -> None:
        self.collective_rpc("save_sharded_state",
                            kwargs=dict(path=path,
                                        pattern=pattern,
                                        max_size=max_size))
232

233
234
235
236
237
238
    @abstractmethod
    def check_health(self) -> None:
        """Checks if the executor is healthy. If not, it should raise an
        exception."""
        raise NotImplementedError

239
240
241
242
243
244
245
    def shutdown(self) -> None:
        """Shutdown the executor."""
        return

    def __del__(self):
        self.shutdown()

246
    async def execute_model_async(
247
248
            self,
            execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
249
        """Executes one model step on the given sequences."""
250
251
        output = await make_async(self.execute_model)(execute_model_req)
        return output
252

253
254
255
256
    async def stop_remote_worker_execution_loop_async(self) -> None:
        """Releases parallel workers from model loop."""
        return

257
258
259
    async def check_health_async(self) -> None:
        """Checks if the executor is healthy. If not, it should raise an
        exception."""
260
        self.check_health()
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311


class DistributedExecutorBase(ExecutorBase):
    """Abstract superclass of distributed executor implementations."""

    def __init__(self, *args, **kwargs):
        # This is non-None when the execute model loop is running
        # in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
        self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None

        super().__init__(*args, **kwargs)

    def execute_model(
        self,
        execute_model_req: ExecuteModelRequest,
    ) -> List[SamplerOutput]:
        # TODO: unify into collective_rpc
        if self.parallel_worker_tasks is None:
            self.parallel_worker_tasks = self._run_workers(
                "start_worker_execution_loop",
                async_run_tensor_parallel_workers_only=True)

        # Only the driver worker returns the sampling results.
        driver_outputs = self._driver_execute_model(execute_model_req)
        assert driver_outputs is not None
        return driver_outputs

    def stop_remote_worker_execution_loop(self) -> None:
        if self.parallel_worker_tasks is None:
            return

        self._driver_execute_model(execute_model_req=None)
        parallel_worker_tasks = self.parallel_worker_tasks
        self.parallel_worker_tasks = None
        # Ensure that workers exit model loop cleanly
        # (this will raise otherwise)
        self._wait_for_tasks_completion(parallel_worker_tasks)

    @abstractmethod
    def _driver_execute_model(
        self, execute_model_req: Optional[ExecuteModelRequest]
    ) -> Optional[List[SamplerOutput]]:
        """Run execute_model in the driver worker.

        Passing None will cause the driver to stop the model execution loop
        running in each of the remote workers. In this case, this method
        returns None. Otherwise, this method returns the model output.
        """
        raise NotImplementedError

    def collective_rpc(self,
312
                       method: Union[str, Callable],
313
314
315
316
317
318
319
320
                       timeout: Optional[float] = None,
                       args: Tuple = (),
                       kwargs: Optional[Dict] = None) -> List[Any]:
        return self._run_workers(method, *args, **(kwargs or {}))

    @abstractmethod
    def _run_workers(
        self,
321
        method: Union[str, Callable],
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
        *args,
        async_run_tensor_parallel_workers_only: bool = False,
        max_concurrent_workers: Optional[int] = None,
        **kwargs,
    ) -> Any:
        """Runs the given method on all workers.

        Args:
            async_run_tensor_parallel_workers_only: If True the method will be
                run only in the remote TP workers, not the driver worker.
                It will also be run asynchronously and return a list of futures
                rather than blocking on the results.
        
        # TODO: simplify and merge with collective_rpc
        """
        raise NotImplementedError

    @abstractmethod
    def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
        """Wait for futures returned from _run_workers() with
        async_run_remote_workers_only to complete."""
        raise NotImplementedError

    async def execute_model_async(
            self,
            execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
        if self.parallel_worker_tasks is None:
            # Start model execution loop running in the parallel workers
            self.parallel_worker_tasks = asyncio.create_task(
                self._start_worker_execution_loop())

        # Only the driver worker returns the sampling results.
        return await self._driver_execute_model_async(execute_model_req)

    async def stop_remote_worker_execution_loop_async(self) -> None:
        if self.parallel_worker_tasks is None:
            return

        await self._driver_execute_model_async()
        parallel_worker_tasks = self.parallel_worker_tasks
        self.parallel_worker_tasks = None
        # Ensure that workers exit model loop cleanly
        # (this will raise otherwise)
        await parallel_worker_tasks

    @abstractmethod
    async def _driver_execute_model_async(
        self,
        execute_model_req: Optional[ExecuteModelRequest] = None,
    ) -> List[SamplerOutput]:
        """Execute the model asynchronously in the driver worker.

        Passing None will cause the driver to stop the model execution
        loop running in each of the remote workers.
        """
        raise NotImplementedError

    @abstractmethod
    async def _start_worker_execution_loop(self):
        """Run execution loop on all workers. It guarantees all workers run
        the loop or None of them is running the loop. Loop can be stopped by
        `stop_remote_worker_execution_loop`.
        The API is idempotent (guarantee only 1 loop run at any moment)."""
        raise NotImplementedError