serving.py 55.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
import json
5
import sys
6
import time
7
import traceback
8
from collections.abc import AsyncGenerator, Callable, Iterable, Mapping
9
from dataclasses import dataclass, field
10
from http import HTTPStatus
11
from typing import Any, ClassVar, Generic, TypeAlias, TypeVar, cast
12

13
import numpy as np
14
from fastapi import Request
15
16
17
from openai.types.responses import (
    ToolChoiceFunction,
)
18
19
from pydantic import ConfigDict, TypeAdapter
from starlette.datastructures import Headers
20

21
import vllm.envs as envs
22
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
23
from vllm.engine.protocol import EngineClient
24
25
26
27
28
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ChatTemplateContentFormatOption,
    ConversationMessage,
)
29
from vllm.entrypoints.logger import RequestLogger
30
from vllm.entrypoints.openai.chat_completion.protocol import (
31
    ChatCompletionNamedToolChoiceParam,
32
33
    ChatCompletionRequest,
    ChatCompletionResponse,
34
)
35
from vllm.entrypoints.openai.completion.protocol import (
36
37
    CompletionRequest,
    CompletionResponse,
38
39
)
from vllm.entrypoints.openai.engine.protocol import (
40
41
    ErrorInfo,
    ErrorResponse,
42
    FunctionCall,
43
    FunctionDefinition,
44
)
45
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
46
47
48
49
50
51
from vllm.entrypoints.openai.responses.context import (
    ConversationContext,
    HarmonyContext,
    ParsableContext,
    StreamingHarmonyContext,
)
52
53
54
55
from vllm.entrypoints.openai.responses.protocol import (
    ResponseInputOutputItem,
    ResponsesRequest,
)
56
57
58
from vllm.entrypoints.openai.responses.utils import (
    construct_input_messages,
)
59
60
61
62
63
from vllm.entrypoints.openai.translations.protocol import (
    TranscriptionRequest,
    TranscriptionResponse,
    TranslationRequest,
)
64
65
66
67
68
69
from vllm.entrypoints.pooling.classify.protocol import (
    ClassificationChatRequest,
    ClassificationCompletionRequest,
    ClassificationResponse,
)
from vllm.entrypoints.pooling.embed.protocol import (
70
    EmbeddingBytesResponse,
71
72
73
74
75
76
    EmbeddingChatRequest,
    EmbeddingCompletionRequest,
    EmbeddingResponse,
)
from vllm.entrypoints.pooling.pooling.protocol import (
    IOProcessorRequest,
77
78
    PoolingChatRequest,
    PoolingCompletionRequest,
79
80
81
82
    PoolingResponse,
)
from vllm.entrypoints.pooling.score.protocol import (
    RerankRequest,
83
84
    ScoreDataRequest,
    ScoreQueriesDocumentsRequest,
85
86
    ScoreRequest,
    ScoreResponse,
87
    ScoreTextRequest,
88
)
89
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig
90
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
91
92
93
94
95
96
from vllm.entrypoints.serve.tokenize.protocol import (
    DetokenizeRequest,
    TokenizeChatRequest,
    TokenizeCompletionRequest,
    TokenizeResponse,
)
97
from vllm.entrypoints.utils import _validate_truncation_size, sanitize_message
98
from vllm.exceptions import VLLMValidationError
99
from vllm.inputs.data import PromptType, TokensPrompt
100
101
102
103
104
from vllm.inputs.parse import (
    PromptComponents,
    get_prompt_components,
    is_explicit_encoder_decoder_prompt,
)
105
from vllm.logger import init_logger
106
from vllm.logprobs import Logprob, PromptLogprobs
107
from vllm.lora.request import LoRARequest
108
from vllm.multimodal import MultiModalDataDict
109
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
110
from vllm.pooling_params import PoolingParams
111
from vllm.reasoning import ReasoningParser, ReasoningParserManager
112
from vllm.renderers import RendererLike
113
from vllm.sampling_params import BeamSearchParams, SamplingParams
114
from vllm.tokenizers import TokenizerLike
115
from vllm.tool_parsers import ToolParser, ToolParserManager
116
117
118
119
120
from vllm.tracing import (
    contains_trace_headers,
    extract_trace_headers,
    log_tracing_disabled_warning,
)
121
from vllm.utils import random_uuid
122
from vllm.utils.async_utils import (
123
    AsyncMicrobatchTokenizer,
124
    collect_from_async_generator,
125
126
    merge_async_iterators,
)
127
from vllm.v1.engine import EngineCoreRequest
128

129
130
131
132
133
134
135
136
137

class GenerationError(Exception):
    """raised when finish_reason indicates internal server error (500)"""

    def __init__(self, message: str = "Internal server error"):
        super().__init__(message)
        self.status_code = HTTPStatus.INTERNAL_SERVER_ERROR


138
139
logger = init_logger(__name__)

140
141
CompletionLikeRequest: TypeAlias = (
    CompletionRequest
142
    | TokenizeCompletionRequest
143
144
    | DetokenizeRequest
    | EmbeddingCompletionRequest
145
    | ClassificationCompletionRequest
146
    | RerankRequest
147
    | ScoreRequest
148
    | PoolingCompletionRequest
149
)
150

151
ChatLikeRequest: TypeAlias = (
152
153
    ChatCompletionRequest
    | TokenizeChatRequest
154
    | EmbeddingChatRequest
155
    | ClassificationChatRequest
156
    | PoolingChatRequest
157
158
159
160
161
162
163
164
)
SpeechToTextRequest: TypeAlias = TranscriptionRequest | TranslationRequest
AnyRequest: TypeAlias = (
    CompletionLikeRequest
    | ChatLikeRequest
    | SpeechToTextRequest
    | ResponsesRequest
    | IOProcessorRequest
165
    | GenerateRequest
166
167
168
169
170
171
)

AnyResponse: TypeAlias = (
    CompletionResponse
    | ChatCompletionResponse
    | EmbeddingResponse
172
    | EmbeddingBytesResponse
173
174
175
176
177
    | TranscriptionResponse
    | TokenizeResponse
    | PoolingResponse
    | ClassificationResponse
    | ScoreResponse
178
    | GenerateResponse
179
)
180

181

182
183
184
RequestT = TypeVar("RequestT", bound=AnyRequest)


185
@dataclass(kw_only=True)
186
class ServeContext(Generic[RequestT]):
187
    request: RequestT
188
    raw_request: Request | None = None
189
190
    model_name: str
    request_id: str
191
    created_time: int = field(default_factory=lambda: int(time.time()))
192
    lora_request: LoRARequest | None = None
193
    engine_prompts: list[TokensPrompt] | None = None
194

195
196
197
198
    result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
        None
    )
    final_res_batch: list[PoolingRequestOutput] = field(default_factory=list)
199

200
    model_config = ConfigDict(arbitrary_types_allowed=True)
201
202


203
class OpenAIServing:
204
205
206
207
    request_id_prefix: ClassVar[str] = """
    A short string prepended to every request’s ID (e.g. "embd", "classify")
    so you can easily tell “this ID came from Embedding vs Classification.”
    """
208

209
210
    def __init__(
        self,
211
        engine_client: EngineClient,
212
        models: OpenAIServingModels,
213
        *,
214
        request_logger: RequestLogger | None,
215
        return_tokens_as_token_ids: bool = False,
216
        log_error_stack: bool = False,
217
    ):
218
219
        super().__init__()

220
        self.engine_client = engine_client
221

222
        self.models = models
223

224
        self.request_logger = request_logger
225
        self.return_tokens_as_token_ids = return_tokens_as_token_ids
226

227
        self._async_tokenizer_pool: dict[TokenizerLike, AsyncMicrobatchTokenizer] = {}
228
        self.log_error_stack = log_error_stack
229

230
        self.input_processor = self.models.input_processor
231
        self.io_processor = self.models.io_processor
232
        self.renderer = self.models.renderer
233
234
235
        self.model_config = self.models.model_config
        self.max_model_len = self.model_config.max_model_len

236
    def _get_tool_parser(
237
        self, tool_parser_name: str | None = None, enable_auto_tools: bool = False
238
    ) -> Callable[[TokenizerLike], ToolParser] | None:
239
240
241
242
        """Get the tool parser based on the name."""
        parser = None
        if not enable_auto_tools or tool_parser_name is None:
            return parser
243
        logger.info('"auto" tool choice has been enabled.')
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263

        try:
            if tool_parser_name == "pythonic" and self.model_config.model.startswith(
                "meta-llama/Llama-3.2"
            ):
                logger.warning(
                    "Llama3.2 models may struggle to emit valid pythonic tool calls"
                )
            parser = ToolParserManager.get_tool_parser(tool_parser_name)
        except Exception as e:
            raise TypeError(
                "Error: --enable-auto-tool-choice requires "
                f"tool_parser:'{tool_parser_name}' which has not "
                "been registered"
            ) from e
        return parser

    def _get_reasoning_parser(
        self,
        reasoning_parser_name: str,
264
    ) -> Callable[[TokenizerLike], ReasoningParser] | None:
265
266
267
268
269
270
271
272
273
274
275
        """Get the reasoning parser based on the name."""
        parser = None
        if not reasoning_parser_name:
            return None
        try:
            parser = ReasoningParserManager.get_reasoning_parser(reasoning_parser_name)
            assert parser is not None
        except Exception as e:
            raise TypeError(f"{reasoning_parser_name=} has not been registered") from e
        return parser

276
    async def reset_mm_cache(self) -> None:
277
        self.input_processor.clear_mm_cache()
278
279
        await self.engine_client.reset_mm_cache()

280
281
282
283
284
    async def beam_search(
        self,
        prompt: PromptType,
        request_id: str,
        params: BeamSearchParams,
285
        lora_request: LoRARequest | None = None,
286
        trace_headers: Mapping[str, str] | None = None,
287
288
289
290
291
292
293
294
    ) -> AsyncGenerator[RequestOutput, None]:
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        ignore_eos = params.ignore_eos
        temperature = params.temperature
        length_penalty = params.length_penalty
        include_stop_str_in_output = params.include_stop_str_in_output

295
296
        input_processor = self.input_processor
        tokenizer = input_processor.tokenizer
297
        if tokenizer is None:
298
299
300
301
            raise VLLMValidationError(
                "You cannot use beam search when `skip_tokenizer_init=True`",
                parameter="skip_tokenizer_init",
                value=True,
302
303
304
305
306
307
308
            )

        eos_token_id: int = tokenizer.eos_token_id  # type: ignore

        if is_explicit_encoder_decoder_prompt(prompt):
            raise NotImplementedError

309
        prompt_text: str | None
310
        prompt_token_ids: list[int]
311
        multi_modal_data: MultiModalDataDict | None
312
313
314
315
316
317
318
319
320
        if isinstance(prompt, str):
            prompt_text = prompt
            prompt_token_ids = []
            multi_modal_data = None
        else:
            prompt_text = prompt.get("prompt")  # type: ignore
            prompt_token_ids = prompt.get("prompt_token_ids", [])  # type: ignore
            multi_modal_data = prompt.get("multi_modal_data")  # type: ignore

321
322
323
324
325
326
327
328
329
330
        mm_processor_kwargs: dict[str, Any] | None = None

        # This is a workaround to fix multimodal beam search; this is a
        # bandaid fix for 2 small problems:
        # 1. Multi_modal_data on the processed_inputs currently resolves to
        #    `None`.
        # 2. preprocessing above expands the multimodal placeholders. However,
        #    this happens again in generation, so the double expansion causes
        #    a mismatch.
        # TODO - would be ideal to handle this more gracefully.
331
332
333
334
335

        tokenized_length = len(prompt_token_ids)

        sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)

336
        logprobs_num = 2 * beam_width
337
        beam_search_params = SamplingParams(
338
            logprobs=logprobs_num,
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
            max_tokens=1,
            temperature=temperature,
        )
        all_beams = [
            BeamSearchSequence(
                tokens=prompt_token_ids,
                cum_logprob=0,
                logprobs=[],
                multi_modal_data=multi_modal_data,
                mm_processor_kwargs=mm_processor_kwargs,
                lora_request=lora_request,
            )
        ]
        completed = []

        for _ in range(max_tokens):
            prompts_batch, lora_req_batch = zip(
                *[
                    (
358
                        TokensPrompt(
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
                            prompt_token_ids=beam.tokens,
                            multi_modal_data=beam.multi_modal_data,
                            mm_processor_kwargs=beam.mm_processor_kwargs,
                        ),
                        beam.lora_request,
                    )
                    for beam in all_beams
                ]
            )

            tasks = []
            request_id_batch = f"{request_id}-{random_uuid()}"

            for i, (individual_prompt, lora_req) in enumerate(
                zip(prompts_batch, lora_req_batch)
            ):
                request_id_item = f"{request_id_batch}-beam-{i}"
                task = asyncio.create_task(
                    collect_from_async_generator(
                        self.engine_client.generate(
                            individual_prompt,
                            beam_search_params,
                            request_id_item,
                            lora_request=lora_req,
383
                            trace_headers=trace_headers,
384
385
386
387
388
389
390
391
                        )
                    )
                )
                tasks.append(task)

            output = [x[0] for x in await asyncio.gather(*tasks)]

            new_beams = []
392
393
394
395
396
397
398
399
            # Store all new tokens generated by beam
            all_beams_token_id = []
            # Store the cumulative probability of all tokens
            # generated by beam search
            all_beams_logprob = []
            # Iterate through all beam inference results
            for i, result in enumerate(output):
                current_beam = all_beams[i]
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422

                # check for error finish reason and abort beam search
                if result.outputs[0].finish_reason == "error":
                    # yield error output and terminate beam search
                    yield RequestOutput(
                        request_id=request_id,
                        prompt=prompt_text,
                        outputs=[
                            CompletionOutput(
                                index=0,
                                text="",
                                token_ids=[],
                                cumulative_logprob=None,
                                logprobs=None,
                                finish_reason="error",
                            )
                        ],
                        finished=True,
                        prompt_token_ids=prompt_token_ids,
                        prompt_logprobs=None,
                    )
                    return

423
424
                if result.outputs[0].logprobs is not None:
                    logprobs = result.outputs[0].logprobs[0]
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
                    all_beams_token_id.extend(list(logprobs.keys()))
                    all_beams_logprob.extend(
                        [
                            current_beam.cum_logprob + obj.logprob
                            for obj in logprobs.values()
                        ]
                    )

            # Handle the token for the end of sentence (EOS)
            all_beams_token_id = np.array(all_beams_token_id)
            all_beams_logprob = np.array(all_beams_logprob)

            if not ignore_eos:
                # Get the index position of eos token in all generated results
                eos_idx = np.where(all_beams_token_id == eos_token_id)[0]
                for idx in eos_idx:
                    current_beam = all_beams[idx // logprobs_num]
                    result = output[idx // logprobs_num]
                    assert result.outputs[0].logprobs is not None
                    logprobs_entry = result.outputs[0].logprobs[0]
                    completed.append(
                        BeamSearchSequence(
                            tokens=current_beam.tokens + [eos_token_id]
                            if include_stop_str_in_output
                            else current_beam.tokens,
                            logprobs=current_beam.logprobs + [logprobs_entry],
                            cum_logprob=float(all_beams_logprob[idx]),
                            finish_reason="stop",
                            stop_reason=eos_token_id,
                        )
                    )
                # After processing, set the log probability of the eos condition
                # to negative infinity.
                all_beams_logprob[eos_idx] = -np.inf

            # Processing non-EOS tokens
            # Get indices of the top beam_width probabilities
            topn_idx = np.argpartition(np.negative(all_beams_logprob), beam_width)[
                :beam_width
            ]

            for idx in topn_idx:
                current_beam = all_beams[idx // logprobs_num]
                result = output[idx // logprobs_num]
                token_id = int(all_beams_token_id[idx])
                assert result.outputs[0].logprobs is not None
                logprobs_entry = result.outputs[0].logprobs[0]
                new_beams.append(
                    BeamSearchSequence(
                        tokens=current_beam.tokens + [token_id],
                        logprobs=current_beam.logprobs + [logprobs_entry],
                        lora_request=current_beam.lora_request,
                        cum_logprob=float(all_beams_logprob[idx]),
                        multi_modal_data=current_beam.multi_modal_data,
                        mm_processor_kwargs=current_beam.mm_processor_kwargs,
                    )
                )

            all_beams = new_beams
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517

        completed.extend(all_beams)
        sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
        best_beams = sorted_completed[:beam_width]

        for beam in best_beams:
            if beam.tokens[-1] == eos_token_id and not ignore_eos:
                # Skip the eos token in the text.
                tokens = beam.tokens[tokenized_length:-1]
            else:
                tokens = beam.tokens[tokenized_length:]
            beam.text = tokenizer.decode(tokens)

        yield RequestOutput(
            request_id=request_id,
            prompt=prompt_text,
            outputs=[
                CompletionOutput(
                    text=beam.text,  # type: ignore
                    cumulative_logprob=beam.cum_logprob,
                    token_ids=beam.tokens[tokenized_length:],
                    index=i,
                    logprobs=beam.logprobs,
                    finish_reason=beam.finish_reason
                    if beam.finish_reason is not None
                    else "length",
                    stop_reason=beam.stop_reason,
                )
                for (i, beam) in enumerate(best_beams)
            ],
            finished=True,
            prompt_token_ids=prompt_token_ids,
            prompt_logprobs=None,
        )
518

519
    def _get_completion_renderer(self) -> BaseRenderer:
520
521
522
523
524
525
        """
        Get a Renderer instance with the provided tokenizer.
        Uses shared async tokenizer pool for efficiency.
        """
        return CompletionRenderer(
            model_config=self.model_config,
526
            tokenizer=self.renderer.tokenizer,
527
528
            async_tokenizer_pool=self._async_tokenizer_pool,
        )
529

530
531
532
533
534
535
536
537
538
539
540
541
542
    def _build_render_config(
        self,
        request: Any,
    ) -> RenderConfig:
        """
        Build and return a `RenderConfig` for an endpoint.

        Used by the renderer to control how prompts are prepared
        (e.g., tokenization and length handling). Endpoints should
        implement this with logic appropriate to their request type.
        """
        raise NotImplementedError

543
544
    def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer:
        """
545
        Return (and cache) an `AsyncMicrobatchTokenizer` bound to the
546
547
548
549
550
551
552
        given tokenizer.
        """
        async_tokenizer = self._async_tokenizer_pool.get(tokenizer)
        if async_tokenizer is None:
            async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
            self._async_tokenizer_pool[tokenizer] = async_tokenizer
        return async_tokenizer
553

554
555
556
    async def _preprocess(
        self,
        ctx: ServeContext,
557
    ) -> ErrorResponse | None:
558
559
560
561
562
563
564
565
566
        """
        Default preprocessing hook. Subclasses may override
        to prepare `ctx` (classification, embedding, etc.).
        """
        return None

    def _build_response(
        self,
        ctx: ServeContext,
567
    ) -> AnyResponse | ErrorResponse:
568
569
570
571
572
573
574
575
576
        """
        Default response builder. Subclass may override this method
        to return the appropriate response object.
        """
        return self.create_error_response("unimplemented endpoint")

    async def handle(
        self,
        ctx: ServeContext,
577
    ) -> AnyResponse | ErrorResponse:
578
        async for response in self._pipeline(ctx):
579
580
581
582
583
584
585
            return response

        return self.create_error_response("No response yielded from pipeline")

    async def _pipeline(
        self,
        ctx: ServeContext,
586
    ) -> AsyncGenerator[AnyResponse | ErrorResponse, None]:
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
        """Execute the request processing pipeline yielding responses."""
        if error := await self._check_model(ctx.request):
            yield error
        if error := self._validate_request(ctx):
            yield error

        preprocess_ret = await self._preprocess(ctx)
        if isinstance(preprocess_ret, ErrorResponse):
            yield preprocess_ret

        generators_ret = await self._prepare_generators(ctx)
        if isinstance(generators_ret, ErrorResponse):
            yield generators_ret

        collect_ret = await self._collect_batch(ctx)
        if isinstance(collect_ret, ErrorResponse):
            yield collect_ret

        yield self._build_response(ctx)

607
    def _validate_request(self, ctx: ServeContext) -> ErrorResponse | None:
608
        truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None)
609

610
611
612
613
        if (
            truncate_prompt_tokens is not None
            and truncate_prompt_tokens > self.max_model_len
        ):
614
615
616
            return self.create_error_response(
                "truncate_prompt_tokens value is "
                "greater than max_model_len."
617
618
                " Please, select a smaller truncation size."
            )
619
620
        return None

621
622
623
    def _create_pooling_params(
        self,
        ctx: ServeContext,
624
    ) -> PoolingParams | ErrorResponse:
625
626
        if not hasattr(ctx.request, "to_pooling_params"):
            return self.create_error_response(
627
628
                "Request type does not support pooling parameters"
            )
629
630
631

        return ctx.request.to_pooling_params()

632
633
634
    async def _prepare_generators(
        self,
        ctx: ServeContext,
635
    ) -> ErrorResponse | None:
636
        """Schedule the request and get the result generator."""
637
        generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
638
639

        try:
640
641
642
643
644
            trace_headers = (
                None
                if ctx.raw_request is None
                else await self._get_trace_headers(ctx.raw_request.headers)
            )
645

646
647
648
            pooling_params = self._create_pooling_params(ctx)
            if isinstance(pooling_params, ErrorResponse):
                return pooling_params
649
650

            if ctx.engine_prompts is None:
651
                return self.create_error_response("Engine prompts not available")
652
653
654
655

            for i, engine_prompt in enumerate(ctx.engine_prompts):
                request_id_item = f"{ctx.request_id}-{i}"

656
657
                self._log_inputs(
                    request_id_item,
658
                    engine_prompt,
659
660
661
                    params=pooling_params,
                    lora_request=ctx.lora_request,
                )
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678

                generator = self.engine_client.encode(
                    engine_prompt,
                    pooling_params,
                    request_id_item,
                    lora_request=ctx.lora_request,
                    trace_headers=trace_headers,
                    priority=getattr(ctx.request, "priority", 0),
                )

                generators.append(generator)

            ctx.result_generator = merge_async_iterators(*generators)

            return None

        except Exception as e:
679
            return self.create_error_response(e)
680
681
682
683

    async def _collect_batch(
        self,
        ctx: ServeContext,
684
    ) -> ErrorResponse | None:
685
686
687
        """Collect batch results from the result generator."""
        try:
            if ctx.engine_prompts is None:
688
                return self.create_error_response("Engine prompts not available")
689
690

            num_prompts = len(ctx.engine_prompts)
691
            final_res_batch: list[PoolingRequestOutput | None]
692
693
694
            final_res_batch = [None] * num_prompts

            if ctx.result_generator is None:
695
                return self.create_error_response("Result generator not available")
696
697
698
699
700
701

            async for i, res in ctx.result_generator:
                final_res_batch[i] = res

            if None in final_res_batch:
                return self.create_error_response(
702
703
                    "Failed to generate results for all prompts"
                )
704

705
            ctx.final_res_batch = [res for res in final_res_batch if res is not None]
706
707
708
709

            return None

        except Exception as e:
710
            return self.create_error_response(e)
711

712
    def create_error_response(
713
        self,
714
        message: str | Exception,
715
716
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
717
        param: str | None = None,
718
    ) -> ErrorResponse:
719
720
721
722
723
        exc: Exception | None = None

        if isinstance(message, Exception):
            exc = message

724
            from vllm.exceptions import VLLMValidationError
725
726
727
728
729

            if isinstance(exc, VLLMValidationError):
                err_type = "BadRequestError"
                status_code = HTTPStatus.BAD_REQUEST
                param = exc.parameter
730
            elif isinstance(exc, (ValueError, TypeError, RuntimeError, OverflowError)):
731
732
733
734
                # Common validation errors from user input
                err_type = "BadRequestError"
                status_code = HTTPStatus.BAD_REQUEST
                param = None
735
736
737
738
            elif isinstance(exc, NotImplementedError):
                err_type = "NotImplementedError"
                status_code = HTTPStatus.NOT_IMPLEMENTED
                param = None
739
740
741
742
743
744
745
746
747
748
749
750
            elif exc.__class__.__name__ == "TemplateError":
                # jinja2.TemplateError (avoid importing jinja2)
                err_type = "BadRequestError"
                status_code = HTTPStatus.BAD_REQUEST
                param = None
            else:
                err_type = "InternalServerError"
                status_code = HTTPStatus.INTERNAL_SERVER_ERROR
                param = None

            message = str(exc)

751
752
753
754
755
756
        if self.log_error_stack:
            exc_type, _, _ = sys.exc_info()
            if exc_type is not None:
                traceback.print_exc()
            else:
                traceback.print_stack()
757

758
        return ErrorResponse(
759
            error=ErrorInfo(
760
                message=sanitize_message(message),
761
762
763
764
                type=err_type,
                code=status_code.value,
                param=param,
            )
765
        )
766

767
    def create_streaming_error_response(
768
        self,
769
        message: str | Exception,
770
771
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
772
        param: str | None = None,
773
    ) -> str:
774
        json_str = json.dumps(
775
            self.create_error_response(
776
777
778
779
                message=message,
                err_type=err_type,
                status_code=status_code,
                param=param,
780
781
            ).model_dump()
        )
782
783
        return json_str

784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
    def _raise_if_error(self, finish_reason: str | None, request_id: str) -> None:
        """Raise GenerationError if finish_reason indicates an error."""
        if finish_reason == "error":
            logger.error(
                "Request %s failed with an internal error during generation",
                request_id,
            )
            raise GenerationError("Internal server error")

    def _convert_generation_error_to_response(
        self, e: GenerationError
    ) -> ErrorResponse:
        """Convert GenerationError to ErrorResponse."""
        return self.create_error_response(
            str(e),
            err_type="InternalServerError",
            status_code=e.status_code,
        )

    def _convert_generation_error_to_streaming_response(
        self, e: GenerationError
    ) -> str:
        """Convert GenerationError to streaming error response."""
        return self.create_streaming_error_response(
            str(e),
            err_type="InternalServerError",
            status_code=e.status_code,
        )

813
    async def _check_model(
814
815
        self,
        request: AnyRequest,
816
    ) -> ErrorResponse | None:
817
818
        error_response = None

819
        if self._is_model_supported(request.model):
820
            return None
821
        if request.model in self.models.lora_requests:
822
            return None
823
824
825
826
827
        if (
            envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING
            and request.model
            and (load_result := await self.models.resolve_lora(request.model))
        ):
828
829
            if isinstance(load_result, LoRARequest):
                return None
830
831
832
833
            if (
                isinstance(load_result, ErrorResponse)
                and load_result.error.code == HTTPStatus.BAD_REQUEST.value
            ):
834
835
836
                error_response = load_result

        return error_response or self.create_error_response(
837
838
            message=f"The model `{request.model}` does not exist.",
            err_type="NotFoundError",
839
            status_code=HTTPStatus.NOT_FOUND,
840
            param="model",
841
        )
842

843
    def _get_active_default_mm_loras(self, request: AnyRequest) -> LoRARequest | None:
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
        """Determine if there are any active default multimodal loras."""
        # TODO: Currently this is only enabled for chat completions
        # to be better aligned with only being enabled for .generate
        # when run offline. It would be nice to support additional
        # tasks types in the future.
        message_types = self._get_message_types(request)
        default_mm_loras = set()

        for lora in self.models.lora_requests.values():
            # Best effort match for default multimodal lora adapters;
            # There is probably a better way to do this, but currently
            # this matches against the set of 'types' in any content lists
            # up until '_', e.g., to match audio_url -> audio
            if lora.lora_name in message_types:
                default_mm_loras.add(lora)

        # Currently only support default modality specific loras if
        # we have exactly one lora matched on the request.
        if len(default_mm_loras) == 1:
            return default_mm_loras.pop()
        return None

866
    def _maybe_get_adapters(
867
868
869
        self,
        request: AnyRequest,
        supports_default_mm_loras: bool = False,
870
    ) -> LoRARequest | None:
871
        if request.model in self.models.lora_requests:
872
            return self.models.lora_requests[request.model]
873
874
875
876
877
878

        # Currently only support default modality specific loras
        # if we have exactly one lora matched on the request.
        if supports_default_mm_loras:
            default_mm_lora = self._get_active_default_mm_loras(request)
            if default_mm_lora is not None:
879
                return default_mm_lora
880
881

        if self._is_model_supported(request.model):
882
            return None
883

884
        # if _check_model has been called earlier, this will be unreachable
885
        raise ValueError(f"The model `{request.model}` does not exist.")
886

887
888
889
890
891
892
893
894
895
896
    def _get_message_types(self, request: AnyRequest) -> set[str]:
        """Retrieve the set of types from message content dicts up
        until `_`; we use this to match potential multimodal data
        with default per modality loras.
        """
        message_types: set[str] = set()

        if not hasattr(request, "messages"):
            return message_types

897
898
899
900
901
        messages = request.messages
        if messages is None or isinstance(messages, (str, bytes)):
            return message_types

        for message in messages:
902
903
904
905
906
            if (
                isinstance(message, dict)
                and "content" in message
                and isinstance(message["content"], list)
            ):
907
908
909
910
911
                for content_dict in message["content"]:
                    if "type" in content_dict:
                        message_types.add(content_dict["type"].split("_")[0])
        return message_types

912
    async def _normalize_prompt_text_to_input(
913
914
915
        self,
        request: AnyRequest,
        prompt: str,
916
        tokenizer: TokenizerLike,
917
        add_special_tokens: bool,
918
    ) -> TokensPrompt:
919
920
        async_tokenizer = self._get_async_tokenizer(tokenizer)

921
922
923
924
        if (
            self.model_config.encoder_config is not None
            and self.model_config.encoder_config.get("do_lower_case", False)
        ):
925
926
            prompt = prompt.lower()

927
        truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
928

929
        if truncate_prompt_tokens is None:
930
            encoded = await async_tokenizer(
931
932
                prompt, add_special_tokens=add_special_tokens
            )
933
934
        elif truncate_prompt_tokens < 0:
            # Negative means we cap at the model's max length
935
936
937
938
            encoded = await async_tokenizer(
                prompt,
                add_special_tokens=add_special_tokens,
                truncation=True,
939
940
                max_length=self.max_model_len,
            )
941
        else:
942
943
944
945
            encoded = await async_tokenizer(
                prompt,
                add_special_tokens=add_special_tokens,
                truncation=True,
946
947
                max_length=truncate_prompt_tokens,
            )
948
949
950
951
952
953

        input_ids = encoded.input_ids
        input_text = prompt

        return self._validate_input(request, input_ids, input_text)

954
    async def _normalize_prompt_tokens_to_input(
955
956
        self,
        request: AnyRequest,
957
        prompt_ids: list[int],
958
        tokenizer: TokenizerLike | None,
959
    ) -> TokensPrompt:
960
        truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
961

962
        if truncate_prompt_tokens is None:
963
            input_ids = prompt_ids
964
        elif truncate_prompt_tokens < 0:
965
            input_ids = prompt_ids[-self.max_model_len :]
966
967
968
        else:
            input_ids = prompt_ids[-truncate_prompt_tokens:]

969
970
971
972
973
        if tokenizer is None:
            input_text = ""
        else:
            async_tokenizer = self._get_async_tokenizer(tokenizer)
            input_text = await async_tokenizer.decode(input_ids)
974

975
976
977
978
        return self._validate_input(request, input_ids, input_text)

    def _validate_input(
        self,
979
        request: object,
980
        input_ids: list[int],
981
        input_text: str,
982
    ) -> TokensPrompt:
983
984
        token_num = len(input_ids)

985
986
        # Note: EmbeddingRequest, ClassificationRequest,
        # and ScoreRequest doesn't have max_tokens
987
        if isinstance(
988
            request,
989
990
991
            (
                EmbeddingChatRequest,
                EmbeddingCompletionRequest,
992
993
994
                ScoreDataRequest,
                ScoreTextRequest,
                ScoreQueriesDocumentsRequest,
995
                RerankRequest,
996
997
                ClassificationCompletionRequest,
                ClassificationChatRequest,
998
999
            ),
        ):
1000
1001
            # Note: input length can be up to the entire model context length
            # since these requests don't generate tokens.
1002
            if token_num > self.max_model_len:
1003
                operations: dict[type[AnyRequest], str] = {
1004
1005
1006
                    ScoreDataRequest: "score",
                    ScoreTextRequest: "score",
                    ScoreQueriesDocumentsRequest: "score",
1007
1008
                    ClassificationCompletionRequest: "classification",
                    ClassificationChatRequest: "classification",
1009
                }
1010
                operation = operations.get(type(request), "embedding generation")
1011
                raise VLLMValidationError(
1012
1013
                    f"This model's maximum context length is "
                    f"{self.max_model_len} tokens. However, you requested "
1014
                    f"{token_num} tokens in the input for {operation}. "
1015
1016
1017
                    f"Please reduce the length of the input.",
                    parameter="input_tokens",
                    value=token_num,
1018
                )
1019
            return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
1020

1021
1022
        # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
        # and does not require model context length validation
1023
        if isinstance(
1024
1025
            request,
            (TokenizeCompletionRequest, TokenizeChatRequest, DetokenizeRequest),
1026
        ):
1027
            return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
1028

1029
1030
1031
1032
1033
        # chat completion endpoint supports max_completion_tokens
        if isinstance(request, ChatCompletionRequest):
            # TODO(#9845): remove max_tokens when field dropped from OpenAI API
            max_tokens = request.max_completion_tokens or request.max_tokens
        else:
1034
            max_tokens = getattr(request, "max_tokens", None)
1035
1036
1037
1038

        # Note: input length can be up to model context length - 1 for
        # completion-like requests.
        if token_num >= self.max_model_len:
1039
            raise VLLMValidationError(
1040
                f"This model's maximum context length is "
1041
1042
                f"{self.max_model_len} tokens. However, your request has "
                f"{token_num} input tokens. Please reduce the length of "
1043
1044
1045
                "the input messages.",
                parameter="input_tokens",
                value=token_num,
1046
            )
1047

1048
        if max_tokens is not None and token_num + max_tokens > self.max_model_len:
1049
            raise VLLMValidationError(
1050
1051
1052
1053
                "'max_tokens' or 'max_completion_tokens' is too large: "
                f"{max_tokens}. This model's maximum context length is "
                f"{self.max_model_len} tokens and your request has "
                f"{token_num} input tokens ({max_tokens} > {self.max_model_len}"
1054
1055
1056
                f" - {token_num}).",
                parameter="max_tokens",
                value=max_tokens,
1057
            )
1058

1059
        return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
1060

1061
    async def _tokenize_prompt_input_async(
1062
1063
        self,
        request: AnyRequest,
1064
        tokenizer: TokenizerLike,
1065
        prompt_input: str | list[int],
1066
        add_special_tokens: bool = True,
1067
    ) -> TokensPrompt:
1068
        """
1069
        A simpler implementation that tokenizes a single prompt input.
1070
        """
1071
        async for result in self._tokenize_prompt_inputs_async(
1072
1073
            request,
            tokenizer,
1074
            [prompt_input],
1075
            add_special_tokens=add_special_tokens,
1076
1077
1078
        ):
            return result
        raise ValueError("No results yielded from tokenization")
1079

1080
    async def _tokenize_prompt_inputs_async(
1081
1082
        self,
        request: AnyRequest,
1083
        tokenizer: TokenizerLike,
1084
        prompt_inputs: Iterable[str | list[int]],
1085
        add_special_tokens: bool = True,
1086
    ) -> AsyncGenerator[TokensPrompt, None]:
1087
        """
1088
        A simpler implementation that tokenizes multiple prompt inputs.
1089
        """
1090
1091
        for prompt in prompt_inputs:
            if isinstance(prompt, str):
1092
                yield await self._normalize_prompt_text_to_input(
1093
                    request,
1094
1095
                    prompt=prompt,
                    tokenizer=tokenizer,
1096
1097
1098
                    add_special_tokens=add_special_tokens,
                )
            else:
1099
                yield await self._normalize_prompt_tokens_to_input(
1100
                    request,
1101
1102
                    prompt_ids=prompt,
                    tokenizer=tokenizer,
1103
1104
                )

1105
1106
    def _validate_chat_template(
        self,
1107
1108
        request_chat_template: str | None,
        chat_template_kwargs: dict[str, Any] | None,
1109
        trust_request_chat_template: bool,
1110
    ) -> ErrorResponse | None:
1111
        if not trust_request_chat_template and (
1112
1113
1114
1115
1116
1117
            request_chat_template is not None
            or (
                chat_template_kwargs
                and chat_template_kwargs.get("chat_template") is not None
            )
        ):
1118
1119
1120
            return self.create_error_response(
                "Chat template is passed with request, but "
                "--trust-request-chat-template is not set. "
1121
1122
                "Refused request with untrusted chat template."
            )
1123
1124
        return None

1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
    @staticmethod
    def _prepare_extra_chat_template_kwargs(
        request_chat_template_kwargs: dict[str, Any] | None = None,
        default_chat_template_kwargs: dict[str, Any] | None = None,
    ) -> dict[str, Any]:
        """Helper to merge server-default and request-specific chat template kwargs."""
        request_chat_template_kwargs = request_chat_template_kwargs or {}
        if default_chat_template_kwargs is None:
            return request_chat_template_kwargs
        # Apply server defaults first, then request kwargs override.
        return default_chat_template_kwargs | request_chat_template_kwargs

1137
1138
    async def _preprocess_chat(
        self,
1139
        request: ChatLikeRequest | ResponsesRequest,
1140
        renderer: RendererLike,
1141
        messages: list[ChatCompletionMessageParam],
1142
        chat_template: str | None,
1143
        chat_template_content_format: ChatTemplateContentFormatOption,
1144
1145
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
1146
1147
1148
        tool_dicts: list[dict[str, Any]] | None = None,
        documents: list[dict[str, str]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
1149
        default_chat_template_kwargs: dict[str, Any] | None = None,
1150
        tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
1151
        add_special_tokens: bool = False,
1152
    ) -> tuple[list[ConversationMessage], list[TokensPrompt]]:
1153
1154
1155
1156
1157
1158
1159
1160
1161
        chat_template_kwargs = {
            "chat_template": chat_template,
            "add_generation_prompt": add_generation_prompt,
            "continue_final_message": continue_final_message,
            "tools": tool_dicts,
            "documents": documents,
            **(chat_template_kwargs or {}),
        }
        chat_template_kwargs = self._prepare_extra_chat_template_kwargs(
1162
1163
1164
            chat_template_kwargs,
            default_chat_template_kwargs,
        )
1165

1166
1167
1168
1169
1170
        # Use the async tokenizer in `OpenAIServing` if possible.
        # Later we can move it into the renderer so that we can return both
        # text and token IDs in the same prompt from `render_messages_async`
        # which is used for logging and `enable_response_messages`.
        from vllm.tokenizers.mistral import MistralTokenizer
1171

1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
        conversation, engine_prompt = await renderer.render_messages_async(
            messages,
            chat_template_content_format=chat_template_content_format,
            tokenize=(
                chat_template_kwargs.pop("tokenize", False)
                or isinstance(renderer.tokenizer, MistralTokenizer)
            ),
            **chat_template_kwargs,
        )

        if "prompt_token_ids" not in engine_prompt:
            extra_data = engine_prompt
            engine_prompt = await self._tokenize_prompt_input_async(
                request,
                renderer.get_tokenizer(),
                engine_prompt["prompt"],
                add_special_tokens=add_special_tokens,
1189
            )
1190
1191
            # Fill in other keys like MM data
            engine_prompt.update(extra_data)  # type: ignore
1192
        else:
1193
1194
1195
1196
            self._validate_input(
                request=request,
                input_ids=engine_prompt["prompt_token_ids"],  # type: ignore
                input_text="",
1197
1198
            )

1199
1200
1201
1202
1203
1204
        engine_prompt = cast(TokensPrompt, engine_prompt)

        if request.mm_processor_kwargs is not None:
            engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
        if (cache_salt := getattr(request, "cache_salt", None)) is not None:
            engine_prompt["cache_salt"] = cache_salt
1205

1206
1207
1208
        # tool parsing is done only if a tool_parser has been set and if
        # tool_choice is not "none" (if tool_choice is "none" but a tool_parser
        # is set, we want to prevent parsing a tool_call hallucinated by the LLM
1209
1210
1211
        should_parse_tools = tool_parser is not None and (
            hasattr(request, "tool_choice") and request.tool_choice != "none"
        )
1212
1213

        if should_parse_tools:
1214
1215
1216
1217
1218
            if not isinstance(request, ChatCompletionRequest | ResponsesRequest):
                msg = (
                    "Tool usage is only supported for Chat Completions API "
                    "or Responses API requests."
                )
1219
                raise NotImplementedError(msg)
1220

1221
1222
            tokenizer = renderer.get_tokenizer()
            request = tool_parser(tokenizer).adjust_request(request=request)  # type: ignore
1223

1224
        return conversation, [engine_prompt]
1225

1226
1227
1228
1229
    async def _process_inputs(
        self,
        request_id: str,
        engine_prompt: PromptType,
1230
        params: SamplingParams | PoolingParams,
1231
        *,
1232
1233
        lora_request: LoRARequest | None,
        trace_headers: Mapping[str, str] | None,
1234
        priority: int,
1235
        data_parallel_rank: int | None = None,
1236
    ) -> tuple[EngineCoreRequest, dict[str, Any]]:
1237
        """Use the Processor to process inputs for AsyncLLM."""
1238
        tokenization_kwargs: dict[str, Any] = {}
1239
1240
1241
        _validate_truncation_size(
            self.max_model_len, params.truncate_prompt_tokens, tokenization_kwargs
        )
1242

1243
        engine_request = self.input_processor.process_inputs(
1244
1245
            request_id,
            engine_prompt,
1246
            params,
1247
1248
1249
1250
            lora_request=lora_request,
            tokenization_kwargs=tokenization_kwargs,
            trace_headers=trace_headers,
            priority=priority,
1251
            data_parallel_rank=data_parallel_rank,
1252
1253
1254
        )
        return engine_request, tokenization_kwargs

1255
1256
1257
    async def _render_next_turn(
        self,
        request: ResponsesRequest,
1258
        renderer: RendererLike,
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
        messages: list[ResponseInputOutputItem],
        tool_dicts: list[dict[str, Any]] | None,
        tool_parser,
        chat_template: str | None,
        chat_template_content_format: ChatTemplateContentFormatOption,
    ):
        new_messages = construct_input_messages(
            request_input=messages,
        )

1269
        _, engine_prompts = await self._preprocess_chat(
1270
            request,
1271
            renderer,
1272
1273
1274
1275
1276
1277
            new_messages,
            tool_dicts=tool_dicts,
            tool_parser=tool_parser,
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
        )
1278
        return engine_prompts
1279

1280
1281
1282
    async def _generate_with_builtin_tools(
        self,
        request_id: str,
1283
        engine_prompt: TokensPrompt,
1284
1285
        sampling_params: SamplingParams,
        context: ConversationContext,
1286
        lora_request: LoRARequest | None = None,
1287
1288
1289
        priority: int = 0,
        **kwargs,
    ):
1290
1291
        prompt_text, _, _ = self._get_prompt_components(engine_prompt)

1292
        orig_priority = priority
1293
        sub_request = 0
1294
        while True:
1295
1296
            # Ensure that each sub-request has a unique request id.
            sub_request_id = f"{request_id}_{sub_request}"
1297
            self._log_inputs(
1298
                sub_request_id,
1299
                engine_prompt,
1300
1301
1302
                params=sampling_params,
                lora_request=lora_request,
            )
1303
            trace_headers = kwargs.get("trace_headers")
1304
            engine_request, tokenization_kwargs = await self._process_inputs(
1305
                sub_request_id,
1306
1307
                engine_prompt,
                sampling_params,
1308
1309
1310
                lora_request=lora_request,
                trace_headers=trace_headers,
                priority=priority,
1311
            )
1312
1313
1314
1315

            generator = self.engine_client.generate(
                engine_request,
                sampling_params,
1316
                sub_request_id,
1317
1318
                lora_request=lora_request,
                priority=priority,
1319
1320
                prompt_text=prompt_text,
                tokenization_kwargs=tokenization_kwargs,
1321
1322
                **kwargs,
            )
1323

1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
            async for res in generator:
                context.append_output(res)
                # NOTE(woosuk): The stop condition is handled by the engine.
                yield context

            if not context.need_builtin_tool_call():
                # The model did not ask for a tool call, so we're done.
                break

            # Call the tool and update the context with the result.
            tool_output = await context.call_tool()
1335
            context.append_tool_output(tool_output)
1336
1337
1338
1339
1340
1341

            # TODO: uncomment this and enable tool output streaming
            # yield context

            # Create inputs for the next turn.
            # Render the next prompt token ids.
1342
1343
            if isinstance(context, (HarmonyContext, StreamingHarmonyContext)):
                prompt_token_ids = context.render_for_completion()
1344
                engine_prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
1345
            elif isinstance(context, ParsableContext):
1346
                engine_prompts = await self._render_next_turn(
1347
                    context.request,
1348
                    context.renderer,
1349
1350
1351
1352
1353
1354
1355
                    context.parser.response_messages,
                    context.tool_dicts,
                    context.tool_parser_cls,
                    context.chat_template,
                    context.chat_template_content_format,
                )
                engine_prompt = engine_prompts[0]
1356
                prompt_text, _, _ = self._get_prompt_components(engine_prompt)
1357

1358
            # Update the sampling params.
1359
1360
1361
            sampling_params.max_tokens = self.max_model_len - len(
                engine_prompt["prompt_token_ids"]
            )
1362
1363
            # OPTIMIZATION
            priority = orig_priority - 1
1364
            sub_request += 1
1365

1366
1367
    def _get_prompt_components(self, prompt: PromptType) -> PromptComponents:
        return get_prompt_components(prompt)
1368

1369
1370
1371
    def _log_inputs(
        self,
        request_id: str,
1372
        inputs: PromptType,
1373
1374
        params: SamplingParams | PoolingParams | BeamSearchParams | None,
        lora_request: LoRARequest | None,
1375
1376
1377
    ) -> None:
        if self.request_logger is None:
            return
1378

1379
        prompt, prompt_token_ids, prompt_embeds = self._get_prompt_components(inputs)
1380
1381
1382
1383
1384

        self.request_logger.log_inputs(
            request_id,
            prompt,
            prompt_token_ids,
1385
            prompt_embeds,
1386
1387
1388
            params=params,
            lora_request=lora_request,
        )
1389

1390
1391
1392
    async def _get_trace_headers(
        self,
        headers: Headers,
1393
    ) -> Mapping[str, str] | None:
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
        is_tracing_enabled = await self.engine_client.is_tracing_enabled()

        if is_tracing_enabled:
            return extract_trace_headers(headers)

        if contains_trace_headers(headers):
            log_tracing_disabled_warning()

        return None

1404
    @staticmethod
1405
    def _base_request_id(
1406
1407
        raw_request: Request | None, default: str | None = None
    ) -> str | None:
1408
        """Pulls the request id to use from a header, if provided"""
1409
1410
1411
1412
        if raw_request is not None and (
            (req_id := raw_request.headers.get("X-Request-Id")) is not None
        ):
            return req_id
1413

1414
        return random_uuid() if default is None else default
1415

1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
    @staticmethod
    def _get_data_parallel_rank(raw_request: Request | None) -> int | None:
        """Pulls the data parallel rank from a header, if provided"""
        if raw_request is None:
            return None

        rank_str = raw_request.headers.get("X-data-parallel-rank")
        if rank_str is None:
            return None

        try:
            return int(rank_str)
        except ValueError:
            return None

1431
1432
1433
    @staticmethod
    def _parse_tool_calls_from_content(
        request: ResponsesRequest | ChatCompletionRequest,
1434
        tokenizer: TokenizerLike | None,
1435
        enable_auto_tools: bool,
1436
        tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
        content: str | None = None,
    ) -> tuple[list[FunctionCall] | None, str | None]:
        function_calls = list[FunctionCall]()
        if request.tool_choice and isinstance(request.tool_choice, ToolChoiceFunction):
            assert content is not None
            # Forced Function Call
            function_calls.append(
                FunctionCall(name=request.tool_choice.name, arguments=content)
            )
            content = None  # Clear content since tool is called.
        elif request.tool_choice and isinstance(
            request.tool_choice, ChatCompletionNamedToolChoiceParam
        ):
            assert content is not None
            # Forced Function Call
            function_calls.append(
                FunctionCall(name=request.tool_choice.function.name, arguments=content)
            )
            content = None  # Clear content since tool is called.
        elif request.tool_choice == "required":
            assert content is not None
            tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(content)
            function_calls.extend(
                [
                    FunctionCall(
                        name=tool_call.name,
                        arguments=json.dumps(tool_call.parameters, ensure_ascii=False),
                    )
                    for tool_call in tool_calls
                ]
            )
            content = None  # Clear content since tool is called.
        elif (
            tool_parser_cls
            and enable_auto_tools
            and (request.tool_choice == "auto" or request.tool_choice is None)
        ):
1474
1475
1476
1477
1478
            if tokenizer is None:
                raise ValueError(
                    "Tokenizer not available when `skip_tokenizer_init=True`"
                )

1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
            # Automatic Tool Call Parsing
            try:
                tool_parser = tool_parser_cls(tokenizer)
            except RuntimeError as e:
                logger.exception("Error in tool parser creation.")
                raise e
            tool_call_info = tool_parser.extract_tool_calls(
                content if content is not None else "",
                request=request,  # type: ignore
            )
            if tool_call_info is not None and tool_call_info.tools_called:
                # extract_tool_calls() returns a list of tool calls.
                function_calls.extend(
                    FunctionCall(
1493
                        id=tool_call.id,
1494
1495
1496
1497
1498
1499
                        name=tool_call.function.name,
                        arguments=tool_call.function.arguments,
                    )
                    for tool_call in tool_call_info.tool_calls
                )
                content = tool_call_info.content
1500
1501
                if content and content.strip() == "":
                    content = None
1502
1503
1504
1505
1506
1507
            else:
                # No tool calls.
                return None, content

        return function_calls, content

1508
    @staticmethod
1509
1510
1511
    def _get_decoded_token(
        logprob: Logprob,
        token_id: int,
1512
        tokenizer: TokenizerLike | None,
1513
1514
        return_as_token_id: bool = False,
    ) -> str:
1515
1516
1517
        if return_as_token_id:
            return f"token_id:{token_id}"

1518
1519
        if logprob.decoded_token is not None:
            return logprob.decoded_token
1520
1521
1522
1523
1524
1525

        if tokenizer is None:
            raise ValueError(
                "Unable to get tokenizer because `skip_tokenizer_init=True`"
            )

1526
        return tokenizer.decode(token_id)
1527

1528
    def _is_model_supported(self, model_name: str | None) -> bool:
1529
1530
        if not model_name:
            return True
1531
        return self.models.is_base_model(model_name)
1532

1533
1534

def clamp_prompt_logprobs(
1535
1536
    prompt_logprobs: PromptLogprobs | None,
) -> PromptLogprobs | None:
1537
1538
1539
1540
1541
1542
1543
    if prompt_logprobs is None:
        return prompt_logprobs

    for logprob_dict in prompt_logprobs:
        if logprob_dict is None:
            continue
        for logprob_values in logprob_dict.values():
1544
            if logprob_values.logprob == float("-inf"):
1545
1546
                logprob_values.logprob = -9999.0
    return prompt_logprobs