abstract.py 4.11 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from concurrent.futures import Future
4
from typing import Union
5

6
7
8
import torch
import torch.distributed as dist

9
from vllm.config import VllmConfig
10
11
12
13
14
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.uniproc_executor import (  # noqa
    ExecutorWithExternalLauncher as ExecutorWithExternalLauncherV0)
from vllm.executor.uniproc_executor import (  # noqa
    UniProcExecutor as UniProcExecutorV0)
15
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
16
17
18
from vllm.v1.outputs import ModelRunnerOutput


19
20
21
22
class Executor(ExecutorBase):
    """
    Abstract class for v1 executors, mainly define some methods for v1.
    For methods shared by v0 and v1, define them in ExecutorBase"""
23

24
    @staticmethod
25
26
    def get_class(vllm_config: VllmConfig) -> type["Executor"]:
        executor_class: type[Executor]
27
        parallel_config = vllm_config.parallel_config
28
        distributed_executor_backend = (
29
            parallel_config.distributed_executor_backend)
30
31
32
33
34
35
36
37
        # distributed_executor_backend must be set in VllmConfig.__post_init__
        if isinstance(distributed_executor_backend, type):
            if not issubclass(distributed_executor_backend, ExecutorBase):
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
                    f"ExecutorBase. Got {distributed_executor_backend}.")
            executor_class = distributed_executor_backend
        elif distributed_executor_backend == "ray":
38
39
            from vllm.v1.executor.ray_distributed_executor import (  # noqa
                RayDistributedExecutor)
40
            executor_class = RayDistributedExecutor
41
42
43
        elif distributed_executor_backend == "mp":
            from vllm.v1.executor.multiproc_executor import MultiprocExecutor
            executor_class = MultiprocExecutor
44
45
46
47
48
49
        elif distributed_executor_backend == "uni":
            executor_class = UniProcExecutor
        elif distributed_executor_backend == "external_launcher":
            # TODO: make v1 scheduling deterministic
            # to support external launcher
            executor_class = ExecutorWithExternalLauncher
50
        else:
51
52
            raise ValueError("Unknown distributed executor backend: "
                             f"{distributed_executor_backend}")
53
54
        return executor_class

55
    def initialize_from_config(self,
56
                               kv_cache_configs: list[KVCacheConfig]) -> None:
57
58
59
60
        """
        Initialize the KV caches and begin the model execution loop of the
        underlying workers.
        """
61
62
        self.collective_rpc("initialize_from_config",
                            args=(kv_cache_configs, ))
63
        self.collective_rpc("compile_or_warm_up_model")
64

65
    def determine_available_memory(self) -> list[int]:  # in bytes
66
        output = self.collective_rpc("determine_available_memory")
67
        return output
68

69
    def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]:
70
        output = self.collective_rpc("get_kv_cache_spec")
71
        return output
72
73
74
75

    def execute_model(
        self,
        scheduler_output,
76
    ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
77
78
79
        output = self.collective_rpc("execute_model",
                                     args=(scheduler_output, ))
        return output[0]
80

81
82
83
84
    @property
    def max_concurrent_batches(self) -> int:
        return 1

85
    def profile(self, is_start: bool = True):
86
87
88
89
90
91
92
93
        self.collective_rpc("profile", args=(is_start, ))


class UniProcExecutor(UniProcExecutorV0, Executor):
    pass


class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor):
94

95
    def determine_available_memory(self) -> list[int]:  # in bytes
96
97
98
99
100
101
102
        # same as determine_num_available_blocks in v0,
        # we need to get the min across all ranks.
        memory = super().determine_available_memory()
        from vllm.distributed.parallel_state import get_world_group
        cpu_group = get_world_group().cpu_group
        memory_tensor = torch.tensor([memory], device="cpu", dtype=torch.int64)
        dist.all_reduce(memory_tensor, group=cpu_group, op=dist.ReduceOp.MIN)
103
        return [memory_tensor.item()]