abstract.py 12.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
import time
from abc import ABC, abstractmethod
5
from collections.abc import Callable
6
from concurrent.futures import Future
7
from functools import cached_property
8
from typing import TYPE_CHECKING, Literal, TypeVar, overload
9

10
from vllm.config import VllmConfig
11
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
12
13
14
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
    KVConnectorHandshakeMetadata,
)
15
16
17
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.tasks import SupportedTask
18
from vllm.utils.import_utils import resolve_obj_by_qualname
19
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
20
from vllm.v1.engine import ReconfigureDistributedRequest
21
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
22
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
23
24
from vllm.v1.worker.worker_base import WorkerBase

25
26
27
if TYPE_CHECKING:
    from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase

28
29
30
logger = init_logger(__name__)

_R = TypeVar("_R")
31

32
33
FailureCallback = Callable[[], None]

34

35
36
37
38
39
class Executor(ABC):
    """Abstract base class for vLLM executors."

    An executor is responsible for executing the model on one device,
    or it can be a distributed executor that can execute the model on multiple devices.
40
    """
41
42
43

    uses_ray: bool = False  # whether the executor uses Ray for orchestration.
    supports_pp: bool = False  # whether the executor supports PP
44

45
    @staticmethod
46
47
    def get_class(vllm_config: VllmConfig) -> type["Executor"]:
        executor_class: type[Executor]
48
        parallel_config = vllm_config.parallel_config
49
        distributed_executor_backend = parallel_config.distributed_executor_backend
50
51
        # distributed_executor_backend must be set in VllmConfig.__post_init__
        if isinstance(distributed_executor_backend, type):
52
            if not issubclass(distributed_executor_backend, Executor):
53
54
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
55
                    f"Executor. Got {distributed_executor_backend}."
56
                )
57
58
            executor_class = distributed_executor_backend
        elif distributed_executor_backend == "ray":
59
            from vllm.v1.executor.ray_executor import RayDistributedExecutor
60

61
            executor_class = RayDistributedExecutor
62
63
        elif distributed_executor_backend == "mp":
            from vllm.v1.executor.multiproc_executor import MultiprocExecutor
64

65
            executor_class = MultiprocExecutor
66
        elif distributed_executor_backend == "uni":
67
68
            from vllm.v1.executor.uniproc_executor import UniProcExecutor

69
70
71
72
73
            executor_class = UniProcExecutor
        elif distributed_executor_backend == "external_launcher":
            # TODO: make v1 scheduling deterministic
            # to support external launcher
            executor_class = ExecutorWithExternalLauncher
74
        elif isinstance(distributed_executor_backend, str):
75
            executor_class = resolve_obj_by_qualname(distributed_executor_backend)
76
            if not issubclass(executor_class, Executor):
77
78
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
79
                    f"Executor. Got {executor_class}."
80
                )
81
        else:
82
83
84
            raise ValueError(
                f"Unknown distributed executor backend: {distributed_executor_backend}"
            )
85
86
        return executor_class

87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
    def __init__(
        self,
        vllm_config: VllmConfig,
    ) -> None:
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.load_config = vllm_config.load_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.device_config = vllm_config.device_config
        self.speculative_config = vllm_config.speculative_config
        self.observability_config = vllm_config.observability_config
        self._init_executor()
        self.is_sleeping = False
        self.sleeping_tags: set[str] = set()
        self.kv_output_aggregator: KVOutputAggregator | None = None

    @abstractmethod
    def _init_executor(self) -> None:
        raise NotImplementedError

110
    def initialize_from_config(self, kv_cache_configs: list[KVCacheConfig]) -> None:
111
112
113
114
        """
        Initialize the KV caches and begin the model execution loop of the
        underlying workers.
        """
115
        self.collective_rpc("initialize_from_config", args=(kv_cache_configs,))
116
        self.collective_rpc("compile_or_warm_up_model")
117

118
    def register_failure_callback(self, callback: FailureCallback):  # noqa: B027
119
120
121
122
123
124
        """
        Register a function to be called if the executor enters a permanent
        failed state.
        """
        pass

125
    def determine_available_memory(self) -> list[int]:  # in bytes
126
        return self.collective_rpc("determine_available_memory")
127

128
    def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]:
129
        return self.collective_rpc("get_kv_cache_spec")
130

131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
    @overload
    def collective_rpc(
        self,
        method: str | Callable[[WorkerBase], _R],
        timeout: float | None = None,
        args: tuple = (),
        kwargs: dict | None = None,
        non_block: Literal[False] = False,
    ) -> list[_R]:
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.
            non_block: If `True`, returns a list of Futures instead of waiting
                for the results.

        Returns:
            A list containing the results from each worker.

        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
        """
        pass

    @overload
167
168
    def collective_rpc(
        self,
169
        method: str | Callable[[WorkerBase], _R],
170
        timeout: float | None = None,
171
        args: tuple = (),
172
        kwargs: dict | None = None,
173
        non_block: Literal[True] = True,
174
    ) -> Future[list[_R]]:
175
176
177
178
179
180
        pass

    @abstractmethod
    def collective_rpc(
        self, method, timeout=None, args=(), kwargs=None, non_block: bool = False
    ):
181
182
        raise NotImplementedError

183
184
185
186
187
    def get_kv_connector_handshake_metadata(
        self,
    ) -> list[dict[int, KVConnectorHandshakeMetadata]]:
        return self.collective_rpc("get_kv_connector_handshake_metadata")

188
    @overload
189
    def execute_model(
190
191
        self, scheduler_output: SchedulerOutput, non_block: Literal[False] = False
    ) -> ModelRunnerOutput | None:
192
193
194
195
        pass

    @overload
    def execute_model(
196
197
        self, scheduler_output: SchedulerOutput, non_block: Literal[True] = True
    ) -> Future[ModelRunnerOutput | None]:
198
199
200
201
        pass

    def execute_model(
        self, scheduler_output: SchedulerOutput, non_block: bool = False
202
    ) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
203
        output = self.collective_rpc(  # type: ignore[call-overload]
204
205
            "execute_model", args=(scheduler_output,), non_block=non_block
        )
206
        return output[0]
207

208
209
210
211
212
213
214
215
216
217
218
219
220
221
    @overload
    def sample_tokens(
        self, grammar_output: GrammarOutput | None, non_block: Literal[False] = False
    ) -> ModelRunnerOutput:
        pass

    @overload
    def sample_tokens(
        self, grammar_output: GrammarOutput | None, non_block: Literal[True] = True
    ) -> Future[ModelRunnerOutput]:
        pass

    def sample_tokens(
        self, grammar_output: GrammarOutput | None, non_block: bool = False
222
    ) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
223
224
225
226
227
        output = self.collective_rpc(  # type: ignore[call-overload]
            "sample_tokens", args=(grammar_output,), non_block=non_block
        )
        return output[0]

228
229
230
    def execute_dummy_batch(self) -> None:
        self.collective_rpc("execute_dummy_batch")

231
    def take_draft_token_ids(self) -> DraftTokenIds | None:
232
        output: list[DraftTokenIds] = self.collective_rpc("take_draft_token_ids")
233
234
        return output[0]

235
236
237
238
    @property
    def max_concurrent_batches(self) -> int:
        return 1

239
    def profile(self, is_start: bool = True):
240
        self.collective_rpc("profile", args=(is_start,))
241

242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
    def save_sharded_state(
        self,
        path: str,
        pattern: str | None = None,
        max_size: int | None = None,
    ) -> None:
        self.collective_rpc(
            "save_sharded_state",
            kwargs=dict(path=path, pattern=pattern, max_size=max_size),
        )

    @abstractmethod
    def check_health(self) -> None:
        """Checks if the executor is healthy. If not, it should raise an
        exception."""
        raise NotImplementedError
258

259
260
261
    def shutdown(self) -> None:
        """Shutdown the executor."""
        self.collective_rpc("shutdown")
262

263
    def init_kv_output_aggregator(self, connector: "KVConnectorBase") -> None:
264
        """Init KVOutputAggregator"""
265
266
        self.kv_output_aggregator = KVOutputAggregator.from_connector(
            connector, self.parallel_config.world_size
267
        )
268

269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
    @cached_property  # Avoid unnecessary RPC calls
    def supported_tasks(self) -> tuple[SupportedTask, ...]:
        output: list[tuple[SupportedTask, ...]]
        output = self.collective_rpc("get_supported_tasks")
        return output[0]

    def add_lora(self, lora_request: LoRARequest) -> bool:
        assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
        return all(self.collective_rpc("add_lora", args=(lora_request,)))

    def remove_lora(self, lora_id: int) -> bool:
        assert lora_id > 0, "lora_id must be greater than 0."
        return all(self.collective_rpc("remove_lora", args=(lora_id,)))

    def pin_lora(self, lora_id: int) -> bool:
        assert lora_id > 0, "lora_id must be greater than 0."
        return all(self.collective_rpc("pin_lora", args=(lora_id,)))

    def list_loras(self) -> set[int]:
        sets: list[set[int]] = self.collective_rpc("list_loras")
        for s in sets:
            assert s == sets[0], "All workers should have the same LORAs."
        return sets[0]

    def reset_mm_cache(self) -> None:
        """Reset the multi-modal cache in each worker."""
        self.collective_rpc("reset_mm_cache")

    def sleep(self, level: int = 1):
        if self.is_sleeping:
            logger.warning("Executor is already sleeping.")
            return
        time_before_sleep = time.perf_counter()
        self.collective_rpc("sleep", kwargs=dict(level=level))
        time_after_sleep = time.perf_counter()
        self.sleeping_tags = {"weights", "kv_cache"}
        self.is_sleeping = True
        logger.info(
            "It took %.6f seconds to fall asleep.", time_after_sleep - time_before_sleep
        )

    def wake_up(self, tags: list[str] | None = None):
        if not self.is_sleeping:
            logger.warning("Executor is not sleeping.")
            return
        if tags:
            for tag in tags:
                if tag not in self.sleeping_tags:
                    logger.warning(
                        "Tag %s is not in sleeping tags %s", tag, self.sleeping_tags
                    )
                    return
        time_before_wakeup = time.perf_counter()
        self.collective_rpc("wake_up", kwargs=dict(tags=tags))
        time_after_wakeup = time.perf_counter()
        logger.info(
            "It took %.6f seconds to wake up tags %s.",
            time_after_wakeup - time_before_wakeup,
            tags if tags is not None else self.sleeping_tags,
        )
        if tags:
            for tag in tags:
                self.sleeping_tags.remove(tag)
        else:
            self.sleeping_tags.clear()
        if not self.sleeping_tags:
            self.is_sleeping = False

    def reinitialize_distributed(
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
        raise NotImplementedError


from vllm.v1.executor.uniproc_executor import (  # noqa: E402
    ExecutorWithExternalLauncher as _ExecutorWithExternalLauncher,
)
from vllm.v1.executor.uniproc_executor import (  # noqa: E402
    UniProcExecutor as _UniProcExecutor,
)

# For backwards compatibility.
UniProcExecutor = _UniProcExecutor
ExecutorWithExternalLauncher = _ExecutorWithExternalLauncher