serving_engine.py 54.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
import json
5
import sys
6
import time
7
import traceback
8
from collections.abc import AsyncGenerator, Callable, Iterable, Mapping
9
from concurrent.futures import ThreadPoolExecutor
10
from dataclasses import dataclass, field
11
from http import HTTPStatus
12
from typing import Any, ClassVar, Generic, TypeAlias, TypeVar
13

14
import numpy as np
15
from fastapi import Request
16
17
18
from openai.types.responses import (
    ToolChoiceFunction,
)
19
20
from pydantic import ConfigDict, TypeAdapter
from starlette.datastructures import Headers
21

22
import vllm.envs as envs
23
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
24
from vllm.engine.protocol import EngineClient
25
26
27
28
29
30
31
32
33
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ChatTemplateContentFormatOption,
    ConversationMessage,
    apply_hf_chat_template,
    apply_mistral_chat_template,
    parse_chat_messages_futures,
    resolve_chat_template_content_format,
)
34
35
36
37
38
39
from vllm.entrypoints.context import (
    ConversationContext,
    HarmonyContext,
    ParsableContext,
    StreamingHarmonyContext,
)
40
from vllm.entrypoints.logger import RequestLogger
41
from vllm.entrypoints.openai.protocol import (
42
    ChatCompletionNamedToolChoiceParam,
43
44
45
46
47
48
49
    ChatCompletionRequest,
    ChatCompletionResponse,
    CompletionRequest,
    CompletionResponse,
    DetokenizeRequest,
    ErrorInfo,
    ErrorResponse,
50
51
    FunctionCall,
    FunctionDefinition,
52
    ResponseInputOutputItem,
53
54
55
56
57
58
59
60
    ResponsesRequest,
    TokenizeChatRequest,
    TokenizeCompletionRequest,
    TokenizeResponse,
    TranscriptionRequest,
    TranscriptionResponse,
    TranslationRequest,
)
61
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from vllm.entrypoints.pooling.classify.protocol import (
    ClassificationChatRequest,
    ClassificationCompletionRequest,
    ClassificationRequest,
    ClassificationResponse,
)
from vllm.entrypoints.pooling.embed.protocol import (
    EmbeddingChatRequest,
    EmbeddingCompletionRequest,
    EmbeddingRequest,
    EmbeddingResponse,
)
from vllm.entrypoints.pooling.pooling.protocol import (
    IOProcessorRequest,
    PoolingResponse,
)
from vllm.entrypoints.pooling.score.protocol import (
    RerankRequest,
    ScoreRequest,
    ScoreResponse,
)
83
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig
84
85
86
from vllm.entrypoints.responses_utils import (
    construct_input_messages,
)
87
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
88
from vllm.entrypoints.utils import _validate_truncation_size
89
from vllm.inputs.data import PromptType, TokensPrompt
90
91
92
93
94
from vllm.inputs.parse import (
    PromptComponents,
    get_prompt_components,
    is_explicit_encoder_decoder_prompt,
)
95
from vllm.logger import init_logger
96
from vllm.logprobs import Logprob, PromptLogprobs
97
from vllm.lora.request import LoRARequest
98
from vllm.multimodal import MultiModalDataDict
99
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
100
from vllm.pooling_params import PoolingParams
101
from vllm.reasoning import ReasoningParser, ReasoningParserManager
102
from vllm.sampling_params import BeamSearchParams, SamplingParams
103
from vllm.tokenizers import TokenizerLike
104
from vllm.tokenizers.deepseek_v32 import DeepseekV32Tokenizer
105
from vllm.tokenizers.mistral import MistralTokenizer
106
from vllm.tool_parsers import ToolParser, ToolParserManager
107
108
109
110
111
from vllm.tracing import (
    contains_trace_headers,
    extract_trace_headers,
    log_tracing_disabled_warning,
)
112
from vllm.utils import random_uuid
113
from vllm.utils.async_utils import (
114
    AsyncMicrobatchTokenizer,
115
    collect_from_async_generator,
116
    make_async,
117
118
    merge_async_iterators,
)
119
from vllm.utils.collection_utils import is_list_of
120
from vllm.v1.engine import EngineCoreRequest
121
from vllm.transformers_utils.tokenizers import CPM9GTokenizer
122

123
124
125
126
127
128
129
130
131

class GenerationError(Exception):
    """raised when finish_reason indicates internal server error (500)"""

    def __init__(self, message: str = "Internal server error"):
        super().__init__(message)
        self.status_code = HTTPStatus.INTERNAL_SERVER_ERROR


132
133
logger = init_logger(__name__)

134
135
136
137
138
CompletionLikeRequest: TypeAlias = (
    CompletionRequest
    | DetokenizeRequest
    | EmbeddingCompletionRequest
    | RerankRequest
139
    | ClassificationCompletionRequest
140
141
142
    | ScoreRequest
    | TokenizeCompletionRequest
)
143

144
ChatLikeRequest: TypeAlias = (
145
146
147
148
    ChatCompletionRequest
    | EmbeddingChatRequest
    | TokenizeChatRequest
    | ClassificationChatRequest
149
150
151
152
153
154
155
156
)
SpeechToTextRequest: TypeAlias = TranscriptionRequest | TranslationRequest
AnyRequest: TypeAlias = (
    CompletionLikeRequest
    | ChatLikeRequest
    | SpeechToTextRequest
    | ResponsesRequest
    | IOProcessorRequest
157
    | GenerateRequest
158
159
160
161
162
163
164
165
166
167
168
)

AnyResponse: TypeAlias = (
    CompletionResponse
    | ChatCompletionResponse
    | EmbeddingResponse
    | TranscriptionResponse
    | TokenizeResponse
    | PoolingResponse
    | ClassificationResponse
    | ScoreResponse
169
    | GenerateResponse
170
)
171

172

173
174
175
RequestT = TypeVar("RequestT", bound=AnyRequest)


176
177
@dataclass(kw_only=True)
class RequestProcessingMixin:
178
    """
179
    Mixin for request processing,
180
181
    handling prompt preparation and engine input.
    """
182

183
    engine_prompts: list[TokensPrompt] | None = field(default_factory=list)
184
185


186
187
@dataclass(kw_only=True)
class ResponseGenerationMixin:
188
    """
189
    Mixin for response generation,
190
191
    managing result generators and final batch results.
    """
192

193
194
195
    result_generator: (
        AsyncGenerator[tuple[int, RequestOutput | PoolingRequestOutput], None] | None
    ) = None
196
    final_res_batch: list[RequestOutput | PoolingRequestOutput] = field(
197
198
        default_factory=list
    )
199
200
201
202

    model_config = ConfigDict(arbitrary_types_allowed=True)


203
204
@dataclass(kw_only=True)
class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, Generic[RequestT]):
205
206
    # Shared across all requests
    request: RequestT
207
    raw_request: Request | None = None
208
209
    model_name: str
    request_id: str
210
    created_time: int = field(default_factory=lambda: int(time.time()))
211
    lora_request: LoRARequest | None = None
212
213

    # Shared across most requests
214
    tokenizer: TokenizerLike | None = None
215
216


217
218
219
@dataclass(kw_only=True)
class ClassificationServeContext(ServeContext[ClassificationRequest]):
    pass
220
221


222
@dataclass(kw_only=True)
223
class EmbeddingServeContext(ServeContext[EmbeddingRequest]):
224
    chat_template: str | None = None
225
226
227
    chat_template_content_format: ChatTemplateContentFormatOption


228
class OpenAIServing:
229
230
231
232
    request_id_prefix: ClassVar[str] = """
    A short string prepended to every request’s ID (e.g. "embd", "classify")
    so you can easily tell “this ID came from Embedding vs Classification.”
    """
233

234
235
    def __init__(
        self,
236
        engine_client: EngineClient,
237
        models: OpenAIServingModels,
238
        *,
239
        request_logger: RequestLogger | None,
240
        return_tokens_as_token_ids: bool = False,
241
        log_error_stack: bool = False,
242
    ):
243
244
        super().__init__()

245
        self.engine_client = engine_client
246

247
        self.models = models
248

249
        self.request_logger = request_logger
250
        self.return_tokens_as_token_ids = return_tokens_as_token_ids
251
        self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
252
        self._apply_mistral_chat_template_async = make_async(
253
254
            apply_mistral_chat_template, executor=self._tokenizer_executor
        )
255

256
        self._async_tokenizer_pool: dict[TokenizerLike, AsyncMicrobatchTokenizer] = {}
257
        self.log_error_stack = log_error_stack
258

259
        self.input_processor = self.models.input_processor
260
261
262
        self.io_processor = self.models.io_processor
        self.model_config = self.models.model_config
        self.max_model_len = self.model_config.max_model_len
263
264
265
        self.tokenizer_mode = self.models.model_config.tokenizer_mode
        if self.models.model_config.tokenizer_mode == "cpm":
            self.tokenizer = CPM9GTokenizer(self.models.model_config.model, trust_remote_code=True) 
266

267
    def _get_tool_parser(
268
        self, tool_parser_name: str | None = None, enable_auto_tools: bool = False
269
    ) -> Callable[[TokenizerLike], ToolParser] | None:
270
271
272
273
        """Get the tool parser based on the name."""
        parser = None
        if not enable_auto_tools or tool_parser_name is None:
            return parser
274
        logger.info('"auto" tool choice has been enabled.')
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294

        try:
            if tool_parser_name == "pythonic" and self.model_config.model.startswith(
                "meta-llama/Llama-3.2"
            ):
                logger.warning(
                    "Llama3.2 models may struggle to emit valid pythonic tool calls"
                )
            parser = ToolParserManager.get_tool_parser(tool_parser_name)
        except Exception as e:
            raise TypeError(
                "Error: --enable-auto-tool-choice requires "
                f"tool_parser:'{tool_parser_name}' which has not "
                "been registered"
            ) from e
        return parser

    def _get_reasoning_parser(
        self,
        reasoning_parser_name: str,
295
    ) -> Callable[[TokenizerLike], ReasoningParser] | None:
296
297
298
299
300
301
302
303
304
305
306
        """Get the reasoning parser based on the name."""
        parser = None
        if not reasoning_parser_name:
            return None
        try:
            parser = ReasoningParserManager.get_reasoning_parser(reasoning_parser_name)
            assert parser is not None
        except Exception as e:
            raise TypeError(f"{reasoning_parser_name=} has not been registered") from e
        return parser

307
    async def reset_mm_cache(self) -> None:
308
        self.input_processor.clear_mm_cache()
309
310
        await self.engine_client.reset_mm_cache()

311
312
313
314
315
    async def beam_search(
        self,
        prompt: PromptType,
        request_id: str,
        params: BeamSearchParams,
316
        lora_request: LoRARequest | None = None,
317
        trace_headers: Mapping[str, str] | None = None,
318
319
320
321
322
323
324
325
    ) -> AsyncGenerator[RequestOutput, None]:
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        ignore_eos = params.ignore_eos
        temperature = params.temperature
        length_penalty = params.length_penalty
        include_stop_str_in_output = params.include_stop_str_in_output

326
327
        input_processor = self.input_processor
        tokenizer = input_processor.tokenizer
328
329
        if tokenizer is None:
            raise ValueError(
330
                "You cannot use beam search when `skip_tokenizer_init=True`"
331
332
333
334
335
336
337
            )

        eos_token_id: int = tokenizer.eos_token_id  # type: ignore

        if is_explicit_encoder_decoder_prompt(prompt):
            raise NotImplementedError

338
        prompt_text: str | None
339
        prompt_token_ids: list[int]
340
        multi_modal_data: MultiModalDataDict | None
341
342
343
344
345
346
347
348
349
        if isinstance(prompt, str):
            prompt_text = prompt
            prompt_token_ids = []
            multi_modal_data = None
        else:
            prompt_text = prompt.get("prompt")  # type: ignore
            prompt_token_ids = prompt.get("prompt_token_ids", [])  # type: ignore
            multi_modal_data = prompt.get("multi_modal_data")  # type: ignore

350
351
352
353
354
355
356
357
358
359
        mm_processor_kwargs: dict[str, Any] | None = None

        # This is a workaround to fix multimodal beam search; this is a
        # bandaid fix for 2 small problems:
        # 1. Multi_modal_data on the processed_inputs currently resolves to
        #    `None`.
        # 2. preprocessing above expands the multimodal placeholders. However,
        #    this happens again in generation, so the double expansion causes
        #    a mismatch.
        # TODO - would be ideal to handle this more gracefully.
360
361
362
363
364

        tokenized_length = len(prompt_token_ids)

        sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)

365
        logprobs_num = 2 * beam_width
366
        beam_search_params = SamplingParams(
367
            logprobs=logprobs_num,
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
            max_tokens=1,
            temperature=temperature,
        )
        all_beams = [
            BeamSearchSequence(
                tokens=prompt_token_ids,
                cum_logprob=0,
                logprobs=[],
                multi_modal_data=multi_modal_data,
                mm_processor_kwargs=mm_processor_kwargs,
                lora_request=lora_request,
            )
        ]
        completed = []

        for _ in range(max_tokens):
            prompts_batch, lora_req_batch = zip(
                *[
                    (
387
                        TokensPrompt(
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
                            prompt_token_ids=beam.tokens,
                            multi_modal_data=beam.multi_modal_data,
                            mm_processor_kwargs=beam.mm_processor_kwargs,
                        ),
                        beam.lora_request,
                    )
                    for beam in all_beams
                ]
            )

            tasks = []
            request_id_batch = f"{request_id}-{random_uuid()}"

            for i, (individual_prompt, lora_req) in enumerate(
                zip(prompts_batch, lora_req_batch)
            ):
                request_id_item = f"{request_id_batch}-beam-{i}"
                task = asyncio.create_task(
                    collect_from_async_generator(
                        self.engine_client.generate(
                            individual_prompt,
                            beam_search_params,
                            request_id_item,
                            lora_request=lora_req,
412
                            trace_headers=trace_headers,
413
414
415
416
417
418
419
420
                        )
                    )
                )
                tasks.append(task)

            output = [x[0] for x in await asyncio.gather(*tasks)]

            new_beams = []
421
422
423
424
425
426
427
428
            # Store all new tokens generated by beam
            all_beams_token_id = []
            # Store the cumulative probability of all tokens
            # generated by beam search
            all_beams_logprob = []
            # Iterate through all beam inference results
            for i, result in enumerate(output):
                current_beam = all_beams[i]
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451

                # check for error finish reason and abort beam search
                if result.outputs[0].finish_reason == "error":
                    # yield error output and terminate beam search
                    yield RequestOutput(
                        request_id=request_id,
                        prompt=prompt_text,
                        outputs=[
                            CompletionOutput(
                                index=0,
                                text="",
                                token_ids=[],
                                cumulative_logprob=None,
                                logprobs=None,
                                finish_reason="error",
                            )
                        ],
                        finished=True,
                        prompt_token_ids=prompt_token_ids,
                        prompt_logprobs=None,
                    )
                    return

452
453
                if result.outputs[0].logprobs is not None:
                    logprobs = result.outputs[0].logprobs[0]
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
                    all_beams_token_id.extend(list(logprobs.keys()))
                    all_beams_logprob.extend(
                        [
                            current_beam.cum_logprob + obj.logprob
                            for obj in logprobs.values()
                        ]
                    )

            # Handle the token for the end of sentence (EOS)
            all_beams_token_id = np.array(all_beams_token_id)
            all_beams_logprob = np.array(all_beams_logprob)

            if not ignore_eos:
                # Get the index position of eos token in all generated results
                eos_idx = np.where(all_beams_token_id == eos_token_id)[0]
                for idx in eos_idx:
                    current_beam = all_beams[idx // logprobs_num]
                    result = output[idx // logprobs_num]
                    assert result.outputs[0].logprobs is not None
                    logprobs_entry = result.outputs[0].logprobs[0]
                    completed.append(
                        BeamSearchSequence(
                            tokens=current_beam.tokens + [eos_token_id]
                            if include_stop_str_in_output
                            else current_beam.tokens,
                            logprobs=current_beam.logprobs + [logprobs_entry],
                            cum_logprob=float(all_beams_logprob[idx]),
                            finish_reason="stop",
                            stop_reason=eos_token_id,
                        )
                    )
                # After processing, set the log probability of the eos condition
                # to negative infinity.
                all_beams_logprob[eos_idx] = -np.inf

            # Processing non-EOS tokens
            # Get indices of the top beam_width probabilities
            topn_idx = np.argpartition(np.negative(all_beams_logprob), beam_width)[
                :beam_width
            ]

            for idx in topn_idx:
                current_beam = all_beams[idx // logprobs_num]
                result = output[idx // logprobs_num]
                token_id = int(all_beams_token_id[idx])
                assert result.outputs[0].logprobs is not None
                logprobs_entry = result.outputs[0].logprobs[0]
                new_beams.append(
                    BeamSearchSequence(
                        tokens=current_beam.tokens + [token_id],
                        logprobs=current_beam.logprobs + [logprobs_entry],
                        lora_request=current_beam.lora_request,
                        cum_logprob=float(all_beams_logprob[idx]),
                        multi_modal_data=current_beam.multi_modal_data,
                        mm_processor_kwargs=current_beam.mm_processor_kwargs,
                    )
                )

            all_beams = new_beams
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546

        completed.extend(all_beams)
        sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
        best_beams = sorted_completed[:beam_width]

        for beam in best_beams:
            if beam.tokens[-1] == eos_token_id and not ignore_eos:
                # Skip the eos token in the text.
                tokens = beam.tokens[tokenized_length:-1]
            else:
                tokens = beam.tokens[tokenized_length:]
            beam.text = tokenizer.decode(tokens)

        yield RequestOutput(
            request_id=request_id,
            prompt=prompt_text,
            outputs=[
                CompletionOutput(
                    text=beam.text,  # type: ignore
                    cumulative_logprob=beam.cum_logprob,
                    token_ids=beam.tokens[tokenized_length:],
                    index=i,
                    logprobs=beam.logprobs,
                    finish_reason=beam.finish_reason
                    if beam.finish_reason is not None
                    else "length",
                    stop_reason=beam.stop_reason,
                )
                for (i, beam) in enumerate(best_beams)
            ],
            finished=True,
            prompt_token_ids=prompt_token_ids,
            prompt_logprobs=None,
        )
547

548
    def _get_renderer(self, tokenizer: TokenizerLike | None) -> BaseRenderer:
549
550
551
552
553
554
555
        """
        Get a Renderer instance with the provided tokenizer.
        Uses shared async tokenizer pool for efficiency.
        """
        return CompletionRenderer(
            model_config=self.model_config,
            tokenizer=tokenizer,
556
557
            async_tokenizer_pool=self._async_tokenizer_pool,
        )
558

559
560
561
562
563
564
565
566
567
568
569
570
571
    def _build_render_config(
        self,
        request: Any,
    ) -> RenderConfig:
        """
        Build and return a `RenderConfig` for an endpoint.

        Used by the renderer to control how prompts are prepared
        (e.g., tokenization and length handling). Endpoints should
        implement this with logic appropriate to their request type.
        """
        raise NotImplementedError

572
573
    def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer:
        """
574
        Return (and cache) an `AsyncMicrobatchTokenizer` bound to the
575
576
577
578
579
580
581
        given tokenizer.
        """
        async_tokenizer = self._async_tokenizer_pool.get(tokenizer)
        if async_tokenizer is None:
            async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
            self._async_tokenizer_pool[tokenizer] = async_tokenizer
        return async_tokenizer
582

583
584
585
    async def _preprocess(
        self,
        ctx: ServeContext,
586
    ) -> ErrorResponse | None:
587
588
589
590
591
592
593
594
595
        """
        Default preprocessing hook. Subclasses may override
        to prepare `ctx` (classification, embedding, etc.).
        """
        return None

    def _build_response(
        self,
        ctx: ServeContext,
596
    ) -> AnyResponse | ErrorResponse:
597
598
599
600
601
602
603
604
605
        """
        Default response builder. Subclass may override this method
        to return the appropriate response object.
        """
        return self.create_error_response("unimplemented endpoint")

    async def handle(
        self,
        ctx: ServeContext,
606
607
    ) -> AnyResponse | ErrorResponse:
        generation: AsyncGenerator[AnyResponse | ErrorResponse, None]
608
609
610
611
612
613
614
615
616
617
        generation = self._pipeline(ctx)

        async for response in generation:
            return response

        return self.create_error_response("No response yielded from pipeline")

    async def _pipeline(
        self,
        ctx: ServeContext,
618
    ) -> AsyncGenerator[AnyResponse | ErrorResponse, None]:
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
        """Execute the request processing pipeline yielding responses."""
        if error := await self._check_model(ctx.request):
            yield error
        if error := self._validate_request(ctx):
            yield error

        preprocess_ret = await self._preprocess(ctx)
        if isinstance(preprocess_ret, ErrorResponse):
            yield preprocess_ret

        generators_ret = await self._prepare_generators(ctx)
        if isinstance(generators_ret, ErrorResponse):
            yield generators_ret

        collect_ret = await self._collect_batch(ctx)
        if isinstance(collect_ret, ErrorResponse):
            yield collect_ret

        yield self._build_response(ctx)

639
    def _validate_request(self, ctx: ServeContext) -> ErrorResponse | None:
640
        truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None)
641

642
643
644
645
        if (
            truncate_prompt_tokens is not None
            and truncate_prompt_tokens > self.max_model_len
        ):
646
647
648
            return self.create_error_response(
                "truncate_prompt_tokens value is "
                "greater than max_model_len."
649
650
                " Please, select a smaller truncation size."
            )
651
652
        return None

653
654
655
    def _create_pooling_params(
        self,
        ctx: ServeContext,
656
    ) -> PoolingParams | ErrorResponse:
657
658
        if not hasattr(ctx.request, "to_pooling_params"):
            return self.create_error_response(
659
660
                "Request type does not support pooling parameters"
            )
661
662
663

        return ctx.request.to_pooling_params()

664
665
666
    async def _prepare_generators(
        self,
        ctx: ServeContext,
667
    ) -> ErrorResponse | None:
668
        """Schedule the request and get the result generator."""
669
        generators: list[
670
            AsyncGenerator[RequestOutput | PoolingRequestOutput, None]
671
        ] = []
672
673

        try:
674
675
676
677
678
            trace_headers = (
                None
                if ctx.raw_request is None
                else await self._get_trace_headers(ctx.raw_request.headers)
            )
679

680
681
682
            pooling_params = self._create_pooling_params(ctx)
            if isinstance(pooling_params, ErrorResponse):
                return pooling_params
683
684

            if ctx.engine_prompts is None:
685
                return self.create_error_response("Engine prompts not available")
686
687
688
689

            for i, engine_prompt in enumerate(ctx.engine_prompts):
                request_id_item = f"{ctx.request_id}-{i}"

690
691
                self._log_inputs(
                    request_id_item,
692
                    engine_prompt,
693
694
695
                    params=pooling_params,
                    lora_request=ctx.lora_request,
                )
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718

                generator = self.engine_client.encode(
                    engine_prompt,
                    pooling_params,
                    request_id_item,
                    lora_request=ctx.lora_request,
                    trace_headers=trace_headers,
                    priority=getattr(ctx.request, "priority", 0),
                )

                generators.append(generator)

            ctx.result_generator = merge_async_iterators(*generators)

            return None

        except Exception as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))

    async def _collect_batch(
        self,
        ctx: ServeContext,
719
    ) -> ErrorResponse | None:
720
721
722
        """Collect batch results from the result generator."""
        try:
            if ctx.engine_prompts is None:
723
                return self.create_error_response("Engine prompts not available")
724
725

            num_prompts = len(ctx.engine_prompts)
726
            final_res_batch: list[RequestOutput | PoolingRequestOutput | None]
727
728
729
            final_res_batch = [None] * num_prompts

            if ctx.result_generator is None:
730
                return self.create_error_response("Result generator not available")
731
732
733
734
735
736

            async for i, res in ctx.result_generator:
                final_res_batch[i] = res

            if None in final_res_batch:
                return self.create_error_response(
737
738
                    "Failed to generate results for all prompts"
                )
739

740
            ctx.final_res_batch = [res for res in final_res_batch if res is not None]
741
742
743
744
745
746

            return None

        except Exception as e:
            return self.create_error_response(str(e))

747
    def create_error_response(
748
749
750
751
752
        self,
        message: str,
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
    ) -> ErrorResponse:
753
754
755
756
757
758
        if self.log_error_stack:
            exc_type, _, _ = sys.exc_info()
            if exc_type is not None:
                traceback.print_exc()
            else:
                traceback.print_stack()
759
760
761
        return ErrorResponse(
            error=ErrorInfo(message=message, type=err_type, code=status_code.value)
        )
762

763
    def create_streaming_error_response(
764
765
766
767
768
        self,
        message: str,
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
    ) -> str:
769
        json_str = json.dumps(
770
771
772
773
            self.create_error_response(
                message=message, err_type=err_type, status_code=status_code
            ).model_dump()
        )
774
775
        return json_str

776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
    def _raise_if_error(self, finish_reason: str | None, request_id: str) -> None:
        """Raise GenerationError if finish_reason indicates an error."""
        if finish_reason == "error":
            logger.error(
                "Request %s failed with an internal error during generation",
                request_id,
            )
            raise GenerationError("Internal server error")

    def _convert_generation_error_to_response(
        self, e: GenerationError
    ) -> ErrorResponse:
        """Convert GenerationError to ErrorResponse."""
        return self.create_error_response(
            str(e),
            err_type="InternalServerError",
            status_code=e.status_code,
        )

    def _convert_generation_error_to_streaming_response(
        self, e: GenerationError
    ) -> str:
        """Convert GenerationError to streaming error response."""
        return self.create_streaming_error_response(
            str(e),
            err_type="InternalServerError",
            status_code=e.status_code,
        )

805
    async def _check_model(
806
807
        self,
        request: AnyRequest,
808
    ) -> ErrorResponse | None:
809
810
        error_response = None

811
        if self._is_model_supported(request.model):
812
            return None
813
        if request.model in self.models.lora_requests:
814
            return None
815
816
817
818
819
        if (
            envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING
            and request.model
            and (load_result := await self.models.resolve_lora(request.model))
        ):
820
821
            if isinstance(load_result, LoRARequest):
                return None
822
823
824
825
            if (
                isinstance(load_result, ErrorResponse)
                and load_result.error.code == HTTPStatus.BAD_REQUEST.value
            ):
826
827
828
                error_response = load_result

        return error_response or self.create_error_response(
829
830
            message=f"The model `{request.model}` does not exist.",
            err_type="NotFoundError",
831
832
            status_code=HTTPStatus.NOT_FOUND,
        )
833

834
    def _get_active_default_mm_loras(self, request: AnyRequest) -> LoRARequest | None:
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
        """Determine if there are any active default multimodal loras."""
        # TODO: Currently this is only enabled for chat completions
        # to be better aligned with only being enabled for .generate
        # when run offline. It would be nice to support additional
        # tasks types in the future.
        message_types = self._get_message_types(request)
        default_mm_loras = set()

        for lora in self.models.lora_requests.values():
            # Best effort match for default multimodal lora adapters;
            # There is probably a better way to do this, but currently
            # this matches against the set of 'types' in any content lists
            # up until '_', e.g., to match audio_url -> audio
            if lora.lora_name in message_types:
                default_mm_loras.add(lora)

        # Currently only support default modality specific loras if
        # we have exactly one lora matched on the request.
        if len(default_mm_loras) == 1:
            return default_mm_loras.pop()
        return None

857
    def _maybe_get_adapters(
858
859
860
        self,
        request: AnyRequest,
        supports_default_mm_loras: bool = False,
861
    ) -> LoRARequest | None:
862
        if request.model in self.models.lora_requests:
863
            return self.models.lora_requests[request.model]
864
865
866
867
868
869

        # Currently only support default modality specific loras
        # if we have exactly one lora matched on the request.
        if supports_default_mm_loras:
            default_mm_lora = self._get_active_default_mm_loras(request)
            if default_mm_lora is not None:
870
                return default_mm_lora
871

872
        if self._is_model_supported(request.model):
873
            return None
874

875
        # if _check_model has been called earlier, this will be unreachable
876
        raise ValueError(f"The model `{request.model}` does not exist.")
877

878
879
880
881
882
883
884
885
886
887
    def _get_message_types(self, request: AnyRequest) -> set[str]:
        """Retrieve the set of types from message content dicts up
        until `_`; we use this to match potential multimodal data
        with default per modality loras.
        """
        message_types: set[str] = set()

        if not hasattr(request, "messages"):
            return message_types

888
889
890
891
892
        messages = request.messages
        if messages is None or isinstance(messages, (str, bytes)):
            return message_types

        for message in messages:
893
894
895
896
897
            if (
                isinstance(message, dict)
                and "content" in message
                and isinstance(message["content"], list)
            ):
898
899
900
901
902
                for content_dict in message["content"]:
                    if "type" in content_dict:
                        message_types.add(content_dict["type"].split("_")[0])
        return message_types

903
    async def _normalize_prompt_text_to_input(
904
905
906
        self,
        request: AnyRequest,
        prompt: str,
907
        tokenizer: TokenizerLike,
908
        add_special_tokens: bool,
909
    ) -> TokensPrompt:
910
911
        async_tokenizer = self._get_async_tokenizer(tokenizer)

912
913
914
915
        if (
            self.model_config.encoder_config is not None
            and self.model_config.encoder_config.get("do_lower_case", False)
        ):
916
917
            prompt = prompt.lower()

918
        truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
919

920
        if truncate_prompt_tokens is None:
921
            encoded = await async_tokenizer(
922
923
                prompt, add_special_tokens=add_special_tokens
            )
924
925
        elif truncate_prompt_tokens < 0:
            # Negative means we cap at the model's max length
926
927
928
929
            encoded = await async_tokenizer(
                prompt,
                add_special_tokens=add_special_tokens,
                truncation=True,
930
931
                max_length=self.max_model_len,
            )
932
        else:
933
934
935
936
            encoded = await async_tokenizer(
                prompt,
                add_special_tokens=add_special_tokens,
                truncation=True,
937
938
                max_length=truncate_prompt_tokens,
            )
939

940
941
942
943
        if self.tokenizer_mode == "cpm":
            input_ids = [self.tokenizer.bos_id] + self.tokenizer.encode(prompt)
        else:
            input_ids = encoded.input_ids
944
945
946
947
948

        input_text = prompt

        return self._validate_input(request, input_ids, input_text)

949
    async def _normalize_prompt_tokens_to_input(
950
951
        self,
        request: AnyRequest,
952
        prompt_ids: list[int],
953
        tokenizer: TokenizerLike | None,
954
    ) -> TokensPrompt:
955
        truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
956

957
        if truncate_prompt_tokens is None:
958
            input_ids = prompt_ids
959
        elif truncate_prompt_tokens < 0:
960
            input_ids = prompt_ids[-self.max_model_len :]
961
962
963
        else:
            input_ids = prompt_ids[-truncate_prompt_tokens:]

964
965
966
        if tokenizer is None:
            input_text = ""
        else:
967
968
            async_tokenizer = self._get_async_tokenizer(tokenizer) 
            input_text = await async_tokenizer.decode(input_ids) if self.tokenizer_mode != "cpm" else await self.tokenizer.decode_all(input_ids)
969

970
971
972
973
974
        return self._validate_input(request, input_ids, input_text)

    def _validate_input(
        self,
        request: AnyRequest,
975
        input_ids: list[int],
976
        input_text: str,
977
    ) -> TokensPrompt:
978
979
        token_num = len(input_ids)

980
981
        # Note: EmbeddingRequest, ClassificationRequest,
        # and ScoreRequest doesn't have max_tokens
982
        if isinstance(
983
            request,
984
985
986
987
988
            (
                EmbeddingChatRequest,
                EmbeddingCompletionRequest,
                ScoreRequest,
                RerankRequest,
989
990
                ClassificationCompletionRequest,
                ClassificationChatRequest,
991
992
            ),
        ):
993
994
            # Note: input length can be up to the entire model context length
            # since these requests don't generate tokens.
995
            if token_num > self.max_model_len:
996
997
                operations: dict[type[AnyRequest], str] = {
                    ScoreRequest: "score",
998
999
                    ClassificationCompletionRequest: "classification",
                    ClassificationChatRequest: "classification",
1000
                }
1001
                operation = operations.get(type(request), "embedding generation")
1002
1003
1004
                raise ValueError(
                    f"This model's maximum context length is "
                    f"{self.max_model_len} tokens. However, you requested "
1005
                    f"{token_num} tokens in the input for {operation}. "
1006
1007
                    f"Please reduce the length of the input."
                )
1008
            return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
1009

1010
1011
        # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
        # and does not require model context length validation
1012
        if isinstance(
1013
1014
            request,
            (TokenizeCompletionRequest, TokenizeChatRequest, DetokenizeRequest),
1015
        ):
1016
            return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
1017

1018
1019
1020
1021
1022
        # chat completion endpoint supports max_completion_tokens
        if isinstance(request, ChatCompletionRequest):
            # TODO(#9845): remove max_tokens when field dropped from OpenAI API
            max_tokens = request.max_completion_tokens or request.max_tokens
        else:
1023
            max_tokens = getattr(request, "max_tokens", None)
1024
1025
1026
1027

        # Note: input length can be up to model context length - 1 for
        # completion-like requests.
        if token_num >= self.max_model_len:
1028
            raise ValueError(
1029
                f"This model's maximum context length is "
1030
1031
                f"{self.max_model_len} tokens. However, your request has "
                f"{token_num} input tokens. Please reduce the length of "
1032
1033
                "the input messages."
            )
1034

1035
        if max_tokens is not None and token_num + max_tokens > self.max_model_len:
1036
1037
1038
1039
1040
            raise ValueError(
                "'max_tokens' or 'max_completion_tokens' is too large: "
                f"{max_tokens}. This model's maximum context length is "
                f"{self.max_model_len} tokens and your request has "
                f"{token_num} input tokens ({max_tokens} > {self.max_model_len}"
1041
1042
                f" - {token_num})."
            )
1043

1044
        return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
1045

1046
    async def _tokenize_prompt_input_async(
1047
1048
        self,
        request: AnyRequest,
1049
        tokenizer: TokenizerLike,
1050
        prompt_input: str | list[int],
1051
        add_special_tokens: bool = True,
1052
    ) -> TokensPrompt:
1053
        """
1054
        A simpler implementation that tokenizes a single prompt input.
1055
        """
1056
        async for result in self._tokenize_prompt_inputs_async(
1057
1058
            request,
            tokenizer,
1059
            [prompt_input],
1060
            add_special_tokens=add_special_tokens,
1061
1062
1063
        ):
            return result
        raise ValueError("No results yielded from tokenization")
1064

1065
    async def _tokenize_prompt_inputs_async(
1066
1067
        self,
        request: AnyRequest,
1068
        tokenizer: TokenizerLike,
1069
        prompt_inputs: Iterable[str | list[int]],
1070
        add_special_tokens: bool = True,
1071
    ) -> AsyncGenerator[TokensPrompt, None]:
1072
        """
1073
        A simpler implementation that tokenizes multiple prompt inputs.
1074
        """
1075
1076
        for prompt in prompt_inputs:
            if isinstance(prompt, str):
1077
                yield await self._normalize_prompt_text_to_input(
1078
                    request,
1079
1080
                    prompt=prompt,
                    tokenizer=tokenizer,
1081
1082
1083
                    add_special_tokens=add_special_tokens,
                )
            else:
1084
                yield await self._normalize_prompt_tokens_to_input(
1085
                    request,
1086
1087
                    prompt_ids=prompt,
                    tokenizer=tokenizer,
1088
1089
                )

1090
1091
    def _validate_chat_template(
        self,
1092
1093
        request_chat_template: str | None,
        chat_template_kwargs: dict[str, Any] | None,
1094
        trust_request_chat_template: bool,
1095
    ) -> ErrorResponse | None:
1096
        if not trust_request_chat_template and (
1097
1098
1099
1100
1101
1102
            request_chat_template is not None
            or (
                chat_template_kwargs
                and chat_template_kwargs.get("chat_template") is not None
            )
        ):
1103
1104
1105
            return self.create_error_response(
                "Chat template is passed with request, but "
                "--trust-request-chat-template is not set. "
1106
1107
                "Refused request with untrusted chat template."
            )
1108
1109
        return None

1110
1111
    async def _preprocess_chat(
        self,
1112
        request: ChatLikeRequest | ResponsesRequest,
1113
        tokenizer: TokenizerLike | None,
1114
        messages: list[ChatCompletionMessageParam],
1115
        chat_template: str | None,
1116
        chat_template_content_format: ChatTemplateContentFormatOption,
1117
1118
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
1119
1120
1121
        tool_dicts: list[dict[str, Any]] | None = None,
        documents: list[dict[str, str]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
1122
        tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
1123
        add_special_tokens: bool = False,
1124
    ) -> tuple[list[ConversationMessage], list[TokensPrompt]]:
1125
1126
        model_config = self.model_config

1127
1128
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
1129
            tool_dicts,
1130
1131
            chat_template_content_format,
            tokenizer,
1132
            model_config=model_config,
1133
        )
1134
        conversation, mm_data_future, mm_uuids = parse_chat_messages_futures(
1135
            messages,
1136
            model_config,
1137
            content_format=resolved_content_format,
1138
1139
        )

1140
        _chat_template_kwargs: dict[str, Any] = dict(
1141
1142
1143
1144
1145
1146
1147
1148
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tool_dicts,
            documents=documents,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

1149
        request_prompt: str | list[int]
1150
1151
1152
1153

        if tokenizer is None:
            request_prompt = "placeholder"
        elif isinstance(tokenizer, MistralTokenizer):
1154
            request_prompt = await self._apply_mistral_chat_template_async(
1155
1156
                tokenizer,
                messages=messages,
1157
                **_chat_template_kwargs,
1158
            )
1159
1160
1161
1162
        elif isinstance(tokenizer, DeepseekV32Tokenizer):
            request_prompt = tokenizer.apply_chat_template(
                conversation=conversation,
                messages=messages,
1163
                model_config=model_config,
1164
1165
                **_chat_template_kwargs,
            )
1166
1167
        else:
            request_prompt = apply_hf_chat_template(
1168
                tokenizer=tokenizer,
1169
                conversation=conversation,
1170
                model_config=model_config,
1171
                **_chat_template_kwargs,
1172
1173
1174
1175
            )

        mm_data = await mm_data_future

1176
1177
1178
        # tool parsing is done only if a tool_parser has been set and if
        # tool_choice is not "none" (if tool_choice is "none" but a tool_parser
        # is set, we want to prevent parsing a tool_call hallucinated by the LLM
1179
1180
1181
        should_parse_tools = tool_parser is not None and (
            hasattr(request, "tool_choice") and request.tool_choice != "none"
        )
1182
1183

        if should_parse_tools:
1184
1185
1186
1187
1188
            if not isinstance(request, ChatCompletionRequest | ResponsesRequest):
                msg = (
                    "Tool usage is only supported for Chat Completions API "
                    "or Responses API requests."
                )
1189
                raise NotImplementedError(msg)
1190
            request = tool_parser(tokenizer).adjust_request(request=request)  # type: ignore
1191

1192
1193
        if tokenizer is None:
            assert isinstance(request_prompt, str), (
1194
1195
                "Prompt has to be a string",
                "when the tokenizer is not initialised",
1196
            )
1197
            prompt_inputs = TokensPrompt(prompt=request_prompt, prompt_token_ids=[1])
1198
        elif isinstance(request_prompt, str):
1199
            prompt_inputs = await self._tokenize_prompt_input_async(
1200
1201
1202
1203
1204
1205
1206
1207
                request,
                tokenizer,
                request_prompt,
                add_special_tokens=add_special_tokens,
            )
        else:
            # For MistralTokenizer
            assert is_list_of(request_prompt, int), (
1208
1209
                "Prompt has to be either a string or a list of token ids"
            )
1210
            prompt_inputs = TokensPrompt(
1211
                prompt=tokenizer.decode(request_prompt),
1212
1213
                prompt_token_ids=request_prompt,
            )
1214

1215
1216
1217
1218
        engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["prompt_token_ids"])
        if "prompt" in prompt_inputs:
            engine_prompt["prompt"] = prompt_inputs["prompt"]

1219
1220
        if mm_data is not None:
            engine_prompt["multi_modal_data"] = mm_data
1221
1222
1223
1224

        if mm_uuids is not None:
            engine_prompt["multi_modal_uuids"] = mm_uuids

1225
1226
        if request.mm_processor_kwargs is not None:
            engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
1227

1228
1229
1230
        if hasattr(request, "cache_salt") and request.cache_salt is not None:
            engine_prompt["cache_salt"] = request.cache_salt

1231
        return conversation, [engine_prompt]
1232

1233
1234
1235
1236
    async def _process_inputs(
        self,
        request_id: str,
        engine_prompt: PromptType,
1237
        params: SamplingParams | PoolingParams,
1238
        *,
1239
1240
        lora_request: LoRARequest | None,
        trace_headers: Mapping[str, str] | None,
1241
1242
        priority: int,
    ) -> tuple[EngineCoreRequest, dict[str, Any]]:
1243
        """Use the Processor to process inputs for AsyncLLM."""
1244
        tokenization_kwargs: dict[str, Any] = {}
1245
1246
1247
        _validate_truncation_size(
            self.max_model_len, params.truncate_prompt_tokens, tokenization_kwargs
        )
1248

1249
        engine_request = self.input_processor.process_inputs(
1250
1251
            request_id,
            engine_prompt,
1252
            params,
1253
1254
1255
1256
1257
1258
1259
            lora_request=lora_request,
            tokenization_kwargs=tokenization_kwargs,
            trace_headers=trace_headers,
            priority=priority,
        )
        return engine_request, tokenization_kwargs

1260
1261
1262
    async def _render_next_turn(
        self,
        request: ResponsesRequest,
1263
        tokenizer: TokenizerLike | None,
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
        messages: list[ResponseInputOutputItem],
        tool_dicts: list[dict[str, Any]] | None,
        tool_parser,
        chat_template: str | None,
        chat_template_content_format: ChatTemplateContentFormatOption,
    ):
        new_messages = construct_input_messages(
            request_input=messages,
        )

1274
        _, engine_prompts = await self._preprocess_chat(
1275
1276
1277
1278
1279
1280
1281
1282
            request,
            tokenizer,
            new_messages,
            tool_dicts=tool_dicts,
            tool_parser=tool_parser,
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
        )
1283
        return engine_prompts
1284

1285
    async def _generate_with_builtin_tools(
1286
        self,
1287
        request_id: str,
1288
        engine_prompt: TokensPrompt,
1289
1290
        sampling_params: SamplingParams,
        context: ConversationContext,
1291
        lora_request: LoRARequest | None = None,
1292
1293
1294
        priority: int = 0,
        **kwargs,
    ):
1295
1296
        prompt_text, _, _ = self._get_prompt_components(engine_prompt)

1297
        orig_priority = priority
1298
        sub_request = 0
1299
        while True:
1300
1301
            # Ensure that each sub-request has a unique request id.
            sub_request_id = f"{request_id}_{sub_request}"
1302
            self._log_inputs(
1303
                sub_request_id,
1304
                engine_prompt,
1305
1306
1307
                params=sampling_params,
                lora_request=lora_request,
            )
1308
            trace_headers = kwargs.get("trace_headers")
1309
            engine_request, tokenization_kwargs = await self._process_inputs(
1310
                sub_request_id,
1311
1312
                engine_prompt,
                sampling_params,
1313
1314
1315
                lora_request=lora_request,
                trace_headers=trace_headers,
                priority=priority,
1316
            )
1317
1318
1319
1320

            generator = self.engine_client.generate(
                engine_request,
                sampling_params,
1321
                sub_request_id,
1322
1323
                lora_request=lora_request,
                priority=priority,
1324
1325
                prompt_text=prompt_text,
                tokenization_kwargs=tokenization_kwargs,
1326
1327
                **kwargs,
            )
1328

1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
            async for res in generator:
                context.append_output(res)
                # NOTE(woosuk): The stop condition is handled by the engine.
                yield context

            if not context.need_builtin_tool_call():
                # The model did not ask for a tool call, so we're done.
                break

            # Call the tool and update the context with the result.
            tool_output = await context.call_tool()
1340
            context.append_tool_output(tool_output)
1341
1342
1343
1344
1345
1346

            # TODO: uncomment this and enable tool output streaming
            # yield context

            # Create inputs for the next turn.
            # Render the next prompt token ids.
1347
1348
            if isinstance(context, (HarmonyContext, StreamingHarmonyContext)):
                prompt_token_ids = context.render_for_completion()
1349
                engine_prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
1350
            elif isinstance(context, ParsableContext):
1351
                engine_prompts = await self._render_next_turn(
1352
1353
1354
1355
1356
1357
1358
1359
1360
                    context.request,
                    context.tokenizer,
                    context.parser.response_messages,
                    context.tool_dicts,
                    context.tool_parser_cls,
                    context.chat_template,
                    context.chat_template_content_format,
                )
                engine_prompt = engine_prompts[0]
1361
                prompt_text, _, _ = self._get_prompt_components(engine_prompt)
1362

1363
            # Update the sampling params.
1364
1365
1366
            sampling_params.max_tokens = self.max_model_len - len(
                engine_prompt["prompt_token_ids"]
            )
1367
1368
            # OPTIMIZATION
            priority = orig_priority - 1
1369
            sub_request += 1
1370

1371
1372
    def _get_prompt_components(self, prompt: PromptType) -> PromptComponents:
        return get_prompt_components(prompt)
1373

1374
1375
1376
    def _log_inputs(
        self,
        request_id: str,
1377
        inputs: PromptType,
1378
1379
        params: SamplingParams | PoolingParams | BeamSearchParams | None,
        lora_request: LoRARequest | None,
1380
1381
1382
    ) -> None:
        if self.request_logger is None:
            return
1383

1384
        prompt, prompt_token_ids, prompt_embeds = self._get_prompt_components(inputs)
1385
1386
1387
1388
1389

        self.request_logger.log_inputs(
            request_id,
            prompt,
            prompt_token_ids,
1390
            prompt_embeds,
1391
1392
1393
            params=params,
            lora_request=lora_request,
        )
1394

1395
1396
1397
    async def _get_trace_headers(
        self,
        headers: Headers,
1398
    ) -> Mapping[str, str] | None:
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
        is_tracing_enabled = await self.engine_client.is_tracing_enabled()

        if is_tracing_enabled:
            return extract_trace_headers(headers)

        if contains_trace_headers(headers):
            log_tracing_disabled_warning()

        return None

1409
    @staticmethod
1410
    def _base_request_id(
1411
1412
        raw_request: Request | None, default: str | None = None
    ) -> str | None:
1413
        """Pulls the request id to use from a header, if provided"""
1414
1415
1416
1417
        if raw_request is not None and (
            (req_id := raw_request.headers.get("X-Request-Id")) is not None
        ):
            return req_id
1418

1419
        return random_uuid() if default is None else default
1420

1421
1422
1423
    @staticmethod
    def _get_data_parallel_rank(raw_request: Request | None) -> int | None:
        """Pulls the data parallel rank from a header, if provided"""
1424
        if raw_request is None:
1425
1426
1427
1428
1429
            return None

        rank_str = raw_request.headers.get("X-data-parallel-rank")
        if rank_str is None:
            return None
1430

1431
1432
1433
1434
1435
        try:
            return int(rank_str)
        except ValueError:
            return None

1436
1437
1438
    @staticmethod
    def _parse_tool_calls_from_content(
        request: ResponsesRequest | ChatCompletionRequest,
1439
        tokenizer: TokenizerLike | None,
1440
        enable_auto_tools: bool,
1441
        tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
        content: str | None = None,
    ) -> tuple[list[FunctionCall] | None, str | None]:
        function_calls = list[FunctionCall]()
        if request.tool_choice and isinstance(request.tool_choice, ToolChoiceFunction):
            assert content is not None
            # Forced Function Call
            function_calls.append(
                FunctionCall(name=request.tool_choice.name, arguments=content)
            )
            content = None  # Clear content since tool is called.
        elif request.tool_choice and isinstance(
            request.tool_choice, ChatCompletionNamedToolChoiceParam
        ):
            assert content is not None
            # Forced Function Call
            function_calls.append(
                FunctionCall(name=request.tool_choice.function.name, arguments=content)
            )
            content = None  # Clear content since tool is called.
        elif request.tool_choice == "required":
            assert content is not None
            tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(content)
            function_calls.extend(
                [
                    FunctionCall(
                        name=tool_call.name,
                        arguments=json.dumps(tool_call.parameters, ensure_ascii=False),
                    )
                    for tool_call in tool_calls
                ]
            )
            content = None  # Clear content since tool is called.
        elif (
            tool_parser_cls
            and enable_auto_tools
            and (request.tool_choice == "auto" or request.tool_choice is None)
        ):
1479
1480
1481
1482
1483
            if tokenizer is None:
                raise ValueError(
                    "Tokenizer not available when `skip_tokenizer_init=True`"
                )

1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
            # Automatic Tool Call Parsing
            try:
                tool_parser = tool_parser_cls(tokenizer)
            except RuntimeError as e:
                logger.exception("Error in tool parser creation.")
                raise e
            tool_call_info = tool_parser.extract_tool_calls(
                content if content is not None else "",
                request=request,  # type: ignore
            )
            if tool_call_info is not None and tool_call_info.tools_called:
                # extract_tool_calls() returns a list of tool calls.
                function_calls.extend(
                    FunctionCall(
                        name=tool_call.function.name,
                        arguments=tool_call.function.arguments,
                    )
                    for tool_call in tool_call_info.tool_calls
                )
                content = tool_call_info.content
1504
1505
                if content and content.strip() == "":
                    content = None
1506
1507
1508
1509
1510
            else:
                # No tool calls.
                return None, content

        return function_calls, content
1511

1512
    @staticmethod
1513
1514
1515
    def _get_decoded_token(
        logprob: Logprob,
        token_id: int,
1516
        tokenizer: TokenizerLike | None,
1517
1518
        return_as_token_id: bool = False,
    ) -> str:
1519
1520
1521
        if return_as_token_id:
            return f"token_id:{token_id}"

1522
1523
        if logprob.decoded_token is not None:
            return logprob.decoded_token
1524
1525
1526
1527
1528
1529

        if tokenizer is None:
            raise ValueError(
                "Unable to get tokenizer because `skip_tokenizer_init=True`"
            )

1530
        return tokenizer.decode(token_id)
1531

1532
    def _is_model_supported(self, model_name: str | None) -> bool:
1533
1534
        if not model_name:
            return True
1535
        return self.models.is_base_model(model_name)
1536

1537
1538

def clamp_prompt_logprobs(
1539
1540
    prompt_logprobs: PromptLogprobs | None,
) -> PromptLogprobs | None:
1541
1542
1543
1544
1545
1546
1547
    if prompt_logprobs is None:
        return prompt_logprobs

    for logprob_dict in prompt_logprobs:
        if logprob_dict is None:
            continue
        for logprob_values in logprob_dict.values():
1548
            if logprob_values.logprob == float("-inf"):
1549
1550
                logprob_values.logprob = -9999.0
    return prompt_logprobs