abstract.py 13.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
import time
from abc import ABC, abstractmethod
5
from collections.abc import Callable
6
from concurrent.futures import Future
7
from functools import cached_property
8
from typing import TYPE_CHECKING, Literal, TypeVar, overload
9

10
from vllm.config import VllmConfig
11
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
12
13
14
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
    KVConnectorHandshakeMetadata,
)
15
16
17
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.tasks import SupportedTask
18
from vllm.tracing import instrument
19
from vllm.utils.import_utils import resolve_obj_by_qualname
20
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
21
from vllm.v1.engine import ReconfigureDistributedRequest
22
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
23
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
24
25
from vllm.v1.worker.worker_base import WorkerBase

26
27
28
if TYPE_CHECKING:
    from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase

29
30
31
logger = init_logger(__name__)

_R = TypeVar("_R")
32

33
34
FailureCallback = Callable[[], None]

35

36
37
38
39
40
class Executor(ABC):
    """Abstract base class for vLLM executors."

    An executor is responsible for executing the model on one device,
    or it can be a distributed executor that can execute the model on multiple devices.
41
    """
42
43
44

    uses_ray: bool = False  # whether the executor uses Ray for orchestration.
    supports_pp: bool = False  # whether the executor supports PP
45

46
    @staticmethod
47
48
    def get_class(vllm_config: VllmConfig) -> type["Executor"]:
        executor_class: type[Executor]
49
        parallel_config = vllm_config.parallel_config
50
        distributed_executor_backend = parallel_config.distributed_executor_backend
51
52
        # distributed_executor_backend must be set in VllmConfig.__post_init__
        if isinstance(distributed_executor_backend, type):
53
            if not issubclass(distributed_executor_backend, Executor):
54
55
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
56
                    f"Executor. Got {distributed_executor_backend}."
57
                )
58
59
            executor_class = distributed_executor_backend
        elif distributed_executor_backend == "ray":
60
            from vllm.v1.executor.ray_executor import RayDistributedExecutor
61

62
            executor_class = RayDistributedExecutor
63
64
        elif distributed_executor_backend == "mp":
            from vllm.v1.executor.multiproc_executor import MultiprocExecutor
65

66
            executor_class = MultiprocExecutor
67
        elif distributed_executor_backend == "uni":
68
69
            from vllm.v1.executor.uniproc_executor import UniProcExecutor

70
71
72
73
74
            executor_class = UniProcExecutor
        elif distributed_executor_backend == "external_launcher":
            # TODO: make v1 scheduling deterministic
            # to support external launcher
            executor_class = ExecutorWithExternalLauncher
75
        elif isinstance(distributed_executor_backend, str):
76
            executor_class = resolve_obj_by_qualname(distributed_executor_backend)
77
            if not issubclass(executor_class, Executor):
78
79
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
80
                    f"Executor. Got {executor_class}."
81
                )
82
        else:
83
84
85
            raise ValueError(
                f"Unknown distributed executor backend: {distributed_executor_backend}"
            )
86
87
        return executor_class

88
    @instrument(span_name="Executor init")
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
    def __init__(
        self,
        vllm_config: VllmConfig,
    ) -> None:
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.load_config = vllm_config.load_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.device_config = vllm_config.device_config
        self.speculative_config = vllm_config.speculative_config
        self.observability_config = vllm_config.observability_config
        self._init_executor()
        self.is_sleeping = False
        self.sleeping_tags: set[str] = set()
        self.kv_output_aggregator: KVOutputAggregator | None = None

    @abstractmethod
    def _init_executor(self) -> None:
        raise NotImplementedError

112
    def initialize_from_config(self, kv_cache_configs: list[KVCacheConfig]) -> None:
113
114
115
116
        """
        Initialize the KV caches and begin the model execution loop of the
        underlying workers.
        """
117
        self.collective_rpc("initialize_from_config", args=(kv_cache_configs,))
118
119
120
121
122
123
124
125
126
        compilation_times: list[float] = self.collective_rpc("compile_or_warm_up_model")
        # Propagate compilation time from workers back to the main process.
        # With TP>1, compilation happens in worker processes, so the main
        # process config is never updated. Use max across workers since they
        # compile in parallel.
        if compilation_times:
            self.vllm_config.compilation_config.compilation_time = max(
                compilation_times
            )
127

128
    def register_failure_callback(self, callback: FailureCallback):  # noqa: B027
129
130
131
132
133
134
        """
        Register a function to be called if the executor enters a permanent
        failed state.
        """
        pass

135
    def determine_available_memory(self) -> list[int]:  # in bytes
136
        return self.collective_rpc("determine_available_memory")
137

138
    def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]:
139
        return self.collective_rpc("get_kv_cache_spec")
140

141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
    @overload
    def collective_rpc(
        self,
        method: str | Callable[[WorkerBase], _R],
        timeout: float | None = None,
        args: tuple = (),
        kwargs: dict | None = None,
        non_block: Literal[False] = False,
    ) -> list[_R]:
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.
            non_block: If `True`, returns a list of Futures instead of waiting
                for the results.

        Returns:
            A list containing the results from each worker.

        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
        """
        pass

    @overload
177
178
    def collective_rpc(
        self,
179
        method: str | Callable[[WorkerBase], _R],
180
        timeout: float | None = None,
181
        args: tuple = (),
182
        kwargs: dict | None = None,
183
        non_block: Literal[True] = True,
184
    ) -> Future[list[_R]]:
185
186
187
188
189
190
        pass

    @abstractmethod
    def collective_rpc(
        self, method, timeout=None, args=(), kwargs=None, non_block: bool = False
    ):
191
192
        raise NotImplementedError

193
194
195
196
197
    def get_kv_connector_handshake_metadata(
        self,
    ) -> list[dict[int, KVConnectorHandshakeMetadata]]:
        return self.collective_rpc("get_kv_connector_handshake_metadata")

198
    @overload
199
    def execute_model(
200
201
        self, scheduler_output: SchedulerOutput, non_block: Literal[False] = False
    ) -> ModelRunnerOutput | None:
202
203
204
205
        pass

    @overload
    def execute_model(
206
207
        self, scheduler_output: SchedulerOutput, non_block: Literal[True] = True
    ) -> Future[ModelRunnerOutput | None]:
208
209
210
211
        pass

    def execute_model(
        self, scheduler_output: SchedulerOutput, non_block: bool = False
212
    ) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
213
        output = self.collective_rpc(  # type: ignore[call-overload]
214
215
            "execute_model", args=(scheduler_output,), non_block=non_block
        )
216
        return output[0]
217

218
219
220
221
222
223
224
225
226
227
228
229
230
231
    @overload
    def sample_tokens(
        self, grammar_output: GrammarOutput | None, non_block: Literal[False] = False
    ) -> ModelRunnerOutput:
        pass

    @overload
    def sample_tokens(
        self, grammar_output: GrammarOutput | None, non_block: Literal[True] = True
    ) -> Future[ModelRunnerOutput]:
        pass

    def sample_tokens(
        self, grammar_output: GrammarOutput | None, non_block: bool = False
232
    ) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
233
234
235
236
237
        output = self.collective_rpc(  # type: ignore[call-overload]
            "sample_tokens", args=(grammar_output,), non_block=non_block
        )
        return output[0]

238
239
240
    def execute_dummy_batch(self) -> None:
        self.collective_rpc("execute_dummy_batch")

241
    def take_draft_token_ids(self) -> DraftTokenIds | None:
242
        output: list[DraftTokenIds] = self.collective_rpc("take_draft_token_ids")
243
244
        return output[0]

245
246
247
248
    @property
    def max_concurrent_batches(self) -> int:
        return 1

249
250
    def profile(self, is_start: bool = True, profile_prefix: str | None = None):
        self.collective_rpc("profile", args=(is_start, profile_prefix))
251

252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
    def save_sharded_state(
        self,
        path: str,
        pattern: str | None = None,
        max_size: int | None = None,
    ) -> None:
        self.collective_rpc(
            "save_sharded_state",
            kwargs=dict(path=path, pattern=pattern, max_size=max_size),
        )

    @abstractmethod
    def check_health(self) -> None:
        """Checks if the executor is healthy. If not, it should raise an
        exception."""
        raise NotImplementedError
268

269
270
271
    def shutdown(self) -> None:
        """Shutdown the executor."""
        self.collective_rpc("shutdown")
272

273
    def init_kv_output_aggregator(self, connector: "KVConnectorBase") -> None:
274
        """Init KVOutputAggregator"""
275
276
        self.kv_output_aggregator = KVOutputAggregator.from_connector(
            connector, self.parallel_config.world_size
277
        )
278

279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
    @cached_property  # Avoid unnecessary RPC calls
    def supported_tasks(self) -> tuple[SupportedTask, ...]:
        output: list[tuple[SupportedTask, ...]]
        output = self.collective_rpc("get_supported_tasks")
        return output[0]

    def add_lora(self, lora_request: LoRARequest) -> bool:
        assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
        return all(self.collective_rpc("add_lora", args=(lora_request,)))

    def remove_lora(self, lora_id: int) -> bool:
        assert lora_id > 0, "lora_id must be greater than 0."
        return all(self.collective_rpc("remove_lora", args=(lora_id,)))

    def pin_lora(self, lora_id: int) -> bool:
        assert lora_id > 0, "lora_id must be greater than 0."
        return all(self.collective_rpc("pin_lora", args=(lora_id,)))

    def list_loras(self) -> set[int]:
        sets: list[set[int]] = self.collective_rpc("list_loras")
        for s in sets:
            assert s == sets[0], "All workers should have the same LORAs."
        return sets[0]

    def reset_mm_cache(self) -> None:
        """Reset the multi-modal cache in each worker."""
        self.collective_rpc("reset_mm_cache")

307
308
309
310
    def reset_encoder_cache(self) -> None:
        """Reset the encoder cache in each worker to clear cached encoder outputs."""
        self.collective_rpc("reset_encoder_cache")

311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
    def sleep(self, level: int = 1):
        if self.is_sleeping:
            logger.warning("Executor is already sleeping.")
            return
        time_before_sleep = time.perf_counter()
        self.collective_rpc("sleep", kwargs=dict(level=level))
        time_after_sleep = time.perf_counter()
        self.sleeping_tags = {"weights", "kv_cache"}
        self.is_sleeping = True
        logger.info(
            "It took %.6f seconds to fall asleep.", time_after_sleep - time_before_sleep
        )

    def wake_up(self, tags: list[str] | None = None):
        if not self.is_sleeping:
            logger.warning("Executor is not sleeping.")
            return
        if tags:
            for tag in tags:
                if tag not in self.sleeping_tags:
                    logger.warning(
                        "Tag %s is not in sleeping tags %s", tag, self.sleeping_tags
                    )
                    return
        time_before_wakeup = time.perf_counter()
        self.collective_rpc("wake_up", kwargs=dict(tags=tags))
        time_after_wakeup = time.perf_counter()
        logger.info(
            "It took %.6f seconds to wake up tags %s.",
            time_after_wakeup - time_before_wakeup,
            tags if tags is not None else self.sleeping_tags,
        )
        if tags:
            for tag in tags:
                self.sleeping_tags.remove(tag)
        else:
            self.sleeping_tags.clear()
        if not self.sleeping_tags:
            self.is_sleeping = False

    def reinitialize_distributed(
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
        raise NotImplementedError

356
357
358
359
360
361
362
    @classmethod
    def supports_async_scheduling(cls) -> bool:
        """
        Whether the executor supports async scheduling.
        """
        return False

363
364
365
366
367
368
369
370
371
372
373

from vllm.v1.executor.uniproc_executor import (  # noqa: E402
    ExecutorWithExternalLauncher as _ExecutorWithExternalLauncher,
)
from vllm.v1.executor.uniproc_executor import (  # noqa: E402
    UniProcExecutor as _UniProcExecutor,
)

# For backwards compatibility.
UniProcExecutor = _UniProcExecutor
ExecutorWithExternalLauncher = _ExecutorWithExternalLauncher