serving.py 57.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
import json
5
import sys
6
import time
7
import traceback
8
from collections.abc import AsyncGenerator, Callable, Iterable, Mapping
9
from concurrent.futures import ThreadPoolExecutor
10
from dataclasses import dataclass, field
11
from http import HTTPStatus
12
from typing import Any, ClassVar, Generic, TypeAlias, TypeVar
13

14
import numpy as np
15
from fastapi import Request
16
17
18
from openai.types.responses import (
    ToolChoiceFunction,
)
19
20
from pydantic import ConfigDict, TypeAdapter
from starlette.datastructures import Headers
21

22
import vllm.envs as envs
23
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
24
from vllm.engine.protocol import EngineClient
25
26
27
28
29
30
31
32
33
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ChatTemplateContentFormatOption,
    ConversationMessage,
    apply_hf_chat_template,
    apply_mistral_chat_template,
    parse_chat_messages_futures,
    resolve_chat_template_content_format,
)
34
35
36
37
38
39
from vllm.entrypoints.context import (
    ConversationContext,
    HarmonyContext,
    ParsableContext,
    StreamingHarmonyContext,
)
40
from vllm.entrypoints.logger import RequestLogger
41
from vllm.entrypoints.openai.chat_completion.protocol import (
42
    ChatCompletionNamedToolChoiceParam,
43
44
    ChatCompletionRequest,
    ChatCompletionResponse,
45
46
)
from vllm.entrypoints.openai.engine.protocol import (
47
48
49
50
    CompletionRequest,
    CompletionResponse,
    ErrorInfo,
    ErrorResponse,
51
    FunctionCall,
52
    FunctionDefinition,
53
    VLLMValidationError,
54
)
55
56
57
58
from vllm.entrypoints.openai.responses.protocol import (
    ResponseInputOutputItem,
    ResponsesRequest,
)
59
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
60
61
62
63
64
from vllm.entrypoints.openai.translations.protocol import (
    TranscriptionRequest,
    TranscriptionResponse,
    TranslationRequest,
)
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from vllm.entrypoints.pooling.classify.protocol import (
    ClassificationChatRequest,
    ClassificationCompletionRequest,
    ClassificationRequest,
    ClassificationResponse,
)
from vllm.entrypoints.pooling.embed.protocol import (
    EmbeddingChatRequest,
    EmbeddingCompletionRequest,
    EmbeddingRequest,
    EmbeddingResponse,
)
from vllm.entrypoints.pooling.pooling.protocol import (
    IOProcessorRequest,
    PoolingResponse,
)
from vllm.entrypoints.pooling.score.protocol import (
    RerankRequest,
    ScoreRequest,
    ScoreResponse,
)
86
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig
87
88
89
from vllm.entrypoints.responses_utils import (
    construct_input_messages,
)
90
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
91
92
93
94
95
96
from vllm.entrypoints.serve.tokenize.protocol import (
    DetokenizeRequest,
    TokenizeChatRequest,
    TokenizeCompletionRequest,
    TokenizeResponse,
)
97
from vllm.entrypoints.utils import _validate_truncation_size
98
from vllm.inputs.data import PromptType, TokensPrompt
99
100
101
102
103
from vllm.inputs.parse import (
    PromptComponents,
    get_prompt_components,
    is_explicit_encoder_decoder_prompt,
)
104
from vllm.logger import init_logger
105
from vllm.logprobs import Logprob, PromptLogprobs
106
from vllm.lora.request import LoRARequest
107
from vllm.multimodal import MultiModalDataDict
108
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
109
from vllm.pooling_params import PoolingParams
110
from vllm.reasoning import ReasoningParser, ReasoningParserManager
111
from vllm.sampling_params import BeamSearchParams, SamplingParams
112
from vllm.tokenizers import TokenizerLike
113
from vllm.tokenizers.deepseek_v32 import DeepseekV32Tokenizer
114
from vllm.tokenizers.mistral import MistralTokenizer
115
from vllm.tool_parsers import ToolParser, ToolParserManager
116
117
118
119
120
from vllm.tracing import (
    contains_trace_headers,
    extract_trace_headers,
    log_tracing_disabled_warning,
)
121
from vllm.utils import random_uuid
122
from vllm.utils.async_utils import (
123
    AsyncMicrobatchTokenizer,
124
    collect_from_async_generator,
125
    make_async,
126
127
    merge_async_iterators,
)
128
from vllm.utils.collection_utils import is_list_of
129
from vllm.v1.engine import EngineCoreRequest
130

131
132
133
134
135
136
137
138
139

class GenerationError(Exception):
    """raised when finish_reason indicates internal server error (500)"""

    def __init__(self, message: str = "Internal server error"):
        super().__init__(message)
        self.status_code = HTTPStatus.INTERNAL_SERVER_ERROR


140
141
logger = init_logger(__name__)

142
143
144
145
146
CompletionLikeRequest: TypeAlias = (
    CompletionRequest
    | DetokenizeRequest
    | EmbeddingCompletionRequest
    | RerankRequest
147
    | ClassificationCompletionRequest
148
149
150
    | ScoreRequest
    | TokenizeCompletionRequest
)
151

152
ChatLikeRequest: TypeAlias = (
153
154
155
156
    ChatCompletionRequest
    | EmbeddingChatRequest
    | TokenizeChatRequest
    | ClassificationChatRequest
157
158
159
160
161
162
163
164
)
SpeechToTextRequest: TypeAlias = TranscriptionRequest | TranslationRequest
AnyRequest: TypeAlias = (
    CompletionLikeRequest
    | ChatLikeRequest
    | SpeechToTextRequest
    | ResponsesRequest
    | IOProcessorRequest
165
    | GenerateRequest
166
167
168
169
170
171
172
173
174
175
176
)

AnyResponse: TypeAlias = (
    CompletionResponse
    | ChatCompletionResponse
    | EmbeddingResponse
    | TranscriptionResponse
    | TokenizeResponse
    | PoolingResponse
    | ClassificationResponse
    | ScoreResponse
177
    | GenerateResponse
178
)
179

180

181
182
183
RequestT = TypeVar("RequestT", bound=AnyRequest)


184
185
@dataclass(kw_only=True)
class RequestProcessingMixin:
186
    """
187
    Mixin for request processing,
188
189
    handling prompt preparation and engine input.
    """
190

191
    engine_prompts: list[TokensPrompt] | None = field(default_factory=list)
192
193


194
195
@dataclass(kw_only=True)
class ResponseGenerationMixin:
196
    """
197
    Mixin for response generation,
198
199
    managing result generators and final batch results.
    """
200

201
202
203
    result_generator: (
        AsyncGenerator[tuple[int, RequestOutput | PoolingRequestOutput], None] | None
    ) = None
204
    final_res_batch: list[RequestOutput | PoolingRequestOutput] = field(
205
206
        default_factory=list
    )
207
208
209
210

    model_config = ConfigDict(arbitrary_types_allowed=True)


211
212
@dataclass(kw_only=True)
class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, Generic[RequestT]):
213
214
    # Shared across all requests
    request: RequestT
215
    raw_request: Request | None = None
216
217
    model_name: str
    request_id: str
218
    created_time: int = field(default_factory=lambda: int(time.time()))
219
    lora_request: LoRARequest | None = None
220
221

    # Shared across most requests
222
    tokenizer: TokenizerLike | None = None
223
224


225
226
227
@dataclass(kw_only=True)
class ClassificationServeContext(ServeContext[ClassificationRequest]):
    pass
228
229


230
@dataclass(kw_only=True)
231
class EmbeddingServeContext(ServeContext[EmbeddingRequest]):
232
    chat_template: str | None = None
233
234
235
    chat_template_content_format: ChatTemplateContentFormatOption


236
class OpenAIServing:
237
238
239
240
    request_id_prefix: ClassVar[str] = """
    A short string prepended to every request’s ID (e.g. "embd", "classify")
    so you can easily tell “this ID came from Embedding vs Classification.”
    """
241

242
243
    def __init__(
        self,
244
        engine_client: EngineClient,
245
        models: OpenAIServingModels,
246
        *,
247
        request_logger: RequestLogger | None,
248
        return_tokens_as_token_ids: bool = False,
249
        log_error_stack: bool = False,
250
    ):
251
252
        super().__init__()

253
        self.engine_client = engine_client
254

255
        self.models = models
256

257
        self.request_logger = request_logger
258
        self.return_tokens_as_token_ids = return_tokens_as_token_ids
259
        self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
260
        self._apply_mistral_chat_template_async = make_async(
261
262
            apply_mistral_chat_template, executor=self._tokenizer_executor
        )
263

264
        self._async_tokenizer_pool: dict[TokenizerLike, AsyncMicrobatchTokenizer] = {}
265
        self.log_error_stack = log_error_stack
266

267
        self.input_processor = self.models.input_processor
268
269
270
271
        self.io_processor = self.models.io_processor
        self.model_config = self.models.model_config
        self.max_model_len = self.model_config.max_model_len

272
    def _get_tool_parser(
273
        self, tool_parser_name: str | None = None, enable_auto_tools: bool = False
274
    ) -> Callable[[TokenizerLike], ToolParser] | None:
275
276
277
278
        """Get the tool parser based on the name."""
        parser = None
        if not enable_auto_tools or tool_parser_name is None:
            return parser
279
        logger.info('"auto" tool choice has been enabled.')
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299

        try:
            if tool_parser_name == "pythonic" and self.model_config.model.startswith(
                "meta-llama/Llama-3.2"
            ):
                logger.warning(
                    "Llama3.2 models may struggle to emit valid pythonic tool calls"
                )
            parser = ToolParserManager.get_tool_parser(tool_parser_name)
        except Exception as e:
            raise TypeError(
                "Error: --enable-auto-tool-choice requires "
                f"tool_parser:'{tool_parser_name}' which has not "
                "been registered"
            ) from e
        return parser

    def _get_reasoning_parser(
        self,
        reasoning_parser_name: str,
300
    ) -> Callable[[TokenizerLike], ReasoningParser] | None:
301
302
303
304
305
306
307
308
309
310
311
        """Get the reasoning parser based on the name."""
        parser = None
        if not reasoning_parser_name:
            return None
        try:
            parser = ReasoningParserManager.get_reasoning_parser(reasoning_parser_name)
            assert parser is not None
        except Exception as e:
            raise TypeError(f"{reasoning_parser_name=} has not been registered") from e
        return parser

312
    async def reset_mm_cache(self) -> None:
313
        self.input_processor.clear_mm_cache()
314
315
        await self.engine_client.reset_mm_cache()

316
317
318
319
320
    async def beam_search(
        self,
        prompt: PromptType,
        request_id: str,
        params: BeamSearchParams,
321
        lora_request: LoRARequest | None = None,
322
        trace_headers: Mapping[str, str] | None = None,
323
324
325
326
327
328
329
330
    ) -> AsyncGenerator[RequestOutput, None]:
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        ignore_eos = params.ignore_eos
        temperature = params.temperature
        length_penalty = params.length_penalty
        include_stop_str_in_output = params.include_stop_str_in_output

331
332
        input_processor = self.input_processor
        tokenizer = input_processor.tokenizer
333
        if tokenizer is None:
334
335
336
337
            raise VLLMValidationError(
                "You cannot use beam search when `skip_tokenizer_init=True`",
                parameter="skip_tokenizer_init",
                value=True,
338
339
340
341
342
343
344
            )

        eos_token_id: int = tokenizer.eos_token_id  # type: ignore

        if is_explicit_encoder_decoder_prompt(prompt):
            raise NotImplementedError

345
        prompt_text: str | None
346
        prompt_token_ids: list[int]
347
        multi_modal_data: MultiModalDataDict | None
348
349
350
351
352
353
354
355
356
        if isinstance(prompt, str):
            prompt_text = prompt
            prompt_token_ids = []
            multi_modal_data = None
        else:
            prompt_text = prompt.get("prompt")  # type: ignore
            prompt_token_ids = prompt.get("prompt_token_ids", [])  # type: ignore
            multi_modal_data = prompt.get("multi_modal_data")  # type: ignore

357
358
359
360
361
362
363
364
365
366
        mm_processor_kwargs: dict[str, Any] | None = None

        # This is a workaround to fix multimodal beam search; this is a
        # bandaid fix for 2 small problems:
        # 1. Multi_modal_data on the processed_inputs currently resolves to
        #    `None`.
        # 2. preprocessing above expands the multimodal placeholders. However,
        #    this happens again in generation, so the double expansion causes
        #    a mismatch.
        # TODO - would be ideal to handle this more gracefully.
367
368
369
370
371

        tokenized_length = len(prompt_token_ids)

        sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)

372
        logprobs_num = 2 * beam_width
373
        beam_search_params = SamplingParams(
374
            logprobs=logprobs_num,
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
            max_tokens=1,
            temperature=temperature,
        )
        all_beams = [
            BeamSearchSequence(
                tokens=prompt_token_ids,
                cum_logprob=0,
                logprobs=[],
                multi_modal_data=multi_modal_data,
                mm_processor_kwargs=mm_processor_kwargs,
                lora_request=lora_request,
            )
        ]
        completed = []

        for _ in range(max_tokens):
            prompts_batch, lora_req_batch = zip(
                *[
                    (
394
                        TokensPrompt(
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
                            prompt_token_ids=beam.tokens,
                            multi_modal_data=beam.multi_modal_data,
                            mm_processor_kwargs=beam.mm_processor_kwargs,
                        ),
                        beam.lora_request,
                    )
                    for beam in all_beams
                ]
            )

            tasks = []
            request_id_batch = f"{request_id}-{random_uuid()}"

            for i, (individual_prompt, lora_req) in enumerate(
                zip(prompts_batch, lora_req_batch)
            ):
                request_id_item = f"{request_id_batch}-beam-{i}"
                task = asyncio.create_task(
                    collect_from_async_generator(
                        self.engine_client.generate(
                            individual_prompt,
                            beam_search_params,
                            request_id_item,
                            lora_request=lora_req,
419
                            trace_headers=trace_headers,
420
421
422
423
424
425
426
427
                        )
                    )
                )
                tasks.append(task)

            output = [x[0] for x in await asyncio.gather(*tasks)]

            new_beams = []
428
429
430
431
432
433
434
435
            # Store all new tokens generated by beam
            all_beams_token_id = []
            # Store the cumulative probability of all tokens
            # generated by beam search
            all_beams_logprob = []
            # Iterate through all beam inference results
            for i, result in enumerate(output):
                current_beam = all_beams[i]
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458

                # check for error finish reason and abort beam search
                if result.outputs[0].finish_reason == "error":
                    # yield error output and terminate beam search
                    yield RequestOutput(
                        request_id=request_id,
                        prompt=prompt_text,
                        outputs=[
                            CompletionOutput(
                                index=0,
                                text="",
                                token_ids=[],
                                cumulative_logprob=None,
                                logprobs=None,
                                finish_reason="error",
                            )
                        ],
                        finished=True,
                        prompt_token_ids=prompt_token_ids,
                        prompt_logprobs=None,
                    )
                    return

459
460
                if result.outputs[0].logprobs is not None:
                    logprobs = result.outputs[0].logprobs[0]
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
                    all_beams_token_id.extend(list(logprobs.keys()))
                    all_beams_logprob.extend(
                        [
                            current_beam.cum_logprob + obj.logprob
                            for obj in logprobs.values()
                        ]
                    )

            # Handle the token for the end of sentence (EOS)
            all_beams_token_id = np.array(all_beams_token_id)
            all_beams_logprob = np.array(all_beams_logprob)

            if not ignore_eos:
                # Get the index position of eos token in all generated results
                eos_idx = np.where(all_beams_token_id == eos_token_id)[0]
                for idx in eos_idx:
                    current_beam = all_beams[idx // logprobs_num]
                    result = output[idx // logprobs_num]
                    assert result.outputs[0].logprobs is not None
                    logprobs_entry = result.outputs[0].logprobs[0]
                    completed.append(
                        BeamSearchSequence(
                            tokens=current_beam.tokens + [eos_token_id]
                            if include_stop_str_in_output
                            else current_beam.tokens,
                            logprobs=current_beam.logprobs + [logprobs_entry],
                            cum_logprob=float(all_beams_logprob[idx]),
                            finish_reason="stop",
                            stop_reason=eos_token_id,
                        )
                    )
                # After processing, set the log probability of the eos condition
                # to negative infinity.
                all_beams_logprob[eos_idx] = -np.inf

            # Processing non-EOS tokens
            # Get indices of the top beam_width probabilities
            topn_idx = np.argpartition(np.negative(all_beams_logprob), beam_width)[
                :beam_width
            ]

            for idx in topn_idx:
                current_beam = all_beams[idx // logprobs_num]
                result = output[idx // logprobs_num]
                token_id = int(all_beams_token_id[idx])
                assert result.outputs[0].logprobs is not None
                logprobs_entry = result.outputs[0].logprobs[0]
                new_beams.append(
                    BeamSearchSequence(
                        tokens=current_beam.tokens + [token_id],
                        logprobs=current_beam.logprobs + [logprobs_entry],
                        lora_request=current_beam.lora_request,
                        cum_logprob=float(all_beams_logprob[idx]),
                        multi_modal_data=current_beam.multi_modal_data,
                        mm_processor_kwargs=current_beam.mm_processor_kwargs,
                    )
                )

            all_beams = new_beams
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553

        completed.extend(all_beams)
        sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
        best_beams = sorted_completed[:beam_width]

        for beam in best_beams:
            if beam.tokens[-1] == eos_token_id and not ignore_eos:
                # Skip the eos token in the text.
                tokens = beam.tokens[tokenized_length:-1]
            else:
                tokens = beam.tokens[tokenized_length:]
            beam.text = tokenizer.decode(tokens)

        yield RequestOutput(
            request_id=request_id,
            prompt=prompt_text,
            outputs=[
                CompletionOutput(
                    text=beam.text,  # type: ignore
                    cumulative_logprob=beam.cum_logprob,
                    token_ids=beam.tokens[tokenized_length:],
                    index=i,
                    logprobs=beam.logprobs,
                    finish_reason=beam.finish_reason
                    if beam.finish_reason is not None
                    else "length",
                    stop_reason=beam.stop_reason,
                )
                for (i, beam) in enumerate(best_beams)
            ],
            finished=True,
            prompt_token_ids=prompt_token_ids,
            prompt_logprobs=None,
        )
554

555
    def _get_renderer(self, tokenizer: TokenizerLike | None) -> BaseRenderer:
556
557
558
559
560
561
562
        """
        Get a Renderer instance with the provided tokenizer.
        Uses shared async tokenizer pool for efficiency.
        """
        return CompletionRenderer(
            model_config=self.model_config,
            tokenizer=tokenizer,
563
564
            async_tokenizer_pool=self._async_tokenizer_pool,
        )
565

566
567
568
569
570
571
572
573
574
575
576
577
578
    def _build_render_config(
        self,
        request: Any,
    ) -> RenderConfig:
        """
        Build and return a `RenderConfig` for an endpoint.

        Used by the renderer to control how prompts are prepared
        (e.g., tokenization and length handling). Endpoints should
        implement this with logic appropriate to their request type.
        """
        raise NotImplementedError

579
580
    def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer:
        """
581
        Return (and cache) an `AsyncMicrobatchTokenizer` bound to the
582
583
584
585
586
587
588
        given tokenizer.
        """
        async_tokenizer = self._async_tokenizer_pool.get(tokenizer)
        if async_tokenizer is None:
            async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
            self._async_tokenizer_pool[tokenizer] = async_tokenizer
        return async_tokenizer
589

590
591
592
    async def _preprocess(
        self,
        ctx: ServeContext,
593
    ) -> ErrorResponse | None:
594
595
596
597
598
599
600
601
602
        """
        Default preprocessing hook. Subclasses may override
        to prepare `ctx` (classification, embedding, etc.).
        """
        return None

    def _build_response(
        self,
        ctx: ServeContext,
603
    ) -> AnyResponse | ErrorResponse:
604
605
606
607
608
609
610
611
612
        """
        Default response builder. Subclass may override this method
        to return the appropriate response object.
        """
        return self.create_error_response("unimplemented endpoint")

    async def handle(
        self,
        ctx: ServeContext,
613
614
    ) -> AnyResponse | ErrorResponse:
        generation: AsyncGenerator[AnyResponse | ErrorResponse, None]
615
616
617
618
619
620
621
622
623
624
        generation = self._pipeline(ctx)

        async for response in generation:
            return response

        return self.create_error_response("No response yielded from pipeline")

    async def _pipeline(
        self,
        ctx: ServeContext,
625
    ) -> AsyncGenerator[AnyResponse | ErrorResponse, None]:
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
        """Execute the request processing pipeline yielding responses."""
        if error := await self._check_model(ctx.request):
            yield error
        if error := self._validate_request(ctx):
            yield error

        preprocess_ret = await self._preprocess(ctx)
        if isinstance(preprocess_ret, ErrorResponse):
            yield preprocess_ret

        generators_ret = await self._prepare_generators(ctx)
        if isinstance(generators_ret, ErrorResponse):
            yield generators_ret

        collect_ret = await self._collect_batch(ctx)
        if isinstance(collect_ret, ErrorResponse):
            yield collect_ret

        yield self._build_response(ctx)

646
    def _validate_request(self, ctx: ServeContext) -> ErrorResponse | None:
647
        truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None)
648

649
650
651
652
        if (
            truncate_prompt_tokens is not None
            and truncate_prompt_tokens > self.max_model_len
        ):
653
654
655
            return self.create_error_response(
                "truncate_prompt_tokens value is "
                "greater than max_model_len."
656
657
                " Please, select a smaller truncation size."
            )
658
659
        return None

660
661
662
    def _create_pooling_params(
        self,
        ctx: ServeContext,
663
    ) -> PoolingParams | ErrorResponse:
664
665
        if not hasattr(ctx.request, "to_pooling_params"):
            return self.create_error_response(
666
667
                "Request type does not support pooling parameters"
            )
668
669
670

        return ctx.request.to_pooling_params()

671
672
673
    async def _prepare_generators(
        self,
        ctx: ServeContext,
674
    ) -> ErrorResponse | None:
675
        """Schedule the request and get the result generator."""
676
        generators: list[
677
            AsyncGenerator[RequestOutput | PoolingRequestOutput, None]
678
        ] = []
679
680

        try:
681
682
683
684
685
            trace_headers = (
                None
                if ctx.raw_request is None
                else await self._get_trace_headers(ctx.raw_request.headers)
            )
686

687
688
689
            pooling_params = self._create_pooling_params(ctx)
            if isinstance(pooling_params, ErrorResponse):
                return pooling_params
690
691

            if ctx.engine_prompts is None:
692
                return self.create_error_response("Engine prompts not available")
693
694
695
696

            for i, engine_prompt in enumerate(ctx.engine_prompts):
                request_id_item = f"{ctx.request_id}-{i}"

697
698
                self._log_inputs(
                    request_id_item,
699
                    engine_prompt,
700
701
702
                    params=pooling_params,
                    lora_request=ctx.lora_request,
                )
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719

                generator = self.engine_client.encode(
                    engine_prompt,
                    pooling_params,
                    request_id_item,
                    lora_request=ctx.lora_request,
                    trace_headers=trace_headers,
                    priority=getattr(ctx.request, "priority", 0),
                )

                generators.append(generator)

            ctx.result_generator = merge_async_iterators(*generators)

            return None

        except Exception as e:
720
            return self.create_error_response(e)
721
722
723
724

    async def _collect_batch(
        self,
        ctx: ServeContext,
725
    ) -> ErrorResponse | None:
726
727
728
        """Collect batch results from the result generator."""
        try:
            if ctx.engine_prompts is None:
729
                return self.create_error_response("Engine prompts not available")
730
731

            num_prompts = len(ctx.engine_prompts)
732
            final_res_batch: list[RequestOutput | PoolingRequestOutput | None]
733
734
735
            final_res_batch = [None] * num_prompts

            if ctx.result_generator is None:
736
                return self.create_error_response("Result generator not available")
737
738
739
740
741
742

            async for i, res in ctx.result_generator:
                final_res_batch[i] = res

            if None in final_res_batch:
                return self.create_error_response(
743
744
                    "Failed to generate results for all prompts"
                )
745

746
            ctx.final_res_batch = [res for res in final_res_batch if res is not None]
747
748
749
750

            return None

        except Exception as e:
751
            return self.create_error_response(e)
752

753
    def create_error_response(
754
        self,
755
        message: str | Exception,
756
757
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
758
        param: str | None = None,
759
    ) -> ErrorResponse:
760
761
762
763
764
        exc: Exception | None = None

        if isinstance(message, Exception):
            exc = message

765
            from vllm.exceptions import VLLMValidationError
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787

            if isinstance(exc, VLLMValidationError):
                err_type = "BadRequestError"
                status_code = HTTPStatus.BAD_REQUEST
                param = exc.parameter
            elif isinstance(exc, (ValueError, TypeError, RuntimeError)):
                # Common validation errors from user input
                err_type = "BadRequestError"
                status_code = HTTPStatus.BAD_REQUEST
                param = None
            elif exc.__class__.__name__ == "TemplateError":
                # jinja2.TemplateError (avoid importing jinja2)
                err_type = "BadRequestError"
                status_code = HTTPStatus.BAD_REQUEST
                param = None
            else:
                err_type = "InternalServerError"
                status_code = HTTPStatus.INTERNAL_SERVER_ERROR
                param = None

            message = str(exc)

788
789
790
791
792
793
        if self.log_error_stack:
            exc_type, _, _ = sys.exc_info()
            if exc_type is not None:
                traceback.print_exc()
            else:
                traceback.print_stack()
794
        return ErrorResponse(
795
796
797
798
799
800
            error=ErrorInfo(
                message=message,
                type=err_type,
                code=status_code.value,
                param=param,
            )
801
        )
802

803
    def create_streaming_error_response(
804
        self,
805
        message: str | Exception,
806
807
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
808
        param: str | None = None,
809
    ) -> str:
810
        json_str = json.dumps(
811
            self.create_error_response(
812
813
814
815
                message=message,
                err_type=err_type,
                status_code=status_code,
                param=param,
816
817
            ).model_dump()
        )
818
819
        return json_str

820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
    def _raise_if_error(self, finish_reason: str | None, request_id: str) -> None:
        """Raise GenerationError if finish_reason indicates an error."""
        if finish_reason == "error":
            logger.error(
                "Request %s failed with an internal error during generation",
                request_id,
            )
            raise GenerationError("Internal server error")

    def _convert_generation_error_to_response(
        self, e: GenerationError
    ) -> ErrorResponse:
        """Convert GenerationError to ErrorResponse."""
        return self.create_error_response(
            str(e),
            err_type="InternalServerError",
            status_code=e.status_code,
        )

    def _convert_generation_error_to_streaming_response(
        self, e: GenerationError
    ) -> str:
        """Convert GenerationError to streaming error response."""
        return self.create_streaming_error_response(
            str(e),
            err_type="InternalServerError",
            status_code=e.status_code,
        )

849
    async def _check_model(
850
851
        self,
        request: AnyRequest,
852
    ) -> ErrorResponse | None:
853
854
        error_response = None

855
        if self._is_model_supported(request.model):
856
            return None
857
        if request.model in self.models.lora_requests:
858
            return None
859
860
861
862
863
        if (
            envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING
            and request.model
            and (load_result := await self.models.resolve_lora(request.model))
        ):
864
865
            if isinstance(load_result, LoRARequest):
                return None
866
867
868
869
            if (
                isinstance(load_result, ErrorResponse)
                and load_result.error.code == HTTPStatus.BAD_REQUEST.value
            ):
870
871
872
                error_response = load_result

        return error_response or self.create_error_response(
873
874
            message=f"The model `{request.model}` does not exist.",
            err_type="NotFoundError",
875
            status_code=HTTPStatus.NOT_FOUND,
876
            param="model",
877
        )
878

879
    def _get_active_default_mm_loras(self, request: AnyRequest) -> LoRARequest | None:
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
        """Determine if there are any active default multimodal loras."""
        # TODO: Currently this is only enabled for chat completions
        # to be better aligned with only being enabled for .generate
        # when run offline. It would be nice to support additional
        # tasks types in the future.
        message_types = self._get_message_types(request)
        default_mm_loras = set()

        for lora in self.models.lora_requests.values():
            # Best effort match for default multimodal lora adapters;
            # There is probably a better way to do this, but currently
            # this matches against the set of 'types' in any content lists
            # up until '_', e.g., to match audio_url -> audio
            if lora.lora_name in message_types:
                default_mm_loras.add(lora)

        # Currently only support default modality specific loras if
        # we have exactly one lora matched on the request.
        if len(default_mm_loras) == 1:
            return default_mm_loras.pop()
        return None

902
    def _maybe_get_adapters(
903
904
905
        self,
        request: AnyRequest,
        supports_default_mm_loras: bool = False,
906
    ) -> LoRARequest | None:
907
        if request.model in self.models.lora_requests:
908
            return self.models.lora_requests[request.model]
909
910
911
912
913
914

        # Currently only support default modality specific loras
        # if we have exactly one lora matched on the request.
        if supports_default_mm_loras:
            default_mm_lora = self._get_active_default_mm_loras(request)
            if default_mm_lora is not None:
915
                return default_mm_lora
916
917

        if self._is_model_supported(request.model):
918
            return None
919

920
        # if _check_model has been called earlier, this will be unreachable
921
        raise ValueError(f"The model `{request.model}` does not exist.")
922

923
924
925
926
927
928
929
930
931
932
    def _get_message_types(self, request: AnyRequest) -> set[str]:
        """Retrieve the set of types from message content dicts up
        until `_`; we use this to match potential multimodal data
        with default per modality loras.
        """
        message_types: set[str] = set()

        if not hasattr(request, "messages"):
            return message_types

933
934
935
936
937
        messages = request.messages
        if messages is None or isinstance(messages, (str, bytes)):
            return message_types

        for message in messages:
938
939
940
941
942
            if (
                isinstance(message, dict)
                and "content" in message
                and isinstance(message["content"], list)
            ):
943
944
945
946
947
                for content_dict in message["content"]:
                    if "type" in content_dict:
                        message_types.add(content_dict["type"].split("_")[0])
        return message_types

948
    async def _normalize_prompt_text_to_input(
949
950
951
        self,
        request: AnyRequest,
        prompt: str,
952
        tokenizer: TokenizerLike,
953
        add_special_tokens: bool,
954
    ) -> TokensPrompt:
955
956
        async_tokenizer = self._get_async_tokenizer(tokenizer)

957
958
959
960
        if (
            self.model_config.encoder_config is not None
            and self.model_config.encoder_config.get("do_lower_case", False)
        ):
961
962
            prompt = prompt.lower()

963
        truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
964

965
        if truncate_prompt_tokens is None:
966
            encoded = await async_tokenizer(
967
968
                prompt, add_special_tokens=add_special_tokens
            )
969
970
        elif truncate_prompt_tokens < 0:
            # Negative means we cap at the model's max length
971
972
973
974
            encoded = await async_tokenizer(
                prompt,
                add_special_tokens=add_special_tokens,
                truncation=True,
975
976
                max_length=self.max_model_len,
            )
977
        else:
978
979
980
981
            encoded = await async_tokenizer(
                prompt,
                add_special_tokens=add_special_tokens,
                truncation=True,
982
983
                max_length=truncate_prompt_tokens,
            )
984
985
986
987
988
989

        input_ids = encoded.input_ids
        input_text = prompt

        return self._validate_input(request, input_ids, input_text)

990
    async def _normalize_prompt_tokens_to_input(
991
992
        self,
        request: AnyRequest,
993
        prompt_ids: list[int],
994
        tokenizer: TokenizerLike | None,
995
    ) -> TokensPrompt:
996
        truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
997

998
        if truncate_prompt_tokens is None:
999
            input_ids = prompt_ids
1000
        elif truncate_prompt_tokens < 0:
1001
            input_ids = prompt_ids[-self.max_model_len :]
1002
1003
1004
        else:
            input_ids = prompt_ids[-truncate_prompt_tokens:]

1005
1006
1007
1008
1009
        if tokenizer is None:
            input_text = ""
        else:
            async_tokenizer = self._get_async_tokenizer(tokenizer)
            input_text = await async_tokenizer.decode(input_ids)
1010

1011
1012
1013
1014
1015
        return self._validate_input(request, input_ids, input_text)

    def _validate_input(
        self,
        request: AnyRequest,
1016
        input_ids: list[int],
1017
        input_text: str,
1018
    ) -> TokensPrompt:
1019
1020
        token_num = len(input_ids)

1021
1022
        # Note: EmbeddingRequest, ClassificationRequest,
        # and ScoreRequest doesn't have max_tokens
1023
        if isinstance(
1024
            request,
1025
1026
1027
1028
1029
            (
                EmbeddingChatRequest,
                EmbeddingCompletionRequest,
                ScoreRequest,
                RerankRequest,
1030
1031
                ClassificationCompletionRequest,
                ClassificationChatRequest,
1032
1033
            ),
        ):
1034
1035
            # Note: input length can be up to the entire model context length
            # since these requests don't generate tokens.
1036
            if token_num > self.max_model_len:
1037
1038
                operations: dict[type[AnyRequest], str] = {
                    ScoreRequest: "score",
1039
1040
                    ClassificationCompletionRequest: "classification",
                    ClassificationChatRequest: "classification",
1041
                }
1042
                operation = operations.get(type(request), "embedding generation")
1043
                raise VLLMValidationError(
1044
1045
                    f"This model's maximum context length is "
                    f"{self.max_model_len} tokens. However, you requested "
1046
                    f"{token_num} tokens in the input for {operation}. "
1047
1048
1049
                    f"Please reduce the length of the input.",
                    parameter="input_tokens",
                    value=token_num,
1050
                )
1051
            return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
1052

1053
1054
        # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
        # and does not require model context length validation
1055
        if isinstance(
1056
1057
            request,
            (TokenizeCompletionRequest, TokenizeChatRequest, DetokenizeRequest),
1058
        ):
1059
            return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
1060

1061
1062
1063
1064
1065
        # chat completion endpoint supports max_completion_tokens
        if isinstance(request, ChatCompletionRequest):
            # TODO(#9845): remove max_tokens when field dropped from OpenAI API
            max_tokens = request.max_completion_tokens or request.max_tokens
        else:
1066
            max_tokens = getattr(request, "max_tokens", None)
1067
1068
1069
1070

        # Note: input length can be up to model context length - 1 for
        # completion-like requests.
        if token_num >= self.max_model_len:
1071
            raise VLLMValidationError(
1072
                f"This model's maximum context length is "
1073
1074
                f"{self.max_model_len} tokens. However, your request has "
                f"{token_num} input tokens. Please reduce the length of "
1075
1076
1077
                "the input messages.",
                parameter="input_tokens",
                value=token_num,
1078
            )
1079

1080
        if max_tokens is not None and token_num + max_tokens > self.max_model_len:
1081
            raise VLLMValidationError(
1082
1083
1084
1085
                "'max_tokens' or 'max_completion_tokens' is too large: "
                f"{max_tokens}. This model's maximum context length is "
                f"{self.max_model_len} tokens and your request has "
                f"{token_num} input tokens ({max_tokens} > {self.max_model_len}"
1086
1087
1088
                f" - {token_num}).",
                parameter="max_tokens",
                value=max_tokens,
1089
            )
1090

1091
        return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
1092

1093
    async def _tokenize_prompt_input_async(
1094
1095
        self,
        request: AnyRequest,
1096
        tokenizer: TokenizerLike,
1097
        prompt_input: str | list[int],
1098
        add_special_tokens: bool = True,
1099
    ) -> TokensPrompt:
1100
        """
1101
        A simpler implementation that tokenizes a single prompt input.
1102
        """
1103
        async for result in self._tokenize_prompt_inputs_async(
1104
1105
            request,
            tokenizer,
1106
            [prompt_input],
1107
            add_special_tokens=add_special_tokens,
1108
1109
1110
        ):
            return result
        raise ValueError("No results yielded from tokenization")
1111

1112
    async def _tokenize_prompt_inputs_async(
1113
1114
        self,
        request: AnyRequest,
1115
        tokenizer: TokenizerLike,
1116
        prompt_inputs: Iterable[str | list[int]],
1117
        add_special_tokens: bool = True,
1118
    ) -> AsyncGenerator[TokensPrompt, None]:
1119
        """
1120
        A simpler implementation that tokenizes multiple prompt inputs.
1121
        """
1122
1123
        for prompt in prompt_inputs:
            if isinstance(prompt, str):
1124
                yield await self._normalize_prompt_text_to_input(
1125
                    request,
1126
1127
                    prompt=prompt,
                    tokenizer=tokenizer,
1128
1129
1130
                    add_special_tokens=add_special_tokens,
                )
            else:
1131
                yield await self._normalize_prompt_tokens_to_input(
1132
                    request,
1133
1134
                    prompt_ids=prompt,
                    tokenizer=tokenizer,
1135
1136
                )

1137
1138
    def _validate_chat_template(
        self,
1139
1140
        request_chat_template: str | None,
        chat_template_kwargs: dict[str, Any] | None,
1141
        trust_request_chat_template: bool,
1142
    ) -> ErrorResponse | None:
1143
        if not trust_request_chat_template and (
1144
1145
1146
1147
1148
1149
            request_chat_template is not None
            or (
                chat_template_kwargs
                and chat_template_kwargs.get("chat_template") is not None
            )
        ):
1150
1151
1152
            return self.create_error_response(
                "Chat template is passed with request, but "
                "--trust-request-chat-template is not set. "
1153
1154
                "Refused request with untrusted chat template."
            )
1155
1156
        return None

1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
    @staticmethod
    def _prepare_extra_chat_template_kwargs(
        request_chat_template_kwargs: dict[str, Any] | None = None,
        default_chat_template_kwargs: dict[str, Any] | None = None,
    ) -> dict[str, Any]:
        """Helper to merge server-default and request-specific chat template kwargs."""
        request_chat_template_kwargs = request_chat_template_kwargs or {}
        if default_chat_template_kwargs is None:
            return request_chat_template_kwargs
        # Apply server defaults first, then request kwargs override.
        return default_chat_template_kwargs | request_chat_template_kwargs

1169
1170
    async def _preprocess_chat(
        self,
1171
        request: ChatLikeRequest | ResponsesRequest,
1172
        tokenizer: TokenizerLike | None,
1173
        messages: list[ChatCompletionMessageParam],
1174
        chat_template: str | None,
1175
        chat_template_content_format: ChatTemplateContentFormatOption,
1176
1177
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
1178
1179
1180
        tool_dicts: list[dict[str, Any]] | None = None,
        documents: list[dict[str, str]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
1181
        default_chat_template_kwargs: dict[str, Any] | None = None,
1182
        tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
1183
        add_special_tokens: bool = False,
1184
    ) -> tuple[list[ConversationMessage], list[TokensPrompt]]:
1185
        model_config = self.model_config
1186

1187
1188
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
1189
            tool_dicts,
1190
1191
            chat_template_content_format,
            tokenizer,
1192
            model_config=model_config,
1193
        )
1194
        conversation, mm_data_future, mm_uuids = parse_chat_messages_futures(
1195
            messages,
1196
            model_config,
1197
            content_format=resolved_content_format,
1198
1199
        )

1200
        _chat_template_kwargs: dict[str, Any] = dict(
1201
1202
1203
1204
1205
1206
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tool_dicts,
            documents=documents,
        )
1207
1208
1209
1210
        _chat_template_kwargs |= self._prepare_extra_chat_template_kwargs(
            chat_template_kwargs,
            default_chat_template_kwargs,
        )
1211

1212
        request_prompt: str | list[int]
1213
1214
1215
1216

        if tokenizer is None:
            request_prompt = "placeholder"
        elif isinstance(tokenizer, MistralTokenizer):
1217
            request_prompt = await self._apply_mistral_chat_template_async(
1218
1219
                tokenizer,
                messages=messages,
1220
                **_chat_template_kwargs,
1221
            )
1222
1223
1224
1225
        elif isinstance(tokenizer, DeepseekV32Tokenizer):
            request_prompt = tokenizer.apply_chat_template(
                conversation=conversation,
                messages=messages,
1226
                model_config=model_config,
1227
1228
                **_chat_template_kwargs,
            )
1229
1230
        else:
            request_prompt = apply_hf_chat_template(
1231
                tokenizer=tokenizer,
1232
                conversation=conversation,
1233
                model_config=model_config,
1234
                **_chat_template_kwargs,
1235
1236
1237
1238
            )

        mm_data = await mm_data_future

1239
1240
1241
        # tool parsing is done only if a tool_parser has been set and if
        # tool_choice is not "none" (if tool_choice is "none" but a tool_parser
        # is set, we want to prevent parsing a tool_call hallucinated by the LLM
1242
1243
1244
        should_parse_tools = tool_parser is not None and (
            hasattr(request, "tool_choice") and request.tool_choice != "none"
        )
1245
1246

        if should_parse_tools:
1247
1248
1249
1250
1251
            if not isinstance(request, ChatCompletionRequest | ResponsesRequest):
                msg = (
                    "Tool usage is only supported for Chat Completions API "
                    "or Responses API requests."
                )
1252
                raise NotImplementedError(msg)
1253
            request = tool_parser(tokenizer).adjust_request(request=request)  # type: ignore
1254

1255
1256
        if tokenizer is None:
            assert isinstance(request_prompt, str), (
1257
1258
                "Prompt has to be a string",
                "when the tokenizer is not initialised",
1259
            )
1260
            prompt_inputs = TokensPrompt(prompt=request_prompt, prompt_token_ids=[1])
1261
        elif isinstance(request_prompt, str):
1262
            prompt_inputs = await self._tokenize_prompt_input_async(
1263
1264
1265
1266
1267
1268
1269
1270
                request,
                tokenizer,
                request_prompt,
                add_special_tokens=add_special_tokens,
            )
        else:
            # For MistralTokenizer
            assert is_list_of(request_prompt, int), (
1271
1272
                "Prompt has to be either a string or a list of token ids"
            )
1273
            prompt_inputs = TokensPrompt(
1274
                prompt=tokenizer.decode(request_prompt),
1275
1276
                prompt_token_ids=request_prompt,
            )
1277

1278
1279
1280
1281
        engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["prompt_token_ids"])
        if "prompt" in prompt_inputs:
            engine_prompt["prompt"] = prompt_inputs["prompt"]

1282
1283
        if mm_data is not None:
            engine_prompt["multi_modal_data"] = mm_data
1284
1285
1286
1287

        if mm_uuids is not None:
            engine_prompt["multi_modal_uuids"] = mm_uuids

1288
1289
        if request.mm_processor_kwargs is not None:
            engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
1290

1291
1292
1293
        if hasattr(request, "cache_salt") and request.cache_salt is not None:
            engine_prompt["cache_salt"] = request.cache_salt

1294
        return conversation, [engine_prompt]
1295

1296
1297
1298
1299
    async def _process_inputs(
        self,
        request_id: str,
        engine_prompt: PromptType,
1300
        params: SamplingParams | PoolingParams,
1301
        *,
1302
1303
        lora_request: LoRARequest | None,
        trace_headers: Mapping[str, str] | None,
1304
        priority: int,
1305
        data_parallel_rank: int | None = None,
1306
    ) -> tuple[EngineCoreRequest, dict[str, Any]]:
1307
        """Use the Processor to process inputs for AsyncLLM."""
1308
        tokenization_kwargs: dict[str, Any] = {}
1309
1310
1311
        _validate_truncation_size(
            self.max_model_len, params.truncate_prompt_tokens, tokenization_kwargs
        )
1312

1313
        engine_request = self.input_processor.process_inputs(
1314
1315
            request_id,
            engine_prompt,
1316
            params,
1317
1318
1319
1320
            lora_request=lora_request,
            tokenization_kwargs=tokenization_kwargs,
            trace_headers=trace_headers,
            priority=priority,
1321
            data_parallel_rank=data_parallel_rank,
1322
1323
1324
        )
        return engine_request, tokenization_kwargs

1325
1326
1327
    async def _render_next_turn(
        self,
        request: ResponsesRequest,
1328
        tokenizer: TokenizerLike | None,
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
        messages: list[ResponseInputOutputItem],
        tool_dicts: list[dict[str, Any]] | None,
        tool_parser,
        chat_template: str | None,
        chat_template_content_format: ChatTemplateContentFormatOption,
    ):
        new_messages = construct_input_messages(
            request_input=messages,
        )

1339
        _, engine_prompts = await self._preprocess_chat(
1340
1341
1342
1343
1344
1345
1346
1347
            request,
            tokenizer,
            new_messages,
            tool_dicts=tool_dicts,
            tool_parser=tool_parser,
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
        )
1348
        return engine_prompts
1349

1350
1351
1352
    async def _generate_with_builtin_tools(
        self,
        request_id: str,
1353
        engine_prompt: TokensPrompt,
1354
1355
        sampling_params: SamplingParams,
        context: ConversationContext,
1356
        lora_request: LoRARequest | None = None,
1357
1358
1359
        priority: int = 0,
        **kwargs,
    ):
1360
1361
        prompt_text, _, _ = self._get_prompt_components(engine_prompt)

1362
        orig_priority = priority
1363
        sub_request = 0
1364
        while True:
1365
1366
            # Ensure that each sub-request has a unique request id.
            sub_request_id = f"{request_id}_{sub_request}"
1367
            self._log_inputs(
1368
                sub_request_id,
1369
                engine_prompt,
1370
1371
1372
                params=sampling_params,
                lora_request=lora_request,
            )
1373
            trace_headers = kwargs.get("trace_headers")
1374
            engine_request, tokenization_kwargs = await self._process_inputs(
1375
                sub_request_id,
1376
1377
                engine_prompt,
                sampling_params,
1378
1379
1380
                lora_request=lora_request,
                trace_headers=trace_headers,
                priority=priority,
1381
            )
1382
1383
1384
1385

            generator = self.engine_client.generate(
                engine_request,
                sampling_params,
1386
                sub_request_id,
1387
1388
                lora_request=lora_request,
                priority=priority,
1389
1390
                prompt_text=prompt_text,
                tokenization_kwargs=tokenization_kwargs,
1391
1392
                **kwargs,
            )
1393

1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
            async for res in generator:
                context.append_output(res)
                # NOTE(woosuk): The stop condition is handled by the engine.
                yield context

            if not context.need_builtin_tool_call():
                # The model did not ask for a tool call, so we're done.
                break

            # Call the tool and update the context with the result.
            tool_output = await context.call_tool()
1405
            context.append_tool_output(tool_output)
1406
1407
1408
1409
1410
1411

            # TODO: uncomment this and enable tool output streaming
            # yield context

            # Create inputs for the next turn.
            # Render the next prompt token ids.
1412
1413
            if isinstance(context, (HarmonyContext, StreamingHarmonyContext)):
                prompt_token_ids = context.render_for_completion()
1414
                engine_prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
1415
            elif isinstance(context, ParsableContext):
1416
                engine_prompts = await self._render_next_turn(
1417
1418
1419
1420
1421
1422
1423
1424
1425
                    context.request,
                    context.tokenizer,
                    context.parser.response_messages,
                    context.tool_dicts,
                    context.tool_parser_cls,
                    context.chat_template,
                    context.chat_template_content_format,
                )
                engine_prompt = engine_prompts[0]
1426
                prompt_text, _, _ = self._get_prompt_components(engine_prompt)
1427

1428
            # Update the sampling params.
1429
1430
1431
            sampling_params.max_tokens = self.max_model_len - len(
                engine_prompt["prompt_token_ids"]
            )
1432
1433
            # OPTIMIZATION
            priority = orig_priority - 1
1434
            sub_request += 1
1435

1436
1437
    def _get_prompt_components(self, prompt: PromptType) -> PromptComponents:
        return get_prompt_components(prompt)
1438

1439
1440
1441
    def _log_inputs(
        self,
        request_id: str,
1442
        inputs: PromptType,
1443
1444
        params: SamplingParams | PoolingParams | BeamSearchParams | None,
        lora_request: LoRARequest | None,
1445
1446
1447
    ) -> None:
        if self.request_logger is None:
            return
1448

1449
        prompt, prompt_token_ids, prompt_embeds = self._get_prompt_components(inputs)
1450
1451
1452
1453
1454

        self.request_logger.log_inputs(
            request_id,
            prompt,
            prompt_token_ids,
1455
            prompt_embeds,
1456
1457
1458
            params=params,
            lora_request=lora_request,
        )
1459

1460
1461
1462
    async def _get_trace_headers(
        self,
        headers: Headers,
1463
    ) -> Mapping[str, str] | None:
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
        is_tracing_enabled = await self.engine_client.is_tracing_enabled()

        if is_tracing_enabled:
            return extract_trace_headers(headers)

        if contains_trace_headers(headers):
            log_tracing_disabled_warning()

        return None

1474
    @staticmethod
1475
    def _base_request_id(
1476
1477
        raw_request: Request | None, default: str | None = None
    ) -> str | None:
1478
        """Pulls the request id to use from a header, if provided"""
1479
1480
1481
1482
        if raw_request is not None and (
            (req_id := raw_request.headers.get("X-Request-Id")) is not None
        ):
            return req_id
1483

1484
        return random_uuid() if default is None else default
1485

1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
    @staticmethod
    def _get_data_parallel_rank(raw_request: Request | None) -> int | None:
        """Pulls the data parallel rank from a header, if provided"""
        if raw_request is None:
            return None

        rank_str = raw_request.headers.get("X-data-parallel-rank")
        if rank_str is None:
            return None

        try:
            return int(rank_str)
        except ValueError:
            return None

1501
1502
1503
    @staticmethod
    def _parse_tool_calls_from_content(
        request: ResponsesRequest | ChatCompletionRequest,
1504
        tokenizer: TokenizerLike | None,
1505
        enable_auto_tools: bool,
1506
        tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
        content: str | None = None,
    ) -> tuple[list[FunctionCall] | None, str | None]:
        function_calls = list[FunctionCall]()
        if request.tool_choice and isinstance(request.tool_choice, ToolChoiceFunction):
            assert content is not None
            # Forced Function Call
            function_calls.append(
                FunctionCall(name=request.tool_choice.name, arguments=content)
            )
            content = None  # Clear content since tool is called.
        elif request.tool_choice and isinstance(
            request.tool_choice, ChatCompletionNamedToolChoiceParam
        ):
            assert content is not None
            # Forced Function Call
            function_calls.append(
                FunctionCall(name=request.tool_choice.function.name, arguments=content)
            )
            content = None  # Clear content since tool is called.
        elif request.tool_choice == "required":
            assert content is not None
            tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(content)
            function_calls.extend(
                [
                    FunctionCall(
                        name=tool_call.name,
                        arguments=json.dumps(tool_call.parameters, ensure_ascii=False),
                    )
                    for tool_call in tool_calls
                ]
            )
            content = None  # Clear content since tool is called.
        elif (
            tool_parser_cls
            and enable_auto_tools
            and (request.tool_choice == "auto" or request.tool_choice is None)
        ):
1544
1545
1546
1547
1548
            if tokenizer is None:
                raise ValueError(
                    "Tokenizer not available when `skip_tokenizer_init=True`"
                )

1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
            # Automatic Tool Call Parsing
            try:
                tool_parser = tool_parser_cls(tokenizer)
            except RuntimeError as e:
                logger.exception("Error in tool parser creation.")
                raise e
            tool_call_info = tool_parser.extract_tool_calls(
                content if content is not None else "",
                request=request,  # type: ignore
            )
            if tool_call_info is not None and tool_call_info.tools_called:
                # extract_tool_calls() returns a list of tool calls.
                function_calls.extend(
                    FunctionCall(
                        name=tool_call.function.name,
                        arguments=tool_call.function.arguments,
                    )
                    for tool_call in tool_call_info.tool_calls
                )
                content = tool_call_info.content
1569
1570
                if content and content.strip() == "":
                    content = None
1571
1572
1573
1574
1575
1576
            else:
                # No tool calls.
                return None, content

        return function_calls, content

1577
    @staticmethod
1578
1579
1580
    def _get_decoded_token(
        logprob: Logprob,
        token_id: int,
1581
        tokenizer: TokenizerLike | None,
1582
1583
        return_as_token_id: bool = False,
    ) -> str:
1584
1585
1586
        if return_as_token_id:
            return f"token_id:{token_id}"

1587
1588
        if logprob.decoded_token is not None:
            return logprob.decoded_token
1589
1590
1591
1592
1593
1594

        if tokenizer is None:
            raise ValueError(
                "Unable to get tokenizer because `skip_tokenizer_init=True`"
            )

1595
        return tokenizer.decode(token_id)
1596

1597
    def _is_model_supported(self, model_name: str | None) -> bool:
1598
1599
        if not model_name:
            return True
1600
        return self.models.is_base_model(model_name)
1601

1602
1603

def clamp_prompt_logprobs(
1604
1605
    prompt_logprobs: PromptLogprobs | None,
) -> PromptLogprobs | None:
1606
1607
1608
1609
1610
1611
1612
    if prompt_logprobs is None:
        return prompt_logprobs

    for logprob_dict in prompt_logprobs:
        if logprob_dict is None:
            continue
        for logprob_values in logprob_dict.values():
1613
            if logprob_values.logprob == float("-inf"):
1614
1615
                logprob_values.logprob = -9999.0
    return prompt_logprobs