abstract.py 5.38 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Callable
5
from concurrent.futures import Future
6
from typing import Any
7

8
9
10
import torch
import torch.distributed as dist

11
from vllm.config import VllmConfig
12
13
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.uniproc_executor import (  # noqa
14
15
16
    ExecutorWithExternalLauncher as ExecutorWithExternalLauncherV0,
)
from vllm.executor.uniproc_executor import UniProcExecutor as UniProcExecutorV0  # noqa
17
from vllm.utils import resolve_obj_by_qualname
18
from vllm.v1.core.sched.output import SchedulerOutput
19
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
20
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
21

22
23
FailureCallback = Callable[[], None]

24

25
26
27
28
class Executor(ExecutorBase):
    """
    Abstract class for v1 executors, mainly define some methods for v1.
    For methods shared by v0 and v1, define them in ExecutorBase"""
29

30
    @staticmethod
31
32
    def get_class(vllm_config: VllmConfig) -> type["Executor"]:
        executor_class: type[Executor]
33
        parallel_config = vllm_config.parallel_config
34
        distributed_executor_backend = parallel_config.distributed_executor_backend
35
36
37
38
39
        # distributed_executor_backend must be set in VllmConfig.__post_init__
        if isinstance(distributed_executor_backend, type):
            if not issubclass(distributed_executor_backend, ExecutorBase):
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
40
41
                    f"ExecutorBase. Got {distributed_executor_backend}."
                )
42
43
            executor_class = distributed_executor_backend
        elif distributed_executor_backend == "ray":
44
            from vllm.v1.executor.ray_distributed_executor import (  # noqa
45
46
47
                RayDistributedExecutor,
            )

48
            executor_class = RayDistributedExecutor
49
50
        elif distributed_executor_backend == "mp":
            from vllm.v1.executor.multiproc_executor import MultiprocExecutor
51

52
            executor_class = MultiprocExecutor
53
54
55
56
57
58
        elif distributed_executor_backend == "uni":
            executor_class = UniProcExecutor
        elif distributed_executor_backend == "external_launcher":
            # TODO: make v1 scheduling deterministic
            # to support external launcher
            executor_class = ExecutorWithExternalLauncher
59
        elif isinstance(distributed_executor_backend, str):
60
            executor_class = resolve_obj_by_qualname(distributed_executor_backend)
61
62
63
            if not issubclass(executor_class, ExecutorBase):
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
64
65
                    f"ExecutorBase. Got {executor_class}."
                )
66
        else:
67
68
69
            raise ValueError(
                f"Unknown distributed executor backend: {distributed_executor_backend}"
            )
70
71
        return executor_class

72
    def initialize_from_config(self, kv_cache_configs: list[KVCacheConfig]) -> None:
73
74
75
76
        """
        Initialize the KV caches and begin the model execution loop of the
        underlying workers.
        """
77
        self.collective_rpc("initialize_from_config", args=(kv_cache_configs,))
78
        self.collective_rpc("compile_or_warm_up_model")
79

80
81
82
83
84
85
86
    def register_failure_callback(self, callback: FailureCallback):
        """
        Register a function to be called if the executor enters a permanent
        failed state.
        """
        pass

87
    def determine_available_memory(self) -> list[int]:  # in bytes
88
        return self.collective_rpc("determine_available_memory")
89

90
    def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]:
91
        return self.collective_rpc("get_kv_cache_spec")
92

93
94
    def collective_rpc(
        self,
95
96
        method: str | Callable,
        timeout: float | None = None,
97
        args: tuple = (),
98
        kwargs: dict | None = None,
99
100
        non_block: bool = False,
    ) -> list[Any]:
101
102
        raise NotImplementedError

103
104
    def execute_model(
        self,
105
106
        scheduler_output: SchedulerOutput,
        non_block: bool = False,
107
    ) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
108
109
110
        output = self.collective_rpc(
            "execute_model", args=(scheduler_output,), non_block=non_block
        )
111
        return output[0]
112

113
114
115
    def execute_dummy_batch(self) -> None:
        self.collective_rpc("execute_dummy_batch")

116
    def take_draft_token_ids(self) -> DraftTokenIds | None:
117
118
119
        output = self.collective_rpc("take_draft_token_ids")
        return output[0]

120
121
122
123
    @property
    def max_concurrent_batches(self) -> int:
        return 1

124
    def profile(self, is_start: bool = True):
125
        self.collective_rpc("profile", args=(is_start,))
126
127
128
129
130
131
132


class UniProcExecutor(UniProcExecutorV0, Executor):
    pass


class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor):
133
    def determine_available_memory(self) -> list[int]:  # in bytes
134
135
136
137
        # same as determine_num_available_blocks in v0,
        # we need to get the min across all ranks.
        memory = super().determine_available_memory()
        from vllm.distributed.parallel_state import get_world_group
138

139
140
141
        cpu_group = get_world_group().cpu_group
        memory_tensor = torch.tensor([memory], device="cpu", dtype=torch.int64)
        dist.all_reduce(memory_tensor, group=cpu_group, op=dist.ReduceOp.MIN)
142
        return [memory_tensor.item()]