serving_engine.py 49.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
import json
5
import sys
6
import time
7
import traceback
8
from collections.abc import AsyncGenerator, Callable, Iterable, Mapping, Sequence
9
from concurrent.futures import ThreadPoolExecutor
10
from http import HTTPStatus
11
from typing import Any, ClassVar, Generic, TypeAlias, TypeVar
12

13
import torch
14
from fastapi import Request
15
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
16
from starlette.datastructures import Headers
17
18
from typing_extensions import TypeIs

19
20
21
22
23
if sys.version_info >= (3, 12):
    from typing import TypedDict
else:
    from typing_extensions import TypedDict

24
25
26
27
from openai.types.responses import (
    ToolChoiceFunction,
)

28
import vllm.envs as envs
29
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
30
from vllm.engine.protocol import EngineClient
31
32
33
34
35
36
37
38
39
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ChatTemplateContentFormatOption,
    ConversationMessage,
    apply_hf_chat_template,
    apply_mistral_chat_template,
    parse_chat_messages_futures,
    resolve_chat_template_content_format,
)
40
from vllm.entrypoints.context import ConversationContext
41
from vllm.entrypoints.logger import RequestLogger
42
from vllm.entrypoints.openai.protocol import (
43
    ChatCompletionNamedToolChoiceParam,
44
45
46
47
48
49
50
51
52
53
54
55
56
    ChatCompletionRequest,
    ChatCompletionResponse,
    ClassificationRequest,
    ClassificationResponse,
    CompletionRequest,
    CompletionResponse,
    DetokenizeRequest,
    EmbeddingChatRequest,
    EmbeddingCompletionRequest,
    EmbeddingRequest,
    EmbeddingResponse,
    ErrorInfo,
    ErrorResponse,
57
58
    FunctionCall,
    FunctionDefinition,
59
60
61
62
63
64
65
66
67
68
69
70
71
    IOProcessorRequest,
    PoolingResponse,
    RerankRequest,
    ResponsesRequest,
    ScoreRequest,
    ScoreResponse,
    TokenizeChatRequest,
    TokenizeCompletionRequest,
    TokenizeResponse,
    TranscriptionRequest,
    TranscriptionResponse,
    TranslationRequest,
)
72
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
73
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
74
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig
75
from vllm.entrypoints.utils import _validate_truncation_size
76
from vllm.inputs.data import PromptType
77
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
78
79
80
81
82
from vllm.inputs.parse import (
    PromptComponents,
    get_prompt_components,
    is_explicit_encoder_decoder_prompt,
)
83
from vllm.logger import init_logger
84
from vllm.logprobs import Logprob, PromptLogprobs
85
from vllm.lora.request import LoRARequest
86
from vllm.multimodal import (  # noqa: F401 - Required to resolve Pydantic error in RequestProcessingMixin
87
88
89
    MultiModalDataDict,
    MultiModalUUIDDict,
)
90
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
91
from vllm.pooling_params import PoolingParams
92
from vllm.reasoning import ReasoningParser, ReasoningParserManager
93
from vllm.sampling_params import BeamSearchParams, SamplingParams
94
95
96
97
98
from vllm.tracing import (
    contains_trace_headers,
    extract_trace_headers,
    log_tracing_disabled_warning,
)
99
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
100
from vllm.utils import random_uuid
101
from vllm.utils.async_utils import (
102
    AsyncMicrobatchTokenizer,
103
    collect_from_async_generator,
104
    make_async,
105
106
    merge_async_iterators,
)
107
from vllm.utils.collection_utils import is_list_of
108
from vllm.v1.engine import EngineCoreRequest
109
110
111

logger = init_logger(__name__)

112
113
114
115
116
117
118
119
120
CompletionLikeRequest: TypeAlias = (
    CompletionRequest
    | DetokenizeRequest
    | EmbeddingCompletionRequest
    | RerankRequest
    | ClassificationRequest
    | ScoreRequest
    | TokenizeCompletionRequest
)
121

122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
ChatLikeRequest: TypeAlias = (
    ChatCompletionRequest | EmbeddingChatRequest | TokenizeChatRequest
)
SpeechToTextRequest: TypeAlias = TranscriptionRequest | TranslationRequest
AnyRequest: TypeAlias = (
    CompletionLikeRequest
    | ChatLikeRequest
    | SpeechToTextRequest
    | ResponsesRequest
    | IOProcessorRequest
)

AnyResponse: TypeAlias = (
    CompletionResponse
    | ChatCompletionResponse
    | EmbeddingResponse
    | TranscriptionResponse
    | TokenizeResponse
    | PoolingResponse
    | ClassificationResponse
    | ScoreResponse
)
144

145
146
147

class TextTokensPrompt(TypedDict):
    prompt: str
148
    prompt_token_ids: list[int]
149
150


151
152
153
154
class EmbedsPrompt(TypedDict):
    prompt_embeds: torch.Tensor


155
RequestPrompt: TypeAlias = list[int] | str | TextTokensPrompt | EmbedsPrompt
156
157
158


def is_text_tokens_prompt(prompt: RequestPrompt) -> TypeIs[TextTokensPrompt]:
159
160
161
162
163
    return (
        isinstance(prompt, dict)
        and "prompt_token_ids" in prompt
        and "prompt_embeds" not in prompt
    )
164
165
166


def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]:
167
168
169
170
171
    return (
        isinstance(prompt, dict)
        and "prompt_token_ids" not in prompt
        and "prompt_embeds" in prompt
    )
172

173

174
175
176
177
178
RequestT = TypeVar("RequestT", bound=AnyRequest)


class RequestProcessingMixin(BaseModel):
    """
179
    Mixin for request processing,
180
181
    handling prompt preparation and engine input.
    """
182

183
184
    request_prompts: Sequence[RequestPrompt] | None = []
    engine_prompts: list[EngineTokensPrompt] | None = []
185
186
187
188
189
190

    model_config = ConfigDict(arbitrary_types_allowed=True)


class ResponseGenerationMixin(BaseModel):
    """
191
    Mixin for response generation,
192
193
    managing result generators and final batch results.
    """
194

195
196
197
198
    result_generator: (
        AsyncGenerator[tuple[int, RequestOutput | PoolingRequestOutput], None] | None
    ) = None
    final_res_batch: list[RequestOutput | PoolingRequestOutput] = Field(
199
200
        default_factory=list
    )
201
202
203
204

    model_config = ConfigDict(arbitrary_types_allowed=True)


205
class ServeContext(
206
207
208
209
    RequestProcessingMixin,
    ResponseGenerationMixin,
    BaseModel,
    Generic[RequestT],
210
):
211
212
    # Shared across all requests
    request: RequestT
213
    raw_request: Request | None = None
214
215
216
    model_name: str
    request_id: str
    created_time: int = Field(default_factory=lambda: int(time.time()))
217
    lora_request: LoRARequest | None = None
218
219

    # Shared across most requests
220
    tokenizer: AnyTokenizer | None = None
221
222
223
224
225
226
227
228
229
230
231
232
233

    # `protected_namespaces` resolves Pydantic v2's warning
    # on conflict with protected namespace "model_"
    model_config = ConfigDict(
        protected_namespaces=(),
        arbitrary_types_allowed=True,
    )


ClassificationServeContext = ServeContext[ClassificationRequest]


class EmbeddingServeContext(ServeContext[EmbeddingRequest]):
234
    chat_template: str | None = None
235
236
237
238
239
240
241
242
243
244
    chat_template_content_format: ChatTemplateContentFormatOption


# Used to resolve the Pydantic error related to
# forward reference of MultiModalDataDict in TokensPrompt
RequestProcessingMixin.model_rebuild()
ServeContext.model_rebuild()
ClassificationServeContext.model_rebuild()
EmbeddingServeContext.model_rebuild()

245

246
class OpenAIServing:
247
248
249
250
    request_id_prefix: ClassVar[str] = """
    A short string prepended to every request’s ID (e.g. "embd", "classify")
    so you can easily tell “this ID came from Embedding vs Classification.”
    """
251

252
253
    def __init__(
        self,
254
        engine_client: EngineClient,
255
        models: OpenAIServingModels,
256
        *,
257
        request_logger: RequestLogger | None,
258
        return_tokens_as_token_ids: bool = False,
259
        log_error_stack: bool = False,
260
    ):
261
262
        super().__init__()

263
        self.engine_client = engine_client
264

265
        self.models = models
266

267
        self.request_logger = request_logger
268
        self.return_tokens_as_token_ids = return_tokens_as_token_ids
269
        self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
270
        self._apply_mistral_chat_template_async = make_async(
271
272
            apply_mistral_chat_template, executor=self._tokenizer_executor
        )
273

274
        self._async_tokenizer_pool: dict[AnyTokenizer, AsyncMicrobatchTokenizer] = {}
275
        self.log_error_stack = log_error_stack
276

277
278
279
280
281
        self.processor = self.models.processor
        self.io_processor = self.models.io_processor
        self.model_config = self.models.model_config
        self.max_model_len = self.model_config.max_model_len

282
    def _get_tool_parser(
283
284
        self, tool_parser_name: str | None = None, enable_auto_tools: bool = False
    ) -> Callable[[AnyTokenizer], ToolParser] | None:
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
        """Get the tool parser based on the name."""
        parser = None
        if not enable_auto_tools or tool_parser_name is None:
            return parser
        logger.info(
            '"auto" tool choice has been enabled please note that while'
            " the parallel_tool_calls client option is preset for "
            "compatibility reasons, it will be ignored."
        )

        try:
            if tool_parser_name == "pythonic" and self.model_config.model.startswith(
                "meta-llama/Llama-3.2"
            ):
                logger.warning(
                    "Llama3.2 models may struggle to emit valid pythonic tool calls"
                )
            parser = ToolParserManager.get_tool_parser(tool_parser_name)
        except Exception as e:
            raise TypeError(
                "Error: --enable-auto-tool-choice requires "
                f"tool_parser:'{tool_parser_name}' which has not "
                "been registered"
            ) from e
        return parser

    def _get_reasoning_parser(
        self,
        reasoning_parser_name: str,
314
    ) -> Callable[[AnyTokenizer], ReasoningParser] | None:
315
316
317
318
319
320
321
322
323
324
325
        """Get the reasoning parser based on the name."""
        parser = None
        if not reasoning_parser_name:
            return None
        try:
            parser = ReasoningParserManager.get_reasoning_parser(reasoning_parser_name)
            assert parser is not None
        except Exception as e:
            raise TypeError(f"{reasoning_parser_name=} has not been registered") from e
        return parser

326
327
328
329
    async def reset_mm_cache(self) -> None:
        self.processor.clear_mm_cache()
        await self.engine_client.reset_mm_cache()

330
331
332
333
334
    async def beam_search(
        self,
        prompt: PromptType,
        request_id: str,
        params: BeamSearchParams,
335
        lora_request: LoRARequest | None = None,
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
    ) -> AsyncGenerator[RequestOutput, None]:
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        ignore_eos = params.ignore_eos
        temperature = params.temperature
        length_penalty = params.length_penalty
        include_stop_str_in_output = params.include_stop_str_in_output

        processor = self.processor
        tokenizer = processor.tokenizer
        if tokenizer is None:
            raise ValueError(
                "You cannot use beam search when `skip_tokenizer_init` is True"
            )

        eos_token_id: int = tokenizer.eos_token_id  # type: ignore

        if is_explicit_encoder_decoder_prompt(prompt):
            raise NotImplementedError

356
        prompt_text: str | None
357
        prompt_token_ids: list[int]
358
        multi_modal_data: MultiModalDataDict | None
359
360
361
362
363
364
365
366
367
        if isinstance(prompt, str):
            prompt_text = prompt
            prompt_token_ids = []
            multi_modal_data = None
        else:
            prompt_text = prompt.get("prompt")  # type: ignore
            prompt_token_ids = prompt.get("prompt_token_ids", [])  # type: ignore
            multi_modal_data = prompt.get("multi_modal_data")  # type: ignore

368
369
370
371
372
373
374
375
376
377
        mm_processor_kwargs: dict[str, Any] | None = None

        # This is a workaround to fix multimodal beam search; this is a
        # bandaid fix for 2 small problems:
        # 1. Multi_modal_data on the processed_inputs currently resolves to
        #    `None`.
        # 2. preprocessing above expands the multimodal placeholders. However,
        #    this happens again in generation, so the double expansion causes
        #    a mismatch.
        # TODO - would be ideal to handle this more gracefully.
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504

        tokenized_length = len(prompt_token_ids)

        sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)

        beam_search_params = SamplingParams(
            logprobs=2 * beam_width,
            max_tokens=1,
            temperature=temperature,
        )
        all_beams = [
            BeamSearchSequence(
                tokens=prompt_token_ids,
                cum_logprob=0,
                logprobs=[],
                multi_modal_data=multi_modal_data,
                mm_processor_kwargs=mm_processor_kwargs,
                lora_request=lora_request,
            )
        ]
        completed = []

        for _ in range(max_tokens):
            prompts_batch, lora_req_batch = zip(
                *[
                    (
                        EngineTokensPrompt(
                            prompt_token_ids=beam.tokens,
                            multi_modal_data=beam.multi_modal_data,
                            mm_processor_kwargs=beam.mm_processor_kwargs,
                        ),
                        beam.lora_request,
                    )
                    for beam in all_beams
                ]
            )

            tasks = []
            request_id_batch = f"{request_id}-{random_uuid()}"

            for i, (individual_prompt, lora_req) in enumerate(
                zip(prompts_batch, lora_req_batch)
            ):
                request_id_item = f"{request_id_batch}-beam-{i}"
                task = asyncio.create_task(
                    collect_from_async_generator(
                        self.engine_client.generate(
                            individual_prompt,
                            beam_search_params,
                            request_id_item,
                            lora_request=lora_req,
                        )
                    )
                )
                tasks.append(task)

            output = [x[0] for x in await asyncio.gather(*tasks)]

            new_beams = []
            for i, current_beam in enumerate(all_beams):
                result = output[i]

                if result.outputs[0].logprobs is not None:
                    logprobs = result.outputs[0].logprobs[0]
                    for token_id, logprob_obj in logprobs.items():
                        if token_id == eos_token_id and not ignore_eos:
                            completed.append(
                                BeamSearchSequence(
                                    tokens=current_beam.tokens + [token_id]
                                    if include_stop_str_in_output
                                    else current_beam.tokens,
                                    logprobs=current_beam.logprobs + [logprobs],
                                    cum_logprob=current_beam.cum_logprob
                                    + logprob_obj.logprob,
                                    finish_reason="stop",
                                    stop_reason=eos_token_id,
                                )
                            )
                        else:
                            new_beams.append(
                                BeamSearchSequence(
                                    tokens=current_beam.tokens + [token_id],
                                    logprobs=current_beam.logprobs + [logprobs],
                                    lora_request=current_beam.lora_request,
                                    cum_logprob=current_beam.cum_logprob
                                    + logprob_obj.logprob,
                                    multi_modal_data=current_beam.multi_modal_data,
                                    mm_processor_kwargs=current_beam.mm_processor_kwargs,
                                )
                            )

            sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
            all_beams = sorted_beams[:beam_width]

        completed.extend(all_beams)
        sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
        best_beams = sorted_completed[:beam_width]

        for beam in best_beams:
            if beam.tokens[-1] == eos_token_id and not ignore_eos:
                # Skip the eos token in the text.
                tokens = beam.tokens[tokenized_length:-1]
            else:
                tokens = beam.tokens[tokenized_length:]
            beam.text = tokenizer.decode(tokens)

        yield RequestOutput(
            request_id=request_id,
            prompt=prompt_text,
            outputs=[
                CompletionOutput(
                    text=beam.text,  # type: ignore
                    cumulative_logprob=beam.cum_logprob,
                    token_ids=beam.tokens[tokenized_length:],
                    index=i,
                    logprobs=beam.logprobs,
                    finish_reason=beam.finish_reason
                    if beam.finish_reason is not None
                    else "length",
                    stop_reason=beam.stop_reason,
                )
                for (i, beam) in enumerate(best_beams)
            ],
            finished=True,
            prompt_token_ids=prompt_token_ids,
            prompt_logprobs=None,
        )
505

506
    def _get_renderer(self, tokenizer: AnyTokenizer | None) -> BaseRenderer:
507
508
509
510
511
512
513
        """
        Get a Renderer instance with the provided tokenizer.
        Uses shared async tokenizer pool for efficiency.
        """
        return CompletionRenderer(
            model_config=self.model_config,
            tokenizer=tokenizer,
514
515
            async_tokenizer_pool=self._async_tokenizer_pool,
        )
516

517
518
519
520
521
522
523
524
525
526
527
528
529
    def _build_render_config(
        self,
        request: Any,
    ) -> RenderConfig:
        """
        Build and return a `RenderConfig` for an endpoint.

        Used by the renderer to control how prompts are prepared
        (e.g., tokenization and length handling). Endpoints should
        implement this with logic appropriate to their request type.
        """
        raise NotImplementedError

530
531
    def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer:
        """
532
        Return (and cache) an `AsyncMicrobatchTokenizer` bound to the
533
534
535
536
537
538
539
        given tokenizer.
        """
        async_tokenizer = self._async_tokenizer_pool.get(tokenizer)
        if async_tokenizer is None:
            async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
            self._async_tokenizer_pool[tokenizer] = async_tokenizer
        return async_tokenizer
540

541
542
543
    async def _preprocess(
        self,
        ctx: ServeContext,
544
    ) -> ErrorResponse | None:
545
546
547
548
549
550
551
552
553
        """
        Default preprocessing hook. Subclasses may override
        to prepare `ctx` (classification, embedding, etc.).
        """
        return None

    def _build_response(
        self,
        ctx: ServeContext,
554
    ) -> AnyResponse | ErrorResponse:
555
556
557
558
559
560
561
562
563
        """
        Default response builder. Subclass may override this method
        to return the appropriate response object.
        """
        return self.create_error_response("unimplemented endpoint")

    async def handle(
        self,
        ctx: ServeContext,
564
565
    ) -> AnyResponse | ErrorResponse:
        generation: AsyncGenerator[AnyResponse | ErrorResponse, None]
566
567
568
569
570
571
572
573
574
575
        generation = self._pipeline(ctx)

        async for response in generation:
            return response

        return self.create_error_response("No response yielded from pipeline")

    async def _pipeline(
        self,
        ctx: ServeContext,
576
    ) -> AsyncGenerator[AnyResponse | ErrorResponse, None]:
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
        """Execute the request processing pipeline yielding responses."""
        if error := await self._check_model(ctx.request):
            yield error
        if error := self._validate_request(ctx):
            yield error

        preprocess_ret = await self._preprocess(ctx)
        if isinstance(preprocess_ret, ErrorResponse):
            yield preprocess_ret

        generators_ret = await self._prepare_generators(ctx)
        if isinstance(generators_ret, ErrorResponse):
            yield generators_ret

        collect_ret = await self._collect_batch(ctx)
        if isinstance(collect_ret, ErrorResponse):
            yield collect_ret

        yield self._build_response(ctx)

597
    def _validate_request(self, ctx: ServeContext) -> ErrorResponse | None:
598
        truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None)
599

600
601
602
603
        if (
            truncate_prompt_tokens is not None
            and truncate_prompt_tokens > self.max_model_len
        ):
604
605
606
            return self.create_error_response(
                "truncate_prompt_tokens value is "
                "greater than max_model_len."
607
608
                " Please, select a smaller truncation size."
            )
609
610
        return None

611
612
613
    def _create_pooling_params(
        self,
        ctx: ServeContext,
614
    ) -> PoolingParams | ErrorResponse:
615
616
        if not hasattr(ctx.request, "to_pooling_params"):
            return self.create_error_response(
617
618
                "Request type does not support pooling parameters"
            )
619
620
621

        return ctx.request.to_pooling_params()

622
623
624
    async def _prepare_generators(
        self,
        ctx: ServeContext,
625
    ) -> ErrorResponse | None:
626
        """Schedule the request and get the result generator."""
627
        generators: list[
628
            AsyncGenerator[RequestOutput | PoolingRequestOutput, None]
629
        ] = []
630
631

        try:
632
633
634
635
636
            trace_headers = (
                None
                if ctx.raw_request is None
                else await self._get_trace_headers(ctx.raw_request.headers)
            )
637

638
639
640
            pooling_params = self._create_pooling_params(ctx)
            if isinstance(pooling_params, ErrorResponse):
                return pooling_params
641
642

            if ctx.engine_prompts is None:
643
                return self.create_error_response("Engine prompts not available")
644
645
646
647

            for i, engine_prompt in enumerate(ctx.engine_prompts):
                request_id_item = f"{ctx.request_id}-{i}"

648
649
                self._log_inputs(
                    request_id_item,
650
                    engine_prompt,
651
652
653
                    params=pooling_params,
                    lora_request=ctx.lora_request,
                )
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676

                generator = self.engine_client.encode(
                    engine_prompt,
                    pooling_params,
                    request_id_item,
                    lora_request=ctx.lora_request,
                    trace_headers=trace_headers,
                    priority=getattr(ctx.request, "priority", 0),
                )

                generators.append(generator)

            ctx.result_generator = merge_async_iterators(*generators)

            return None

        except Exception as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))

    async def _collect_batch(
        self,
        ctx: ServeContext,
677
    ) -> ErrorResponse | None:
678
679
680
        """Collect batch results from the result generator."""
        try:
            if ctx.engine_prompts is None:
681
                return self.create_error_response("Engine prompts not available")
682
683

            num_prompts = len(ctx.engine_prompts)
684
            final_res_batch: list[RequestOutput | PoolingRequestOutput | None]
685
686
687
            final_res_batch = [None] * num_prompts

            if ctx.result_generator is None:
688
                return self.create_error_response("Result generator not available")
689
690
691
692
693
694

            async for i, res in ctx.result_generator:
                final_res_batch[i] = res

            if None in final_res_batch:
                return self.create_error_response(
695
696
                    "Failed to generate results for all prompts"
                )
697

698
            ctx.final_res_batch = [res for res in final_res_batch if res is not None]
699
700
701
702
703
704

            return None

        except Exception as e:
            return self.create_error_response(str(e))

705
    def create_error_response(
706
707
708
709
710
        self,
        message: str,
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
    ) -> ErrorResponse:
711
712
713
714
715
716
        if self.log_error_stack:
            exc_type, _, _ = sys.exc_info()
            if exc_type is not None:
                traceback.print_exc()
            else:
                traceback.print_stack()
717
718
719
        return ErrorResponse(
            error=ErrorInfo(message=message, type=err_type, code=status_code.value)
        )
720

721
    def create_streaming_error_response(
722
723
724
725
726
        self,
        message: str,
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
    ) -> str:
727
        json_str = json.dumps(
728
729
730
731
            self.create_error_response(
                message=message, err_type=err_type, status_code=status_code
            ).model_dump()
        )
732
733
        return json_str

734
    async def _check_model(
735
736
        self,
        request: AnyRequest,
737
    ) -> ErrorResponse | None:
738
739
        error_response = None

740
        if self._is_model_supported(request.model):
741
            return None
742
        if request.model in self.models.lora_requests:
743
            return None
744
745
746
747
748
        if (
            envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING
            and request.model
            and (load_result := await self.models.resolve_lora(request.model))
        ):
749
750
            if isinstance(load_result, LoRARequest):
                return None
751
752
753
754
            if (
                isinstance(load_result, ErrorResponse)
                and load_result.error.code == HTTPStatus.BAD_REQUEST.value
            ):
755
756
757
                error_response = load_result

        return error_response or self.create_error_response(
758
759
            message=f"The model `{request.model}` does not exist.",
            err_type="NotFoundError",
760
761
            status_code=HTTPStatus.NOT_FOUND,
        )
762

763
    def _get_active_default_mm_loras(self, request: AnyRequest) -> LoRARequest | None:
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
        """Determine if there are any active default multimodal loras."""
        # TODO: Currently this is only enabled for chat completions
        # to be better aligned with only being enabled for .generate
        # when run offline. It would be nice to support additional
        # tasks types in the future.
        message_types = self._get_message_types(request)
        default_mm_loras = set()

        for lora in self.models.lora_requests.values():
            # Best effort match for default multimodal lora adapters;
            # There is probably a better way to do this, but currently
            # this matches against the set of 'types' in any content lists
            # up until '_', e.g., to match audio_url -> audio
            if lora.lora_name in message_types:
                default_mm_loras.add(lora)

        # Currently only support default modality specific loras if
        # we have exactly one lora matched on the request.
        if len(default_mm_loras) == 1:
            return default_mm_loras.pop()
        return None

786
    def _maybe_get_adapters(
787
788
789
        self,
        request: AnyRequest,
        supports_default_mm_loras: bool = False,
790
    ) -> LoRARequest | None:
791
        if request.model in self.models.lora_requests:
792
            return self.models.lora_requests[request.model]
793
794
795
796
797
798

        # Currently only support default modality specific loras
        # if we have exactly one lora matched on the request.
        if supports_default_mm_loras:
            default_mm_lora = self._get_active_default_mm_loras(request)
            if default_mm_lora is not None:
799
                return default_mm_lora
800
801

        if self._is_model_supported(request.model):
802
            return None
803

804
        # if _check_model has been called earlier, this will be unreachable
805
        raise ValueError(f"The model `{request.model}` does not exist.")
806

807
808
809
810
811
812
813
814
815
816
817
    def _get_message_types(self, request: AnyRequest) -> set[str]:
        """Retrieve the set of types from message content dicts up
        until `_`; we use this to match potential multimodal data
        with default per modality loras.
        """
        message_types: set[str] = set()

        if not hasattr(request, "messages"):
            return message_types

        for message in request.messages:
818
819
820
821
822
            if (
                isinstance(message, dict)
                and "content" in message
                and isinstance(message["content"], list)
            ):
823
824
825
826
827
                for content_dict in message["content"]:
                    if "type" in content_dict:
                        message_types.add(content_dict["type"].split("_")[0])
        return message_types

828
    async def _normalize_prompt_text_to_input(
829
830
831
        self,
        request: AnyRequest,
        prompt: str,
832
        tokenizer: AnyTokenizer,
833
834
        add_special_tokens: bool,
    ) -> TextTokensPrompt:
835
836
        async_tokenizer = self._get_async_tokenizer(tokenizer)

837
838
839
840
        if (
            self.model_config.encoder_config is not None
            and self.model_config.encoder_config.get("do_lower_case", False)
        ):
841
842
            prompt = prompt.lower()

843
        truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
844

845
        if truncate_prompt_tokens is None:
846
            encoded = await async_tokenizer(
847
848
                prompt, add_special_tokens=add_special_tokens
            )
849
850
        elif truncate_prompt_tokens < 0:
            # Negative means we cap at the model's max length
851
852
853
854
            encoded = await async_tokenizer(
                prompt,
                add_special_tokens=add_special_tokens,
                truncation=True,
855
856
                max_length=self.max_model_len,
            )
857
        else:
858
859
860
861
            encoded = await async_tokenizer(
                prompt,
                add_special_tokens=add_special_tokens,
                truncation=True,
862
863
                max_length=truncate_prompt_tokens,
            )
864
865
866
867
868
869

        input_ids = encoded.input_ids
        input_text = prompt

        return self._validate_input(request, input_ids, input_text)

870
    async def _normalize_prompt_tokens_to_input(
871
872
        self,
        request: AnyRequest,
873
        prompt_ids: list[int],
874
        tokenizer: AnyTokenizer | None,
875
    ) -> TextTokensPrompt:
876
        truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
877

878
        if truncate_prompt_tokens is None:
879
            input_ids = prompt_ids
880
        elif truncate_prompt_tokens < 0:
881
            input_ids = prompt_ids[-self.max_model_len :]
882
883
884
        else:
            input_ids = prompt_ids[-truncate_prompt_tokens:]

885
886
887
888
889
        if tokenizer is None:
            input_text = ""
        else:
            async_tokenizer = self._get_async_tokenizer(tokenizer)
            input_text = await async_tokenizer.decode(input_ids)
890

891
892
893
894
895
        return self._validate_input(request, input_ids, input_text)

    def _validate_input(
        self,
        request: AnyRequest,
896
        input_ids: list[int],
897
898
        input_text: str,
    ) -> TextTokensPrompt:
899
900
        token_num = len(input_ids)

901
902
        # Note: EmbeddingRequest, ClassificationRequest,
        # and ScoreRequest doesn't have max_tokens
903
        if isinstance(
904
            request,
905
906
907
908
909
910
911
912
            (
                EmbeddingChatRequest,
                EmbeddingCompletionRequest,
                ScoreRequest,
                RerankRequest,
                ClassificationRequest,
            ),
        ):
913
914
            # Note: input length can be up to the entire model context length
            # since these requests don't generate tokens.
915
            if token_num > self.max_model_len:
916
917
                operations: dict[type[AnyRequest], str] = {
                    ScoreRequest: "score",
918
                    ClassificationRequest: "classification",
919
                }
920
                operation = operations.get(type(request), "embedding generation")
921
922
923
                raise ValueError(
                    f"This model's maximum context length is "
                    f"{self.max_model_len} tokens. However, you requested "
924
                    f"{token_num} tokens in the input for {operation}. "
925
926
927
                    f"Please reduce the length of the input."
                )
            return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
928

929
930
        # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
        # and does not require model context length validation
931
        if isinstance(
932
933
            request,
            (TokenizeCompletionRequest, TokenizeChatRequest, DetokenizeRequest),
934
        ):
935
            return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
936

937
938
939
940
941
        # chat completion endpoint supports max_completion_tokens
        if isinstance(request, ChatCompletionRequest):
            # TODO(#9845): remove max_tokens when field dropped from OpenAI API
            max_tokens = request.max_completion_tokens or request.max_tokens
        else:
942
            max_tokens = getattr(request, "max_tokens", None)
943
944
945
946

        # Note: input length can be up to model context length - 1 for
        # completion-like requests.
        if token_num >= self.max_model_len:
947
            raise ValueError(
948
                f"This model's maximum context length is "
949
950
                f"{self.max_model_len} tokens. However, your request has "
                f"{token_num} input tokens. Please reduce the length of "
951
952
                "the input messages."
            )
953

954
        if max_tokens is not None and token_num + max_tokens > self.max_model_len:
955
956
957
958
959
            raise ValueError(
                "'max_tokens' or 'max_completion_tokens' is too large: "
                f"{max_tokens}. This model's maximum context length is "
                f"{self.max_model_len} tokens and your request has "
                f"{token_num} input tokens ({max_tokens} > {self.max_model_len}"
960
961
                f" - {token_num})."
            )
962
963
964

        return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)

965
    async def _tokenize_prompt_input_async(
966
967
968
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
969
        prompt_input: str | list[int],
970
971
972
        add_special_tokens: bool = True,
    ) -> TextTokensPrompt:
        """
973
        A simpler implementation that tokenizes a single prompt input.
974
        """
975
        async for result in self._tokenize_prompt_inputs_async(
976
977
            request,
            tokenizer,
978
            [prompt_input],
979
            add_special_tokens=add_special_tokens,
980
981
982
        ):
            return result
        raise ValueError("No results yielded from tokenization")
983

984
    async def _tokenize_prompt_inputs_async(
985
986
987
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
988
        prompt_inputs: Iterable[str | list[int]],
989
        add_special_tokens: bool = True,
990
    ) -> AsyncGenerator[TextTokensPrompt, None]:
991
        """
992
        A simpler implementation that tokenizes multiple prompt inputs.
993
        """
994
995
        for prompt in prompt_inputs:
            if isinstance(prompt, str):
996
                yield await self._normalize_prompt_text_to_input(
997
                    request,
998
999
                    prompt=prompt,
                    tokenizer=tokenizer,
1000
1001
1002
                    add_special_tokens=add_special_tokens,
                )
            else:
1003
                yield await self._normalize_prompt_tokens_to_input(
1004
                    request,
1005
1006
                    prompt_ids=prompt,
                    tokenizer=tokenizer,
1007
1008
                )

1009
1010
    def _validate_chat_template(
        self,
1011
1012
        request_chat_template: str | None,
        chat_template_kwargs: dict[str, Any] | None,
1013
        trust_request_chat_template: bool,
1014
    ) -> ErrorResponse | None:
1015
        if not trust_request_chat_template and (
1016
1017
1018
1019
1020
1021
            request_chat_template is not None
            or (
                chat_template_kwargs
                and chat_template_kwargs.get("chat_template") is not None
            )
        ):
1022
1023
1024
            return self.create_error_response(
                "Chat template is passed with request, but "
                "--trust-request-chat-template is not set. "
1025
1026
                "Refused request with untrusted chat template."
            )
1027
1028
        return None

1029
1030
    async def _preprocess_chat(
        self,
1031
        request: ChatLikeRequest | ResponsesRequest,
1032
        tokenizer: AnyTokenizer,
1033
        messages: list[ChatCompletionMessageParam],
1034
        chat_template: str | None,
1035
        chat_template_content_format: ChatTemplateContentFormatOption,
1036
1037
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
1038
1039
1040
1041
        tool_dicts: list[dict[str, Any]] | None = None,
        documents: list[dict[str, str]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
        tool_parser: Callable[[AnyTokenizer], ToolParser] | None = None,
1042
        add_special_tokens: bool = False,
1043
    ) -> tuple[
1044
1045
1046
        list[ConversationMessage],
        Sequence[RequestPrompt],
        list[EngineTokensPrompt],
1047
    ]:
1048
1049
        model_config = self.model_config

1050
1051
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
1052
            tool_dicts,
1053
1054
            chat_template_content_format,
            tokenizer,
1055
            model_config=model_config,
1056
        )
1057
        conversation, mm_data_future, mm_uuids = parse_chat_messages_futures(
1058
            messages,
1059
            model_config,
1060
            tokenizer,
1061
            content_format=resolved_content_format,
1062
1063
        )

1064
        _chat_template_kwargs: dict[str, Any] = dict(
1065
1066
1067
1068
1069
1070
1071
1072
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tool_dicts,
            documents=documents,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

1073
        request_prompt: str | list[int]
1074
1075
1076
1077

        if tokenizer is None:
            request_prompt = "placeholder"
        elif isinstance(tokenizer, MistralTokenizer):
1078
            request_prompt = await self._apply_mistral_chat_template_async(
1079
1080
                tokenizer,
                messages=messages,
1081
                **_chat_template_kwargs,
1082
1083
1084
            )
        else:
            request_prompt = apply_hf_chat_template(
1085
                tokenizer=tokenizer,
1086
                conversation=conversation,
1087
                model_config=model_config,
1088
                **_chat_template_kwargs,
1089
1090
1091
1092
            )

        mm_data = await mm_data_future

1093
1094
1095
        # tool parsing is done only if a tool_parser has been set and if
        # tool_choice is not "none" (if tool_choice is "none" but a tool_parser
        # is set, we want to prevent parsing a tool_call hallucinated by the LLM
1096
1097
1098
        should_parse_tools = tool_parser is not None and (
            hasattr(request, "tool_choice") and request.tool_choice != "none"
        )
1099
1100

        if should_parse_tools:
1101
1102
1103
1104
1105
            if not isinstance(request, ChatCompletionRequest | ResponsesRequest):
                msg = (
                    "Tool usage is only supported for Chat Completions API "
                    "or Responses API requests."
                )
1106
                raise NotImplementedError(msg)
1107
            request = tool_parser(tokenizer).adjust_request(request=request)  # type: ignore
1108

1109
1110
        if tokenizer is None:
            assert isinstance(request_prompt, str), (
1111
1112
                "Prompt has to be a string",
                "when the tokenizer is not initialised",
1113
            )
1114
1115
1116
            prompt_inputs = TextTokensPrompt(
                prompt=request_prompt, prompt_token_ids=[1]
            )
1117
        elif isinstance(request_prompt, str):
1118
            prompt_inputs = await self._tokenize_prompt_input_async(
1119
1120
1121
1122
1123
1124
1125
1126
                request,
                tokenizer,
                request_prompt,
                add_special_tokens=add_special_tokens,
            )
        else:
            # For MistralTokenizer
            assert is_list_of(request_prompt, int), (
1127
1128
                "Prompt has to be either a string or a list of token ids"
            )
1129
1130
            prompt_inputs = TextTokensPrompt(
                prompt=tokenizer.decode(request_prompt),
1131
1132
                prompt_token_ids=request_prompt,
            )
1133

1134
        engine_prompt = EngineTokensPrompt(
1135
1136
            prompt_token_ids=prompt_inputs["prompt_token_ids"]
        )
1137
1138
        if mm_data is not None:
            engine_prompt["multi_modal_data"] = mm_data
1139
1140
1141
1142

        if mm_uuids is not None:
            engine_prompt["multi_modal_uuids"] = mm_uuids

1143
1144
        if request.mm_processor_kwargs is not None:
            engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
1145

1146
1147
1148
        if hasattr(request, "cache_salt") and request.cache_salt is not None:
            engine_prompt["cache_salt"] = request.cache_salt

1149
1150
        return conversation, [request_prompt], [engine_prompt]

1151
1152
1153
1154
    async def _process_inputs(
        self,
        request_id: str,
        engine_prompt: PromptType,
1155
        params: SamplingParams | PoolingParams,
1156
        *,
1157
1158
        lora_request: LoRARequest | None,
        trace_headers: Mapping[str, str] | None,
1159
1160
        priority: int,
    ) -> tuple[EngineCoreRequest, dict[str, Any]]:
1161
        """Use the Processor to process inputs for AsyncLLM."""
1162
        tokenization_kwargs: dict[str, Any] = {}
1163
1164
1165
        _validate_truncation_size(
            self.max_model_len, params.truncate_prompt_tokens, tokenization_kwargs
        )
1166

1167
        engine_request = self.processor.process_inputs(
1168
1169
            request_id,
            engine_prompt,
1170
            params,
1171
1172
1173
1174
1175
1176
1177
            lora_request=lora_request,
            tokenization_kwargs=tokenization_kwargs,
            trace_headers=trace_headers,
            priority=priority,
        )
        return engine_request, tokenization_kwargs

1178
1179
1180
1181
1182
1183
1184
    async def _generate_with_builtin_tools(
        self,
        request_id: str,
        request_prompt: RequestPrompt,
        engine_prompt: EngineTokensPrompt,
        sampling_params: SamplingParams,
        context: ConversationContext,
1185
        lora_request: LoRARequest | None = None,
1186
1187
1188
        priority: int = 0,
        **kwargs,
    ):
1189
        prompt_text, _, _ = self._get_prompt_components(request_prompt)
1190
1191
1192
1193
1194
1195
1196
1197
        orig_priority = priority
        while True:
            self._log_inputs(
                request_id,
                request_prompt,
                params=sampling_params,
                lora_request=lora_request,
            )
1198
            trace_headers = kwargs.get("trace_headers")
1199
            engine_request, tokenization_kwargs = await self._process_inputs(
1200
                request_id,
1201
1202
                engine_prompt,
                sampling_params,
1203
1204
1205
                lora_request=lora_request,
                trace_headers=trace_headers,
                priority=priority,
1206
            )
1207
1208
1209
1210

            generator = self.engine_client.generate(
                engine_request,
                sampling_params,
1211
1212
1213
                request_id,
                lora_request=lora_request,
                priority=priority,
1214
1215
                prompt_text=prompt_text,
                tokenization_kwargs=tokenization_kwargs,
1216
1217
                **kwargs,
            )
1218

1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
            async for res in generator:
                context.append_output(res)
                # NOTE(woosuk): The stop condition is handled by the engine.
                yield context

            if not context.need_builtin_tool_call():
                # The model did not ask for a tool call, so we're done.
                break

            # Call the tool and update the context with the result.
            tool_output = await context.call_tool()
            context.append_output(tool_output)

            # TODO: uncomment this and enable tool output streaming
            # yield context

            # Create inputs for the next turn.
            # Render the next prompt token ids.
            prompt_token_ids = context.render_for_completion()
1238
            engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids)
1239
1240
            request_prompt = prompt_token_ids
            # Update the sampling params.
1241
            sampling_params.max_tokens = self.max_model_len - len(prompt_token_ids)
1242
1243
1244
            # OPTIMIZATION
            priority = orig_priority - 1

1245
1246
    def _get_prompt_components(
        self,
1247
        prompt: RequestPrompt | PromptType,
1248
    ) -> PromptComponents:
1249
1250
        if isinstance(prompt, list):
            return PromptComponents(token_ids=prompt)
1251

1252
        return get_prompt_components(prompt)  # type: ignore[arg-type]
1253

1254
1255
1256
    def _log_inputs(
        self,
        request_id: str,
1257
1258
1259
        inputs: RequestPrompt | PromptType,
        params: SamplingParams | PoolingParams | BeamSearchParams | None,
        lora_request: LoRARequest | None,
1260
1261
1262
    ) -> None:
        if self.request_logger is None:
            return
1263

1264
        prompt, prompt_token_ids, prompt_embeds = self._get_prompt_components(inputs)
1265
1266
1267
1268
1269

        self.request_logger.log_inputs(
            request_id,
            prompt,
            prompt_token_ids,
1270
            prompt_embeds,
1271
1272
1273
            params=params,
            lora_request=lora_request,
        )
1274

1275
1276
1277
    async def _get_trace_headers(
        self,
        headers: Headers,
1278
    ) -> Mapping[str, str] | None:
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
        is_tracing_enabled = await self.engine_client.is_tracing_enabled()

        if is_tracing_enabled:
            return extract_trace_headers(headers)

        if contains_trace_headers(headers):
            log_tracing_disabled_warning()

        return None

1289
    @staticmethod
1290
    def _base_request_id(
1291
1292
        raw_request: Request | None, default: str | None = None
    ) -> str | None:
1293
1294
        """Pulls the request id to use from a header, if provided"""
        default = default or random_uuid()
1295
1296
1297
1298
        if raw_request is None:
            return default

        return raw_request.headers.get("X-Request-Id", default)
1299

1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
    @staticmethod
    def _get_data_parallel_rank(raw_request: Request | None) -> int | None:
        """Pulls the data parallel rank from a header, if provided"""
        if raw_request is None:
            return None

        rank_str = raw_request.headers.get("X-data-parallel-rank")
        if rank_str is None:
            return None

        try:
            return int(rank_str)
        except ValueError:
            return None

1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
    @staticmethod
    def _parse_tool_calls_from_content(
        request: ResponsesRequest | ChatCompletionRequest,
        tokenizer: AnyTokenizer,
        enable_auto_tools: bool,
        tool_parser_cls: Callable[[AnyTokenizer], ToolParser] | None,
        content: str | None = None,
    ) -> tuple[list[FunctionCall] | None, str | None]:
        function_calls = list[FunctionCall]()
        if request.tool_choice and isinstance(request.tool_choice, ToolChoiceFunction):
            assert content is not None
            # Forced Function Call
            function_calls.append(
                FunctionCall(name=request.tool_choice.name, arguments=content)
            )
            content = None  # Clear content since tool is called.
        elif request.tool_choice and isinstance(
            request.tool_choice, ChatCompletionNamedToolChoiceParam
        ):
            assert content is not None
            # Forced Function Call
            function_calls.append(
                FunctionCall(name=request.tool_choice.function.name, arguments=content)
            )
            content = None  # Clear content since tool is called.
        elif request.tool_choice == "required":
            assert content is not None
            tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(content)
            function_calls.extend(
                [
                    FunctionCall(
                        name=tool_call.name,
                        arguments=json.dumps(tool_call.parameters, ensure_ascii=False),
                    )
                    for tool_call in tool_calls
                ]
            )
            content = None  # Clear content since tool is called.
        elif (
            tool_parser_cls
            and enable_auto_tools
            and (request.tool_choice == "auto" or request.tool_choice is None)
        ):
            # Automatic Tool Call Parsing
            try:
                tool_parser = tool_parser_cls(tokenizer)
            except RuntimeError as e:
                logger.exception("Error in tool parser creation.")
                raise e
            tool_call_info = tool_parser.extract_tool_calls(
                content if content is not None else "",
                request=request,  # type: ignore
            )
            if tool_call_info is not None and tool_call_info.tools_called:
                # extract_tool_calls() returns a list of tool calls.
                function_calls.extend(
                    FunctionCall(
                        name=tool_call.function.name,
                        arguments=tool_call.function.arguments,
                    )
                    for tool_call in tool_call_info.tool_calls
                )
                content = tool_call_info.content
            else:
                # No tool calls.
                return None, content

        return function_calls, content

1384
    @staticmethod
1385
1386
1387
1388
1389
1390
    def _get_decoded_token(
        logprob: Logprob,
        token_id: int,
        tokenizer: AnyTokenizer,
        return_as_token_id: bool = False,
    ) -> str:
1391
1392
1393
        if return_as_token_id:
            return f"token_id:{token_id}"

1394
1395
        if logprob.decoded_token is not None:
            return logprob.decoded_token
1396
        return tokenizer.decode(token_id)
1397

1398
    def _is_model_supported(self, model_name: str | None) -> bool:
1399
1400
        if not model_name:
            return True
1401
        return self.models.is_base_model(model_name)
1402

1403
1404

def clamp_prompt_logprobs(
1405
1406
    prompt_logprobs: PromptLogprobs | None,
) -> PromptLogprobs | None:
1407
1408
1409
1410
1411
1412
1413
    if prompt_logprobs is None:
        return prompt_logprobs

    for logprob_dict in prompt_logprobs:
        if logprob_dict is None:
            continue
        for logprob_values in logprob_dict.values():
1414
            if logprob_values.logprob == float("-inf"):
1415
1416
                logprob_values.logprob = -9999.0
    return prompt_logprobs