protocol.py 6.48 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from collections.abc import AsyncGenerator, Iterable, Mapping
6
from typing import TYPE_CHECKING, Any
7

8
from vllm.config import ModelConfig, VllmConfig
9
10
11
12
from vllm.distributed.weight_transfer.base import (
    WeightTransferInitRequest,
    WeightTransferUpdateRequest,
)
13
from vllm.inputs.data import PromptType, StreamingInput
14
from vllm.lora.request import LoRARequest
15
16
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import IOProcessor
17
from vllm.pooling_params import PoolingParams
18
from vllm.renderers import BaseRenderer
19
from vllm.renderers.inputs import DictPrompt, TokPrompt
20
from vllm.sampling_params import SamplingParams
21
from vllm.tasks import SupportedTask
22
from vllm.v1.engine import EngineCoreRequest
23
from vllm.v1.engine.input_processor import InputProcessor
24

25
26
27
if TYPE_CHECKING:
    from vllm.v1.engine import PauseMode

28

29
class EngineClient(ABC):
30
    """Protocol class for Clients to Engine"""
31

32
33
    vllm_config: VllmConfig
    model_config: ModelConfig
34
    input_processor: InputProcessor
35
    io_processor: IOProcessor | None
36

37
38
    @property
    @abstractmethod
39
    def renderer(self) -> BaseRenderer: ...
40

41
    @property
42
    @abstractmethod
43
    def is_running(self) -> bool: ...
44
45

    @property
46
    @abstractmethod
47
    def is_stopped(self) -> bool: ...
48
49

    @property
50
    @abstractmethod
51
    def errored(self) -> bool: ...
52

53
    @property
54
    @abstractmethod
55
    def dead_error(self) -> BaseException: ...
56

57
    @abstractmethod
58
    def generate(
59
        self,
60
61
62
63
64
        prompt: EngineCoreRequest
        | PromptType
        | DictPrompt
        | TokPrompt
        | AsyncGenerator[StreamingInput, None],
65
66
        sampling_params: SamplingParams,
        request_id: str,
67
        *,
68
69
70
71
        prompt_text: str | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
72
        priority: int = 0,
73
        data_parallel_rank: int | None = None,
74
    ) -> AsyncGenerator[RequestOutput, None]:
75
        """Generate outputs for a request."""
76
        ...
77

78
    @abstractmethod
79
    def encode(
80
        self,
81
        prompt: PromptType | DictPrompt | TokPrompt,
82
83
        pooling_params: PoolingParams,
        request_id: str,
84
85
        lora_request: LoRARequest | None = None,
        trace_headers: Mapping[str, str] | None = None,
86
        priority: int = 0,
87
        tokenization_kwargs: dict[str, Any] | None = None,
88
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
89
        """Generate outputs for a request from a pooling model."""
90
        ...
91

92
    @abstractmethod
93
    async def abort(self, request_id: str | Iterable[str]) -> None:
94
95
96
        """Abort a request.

        Args:
97
98
            request_id: The unique id of the request,
                        or an iterable of such ids.
99
        """
100
        ...
101

102
    @abstractmethod
103
    async def is_tracing_enabled(self) -> bool: ...
104

105
    @abstractmethod
106
    async def do_log_stats(self) -> None: ...
107

108
    @abstractmethod
109
110
    async def check_health(self) -> None:
        """Raise if unhealthy"""
111
        ...
112

113
    @abstractmethod
114
115
116
117
    async def start_profile(self) -> None:
        """Start profiling the engine"""
        ...

118
    @abstractmethod
119
    async def stop_profile(self) -> None:
120
        """Stop profiling the engine"""
121
        ...
122

123
124
125
126
127
    @abstractmethod
    async def reset_mm_cache(self) -> None:
        """Reset the multi-modal cache"""
        ...

128
129
130
131
132
    @abstractmethod
    async def reset_encoder_cache(self) -> None:
        """Reset the encoder cache"""
        ...

133
    @abstractmethod
134
135
136
137
    async def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        """Reset the prefix cache and optionally any configured connector cache"""
138
139
        ...

140
141
142
143
144
145
    @abstractmethod
    async def sleep(self, level: int = 1) -> None:
        """Sleep the engine"""
        ...

    @abstractmethod
146
    async def wake_up(self, tags: list[str] | None = None) -> None:
147
148
149
        """Wake up the engine"""
        ...

150
151
152
153
154
    @abstractmethod
    async def is_sleeping(self) -> bool:
        """Check whether the engine is sleeping"""
        ...

155
    @abstractmethod
156
    async def add_lora(self, lora_request: LoRARequest) -> bool:
157
158
        """Load a new LoRA adapter into the engine for future requests."""
        ...
159

160
161
162
163
    @abstractmethod
    async def pause_generation(
        self,
        *,
164
        mode: "PauseMode" = "abort",
165
166
167
168
169
170
        wait_for_inflight_requests: bool = False,
        clear_cache: bool = True,
    ) -> None:
        """Pause new generation/encoding requests.

        Args:
171
172
173
174
175
176
177
178
179
            mode: How to handle in-flight requests:
                - ``"abort"``: Abort all in-flight requests immediately
                  and return partial results with "abort" reason (default).
                - ``"wait"``: Wait for in-flight requests to complete.
                - ``"keep"``: Freeze requests in queue; they resume on
                  :meth:`resume_generation`.
            wait_for_inflight_requests: DEPRECATED. Use ``mode="wait"`` instead.
            clear_cache: DEPRECATED. Whether to clear KV and prefix caches
                after draining.
180
181
182
183
184
185
186
187
188
189
190
191
192
        """
        ...

    @abstractmethod
    async def resume_generation(self) -> None:
        """Resume accepting generation/encoding requests."""
        ...

    @abstractmethod
    async def is_paused(self) -> bool:
        """Return whether the engine is currently paused."""
        ...

193
194
195
    async def scale_elastic_ep(
        self, new_data_parallel_size: int, drain_timeout: int = 300
    ) -> None:
196
197
        """Scale the engine"""
        raise NotImplementedError
198

199
200
201
    async def collective_rpc(
        self,
        method: str,
202
        timeout: float | None = None,
203
        args: tuple = (),
204
        kwargs: dict | None = None,
205
    ):
206
207
        """Perform a collective RPC call to the given path."""
        raise NotImplementedError
208
209
210
211

    async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        """Get supported tasks"""
        raise NotImplementedError
212
213
214
215
216
217
218
219
220
221

    async def init_weight_transfer_engine(
        self, init_request: WeightTransferInitRequest
    ) -> None:
        """Initialize weight transfer for RL training."""
        raise NotImplementedError

    async def update_weights(self, request: WeightTransferUpdateRequest) -> None:
        """Batched weight update for RL training."""
        raise NotImplementedError