protocol.py 6.84 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from collections.abc import AsyncGenerator, Iterable, Mapping
6
from dataclasses import dataclass
7
from typing import TYPE_CHECKING, Any
8

9
from vllm.config import ModelConfig, VllmConfig
10
11
12
13
from vllm.distributed.weight_transfer.base import (
    WeightTransferInitRequest,
    WeightTransferUpdateRequest,
)
14
from vllm.inputs import EngineInput, PromptType
15
from vllm.lora.request import LoRARequest
16
from vllm.outputs import PoolingRequestOutput, RequestOutput
17
from vllm.pooling_params import PoolingParams
18
from vllm.renderers import BaseRenderer
19
from vllm.sampling_params import SamplingParams
20
from vllm.tasks import SupportedTask
21
from vllm.v1.engine import EngineCoreRequest
22
from vllm.v1.engine.input_processor import InputProcessor
23

24
25
26
if TYPE_CHECKING:
    from vllm.v1.engine import PauseMode

27

28
29
30
31
32
33
34
35
@dataclass
class StreamingInput:
    """Input data for a streaming generation request.

    This is used with generate() to support multi-turn streaming sessions
    where inputs are provided via an async generator.
    """

36
    prompt: EngineInput
37
38
39
    sampling_params: SamplingParams | None = None


40
class EngineClient(ABC):
41
    """Protocol class for Clients to Engine"""
42

43
44
    vllm_config: VllmConfig
    model_config: ModelConfig
45
46
    renderer: BaseRenderer
    input_processor: InputProcessor
47

48
    @property
49
    @abstractmethod
50
    def is_running(self) -> bool: ...
51
52

    @property
53
    @abstractmethod
54
    def is_stopped(self) -> bool: ...
55
56

    @property
57
    @abstractmethod
58
    def errored(self) -> bool: ...
59

60
    @property
61
    @abstractmethod
62
    def dead_error(self) -> BaseException: ...
63

64
    @abstractmethod
65
    def generate(
66
        self,
67
68
        prompt: EngineCoreRequest
        | PromptType
69
        | EngineInput
70
        | AsyncGenerator[StreamingInput, None],
71
72
        sampling_params: SamplingParams,
        request_id: str,
73
        *,
74
75
76
77
        prompt_text: str | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
78
        priority: int = 0,
79
        data_parallel_rank: int | None = None,
80
        reasoning_ended: bool | None = None,
81
    ) -> AsyncGenerator[RequestOutput, None]:
82
        """Generate outputs for a request."""
83
        ...
84

85
    @abstractmethod
86
    def encode(
87
        self,
88
        prompt: PromptType | EngineInput,
89
90
        pooling_params: PoolingParams,
        request_id: str,
91
92
        lora_request: LoRARequest | None = None,
        trace_headers: Mapping[str, str] | None = None,
93
        priority: int = 0,
94
        tokenization_kwargs: dict[str, Any] | None = None,
95
        reasoning_ended: bool | None = None,
96
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
97
        """Generate outputs for a request from a pooling model."""
98
        ...
99

100
    @abstractmethod
101
    async def abort(self, request_id: str | Iterable[str]) -> None:
102
103
104
        """Abort a request.

        Args:
105
106
            request_id: The unique id of the request,
                        or an iterable of such ids.
107
        """
108
        ...
109

110
    @abstractmethod
111
    async def is_tracing_enabled(self) -> bool: ...
112

113
    @abstractmethod
114
    async def do_log_stats(self) -> None: ...
115

116
    @abstractmethod
117
118
    async def check_health(self) -> None:
        """Raise if unhealthy"""
119
        ...
120

121
    @abstractmethod
122
123
124
125
    async def start_profile(self) -> None:
        """Start profiling the engine"""
        ...

126
    @abstractmethod
127
    async def stop_profile(self) -> None:
128
        """Stop profiling the engine"""
129
        ...
130

131
132
133
134
135
    @abstractmethod
    async def reset_mm_cache(self) -> None:
        """Reset the multi-modal cache"""
        ...

136
137
138
139
140
    @abstractmethod
    async def reset_encoder_cache(self) -> None:
        """Reset the encoder cache"""
        ...

141
    @abstractmethod
142
143
144
145
    async def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        """Reset the prefix cache and optionally any configured connector cache"""
146
147
        ...

148
    @abstractmethod
149
    async def sleep(self, level: int = 1, mode: "PauseMode" = "abort") -> None:
150
151
152
153
        """Sleep the engine"""
        ...

    @abstractmethod
154
    async def wake_up(self, tags: list[str] | None = None) -> None:
155
156
157
        """Wake up the engine"""
        ...

158
159
160
161
162
    @abstractmethod
    async def is_sleeping(self) -> bool:
        """Check whether the engine is sleeping"""
        ...

163
    @abstractmethod
164
    async def add_lora(self, lora_request: LoRARequest) -> bool:
165
166
        """Load a new LoRA adapter into the engine for future requests."""
        ...
167

168
169
170
171
    @abstractmethod
    async def pause_generation(
        self,
        *,
172
        mode: "PauseMode" = "abort",
173
174
175
176
177
178
        wait_for_inflight_requests: bool = False,
        clear_cache: bool = True,
    ) -> None:
        """Pause new generation/encoding requests.

        Args:
179
180
181
182
183
184
185
186
187
            mode: How to handle in-flight requests:
                - ``"abort"``: Abort all in-flight requests immediately
                  and return partial results with "abort" reason (default).
                - ``"wait"``: Wait for in-flight requests to complete.
                - ``"keep"``: Freeze requests in queue; they resume on
                  :meth:`resume_generation`.
            wait_for_inflight_requests: DEPRECATED. Use ``mode="wait"`` instead.
            clear_cache: DEPRECATED. Whether to clear KV and prefix caches
                after draining.
188
189
190
191
192
193
194
195
196
197
198
199
200
        """
        ...

    @abstractmethod
    async def resume_generation(self) -> None:
        """Resume accepting generation/encoding requests."""
        ...

    @abstractmethod
    async def is_paused(self) -> bool:
        """Return whether the engine is currently paused."""
        ...

201
202
203
204
205
    @abstractmethod
    def shutdown(self, timeout: float | None = None) -> None:
        """Shutdown the engine with optional timeout."""
        ...

206
207
208
    async def scale_elastic_ep(
        self, new_data_parallel_size: int, drain_timeout: int = 300
    ) -> None:
209
210
        """Scale the engine"""
        raise NotImplementedError
211

212
213
214
    async def collective_rpc(
        self,
        method: str,
215
        timeout: float | None = None,
216
        args: tuple = (),
217
        kwargs: dict | None = None,
218
    ):
219
220
        """Perform a collective RPC call to the given path."""
        raise NotImplementedError
221
222
223
224

    async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        """Get supported tasks"""
        raise NotImplementedError
225
226
227
228
229
230
231
232
233
234

    async def init_weight_transfer_engine(
        self, init_request: WeightTransferInitRequest
    ) -> None:
        """Initialize weight transfer for RL training."""
        raise NotImplementedError

    async def update_weights(self, request: WeightTransferUpdateRequest) -> None:
        """Batched weight update for RL training."""
        raise NotImplementedError