abstract.py 5.39 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from concurrent.futures import Future
5
from typing import Any, Callable, Optional, Union
6

7
8
9
import torch
import torch.distributed as dist

10
from vllm.config import VllmConfig
11
12
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.uniproc_executor import (  # noqa
13
14
15
    ExecutorWithExternalLauncher as ExecutorWithExternalLauncherV0,
)
from vllm.executor.uniproc_executor import UniProcExecutor as UniProcExecutorV0  # noqa
16
from vllm.utils import resolve_obj_by_qualname
17
from vllm.v1.core.sched.output import SchedulerOutput
18
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
19
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
20

21
22
FailureCallback = Callable[[], None]

23

24
25
26
27
class Executor(ExecutorBase):
    """
    Abstract class for v1 executors, mainly define some methods for v1.
    For methods shared by v0 and v1, define them in ExecutorBase"""
28

29
    @staticmethod
30
31
    def get_class(vllm_config: VllmConfig) -> type["Executor"]:
        executor_class: type[Executor]
32
        parallel_config = vllm_config.parallel_config
33
        distributed_executor_backend = parallel_config.distributed_executor_backend
34
35
36
37
38
        # distributed_executor_backend must be set in VllmConfig.__post_init__
        if isinstance(distributed_executor_backend, type):
            if not issubclass(distributed_executor_backend, ExecutorBase):
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
39
40
                    f"ExecutorBase. Got {distributed_executor_backend}."
                )
41
42
            executor_class = distributed_executor_backend
        elif distributed_executor_backend == "ray":
43
            from vllm.v1.executor.ray_distributed_executor import (  # noqa
44
45
46
                RayDistributedExecutor,
            )

47
            executor_class = RayDistributedExecutor
48
49
        elif distributed_executor_backend == "mp":
            from vllm.v1.executor.multiproc_executor import MultiprocExecutor
50

51
            executor_class = MultiprocExecutor
52
53
54
55
56
57
        elif distributed_executor_backend == "uni":
            executor_class = UniProcExecutor
        elif distributed_executor_backend == "external_launcher":
            # TODO: make v1 scheduling deterministic
            # to support external launcher
            executor_class = ExecutorWithExternalLauncher
58
        elif isinstance(distributed_executor_backend, str):
59
            executor_class = resolve_obj_by_qualname(distributed_executor_backend)
60
61
62
            if not issubclass(executor_class, ExecutorBase):
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
63
64
                    f"ExecutorBase. Got {executor_class}."
                )
65
        else:
66
67
68
            raise ValueError(
                f"Unknown distributed executor backend: {distributed_executor_backend}"
            )
69
70
        return executor_class

71
    def initialize_from_config(self, kv_cache_configs: list[KVCacheConfig]) -> None:
72
73
74
75
        """
        Initialize the KV caches and begin the model execution loop of the
        underlying workers.
        """
76
        self.collective_rpc("initialize_from_config", args=(kv_cache_configs,))
77
        self.collective_rpc("compile_or_warm_up_model")
78

79
80
81
82
83
84
85
    def register_failure_callback(self, callback: FailureCallback):
        """
        Register a function to be called if the executor enters a permanent
        failed state.
        """
        pass

86
    def determine_available_memory(self) -> list[int]:  # in bytes
87
        return self.collective_rpc("determine_available_memory")
88

89
    def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]:
90
        return self.collective_rpc("get_kv_cache_spec")
91

92
93
94
95
96
97
98
99
    def collective_rpc(
        self,
        method: Union[str, Callable],
        timeout: Optional[float] = None,
        args: tuple = (),
        kwargs: Optional[dict] = None,
        non_block: bool = False,
    ) -> list[Any]:
100
101
        raise NotImplementedError

102
103
    def execute_model(
        self,
104
105
        scheduler_output: SchedulerOutput,
        non_block: bool = False,
106
    ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
107
108
109
        output = self.collective_rpc(
            "execute_model", args=(scheduler_output,), non_block=non_block
        )
110
        return output[0]
111

112
113
114
    def execute_dummy_batch(self) -> None:
        self.collective_rpc("execute_dummy_batch")

115
116
117
118
    def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
        output = self.collective_rpc("take_draft_token_ids")
        return output[0]

119
120
121
122
    @property
    def max_concurrent_batches(self) -> int:
        return 1

123
    def profile(self, is_start: bool = True):
124
        self.collective_rpc("profile", args=(is_start,))
125
126
127
128
129
130
131


class UniProcExecutor(UniProcExecutorV0, Executor):
    pass


class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor):
132
    def determine_available_memory(self) -> list[int]:  # in bytes
133
134
135
136
        # same as determine_num_available_blocks in v0,
        # we need to get the min across all ranks.
        memory = super().determine_available_memory()
        from vllm.distributed.parallel_state import get_world_group
137

138
139
140
        cpu_group = get_world_group().cpu_group
        memory_tensor = torch.tensor([memory], device="cpu", dtype=torch.int64)
        dist.all_reduce(memory_tensor, group=cpu_group, op=dist.ReduceOp.MIN)
141
        return [memory_tensor.item()]