serving.py 26.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
import time
6
7
from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import Sequence as GenericSequence
8
from typing import cast
9

10
import jinja2
11
from fastapi import Request
12

13
from vllm.engine.protocol import EngineClient
14
from vllm.entrypoints.logger import RequestLogger
15
from vllm.entrypoints.openai.completion.protocol import (
16
17
18
19
20
21
    CompletionLogProbs,
    CompletionRequest,
    CompletionResponse,
    CompletionResponseChoice,
    CompletionResponseStreamChoice,
    CompletionStreamResponse,
22
23
)
from vllm.entrypoints.openai.engine.protocol import (
24
25
26
27
28
    ErrorResponse,
    PromptTokenUsageInfo,
    RequestResponseMetadata,
    UsageInfo,
)
29
from vllm.entrypoints.openai.engine.serving import (
30
31
32
33
    GenerationError,
    OpenAIServing,
    clamp_prompt_logprobs,
)
34
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
35
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
36
from vllm.exceptions import VLLMValidationError
37
from vllm.inputs.data import ProcessorInputs
38
from vllm.logger import init_logger
39
from vllm.logprobs import Logprob
40
from vllm.outputs import RequestOutput
41
from vllm.sampling_params import BeamSearchParams, SamplingParams
42
from vllm.tokenizers import TokenizerLike
43
44
from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.collection_utils import as_list
45
46
47
48
49

logger = init_logger(__name__)


class OpenAIServingCompletion(OpenAIServing):
50
51
    def __init__(
        self,
52
        engine_client: EngineClient,
53
        models: OpenAIServingModels,
54
        *,
55
        request_logger: RequestLogger | None,
56
        return_tokens_as_token_ids: bool = False,
57
        enable_prompt_tokens_details: bool = False,
58
        enable_force_include_usage: bool = False,
59
        log_error_stack: bool = False,
60
    ):
61
62
63
64
65
        super().__init__(
            engine_client=engine_client,
            models=models,
            request_logger=request_logger,
            return_tokens_as_token_ids=return_tokens_as_token_ids,
66
            log_error_stack=log_error_stack,
67
        )
68

69
        self.enable_prompt_tokens_details = enable_prompt_tokens_details
70
        self.enable_force_include_usage = enable_force_include_usage
71
72

        self.default_sampling_params = self.model_config.get_diff_sampling_param()
73
74
75
76
77
78
        mc = self.model_config
        self.override_max_tokens = (
            self.default_sampling_params.get("max_tokens")
            if mc.generation_config not in ("auto", "vllm")
            else getattr(mc, "override_generation_config", {}).get("max_new_tokens")
        )
79

80
    async def render_completion_request(
81
82
        self,
        request: CompletionRequest,
83
    ) -> list[ProcessorInputs] | ErrorResponse:
84
85
        """
        render completion request by validating and preprocessing inputs.
86

87
88
89
        Returns:
            A list of engine_prompts on success,
            or an ErrorResponse on failure.
90
91
92
93
94
        """
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

95
96
97
98
99
100
        # If the engine is dead, raise the engine's DEAD_ERROR.
        # This is required for the streaming case, where we return a
        # success status before we actually start generating text :).
        if self.engine_client.errored:
            raise self.engine_client.dead_error

101
        # Return error for unsupported features.
102
        if request.suffix is not None:
103
            return self.create_error_response("suffix is not currently supported")
104

105
        if request.echo and request.prompt_embeds is not None:
106
            return self.create_error_response("Echo is unsupported with prompt embeds.")
107

108
        if request.prompt_logprobs is not None and request.prompt_embeds is not None:
109
            return self.create_error_response(
110
111
                "prompt_logprobs is not compatible with prompt embeds."
            )
112

113
        try:
114
115
116
            engine_prompts = await self._preprocess_completion(
                request,
                prompt_input=request.prompt,
117
                prompt_embeds=request.prompt_embeds,
118
            )
119
        except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e:
120
            logger.exception("Error in preprocessing prompt inputs")
121
            return self.create_error_response(e)
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155

        return engine_prompts

    async def create_completion(
        self,
        request: CompletionRequest,
        raw_request: Request | None = None,
    ) -> AsyncGenerator[str, None] | CompletionResponse | ErrorResponse:
        """Completion API similar to OpenAI's API.

        See https://platform.openai.com/docs/api-reference/completions/create
        for the API specification. This API mimics the OpenAI Completion API.

        NOTE: Currently we do not support the following feature:
            - suffix (the language models we currently support do not support
            suffix)
        """
        result = await self.render_completion_request(request)
        if isinstance(result, ErrorResponse):
            return result

        engine_prompts = result

        request_id = f"cmpl-{self._base_request_id(raw_request, request.request_id)}"
        created_time = int(time.time())

        request_metadata = RequestResponseMetadata(request_id=request_id)
        if raw_request:
            raw_request.state.request_metadata = request_metadata

        try:
            lora_request = self._maybe_get_adapters(request)
        except (ValueError, TypeError, RuntimeError) as e:
            logger.exception("Error preparing request components")
156
            return self.create_error_response(e)
157

158
159
160
        # Extract data_parallel_rank from header (router can inject it)
        data_parallel_rank = self._get_data_parallel_rank(raw_request)

161
        # Schedule the request and get the result generator.
162
        max_model_len = self.model_config.max_model_len
163
        generators: list[AsyncGenerator[RequestOutput, None]] = []
164
165
        try:
            for i, engine_prompt in enumerate(engine_prompts):
166
                max_tokens = get_max_tokens(
167
                    max_model_len,
168
                    request.max_tokens,
169
170
                    self._extract_prompt_len(engine_prompt),
                    self.default_sampling_params,
171
                    self.override_max_tokens,
172
                )
173

174
                sampling_params: SamplingParams | BeamSearchParams
175
176
                if request.use_beam_search:
                    sampling_params = request.to_beam_search_params(
177
178
                        max_tokens, self.default_sampling_params
                    )
179
180
                else:
                    sampling_params = request.to_sampling_params(
181
182
183
                        max_tokens,
                        self.default_sampling_params,
                    )
184

185
186
                request_id_item = f"{request_id}-{i}"

187
188
                self._log_inputs(
                    request_id_item,
189
                    engine_prompt,
190
191
192
                    params=sampling_params,
                    lora_request=lora_request,
                )
193

194
195
196
197
198
                trace_headers = (
                    None
                    if raw_request is None
                    else await self._get_trace_headers(raw_request.headers)
                )
199

200
                if isinstance(sampling_params, BeamSearchParams):
201
                    generator = self.beam_search(
202
                        prompt=engine_prompt,
203
204
                        request_id=request_id,
                        params=sampling_params,
205
                        lora_request=lora_request,
206
                        trace_headers=trace_headers,
207
                    )
208
209
                else:
                    generator = self.engine_client.generate(
210
                        engine_prompt,
211
212
213
214
215
                        sampling_params,
                        request_id_item,
                        lora_request=lora_request,
                        trace_headers=trace_headers,
                        priority=request.priority,
216
                        data_parallel_rank=data_parallel_rank,
217
                    )
218
219

                generators.append(generator)
220
        except ValueError as e:
221
            return self.create_error_response(e)
222

223
        result_generator = merge_async_iterators(*generators)
224

225
        model_name = self.models.model_name(lora_request)
226
227
        num_prompts = len(engine_prompts)

228
229
        # We do not stream the results when using beam search.
        stream = request.stream and not request.use_beam_search
230
231

        # Streaming response
232
233
        tokenizer = self.renderer.tokenizer

234
        if stream:
235
236
            return self.completion_stream_generator(
                request,
237
                engine_prompts,
238
239
240
241
                result_generator,
                request_id,
                created_time,
                model_name,
242
                num_prompts=num_prompts,
243
                tokenizer=tokenizer,
244
                request_metadata=request_metadata,
245
            )
246
247

        # Non-streaming response
248
        final_res_batch: list[RequestOutput | None] = [None] * num_prompts
249
250
251
        try:
            async for i, res in result_generator:
                final_res_batch[i] = res
252
253
254
255
256
257
258
259

            for i, final_res in enumerate(final_res_batch):
                assert final_res is not None

                # The output should contain the input text
                # We did not pass it into vLLM engine to avoid being redundant
                # with the inputs token IDs
                if final_res.prompt is None:
260
                    engine_prompt = engine_prompts[i]
261
                    final_res.prompt = self._extract_prompt_text(engine_prompt)
262

263
            final_res_batch_checked = cast(list[RequestOutput], final_res_batch)
264

265
            response = self.request_output_to_completion_response(
266
267
268
269
270
271
                final_res_batch_checked,
                request,
                request_id,
                created_time,
                model_name,
                tokenizer,
272
                request_metadata,
273
            )
274
275
        except asyncio.CancelledError:
            return self.create_error_response("Client disconnected")
276
277
        except GenerationError as e:
            return self._convert_generation_error_to_response(e)
278
        except ValueError as e:
279
            return self.create_error_response(e)
280

281
282
        # When user requests streaming but we don't stream, we still need to
        # return a streaming response with a single event.
283
        if request.stream:
284
            response_json = response.model_dump_json()
285
286
287
288
289
290
291
292

            async def fake_stream_generator() -> AsyncGenerator[str, None]:
                yield f"data: {response_json}\n\n"
                yield "data: [DONE]\n\n"

            return fake_stream_generator()

        return response
293
294
295
296

    async def completion_stream_generator(
        self,
        request: CompletionRequest,
297
        engine_prompts: list[ProcessorInputs],
298
        result_generator: AsyncIterator[tuple[int, RequestOutput]],
299
300
301
302
        request_id: str,
        created_time: int,
        model_name: str,
        num_prompts: int,
303
        tokenizer: TokenizerLike | None,
304
        request_metadata: RequestResponseMetadata,
305
    ) -> AsyncGenerator[str, None]:
306
        num_choices = 1 if request.n is None else request.n
307
        previous_text_lens = [0] * num_choices * num_prompts
308
309
        previous_num_tokens = [0] * num_choices * num_prompts
        has_echoed = [False] * num_choices * num_prompts
310
        num_prompt_tokens = [0] * num_prompts
311
312
        num_cached_tokens = None
        first_iteration = True
313

314
        stream_options = request.stream_options
315
316
317
        include_usage, include_continuous_usage = should_include_usage(
            stream_options, self.enable_force_include_usage
        )
318

319
320
        try:
            async for prompt_idx, res in result_generator:
321
322
                prompt_token_ids = res.prompt_token_ids
                prompt_logprobs = res.prompt_logprobs
323

324
325
326
327
                if first_iteration:
                    num_cached_tokens = res.num_cached_tokens
                    first_iteration = False

328
329
330
                prompt_text = res.prompt
                if prompt_text is None:
                    engine_prompt = engine_prompts[prompt_idx]
331
                    prompt_text = self._extract_prompt_text(engine_prompt)
332

333
                # Prompt details are excluded from later streamed outputs
334
335
                if prompt_token_ids is not None:
                    num_prompt_tokens[prompt_idx] = len(prompt_token_ids)
336

337
                delta_token_ids: GenericSequence[int]
338
                out_logprobs: GenericSequence[dict[int, Logprob] | None] | None
339
340

                for output in res.outputs:
341
                    i = output.index + prompt_idx * num_choices
342

343
344
345
                    # Useful when request.return_token_ids is True
                    # Returning prompt token IDs shares the same logic
                    # with the echo implementation.
346
                    prompt_token_ids_to_return: list[int] | None = None
347

348
                    assert request.max_tokens is not None
349
                    if request.echo and not has_echoed[i]:
350
                        assert prompt_token_ids is not None
351
352
                        if request.return_token_ids:
                            prompt_text = ""
353
                        assert prompt_text is not None
354
355
356
357
358
359
360
361
362
                        if request.max_tokens == 0:
                            # only return the prompt
                            delta_text = prompt_text
                            delta_token_ids = prompt_token_ids
                            out_logprobs = prompt_logprobs
                        else:
                            # echo the prompt and first token
                            delta_text = prompt_text + output.text
                            delta_token_ids = [
363
364
                                *prompt_token_ids,
                                *output.token_ids,
365
366
                            ]
                            out_logprobs = [
367
                                *(prompt_logprobs or []),
368
369
                                *(output.logprobs or []),
                            ]
370
                        prompt_token_ids_to_return = prompt_token_ids
371
372
373
                        has_echoed[i] = True
                    else:
                        # return just the delta
374
375
376
                        delta_text = output.text
                        delta_token_ids = output.token_ids
                        out_logprobs = output.logprobs
377

378
379
                        # has_echoed[i] is reused here to indicate whether
                        # we have already returned the prompt token IDs.
380
                        if not has_echoed[i] and request.return_token_ids:
381
382
383
                            prompt_token_ids_to_return = prompt_token_ids
                            has_echoed[i] = True

384
385
386
387
388
                        if (
                            not delta_text
                            and not delta_token_ids
                            and not previous_num_tokens[i]
                        ):
389
390
391
                            # Chunked prefill case, don't return empty chunks
                            continue

392
                    if request.logprobs is not None:
393
                        assert out_logprobs is not None, "Did not output logprobs"
394
                        logprobs = self._create_completion_logprobs(
395
                            token_ids=delta_token_ids,
396
                            top_logprobs=out_logprobs,
397
                            num_output_top_logprobs=request.logprobs,
398
                            tokenizer=tokenizer,
399
                            initial_text_offset=previous_text_lens[i],
400
                            return_as_token_id=request.return_tokens_as_token_ids,
401
402
403
404
                        )
                    else:
                        logprobs = None

405
406
                    previous_text_lens[i] += len(output.text)
                    previous_num_tokens[i] += len(output.token_ids)
407
                    finish_reason = output.finish_reason
408
                    stop_reason = output.stop_reason
409

410
411
                    self._raise_if_error(finish_reason, request_id)

412
                    chunk = CompletionStreamResponse(
413
414
415
416
417
418
419
420
421
                        id=request_id,
                        created=created_time,
                        model=model_name,
                        choices=[
                            CompletionResponseStreamChoice(
                                index=i,
                                text=delta_text,
                                logprobs=logprobs,
                                finish_reason=finish_reason,
422
                                stop_reason=stop_reason,
423
                                prompt_token_ids=prompt_token_ids_to_return,
424
425
426
427
428
                                token_ids=(
                                    as_list(output.token_ids)
                                    if request.return_token_ids
                                    else None
                                ),
429
                            )
430
431
                        ],
                    )
432
433
434
435
436
437
438
439
                    if include_continuous_usage:
                        prompt_tokens = num_prompt_tokens[prompt_idx]
                        completion_tokens = previous_num_tokens[i]
                        chunk.usage = UsageInfo(
                            prompt_tokens=prompt_tokens,
                            completion_tokens=completion_tokens,
                            total_tokens=prompt_tokens + completion_tokens,
                        )
440

441
                    response_json = chunk.model_dump_json(exclude_unset=False)
442
                    yield f"data: {response_json}\n\n"
443

444
445
446
447
448
            total_prompt_tokens = sum(num_prompt_tokens)
            total_completion_tokens = sum(previous_num_tokens)
            final_usage_info = UsageInfo(
                prompt_tokens=total_prompt_tokens,
                completion_tokens=total_completion_tokens,
449
450
                total_tokens=total_prompt_tokens + total_completion_tokens,
            )
451

452
453
            if self.enable_prompt_tokens_details and num_cached_tokens:
                final_usage_info.prompt_tokens_details = PromptTokenUsageInfo(
454
455
                    cached_tokens=num_cached_tokens
                )
456

457
            if include_usage:
458
459
460
461
462
                final_usage_chunk = CompletionStreamResponse(
                    id=request_id,
                    created=created_time,
                    model=model_name,
                    choices=[],
463
                    usage=final_usage_info,
464
                )
465
                final_usage_data = final_usage_chunk.model_dump_json(
466
467
                    exclude_unset=False, exclude_none=True
                )
468
469
                yield f"data: {final_usage_data}\n\n"

470
            # report to FastAPI middleware aggregate usage across all choices
471
            request_metadata.final_usage_info = final_usage_info
472

473
474
        except GenerationError as e:
            yield f"data: {self._convert_generation_error_to_streaming_response(e)}\n\n"
475
        except Exception as e:
476
            logger.exception("Error in completion stream generator.")
477
            data = self.create_streaming_error_response(e)
478
479
480
481
482
            yield f"data: {data}\n\n"
        yield "data: [DONE]\n\n"

    def request_output_to_completion_response(
        self,
483
        final_res_batch: list[RequestOutput],
484
485
486
487
        request: CompletionRequest,
        request_id: str,
        created_time: int,
        model_name: str,
488
        tokenizer: TokenizerLike | None,
489
        request_metadata: RequestResponseMetadata,
490
    ) -> CompletionResponse:
491
        choices: list[CompletionResponseChoice] = []
492
493
        num_prompt_tokens = 0
        num_generated_tokens = 0
494
495
        kv_transfer_params = None
        last_final_res = None
496
        for final_res in final_res_batch:
497
            last_final_res = final_res
498
            prompt_token_ids = final_res.prompt_token_ids
499
            assert prompt_token_ids is not None
500
            prompt_logprobs = clamp_prompt_logprobs(final_res.prompt_logprobs)
501
502
            prompt_text = final_res.prompt

503
            token_ids: GenericSequence[int]
504
            out_logprobs: GenericSequence[dict[int, Logprob] | None] | None
505

506
            for output in final_res.outputs:
507
508
                self._raise_if_error(output.finish_reason, request_id)

509
                assert request.max_tokens is not None
510
                if request.echo:
511
512
                    if request.return_token_ids:
                        prompt_text = ""
513
                    assert prompt_text is not None
514
515
516
517
                    if request.max_tokens == 0:
                        token_ids = prompt_token_ids
                        out_logprobs = prompt_logprobs
                        output_text = prompt_text
518
                    else:
519
520
521
522
523
524
525
526
527
528
529
530
531
                        token_ids = [*prompt_token_ids, *output.token_ids]

                        if request.logprobs is None:
                            out_logprobs = None
                        else:
                            assert prompt_logprobs is not None
                            assert output.logprobs is not None
                            out_logprobs = [
                                *prompt_logprobs,
                                *output.logprobs,
                            ]

                        output_text = prompt_text + output.text
532
533
                else:
                    token_ids = output.token_ids
534
                    out_logprobs = output.logprobs
535
536
537
                    output_text = output.text

                if request.logprobs is not None:
538
                    assert out_logprobs is not None, "Did not output logprobs"
539
                    logprobs = self._create_completion_logprobs(
540
                        token_ids=token_ids,
541
                        top_logprobs=out_logprobs,
542
                        tokenizer=tokenizer,
543
                        num_output_top_logprobs=request.logprobs,
544
                        return_as_token_id=request.return_tokens_as_token_ids,
545
546
547
548
549
550
551
552
553
                    )
                else:
                    logprobs = None

                choice_data = CompletionResponseChoice(
                    index=len(choices),
                    text=output_text,
                    logprobs=logprobs,
                    finish_reason=output.finish_reason,
554
                    stop_reason=output.stop_reason,
555
                    prompt_logprobs=final_res.prompt_logprobs,
556
557
558
559
560
561
                    prompt_token_ids=(
                        prompt_token_ids if request.return_token_ids else None
                    ),
                    token_ids=(
                        as_list(output.token_ids) if request.return_token_ids else None
                    ),
562
563
564
                )
                choices.append(choice_data)

565
566
                num_generated_tokens += len(output.token_ids)

567
568
569
570
571
572
573
574
            num_prompt_tokens += len(prompt_token_ids)

        usage = UsageInfo(
            prompt_tokens=num_prompt_tokens,
            completion_tokens=num_generated_tokens,
            total_tokens=num_prompt_tokens + num_generated_tokens,
        )

575
576
577
578
579
        if (
            self.enable_prompt_tokens_details
            and last_final_res
            and last_final_res.num_cached_tokens
        ):
580
            usage.prompt_tokens_details = PromptTokenUsageInfo(
581
582
                cached_tokens=last_final_res.num_cached_tokens
            )
583

584
        request_metadata.final_usage_info = usage
585
586
        if final_res_batch:
            kv_transfer_params = final_res_batch[0].kv_transfer_params
587
588
589
590
591
592
        return CompletionResponse(
            id=request_id,
            created=created_time,
            model=model_name,
            choices=choices,
            usage=usage,
593
594
            kv_transfer_params=kv_transfer_params,
        )
595
596
597
598

    def _create_completion_logprobs(
        self,
        token_ids: GenericSequence[int],
599
        top_logprobs: GenericSequence[dict[int, Logprob] | None],
600
        num_output_top_logprobs: int,
601
        tokenizer: TokenizerLike | None,
602
        initial_text_offset: int = 0,
603
        return_as_token_id: bool | None = None,
604
605
    ) -> CompletionLogProbs:
        """Create logprobs for OpenAI Completion API."""
606
        out_text_offset: list[int] = []
607
        out_token_logprobs: list[float | None] = []
608
        out_tokens: list[str] = []
609
        out_top_logprobs: list[dict[str, float] | None] = []
610
611
612

        last_token_len = 0

613
614
615
616
617
        should_return_as_token_id = (
            return_as_token_id
            if return_as_token_id is not None
            else self.return_tokens_as_token_ids
        )
618
619
620
        for i, token_id in enumerate(token_ids):
            step_top_logprobs = top_logprobs[i]
            if step_top_logprobs is None:
621
                if should_return_as_token_id:
622
                    token = f"token_id:{token_id}"
623
624
                else:
                    if tokenizer is None:
625
626
627
628
629
                        raise VLLMValidationError(
                            "Unable to get tokenizer because "
                            "`skip_tokenizer_init=True`",
                            parameter="skip_tokenizer_init",
                            value=True,
630
631
632
                        )

                    token = tokenizer.decode(token_id)
633

634
635
636
637
                out_tokens.append(token)
                out_token_logprobs.append(None)
                out_top_logprobs.append(None)
            else:
638
639
                step_token = step_top_logprobs[token_id]

640
                token = self._get_decoded_token(
641
                    step_token,
642
643
                    token_id,
                    tokenizer,
644
                    return_as_token_id=should_return_as_token_id,
645
646
647
                )
                token_logprob = max(step_token.logprob, -9999.0)

648
649
650
651
652
653
654
                out_tokens.append(token)
                out_token_logprobs.append(token_logprob)

                # makes sure to add the top num_output_top_logprobs + 1
                # logprobs, as defined in the openai API
                # (cf. https://github.com/openai/openai-openapi/blob/
                # 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153)
655
656
657
658
659
660
661
662
663
664
665
666
667
668
                out_top_logprobs.append(
                    {
                        # Convert float("-inf") to the
                        # JSON-serializable float that OpenAI uses
                        self._get_decoded_token(
                            top_lp[1],
                            top_lp[0],
                            tokenizer,
                            return_as_token_id=should_return_as_token_id,
                        ): max(top_lp[1].logprob, -9999.0)
                        for i, top_lp in enumerate(step_top_logprobs.items())
                        if num_output_top_logprobs >= i
                    }
                )
669
670
671
672
673
674
675
676
677
678
679
680
681

            if len(out_text_offset) == 0:
                out_text_offset.append(initial_text_offset)
            else:
                out_text_offset.append(out_text_offset[-1] + last_token_len)
            last_token_len = len(token)

        return CompletionLogProbs(
            text_offset=out_text_offset,
            token_logprobs=out_token_logprobs,
            tokens=out_tokens,
            top_logprobs=out_top_logprobs,
        )