protocol.py 12.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import asyncio
from abc import ABC, abstractmethod
6
7
from collections.abc import AsyncGenerator, Iterable, Mapping
from typing import Any, Optional, Union
8

9
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
10
from vllm.config import ModelConfig, VllmConfig
11
from vllm.inputs.data import PromptType, TokensPrompt
12
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
13
from vllm.inputs.preprocess import InputPreprocessor
14
from vllm.logger import init_logger
15
from vllm.lora.request import LoRARequest
16
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
17
from vllm.plugins.io_processors.interface import IOProcessor
18
from vllm.pooling_params import PoolingParams
19
from vllm.sampling_params import BeamSearchParams, SamplingParams
20
from vllm.tasks import SupportedTask
21
from vllm.transformers_utils.tokenizer import AnyTokenizer
22
from vllm.utils import Device, collect_from_async_generator, random_uuid
23
from vllm.v1.engine import EngineCoreRequest
24

25
logger = init_logger(__name__)
26

27
28

class EngineClient(ABC):
29
    """Protocol class for Clients to Engine"""
30
31

    @property
32
    @abstractmethod
33
    def is_running(self) -> bool: ...
34
35

    @property
36
    @abstractmethod
37
    def is_stopped(self) -> bool: ...
38
39

    @property
40
    @abstractmethod
41
    def errored(self) -> bool: ...
42

43
    @property
44
    @abstractmethod
45
    def dead_error(self) -> BaseException: ...
46

47
    @abstractmethod
48
    def generate(
49
        self,
50
        prompt: Union[EngineCoreRequest, PromptType],
51
52
        sampling_params: SamplingParams,
        request_id: str,
53
54
        *,
        prompt_text: Optional[str] = None,
55
        lora_request: Optional[LoRARequest] = None,
56
        tokenization_kwargs: Optional[dict[str, Any]] = None,
57
        trace_headers: Optional[Mapping[str, str]] = None,
58
        priority: int = 0,
59
        data_parallel_rank: Optional[int] = None,
60
    ) -> AsyncGenerator[RequestOutput, None]:
61
        """Generate outputs for a request."""
62
        ...
63

64
65
    async def beam_search(
        self,
66
        prompt: PromptType,
67
68
        request_id: str,
        params: BeamSearchParams,
69
        lora_request: Optional[LoRARequest] = None,
70
71
72
73
74
75
    ) -> AsyncGenerator[RequestOutput, None]:
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        ignore_eos = params.ignore_eos
        temperature = params.temperature
        length_penalty = params.length_penalty
76
        include_stop_str_in_output = params.include_stop_str_in_output
77

78
        preprocessor = await self.get_input_preprocessor()
79
        tokenizer = preprocessor.get_tokenizer()
80
        eos_token_id = tokenizer.eos_token_id
81

82
83
84
        if is_explicit_encoder_decoder_prompt(prompt):
            raise NotImplementedError
        else:
85
            processed_inputs = preprocessor._prompt_to_llm_inputs(prompt)
86

87
88
89
        if processed_inputs["type"] == "embeds":
            raise NotImplementedError

90
91
92
93
94
95
96
97
        # This is a workaround to fix multimodal beam search; this is a
        # bandaid fix for 2 small problems:
        # 1. Multi_modal_data on the processed_inputs currently resolves to
        #    `None`.
        # 2. preprocessing above expands the multimodal placeholders. However,
        #    this happens again in generation, so the double expansion causes
        #    a mismatch.
        # TODO - would be ideal to handle this more gracefully.
98
99
100
101
102
103
104
105
        if isinstance(prompt, str):
            prompt_text = prompt
            prompt_token_ids = []
            multi_modal_data = None
        else:
            prompt_text = prompt.get("prompt")
            prompt_token_ids = prompt.get("prompt_token_ids", [])
            multi_modal_data = prompt.get("multi_modal_data")
106

107
108
        mm_processor_kwargs = processed_inputs.get("mm_processor_kwargs")

109
        tokenized_length = len(prompt_token_ids)
110

111
        sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
112

113
114
115
116
117
        beam_search_params = SamplingParams(
            logprobs=2 * beam_width,
            max_tokens=1,
            temperature=temperature,
        )
118
        all_beams = [
119
120
121
122
123
124
125
126
            BeamSearchSequence(
                tokens=prompt_token_ids,
                cum_logprob=0,
                logprobs=[],
                multi_modal_data=multi_modal_data,
                mm_processor_kwargs=mm_processor_kwargs,
                lora_request=lora_request,
            )
127
        ]
128
129
130
        completed = []

        for _ in range(max_tokens):
131
132
133
134
135
136
137
138
139
140
141
142
143
            prompts_batch, lora_req_batch = zip(
                *[
                    (
                        TokensPrompt(
                            prompt_token_ids=beam.tokens,
                            multi_modal_data=beam.multi_modal_data,
                            mm_processor_kwargs=beam.mm_processor_kwargs,
                        ),
                        beam.lora_request,
                    )
                    for beam in all_beams
                ]
            )
144
145
146
147

            tasks = []

            request_id = f"beam_search-{random_uuid()}"
148
149
150
            for i, (individual_prompt, lora_req) in enumerate(
                zip(prompts_batch, lora_req_batch)
            ):
151
152
153
                request_id_item = f"{request_id}-{i}"
                task = asyncio.create_task(
                    collect_from_async_generator(
154
155
156
157
158
159
160
161
                        self.generate(
                            individual_prompt,
                            beam_search_params,
                            request_id_item,
                            lora_request=lora_req,
                        )
                    )
                )
162
163
164
165
166
167
168
169
170
171
172
173
174
                tasks.append(task)

            output = await asyncio.gather(*tasks)

            output = [x[0] for x in output]

            new_beams = []
            for i, current_beam in enumerate(all_beams):
                result = output[i]

                if result.outputs[0].logprobs is not None:
                    logprobs = result.outputs[0].logprobs[0]
                    for token_id, logprob_obj in logprobs.items():
175
                        if token_id == eos_token_id and not ignore_eos:
176
177
                            completed.append(
                                BeamSearchSequence(
178
179
                                    tokens=current_beam.tokens + [token_id]
                                    if include_stop_str_in_output
180
                                    else current_beam.tokens,
181
182
183
                                    logprobs=current_beam.logprobs + [logprobs],
                                    cum_logprob=current_beam.cum_logprob
                                    + logprob_obj.logprob,
184
                                    finish_reason="stop",
185
186
187
                                    stop_reason=eos_token_id,
                                )
                            )
188
                        else:
189
190
191
                            new_beams.append(
                                BeamSearchSequence(
                                    tokens=current_beam.tokens + [token_id],
192
                                    logprobs=current_beam.logprobs + [logprobs],
193
                                    lora_request=current_beam.lora_request,
194
195
196
197
198
199
                                    cum_logprob=current_beam.cum_logprob
                                    + logprob_obj.logprob,
                                    multi_modal_data=current_beam.multi_modal_data,
                                    mm_processor_kwargs=current_beam.mm_processor_kwargs,
                                )
                            )
200
201
202
203
204
205
206
207
208

            sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
            all_beams = sorted_beams[:beam_width]

        completed.extend(all_beams)
        sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
        best_beams = sorted_completed[:beam_width]

        for beam in best_beams:
209
            if beam.tokens[-1] == eos_token_id and not ignore_eos:
Robert Shaw's avatar
Robert Shaw committed
210
211
212
213
214
                # Skip the eos token in the text.
                tokens = beam.tokens[tokenized_length:-1]
            else:
                tokens = beam.tokens[tokenized_length:]
            beam.text = tokenizer.decode(tokens)
215

216
        yield RequestOutput(
217
            request_id=request_id,
218
            prompt=prompt_text,
219
            outputs=[
220
221
222
223
224
225
226
227
228
229
230
                CompletionOutput(
                    text=beam.text,
                    cumulative_logprob=beam.cum_logprob,
                    token_ids=beam.tokens[tokenized_length:],
                    index=i,
                    logprobs=beam.logprobs,
                    finish_reason=beam.finish_reason
                    if beam.finish_reason is not None
                    else "length",
                    stop_reason=beam.stop_reason,
                )
231
                for (i, beam) in enumerate(best_beams)
232
233
            ],
            finished=True,
234
            prompt_token_ids=prompt_token_ids,
235
236
            prompt_logprobs=None,
        )
237
238

    @abstractmethod
239
    def encode(
240
        self,
241
        prompt: PromptType,
242
243
244
245
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
246
        priority: int = 0,
247
        tokenization_kwargs: Optional[dict[str, Any]] = None,
248
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
249
        """Generate outputs for a request from a pooling model."""
250
        ...
251

252
    @abstractmethod
253
    async def abort(self, request_id: Union[str, Iterable[str]]) -> None:
254
255
256
        """Abort a request.

        Args:
257
258
            request_id: The unique id of the request,
                        or an iterable of such ids.
259
        """
260
        ...
261
262
263
264
265

    @abstractmethod
    async def get_vllm_config(self) -> VllmConfig:
        """Get the vllm configuration of the vLLM engine."""
        ...
266

267
    @abstractmethod
268
269
    async def get_model_config(self) -> ModelConfig:
        """Get the model configuration of the vLLM engine."""
270
        ...
271

272
273
274
275
    @abstractmethod
    async def get_input_preprocessor(self) -> InputPreprocessor:
        """Get the input processor of the vLLM engine."""
        ...
276

277
    @abstractmethod
278
279
    async def get_tokenizer(self) -> AnyTokenizer:
        """Get the tokenizer"""
280
        ...
281

282
283
284
    async def get_io_processor(self) -> IOProcessor:
        raise NotImplementedError

285
    @abstractmethod
286
    async def is_tracing_enabled(self) -> bool: ...
287

288
    @abstractmethod
289
    async def do_log_stats(self) -> None: ...
290

291
    @abstractmethod
292
293
    async def check_health(self) -> None:
        """Raise if unhealthy"""
294
        ...
295

296
    @abstractmethod
297
298
299
300
    async def start_profile(self) -> None:
        """Start profiling the engine"""
        ...

301
    @abstractmethod
302
303
304
    async def stop_profile(self) -> None:
        """Start profiling the engine"""
        ...
305

306
307
308
309
310
    @abstractmethod
    async def reset_mm_cache(self) -> None:
        """Reset the multi-modal cache"""
        ...

311
    @abstractmethod
312
    async def reset_prefix_cache(self, device: Optional[Device] = None) -> None:
313
314
315
        """Reset the prefix cache"""
        ...

316
317
318
319
320
321
    @abstractmethod
    async def sleep(self, level: int = 1) -> None:
        """Sleep the engine"""
        ...

    @abstractmethod
322
    async def wake_up(self, tags: Optional[list[str]] = None) -> None:
323
324
325
        """Wake up the engine"""
        ...

326
327
328
329
330
    @abstractmethod
    async def is_sleeping(self) -> bool:
        """Check whether the engine is sleeping"""
        ...

331
    @abstractmethod
332
    async def add_lora(self, lora_request: LoRARequest) -> bool:
333
334
        """Load a new LoRA adapter into the engine for future requests."""
        ...
335

336
337
338
    async def scale_elastic_ep(
        self, new_data_parallel_size: int, drain_timeout: int = 300
    ) -> None:
339
340
        """Scale the engine"""
        raise NotImplementedError
341

342
343
344
345
346
347
348
    async def collective_rpc(
        self,
        method: str,
        timeout: Optional[float] = None,
        args: tuple = (),
        kwargs: Optional[dict] = None,
    ):
349
350
        """Perform a collective RPC call to the given path."""
        raise NotImplementedError
351
352
353
354

    async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        """Get supported tasks"""
        raise NotImplementedError