protocol.py 6.09 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from collections.abc import AsyncGenerator, Iterable, Mapping
6
from typing import Any
7

8
from vllm.config import ModelConfig, VllmConfig
9
10
11
12
from vllm.distributed.weight_transfer.base import (
    WeightTransferInitRequest,
    WeightTransferUpdateRequest,
)
13
from vllm.inputs.data import PromptType, StreamingInput
14
from vllm.lora.request import LoRARequest
15
16
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import IOProcessor
17
from vllm.pooling_params import PoolingParams
18
from vllm.renderers import BaseRenderer
19
from vllm.renderers.inputs import DictPrompt, TokPrompt
20
from vllm.sampling_params import SamplingParams
21
from vllm.tasks import SupportedTask
22
from vllm.v1.engine import EngineCoreRequest
23
from vllm.v1.engine.input_processor import InputProcessor
24

25

26
class EngineClient(ABC):
27
    """Protocol class for Clients to Engine"""
28

29
30
    vllm_config: VllmConfig
    model_config: ModelConfig
31
    input_processor: InputProcessor
32
    io_processor: IOProcessor | None
33

34
35
    @property
    @abstractmethod
36
    def renderer(self) -> BaseRenderer: ...
37

38
    @property
39
    @abstractmethod
40
    def is_running(self) -> bool: ...
41
42

    @property
43
    @abstractmethod
44
    def is_stopped(self) -> bool: ...
45
46

    @property
47
    @abstractmethod
48
    def errored(self) -> bool: ...
49

50
    @property
51
    @abstractmethod
52
    def dead_error(self) -> BaseException: ...
53

54
    @abstractmethod
55
    def generate(
56
        self,
57
58
59
60
61
        prompt: EngineCoreRequest
        | PromptType
        | DictPrompt
        | TokPrompt
        | AsyncGenerator[StreamingInput, None],
62
63
        sampling_params: SamplingParams,
        request_id: str,
64
        *,
65
66
67
68
        prompt_text: str | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
69
        priority: int = 0,
70
        data_parallel_rank: int | None = None,
71
    ) -> AsyncGenerator[RequestOutput, None]:
72
        """Generate outputs for a request."""
73
        ...
74

75
    @abstractmethod
76
    def encode(
77
        self,
78
        prompt: PromptType | DictPrompt | TokPrompt,
79
80
        pooling_params: PoolingParams,
        request_id: str,
81
82
        lora_request: LoRARequest | None = None,
        trace_headers: Mapping[str, str] | None = None,
83
        priority: int = 0,
84
        tokenization_kwargs: dict[str, Any] | None = None,
85
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
86
        """Generate outputs for a request from a pooling model."""
87
        ...
88

89
    @abstractmethod
90
    async def abort(self, request_id: str | Iterable[str]) -> None:
91
92
93
        """Abort a request.

        Args:
94
95
            request_id: The unique id of the request,
                        or an iterable of such ids.
96
        """
97
        ...
98

99
    @abstractmethod
100
    async def is_tracing_enabled(self) -> bool: ...
101

102
    @abstractmethod
103
    async def do_log_stats(self) -> None: ...
104

105
    @abstractmethod
106
107
    async def check_health(self) -> None:
        """Raise if unhealthy"""
108
        ...
109

110
    @abstractmethod
111
112
113
114
    async def start_profile(self) -> None:
        """Start profiling the engine"""
        ...

115
    @abstractmethod
116
    async def stop_profile(self) -> None:
117
        """Stop profiling the engine"""
118
        ...
119

120
121
122
123
124
    @abstractmethod
    async def reset_mm_cache(self) -> None:
        """Reset the multi-modal cache"""
        ...

125
126
127
128
129
    @abstractmethod
    async def reset_encoder_cache(self) -> None:
        """Reset the encoder cache"""
        ...

130
    @abstractmethod
131
132
133
134
    async def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        """Reset the prefix cache and optionally any configured connector cache"""
135
136
        ...

137
138
139
140
141
142
    @abstractmethod
    async def sleep(self, level: int = 1) -> None:
        """Sleep the engine"""
        ...

    @abstractmethod
143
    async def wake_up(self, tags: list[str] | None = None) -> None:
144
145
146
        """Wake up the engine"""
        ...

147
148
149
150
151
    @abstractmethod
    async def is_sleeping(self) -> bool:
        """Check whether the engine is sleeping"""
        ...

152
    @abstractmethod
153
    async def add_lora(self, lora_request: LoRARequest) -> bool:
154
155
        """Load a new LoRA adapter into the engine for future requests."""
        ...
156

157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
    @abstractmethod
    async def pause_generation(
        self,
        *,
        wait_for_inflight_requests: bool = False,
        clear_cache: bool = True,
    ) -> None:
        """Pause new generation/encoding requests.

        Args:
            wait_for_inflight_requests: When ``True`` waits for in-flight requests
                to finish before pausing. When ``False`` (default), aborts in-flight
                requests immediately.
            clear_cache: Whether to clear KV and prefix caches after draining.
        """
        ...

    @abstractmethod
    async def resume_generation(self) -> None:
        """Resume accepting generation/encoding requests."""
        ...

    @abstractmethod
    async def is_paused(self) -> bool:
        """Return whether the engine is currently paused."""
        ...

184
185
186
    async def scale_elastic_ep(
        self, new_data_parallel_size: int, drain_timeout: int = 300
    ) -> None:
187
188
        """Scale the engine"""
        raise NotImplementedError
189

190
191
192
    async def collective_rpc(
        self,
        method: str,
193
        timeout: float | None = None,
194
        args: tuple = (),
195
        kwargs: dict | None = None,
196
    ):
197
198
        """Perform a collective RPC call to the given path."""
        raise NotImplementedError
199
200
201
202

    async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        """Get supported tasks"""
        raise NotImplementedError
203
204
205
206
207
208
209
210
211
212

    async def init_weight_transfer_engine(
        self, init_request: WeightTransferInitRequest
    ) -> None:
        """Initialize weight transfer for RL training."""
        raise NotImplementedError

    async def update_weights(self, request: WeightTransferUpdateRequest) -> None:
        """Batched weight update for RL training."""
        raise NotImplementedError