serving_completion.py 28.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
import time
6
7
from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import Sequence as GenericSequence
8
from typing import cast
9

10
import jinja2
11
from fastapi import Request
12

13
from vllm.engine.protocol import EngineClient
14
from vllm.entrypoints.logger import RequestLogger
15
16
17
18
19
20
21
22
23
24
25
26
27
from vllm.entrypoints.openai.protocol import (
    CompletionLogProbs,
    CompletionRequest,
    CompletionResponse,
    CompletionResponseChoice,
    CompletionResponseStreamChoice,
    CompletionStreamResponse,
    ErrorResponse,
    PromptTokenUsageInfo,
    RequestResponseMetadata,
    UsageInfo,
)
from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs
28
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
29
from vllm.entrypoints.renderer import RenderConfig
30
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
31
from vllm.inputs.data import EmbedsPrompt, TokensPrompt, is_embeds_prompt
32
from vllm.logger import init_logger
33
from vllm.logprobs import Logprob
34
from vllm.outputs import RequestOutput
35
from vllm.sampling_params import BeamSearchParams, SamplingParams
36
from vllm.transformers_utils.tokenizer import AnyTokenizer
37
from vllm.utils import as_list
38
from vllm.utils.asyncio import merge_async_iterators
39
40
41
42
43

logger = init_logger(__name__)


class OpenAIServingCompletion(OpenAIServing):
44
45
    def __init__(
        self,
46
        engine_client: EngineClient,
47
        models: OpenAIServingModels,
48
        *,
49
        request_logger: RequestLogger | None,
50
        return_tokens_as_token_ids: bool = False,
51
        enable_prompt_tokens_details: bool = False,
52
        enable_force_include_usage: bool = False,
53
        log_error_stack: bool = False,
54
    ):
55
56
57
58
59
        super().__init__(
            engine_client=engine_client,
            models=models,
            request_logger=request_logger,
            return_tokens_as_token_ids=return_tokens_as_token_ids,
60
            log_error_stack=log_error_stack,
61
        )
62
        self.enable_prompt_tokens_details = enable_prompt_tokens_details
63
        self.default_sampling_params = self.model_config.get_diff_sampling_param()
64
        self.enable_force_include_usage = enable_force_include_usage
65
        if self.default_sampling_params:
66
67
            source = self.model_config.generation_config
            source = "model" if source == "auto" else source
68
69
70
71
72
            logger.info(
                "Using default completion sampling params from %s: %s",
                source,
                self.default_sampling_params,
            )
73

74
75
76
    async def create_completion(
        self,
        request: CompletionRequest,
77
78
        raw_request: Request | None = None,
    ) -> AsyncGenerator[str, None] | CompletionResponse | ErrorResponse:
79
80
81
82
83
        """Completion API similar to OpenAI's API.

        See https://platform.openai.com/docs/api-reference/completions/create
        for the API specification. This API mimics the OpenAI Completion API.

84
        NOTE: Currently we do not support the following feature:
85
86
87
88
89
90
91
            - suffix (the language models we currently support do not support
            suffix)
        """
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

92
93
94
95
96
97
        # If the engine is dead, raise the engine's DEAD_ERROR.
        # This is required for the streaming case, where we return a
        # success status before we actually start generating text :).
        if self.engine_client.errored:
            raise self.engine_client.dead_error

98
        # Return error for unsupported features.
99
        if request.suffix is not None:
100
            return self.create_error_response("suffix is not currently supported")
101

102
        if request.echo and request.prompt_embeds is not None:
103
            return self.create_error_response("Echo is unsupported with prompt embeds.")
104

105
        if request.prompt_logprobs is not None and request.prompt_embeds is not None:
106
            return self.create_error_response(
107
108
                "prompt_logprobs is not compatible with prompt embeds."
            )
109

110
        request_id = f"cmpl-{self._base_request_id(raw_request, request.request_id)}"
111
        created_time = int(time.time())
112

113
114
115
116
        request_metadata = RequestResponseMetadata(request_id=request_id)
        if raw_request:
            raw_request.state.request_metadata = request_metadata

117
        try:
118
            lora_request = self._maybe_get_adapters(request)
119

120
121
122
            if self.model_config.skip_tokenizer_init:
                tokenizer = None
            else:
123
                tokenizer = await self.engine_client.get_tokenizer()
124
125
126
127
128
            renderer = self._get_renderer(tokenizer)

            engine_prompts = await renderer.render_prompt_and_embeds(
                prompt_or_prompts=request.prompt,
                prompt_embeds=request.prompt_embeds,
129
                config=self._build_render_config(request),
130
131
132
133
            )
        except ValueError as e:
            logger.exception("Error in preprocessing prompt inputs")
            return self.create_error_response(str(e))
134
135
136
137
138
139
140
141
142
        except TypeError as e:
            logger.exception("Error in preprocessing prompt inputs")
            return self.create_error_response(str(e))
        except RuntimeError as e:
            logger.exception("Error in preprocessing prompt inputs")
            return self.create_error_response(str(e))
        except jinja2.TemplateError as e:
            logger.exception("Error in preprocessing prompt inputs")
            return self.create_error_response(str(e))
143

144
        # Schedule the request and get the result generator.
145
        generators: list[AsyncGenerator[RequestOutput, None]] = []
146
147
        try:
            for i, engine_prompt in enumerate(engine_prompts):
148
                prompt_text, prompt_token_ids, prompt_embeds = (
149
150
                    self._get_prompt_components(engine_prompt)
                )
151
152
153
154
155
156

                input_length = None
                if prompt_token_ids is not None:
                    input_length = len(prompt_token_ids)
                elif prompt_embeds is not None:
                    input_length = len(prompt_embeds)
157
                else:
158
                    raise NotImplementedError
159
160
161
162
163
164
165
166

                if self.default_sampling_params is None:
                    self.default_sampling_params = {}

                max_tokens = get_max_tokens(
                    max_model_len=self.max_model_len,
                    request=request,
                    input_length=input_length,
167
168
                    default_sampling_params=self.default_sampling_params,
                )
169

170
                sampling_params: SamplingParams | BeamSearchParams
171
172
                if request.use_beam_search:
                    sampling_params = request.to_beam_search_params(
173
174
                        max_tokens, self.default_sampling_params
                    )
175
176
                else:
                    sampling_params = request.to_sampling_params(
177
178
179
180
                        max_tokens,
                        self.model_config.logits_processor_pattern,
                        self.default_sampling_params,
                    )
181

182
183
                request_id_item = f"{request_id}-{i}"

184
185
                self._log_inputs(
                    request_id_item,
186
                    engine_prompt,
187
188
189
                    params=sampling_params,
                    lora_request=lora_request,
                )
190

191
192
193
194
195
                trace_headers = (
                    None
                    if raw_request is None
                    else await self._get_trace_headers(raw_request.headers)
                )
196

197
198
199
                # Mypy inconsistently requires this second cast in different
                # environments. It shouldn't be necessary (redundant from above)
                # but pre-commit in CI fails without it.
200
                engine_prompt = cast(EmbedsPrompt | TokensPrompt, engine_prompt)
201
                if isinstance(sampling_params, BeamSearchParams):
202
                    generator = self.beam_search(
203
                        prompt=engine_prompt,
204
205
                        request_id=request_id,
                        params=sampling_params,
206
                        lora_request=lora_request,
207
                    )
208
                else:
209
210
211
212
213
214
215
216
                    engine_request, tokenization_kwargs = await self._process_inputs(
                        request_id_item,
                        engine_prompt,
                        sampling_params,
                        lora_request=lora_request,
                        trace_headers=trace_headers,
                        priority=request.priority,
                    )
217

218
                    generator = self.engine_client.generate(
219
                        engine_request,
220
221
222
223
224
                        sampling_params,
                        request_id_item,
                        lora_request=lora_request,
                        trace_headers=trace_headers,
                        priority=request.priority,
225
226
                        prompt_text=prompt_text,
                        tokenization_kwargs=tokenization_kwargs,
227
                    )
228
229

                generators.append(generator)
230
        except ValueError as e:
231
            # TODO: Use a vllm-specific Validation Error
232
            return self.create_error_response(str(e))
233

234
        result_generator = merge_async_iterators(*generators)
235

236
        model_name = self.models.model_name(lora_request)
237
238
        num_prompts = len(engine_prompts)

239
240
241
        # Similar to the OpenAI API, when n != best_of, we do not stream the
        # results. Noting that best_of is only supported in V0. In addition,
        # we do not stream the results when use beam search.
242
243
244
245
246
        stream = (
            request.stream
            and (request.best_of is None or request.n == request.best_of)
            and not request.use_beam_search
        )
247
248
249

        # Streaming response
        if stream:
250
251
            return self.completion_stream_generator(
                request,
252
                engine_prompts,
253
254
255
256
                result_generator,
                request_id,
                created_time,
                model_name,
257
                num_prompts=num_prompts,
258
                tokenizer=tokenizer,
259
                request_metadata=request_metadata,
260
            )
261
262

        # Non-streaming response
263
        final_res_batch: list[RequestOutput | None] = [None] * num_prompts
264
265
266
        try:
            async for i, res in result_generator:
                final_res_batch[i] = res
267
268
269
270
271
272
273
274

            for i, final_res in enumerate(final_res_batch):
                assert final_res is not None

                # The output should contain the input text
                # We did not pass it into vLLM engine to avoid being redundant
                # with the inputs token IDs
                if final_res.prompt is None:
275
                    engine_prompt = engine_prompts[i]
276
277
278
279
280
                    final_res.prompt = (
                        None
                        if is_embeds_prompt(engine_prompt)
                        else engine_prompt.get("prompt")
                    )
281

282
            final_res_batch_checked = cast(list[RequestOutput], final_res_batch)
283

284
            response = self.request_output_to_completion_response(
285
286
287
288
289
290
                final_res_batch_checked,
                request,
                request_id,
                created_time,
                model_name,
                tokenizer,
291
                request_metadata,
292
            )
293
294
        except asyncio.CancelledError:
            return self.create_error_response("Client disconnected")
295
296
297
        except ValueError as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))
298

299
300
        # When user requests streaming but we don't stream, we still need to
        # return a streaming response with a single event.
301
        if request.stream:
302
            response_json = response.model_dump_json()
303
304
305
306
307
308
309
310

            async def fake_stream_generator() -> AsyncGenerator[str, None]:
                yield f"data: {response_json}\n\n"
                yield "data: [DONE]\n\n"

            return fake_stream_generator()

        return response
311
312
313
314

    async def completion_stream_generator(
        self,
        request: CompletionRequest,
315
        engine_prompts: list[TokensPrompt | EmbedsPrompt],
316
        result_generator: AsyncIterator[tuple[int, RequestOutput]],
317
318
319
320
        request_id: str,
        created_time: int,
        model_name: str,
        num_prompts: int,
321
        tokenizer: AnyTokenizer,
322
        request_metadata: RequestResponseMetadata,
323
    ) -> AsyncGenerator[str, None]:
324
        num_choices = 1 if request.n is None else request.n
325
        previous_text_lens = [0] * num_choices * num_prompts
326
327
        previous_num_tokens = [0] * num_choices * num_prompts
        has_echoed = [False] * num_choices * num_prompts
328
        num_prompt_tokens = [0] * num_prompts
329
330
        num_cached_tokens = None
        first_iteration = True
331

332
        stream_options = request.stream_options
333
334
335
        include_usage, include_continuous_usage = should_include_usage(
            stream_options, self.enable_force_include_usage
        )
336

337
338
        try:
            async for prompt_idx, res in result_generator:
339
340
                prompt_token_ids = res.prompt_token_ids
                prompt_logprobs = res.prompt_logprobs
341

342
343
344
345
                if first_iteration:
                    num_cached_tokens = res.num_cached_tokens
                    first_iteration = False

346
347
348
                prompt_text = res.prompt
                if prompt_text is None:
                    engine_prompt = engine_prompts[prompt_idx]
349
350
351
352
353
                    prompt_text = (
                        None
                        if is_embeds_prompt(engine_prompt)
                        else engine_prompt.get("prompt")
                    )
354

355
                # Prompt details are excluded from later streamed outputs
356
357
                if prompt_token_ids is not None:
                    num_prompt_tokens[prompt_idx] = len(prompt_token_ids)
358

359
                delta_token_ids: GenericSequence[int]
360
                out_logprobs: GenericSequence[dict[int, Logprob] | None] | None
361
362

                for output in res.outputs:
363
                    i = output.index + prompt_idx * num_choices
364

365
366
367
                    # Useful when request.return_token_ids is True
                    # Returning prompt token IDs shares the same logic
                    # with the echo implementation.
368
                    prompt_token_ids_to_return: list[int] | None = None
369

370
                    assert request.max_tokens is not None
371
                    if request.echo and not has_echoed[i]:
372
                        assert prompt_token_ids is not None
373
374
                        if request.return_token_ids:
                            prompt_text = ""
375
                        assert prompt_text is not None
376
377
378
379
380
381
382
383
384
                        if request.max_tokens == 0:
                            # only return the prompt
                            delta_text = prompt_text
                            delta_token_ids = prompt_token_ids
                            out_logprobs = prompt_logprobs
                        else:
                            # echo the prompt and first token
                            delta_text = prompt_text + output.text
                            delta_token_ids = [
385
386
                                *prompt_token_ids,
                                *output.token_ids,
387
388
                            ]
                            out_logprobs = [
389
                                *(prompt_logprobs or []),
390
391
                                *(output.logprobs or []),
                            ]
392
                        prompt_token_ids_to_return = prompt_token_ids
393
394
395
                        has_echoed[i] = True
                    else:
                        # return just the delta
396
397
398
                        delta_text = output.text
                        delta_token_ids = output.token_ids
                        out_logprobs = output.logprobs
399

400
401
402
403
404
405
                        # has_echoed[i] is reused here to indicate whether
                        # we have already returned the prompt token IDs.
                        if not has_echoed[i]:
                            prompt_token_ids_to_return = prompt_token_ids
                            has_echoed[i] = True

406
407
408
409
410
                        if (
                            not delta_text
                            and not delta_token_ids
                            and not previous_num_tokens[i]
                        ):
411
412
413
                            # Chunked prefill case, don't return empty chunks
                            continue

414
                    if request.logprobs is not None:
415
                        assert out_logprobs is not None, "Did not output logprobs"
416
                        logprobs = self._create_completion_logprobs(
417
                            token_ids=delta_token_ids,
418
                            top_logprobs=out_logprobs,
419
                            num_output_top_logprobs=request.logprobs,
420
                            tokenizer=tokenizer,
421
                            initial_text_offset=previous_text_lens[i],
422
                            return_as_token_id=request.return_tokens_as_token_ids,
423
424
425
426
                        )
                    else:
                        logprobs = None

427
428
                    previous_text_lens[i] += len(output.text)
                    previous_num_tokens[i] += len(output.token_ids)
429
                    finish_reason = output.finish_reason
430
                    stop_reason = output.stop_reason
431
432

                    chunk = CompletionStreamResponse(
433
434
435
436
437
438
439
440
441
                        id=request_id,
                        created=created_time,
                        model=model_name,
                        choices=[
                            CompletionResponseStreamChoice(
                                index=i,
                                text=delta_text,
                                logprobs=logprobs,
                                finish_reason=finish_reason,
442
                                stop_reason=stop_reason,
443
                                prompt_token_ids=prompt_token_ids_to_return,
444
445
446
447
448
                                token_ids=(
                                    as_list(output.token_ids)
                                    if request.return_token_ids
                                    else None
                                ),
449
                            )
450
451
                        ],
                    )
452
453
454
455
456
457
458
459
                    if include_continuous_usage:
                        prompt_tokens = num_prompt_tokens[prompt_idx]
                        completion_tokens = previous_num_tokens[i]
                        chunk.usage = UsageInfo(
                            prompt_tokens=prompt_tokens,
                            completion_tokens=completion_tokens,
                            total_tokens=prompt_tokens + completion_tokens,
                        )
460

461
                    response_json = chunk.model_dump_json(exclude_unset=False)
462
                    yield f"data: {response_json}\n\n"
463

464
465
466
467
468
            total_prompt_tokens = sum(num_prompt_tokens)
            total_completion_tokens = sum(previous_num_tokens)
            final_usage_info = UsageInfo(
                prompt_tokens=total_prompt_tokens,
                completion_tokens=total_completion_tokens,
469
470
                total_tokens=total_prompt_tokens + total_completion_tokens,
            )
471

472
473
            if self.enable_prompt_tokens_details and num_cached_tokens:
                final_usage_info.prompt_tokens_details = PromptTokenUsageInfo(
474
475
                    cached_tokens=num_cached_tokens
                )
476

477
            if include_usage:
478
479
480
481
482
                final_usage_chunk = CompletionStreamResponse(
                    id=request_id,
                    created=created_time,
                    model=model_name,
                    choices=[],
483
                    usage=final_usage_info,
484
                )
485
                final_usage_data = final_usage_chunk.model_dump_json(
486
487
                    exclude_unset=False, exclude_none=True
                )
488
489
                yield f"data: {final_usage_data}\n\n"

490
            # report to FastAPI middleware aggregate usage across all choices
491
            request_metadata.final_usage_info = final_usage_info
492

493
        except Exception as e:
494
495
496
497
498
499
500
            # TODO: Use a vllm-specific Validation Error
            data = self.create_streaming_error_response(str(e))
            yield f"data: {data}\n\n"
        yield "data: [DONE]\n\n"

    def request_output_to_completion_response(
        self,
501
        final_res_batch: list[RequestOutput],
502
503
504
505
        request: CompletionRequest,
        request_id: str,
        created_time: int,
        model_name: str,
506
        tokenizer: AnyTokenizer,
507
        request_metadata: RequestResponseMetadata,
508
    ) -> CompletionResponse:
509
        choices: list[CompletionResponseChoice] = []
510
511
        num_prompt_tokens = 0
        num_generated_tokens = 0
512
513
        kv_transfer_params = None
        last_final_res = None
514
        for final_res in final_res_batch:
515
            last_final_res = final_res
516
            prompt_token_ids = final_res.prompt_token_ids
517
            assert prompt_token_ids is not None
518
            prompt_logprobs = clamp_prompt_logprobs(final_res.prompt_logprobs)
519
520
            prompt_text = final_res.prompt

521
            token_ids: GenericSequence[int]
522
            out_logprobs: GenericSequence[dict[int, Logprob] | None] | None
523

524
            for output in final_res.outputs:
525
                assert request.max_tokens is not None
526
                if request.echo:
527
528
                    if request.return_token_ids:
                        prompt_text = ""
529
                    assert prompt_text is not None
530
531
532
533
                    if request.max_tokens == 0:
                        token_ids = prompt_token_ids
                        out_logprobs = prompt_logprobs
                        output_text = prompt_text
534
                    else:
535
536
537
538
539
540
541
542
543
544
545
546
547
                        token_ids = [*prompt_token_ids, *output.token_ids]

                        if request.logprobs is None:
                            out_logprobs = None
                        else:
                            assert prompt_logprobs is not None
                            assert output.logprobs is not None
                            out_logprobs = [
                                *prompt_logprobs,
                                *output.logprobs,
                            ]

                        output_text = prompt_text + output.text
548
549
                else:
                    token_ids = output.token_ids
550
                    out_logprobs = output.logprobs
551
552
553
                    output_text = output.text

                if request.logprobs is not None:
554
                    assert out_logprobs is not None, "Did not output logprobs"
555
                    logprobs = self._create_completion_logprobs(
556
                        token_ids=token_ids,
557
                        top_logprobs=out_logprobs,
558
                        tokenizer=tokenizer,
559
                        num_output_top_logprobs=request.logprobs,
560
                        return_as_token_id=request.return_tokens_as_token_ids,
561
562
563
564
565
566
567
568
569
                    )
                else:
                    logprobs = None

                choice_data = CompletionResponseChoice(
                    index=len(choices),
                    text=output_text,
                    logprobs=logprobs,
                    finish_reason=output.finish_reason,
570
                    stop_reason=output.stop_reason,
571
                    prompt_logprobs=final_res.prompt_logprobs,
572
573
574
575
576
577
                    prompt_token_ids=(
                        prompt_token_ids if request.return_token_ids else None
                    ),
                    token_ids=(
                        as_list(output.token_ids) if request.return_token_ids else None
                    ),
578
579
580
                )
                choices.append(choice_data)

581
582
                num_generated_tokens += len(output.token_ids)

583
584
585
586
587
588
589
590
            num_prompt_tokens += len(prompt_token_ids)

        usage = UsageInfo(
            prompt_tokens=num_prompt_tokens,
            completion_tokens=num_generated_tokens,
            total_tokens=num_prompt_tokens + num_generated_tokens,
        )

591
592
593
594
595
        if (
            self.enable_prompt_tokens_details
            and last_final_res
            and last_final_res.num_cached_tokens
        ):
596
            usage.prompt_tokens_details = PromptTokenUsageInfo(
597
598
                cached_tokens=last_final_res.num_cached_tokens
            )
599

600
        request_metadata.final_usage_info = usage
601
602
        if final_res_batch:
            kv_transfer_params = final_res_batch[0].kv_transfer_params
603
604
605
606
607
608
        return CompletionResponse(
            id=request_id,
            created=created_time,
            model=model_name,
            choices=choices,
            usage=usage,
609
610
            kv_transfer_params=kv_transfer_params,
        )
611
612
613
614

    def _create_completion_logprobs(
        self,
        token_ids: GenericSequence[int],
615
        top_logprobs: GenericSequence[dict[int, Logprob] | None],
616
        num_output_top_logprobs: int,
617
        tokenizer: AnyTokenizer,
618
        initial_text_offset: int = 0,
619
        return_as_token_id: bool | None = None,
620
621
    ) -> CompletionLogProbs:
        """Create logprobs for OpenAI Completion API."""
622
        out_text_offset: list[int] = []
623
        out_token_logprobs: list[float | None] = []
624
        out_tokens: list[str] = []
625
        out_top_logprobs: list[dict[str, float] | None] = []
626
627
628

        last_token_len = 0

629
630
631
632
633
        should_return_as_token_id = (
            return_as_token_id
            if return_as_token_id is not None
            else self.return_tokens_as_token_ids
        )
634
635
636
        for i, token_id in enumerate(token_ids):
            step_top_logprobs = top_logprobs[i]
            if step_top_logprobs is None:
637
                token = tokenizer.decode(token_id)
638
                if should_return_as_token_id:
639
                    token = f"token_id:{token_id}"
640

641
642
643
644
                out_tokens.append(token)
                out_token_logprobs.append(None)
                out_top_logprobs.append(None)
            else:
645
646
                step_token = step_top_logprobs[token_id]

647
                token = self._get_decoded_token(
648
                    step_token,
649
650
                    token_id,
                    tokenizer,
651
                    return_as_token_id=should_return_as_token_id,
652
653
654
                )
                token_logprob = max(step_token.logprob, -9999.0)

655
656
657
658
659
660
661
                out_tokens.append(token)
                out_token_logprobs.append(token_logprob)

                # makes sure to add the top num_output_top_logprobs + 1
                # logprobs, as defined in the openai API
                # (cf. https://github.com/openai/openai-openapi/blob/
                # 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153)
662
663
664
665
666
667
668
669
670
671
672
673
674
675
                out_top_logprobs.append(
                    {
                        # Convert float("-inf") to the
                        # JSON-serializable float that OpenAI uses
                        self._get_decoded_token(
                            top_lp[1],
                            top_lp[0],
                            tokenizer,
                            return_as_token_id=should_return_as_token_id,
                        ): max(top_lp[1].logprob, -9999.0)
                        for i, top_lp in enumerate(step_top_logprobs.items())
                        if num_output_top_logprobs >= i
                    }
                )
676
677
678
679
680
681
682
683
684
685
686
687
688

            if len(out_text_offset) == 0:
                out_text_offset.append(initial_text_offset)
            else:
                out_text_offset.append(out_text_offset[-1] + last_token_len)
            last_token_len = len(token)

        return CompletionLogProbs(
            text_offset=out_text_offset,
            token_logprobs=out_token_logprobs,
            tokens=out_tokens,
            top_logprobs=out_top_logprobs,
        )
689
690
691
692

    def _build_render_config(
        self,
        request: CompletionRequest,
693
        max_input_length: int | None = None,
694
695
696
697
698
699
700
    ) -> RenderConfig:
        max_input_tokens_len = self.max_model_len - (request.max_tokens or 0)
        return RenderConfig(
            max_length=max_input_tokens_len,
            truncate_prompt_tokens=request.truncate_prompt_tokens,
            add_special_tokens=request.add_special_tokens,
            cache_salt=request.cache_salt,
701
            needs_detokenization=bool(request.echo and not request.return_token_ids),
702
        )