abstract.py 4.36 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from concurrent.futures import Future
4
from typing import Callable, Union
5

6
7
8
import torch
import torch.distributed as dist

9
from vllm.config import VllmConfig
10
11
12
13
14
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.uniproc_executor import (  # noqa
    ExecutorWithExternalLauncher as ExecutorWithExternalLauncherV0)
from vllm.executor.uniproc_executor import (  # noqa
    UniProcExecutor as UniProcExecutorV0)
15
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
16
17
from vllm.v1.outputs import ModelRunnerOutput

18
19
FailureCallback = Callable[[], None]

20

21
22
23
24
class Executor(ExecutorBase):
    """
    Abstract class for v1 executors, mainly define some methods for v1.
    For methods shared by v0 and v1, define them in ExecutorBase"""
25

26
    @staticmethod
27
28
    def get_class(vllm_config: VllmConfig) -> type["Executor"]:
        executor_class: type[Executor]
29
        parallel_config = vllm_config.parallel_config
30
        distributed_executor_backend = (
31
            parallel_config.distributed_executor_backend)
32
33
34
35
36
37
38
39
        # distributed_executor_backend must be set in VllmConfig.__post_init__
        if isinstance(distributed_executor_backend, type):
            if not issubclass(distributed_executor_backend, ExecutorBase):
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
                    f"ExecutorBase. Got {distributed_executor_backend}.")
            executor_class = distributed_executor_backend
        elif distributed_executor_backend == "ray":
40
41
            from vllm.v1.executor.ray_distributed_executor import (  # noqa
                RayDistributedExecutor)
42
            executor_class = RayDistributedExecutor
43
44
45
        elif distributed_executor_backend == "mp":
            from vllm.v1.executor.multiproc_executor import MultiprocExecutor
            executor_class = MultiprocExecutor
46
47
48
49
50
51
        elif distributed_executor_backend == "uni":
            executor_class = UniProcExecutor
        elif distributed_executor_backend == "external_launcher":
            # TODO: make v1 scheduling deterministic
            # to support external launcher
            executor_class = ExecutorWithExternalLauncher
52
        else:
53
54
            raise ValueError("Unknown distributed executor backend: "
                             f"{distributed_executor_backend}")
55
56
        return executor_class

57
    def initialize_from_config(self,
58
                               kv_cache_configs: list[KVCacheConfig]) -> None:
59
60
61
62
        """
        Initialize the KV caches and begin the model execution loop of the
        underlying workers.
        """
63
64
        self.collective_rpc("initialize_from_config",
                            args=(kv_cache_configs, ))
65
        self.collective_rpc("compile_or_warm_up_model")
66

67
68
69
70
71
72
73
    def register_failure_callback(self, callback: FailureCallback):
        """
        Register a function to be called if the executor enters a permanent
        failed state.
        """
        pass

74
    def determine_available_memory(self) -> list[int]:  # in bytes
75
        output = self.collective_rpc("determine_available_memory")
76
        return output
77

78
    def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]:
79
        output = self.collective_rpc("get_kv_cache_spec")
80
        return output
81
82
83
84

    def execute_model(
        self,
        scheduler_output,
85
    ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
86
87
88
        output = self.collective_rpc("execute_model",
                                     args=(scheduler_output, ))
        return output[0]
89

90
91
92
93
    @property
    def max_concurrent_batches(self) -> int:
        return 1

94
    def profile(self, is_start: bool = True):
95
96
97
98
99
100
101
102
        self.collective_rpc("profile", args=(is_start, ))


class UniProcExecutor(UniProcExecutorV0, Executor):
    pass


class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor):
103

104
    def determine_available_memory(self) -> list[int]:  # in bytes
105
106
107
108
109
110
111
        # same as determine_num_available_blocks in v0,
        # we need to get the min across all ranks.
        memory = super().determine_available_memory()
        from vllm.distributed.parallel_state import get_world_group
        cpu_group = get_world_group().cpu_group
        memory_tensor = torch.tensor([memory], device="cpu", dtype=torch.int64)
        dist.all_reduce(memory_tensor, group=cpu_group, op=dist.ReduceOp.MIN)
112
        return [memory_tensor.item()]