protocol.py 6.43 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from collections.abc import AsyncGenerator, Iterable, Mapping
6
from typing import TYPE_CHECKING, Any
7

8
from vllm.config import ModelConfig, VllmConfig
9
10
11
12
from vllm.distributed.weight_transfer.base import (
    WeightTransferInitRequest,
    WeightTransferUpdateRequest,
)
13
from vllm.inputs.data import PromptType, StreamingInput
14
from vllm.lora.request import LoRARequest
15
16
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import IOProcessor
17
from vllm.pooling_params import PoolingParams
18
from vllm.renderers import BaseRenderer
19
from vllm.renderers.inputs import DictPrompt, TokPrompt
20
from vllm.sampling_params import SamplingParams
21
from vllm.tasks import SupportedTask
22
from vllm.v1.engine import EngineCoreRequest
23
from vllm.v1.engine.input_processor import InputProcessor
24

25
26
27
if TYPE_CHECKING:
    from vllm.v1.engine import PauseMode

28

29
class EngineClient(ABC):
30
    """Protocol class for Clients to Engine"""
31

32
33
    vllm_config: VllmConfig
    model_config: ModelConfig
34
    renderer: BaseRenderer
35
    io_processor: IOProcessor | None
36
    input_processor: InputProcessor
37

38
    @property
39
    @abstractmethod
40
    def is_running(self) -> bool: ...
41
42

    @property
43
    @abstractmethod
44
    def is_stopped(self) -> bool: ...
45
46

    @property
47
    @abstractmethod
48
    def errored(self) -> bool: ...
49

50
    @property
51
    @abstractmethod
52
    def dead_error(self) -> BaseException: ...
53

54
    @abstractmethod
55
    def generate(
56
        self,
57
58
59
60
61
        prompt: EngineCoreRequest
        | PromptType
        | DictPrompt
        | TokPrompt
        | AsyncGenerator[StreamingInput, None],
62
63
        sampling_params: SamplingParams,
        request_id: str,
64
        *,
65
66
67
68
        prompt_text: str | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
69
        priority: int = 0,
70
        data_parallel_rank: int | None = None,
71
    ) -> AsyncGenerator[RequestOutput, None]:
72
        """Generate outputs for a request."""
73
        ...
74

75
    @abstractmethod
76
    def encode(
77
        self,
78
        prompt: PromptType | DictPrompt | TokPrompt,
79
80
        pooling_params: PoolingParams,
        request_id: str,
81
82
        lora_request: LoRARequest | None = None,
        trace_headers: Mapping[str, str] | None = None,
83
        priority: int = 0,
84
        tokenization_kwargs: dict[str, Any] | None = None,
85
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
86
        """Generate outputs for a request from a pooling model."""
87
        ...
88

89
    @abstractmethod
90
    async def abort(self, request_id: str | Iterable[str]) -> None:
91
92
93
        """Abort a request.

        Args:
94
95
            request_id: The unique id of the request,
                        or an iterable of such ids.
96
        """
97
        ...
98

99
    @abstractmethod
100
    async def is_tracing_enabled(self) -> bool: ...
101

102
    @abstractmethod
103
    async def do_log_stats(self) -> None: ...
104

105
    @abstractmethod
106
107
    async def check_health(self) -> None:
        """Raise if unhealthy"""
108
        ...
109

110
    @abstractmethod
111
112
113
114
    async def start_profile(self) -> None:
        """Start profiling the engine"""
        ...

115
    @abstractmethod
116
    async def stop_profile(self) -> None:
117
        """Stop profiling the engine"""
118
        ...
119

120
121
122
123
124
    @abstractmethod
    async def reset_mm_cache(self) -> None:
        """Reset the multi-modal cache"""
        ...

125
126
127
128
129
    @abstractmethod
    async def reset_encoder_cache(self) -> None:
        """Reset the encoder cache"""
        ...

130
    @abstractmethod
131
132
133
134
    async def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        """Reset the prefix cache and optionally any configured connector cache"""
135
136
        ...

137
138
139
140
141
142
    @abstractmethod
    async def sleep(self, level: int = 1) -> None:
        """Sleep the engine"""
        ...

    @abstractmethod
143
    async def wake_up(self, tags: list[str] | None = None) -> None:
144
145
146
        """Wake up the engine"""
        ...

147
148
149
150
151
    @abstractmethod
    async def is_sleeping(self) -> bool:
        """Check whether the engine is sleeping"""
        ...

152
    @abstractmethod
153
    async def add_lora(self, lora_request: LoRARequest) -> bool:
154
155
        """Load a new LoRA adapter into the engine for future requests."""
        ...
156

157
158
159
160
    @abstractmethod
    async def pause_generation(
        self,
        *,
161
        mode: "PauseMode" = "abort",
162
163
164
165
166
167
        wait_for_inflight_requests: bool = False,
        clear_cache: bool = True,
    ) -> None:
        """Pause new generation/encoding requests.

        Args:
168
169
170
171
172
173
174
175
176
            mode: How to handle in-flight requests:
                - ``"abort"``: Abort all in-flight requests immediately
                  and return partial results with "abort" reason (default).
                - ``"wait"``: Wait for in-flight requests to complete.
                - ``"keep"``: Freeze requests in queue; they resume on
                  :meth:`resume_generation`.
            wait_for_inflight_requests: DEPRECATED. Use ``mode="wait"`` instead.
            clear_cache: DEPRECATED. Whether to clear KV and prefix caches
                after draining.
177
178
179
180
181
182
183
184
185
186
187
188
189
        """
        ...

    @abstractmethod
    async def resume_generation(self) -> None:
        """Resume accepting generation/encoding requests."""
        ...

    @abstractmethod
    async def is_paused(self) -> bool:
        """Return whether the engine is currently paused."""
        ...

190
191
192
    async def scale_elastic_ep(
        self, new_data_parallel_size: int, drain_timeout: int = 300
    ) -> None:
193
194
        """Scale the engine"""
        raise NotImplementedError
195

196
197
198
    async def collective_rpc(
        self,
        method: str,
199
        timeout: float | None = None,
200
        args: tuple = (),
201
        kwargs: dict | None = None,
202
    ):
203
204
        """Perform a collective RPC call to the given path."""
        raise NotImplementedError
205
206
207
208

    async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        """Get supported tasks"""
        raise NotImplementedError
209
210
211
212
213
214
215
216
217
218

    async def init_weight_transfer_engine(
        self, init_request: WeightTransferInitRequest
    ) -> None:
        """Initialize weight transfer for RL training."""
        raise NotImplementedError

    async def update_weights(self, request: WeightTransferUpdateRequest) -> None:
        """Batched weight update for RL training."""
        raise NotImplementedError