"vllm/vscode:/vscode.git/clone" did not exist on "68be0f853ed0cb131468e1f9062b05d8d7a4ab34"
serving_completion.py 29.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
import time
6
7
from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import Sequence as GenericSequence
8
from typing import cast
9

10
import jinja2
11
from fastapi import Request
12

13
from vllm.engine.protocol import EngineClient
14
from vllm.entrypoints.logger import RequestLogger
15
16
17
18
19
20
21
22
23
24
25
26
from vllm.entrypoints.openai.protocol import (
    CompletionLogProbs,
    CompletionRequest,
    CompletionResponse,
    CompletionResponseChoice,
    CompletionResponseStreamChoice,
    CompletionStreamResponse,
    ErrorResponse,
    PromptTokenUsageInfo,
    RequestResponseMetadata,
    UsageInfo,
)
27
28
29
30
31
from vllm.entrypoints.openai.serving_engine import (
    GenerationError,
    OpenAIServing,
    clamp_prompt_logprobs,
)
32
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
33
from vllm.entrypoints.renderer import RenderConfig
34
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
35
from vllm.inputs.data import EmbedsPrompt, TokensPrompt, is_embeds_prompt
36
from vllm.logger import init_logger
37
from vllm.logprobs import Logprob
38
from vllm.outputs import RequestOutput
39
from vllm.sampling_params import BeamSearchParams, SamplingParams
40
from vllm.tokenizers import TokenizerLike
41
42
from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.collection_utils import as_list
43
from vllm.v1.sample.logits_processor import validate_logits_processors_parameters
44
45
46
47
48

logger = init_logger(__name__)


class OpenAIServingCompletion(OpenAIServing):
49
50
    def __init__(
        self,
51
        engine_client: EngineClient,
52
        models: OpenAIServingModels,
53
        *,
54
        request_logger: RequestLogger | None,
55
        return_tokens_as_token_ids: bool = False,
56
        enable_prompt_tokens_details: bool = False,
57
        enable_force_include_usage: bool = False,
58
        log_error_stack: bool = False,
59
    ):
60
61
62
63
64
        super().__init__(
            engine_client=engine_client,
            models=models,
            request_logger=request_logger,
            return_tokens_as_token_ids=return_tokens_as_token_ids,
65
            log_error_stack=log_error_stack,
66
        )
67
68
69
70

        # set up logits processors
        self.logits_processors = self.model_config.logits_processors

71
        self.enable_prompt_tokens_details = enable_prompt_tokens_details
72
        self.default_sampling_params = self.model_config.get_diff_sampling_param()
73
        self.enable_force_include_usage = enable_force_include_usage
74
        if self.default_sampling_params:
75
76
            source = self.model_config.generation_config
            source = "model" if source == "auto" else source
77
78
79
80
81
            logger.info(
                "Using default completion sampling params from %s: %s",
                source,
                self.default_sampling_params,
            )
82

83
84
85
    async def create_completion(
        self,
        request: CompletionRequest,
86
87
        raw_request: Request | None = None,
    ) -> AsyncGenerator[str, None] | CompletionResponse | ErrorResponse:
88
89
90
91
92
        """Completion API similar to OpenAI's API.

        See https://platform.openai.com/docs/api-reference/completions/create
        for the API specification. This API mimics the OpenAI Completion API.

93
        NOTE: Currently we do not support the following feature:
94
95
96
97
98
99
100
            - suffix (the language models we currently support do not support
            suffix)
        """
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

101
102
103
104
105
106
        # If the engine is dead, raise the engine's DEAD_ERROR.
        # This is required for the streaming case, where we return a
        # success status before we actually start generating text :).
        if self.engine_client.errored:
            raise self.engine_client.dead_error

107
        # Return error for unsupported features.
108
        if request.suffix is not None:
109
            return self.create_error_response("suffix is not currently supported")
110

111
        if request.echo and request.prompt_embeds is not None:
112
            return self.create_error_response("Echo is unsupported with prompt embeds.")
113

114
        if request.prompt_logprobs is not None and request.prompt_embeds is not None:
115
            return self.create_error_response(
116
117
                "prompt_logprobs is not compatible with prompt embeds."
            )
118

119
        request_id = f"cmpl-{self._base_request_id(raw_request, request.request_id)}"
120
        created_time = int(time.time())
121

122
123
124
125
        request_metadata = RequestResponseMetadata(request_id=request_id)
        if raw_request:
            raw_request.state.request_metadata = request_metadata

126
        try:
127
            lora_request = self._maybe_get_adapters(request)
128

129
            if self.model_config.skip_tokenizer_init:
130
131
                tokenizer = None
            else:
132
                tokenizer = await self.engine_client.get_tokenizer()
133
134
135
136
137
            renderer = self._get_renderer(tokenizer)

            engine_prompts = await renderer.render_prompt_and_embeds(
                prompt_or_prompts=request.prompt,
                prompt_embeds=request.prompt_embeds,
138
                config=self._build_render_config(request),
139
140
141
142
            )
        except ValueError as e:
            logger.exception("Error in preprocessing prompt inputs")
            return self.create_error_response(str(e))
143
144
145
146
147
148
149
150
151
        except TypeError as e:
            logger.exception("Error in preprocessing prompt inputs")
            return self.create_error_response(str(e))
        except RuntimeError as e:
            logger.exception("Error in preprocessing prompt inputs")
            return self.create_error_response(str(e))
        except jinja2.TemplateError as e:
            logger.exception("Error in preprocessing prompt inputs")
            return self.create_error_response(str(e))
152

153
154
155
        # Extract data_parallel_rank from header (router can inject it)
        data_parallel_rank = self._get_data_parallel_rank(raw_request)

156
        # Schedule the request and get the result generator.
157
        generators: list[AsyncGenerator[RequestOutput, None]] = []
158
159
        try:
            for i, engine_prompt in enumerate(engine_prompts):
160
                prompt_text, prompt_token_ids, prompt_embeds = (
161
162
                    self._get_prompt_components(engine_prompt)
                )
163
164
165
166
167
168

                input_length = None
                if prompt_token_ids is not None:
                    input_length = len(prompt_token_ids)
                elif prompt_embeds is not None:
                    input_length = len(prompt_embeds)
169
                else:
170
                    raise NotImplementedError
171
172
173
174
175
176
177
178

                if self.default_sampling_params is None:
                    self.default_sampling_params = {}

                max_tokens = get_max_tokens(
                    max_model_len=self.max_model_len,
                    request=request,
                    input_length=input_length,
179
180
                    default_sampling_params=self.default_sampling_params,
                )
181

182
                sampling_params: SamplingParams | BeamSearchParams
183
184
                if request.use_beam_search:
                    sampling_params = request.to_beam_search_params(
185
186
                        max_tokens, self.default_sampling_params
                    )
187
188
                else:
                    sampling_params = request.to_sampling_params(
189
190
191
192
                        max_tokens,
                        self.model_config.logits_processor_pattern,
                        self.default_sampling_params,
                    )
193
194
195
196
                    validate_logits_processors_parameters(
                        self.logits_processors,
                        sampling_params,
                    )
197

198
199
                request_id_item = f"{request_id}-{i}"

200
201
                self._log_inputs(
                    request_id_item,
202
                    engine_prompt,
203
204
205
                    params=sampling_params,
                    lora_request=lora_request,
                )
206

207
208
209
210
211
                trace_headers = (
                    None
                    if raw_request is None
                    else await self._get_trace_headers(raw_request.headers)
                )
212

213
214
215
                # Mypy inconsistently requires this second cast in different
                # environments. It shouldn't be necessary (redundant from above)
                # but pre-commit in CI fails without it.
216
                engine_prompt = cast(EmbedsPrompt | TokensPrompt, engine_prompt)
217
                if isinstance(sampling_params, BeamSearchParams):
218
                    generator = self.beam_search(
219
                        prompt=engine_prompt,
220
221
                        request_id=request_id,
                        params=sampling_params,
222
                        lora_request=lora_request,
223
                        trace_headers=trace_headers,
224
                    )
225
                else:
226
227
228
229
230
231
232
233
                    engine_request, tokenization_kwargs = await self._process_inputs(
                        request_id_item,
                        engine_prompt,
                        sampling_params,
                        lora_request=lora_request,
                        trace_headers=trace_headers,
                        priority=request.priority,
                    )
234

235
                    generator = self.engine_client.generate(
236
                        engine_request,
237
238
239
240
241
                        sampling_params,
                        request_id_item,
                        lora_request=lora_request,
                        trace_headers=trace_headers,
                        priority=request.priority,
242
243
                        prompt_text=prompt_text,
                        tokenization_kwargs=tokenization_kwargs,
244
                        data_parallel_rank=data_parallel_rank,
245
                    )
246
247

                generators.append(generator)
248
        except ValueError as e:
249
            # TODO: Use a vllm-specific Validation Error
250
            return self.create_error_response(str(e))
251

252
        result_generator = merge_async_iterators(*generators)
253

254
        model_name = self.models.model_name(lora_request)
255
256
        num_prompts = len(engine_prompts)

257
258
        # We do not stream the results when using beam search.
        stream = request.stream and not request.use_beam_search
259
260
261

        # Streaming response
        if stream:
262
263
            return self.completion_stream_generator(
                request,
264
                engine_prompts,
265
266
267
268
                result_generator,
                request_id,
                created_time,
                model_name,
269
                num_prompts=num_prompts,
270
                tokenizer=tokenizer,
271
                request_metadata=request_metadata,
272
            )
273
274

        # Non-streaming response
275
        final_res_batch: list[RequestOutput | None] = [None] * num_prompts
276
277
278
        try:
            async for i, res in result_generator:
                final_res_batch[i] = res
279
280
281
282
283
284
285
286

            for i, final_res in enumerate(final_res_batch):
                assert final_res is not None

                # The output should contain the input text
                # We did not pass it into vLLM engine to avoid being redundant
                # with the inputs token IDs
                if final_res.prompt is None:
287
                    engine_prompt = engine_prompts[i]
288
289
290
291
292
                    final_res.prompt = (
                        None
                        if is_embeds_prompt(engine_prompt)
                        else engine_prompt.get("prompt")
                    )
293

294
            final_res_batch_checked = cast(list[RequestOutput], final_res_batch)
295

296
            response = self.request_output_to_completion_response(
297
298
299
300
301
302
                final_res_batch_checked,
                request,
                request_id,
                created_time,
                model_name,
                tokenizer,
303
                request_metadata,
304
            )
305
306
        except asyncio.CancelledError:
            return self.create_error_response("Client disconnected")
307
308
        except GenerationError as e:
            return self._convert_generation_error_to_response(e)
309
310
311
        except ValueError as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))
312

313
314
        # When user requests streaming but we don't stream, we still need to
        # return a streaming response with a single event.
315
        if request.stream:
316
            response_json = response.model_dump_json()
317
318
319
320
321
322
323
324

            async def fake_stream_generator() -> AsyncGenerator[str, None]:
                yield f"data: {response_json}\n\n"
                yield "data: [DONE]\n\n"

            return fake_stream_generator()

        return response
325
326
327
328

    async def completion_stream_generator(
        self,
        request: CompletionRequest,
329
        engine_prompts: list[TokensPrompt | EmbedsPrompt],
330
        result_generator: AsyncIterator[tuple[int, RequestOutput]],
331
332
333
334
        request_id: str,
        created_time: int,
        model_name: str,
        num_prompts: int,
335
        tokenizer: TokenizerLike | None,
336
        request_metadata: RequestResponseMetadata,
337
    ) -> AsyncGenerator[str, None]:
338
        num_choices = 1 if request.n is None else request.n
339
        previous_text_lens = [0] * num_choices * num_prompts
340
341
        previous_num_tokens = [0] * num_choices * num_prompts
        has_echoed = [False] * num_choices * num_prompts
342
        num_prompt_tokens = [0] * num_prompts
343
344
        num_cached_tokens = None
        first_iteration = True
345

346
        stream_options = request.stream_options
347
348
349
        include_usage, include_continuous_usage = should_include_usage(
            stream_options, self.enable_force_include_usage
        )
350

351
352
        try:
            async for prompt_idx, res in result_generator:
353
354
                prompt_token_ids = res.prompt_token_ids
                prompt_logprobs = res.prompt_logprobs
355

356
357
358
359
                if first_iteration:
                    num_cached_tokens = res.num_cached_tokens
                    first_iteration = False

360
361
362
                prompt_text = res.prompt
                if prompt_text is None:
                    engine_prompt = engine_prompts[prompt_idx]
363
364
365
366
367
                    prompt_text = (
                        None
                        if is_embeds_prompt(engine_prompt)
                        else engine_prompt.get("prompt")
                    )
368

369
                # Prompt details are excluded from later streamed outputs
370
371
                if prompt_token_ids is not None:
                    num_prompt_tokens[prompt_idx] = len(prompt_token_ids)
372

373
                delta_token_ids: GenericSequence[int]
374
                out_logprobs: GenericSequence[dict[int, Logprob] | None] | None
375
376

                for output in res.outputs:
377
                    i = output.index + prompt_idx * num_choices
378

379
380
381
                    # Useful when request.return_token_ids is True
                    # Returning prompt token IDs shares the same logic
                    # with the echo implementation.
382
                    prompt_token_ids_to_return: list[int] | None = None
383

384
                    assert request.max_tokens is not None
385
                    if request.echo and not has_echoed[i]:
386
                        assert prompt_token_ids is not None
387
388
                        if request.return_token_ids:
                            prompt_text = ""
389
                        assert prompt_text is not None
390
391
392
393
394
395
396
397
398
                        if request.max_tokens == 0:
                            # only return the prompt
                            delta_text = prompt_text
                            delta_token_ids = prompt_token_ids
                            out_logprobs = prompt_logprobs
                        else:
                            # echo the prompt and first token
                            delta_text = prompt_text + output.text
                            delta_token_ids = [
399
400
                                *prompt_token_ids,
                                *output.token_ids,
401
402
                            ]
                            out_logprobs = [
403
                                *(prompt_logprobs or []),
404
405
                                *(output.logprobs or []),
                            ]
406
                        prompt_token_ids_to_return = prompt_token_ids
407
408
409
                        has_echoed[i] = True
                    else:
                        # return just the delta
410
411
412
                        delta_text = output.text
                        delta_token_ids = output.token_ids
                        out_logprobs = output.logprobs
413

414
415
                        # has_echoed[i] is reused here to indicate whether
                        # we have already returned the prompt token IDs.
416
                        if not has_echoed[i] and request.return_token_ids:
417
418
419
                            prompt_token_ids_to_return = prompt_token_ids
                            has_echoed[i] = True

420
421
422
423
424
                        if (
                            not delta_text
                            and not delta_token_ids
                            and not previous_num_tokens[i]
                        ):
425
426
427
                            # Chunked prefill case, don't return empty chunks
                            continue

428
                    if request.logprobs is not None:
429
                        assert out_logprobs is not None, "Did not output logprobs"
430
                        logprobs = self._create_completion_logprobs(
431
                            token_ids=delta_token_ids,
432
                            top_logprobs=out_logprobs,
433
                            num_output_top_logprobs=request.logprobs,
434
                            tokenizer=tokenizer,
435
                            initial_text_offset=previous_text_lens[i],
436
                            return_as_token_id=request.return_tokens_as_token_ids,
437
438
439
440
                        )
                    else:
                        logprobs = None

441
442
                    previous_text_lens[i] += len(output.text)
                    previous_num_tokens[i] += len(output.token_ids)
443
                    finish_reason = output.finish_reason
444
                    stop_reason = output.stop_reason
445

446
447
                    self._raise_if_error(finish_reason, request_id)

448
                    chunk = CompletionStreamResponse(
449
450
451
452
453
454
455
456
457
                        id=request_id,
                        created=created_time,
                        model=model_name,
                        choices=[
                            CompletionResponseStreamChoice(
                                index=i,
                                text=delta_text,
                                logprobs=logprobs,
                                finish_reason=finish_reason,
458
                                stop_reason=stop_reason,
459
                                prompt_token_ids=prompt_token_ids_to_return,
460
461
462
463
464
                                token_ids=(
                                    as_list(output.token_ids)
                                    if request.return_token_ids
                                    else None
                                ),
465
                            )
466
467
                        ],
                    )
468
469
470
471
472
473
474
475
                    if include_continuous_usage:
                        prompt_tokens = num_prompt_tokens[prompt_idx]
                        completion_tokens = previous_num_tokens[i]
                        chunk.usage = UsageInfo(
                            prompt_tokens=prompt_tokens,
                            completion_tokens=completion_tokens,
                            total_tokens=prompt_tokens + completion_tokens,
                        )
476

477
                    response_json = chunk.model_dump_json(exclude_unset=False)
478
                    yield f"data: {response_json}\n\n"
479

480
481
482
483
484
            total_prompt_tokens = sum(num_prompt_tokens)
            total_completion_tokens = sum(previous_num_tokens)
            final_usage_info = UsageInfo(
                prompt_tokens=total_prompt_tokens,
                completion_tokens=total_completion_tokens,
485
486
                total_tokens=total_prompt_tokens + total_completion_tokens,
            )
487

488
489
            if self.enable_prompt_tokens_details and num_cached_tokens:
                final_usage_info.prompt_tokens_details = PromptTokenUsageInfo(
490
491
                    cached_tokens=num_cached_tokens
                )
492

493
            if include_usage:
494
495
496
497
498
                final_usage_chunk = CompletionStreamResponse(
                    id=request_id,
                    created=created_time,
                    model=model_name,
                    choices=[],
499
                    usage=final_usage_info,
500
                )
501
                final_usage_data = final_usage_chunk.model_dump_json(
502
503
                    exclude_unset=False, exclude_none=True
                )
504
505
                yield f"data: {final_usage_data}\n\n"

506
            # report to FastAPI middleware aggregate usage across all choices
507
            request_metadata.final_usage_info = final_usage_info
508

509
510
        except GenerationError as e:
            yield f"data: {self._convert_generation_error_to_streaming_response(e)}\n\n"
511
        except Exception as e:
512
            # TODO: Use a vllm-specific Validation Error
513
            logger.exception("Error in completion stream generator.")
514
515
516
517
518
519
            data = self.create_streaming_error_response(str(e))
            yield f"data: {data}\n\n"
        yield "data: [DONE]\n\n"

    def request_output_to_completion_response(
        self,
520
        final_res_batch: list[RequestOutput],
521
522
523
524
        request: CompletionRequest,
        request_id: str,
        created_time: int,
        model_name: str,
525
        tokenizer: TokenizerLike | None,
526
        request_metadata: RequestResponseMetadata,
527
    ) -> CompletionResponse:
528
        choices: list[CompletionResponseChoice] = []
529
530
        num_prompt_tokens = 0
        num_generated_tokens = 0
531
532
        kv_transfer_params = None
        last_final_res = None
533
        for final_res in final_res_batch:
534
            last_final_res = final_res
535
            prompt_token_ids = final_res.prompt_token_ids
536
            assert prompt_token_ids is not None
537
            prompt_logprobs = clamp_prompt_logprobs(final_res.prompt_logprobs)
538
539
            prompt_text = final_res.prompt

540
            token_ids: GenericSequence[int]
541
            out_logprobs: GenericSequence[dict[int, Logprob] | None] | None
542

543
            for output in final_res.outputs:
544
545
                self._raise_if_error(output.finish_reason, request_id)

546
                assert request.max_tokens is not None
547
                if request.echo:
548
549
                    if request.return_token_ids:
                        prompt_text = ""
550
                    assert prompt_text is not None
551
552
553
554
                    if request.max_tokens == 0:
                        token_ids = prompt_token_ids
                        out_logprobs = prompt_logprobs
                        output_text = prompt_text
555
                    else:
556
557
558
559
560
561
562
563
564
565
566
567
568
                        token_ids = [*prompt_token_ids, *output.token_ids]

                        if request.logprobs is None:
                            out_logprobs = None
                        else:
                            assert prompt_logprobs is not None
                            assert output.logprobs is not None
                            out_logprobs = [
                                *prompt_logprobs,
                                *output.logprobs,
                            ]

                        output_text = prompt_text + output.text
569
570
                else:
                    token_ids = output.token_ids
571
                    out_logprobs = output.logprobs
572
573
574
                    output_text = output.text

                if request.logprobs is not None:
575
                    assert out_logprobs is not None, "Did not output logprobs"
576
                    logprobs = self._create_completion_logprobs(
577
                        token_ids=token_ids,
578
                        top_logprobs=out_logprobs,
579
                        tokenizer=tokenizer,
580
                        num_output_top_logprobs=request.logprobs,
581
                        return_as_token_id=request.return_tokens_as_token_ids,
582
583
584
585
586
587
588
589
590
                    )
                else:
                    logprobs = None

                choice_data = CompletionResponseChoice(
                    index=len(choices),
                    text=output_text,
                    logprobs=logprobs,
                    finish_reason=output.finish_reason,
591
                    stop_reason=output.stop_reason,
592
                    prompt_logprobs=final_res.prompt_logprobs,
593
594
595
596
597
598
                    prompt_token_ids=(
                        prompt_token_ids if request.return_token_ids else None
                    ),
                    token_ids=(
                        as_list(output.token_ids) if request.return_token_ids else None
                    ),
599
600
601
                )
                choices.append(choice_data)

602
603
                num_generated_tokens += len(output.token_ids)

604
605
606
607
608
609
610
611
            num_prompt_tokens += len(prompt_token_ids)

        usage = UsageInfo(
            prompt_tokens=num_prompt_tokens,
            completion_tokens=num_generated_tokens,
            total_tokens=num_prompt_tokens + num_generated_tokens,
        )

612
613
614
615
616
        if (
            self.enable_prompt_tokens_details
            and last_final_res
            and last_final_res.num_cached_tokens
        ):
617
            usage.prompt_tokens_details = PromptTokenUsageInfo(
618
619
                cached_tokens=last_final_res.num_cached_tokens
            )
620

621
        request_metadata.final_usage_info = usage
622
623
        if final_res_batch:
            kv_transfer_params = final_res_batch[0].kv_transfer_params
624
625
626
627
628
629
        return CompletionResponse(
            id=request_id,
            created=created_time,
            model=model_name,
            choices=choices,
            usage=usage,
630
631
            kv_transfer_params=kv_transfer_params,
        )
632
633
634
635

    def _create_completion_logprobs(
        self,
        token_ids: GenericSequence[int],
636
        top_logprobs: GenericSequence[dict[int, Logprob] | None],
637
        num_output_top_logprobs: int,
638
        tokenizer: TokenizerLike | None,
639
        initial_text_offset: int = 0,
640
        return_as_token_id: bool | None = None,
641
642
    ) -> CompletionLogProbs:
        """Create logprobs for OpenAI Completion API."""
643
        out_text_offset: list[int] = []
644
        out_token_logprobs: list[float | None] = []
645
        out_tokens: list[str] = []
646
        out_top_logprobs: list[dict[str, float] | None] = []
647
648
649

        last_token_len = 0

650
651
652
653
654
        should_return_as_token_id = (
            return_as_token_id
            if return_as_token_id is not None
            else self.return_tokens_as_token_ids
        )
655
656
657
        for i, token_id in enumerate(token_ids):
            step_top_logprobs = top_logprobs[i]
            if step_top_logprobs is None:
658
                if should_return_as_token_id:
659
                    token = f"token_id:{token_id}"
660
661
662
663
664
665
666
                else:
                    if tokenizer is None:
                        raise ValueError(
                            "Unable to get tokenizer because `skip_tokenizer_init=True`"
                        )

                    token = tokenizer.decode(token_id)
667

668
669
670
671
                out_tokens.append(token)
                out_token_logprobs.append(None)
                out_top_logprobs.append(None)
            else:
672
673
                step_token = step_top_logprobs[token_id]

674
                token = self._get_decoded_token(
675
                    step_token,
676
677
                    token_id,
                    tokenizer,
678
                    return_as_token_id=should_return_as_token_id,
679
680
681
                )
                token_logprob = max(step_token.logprob, -9999.0)

682
683
684
685
686
687
688
                out_tokens.append(token)
                out_token_logprobs.append(token_logprob)

                # makes sure to add the top num_output_top_logprobs + 1
                # logprobs, as defined in the openai API
                # (cf. https://github.com/openai/openai-openapi/blob/
                # 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153)
689
690
691
692
693
694
695
696
697
698
699
700
701
702
                out_top_logprobs.append(
                    {
                        # Convert float("-inf") to the
                        # JSON-serializable float that OpenAI uses
                        self._get_decoded_token(
                            top_lp[1],
                            top_lp[0],
                            tokenizer,
                            return_as_token_id=should_return_as_token_id,
                        ): max(top_lp[1].logprob, -9999.0)
                        for i, top_lp in enumerate(step_top_logprobs.items())
                        if num_output_top_logprobs >= i
                    }
                )
703
704
705
706
707
708
709
710
711
712
713
714
715

            if len(out_text_offset) == 0:
                out_text_offset.append(initial_text_offset)
            else:
                out_text_offset.append(out_text_offset[-1] + last_token_len)
            last_token_len = len(token)

        return CompletionLogProbs(
            text_offset=out_text_offset,
            token_logprobs=out_token_logprobs,
            tokens=out_tokens,
            top_logprobs=out_top_logprobs,
        )
716
717
718
719

    def _build_render_config(
        self,
        request: CompletionRequest,
720
        max_input_length: int | None = None,
721
722
723
724
725
726
727
    ) -> RenderConfig:
        max_input_tokens_len = self.max_model_len - (request.max_tokens or 0)
        return RenderConfig(
            max_length=max_input_tokens_len,
            truncate_prompt_tokens=request.truncate_prompt_tokens,
            add_special_tokens=request.add_special_tokens,
            cache_salt=request.cache_salt,
728
            needs_detokenization=bool(request.echo and not request.return_token_ids),
729
        )