executor_base.py 14.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
import time
6
from abc import ABC, abstractmethod
7
from collections.abc import Awaitable, Callable
8
from functools import cached_property
9
from typing import Any
10

11
from typing_extensions import TypeVar
12

13
import vllm.platforms
14
from vllm.config import VllmConfig
15
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
16
from vllm.logger import init_logger
17
from vllm.lora.request import LoRARequest
18
from vllm.sequence import ExecuteModelRequest
19
from vllm.tasks import SupportedTask
20
from vllm.utils.async_utils import make_async
21
from vllm.v1.outputs import SamplerOutput
22
from vllm.v1.worker.worker_base import WorkerBase
23
24

logger = init_logger(__name__)
25

26
27
_R = TypeVar("_R", default=Any)

28
29
30
31

class ExecutorBase(ABC):
    """Base class for all executors.

32
    An executor is responsible for executing the model on one device,
33
    or it can be a distributed executor
34
35
36
    that can execute the model on multiple devices.
    """

37
    uses_ray: bool  # whether the executor uses Ray for orchestration.
38
    supports_pp: bool = False  # whether the executor supports PP
39

40
41
    def __init__(
        self,
42
        vllm_config: VllmConfig,
43
    ) -> None:
44
45
46
47
48
49
50
51
52
53
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.load_config = vllm_config.load_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.device_config = vllm_config.device_config
        self.speculative_config = vllm_config.speculative_config
        self.observability_config = vllm_config.observability_config
54
        self._init_executor()
55
        self.is_sleeping = False
56
        self.sleeping_tags: set[str] = set()
57
        self.kv_output_aggregator: KVOutputAggregator | None = None
58
59
60

    @abstractmethod
    def _init_executor(self) -> None:
61
        raise NotImplementedError
62

63
    @abstractmethod
64
65
    def collective_rpc(
        self,
66
67
        method: str | Callable[[WorkerBase], _R],
        timeout: float | None = None,
68
        args: tuple = (),
69
        kwargs: dict[str, Any] | None = None,
70
    ) -> list[_R]:
71
        """
72
73
74
75
76
77
78
79
80
81
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
82
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
83
84
85
86
87
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
88

89
90
91
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
92
        """
93
        raise NotImplementedError
94

95
    def determine_num_available_blocks(self) -> tuple[int, int]:
96
97
98
99
100
101
102
        """Determine the number of available blocks for the GPU KV cache and
        swappable CPU KV cache.

        Normally, this should simply delegate to the underlying Worker. Some
        ExecutorBase may require modification of the result, e.g. to ensure the
        selected cache sizes are compatible with all workers.

103
104
        Returns a tuple `(num_gpu_blocks, num_cpu_blocks)`, where
        `num_gpu_blocks` are blocks that are "active" on the device and can be
105
        appended to.
106
        `num_cpu_blocks` refers to "swapped" blocks in CPU memory and cannot be
107
108
        appended to.
        """
109
110
111
112
        results = self.collective_rpc("determine_num_available_blocks")
        a = min([r[0] for r in results])
        b = min([r[1] for r in results])
        return a, b
113

114
    def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
115
        """Initialize the KV cache by invoking the underlying worker."""
116
        # NOTE: This is logged in the executor because there can be >1 workers.
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
        logger.info(
            "# %s blocks: %d, # CPU blocks: %d",
            vllm.platforms.current_platform.device_name,
            num_gpu_blocks,
            num_cpu_blocks,
        )
        max_concurrency = (
            num_gpu_blocks
            * self.cache_config.block_size
            / self.model_config.max_model_len
        )
        logger.info(
            "Maximum concurrency for %s tokens per request: %.2fx",
            self.model_config.max_model_len,
            max_concurrency,
        )
133
134
135
136

        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

137
        self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks))
138

139
    @cached_property  # Avoid unnecessary RPC calls
140
141
142
    def supported_tasks(self) -> tuple[SupportedTask, ...]:
        output = self.collective_rpc("get_supported_tasks")
        return output[0]
143

144
    def execute_model(
145
        self, execute_model_req: ExecuteModelRequest
146
    ) -> list[SamplerOutput]:
147
        output = self.collective_rpc("execute_model", args=(execute_model_req,))
148
        assert output[0] is not None
149
        return output[0]
150

151
152
153
154
    def stop_remote_worker_execution_loop(self) -> None:
        """Releases parallel workers from model loop."""
        return

155
    def add_lora(self, lora_request: LoRARequest) -> bool:
156
        assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
157
        return all(self.collective_rpc("add_lora", args=(lora_request,)))
158
159

    def remove_lora(self, lora_id: int) -> bool:
160
        assert lora_id > 0, "lora_id must be greater than 0."
161
        return all(self.collective_rpc("remove_lora", args=(lora_id,)))
162

163
    def pin_lora(self, lora_id: int) -> bool:
164
        assert lora_id > 0, "lora_id must be greater than 0."
165
        return all(self.collective_rpc("pin_lora", args=(lora_id,)))
166

167
    def list_loras(self) -> set[int]:
168
169
170
171
        sets = self.collective_rpc("list_loras")
        for s in sets:
            assert s == sets[0], "All workers should have the same LORAs."
        return sets[0]
172

173
174
175
176
    def reset_mm_cache(self) -> None:
        """Reset the multi-modal cache in each worker."""
        self.collective_rpc("reset_mm_cache")

177
178
179
180
181
182
    def start_profile(self) -> None:
        self.collective_rpc("start_profile")

    def stop_profile(self) -> None:
        self.collective_rpc("stop_profile")

183
    def sleep(self, level: int = 1):
184
185
186
        if self.is_sleeping:
            logger.warning("Executor is already sleeping.")
            return
187
        time_before_sleep = time.perf_counter()
188
        self.collective_rpc("sleep", kwargs=dict(level=level))
189
        time_after_sleep = time.perf_counter()
190
        self.sleeping_tags = {"weights", "kv_cache"}
191
        self.is_sleeping = True
192
193
194
        logger.info(
            "It took %.6f seconds to fall asleep.", time_after_sleep - time_before_sleep
        )
195

196
    def wake_up(self, tags: list[str] | None = None):
197
198
199
        if not self.is_sleeping:
            logger.warning("Executor is not sleeping.")
            return
200
201
202
        if tags:
            for tag in tags:
                if tag not in self.sleeping_tags:
203
204
205
                    logger.warning(
                        "Tag %s is not in sleeping tags %s", tag, self.sleeping_tags
                    )
206
                    return
207
        time_before_wakeup = time.perf_counter()
208
        self.collective_rpc("wake_up", kwargs=dict(tags=tags))
209
        time_after_wakeup = time.perf_counter()
210
211
212
213
214
        logger.info(
            "It took %.6f seconds to wake up tags %s.",
            time_after_wakeup - time_before_wakeup,
            tags if tags is not None else self.sleeping_tags,
        )
215
216
217
218
219
220
221
        if tags:
            for tag in tags:
                self.sleeping_tags.remove(tag)
        else:
            self.sleeping_tags.clear()
        if not self.sleeping_tags:
            self.is_sleeping = False
222

223
224
225
    def save_sharded_state(
        self,
        path: str,
226
227
        pattern: str | None = None,
        max_size: int | None = None,
228
    ) -> None:
229
230
231
232
        self.collective_rpc(
            "save_sharded_state",
            kwargs=dict(path=path, pattern=pattern, max_size=max_size),
        )
233

234
235
236
237
238
239
    @abstractmethod
    def check_health(self) -> None:
        """Checks if the executor is healthy. If not, it should raise an
        exception."""
        raise NotImplementedError

240
241
    def shutdown(self) -> None:
        """Shutdown the executor."""
242
        self.collective_rpc("shutdown")
243

244
    async def execute_model_async(
245
        self, execute_model_req: ExecuteModelRequest
246
    ) -> list[SamplerOutput]:
247
        """Executes one model step on the given sequences."""
248
249
        output = await make_async(self.execute_model)(execute_model_req)
        return output
250

251
252
253
254
    async def stop_remote_worker_execution_loop_async(self) -> None:
        """Releases parallel workers from model loop."""
        return

255
256
257
    async def check_health_async(self) -> None:
        """Checks if the executor is healthy. If not, it should raise an
        exception."""
258
        self.check_health()
259

260
    def init_kv_output_aggregator(self, finished_count: int | None) -> None:
261
262
        """Init KVOutputAggregator"""
        self.kv_output_aggregator = KVOutputAggregator(
263
264
            finished_count or self.parallel_config.world_size
        )
265

266
267
268
269
270
271
272

class DistributedExecutorBase(ExecutorBase):
    """Abstract superclass of distributed executor implementations."""

    def __init__(self, *args, **kwargs):
        # This is non-None when the execute model loop is running
        # in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
273
        self.parallel_worker_tasks: Any | Awaitable[Any] | None = None
274
275
276
277
278
279

        super().__init__(*args, **kwargs)

    def execute_model(
        self,
        execute_model_req: ExecuteModelRequest,
280
    ) -> list[SamplerOutput]:
281
282
283
284
        # TODO: unify into collective_rpc
        if self.parallel_worker_tasks is None:
            self.parallel_worker_tasks = self._run_workers(
                "start_worker_execution_loop",
285
286
                async_run_tensor_parallel_workers_only=True,
            )
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305

        # Only the driver worker returns the sampling results.
        driver_outputs = self._driver_execute_model(execute_model_req)
        assert driver_outputs is not None
        return driver_outputs

    def stop_remote_worker_execution_loop(self) -> None:
        if self.parallel_worker_tasks is None:
            return

        self._driver_execute_model(execute_model_req=None)
        parallel_worker_tasks = self.parallel_worker_tasks
        self.parallel_worker_tasks = None
        # Ensure that workers exit model loop cleanly
        # (this will raise otherwise)
        self._wait_for_tasks_completion(parallel_worker_tasks)

    @abstractmethod
    def _driver_execute_model(
306
307
        self, execute_model_req: ExecuteModelRequest | None
    ) -> list[SamplerOutput] | None:
308
309
310
311
312
313
314
315
        """Run execute_model in the driver worker.

        Passing None will cause the driver to stop the model execution loop
        running in each of the remote workers. In this case, this method
        returns None. Otherwise, this method returns the model output.
        """
        raise NotImplementedError

316
317
    def collective_rpc(
        self,
318
319
        method: str | Callable,
        timeout: float | None = None,
320
        args: tuple = (),
321
        kwargs: dict[str, Any] | None = None,
322
    ) -> list[Any]:
323
324
325
326
327
        return self._run_workers(method, *args, **(kwargs or {}))

    @abstractmethod
    def _run_workers(
        self,
328
        method: str | Callable,
329
330
        *args,
        async_run_tensor_parallel_workers_only: bool = False,
331
        max_concurrent_workers: int | None = None,
332
333
334
335
336
337
338
339
340
        **kwargs,
    ) -> Any:
        """Runs the given method on all workers.

        Args:
            async_run_tensor_parallel_workers_only: If True the method will be
                run only in the remote TP workers, not the driver worker.
                It will also be run asynchronously and return a list of futures
                rather than blocking on the results.
341

342
343
344
345
346
347
348
349
350
351
352
        # TODO: simplify and merge with collective_rpc
        """
        raise NotImplementedError

    @abstractmethod
    def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
        """Wait for futures returned from _run_workers() with
        async_run_remote_workers_only to complete."""
        raise NotImplementedError

    async def execute_model_async(
353
        self, execute_model_req: ExecuteModelRequest
354
    ) -> list[SamplerOutput]:
355
356
357
        if self.parallel_worker_tasks is None:
            # Start model execution loop running in the parallel workers
            self.parallel_worker_tasks = asyncio.create_task(
358
359
                self._start_worker_execution_loop()
            )
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377

        # Only the driver worker returns the sampling results.
        return await self._driver_execute_model_async(execute_model_req)

    async def stop_remote_worker_execution_loop_async(self) -> None:
        if self.parallel_worker_tasks is None:
            return

        await self._driver_execute_model_async()
        parallel_worker_tasks = self.parallel_worker_tasks
        self.parallel_worker_tasks = None
        # Ensure that workers exit model loop cleanly
        # (this will raise otherwise)
        await parallel_worker_tasks

    @abstractmethod
    async def _driver_execute_model_async(
        self,
378
        execute_model_req: ExecuteModelRequest | None = None,
379
    ) -> list[SamplerOutput]:
380
381
382
383
384
385
386
387
388
389
390
391
392
393
        """Execute the model asynchronously in the driver worker.

        Passing None will cause the driver to stop the model execution
        loop running in each of the remote workers.
        """
        raise NotImplementedError

    @abstractmethod
    async def _start_worker_execution_loop(self):
        """Run execution loop on all workers. It guarantees all workers run
        the loop or None of them is running the loop. Loop can be stopped by
        `stop_remote_worker_execution_loop`.
        The API is idempotent (guarantee only 1 loop run at any moment)."""
        raise NotImplementedError