serving_completion.py 29.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
import time
6
7
from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import Sequence as GenericSequence
8
from typing import cast
9

10
import jinja2
11
from fastapi import Request
12

13
from vllm.engine.protocol import EngineClient
14
from vllm.entrypoints.logger import RequestLogger
15
16
17
18
19
20
21
22
23
24
25
26
from vllm.entrypoints.openai.protocol import (
    CompletionLogProbs,
    CompletionRequest,
    CompletionResponse,
    CompletionResponseChoice,
    CompletionResponseStreamChoice,
    CompletionStreamResponse,
    ErrorResponse,
    PromptTokenUsageInfo,
    RequestResponseMetadata,
    UsageInfo,
)
27
28
29
30
31
from vllm.entrypoints.openai.serving_engine import (
    GenerationError,
    OpenAIServing,
    clamp_prompt_logprobs,
)
32
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
33
from vllm.entrypoints.renderer import RenderConfig
34
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
35
from vllm.inputs.data import EmbedsPrompt, TokensPrompt, is_embeds_prompt
36
from vllm.logger import init_logger
37
from vllm.logprobs import Logprob
38
from vllm.outputs import RequestOutput
39
from vllm.sampling_params import BeamSearchParams, SamplingParams
40
from vllm.tokenizers import TokenizerLike
41
42
from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.collection_utils import as_list
43
from vllm.v1.sample.logits_processor import validate_logits_processors_parameters
44
45
46
47
48

logger = init_logger(__name__)


class OpenAIServingCompletion(OpenAIServing):
49
50
    def __init__(
        self,
51
        engine_client: EngineClient,
52
        models: OpenAIServingModels,
53
        *,
54
        request_logger: RequestLogger | None,
55
        return_tokens_as_token_ids: bool = False,
56
        enable_prompt_tokens_details: bool = False,
57
        enable_force_include_usage: bool = False,
58
        log_error_stack: bool = False,
59
    ):
60
61
62
63
64
        super().__init__(
            engine_client=engine_client,
            models=models,
            request_logger=request_logger,
            return_tokens_as_token_ids=return_tokens_as_token_ids,
65
            log_error_stack=log_error_stack,
66
        )
67
68
69
70

        # set up logits processors
        self.logits_processors = self.model_config.logits_processors

71
        self.enable_prompt_tokens_details = enable_prompt_tokens_details
72
        self.default_sampling_params = self.model_config.get_diff_sampling_param()
73
        self.enable_force_include_usage = enable_force_include_usage
74
        if self.default_sampling_params:
75
76
            source = self.model_config.generation_config
            source = "model" if source == "auto" else source
77
78
79
80
81
            logger.info(
                "Using default completion sampling params from %s: %s",
                source,
                self.default_sampling_params,
            )
82

83
84
85
    async def create_completion(
        self,
        request: CompletionRequest,
86
87
        raw_request: Request | None = None,
    ) -> AsyncGenerator[str, None] | CompletionResponse | ErrorResponse:
88
89
90
91
92
        """Completion API similar to OpenAI's API.

        See https://platform.openai.com/docs/api-reference/completions/create
        for the API specification. This API mimics the OpenAI Completion API.

93
        NOTE: Currently we do not support the following feature:
94
95
96
97
98
99
100
            - suffix (the language models we currently support do not support
            suffix)
        """
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

101
102
103
104
105
106
        # If the engine is dead, raise the engine's DEAD_ERROR.
        # This is required for the streaming case, where we return a
        # success status before we actually start generating text :).
        if self.engine_client.errored:
            raise self.engine_client.dead_error

107
        # Return error for unsupported features.
108
        if request.suffix is not None:
109
            return self.create_error_response("suffix is not currently supported")
110

111
        if request.echo and request.prompt_embeds is not None:
112
            return self.create_error_response("Echo is unsupported with prompt embeds.")
113

114
        if request.prompt_logprobs is not None and request.prompt_embeds is not None:
115
            return self.create_error_response(
116
117
                "prompt_logprobs is not compatible with prompt embeds."
            )
118

119
        request_id = f"cmpl-{self._base_request_id(raw_request, request.request_id)}"
120
        created_time = int(time.time())
121

122
123
124
125
        request_metadata = RequestResponseMetadata(request_id=request_id)
        if raw_request:
            raw_request.state.request_metadata = request_metadata

126
        try:
127
            lora_request = self._maybe_get_adapters(request)
128

129
            if self.model_config.skip_tokenizer_init:
130
131
                tokenizer = None
            else:
132
                tokenizer = await self.engine_client.get_tokenizer()
133
134
135
136
137
            renderer = self._get_renderer(tokenizer)

            engine_prompts = await renderer.render_prompt_and_embeds(
                prompt_or_prompts=request.prompt,
                prompt_embeds=request.prompt_embeds,
138
                config=self._build_render_config(request),
139
140
141
142
            )
        except ValueError as e:
            logger.exception("Error in preprocessing prompt inputs")
            return self.create_error_response(str(e))
143
144
145
146
147
148
149
150
151
        except TypeError as e:
            logger.exception("Error in preprocessing prompt inputs")
            return self.create_error_response(str(e))
        except RuntimeError as e:
            logger.exception("Error in preprocessing prompt inputs")
            return self.create_error_response(str(e))
        except jinja2.TemplateError as e:
            logger.exception("Error in preprocessing prompt inputs")
            return self.create_error_response(str(e))
152

153
154
155
        # Extract data_parallel_rank from header (router can inject it)
        data_parallel_rank = self._get_data_parallel_rank(raw_request)

156
        # Schedule the request and get the result generator.
157
        generators: list[AsyncGenerator[RequestOutput, None]] = []
158
159
        try:
            for i, engine_prompt in enumerate(engine_prompts):
160
                prompt_text, prompt_token_ids, prompt_embeds = (
161
162
                    self._get_prompt_components(engine_prompt)
                )
163
164
165
166
167
168

                input_length = None
                if prompt_token_ids is not None:
                    input_length = len(prompt_token_ids)
                elif prompt_embeds is not None:
                    input_length = len(prompt_embeds)
169
                else:
170
                    raise NotImplementedError
171
172
173
174
175
176
177
178

                if self.default_sampling_params is None:
                    self.default_sampling_params = {}

                max_tokens = get_max_tokens(
                    max_model_len=self.max_model_len,
                    request=request,
                    input_length=input_length,
179
180
                    default_sampling_params=self.default_sampling_params,
                )
181

182
                sampling_params: SamplingParams | BeamSearchParams
183
184
                if request.use_beam_search:
                    sampling_params = request.to_beam_search_params(
185
186
                        max_tokens, self.default_sampling_params
                    )
187
188
                else:
                    sampling_params = request.to_sampling_params(
189
190
191
192
                        max_tokens,
                        self.model_config.logits_processor_pattern,
                        self.default_sampling_params,
                    )
193
194
195
196
                    validate_logits_processors_parameters(
                        self.logits_processors,
                        sampling_params,
                    )
197

198
199
                request_id_item = f"{request_id}-{i}"

200
201
                self._log_inputs(
                    request_id_item,
202
                    engine_prompt,
203
204
205
                    params=sampling_params,
                    lora_request=lora_request,
                )
206

207
208
209
210
211
                trace_headers = (
                    None
                    if raw_request is None
                    else await self._get_trace_headers(raw_request.headers)
                )
212

213
214
215
                # Mypy inconsistently requires this second cast in different
                # environments. It shouldn't be necessary (redundant from above)
                # but pre-commit in CI fails without it.
216
                engine_prompt = cast(EmbedsPrompt | TokensPrompt, engine_prompt)
217
                if isinstance(sampling_params, BeamSearchParams):
218
                    generator = self.beam_search(
219
                        prompt=engine_prompt,
220
221
                        request_id=request_id,
                        params=sampling_params,
222
                        lora_request=lora_request,
223
                        trace_headers=trace_headers,
224
                    )
225
                else:
226
227
228
229
230
231
232
                    engine_request, tokenization_kwargs = await self._process_inputs(
                        request_id_item,
                        engine_prompt,
                        sampling_params,
                        lora_request=lora_request,
                        trace_headers=trace_headers,
                        priority=request.priority,
233
                        data_parallel_rank=data_parallel_rank,
234
                    )
235

236
                    generator = self.engine_client.generate(
237
                        engine_request,
238
239
240
241
242
                        sampling_params,
                        request_id_item,
                        lora_request=lora_request,
                        trace_headers=trace_headers,
                        priority=request.priority,
243
244
                        prompt_text=prompt_text,
                        tokenization_kwargs=tokenization_kwargs,
245
                        data_parallel_rank=data_parallel_rank,
246
                    )
247
248

                generators.append(generator)
249
        except ValueError as e:
250
            # TODO: Use a vllm-specific Validation Error
251
            return self.create_error_response(str(e))
252

253
        result_generator = merge_async_iterators(*generators)
254

255
        model_name = self.models.model_name(lora_request)
256
257
        num_prompts = len(engine_prompts)

258
259
        # We do not stream the results when using beam search.
        stream = request.stream and not request.use_beam_search
260
261
262

        # Streaming response
        if stream:
263
264
            return self.completion_stream_generator(
                request,
265
                engine_prompts,
266
267
268
269
                result_generator,
                request_id,
                created_time,
                model_name,
270
                num_prompts=num_prompts,
271
                tokenizer=tokenizer,
272
                request_metadata=request_metadata,
273
            )
274
275

        # Non-streaming response
276
        final_res_batch: list[RequestOutput | None] = [None] * num_prompts
277
278
279
        try:
            async for i, res in result_generator:
                final_res_batch[i] = res
280
281
282
283
284
285
286
287

            for i, final_res in enumerate(final_res_batch):
                assert final_res is not None

                # The output should contain the input text
                # We did not pass it into vLLM engine to avoid being redundant
                # with the inputs token IDs
                if final_res.prompt is None:
288
                    engine_prompt = engine_prompts[i]
289
290
291
292
293
                    final_res.prompt = (
                        None
                        if is_embeds_prompt(engine_prompt)
                        else engine_prompt.get("prompt")
                    )
294

295
            final_res_batch_checked = cast(list[RequestOutput], final_res_batch)
296

297
            response = self.request_output_to_completion_response(
298
299
300
301
302
303
                final_res_batch_checked,
                request,
                request_id,
                created_time,
                model_name,
                tokenizer,
304
                request_metadata,
305
            )
306
307
        except asyncio.CancelledError:
            return self.create_error_response("Client disconnected")
308
309
        except GenerationError as e:
            return self._convert_generation_error_to_response(e)
310
311
312
        except ValueError as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))
313

314
315
        # When user requests streaming but we don't stream, we still need to
        # return a streaming response with a single event.
316
        if request.stream:
317
            response_json = response.model_dump_json()
318
319
320
321
322
323
324
325

            async def fake_stream_generator() -> AsyncGenerator[str, None]:
                yield f"data: {response_json}\n\n"
                yield "data: [DONE]\n\n"

            return fake_stream_generator()

        return response
326
327
328
329

    async def completion_stream_generator(
        self,
        request: CompletionRequest,
330
        engine_prompts: list[TokensPrompt | EmbedsPrompt],
331
        result_generator: AsyncIterator[tuple[int, RequestOutput]],
332
333
334
335
        request_id: str,
        created_time: int,
        model_name: str,
        num_prompts: int,
336
        tokenizer: TokenizerLike | None,
337
        request_metadata: RequestResponseMetadata,
338
    ) -> AsyncGenerator[str, None]:
339
        num_choices = 1 if request.n is None else request.n
340
        previous_text_lens = [0] * num_choices * num_prompts
341
342
        previous_num_tokens = [0] * num_choices * num_prompts
        has_echoed = [False] * num_choices * num_prompts
343
        num_prompt_tokens = [0] * num_prompts
344
345
        num_cached_tokens = None
        first_iteration = True
346

347
        stream_options = request.stream_options
348
349
350
        include_usage, include_continuous_usage = should_include_usage(
            stream_options, self.enable_force_include_usage
        )
351

352
353
        try:
            async for prompt_idx, res in result_generator:
354
355
                prompt_token_ids = res.prompt_token_ids
                prompt_logprobs = res.prompt_logprobs
356

357
358
359
360
                if first_iteration:
                    num_cached_tokens = res.num_cached_tokens
                    first_iteration = False

361
362
363
                prompt_text = res.prompt
                if prompt_text is None:
                    engine_prompt = engine_prompts[prompt_idx]
364
365
366
367
368
                    prompt_text = (
                        None
                        if is_embeds_prompt(engine_prompt)
                        else engine_prompt.get("prompt")
                    )
369

370
                # Prompt details are excluded from later streamed outputs
371
372
                if prompt_token_ids is not None:
                    num_prompt_tokens[prompt_idx] = len(prompt_token_ids)
373

374
                delta_token_ids: GenericSequence[int]
375
                out_logprobs: GenericSequence[dict[int, Logprob] | None] | None
376
377

                for output in res.outputs:
378
                    i = output.index + prompt_idx * num_choices
379

380
381
382
                    # Useful when request.return_token_ids is True
                    # Returning prompt token IDs shares the same logic
                    # with the echo implementation.
383
                    prompt_token_ids_to_return: list[int] | None = None
384

385
                    assert request.max_tokens is not None
386
                    if request.echo and not has_echoed[i]:
387
                        assert prompt_token_ids is not None
388
389
                        if request.return_token_ids:
                            prompt_text = ""
390
                        assert prompt_text is not None
391
392
393
394
395
396
397
398
399
                        if request.max_tokens == 0:
                            # only return the prompt
                            delta_text = prompt_text
                            delta_token_ids = prompt_token_ids
                            out_logprobs = prompt_logprobs
                        else:
                            # echo the prompt and first token
                            delta_text = prompt_text + output.text
                            delta_token_ids = [
400
401
                                *prompt_token_ids,
                                *output.token_ids,
402
403
                            ]
                            out_logprobs = [
404
                                *(prompt_logprobs or []),
405
406
                                *(output.logprobs or []),
                            ]
407
                        prompt_token_ids_to_return = prompt_token_ids
408
409
410
                        has_echoed[i] = True
                    else:
                        # return just the delta
411
412
413
                        delta_text = output.text
                        delta_token_ids = output.token_ids
                        out_logprobs = output.logprobs
414

415
416
                        # has_echoed[i] is reused here to indicate whether
                        # we have already returned the prompt token IDs.
417
                        if not has_echoed[i] and request.return_token_ids:
418
419
420
                            prompt_token_ids_to_return = prompt_token_ids
                            has_echoed[i] = True

421
422
423
424
425
                        if (
                            not delta_text
                            and not delta_token_ids
                            and not previous_num_tokens[i]
                        ):
426
427
428
                            # Chunked prefill case, don't return empty chunks
                            continue

429
                    if request.logprobs is not None:
430
                        assert out_logprobs is not None, "Did not output logprobs"
431
                        logprobs = self._create_completion_logprobs(
432
                            token_ids=delta_token_ids,
433
                            top_logprobs=out_logprobs,
434
                            num_output_top_logprobs=request.logprobs,
435
                            tokenizer=tokenizer,
436
                            initial_text_offset=previous_text_lens[i],
437
                            return_as_token_id=request.return_tokens_as_token_ids,
438
439
440
441
                        )
                    else:
                        logprobs = None

442
443
                    previous_text_lens[i] += len(output.text)
                    previous_num_tokens[i] += len(output.token_ids)
444
                    finish_reason = output.finish_reason
445
                    stop_reason = output.stop_reason
446

447
448
                    self._raise_if_error(finish_reason, request_id)

449
                    chunk = CompletionStreamResponse(
450
451
452
453
454
455
456
457
458
                        id=request_id,
                        created=created_time,
                        model=model_name,
                        choices=[
                            CompletionResponseStreamChoice(
                                index=i,
                                text=delta_text,
                                logprobs=logprobs,
                                finish_reason=finish_reason,
459
                                stop_reason=stop_reason,
460
                                prompt_token_ids=prompt_token_ids_to_return,
461
462
463
464
465
                                token_ids=(
                                    as_list(output.token_ids)
                                    if request.return_token_ids
                                    else None
                                ),
466
                            )
467
468
                        ],
                    )
469
470
471
472
473
474
475
476
                    if include_continuous_usage:
                        prompt_tokens = num_prompt_tokens[prompt_idx]
                        completion_tokens = previous_num_tokens[i]
                        chunk.usage = UsageInfo(
                            prompt_tokens=prompt_tokens,
                            completion_tokens=completion_tokens,
                            total_tokens=prompt_tokens + completion_tokens,
                        )
477

478
                    response_json = chunk.model_dump_json(exclude_unset=False)
479
                    yield f"data: {response_json}\n\n"
480

481
482
483
484
485
            total_prompt_tokens = sum(num_prompt_tokens)
            total_completion_tokens = sum(previous_num_tokens)
            final_usage_info = UsageInfo(
                prompt_tokens=total_prompt_tokens,
                completion_tokens=total_completion_tokens,
486
487
                total_tokens=total_prompt_tokens + total_completion_tokens,
            )
488

489
490
            if self.enable_prompt_tokens_details and num_cached_tokens:
                final_usage_info.prompt_tokens_details = PromptTokenUsageInfo(
491
492
                    cached_tokens=num_cached_tokens
                )
493

494
            if include_usage:
495
496
497
498
499
                final_usage_chunk = CompletionStreamResponse(
                    id=request_id,
                    created=created_time,
                    model=model_name,
                    choices=[],
500
                    usage=final_usage_info,
501
                )
502
                final_usage_data = final_usage_chunk.model_dump_json(
503
504
                    exclude_unset=False, exclude_none=True
                )
505
506
                yield f"data: {final_usage_data}\n\n"

507
            # report to FastAPI middleware aggregate usage across all choices
508
            request_metadata.final_usage_info = final_usage_info
509

510
511
        except GenerationError as e:
            yield f"data: {self._convert_generation_error_to_streaming_response(e)}\n\n"
512
        except Exception as e:
513
            # TODO: Use a vllm-specific Validation Error
514
            logger.exception("Error in completion stream generator.")
515
516
517
518
519
520
            data = self.create_streaming_error_response(str(e))
            yield f"data: {data}\n\n"
        yield "data: [DONE]\n\n"

    def request_output_to_completion_response(
        self,
521
        final_res_batch: list[RequestOutput],
522
523
524
525
        request: CompletionRequest,
        request_id: str,
        created_time: int,
        model_name: str,
526
        tokenizer: TokenizerLike | None,
527
        request_metadata: RequestResponseMetadata,
528
    ) -> CompletionResponse:
529
        choices: list[CompletionResponseChoice] = []
530
531
        num_prompt_tokens = 0
        num_generated_tokens = 0
532
533
        kv_transfer_params = None
        last_final_res = None
534
        for final_res in final_res_batch:
535
            last_final_res = final_res
536
            prompt_token_ids = final_res.prompt_token_ids
537
            assert prompt_token_ids is not None
538
            prompt_logprobs = clamp_prompt_logprobs(final_res.prompt_logprobs)
539
540
            prompt_text = final_res.prompt

541
            token_ids: GenericSequence[int]
542
            out_logprobs: GenericSequence[dict[int, Logprob] | None] | None
543

544
            for output in final_res.outputs:
545
546
                self._raise_if_error(output.finish_reason, request_id)

547
                assert request.max_tokens is not None
548
                if request.echo:
549
550
                    if request.return_token_ids:
                        prompt_text = ""
551
                    assert prompt_text is not None
552
553
554
555
                    if request.max_tokens == 0:
                        token_ids = prompt_token_ids
                        out_logprobs = prompt_logprobs
                        output_text = prompt_text
556
                    else:
557
558
559
560
561
562
563
564
565
566
567
568
569
                        token_ids = [*prompt_token_ids, *output.token_ids]

                        if request.logprobs is None:
                            out_logprobs = None
                        else:
                            assert prompt_logprobs is not None
                            assert output.logprobs is not None
                            out_logprobs = [
                                *prompt_logprobs,
                                *output.logprobs,
                            ]

                        output_text = prompt_text + output.text
570
571
                else:
                    token_ids = output.token_ids
572
                    out_logprobs = output.logprobs
573
574
575
                    output_text = output.text

                if request.logprobs is not None:
576
                    assert out_logprobs is not None, "Did not output logprobs"
577
                    logprobs = self._create_completion_logprobs(
578
                        token_ids=token_ids,
579
                        top_logprobs=out_logprobs,
580
                        tokenizer=tokenizer,
581
                        num_output_top_logprobs=request.logprobs,
582
                        return_as_token_id=request.return_tokens_as_token_ids,
583
584
585
586
587
588
589
590
591
                    )
                else:
                    logprobs = None

                choice_data = CompletionResponseChoice(
                    index=len(choices),
                    text=output_text,
                    logprobs=logprobs,
                    finish_reason=output.finish_reason,
592
                    stop_reason=output.stop_reason,
593
                    prompt_logprobs=final_res.prompt_logprobs,
594
595
596
597
598
599
                    prompt_token_ids=(
                        prompt_token_ids if request.return_token_ids else None
                    ),
                    token_ids=(
                        as_list(output.token_ids) if request.return_token_ids else None
                    ),
600
601
602
                )
                choices.append(choice_data)

603
604
                num_generated_tokens += len(output.token_ids)

605
606
607
608
609
610
611
612
            num_prompt_tokens += len(prompt_token_ids)

        usage = UsageInfo(
            prompt_tokens=num_prompt_tokens,
            completion_tokens=num_generated_tokens,
            total_tokens=num_prompt_tokens + num_generated_tokens,
        )

613
614
615
616
617
        if (
            self.enable_prompt_tokens_details
            and last_final_res
            and last_final_res.num_cached_tokens
        ):
618
            usage.prompt_tokens_details = PromptTokenUsageInfo(
619
620
                cached_tokens=last_final_res.num_cached_tokens
            )
621

622
        request_metadata.final_usage_info = usage
623
624
        if final_res_batch:
            kv_transfer_params = final_res_batch[0].kv_transfer_params
625
626
627
628
629
630
        return CompletionResponse(
            id=request_id,
            created=created_time,
            model=model_name,
            choices=choices,
            usage=usage,
631
632
            kv_transfer_params=kv_transfer_params,
        )
633
634
635
636

    def _create_completion_logprobs(
        self,
        token_ids: GenericSequence[int],
637
        top_logprobs: GenericSequence[dict[int, Logprob] | None],
638
        num_output_top_logprobs: int,
639
        tokenizer: TokenizerLike | None,
640
        initial_text_offset: int = 0,
641
        return_as_token_id: bool | None = None,
642
643
    ) -> CompletionLogProbs:
        """Create logprobs for OpenAI Completion API."""
644
        out_text_offset: list[int] = []
645
        out_token_logprobs: list[float | None] = []
646
        out_tokens: list[str] = []
647
        out_top_logprobs: list[dict[str, float] | None] = []
648
649
650

        last_token_len = 0

651
652
653
654
655
        should_return_as_token_id = (
            return_as_token_id
            if return_as_token_id is not None
            else self.return_tokens_as_token_ids
        )
656
657
658
        for i, token_id in enumerate(token_ids):
            step_top_logprobs = top_logprobs[i]
            if step_top_logprobs is None:
659
                if should_return_as_token_id:
660
                    token = f"token_id:{token_id}"
661
662
663
664
665
666
667
                else:
                    if tokenizer is None:
                        raise ValueError(
                            "Unable to get tokenizer because `skip_tokenizer_init=True`"
                        )

                    token = tokenizer.decode(token_id)
668

669
670
671
672
                out_tokens.append(token)
                out_token_logprobs.append(None)
                out_top_logprobs.append(None)
            else:
673
674
                step_token = step_top_logprobs[token_id]

675
                token = self._get_decoded_token(
676
                    step_token,
677
678
                    token_id,
                    tokenizer,
679
                    return_as_token_id=should_return_as_token_id,
680
681
682
                )
                token_logprob = max(step_token.logprob, -9999.0)

683
684
685
686
687
688
689
                out_tokens.append(token)
                out_token_logprobs.append(token_logprob)

                # makes sure to add the top num_output_top_logprobs + 1
                # logprobs, as defined in the openai API
                # (cf. https://github.com/openai/openai-openapi/blob/
                # 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153)
690
691
692
693
694
695
696
697
698
699
700
701
702
703
                out_top_logprobs.append(
                    {
                        # Convert float("-inf") to the
                        # JSON-serializable float that OpenAI uses
                        self._get_decoded_token(
                            top_lp[1],
                            top_lp[0],
                            tokenizer,
                            return_as_token_id=should_return_as_token_id,
                        ): max(top_lp[1].logprob, -9999.0)
                        for i, top_lp in enumerate(step_top_logprobs.items())
                        if num_output_top_logprobs >= i
                    }
                )
704
705
706
707
708
709
710
711
712
713
714
715
716

            if len(out_text_offset) == 0:
                out_text_offset.append(initial_text_offset)
            else:
                out_text_offset.append(out_text_offset[-1] + last_token_len)
            last_token_len = len(token)

        return CompletionLogProbs(
            text_offset=out_text_offset,
            token_logprobs=out_token_logprobs,
            tokens=out_tokens,
            top_logprobs=out_top_logprobs,
        )
717
718
719
720

    def _build_render_config(
        self,
        request: CompletionRequest,
721
        max_input_length: int | None = None,
722
723
724
725
726
727
728
    ) -> RenderConfig:
        max_input_tokens_len = self.max_model_len - (request.max_tokens or 0)
        return RenderConfig(
            max_length=max_input_tokens_len,
            truncate_prompt_tokens=request.truncate_prompt_tokens,
            add_special_tokens=request.add_special_tokens,
            cache_salt=request.cache_salt,
729
            needs_detokenization=bool(request.echo and not request.return_token_ids),
730
        )