executor_base.py 4.58 KB
Newer Older
1
from abc import ABC, abstractmethod
2
from typing import List, Optional, Set, Tuple
3

4
from vllm.config import EngineConfig
5
from vllm.lora.request import LoRARequest
6
from vllm.model_executor.layers.sampler import SamplerOutput
7
from vllm.prompt_adapter.request import PromptAdapterRequest
8
from vllm.sequence import ExecuteModelRequest
9
10
11
12
13
14
15
16
17
18


class ExecutorBase(ABC):
    """Base class for all executors.

    An executor is responsible for executing the model on a specific device
    type (e.g., CPU, GPU, Neuron, etc.). Or it can be a distributed executor
    that can execute the model on multiple devices.
    """

19
20
    uses_ray: bool  # whether the executor uses Ray for orchestration.

21
22
    def __init__(
        self,
23
        vllm_config: EngineConfig,
24
    ) -> None:
25
26
27
28
29
30
31
32
33
34
35
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.load_config = vllm_config.load_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.device_config = vllm_config.device_config
        self.speculative_config = vllm_config.speculative_config
        self.prompt_adapter_config = vllm_config.prompt_adapter_config
        self.observability_config = vllm_config.observability_config
36
37
38
39
40
        self._init_executor()

    @abstractmethod
    def _init_executor(self) -> None:
        pass
41

42
    @abstractmethod
43
    def determine_num_available_blocks(self) -> Tuple[int, int]:
44
45
46
47
48
49
50
        """Determine the number of available blocks for the GPU KV cache and
        swappable CPU KV cache.

        Normally, this should simply delegate to the underlying Worker. Some
        ExecutorBase may require modification of the result, e.g. to ensure the
        selected cache sizes are compatible with all workers.

51
        Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
52
53
54
55
56
57
58
59
60
61
62
63
64
        are blocks that are "active" on the device and can be appended to.
        num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
        appended to.
        """
        raise NotImplementedError

    @abstractmethod
    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
        """Initialize the KV cache with the given size in blocks.
        """
        raise NotImplementedError

65
    @abstractmethod
66
    def execute_model(
67
68
        self, execute_model_req: ExecuteModelRequest
    ) -> Optional[List[SamplerOutput]]:
69
        """Executes at least one model step on the given sequences."""
70
71
        raise NotImplementedError

72
73
74
75
    def stop_remote_worker_execution_loop(self) -> None:
        """Releases parallel workers from model loop."""
        return

76
77
78
79
80
81
82
83
    @abstractmethod
    def add_lora(self, lora_request: LoRARequest) -> bool:
        raise NotImplementedError

    @abstractmethod
    def remove_lora(self, lora_id: int) -> bool:
        raise NotImplementedError

84
85
86
87
    @abstractmethod
    def pin_lora(self, lora_id: int) -> bool:
        raise NotImplementedError  # type: ignore

88
    @abstractmethod
89
    def list_loras(self) -> Set[int]:
90
91
        raise NotImplementedError

92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
    @abstractmethod
    def add_prompt_adapter(
            self, prompt_adapter_request: PromptAdapterRequest) -> bool:
        raise NotImplementedError

    @abstractmethod
    def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
        raise NotImplementedError

    @abstractmethod
    def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
        raise NotImplementedError  # type: ignore

    @abstractmethod
    def list_prompt_adapters(self) -> Set[int]:
        raise NotImplementedError

109
110
111
112
113
114
    @abstractmethod
    def check_health(self) -> None:
        """Checks if the executor is healthy. If not, it should raise an
        exception."""
        raise NotImplementedError

115
116
117
118
119
120
121
    def shutdown(self) -> None:
        """Shutdown the executor."""
        return

    def __del__(self):
        self.shutdown()

122
123
124
125
126

class ExecutorAsyncBase(ExecutorBase):

    @abstractmethod
    async def execute_model_async(
127
128
            self,
            execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
129
130
131
        """Executes one model step on the given sequences."""
        raise NotImplementedError

132
133
134
135
    async def stop_remote_worker_execution_loop_async(self) -> None:
        """Releases parallel workers from model loop."""
        return

136
137
138
    async def check_health_async(self) -> None:
        """Checks if the executor is healthy. If not, it should raise an
        exception."""
139
        self.check_health()