"vllm/vscode:/vscode.git/clone" did not exist on "3f175f18a2e5d430ffa17fcb96759a758cc3ec05"
abstract.py 5.02 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from concurrent.futures import Future
5
from typing import Callable, Optional, Union
6

7
8
9
import torch
import torch.distributed as dist

10
from vllm.config import VllmConfig
11
12
13
14
15
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.uniproc_executor import (  # noqa
    ExecutorWithExternalLauncher as ExecutorWithExternalLauncherV0)
from vllm.executor.uniproc_executor import (  # noqa
    UniProcExecutor as UniProcExecutorV0)
16
from vllm.utils import resolve_obj_by_qualname
17
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
18
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
19

20
21
FailureCallback = Callable[[], None]

22

23
24
25
26
class Executor(ExecutorBase):
    """
    Abstract class for v1 executors, mainly define some methods for v1.
    For methods shared by v0 and v1, define them in ExecutorBase"""
27

28
    @staticmethod
29
30
    def get_class(vllm_config: VllmConfig) -> type["Executor"]:
        executor_class: type[Executor]
31
        parallel_config = vllm_config.parallel_config
32
        distributed_executor_backend = (
33
            parallel_config.distributed_executor_backend)
34
35
36
37
38
39
40
41
        # distributed_executor_backend must be set in VllmConfig.__post_init__
        if isinstance(distributed_executor_backend, type):
            if not issubclass(distributed_executor_backend, ExecutorBase):
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
                    f"ExecutorBase. Got {distributed_executor_backend}.")
            executor_class = distributed_executor_backend
        elif distributed_executor_backend == "ray":
42
43
            from vllm.v1.executor.ray_distributed_executor import (  # noqa
                RayDistributedExecutor)
44
            executor_class = RayDistributedExecutor
45
46
47
        elif distributed_executor_backend == "mp":
            from vllm.v1.executor.multiproc_executor import MultiprocExecutor
            executor_class = MultiprocExecutor
48
49
50
51
52
53
        elif distributed_executor_backend == "uni":
            executor_class = UniProcExecutor
        elif distributed_executor_backend == "external_launcher":
            # TODO: make v1 scheduling deterministic
            # to support external launcher
            executor_class = ExecutorWithExternalLauncher
54
55
56
57
58
59
60
        elif isinstance(distributed_executor_backend, str):
            executor_class = resolve_obj_by_qualname(
                distributed_executor_backend)
            if not issubclass(executor_class, ExecutorBase):
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
                    f"ExecutorBase. Got {executor_class}.")
61
        else:
62
63
            raise ValueError("Unknown distributed executor backend: "
                             f"{distributed_executor_backend}")
64
65
        return executor_class

66
    def initialize_from_config(self,
67
                               kv_cache_configs: list[KVCacheConfig]) -> None:
68
69
70
71
        """
        Initialize the KV caches and begin the model execution loop of the
        underlying workers.
        """
72
73
        self.collective_rpc("initialize_from_config",
                            args=(kv_cache_configs, ))
74
        self.collective_rpc("compile_or_warm_up_model")
75

76
77
78
79
80
81
82
    def register_failure_callback(self, callback: FailureCallback):
        """
        Register a function to be called if the executor enters a permanent
        failed state.
        """
        pass

83
    def determine_available_memory(self) -> list[int]:  # in bytes
84
        output = self.collective_rpc("determine_available_memory")
85
        return output
86

87
    def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]:
88
        output = self.collective_rpc("get_kv_cache_spec")
89
        return output
90
91
92
93

    def execute_model(
        self,
        scheduler_output,
94
    ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
95
96
97
        output = self.collective_rpc("execute_model",
                                     args=(scheduler_output, ))
        return output[0]
98

99
100
101
102
    def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
        output = self.collective_rpc("take_draft_token_ids")
        return output[0]

103
104
105
106
    @property
    def max_concurrent_batches(self) -> int:
        return 1

107
    def profile(self, is_start: bool = True):
108
109
110
111
112
113
114
115
        self.collective_rpc("profile", args=(is_start, ))


class UniProcExecutor(UniProcExecutorV0, Executor):
    pass


class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor):
116

117
    def determine_available_memory(self) -> list[int]:  # in bytes
118
119
120
121
122
123
124
        # same as determine_num_available_blocks in v0,
        # we need to get the min across all ranks.
        memory = super().determine_available_memory()
        from vllm.distributed.parallel_state import get_world_group
        cpu_group = get_world_group().cpu_group
        memory_tensor = torch.tensor([memory], device="cpu", dtype=torch.int64)
        dist.all_reduce(memory_tensor, group=cpu_group, op=dist.ReduceOp.MIN)
125
        return [memory_tensor.item()]