protocol.py 5.47 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from collections.abc import AsyncGenerator, Iterable, Mapping
6
from typing import Any
7

8
from vllm.config import ModelConfig, VllmConfig
9
from vllm.inputs.data import PromptType, StreamingInput
10
from vllm.lora.request import LoRARequest
11
12
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import IOProcessor
13
from vllm.pooling_params import PoolingParams
14
from vllm.renderers import BaseRenderer
15
from vllm.sampling_params import SamplingParams
16
from vllm.tasks import SupportedTask
17
from vllm.v1.engine import EngineCoreRequest
18
from vllm.v1.engine.input_processor import InputProcessor
19

20

21
class EngineClient(ABC):
22
    """Protocol class for Clients to Engine"""
23

24
25
    vllm_config: VllmConfig
    model_config: ModelConfig
26
    input_processor: InputProcessor
27
    io_processor: IOProcessor | None
28

29
30
    @property
    @abstractmethod
31
    def renderer(self) -> BaseRenderer: ...
32

33
    @property
34
    @abstractmethod
35
    def is_running(self) -> bool: ...
36
37

    @property
38
    @abstractmethod
39
    def is_stopped(self) -> bool: ...
40
41

    @property
42
    @abstractmethod
43
    def errored(self) -> bool: ...
44

45
    @property
46
    @abstractmethod
47
    def dead_error(self) -> BaseException: ...
48

49
    @abstractmethod
50
    def generate(
51
        self,
52
        prompt: EngineCoreRequest | PromptType | AsyncGenerator[StreamingInput, None],
53
54
        sampling_params: SamplingParams,
        request_id: str,
55
        *,
56
57
58
59
        prompt_text: str | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
60
        priority: int = 0,
61
        data_parallel_rank: int | None = None,
62
    ) -> AsyncGenerator[RequestOutput, None]:
63
        """Generate outputs for a request."""
64
        ...
65

66
    @abstractmethod
67
    def encode(
68
        self,
69
        prompt: PromptType,
70
71
        pooling_params: PoolingParams,
        request_id: str,
72
73
        lora_request: LoRARequest | None = None,
        trace_headers: Mapping[str, str] | None = None,
74
        priority: int = 0,
75
        tokenization_kwargs: dict[str, Any] | None = None,
76
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
77
        """Generate outputs for a request from a pooling model."""
78
        ...
79

80
    @abstractmethod
81
    async def abort(self, request_id: str | Iterable[str]) -> None:
82
83
84
        """Abort a request.

        Args:
85
86
            request_id: The unique id of the request,
                        or an iterable of such ids.
87
        """
88
        ...
89

90
    @abstractmethod
91
    async def is_tracing_enabled(self) -> bool: ...
92

93
    @abstractmethod
94
    async def do_log_stats(self) -> None: ...
95

96
    @abstractmethod
97
98
    async def check_health(self) -> None:
        """Raise if unhealthy"""
99
        ...
100

101
    @abstractmethod
102
103
104
105
    async def start_profile(self) -> None:
        """Start profiling the engine"""
        ...

106
    @abstractmethod
107
    async def stop_profile(self) -> None:
108
        """Stop profiling the engine"""
109
        ...
110

111
112
113
114
115
    @abstractmethod
    async def reset_mm_cache(self) -> None:
        """Reset the multi-modal cache"""
        ...

116
117
118
119
120
    @abstractmethod
    async def reset_encoder_cache(self) -> None:
        """Reset the encoder cache"""
        ...

121
    @abstractmethod
122
123
124
125
    async def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        """Reset the prefix cache and optionally any configured connector cache"""
126
127
        ...

128
129
130
131
132
133
    @abstractmethod
    async def sleep(self, level: int = 1) -> None:
        """Sleep the engine"""
        ...

    @abstractmethod
134
    async def wake_up(self, tags: list[str] | None = None) -> None:
135
136
137
        """Wake up the engine"""
        ...

138
139
140
141
142
    @abstractmethod
    async def is_sleeping(self) -> bool:
        """Check whether the engine is sleeping"""
        ...

143
    @abstractmethod
144
    async def add_lora(self, lora_request: LoRARequest) -> bool:
145
146
        """Load a new LoRA adapter into the engine for future requests."""
        ...
147

148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
    @abstractmethod
    async def pause_generation(
        self,
        *,
        wait_for_inflight_requests: bool = False,
        clear_cache: bool = True,
    ) -> None:
        """Pause new generation/encoding requests.

        Args:
            wait_for_inflight_requests: When ``True`` waits for in-flight requests
                to finish before pausing. When ``False`` (default), aborts in-flight
                requests immediately.
            clear_cache: Whether to clear KV and prefix caches after draining.
        """
        ...

    @abstractmethod
    async def resume_generation(self) -> None:
        """Resume accepting generation/encoding requests."""
        ...

    @abstractmethod
    async def is_paused(self) -> bool:
        """Return whether the engine is currently paused."""
        ...

175
176
177
    async def scale_elastic_ep(
        self, new_data_parallel_size: int, drain_timeout: int = 300
    ) -> None:
178
179
        """Scale the engine"""
        raise NotImplementedError
180

181
182
183
    async def collective_rpc(
        self,
        method: str,
184
        timeout: float | None = None,
185
        args: tuple = (),
186
        kwargs: dict | None = None,
187
    ):
188
189
        """Perform a collective RPC call to the given path."""
        raise NotImplementedError
190
191
192
193

    async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        """Get supported tasks"""
        raise NotImplementedError