serving.py 27.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
import time
6
7
from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import Sequence as GenericSequence
8
from typing import cast
9

10
import jinja2
11
from fastapi import Request
12

13
from vllm.engine.protocol import EngineClient
14
from vllm.entrypoints.logger import RequestLogger
15
from vllm.entrypoints.openai.completion.protocol import (
16
17
18
19
20
21
    CompletionLogProbs,
    CompletionRequest,
    CompletionResponse,
    CompletionResponseChoice,
    CompletionResponseStreamChoice,
    CompletionStreamResponse,
22
23
)
from vllm.entrypoints.openai.engine.protocol import (
24
25
26
27
28
    ErrorResponse,
    PromptTokenUsageInfo,
    RequestResponseMetadata,
    UsageInfo,
)
29
from vllm.entrypoints.openai.engine.serving import (
30
31
32
33
    GenerationError,
    OpenAIServing,
    clamp_prompt_logprobs,
)
34
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
35
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
36
from vllm.exceptions import VLLMValidationError
37
from vllm.logger import init_logger
38
from vllm.logprobs import Logprob
39
from vllm.outputs import RequestOutput
40
from vllm.renderers.inputs import TokPrompt
41
from vllm.sampling_params import BeamSearchParams, SamplingParams
42
from vllm.tokenizers import TokenizerLike
43
44
from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.collection_utils import as_list
45
46
47
48
49

logger = init_logger(__name__)


class OpenAIServingCompletion(OpenAIServing):
50
51
    def __init__(
        self,
52
        engine_client: EngineClient,
53
        models: OpenAIServingModels,
54
        *,
55
        request_logger: RequestLogger | None,
56
        return_tokens_as_token_ids: bool = False,
57
        enable_prompt_tokens_details: bool = False,
58
        enable_force_include_usage: bool = False,
59
        log_error_stack: bool = False,
60
    ):
61
62
63
64
65
        super().__init__(
            engine_client=engine_client,
            models=models,
            request_logger=request_logger,
            return_tokens_as_token_ids=return_tokens_as_token_ids,
66
            log_error_stack=log_error_stack,
67
        )
68

69
        self.enable_prompt_tokens_details = enable_prompt_tokens_details
70
        self.enable_force_include_usage = enable_force_include_usage
71
72

        self.default_sampling_params = self.model_config.get_diff_sampling_param()
73

74
    async def render_completion_request(
75
76
        self,
        request: CompletionRequest,
77
    ) -> list[TokPrompt] | ErrorResponse:
78
79
        """
        render completion request by validating and preprocessing inputs.
80

81
82
83
        Returns:
            A list of engine_prompts on success,
            or an ErrorResponse on failure.
84
85
86
87
88
        """
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

89
90
91
92
93
94
        # If the engine is dead, raise the engine's DEAD_ERROR.
        # This is required for the streaming case, where we return a
        # success status before we actually start generating text :).
        if self.engine_client.errored:
            raise self.engine_client.dead_error

95
        # Return error for unsupported features.
96
        if request.suffix is not None:
97
            return self.create_error_response("suffix is not currently supported")
98

99
        if request.echo and request.prompt_embeds is not None:
100
            return self.create_error_response("Echo is unsupported with prompt embeds.")
101

102
        if request.prompt_logprobs is not None and request.prompt_embeds is not None:
103
            return self.create_error_response(
104
105
                "prompt_logprobs is not compatible with prompt embeds."
            )
106

107
        try:
108
109
110
            engine_prompts = await self._preprocess_completion(
                request,
                prompt_input=request.prompt,
111
                prompt_embeds=request.prompt_embeds,
112
            )
113
        except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e:
114
            logger.exception("Error in preprocessing prompt inputs")
115
            return self.create_error_response(e)
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149

        return engine_prompts

    async def create_completion(
        self,
        request: CompletionRequest,
        raw_request: Request | None = None,
    ) -> AsyncGenerator[str, None] | CompletionResponse | ErrorResponse:
        """Completion API similar to OpenAI's API.

        See https://platform.openai.com/docs/api-reference/completions/create
        for the API specification. This API mimics the OpenAI Completion API.

        NOTE: Currently we do not support the following feature:
            - suffix (the language models we currently support do not support
            suffix)
        """
        result = await self.render_completion_request(request)
        if isinstance(result, ErrorResponse):
            return result

        engine_prompts = result

        request_id = f"cmpl-{self._base_request_id(raw_request, request.request_id)}"
        created_time = int(time.time())

        request_metadata = RequestResponseMetadata(request_id=request_id)
        if raw_request:
            raw_request.state.request_metadata = request_metadata

        try:
            lora_request = self._maybe_get_adapters(request)
        except (ValueError, TypeError, RuntimeError) as e:
            logger.exception("Error preparing request components")
150
            return self.create_error_response(e)
151

152
153
154
        # Extract data_parallel_rank from header (router can inject it)
        data_parallel_rank = self._get_data_parallel_rank(raw_request)

155
        # Schedule the request and get the result generator.
156
        max_model_len = self.model_config.max_model_len
157
        generators: list[AsyncGenerator[RequestOutput, None]] = []
158
159
        try:
            for i, engine_prompt in enumerate(engine_prompts):
160
                prompt_text = self._extract_prompt_text(engine_prompt)
161
162

                max_tokens = get_max_tokens(
163
                    max_model_len,
164
                    request.max_tokens,
165
166
                    self._extract_prompt_len(engine_prompt),
                    self.default_sampling_params,
167
                )
168

169
                sampling_params: SamplingParams | BeamSearchParams
170
171
                if request.use_beam_search:
                    sampling_params = request.to_beam_search_params(
172
173
                        max_tokens, self.default_sampling_params
                    )
174
175
                else:
                    sampling_params = request.to_sampling_params(
176
177
178
                        max_tokens,
                        self.default_sampling_params,
                    )
179

180
181
                request_id_item = f"{request_id}-{i}"

182
183
                self._log_inputs(
                    request_id_item,
184
                    engine_prompt,
185
186
187
                    params=sampling_params,
                    lora_request=lora_request,
                )
188

189
190
191
192
193
                trace_headers = (
                    None
                    if raw_request is None
                    else await self._get_trace_headers(raw_request.headers)
                )
194

195
                if isinstance(sampling_params, BeamSearchParams):
196
                    generator = self.beam_search(
197
                        prompt=engine_prompt,
198
199
                        request_id=request_id,
                        params=sampling_params,
200
                        lora_request=lora_request,
201
                        trace_headers=trace_headers,
202
                    )
203
                else:
204
205
206
207
                    tok_params = request.build_tok_params(self.model_config)
                    tokenization_kwargs = tok_params.get_encode_kwargs()

                    engine_request = self.input_processor.process_inputs(
208
209
210
211
                        request_id_item,
                        engine_prompt,
                        sampling_params,
                        lora_request=lora_request,
212
                        tokenization_kwargs=tokenization_kwargs,
213
214
                        trace_headers=trace_headers,
                        priority=request.priority,
215
                        data_parallel_rank=data_parallel_rank,
216
                    )
217

218
                    generator = self.engine_client.generate(
219
                        engine_request,
220
221
222
223
224
                        sampling_params,
                        request_id_item,
                        lora_request=lora_request,
                        trace_headers=trace_headers,
                        priority=request.priority,
225
226
                        prompt_text=prompt_text,
                        tokenization_kwargs=tokenization_kwargs,
227
                        data_parallel_rank=data_parallel_rank,
228
                    )
229
230

                generators.append(generator)
231
        except ValueError as e:
232
            return self.create_error_response(e)
233

234
        result_generator = merge_async_iterators(*generators)
235

236
        model_name = self.models.model_name(lora_request)
237
238
        num_prompts = len(engine_prompts)

239
240
        # We do not stream the results when using beam search.
        stream = request.stream and not request.use_beam_search
241
242

        # Streaming response
243
244
        tokenizer = self.renderer.tokenizer

245
        if stream:
246
247
            return self.completion_stream_generator(
                request,
248
                engine_prompts,
249
250
251
252
                result_generator,
                request_id,
                created_time,
                model_name,
253
                num_prompts=num_prompts,
254
                tokenizer=tokenizer,
255
                request_metadata=request_metadata,
256
            )
257
258

        # Non-streaming response
259
        final_res_batch: list[RequestOutput | None] = [None] * num_prompts
260
261
262
        try:
            async for i, res in result_generator:
                final_res_batch[i] = res
263
264
265
266
267
268
269
270

            for i, final_res in enumerate(final_res_batch):
                assert final_res is not None

                # The output should contain the input text
                # We did not pass it into vLLM engine to avoid being redundant
                # with the inputs token IDs
                if final_res.prompt is None:
271
                    engine_prompt = engine_prompts[i]
272
                    final_res.prompt = self._extract_prompt_text(engine_prompt)
273

274
            final_res_batch_checked = cast(list[RequestOutput], final_res_batch)
275

276
            response = self.request_output_to_completion_response(
277
278
279
280
281
282
                final_res_batch_checked,
                request,
                request_id,
                created_time,
                model_name,
                tokenizer,
283
                request_metadata,
284
            )
285
286
        except asyncio.CancelledError:
            return self.create_error_response("Client disconnected")
287
288
        except GenerationError as e:
            return self._convert_generation_error_to_response(e)
289
        except ValueError as e:
290
            return self.create_error_response(e)
291

292
293
        # When user requests streaming but we don't stream, we still need to
        # return a streaming response with a single event.
294
        if request.stream:
295
            response_json = response.model_dump_json()
296
297
298
299
300
301
302
303

            async def fake_stream_generator() -> AsyncGenerator[str, None]:
                yield f"data: {response_json}\n\n"
                yield "data: [DONE]\n\n"

            return fake_stream_generator()

        return response
304
305
306
307

    async def completion_stream_generator(
        self,
        request: CompletionRequest,
308
        engine_prompts: list[TokPrompt],
309
        result_generator: AsyncIterator[tuple[int, RequestOutput]],
310
311
312
313
        request_id: str,
        created_time: int,
        model_name: str,
        num_prompts: int,
314
        tokenizer: TokenizerLike | None,
315
        request_metadata: RequestResponseMetadata,
316
    ) -> AsyncGenerator[str, None]:
317
        num_choices = 1 if request.n is None else request.n
318
        previous_text_lens = [0] * num_choices * num_prompts
319
320
        previous_num_tokens = [0] * num_choices * num_prompts
        has_echoed = [False] * num_choices * num_prompts
321
        num_prompt_tokens = [0] * num_prompts
322
323
        num_cached_tokens = None
        first_iteration = True
324

325
        stream_options = request.stream_options
326
327
328
        include_usage, include_continuous_usage = should_include_usage(
            stream_options, self.enable_force_include_usage
        )
329

330
331
        try:
            async for prompt_idx, res in result_generator:
332
333
                prompt_token_ids = res.prompt_token_ids
                prompt_logprobs = res.prompt_logprobs
334

335
336
337
338
                if first_iteration:
                    num_cached_tokens = res.num_cached_tokens
                    first_iteration = False

339
340
341
                prompt_text = res.prompt
                if prompt_text is None:
                    engine_prompt = engine_prompts[prompt_idx]
342
                    prompt_text = self._extract_prompt_text(engine_prompt)
343

344
                # Prompt details are excluded from later streamed outputs
345
346
                if prompt_token_ids is not None:
                    num_prompt_tokens[prompt_idx] = len(prompt_token_ids)
347

348
                delta_token_ids: GenericSequence[int]
349
                out_logprobs: GenericSequence[dict[int, Logprob] | None] | None
350
351

                for output in res.outputs:
352
                    i = output.index + prompt_idx * num_choices
353

354
355
356
                    # Useful when request.return_token_ids is True
                    # Returning prompt token IDs shares the same logic
                    # with the echo implementation.
357
                    prompt_token_ids_to_return: list[int] | None = None
358

359
                    assert request.max_tokens is not None
360
                    if request.echo and not has_echoed[i]:
361
                        assert prompt_token_ids is not None
362
363
                        if request.return_token_ids:
                            prompt_text = ""
364
                        assert prompt_text is not None
365
366
367
368
369
370
371
372
373
                        if request.max_tokens == 0:
                            # only return the prompt
                            delta_text = prompt_text
                            delta_token_ids = prompt_token_ids
                            out_logprobs = prompt_logprobs
                        else:
                            # echo the prompt and first token
                            delta_text = prompt_text + output.text
                            delta_token_ids = [
374
375
                                *prompt_token_ids,
                                *output.token_ids,
376
377
                            ]
                            out_logprobs = [
378
                                *(prompt_logprobs or []),
379
380
                                *(output.logprobs or []),
                            ]
381
                        prompt_token_ids_to_return = prompt_token_ids
382
383
384
                        has_echoed[i] = True
                    else:
                        # return just the delta
385
386
387
                        delta_text = output.text
                        delta_token_ids = output.token_ids
                        out_logprobs = output.logprobs
388

389
390
                        # has_echoed[i] is reused here to indicate whether
                        # we have already returned the prompt token IDs.
391
                        if not has_echoed[i] and request.return_token_ids:
392
393
394
                            prompt_token_ids_to_return = prompt_token_ids
                            has_echoed[i] = True

395
396
397
398
399
                        if (
                            not delta_text
                            and not delta_token_ids
                            and not previous_num_tokens[i]
                        ):
400
401
402
                            # Chunked prefill case, don't return empty chunks
                            continue

403
                    if request.logprobs is not None:
404
                        assert out_logprobs is not None, "Did not output logprobs"
405
                        logprobs = self._create_completion_logprobs(
406
                            token_ids=delta_token_ids,
407
                            top_logprobs=out_logprobs,
408
                            num_output_top_logprobs=request.logprobs,
409
                            tokenizer=tokenizer,
410
                            initial_text_offset=previous_text_lens[i],
411
                            return_as_token_id=request.return_tokens_as_token_ids,
412
413
414
415
                        )
                    else:
                        logprobs = None

416
417
                    previous_text_lens[i] += len(output.text)
                    previous_num_tokens[i] += len(output.token_ids)
418
                    finish_reason = output.finish_reason
419
                    stop_reason = output.stop_reason
420

421
422
                    self._raise_if_error(finish_reason, request_id)

423
                    chunk = CompletionStreamResponse(
424
425
426
427
428
429
430
431
432
                        id=request_id,
                        created=created_time,
                        model=model_name,
                        choices=[
                            CompletionResponseStreamChoice(
                                index=i,
                                text=delta_text,
                                logprobs=logprobs,
                                finish_reason=finish_reason,
433
                                stop_reason=stop_reason,
434
                                prompt_token_ids=prompt_token_ids_to_return,
435
436
437
438
439
                                token_ids=(
                                    as_list(output.token_ids)
                                    if request.return_token_ids
                                    else None
                                ),
440
                            )
441
442
                        ],
                    )
443
444
445
446
447
448
449
450
                    if include_continuous_usage:
                        prompt_tokens = num_prompt_tokens[prompt_idx]
                        completion_tokens = previous_num_tokens[i]
                        chunk.usage = UsageInfo(
                            prompt_tokens=prompt_tokens,
                            completion_tokens=completion_tokens,
                            total_tokens=prompt_tokens + completion_tokens,
                        )
451

452
                    response_json = chunk.model_dump_json(exclude_unset=False)
453
                    yield f"data: {response_json}\n\n"
454

455
456
457
458
459
            total_prompt_tokens = sum(num_prompt_tokens)
            total_completion_tokens = sum(previous_num_tokens)
            final_usage_info = UsageInfo(
                prompt_tokens=total_prompt_tokens,
                completion_tokens=total_completion_tokens,
460
461
                total_tokens=total_prompt_tokens + total_completion_tokens,
            )
462

463
464
            if self.enable_prompt_tokens_details and num_cached_tokens:
                final_usage_info.prompt_tokens_details = PromptTokenUsageInfo(
465
466
                    cached_tokens=num_cached_tokens
                )
467

468
            if include_usage:
469
470
471
472
473
                final_usage_chunk = CompletionStreamResponse(
                    id=request_id,
                    created=created_time,
                    model=model_name,
                    choices=[],
474
                    usage=final_usage_info,
475
                )
476
                final_usage_data = final_usage_chunk.model_dump_json(
477
478
                    exclude_unset=False, exclude_none=True
                )
479
480
                yield f"data: {final_usage_data}\n\n"

481
            # report to FastAPI middleware aggregate usage across all choices
482
            request_metadata.final_usage_info = final_usage_info
483

484
485
        except GenerationError as e:
            yield f"data: {self._convert_generation_error_to_streaming_response(e)}\n\n"
486
        except Exception as e:
487
            logger.exception("Error in completion stream generator.")
488
            data = self.create_streaming_error_response(e)
489
490
491
492
493
            yield f"data: {data}\n\n"
        yield "data: [DONE]\n\n"

    def request_output_to_completion_response(
        self,
494
        final_res_batch: list[RequestOutput],
495
496
497
498
        request: CompletionRequest,
        request_id: str,
        created_time: int,
        model_name: str,
499
        tokenizer: TokenizerLike | None,
500
        request_metadata: RequestResponseMetadata,
501
    ) -> CompletionResponse:
502
        choices: list[CompletionResponseChoice] = []
503
504
        num_prompt_tokens = 0
        num_generated_tokens = 0
505
506
        kv_transfer_params = None
        last_final_res = None
507
        for final_res in final_res_batch:
508
            last_final_res = final_res
509
            prompt_token_ids = final_res.prompt_token_ids
510
            assert prompt_token_ids is not None
511
            prompt_logprobs = clamp_prompt_logprobs(final_res.prompt_logprobs)
512
513
            prompt_text = final_res.prompt

514
            token_ids: GenericSequence[int]
515
            out_logprobs: GenericSequence[dict[int, Logprob] | None] | None
516

517
            for output in final_res.outputs:
518
519
                self._raise_if_error(output.finish_reason, request_id)

520
                assert request.max_tokens is not None
521
                if request.echo:
522
523
                    if request.return_token_ids:
                        prompt_text = ""
524
                    assert prompt_text is not None
525
526
527
528
                    if request.max_tokens == 0:
                        token_ids = prompt_token_ids
                        out_logprobs = prompt_logprobs
                        output_text = prompt_text
529
                    else:
530
531
532
533
534
535
536
537
538
539
540
541
542
                        token_ids = [*prompt_token_ids, *output.token_ids]

                        if request.logprobs is None:
                            out_logprobs = None
                        else:
                            assert prompt_logprobs is not None
                            assert output.logprobs is not None
                            out_logprobs = [
                                *prompt_logprobs,
                                *output.logprobs,
                            ]

                        output_text = prompt_text + output.text
543
544
                else:
                    token_ids = output.token_ids
545
                    out_logprobs = output.logprobs
546
547
548
                    output_text = output.text

                if request.logprobs is not None:
549
                    assert out_logprobs is not None, "Did not output logprobs"
550
                    logprobs = self._create_completion_logprobs(
551
                        token_ids=token_ids,
552
                        top_logprobs=out_logprobs,
553
                        tokenizer=tokenizer,
554
                        num_output_top_logprobs=request.logprobs,
555
                        return_as_token_id=request.return_tokens_as_token_ids,
556
557
558
559
560
561
562
563
564
                    )
                else:
                    logprobs = None

                choice_data = CompletionResponseChoice(
                    index=len(choices),
                    text=output_text,
                    logprobs=logprobs,
                    finish_reason=output.finish_reason,
565
                    stop_reason=output.stop_reason,
566
                    prompt_logprobs=final_res.prompt_logprobs,
567
568
569
570
571
572
                    prompt_token_ids=(
                        prompt_token_ids if request.return_token_ids else None
                    ),
                    token_ids=(
                        as_list(output.token_ids) if request.return_token_ids else None
                    ),
573
574
575
                )
                choices.append(choice_data)

576
577
                num_generated_tokens += len(output.token_ids)

578
579
580
581
582
583
584
585
            num_prompt_tokens += len(prompt_token_ids)

        usage = UsageInfo(
            prompt_tokens=num_prompt_tokens,
            completion_tokens=num_generated_tokens,
            total_tokens=num_prompt_tokens + num_generated_tokens,
        )

586
587
588
589
590
        if (
            self.enable_prompt_tokens_details
            and last_final_res
            and last_final_res.num_cached_tokens
        ):
591
            usage.prompt_tokens_details = PromptTokenUsageInfo(
592
593
                cached_tokens=last_final_res.num_cached_tokens
            )
594

595
        request_metadata.final_usage_info = usage
596
597
        if final_res_batch:
            kv_transfer_params = final_res_batch[0].kv_transfer_params
598
599
600
601
602
603
        return CompletionResponse(
            id=request_id,
            created=created_time,
            model=model_name,
            choices=choices,
            usage=usage,
604
605
            kv_transfer_params=kv_transfer_params,
        )
606
607
608
609

    def _create_completion_logprobs(
        self,
        token_ids: GenericSequence[int],
610
        top_logprobs: GenericSequence[dict[int, Logprob] | None],
611
        num_output_top_logprobs: int,
612
        tokenizer: TokenizerLike | None,
613
        initial_text_offset: int = 0,
614
        return_as_token_id: bool | None = None,
615
616
    ) -> CompletionLogProbs:
        """Create logprobs for OpenAI Completion API."""
617
        out_text_offset: list[int] = []
618
        out_token_logprobs: list[float | None] = []
619
        out_tokens: list[str] = []
620
        out_top_logprobs: list[dict[str, float] | None] = []
621
622
623

        last_token_len = 0

624
625
626
627
628
        should_return_as_token_id = (
            return_as_token_id
            if return_as_token_id is not None
            else self.return_tokens_as_token_ids
        )
629
630
631
        for i, token_id in enumerate(token_ids):
            step_top_logprobs = top_logprobs[i]
            if step_top_logprobs is None:
632
                if should_return_as_token_id:
633
                    token = f"token_id:{token_id}"
634
635
                else:
                    if tokenizer is None:
636
637
638
639
640
                        raise VLLMValidationError(
                            "Unable to get tokenizer because "
                            "`skip_tokenizer_init=True`",
                            parameter="skip_tokenizer_init",
                            value=True,
641
642
643
                        )

                    token = tokenizer.decode(token_id)
644

645
646
647
648
                out_tokens.append(token)
                out_token_logprobs.append(None)
                out_top_logprobs.append(None)
            else:
649
650
                step_token = step_top_logprobs[token_id]

651
                token = self._get_decoded_token(
652
                    step_token,
653
654
                    token_id,
                    tokenizer,
655
                    return_as_token_id=should_return_as_token_id,
656
657
658
                )
                token_logprob = max(step_token.logprob, -9999.0)

659
660
661
662
663
664
665
                out_tokens.append(token)
                out_token_logprobs.append(token_logprob)

                # makes sure to add the top num_output_top_logprobs + 1
                # logprobs, as defined in the openai API
                # (cf. https://github.com/openai/openai-openapi/blob/
                # 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153)
666
667
668
669
670
671
672
673
674
675
676
677
678
679
                out_top_logprobs.append(
                    {
                        # Convert float("-inf") to the
                        # JSON-serializable float that OpenAI uses
                        self._get_decoded_token(
                            top_lp[1],
                            top_lp[0],
                            tokenizer,
                            return_as_token_id=should_return_as_token_id,
                        ): max(top_lp[1].logprob, -9999.0)
                        for i, top_lp in enumerate(step_top_logprobs.items())
                        if num_output_top_logprobs >= i
                    }
                )
680
681
682
683
684
685
686
687
688
689
690
691
692

            if len(out_text_offset) == 0:
                out_text_offset.append(initial_text_offset)
            else:
                out_text_offset.append(out_text_offset[-1] + last_token_len)
            last_token_len = len(token)

        return CompletionLogProbs(
            text_offset=out_text_offset,
            token_logprobs=out_token_logprobs,
            tokens=out_tokens,
            top_logprobs=out_top_logprobs,
        )