serving_completion.py 29.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
import time
6
7
from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import Sequence as GenericSequence
8
from typing import cast
9

10
import jinja2
11
from fastapi import Request
12

13
from vllm.engine.protocol import EngineClient
14
from vllm.entrypoints.logger import RequestLogger
15
16
17
18
19
20
21
22
23
24
25
from vllm.entrypoints.openai.protocol import (
    CompletionLogProbs,
    CompletionRequest,
    CompletionResponse,
    CompletionResponseChoice,
    CompletionResponseStreamChoice,
    CompletionStreamResponse,
    ErrorResponse,
    PromptTokenUsageInfo,
    RequestResponseMetadata,
    UsageInfo,
26
    VLLMValidationError,
27
)
28
29
30
31
32
from vllm.entrypoints.openai.serving_engine import (
    GenerationError,
    OpenAIServing,
    clamp_prompt_logprobs,
)
33
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
34
from vllm.entrypoints.renderer import RenderConfig
35
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
36
from vllm.inputs.data import EmbedsPrompt, TokensPrompt, is_embeds_prompt
37
from vllm.logger import init_logger
38
from vllm.logprobs import Logprob
39
from vllm.outputs import RequestOutput
40
from vllm.sampling_params import BeamSearchParams, SamplingParams
41
from vllm.tokenizers import TokenizerLike
42
43
from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.collection_utils import as_list
44
from vllm.v1.sample.logits_processor import validate_logits_processors_parameters
45
46
47
48
49

logger = init_logger(__name__)


class OpenAIServingCompletion(OpenAIServing):
50
51
    def __init__(
        self,
52
        engine_client: EngineClient,
53
        models: OpenAIServingModels,
54
        *,
55
        request_logger: RequestLogger | None,
56
        return_tokens_as_token_ids: bool = False,
57
        enable_prompt_tokens_details: bool = False,
58
        enable_force_include_usage: bool = False,
59
        log_error_stack: bool = False,
60
    ):
61
62
63
64
65
        super().__init__(
            engine_client=engine_client,
            models=models,
            request_logger=request_logger,
            return_tokens_as_token_ids=return_tokens_as_token_ids,
66
            log_error_stack=log_error_stack,
67
        )
68
69
70
71

        # set up logits processors
        self.logits_processors = self.model_config.logits_processors

72
        self.enable_prompt_tokens_details = enable_prompt_tokens_details
73
        self.default_sampling_params = self.model_config.get_diff_sampling_param()
74
        self.enable_force_include_usage = enable_force_include_usage
75
        if self.default_sampling_params:
76
77
            source = self.model_config.generation_config
            source = "model" if source == "auto" else source
78
79
80
81
82
            logger.info(
                "Using default completion sampling params from %s: %s",
                source,
                self.default_sampling_params,
            )
83

84
85
86
    async def create_completion(
        self,
        request: CompletionRequest,
87
88
        raw_request: Request | None = None,
    ) -> AsyncGenerator[str, None] | CompletionResponse | ErrorResponse:
89
90
91
92
93
        """Completion API similar to OpenAI's API.

        See https://platform.openai.com/docs/api-reference/completions/create
        for the API specification. This API mimics the OpenAI Completion API.

94
        NOTE: Currently we do not support the following feature:
95
96
97
98
99
100
101
            - suffix (the language models we currently support do not support
            suffix)
        """
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

102
103
104
105
106
107
        # If the engine is dead, raise the engine's DEAD_ERROR.
        # This is required for the streaming case, where we return a
        # success status before we actually start generating text :).
        if self.engine_client.errored:
            raise self.engine_client.dead_error

108
        # Return error for unsupported features.
109
        if request.suffix is not None:
110
            return self.create_error_response("suffix is not currently supported")
111

112
        if request.echo and request.prompt_embeds is not None:
113
            return self.create_error_response("Echo is unsupported with prompt embeds.")
114

115
        if request.prompt_logprobs is not None and request.prompt_embeds is not None:
116
            return self.create_error_response(
117
118
                "prompt_logprobs is not compatible with prompt embeds."
            )
119

120
        request_id = f"cmpl-{self._base_request_id(raw_request, request.request_id)}"
121
        created_time = int(time.time())
122

123
124
125
126
        request_metadata = RequestResponseMetadata(request_id=request_id)
        if raw_request:
            raw_request.state.request_metadata = request_metadata

127
        try:
128
            lora_request = self._maybe_get_adapters(request)
129

130
            if self.model_config.skip_tokenizer_init:
131
132
                tokenizer = None
            else:
133
                tokenizer = await self.engine_client.get_tokenizer()
134
135
136
137
138
            renderer = self._get_renderer(tokenizer)

            engine_prompts = await renderer.render_prompt_and_embeds(
                prompt_or_prompts=request.prompt,
                prompt_embeds=request.prompt_embeds,
139
                config=self._build_render_config(request),
140
141
142
143
            )
        except ValueError as e:
            logger.exception("Error in preprocessing prompt inputs")
            return self.create_error_response(str(e))
144
145
146
147
148
149
150
151
152
        except TypeError as e:
            logger.exception("Error in preprocessing prompt inputs")
            return self.create_error_response(str(e))
        except RuntimeError as e:
            logger.exception("Error in preprocessing prompt inputs")
            return self.create_error_response(str(e))
        except jinja2.TemplateError as e:
            logger.exception("Error in preprocessing prompt inputs")
            return self.create_error_response(str(e))
153

154
155
156
        # Extract data_parallel_rank from header (router can inject it)
        data_parallel_rank = self._get_data_parallel_rank(raw_request)

157
        # Schedule the request and get the result generator.
158
        generators: list[AsyncGenerator[RequestOutput, None]] = []
159
160
        try:
            for i, engine_prompt in enumerate(engine_prompts):
161
                prompt_text, prompt_token_ids, prompt_embeds = (
162
163
                    self._get_prompt_components(engine_prompt)
                )
164
165
166
167
168
169

                input_length = None
                if prompt_token_ids is not None:
                    input_length = len(prompt_token_ids)
                elif prompt_embeds is not None:
                    input_length = len(prompt_embeds)
170
                else:
171
                    raise NotImplementedError
172
173
174
175
176
177
178
179

                if self.default_sampling_params is None:
                    self.default_sampling_params = {}

                max_tokens = get_max_tokens(
                    max_model_len=self.max_model_len,
                    request=request,
                    input_length=input_length,
180
181
                    default_sampling_params=self.default_sampling_params,
                )
182

183
                sampling_params: SamplingParams | BeamSearchParams
184
185
                if request.use_beam_search:
                    sampling_params = request.to_beam_search_params(
186
187
                        max_tokens, self.default_sampling_params
                    )
188
189
                else:
                    sampling_params = request.to_sampling_params(
190
191
192
193
                        max_tokens,
                        self.model_config.logits_processor_pattern,
                        self.default_sampling_params,
                    )
194
195
196
197
                    validate_logits_processors_parameters(
                        self.logits_processors,
                        sampling_params,
                    )
198

199
200
                request_id_item = f"{request_id}-{i}"

201
202
                self._log_inputs(
                    request_id_item,
203
                    engine_prompt,
204
205
206
                    params=sampling_params,
                    lora_request=lora_request,
                )
207

208
209
210
211
212
                trace_headers = (
                    None
                    if raw_request is None
                    else await self._get_trace_headers(raw_request.headers)
                )
213

214
215
216
                # Mypy inconsistently requires this second cast in different
                # environments. It shouldn't be necessary (redundant from above)
                # but pre-commit in CI fails without it.
217
                engine_prompt = cast(EmbedsPrompt | TokensPrompt, engine_prompt)
218
                if isinstance(sampling_params, BeamSearchParams):
219
                    generator = self.beam_search(
220
                        prompt=engine_prompt,
221
222
                        request_id=request_id,
                        params=sampling_params,
223
                        lora_request=lora_request,
224
                        trace_headers=trace_headers,
225
                    )
226
                else:
227
228
229
230
231
232
233
                    engine_request, tokenization_kwargs = await self._process_inputs(
                        request_id_item,
                        engine_prompt,
                        sampling_params,
                        lora_request=lora_request,
                        trace_headers=trace_headers,
                        priority=request.priority,
234
                        data_parallel_rank=data_parallel_rank,
235
                    )
236

237
                    generator = self.engine_client.generate(
238
                        engine_request,
239
240
241
242
243
                        sampling_params,
                        request_id_item,
                        lora_request=lora_request,
                        trace_headers=trace_headers,
                        priority=request.priority,
244
245
                        prompt_text=prompt_text,
                        tokenization_kwargs=tokenization_kwargs,
246
                        data_parallel_rank=data_parallel_rank,
247
                    )
248
249

                generators.append(generator)
250
        except ValueError as e:
251
            return self.create_error_response(e)
252

253
        result_generator = merge_async_iterators(*generators)
254

255
        model_name = self.models.model_name(lora_request)
256
257
        num_prompts = len(engine_prompts)

258
259
        # We do not stream the results when using beam search.
        stream = request.stream and not request.use_beam_search
260
261
262

        # Streaming response
        if stream:
263
264
            return self.completion_stream_generator(
                request,
265
                engine_prompts,
266
267
268
269
                result_generator,
                request_id,
                created_time,
                model_name,
270
                num_prompts=num_prompts,
271
                tokenizer=tokenizer,
272
                request_metadata=request_metadata,
273
            )
274
275

        # Non-streaming response
276
        final_res_batch: list[RequestOutput | None] = [None] * num_prompts
277
278
279
        try:
            async for i, res in result_generator:
                final_res_batch[i] = res
280
281
282
283
284
285
286
287

            for i, final_res in enumerate(final_res_batch):
                assert final_res is not None

                # The output should contain the input text
                # We did not pass it into vLLM engine to avoid being redundant
                # with the inputs token IDs
                if final_res.prompt is None:
288
                    engine_prompt = engine_prompts[i]
289
290
291
292
293
                    final_res.prompt = (
                        None
                        if is_embeds_prompt(engine_prompt)
                        else engine_prompt.get("prompt")
                    )
294

295
            final_res_batch_checked = cast(list[RequestOutput], final_res_batch)
296

297
            response = self.request_output_to_completion_response(
298
299
300
301
302
303
                final_res_batch_checked,
                request,
                request_id,
                created_time,
                model_name,
                tokenizer,
304
                request_metadata,
305
            )
306
307
        except asyncio.CancelledError:
            return self.create_error_response("Client disconnected")
308
309
        except GenerationError as e:
            return self._convert_generation_error_to_response(e)
310
        except ValueError as e:
311
            return self.create_error_response(e)
312

313
314
        # When user requests streaming but we don't stream, we still need to
        # return a streaming response with a single event.
315
        if request.stream:
316
            response_json = response.model_dump_json()
317
318
319
320
321
322
323
324

            async def fake_stream_generator() -> AsyncGenerator[str, None]:
                yield f"data: {response_json}\n\n"
                yield "data: [DONE]\n\n"

            return fake_stream_generator()

        return response
325
326
327
328

    async def completion_stream_generator(
        self,
        request: CompletionRequest,
329
        engine_prompts: list[TokensPrompt | EmbedsPrompt],
330
        result_generator: AsyncIterator[tuple[int, RequestOutput]],
331
332
333
334
        request_id: str,
        created_time: int,
        model_name: str,
        num_prompts: int,
335
        tokenizer: TokenizerLike | None,
336
        request_metadata: RequestResponseMetadata,
337
    ) -> AsyncGenerator[str, None]:
338
        num_choices = 1 if request.n is None else request.n
339
        previous_text_lens = [0] * num_choices * num_prompts
340
341
        previous_num_tokens = [0] * num_choices * num_prompts
        has_echoed = [False] * num_choices * num_prompts
342
        num_prompt_tokens = [0] * num_prompts
343
344
        num_cached_tokens = None
        first_iteration = True
345

346
        stream_options = request.stream_options
347
348
349
        include_usage, include_continuous_usage = should_include_usage(
            stream_options, self.enable_force_include_usage
        )
350

351
352
        try:
            async for prompt_idx, res in result_generator:
353
354
                prompt_token_ids = res.prompt_token_ids
                prompt_logprobs = res.prompt_logprobs
355

356
357
358
359
                if first_iteration:
                    num_cached_tokens = res.num_cached_tokens
                    first_iteration = False

360
361
362
                prompt_text = res.prompt
                if prompt_text is None:
                    engine_prompt = engine_prompts[prompt_idx]
363
364
365
366
367
                    prompt_text = (
                        None
                        if is_embeds_prompt(engine_prompt)
                        else engine_prompt.get("prompt")
                    )
368

369
                # Prompt details are excluded from later streamed outputs
370
371
                if prompt_token_ids is not None:
                    num_prompt_tokens[prompt_idx] = len(prompt_token_ids)
372

373
                delta_token_ids: GenericSequence[int]
374
                out_logprobs: GenericSequence[dict[int, Logprob] | None] | None
375
376

                for output in res.outputs:
377
                    i = output.index + prompt_idx * num_choices
378

379
380
381
                    # Useful when request.return_token_ids is True
                    # Returning prompt token IDs shares the same logic
                    # with the echo implementation.
382
                    prompt_token_ids_to_return: list[int] | None = None
383

384
                    assert request.max_tokens is not None
385
                    if request.echo and not has_echoed[i]:
386
                        assert prompt_token_ids is not None
387
388
                        if request.return_token_ids:
                            prompt_text = ""
389
                        assert prompt_text is not None
390
391
392
393
394
395
396
397
398
                        if request.max_tokens == 0:
                            # only return the prompt
                            delta_text = prompt_text
                            delta_token_ids = prompt_token_ids
                            out_logprobs = prompt_logprobs
                        else:
                            # echo the prompt and first token
                            delta_text = prompt_text + output.text
                            delta_token_ids = [
399
400
                                *prompt_token_ids,
                                *output.token_ids,
401
402
                            ]
                            out_logprobs = [
403
                                *(prompt_logprobs or []),
404
405
                                *(output.logprobs or []),
                            ]
406
                        prompt_token_ids_to_return = prompt_token_ids
407
408
409
                        has_echoed[i] = True
                    else:
                        # return just the delta
410
411
412
                        delta_text = output.text
                        delta_token_ids = output.token_ids
                        out_logprobs = output.logprobs
413

414
415
                        # has_echoed[i] is reused here to indicate whether
                        # we have already returned the prompt token IDs.
416
                        if not has_echoed[i] and request.return_token_ids:
417
418
419
                            prompt_token_ids_to_return = prompt_token_ids
                            has_echoed[i] = True

420
421
422
423
424
                        if (
                            not delta_text
                            and not delta_token_ids
                            and not previous_num_tokens[i]
                        ):
425
426
427
                            # Chunked prefill case, don't return empty chunks
                            continue

428
                    if request.logprobs is not None:
429
                        assert out_logprobs is not None, "Did not output logprobs"
430
                        logprobs = self._create_completion_logprobs(
431
                            token_ids=delta_token_ids,
432
                            top_logprobs=out_logprobs,
433
                            num_output_top_logprobs=request.logprobs,
434
                            tokenizer=tokenizer,
435
                            initial_text_offset=previous_text_lens[i],
436
                            return_as_token_id=request.return_tokens_as_token_ids,
437
438
439
440
                        )
                    else:
                        logprobs = None

441
442
                    previous_text_lens[i] += len(output.text)
                    previous_num_tokens[i] += len(output.token_ids)
443
                    finish_reason = output.finish_reason
444
                    stop_reason = output.stop_reason
445

446
447
                    self._raise_if_error(finish_reason, request_id)

448
                    chunk = CompletionStreamResponse(
449
450
451
452
453
454
455
456
457
                        id=request_id,
                        created=created_time,
                        model=model_name,
                        choices=[
                            CompletionResponseStreamChoice(
                                index=i,
                                text=delta_text,
                                logprobs=logprobs,
                                finish_reason=finish_reason,
458
                                stop_reason=stop_reason,
459
                                prompt_token_ids=prompt_token_ids_to_return,
460
461
462
463
464
                                token_ids=(
                                    as_list(output.token_ids)
                                    if request.return_token_ids
                                    else None
                                ),
465
                            )
466
467
                        ],
                    )
468
469
470
471
472
473
474
475
                    if include_continuous_usage:
                        prompt_tokens = num_prompt_tokens[prompt_idx]
                        completion_tokens = previous_num_tokens[i]
                        chunk.usage = UsageInfo(
                            prompt_tokens=prompt_tokens,
                            completion_tokens=completion_tokens,
                            total_tokens=prompt_tokens + completion_tokens,
                        )
476

477
                    response_json = chunk.model_dump_json(exclude_unset=False)
478
                    yield f"data: {response_json}\n\n"
479

480
481
482
483
484
            total_prompt_tokens = sum(num_prompt_tokens)
            total_completion_tokens = sum(previous_num_tokens)
            final_usage_info = UsageInfo(
                prompt_tokens=total_prompt_tokens,
                completion_tokens=total_completion_tokens,
485
486
                total_tokens=total_prompt_tokens + total_completion_tokens,
            )
487

488
489
            if self.enable_prompt_tokens_details and num_cached_tokens:
                final_usage_info.prompt_tokens_details = PromptTokenUsageInfo(
490
491
                    cached_tokens=num_cached_tokens
                )
492

493
            if include_usage:
494
495
496
497
498
                final_usage_chunk = CompletionStreamResponse(
                    id=request_id,
                    created=created_time,
                    model=model_name,
                    choices=[],
499
                    usage=final_usage_info,
500
                )
501
                final_usage_data = final_usage_chunk.model_dump_json(
502
503
                    exclude_unset=False, exclude_none=True
                )
504
505
                yield f"data: {final_usage_data}\n\n"

506
            # report to FastAPI middleware aggregate usage across all choices
507
            request_metadata.final_usage_info = final_usage_info
508

509
510
        except GenerationError as e:
            yield f"data: {self._convert_generation_error_to_streaming_response(e)}\n\n"
511
        except Exception as e:
512
            logger.exception("Error in completion stream generator.")
513
            data = self.create_streaming_error_response(e)
514
515
516
517
518
            yield f"data: {data}\n\n"
        yield "data: [DONE]\n\n"

    def request_output_to_completion_response(
        self,
519
        final_res_batch: list[RequestOutput],
520
521
522
523
        request: CompletionRequest,
        request_id: str,
        created_time: int,
        model_name: str,
524
        tokenizer: TokenizerLike | None,
525
        request_metadata: RequestResponseMetadata,
526
    ) -> CompletionResponse:
527
        choices: list[CompletionResponseChoice] = []
528
529
        num_prompt_tokens = 0
        num_generated_tokens = 0
530
531
        kv_transfer_params = None
        last_final_res = None
532
        for final_res in final_res_batch:
533
            last_final_res = final_res
534
            prompt_token_ids = final_res.prompt_token_ids
535
            assert prompt_token_ids is not None
536
            prompt_logprobs = clamp_prompt_logprobs(final_res.prompt_logprobs)
537
538
            prompt_text = final_res.prompt

539
            token_ids: GenericSequence[int]
540
            out_logprobs: GenericSequence[dict[int, Logprob] | None] | None
541

542
            for output in final_res.outputs:
543
544
                self._raise_if_error(output.finish_reason, request_id)

545
                assert request.max_tokens is not None
546
                if request.echo:
547
548
                    if request.return_token_ids:
                        prompt_text = ""
549
                    assert prompt_text is not None
550
551
552
553
                    if request.max_tokens == 0:
                        token_ids = prompt_token_ids
                        out_logprobs = prompt_logprobs
                        output_text = prompt_text
554
                    else:
555
556
557
558
559
560
561
562
563
564
565
566
567
                        token_ids = [*prompt_token_ids, *output.token_ids]

                        if request.logprobs is None:
                            out_logprobs = None
                        else:
                            assert prompt_logprobs is not None
                            assert output.logprobs is not None
                            out_logprobs = [
                                *prompt_logprobs,
                                *output.logprobs,
                            ]

                        output_text = prompt_text + output.text
568
569
                else:
                    token_ids = output.token_ids
570
                    out_logprobs = output.logprobs
571
572
573
                    output_text = output.text

                if request.logprobs is not None:
574
                    assert out_logprobs is not None, "Did not output logprobs"
575
                    logprobs = self._create_completion_logprobs(
576
                        token_ids=token_ids,
577
                        top_logprobs=out_logprobs,
578
                        tokenizer=tokenizer,
579
                        num_output_top_logprobs=request.logprobs,
580
                        return_as_token_id=request.return_tokens_as_token_ids,
581
582
583
584
585
586
587
588
589
                    )
                else:
                    logprobs = None

                choice_data = CompletionResponseChoice(
                    index=len(choices),
                    text=output_text,
                    logprobs=logprobs,
                    finish_reason=output.finish_reason,
590
                    stop_reason=output.stop_reason,
591
                    prompt_logprobs=final_res.prompt_logprobs,
592
593
594
595
596
597
                    prompt_token_ids=(
                        prompt_token_ids if request.return_token_ids else None
                    ),
                    token_ids=(
                        as_list(output.token_ids) if request.return_token_ids else None
                    ),
598
599
600
                )
                choices.append(choice_data)

601
602
                num_generated_tokens += len(output.token_ids)

603
604
605
606
607
608
609
610
            num_prompt_tokens += len(prompt_token_ids)

        usage = UsageInfo(
            prompt_tokens=num_prompt_tokens,
            completion_tokens=num_generated_tokens,
            total_tokens=num_prompt_tokens + num_generated_tokens,
        )

611
612
613
614
615
        if (
            self.enable_prompt_tokens_details
            and last_final_res
            and last_final_res.num_cached_tokens
        ):
616
            usage.prompt_tokens_details = PromptTokenUsageInfo(
617
618
                cached_tokens=last_final_res.num_cached_tokens
            )
619

620
        request_metadata.final_usage_info = usage
621
622
        if final_res_batch:
            kv_transfer_params = final_res_batch[0].kv_transfer_params
623
624
625
626
627
628
        return CompletionResponse(
            id=request_id,
            created=created_time,
            model=model_name,
            choices=choices,
            usage=usage,
629
630
            kv_transfer_params=kv_transfer_params,
        )
631
632
633
634

    def _create_completion_logprobs(
        self,
        token_ids: GenericSequence[int],
635
        top_logprobs: GenericSequence[dict[int, Logprob] | None],
636
        num_output_top_logprobs: int,
637
        tokenizer: TokenizerLike | None,
638
        initial_text_offset: int = 0,
639
        return_as_token_id: bool | None = None,
640
641
    ) -> CompletionLogProbs:
        """Create logprobs for OpenAI Completion API."""
642
        out_text_offset: list[int] = []
643
        out_token_logprobs: list[float | None] = []
644
        out_tokens: list[str] = []
645
        out_top_logprobs: list[dict[str, float] | None] = []
646
647
648

        last_token_len = 0

649
650
651
652
653
        should_return_as_token_id = (
            return_as_token_id
            if return_as_token_id is not None
            else self.return_tokens_as_token_ids
        )
654
655
656
        for i, token_id in enumerate(token_ids):
            step_top_logprobs = top_logprobs[i]
            if step_top_logprobs is None:
657
                if should_return_as_token_id:
658
                    token = f"token_id:{token_id}"
659
660
                else:
                    if tokenizer is None:
661
662
663
664
665
                        raise VLLMValidationError(
                            "Unable to get tokenizer because "
                            "`skip_tokenizer_init=True`",
                            parameter="skip_tokenizer_init",
                            value=True,
666
667
668
                        )

                    token = tokenizer.decode(token_id)
669

670
671
672
673
                out_tokens.append(token)
                out_token_logprobs.append(None)
                out_top_logprobs.append(None)
            else:
674
675
                step_token = step_top_logprobs[token_id]

676
                token = self._get_decoded_token(
677
                    step_token,
678
679
                    token_id,
                    tokenizer,
680
                    return_as_token_id=should_return_as_token_id,
681
682
683
                )
                token_logprob = max(step_token.logprob, -9999.0)

684
685
686
687
688
689
690
                out_tokens.append(token)
                out_token_logprobs.append(token_logprob)

                # makes sure to add the top num_output_top_logprobs + 1
                # logprobs, as defined in the openai API
                # (cf. https://github.com/openai/openai-openapi/blob/
                # 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153)
691
692
693
694
695
696
697
698
699
700
701
702
703
704
                out_top_logprobs.append(
                    {
                        # Convert float("-inf") to the
                        # JSON-serializable float that OpenAI uses
                        self._get_decoded_token(
                            top_lp[1],
                            top_lp[0],
                            tokenizer,
                            return_as_token_id=should_return_as_token_id,
                        ): max(top_lp[1].logprob, -9999.0)
                        for i, top_lp in enumerate(step_top_logprobs.items())
                        if num_output_top_logprobs >= i
                    }
                )
705
706
707
708
709
710
711
712
713
714
715
716
717

            if len(out_text_offset) == 0:
                out_text_offset.append(initial_text_offset)
            else:
                out_text_offset.append(out_text_offset[-1] + last_token_len)
            last_token_len = len(token)

        return CompletionLogProbs(
            text_offset=out_text_offset,
            token_logprobs=out_token_logprobs,
            tokens=out_tokens,
            top_logprobs=out_top_logprobs,
        )
718
719
720
721

    def _build_render_config(
        self,
        request: CompletionRequest,
722
        max_input_length: int | None = None,
723
    ) -> RenderConfig:
724
725
726
727
728
729
730
731
732
        # Validate max_tokens before using it
        if request.max_tokens is not None and request.max_tokens > self.max_model_len:
            raise VLLMValidationError(
                f"'max_tokens' ({request.max_tokens}) cannot be greater than "
                f"the model's maximum context length ({self.max_model_len}).",
                parameter="max_tokens",
                value=request.max_tokens,
            )

733
734
735
736
737
738
        max_input_tokens_len = self.max_model_len - (request.max_tokens or 0)
        return RenderConfig(
            max_length=max_input_tokens_len,
            truncate_prompt_tokens=request.truncate_prompt_tokens,
            add_special_tokens=request.add_special_tokens,
            cache_salt=request.cache_salt,
739
            needs_detokenization=bool(request.echo and not request.return_token_ids),
740
        )