abstract.py 5.55 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from concurrent.futures import Future
5
from typing import Any, Callable, Optional, Union
6

7
8
9
import torch
import torch.distributed as dist

10
from vllm.config import VllmConfig
11
12
13
14
15
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.uniproc_executor import (  # noqa
    ExecutorWithExternalLauncher as ExecutorWithExternalLauncherV0)
from vllm.executor.uniproc_executor import (  # noqa
    UniProcExecutor as UniProcExecutorV0)
16
from vllm.utils import resolve_obj_by_qualname
17
from vllm.v1.core.sched.output import SchedulerOutput
18
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
19
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
20

21
22
FailureCallback = Callable[[], None]

23

24
25
26
27
class Executor(ExecutorBase):
    """
    Abstract class for v1 executors, mainly define some methods for v1.
    For methods shared by v0 and v1, define them in ExecutorBase"""
28

29
    @staticmethod
30
31
    def get_class(vllm_config: VllmConfig) -> type["Executor"]:
        executor_class: type[Executor]
32
        parallel_config = vllm_config.parallel_config
33
        distributed_executor_backend = (
34
            parallel_config.distributed_executor_backend)
35
36
37
38
39
40
41
42
        # distributed_executor_backend must be set in VllmConfig.__post_init__
        if isinstance(distributed_executor_backend, type):
            if not issubclass(distributed_executor_backend, ExecutorBase):
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
                    f"ExecutorBase. Got {distributed_executor_backend}.")
            executor_class = distributed_executor_backend
        elif distributed_executor_backend == "ray":
43
44
            from vllm.v1.executor.ray_distributed_executor import (  # noqa
                RayDistributedExecutor)
45
            executor_class = RayDistributedExecutor
46
47
48
        elif distributed_executor_backend == "mp":
            from vllm.v1.executor.multiproc_executor import MultiprocExecutor
            executor_class = MultiprocExecutor
49
50
51
52
53
54
        elif distributed_executor_backend == "uni":
            executor_class = UniProcExecutor
        elif distributed_executor_backend == "external_launcher":
            # TODO: make v1 scheduling deterministic
            # to support external launcher
            executor_class = ExecutorWithExternalLauncher
55
56
57
58
59
60
61
        elif isinstance(distributed_executor_backend, str):
            executor_class = resolve_obj_by_qualname(
                distributed_executor_backend)
            if not issubclass(executor_class, ExecutorBase):
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
                    f"ExecutorBase. Got {executor_class}.")
62
        else:
63
64
            raise ValueError("Unknown distributed executor backend: "
                             f"{distributed_executor_backend}")
65
66
        return executor_class

67
    def initialize_from_config(self,
68
                               kv_cache_configs: list[KVCacheConfig]) -> None:
69
70
71
72
        """
        Initialize the KV caches and begin the model execution loop of the
        underlying workers.
        """
73
74
        self.collective_rpc("initialize_from_config",
                            args=(kv_cache_configs, ))
75
        self.collective_rpc("compile_or_warm_up_model")
76

77
78
79
80
81
82
83
    def register_failure_callback(self, callback: FailureCallback):
        """
        Register a function to be called if the executor enters a permanent
        failed state.
        """
        pass

84
    def determine_available_memory(self) -> list[int]:  # in bytes
85
        return self.collective_rpc("determine_available_memory")
86

87
    def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]:
88
        return self.collective_rpc("get_kv_cache_spec")
89

90
91
92
93
94
95
96
97
    def collective_rpc(self,
                       method: Union[str, Callable],
                       timeout: Optional[float] = None,
                       args: tuple = (),
                       kwargs: Optional[dict] = None,
                       non_block: bool = False) -> list[Any]:
        raise NotImplementedError

98
99
    def execute_model(
        self,
100
101
        scheduler_output: SchedulerOutput,
        non_block: bool = False,
102
    ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
103
        output = self.collective_rpc("execute_model",
104
105
                                     args=(scheduler_output, ),
                                     non_block=non_block)
106
        return output[0]
107

108
109
110
    def execute_dummy_batch(self) -> None:
        self.collective_rpc("execute_dummy_batch")

111
112
113
114
    def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
        output = self.collective_rpc("take_draft_token_ids")
        return output[0]

115
116
117
118
    @property
    def max_concurrent_batches(self) -> int:
        return 1

119
    def profile(self, is_start: bool = True):
120
121
122
123
124
125
126
127
        self.collective_rpc("profile", args=(is_start, ))


class UniProcExecutor(UniProcExecutorV0, Executor):
    pass


class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor):
128

129
    def determine_available_memory(self) -> list[int]:  # in bytes
130
131
132
133
134
135
136
        # same as determine_num_available_blocks in v0,
        # we need to get the min across all ranks.
        memory = super().determine_available_memory()
        from vllm.distributed.parallel_state import get_world_group
        cpu_group = get_world_group().cpu_group
        memory_tensor = torch.tensor([memory], device="cpu", dtype=torch.int64)
        dist.all_reduce(memory_tensor, group=cpu_group, op=dist.ReduceOp.MIN)
137
        return [memory_tensor.item()]