protocol.py 5.44 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from collections.abc import AsyncGenerator, Iterable, Mapping
6
from typing import Any
7

8
from vllm.config import ModelConfig, RendererConfig, VllmConfig
9
from vllm.inputs.data import PromptType
10
from vllm.lora.request import LoRARequest
11
12
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import IOProcessor
13
from vllm.pooling_params import PoolingParams
14
from vllm.sampling_params import SamplingParams
15
from vllm.tasks import SupportedTask
16
from vllm.tokenizers import TokenizerLike
17
from vllm.v1.engine import EngineCoreRequest
18
from vllm.v1.engine.input_processor import InputProcessor
19

20

21
class EngineClient(ABC):
22
    """Protocol class for Clients to Engine"""
23

24
    vllm_config: VllmConfig
25
    renderer_config: RendererConfig
26
    model_config: ModelConfig
27
    input_processor: InputProcessor
28
    io_processor: IOProcessor | None
29

30
    @property
31
    @abstractmethod
32
    def is_running(self) -> bool: ...
33
34

    @property
35
    @abstractmethod
36
    def is_stopped(self) -> bool: ...
37
38

    @property
39
    @abstractmethod
40
    def errored(self) -> bool: ...
41

42
    @property
43
    @abstractmethod
44
    def dead_error(self) -> BaseException: ...
45

46
    @abstractmethod
47
    def generate(
48
        self,
49
        prompt: EngineCoreRequest | PromptType,
50
51
        sampling_params: SamplingParams,
        request_id: str,
52
        *,
53
54
55
56
        prompt_text: str | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
57
        priority: int = 0,
58
        data_parallel_rank: int | None = None,
59
    ) -> AsyncGenerator[RequestOutput, None]:
60
        """Generate outputs for a request."""
61
        ...
62

63
    @abstractmethod
64
    def encode(
65
        self,
66
        prompt: PromptType,
67
68
        pooling_params: PoolingParams,
        request_id: str,
69
70
        lora_request: LoRARequest | None = None,
        trace_headers: Mapping[str, str] | None = None,
71
        priority: int = 0,
72
        truncate_prompt_tokens: int | None = None,
73
        tokenization_kwargs: dict[str, Any] | None = None,
74
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
75
        """Generate outputs for a request from a pooling model."""
76
        ...
77

78
    @abstractmethod
79
    async def abort(self, request_id: str | Iterable[str]) -> None:
80
81
82
        """Abort a request.

        Args:
83
84
            request_id: The unique id of the request,
                        or an iterable of such ids.
85
        """
86
        ...
87

88
    @abstractmethod
89
    async def get_tokenizer(self) -> TokenizerLike:
90
        """Get the tokenizer"""
91
        ...
92

93
    @abstractmethod
94
    async def is_tracing_enabled(self) -> bool: ...
95

96
    @abstractmethod
97
    async def do_log_stats(self) -> None: ...
98

99
    @abstractmethod
100
101
    async def check_health(self) -> None:
        """Raise if unhealthy"""
102
        ...
103

104
    @abstractmethod
105
106
107
108
    async def start_profile(self) -> None:
        """Start profiling the engine"""
        ...

109
    @abstractmethod
110
    async def stop_profile(self) -> None:
111
        """Stop profiling the engine"""
112
        ...
113

114
115
116
117
118
    @abstractmethod
    async def reset_mm_cache(self) -> None:
        """Reset the multi-modal cache"""
        ...

119
    @abstractmethod
120
121
122
123
    async def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        """Reset the prefix cache and optionally any configured connector cache"""
124
125
        ...

126
127
128
129
130
131
    @abstractmethod
    async def sleep(self, level: int = 1) -> None:
        """Sleep the engine"""
        ...

    @abstractmethod
132
    async def wake_up(self, tags: list[str] | None = None) -> None:
133
134
135
        """Wake up the engine"""
        ...

136
137
138
139
140
    @abstractmethod
    async def is_sleeping(self) -> bool:
        """Check whether the engine is sleeping"""
        ...

141
    @abstractmethod
142
    async def add_lora(self, lora_request: LoRARequest) -> bool:
143
144
        """Load a new LoRA adapter into the engine for future requests."""
        ...
145

146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
    @abstractmethod
    async def pause_generation(
        self,
        *,
        wait_for_inflight_requests: bool = False,
        clear_cache: bool = True,
    ) -> None:
        """Pause new generation/encoding requests.

        Args:
            wait_for_inflight_requests: When ``True`` waits for in-flight requests
                to finish before pausing. When ``False`` (default), aborts in-flight
                requests immediately.
            clear_cache: Whether to clear KV and prefix caches after draining.
        """
        ...

    @abstractmethod
    async def resume_generation(self) -> None:
        """Resume accepting generation/encoding requests."""
        ...

    @abstractmethod
    async def is_paused(self) -> bool:
        """Return whether the engine is currently paused."""
        ...

173
174
175
    async def scale_elastic_ep(
        self, new_data_parallel_size: int, drain_timeout: int = 300
    ) -> None:
176
177
        """Scale the engine"""
        raise NotImplementedError
178

179
180
181
    async def collective_rpc(
        self,
        method: str,
182
        timeout: float | None = None,
183
        args: tuple = (),
184
        kwargs: dict | None = None,
185
    ):
186
187
        """Perform a collective RPC call to the given path."""
        raise NotImplementedError
188
189
190
191

    async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        """Get supported tasks"""
        raise NotImplementedError