serving.py 29.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
import time
6
7
from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import Sequence as GenericSequence
8
from typing import cast
9

10
import jinja2
11
from fastapi import Request
12

13
from vllm.engine.protocol import EngineClient
14
from vllm.entrypoints.logger import RequestLogger
15
from vllm.entrypoints.openai.completion.protocol import (
16
17
18
19
20
21
    CompletionLogProbs,
    CompletionRequest,
    CompletionResponse,
    CompletionResponseChoice,
    CompletionResponseStreamChoice,
    CompletionStreamResponse,
22
23
)
from vllm.entrypoints.openai.engine.protocol import (
24
25
26
27
28
    ErrorResponse,
    PromptTokenUsageInfo,
    RequestResponseMetadata,
    UsageInfo,
)
29
from vllm.entrypoints.openai.engine.serving import (
30
31
32
33
    GenerationError,
    OpenAIServing,
    clamp_prompt_logprobs,
)
34
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
35
from vllm.entrypoints.renderer import RenderConfig
36
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
37
from vllm.exceptions import VLLMValidationError
38
from vllm.inputs.data import EmbedsPrompt, TokensPrompt, is_embeds_prompt
39
from vllm.logger import init_logger
40
from vllm.logprobs import Logprob
41
from vllm.outputs import RequestOutput
42
from vllm.sampling_params import BeamSearchParams, SamplingParams
43
from vllm.tokenizers import TokenizerLike
44
45
from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.collection_utils import as_list
46
from vllm.v1.sample.logits_processor import validate_logits_processors_parameters
47
48
49
50
51

logger = init_logger(__name__)


class OpenAIServingCompletion(OpenAIServing):
52
53
    def __init__(
        self,
54
        engine_client: EngineClient,
55
        models: OpenAIServingModels,
56
        *,
57
        request_logger: RequestLogger | None,
58
        return_tokens_as_token_ids: bool = False,
59
        enable_prompt_tokens_details: bool = False,
60
        enable_force_include_usage: bool = False,
61
        log_error_stack: bool = False,
62
    ):
63
64
65
66
67
        super().__init__(
            engine_client=engine_client,
            models=models,
            request_logger=request_logger,
            return_tokens_as_token_ids=return_tokens_as_token_ids,
68
            log_error_stack=log_error_stack,
69
        )
70
71
72
73

        # set up logits processors
        self.logits_processors = self.model_config.logits_processors

74
        self.enable_prompt_tokens_details = enable_prompt_tokens_details
75
        self.default_sampling_params = self.model_config.get_diff_sampling_param()
76
        self.enable_force_include_usage = enable_force_include_usage
77
        if self.default_sampling_params:
78
79
            source = self.model_config.generation_config
            source = "model" if source == "auto" else source
80
81
82
83
84
            logger.info(
                "Using default completion sampling params from %s: %s",
                source,
                self.default_sampling_params,
            )
85

86
87
88
    async def create_completion(
        self,
        request: CompletionRequest,
89
90
        raw_request: Request | None = None,
    ) -> AsyncGenerator[str, None] | CompletionResponse | ErrorResponse:
91
92
93
94
95
        """Completion API similar to OpenAI's API.

        See https://platform.openai.com/docs/api-reference/completions/create
        for the API specification. This API mimics the OpenAI Completion API.

96
        NOTE: Currently we do not support the following feature:
97
98
99
100
101
102
103
            - suffix (the language models we currently support do not support
            suffix)
        """
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

104
105
106
107
108
109
        # If the engine is dead, raise the engine's DEAD_ERROR.
        # This is required for the streaming case, where we return a
        # success status before we actually start generating text :).
        if self.engine_client.errored:
            raise self.engine_client.dead_error

110
        # Return error for unsupported features.
111
        if request.suffix is not None:
112
            return self.create_error_response("suffix is not currently supported")
113

114
        if request.echo and request.prompt_embeds is not None:
115
            return self.create_error_response("Echo is unsupported with prompt embeds.")
116

117
        if request.prompt_logprobs is not None and request.prompt_embeds is not None:
118
            return self.create_error_response(
119
120
                "prompt_logprobs is not compatible with prompt embeds."
            )
121

122
        request_id = f"cmpl-{self._base_request_id(raw_request, request.request_id)}"
123
        created_time = int(time.time())
124

125
126
127
128
        request_metadata = RequestResponseMetadata(request_id=request_id)
        if raw_request:
            raw_request.state.request_metadata = request_metadata

129
        try:
130
            lora_request = self._maybe_get_adapters(request)
131

132
            if self.model_config.skip_tokenizer_init:
133
134
                tokenizer = None
            else:
135
                tokenizer = await self.engine_client.get_tokenizer()
136
137
138
139
140
            renderer = self._get_renderer(tokenizer)

            engine_prompts = await renderer.render_prompt_and_embeds(
                prompt_or_prompts=request.prompt,
                prompt_embeds=request.prompt_embeds,
141
                config=self._build_render_config(request),
142
143
144
            )
        except ValueError as e:
            logger.exception("Error in preprocessing prompt inputs")
145
            return self.create_error_response(e)
146
147
        except TypeError as e:
            logger.exception("Error in preprocessing prompt inputs")
148
            return self.create_error_response(e)
149
150
        except RuntimeError as e:
            logger.exception("Error in preprocessing prompt inputs")
151
            return self.create_error_response(e)
152
153
        except jinja2.TemplateError as e:
            logger.exception("Error in preprocessing prompt inputs")
154
            return self.create_error_response(e)
155

156
157
158
        # Extract data_parallel_rank from header (router can inject it)
        data_parallel_rank = self._get_data_parallel_rank(raw_request)

159
        # Schedule the request and get the result generator.
160
        generators: list[AsyncGenerator[RequestOutput, None]] = []
161
162
        try:
            for i, engine_prompt in enumerate(engine_prompts):
163
                prompt_text, prompt_token_ids, prompt_embeds = (
164
165
                    self._get_prompt_components(engine_prompt)
                )
166
167
168
169
170
171

                input_length = None
                if prompt_token_ids is not None:
                    input_length = len(prompt_token_ids)
                elif prompt_embeds is not None:
                    input_length = len(prompt_embeds)
172
                else:
173
                    raise NotImplementedError
174
175
176
177
178
179
180
181

                if self.default_sampling_params is None:
                    self.default_sampling_params = {}

                max_tokens = get_max_tokens(
                    max_model_len=self.max_model_len,
                    request=request,
                    input_length=input_length,
182
183
                    default_sampling_params=self.default_sampling_params,
                )
184

185
                sampling_params: SamplingParams | BeamSearchParams
186
187
                if request.use_beam_search:
                    sampling_params = request.to_beam_search_params(
188
189
                        max_tokens, self.default_sampling_params
                    )
190
191
                else:
                    sampling_params = request.to_sampling_params(
192
193
194
195
                        max_tokens,
                        self.model_config.logits_processor_pattern,
                        self.default_sampling_params,
                    )
196
197
198
199
                    validate_logits_processors_parameters(
                        self.logits_processors,
                        sampling_params,
                    )
200

201
202
                request_id_item = f"{request_id}-{i}"

203
204
                self._log_inputs(
                    request_id_item,
205
                    engine_prompt,
206
207
208
                    params=sampling_params,
                    lora_request=lora_request,
                )
209

210
211
212
213
214
                trace_headers = (
                    None
                    if raw_request is None
                    else await self._get_trace_headers(raw_request.headers)
                )
215

216
217
218
                # Mypy inconsistently requires this second cast in different
                # environments. It shouldn't be necessary (redundant from above)
                # but pre-commit in CI fails without it.
219
                engine_prompt = cast(EmbedsPrompt | TokensPrompt, engine_prompt)
220
                if isinstance(sampling_params, BeamSearchParams):
221
                    generator = self.beam_search(
222
                        prompt=engine_prompt,
223
224
                        request_id=request_id,
                        params=sampling_params,
225
                        lora_request=lora_request,
226
                        trace_headers=trace_headers,
227
                    )
228
                else:
229
230
231
232
233
234
235
                    engine_request, tokenization_kwargs = await self._process_inputs(
                        request_id_item,
                        engine_prompt,
                        sampling_params,
                        lora_request=lora_request,
                        trace_headers=trace_headers,
                        priority=request.priority,
236
                        data_parallel_rank=data_parallel_rank,
237
                    )
238

239
                    generator = self.engine_client.generate(
240
                        engine_request,
241
242
243
244
245
                        sampling_params,
                        request_id_item,
                        lora_request=lora_request,
                        trace_headers=trace_headers,
                        priority=request.priority,
246
247
                        prompt_text=prompt_text,
                        tokenization_kwargs=tokenization_kwargs,
248
                        data_parallel_rank=data_parallel_rank,
249
                    )
250
251

                generators.append(generator)
252
        except ValueError as e:
253
            return self.create_error_response(e)
254

255
        result_generator = merge_async_iterators(*generators)
256

257
        model_name = self.models.model_name(lora_request)
258
259
        num_prompts = len(engine_prompts)

260
261
        # We do not stream the results when using beam search.
        stream = request.stream and not request.use_beam_search
262
263
264

        # Streaming response
        if stream:
265
266
            return self.completion_stream_generator(
                request,
267
                engine_prompts,
268
269
270
271
                result_generator,
                request_id,
                created_time,
                model_name,
272
                num_prompts=num_prompts,
273
                tokenizer=tokenizer,
274
                request_metadata=request_metadata,
275
            )
276
277

        # Non-streaming response
278
        final_res_batch: list[RequestOutput | None] = [None] * num_prompts
279
280
281
        try:
            async for i, res in result_generator:
                final_res_batch[i] = res
282
283
284
285
286
287
288
289

            for i, final_res in enumerate(final_res_batch):
                assert final_res is not None

                # The output should contain the input text
                # We did not pass it into vLLM engine to avoid being redundant
                # with the inputs token IDs
                if final_res.prompt is None:
290
                    engine_prompt = engine_prompts[i]
291
292
293
294
295
                    final_res.prompt = (
                        None
                        if is_embeds_prompt(engine_prompt)
                        else engine_prompt.get("prompt")
                    )
296

297
            final_res_batch_checked = cast(list[RequestOutput], final_res_batch)
298

299
            response = self.request_output_to_completion_response(
300
301
302
303
304
305
                final_res_batch_checked,
                request,
                request_id,
                created_time,
                model_name,
                tokenizer,
306
                request_metadata,
307
            )
308
309
        except asyncio.CancelledError:
            return self.create_error_response("Client disconnected")
310
311
        except GenerationError as e:
            return self._convert_generation_error_to_response(e)
312
        except ValueError as e:
313
            return self.create_error_response(e)
314

315
316
        # When user requests streaming but we don't stream, we still need to
        # return a streaming response with a single event.
317
        if request.stream:
318
            response_json = response.model_dump_json()
319
320
321
322
323
324
325
326

            async def fake_stream_generator() -> AsyncGenerator[str, None]:
                yield f"data: {response_json}\n\n"
                yield "data: [DONE]\n\n"

            return fake_stream_generator()

        return response
327
328
329
330

    async def completion_stream_generator(
        self,
        request: CompletionRequest,
331
        engine_prompts: list[TokensPrompt | EmbedsPrompt],
332
        result_generator: AsyncIterator[tuple[int, RequestOutput]],
333
334
335
336
        request_id: str,
        created_time: int,
        model_name: str,
        num_prompts: int,
337
        tokenizer: TokenizerLike | None,
338
        request_metadata: RequestResponseMetadata,
339
    ) -> AsyncGenerator[str, None]:
340
        num_choices = 1 if request.n is None else request.n
341
        previous_text_lens = [0] * num_choices * num_prompts
342
343
        previous_num_tokens = [0] * num_choices * num_prompts
        has_echoed = [False] * num_choices * num_prompts
344
        num_prompt_tokens = [0] * num_prompts
345
346
        num_cached_tokens = None
        first_iteration = True
347

348
        stream_options = request.stream_options
349
350
351
        include_usage, include_continuous_usage = should_include_usage(
            stream_options, self.enable_force_include_usage
        )
352

353
354
        try:
            async for prompt_idx, res in result_generator:
355
356
                prompt_token_ids = res.prompt_token_ids
                prompt_logprobs = res.prompt_logprobs
357

358
359
360
361
                if first_iteration:
                    num_cached_tokens = res.num_cached_tokens
                    first_iteration = False

362
363
364
                prompt_text = res.prompt
                if prompt_text is None:
                    engine_prompt = engine_prompts[prompt_idx]
365
366
367
368
369
                    prompt_text = (
                        None
                        if is_embeds_prompt(engine_prompt)
                        else engine_prompt.get("prompt")
                    )
370

371
                # Prompt details are excluded from later streamed outputs
372
373
                if prompt_token_ids is not None:
                    num_prompt_tokens[prompt_idx] = len(prompt_token_ids)
374

375
                delta_token_ids: GenericSequence[int]
376
                out_logprobs: GenericSequence[dict[int, Logprob] | None] | None
377
378

                for output in res.outputs:
379
                    i = output.index + prompt_idx * num_choices
380

381
382
383
                    # Useful when request.return_token_ids is True
                    # Returning prompt token IDs shares the same logic
                    # with the echo implementation.
384
                    prompt_token_ids_to_return: list[int] | None = None
385

386
                    assert request.max_tokens is not None
387
                    if request.echo and not has_echoed[i]:
388
                        assert prompt_token_ids is not None
389
390
                        if request.return_token_ids:
                            prompt_text = ""
391
                        assert prompt_text is not None
392
393
394
395
396
397
398
399
400
                        if request.max_tokens == 0:
                            # only return the prompt
                            delta_text = prompt_text
                            delta_token_ids = prompt_token_ids
                            out_logprobs = prompt_logprobs
                        else:
                            # echo the prompt and first token
                            delta_text = prompt_text + output.text
                            delta_token_ids = [
401
402
                                *prompt_token_ids,
                                *output.token_ids,
403
404
                            ]
                            out_logprobs = [
405
                                *(prompt_logprobs or []),
406
407
                                *(output.logprobs or []),
                            ]
408
                        prompt_token_ids_to_return = prompt_token_ids
409
410
411
                        has_echoed[i] = True
                    else:
                        # return just the delta
412
413
414
                        delta_text = output.text
                        delta_token_ids = output.token_ids
                        out_logprobs = output.logprobs
415

416
417
                        # has_echoed[i] is reused here to indicate whether
                        # we have already returned the prompt token IDs.
418
                        if not has_echoed[i] and request.return_token_ids:
419
420
421
                            prompt_token_ids_to_return = prompt_token_ids
                            has_echoed[i] = True

422
423
424
425
426
                        if (
                            not delta_text
                            and not delta_token_ids
                            and not previous_num_tokens[i]
                        ):
427
428
429
                            # Chunked prefill case, don't return empty chunks
                            continue

430
                    if request.logprobs is not None:
431
                        assert out_logprobs is not None, "Did not output logprobs"
432
                        logprobs = self._create_completion_logprobs(
433
                            token_ids=delta_token_ids,
434
                            top_logprobs=out_logprobs,
435
                            num_output_top_logprobs=request.logprobs,
436
                            tokenizer=tokenizer,
437
                            initial_text_offset=previous_text_lens[i],
438
                            return_as_token_id=request.return_tokens_as_token_ids,
439
440
441
442
                        )
                    else:
                        logprobs = None

443
444
                    previous_text_lens[i] += len(output.text)
                    previous_num_tokens[i] += len(output.token_ids)
445
                    finish_reason = output.finish_reason
446
                    stop_reason = output.stop_reason
447

448
449
                    self._raise_if_error(finish_reason, request_id)

450
                    chunk = CompletionStreamResponse(
451
452
453
454
455
456
457
458
459
                        id=request_id,
                        created=created_time,
                        model=model_name,
                        choices=[
                            CompletionResponseStreamChoice(
                                index=i,
                                text=delta_text,
                                logprobs=logprobs,
                                finish_reason=finish_reason,
460
                                stop_reason=stop_reason,
461
                                prompt_token_ids=prompt_token_ids_to_return,
462
463
464
465
466
                                token_ids=(
                                    as_list(output.token_ids)
                                    if request.return_token_ids
                                    else None
                                ),
467
                            )
468
469
                        ],
                    )
470
471
472
473
474
475
476
477
                    if include_continuous_usage:
                        prompt_tokens = num_prompt_tokens[prompt_idx]
                        completion_tokens = previous_num_tokens[i]
                        chunk.usage = UsageInfo(
                            prompt_tokens=prompt_tokens,
                            completion_tokens=completion_tokens,
                            total_tokens=prompt_tokens + completion_tokens,
                        )
478

479
                    response_json = chunk.model_dump_json(exclude_unset=False)
480
                    yield f"data: {response_json}\n\n"
481

482
483
484
485
486
            total_prompt_tokens = sum(num_prompt_tokens)
            total_completion_tokens = sum(previous_num_tokens)
            final_usage_info = UsageInfo(
                prompt_tokens=total_prompt_tokens,
                completion_tokens=total_completion_tokens,
487
488
                total_tokens=total_prompt_tokens + total_completion_tokens,
            )
489

490
491
            if self.enable_prompt_tokens_details and num_cached_tokens:
                final_usage_info.prompt_tokens_details = PromptTokenUsageInfo(
492
493
                    cached_tokens=num_cached_tokens
                )
494

495
            if include_usage:
496
497
498
499
500
                final_usage_chunk = CompletionStreamResponse(
                    id=request_id,
                    created=created_time,
                    model=model_name,
                    choices=[],
501
                    usage=final_usage_info,
502
                )
503
                final_usage_data = final_usage_chunk.model_dump_json(
504
505
                    exclude_unset=False, exclude_none=True
                )
506
507
                yield f"data: {final_usage_data}\n\n"

508
            # report to FastAPI middleware aggregate usage across all choices
509
            request_metadata.final_usage_info = final_usage_info
510

511
512
        except GenerationError as e:
            yield f"data: {self._convert_generation_error_to_streaming_response(e)}\n\n"
513
        except Exception as e:
514
            logger.exception("Error in completion stream generator.")
515
            data = self.create_streaming_error_response(e)
516
517
518
519
520
            yield f"data: {data}\n\n"
        yield "data: [DONE]\n\n"

    def request_output_to_completion_response(
        self,
521
        final_res_batch: list[RequestOutput],
522
523
524
525
        request: CompletionRequest,
        request_id: str,
        created_time: int,
        model_name: str,
526
        tokenizer: TokenizerLike | None,
527
        request_metadata: RequestResponseMetadata,
528
    ) -> CompletionResponse:
529
        choices: list[CompletionResponseChoice] = []
530
531
        num_prompt_tokens = 0
        num_generated_tokens = 0
532
533
        kv_transfer_params = None
        last_final_res = None
534
        for final_res in final_res_batch:
535
            last_final_res = final_res
536
            prompt_token_ids = final_res.prompt_token_ids
537
            assert prompt_token_ids is not None
538
            prompt_logprobs = clamp_prompt_logprobs(final_res.prompt_logprobs)
539
540
            prompt_text = final_res.prompt

541
            token_ids: GenericSequence[int]
542
            out_logprobs: GenericSequence[dict[int, Logprob] | None] | None
543

544
            for output in final_res.outputs:
545
546
                self._raise_if_error(output.finish_reason, request_id)

547
                assert request.max_tokens is not None
548
                if request.echo:
549
550
                    if request.return_token_ids:
                        prompt_text = ""
551
                    assert prompt_text is not None
552
553
554
555
                    if request.max_tokens == 0:
                        token_ids = prompt_token_ids
                        out_logprobs = prompt_logprobs
                        output_text = prompt_text
556
                    else:
557
558
559
560
561
562
563
564
565
566
567
568
569
                        token_ids = [*prompt_token_ids, *output.token_ids]

                        if request.logprobs is None:
                            out_logprobs = None
                        else:
                            assert prompt_logprobs is not None
                            assert output.logprobs is not None
                            out_logprobs = [
                                *prompt_logprobs,
                                *output.logprobs,
                            ]

                        output_text = prompt_text + output.text
570
571
                else:
                    token_ids = output.token_ids
572
                    out_logprobs = output.logprobs
573
574
575
                    output_text = output.text

                if request.logprobs is not None:
576
                    assert out_logprobs is not None, "Did not output logprobs"
577
                    logprobs = self._create_completion_logprobs(
578
                        token_ids=token_ids,
579
                        top_logprobs=out_logprobs,
580
                        tokenizer=tokenizer,
581
                        num_output_top_logprobs=request.logprobs,
582
                        return_as_token_id=request.return_tokens_as_token_ids,
583
584
585
586
587
588
589
590
591
                    )
                else:
                    logprobs = None

                choice_data = CompletionResponseChoice(
                    index=len(choices),
                    text=output_text,
                    logprobs=logprobs,
                    finish_reason=output.finish_reason,
592
                    stop_reason=output.stop_reason,
593
                    prompt_logprobs=final_res.prompt_logprobs,
594
595
596
597
598
599
                    prompt_token_ids=(
                        prompt_token_ids if request.return_token_ids else None
                    ),
                    token_ids=(
                        as_list(output.token_ids) if request.return_token_ids else None
                    ),
600
601
602
                )
                choices.append(choice_data)

603
604
                num_generated_tokens += len(output.token_ids)

605
606
607
608
609
610
611
612
            num_prompt_tokens += len(prompt_token_ids)

        usage = UsageInfo(
            prompt_tokens=num_prompt_tokens,
            completion_tokens=num_generated_tokens,
            total_tokens=num_prompt_tokens + num_generated_tokens,
        )

613
614
615
616
617
        if (
            self.enable_prompt_tokens_details
            and last_final_res
            and last_final_res.num_cached_tokens
        ):
618
            usage.prompt_tokens_details = PromptTokenUsageInfo(
619
620
                cached_tokens=last_final_res.num_cached_tokens
            )
621

622
        request_metadata.final_usage_info = usage
623
624
        if final_res_batch:
            kv_transfer_params = final_res_batch[0].kv_transfer_params
625
626
627
628
629
630
        return CompletionResponse(
            id=request_id,
            created=created_time,
            model=model_name,
            choices=choices,
            usage=usage,
631
632
            kv_transfer_params=kv_transfer_params,
        )
633
634
635
636

    def _create_completion_logprobs(
        self,
        token_ids: GenericSequence[int],
637
        top_logprobs: GenericSequence[dict[int, Logprob] | None],
638
        num_output_top_logprobs: int,
639
        tokenizer: TokenizerLike | None,
640
        initial_text_offset: int = 0,
641
        return_as_token_id: bool | None = None,
642
643
    ) -> CompletionLogProbs:
        """Create logprobs for OpenAI Completion API."""
644
        out_text_offset: list[int] = []
645
        out_token_logprobs: list[float | None] = []
646
        out_tokens: list[str] = []
647
        out_top_logprobs: list[dict[str, float] | None] = []
648
649
650

        last_token_len = 0

651
652
653
654
655
        should_return_as_token_id = (
            return_as_token_id
            if return_as_token_id is not None
            else self.return_tokens_as_token_ids
        )
656
657
658
        for i, token_id in enumerate(token_ids):
            step_top_logprobs = top_logprobs[i]
            if step_top_logprobs is None:
659
                if should_return_as_token_id:
660
                    token = f"token_id:{token_id}"
661
662
                else:
                    if tokenizer is None:
663
664
665
666
667
                        raise VLLMValidationError(
                            "Unable to get tokenizer because "
                            "`skip_tokenizer_init=True`",
                            parameter="skip_tokenizer_init",
                            value=True,
668
669
670
                        )

                    token = tokenizer.decode(token_id)
671

672
673
674
675
                out_tokens.append(token)
                out_token_logprobs.append(None)
                out_top_logprobs.append(None)
            else:
676
677
                step_token = step_top_logprobs[token_id]

678
                token = self._get_decoded_token(
679
                    step_token,
680
681
                    token_id,
                    tokenizer,
682
                    return_as_token_id=should_return_as_token_id,
683
684
685
                )
                token_logprob = max(step_token.logprob, -9999.0)

686
687
688
689
690
691
692
                out_tokens.append(token)
                out_token_logprobs.append(token_logprob)

                # makes sure to add the top num_output_top_logprobs + 1
                # logprobs, as defined in the openai API
                # (cf. https://github.com/openai/openai-openapi/blob/
                # 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153)
693
694
695
696
697
698
699
700
701
702
703
704
705
706
                out_top_logprobs.append(
                    {
                        # Convert float("-inf") to the
                        # JSON-serializable float that OpenAI uses
                        self._get_decoded_token(
                            top_lp[1],
                            top_lp[0],
                            tokenizer,
                            return_as_token_id=should_return_as_token_id,
                        ): max(top_lp[1].logprob, -9999.0)
                        for i, top_lp in enumerate(step_top_logprobs.items())
                        if num_output_top_logprobs >= i
                    }
                )
707
708
709
710
711
712
713
714
715
716
717
718
719

            if len(out_text_offset) == 0:
                out_text_offset.append(initial_text_offset)
            else:
                out_text_offset.append(out_text_offset[-1] + last_token_len)
            last_token_len = len(token)

        return CompletionLogProbs(
            text_offset=out_text_offset,
            token_logprobs=out_token_logprobs,
            tokens=out_tokens,
            top_logprobs=out_top_logprobs,
        )
720
721
722
723

    def _build_render_config(
        self,
        request: CompletionRequest,
724
        max_input_length: int | None = None,
725
    ) -> RenderConfig:
726
727
728
729
730
731
732
733
734
        # Validate max_tokens before using it
        if request.max_tokens is not None and request.max_tokens > self.max_model_len:
            raise VLLMValidationError(
                f"'max_tokens' ({request.max_tokens}) cannot be greater than "
                f"the model's maximum context length ({self.max_model_len}).",
                parameter="max_tokens",
                value=request.max_tokens,
            )

735
736
737
738
739
740
        max_input_tokens_len = self.max_model_len - (request.max_tokens or 0)
        return RenderConfig(
            max_length=max_input_tokens_len,
            truncate_prompt_tokens=request.truncate_prompt_tokens,
            add_special_tokens=request.add_special_tokens,
            cache_salt=request.cache_salt,
741
            needs_detokenization=bool(request.echo and not request.return_token_ids),
742
        )