protocol.py 5.95 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from collections.abc import AsyncGenerator, Iterable, Mapping
6
from typing import Any
7

8
from vllm.config import ModelConfig, VllmConfig
9
10
11
12
from vllm.distributed.weight_transfer.base import (
    WeightTransferInitRequest,
    WeightTransferUpdateRequest,
)
13
from vllm.inputs.data import PromptType, StreamingInput
14
from vllm.lora.request import LoRARequest
15
16
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import IOProcessor
17
from vllm.pooling_params import PoolingParams
18
from vllm.renderers import BaseRenderer
19
from vllm.sampling_params import SamplingParams
20
from vllm.tasks import SupportedTask
21
from vllm.v1.engine import EngineCoreRequest
22
from vllm.v1.engine.input_processor import InputProcessor
23

24

25
class EngineClient(ABC):
26
    """Protocol class for Clients to Engine"""
27

28
29
    vllm_config: VllmConfig
    model_config: ModelConfig
30
    input_processor: InputProcessor
31
    io_processor: IOProcessor | None
32

33
34
    @property
    @abstractmethod
35
    def renderer(self) -> BaseRenderer: ...
36

37
    @property
38
    @abstractmethod
39
    def is_running(self) -> bool: ...
40
41

    @property
42
    @abstractmethod
43
    def is_stopped(self) -> bool: ...
44
45

    @property
46
    @abstractmethod
47
    def errored(self) -> bool: ...
48

49
    @property
50
    @abstractmethod
51
    def dead_error(self) -> BaseException: ...
52

53
    @abstractmethod
54
    def generate(
55
        self,
56
        prompt: EngineCoreRequest | PromptType | AsyncGenerator[StreamingInput, None],
57
58
        sampling_params: SamplingParams,
        request_id: str,
59
        *,
60
61
62
63
        prompt_text: str | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
64
        priority: int = 0,
65
        data_parallel_rank: int | None = None,
66
    ) -> AsyncGenerator[RequestOutput, None]:
67
        """Generate outputs for a request."""
68
        ...
69

70
    @abstractmethod
71
    def encode(
72
        self,
73
        prompt: PromptType,
74
75
        pooling_params: PoolingParams,
        request_id: str,
76
77
        lora_request: LoRARequest | None = None,
        trace_headers: Mapping[str, str] | None = None,
78
        priority: int = 0,
79
        tokenization_kwargs: dict[str, Any] | None = None,
80
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
81
        """Generate outputs for a request from a pooling model."""
82
        ...
83

84
    @abstractmethod
85
    async def abort(self, request_id: str | Iterable[str]) -> None:
86
87
88
        """Abort a request.

        Args:
89
90
            request_id: The unique id of the request,
                        or an iterable of such ids.
91
        """
92
        ...
93

94
    @abstractmethod
95
    async def is_tracing_enabled(self) -> bool: ...
96

97
    @abstractmethod
98
    async def do_log_stats(self) -> None: ...
99

100
    @abstractmethod
101
102
    async def check_health(self) -> None:
        """Raise if unhealthy"""
103
        ...
104

105
    @abstractmethod
106
107
108
109
    async def start_profile(self) -> None:
        """Start profiling the engine"""
        ...

110
    @abstractmethod
111
    async def stop_profile(self) -> None:
112
        """Stop profiling the engine"""
113
        ...
114

115
116
117
118
119
    @abstractmethod
    async def reset_mm_cache(self) -> None:
        """Reset the multi-modal cache"""
        ...

120
121
122
123
124
    @abstractmethod
    async def reset_encoder_cache(self) -> None:
        """Reset the encoder cache"""
        ...

125
    @abstractmethod
126
127
128
129
    async def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        """Reset the prefix cache and optionally any configured connector cache"""
130
131
        ...

132
133
134
135
136
137
    @abstractmethod
    async def sleep(self, level: int = 1) -> None:
        """Sleep the engine"""
        ...

    @abstractmethod
138
    async def wake_up(self, tags: list[str] | None = None) -> None:
139
140
141
        """Wake up the engine"""
        ...

142
143
144
145
146
    @abstractmethod
    async def is_sleeping(self) -> bool:
        """Check whether the engine is sleeping"""
        ...

147
    @abstractmethod
148
    async def add_lora(self, lora_request: LoRARequest) -> bool:
149
150
        """Load a new LoRA adapter into the engine for future requests."""
        ...
151

152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
    @abstractmethod
    async def pause_generation(
        self,
        *,
        wait_for_inflight_requests: bool = False,
        clear_cache: bool = True,
    ) -> None:
        """Pause new generation/encoding requests.

        Args:
            wait_for_inflight_requests: When ``True`` waits for in-flight requests
                to finish before pausing. When ``False`` (default), aborts in-flight
                requests immediately.
            clear_cache: Whether to clear KV and prefix caches after draining.
        """
        ...

    @abstractmethod
    async def resume_generation(self) -> None:
        """Resume accepting generation/encoding requests."""
        ...

    @abstractmethod
    async def is_paused(self) -> bool:
        """Return whether the engine is currently paused."""
        ...

179
180
181
    async def scale_elastic_ep(
        self, new_data_parallel_size: int, drain_timeout: int = 300
    ) -> None:
182
183
        """Scale the engine"""
        raise NotImplementedError
184

185
186
187
    async def collective_rpc(
        self,
        method: str,
188
        timeout: float | None = None,
189
        args: tuple = (),
190
        kwargs: dict | None = None,
191
    ):
192
193
        """Perform a collective RPC call to the given path."""
        raise NotImplementedError
194
195
196
197

    async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        """Get supported tasks"""
        raise NotImplementedError
198
199
200
201
202
203
204
205
206
207

    async def init_weight_transfer_engine(
        self, init_request: WeightTransferInitRequest
    ) -> None:
        """Initialize weight transfer for RL training."""
        raise NotImplementedError

    async def update_weights(self, request: WeightTransferUpdateRequest) -> None:
        """Batched weight update for RL training."""
        raise NotImplementedError