serving_completion.py 22.9 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import asyncio
4
import time
5
6
7
from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import Sequence as GenericSequence
from typing import Optional, Union, cast
8

9
import jinja2
10
from fastapi import Request
11

12
from vllm.config import ModelConfig
13
from vllm.engine.protocol import EngineClient
14
from vllm.entrypoints.logger import RequestLogger
15
# yapf conflicts with isort for this block
16
17
18
# yapf: disable
from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
                                              CompletionRequest,
19
20
21
22
                                              CompletionResponse,
                                              CompletionResponseChoice,
                                              CompletionResponseStreamChoice,
                                              CompletionStreamResponse,
23
24
25
                                              ErrorResponse,
                                              RequestResponseMetadata,
                                              UsageInfo)
26
# yapf: enable
27
28
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
                                                    clamp_prompt_logprobs)
29
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
30
31
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
32
from vllm.sampling_params import BeamSearchParams, SamplingParams
33
from vllm.sequence import Logprob
34
from vllm.transformers_utils.tokenizer import AnyTokenizer
35
from vllm.utils import merge_async_iterators
36
37
38
39
40
41

logger = init_logger(__name__)


class OpenAIServingCompletion(OpenAIServing):

42
43
    def __init__(
        self,
44
        engine_client: EngineClient,
45
        model_config: ModelConfig,
46
        models: OpenAIServingModels,
47
48
        *,
        request_logger: Optional[RequestLogger],
49
        return_tokens_as_token_ids: bool = False,
50
    ):
51
        super().__init__(engine_client=engine_client,
52
                         model_config=model_config,
53
                         models=models,
54
55
                         request_logger=request_logger,
                         return_tokens_as_token_ids=return_tokens_as_token_ids)
56
57
58
        self.default_sampling_params = (
            self.model_config.get_diff_sampling_param())
        if self.default_sampling_params:
59
60
61
62
            source = self.model_config.generation_config
            source = "model" if source == "auto" else source
            logger.info("Using default completion sampling params from %s: %s",
                        source, self.default_sampling_params)
63

64
65
66
    async def create_completion(
        self,
        request: CompletionRequest,
67
        raw_request: Optional[Request] = None,
68
    ) -> Union[AsyncGenerator[str, None], CompletionResponse, ErrorResponse]:
69
70
71
72
73
        """Completion API similar to OpenAI's API.

        See https://platform.openai.com/docs/api-reference/completions/create
        for the API specification. This API mimics the OpenAI Completion API.

74
        NOTE: Currently we do not support the following feature:
75
76
77
78
79
80
81
            - suffix (the language models we currently support do not support
            suffix)
        """
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

82
83
84
85
86
87
        # If the engine is dead, raise the engine's DEAD_ERROR.
        # This is required for the streaming case, where we return a
        # success status before we actually start generating text :).
        if self.engine_client.errored:
            raise self.engine_client.dead_error

88
        # Return error for unsupported features.
89
90
91
92
        if request.suffix is not None:
            return self.create_error_response(
                "suffix is not currently supported")

93
        request_id = f"cmpl-{self._base_request_id(raw_request)}"
94
        created_time = int(time.time())
95

96
97
98
99
        request_metadata = RequestResponseMetadata(request_id=request_id)
        if raw_request:
            raw_request.state.request_metadata = request_metadata

100
        try:
101
102
103
104
105
            (
                lora_request,
                prompt_adapter_request,
            ) = self._maybe_get_adapters(request)

106
            tokenizer = await self.engine_client.get_tokenizer(lora_request)
107

108
            request_prompts, engine_prompts = await self._preprocess_completion(
109
110
111
112
113
114
115
116
117
                request,
                tokenizer,
                request.prompt,
                truncate_prompt_tokens=request.truncate_prompt_tokens,
                add_special_tokens=request.add_special_tokens,
            )
        except ValueError as e:
            logger.exception("Error in preprocessing prompt inputs")
            return self.create_error_response(str(e))
118
119
120
121
122
123
124
125
126
        except TypeError as e:
            logger.exception("Error in preprocessing prompt inputs")
            return self.create_error_response(str(e))
        except RuntimeError as e:
            logger.exception("Error in preprocessing prompt inputs")
            return self.create_error_response(str(e))
        except jinja2.TemplateError as e:
            logger.exception("Error in preprocessing prompt inputs")
            return self.create_error_response(str(e))
127

128
        # Schedule the request and get the result generator.
129
        generators: list[AsyncGenerator[RequestOutput, None]] = []
130
131
        try:
            for i, engine_prompt in enumerate(engine_prompts):
132
133
                sampling_params: Union[SamplingParams, BeamSearchParams]
                default_max_tokens = self.max_model_len - len(
134
                    engine_prompt["prompt_token_ids"])
135
136
                if request.use_beam_search:
                    sampling_params = request.to_beam_search_params(
137
                        default_max_tokens, self.default_sampling_params)
138
139
                else:
                    sampling_params = request.to_sampling_params(
140
                        default_max_tokens,
141
                        self.model_config.logits_processor_pattern,
142
                        self.default_sampling_params)
143

144
145
146
                request_id_item = f"{request_id}-{i}"

                self._log_inputs(request_id_item,
147
                                 request_prompts[i],
148
149
150
                                 params=sampling_params,
                                 lora_request=lora_request,
                                 prompt_adapter_request=prompt_adapter_request)
151

152
                trace_headers = (None if raw_request is None else await
153
                                 self._get_trace_headers(raw_request.headers))
154

155
156
                if isinstance(sampling_params, BeamSearchParams):
                    generator = self.engine_client.beam_search(
157
                        prompt=engine_prompt,
158
159
                        request_id=request_id,
                        params=sampling_params,
160
                    )
161
162
                else:
                    generator = self.engine_client.generate(
163
                        engine_prompt,
164
165
166
167
168
169
170
                        sampling_params,
                        request_id_item,
                        lora_request=lora_request,
                        prompt_adapter_request=prompt_adapter_request,
                        trace_headers=trace_headers,
                        priority=request.priority,
                    )
171
172

                generators.append(generator)
173
        except ValueError as e:
174
            # TODO: Use a vllm-specific Validation Error
175
            return self.create_error_response(str(e))
176

177
        result_generator = merge_async_iterators(*generators)
178

179
        model_name = self._get_model_name(request.model, lora_request)
180
181
        num_prompts = len(engine_prompts)

182
183
184
185
186
187
        # Similar to the OpenAI API, when n != best_of, we do not stream the
        # results. Noting that best_of is only supported in V0. In addition,
        # we do not stream the results when use beam search.
        stream = (request.stream
                  and (request.best_of is None or request.n == request.best_of)
                  and not request.use_beam_search)
188
189
190

        # Streaming response
        if stream:
191
192
193
194
195
196
            return self.completion_stream_generator(
                request,
                result_generator,
                request_id,
                created_time,
                model_name,
197
                num_prompts=num_prompts,
198
199
                tokenizer=tokenizer,
                request_metadata=request_metadata)
200
201

        # Non-streaming response
202
        final_res_batch: list[Optional[RequestOutput]] = [None] * num_prompts
203
204
205
        try:
            async for i, res in result_generator:
                final_res_batch[i] = res
206
207
208
209
210
211
212
213

            for i, final_res in enumerate(final_res_batch):
                assert final_res is not None

                # The output should contain the input text
                # We did not pass it into vLLM engine to avoid being redundant
                # with the inputs token IDs
                if final_res.prompt is None:
214
                    final_res.prompt = request_prompts[i]["prompt"]
215

216
            final_res_batch_checked = cast(list[RequestOutput],
217
218
                                           final_res_batch)

219
            response = self.request_output_to_completion_response(
220
221
222
223
224
225
                final_res_batch_checked,
                request,
                request_id,
                created_time,
                model_name,
                tokenizer,
226
                request_metadata,
227
            )
228
229
        except asyncio.CancelledError:
            return self.create_error_response("Client disconnected")
230
231
232
        except ValueError as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))
233

234
235
        # When user requests streaming but we don't stream, we still need to
        # return a streaming response with a single event.
236
        if request.stream:
237
            response_json = response.model_dump_json()
238
239
240
241
242
243
244
245

            async def fake_stream_generator() -> AsyncGenerator[str, None]:
                yield f"data: {response_json}\n\n"
                yield "data: [DONE]\n\n"

            return fake_stream_generator()

        return response
246
247
248
249

    async def completion_stream_generator(
        self,
        request: CompletionRequest,
250
        result_generator: AsyncIterator[tuple[int, RequestOutput]],
251
252
253
254
        request_id: str,
        created_time: int,
        model_name: str,
        num_prompts: int,
255
        tokenizer: AnyTokenizer,
256
        request_metadata: RequestResponseMetadata,
257
    ) -> AsyncGenerator[str, None]:
258
        num_choices = 1 if request.n is None else request.n
259
        previous_text_lens = [0] * num_choices * num_prompts
260
261
        previous_num_tokens = [0] * num_choices * num_prompts
        has_echoed = [False] * num_choices * num_prompts
262
        num_prompt_tokens = [0] * num_prompts
263

264
265
266
267
268
269
270
271
        stream_options = request.stream_options
        if stream_options:
            include_usage = stream_options.include_usage
            include_continuous_usage = include_usage and \
                                       stream_options.continuous_usage_stats
        else:
            include_usage, include_continuous_usage = False, False

272
273
        try:
            async for prompt_idx, res in result_generator:
274
275
276
277
                prompt_token_ids = res.prompt_token_ids
                prompt_logprobs = res.prompt_logprobs
                prompt_text = res.prompt

278
279
280
281
                # Prompt details are excluded from later streamed outputs
                if res.prompt_token_ids is not None:
                    num_prompt_tokens[prompt_idx] = len(res.prompt_token_ids)

282
                delta_token_ids: GenericSequence[int]
283
                out_logprobs: Optional[GenericSequence[Optional[dict[
284
                    int, Logprob]]]]
285
286

                for output in res.outputs:
287
                    i = output.index + prompt_idx * num_choices
288

289
                    assert request.max_tokens is not None
290
                    if request.echo and not has_echoed[i]:
291
                        assert prompt_token_ids is not None
292
                        assert prompt_text is not None
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
                        if request.max_tokens == 0:
                            # only return the prompt
                            delta_text = prompt_text
                            delta_token_ids = prompt_token_ids
                            out_logprobs = prompt_logprobs
                        else:
                            assert prompt_logprobs is not None
                            # echo the prompt and first token
                            delta_text = prompt_text + output.text
                            delta_token_ids = [
                                *prompt_token_ids, *output.token_ids
                            ]
                            out_logprobs = [
                                *prompt_logprobs,
                                *(output.logprobs or []),
                            ]
309
310
311
                        has_echoed[i] = True
                    else:
                        # return just the delta
312
313
314
                        delta_text = output.text
                        delta_token_ids = output.token_ids
                        out_logprobs = output.logprobs
315

316
317
318
319
320
                        if not delta_text and not delta_token_ids \
                            and not previous_num_tokens[i]:
                            # Chunked prefill case, don't return empty chunks
                            continue

321
                    if request.logprobs is not None:
322
323
                        assert out_logprobs is not None, (
                            "Did not output logprobs")
324
                        logprobs = self._create_completion_logprobs(
325
                            token_ids=delta_token_ids,
326
                            top_logprobs=out_logprobs,
327
                            num_output_top_logprobs=request.logprobs,
328
                            tokenizer=tokenizer,
329
                            initial_text_offset=previous_text_lens[i],
330
331
                            return_as_token_id=request.
                            return_tokens_as_token_ids,
332
333
334
335
                        )
                    else:
                        logprobs = None

336
337
                    previous_text_lens[i] += len(output.text)
                    previous_num_tokens[i] += len(output.token_ids)
338
                    finish_reason = output.finish_reason
339
                    stop_reason = output.stop_reason
340
341

                    chunk = CompletionStreamResponse(
342
343
344
345
346
347
348
349
350
                        id=request_id,
                        created=created_time,
                        model=model_name,
                        choices=[
                            CompletionResponseStreamChoice(
                                index=i,
                                text=delta_text,
                                logprobs=logprobs,
                                finish_reason=finish_reason,
351
                                stop_reason=stop_reason,
352
                            )
353
                        ])
354
355
356
357
358
359
360
361
                    if include_continuous_usage:
                        prompt_tokens = num_prompt_tokens[prompt_idx]
                        completion_tokens = previous_num_tokens[i]
                        chunk.usage = UsageInfo(
                            prompt_tokens=prompt_tokens,
                            completion_tokens=completion_tokens,
                            total_tokens=prompt_tokens + completion_tokens,
                        )
362

363
                    response_json = chunk.model_dump_json(exclude_unset=False)
364
                    yield f"data: {response_json}\n\n"
365

366
367
368
369
370
371
372
373
            total_prompt_tokens = sum(num_prompt_tokens)
            total_completion_tokens = sum(previous_num_tokens)
            final_usage_info = UsageInfo(
                prompt_tokens=total_prompt_tokens,
                completion_tokens=total_completion_tokens,
                total_tokens=total_prompt_tokens + total_completion_tokens)

            if include_usage:
374
375
376
377
378
                final_usage_chunk = CompletionStreamResponse(
                    id=request_id,
                    created=created_time,
                    model=model_name,
                    choices=[],
379
                    usage=final_usage_info,
380
381
                )
                final_usage_data = (final_usage_chunk.model_dump_json(
382
                    exclude_unset=False, exclude_none=True))
383
384
                yield f"data: {final_usage_data}\n\n"

385
            # report to FastAPI middleware aggregate usage across all choices
386
            request_metadata.final_usage_info = final_usage_info
387

388
        except Exception as e:
389
390
391
392
393
394
395
            # TODO: Use a vllm-specific Validation Error
            data = self.create_streaming_error_response(str(e))
            yield f"data: {data}\n\n"
        yield "data: [DONE]\n\n"

    def request_output_to_completion_response(
        self,
396
        final_res_batch: list[RequestOutput],
397
398
399
400
        request: CompletionRequest,
        request_id: str,
        created_time: int,
        model_name: str,
401
        tokenizer: AnyTokenizer,
402
        request_metadata: RequestResponseMetadata,
403
    ) -> CompletionResponse:
404
        choices: list[CompletionResponseChoice] = []
405
406
        num_prompt_tokens = 0
        num_generated_tokens = 0
407

408
409
        for final_res in final_res_batch:
            prompt_token_ids = final_res.prompt_token_ids
410
            assert prompt_token_ids is not None
411
            prompt_logprobs = clamp_prompt_logprobs(final_res.prompt_logprobs)
412
413
            prompt_text = final_res.prompt

414
            token_ids: GenericSequence[int]
415
            out_logprobs: Optional[GenericSequence[Optional[dict[int,
416
417
                                                                 Logprob]]]]

418
            for output in final_res.outputs:
419
                assert request.max_tokens is not None
420
                if request.echo:
421
                    assert prompt_text is not None
422
423
424
425
                    if request.max_tokens == 0:
                        token_ids = prompt_token_ids
                        out_logprobs = prompt_logprobs
                        output_text = prompt_text
426
                    else:
427
428
429
430
431
432
433
434
435
436
437
438
439
                        token_ids = [*prompt_token_ids, *output.token_ids]

                        if request.logprobs is None:
                            out_logprobs = None
                        else:
                            assert prompt_logprobs is not None
                            assert output.logprobs is not None
                            out_logprobs = [
                                *prompt_logprobs,
                                *output.logprobs,
                            ]

                        output_text = prompt_text + output.text
440
441
                else:
                    token_ids = output.token_ids
442
                    out_logprobs = output.logprobs
443
444
445
                    output_text = output.text

                if request.logprobs is not None:
446
                    assert out_logprobs is not None, "Did not output logprobs"
447
                    logprobs = self._create_completion_logprobs(
448
                        token_ids=token_ids,
449
                        top_logprobs=out_logprobs,
450
                        tokenizer=tokenizer,
451
                        num_output_top_logprobs=request.logprobs,
452
                        return_as_token_id=request.return_tokens_as_token_ids,
453
454
455
456
457
458
459
460
461
                    )
                else:
                    logprobs = None

                choice_data = CompletionResponseChoice(
                    index=len(choices),
                    text=output_text,
                    logprobs=logprobs,
                    finish_reason=output.finish_reason,
462
                    stop_reason=output.stop_reason,
463
                    prompt_logprobs=final_res.prompt_logprobs,
464
465
466
                )
                choices.append(choice_data)

467
468
                num_generated_tokens += len(output.token_ids)

469
470
471
472
473
474
475
476
            num_prompt_tokens += len(prompt_token_ids)

        usage = UsageInfo(
            prompt_tokens=num_prompt_tokens,
            completion_tokens=num_generated_tokens,
            total_tokens=num_prompt_tokens + num_generated_tokens,
        )

477
478
        request_metadata.final_usage_info = usage

479
480
481
482
483
484
485
        return CompletionResponse(
            id=request_id,
            created=created_time,
            model=model_name,
            choices=choices,
            usage=usage,
        )
486
487
488
489

    def _create_completion_logprobs(
        self,
        token_ids: GenericSequence[int],
490
        top_logprobs: GenericSequence[Optional[dict[int, Logprob]]],
491
        num_output_top_logprobs: int,
492
        tokenizer: AnyTokenizer,
493
        initial_text_offset: int = 0,
494
        return_as_token_id: Optional[bool] = None,
495
496
    ) -> CompletionLogProbs:
        """Create logprobs for OpenAI Completion API."""
497
498
499
500
        out_text_offset: list[int] = []
        out_token_logprobs: list[Optional[float]] = []
        out_tokens: list[str] = []
        out_top_logprobs: list[Optional[dict[str, float]]] = []
501
502
503

        last_token_len = 0

504
505
        should_return_as_token_id = return_as_token_id if \
            return_as_token_id is not None else self.return_tokens_as_token_ids
506
507
508
        for i, token_id in enumerate(token_ids):
            step_top_logprobs = top_logprobs[i]
            if step_top_logprobs is None:
509
                token = tokenizer.decode(token_id)
510
                if should_return_as_token_id:
511
                    token = f"token_id:{token_id}"
512

513
514
515
516
                out_tokens.append(token)
                out_token_logprobs.append(None)
                out_top_logprobs.append(None)
            else:
517
518
                step_token = step_top_logprobs[token_id]

519
                token = self._get_decoded_token(
520
                    step_token,
521
522
                    token_id,
                    tokenizer,
523
                    return_as_token_id=should_return_as_token_id,
524
525
526
                )
                token_logprob = max(step_token.logprob, -9999.0)

527
528
529
530
531
532
533
534
535
536
                out_tokens.append(token)
                out_token_logprobs.append(token_logprob)

                # makes sure to add the top num_output_top_logprobs + 1
                # logprobs, as defined in the openai API
                # (cf. https://github.com/openai/openai-openapi/blob/
                # 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153)
                out_top_logprobs.append({
                    # Convert float("-inf") to the
                    # JSON-serializable float that OpenAI uses
537
538
539
                    self._get_decoded_token(top_lp[1],
                                            top_lp[0],
                                            tokenizer,
540
                                            return_as_token_id=should_return_as_token_id):
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
                    max(top_lp[1].logprob, -9999.0)
                    for i, top_lp in enumerate(step_top_logprobs.items())
                    if num_output_top_logprobs >= i
                })

            if len(out_text_offset) == 0:
                out_text_offset.append(initial_text_offset)
            else:
                out_text_offset.append(out_text_offset[-1] + last_token_len)
            last_token_len = len(token)

        return CompletionLogProbs(
            text_offset=out_text_offset,
            token_logprobs=out_token_logprobs,
            tokens=out_tokens,
            top_logprobs=out_top_logprobs,
        )