protocol.py 5.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import enum
5
from abc import ABC, abstractmethod
6
from collections.abc import AsyncGenerator, Iterable, Mapping
7
from typing import Any
8

9
from vllm.config import ModelConfig, VllmConfig
10
from vllm.inputs.data import PromptType
11
from vllm.logger import init_logger
12
from vllm.lora.request import LoRARequest
13
14
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import IOProcessor
15
from vllm.pooling_params import PoolingParams
16
from vllm.sampling_params import SamplingParams
17
from vllm.tasks import SupportedTask
18
from vllm.transformers_utils.tokenizer import AnyTokenizer
19
from vllm.v1.engine import EngineCoreRequest
20
from vllm.v1.engine.processor import Processor
21

22
logger = init_logger(__name__)
23

24

25
26
27
28
29
class Device(enum.Enum):
    GPU = enum.auto()
    CPU = enum.auto()


30
class EngineClient(ABC):
31
    """Protocol class for Clients to Engine"""
32

33
34
35
    vllm_config: VllmConfig
    model_config: ModelConfig
    processor: Processor
36
    io_processor: IOProcessor | None
37

38
    @property
39
    @abstractmethod
40
    def is_running(self) -> bool: ...
41
42

    @property
43
    @abstractmethod
44
    def is_stopped(self) -> bool: ...
45
46

    @property
47
    @abstractmethod
48
    def errored(self) -> bool: ...
49

50
    @property
51
    @abstractmethod
52
    def dead_error(self) -> BaseException: ...
53

54
    @abstractmethod
55
    def generate(
56
        self,
57
        prompt: EngineCoreRequest | PromptType,
58
59
        sampling_params: SamplingParams,
        request_id: str,
60
        *,
61
62
63
64
        prompt_text: str | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
65
        priority: int = 0,
66
        data_parallel_rank: int | None = None,
67
    ) -> AsyncGenerator[RequestOutput, None]:
68
        """Generate outputs for a request."""
69
        ...
70

71
    @abstractmethod
72
    def encode(
73
        self,
74
        prompt: PromptType,
75
76
        pooling_params: PoolingParams,
        request_id: str,
77
78
        lora_request: LoRARequest | None = None,
        trace_headers: Mapping[str, str] | None = None,
79
        priority: int = 0,
80
        truncate_prompt_tokens: int | None = None,
81
        tokenization_kwargs: dict[str, Any] | None = None,
82
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
83
        """Generate outputs for a request from a pooling model."""
84
        ...
85

86
    @abstractmethod
87
    async def abort(self, request_id: str | Iterable[str]) -> None:
88
89
90
        """Abort a request.

        Args:
91
92
            request_id: The unique id of the request,
                        or an iterable of such ids.
93
        """
94
        ...
95

96
    @abstractmethod
97
98
    async def get_tokenizer(self) -> AnyTokenizer:
        """Get the tokenizer"""
99
        ...
100

101
    @abstractmethod
102
    async def is_tracing_enabled(self) -> bool: ...
103

104
    @abstractmethod
105
    async def do_log_stats(self) -> None: ...
106

107
    @abstractmethod
108
109
    async def check_health(self) -> None:
        """Raise if unhealthy"""
110
        ...
111

112
    @abstractmethod
113
114
115
116
    async def start_profile(self) -> None:
        """Start profiling the engine"""
        ...

117
    @abstractmethod
118
    async def stop_profile(self) -> None:
119
        """Stop profiling the engine"""
120
        ...
121

122
123
124
125
126
    @abstractmethod
    async def reset_mm_cache(self) -> None:
        """Reset the multi-modal cache"""
        ...

127
    @abstractmethod
128
    async def reset_prefix_cache(self) -> None:
129
130
131
        """Reset the prefix cache"""
        ...

132
133
134
135
136
137
    @abstractmethod
    async def sleep(self, level: int = 1) -> None:
        """Sleep the engine"""
        ...

    @abstractmethod
138
    async def wake_up(self, tags: list[str] | None = None) -> None:
139
140
141
        """Wake up the engine"""
        ...

142
143
144
145
146
    @abstractmethod
    async def is_sleeping(self) -> bool:
        """Check whether the engine is sleeping"""
        ...

147
    @abstractmethod
148
    async def add_lora(self, lora_request: LoRARequest) -> bool:
149
150
        """Load a new LoRA adapter into the engine for future requests."""
        ...
151

152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
    @abstractmethod
    async def pause_generation(
        self,
        *,
        wait_for_inflight_requests: bool = False,
        clear_cache: bool = True,
    ) -> None:
        """Pause new generation/encoding requests.

        Args:
            wait_for_inflight_requests: When ``True`` waits for in-flight requests
                to finish before pausing. When ``False`` (default), aborts in-flight
                requests immediately.
            clear_cache: Whether to clear KV and prefix caches after draining.
        """
        ...

    @abstractmethod
    async def resume_generation(self) -> None:
        """Resume accepting generation/encoding requests."""
        ...

    @abstractmethod
    async def is_paused(self) -> bool:
        """Return whether the engine is currently paused."""
        ...

179
180
181
    async def scale_elastic_ep(
        self, new_data_parallel_size: int, drain_timeout: int = 300
    ) -> None:
182
183
        """Scale the engine"""
        raise NotImplementedError
184

185
186
187
    async def collective_rpc(
        self,
        method: str,
188
        timeout: float | None = None,
189
        args: tuple = (),
190
        kwargs: dict | None = None,
191
    ):
192
193
        """Perform a collective RPC call to the given path."""
        raise NotImplementedError
194
195
196
197

    async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        """Get supported tasks"""
        raise NotImplementedError