protocol.py 4.52 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from collections.abc import AsyncGenerator, Iterable, Mapping
6
from typing import Any
7

8
from vllm.config import ModelConfig, VllmConfig
9
from vllm.inputs.data import PromptType
10
from vllm.logger import init_logger
11
from vllm.lora.request import LoRARequest
12
13
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import IOProcessor
14
from vllm.pooling_params import PoolingParams
15
from vllm.sampling_params import SamplingParams
16
from vllm.tasks import SupportedTask
17
from vllm.transformers_utils.tokenizer import AnyTokenizer
18
from vllm.utils import Device
19
from vllm.v1.engine import EngineCoreRequest
20
from vllm.v1.engine.processor import Processor
21

22
logger = init_logger(__name__)
23

24
25

class EngineClient(ABC):
26
    """Protocol class for Clients to Engine"""
27

28
29
30
    vllm_config: VllmConfig
    model_config: ModelConfig
    processor: Processor
31
    io_processor: IOProcessor | None
32

33
    @property
34
    @abstractmethod
35
    def is_running(self) -> bool: ...
36
37

    @property
38
    @abstractmethod
39
    def is_stopped(self) -> bool: ...
40
41

    @property
42
    @abstractmethod
43
    def errored(self) -> bool: ...
44

45
    @property
46
    @abstractmethod
47
    def dead_error(self) -> BaseException: ...
48

49
    @abstractmethod
50
    def generate(
51
        self,
52
        prompt: EngineCoreRequest | PromptType,
53
54
        sampling_params: SamplingParams,
        request_id: str,
55
        *,
56
57
58
59
        prompt_text: str | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
60
        priority: int = 0,
61
        data_parallel_rank: int | None = None,
62
    ) -> AsyncGenerator[RequestOutput, None]:
63
        """Generate outputs for a request."""
64
        ...
65

66
    @abstractmethod
67
    def encode(
68
        self,
69
        prompt: PromptType,
70
71
        pooling_params: PoolingParams,
        request_id: str,
72
73
        lora_request: LoRARequest | None = None,
        trace_headers: Mapping[str, str] | None = None,
74
        priority: int = 0,
75
        tokenization_kwargs: dict[str, Any] | None = None,
76
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
77
        """Generate outputs for a request from a pooling model."""
78
        ...
79

80
    @abstractmethod
81
    async def abort(self, request_id: str | Iterable[str]) -> None:
82
83
84
        """Abort a request.

        Args:
85
86
            request_id: The unique id of the request,
                        or an iterable of such ids.
87
        """
88
        ...
89

90
    @abstractmethod
91
92
    async def get_tokenizer(self) -> AnyTokenizer:
        """Get the tokenizer"""
93
        ...
94

95
    @abstractmethod
96
    async def is_tracing_enabled(self) -> bool: ...
97

98
    @abstractmethod
99
    async def do_log_stats(self) -> None: ...
100

101
    @abstractmethod
102
103
    async def check_health(self) -> None:
        """Raise if unhealthy"""
104
        ...
105

106
    @abstractmethod
107
108
109
110
    async def start_profile(self) -> None:
        """Start profiling the engine"""
        ...

111
    @abstractmethod
112
    async def stop_profile(self) -> None:
113
        """Stop profiling the engine"""
114
        ...
115

116
117
118
119
120
    @abstractmethod
    async def reset_mm_cache(self) -> None:
        """Reset the multi-modal cache"""
        ...

121
    @abstractmethod
122
    async def reset_prefix_cache(self, device: Device | None = None) -> None:
123
124
125
        """Reset the prefix cache"""
        ...

126
127
128
129
130
131
    @abstractmethod
    async def sleep(self, level: int = 1) -> None:
        """Sleep the engine"""
        ...

    @abstractmethod
132
    async def wake_up(self, tags: list[str] | None = None) -> None:
133
134
135
        """Wake up the engine"""
        ...

136
137
138
139
140
    @abstractmethod
    async def is_sleeping(self) -> bool:
        """Check whether the engine is sleeping"""
        ...

141
    @abstractmethod
142
    async def add_lora(self, lora_request: LoRARequest) -> bool:
143
144
        """Load a new LoRA adapter into the engine for future requests."""
        ...
145

146
147
148
    async def scale_elastic_ep(
        self, new_data_parallel_size: int, drain_timeout: int = 300
    ) -> None:
149
150
        """Scale the engine"""
        raise NotImplementedError
151

152
153
154
    async def collective_rpc(
        self,
        method: str,
155
        timeout: float | None = None,
156
        args: tuple = (),
157
        kwargs: dict | None = None,
158
    ):
159
160
        """Perform a collective RPC call to the given path."""
        raise NotImplementedError
161
162
163
164

    async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        """Get supported tasks"""
        raise NotImplementedError