serving_engine.py 55.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
import json
5
import sys
6
import time
7
import traceback
8
from collections.abc import AsyncGenerator, Callable, Iterable, Mapping, Sequence
9
from concurrent.futures import ThreadPoolExecutor
10
from dataclasses import dataclass, field
11
from http import HTTPStatus
12
from typing import Any, ClassVar, Generic, TypeAlias, TypeVar
13

14
import numpy as np
15
import torch
16
from fastapi import Request
17
from pydantic import ConfigDict, TypeAdapter
18
from starlette.datastructures import Headers
19
20
from typing_extensions import TypeIs

21
22
23
24
25
26
27
28
29
30
from vllm.entrypoints.context import (
    HarmonyContext,
    ParsableContext,
    StreamingHarmonyContext,
)
from vllm.entrypoints.openai.protocol import (
    FunctionCall,
    ResponseInputOutputItem,
    ResponsesRequest,
)
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from vllm.entrypoints.pooling.classify.protocol import (
    ClassificationChatRequest,
    ClassificationCompletionRequest,
    ClassificationRequest,
    ClassificationResponse,
)
from vllm.entrypoints.pooling.embed.protocol import (
    EmbeddingChatRequest,
    EmbeddingCompletionRequest,
    EmbeddingRequest,
    EmbeddingResponse,
)
from vllm.entrypoints.pooling.pooling.protocol import (
    IOProcessorRequest,
    PoolingResponse,
)
from vllm.entrypoints.pooling.score.protocol import (
    RerankRequest,
    ScoreRequest,
    ScoreResponse,
)
52
from vllm.transformers_utils.tokenizer import AnyTokenizer
53

54
55
56
57
58
if sys.version_info >= (3, 12):
    from typing import TypedDict
else:
    from typing_extensions import TypedDict

59
60
61
62
from openai.types.responses import (
    ToolChoiceFunction,
)

63
import vllm.envs as envs
64
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
65
from vllm.engine.protocol import EngineClient
66
67
68
69
70
71
72
73
74
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ChatTemplateContentFormatOption,
    ConversationMessage,
    apply_hf_chat_template,
    apply_mistral_chat_template,
    parse_chat_messages_futures,
    resolve_chat_template_content_format,
)
75
from vllm.entrypoints.context import ConversationContext
76
from vllm.entrypoints.logger import RequestLogger
77
from vllm.entrypoints.openai.protocol import (
78
    ChatCompletionNamedToolChoiceParam,
79
80
81
82
83
84
85
    ChatCompletionRequest,
    ChatCompletionResponse,
    CompletionRequest,
    CompletionResponse,
    DetokenizeRequest,
    ErrorInfo,
    ErrorResponse,
86
    FunctionDefinition,
87
88
89
90
91
92
93
    TokenizeChatRequest,
    TokenizeCompletionRequest,
    TokenizeResponse,
    TranscriptionRequest,
    TranscriptionResponse,
    TranslationRequest,
)
94
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
95
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
96
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig
97
98
99
from vllm.entrypoints.responses_utils import (
    construct_input_messages,
)
100
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
101
from vllm.entrypoints.utils import _validate_truncation_size
102
from vllm.inputs.data import PromptType
103
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
104
105
106
107
108
from vllm.inputs.parse import (
    PromptComponents,
    get_prompt_components,
    is_explicit_encoder_decoder_prompt,
)
109
from vllm.logger import init_logger
110
from vllm.logprobs import Logprob, PromptLogprobs
111
from vllm.lora.request import LoRARequest
112
from vllm.multimodal import (  # noqa: F401 - Required to resolve Pydantic error in RequestProcessingMixin
113
114
115
    MultiModalDataDict,
    MultiModalUUIDDict,
)
116
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
117
from vllm.pooling_params import PoolingParams
118
from vllm.reasoning import ReasoningParser, ReasoningParserManager
119
from vllm.sampling_params import BeamSearchParams, SamplingParams
120
from vllm.tokenizers import DeepseekV32Tokenizer, MistralTokenizer, TokenizerLike
121
122
123
124
125
from vllm.tracing import (
    contains_trace_headers,
    extract_trace_headers,
    log_tracing_disabled_warning,
)
126
from vllm.utils import random_uuid
127
from vllm.utils.async_utils import (
128
    AsyncMicrobatchTokenizer,
129
    collect_from_async_generator,
130
    make_async,
131
132
    merge_async_iterators,
)
133
from vllm.utils.collection_utils import is_list_of
134
from vllm.v1.engine import EngineCoreRequest
135

136
137
138
139
140
141
142
143
144

class GenerationError(Exception):
    """raised when finish_reason indicates internal server error (500)"""

    def __init__(self, message: str = "Internal server error"):
        super().__init__(message)
        self.status_code = HTTPStatus.INTERNAL_SERVER_ERROR


145
146
logger = init_logger(__name__)

147
148
149
150
151
CompletionLikeRequest: TypeAlias = (
    CompletionRequest
    | DetokenizeRequest
    | EmbeddingCompletionRequest
    | RerankRequest
152
    | ClassificationCompletionRequest
153
154
155
    | ScoreRequest
    | TokenizeCompletionRequest
)
156

157
ChatLikeRequest: TypeAlias = (
158
159
160
161
    ChatCompletionRequest
    | EmbeddingChatRequest
    | TokenizeChatRequest
    | ClassificationChatRequest
162
163
164
165
166
167
168
169
)
SpeechToTextRequest: TypeAlias = TranscriptionRequest | TranslationRequest
AnyRequest: TypeAlias = (
    CompletionLikeRequest
    | ChatLikeRequest
    | SpeechToTextRequest
    | ResponsesRequest
    | IOProcessorRequest
170
    | GenerateRequest
171
172
173
174
175
176
177
178
179
180
181
)

AnyResponse: TypeAlias = (
    CompletionResponse
    | ChatCompletionResponse
    | EmbeddingResponse
    | TranscriptionResponse
    | TokenizeResponse
    | PoolingResponse
    | ClassificationResponse
    | ScoreResponse
182
    | GenerateResponse
183
)
184

185
186
187

class TextTokensPrompt(TypedDict):
    prompt: str
188
    prompt_token_ids: list[int]
189
190


191
192
193
194
class EmbedsPrompt(TypedDict):
    prompt_embeds: torch.Tensor


195
RequestPrompt: TypeAlias = list[int] | str | TextTokensPrompt | EmbedsPrompt
196
197
198


def is_text_tokens_prompt(prompt: RequestPrompt) -> TypeIs[TextTokensPrompt]:
199
200
201
202
203
    return (
        isinstance(prompt, dict)
        and "prompt_token_ids" in prompt
        and "prompt_embeds" not in prompt
    )
204
205
206


def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]:
207
208
209
210
211
    return (
        isinstance(prompt, dict)
        and "prompt_token_ids" not in prompt
        and "prompt_embeds" in prompt
    )
212

213

214
215
216
RequestT = TypeVar("RequestT", bound=AnyRequest)


217
218
@dataclass(kw_only=True)
class RequestProcessingMixin:
219
    """
220
    Mixin for request processing,
221
222
    handling prompt preparation and engine input.
    """
223

224
225
    request_prompts: Sequence[RequestPrompt] | None = field(default_factory=list)
    engine_prompts: list[EngineTokensPrompt] | None = field(default_factory=list)
226
227


228
229
@dataclass(kw_only=True)
class ResponseGenerationMixin:
230
    """
231
    Mixin for response generation,
232
233
    managing result generators and final batch results.
    """
234

235
236
237
    result_generator: (
        AsyncGenerator[tuple[int, RequestOutput | PoolingRequestOutput], None] | None
    ) = None
238
    final_res_batch: list[RequestOutput | PoolingRequestOutput] = field(
239
240
        default_factory=list
    )
241
242
243
244

    model_config = ConfigDict(arbitrary_types_allowed=True)


245
246
@dataclass(kw_only=True)
class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, Generic[RequestT]):
247
248
    # Shared across all requests
    request: RequestT
249
    raw_request: Request | None = None
250
251
    model_name: str
    request_id: str
252
    created_time: int = field(default_factory=lambda: int(time.time()))
253
    lora_request: LoRARequest | None = None
254
255

    # Shared across most requests
256
    tokenizer: TokenizerLike | None = None
257
258


259
260
261
@dataclass(kw_only=True)
class ClassificationServeContext(ServeContext[ClassificationRequest]):
    pass
262
263


264
@dataclass(kw_only=True)
265
class EmbeddingServeContext(ServeContext[EmbeddingRequest]):
266
    chat_template: str | None = None
267
268
269
    chat_template_content_format: ChatTemplateContentFormatOption


270
class OpenAIServing:
271
272
273
274
    request_id_prefix: ClassVar[str] = """
    A short string prepended to every request’s ID (e.g. "embd", "classify")
    so you can easily tell “this ID came from Embedding vs Classification.”
    """
275

276
277
    def __init__(
        self,
278
        engine_client: EngineClient,
279
        models: OpenAIServingModels,
280
        *,
281
        request_logger: RequestLogger | None,
282
        return_tokens_as_token_ids: bool = False,
283
        log_error_stack: bool = False,
284
    ):
285
286
        super().__init__()

287
        self.engine_client = engine_client
288

289
        self.models = models
290

291
        self.request_logger = request_logger
292
        self.return_tokens_as_token_ids = return_tokens_as_token_ids
293
        self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
294
        self._apply_mistral_chat_template_async = make_async(
295
296
            apply_mistral_chat_template, executor=self._tokenizer_executor
        )
297

298
        self._async_tokenizer_pool: dict[TokenizerLike, AsyncMicrobatchTokenizer] = {}
299
        self.log_error_stack = log_error_stack
300

301
        self.input_processor = self.models.input_processor
302
303
304
305
        self.io_processor = self.models.io_processor
        self.model_config = self.models.model_config
        self.max_model_len = self.model_config.max_model_len

306
    def _get_tool_parser(
307
        self, tool_parser_name: str | None = None, enable_auto_tools: bool = False
308
    ) -> Callable[[TokenizerLike], ToolParser] | None:
309
310
311
312
        """Get the tool parser based on the name."""
        parser = None
        if not enable_auto_tools or tool_parser_name is None:
            return parser
313
        logger.info('"auto" tool choice has been enabled.')
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333

        try:
            if tool_parser_name == "pythonic" and self.model_config.model.startswith(
                "meta-llama/Llama-3.2"
            ):
                logger.warning(
                    "Llama3.2 models may struggle to emit valid pythonic tool calls"
                )
            parser = ToolParserManager.get_tool_parser(tool_parser_name)
        except Exception as e:
            raise TypeError(
                "Error: --enable-auto-tool-choice requires "
                f"tool_parser:'{tool_parser_name}' which has not "
                "been registered"
            ) from e
        return parser

    def _get_reasoning_parser(
        self,
        reasoning_parser_name: str,
334
    ) -> Callable[[TokenizerLike], ReasoningParser] | None:
335
336
337
338
339
340
341
342
343
344
345
        """Get the reasoning parser based on the name."""
        parser = None
        if not reasoning_parser_name:
            return None
        try:
            parser = ReasoningParserManager.get_reasoning_parser(reasoning_parser_name)
            assert parser is not None
        except Exception as e:
            raise TypeError(f"{reasoning_parser_name=} has not been registered") from e
        return parser

346
    async def reset_mm_cache(self) -> None:
347
        self.input_processor.clear_mm_cache()
348
349
        await self.engine_client.reset_mm_cache()

350
351
352
353
354
    async def beam_search(
        self,
        prompt: PromptType,
        request_id: str,
        params: BeamSearchParams,
355
        lora_request: LoRARequest | None = None,
356
        trace_headers: Mapping[str, str] | None = None,
357
358
359
360
361
362
363
364
    ) -> AsyncGenerator[RequestOutput, None]:
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        ignore_eos = params.ignore_eos
        temperature = params.temperature
        length_penalty = params.length_penalty
        include_stop_str_in_output = params.include_stop_str_in_output

365
366
        input_processor = self.input_processor
        tokenizer = input_processor.tokenizer
367
368
        if tokenizer is None:
            raise ValueError(
369
                "You cannot use beam search when `skip_tokenizer_init=True`"
370
371
372
373
374
375
376
            )

        eos_token_id: int = tokenizer.eos_token_id  # type: ignore

        if is_explicit_encoder_decoder_prompt(prompt):
            raise NotImplementedError

377
        prompt_text: str | None
378
        prompt_token_ids: list[int]
379
        multi_modal_data: MultiModalDataDict | None
380
381
382
383
384
385
386
387
388
        if isinstance(prompt, str):
            prompt_text = prompt
            prompt_token_ids = []
            multi_modal_data = None
        else:
            prompt_text = prompt.get("prompt")  # type: ignore
            prompt_token_ids = prompt.get("prompt_token_ids", [])  # type: ignore
            multi_modal_data = prompt.get("multi_modal_data")  # type: ignore

389
390
391
392
393
394
395
396
397
398
        mm_processor_kwargs: dict[str, Any] | None = None

        # This is a workaround to fix multimodal beam search; this is a
        # bandaid fix for 2 small problems:
        # 1. Multi_modal_data on the processed_inputs currently resolves to
        #    `None`.
        # 2. preprocessing above expands the multimodal placeholders. However,
        #    this happens again in generation, so the double expansion causes
        #    a mismatch.
        # TODO - would be ideal to handle this more gracefully.
399
400
401
402
403

        tokenized_length = len(prompt_token_ids)

        sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)

404
        logprobs_num = 2 * beam_width
405
        beam_search_params = SamplingParams(
406
            logprobs=logprobs_num,
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
            max_tokens=1,
            temperature=temperature,
        )
        all_beams = [
            BeamSearchSequence(
                tokens=prompt_token_ids,
                cum_logprob=0,
                logprobs=[],
                multi_modal_data=multi_modal_data,
                mm_processor_kwargs=mm_processor_kwargs,
                lora_request=lora_request,
            )
        ]
        completed = []

        for _ in range(max_tokens):
            prompts_batch, lora_req_batch = zip(
                *[
                    (
                        EngineTokensPrompt(
                            prompt_token_ids=beam.tokens,
                            multi_modal_data=beam.multi_modal_data,
                            mm_processor_kwargs=beam.mm_processor_kwargs,
                        ),
                        beam.lora_request,
                    )
                    for beam in all_beams
                ]
            )

            tasks = []
            request_id_batch = f"{request_id}-{random_uuid()}"

            for i, (individual_prompt, lora_req) in enumerate(
                zip(prompts_batch, lora_req_batch)
            ):
                request_id_item = f"{request_id_batch}-beam-{i}"
                task = asyncio.create_task(
                    collect_from_async_generator(
                        self.engine_client.generate(
                            individual_prompt,
                            beam_search_params,
                            request_id_item,
                            lora_request=lora_req,
451
                            trace_headers=trace_headers,
452
453
454
455
456
457
458
459
                        )
                    )
                )
                tasks.append(task)

            output = [x[0] for x in await asyncio.gather(*tasks)]

            new_beams = []
460
461
462
463
464
465
466
467
            # Store all new tokens generated by beam
            all_beams_token_id = []
            # Store the cumulative probability of all tokens
            # generated by beam search
            all_beams_logprob = []
            # Iterate through all beam inference results
            for i, result in enumerate(output):
                current_beam = all_beams[i]
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490

                # check for error finish reason and abort beam search
                if result.outputs[0].finish_reason == "error":
                    # yield error output and terminate beam search
                    yield RequestOutput(
                        request_id=request_id,
                        prompt=prompt_text,
                        outputs=[
                            CompletionOutput(
                                index=0,
                                text="",
                                token_ids=[],
                                cumulative_logprob=None,
                                logprobs=None,
                                finish_reason="error",
                            )
                        ],
                        finished=True,
                        prompt_token_ids=prompt_token_ids,
                        prompt_logprobs=None,
                    )
                    return

491
492
                if result.outputs[0].logprobs is not None:
                    logprobs = result.outputs[0].logprobs[0]
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
                    all_beams_token_id.extend(list(logprobs.keys()))
                    all_beams_logprob.extend(
                        [
                            current_beam.cum_logprob + obj.logprob
                            for obj in logprobs.values()
                        ]
                    )

            # Handle the token for the end of sentence (EOS)
            all_beams_token_id = np.array(all_beams_token_id)
            all_beams_logprob = np.array(all_beams_logprob)

            if not ignore_eos:
                # Get the index position of eos token in all generated results
                eos_idx = np.where(all_beams_token_id == eos_token_id)[0]
                for idx in eos_idx:
                    current_beam = all_beams[idx // logprobs_num]
                    result = output[idx // logprobs_num]
                    assert result.outputs[0].logprobs is not None
                    logprobs_entry = result.outputs[0].logprobs[0]
                    completed.append(
                        BeamSearchSequence(
                            tokens=current_beam.tokens + [eos_token_id]
                            if include_stop_str_in_output
                            else current_beam.tokens,
                            logprobs=current_beam.logprobs + [logprobs_entry],
                            cum_logprob=float(all_beams_logprob[idx]),
                            finish_reason="stop",
                            stop_reason=eos_token_id,
                        )
                    )
                # After processing, set the log probability of the eos condition
                # to negative infinity.
                all_beams_logprob[eos_idx] = -np.inf

            # Processing non-EOS tokens
            # Get indices of the top beam_width probabilities
            topn_idx = np.argpartition(np.negative(all_beams_logprob), beam_width)[
                :beam_width
            ]

            for idx in topn_idx:
                current_beam = all_beams[idx // logprobs_num]
                result = output[idx // logprobs_num]
                token_id = int(all_beams_token_id[idx])
                assert result.outputs[0].logprobs is not None
                logprobs_entry = result.outputs[0].logprobs[0]
                new_beams.append(
                    BeamSearchSequence(
                        tokens=current_beam.tokens + [token_id],
                        logprobs=current_beam.logprobs + [logprobs_entry],
                        lora_request=current_beam.lora_request,
                        cum_logprob=float(all_beams_logprob[idx]),
                        multi_modal_data=current_beam.multi_modal_data,
                        mm_processor_kwargs=current_beam.mm_processor_kwargs,
                    )
                )

            all_beams = new_beams
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585

        completed.extend(all_beams)
        sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
        best_beams = sorted_completed[:beam_width]

        for beam in best_beams:
            if beam.tokens[-1] == eos_token_id and not ignore_eos:
                # Skip the eos token in the text.
                tokens = beam.tokens[tokenized_length:-1]
            else:
                tokens = beam.tokens[tokenized_length:]
            beam.text = tokenizer.decode(tokens)

        yield RequestOutput(
            request_id=request_id,
            prompt=prompt_text,
            outputs=[
                CompletionOutput(
                    text=beam.text,  # type: ignore
                    cumulative_logprob=beam.cum_logprob,
                    token_ids=beam.tokens[tokenized_length:],
                    index=i,
                    logprobs=beam.logprobs,
                    finish_reason=beam.finish_reason
                    if beam.finish_reason is not None
                    else "length",
                    stop_reason=beam.stop_reason,
                )
                for (i, beam) in enumerate(best_beams)
            ],
            finished=True,
            prompt_token_ids=prompt_token_ids,
            prompt_logprobs=None,
        )
586

587
    def _get_renderer(self, tokenizer: TokenizerLike | None) -> BaseRenderer:
588
589
590
591
592
593
594
        """
        Get a Renderer instance with the provided tokenizer.
        Uses shared async tokenizer pool for efficiency.
        """
        return CompletionRenderer(
            model_config=self.model_config,
            tokenizer=tokenizer,
595
596
            async_tokenizer_pool=self._async_tokenizer_pool,
        )
597

598
599
600
601
602
603
604
605
606
607
608
609
610
    def _build_render_config(
        self,
        request: Any,
    ) -> RenderConfig:
        """
        Build and return a `RenderConfig` for an endpoint.

        Used by the renderer to control how prompts are prepared
        (e.g., tokenization and length handling). Endpoints should
        implement this with logic appropriate to their request type.
        """
        raise NotImplementedError

611
612
    def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer:
        """
613
        Return (and cache) an `AsyncMicrobatchTokenizer` bound to the
614
615
616
617
618
619
620
        given tokenizer.
        """
        async_tokenizer = self._async_tokenizer_pool.get(tokenizer)
        if async_tokenizer is None:
            async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
            self._async_tokenizer_pool[tokenizer] = async_tokenizer
        return async_tokenizer
621

622
623
624
    async def _preprocess(
        self,
        ctx: ServeContext,
625
    ) -> ErrorResponse | None:
626
627
628
629
630
631
632
633
634
        """
        Default preprocessing hook. Subclasses may override
        to prepare `ctx` (classification, embedding, etc.).
        """
        return None

    def _build_response(
        self,
        ctx: ServeContext,
635
    ) -> AnyResponse | ErrorResponse:
636
637
638
639
640
641
642
643
644
        """
        Default response builder. Subclass may override this method
        to return the appropriate response object.
        """
        return self.create_error_response("unimplemented endpoint")

    async def handle(
        self,
        ctx: ServeContext,
645
646
    ) -> AnyResponse | ErrorResponse:
        generation: AsyncGenerator[AnyResponse | ErrorResponse, None]
647
648
649
650
651
652
653
654
655
656
        generation = self._pipeline(ctx)

        async for response in generation:
            return response

        return self.create_error_response("No response yielded from pipeline")

    async def _pipeline(
        self,
        ctx: ServeContext,
657
    ) -> AsyncGenerator[AnyResponse | ErrorResponse, None]:
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
        """Execute the request processing pipeline yielding responses."""
        if error := await self._check_model(ctx.request):
            yield error
        if error := self._validate_request(ctx):
            yield error

        preprocess_ret = await self._preprocess(ctx)
        if isinstance(preprocess_ret, ErrorResponse):
            yield preprocess_ret

        generators_ret = await self._prepare_generators(ctx)
        if isinstance(generators_ret, ErrorResponse):
            yield generators_ret

        collect_ret = await self._collect_batch(ctx)
        if isinstance(collect_ret, ErrorResponse):
            yield collect_ret

        yield self._build_response(ctx)

678
    def _validate_request(self, ctx: ServeContext) -> ErrorResponse | None:
679
        truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None)
680

681
682
683
684
        if (
            truncate_prompt_tokens is not None
            and truncate_prompt_tokens > self.max_model_len
        ):
685
686
687
            return self.create_error_response(
                "truncate_prompt_tokens value is "
                "greater than max_model_len."
688
689
                " Please, select a smaller truncation size."
            )
690
691
        return None

692
693
694
    def _create_pooling_params(
        self,
        ctx: ServeContext,
695
    ) -> PoolingParams | ErrorResponse:
696
697
        if not hasattr(ctx.request, "to_pooling_params"):
            return self.create_error_response(
698
699
                "Request type does not support pooling parameters"
            )
700
701
702

        return ctx.request.to_pooling_params()

703
704
705
    async def _prepare_generators(
        self,
        ctx: ServeContext,
706
    ) -> ErrorResponse | None:
707
        """Schedule the request and get the result generator."""
708
        generators: list[
709
            AsyncGenerator[RequestOutput | PoolingRequestOutput, None]
710
        ] = []
711
712

        try:
713
714
715
716
717
            trace_headers = (
                None
                if ctx.raw_request is None
                else await self._get_trace_headers(ctx.raw_request.headers)
            )
718

719
720
721
            pooling_params = self._create_pooling_params(ctx)
            if isinstance(pooling_params, ErrorResponse):
                return pooling_params
722
723

            if ctx.engine_prompts is None:
724
                return self.create_error_response("Engine prompts not available")
725
726
727
728

            for i, engine_prompt in enumerate(ctx.engine_prompts):
                request_id_item = f"{ctx.request_id}-{i}"

729
730
                self._log_inputs(
                    request_id_item,
731
                    engine_prompt,
732
733
734
                    params=pooling_params,
                    lora_request=ctx.lora_request,
                )
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757

                generator = self.engine_client.encode(
                    engine_prompt,
                    pooling_params,
                    request_id_item,
                    lora_request=ctx.lora_request,
                    trace_headers=trace_headers,
                    priority=getattr(ctx.request, "priority", 0),
                )

                generators.append(generator)

            ctx.result_generator = merge_async_iterators(*generators)

            return None

        except Exception as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))

    async def _collect_batch(
        self,
        ctx: ServeContext,
758
    ) -> ErrorResponse | None:
759
760
761
        """Collect batch results from the result generator."""
        try:
            if ctx.engine_prompts is None:
762
                return self.create_error_response("Engine prompts not available")
763
764

            num_prompts = len(ctx.engine_prompts)
765
            final_res_batch: list[RequestOutput | PoolingRequestOutput | None]
766
767
768
            final_res_batch = [None] * num_prompts

            if ctx.result_generator is None:
769
                return self.create_error_response("Result generator not available")
770
771
772
773
774
775

            async for i, res in ctx.result_generator:
                final_res_batch[i] = res

            if None in final_res_batch:
                return self.create_error_response(
776
777
                    "Failed to generate results for all prompts"
                )
778

779
            ctx.final_res_batch = [res for res in final_res_batch if res is not None]
780
781
782
783
784
785

            return None

        except Exception as e:
            return self.create_error_response(str(e))

786
    def create_error_response(
787
788
789
790
791
        self,
        message: str,
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
    ) -> ErrorResponse:
792
793
794
795
796
797
        if self.log_error_stack:
            exc_type, _, _ = sys.exc_info()
            if exc_type is not None:
                traceback.print_exc()
            else:
                traceback.print_stack()
798
799
800
        return ErrorResponse(
            error=ErrorInfo(message=message, type=err_type, code=status_code.value)
        )
801

802
    def create_streaming_error_response(
803
804
805
806
807
        self,
        message: str,
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
    ) -> str:
808
        json_str = json.dumps(
809
810
811
812
            self.create_error_response(
                message=message, err_type=err_type, status_code=status_code
            ).model_dump()
        )
813
814
        return json_str

815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
    def _raise_if_error(self, finish_reason: str | None, request_id: str) -> None:
        """Raise GenerationError if finish_reason indicates an error."""
        if finish_reason == "error":
            logger.error(
                "Request %s failed with an internal error during generation",
                request_id,
            )
            raise GenerationError("Internal server error")

    def _convert_generation_error_to_response(
        self, e: GenerationError
    ) -> ErrorResponse:
        """Convert GenerationError to ErrorResponse."""
        return self.create_error_response(
            str(e),
            err_type="InternalServerError",
            status_code=e.status_code,
        )

    def _convert_generation_error_to_streaming_response(
        self, e: GenerationError
    ) -> str:
        """Convert GenerationError to streaming error response."""
        return self.create_streaming_error_response(
            str(e),
            err_type="InternalServerError",
            status_code=e.status_code,
        )

844
    async def _check_model(
845
846
        self,
        request: AnyRequest,
847
    ) -> ErrorResponse | None:
848
849
        error_response = None

850
        if self._is_model_supported(request.model):
851
            return None
852
        if request.model in self.models.lora_requests:
853
            return None
854
855
856
857
858
        if (
            envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING
            and request.model
            and (load_result := await self.models.resolve_lora(request.model))
        ):
859
860
            if isinstance(load_result, LoRARequest):
                return None
861
862
863
864
            if (
                isinstance(load_result, ErrorResponse)
                and load_result.error.code == HTTPStatus.BAD_REQUEST.value
            ):
865
866
867
                error_response = load_result

        return error_response or self.create_error_response(
868
869
            message=f"The model `{request.model}` does not exist.",
            err_type="NotFoundError",
870
871
            status_code=HTTPStatus.NOT_FOUND,
        )
872

873
    def _get_active_default_mm_loras(self, request: AnyRequest) -> LoRARequest | None:
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
        """Determine if there are any active default multimodal loras."""
        # TODO: Currently this is only enabled for chat completions
        # to be better aligned with only being enabled for .generate
        # when run offline. It would be nice to support additional
        # tasks types in the future.
        message_types = self._get_message_types(request)
        default_mm_loras = set()

        for lora in self.models.lora_requests.values():
            # Best effort match for default multimodal lora adapters;
            # There is probably a better way to do this, but currently
            # this matches against the set of 'types' in any content lists
            # up until '_', e.g., to match audio_url -> audio
            if lora.lora_name in message_types:
                default_mm_loras.add(lora)

        # Currently only support default modality specific loras if
        # we have exactly one lora matched on the request.
        if len(default_mm_loras) == 1:
            return default_mm_loras.pop()
        return None

896
    def _maybe_get_adapters(
897
898
899
        self,
        request: AnyRequest,
        supports_default_mm_loras: bool = False,
900
    ) -> LoRARequest | None:
901
        if request.model in self.models.lora_requests:
902
            return self.models.lora_requests[request.model]
903
904
905
906
907
908

        # Currently only support default modality specific loras
        # if we have exactly one lora matched on the request.
        if supports_default_mm_loras:
            default_mm_lora = self._get_active_default_mm_loras(request)
            if default_mm_lora is not None:
909
                return default_mm_lora
910
911

        if self._is_model_supported(request.model):
912
            return None
913

914
        # if _check_model has been called earlier, this will be unreachable
915
        raise ValueError(f"The model `{request.model}` does not exist.")
916

917
918
919
920
921
922
923
924
925
926
    def _get_message_types(self, request: AnyRequest) -> set[str]:
        """Retrieve the set of types from message content dicts up
        until `_`; we use this to match potential multimodal data
        with default per modality loras.
        """
        message_types: set[str] = set()

        if not hasattr(request, "messages"):
            return message_types

927
928
929
930
931
        messages = request.messages
        if messages is None or isinstance(messages, (str, bytes)):
            return message_types

        for message in messages:
932
933
934
935
936
            if (
                isinstance(message, dict)
                and "content" in message
                and isinstance(message["content"], list)
            ):
937
938
939
940
941
                for content_dict in message["content"]:
                    if "type" in content_dict:
                        message_types.add(content_dict["type"].split("_")[0])
        return message_types

942
    async def _normalize_prompt_text_to_input(
943
944
945
        self,
        request: AnyRequest,
        prompt: str,
946
        tokenizer: TokenizerLike,
947
948
        add_special_tokens: bool,
    ) -> TextTokensPrompt:
949
950
        async_tokenizer = self._get_async_tokenizer(tokenizer)

951
952
953
954
        if (
            self.model_config.encoder_config is not None
            and self.model_config.encoder_config.get("do_lower_case", False)
        ):
955
956
            prompt = prompt.lower()

957
        truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
958

959
        if truncate_prompt_tokens is None:
960
            encoded = await async_tokenizer(
961
962
                prompt, add_special_tokens=add_special_tokens
            )
963
964
        elif truncate_prompt_tokens < 0:
            # Negative means we cap at the model's max length
965
966
967
968
            encoded = await async_tokenizer(
                prompt,
                add_special_tokens=add_special_tokens,
                truncation=True,
969
970
                max_length=self.max_model_len,
            )
971
        else:
972
973
974
975
            encoded = await async_tokenizer(
                prompt,
                add_special_tokens=add_special_tokens,
                truncation=True,
976
977
                max_length=truncate_prompt_tokens,
            )
978
979
980
981
982
983

        input_ids = encoded.input_ids
        input_text = prompt

        return self._validate_input(request, input_ids, input_text)

984
    async def _normalize_prompt_tokens_to_input(
985
986
        self,
        request: AnyRequest,
987
        prompt_ids: list[int],
988
        tokenizer: TokenizerLike | None,
989
    ) -> TextTokensPrompt:
990
        truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
991

992
        if truncate_prompt_tokens is None:
993
            input_ids = prompt_ids
994
        elif truncate_prompt_tokens < 0:
995
            input_ids = prompt_ids[-self.max_model_len :]
996
997
998
        else:
            input_ids = prompt_ids[-truncate_prompt_tokens:]

999
1000
1001
1002
1003
        if tokenizer is None:
            input_text = ""
        else:
            async_tokenizer = self._get_async_tokenizer(tokenizer)
            input_text = await async_tokenizer.decode(input_ids)
1004

1005
1006
1007
1008
1009
        return self._validate_input(request, input_ids, input_text)

    def _validate_input(
        self,
        request: AnyRequest,
1010
        input_ids: list[int],
1011
1012
        input_text: str,
    ) -> TextTokensPrompt:
1013
1014
        token_num = len(input_ids)

1015
1016
        # Note: EmbeddingRequest, ClassificationRequest,
        # and ScoreRequest doesn't have max_tokens
1017
        if isinstance(
1018
            request,
1019
1020
1021
1022
1023
            (
                EmbeddingChatRequest,
                EmbeddingCompletionRequest,
                ScoreRequest,
                RerankRequest,
1024
1025
                ClassificationCompletionRequest,
                ClassificationChatRequest,
1026
1027
            ),
        ):
1028
1029
            # Note: input length can be up to the entire model context length
            # since these requests don't generate tokens.
1030
            if token_num > self.max_model_len:
1031
1032
                operations: dict[type[AnyRequest], str] = {
                    ScoreRequest: "score",
1033
1034
                    ClassificationCompletionRequest: "classification",
                    ClassificationChatRequest: "classification",
1035
                }
1036
                operation = operations.get(type(request), "embedding generation")
1037
1038
1039
                raise ValueError(
                    f"This model's maximum context length is "
                    f"{self.max_model_len} tokens. However, you requested "
1040
                    f"{token_num} tokens in the input for {operation}. "
1041
1042
1043
                    f"Please reduce the length of the input."
                )
            return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
1044

1045
1046
        # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
        # and does not require model context length validation
1047
        if isinstance(
1048
1049
            request,
            (TokenizeCompletionRequest, TokenizeChatRequest, DetokenizeRequest),
1050
        ):
1051
            return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
1052

1053
1054
1055
1056
1057
        # chat completion endpoint supports max_completion_tokens
        if isinstance(request, ChatCompletionRequest):
            # TODO(#9845): remove max_tokens when field dropped from OpenAI API
            max_tokens = request.max_completion_tokens or request.max_tokens
        else:
1058
            max_tokens = getattr(request, "max_tokens", None)
1059
1060
1061
1062

        # Note: input length can be up to model context length - 1 for
        # completion-like requests.
        if token_num >= self.max_model_len:
1063
            raise ValueError(
1064
                f"This model's maximum context length is "
1065
1066
                f"{self.max_model_len} tokens. However, your request has "
                f"{token_num} input tokens. Please reduce the length of "
1067
1068
                "the input messages."
            )
1069

1070
        if max_tokens is not None and token_num + max_tokens > self.max_model_len:
1071
1072
1073
1074
1075
            raise ValueError(
                "'max_tokens' or 'max_completion_tokens' is too large: "
                f"{max_tokens}. This model's maximum context length is "
                f"{self.max_model_len} tokens and your request has "
                f"{token_num} input tokens ({max_tokens} > {self.max_model_len}"
1076
1077
                f" - {token_num})."
            )
1078
1079
1080

        return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)

1081
    async def _tokenize_prompt_input_async(
1082
1083
        self,
        request: AnyRequest,
1084
        tokenizer: TokenizerLike,
1085
        prompt_input: str | list[int],
1086
1087
1088
        add_special_tokens: bool = True,
    ) -> TextTokensPrompt:
        """
1089
        A simpler implementation that tokenizes a single prompt input.
1090
        """
1091
        async for result in self._tokenize_prompt_inputs_async(
1092
1093
            request,
            tokenizer,
1094
            [prompt_input],
1095
            add_special_tokens=add_special_tokens,
1096
1097
1098
        ):
            return result
        raise ValueError("No results yielded from tokenization")
1099

1100
    async def _tokenize_prompt_inputs_async(
1101
1102
        self,
        request: AnyRequest,
1103
        tokenizer: TokenizerLike,
1104
        prompt_inputs: Iterable[str | list[int]],
1105
        add_special_tokens: bool = True,
1106
    ) -> AsyncGenerator[TextTokensPrompt, None]:
1107
        """
1108
        A simpler implementation that tokenizes multiple prompt inputs.
1109
        """
1110
1111
        for prompt in prompt_inputs:
            if isinstance(prompt, str):
1112
                yield await self._normalize_prompt_text_to_input(
1113
                    request,
1114
1115
                    prompt=prompt,
                    tokenizer=tokenizer,
1116
1117
1118
                    add_special_tokens=add_special_tokens,
                )
            else:
1119
                yield await self._normalize_prompt_tokens_to_input(
1120
                    request,
1121
1122
                    prompt_ids=prompt,
                    tokenizer=tokenizer,
1123
1124
                )

1125
1126
    def _validate_chat_template(
        self,
1127
1128
        request_chat_template: str | None,
        chat_template_kwargs: dict[str, Any] | None,
1129
        trust_request_chat_template: bool,
1130
    ) -> ErrorResponse | None:
1131
        if not trust_request_chat_template and (
1132
1133
1134
1135
1136
1137
            request_chat_template is not None
            or (
                chat_template_kwargs
                and chat_template_kwargs.get("chat_template") is not None
            )
        ):
1138
1139
1140
            return self.create_error_response(
                "Chat template is passed with request, but "
                "--trust-request-chat-template is not set. "
1141
1142
                "Refused request with untrusted chat template."
            )
1143
1144
        return None

1145
1146
    async def _preprocess_chat(
        self,
1147
        request: ChatLikeRequest | ResponsesRequest,
1148
        tokenizer: TokenizerLike | None,
1149
        messages: list[ChatCompletionMessageParam],
1150
        chat_template: str | None,
1151
        chat_template_content_format: ChatTemplateContentFormatOption,
1152
1153
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
1154
1155
1156
        tool_dicts: list[dict[str, Any]] | None = None,
        documents: list[dict[str, str]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
1157
        tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
1158
        add_special_tokens: bool = False,
1159
    ) -> tuple[
1160
1161
1162
        list[ConversationMessage],
        Sequence[RequestPrompt],
        list[EngineTokensPrompt],
1163
    ]:
1164
        model_config = self.model_config
1165

1166
1167
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
1168
            tool_dicts,
1169
1170
            chat_template_content_format,
            tokenizer,
1171
            model_config=model_config,
1172
        )
1173
        conversation, mm_data_future, mm_uuids = parse_chat_messages_futures(
1174
            messages,
1175
            model_config,
1176
            content_format=resolved_content_format,
1177
1178
        )

1179
        _chat_template_kwargs: dict[str, Any] = dict(
1180
1181
1182
1183
1184
1185
1186
1187
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tool_dicts,
            documents=documents,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

1188
        request_prompt: str | list[int]
1189
1190
1191
1192

        if tokenizer is None:
            request_prompt = "placeholder"
        elif isinstance(tokenizer, MistralTokenizer):
1193
            request_prompt = await self._apply_mistral_chat_template_async(
1194
1195
                tokenizer,
                messages=messages,
1196
                **_chat_template_kwargs,
1197
            )
1198
1199
1200
1201
        elif isinstance(tokenizer, DeepseekV32Tokenizer):
            request_prompt = tokenizer.apply_chat_template(
                conversation=conversation,
                messages=messages,
1202
                model_config=model_config,
1203
1204
                **_chat_template_kwargs,
            )
1205
1206
        else:
            request_prompt = apply_hf_chat_template(
1207
                tokenizer=tokenizer,
1208
                conversation=conversation,
1209
                model_config=model_config,
1210
                **_chat_template_kwargs,
1211
1212
1213
1214
            )

        mm_data = await mm_data_future

1215
1216
1217
        # tool parsing is done only if a tool_parser has been set and if
        # tool_choice is not "none" (if tool_choice is "none" but a tool_parser
        # is set, we want to prevent parsing a tool_call hallucinated by the LLM
1218
1219
1220
        should_parse_tools = tool_parser is not None and (
            hasattr(request, "tool_choice") and request.tool_choice != "none"
        )
1221
1222

        if should_parse_tools:
1223
1224
1225
1226
1227
            if not isinstance(request, ChatCompletionRequest | ResponsesRequest):
                msg = (
                    "Tool usage is only supported for Chat Completions API "
                    "or Responses API requests."
                )
1228
                raise NotImplementedError(msg)
1229
            request = tool_parser(tokenizer).adjust_request(request=request)  # type: ignore
1230

1231
1232
        if tokenizer is None:
            assert isinstance(request_prompt, str), (
1233
1234
                "Prompt has to be a string",
                "when the tokenizer is not initialised",
1235
            )
1236
1237
1238
            prompt_inputs = TextTokensPrompt(
                prompt=request_prompt, prompt_token_ids=[1]
            )
1239
        elif isinstance(request_prompt, str):
1240
            prompt_inputs = await self._tokenize_prompt_input_async(
1241
1242
1243
1244
1245
1246
1247
1248
                request,
                tokenizer,
                request_prompt,
                add_special_tokens=add_special_tokens,
            )
        else:
            # For MistralTokenizer
            assert is_list_of(request_prompt, int), (
1249
1250
                "Prompt has to be either a string or a list of token ids"
            )
1251
1252
            prompt_inputs = TextTokensPrompt(
                prompt=tokenizer.decode(request_prompt),
1253
1254
                prompt_token_ids=request_prompt,
            )
1255

1256
        engine_prompt = EngineTokensPrompt(
1257
1258
            prompt_token_ids=prompt_inputs["prompt_token_ids"]
        )
1259
1260
        if mm_data is not None:
            engine_prompt["multi_modal_data"] = mm_data
1261
1262
1263
1264

        if mm_uuids is not None:
            engine_prompt["multi_modal_uuids"] = mm_uuids

1265
1266
        if request.mm_processor_kwargs is not None:
            engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
1267

1268
1269
1270
        if hasattr(request, "cache_salt") and request.cache_salt is not None:
            engine_prompt["cache_salt"] = request.cache_salt

1271
1272
        return conversation, [request_prompt], [engine_prompt]

1273
1274
1275
1276
    async def _process_inputs(
        self,
        request_id: str,
        engine_prompt: PromptType,
1277
        params: SamplingParams | PoolingParams,
1278
        *,
1279
1280
        lora_request: LoRARequest | None,
        trace_headers: Mapping[str, str] | None,
1281
1282
        priority: int,
    ) -> tuple[EngineCoreRequest, dict[str, Any]]:
1283
        """Use the Processor to process inputs for AsyncLLM."""
1284
        tokenization_kwargs: dict[str, Any] = {}
1285
1286
1287
        _validate_truncation_size(
            self.max_model_len, params.truncate_prompt_tokens, tokenization_kwargs
        )
1288

1289
        engine_request = self.input_processor.process_inputs(
1290
1291
            request_id,
            engine_prompt,
1292
            params,
1293
1294
1295
1296
1297
1298
1299
            lora_request=lora_request,
            tokenization_kwargs=tokenization_kwargs,
            trace_headers=trace_headers,
            priority=priority,
        )
        return engine_request, tokenization_kwargs

1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
    async def _render_next_turn(
        self,
        request: ResponsesRequest,
        tokenizer: AnyTokenizer,
        messages: list[ResponseInputOutputItem],
        tool_dicts: list[dict[str, Any]] | None,
        tool_parser,
        chat_template: str | None,
        chat_template_content_format: ChatTemplateContentFormatOption,
    ):
        new_messages = construct_input_messages(
            request_input=messages,
        )

        _, request_prompts, engine_prompts = await self._preprocess_chat(
            request,
            tokenizer,
            new_messages,
            tool_dicts=tool_dicts,
            tool_parser=tool_parser,
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
        )
        return request_prompts, engine_prompts

1325
1326
1327
1328
1329
1330
1331
    async def _generate_with_builtin_tools(
        self,
        request_id: str,
        request_prompt: RequestPrompt,
        engine_prompt: EngineTokensPrompt,
        sampling_params: SamplingParams,
        context: ConversationContext,
1332
        lora_request: LoRARequest | None = None,
1333
1334
1335
        priority: int = 0,
        **kwargs,
    ):
1336
        prompt_text, _, _ = self._get_prompt_components(request_prompt)
1337
        orig_priority = priority
1338
        sub_request = 0
1339
        while True:
1340
1341
            # Ensure that each sub-request has a unique request id.
            sub_request_id = f"{request_id}_{sub_request}"
1342
            self._log_inputs(
1343
                sub_request_id,
1344
1345
1346
1347
                request_prompt,
                params=sampling_params,
                lora_request=lora_request,
            )
1348
            trace_headers = kwargs.get("trace_headers")
1349
            engine_request, tokenization_kwargs = await self._process_inputs(
1350
                sub_request_id,
1351
1352
                engine_prompt,
                sampling_params,
1353
1354
1355
                lora_request=lora_request,
                trace_headers=trace_headers,
                priority=priority,
1356
            )
1357
1358
1359
1360

            generator = self.engine_client.generate(
                engine_request,
                sampling_params,
1361
                sub_request_id,
1362
1363
                lora_request=lora_request,
                priority=priority,
1364
1365
                prompt_text=prompt_text,
                tokenization_kwargs=tokenization_kwargs,
1366
1367
                **kwargs,
            )
1368

1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
            async for res in generator:
                context.append_output(res)
                # NOTE(woosuk): The stop condition is handled by the engine.
                yield context

            if not context.need_builtin_tool_call():
                # The model did not ask for a tool call, so we're done.
                break

            # Call the tool and update the context with the result.
            tool_output = await context.call_tool()
1380
            context.append_tool_output(tool_output)
1381
1382
1383
1384
1385
1386

            # TODO: uncomment this and enable tool output streaming
            # yield context

            # Create inputs for the next turn.
            # Render the next prompt token ids.
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
            if isinstance(context, (HarmonyContext, StreamingHarmonyContext)):
                prompt_token_ids = context.render_for_completion()
                engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids)
                request_prompt = prompt_token_ids
            elif isinstance(context, ParsableContext):
                request_prompts, engine_prompts = await self._render_next_turn(
                    context.request,
                    context.tokenizer,
                    context.parser.response_messages,
                    context.tool_dicts,
                    context.tool_parser_cls,
                    context.chat_template,
                    context.chat_template_content_format,
                )
                engine_prompt = engine_prompts[0]
                request_prompt = request_prompts[0]
1403
                prompt_text, _, _ = self._get_prompt_components(request_prompt)
1404

1405
            # Update the sampling params.
1406
1407
1408
            sampling_params.max_tokens = self.max_model_len - len(
                engine_prompt["prompt_token_ids"]
            )
1409
1410
            # OPTIMIZATION
            priority = orig_priority - 1
1411
            sub_request += 1
1412

1413
1414
    def _get_prompt_components(
        self,
1415
        prompt: RequestPrompt | PromptType,
1416
    ) -> PromptComponents:
1417
1418
        if isinstance(prompt, list):
            return PromptComponents(token_ids=prompt)
1419

1420
        return get_prompt_components(prompt)  # type: ignore[arg-type]
1421

1422
1423
1424
    def _log_inputs(
        self,
        request_id: str,
1425
1426
1427
        inputs: RequestPrompt | PromptType,
        params: SamplingParams | PoolingParams | BeamSearchParams | None,
        lora_request: LoRARequest | None,
1428
1429
1430
    ) -> None:
        if self.request_logger is None:
            return
1431

1432
        prompt, prompt_token_ids, prompt_embeds = self._get_prompt_components(inputs)
1433
1434
1435
1436
1437

        self.request_logger.log_inputs(
            request_id,
            prompt,
            prompt_token_ids,
1438
            prompt_embeds,
1439
1440
1441
            params=params,
            lora_request=lora_request,
        )
1442

1443
1444
1445
    async def _get_trace_headers(
        self,
        headers: Headers,
1446
    ) -> Mapping[str, str] | None:
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
        is_tracing_enabled = await self.engine_client.is_tracing_enabled()

        if is_tracing_enabled:
            return extract_trace_headers(headers)

        if contains_trace_headers(headers):
            log_tracing_disabled_warning()

        return None

1457
    @staticmethod
1458
    def _base_request_id(
1459
1460
        raw_request: Request | None, default: str | None = None
    ) -> str | None:
1461
        """Pulls the request id to use from a header, if provided"""
1462
1463
1464
1465
        if raw_request is not None and (
            (req_id := raw_request.headers.get("X-Request-Id")) is not None
        ):
            return req_id
1466

1467
        return random_uuid() if default is None else default
1468

1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
    @staticmethod
    def _get_data_parallel_rank(raw_request: Request | None) -> int | None:
        """Pulls the data parallel rank from a header, if provided"""
        if raw_request is None:
            return None

        rank_str = raw_request.headers.get("X-data-parallel-rank")
        if rank_str is None:
            return None

        try:
            return int(rank_str)
        except ValueError:
            return None

1484
1485
1486
    @staticmethod
    def _parse_tool_calls_from_content(
        request: ResponsesRequest | ChatCompletionRequest,
1487
        tokenizer: TokenizerLike,
1488
        enable_auto_tools: bool,
1489
        tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
        content: str | None = None,
    ) -> tuple[list[FunctionCall] | None, str | None]:
        function_calls = list[FunctionCall]()
        if request.tool_choice and isinstance(request.tool_choice, ToolChoiceFunction):
            assert content is not None
            # Forced Function Call
            function_calls.append(
                FunctionCall(name=request.tool_choice.name, arguments=content)
            )
            content = None  # Clear content since tool is called.
        elif request.tool_choice and isinstance(
            request.tool_choice, ChatCompletionNamedToolChoiceParam
        ):
            assert content is not None
            # Forced Function Call
            function_calls.append(
                FunctionCall(name=request.tool_choice.function.name, arguments=content)
            )
            content = None  # Clear content since tool is called.
        elif request.tool_choice == "required":
            assert content is not None
            tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(content)
            function_calls.extend(
                [
                    FunctionCall(
                        name=tool_call.name,
                        arguments=json.dumps(tool_call.parameters, ensure_ascii=False),
                    )
                    for tool_call in tool_calls
                ]
            )
            content = None  # Clear content since tool is called.
        elif (
            tool_parser_cls
            and enable_auto_tools
            and (request.tool_choice == "auto" or request.tool_choice is None)
        ):
            # Automatic Tool Call Parsing
            try:
                tool_parser = tool_parser_cls(tokenizer)
            except RuntimeError as e:
                logger.exception("Error in tool parser creation.")
                raise e
            tool_call_info = tool_parser.extract_tool_calls(
                content if content is not None else "",
                request=request,  # type: ignore
            )
            if tool_call_info is not None and tool_call_info.tools_called:
                # extract_tool_calls() returns a list of tool calls.
                function_calls.extend(
                    FunctionCall(
                        name=tool_call.function.name,
                        arguments=tool_call.function.arguments,
                    )
                    for tool_call in tool_call_info.tool_calls
                )
                content = tool_call_info.content
1547
1548
                if content and content.strip() == "":
                    content = None
1549
1550
1551
1552
1553
1554
            else:
                # No tool calls.
                return None, content

        return function_calls, content

1555
    @staticmethod
1556
1557
1558
    def _get_decoded_token(
        logprob: Logprob,
        token_id: int,
1559
        tokenizer: TokenizerLike | None,
1560
1561
        return_as_token_id: bool = False,
    ) -> str:
1562
1563
1564
        if return_as_token_id:
            return f"token_id:{token_id}"

1565
1566
        if logprob.decoded_token is not None:
            return logprob.decoded_token
1567
1568
1569
1570
1571
1572

        if tokenizer is None:
            raise ValueError(
                "Unable to get tokenizer because `skip_tokenizer_init=True`"
            )

1573
        return tokenizer.decode(token_id)
1574

1575
    def _is_model_supported(self, model_name: str | None) -> bool:
1576
1577
        if not model_name:
            return True
1578
        return self.models.is_base_model(model_name)
1579

1580
1581

def clamp_prompt_logprobs(
1582
1583
    prompt_logprobs: PromptLogprobs | None,
) -> PromptLogprobs | None:
1584
1585
1586
1587
1588
1589
1590
    if prompt_logprobs is None:
        return prompt_logprobs

    for logprob_dict in prompt_logprobs:
        if logprob_dict is None:
            continue
        for logprob_values in logprob_dict.values():
1591
            if logprob_values.logprob == float("-inf"):
1592
1593
                logprob_values.logprob = -9999.0
    return prompt_logprobs