serving_engine.py 53.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
import json
5
import sys
6
import time
7
import traceback
8
from collections.abc import AsyncGenerator, Callable, Iterable, Mapping, Sequence
9
from concurrent.futures import ThreadPoolExecutor
10
from dataclasses import dataclass, field
11
from http import HTTPStatus
12
from typing import Any, ClassVar, Generic, TypeAlias, TypeVar
13

14
import numpy as np
15
import torch
16
from fastapi import Request
17
from pydantic import ConfigDict, TypeAdapter
18
from starlette.datastructures import Headers
19
20
from typing_extensions import TypeIs

21
22
23
24
25
26
27
28
29
30
from vllm.entrypoints.context import (
    HarmonyContext,
    ParsableContext,
    StreamingHarmonyContext,
)
from vllm.entrypoints.openai.protocol import (
    FunctionCall,
    ResponseInputOutputItem,
    ResponsesRequest,
)
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from vllm.entrypoints.pooling.classify.protocol import (
    ClassificationChatRequest,
    ClassificationCompletionRequest,
    ClassificationRequest,
    ClassificationResponse,
)
from vllm.entrypoints.pooling.embed.protocol import (
    EmbeddingChatRequest,
    EmbeddingCompletionRequest,
    EmbeddingRequest,
    EmbeddingResponse,
)
from vllm.entrypoints.pooling.pooling.protocol import (
    IOProcessorRequest,
    PoolingResponse,
)
from vllm.entrypoints.pooling.score.protocol import (
    RerankRequest,
    ScoreRequest,
    ScoreResponse,
)
52
from vllm.transformers_utils.tokenizer import AnyTokenizer
53

54
55
56
57
58
if sys.version_info >= (3, 12):
    from typing import TypedDict
else:
    from typing_extensions import TypedDict

59
60
61
62
from openai.types.responses import (
    ToolChoiceFunction,
)

63
import vllm.envs as envs
64
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
65
from vllm.engine.protocol import EngineClient
66
67
68
69
70
71
72
73
74
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ChatTemplateContentFormatOption,
    ConversationMessage,
    apply_hf_chat_template,
    apply_mistral_chat_template,
    parse_chat_messages_futures,
    resolve_chat_template_content_format,
)
75
from vllm.entrypoints.context import ConversationContext
76
from vllm.entrypoints.logger import RequestLogger
77
from vllm.entrypoints.openai.protocol import (
78
    ChatCompletionNamedToolChoiceParam,
79
80
81
82
83
84
85
    ChatCompletionRequest,
    ChatCompletionResponse,
    CompletionRequest,
    CompletionResponse,
    DetokenizeRequest,
    ErrorInfo,
    ErrorResponse,
86
    FunctionDefinition,
87
88
89
90
91
92
93
    TokenizeChatRequest,
    TokenizeCompletionRequest,
    TokenizeResponse,
    TranscriptionRequest,
    TranscriptionResponse,
    TranslationRequest,
)
94
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
95
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
96
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig
97
98
99
from vllm.entrypoints.responses_utils import (
    construct_input_messages,
)
100
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
101
from vllm.entrypoints.utils import _validate_truncation_size
102
from vllm.inputs.data import PromptType
103
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
104
105
106
107
108
from vllm.inputs.parse import (
    PromptComponents,
    get_prompt_components,
    is_explicit_encoder_decoder_prompt,
)
109
from vllm.logger import init_logger
110
from vllm.logprobs import Logprob, PromptLogprobs
111
from vllm.lora.request import LoRARequest
112
from vllm.multimodal import (  # noqa: F401 - Required to resolve Pydantic error in RequestProcessingMixin
113
114
115
    MultiModalDataDict,
    MultiModalUUIDDict,
)
116
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
117
from vllm.pooling_params import PoolingParams
118
from vllm.reasoning import ReasoningParser, ReasoningParserManager
119
from vllm.sampling_params import BeamSearchParams, SamplingParams
120
from vllm.tokenizers import DeepseekV32Tokenizer, MistralTokenizer, TokenizerLike
121
122
123
124
125
from vllm.tracing import (
    contains_trace_headers,
    extract_trace_headers,
    log_tracing_disabled_warning,
)
126
from vllm.utils import random_uuid
127
from vllm.utils.async_utils import (
128
    AsyncMicrobatchTokenizer,
129
    collect_from_async_generator,
130
    make_async,
131
132
    merge_async_iterators,
)
133
from vllm.utils.collection_utils import is_list_of
134
from vllm.v1.engine import EngineCoreRequest
135
136
137

logger = init_logger(__name__)

138
139
140
141
142
CompletionLikeRequest: TypeAlias = (
    CompletionRequest
    | DetokenizeRequest
    | EmbeddingCompletionRequest
    | RerankRequest
143
    | ClassificationCompletionRequest
144
145
146
    | ScoreRequest
    | TokenizeCompletionRequest
)
147

148
ChatLikeRequest: TypeAlias = (
149
150
151
152
    ChatCompletionRequest
    | EmbeddingChatRequest
    | TokenizeChatRequest
    | ClassificationChatRequest
153
154
155
156
157
158
159
160
)
SpeechToTextRequest: TypeAlias = TranscriptionRequest | TranslationRequest
AnyRequest: TypeAlias = (
    CompletionLikeRequest
    | ChatLikeRequest
    | SpeechToTextRequest
    | ResponsesRequest
    | IOProcessorRequest
161
    | GenerateRequest
162
163
164
165
166
167
168
169
170
171
172
)

AnyResponse: TypeAlias = (
    CompletionResponse
    | ChatCompletionResponse
    | EmbeddingResponse
    | TranscriptionResponse
    | TokenizeResponse
    | PoolingResponse
    | ClassificationResponse
    | ScoreResponse
173
    | GenerateResponse
174
)
175

176
177
178

class TextTokensPrompt(TypedDict):
    prompt: str
179
    prompt_token_ids: list[int]
180
181


182
183
184
185
class EmbedsPrompt(TypedDict):
    prompt_embeds: torch.Tensor


186
RequestPrompt: TypeAlias = list[int] | str | TextTokensPrompt | EmbedsPrompt
187
188
189


def is_text_tokens_prompt(prompt: RequestPrompt) -> TypeIs[TextTokensPrompt]:
190
191
192
193
194
    return (
        isinstance(prompt, dict)
        and "prompt_token_ids" in prompt
        and "prompt_embeds" not in prompt
    )
195
196
197


def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]:
198
199
200
201
202
    return (
        isinstance(prompt, dict)
        and "prompt_token_ids" not in prompt
        and "prompt_embeds" in prompt
    )
203

204

205
206
207
RequestT = TypeVar("RequestT", bound=AnyRequest)


208
209
@dataclass(kw_only=True)
class RequestProcessingMixin:
210
    """
211
    Mixin for request processing,
212
213
    handling prompt preparation and engine input.
    """
214

215
216
    request_prompts: Sequence[RequestPrompt] | None = field(default_factory=list)
    engine_prompts: list[EngineTokensPrompt] | None = field(default_factory=list)
217
218


219
220
@dataclass(kw_only=True)
class ResponseGenerationMixin:
221
    """
222
    Mixin for response generation,
223
224
    managing result generators and final batch results.
    """
225

226
227
228
    result_generator: (
        AsyncGenerator[tuple[int, RequestOutput | PoolingRequestOutput], None] | None
    ) = None
229
    final_res_batch: list[RequestOutput | PoolingRequestOutput] = field(
230
231
        default_factory=list
    )
232
233
234
235

    model_config = ConfigDict(arbitrary_types_allowed=True)


236
237
@dataclass(kw_only=True)
class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, Generic[RequestT]):
238
239
    # Shared across all requests
    request: RequestT
240
    raw_request: Request | None = None
241
242
    model_name: str
    request_id: str
243
    created_time: int = field(default_factory=lambda: int(time.time()))
244
    lora_request: LoRARequest | None = None
245
246

    # Shared across most requests
247
    tokenizer: TokenizerLike | None = None
248
249


250
251
252
@dataclass(kw_only=True)
class ClassificationServeContext(ServeContext[ClassificationRequest]):
    pass
253
254


255
@dataclass(kw_only=True)
256
class EmbeddingServeContext(ServeContext[EmbeddingRequest]):
257
    chat_template: str | None = None
258
259
260
    chat_template_content_format: ChatTemplateContentFormatOption


261
class OpenAIServing:
262
263
264
265
    request_id_prefix: ClassVar[str] = """
    A short string prepended to every request’s ID (e.g. "embd", "classify")
    so you can easily tell “this ID came from Embedding vs Classification.”
    """
266

267
268
    def __init__(
        self,
269
        engine_client: EngineClient,
270
        models: OpenAIServingModels,
271
        *,
272
        request_logger: RequestLogger | None,
273
        return_tokens_as_token_ids: bool = False,
274
        log_error_stack: bool = False,
275
    ):
276
277
        super().__init__()

278
        self.engine_client = engine_client
279

280
        self.models = models
281

282
        self.request_logger = request_logger
283
        self.return_tokens_as_token_ids = return_tokens_as_token_ids
284
        self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
285
        self._apply_mistral_chat_template_async = make_async(
286
287
            apply_mistral_chat_template, executor=self._tokenizer_executor
        )
288

289
        self._async_tokenizer_pool: dict[TokenizerLike, AsyncMicrobatchTokenizer] = {}
290
        self.log_error_stack = log_error_stack
291

292
        self.input_processor = self.models.input_processor
293
        self.io_processor = self.models.io_processor
294
        self.renderer_config = self.models.renderer_config
295
296
297
        self.model_config = self.models.model_config
        self.max_model_len = self.model_config.max_model_len

298
    def _get_tool_parser(
299
        self, tool_parser_name: str | None = None, enable_auto_tools: bool = False
300
    ) -> Callable[[TokenizerLike], ToolParser] | None:
301
302
303
304
        """Get the tool parser based on the name."""
        parser = None
        if not enable_auto_tools or tool_parser_name is None:
            return parser
305
        logger.info('"auto" tool choice has been enabled.')
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325

        try:
            if tool_parser_name == "pythonic" and self.model_config.model.startswith(
                "meta-llama/Llama-3.2"
            ):
                logger.warning(
                    "Llama3.2 models may struggle to emit valid pythonic tool calls"
                )
            parser = ToolParserManager.get_tool_parser(tool_parser_name)
        except Exception as e:
            raise TypeError(
                "Error: --enable-auto-tool-choice requires "
                f"tool_parser:'{tool_parser_name}' which has not "
                "been registered"
            ) from e
        return parser

    def _get_reasoning_parser(
        self,
        reasoning_parser_name: str,
326
    ) -> Callable[[TokenizerLike], ReasoningParser] | None:
327
328
329
330
331
332
333
334
335
336
337
        """Get the reasoning parser based on the name."""
        parser = None
        if not reasoning_parser_name:
            return None
        try:
            parser = ReasoningParserManager.get_reasoning_parser(reasoning_parser_name)
            assert parser is not None
        except Exception as e:
            raise TypeError(f"{reasoning_parser_name=} has not been registered") from e
        return parser

338
    async def reset_mm_cache(self) -> None:
339
        self.input_processor.clear_mm_cache()
340
341
        await self.engine_client.reset_mm_cache()

342
343
344
345
346
    async def beam_search(
        self,
        prompt: PromptType,
        request_id: str,
        params: BeamSearchParams,
347
        lora_request: LoRARequest | None = None,
348
        trace_headers: Mapping[str, str] | None = None,
349
350
351
352
353
354
355
356
    ) -> AsyncGenerator[RequestOutput, None]:
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        ignore_eos = params.ignore_eos
        temperature = params.temperature
        length_penalty = params.length_penalty
        include_stop_str_in_output = params.include_stop_str_in_output

357
358
        input_processor = self.input_processor
        tokenizer = input_processor.tokenizer
359
360
        if tokenizer is None:
            raise ValueError(
361
                "You cannot use beam search when `skip_tokenizer_init=True`"
362
363
364
365
366
367
368
            )

        eos_token_id: int = tokenizer.eos_token_id  # type: ignore

        if is_explicit_encoder_decoder_prompt(prompt):
            raise NotImplementedError

369
        prompt_text: str | None
370
        prompt_token_ids: list[int]
371
        multi_modal_data: MultiModalDataDict | None
372
373
374
375
376
377
378
379
380
        if isinstance(prompt, str):
            prompt_text = prompt
            prompt_token_ids = []
            multi_modal_data = None
        else:
            prompt_text = prompt.get("prompt")  # type: ignore
            prompt_token_ids = prompt.get("prompt_token_ids", [])  # type: ignore
            multi_modal_data = prompt.get("multi_modal_data")  # type: ignore

381
382
383
384
385
386
387
388
389
390
        mm_processor_kwargs: dict[str, Any] | None = None

        # This is a workaround to fix multimodal beam search; this is a
        # bandaid fix for 2 small problems:
        # 1. Multi_modal_data on the processed_inputs currently resolves to
        #    `None`.
        # 2. preprocessing above expands the multimodal placeholders. However,
        #    this happens again in generation, so the double expansion causes
        #    a mismatch.
        # TODO - would be ideal to handle this more gracefully.
391
392
393
394
395

        tokenized_length = len(prompt_token_ids)

        sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)

396
        logprobs_num = 2 * beam_width
397
        beam_search_params = SamplingParams(
398
            logprobs=logprobs_num,
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
            max_tokens=1,
            temperature=temperature,
        )
        all_beams = [
            BeamSearchSequence(
                tokens=prompt_token_ids,
                cum_logprob=0,
                logprobs=[],
                multi_modal_data=multi_modal_data,
                mm_processor_kwargs=mm_processor_kwargs,
                lora_request=lora_request,
            )
        ]
        completed = []

        for _ in range(max_tokens):
            prompts_batch, lora_req_batch = zip(
                *[
                    (
                        EngineTokensPrompt(
                            prompt_token_ids=beam.tokens,
                            multi_modal_data=beam.multi_modal_data,
                            mm_processor_kwargs=beam.mm_processor_kwargs,
                        ),
                        beam.lora_request,
                    )
                    for beam in all_beams
                ]
            )

            tasks = []
            request_id_batch = f"{request_id}-{random_uuid()}"

            for i, (individual_prompt, lora_req) in enumerate(
                zip(prompts_batch, lora_req_batch)
            ):
                request_id_item = f"{request_id_batch}-beam-{i}"
                task = asyncio.create_task(
                    collect_from_async_generator(
                        self.engine_client.generate(
                            individual_prompt,
                            beam_search_params,
                            request_id_item,
                            lora_request=lora_req,
443
                            trace_headers=trace_headers,
444
445
446
447
448
449
450
451
                        )
                    )
                )
                tasks.append(task)

            output = [x[0] for x in await asyncio.gather(*tasks)]

            new_beams = []
452
453
454
455
456
457
458
459
            # Store all new tokens generated by beam
            all_beams_token_id = []
            # Store the cumulative probability of all tokens
            # generated by beam search
            all_beams_logprob = []
            # Iterate through all beam inference results
            for i, result in enumerate(output):
                current_beam = all_beams[i]
460
461
                if result.outputs[0].logprobs is not None:
                    logprobs = result.outputs[0].logprobs[0]
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
                    all_beams_token_id.extend(list(logprobs.keys()))
                    all_beams_logprob.extend(
                        [
                            current_beam.cum_logprob + obj.logprob
                            for obj in logprobs.values()
                        ]
                    )

            # Handle the token for the end of sentence (EOS)
            all_beams_token_id = np.array(all_beams_token_id)
            all_beams_logprob = np.array(all_beams_logprob)

            if not ignore_eos:
                # Get the index position of eos token in all generated results
                eos_idx = np.where(all_beams_token_id == eos_token_id)[0]
                for idx in eos_idx:
                    current_beam = all_beams[idx // logprobs_num]
                    result = output[idx // logprobs_num]
                    assert result.outputs[0].logprobs is not None
                    logprobs_entry = result.outputs[0].logprobs[0]
                    completed.append(
                        BeamSearchSequence(
                            tokens=current_beam.tokens + [eos_token_id]
                            if include_stop_str_in_output
                            else current_beam.tokens,
                            logprobs=current_beam.logprobs + [logprobs_entry],
                            cum_logprob=float(all_beams_logprob[idx]),
                            finish_reason="stop",
                            stop_reason=eos_token_id,
                        )
                    )
                # After processing, set the log probability of the eos condition
                # to negative infinity.
                all_beams_logprob[eos_idx] = -np.inf

            # Processing non-EOS tokens
            # Get indices of the top beam_width probabilities
            topn_idx = np.argpartition(np.negative(all_beams_logprob), beam_width)[
                :beam_width
            ]

            for idx in topn_idx:
                current_beam = all_beams[idx // logprobs_num]
                result = output[idx // logprobs_num]
                token_id = int(all_beams_token_id[idx])
                assert result.outputs[0].logprobs is not None
                logprobs_entry = result.outputs[0].logprobs[0]
                new_beams.append(
                    BeamSearchSequence(
                        tokens=current_beam.tokens + [token_id],
                        logprobs=current_beam.logprobs + [logprobs_entry],
                        lora_request=current_beam.lora_request,
                        cum_logprob=float(all_beams_logprob[idx]),
                        multi_modal_data=current_beam.multi_modal_data,
                        mm_processor_kwargs=current_beam.mm_processor_kwargs,
                    )
                )

            all_beams = new_beams
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554

        completed.extend(all_beams)
        sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
        best_beams = sorted_completed[:beam_width]

        for beam in best_beams:
            if beam.tokens[-1] == eos_token_id and not ignore_eos:
                # Skip the eos token in the text.
                tokens = beam.tokens[tokenized_length:-1]
            else:
                tokens = beam.tokens[tokenized_length:]
            beam.text = tokenizer.decode(tokens)

        yield RequestOutput(
            request_id=request_id,
            prompt=prompt_text,
            outputs=[
                CompletionOutput(
                    text=beam.text,  # type: ignore
                    cumulative_logprob=beam.cum_logprob,
                    token_ids=beam.tokens[tokenized_length:],
                    index=i,
                    logprobs=beam.logprobs,
                    finish_reason=beam.finish_reason
                    if beam.finish_reason is not None
                    else "length",
                    stop_reason=beam.stop_reason,
                )
                for (i, beam) in enumerate(best_beams)
            ],
            finished=True,
            prompt_token_ids=prompt_token_ids,
            prompt_logprobs=None,
        )
555

556
    def _get_renderer(self, tokenizer: TokenizerLike | None) -> BaseRenderer:
557
558
559
560
561
562
563
        """
        Get a Renderer instance with the provided tokenizer.
        Uses shared async tokenizer pool for efficiency.
        """
        return CompletionRenderer(
            model_config=self.model_config,
            tokenizer=tokenizer,
564
565
            async_tokenizer_pool=self._async_tokenizer_pool,
        )
566

567
568
569
570
571
572
573
574
575
576
577
578
579
    def _build_render_config(
        self,
        request: Any,
    ) -> RenderConfig:
        """
        Build and return a `RenderConfig` for an endpoint.

        Used by the renderer to control how prompts are prepared
        (e.g., tokenization and length handling). Endpoints should
        implement this with logic appropriate to their request type.
        """
        raise NotImplementedError

580
581
    def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer:
        """
582
        Return (and cache) an `AsyncMicrobatchTokenizer` bound to the
583
584
585
586
587
588
589
        given tokenizer.
        """
        async_tokenizer = self._async_tokenizer_pool.get(tokenizer)
        if async_tokenizer is None:
            async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
            self._async_tokenizer_pool[tokenizer] = async_tokenizer
        return async_tokenizer
590

591
592
593
    async def _preprocess(
        self,
        ctx: ServeContext,
594
    ) -> ErrorResponse | None:
595
596
597
598
599
600
601
602
603
        """
        Default preprocessing hook. Subclasses may override
        to prepare `ctx` (classification, embedding, etc.).
        """
        return None

    def _build_response(
        self,
        ctx: ServeContext,
604
    ) -> AnyResponse | ErrorResponse:
605
606
607
608
609
610
611
612
613
        """
        Default response builder. Subclass may override this method
        to return the appropriate response object.
        """
        return self.create_error_response("unimplemented endpoint")

    async def handle(
        self,
        ctx: ServeContext,
614
615
    ) -> AnyResponse | ErrorResponse:
        generation: AsyncGenerator[AnyResponse | ErrorResponse, None]
616
617
618
619
620
621
622
623
624
625
        generation = self._pipeline(ctx)

        async for response in generation:
            return response

        return self.create_error_response("No response yielded from pipeline")

    async def _pipeline(
        self,
        ctx: ServeContext,
626
    ) -> AsyncGenerator[AnyResponse | ErrorResponse, None]:
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
        """Execute the request processing pipeline yielding responses."""
        if error := await self._check_model(ctx.request):
            yield error
        if error := self._validate_request(ctx):
            yield error

        preprocess_ret = await self._preprocess(ctx)
        if isinstance(preprocess_ret, ErrorResponse):
            yield preprocess_ret

        generators_ret = await self._prepare_generators(ctx)
        if isinstance(generators_ret, ErrorResponse):
            yield generators_ret

        collect_ret = await self._collect_batch(ctx)
        if isinstance(collect_ret, ErrorResponse):
            yield collect_ret

        yield self._build_response(ctx)

647
    def _validate_request(self, ctx: ServeContext) -> ErrorResponse | None:
648
        truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None)
649

650
651
652
653
        if (
            truncate_prompt_tokens is not None
            and truncate_prompt_tokens > self.max_model_len
        ):
654
655
656
            return self.create_error_response(
                "truncate_prompt_tokens value is "
                "greater than max_model_len."
657
658
                " Please, select a smaller truncation size."
            )
659
660
        return None

661
662
663
    def _create_pooling_params(
        self,
        ctx: ServeContext,
664
    ) -> PoolingParams | ErrorResponse:
665
666
        if not hasattr(ctx.request, "to_pooling_params"):
            return self.create_error_response(
667
668
                "Request type does not support pooling parameters"
            )
669
670
671

        return ctx.request.to_pooling_params()

672
673
674
    async def _prepare_generators(
        self,
        ctx: ServeContext,
675
    ) -> ErrorResponse | None:
676
        """Schedule the request and get the result generator."""
677
        generators: list[
678
            AsyncGenerator[RequestOutput | PoolingRequestOutput, None]
679
        ] = []
680
681

        try:
682
683
684
685
686
            trace_headers = (
                None
                if ctx.raw_request is None
                else await self._get_trace_headers(ctx.raw_request.headers)
            )
687

688
689
690
            pooling_params = self._create_pooling_params(ctx)
            if isinstance(pooling_params, ErrorResponse):
                return pooling_params
691
692

            if ctx.engine_prompts is None:
693
                return self.create_error_response("Engine prompts not available")
694
695
696
697

            for i, engine_prompt in enumerate(ctx.engine_prompts):
                request_id_item = f"{ctx.request_id}-{i}"

698
699
                self._log_inputs(
                    request_id_item,
700
                    engine_prompt,
701
702
703
                    params=pooling_params,
                    lora_request=ctx.lora_request,
                )
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726

                generator = self.engine_client.encode(
                    engine_prompt,
                    pooling_params,
                    request_id_item,
                    lora_request=ctx.lora_request,
                    trace_headers=trace_headers,
                    priority=getattr(ctx.request, "priority", 0),
                )

                generators.append(generator)

            ctx.result_generator = merge_async_iterators(*generators)

            return None

        except Exception as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))

    async def _collect_batch(
        self,
        ctx: ServeContext,
727
    ) -> ErrorResponse | None:
728
729
730
        """Collect batch results from the result generator."""
        try:
            if ctx.engine_prompts is None:
731
                return self.create_error_response("Engine prompts not available")
732
733

            num_prompts = len(ctx.engine_prompts)
734
            final_res_batch: list[RequestOutput | PoolingRequestOutput | None]
735
736
737
            final_res_batch = [None] * num_prompts

            if ctx.result_generator is None:
738
                return self.create_error_response("Result generator not available")
739
740
741
742
743
744

            async for i, res in ctx.result_generator:
                final_res_batch[i] = res

            if None in final_res_batch:
                return self.create_error_response(
745
746
                    "Failed to generate results for all prompts"
                )
747

748
            ctx.final_res_batch = [res for res in final_res_batch if res is not None]
749
750
751
752
753
754

            return None

        except Exception as e:
            return self.create_error_response(str(e))

755
    def create_error_response(
756
757
758
759
760
        self,
        message: str,
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
    ) -> ErrorResponse:
761
762
763
764
765
766
        if self.log_error_stack:
            exc_type, _, _ = sys.exc_info()
            if exc_type is not None:
                traceback.print_exc()
            else:
                traceback.print_stack()
767
768
769
        return ErrorResponse(
            error=ErrorInfo(message=message, type=err_type, code=status_code.value)
        )
770

771
    def create_streaming_error_response(
772
773
774
775
776
        self,
        message: str,
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
    ) -> str:
777
        json_str = json.dumps(
778
779
780
781
            self.create_error_response(
                message=message, err_type=err_type, status_code=status_code
            ).model_dump()
        )
782
783
        return json_str

784
    async def _check_model(
785
786
        self,
        request: AnyRequest,
787
    ) -> ErrorResponse | None:
788
789
        error_response = None

790
        if self._is_model_supported(request.model):
791
            return None
792
        if request.model in self.models.lora_requests:
793
            return None
794
795
796
797
798
        if (
            envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING
            and request.model
            and (load_result := await self.models.resolve_lora(request.model))
        ):
799
800
            if isinstance(load_result, LoRARequest):
                return None
801
802
803
804
            if (
                isinstance(load_result, ErrorResponse)
                and load_result.error.code == HTTPStatus.BAD_REQUEST.value
            ):
805
806
807
                error_response = load_result

        return error_response or self.create_error_response(
808
809
            message=f"The model `{request.model}` does not exist.",
            err_type="NotFoundError",
810
811
            status_code=HTTPStatus.NOT_FOUND,
        )
812

813
    def _get_active_default_mm_loras(self, request: AnyRequest) -> LoRARequest | None:
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
        """Determine if there are any active default multimodal loras."""
        # TODO: Currently this is only enabled for chat completions
        # to be better aligned with only being enabled for .generate
        # when run offline. It would be nice to support additional
        # tasks types in the future.
        message_types = self._get_message_types(request)
        default_mm_loras = set()

        for lora in self.models.lora_requests.values():
            # Best effort match for default multimodal lora adapters;
            # There is probably a better way to do this, but currently
            # this matches against the set of 'types' in any content lists
            # up until '_', e.g., to match audio_url -> audio
            if lora.lora_name in message_types:
                default_mm_loras.add(lora)

        # Currently only support default modality specific loras if
        # we have exactly one lora matched on the request.
        if len(default_mm_loras) == 1:
            return default_mm_loras.pop()
        return None

836
    def _maybe_get_adapters(
837
838
839
        self,
        request: AnyRequest,
        supports_default_mm_loras: bool = False,
840
    ) -> LoRARequest | None:
841
        if request.model in self.models.lora_requests:
842
            return self.models.lora_requests[request.model]
843
844
845
846
847
848

        # Currently only support default modality specific loras
        # if we have exactly one lora matched on the request.
        if supports_default_mm_loras:
            default_mm_lora = self._get_active_default_mm_loras(request)
            if default_mm_lora is not None:
849
                return default_mm_lora
850
851

        if self._is_model_supported(request.model):
852
            return None
853

854
        # if _check_model has been called earlier, this will be unreachable
855
        raise ValueError(f"The model `{request.model}` does not exist.")
856

857
858
859
860
861
862
863
864
865
866
    def _get_message_types(self, request: AnyRequest) -> set[str]:
        """Retrieve the set of types from message content dicts up
        until `_`; we use this to match potential multimodal data
        with default per modality loras.
        """
        message_types: set[str] = set()

        if not hasattr(request, "messages"):
            return message_types

867
868
869
870
871
        messages = request.messages
        if messages is None or isinstance(messages, (str, bytes)):
            return message_types

        for message in messages:
872
873
874
875
876
            if (
                isinstance(message, dict)
                and "content" in message
                and isinstance(message["content"], list)
            ):
877
878
879
880
881
                for content_dict in message["content"]:
                    if "type" in content_dict:
                        message_types.add(content_dict["type"].split("_")[0])
        return message_types

882
    async def _normalize_prompt_text_to_input(
883
884
885
        self,
        request: AnyRequest,
        prompt: str,
886
        tokenizer: TokenizerLike,
887
888
        add_special_tokens: bool,
    ) -> TextTokensPrompt:
889
890
        async_tokenizer = self._get_async_tokenizer(tokenizer)

891
892
893
894
        if (
            self.model_config.encoder_config is not None
            and self.model_config.encoder_config.get("do_lower_case", False)
        ):
895
896
            prompt = prompt.lower()

897
        truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
898

899
        if truncate_prompt_tokens is None:
900
            encoded = await async_tokenizer(
901
902
                prompt, add_special_tokens=add_special_tokens
            )
903
904
        elif truncate_prompt_tokens < 0:
            # Negative means we cap at the model's max length
905
906
907
908
            encoded = await async_tokenizer(
                prompt,
                add_special_tokens=add_special_tokens,
                truncation=True,
909
910
                max_length=self.max_model_len,
            )
911
        else:
912
913
914
915
            encoded = await async_tokenizer(
                prompt,
                add_special_tokens=add_special_tokens,
                truncation=True,
916
917
                max_length=truncate_prompt_tokens,
            )
918
919
920
921
922
923

        input_ids = encoded.input_ids
        input_text = prompt

        return self._validate_input(request, input_ids, input_text)

924
    async def _normalize_prompt_tokens_to_input(
925
926
        self,
        request: AnyRequest,
927
        prompt_ids: list[int],
928
        tokenizer: TokenizerLike | None,
929
    ) -> TextTokensPrompt:
930
        truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
931

932
        if truncate_prompt_tokens is None:
933
            input_ids = prompt_ids
934
        elif truncate_prompt_tokens < 0:
935
            input_ids = prompt_ids[-self.max_model_len :]
936
937
938
        else:
            input_ids = prompt_ids[-truncate_prompt_tokens:]

939
940
941
942
943
        if tokenizer is None:
            input_text = ""
        else:
            async_tokenizer = self._get_async_tokenizer(tokenizer)
            input_text = await async_tokenizer.decode(input_ids)
944

945
946
947
948
949
        return self._validate_input(request, input_ids, input_text)

    def _validate_input(
        self,
        request: AnyRequest,
950
        input_ids: list[int],
951
952
        input_text: str,
    ) -> TextTokensPrompt:
953
954
        token_num = len(input_ids)

955
956
        # Note: EmbeddingRequest, ClassificationRequest,
        # and ScoreRequest doesn't have max_tokens
957
        if isinstance(
958
            request,
959
960
961
962
963
            (
                EmbeddingChatRequest,
                EmbeddingCompletionRequest,
                ScoreRequest,
                RerankRequest,
964
965
                ClassificationCompletionRequest,
                ClassificationChatRequest,
966
967
            ),
        ):
968
969
            # Note: input length can be up to the entire model context length
            # since these requests don't generate tokens.
970
            if token_num > self.max_model_len:
971
972
                operations: dict[type[AnyRequest], str] = {
                    ScoreRequest: "score",
973
974
                    ClassificationCompletionRequest: "classification",
                    ClassificationChatRequest: "classification",
975
                }
976
                operation = operations.get(type(request), "embedding generation")
977
978
979
                raise ValueError(
                    f"This model's maximum context length is "
                    f"{self.max_model_len} tokens. However, you requested "
980
                    f"{token_num} tokens in the input for {operation}. "
981
982
983
                    f"Please reduce the length of the input."
                )
            return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
984

985
986
        # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
        # and does not require model context length validation
987
        if isinstance(
988
989
            request,
            (TokenizeCompletionRequest, TokenizeChatRequest, DetokenizeRequest),
990
        ):
991
            return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
992

993
994
995
996
997
        # chat completion endpoint supports max_completion_tokens
        if isinstance(request, ChatCompletionRequest):
            # TODO(#9845): remove max_tokens when field dropped from OpenAI API
            max_tokens = request.max_completion_tokens or request.max_tokens
        else:
998
            max_tokens = getattr(request, "max_tokens", None)
999
1000
1001
1002

        # Note: input length can be up to model context length - 1 for
        # completion-like requests.
        if token_num >= self.max_model_len:
1003
            raise ValueError(
1004
                f"This model's maximum context length is "
1005
1006
                f"{self.max_model_len} tokens. However, your request has "
                f"{token_num} input tokens. Please reduce the length of "
1007
1008
                "the input messages."
            )
1009

1010
        if max_tokens is not None and token_num + max_tokens > self.max_model_len:
1011
1012
1013
1014
1015
            raise ValueError(
                "'max_tokens' or 'max_completion_tokens' is too large: "
                f"{max_tokens}. This model's maximum context length is "
                f"{self.max_model_len} tokens and your request has "
                f"{token_num} input tokens ({max_tokens} > {self.max_model_len}"
1016
1017
                f" - {token_num})."
            )
1018
1019
1020

        return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)

1021
    async def _tokenize_prompt_input_async(
1022
1023
        self,
        request: AnyRequest,
1024
        tokenizer: TokenizerLike,
1025
        prompt_input: str | list[int],
1026
1027
1028
        add_special_tokens: bool = True,
    ) -> TextTokensPrompt:
        """
1029
        A simpler implementation that tokenizes a single prompt input.
1030
        """
1031
        async for result in self._tokenize_prompt_inputs_async(
1032
1033
            request,
            tokenizer,
1034
            [prompt_input],
1035
            add_special_tokens=add_special_tokens,
1036
1037
1038
        ):
            return result
        raise ValueError("No results yielded from tokenization")
1039

1040
    async def _tokenize_prompt_inputs_async(
1041
1042
        self,
        request: AnyRequest,
1043
        tokenizer: TokenizerLike,
1044
        prompt_inputs: Iterable[str | list[int]],
1045
        add_special_tokens: bool = True,
1046
    ) -> AsyncGenerator[TextTokensPrompt, None]:
1047
        """
1048
        A simpler implementation that tokenizes multiple prompt inputs.
1049
        """
1050
1051
        for prompt in prompt_inputs:
            if isinstance(prompt, str):
1052
                yield await self._normalize_prompt_text_to_input(
1053
                    request,
1054
1055
                    prompt=prompt,
                    tokenizer=tokenizer,
1056
1057
1058
                    add_special_tokens=add_special_tokens,
                )
            else:
1059
                yield await self._normalize_prompt_tokens_to_input(
1060
                    request,
1061
1062
                    prompt_ids=prompt,
                    tokenizer=tokenizer,
1063
1064
                )

1065
1066
    def _validate_chat_template(
        self,
1067
1068
        request_chat_template: str | None,
        chat_template_kwargs: dict[str, Any] | None,
1069
        trust_request_chat_template: bool,
1070
    ) -> ErrorResponse | None:
1071
        if not trust_request_chat_template and (
1072
1073
1074
1075
1076
1077
            request_chat_template is not None
            or (
                chat_template_kwargs
                and chat_template_kwargs.get("chat_template") is not None
            )
        ):
1078
1079
1080
            return self.create_error_response(
                "Chat template is passed with request, but "
                "--trust-request-chat-template is not set. "
1081
1082
                "Refused request with untrusted chat template."
            )
1083
1084
        return None

1085
1086
    async def _preprocess_chat(
        self,
1087
        request: ChatLikeRequest | ResponsesRequest,
1088
        tokenizer: TokenizerLike | None,
1089
        messages: list[ChatCompletionMessageParam],
1090
        chat_template: str | None,
1091
        chat_template_content_format: ChatTemplateContentFormatOption,
1092
1093
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
1094
1095
1096
        tool_dicts: list[dict[str, Any]] | None = None,
        documents: list[dict[str, str]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
1097
        tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
1098
        add_special_tokens: bool = False,
1099
    ) -> tuple[
1100
1101
1102
        list[ConversationMessage],
        Sequence[RequestPrompt],
        list[EngineTokensPrompt],
1103
    ]:
1104
        renderer_config = self.renderer_config
1105

1106
1107
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
1108
            tool_dicts,
1109
1110
            chat_template_content_format,
            tokenizer,
1111
            renderer_config=renderer_config,
1112
        )
1113
        conversation, mm_data_future, mm_uuids = parse_chat_messages_futures(
1114
            messages,
1115
            renderer_config,
1116
            content_format=resolved_content_format,
1117
1118
        )

1119
        _chat_template_kwargs: dict[str, Any] = dict(
1120
1121
1122
1123
1124
1125
1126
1127
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tool_dicts,
            documents=documents,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

1128
        request_prompt: str | list[int]
1129
1130
1131
1132

        if tokenizer is None:
            request_prompt = "placeholder"
        elif isinstance(tokenizer, MistralTokenizer):
1133
            request_prompt = await self._apply_mistral_chat_template_async(
1134
1135
                tokenizer,
                messages=messages,
1136
                **_chat_template_kwargs,
1137
            )
1138
1139
1140
1141
        elif isinstance(tokenizer, DeepseekV32Tokenizer):
            request_prompt = tokenizer.apply_chat_template(
                conversation=conversation,
                messages=messages,
1142
                model_config=renderer_config.model_config,
1143
1144
                **_chat_template_kwargs,
            )
1145
1146
        else:
            request_prompt = apply_hf_chat_template(
1147
                tokenizer=tokenizer,
1148
                conversation=conversation,
1149
                renderer_config=renderer_config,
1150
                **_chat_template_kwargs,
1151
1152
1153
1154
            )

        mm_data = await mm_data_future

1155
1156
1157
        # tool parsing is done only if a tool_parser has been set and if
        # tool_choice is not "none" (if tool_choice is "none" but a tool_parser
        # is set, we want to prevent parsing a tool_call hallucinated by the LLM
1158
1159
1160
        should_parse_tools = tool_parser is not None and (
            hasattr(request, "tool_choice") and request.tool_choice != "none"
        )
1161
1162

        if should_parse_tools:
1163
1164
1165
1166
1167
            if not isinstance(request, ChatCompletionRequest | ResponsesRequest):
                msg = (
                    "Tool usage is only supported for Chat Completions API "
                    "or Responses API requests."
                )
1168
                raise NotImplementedError(msg)
1169
            request = tool_parser(tokenizer).adjust_request(request=request)  # type: ignore
1170

1171
1172
        if tokenizer is None:
            assert isinstance(request_prompt, str), (
1173
1174
                "Prompt has to be a string",
                "when the tokenizer is not initialised",
1175
            )
1176
1177
1178
            prompt_inputs = TextTokensPrompt(
                prompt=request_prompt, prompt_token_ids=[1]
            )
1179
        elif isinstance(request_prompt, str):
1180
            prompt_inputs = await self._tokenize_prompt_input_async(
1181
1182
1183
1184
1185
1186
1187
1188
                request,
                tokenizer,
                request_prompt,
                add_special_tokens=add_special_tokens,
            )
        else:
            # For MistralTokenizer
            assert is_list_of(request_prompt, int), (
1189
1190
                "Prompt has to be either a string or a list of token ids"
            )
1191
1192
            prompt_inputs = TextTokensPrompt(
                prompt=tokenizer.decode(request_prompt),
1193
1194
                prompt_token_ids=request_prompt,
            )
1195

1196
        engine_prompt = EngineTokensPrompt(
1197
1198
            prompt_token_ids=prompt_inputs["prompt_token_ids"]
        )
1199
1200
        if mm_data is not None:
            engine_prompt["multi_modal_data"] = mm_data
1201
1202
1203
1204

        if mm_uuids is not None:
            engine_prompt["multi_modal_uuids"] = mm_uuids

1205
1206
        if request.mm_processor_kwargs is not None:
            engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
1207

1208
1209
1210
        if hasattr(request, "cache_salt") and request.cache_salt is not None:
            engine_prompt["cache_salt"] = request.cache_salt

1211
1212
        return conversation, [request_prompt], [engine_prompt]

1213
1214
1215
1216
    async def _process_inputs(
        self,
        request_id: str,
        engine_prompt: PromptType,
1217
        params: SamplingParams | PoolingParams,
1218
        *,
1219
1220
        lora_request: LoRARequest | None,
        trace_headers: Mapping[str, str] | None,
1221
1222
        priority: int,
    ) -> tuple[EngineCoreRequest, dict[str, Any]]:
1223
        """Use the Processor to process inputs for AsyncLLM."""
1224
        tokenization_kwargs: dict[str, Any] = {}
1225
1226
1227
        _validate_truncation_size(
            self.max_model_len, params.truncate_prompt_tokens, tokenization_kwargs
        )
1228

1229
        engine_request = self.input_processor.process_inputs(
1230
1231
            request_id,
            engine_prompt,
1232
            params,
1233
1234
1235
1236
1237
1238
1239
            lora_request=lora_request,
            tokenization_kwargs=tokenization_kwargs,
            trace_headers=trace_headers,
            priority=priority,
        )
        return engine_request, tokenization_kwargs

1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
    async def _render_next_turn(
        self,
        request: ResponsesRequest,
        tokenizer: AnyTokenizer,
        messages: list[ResponseInputOutputItem],
        tool_dicts: list[dict[str, Any]] | None,
        tool_parser,
        chat_template: str | None,
        chat_template_content_format: ChatTemplateContentFormatOption,
    ):
        new_messages = construct_input_messages(
            request_input=messages,
        )

        _, request_prompts, engine_prompts = await self._preprocess_chat(
            request,
            tokenizer,
            new_messages,
            tool_dicts=tool_dicts,
            tool_parser=tool_parser,
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
        )
        return request_prompts, engine_prompts

1265
1266
1267
1268
1269
1270
1271
    async def _generate_with_builtin_tools(
        self,
        request_id: str,
        request_prompt: RequestPrompt,
        engine_prompt: EngineTokensPrompt,
        sampling_params: SamplingParams,
        context: ConversationContext,
1272
        lora_request: LoRARequest | None = None,
1273
1274
1275
        priority: int = 0,
        **kwargs,
    ):
1276
        prompt_text, _, _ = self._get_prompt_components(request_prompt)
1277
        orig_priority = priority
1278
        sub_request = 0
1279
        while True:
1280
1281
            # Ensure that each sub-request has a unique request id.
            sub_request_id = f"{request_id}_{sub_request}"
1282
            self._log_inputs(
1283
                sub_request_id,
1284
1285
1286
1287
                request_prompt,
                params=sampling_params,
                lora_request=lora_request,
            )
1288
            trace_headers = kwargs.get("trace_headers")
1289
            engine_request, tokenization_kwargs = await self._process_inputs(
1290
                sub_request_id,
1291
1292
                engine_prompt,
                sampling_params,
1293
1294
1295
                lora_request=lora_request,
                trace_headers=trace_headers,
                priority=priority,
1296
            )
1297
1298
1299
1300

            generator = self.engine_client.generate(
                engine_request,
                sampling_params,
1301
                sub_request_id,
1302
1303
                lora_request=lora_request,
                priority=priority,
1304
1305
                prompt_text=prompt_text,
                tokenization_kwargs=tokenization_kwargs,
1306
1307
                **kwargs,
            )
1308

1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
            async for res in generator:
                context.append_output(res)
                # NOTE(woosuk): The stop condition is handled by the engine.
                yield context

            if not context.need_builtin_tool_call():
                # The model did not ask for a tool call, so we're done.
                break

            # Call the tool and update the context with the result.
            tool_output = await context.call_tool()
1320
            context.append_tool_output(tool_output)
1321
1322
1323
1324
1325
1326

            # TODO: uncomment this and enable tool output streaming
            # yield context

            # Create inputs for the next turn.
            # Render the next prompt token ids.
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
            if isinstance(context, (HarmonyContext, StreamingHarmonyContext)):
                prompt_token_ids = context.render_for_completion()
                engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids)
                request_prompt = prompt_token_ids
            elif isinstance(context, ParsableContext):
                request_prompts, engine_prompts = await self._render_next_turn(
                    context.request,
                    context.tokenizer,
                    context.parser.response_messages,
                    context.tool_dicts,
                    context.tool_parser_cls,
                    context.chat_template,
                    context.chat_template_content_format,
                )
                engine_prompt = engine_prompts[0]
                request_prompt = request_prompts[0]

1344
            # Update the sampling params.
1345
1346
1347
            sampling_params.max_tokens = self.max_model_len - len(
                engine_prompt["prompt_token_ids"]
            )
1348
1349
            # OPTIMIZATION
            priority = orig_priority - 1
1350
            sub_request += 1
1351

1352
1353
    def _get_prompt_components(
        self,
1354
        prompt: RequestPrompt | PromptType,
1355
    ) -> PromptComponents:
1356
1357
        if isinstance(prompt, list):
            return PromptComponents(token_ids=prompt)
1358

1359
        return get_prompt_components(prompt)  # type: ignore[arg-type]
1360

1361
1362
1363
    def _log_inputs(
        self,
        request_id: str,
1364
1365
1366
        inputs: RequestPrompt | PromptType,
        params: SamplingParams | PoolingParams | BeamSearchParams | None,
        lora_request: LoRARequest | None,
1367
1368
1369
    ) -> None:
        if self.request_logger is None:
            return
1370

1371
        prompt, prompt_token_ids, prompt_embeds = self._get_prompt_components(inputs)
1372
1373
1374
1375
1376

        self.request_logger.log_inputs(
            request_id,
            prompt,
            prompt_token_ids,
1377
            prompt_embeds,
1378
1379
1380
            params=params,
            lora_request=lora_request,
        )
1381

1382
1383
1384
    async def _get_trace_headers(
        self,
        headers: Headers,
1385
    ) -> Mapping[str, str] | None:
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
        is_tracing_enabled = await self.engine_client.is_tracing_enabled()

        if is_tracing_enabled:
            return extract_trace_headers(headers)

        if contains_trace_headers(headers):
            log_tracing_disabled_warning()

        return None

1396
    @staticmethod
1397
    def _base_request_id(
1398
1399
        raw_request: Request | None, default: str | None = None
    ) -> str | None:
1400
        """Pulls the request id to use from a header, if provided"""
1401
1402
1403
1404
        if raw_request is not None and (
            (req_id := raw_request.headers.get("X-Request-Id")) is not None
        ):
            return req_id
1405

1406
        return random_uuid() if default is None else default
1407

1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
    @staticmethod
    def _get_data_parallel_rank(raw_request: Request | None) -> int | None:
        """Pulls the data parallel rank from a header, if provided"""
        if raw_request is None:
            return None

        rank_str = raw_request.headers.get("X-data-parallel-rank")
        if rank_str is None:
            return None

        try:
            return int(rank_str)
        except ValueError:
            return None

1423
1424
1425
    @staticmethod
    def _parse_tool_calls_from_content(
        request: ResponsesRequest | ChatCompletionRequest,
1426
        tokenizer: TokenizerLike,
1427
        enable_auto_tools: bool,
1428
        tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
        content: str | None = None,
    ) -> tuple[list[FunctionCall] | None, str | None]:
        function_calls = list[FunctionCall]()
        if request.tool_choice and isinstance(request.tool_choice, ToolChoiceFunction):
            assert content is not None
            # Forced Function Call
            function_calls.append(
                FunctionCall(name=request.tool_choice.name, arguments=content)
            )
            content = None  # Clear content since tool is called.
        elif request.tool_choice and isinstance(
            request.tool_choice, ChatCompletionNamedToolChoiceParam
        ):
            assert content is not None
            # Forced Function Call
            function_calls.append(
                FunctionCall(name=request.tool_choice.function.name, arguments=content)
            )
            content = None  # Clear content since tool is called.
        elif request.tool_choice == "required":
            assert content is not None
            tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(content)
            function_calls.extend(
                [
                    FunctionCall(
                        name=tool_call.name,
                        arguments=json.dumps(tool_call.parameters, ensure_ascii=False),
                    )
                    for tool_call in tool_calls
                ]
            )
            content = None  # Clear content since tool is called.
        elif (
            tool_parser_cls
            and enable_auto_tools
            and (request.tool_choice == "auto" or request.tool_choice is None)
        ):
            # Automatic Tool Call Parsing
            try:
                tool_parser = tool_parser_cls(tokenizer)
            except RuntimeError as e:
                logger.exception("Error in tool parser creation.")
                raise e
            tool_call_info = tool_parser.extract_tool_calls(
                content if content is not None else "",
                request=request,  # type: ignore
            )
            if tool_call_info is not None and tool_call_info.tools_called:
                # extract_tool_calls() returns a list of tool calls.
                function_calls.extend(
                    FunctionCall(
                        name=tool_call.function.name,
                        arguments=tool_call.function.arguments,
                    )
                    for tool_call in tool_call_info.tool_calls
                )
                content = tool_call_info.content
1486
1487
                if content and content.strip() == "":
                    content = None
1488
1489
1490
1491
1492
1493
            else:
                # No tool calls.
                return None, content

        return function_calls, content

1494
    @staticmethod
1495
1496
1497
    def _get_decoded_token(
        logprob: Logprob,
        token_id: int,
1498
        tokenizer: TokenizerLike | None,
1499
1500
        return_as_token_id: bool = False,
    ) -> str:
1501
1502
1503
        if return_as_token_id:
            return f"token_id:{token_id}"

1504
1505
        if logprob.decoded_token is not None:
            return logprob.decoded_token
1506
1507
1508
1509
1510
1511

        if tokenizer is None:
            raise ValueError(
                "Unable to get tokenizer because `skip_tokenizer_init=True`"
            )

1512
        return tokenizer.decode(token_id)
1513

1514
    def _is_model_supported(self, model_name: str | None) -> bool:
1515
1516
        if not model_name:
            return True
1517
        return self.models.is_base_model(model_name)
1518

1519
1520

def clamp_prompt_logprobs(
1521
1522
    prompt_logprobs: PromptLogprobs | None,
) -> PromptLogprobs | None:
1523
1524
1525
1526
1527
1528
1529
    if prompt_logprobs is None:
        return prompt_logprobs

    for logprob_dict in prompt_logprobs:
        if logprob_dict is None:
            continue
        for logprob_values in logprob_dict.values():
1530
            if logprob_values.logprob == float("-inf"):
1531
1532
                logprob_values.logprob = -9999.0
    return prompt_logprobs