protocol.py 4.63 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import enum
5
from abc import ABC, abstractmethod
6
from collections.abc import AsyncGenerator, Iterable, Mapping
7
from typing import Any
8

9
from vllm.config import ModelConfig, VllmConfig
10
from vllm.inputs.data import PromptType
11
from vllm.logger import init_logger
12
from vllm.lora.request import LoRARequest
13
14
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import IOProcessor
15
from vllm.pooling_params import PoolingParams
16
from vllm.sampling_params import SamplingParams
17
from vllm.tasks import SupportedTask
18
from vllm.transformers_utils.tokenizer import AnyTokenizer
19
from vllm.v1.engine import EngineCoreRequest
20
from vllm.v1.engine.processor import Processor
21

22
logger = init_logger(__name__)
23

24

25
26
27
28
29
class Device(enum.Enum):
    GPU = enum.auto()
    CPU = enum.auto()


30
class EngineClient(ABC):
31
    """Protocol class for Clients to Engine"""
32

33
34
35
    vllm_config: VllmConfig
    model_config: ModelConfig
    processor: Processor
36
    io_processor: IOProcessor | None
37

38
    @property
39
    @abstractmethod
40
    def is_running(self) -> bool: ...
41
42

    @property
43
    @abstractmethod
44
    def is_stopped(self) -> bool: ...
45
46

    @property
47
    @abstractmethod
48
    def errored(self) -> bool: ...
49

50
    @property
51
    @abstractmethod
52
    def dead_error(self) -> BaseException: ...
53

54
    @abstractmethod
55
    def generate(
56
        self,
57
        prompt: EngineCoreRequest | PromptType,
58
59
        sampling_params: SamplingParams,
        request_id: str,
60
        *,
61
62
63
64
        prompt_text: str | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
65
        priority: int = 0,
66
        data_parallel_rank: int | None = None,
67
    ) -> AsyncGenerator[RequestOutput, None]:
68
        """Generate outputs for a request."""
69
        ...
70

71
    @abstractmethod
72
    def encode(
73
        self,
74
        prompt: PromptType,
75
76
        pooling_params: PoolingParams,
        request_id: str,
77
78
        lora_request: LoRARequest | None = None,
        trace_headers: Mapping[str, str] | None = None,
79
        priority: int = 0,
80
        truncate_prompt_tokens: int | None = None,
81
        tokenization_kwargs: dict[str, Any] | None = None,
82
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
83
        """Generate outputs for a request from a pooling model."""
84
        ...
85

86
    @abstractmethod
87
    async def abort(self, request_id: str | Iterable[str]) -> None:
88
89
90
        """Abort a request.

        Args:
91
92
            request_id: The unique id of the request,
                        or an iterable of such ids.
93
        """
94
        ...
95

96
    @abstractmethod
97
98
    async def get_tokenizer(self) -> AnyTokenizer:
        """Get the tokenizer"""
99
        ...
100

101
    @abstractmethod
102
    async def is_tracing_enabled(self) -> bool: ...
103

104
    @abstractmethod
105
    async def do_log_stats(self) -> None: ...
106

107
    @abstractmethod
108
109
    async def check_health(self) -> None:
        """Raise if unhealthy"""
110
        ...
111

112
    @abstractmethod
113
114
115
116
    async def start_profile(self) -> None:
        """Start profiling the engine"""
        ...

117
    @abstractmethod
118
    async def stop_profile(self) -> None:
119
        """Stop profiling the engine"""
120
        ...
121

122
123
124
125
126
    @abstractmethod
    async def reset_mm_cache(self) -> None:
        """Reset the multi-modal cache"""
        ...

127
    @abstractmethod
128
    async def reset_prefix_cache(self, device: Device | None = None) -> None:
129
130
131
        """Reset the prefix cache"""
        ...

132
133
134
135
136
137
    @abstractmethod
    async def sleep(self, level: int = 1) -> None:
        """Sleep the engine"""
        ...

    @abstractmethod
138
    async def wake_up(self, tags: list[str] | None = None) -> None:
139
140
141
        """Wake up the engine"""
        ...

142
143
144
145
146
    @abstractmethod
    async def is_sleeping(self) -> bool:
        """Check whether the engine is sleeping"""
        ...

147
    @abstractmethod
148
    async def add_lora(self, lora_request: LoRARequest) -> bool:
149
150
        """Load a new LoRA adapter into the engine for future requests."""
        ...
151

152
153
154
    async def scale_elastic_ep(
        self, new_data_parallel_size: int, drain_timeout: int = 300
    ) -> None:
155
156
        """Scale the engine"""
        raise NotImplementedError
157

158
159
160
    async def collective_rpc(
        self,
        method: str,
161
        timeout: float | None = None,
162
        args: tuple = (),
163
        kwargs: dict | None = None,
164
    ):
165
166
        """Perform a collective RPC call to the given path."""
        raise NotImplementedError
167
168
169
170

    async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        """Get supported tasks"""
        raise NotImplementedError