protocol.py 12 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import asyncio
from abc import ABC, abstractmethod
6
from typing import AsyncGenerator, Iterable, Mapping, Optional, Union
7

8
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
9
from vllm.config import DecodingConfig, ModelConfig, VllmConfig
10
from vllm.core.scheduler import SchedulerOutputs
11
from vllm.inputs.data import PromptType, TokensPrompt
12
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
13
from vllm.inputs.preprocess import InputPreprocessor
14
from vllm.logger import init_logger
15
from vllm.lora.request import LoRARequest
16
from vllm.model_executor.layers.sampler import SamplerOutput
17
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
18
from vllm.pooling_params import PoolingParams
19
from vllm.sampling_params import BeamSearchParams, SamplingParams
20
from vllm.transformers_utils.tokenizer import AnyTokenizer
21
from vllm.utils import Device, collect_from_async_generator, random_uuid
22

23
logger = init_logger(__name__)
24

25
26

class EngineClient(ABC):
27
    """Protocol class for Clients to Engine"""
28
29

    @property
30
    @abstractmethod
31
32
33
34
    def is_running(self) -> bool:
        ...

    @property
35
    @abstractmethod
36
37
38
39
    def is_stopped(self) -> bool:
        ...

    @property
40
    @abstractmethod
41
42
43
    def errored(self) -> bool:
        ...

44
    @property
45
    @abstractmethod
46
47
    def dead_error(self) -> BaseException:
        ...
48

49
    @abstractmethod
50
    def generate(
51
        self,
52
        prompt: PromptType,
53
54
55
56
        sampling_params: SamplingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
57
        priority: int = 0,
58
    ) -> AsyncGenerator[RequestOutput, None]:
59
        """Generate outputs for a request."""
60
        ...
61

62
63
    async def beam_search(
        self,
64
        prompt: PromptType,
65
66
        request_id: str,
        params: BeamSearchParams,
67
        lora_request: Optional[LoRARequest] = None,
68
69
70
71
72
73
74
    ) -> AsyncGenerator[RequestOutput, None]:

        beam_width = params.beam_width
        max_tokens = params.max_tokens
        ignore_eos = params.ignore_eos
        temperature = params.temperature
        length_penalty = params.length_penalty
75
        include_stop_str_in_output = params.include_stop_str_in_output
76

77
78
79
        preprocessor = await self.get_input_preprocessor()
        tokenizer_group = preprocessor.get_tokenizer_group()
        tokenizer = await tokenizer_group.get_lora_tokenizer_async()
80

81
82
83
        if is_explicit_encoder_decoder_prompt(prompt):
            raise NotImplementedError
        else:
84
            processed_inputs = preprocessor._prompt_to_llm_inputs(prompt)
85

86
87
88
        if processed_inputs["type"] == "embeds":
            raise NotImplementedError

89
90
91
92
93
94
95
96
97
98
99
        # This is a workaround to fix multimodal beam search; this is a
        # bandaid fix for 2 small problems:
        # 1. Multi_modal_data on the processed_inputs currently resolves to
        #    `None`.
        # 2. preprocessing above expands the multimodal placeholders. However,
        #    this happens again in generation, so the double expansion causes
        #    a mismatch.
        # TODO - would be ideal to handle this more gracefully.
        prompt_token_ids = prompt.get("prompt_token_ids")
        multi_modal_data = prompt.get("multi_modal_data")

100
101
102
        prompt_text = processed_inputs.get("prompt")
        mm_processor_kwargs = processed_inputs.get("mm_processor_kwargs")

103
        tokenized_length = len(prompt_token_ids)
104
105
106
107

        sort_beams_key = create_sort_beams_key_function(
            tokenizer.eos_token_id, length_penalty)

108
109
110
111
112
        beam_search_params = SamplingParams(
            logprobs=2 * beam_width,
            max_tokens=1,
            temperature=temperature,
        )
113
        all_beams = [
114
115
            BeamSearchSequence(tokens=prompt_token_ids,
                               cum_logprob=0,
116
                               logprobs=[],
117
                               multi_modal_data=multi_modal_data,
118
119
                               mm_processor_kwargs=mm_processor_kwargs,
                               lora_request=lora_request)
120
        ]
121
122
123
        completed = []

        for _ in range(max_tokens):
124
            prompts_batch, lora_req_batch = zip(*[(
125
126
                TokensPrompt(prompt_token_ids=beam.tokens,
                             multi_modal_data=beam.multi_modal_data,
127
128
129
                             mm_processor_kwargs=beam.mm_processor_kwargs),
                beam.lora_request,
            ) for beam in all_beams])
130
131
132
133

            tasks = []

            request_id = f"beam_search-{random_uuid()}"
134
135
            for i, (individual_prompt,
                    lora_req) in enumerate(zip(prompts_batch, lora_req_batch)):
136
137
138
                request_id_item = f"{request_id}-{i}"
                task = asyncio.create_task(
                    collect_from_async_generator(
139
140
141
142
                        self.generate(individual_prompt,
                                      beam_search_params,
                                      request_id_item,
                                      lora_request=lora_req)))
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
                tasks.append(task)

            output = await asyncio.gather(*tasks)

            output = [x[0] for x in output]

            new_beams = []
            for i, current_beam in enumerate(all_beams):
                result = output[i]

                if result.outputs[0].logprobs is not None:
                    logprobs = result.outputs[0].logprobs[0]
                    for token_id, logprob_obj in logprobs.items():
                        if token_id == tokenizer.eos_token_id and \
                            not ignore_eos:
158
159
160
161
162
163
164
165
166
167
168
                            completed.append(
                                BeamSearchSequence(
                                    tokens=current_beam.tokens +
                                    [token_id] if include_stop_str_in_output
                                    else current_beam.tokens,
                                    logprobs=current_beam.logprobs +
                                    [logprobs],
                                    cum_logprob=current_beam.cum_logprob +
                                    logprob_obj.logprob,
                                    finish_reason="stop",
                                    stop_reason=tokenizer.eos_token_id))
169
                        else:
170
171
172
173
174
                            new_beams.append(
                                BeamSearchSequence(
                                    tokens=current_beam.tokens + [token_id],
                                    logprobs=current_beam.logprobs +
                                    [logprobs],
175
                                    lora_request=current_beam.lora_request,
176
177
178
179
180
181
                                    cum_logprob=current_beam.cum_logprob +
                                    logprob_obj.logprob,
                                    multi_modal_data=current_beam.
                                    multi_modal_data,
                                    mm_processor_kwargs=current_beam.
                                    mm_processor_kwargs))
182
183
184
185
186
187
188
189
190

            sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
            all_beams = sorted_beams[:beam_width]

        completed.extend(all_beams)
        sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
        best_beams = sorted_completed[:beam_width]

        for beam in best_beams:
Robert Shaw's avatar
Robert Shaw committed
191
192
193
194
195
196
            if (beam.tokens[-1] == tokenizer.eos_token_id and not ignore_eos):
                # Skip the eos token in the text.
                tokens = beam.tokens[tokenized_length:-1]
            else:
                tokens = beam.tokens[tokenized_length:]
            beam.text = tokenizer.decode(tokens)
197
198
199

        beam_search_output = RequestOutput(
            request_id=request_id,
200
            prompt=prompt_text,
201
            outputs=[
202
203
204
205
206
207
208
209
210
                CompletionOutput(text=beam.text,
                                 cumulative_logprob=beam.cum_logprob,
                                 token_ids=beam.tokens[tokenized_length:],
                                 index=i,
                                 logprobs=beam.logprobs,
                                 finish_reason=beam.finish_reason if
                                 beam.finish_reason is not None else "length",
                                 stop_reason=beam.stop_reason)
                for (i, beam) in enumerate(best_beams)
211
212
            ],
            finished=True,
213
            prompt_token_ids=prompt_token_ids,
214
215
216
217
218
            prompt_logprobs=None)

        yield beam_search_output

    @abstractmethod
219
    def encode(
220
        self,
221
        prompt: PromptType,
222
223
224
225
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
226
        priority: int = 0,
227
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
228
        """Generate outputs for a request from a pooling model."""
229
        ...
230

231
    @abstractmethod
232
    async def abort(self, request_id: Union[str, Iterable[str]]) -> None:
233
234
235
        """Abort a request.

        Args:
236
237
            request_id: The unique id of the request,
                        or an iterable of such ids.
238
        """
239
        ...
240
241
242
243
244

    @abstractmethod
    async def get_vllm_config(self) -> VllmConfig:
        """Get the vllm configuration of the vLLM engine."""
        ...
245

246
    @abstractmethod
247
248
    async def get_model_config(self) -> ModelConfig:
        """Get the model configuration of the vLLM engine."""
249
        ...
250

251
    @abstractmethod
252
253
    async def get_decoding_config(self) -> DecodingConfig:
        """Get the decoding configuration of the vLLM engine."""
254
255
256
257
258
259
        ...

    @abstractmethod
    async def get_input_preprocessor(self) -> InputPreprocessor:
        """Get the input processor of the vLLM engine."""
        ...
260

261
    @abstractmethod
262
263
264
    async def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
265
266
267
    ) -> AnyTokenizer:
        """Get the appropriate tokenizer for the request"""
        ...
268

269
    @abstractmethod
270
    async def is_tracing_enabled(self) -> bool:
271
        ...
272

273
    @abstractmethod
274
275
276
    async def do_log_stats(
        self,
        scheduler_outputs: Optional[SchedulerOutputs] = None,
277
        model_output: Optional[list[SamplerOutput]] = None,
278
    ) -> None:
279
        ...
280

281
    @abstractmethod
282
283
    async def check_health(self) -> None:
        """Raise if unhealthy"""
284
        ...
285

286
    @abstractmethod
287
288
289
290
    async def start_profile(self) -> None:
        """Start profiling the engine"""
        ...

291
    @abstractmethod
292
293
294
    async def stop_profile(self) -> None:
        """Start profiling the engine"""
        ...
295

296
297
298
299
300
    @abstractmethod
    async def reset_mm_cache(self) -> None:
        """Reset the multi-modal cache"""
        ...

301
    @abstractmethod
302
303
    async def reset_prefix_cache(self,
                                 device: Optional[Device] = None) -> None:
304
305
306
        """Reset the prefix cache"""
        ...

307
308
309
310
311
312
    @abstractmethod
    async def sleep(self, level: int = 1) -> None:
        """Sleep the engine"""
        ...

    @abstractmethod
313
    async def wake_up(self, tags: Optional[list[str]] = None) -> None:
314
315
316
        """Wake up the engine"""
        ...

317
318
319
320
321
    @abstractmethod
    async def is_sleeping(self) -> bool:
        """Check whether the engine is sleeping"""
        ...

322
323
324
325
    @abstractmethod
    async def add_lora(self, lora_request: LoRARequest) -> None:
        """Load a new LoRA adapter into the engine for future requests."""
        ...
326
327
328
329
330
331

    async def scale_elastic_ep(self,
                               new_data_parallel_size: int,
                               drain_timeout: int = 300) -> None:
        """Scale the engine"""
        raise NotImplementedError
332
333
334
335
336
337
338
339

    async def collective_rpc(self,
                             method: str,
                             timeout: Optional[float] = None,
                             args: tuple = (),
                             kwargs: Optional[dict] = None):
        """Perform a collective RPC call to the given path."""
        raise NotImplementedError