serving.py 55.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
import json
5
import sys
6
import time
7
import traceback
8
from collections.abc import AsyncGenerator, Callable, Iterable, Mapping
9
from dataclasses import dataclass, field
10
from http import HTTPStatus
11
from typing import Any, ClassVar, Generic, TypeAlias, TypeVar, cast
12

13
import numpy as np
14
from fastapi import Request
15
16
17
from openai.types.responses import (
    ToolChoiceFunction,
)
18
19
from pydantic import ConfigDict, TypeAdapter
from starlette.datastructures import Headers
20

21
import vllm.envs as envs
22
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
23
from vllm.engine.protocol import EngineClient
24
25
26
27
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ChatTemplateContentFormatOption,
    ConversationMessage,
28
)
29
from vllm.entrypoints.logger import RequestLogger
30
from vllm.entrypoints.openai.chat_completion.protocol import (
31
    ChatCompletionNamedToolChoiceParam,
32
33
    ChatCompletionRequest,
    ChatCompletionResponse,
34
)
35
from vllm.entrypoints.openai.completion.protocol import (
36
37
    CompletionRequest,
    CompletionResponse,
38
39
)
from vllm.entrypoints.openai.engine.protocol import (
40
41
    ErrorInfo,
    ErrorResponse,
42
43
    FunctionCall,
    FunctionDefinition,
44
)
45
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
46
47
48
49
50
51
from vllm.entrypoints.openai.responses.context import (
    ConversationContext,
    HarmonyContext,
    ParsableContext,
    StreamingHarmonyContext,
)
52
from vllm.entrypoints.openai.responses.protocol import (
53
    ResponseInputOutputItem,
54
    ResponsesRequest,
55
)
56
57
58
from vllm.entrypoints.openai.responses.utils import (
    construct_input_messages,
)
59
from vllm.entrypoints.openai.translations.protocol import (
60
61
62
63
    TranscriptionRequest,
    TranscriptionResponse,
    TranslationRequest,
)
64
65
66
67
68
69
from vllm.entrypoints.pooling.classify.protocol import (
    ClassificationChatRequest,
    ClassificationCompletionRequest,
    ClassificationResponse,
)
from vllm.entrypoints.pooling.embed.protocol import (
70
    EmbeddingBytesResponse,
71
72
73
74
75
76
    EmbeddingChatRequest,
    EmbeddingCompletionRequest,
    EmbeddingResponse,
)
from vllm.entrypoints.pooling.pooling.protocol import (
    IOProcessorRequest,
77
78
    PoolingChatRequest,
    PoolingCompletionRequest,
79
80
81
82
    PoolingResponse,
)
from vllm.entrypoints.pooling.score.protocol import (
    RerankRequest,
83
84
    ScoreDataRequest,
    ScoreQueriesDocumentsRequest,
85
86
    ScoreRequest,
    ScoreResponse,
87
    ScoreTextRequest,
88
)
89
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig
90
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
91
92
93
94
95
96
from vllm.entrypoints.serve.tokenize.protocol import (
    DetokenizeRequest,
    TokenizeChatRequest,
    TokenizeCompletionRequest,
    TokenizeResponse,
)
97
98
99
100
101
from vllm.entrypoints.utils import (
    _validate_truncation_size,
    get_max_tokens,
    sanitize_message,
)
102
from vllm.exceptions import VLLMValidationError
103
from vllm.inputs.data import PromptType, TokensPrompt
104
105
106
107
from vllm.inputs.parse import (
    get_prompt_components,
    is_explicit_encoder_decoder_prompt,
)
108
from vllm.logger import init_logger
109
from vllm.logprobs import Logprob, PromptLogprobs
110
from vllm.lora.request import LoRARequest
111
from vllm.multimodal import MultiModalDataDict
112
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
113
from vllm.pooling_params import PoolingParams
114
from vllm.reasoning import ReasoningParser, ReasoningParserManager
115
from vllm.renderers import RendererLike
116
from vllm.sampling_params import BeamSearchParams, SamplingParams
117
from vllm.tokenizers import TokenizerLike
118
from vllm.tool_parsers import ToolParser, ToolParserManager
119
120
121
122
123
from vllm.tracing import (
    contains_trace_headers,
    extract_trace_headers,
    log_tracing_disabled_warning,
)
124
from vllm.utils import random_uuid
125
from vllm.utils.async_utils import (
126
    AsyncMicrobatchTokenizer,
127
    collect_from_async_generator,
128
129
    merge_async_iterators,
)
130
from vllm.v1.engine import EngineCoreRequest
131

132
133
134
135
136
137
138
139
140

class GenerationError(Exception):
    """raised when finish_reason indicates internal server error (500)"""

    def __init__(self, message: str = "Internal server error"):
        super().__init__(message)
        self.status_code = HTTPStatus.INTERNAL_SERVER_ERROR


141
142
logger = init_logger(__name__)

143
144
CompletionLikeRequest: TypeAlias = (
    CompletionRequest
145
    | TokenizeCompletionRequest
146
147
    | DetokenizeRequest
    | EmbeddingCompletionRequest
148
    | ClassificationCompletionRequest
149
    | RerankRequest
150
    | ScoreRequest
151
    | PoolingCompletionRequest
152
)
153

154
ChatLikeRequest: TypeAlias = (
155
156
    ChatCompletionRequest
    | TokenizeChatRequest
157
    | EmbeddingChatRequest
158
    | ClassificationChatRequest
159
    | PoolingChatRequest
160
161
162
163
164
165
166
167
)
SpeechToTextRequest: TypeAlias = TranscriptionRequest | TranslationRequest
AnyRequest: TypeAlias = (
    CompletionLikeRequest
    | ChatLikeRequest
    | SpeechToTextRequest
    | ResponsesRequest
    | IOProcessorRequest
168
    | GenerateRequest
169
170
171
172
173
174
)

AnyResponse: TypeAlias = (
    CompletionResponse
    | ChatCompletionResponse
    | EmbeddingResponse
175
    | EmbeddingBytesResponse
176
177
178
179
180
    | TranscriptionResponse
    | TokenizeResponse
    | PoolingResponse
    | ClassificationResponse
    | ScoreResponse
181
    | GenerateResponse
182
)
183

184

185
186
187
RequestT = TypeVar("RequestT", bound=AnyRequest)


188
@dataclass(kw_only=True)
189
class ServeContext(Generic[RequestT]):
190
    request: RequestT
191
    raw_request: Request | None = None
192
193
    model_name: str
    request_id: str
194
    created_time: int = field(default_factory=lambda: int(time.time()))
195
    lora_request: LoRARequest | None = None
196
    engine_prompts: list[TokensPrompt] | None = None
197

198
199
200
201
    result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
        None
    )
    final_res_batch: list[PoolingRequestOutput] = field(default_factory=list)
202

203
    model_config = ConfigDict(arbitrary_types_allowed=True)
204
205


206
class OpenAIServing:
207
208
209
210
    request_id_prefix: ClassVar[str] = """
    A short string prepended to every request’s ID (e.g. "embd", "classify")
    so you can easily tell “this ID came from Embedding vs Classification.”
    """
211

212
213
    def __init__(
        self,
214
        engine_client: EngineClient,
215
        models: OpenAIServingModels,
216
        *,
217
        request_logger: RequestLogger | None,
218
        return_tokens_as_token_ids: bool = False,
219
        log_error_stack: bool = False,
220
    ):
221
222
        super().__init__()

223
        self.engine_client = engine_client
224

225
        self.models = models
226

227
        self.request_logger = request_logger
228
        self.return_tokens_as_token_ids = return_tokens_as_token_ids
229

230
        self._async_tokenizer_pool: dict[TokenizerLike, AsyncMicrobatchTokenizer] = {}
231
        self.log_error_stack = log_error_stack
232

233
        self.input_processor = self.models.input_processor
234
        self.io_processor = self.models.io_processor
235
        self.renderer = self.models.renderer
236
237
238
        self.model_config = self.models.model_config
        self.max_model_len = self.model_config.max_model_len

239
    def _get_tool_parser(
240
        self, tool_parser_name: str | None = None, enable_auto_tools: bool = False
241
    ) -> Callable[[TokenizerLike], ToolParser] | None:
242
243
244
245
        """Get the tool parser based on the name."""
        parser = None
        if not enable_auto_tools or tool_parser_name is None:
            return parser
246
        logger.info('"auto" tool choice has been enabled.')
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266

        try:
            if tool_parser_name == "pythonic" and self.model_config.model.startswith(
                "meta-llama/Llama-3.2"
            ):
                logger.warning(
                    "Llama3.2 models may struggle to emit valid pythonic tool calls"
                )
            parser = ToolParserManager.get_tool_parser(tool_parser_name)
        except Exception as e:
            raise TypeError(
                "Error: --enable-auto-tool-choice requires "
                f"tool_parser:'{tool_parser_name}' which has not "
                "been registered"
            ) from e
        return parser

    def _get_reasoning_parser(
        self,
        reasoning_parser_name: str,
267
    ) -> Callable[[TokenizerLike], ReasoningParser] | None:
268
269
270
271
272
273
274
275
276
277
278
        """Get the reasoning parser based on the name."""
        parser = None
        if not reasoning_parser_name:
            return None
        try:
            parser = ReasoningParserManager.get_reasoning_parser(reasoning_parser_name)
            assert parser is not None
        except Exception as e:
            raise TypeError(f"{reasoning_parser_name=} has not been registered") from e
        return parser

279
    async def reset_mm_cache(self) -> None:
280
        self.input_processor.clear_mm_cache()
281
282
        await self.engine_client.reset_mm_cache()

283
284
285
286
287
    async def beam_search(
        self,
        prompt: PromptType,
        request_id: str,
        params: BeamSearchParams,
288
        lora_request: LoRARequest | None = None,
289
        trace_headers: Mapping[str, str] | None = None,
290
291
292
293
294
295
296
297
    ) -> AsyncGenerator[RequestOutput, None]:
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        ignore_eos = params.ignore_eos
        temperature = params.temperature
        length_penalty = params.length_penalty
        include_stop_str_in_output = params.include_stop_str_in_output

298
299
        input_processor = self.input_processor
        tokenizer = input_processor.tokenizer
300
        if tokenizer is None:
301
302
303
304
            raise VLLMValidationError(
                "You cannot use beam search when `skip_tokenizer_init=True`",
                parameter="skip_tokenizer_init",
                value=True,
305
306
307
308
309
310
311
            )

        eos_token_id: int = tokenizer.eos_token_id  # type: ignore

        if is_explicit_encoder_decoder_prompt(prompt):
            raise NotImplementedError

312
        prompt_text: str | None
313
        prompt_token_ids: list[int]
314
        multi_modal_data: MultiModalDataDict | None
315
316
317
318
319
320
321
322
323
        if isinstance(prompt, str):
            prompt_text = prompt
            prompt_token_ids = []
            multi_modal_data = None
        else:
            prompt_text = prompt.get("prompt")  # type: ignore
            prompt_token_ids = prompt.get("prompt_token_ids", [])  # type: ignore
            multi_modal_data = prompt.get("multi_modal_data")  # type: ignore

324
325
326
327
328
329
330
331
332
333
        mm_processor_kwargs: dict[str, Any] | None = None

        # This is a workaround to fix multimodal beam search; this is a
        # bandaid fix for 2 small problems:
        # 1. Multi_modal_data on the processed_inputs currently resolves to
        #    `None`.
        # 2. preprocessing above expands the multimodal placeholders. However,
        #    this happens again in generation, so the double expansion causes
        #    a mismatch.
        # TODO - would be ideal to handle this more gracefully.
334
335
336
337
338

        tokenized_length = len(prompt_token_ids)

        sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)

339
        logprobs_num = 2 * beam_width
340
        beam_search_params = SamplingParams(
341
            logprobs=logprobs_num,
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
            max_tokens=1,
            temperature=temperature,
        )
        all_beams = [
            BeamSearchSequence(
                tokens=prompt_token_ids,
                cum_logprob=0,
                logprobs=[],
                multi_modal_data=multi_modal_data,
                mm_processor_kwargs=mm_processor_kwargs,
                lora_request=lora_request,
            )
        ]
        completed = []

        for _ in range(max_tokens):
            prompts_batch, lora_req_batch = zip(
                *[
                    (
361
                        TokensPrompt(
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
                            prompt_token_ids=beam.tokens,
                            multi_modal_data=beam.multi_modal_data,
                            mm_processor_kwargs=beam.mm_processor_kwargs,
                        ),
                        beam.lora_request,
                    )
                    for beam in all_beams
                ]
            )

            tasks = []
            request_id_batch = f"{request_id}-{random_uuid()}"

            for i, (individual_prompt, lora_req) in enumerate(
                zip(prompts_batch, lora_req_batch)
            ):
                request_id_item = f"{request_id_batch}-beam-{i}"
                task = asyncio.create_task(
                    collect_from_async_generator(
                        self.engine_client.generate(
                            individual_prompt,
                            beam_search_params,
                            request_id_item,
                            lora_request=lora_req,
386
                            trace_headers=trace_headers,
387
388
389
390
391
392
393
394
                        )
                    )
                )
                tasks.append(task)

            output = [x[0] for x in await asyncio.gather(*tasks)]

            new_beams = []
395
396
397
398
399
400
401
402
            # Store all new tokens generated by beam
            all_beams_token_id = []
            # Store the cumulative probability of all tokens
            # generated by beam search
            all_beams_logprob = []
            # Iterate through all beam inference results
            for i, result in enumerate(output):
                current_beam = all_beams[i]
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425

                # check for error finish reason and abort beam search
                if result.outputs[0].finish_reason == "error":
                    # yield error output and terminate beam search
                    yield RequestOutput(
                        request_id=request_id,
                        prompt=prompt_text,
                        outputs=[
                            CompletionOutput(
                                index=0,
                                text="",
                                token_ids=[],
                                cumulative_logprob=None,
                                logprobs=None,
                                finish_reason="error",
                            )
                        ],
                        finished=True,
                        prompt_token_ids=prompt_token_ids,
                        prompt_logprobs=None,
                    )
                    return

426
427
                if result.outputs[0].logprobs is not None:
                    logprobs = result.outputs[0].logprobs[0]
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
                    all_beams_token_id.extend(list(logprobs.keys()))
                    all_beams_logprob.extend(
                        [
                            current_beam.cum_logprob + obj.logprob
                            for obj in logprobs.values()
                        ]
                    )

            # Handle the token for the end of sentence (EOS)
            all_beams_token_id = np.array(all_beams_token_id)
            all_beams_logprob = np.array(all_beams_logprob)

            if not ignore_eos:
                # Get the index position of eos token in all generated results
                eos_idx = np.where(all_beams_token_id == eos_token_id)[0]
                for idx in eos_idx:
                    current_beam = all_beams[idx // logprobs_num]
                    result = output[idx // logprobs_num]
                    assert result.outputs[0].logprobs is not None
                    logprobs_entry = result.outputs[0].logprobs[0]
                    completed.append(
                        BeamSearchSequence(
                            tokens=current_beam.tokens + [eos_token_id]
                            if include_stop_str_in_output
                            else current_beam.tokens,
                            logprobs=current_beam.logprobs + [logprobs_entry],
                            cum_logprob=float(all_beams_logprob[idx]),
                            finish_reason="stop",
                            stop_reason=eos_token_id,
                        )
                    )
                # After processing, set the log probability of the eos condition
                # to negative infinity.
                all_beams_logprob[eos_idx] = -np.inf

            # Processing non-EOS tokens
            # Get indices of the top beam_width probabilities
            topn_idx = np.argpartition(np.negative(all_beams_logprob), beam_width)[
                :beam_width
            ]

            for idx in topn_idx:
                current_beam = all_beams[idx // logprobs_num]
                result = output[idx // logprobs_num]
                token_id = int(all_beams_token_id[idx])
                assert result.outputs[0].logprobs is not None
                logprobs_entry = result.outputs[0].logprobs[0]
                new_beams.append(
                    BeamSearchSequence(
                        tokens=current_beam.tokens + [token_id],
                        logprobs=current_beam.logprobs + [logprobs_entry],
                        lora_request=current_beam.lora_request,
                        cum_logprob=float(all_beams_logprob[idx]),
                        multi_modal_data=current_beam.multi_modal_data,
                        mm_processor_kwargs=current_beam.mm_processor_kwargs,
                    )
                )

            all_beams = new_beams
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520

        completed.extend(all_beams)
        sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
        best_beams = sorted_completed[:beam_width]

        for beam in best_beams:
            if beam.tokens[-1] == eos_token_id and not ignore_eos:
                # Skip the eos token in the text.
                tokens = beam.tokens[tokenized_length:-1]
            else:
                tokens = beam.tokens[tokenized_length:]
            beam.text = tokenizer.decode(tokens)

        yield RequestOutput(
            request_id=request_id,
            prompt=prompt_text,
            outputs=[
                CompletionOutput(
                    text=beam.text,  # type: ignore
                    cumulative_logprob=beam.cum_logprob,
                    token_ids=beam.tokens[tokenized_length:],
                    index=i,
                    logprobs=beam.logprobs,
                    finish_reason=beam.finish_reason
                    if beam.finish_reason is not None
                    else "length",
                    stop_reason=beam.stop_reason,
                )
                for (i, beam) in enumerate(best_beams)
            ],
            finished=True,
            prompt_token_ids=prompt_token_ids,
            prompt_logprobs=None,
        )
521

522
    def _get_completion_renderer(self) -> BaseRenderer:
523
524
525
526
527
528
        """
        Get a Renderer instance with the provided tokenizer.
        Uses shared async tokenizer pool for efficiency.
        """
        return CompletionRenderer(
            model_config=self.model_config,
529
            tokenizer=self.renderer.tokenizer,
530
531
            async_tokenizer_pool=self._async_tokenizer_pool,
        )
532

533
534
535
536
537
538
539
540
541
542
543
544
545
    def _build_render_config(
        self,
        request: Any,
    ) -> RenderConfig:
        """
        Build and return a `RenderConfig` for an endpoint.

        Used by the renderer to control how prompts are prepared
        (e.g., tokenization and length handling). Endpoints should
        implement this with logic appropriate to their request type.
        """
        raise NotImplementedError

546
547
    def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer:
        """
548
        Return (and cache) an `AsyncMicrobatchTokenizer` bound to the
549
550
551
552
553
554
555
        given tokenizer.
        """
        async_tokenizer = self._async_tokenizer_pool.get(tokenizer)
        if async_tokenizer is None:
            async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
            self._async_tokenizer_pool[tokenizer] = async_tokenizer
        return async_tokenizer
556

557
558
559
    async def _preprocess(
        self,
        ctx: ServeContext,
560
    ) -> ErrorResponse | None:
561
562
563
564
565
566
567
568
569
        """
        Default preprocessing hook. Subclasses may override
        to prepare `ctx` (classification, embedding, etc.).
        """
        return None

    def _build_response(
        self,
        ctx: ServeContext,
570
    ) -> AnyResponse | ErrorResponse:
571
572
573
574
575
576
577
578
579
        """
        Default response builder. Subclass may override this method
        to return the appropriate response object.
        """
        return self.create_error_response("unimplemented endpoint")

    async def handle(
        self,
        ctx: ServeContext,
580
    ) -> AnyResponse | ErrorResponse:
581
        async for response in self._pipeline(ctx):
582
583
584
585
586
587
588
            return response

        return self.create_error_response("No response yielded from pipeline")

    async def _pipeline(
        self,
        ctx: ServeContext,
589
    ) -> AsyncGenerator[AnyResponse | ErrorResponse, None]:
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
        """Execute the request processing pipeline yielding responses."""
        if error := await self._check_model(ctx.request):
            yield error
        if error := self._validate_request(ctx):
            yield error

        preprocess_ret = await self._preprocess(ctx)
        if isinstance(preprocess_ret, ErrorResponse):
            yield preprocess_ret

        generators_ret = await self._prepare_generators(ctx)
        if isinstance(generators_ret, ErrorResponse):
            yield generators_ret

        collect_ret = await self._collect_batch(ctx)
        if isinstance(collect_ret, ErrorResponse):
            yield collect_ret

        yield self._build_response(ctx)

610
    def _validate_request(self, ctx: ServeContext) -> ErrorResponse | None:
611
        truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None)
612

613
614
615
616
        if (
            truncate_prompt_tokens is not None
            and truncate_prompt_tokens > self.max_model_len
        ):
617
618
619
            return self.create_error_response(
                "truncate_prompt_tokens value is "
                "greater than max_model_len."
620
621
                " Please, select a smaller truncation size."
            )
622
623
        return None

624
625
626
    def _create_pooling_params(
        self,
        ctx: ServeContext,
627
    ) -> PoolingParams | ErrorResponse:
628
629
        if not hasattr(ctx.request, "to_pooling_params"):
            return self.create_error_response(
630
631
                "Request type does not support pooling parameters"
            )
632
633
634

        return ctx.request.to_pooling_params()

635
636
637
    async def _prepare_generators(
        self,
        ctx: ServeContext,
638
    ) -> ErrorResponse | None:
639
        """Schedule the request and get the result generator."""
640
        generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
641
642

        try:
643
644
645
646
647
            trace_headers = (
                None
                if ctx.raw_request is None
                else await self._get_trace_headers(ctx.raw_request.headers)
            )
648

649
650
651
            pooling_params = self._create_pooling_params(ctx)
            if isinstance(pooling_params, ErrorResponse):
                return pooling_params
652
653

            if ctx.engine_prompts is None:
654
                return self.create_error_response("Engine prompts not available")
655
656
657
658

            for i, engine_prompt in enumerate(ctx.engine_prompts):
                request_id_item = f"{ctx.request_id}-{i}"

659
660
                self._log_inputs(
                    request_id_item,
661
                    engine_prompt,
662
663
664
                    params=pooling_params,
                    lora_request=ctx.lora_request,
                )
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681

                generator = self.engine_client.encode(
                    engine_prompt,
                    pooling_params,
                    request_id_item,
                    lora_request=ctx.lora_request,
                    trace_headers=trace_headers,
                    priority=getattr(ctx.request, "priority", 0),
                )

                generators.append(generator)

            ctx.result_generator = merge_async_iterators(*generators)

            return None

        except Exception as e:
682
            return self.create_error_response(e)
683
684
685
686

    async def _collect_batch(
        self,
        ctx: ServeContext,
687
    ) -> ErrorResponse | None:
688
689
690
        """Collect batch results from the result generator."""
        try:
            if ctx.engine_prompts is None:
691
                return self.create_error_response("Engine prompts not available")
692
693

            num_prompts = len(ctx.engine_prompts)
694
            final_res_batch: list[PoolingRequestOutput | None]
695
696
697
            final_res_batch = [None] * num_prompts

            if ctx.result_generator is None:
698
                return self.create_error_response("Result generator not available")
699
700
701
702
703
704

            async for i, res in ctx.result_generator:
                final_res_batch[i] = res

            if None in final_res_batch:
                return self.create_error_response(
705
706
                    "Failed to generate results for all prompts"
                )
707

708
            ctx.final_res_batch = [res for res in final_res_batch if res is not None]
709
710
711
712

            return None

        except Exception as e:
713
            return self.create_error_response(e)
714

715
    def create_error_response(
716
        self,
717
        message: str | Exception,
718
719
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
720
        param: str | None = None,
721
    ) -> ErrorResponse:
722
723
724
725
726
        exc: Exception | None = None

        if isinstance(message, Exception):
            exc = message

727
            from vllm.exceptions import VLLMValidationError
728
729
730
731
732

            if isinstance(exc, VLLMValidationError):
                err_type = "BadRequestError"
                status_code = HTTPStatus.BAD_REQUEST
                param = exc.parameter
733
            elif isinstance(exc, (ValueError, TypeError, RuntimeError, OverflowError)):
734
735
736
737
                # Common validation errors from user input
                err_type = "BadRequestError"
                status_code = HTTPStatus.BAD_REQUEST
                param = None
738
739
740
741
            elif isinstance(exc, NotImplementedError):
                err_type = "NotImplementedError"
                status_code = HTTPStatus.NOT_IMPLEMENTED
                param = None
742
743
744
745
746
747
748
749
750
751
752
753
            elif exc.__class__.__name__ == "TemplateError":
                # jinja2.TemplateError (avoid importing jinja2)
                err_type = "BadRequestError"
                status_code = HTTPStatus.BAD_REQUEST
                param = None
            else:
                err_type = "InternalServerError"
                status_code = HTTPStatus.INTERNAL_SERVER_ERROR
                param = None

            message = str(exc)

754
755
756
757
758
759
        if self.log_error_stack:
            exc_type, _, _ = sys.exc_info()
            if exc_type is not None:
                traceback.print_exc()
            else:
                traceback.print_stack()
760

761
        return ErrorResponse(
762
            error=ErrorInfo(
763
                message=sanitize_message(message),
764
765
766
767
                type=err_type,
                code=status_code.value,
                param=param,
            )
768
        )
769

770
    def create_streaming_error_response(
771
        self,
772
        message: str | Exception,
773
774
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
775
        param: str | None = None,
776
    ) -> str:
777
        json_str = json.dumps(
778
            self.create_error_response(
779
780
781
782
                message=message,
                err_type=err_type,
                status_code=status_code,
                param=param,
783
784
            ).model_dump()
        )
785
786
        return json_str

787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
    def _raise_if_error(self, finish_reason: str | None, request_id: str) -> None:
        """Raise GenerationError if finish_reason indicates an error."""
        if finish_reason == "error":
            logger.error(
                "Request %s failed with an internal error during generation",
                request_id,
            )
            raise GenerationError("Internal server error")

    def _convert_generation_error_to_response(
        self, e: GenerationError
    ) -> ErrorResponse:
        """Convert GenerationError to ErrorResponse."""
        return self.create_error_response(
            str(e),
            err_type="InternalServerError",
            status_code=e.status_code,
        )

    def _convert_generation_error_to_streaming_response(
        self, e: GenerationError
    ) -> str:
        """Convert GenerationError to streaming error response."""
        return self.create_streaming_error_response(
            str(e),
            err_type="InternalServerError",
            status_code=e.status_code,
        )

816
    async def _check_model(
817
818
        self,
        request: AnyRequest,
819
    ) -> ErrorResponse | None:
820
821
        error_response = None

822
        if self._is_model_supported(request.model):
823
            return None
824
        if request.model in self.models.lora_requests:
825
            return None
826
827
828
829
830
        if (
            envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING
            and request.model
            and (load_result := await self.models.resolve_lora(request.model))
        ):
831
832
            if isinstance(load_result, LoRARequest):
                return None
833
834
835
836
            if (
                isinstance(load_result, ErrorResponse)
                and load_result.error.code == HTTPStatus.BAD_REQUEST.value
            ):
837
838
839
                error_response = load_result

        return error_response or self.create_error_response(
840
841
            message=f"The model `{request.model}` does not exist.",
            err_type="NotFoundError",
842
            status_code=HTTPStatus.NOT_FOUND,
843
            param="model",
844
        )
845

846
    def _get_active_default_mm_loras(self, request: AnyRequest) -> LoRARequest | None:
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
        """Determine if there are any active default multimodal loras."""
        # TODO: Currently this is only enabled for chat completions
        # to be better aligned with only being enabled for .generate
        # when run offline. It would be nice to support additional
        # tasks types in the future.
        message_types = self._get_message_types(request)
        default_mm_loras = set()

        for lora in self.models.lora_requests.values():
            # Best effort match for default multimodal lora adapters;
            # There is probably a better way to do this, but currently
            # this matches against the set of 'types' in any content lists
            # up until '_', e.g., to match audio_url -> audio
            if lora.lora_name in message_types:
                default_mm_loras.add(lora)

        # Currently only support default modality specific loras if
        # we have exactly one lora matched on the request.
        if len(default_mm_loras) == 1:
            return default_mm_loras.pop()
        return None

869
    def _maybe_get_adapters(
870
871
872
        self,
        request: AnyRequest,
        supports_default_mm_loras: bool = False,
873
    ) -> LoRARequest | None:
874
        if request.model in self.models.lora_requests:
875
            return self.models.lora_requests[request.model]
876
877
878
879
880
881

        # Currently only support default modality specific loras
        # if we have exactly one lora matched on the request.
        if supports_default_mm_loras:
            default_mm_lora = self._get_active_default_mm_loras(request)
            if default_mm_lora is not None:
882
                return default_mm_lora
883

884
        if self._is_model_supported(request.model):
885
            return None
886

887
        # if _check_model has been called earlier, this will be unreachable
888
        raise ValueError(f"The model `{request.model}` does not exist.")
889

890
891
892
893
894
895
896
897
898
899
    def _get_message_types(self, request: AnyRequest) -> set[str]:
        """Retrieve the set of types from message content dicts up
        until `_`; we use this to match potential multimodal data
        with default per modality loras.
        """
        message_types: set[str] = set()

        if not hasattr(request, "messages"):
            return message_types

900
901
902
903
904
        messages = request.messages
        if messages is None or isinstance(messages, (str, bytes)):
            return message_types

        for message in messages:
905
906
907
908
909
            if (
                isinstance(message, dict)
                and "content" in message
                and isinstance(message["content"], list)
            ):
910
911
912
913
914
                for content_dict in message["content"]:
                    if "type" in content_dict:
                        message_types.add(content_dict["type"].split("_")[0])
        return message_types

915
    async def _normalize_prompt_text_to_input(
916
917
918
        self,
        request: AnyRequest,
        prompt: str,
919
        tokenizer: TokenizerLike,
920
        add_special_tokens: bool,
921
    ) -> TokensPrompt:
922
923
        async_tokenizer = self._get_async_tokenizer(tokenizer)

924
925
926
927
        if (
            self.model_config.encoder_config is not None
            and self.model_config.encoder_config.get("do_lower_case", False)
        ):
928
929
            prompt = prompt.lower()

930
        truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
931

932
        if truncate_prompt_tokens is None:
933
            encoded = await async_tokenizer(
934
935
                prompt, add_special_tokens=add_special_tokens
            )
936
937
        elif truncate_prompt_tokens < 0:
            # Negative means we cap at the model's max length
938
939
940
941
            encoded = await async_tokenizer(
                prompt,
                add_special_tokens=add_special_tokens,
                truncation=True,
942
943
                max_length=self.max_model_len,
            )
944
        else:
945
946
947
948
            encoded = await async_tokenizer(
                prompt,
                add_special_tokens=add_special_tokens,
                truncation=True,
949
950
                max_length=truncate_prompt_tokens,
            )
951

952
953
        
        input_ids = encoded.input_ids
954
955
956
957
        input_text = prompt

        return self._validate_input(request, input_ids, input_text)

958
    async def _normalize_prompt_tokens_to_input(
959
960
        self,
        request: AnyRequest,
961
        prompt_ids: list[int],
962
        tokenizer: TokenizerLike | None,
963
    ) -> TokensPrompt:
964
        truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
965

966
        if truncate_prompt_tokens is None:
967
            input_ids = prompt_ids
968
        elif truncate_prompt_tokens < 0:
969
            input_ids = prompt_ids[-self.max_model_len :]
970
971
972
        else:
            input_ids = prompt_ids[-truncate_prompt_tokens:]

973
974
975
        if tokenizer is None:
            input_text = ""
        else:
976
            async_tokenizer = self._get_async_tokenizer(tokenizer) 
977
            input_text = await async_tokenizer.decode(input_ids)
978

979
980
981
982
        return self._validate_input(request, input_ids, input_text)

    def _validate_input(
        self,
983
        request: object,
984
        input_ids: list[int],
985
        input_text: str,
986
    ) -> TokensPrompt:
987
988
        token_num = len(input_ids)

989
990
        # Note: EmbeddingRequest, ClassificationRequest,
        # and ScoreRequest doesn't have max_tokens
991
        if isinstance(
992
            request,
993
994
995
            (
                EmbeddingChatRequest,
                EmbeddingCompletionRequest,
996
997
998
                ScoreDataRequest,
                ScoreTextRequest,
                ScoreQueriesDocumentsRequest,
999
                RerankRequest,
1000
1001
                ClassificationCompletionRequest,
                ClassificationChatRequest,
1002
1003
            ),
        ):
1004
1005
            # Note: input length can be up to the entire model context length
            # since these requests don't generate tokens.
1006
            if token_num > self.max_model_len:
1007
                operations: dict[type[AnyRequest], str] = {
1008
1009
1010
                    ScoreDataRequest: "score",
                    ScoreTextRequest: "score",
                    ScoreQueriesDocumentsRequest: "score",
1011
1012
                    ClassificationCompletionRequest: "classification",
                    ClassificationChatRequest: "classification",
1013
                }
1014
                operation = operations.get(type(request), "embedding generation")
1015
                raise VLLMValidationError(
1016
1017
                    f"This model's maximum context length is "
                    f"{self.max_model_len} tokens. However, you requested "
1018
                    f"{token_num} tokens in the input for {operation}. "
1019
1020
1021
                    f"Please reduce the length of the input.",
                    parameter="input_tokens",
                    value=token_num,
1022
                )
1023
            return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
1024

1025
1026
        # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
        # and does not require model context length validation
1027
        if isinstance(
1028
1029
            request,
            (TokenizeCompletionRequest, TokenizeChatRequest, DetokenizeRequest),
1030
        ):
1031
            return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
1032

1033
1034
1035
1036
1037
        # chat completion endpoint supports max_completion_tokens
        if isinstance(request, ChatCompletionRequest):
            # TODO(#9845): remove max_tokens when field dropped from OpenAI API
            max_tokens = request.max_completion_tokens or request.max_tokens
        else:
1038
            max_tokens = getattr(request, "max_tokens", None)
1039
1040
1041
1042

        # Note: input length can be up to model context length - 1 for
        # completion-like requests.
        if token_num >= self.max_model_len:
1043
            raise VLLMValidationError(
1044
                f"This model's maximum context length is "
1045
1046
                f"{self.max_model_len} tokens. However, your request has "
                f"{token_num} input tokens. Please reduce the length of "
1047
1048
1049
                "the input messages.",
                parameter="input_tokens",
                value=token_num,
1050
            )
1051

1052
        if max_tokens is not None and token_num + max_tokens > self.max_model_len:
1053
            raise VLLMValidationError(
1054
1055
1056
1057
                "'max_tokens' or 'max_completion_tokens' is too large: "
                f"{max_tokens}. This model's maximum context length is "
                f"{self.max_model_len} tokens and your request has "
                f"{token_num} input tokens ({max_tokens} > {self.max_model_len}"
1058
1059
1060
                f" - {token_num}).",
                parameter="max_tokens",
                value=max_tokens,
1061
            )
1062

1063
        return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
1064

1065
    async def _tokenize_prompt_input_async(
1066
1067
        self,
        request: AnyRequest,
1068
        tokenizer: TokenizerLike,
1069
        prompt_input: str | list[int],
1070
        add_special_tokens: bool = True,
1071
    ) -> TokensPrompt:
1072
        """
1073
        A simpler implementation that tokenizes a single prompt input.
1074
        """
1075
        async for result in self._tokenize_prompt_inputs_async(
1076
1077
            request,
            tokenizer,
1078
            [prompt_input],
1079
            add_special_tokens=add_special_tokens,
1080
1081
1082
        ):
            return result
        raise ValueError("No results yielded from tokenization")
1083

1084
    async def _tokenize_prompt_inputs_async(
1085
1086
        self,
        request: AnyRequest,
1087
        tokenizer: TokenizerLike,
1088
        prompt_inputs: Iterable[str | list[int]],
1089
        add_special_tokens: bool = True,
1090
    ) -> AsyncGenerator[TokensPrompt, None]:
1091
        """
1092
        A simpler implementation that tokenizes multiple prompt inputs.
1093
        """
1094
1095
        for prompt in prompt_inputs:
            if isinstance(prompt, str):
1096
                yield await self._normalize_prompt_text_to_input(
1097
                    request,
1098
1099
                    prompt=prompt,
                    tokenizer=tokenizer,
1100
1101
1102
                    add_special_tokens=add_special_tokens,
                )
            else:
1103
                yield await self._normalize_prompt_tokens_to_input(
1104
                    request,
1105
1106
                    prompt_ids=prompt,
                    tokenizer=tokenizer,
1107
1108
                )

1109
1110
    def _validate_chat_template(
        self,
1111
1112
        request_chat_template: str | None,
        chat_template_kwargs: dict[str, Any] | None,
1113
        trust_request_chat_template: bool,
1114
    ) -> ErrorResponse | None:
1115
        if not trust_request_chat_template and (
1116
1117
1118
1119
1120
1121
            request_chat_template is not None
            or (
                chat_template_kwargs
                and chat_template_kwargs.get("chat_template") is not None
            )
        ):
1122
1123
1124
            return self.create_error_response(
                "Chat template is passed with request, but "
                "--trust-request-chat-template is not set. "
1125
1126
                "Refused request with untrusted chat template."
            )
1127
1128
        return None

1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
    @staticmethod
    def _prepare_extra_chat_template_kwargs(
        request_chat_template_kwargs: dict[str, Any] | None = None,
        default_chat_template_kwargs: dict[str, Any] | None = None,
    ) -> dict[str, Any]:
        """Helper to merge server-default and request-specific chat template kwargs."""
        request_chat_template_kwargs = request_chat_template_kwargs or {}
        if default_chat_template_kwargs is None:
            return request_chat_template_kwargs
        # Apply server defaults first, then request kwargs override.
        return default_chat_template_kwargs | request_chat_template_kwargs

1141
1142
    async def _preprocess_chat(
        self,
1143
        request: ChatLikeRequest | ResponsesRequest,
1144
        renderer: RendererLike,
1145
        messages: list[ChatCompletionMessageParam],
1146
        chat_template: str | None,
1147
        chat_template_content_format: ChatTemplateContentFormatOption,
1148
1149
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
1150
1151
1152
        tool_dicts: list[dict[str, Any]] | None = None,
        documents: list[dict[str, str]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
1153
        default_chat_template_kwargs: dict[str, Any] | None = None,
1154
        tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
1155
        add_special_tokens: bool = False,
1156
    ) -> tuple[list[ConversationMessage], list[TokensPrompt]]:
1157
1158
1159
1160
1161
1162
1163
1164
1165
        chat_template_kwargs = {
            "chat_template": chat_template,
            "add_generation_prompt": add_generation_prompt,
            "continue_final_message": continue_final_message,
            "tools": tool_dicts,
            "documents": documents,
            **(chat_template_kwargs or {}),
        }
        chat_template_kwargs = self._prepare_extra_chat_template_kwargs(
1166
1167
1168
            chat_template_kwargs,
            default_chat_template_kwargs,
        )
1169

1170
1171
1172
1173
1174
        # Use the async tokenizer in `OpenAIServing` if possible.
        # Later we can move it into the renderer so that we can return both
        # text and token IDs in the same prompt from `render_messages_async`
        # which is used for logging and `enable_response_messages`.
        from vllm.tokenizers.mistral import MistralTokenizer
1175

1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
        conversation, engine_prompt = await renderer.render_messages_async(
            messages,
            chat_template_content_format=chat_template_content_format,
            tokenize=(
                chat_template_kwargs.pop("tokenize", False)
                or isinstance(renderer.tokenizer, MistralTokenizer)
            ),
            **chat_template_kwargs,
        )

        if "prompt_token_ids" not in engine_prompt:
            extra_data = engine_prompt
            engine_prompt = await self._tokenize_prompt_input_async(
                request,
                renderer.get_tokenizer(),
                engine_prompt["prompt"],
                add_special_tokens=add_special_tokens,
1193
            )
1194
1195
            # Fill in other keys like MM data
            engine_prompt.update(extra_data)  # type: ignore
1196
        else:
1197
1198
1199
1200
            self._validate_input(
                request=request,
                input_ids=engine_prompt["prompt_token_ids"],  # type: ignore
                input_text="",
1201
1202
            )

1203
1204
1205
1206
1207
1208
        engine_prompt = cast(TokensPrompt, engine_prompt)

        if request.mm_processor_kwargs is not None:
            engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
        if (cache_salt := getattr(request, "cache_salt", None)) is not None:
            engine_prompt["cache_salt"] = cache_salt
1209

1210
1211
1212
        # tool parsing is done only if a tool_parser has been set and if
        # tool_choice is not "none" (if tool_choice is "none" but a tool_parser
        # is set, we want to prevent parsing a tool_call hallucinated by the LLM
1213
1214
1215
        should_parse_tools = tool_parser is not None and (
            hasattr(request, "tool_choice") and request.tool_choice != "none"
        )
1216
1217

        if should_parse_tools:
1218
1219
1220
1221
1222
            if not isinstance(request, ChatCompletionRequest | ResponsesRequest):
                msg = (
                    "Tool usage is only supported for Chat Completions API "
                    "or Responses API requests."
                )
1223
1224
                raise NotImplementedError(msg)

1225
1226
            tokenizer = renderer.get_tokenizer()
            request = tool_parser(tokenizer).adjust_request(request=request)  # type: ignore
1227

1228
        return conversation, [engine_prompt]
1229

1230
1231
1232
1233
    async def _process_inputs(
        self,
        request_id: str,
        engine_prompt: PromptType,
1234
        params: SamplingParams | PoolingParams,
1235
        *,
1236
1237
        lora_request: LoRARequest | None,
        trace_headers: Mapping[str, str] | None,
1238
        priority: int,
1239
        data_parallel_rank: int | None = None,
1240
    ) -> tuple[EngineCoreRequest, dict[str, Any]]:
1241
        """Use the Processor to process inputs for AsyncLLM."""
1242
        tokenization_kwargs: dict[str, Any] = {}
1243
1244
1245
        _validate_truncation_size(
            self.max_model_len, params.truncate_prompt_tokens, tokenization_kwargs
        )
1246

1247
        engine_request = self.input_processor.process_inputs(
1248
1249
            request_id,
            engine_prompt,
1250
            params,
1251
1252
1253
1254
            lora_request=lora_request,
            tokenization_kwargs=tokenization_kwargs,
            trace_headers=trace_headers,
            priority=priority,
1255
            data_parallel_rank=data_parallel_rank,
1256
1257
1258
        )
        return engine_request, tokenization_kwargs

1259
1260
1261
    async def _render_next_turn(
        self,
        request: ResponsesRequest,
1262
        renderer: RendererLike,
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
        messages: list[ResponseInputOutputItem],
        tool_dicts: list[dict[str, Any]] | None,
        tool_parser,
        chat_template: str | None,
        chat_template_content_format: ChatTemplateContentFormatOption,
    ):
        new_messages = construct_input_messages(
            request_input=messages,
        )

1273
        _, engine_prompts = await self._preprocess_chat(
1274
            request,
1275
            renderer,
1276
1277
1278
1279
1280
1281
            new_messages,
            tool_dicts=tool_dicts,
            tool_parser=tool_parser,
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
        )
1282
        return engine_prompts
1283

1284
    async def _generate_with_builtin_tools(
1285
        self,
1286
        request_id: str,
1287
        engine_prompt: TokensPrompt,
1288
1289
        sampling_params: SamplingParams,
        context: ConversationContext,
1290
        lora_request: LoRARequest | None = None,
1291
1292
1293
        priority: int = 0,
        **kwargs,
    ):
1294
        prompt_text, _, _ = get_prompt_components(engine_prompt)
1295

1296
        orig_priority = priority
1297
        sub_request = 0
1298
        while True:
1299
1300
            # Ensure that each sub-request has a unique request id.
            sub_request_id = f"{request_id}_{sub_request}"
1301
            self._log_inputs(
1302
                sub_request_id,
1303
                engine_prompt,
1304
1305
1306
                params=sampling_params,
                lora_request=lora_request,
            )
1307
            trace_headers = kwargs.get("trace_headers")
1308
            engine_request, tokenization_kwargs = await self._process_inputs(
1309
                sub_request_id,
1310
1311
                engine_prompt,
                sampling_params,
1312
1313
1314
                lora_request=lora_request,
                trace_headers=trace_headers,
                priority=priority,
1315
            )
1316
1317
1318
1319

            generator = self.engine_client.generate(
                engine_request,
                sampling_params,
1320
                sub_request_id,
1321
1322
                lora_request=lora_request,
                priority=priority,
1323
1324
                prompt_text=prompt_text,
                tokenization_kwargs=tokenization_kwargs,
1325
1326
                **kwargs,
            )
1327

1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
            async for res in generator:
                context.append_output(res)
                # NOTE(woosuk): The stop condition is handled by the engine.
                yield context

            if not context.need_builtin_tool_call():
                # The model did not ask for a tool call, so we're done.
                break

            # Call the tool and update the context with the result.
            tool_output = await context.call_tool()
1339
            context.append_tool_output(tool_output)
1340
1341
1342
1343
1344

            # TODO: uncomment this and enable tool output streaming
            # yield context

            # Create inputs for the next turn.
1345
            # Render the next prompt token ids and update sampling_params.
1346
            if isinstance(context, (HarmonyContext, StreamingHarmonyContext)):
1347
1348
1349
1350
                token_ids = context.render_for_completion()
                engine_prompt = TokensPrompt(prompt_token_ids=token_ids)

                sampling_params.max_tokens = self.max_model_len - len(token_ids)
1351
            elif isinstance(context, ParsableContext):
1352
                engine_prompts = await self._render_next_turn(
1353
                    context.request,
1354
                    context.renderer,
1355
1356
1357
1358
1359
1360
1361
                    context.parser.response_messages,
                    context.tool_dicts,
                    context.tool_parser_cls,
                    context.chat_template,
                    context.chat_template_content_format,
                )
                engine_prompt = engine_prompts[0]
1362
1363
1364
1365
1366
1367
1368
1369
                prompt_text, _, _ = get_prompt_components(engine_prompt)

                sampling_params.max_tokens = get_max_tokens(
                    self.max_model_len,
                    context.request,
                    engine_prompt,
                    self.default_sampling_params,  # type: ignore
                )
1370

1371
1372
            # OPTIMIZATION
            priority = orig_priority - 1
1373
            sub_request += 1
1374

1375
1376
1377
    def _log_inputs(
        self,
        request_id: str,
1378
        inputs: PromptType,
1379
1380
        params: SamplingParams | PoolingParams | BeamSearchParams | None,
        lora_request: LoRARequest | None,
1381
1382
1383
    ) -> None:
        if self.request_logger is None:
            return
1384

1385
        prompt, prompt_token_ids, prompt_embeds = get_prompt_components(inputs)
1386
1387
1388
1389
1390

        self.request_logger.log_inputs(
            request_id,
            prompt,
            prompt_token_ids,
1391
            prompt_embeds,
1392
1393
1394
            params=params,
            lora_request=lora_request,
        )
1395

1396
1397
1398
    async def _get_trace_headers(
        self,
        headers: Headers,
1399
    ) -> Mapping[str, str] | None:
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
        is_tracing_enabled = await self.engine_client.is_tracing_enabled()

        if is_tracing_enabled:
            return extract_trace_headers(headers)

        if contains_trace_headers(headers):
            log_tracing_disabled_warning()

        return None

1410
    @staticmethod
1411
    def _base_request_id(
1412
1413
        raw_request: Request | None, default: str | None = None
    ) -> str | None:
1414
        """Pulls the request id to use from a header, if provided"""
1415
1416
1417
1418
        if raw_request is not None and (
            (req_id := raw_request.headers.get("X-Request-Id")) is not None
        ):
            return req_id
1419

1420
        return random_uuid() if default is None else default
1421

1422
1423
1424
    @staticmethod
    def _get_data_parallel_rank(raw_request: Request | None) -> int | None:
        """Pulls the data parallel rank from a header, if provided"""
1425
        if raw_request is None:
1426
1427
1428
1429
1430
            return None

        rank_str = raw_request.headers.get("X-data-parallel-rank")
        if rank_str is None:
            return None
1431

1432
1433
1434
1435
1436
        try:
            return int(rank_str)
        except ValueError:
            return None

1437
1438
1439
    @staticmethod
    def _parse_tool_calls_from_content(
        request: ResponsesRequest | ChatCompletionRequest,
1440
        tokenizer: TokenizerLike | None,
1441
        enable_auto_tools: bool,
1442
        tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
        content: str | None = None,
    ) -> tuple[list[FunctionCall] | None, str | None]:
        function_calls = list[FunctionCall]()
        if request.tool_choice and isinstance(request.tool_choice, ToolChoiceFunction):
            assert content is not None
            # Forced Function Call
            function_calls.append(
                FunctionCall(name=request.tool_choice.name, arguments=content)
            )
            content = None  # Clear content since tool is called.
        elif request.tool_choice and isinstance(
            request.tool_choice, ChatCompletionNamedToolChoiceParam
        ):
            assert content is not None
            # Forced Function Call
            function_calls.append(
                FunctionCall(name=request.tool_choice.function.name, arguments=content)
            )
            content = None  # Clear content since tool is called.
        elif request.tool_choice == "required":
            assert content is not None
            tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(content)
            function_calls.extend(
                [
                    FunctionCall(
                        name=tool_call.name,
                        arguments=json.dumps(tool_call.parameters, ensure_ascii=False),
                    )
                    for tool_call in tool_calls
                ]
            )
            content = None  # Clear content since tool is called.
        elif (
            tool_parser_cls
            and enable_auto_tools
            and (request.tool_choice == "auto" or request.tool_choice is None)
        ):
1480
1481
1482
1483
1484
            if tokenizer is None:
                raise ValueError(
                    "Tokenizer not available when `skip_tokenizer_init=True`"
                )

1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
            # Automatic Tool Call Parsing
            try:
                tool_parser = tool_parser_cls(tokenizer)
            except RuntimeError as e:
                logger.exception("Error in tool parser creation.")
                raise e
            tool_call_info = tool_parser.extract_tool_calls(
                content if content is not None else "",
                request=request,  # type: ignore
            )
            if tool_call_info is not None and tool_call_info.tools_called:
                # extract_tool_calls() returns a list of tool calls.
                function_calls.extend(
                    FunctionCall(
1499
                        id=tool_call.id,
1500
1501
1502
1503
1504
1505
                        name=tool_call.function.name,
                        arguments=tool_call.function.arguments,
                    )
                    for tool_call in tool_call_info.tool_calls
                )
                content = tool_call_info.content
1506
1507
                if content and content.strip() == "":
                    content = None
1508
1509
1510
1511
1512
            else:
                # No tool calls.
                return None, content

        return function_calls, content
1513

1514
    @staticmethod
1515
1516
1517
    def _get_decoded_token(
        logprob: Logprob,
        token_id: int,
1518
        tokenizer: TokenizerLike | None,
1519
1520
        return_as_token_id: bool = False,
    ) -> str:
1521
1522
1523
        if return_as_token_id:
            return f"token_id:{token_id}"

1524
1525
        if logprob.decoded_token is not None:
            return logprob.decoded_token
1526
1527
1528
1529
1530
1531

        if tokenizer is None:
            raise ValueError(
                "Unable to get tokenizer because `skip_tokenizer_init=True`"
            )

1532
        return tokenizer.decode(token_id)
1533

1534
    def _is_model_supported(self, model_name: str | None) -> bool:
1535
1536
        if not model_name:
            return True
1537
        return self.models.is_base_model(model_name)
1538

1539
1540

def clamp_prompt_logprobs(
1541
1542
    prompt_logprobs: PromptLogprobs | None,
) -> PromptLogprobs | None:
1543
1544
1545
1546
1547
1548
1549
    if prompt_logprobs is None:
        return prompt_logprobs

    for logprob_dict in prompt_logprobs:
        if logprob_dict is None:
            continue
        for logprob_values in logprob_dict.values():
1550
            if logprob_values.logprob == float("-inf"):
1551
1552
                logprob_values.logprob = -9999.0
    return prompt_logprobs