protocol.py 6.93 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from collections.abc import AsyncGenerator, Iterable, Mapping
6
from dataclasses import dataclass
7
from typing import TYPE_CHECKING, Any
8

9
from vllm.config import ModelConfig, VllmConfig
10
11
12
13
from vllm.distributed.weight_transfer.base import (
    WeightTransferInitRequest,
    WeightTransferUpdateRequest,
)
14
from vllm.inputs import EngineInput, PromptType
15
from vllm.lora.request import LoRARequest
16
17
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import IOProcessor
18
from vllm.pooling_params import PoolingParams
19
from vllm.renderers import BaseRenderer
20
from vllm.sampling_params import SamplingParams
21
from vllm.tasks import SupportedTask
22
from vllm.v1.engine import EngineCoreRequest
23
from vllm.v1.engine.input_processor import InputProcessor
24

25
26
27
if TYPE_CHECKING:
    from vllm.v1.engine import PauseMode

28

29
30
31
32
33
34
35
36
@dataclass
class StreamingInput:
    """Input data for a streaming generation request.

    This is used with generate() to support multi-turn streaming sessions
    where inputs are provided via an async generator.
    """

37
    prompt: EngineInput
38
39
40
    sampling_params: SamplingParams | None = None


41
class EngineClient(ABC):
42
    """Protocol class for Clients to Engine"""
43

44
45
    vllm_config: VllmConfig
    model_config: ModelConfig
46
    renderer: BaseRenderer
47
    io_processor: IOProcessor | None
48
    input_processor: InputProcessor
49

50
    @property
51
    @abstractmethod
52
    def is_running(self) -> bool: ...
53
54

    @property
55
    @abstractmethod
56
    def is_stopped(self) -> bool: ...
57
58

    @property
59
    @abstractmethod
60
    def errored(self) -> bool: ...
61

62
    @property
63
    @abstractmethod
64
    def dead_error(self) -> BaseException: ...
65

66
    @abstractmethod
67
    def generate(
68
        self,
69
70
        prompt: EngineCoreRequest
        | PromptType
71
        | EngineInput
72
        | AsyncGenerator[StreamingInput, None],
73
74
        sampling_params: SamplingParams,
        request_id: str,
75
        *,
76
77
78
79
        prompt_text: str | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
80
        priority: int = 0,
81
        data_parallel_rank: int | None = None,
82
        reasoning_ended: bool | None = None,
83
    ) -> AsyncGenerator[RequestOutput, None]:
84
        """Generate outputs for a request."""
85
        ...
86

87
    @abstractmethod
88
    def encode(
89
        self,
90
        prompt: PromptType | EngineInput,
91
92
        pooling_params: PoolingParams,
        request_id: str,
93
94
        lora_request: LoRARequest | None = None,
        trace_headers: Mapping[str, str] | None = None,
95
        priority: int = 0,
96
        tokenization_kwargs: dict[str, Any] | None = None,
97
        reasoning_ended: bool | None = None,
98
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
99
        """Generate outputs for a request from a pooling model."""
100
        ...
101

102
    @abstractmethod
103
    async def abort(self, request_id: str | Iterable[str]) -> None:
104
105
106
        """Abort a request.

        Args:
107
108
            request_id: The unique id of the request,
                        or an iterable of such ids.
109
        """
110
        ...
111

112
    @abstractmethod
113
    async def is_tracing_enabled(self) -> bool: ...
114

115
    @abstractmethod
116
    async def do_log_stats(self) -> None: ...
117

118
    @abstractmethod
119
120
    async def check_health(self) -> None:
        """Raise if unhealthy"""
121
        ...
122

123
    @abstractmethod
124
125
126
127
    async def start_profile(self) -> None:
        """Start profiling the engine"""
        ...

128
    @abstractmethod
129
    async def stop_profile(self) -> None:
130
        """Stop profiling the engine"""
131
        ...
132

133
134
135
136
137
    @abstractmethod
    async def reset_mm_cache(self) -> None:
        """Reset the multi-modal cache"""
        ...

138
139
140
141
142
    @abstractmethod
    async def reset_encoder_cache(self) -> None:
        """Reset the encoder cache"""
        ...

143
    @abstractmethod
144
145
146
147
    async def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        """Reset the prefix cache and optionally any configured connector cache"""
148
149
        ...

150
    @abstractmethod
151
    async def sleep(self, level: int = 1, mode: "PauseMode" = "abort") -> None:
152
153
154
155
        """Sleep the engine"""
        ...

    @abstractmethod
156
    async def wake_up(self, tags: list[str] | None = None) -> None:
157
158
159
        """Wake up the engine"""
        ...

160
161
162
163
164
    @abstractmethod
    async def is_sleeping(self) -> bool:
        """Check whether the engine is sleeping"""
        ...

165
    @abstractmethod
166
    async def add_lora(self, lora_request: LoRARequest) -> bool:
167
168
        """Load a new LoRA adapter into the engine for future requests."""
        ...
169

170
171
172
173
    @abstractmethod
    async def pause_generation(
        self,
        *,
174
        mode: "PauseMode" = "abort",
175
176
177
178
179
180
        wait_for_inflight_requests: bool = False,
        clear_cache: bool = True,
    ) -> None:
        """Pause new generation/encoding requests.

        Args:
181
182
183
184
185
186
187
188
189
            mode: How to handle in-flight requests:
                - ``"abort"``: Abort all in-flight requests immediately
                  and return partial results with "abort" reason (default).
                - ``"wait"``: Wait for in-flight requests to complete.
                - ``"keep"``: Freeze requests in queue; they resume on
                  :meth:`resume_generation`.
            wait_for_inflight_requests: DEPRECATED. Use ``mode="wait"`` instead.
            clear_cache: DEPRECATED. Whether to clear KV and prefix caches
                after draining.
190
191
192
193
194
195
196
197
198
199
200
201
202
        """
        ...

    @abstractmethod
    async def resume_generation(self) -> None:
        """Resume accepting generation/encoding requests."""
        ...

    @abstractmethod
    async def is_paused(self) -> bool:
        """Return whether the engine is currently paused."""
        ...

203
204
205
206
207
    @abstractmethod
    def shutdown(self, timeout: float | None = None) -> None:
        """Shutdown the engine with optional timeout."""
        ...

208
209
210
    async def scale_elastic_ep(
        self, new_data_parallel_size: int, drain_timeout: int = 300
    ) -> None:
211
212
        """Scale the engine"""
        raise NotImplementedError
213

214
215
216
    async def collective_rpc(
        self,
        method: str,
217
        timeout: float | None = None,
218
        args: tuple = (),
219
        kwargs: dict | None = None,
220
    ):
221
222
        """Perform a collective RPC call to the given path."""
        raise NotImplementedError
223
224
225
226

    async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        """Get supported tasks"""
        raise NotImplementedError
227
228
229
230
231
232
233
234
235
236

    async def init_weight_transfer_engine(
        self, init_request: WeightTransferInitRequest
    ) -> None:
        """Initialize weight transfer for RL training."""
        raise NotImplementedError

    async def update_weights(self, request: WeightTransferUpdateRequest) -> None:
        """Batched weight update for RL training."""
        raise NotImplementedError