serving_engine.py 54.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
import json
5
import sys
6
import time
7
import traceback
8
from collections.abc import AsyncGenerator, Callable, Iterable, Mapping
9
from concurrent.futures import ThreadPoolExecutor
10
from dataclasses import dataclass, field
11
from http import HTTPStatus
12
from typing import Any, ClassVar, Generic, TypeAlias, TypeVar
13

14
import numpy as np
15
from fastapi import Request
16
17
18
from openai.types.responses import (
    ToolChoiceFunction,
)
19
20
from pydantic import ConfigDict, TypeAdapter
from starlette.datastructures import Headers
21

22
import vllm.envs as envs
23
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
24
from vllm.engine.protocol import EngineClient
25
26
27
28
29
30
31
32
33
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ChatTemplateContentFormatOption,
    ConversationMessage,
    apply_hf_chat_template,
    apply_mistral_chat_template,
    parse_chat_messages_futures,
    resolve_chat_template_content_format,
)
34
35
36
37
38
39
from vllm.entrypoints.context import (
    ConversationContext,
    HarmonyContext,
    ParsableContext,
    StreamingHarmonyContext,
)
40
from vllm.entrypoints.logger import RequestLogger
41
from vllm.entrypoints.openai.protocol import (
42
    ChatCompletionNamedToolChoiceParam,
43
44
45
46
47
48
49
    ChatCompletionRequest,
    ChatCompletionResponse,
    CompletionRequest,
    CompletionResponse,
    DetokenizeRequest,
    ErrorInfo,
    ErrorResponse,
50
    FunctionCall,
51
    FunctionDefinition,
52
53
    ResponseInputOutputItem,
    ResponsesRequest,
54
55
56
57
58
59
60
    TokenizeChatRequest,
    TokenizeCompletionRequest,
    TokenizeResponse,
    TranscriptionRequest,
    TranscriptionResponse,
    TranslationRequest,
)
61
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
62
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from vllm.entrypoints.pooling.classify.protocol import (
    ClassificationChatRequest,
    ClassificationCompletionRequest,
    ClassificationRequest,
    ClassificationResponse,
)
from vllm.entrypoints.pooling.embed.protocol import (
    EmbeddingChatRequest,
    EmbeddingCompletionRequest,
    EmbeddingRequest,
    EmbeddingResponse,
)
from vllm.entrypoints.pooling.pooling.protocol import (
    IOProcessorRequest,
    PoolingResponse,
)
from vllm.entrypoints.pooling.score.protocol import (
    RerankRequest,
    ScoreRequest,
    ScoreResponse,
)
84
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig
85
86
87
from vllm.entrypoints.responses_utils import (
    construct_input_messages,
)
88
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
89
from vllm.entrypoints.utils import _validate_truncation_size
90
from vllm.inputs.data import PromptType, TokensPrompt
91
92
93
94
95
from vllm.inputs.parse import (
    PromptComponents,
    get_prompt_components,
    is_explicit_encoder_decoder_prompt,
)
96
from vllm.logger import init_logger
97
from vllm.logprobs import Logprob, PromptLogprobs
98
from vllm.lora.request import LoRARequest
99
from vllm.multimodal import MultiModalDataDict
100
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
101
from vllm.pooling_params import PoolingParams
102
from vllm.reasoning import ReasoningParser, ReasoningParserManager
103
from vllm.sampling_params import BeamSearchParams, SamplingParams
104
105
106
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.deepseekv32 import DeepseekV32Tokenizer
from vllm.tokenizers.mistral import MistralTokenizer
107
108
109
110
111
from vllm.tracing import (
    contains_trace_headers,
    extract_trace_headers,
    log_tracing_disabled_warning,
)
112
from vllm.utils import random_uuid
113
from vllm.utils.async_utils import (
114
    AsyncMicrobatchTokenizer,
115
    collect_from_async_generator,
116
    make_async,
117
118
    merge_async_iterators,
)
119
from vllm.utils.collection_utils import is_list_of
120
from vllm.v1.engine import EngineCoreRequest
121

122
123
124
125
126
127
128
129
130

class GenerationError(Exception):
    """raised when finish_reason indicates internal server error (500)"""

    def __init__(self, message: str = "Internal server error"):
        super().__init__(message)
        self.status_code = HTTPStatus.INTERNAL_SERVER_ERROR


131
132
logger = init_logger(__name__)

133
134
135
136
137
CompletionLikeRequest: TypeAlias = (
    CompletionRequest
    | DetokenizeRequest
    | EmbeddingCompletionRequest
    | RerankRequest
138
    | ClassificationCompletionRequest
139
140
141
    | ScoreRequest
    | TokenizeCompletionRequest
)
142

143
ChatLikeRequest: TypeAlias = (
144
145
146
147
    ChatCompletionRequest
    | EmbeddingChatRequest
    | TokenizeChatRequest
    | ClassificationChatRequest
148
149
150
151
152
153
154
155
)
SpeechToTextRequest: TypeAlias = TranscriptionRequest | TranslationRequest
AnyRequest: TypeAlias = (
    CompletionLikeRequest
    | ChatLikeRequest
    | SpeechToTextRequest
    | ResponsesRequest
    | IOProcessorRequest
156
    | GenerateRequest
157
158
159
160
161
162
163
164
165
166
167
)

AnyResponse: TypeAlias = (
    CompletionResponse
    | ChatCompletionResponse
    | EmbeddingResponse
    | TranscriptionResponse
    | TokenizeResponse
    | PoolingResponse
    | ClassificationResponse
    | ScoreResponse
168
    | GenerateResponse
169
)
170

171

172
173
174
RequestT = TypeVar("RequestT", bound=AnyRequest)


175
176
@dataclass(kw_only=True)
class RequestProcessingMixin:
177
    """
178
    Mixin for request processing,
179
180
    handling prompt preparation and engine input.
    """
181

182
    engine_prompts: list[TokensPrompt] | None = field(default_factory=list)
183
184


185
186
@dataclass(kw_only=True)
class ResponseGenerationMixin:
187
    """
188
    Mixin for response generation,
189
190
    managing result generators and final batch results.
    """
191

192
193
194
    result_generator: (
        AsyncGenerator[tuple[int, RequestOutput | PoolingRequestOutput], None] | None
    ) = None
195
    final_res_batch: list[RequestOutput | PoolingRequestOutput] = field(
196
197
        default_factory=list
    )
198
199
200
201

    model_config = ConfigDict(arbitrary_types_allowed=True)


202
203
@dataclass(kw_only=True)
class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, Generic[RequestT]):
204
205
    # Shared across all requests
    request: RequestT
206
    raw_request: Request | None = None
207
208
    model_name: str
    request_id: str
209
    created_time: int = field(default_factory=lambda: int(time.time()))
210
    lora_request: LoRARequest | None = None
211
212

    # Shared across most requests
213
    tokenizer: TokenizerLike | None = None
214
215


216
217
218
@dataclass(kw_only=True)
class ClassificationServeContext(ServeContext[ClassificationRequest]):
    pass
219
220


221
@dataclass(kw_only=True)
222
class EmbeddingServeContext(ServeContext[EmbeddingRequest]):
223
    chat_template: str | None = None
224
225
226
    chat_template_content_format: ChatTemplateContentFormatOption


227
class OpenAIServing:
228
229
230
231
    request_id_prefix: ClassVar[str] = """
    A short string prepended to every request’s ID (e.g. "embd", "classify")
    so you can easily tell “this ID came from Embedding vs Classification.”
    """
232

233
234
    def __init__(
        self,
235
        engine_client: EngineClient,
236
        models: OpenAIServingModels,
237
        *,
238
        request_logger: RequestLogger | None,
239
        return_tokens_as_token_ids: bool = False,
240
        log_error_stack: bool = False,
241
    ):
242
243
        super().__init__()

244
        self.engine_client = engine_client
245

246
        self.models = models
247

248
        self.request_logger = request_logger
249
        self.return_tokens_as_token_ids = return_tokens_as_token_ids
250
        self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
251
        self._apply_mistral_chat_template_async = make_async(
252
253
            apply_mistral_chat_template, executor=self._tokenizer_executor
        )
254

255
        self._async_tokenizer_pool: dict[TokenizerLike, AsyncMicrobatchTokenizer] = {}
256
        self.log_error_stack = log_error_stack
257

258
        self.input_processor = self.models.input_processor
259
260
261
262
        self.io_processor = self.models.io_processor
        self.model_config = self.models.model_config
        self.max_model_len = self.model_config.max_model_len

263
    def _get_tool_parser(
264
        self, tool_parser_name: str | None = None, enable_auto_tools: bool = False
265
    ) -> Callable[[TokenizerLike], ToolParser] | None:
266
267
268
269
        """Get the tool parser based on the name."""
        parser = None
        if not enable_auto_tools or tool_parser_name is None:
            return parser
270
        logger.info('"auto" tool choice has been enabled.')
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290

        try:
            if tool_parser_name == "pythonic" and self.model_config.model.startswith(
                "meta-llama/Llama-3.2"
            ):
                logger.warning(
                    "Llama3.2 models may struggle to emit valid pythonic tool calls"
                )
            parser = ToolParserManager.get_tool_parser(tool_parser_name)
        except Exception as e:
            raise TypeError(
                "Error: --enable-auto-tool-choice requires "
                f"tool_parser:'{tool_parser_name}' which has not "
                "been registered"
            ) from e
        return parser

    def _get_reasoning_parser(
        self,
        reasoning_parser_name: str,
291
    ) -> Callable[[TokenizerLike], ReasoningParser] | None:
292
293
294
295
296
297
298
299
300
301
302
        """Get the reasoning parser based on the name."""
        parser = None
        if not reasoning_parser_name:
            return None
        try:
            parser = ReasoningParserManager.get_reasoning_parser(reasoning_parser_name)
            assert parser is not None
        except Exception as e:
            raise TypeError(f"{reasoning_parser_name=} has not been registered") from e
        return parser

303
    async def reset_mm_cache(self) -> None:
304
        self.input_processor.clear_mm_cache()
305
306
        await self.engine_client.reset_mm_cache()

307
308
309
310
311
    async def beam_search(
        self,
        prompt: PromptType,
        request_id: str,
        params: BeamSearchParams,
312
        lora_request: LoRARequest | None = None,
313
        trace_headers: Mapping[str, str] | None = None,
314
315
316
317
318
319
320
321
    ) -> AsyncGenerator[RequestOutput, None]:
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        ignore_eos = params.ignore_eos
        temperature = params.temperature
        length_penalty = params.length_penalty
        include_stop_str_in_output = params.include_stop_str_in_output

322
323
        input_processor = self.input_processor
        tokenizer = input_processor.tokenizer
324
325
        if tokenizer is None:
            raise ValueError(
326
                "You cannot use beam search when `skip_tokenizer_init=True`"
327
328
329
330
331
332
333
            )

        eos_token_id: int = tokenizer.eos_token_id  # type: ignore

        if is_explicit_encoder_decoder_prompt(prompt):
            raise NotImplementedError

334
        prompt_text: str | None
335
        prompt_token_ids: list[int]
336
        multi_modal_data: MultiModalDataDict | None
337
338
339
340
341
342
343
344
345
        if isinstance(prompt, str):
            prompt_text = prompt
            prompt_token_ids = []
            multi_modal_data = None
        else:
            prompt_text = prompt.get("prompt")  # type: ignore
            prompt_token_ids = prompt.get("prompt_token_ids", [])  # type: ignore
            multi_modal_data = prompt.get("multi_modal_data")  # type: ignore

346
347
348
349
350
351
352
353
354
355
        mm_processor_kwargs: dict[str, Any] | None = None

        # This is a workaround to fix multimodal beam search; this is a
        # bandaid fix for 2 small problems:
        # 1. Multi_modal_data on the processed_inputs currently resolves to
        #    `None`.
        # 2. preprocessing above expands the multimodal placeholders. However,
        #    this happens again in generation, so the double expansion causes
        #    a mismatch.
        # TODO - would be ideal to handle this more gracefully.
356
357
358
359
360

        tokenized_length = len(prompt_token_ids)

        sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)

361
        logprobs_num = 2 * beam_width
362
        beam_search_params = SamplingParams(
363
            logprobs=logprobs_num,
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
            max_tokens=1,
            temperature=temperature,
        )
        all_beams = [
            BeamSearchSequence(
                tokens=prompt_token_ids,
                cum_logprob=0,
                logprobs=[],
                multi_modal_data=multi_modal_data,
                mm_processor_kwargs=mm_processor_kwargs,
                lora_request=lora_request,
            )
        ]
        completed = []

        for _ in range(max_tokens):
            prompts_batch, lora_req_batch = zip(
                *[
                    (
383
                        TokensPrompt(
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
                            prompt_token_ids=beam.tokens,
                            multi_modal_data=beam.multi_modal_data,
                            mm_processor_kwargs=beam.mm_processor_kwargs,
                        ),
                        beam.lora_request,
                    )
                    for beam in all_beams
                ]
            )

            tasks = []
            request_id_batch = f"{request_id}-{random_uuid()}"

            for i, (individual_prompt, lora_req) in enumerate(
                zip(prompts_batch, lora_req_batch)
            ):
                request_id_item = f"{request_id_batch}-beam-{i}"
                task = asyncio.create_task(
                    collect_from_async_generator(
                        self.engine_client.generate(
                            individual_prompt,
                            beam_search_params,
                            request_id_item,
                            lora_request=lora_req,
408
                            trace_headers=trace_headers,
409
410
411
412
413
414
415
416
                        )
                    )
                )
                tasks.append(task)

            output = [x[0] for x in await asyncio.gather(*tasks)]

            new_beams = []
417
418
419
420
421
422
423
424
            # Store all new tokens generated by beam
            all_beams_token_id = []
            # Store the cumulative probability of all tokens
            # generated by beam search
            all_beams_logprob = []
            # Iterate through all beam inference results
            for i, result in enumerate(output):
                current_beam = all_beams[i]
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447

                # check for error finish reason and abort beam search
                if result.outputs[0].finish_reason == "error":
                    # yield error output and terminate beam search
                    yield RequestOutput(
                        request_id=request_id,
                        prompt=prompt_text,
                        outputs=[
                            CompletionOutput(
                                index=0,
                                text="",
                                token_ids=[],
                                cumulative_logprob=None,
                                logprobs=None,
                                finish_reason="error",
                            )
                        ],
                        finished=True,
                        prompt_token_ids=prompt_token_ids,
                        prompt_logprobs=None,
                    )
                    return

448
449
                if result.outputs[0].logprobs is not None:
                    logprobs = result.outputs[0].logprobs[0]
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
                    all_beams_token_id.extend(list(logprobs.keys()))
                    all_beams_logprob.extend(
                        [
                            current_beam.cum_logprob + obj.logprob
                            for obj in logprobs.values()
                        ]
                    )

            # Handle the token for the end of sentence (EOS)
            all_beams_token_id = np.array(all_beams_token_id)
            all_beams_logprob = np.array(all_beams_logprob)

            if not ignore_eos:
                # Get the index position of eos token in all generated results
                eos_idx = np.where(all_beams_token_id == eos_token_id)[0]
                for idx in eos_idx:
                    current_beam = all_beams[idx // logprobs_num]
                    result = output[idx // logprobs_num]
                    assert result.outputs[0].logprobs is not None
                    logprobs_entry = result.outputs[0].logprobs[0]
                    completed.append(
                        BeamSearchSequence(
                            tokens=current_beam.tokens + [eos_token_id]
                            if include_stop_str_in_output
                            else current_beam.tokens,
                            logprobs=current_beam.logprobs + [logprobs_entry],
                            cum_logprob=float(all_beams_logprob[idx]),
                            finish_reason="stop",
                            stop_reason=eos_token_id,
                        )
                    )
                # After processing, set the log probability of the eos condition
                # to negative infinity.
                all_beams_logprob[eos_idx] = -np.inf

            # Processing non-EOS tokens
            # Get indices of the top beam_width probabilities
            topn_idx = np.argpartition(np.negative(all_beams_logprob), beam_width)[
                :beam_width
            ]

            for idx in topn_idx:
                current_beam = all_beams[idx // logprobs_num]
                result = output[idx // logprobs_num]
                token_id = int(all_beams_token_id[idx])
                assert result.outputs[0].logprobs is not None
                logprobs_entry = result.outputs[0].logprobs[0]
                new_beams.append(
                    BeamSearchSequence(
                        tokens=current_beam.tokens + [token_id],
                        logprobs=current_beam.logprobs + [logprobs_entry],
                        lora_request=current_beam.lora_request,
                        cum_logprob=float(all_beams_logprob[idx]),
                        multi_modal_data=current_beam.multi_modal_data,
                        mm_processor_kwargs=current_beam.mm_processor_kwargs,
                    )
                )

            all_beams = new_beams
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542

        completed.extend(all_beams)
        sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
        best_beams = sorted_completed[:beam_width]

        for beam in best_beams:
            if beam.tokens[-1] == eos_token_id and not ignore_eos:
                # Skip the eos token in the text.
                tokens = beam.tokens[tokenized_length:-1]
            else:
                tokens = beam.tokens[tokenized_length:]
            beam.text = tokenizer.decode(tokens)

        yield RequestOutput(
            request_id=request_id,
            prompt=prompt_text,
            outputs=[
                CompletionOutput(
                    text=beam.text,  # type: ignore
                    cumulative_logprob=beam.cum_logprob,
                    token_ids=beam.tokens[tokenized_length:],
                    index=i,
                    logprobs=beam.logprobs,
                    finish_reason=beam.finish_reason
                    if beam.finish_reason is not None
                    else "length",
                    stop_reason=beam.stop_reason,
                )
                for (i, beam) in enumerate(best_beams)
            ],
            finished=True,
            prompt_token_ids=prompt_token_ids,
            prompt_logprobs=None,
        )
543

544
    def _get_renderer(self, tokenizer: TokenizerLike | None) -> BaseRenderer:
545
546
547
548
549
550
551
        """
        Get a Renderer instance with the provided tokenizer.
        Uses shared async tokenizer pool for efficiency.
        """
        return CompletionRenderer(
            model_config=self.model_config,
            tokenizer=tokenizer,
552
553
            async_tokenizer_pool=self._async_tokenizer_pool,
        )
554

555
556
557
558
559
560
561
562
563
564
565
566
567
    def _build_render_config(
        self,
        request: Any,
    ) -> RenderConfig:
        """
        Build and return a `RenderConfig` for an endpoint.

        Used by the renderer to control how prompts are prepared
        (e.g., tokenization and length handling). Endpoints should
        implement this with logic appropriate to their request type.
        """
        raise NotImplementedError

568
569
    def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer:
        """
570
        Return (and cache) an `AsyncMicrobatchTokenizer` bound to the
571
572
573
574
575
576
577
        given tokenizer.
        """
        async_tokenizer = self._async_tokenizer_pool.get(tokenizer)
        if async_tokenizer is None:
            async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
            self._async_tokenizer_pool[tokenizer] = async_tokenizer
        return async_tokenizer
578

579
580
581
    async def _preprocess(
        self,
        ctx: ServeContext,
582
    ) -> ErrorResponse | None:
583
584
585
586
587
588
589
590
591
        """
        Default preprocessing hook. Subclasses may override
        to prepare `ctx` (classification, embedding, etc.).
        """
        return None

    def _build_response(
        self,
        ctx: ServeContext,
592
    ) -> AnyResponse | ErrorResponse:
593
594
595
596
597
598
599
600
601
        """
        Default response builder. Subclass may override this method
        to return the appropriate response object.
        """
        return self.create_error_response("unimplemented endpoint")

    async def handle(
        self,
        ctx: ServeContext,
602
603
    ) -> AnyResponse | ErrorResponse:
        generation: AsyncGenerator[AnyResponse | ErrorResponse, None]
604
605
606
607
608
609
610
611
612
613
        generation = self._pipeline(ctx)

        async for response in generation:
            return response

        return self.create_error_response("No response yielded from pipeline")

    async def _pipeline(
        self,
        ctx: ServeContext,
614
    ) -> AsyncGenerator[AnyResponse | ErrorResponse, None]:
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
        """Execute the request processing pipeline yielding responses."""
        if error := await self._check_model(ctx.request):
            yield error
        if error := self._validate_request(ctx):
            yield error

        preprocess_ret = await self._preprocess(ctx)
        if isinstance(preprocess_ret, ErrorResponse):
            yield preprocess_ret

        generators_ret = await self._prepare_generators(ctx)
        if isinstance(generators_ret, ErrorResponse):
            yield generators_ret

        collect_ret = await self._collect_batch(ctx)
        if isinstance(collect_ret, ErrorResponse):
            yield collect_ret

        yield self._build_response(ctx)

635
    def _validate_request(self, ctx: ServeContext) -> ErrorResponse | None:
636
        truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None)
637

638
639
640
641
        if (
            truncate_prompt_tokens is not None
            and truncate_prompt_tokens > self.max_model_len
        ):
642
643
644
            return self.create_error_response(
                "truncate_prompt_tokens value is "
                "greater than max_model_len."
645
646
                " Please, select a smaller truncation size."
            )
647
648
        return None

649
650
651
    def _create_pooling_params(
        self,
        ctx: ServeContext,
652
    ) -> PoolingParams | ErrorResponse:
653
654
        if not hasattr(ctx.request, "to_pooling_params"):
            return self.create_error_response(
655
656
                "Request type does not support pooling parameters"
            )
657
658
659

        return ctx.request.to_pooling_params()

660
661
662
    async def _prepare_generators(
        self,
        ctx: ServeContext,
663
    ) -> ErrorResponse | None:
664
        """Schedule the request and get the result generator."""
665
        generators: list[
666
            AsyncGenerator[RequestOutput | PoolingRequestOutput, None]
667
        ] = []
668
669

        try:
670
671
672
673
674
            trace_headers = (
                None
                if ctx.raw_request is None
                else await self._get_trace_headers(ctx.raw_request.headers)
            )
675

676
677
678
            pooling_params = self._create_pooling_params(ctx)
            if isinstance(pooling_params, ErrorResponse):
                return pooling_params
679
680

            if ctx.engine_prompts is None:
681
                return self.create_error_response("Engine prompts not available")
682
683
684
685

            for i, engine_prompt in enumerate(ctx.engine_prompts):
                request_id_item = f"{ctx.request_id}-{i}"

686
687
                self._log_inputs(
                    request_id_item,
688
                    engine_prompt,
689
690
691
                    params=pooling_params,
                    lora_request=ctx.lora_request,
                )
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714

                generator = self.engine_client.encode(
                    engine_prompt,
                    pooling_params,
                    request_id_item,
                    lora_request=ctx.lora_request,
                    trace_headers=trace_headers,
                    priority=getattr(ctx.request, "priority", 0),
                )

                generators.append(generator)

            ctx.result_generator = merge_async_iterators(*generators)

            return None

        except Exception as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))

    async def _collect_batch(
        self,
        ctx: ServeContext,
715
    ) -> ErrorResponse | None:
716
717
718
        """Collect batch results from the result generator."""
        try:
            if ctx.engine_prompts is None:
719
                return self.create_error_response("Engine prompts not available")
720
721

            num_prompts = len(ctx.engine_prompts)
722
            final_res_batch: list[RequestOutput | PoolingRequestOutput | None]
723
724
725
            final_res_batch = [None] * num_prompts

            if ctx.result_generator is None:
726
                return self.create_error_response("Result generator not available")
727
728
729
730
731
732

            async for i, res in ctx.result_generator:
                final_res_batch[i] = res

            if None in final_res_batch:
                return self.create_error_response(
733
734
                    "Failed to generate results for all prompts"
                )
735

736
            ctx.final_res_batch = [res for res in final_res_batch if res is not None]
737
738
739
740
741
742

            return None

        except Exception as e:
            return self.create_error_response(str(e))

743
    def create_error_response(
744
745
746
747
748
        self,
        message: str,
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
    ) -> ErrorResponse:
749
750
751
752
753
754
        if self.log_error_stack:
            exc_type, _, _ = sys.exc_info()
            if exc_type is not None:
                traceback.print_exc()
            else:
                traceback.print_stack()
755
756
757
        return ErrorResponse(
            error=ErrorInfo(message=message, type=err_type, code=status_code.value)
        )
758

759
    def create_streaming_error_response(
760
761
762
763
764
        self,
        message: str,
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
    ) -> str:
765
        json_str = json.dumps(
766
767
768
769
            self.create_error_response(
                message=message, err_type=err_type, status_code=status_code
            ).model_dump()
        )
770
771
        return json_str

772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
    def _raise_if_error(self, finish_reason: str | None, request_id: str) -> None:
        """Raise GenerationError if finish_reason indicates an error."""
        if finish_reason == "error":
            logger.error(
                "Request %s failed with an internal error during generation",
                request_id,
            )
            raise GenerationError("Internal server error")

    def _convert_generation_error_to_response(
        self, e: GenerationError
    ) -> ErrorResponse:
        """Convert GenerationError to ErrorResponse."""
        return self.create_error_response(
            str(e),
            err_type="InternalServerError",
            status_code=e.status_code,
        )

    def _convert_generation_error_to_streaming_response(
        self, e: GenerationError
    ) -> str:
        """Convert GenerationError to streaming error response."""
        return self.create_streaming_error_response(
            str(e),
            err_type="InternalServerError",
            status_code=e.status_code,
        )

801
    async def _check_model(
802
803
        self,
        request: AnyRequest,
804
    ) -> ErrorResponse | None:
805
806
        error_response = None

807
        if self._is_model_supported(request.model):
808
            return None
809
        if request.model in self.models.lora_requests:
810
            return None
811
812
813
814
815
        if (
            envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING
            and request.model
            and (load_result := await self.models.resolve_lora(request.model))
        ):
816
817
            if isinstance(load_result, LoRARequest):
                return None
818
819
820
821
            if (
                isinstance(load_result, ErrorResponse)
                and load_result.error.code == HTTPStatus.BAD_REQUEST.value
            ):
822
823
824
                error_response = load_result

        return error_response or self.create_error_response(
825
826
            message=f"The model `{request.model}` does not exist.",
            err_type="NotFoundError",
827
828
            status_code=HTTPStatus.NOT_FOUND,
        )
829

830
    def _get_active_default_mm_loras(self, request: AnyRequest) -> LoRARequest | None:
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
        """Determine if there are any active default multimodal loras."""
        # TODO: Currently this is only enabled for chat completions
        # to be better aligned with only being enabled for .generate
        # when run offline. It would be nice to support additional
        # tasks types in the future.
        message_types = self._get_message_types(request)
        default_mm_loras = set()

        for lora in self.models.lora_requests.values():
            # Best effort match for default multimodal lora adapters;
            # There is probably a better way to do this, but currently
            # this matches against the set of 'types' in any content lists
            # up until '_', e.g., to match audio_url -> audio
            if lora.lora_name in message_types:
                default_mm_loras.add(lora)

        # Currently only support default modality specific loras if
        # we have exactly one lora matched on the request.
        if len(default_mm_loras) == 1:
            return default_mm_loras.pop()
        return None

853
    def _maybe_get_adapters(
854
855
856
        self,
        request: AnyRequest,
        supports_default_mm_loras: bool = False,
857
    ) -> LoRARequest | None:
858
        if request.model in self.models.lora_requests:
859
            return self.models.lora_requests[request.model]
860
861
862
863
864
865

        # Currently only support default modality specific loras
        # if we have exactly one lora matched on the request.
        if supports_default_mm_loras:
            default_mm_lora = self._get_active_default_mm_loras(request)
            if default_mm_lora is not None:
866
                return default_mm_lora
867
868

        if self._is_model_supported(request.model):
869
            return None
870

871
        # if _check_model has been called earlier, this will be unreachable
872
        raise ValueError(f"The model `{request.model}` does not exist.")
873

874
875
876
877
878
879
880
881
882
883
    def _get_message_types(self, request: AnyRequest) -> set[str]:
        """Retrieve the set of types from message content dicts up
        until `_`; we use this to match potential multimodal data
        with default per modality loras.
        """
        message_types: set[str] = set()

        if not hasattr(request, "messages"):
            return message_types

884
885
886
887
888
        messages = request.messages
        if messages is None or isinstance(messages, (str, bytes)):
            return message_types

        for message in messages:
889
890
891
892
893
            if (
                isinstance(message, dict)
                and "content" in message
                and isinstance(message["content"], list)
            ):
894
895
896
897
898
                for content_dict in message["content"]:
                    if "type" in content_dict:
                        message_types.add(content_dict["type"].split("_")[0])
        return message_types

899
    async def _normalize_prompt_text_to_input(
900
901
902
        self,
        request: AnyRequest,
        prompt: str,
903
        tokenizer: TokenizerLike,
904
        add_special_tokens: bool,
905
    ) -> TokensPrompt:
906
907
        async_tokenizer = self._get_async_tokenizer(tokenizer)

908
909
910
911
        if (
            self.model_config.encoder_config is not None
            and self.model_config.encoder_config.get("do_lower_case", False)
        ):
912
913
            prompt = prompt.lower()

914
        truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
915

916
        if truncate_prompt_tokens is None:
917
            encoded = await async_tokenizer(
918
919
                prompt, add_special_tokens=add_special_tokens
            )
920
921
        elif truncate_prompt_tokens < 0:
            # Negative means we cap at the model's max length
922
923
924
925
            encoded = await async_tokenizer(
                prompt,
                add_special_tokens=add_special_tokens,
                truncation=True,
926
927
                max_length=self.max_model_len,
            )
928
        else:
929
930
931
932
            encoded = await async_tokenizer(
                prompt,
                add_special_tokens=add_special_tokens,
                truncation=True,
933
934
                max_length=truncate_prompt_tokens,
            )
935
936
937
938
939
940

        input_ids = encoded.input_ids
        input_text = prompt

        return self._validate_input(request, input_ids, input_text)

941
    async def _normalize_prompt_tokens_to_input(
942
943
        self,
        request: AnyRequest,
944
        prompt_ids: list[int],
945
        tokenizer: TokenizerLike | None,
946
    ) -> TokensPrompt:
947
        truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
948

949
        if truncate_prompt_tokens is None:
950
            input_ids = prompt_ids
951
        elif truncate_prompt_tokens < 0:
952
            input_ids = prompt_ids[-self.max_model_len :]
953
954
955
        else:
            input_ids = prompt_ids[-truncate_prompt_tokens:]

956
957
958
959
960
        if tokenizer is None:
            input_text = ""
        else:
            async_tokenizer = self._get_async_tokenizer(tokenizer)
            input_text = await async_tokenizer.decode(input_ids)
961

962
963
964
965
966
        return self._validate_input(request, input_ids, input_text)

    def _validate_input(
        self,
        request: AnyRequest,
967
        input_ids: list[int],
968
        input_text: str,
969
    ) -> TokensPrompt:
970
971
        token_num = len(input_ids)

972
973
        # Note: EmbeddingRequest, ClassificationRequest,
        # and ScoreRequest doesn't have max_tokens
974
        if isinstance(
975
            request,
976
977
978
979
980
            (
                EmbeddingChatRequest,
                EmbeddingCompletionRequest,
                ScoreRequest,
                RerankRequest,
981
982
                ClassificationCompletionRequest,
                ClassificationChatRequest,
983
984
            ),
        ):
985
986
            # Note: input length can be up to the entire model context length
            # since these requests don't generate tokens.
987
            if token_num > self.max_model_len:
988
989
                operations: dict[type[AnyRequest], str] = {
                    ScoreRequest: "score",
990
991
                    ClassificationCompletionRequest: "classification",
                    ClassificationChatRequest: "classification",
992
                }
993
                operation = operations.get(type(request), "embedding generation")
994
995
996
                raise ValueError(
                    f"This model's maximum context length is "
                    f"{self.max_model_len} tokens. However, you requested "
997
                    f"{token_num} tokens in the input for {operation}. "
998
999
                    f"Please reduce the length of the input."
                )
1000
            return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
1001

1002
1003
        # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
        # and does not require model context length validation
1004
        if isinstance(
1005
1006
            request,
            (TokenizeCompletionRequest, TokenizeChatRequest, DetokenizeRequest),
1007
        ):
1008
            return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
1009

1010
1011
1012
1013
1014
        # chat completion endpoint supports max_completion_tokens
        if isinstance(request, ChatCompletionRequest):
            # TODO(#9845): remove max_tokens when field dropped from OpenAI API
            max_tokens = request.max_completion_tokens or request.max_tokens
        else:
1015
            max_tokens = getattr(request, "max_tokens", None)
1016
1017
1018
1019

        # Note: input length can be up to model context length - 1 for
        # completion-like requests.
        if token_num >= self.max_model_len:
1020
            raise ValueError(
1021
                f"This model's maximum context length is "
1022
1023
                f"{self.max_model_len} tokens. However, your request has "
                f"{token_num} input tokens. Please reduce the length of "
1024
1025
                "the input messages."
            )
1026

1027
        if max_tokens is not None and token_num + max_tokens > self.max_model_len:
1028
1029
1030
1031
1032
            raise ValueError(
                "'max_tokens' or 'max_completion_tokens' is too large: "
                f"{max_tokens}. This model's maximum context length is "
                f"{self.max_model_len} tokens and your request has "
                f"{token_num} input tokens ({max_tokens} > {self.max_model_len}"
1033
1034
                f" - {token_num})."
            )
1035

1036
        return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
1037

1038
    async def _tokenize_prompt_input_async(
1039
1040
        self,
        request: AnyRequest,
1041
        tokenizer: TokenizerLike,
1042
        prompt_input: str | list[int],
1043
        add_special_tokens: bool = True,
1044
    ) -> TokensPrompt:
1045
        """
1046
        A simpler implementation that tokenizes a single prompt input.
1047
        """
1048
        async for result in self._tokenize_prompt_inputs_async(
1049
1050
            request,
            tokenizer,
1051
            [prompt_input],
1052
            add_special_tokens=add_special_tokens,
1053
1054
1055
        ):
            return result
        raise ValueError("No results yielded from tokenization")
1056

1057
    async def _tokenize_prompt_inputs_async(
1058
1059
        self,
        request: AnyRequest,
1060
        tokenizer: TokenizerLike,
1061
        prompt_inputs: Iterable[str | list[int]],
1062
        add_special_tokens: bool = True,
1063
    ) -> AsyncGenerator[TokensPrompt, None]:
1064
        """
1065
        A simpler implementation that tokenizes multiple prompt inputs.
1066
        """
1067
1068
        for prompt in prompt_inputs:
            if isinstance(prompt, str):
1069
                yield await self._normalize_prompt_text_to_input(
1070
                    request,
1071
1072
                    prompt=prompt,
                    tokenizer=tokenizer,
1073
1074
1075
                    add_special_tokens=add_special_tokens,
                )
            else:
1076
                yield await self._normalize_prompt_tokens_to_input(
1077
                    request,
1078
1079
                    prompt_ids=prompt,
                    tokenizer=tokenizer,
1080
1081
                )

1082
1083
    def _validate_chat_template(
        self,
1084
1085
        request_chat_template: str | None,
        chat_template_kwargs: dict[str, Any] | None,
1086
        trust_request_chat_template: bool,
1087
    ) -> ErrorResponse | None:
1088
        if not trust_request_chat_template and (
1089
1090
1091
1092
1093
1094
            request_chat_template is not None
            or (
                chat_template_kwargs
                and chat_template_kwargs.get("chat_template") is not None
            )
        ):
1095
1096
1097
            return self.create_error_response(
                "Chat template is passed with request, but "
                "--trust-request-chat-template is not set. "
1098
1099
                "Refused request with untrusted chat template."
            )
1100
1101
        return None

1102
1103
    async def _preprocess_chat(
        self,
1104
        request: ChatLikeRequest | ResponsesRequest,
1105
        tokenizer: TokenizerLike | None,
1106
        messages: list[ChatCompletionMessageParam],
1107
        chat_template: str | None,
1108
        chat_template_content_format: ChatTemplateContentFormatOption,
1109
1110
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
1111
1112
1113
        tool_dicts: list[dict[str, Any]] | None = None,
        documents: list[dict[str, str]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
1114
        tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
1115
        add_special_tokens: bool = False,
1116
    ) -> tuple[list[ConversationMessage], list[TokensPrompt]]:
1117
        model_config = self.model_config
1118

1119
1120
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
1121
            tool_dicts,
1122
1123
            chat_template_content_format,
            tokenizer,
1124
            model_config=model_config,
1125
        )
1126
        conversation, mm_data_future, mm_uuids = parse_chat_messages_futures(
1127
            messages,
1128
            model_config,
1129
            content_format=resolved_content_format,
1130
1131
        )

1132
        _chat_template_kwargs: dict[str, Any] = dict(
1133
1134
1135
1136
1137
1138
1139
1140
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tool_dicts,
            documents=documents,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

1141
        request_prompt: str | list[int]
1142
1143
1144
1145

        if tokenizer is None:
            request_prompt = "placeholder"
        elif isinstance(tokenizer, MistralTokenizer):
1146
            request_prompt = await self._apply_mistral_chat_template_async(
1147
1148
                tokenizer,
                messages=messages,
1149
                **_chat_template_kwargs,
1150
            )
1151
1152
1153
1154
        elif isinstance(tokenizer, DeepseekV32Tokenizer):
            request_prompt = tokenizer.apply_chat_template(
                conversation=conversation,
                messages=messages,
1155
                model_config=model_config,
1156
1157
                **_chat_template_kwargs,
            )
1158
1159
        else:
            request_prompt = apply_hf_chat_template(
1160
                tokenizer=tokenizer,
1161
                conversation=conversation,
1162
                model_config=model_config,
1163
                **_chat_template_kwargs,
1164
1165
1166
1167
            )

        mm_data = await mm_data_future

1168
1169
1170
        # tool parsing is done only if a tool_parser has been set and if
        # tool_choice is not "none" (if tool_choice is "none" but a tool_parser
        # is set, we want to prevent parsing a tool_call hallucinated by the LLM
1171
1172
1173
        should_parse_tools = tool_parser is not None and (
            hasattr(request, "tool_choice") and request.tool_choice != "none"
        )
1174
1175

        if should_parse_tools:
1176
1177
1178
1179
1180
            if not isinstance(request, ChatCompletionRequest | ResponsesRequest):
                msg = (
                    "Tool usage is only supported for Chat Completions API "
                    "or Responses API requests."
                )
1181
                raise NotImplementedError(msg)
1182
            request = tool_parser(tokenizer).adjust_request(request=request)  # type: ignore
1183

1184
1185
        if tokenizer is None:
            assert isinstance(request_prompt, str), (
1186
1187
                "Prompt has to be a string",
                "when the tokenizer is not initialised",
1188
            )
1189
            prompt_inputs = TokensPrompt(prompt=request_prompt, prompt_token_ids=[1])
1190
        elif isinstance(request_prompt, str):
1191
            prompt_inputs = await self._tokenize_prompt_input_async(
1192
1193
1194
1195
1196
1197
1198
1199
                request,
                tokenizer,
                request_prompt,
                add_special_tokens=add_special_tokens,
            )
        else:
            # For MistralTokenizer
            assert is_list_of(request_prompt, int), (
1200
1201
                "Prompt has to be either a string or a list of token ids"
            )
1202
            prompt_inputs = TokensPrompt(
1203
                prompt=tokenizer.decode(request_prompt),
1204
1205
                prompt_token_ids=request_prompt,
            )
1206

1207
1208
1209
1210
        engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["prompt_token_ids"])
        if "prompt" in prompt_inputs:
            engine_prompt["prompt"] = prompt_inputs["prompt"]

1211
1212
        if mm_data is not None:
            engine_prompt["multi_modal_data"] = mm_data
1213
1214
1215
1216

        if mm_uuids is not None:
            engine_prompt["multi_modal_uuids"] = mm_uuids

1217
1218
        if request.mm_processor_kwargs is not None:
            engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
1219

1220
1221
1222
        if hasattr(request, "cache_salt") and request.cache_salt is not None:
            engine_prompt["cache_salt"] = request.cache_salt

1223
        return conversation, [engine_prompt]
1224

1225
1226
1227
1228
    async def _process_inputs(
        self,
        request_id: str,
        engine_prompt: PromptType,
1229
        params: SamplingParams | PoolingParams,
1230
        *,
1231
1232
        lora_request: LoRARequest | None,
        trace_headers: Mapping[str, str] | None,
1233
1234
        priority: int,
    ) -> tuple[EngineCoreRequest, dict[str, Any]]:
1235
        """Use the Processor to process inputs for AsyncLLM."""
1236
        tokenization_kwargs: dict[str, Any] = {}
1237
1238
1239
        _validate_truncation_size(
            self.max_model_len, params.truncate_prompt_tokens, tokenization_kwargs
        )
1240

1241
        engine_request = self.input_processor.process_inputs(
1242
1243
            request_id,
            engine_prompt,
1244
            params,
1245
1246
1247
1248
1249
1250
1251
            lora_request=lora_request,
            tokenization_kwargs=tokenization_kwargs,
            trace_headers=trace_headers,
            priority=priority,
        )
        return engine_request, tokenization_kwargs

1252
1253
1254
    async def _render_next_turn(
        self,
        request: ResponsesRequest,
1255
        tokenizer: TokenizerLike | None,
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
        messages: list[ResponseInputOutputItem],
        tool_dicts: list[dict[str, Any]] | None,
        tool_parser,
        chat_template: str | None,
        chat_template_content_format: ChatTemplateContentFormatOption,
    ):
        new_messages = construct_input_messages(
            request_input=messages,
        )

1266
        _, engine_prompts = await self._preprocess_chat(
1267
1268
1269
1270
1271
1272
1273
1274
            request,
            tokenizer,
            new_messages,
            tool_dicts=tool_dicts,
            tool_parser=tool_parser,
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
        )
1275
        return engine_prompts
1276

1277
1278
1279
    async def _generate_with_builtin_tools(
        self,
        request_id: str,
1280
        engine_prompt: TokensPrompt,
1281
1282
        sampling_params: SamplingParams,
        context: ConversationContext,
1283
        lora_request: LoRARequest | None = None,
1284
1285
1286
        priority: int = 0,
        **kwargs,
    ):
1287
1288
        prompt_text, _, _ = self._get_prompt_components(engine_prompt)

1289
        orig_priority = priority
1290
        sub_request = 0
1291
        while True:
1292
1293
            # Ensure that each sub-request has a unique request id.
            sub_request_id = f"{request_id}_{sub_request}"
1294
            self._log_inputs(
1295
                sub_request_id,
1296
                engine_prompt,
1297
1298
1299
                params=sampling_params,
                lora_request=lora_request,
            )
1300
            trace_headers = kwargs.get("trace_headers")
1301
            engine_request, tokenization_kwargs = await self._process_inputs(
1302
                sub_request_id,
1303
1304
                engine_prompt,
                sampling_params,
1305
1306
1307
                lora_request=lora_request,
                trace_headers=trace_headers,
                priority=priority,
1308
            )
1309
1310
1311
1312

            generator = self.engine_client.generate(
                engine_request,
                sampling_params,
1313
                sub_request_id,
1314
1315
                lora_request=lora_request,
                priority=priority,
1316
1317
                prompt_text=prompt_text,
                tokenization_kwargs=tokenization_kwargs,
1318
1319
                **kwargs,
            )
1320

1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
            async for res in generator:
                context.append_output(res)
                # NOTE(woosuk): The stop condition is handled by the engine.
                yield context

            if not context.need_builtin_tool_call():
                # The model did not ask for a tool call, so we're done.
                break

            # Call the tool and update the context with the result.
            tool_output = await context.call_tool()
1332
            context.append_tool_output(tool_output)
1333
1334
1335
1336
1337
1338

            # TODO: uncomment this and enable tool output streaming
            # yield context

            # Create inputs for the next turn.
            # Render the next prompt token ids.
1339
1340
            if isinstance(context, (HarmonyContext, StreamingHarmonyContext)):
                prompt_token_ids = context.render_for_completion()
1341
                engine_prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
1342
            elif isinstance(context, ParsableContext):
1343
                engine_prompts = await self._render_next_turn(
1344
1345
1346
1347
1348
1349
1350
1351
1352
                    context.request,
                    context.tokenizer,
                    context.parser.response_messages,
                    context.tool_dicts,
                    context.tool_parser_cls,
                    context.chat_template,
                    context.chat_template_content_format,
                )
                engine_prompt = engine_prompts[0]
1353
                prompt_text, _, _ = self._get_prompt_components(engine_prompt)
1354

1355
            # Update the sampling params.
1356
1357
1358
            sampling_params.max_tokens = self.max_model_len - len(
                engine_prompt["prompt_token_ids"]
            )
1359
1360
            # OPTIMIZATION
            priority = orig_priority - 1
1361
            sub_request += 1
1362

1363
1364
    def _get_prompt_components(self, prompt: PromptType) -> PromptComponents:
        return get_prompt_components(prompt)
1365

1366
1367
1368
    def _log_inputs(
        self,
        request_id: str,
1369
        inputs: PromptType,
1370
1371
        params: SamplingParams | PoolingParams | BeamSearchParams | None,
        lora_request: LoRARequest | None,
1372
1373
1374
    ) -> None:
        if self.request_logger is None:
            return
1375

1376
        prompt, prompt_token_ids, prompt_embeds = self._get_prompt_components(inputs)
1377
1378
1379
1380
1381

        self.request_logger.log_inputs(
            request_id,
            prompt,
            prompt_token_ids,
1382
            prompt_embeds,
1383
1384
1385
            params=params,
            lora_request=lora_request,
        )
1386

1387
1388
1389
    async def _get_trace_headers(
        self,
        headers: Headers,
1390
    ) -> Mapping[str, str] | None:
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
        is_tracing_enabled = await self.engine_client.is_tracing_enabled()

        if is_tracing_enabled:
            return extract_trace_headers(headers)

        if contains_trace_headers(headers):
            log_tracing_disabled_warning()

        return None

1401
    @staticmethod
1402
    def _base_request_id(
1403
1404
        raw_request: Request | None, default: str | None = None
    ) -> str | None:
1405
        """Pulls the request id to use from a header, if provided"""
1406
1407
1408
1409
        if raw_request is not None and (
            (req_id := raw_request.headers.get("X-Request-Id")) is not None
        ):
            return req_id
1410

1411
        return random_uuid() if default is None else default
1412

1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
    @staticmethod
    def _get_data_parallel_rank(raw_request: Request | None) -> int | None:
        """Pulls the data parallel rank from a header, if provided"""
        if raw_request is None:
            return None

        rank_str = raw_request.headers.get("X-data-parallel-rank")
        if rank_str is None:
            return None

        try:
            return int(rank_str)
        except ValueError:
            return None

1428
1429
1430
    @staticmethod
    def _parse_tool_calls_from_content(
        request: ResponsesRequest | ChatCompletionRequest,
1431
        tokenizer: TokenizerLike | None,
1432
        enable_auto_tools: bool,
1433
        tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
        content: str | None = None,
    ) -> tuple[list[FunctionCall] | None, str | None]:
        function_calls = list[FunctionCall]()
        if request.tool_choice and isinstance(request.tool_choice, ToolChoiceFunction):
            assert content is not None
            # Forced Function Call
            function_calls.append(
                FunctionCall(name=request.tool_choice.name, arguments=content)
            )
            content = None  # Clear content since tool is called.
        elif request.tool_choice and isinstance(
            request.tool_choice, ChatCompletionNamedToolChoiceParam
        ):
            assert content is not None
            # Forced Function Call
            function_calls.append(
                FunctionCall(name=request.tool_choice.function.name, arguments=content)
            )
            content = None  # Clear content since tool is called.
        elif request.tool_choice == "required":
            assert content is not None
            tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(content)
            function_calls.extend(
                [
                    FunctionCall(
                        name=tool_call.name,
                        arguments=json.dumps(tool_call.parameters, ensure_ascii=False),
                    )
                    for tool_call in tool_calls
                ]
            )
            content = None  # Clear content since tool is called.
        elif (
            tool_parser_cls
            and enable_auto_tools
            and (request.tool_choice == "auto" or request.tool_choice is None)
        ):
1471
1472
1473
1474
1475
            if tokenizer is None:
                raise ValueError(
                    "Tokenizer not available when `skip_tokenizer_init=True`"
                )

1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
            # Automatic Tool Call Parsing
            try:
                tool_parser = tool_parser_cls(tokenizer)
            except RuntimeError as e:
                logger.exception("Error in tool parser creation.")
                raise e
            tool_call_info = tool_parser.extract_tool_calls(
                content if content is not None else "",
                request=request,  # type: ignore
            )
            if tool_call_info is not None and tool_call_info.tools_called:
                # extract_tool_calls() returns a list of tool calls.
                function_calls.extend(
                    FunctionCall(
                        name=tool_call.function.name,
                        arguments=tool_call.function.arguments,
                    )
                    for tool_call in tool_call_info.tool_calls
                )
                content = tool_call_info.content
1496
1497
                if content and content.strip() == "":
                    content = None
1498
1499
1500
1501
1502
1503
            else:
                # No tool calls.
                return None, content

        return function_calls, content

1504
    @staticmethod
1505
1506
1507
    def _get_decoded_token(
        logprob: Logprob,
        token_id: int,
1508
        tokenizer: TokenizerLike | None,
1509
1510
        return_as_token_id: bool = False,
    ) -> str:
1511
1512
1513
        if return_as_token_id:
            return f"token_id:{token_id}"

1514
1515
        if logprob.decoded_token is not None:
            return logprob.decoded_token
1516
1517
1518
1519
1520
1521

        if tokenizer is None:
            raise ValueError(
                "Unable to get tokenizer because `skip_tokenizer_init=True`"
            )

1522
        return tokenizer.decode(token_id)
1523

1524
    def _is_model_supported(self, model_name: str | None) -> bool:
1525
1526
        if not model_name:
            return True
1527
        return self.models.is_base_model(model_name)
1528

1529
1530

def clamp_prompt_logprobs(
1531
1532
    prompt_logprobs: PromptLogprobs | None,
) -> PromptLogprobs | None:
1533
1534
1535
1536
1537
1538
1539
    if prompt_logprobs is None:
        return prompt_logprobs

    for logprob_dict in prompt_logprobs:
        if logprob_dict is None:
            continue
        for logprob_values in logprob_dict.values():
1540
            if logprob_values.logprob == float("-inf"):
1541
1542
                logprob_values.logprob = -9999.0
    return prompt_logprobs