"vscode:/vscode.git/clone" did not exist on "35068264213da8192191716a91e2a8beca31e54a"
serving.py 25.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
import time
6
7
from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import Sequence as GenericSequence
8
from typing import TYPE_CHECKING, cast
9

10
from fastapi import Request
11

12
from vllm.engine.protocol import EngineClient
13
from vllm.entrypoints.logger import RequestLogger
14
from vllm.entrypoints.openai.completion.protocol import (
15
16
17
18
19
20
    CompletionLogProbs,
    CompletionRequest,
    CompletionResponse,
    CompletionResponseChoice,
    CompletionResponseStreamChoice,
    CompletionStreamResponse,
21
22
)
from vllm.entrypoints.openai.engine.protocol import (
23
24
25
26
27
    ErrorResponse,
    PromptTokenUsageInfo,
    RequestResponseMetadata,
    UsageInfo,
)
28
from vllm.entrypoints.openai.engine.serving import (
29
30
31
32
    GenerationError,
    OpenAIServing,
    clamp_prompt_logprobs,
)
33
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
34
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
35
from vllm.exceptions import VLLMValidationError
36
from vllm.inputs.data import ProcessorInputs
37
from vllm.logger import init_logger
38
from vllm.logprobs import Logprob
39
from vllm.outputs import RequestOutput
40
from vllm.sampling_params import BeamSearchParams, SamplingParams
41
from vllm.tokenizers import TokenizerLike
42
43
from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.collection_utils import as_list
44

45
46
47
if TYPE_CHECKING:
    from vllm.entrypoints.serve.render.serving import OpenAIServingRender

48
49
50
51
logger = init_logger(__name__)


class OpenAIServingCompletion(OpenAIServing):
52
53
    def __init__(
        self,
54
        engine_client: EngineClient,
55
        models: OpenAIServingModels,
56
        *,
57
        openai_serving_render: "OpenAIServingRender",
58
        request_logger: RequestLogger | None,
59
        return_tokens_as_token_ids: bool = False,
60
        enable_prompt_tokens_details: bool = False,
61
        enable_force_include_usage: bool = False,
62
    ):
63
64
65
66
67
68
        super().__init__(
            engine_client=engine_client,
            models=models,
            request_logger=request_logger,
            return_tokens_as_token_ids=return_tokens_as_token_ids,
        )
69

70
        self.openai_serving_render = openai_serving_render
71
        self.enable_prompt_tokens_details = enable_prompt_tokens_details
72
        self.enable_force_include_usage = enable_force_include_usage
73
74

        self.default_sampling_params = self.model_config.get_diff_sampling_param()
75
76
77
78
79
80
        mc = self.model_config
        self.override_max_tokens = (
            self.default_sampling_params.get("max_tokens")
            if mc.generation_config not in ("auto", "vllm")
            else getattr(mc, "override_generation_config", {}).get("max_new_tokens")
        )
81

82
    async def render_completion_request(
83
84
        self,
        request: CompletionRequest,
85
    ) -> list[ProcessorInputs] | ErrorResponse:
86
        """
87
88
89
90
        Validate the model and preprocess a completion request.

        Delegates preprocessing logic to OpenAIServingRender, adding the
        engine-aware checks (LoRA model validation, engine health).
91

92
93
94
        Returns:
            A list of engine_prompts on success,
            or an ErrorResponse on failure.
95
96
97
98
99
        """
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

100
101
102
103
104
105
        # If the engine is dead, raise the engine's DEAD_ERROR.
        # This is required for the streaming case, where we return a
        # success status before we actually start generating text :).
        if self.engine_client.errored:
            raise self.engine_client.dead_error

106
        return await self.openai_serving_render.render_completion(request)
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121

    async def create_completion(
        self,
        request: CompletionRequest,
        raw_request: Request | None = None,
    ) -> AsyncGenerator[str, None] | CompletionResponse | ErrorResponse:
        """Completion API similar to OpenAI's API.

        See https://platform.openai.com/docs/api-reference/completions/create
        for the API specification. This API mimics the OpenAI Completion API.

        NOTE: Currently we do not support the following feature:
            - suffix (the language models we currently support do not support
            suffix)
        """
122
123
124
125
126
        if request.stream and request.use_beam_search:
            return self.create_error_response(
                "Streaming is not currently supported with beam search"
            )

127
128
129
130
131
132
133
134
135
136
137
138
139
        result = await self.render_completion_request(request)
        if isinstance(result, ErrorResponse):
            return result

        engine_prompts = result

        request_id = f"cmpl-{self._base_request_id(raw_request, request.request_id)}"
        created_time = int(time.time())

        request_metadata = RequestResponseMetadata(request_id=request_id)
        if raw_request:
            raw_request.state.request_metadata = request_metadata

140
        lora_request = self._maybe_get_adapters(request)
141

142
143
144
        # Extract data_parallel_rank from header (router can inject it)
        data_parallel_rank = self._get_data_parallel_rank(raw_request)

145
        # Schedule the request and get the result generator.
146
        max_model_len = self.model_config.max_model_len
147
        generators: list[AsyncGenerator[RequestOutput, None]] = []
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
        for i, engine_prompt in enumerate(engine_prompts):
            max_tokens = get_max_tokens(
                max_model_len,
                request.max_tokens,
                self._extract_prompt_len(engine_prompt),
                self.default_sampling_params,
                self.override_max_tokens,
            )

            sampling_params: SamplingParams | BeamSearchParams
            if request.use_beam_search:
                sampling_params = request.to_beam_search_params(
                    max_tokens, self.default_sampling_params
                )
            else:
                sampling_params = request.to_sampling_params(
                    max_tokens,
165
                    self.default_sampling_params,
166
                )
167

168
            request_id_item = f"{request_id}-{i}"
169

170
171
172
173
174
175
            self._log_inputs(
                request_id_item,
                engine_prompt,
                params=sampling_params,
                lora_request=lora_request,
            )
176

177
178
179
180
181
182
183
184
185
186
            trace_headers = (
                None
                if raw_request is None
                else await self._get_trace_headers(raw_request.headers)
            )

            if isinstance(sampling_params, BeamSearchParams):
                generator = self.beam_search(
                    prompt=engine_prompt,
                    request_id=request_id,
187
188
                    params=sampling_params,
                    lora_request=lora_request,
189
                    trace_headers=trace_headers,
190
                )
191
192
193
194
195
196
197
198
199
            else:
                generator = self.engine_client.generate(
                    engine_prompt,
                    sampling_params,
                    request_id_item,
                    lora_request=lora_request,
                    trace_headers=trace_headers,
                    priority=request.priority,
                    data_parallel_rank=data_parallel_rank,
200
                )
201

202
            generators.append(generator)
203

204
        result_generator = merge_async_iterators(*generators)
205

206
        model_name = self.models.model_name(lora_request)
207
208
        num_prompts = len(engine_prompts)

209
        # Streaming response
210
211
        tokenizer = self.renderer.tokenizer

212
        if request.stream:
213
214
            return self.completion_stream_generator(
                request,
215
                engine_prompts,
216
217
218
219
                result_generator,
                request_id,
                created_time,
                model_name,
220
                num_prompts=num_prompts,
221
                tokenizer=tokenizer,
222
                request_metadata=request_metadata,
223
            )
224
225

        # Non-streaming response
226
        final_res_batch: list[RequestOutput | None] = [None] * num_prompts
227
228
229
        try:
            async for i, res in result_generator:
                final_res_batch[i] = res
230
231
232
233
234
235
236
237

            for i, final_res in enumerate(final_res_batch):
                assert final_res is not None

                # The output should contain the input text
                # We did not pass it into vLLM engine to avoid being redundant
                # with the inputs token IDs
                if final_res.prompt is None:
238
                    engine_prompt = engine_prompts[i]
239
                    final_res.prompt = self._extract_prompt_text(engine_prompt)
240

241
            final_res_batch_checked = cast(list[RequestOutput], final_res_batch)
242

243
            response = self.request_output_to_completion_response(
244
245
246
247
248
249
                final_res_batch_checked,
                request,
                request_id,
                created_time,
                model_name,
                tokenizer,
250
                request_metadata,
251
            )
252
253
        except asyncio.CancelledError:
            return self.create_error_response("Client disconnected")
254

255
256
        # When user requests streaming but we don't stream, we still need to
        # return a streaming response with a single event.
257
        if request.stream:
258
            response_json = response.model_dump_json()
259
260
261
262
263
264
265
266

            async def fake_stream_generator() -> AsyncGenerator[str, None]:
                yield f"data: {response_json}\n\n"
                yield "data: [DONE]\n\n"

            return fake_stream_generator()

        return response
267
268
269
270

    async def completion_stream_generator(
        self,
        request: CompletionRequest,
271
        engine_prompts: list[ProcessorInputs],
272
        result_generator: AsyncIterator[tuple[int, RequestOutput]],
273
274
275
276
        request_id: str,
        created_time: int,
        model_name: str,
        num_prompts: int,
277
        tokenizer: TokenizerLike | None,
278
        request_metadata: RequestResponseMetadata,
279
    ) -> AsyncGenerator[str, None]:
280
        num_choices = 1 if request.n is None else request.n
281
        previous_text_lens = [0] * num_choices * num_prompts
282
283
        previous_num_tokens = [0] * num_choices * num_prompts
        has_echoed = [False] * num_choices * num_prompts
284
        num_prompt_tokens = [0] * num_prompts
285
286
        num_cached_tokens = None
        first_iteration = True
287

288
        stream_options = request.stream_options
289
290
291
        include_usage, include_continuous_usage = should_include_usage(
            stream_options, self.enable_force_include_usage
        )
292

293
294
        try:
            async for prompt_idx, res in result_generator:
295
296
                prompt_token_ids = res.prompt_token_ids
                prompt_logprobs = res.prompt_logprobs
297

298
299
300
301
                if first_iteration:
                    num_cached_tokens = res.num_cached_tokens
                    first_iteration = False

302
303
304
                prompt_text = res.prompt
                if prompt_text is None:
                    engine_prompt = engine_prompts[prompt_idx]
305
                    prompt_text = self._extract_prompt_text(engine_prompt)
306

307
                # Prompt details are excluded from later streamed outputs
308
309
                if prompt_token_ids is not None:
                    num_prompt_tokens[prompt_idx] = len(prompt_token_ids)
310

311
                delta_token_ids: GenericSequence[int]
312
                out_logprobs: GenericSequence[dict[int, Logprob] | None] | None
313
314

                for output in res.outputs:
315
                    i = output.index + prompt_idx * num_choices
316

317
318
319
                    # Useful when request.return_token_ids is True
                    # Returning prompt token IDs shares the same logic
                    # with the echo implementation.
320
                    prompt_token_ids_to_return: list[int] | None = None
321

322
                    assert request.max_tokens is not None
323
                    if request.echo and not has_echoed[i]:
324
                        assert prompt_token_ids is not None
325
326
                        if request.return_token_ids:
                            prompt_text = ""
327
                        assert prompt_text is not None
328
329
330
331
332
333
334
335
336
                        if request.max_tokens == 0:
                            # only return the prompt
                            delta_text = prompt_text
                            delta_token_ids = prompt_token_ids
                            out_logprobs = prompt_logprobs
                        else:
                            # echo the prompt and first token
                            delta_text = prompt_text + output.text
                            delta_token_ids = [
337
338
                                *prompt_token_ids,
                                *output.token_ids,
339
340
                            ]
                            out_logprobs = [
341
                                *(prompt_logprobs or []),
342
343
                                *(output.logprobs or []),
                            ]
344
                        prompt_token_ids_to_return = prompt_token_ids
345
346
347
                        has_echoed[i] = True
                    else:
                        # return just the delta
348
349
350
                        delta_text = output.text
                        delta_token_ids = output.token_ids
                        out_logprobs = output.logprobs
351

352
353
                        # has_echoed[i] is reused here to indicate whether
                        # we have already returned the prompt token IDs.
354
                        if not has_echoed[i] and request.return_token_ids:
355
356
357
                            prompt_token_ids_to_return = prompt_token_ids
                            has_echoed[i] = True

358
359
360
361
362
                        if (
                            not delta_text
                            and not delta_token_ids
                            and not previous_num_tokens[i]
                        ):
363
364
365
                            # Chunked prefill case, don't return empty chunks
                            continue

366
                    if request.logprobs is not None:
367
                        assert out_logprobs is not None, "Did not output logprobs"
368
                        logprobs = self._create_completion_logprobs(
369
                            token_ids=delta_token_ids,
370
                            top_logprobs=out_logprobs,
371
                            num_output_top_logprobs=request.logprobs,
372
                            tokenizer=tokenizer,
373
                            initial_text_offset=previous_text_lens[i],
374
                            return_as_token_id=request.return_tokens_as_token_ids,
375
376
377
378
                        )
                    else:
                        logprobs = None

379
380
                    previous_text_lens[i] += len(output.text)
                    previous_num_tokens[i] += len(output.token_ids)
381
                    finish_reason = output.finish_reason
382
                    stop_reason = output.stop_reason
383

384
385
                    self._raise_if_error(finish_reason, request_id)

386
                    chunk = CompletionStreamResponse(
387
388
389
390
391
392
393
394
395
                        id=request_id,
                        created=created_time,
                        model=model_name,
                        choices=[
                            CompletionResponseStreamChoice(
                                index=i,
                                text=delta_text,
                                logprobs=logprobs,
                                finish_reason=finish_reason,
396
                                stop_reason=stop_reason,
397
                                prompt_token_ids=prompt_token_ids_to_return,
398
399
400
401
402
                                token_ids=(
                                    as_list(output.token_ids)
                                    if request.return_token_ids
                                    else None
                                ),
403
                            )
404
405
                        ],
                    )
406
407
408
409
410
411
412
413
                    if include_continuous_usage:
                        prompt_tokens = num_prompt_tokens[prompt_idx]
                        completion_tokens = previous_num_tokens[i]
                        chunk.usage = UsageInfo(
                            prompt_tokens=prompt_tokens,
                            completion_tokens=completion_tokens,
                            total_tokens=prompt_tokens + completion_tokens,
                        )
414

415
                    response_json = chunk.model_dump_json(exclude_unset=False)
416
                    yield f"data: {response_json}\n\n"
417

418
419
420
421
422
            total_prompt_tokens = sum(num_prompt_tokens)
            total_completion_tokens = sum(previous_num_tokens)
            final_usage_info = UsageInfo(
                prompt_tokens=total_prompt_tokens,
                completion_tokens=total_completion_tokens,
423
424
                total_tokens=total_prompt_tokens + total_completion_tokens,
            )
425

426
427
            if self.enable_prompt_tokens_details and num_cached_tokens:
                final_usage_info.prompt_tokens_details = PromptTokenUsageInfo(
428
429
                    cached_tokens=num_cached_tokens
                )
430

431
            if include_usage:
432
433
434
435
436
                final_usage_chunk = CompletionStreamResponse(
                    id=request_id,
                    created=created_time,
                    model=model_name,
                    choices=[],
437
                    usage=final_usage_info,
438
                )
439
                final_usage_data = final_usage_chunk.model_dump_json(
440
441
                    exclude_unset=False, exclude_none=True
                )
442
443
                yield f"data: {final_usage_data}\n\n"

444
            # report to FastAPI middleware aggregate usage across all choices
445
            request_metadata.final_usage_info = final_usage_info
446

447
448
        except GenerationError as e:
            yield f"data: {self._convert_generation_error_to_streaming_response(e)}\n\n"
449
        except Exception as e:
450
            logger.exception("Error in completion stream generator.")
451
            data = self.create_streaming_error_response(e)
452
453
454
455
456
            yield f"data: {data}\n\n"
        yield "data: [DONE]\n\n"

    def request_output_to_completion_response(
        self,
457
        final_res_batch: list[RequestOutput],
458
459
460
461
        request: CompletionRequest,
        request_id: str,
        created_time: int,
        model_name: str,
462
        tokenizer: TokenizerLike | None,
463
        request_metadata: RequestResponseMetadata,
464
    ) -> CompletionResponse:
465
        choices: list[CompletionResponseChoice] = []
466
467
        num_prompt_tokens = 0
        num_generated_tokens = 0
468
469
        kv_transfer_params = None
        last_final_res = None
470
        for final_res in final_res_batch:
471
            last_final_res = final_res
472
            prompt_token_ids = final_res.prompt_token_ids
473
            assert prompt_token_ids is not None
474
            prompt_logprobs = clamp_prompt_logprobs(final_res.prompt_logprobs)
475
476
            prompt_text = final_res.prompt

477
            token_ids: GenericSequence[int]
478
            out_logprobs: GenericSequence[dict[int, Logprob] | None] | None
479

480
            for output in final_res.outputs:
481
482
                self._raise_if_error(output.finish_reason, request_id)

483
                assert request.max_tokens is not None
484
                if request.echo:
485
486
                    if request.return_token_ids:
                        prompt_text = ""
487
                    assert prompt_text is not None
488
489
490
491
                    if request.max_tokens == 0:
                        token_ids = prompt_token_ids
                        out_logprobs = prompt_logprobs
                        output_text = prompt_text
492
                    else:
493
494
495
496
497
498
499
500
501
502
503
504
505
                        token_ids = [*prompt_token_ids, *output.token_ids]

                        if request.logprobs is None:
                            out_logprobs = None
                        else:
                            assert prompt_logprobs is not None
                            assert output.logprobs is not None
                            out_logprobs = [
                                *prompt_logprobs,
                                *output.logprobs,
                            ]

                        output_text = prompt_text + output.text
506
507
                else:
                    token_ids = output.token_ids
508
                    out_logprobs = output.logprobs
509
510
511
                    output_text = output.text

                if request.logprobs is not None:
512
                    assert out_logprobs is not None, "Did not output logprobs"
513
                    logprobs = self._create_completion_logprobs(
514
                        token_ids=token_ids,
515
                        top_logprobs=out_logprobs,
516
                        tokenizer=tokenizer,
517
                        num_output_top_logprobs=request.logprobs,
518
                        return_as_token_id=request.return_tokens_as_token_ids,
519
520
521
522
523
524
525
526
527
                    )
                else:
                    logprobs = None

                choice_data = CompletionResponseChoice(
                    index=len(choices),
                    text=output_text,
                    logprobs=logprobs,
                    finish_reason=output.finish_reason,
528
                    stop_reason=output.stop_reason,
529
                    prompt_logprobs=final_res.prompt_logprobs,
530
531
532
533
534
535
                    prompt_token_ids=(
                        prompt_token_ids if request.return_token_ids else None
                    ),
                    token_ids=(
                        as_list(output.token_ids) if request.return_token_ids else None
                    ),
536
537
538
                )
                choices.append(choice_data)

539
540
                num_generated_tokens += len(output.token_ids)

541
542
543
544
545
546
547
548
            num_prompt_tokens += len(prompt_token_ids)

        usage = UsageInfo(
            prompt_tokens=num_prompt_tokens,
            completion_tokens=num_generated_tokens,
            total_tokens=num_prompt_tokens + num_generated_tokens,
        )

549
550
551
552
553
        if (
            self.enable_prompt_tokens_details
            and last_final_res
            and last_final_res.num_cached_tokens
        ):
554
            usage.prompt_tokens_details = PromptTokenUsageInfo(
555
556
                cached_tokens=last_final_res.num_cached_tokens
            )
557

558
        request_metadata.final_usage_info = usage
559
560
        if final_res_batch:
            kv_transfer_params = final_res_batch[0].kv_transfer_params
561
562
563
564
565
566
        return CompletionResponse(
            id=request_id,
            created=created_time,
            model=model_name,
            choices=choices,
            usage=usage,
567
568
            kv_transfer_params=kv_transfer_params,
        )
569
570
571
572

    def _create_completion_logprobs(
        self,
        token_ids: GenericSequence[int],
573
        top_logprobs: GenericSequence[dict[int, Logprob] | None],
574
        num_output_top_logprobs: int,
575
        tokenizer: TokenizerLike | None,
576
        initial_text_offset: int = 0,
577
        return_as_token_id: bool | None = None,
578
579
    ) -> CompletionLogProbs:
        """Create logprobs for OpenAI Completion API."""
580
        out_text_offset: list[int] = []
581
        out_token_logprobs: list[float | None] = []
582
        out_tokens: list[str] = []
583
        out_top_logprobs: list[dict[str, float] | None] = []
584
585
586

        last_token_len = 0

587
588
589
590
591
        should_return_as_token_id = (
            return_as_token_id
            if return_as_token_id is not None
            else self.return_tokens_as_token_ids
        )
592
593
594
        for i, token_id in enumerate(token_ids):
            step_top_logprobs = top_logprobs[i]
            if step_top_logprobs is None:
595
                if should_return_as_token_id:
596
                    token = f"token_id:{token_id}"
597
598
                else:
                    if tokenizer is None:
599
600
601
602
603
                        raise VLLMValidationError(
                            "Unable to get tokenizer because "
                            "`skip_tokenizer_init=True`",
                            parameter="skip_tokenizer_init",
                            value=True,
604
605
606
                        )

                    token = tokenizer.decode(token_id)
607

608
609
610
611
                out_tokens.append(token)
                out_token_logprobs.append(None)
                out_top_logprobs.append(None)
            else:
612
613
                step_token = step_top_logprobs[token_id]

614
                token = self._get_decoded_token(
615
                    step_token,
616
617
                    token_id,
                    tokenizer,
618
                    return_as_token_id=should_return_as_token_id,
619
620
621
                )
                token_logprob = max(step_token.logprob, -9999.0)

622
623
624
625
626
627
628
                out_tokens.append(token)
                out_token_logprobs.append(token_logprob)

                # makes sure to add the top num_output_top_logprobs + 1
                # logprobs, as defined in the openai API
                # (cf. https://github.com/openai/openai-openapi/blob/
                # 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153)
629
630
631
632
633
634
635
636
637
638
639
640
641
642
                out_top_logprobs.append(
                    {
                        # Convert float("-inf") to the
                        # JSON-serializable float that OpenAI uses
                        self._get_decoded_token(
                            top_lp[1],
                            top_lp[0],
                            tokenizer,
                            return_as_token_id=should_return_as_token_id,
                        ): max(top_lp[1].logprob, -9999.0)
                        for i, top_lp in enumerate(step_top_logprobs.items())
                        if num_output_top_logprobs >= i
                    }
                )
643
644
645
646
647
648
649
650
651
652
653
654
655

            if len(out_text_offset) == 0:
                out_text_offset.append(initial_text_offset)
            else:
                out_text_offset.append(out_text_offset[-1] + last_token_len)
            last_token_len = len(token)

        return CompletionLogProbs(
            text_offset=out_text_offset,
            token_logprobs=out_token_logprobs,
            tokens=out_tokens,
            top_logprobs=out_top_logprobs,
        )