serving_engine.py 51.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
import json
5
import sys
6
import time
7
import traceback
8
from collections.abc import AsyncGenerator, Callable, Iterable, Mapping, Sequence
9
from concurrent.futures import ThreadPoolExecutor
10
from dataclasses import dataclass, field
11
from http import HTTPStatus
12
from typing import Any, ClassVar, Generic, TypeAlias, TypeVar
13

14
import numpy as np
15
import torch
16
from fastapi import Request
17
from pydantic import ConfigDict, TypeAdapter
18
from starlette.datastructures import Headers
19
20
from typing_extensions import TypeIs

21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from vllm.entrypoints.pooling.classify.protocol import (
    ClassificationChatRequest,
    ClassificationCompletionRequest,
    ClassificationRequest,
    ClassificationResponse,
)
from vllm.entrypoints.pooling.embed.protocol import (
    EmbeddingChatRequest,
    EmbeddingCompletionRequest,
    EmbeddingRequest,
    EmbeddingResponse,
)
from vllm.entrypoints.pooling.pooling.protocol import (
    IOProcessorRequest,
    PoolingResponse,
)
from vllm.entrypoints.pooling.score.protocol import (
    RerankRequest,
    ScoreRequest,
    ScoreResponse,
)

43
44
45
46
47
if sys.version_info >= (3, 12):
    from typing import TypedDict
else:
    from typing_extensions import TypedDict

48
49
50
51
from openai.types.responses import (
    ToolChoiceFunction,
)

52
import vllm.envs as envs
53
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
54
from vllm.engine.protocol import EngineClient
55
56
57
58
59
60
61
62
63
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ChatTemplateContentFormatOption,
    ConversationMessage,
    apply_hf_chat_template,
    apply_mistral_chat_template,
    parse_chat_messages_futures,
    resolve_chat_template_content_format,
)
64
from vllm.entrypoints.context import ConversationContext
65
from vllm.entrypoints.logger import RequestLogger
66
from vllm.entrypoints.openai.protocol import (
67
    ChatCompletionNamedToolChoiceParam,
68
69
70
71
72
73
74
    ChatCompletionRequest,
    ChatCompletionResponse,
    CompletionRequest,
    CompletionResponse,
    DetokenizeRequest,
    ErrorInfo,
    ErrorResponse,
75
76
    FunctionCall,
    FunctionDefinition,
77
78
79
80
81
82
83
84
    ResponsesRequest,
    TokenizeChatRequest,
    TokenizeCompletionRequest,
    TokenizeResponse,
    TranscriptionRequest,
    TranscriptionResponse,
    TranslationRequest,
)
85
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
86
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
87
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig
88
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
89
from vllm.entrypoints.utils import _validate_truncation_size
90
from vllm.inputs.data import PromptType
91
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
92
93
94
95
96
from vllm.inputs.parse import (
    PromptComponents,
    get_prompt_components,
    is_explicit_encoder_decoder_prompt,
)
97
from vllm.logger import init_logger
98
from vllm.logprobs import Logprob, PromptLogprobs
99
from vllm.lora.request import LoRARequest
100
from vllm.multimodal import (  # noqa: F401 - Required to resolve Pydantic error in RequestProcessingMixin
101
102
103
    MultiModalDataDict,
    MultiModalUUIDDict,
)
104
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
105
from vllm.pooling_params import PoolingParams
106
from vllm.reasoning import ReasoningParser, ReasoningParserManager
107
from vllm.sampling_params import BeamSearchParams, SamplingParams
108
from vllm.tokenizers import MistralTokenizer, TokenizerLike
109
110
111
112
113
from vllm.tracing import (
    contains_trace_headers,
    extract_trace_headers,
    log_tracing_disabled_warning,
)
114
from vllm.utils import random_uuid
115
from vllm.utils.async_utils import (
116
    AsyncMicrobatchTokenizer,
117
    collect_from_async_generator,
118
    make_async,
119
120
    merge_async_iterators,
)
121
from vllm.utils.collection_utils import is_list_of
122
from vllm.v1.engine import EngineCoreRequest
123
124
125

logger = init_logger(__name__)

126
127
128
129
130
CompletionLikeRequest: TypeAlias = (
    CompletionRequest
    | DetokenizeRequest
    | EmbeddingCompletionRequest
    | RerankRequest
131
    | ClassificationCompletionRequest
132
133
134
    | ScoreRequest
    | TokenizeCompletionRequest
)
135

136
ChatLikeRequest: TypeAlias = (
137
138
139
140
    ChatCompletionRequest
    | EmbeddingChatRequest
    | TokenizeChatRequest
    | ClassificationChatRequest
141
142
143
144
145
146
147
148
)
SpeechToTextRequest: TypeAlias = TranscriptionRequest | TranslationRequest
AnyRequest: TypeAlias = (
    CompletionLikeRequest
    | ChatLikeRequest
    | SpeechToTextRequest
    | ResponsesRequest
    | IOProcessorRequest
149
    | GenerateRequest
150
151
152
153
154
155
156
157
158
159
160
)

AnyResponse: TypeAlias = (
    CompletionResponse
    | ChatCompletionResponse
    | EmbeddingResponse
    | TranscriptionResponse
    | TokenizeResponse
    | PoolingResponse
    | ClassificationResponse
    | ScoreResponse
161
    | GenerateResponse
162
)
163

164
165
166

class TextTokensPrompt(TypedDict):
    prompt: str
167
    prompt_token_ids: list[int]
168
169


170
171
172
173
class EmbedsPrompt(TypedDict):
    prompt_embeds: torch.Tensor


174
RequestPrompt: TypeAlias = list[int] | str | TextTokensPrompt | EmbedsPrompt
175
176
177


def is_text_tokens_prompt(prompt: RequestPrompt) -> TypeIs[TextTokensPrompt]:
178
179
180
181
182
    return (
        isinstance(prompt, dict)
        and "prompt_token_ids" in prompt
        and "prompt_embeds" not in prompt
    )
183
184
185


def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]:
186
187
188
189
190
    return (
        isinstance(prompt, dict)
        and "prompt_token_ids" not in prompt
        and "prompt_embeds" in prompt
    )
191

192

193
194
195
RequestT = TypeVar("RequestT", bound=AnyRequest)


196
197
@dataclass(kw_only=True)
class RequestProcessingMixin:
198
    """
199
    Mixin for request processing,
200
201
    handling prompt preparation and engine input.
    """
202

203
204
    request_prompts: Sequence[RequestPrompt] | None = field(default_factory=list)
    engine_prompts: list[EngineTokensPrompt] | None = field(default_factory=list)
205
206


207
208
@dataclass(kw_only=True)
class ResponseGenerationMixin:
209
    """
210
    Mixin for response generation,
211
212
    managing result generators and final batch results.
    """
213

214
215
216
    result_generator: (
        AsyncGenerator[tuple[int, RequestOutput | PoolingRequestOutput], None] | None
    ) = None
217
    final_res_batch: list[RequestOutput | PoolingRequestOutput] = field(
218
219
        default_factory=list
    )
220
221
222
223

    model_config = ConfigDict(arbitrary_types_allowed=True)


224
225
@dataclass(kw_only=True)
class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, Generic[RequestT]):
226
227
    # Shared across all requests
    request: RequestT
228
    raw_request: Request | None = None
229
230
    model_name: str
    request_id: str
231
    created_time: int = field(default_factory=lambda: int(time.time()))
232
    lora_request: LoRARequest | None = None
233
234

    # Shared across most requests
235
    tokenizer: TokenizerLike | None = None
236
237


238
239
240
@dataclass(kw_only=True)
class ClassificationServeContext(ServeContext[ClassificationRequest]):
    pass
241
242


243
@dataclass(kw_only=True)
244
class EmbeddingServeContext(ServeContext[EmbeddingRequest]):
245
    chat_template: str | None = None
246
247
248
    chat_template_content_format: ChatTemplateContentFormatOption


249
class OpenAIServing:
250
251
252
253
    request_id_prefix: ClassVar[str] = """
    A short string prepended to every request’s ID (e.g. "embd", "classify")
    so you can easily tell “this ID came from Embedding vs Classification.”
    """
254

255
256
    def __init__(
        self,
257
        engine_client: EngineClient,
258
        models: OpenAIServingModels,
259
        *,
260
        request_logger: RequestLogger | None,
261
        return_tokens_as_token_ids: bool = False,
262
        log_error_stack: bool = False,
263
    ):
264
265
        super().__init__()

266
        self.engine_client = engine_client
267

268
        self.models = models
269

270
        self.request_logger = request_logger
271
        self.return_tokens_as_token_ids = return_tokens_as_token_ids
272
        self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
273
        self._apply_mistral_chat_template_async = make_async(
274
275
            apply_mistral_chat_template, executor=self._tokenizer_executor
        )
276

277
        self._async_tokenizer_pool: dict[TokenizerLike, AsyncMicrobatchTokenizer] = {}
278
        self.log_error_stack = log_error_stack
279

280
        self.input_processor = self.models.input_processor
281
282
283
284
        self.io_processor = self.models.io_processor
        self.model_config = self.models.model_config
        self.max_model_len = self.model_config.max_model_len

285
    def _get_tool_parser(
286
        self, tool_parser_name: str | None = None, enable_auto_tools: bool = False
287
    ) -> Callable[[TokenizerLike], ToolParser] | None:
288
289
290
291
        """Get the tool parser based on the name."""
        parser = None
        if not enable_auto_tools or tool_parser_name is None:
            return parser
292
        logger.info('"auto" tool choice has been enabled.')
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312

        try:
            if tool_parser_name == "pythonic" and self.model_config.model.startswith(
                "meta-llama/Llama-3.2"
            ):
                logger.warning(
                    "Llama3.2 models may struggle to emit valid pythonic tool calls"
                )
            parser = ToolParserManager.get_tool_parser(tool_parser_name)
        except Exception as e:
            raise TypeError(
                "Error: --enable-auto-tool-choice requires "
                f"tool_parser:'{tool_parser_name}' which has not "
                "been registered"
            ) from e
        return parser

    def _get_reasoning_parser(
        self,
        reasoning_parser_name: str,
313
    ) -> Callable[[TokenizerLike], ReasoningParser] | None:
314
315
316
317
318
319
320
321
322
323
324
        """Get the reasoning parser based on the name."""
        parser = None
        if not reasoning_parser_name:
            return None
        try:
            parser = ReasoningParserManager.get_reasoning_parser(reasoning_parser_name)
            assert parser is not None
        except Exception as e:
            raise TypeError(f"{reasoning_parser_name=} has not been registered") from e
        return parser

325
    async def reset_mm_cache(self) -> None:
326
        self.input_processor.clear_mm_cache()
327
328
        await self.engine_client.reset_mm_cache()

329
330
331
332
333
    async def beam_search(
        self,
        prompt: PromptType,
        request_id: str,
        params: BeamSearchParams,
334
        lora_request: LoRARequest | None = None,
335
        trace_headers: Mapping[str, str] | None = None,
336
337
338
339
340
341
342
343
    ) -> AsyncGenerator[RequestOutput, None]:
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        ignore_eos = params.ignore_eos
        temperature = params.temperature
        length_penalty = params.length_penalty
        include_stop_str_in_output = params.include_stop_str_in_output

344
345
        input_processor = self.input_processor
        tokenizer = input_processor.tokenizer
346
347
        if tokenizer is None:
            raise ValueError(
348
                "You cannot use beam search when `skip_tokenizer_init=True`"
349
350
351
352
353
354
355
            )

        eos_token_id: int = tokenizer.eos_token_id  # type: ignore

        if is_explicit_encoder_decoder_prompt(prompt):
            raise NotImplementedError

356
        prompt_text: str | None
357
        prompt_token_ids: list[int]
358
        multi_modal_data: MultiModalDataDict | None
359
360
361
362
363
364
365
366
367
        if isinstance(prompt, str):
            prompt_text = prompt
            prompt_token_ids = []
            multi_modal_data = None
        else:
            prompt_text = prompt.get("prompt")  # type: ignore
            prompt_token_ids = prompt.get("prompt_token_ids", [])  # type: ignore
            multi_modal_data = prompt.get("multi_modal_data")  # type: ignore

368
369
370
371
372
373
374
375
376
377
        mm_processor_kwargs: dict[str, Any] | None = None

        # This is a workaround to fix multimodal beam search; this is a
        # bandaid fix for 2 small problems:
        # 1. Multi_modal_data on the processed_inputs currently resolves to
        #    `None`.
        # 2. preprocessing above expands the multimodal placeholders. However,
        #    this happens again in generation, so the double expansion causes
        #    a mismatch.
        # TODO - would be ideal to handle this more gracefully.
378
379
380
381
382

        tokenized_length = len(prompt_token_ids)

        sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)

383
        logprobs_num = 2 * beam_width
384
        beam_search_params = SamplingParams(
385
            logprobs=logprobs_num,
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
            max_tokens=1,
            temperature=temperature,
        )
        all_beams = [
            BeamSearchSequence(
                tokens=prompt_token_ids,
                cum_logprob=0,
                logprobs=[],
                multi_modal_data=multi_modal_data,
                mm_processor_kwargs=mm_processor_kwargs,
                lora_request=lora_request,
            )
        ]
        completed = []

        for _ in range(max_tokens):
            prompts_batch, lora_req_batch = zip(
                *[
                    (
                        EngineTokensPrompt(
                            prompt_token_ids=beam.tokens,
                            multi_modal_data=beam.multi_modal_data,
                            mm_processor_kwargs=beam.mm_processor_kwargs,
                        ),
                        beam.lora_request,
                    )
                    for beam in all_beams
                ]
            )

            tasks = []
            request_id_batch = f"{request_id}-{random_uuid()}"

            for i, (individual_prompt, lora_req) in enumerate(
                zip(prompts_batch, lora_req_batch)
            ):
                request_id_item = f"{request_id_batch}-beam-{i}"
                task = asyncio.create_task(
                    collect_from_async_generator(
                        self.engine_client.generate(
                            individual_prompt,
                            beam_search_params,
                            request_id_item,
                            lora_request=lora_req,
430
                            trace_headers=trace_headers,
431
432
433
434
435
436
437
438
                        )
                    )
                )
                tasks.append(task)

            output = [x[0] for x in await asyncio.gather(*tasks)]

            new_beams = []
439
440
441
442
443
444
445
446
            # Store all new tokens generated by beam
            all_beams_token_id = []
            # Store the cumulative probability of all tokens
            # generated by beam search
            all_beams_logprob = []
            # Iterate through all beam inference results
            for i, result in enumerate(output):
                current_beam = all_beams[i]
447
448
                if result.outputs[0].logprobs is not None:
                    logprobs = result.outputs[0].logprobs[0]
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
                    all_beams_token_id.extend(list(logprobs.keys()))
                    all_beams_logprob.extend(
                        [
                            current_beam.cum_logprob + obj.logprob
                            for obj in logprobs.values()
                        ]
                    )

            # Handle the token for the end of sentence (EOS)
            all_beams_token_id = np.array(all_beams_token_id)
            all_beams_logprob = np.array(all_beams_logprob)

            if not ignore_eos:
                # Get the index position of eos token in all generated results
                eos_idx = np.where(all_beams_token_id == eos_token_id)[0]
                for idx in eos_idx:
                    current_beam = all_beams[idx // logprobs_num]
                    result = output[idx // logprobs_num]
                    assert result.outputs[0].logprobs is not None
                    logprobs_entry = result.outputs[0].logprobs[0]
                    completed.append(
                        BeamSearchSequence(
                            tokens=current_beam.tokens + [eos_token_id]
                            if include_stop_str_in_output
                            else current_beam.tokens,
                            logprobs=current_beam.logprobs + [logprobs_entry],
                            cum_logprob=float(all_beams_logprob[idx]),
                            finish_reason="stop",
                            stop_reason=eos_token_id,
                        )
                    )
                # After processing, set the log probability of the eos condition
                # to negative infinity.
                all_beams_logprob[eos_idx] = -np.inf

            # Processing non-EOS tokens
            # Get indices of the top beam_width probabilities
            topn_idx = np.argpartition(np.negative(all_beams_logprob), beam_width)[
                :beam_width
            ]

            for idx in topn_idx:
                current_beam = all_beams[idx // logprobs_num]
                result = output[idx // logprobs_num]
                token_id = int(all_beams_token_id[idx])
                assert result.outputs[0].logprobs is not None
                logprobs_entry = result.outputs[0].logprobs[0]
                new_beams.append(
                    BeamSearchSequence(
                        tokens=current_beam.tokens + [token_id],
                        logprobs=current_beam.logprobs + [logprobs_entry],
                        lora_request=current_beam.lora_request,
                        cum_logprob=float(all_beams_logprob[idx]),
                        multi_modal_data=current_beam.multi_modal_data,
                        mm_processor_kwargs=current_beam.mm_processor_kwargs,
                    )
                )

            all_beams = new_beams
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541

        completed.extend(all_beams)
        sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
        best_beams = sorted_completed[:beam_width]

        for beam in best_beams:
            if beam.tokens[-1] == eos_token_id and not ignore_eos:
                # Skip the eos token in the text.
                tokens = beam.tokens[tokenized_length:-1]
            else:
                tokens = beam.tokens[tokenized_length:]
            beam.text = tokenizer.decode(tokens)

        yield RequestOutput(
            request_id=request_id,
            prompt=prompt_text,
            outputs=[
                CompletionOutput(
                    text=beam.text,  # type: ignore
                    cumulative_logprob=beam.cum_logprob,
                    token_ids=beam.tokens[tokenized_length:],
                    index=i,
                    logprobs=beam.logprobs,
                    finish_reason=beam.finish_reason
                    if beam.finish_reason is not None
                    else "length",
                    stop_reason=beam.stop_reason,
                )
                for (i, beam) in enumerate(best_beams)
            ],
            finished=True,
            prompt_token_ids=prompt_token_ids,
            prompt_logprobs=None,
        )
542

543
    def _get_renderer(self, tokenizer: TokenizerLike | None) -> BaseRenderer:
544
545
546
547
548
549
550
        """
        Get a Renderer instance with the provided tokenizer.
        Uses shared async tokenizer pool for efficiency.
        """
        return CompletionRenderer(
            model_config=self.model_config,
            tokenizer=tokenizer,
551
552
            async_tokenizer_pool=self._async_tokenizer_pool,
        )
553

554
555
556
557
558
559
560
561
562
563
564
565
566
    def _build_render_config(
        self,
        request: Any,
    ) -> RenderConfig:
        """
        Build and return a `RenderConfig` for an endpoint.

        Used by the renderer to control how prompts are prepared
        (e.g., tokenization and length handling). Endpoints should
        implement this with logic appropriate to their request type.
        """
        raise NotImplementedError

567
568
    def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer:
        """
569
        Return (and cache) an `AsyncMicrobatchTokenizer` bound to the
570
571
572
573
574
575
576
        given tokenizer.
        """
        async_tokenizer = self._async_tokenizer_pool.get(tokenizer)
        if async_tokenizer is None:
            async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
            self._async_tokenizer_pool[tokenizer] = async_tokenizer
        return async_tokenizer
577

578
579
580
    async def _preprocess(
        self,
        ctx: ServeContext,
581
    ) -> ErrorResponse | None:
582
583
584
585
586
587
588
589
590
        """
        Default preprocessing hook. Subclasses may override
        to prepare `ctx` (classification, embedding, etc.).
        """
        return None

    def _build_response(
        self,
        ctx: ServeContext,
591
    ) -> AnyResponse | ErrorResponse:
592
593
594
595
596
597
598
599
600
        """
        Default response builder. Subclass may override this method
        to return the appropriate response object.
        """
        return self.create_error_response("unimplemented endpoint")

    async def handle(
        self,
        ctx: ServeContext,
601
602
    ) -> AnyResponse | ErrorResponse:
        generation: AsyncGenerator[AnyResponse | ErrorResponse, None]
603
604
605
606
607
608
609
610
611
612
        generation = self._pipeline(ctx)

        async for response in generation:
            return response

        return self.create_error_response("No response yielded from pipeline")

    async def _pipeline(
        self,
        ctx: ServeContext,
613
    ) -> AsyncGenerator[AnyResponse | ErrorResponse, None]:
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
        """Execute the request processing pipeline yielding responses."""
        if error := await self._check_model(ctx.request):
            yield error
        if error := self._validate_request(ctx):
            yield error

        preprocess_ret = await self._preprocess(ctx)
        if isinstance(preprocess_ret, ErrorResponse):
            yield preprocess_ret

        generators_ret = await self._prepare_generators(ctx)
        if isinstance(generators_ret, ErrorResponse):
            yield generators_ret

        collect_ret = await self._collect_batch(ctx)
        if isinstance(collect_ret, ErrorResponse):
            yield collect_ret

        yield self._build_response(ctx)

634
    def _validate_request(self, ctx: ServeContext) -> ErrorResponse | None:
635
        truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None)
636

637
638
639
640
        if (
            truncate_prompt_tokens is not None
            and truncate_prompt_tokens > self.max_model_len
        ):
641
642
643
            return self.create_error_response(
                "truncate_prompt_tokens value is "
                "greater than max_model_len."
644
645
                " Please, select a smaller truncation size."
            )
646
647
        return None

648
649
650
    def _create_pooling_params(
        self,
        ctx: ServeContext,
651
    ) -> PoolingParams | ErrorResponse:
652
653
        if not hasattr(ctx.request, "to_pooling_params"):
            return self.create_error_response(
654
655
                "Request type does not support pooling parameters"
            )
656
657
658

        return ctx.request.to_pooling_params()

659
660
661
    async def _prepare_generators(
        self,
        ctx: ServeContext,
662
    ) -> ErrorResponse | None:
663
        """Schedule the request and get the result generator."""
664
        generators: list[
665
            AsyncGenerator[RequestOutput | PoolingRequestOutput, None]
666
        ] = []
667
668

        try:
669
670
671
672
673
            trace_headers = (
                None
                if ctx.raw_request is None
                else await self._get_trace_headers(ctx.raw_request.headers)
            )
674

675
676
677
            pooling_params = self._create_pooling_params(ctx)
            if isinstance(pooling_params, ErrorResponse):
                return pooling_params
678
679

            if ctx.engine_prompts is None:
680
                return self.create_error_response("Engine prompts not available")
681
682
683
684

            for i, engine_prompt in enumerate(ctx.engine_prompts):
                request_id_item = f"{ctx.request_id}-{i}"

685
686
                self._log_inputs(
                    request_id_item,
687
                    engine_prompt,
688
689
690
                    params=pooling_params,
                    lora_request=ctx.lora_request,
                )
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713

                generator = self.engine_client.encode(
                    engine_prompt,
                    pooling_params,
                    request_id_item,
                    lora_request=ctx.lora_request,
                    trace_headers=trace_headers,
                    priority=getattr(ctx.request, "priority", 0),
                )

                generators.append(generator)

            ctx.result_generator = merge_async_iterators(*generators)

            return None

        except Exception as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))

    async def _collect_batch(
        self,
        ctx: ServeContext,
714
    ) -> ErrorResponse | None:
715
716
717
        """Collect batch results from the result generator."""
        try:
            if ctx.engine_prompts is None:
718
                return self.create_error_response("Engine prompts not available")
719
720

            num_prompts = len(ctx.engine_prompts)
721
            final_res_batch: list[RequestOutput | PoolingRequestOutput | None]
722
723
724
            final_res_batch = [None] * num_prompts

            if ctx.result_generator is None:
725
                return self.create_error_response("Result generator not available")
726
727
728
729
730
731

            async for i, res in ctx.result_generator:
                final_res_batch[i] = res

            if None in final_res_batch:
                return self.create_error_response(
732
733
                    "Failed to generate results for all prompts"
                )
734

735
            ctx.final_res_batch = [res for res in final_res_batch if res is not None]
736
737
738
739
740
741

            return None

        except Exception as e:
            return self.create_error_response(str(e))

742
    def create_error_response(
743
744
745
746
747
        self,
        message: str,
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
    ) -> ErrorResponse:
748
749
750
751
752
753
        if self.log_error_stack:
            exc_type, _, _ = sys.exc_info()
            if exc_type is not None:
                traceback.print_exc()
            else:
                traceback.print_stack()
754
755
756
        return ErrorResponse(
            error=ErrorInfo(message=message, type=err_type, code=status_code.value)
        )
757

758
    def create_streaming_error_response(
759
760
761
762
763
        self,
        message: str,
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
    ) -> str:
764
        json_str = json.dumps(
765
766
767
768
            self.create_error_response(
                message=message, err_type=err_type, status_code=status_code
            ).model_dump()
        )
769
770
        return json_str

771
    async def _check_model(
772
773
        self,
        request: AnyRequest,
774
    ) -> ErrorResponse | None:
775
776
        error_response = None

777
        if self._is_model_supported(request.model):
778
            return None
779
        if request.model in self.models.lora_requests:
780
            return None
781
782
783
784
785
        if (
            envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING
            and request.model
            and (load_result := await self.models.resolve_lora(request.model))
        ):
786
787
            if isinstance(load_result, LoRARequest):
                return None
788
789
790
791
            if (
                isinstance(load_result, ErrorResponse)
                and load_result.error.code == HTTPStatus.BAD_REQUEST.value
            ):
792
793
794
                error_response = load_result

        return error_response or self.create_error_response(
795
796
            message=f"The model `{request.model}` does not exist.",
            err_type="NotFoundError",
797
798
            status_code=HTTPStatus.NOT_FOUND,
        )
799

800
    def _get_active_default_mm_loras(self, request: AnyRequest) -> LoRARequest | None:
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
        """Determine if there are any active default multimodal loras."""
        # TODO: Currently this is only enabled for chat completions
        # to be better aligned with only being enabled for .generate
        # when run offline. It would be nice to support additional
        # tasks types in the future.
        message_types = self._get_message_types(request)
        default_mm_loras = set()

        for lora in self.models.lora_requests.values():
            # Best effort match for default multimodal lora adapters;
            # There is probably a better way to do this, but currently
            # this matches against the set of 'types' in any content lists
            # up until '_', e.g., to match audio_url -> audio
            if lora.lora_name in message_types:
                default_mm_loras.add(lora)

        # Currently only support default modality specific loras if
        # we have exactly one lora matched on the request.
        if len(default_mm_loras) == 1:
            return default_mm_loras.pop()
        return None

823
    def _maybe_get_adapters(
824
825
826
        self,
        request: AnyRequest,
        supports_default_mm_loras: bool = False,
827
    ) -> LoRARequest | None:
828
        if request.model in self.models.lora_requests:
829
            return self.models.lora_requests[request.model]
830
831
832
833
834
835

        # Currently only support default modality specific loras
        # if we have exactly one lora matched on the request.
        if supports_default_mm_loras:
            default_mm_lora = self._get_active_default_mm_loras(request)
            if default_mm_lora is not None:
836
                return default_mm_lora
837
838

        if self._is_model_supported(request.model):
839
            return None
840

841
        # if _check_model has been called earlier, this will be unreachable
842
        raise ValueError(f"The model `{request.model}` does not exist.")
843

844
845
846
847
848
849
850
851
852
853
    def _get_message_types(self, request: AnyRequest) -> set[str]:
        """Retrieve the set of types from message content dicts up
        until `_`; we use this to match potential multimodal data
        with default per modality loras.
        """
        message_types: set[str] = set()

        if not hasattr(request, "messages"):
            return message_types

854
855
856
857
858
        messages = request.messages
        if messages is None or isinstance(messages, (str, bytes)):
            return message_types

        for message in messages:
859
860
861
862
863
            if (
                isinstance(message, dict)
                and "content" in message
                and isinstance(message["content"], list)
            ):
864
865
866
867
868
                for content_dict in message["content"]:
                    if "type" in content_dict:
                        message_types.add(content_dict["type"].split("_")[0])
        return message_types

869
    async def _normalize_prompt_text_to_input(
870
871
872
        self,
        request: AnyRequest,
        prompt: str,
873
        tokenizer: TokenizerLike,
874
875
        add_special_tokens: bool,
    ) -> TextTokensPrompt:
876
877
        async_tokenizer = self._get_async_tokenizer(tokenizer)

878
879
880
881
        if (
            self.model_config.encoder_config is not None
            and self.model_config.encoder_config.get("do_lower_case", False)
        ):
882
883
            prompt = prompt.lower()

884
        truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
885

886
        if truncate_prompt_tokens is None:
887
            encoded = await async_tokenizer(
888
889
                prompt, add_special_tokens=add_special_tokens
            )
890
891
        elif truncate_prompt_tokens < 0:
            # Negative means we cap at the model's max length
892
893
894
895
            encoded = await async_tokenizer(
                prompt,
                add_special_tokens=add_special_tokens,
                truncation=True,
896
897
                max_length=self.max_model_len,
            )
898
        else:
899
900
901
902
            encoded = await async_tokenizer(
                prompt,
                add_special_tokens=add_special_tokens,
                truncation=True,
903
904
                max_length=truncate_prompt_tokens,
            )
905
906
907
908
909
910

        input_ids = encoded.input_ids
        input_text = prompt

        return self._validate_input(request, input_ids, input_text)

911
    async def _normalize_prompt_tokens_to_input(
912
913
        self,
        request: AnyRequest,
914
        prompt_ids: list[int],
915
        tokenizer: TokenizerLike | None,
916
    ) -> TextTokensPrompt:
917
        truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
918

919
        if truncate_prompt_tokens is None:
920
            input_ids = prompt_ids
921
        elif truncate_prompt_tokens < 0:
922
            input_ids = prompt_ids[-self.max_model_len :]
923
924
925
        else:
            input_ids = prompt_ids[-truncate_prompt_tokens:]

926
927
928
929
930
        if tokenizer is None:
            input_text = ""
        else:
            async_tokenizer = self._get_async_tokenizer(tokenizer)
            input_text = await async_tokenizer.decode(input_ids)
931

932
933
934
935
936
        return self._validate_input(request, input_ids, input_text)

    def _validate_input(
        self,
        request: AnyRequest,
937
        input_ids: list[int],
938
939
        input_text: str,
    ) -> TextTokensPrompt:
940
941
        token_num = len(input_ids)

942
943
        # Note: EmbeddingRequest, ClassificationRequest,
        # and ScoreRequest doesn't have max_tokens
944
        if isinstance(
945
            request,
946
947
948
949
950
            (
                EmbeddingChatRequest,
                EmbeddingCompletionRequest,
                ScoreRequest,
                RerankRequest,
951
952
                ClassificationCompletionRequest,
                ClassificationChatRequest,
953
954
            ),
        ):
955
956
            # Note: input length can be up to the entire model context length
            # since these requests don't generate tokens.
957
            if token_num > self.max_model_len:
958
959
                operations: dict[type[AnyRequest], str] = {
                    ScoreRequest: "score",
960
961
                    ClassificationCompletionRequest: "classification",
                    ClassificationChatRequest: "classification",
962
                }
963
                operation = operations.get(type(request), "embedding generation")
964
965
966
                raise ValueError(
                    f"This model's maximum context length is "
                    f"{self.max_model_len} tokens. However, you requested "
967
                    f"{token_num} tokens in the input for {operation}. "
968
969
970
                    f"Please reduce the length of the input."
                )
            return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
971

972
973
        # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
        # and does not require model context length validation
974
        if isinstance(
975
976
            request,
            (TokenizeCompletionRequest, TokenizeChatRequest, DetokenizeRequest),
977
        ):
978
            return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
979

980
981
982
983
984
        # chat completion endpoint supports max_completion_tokens
        if isinstance(request, ChatCompletionRequest):
            # TODO(#9845): remove max_tokens when field dropped from OpenAI API
            max_tokens = request.max_completion_tokens or request.max_tokens
        else:
985
            max_tokens = getattr(request, "max_tokens", None)
986
987
988
989

        # Note: input length can be up to model context length - 1 for
        # completion-like requests.
        if token_num >= self.max_model_len:
990
            raise ValueError(
991
                f"This model's maximum context length is "
992
993
                f"{self.max_model_len} tokens. However, your request has "
                f"{token_num} input tokens. Please reduce the length of "
994
995
                "the input messages."
            )
996

997
        if max_tokens is not None and token_num + max_tokens > self.max_model_len:
998
999
1000
1001
1002
            raise ValueError(
                "'max_tokens' or 'max_completion_tokens' is too large: "
                f"{max_tokens}. This model's maximum context length is "
                f"{self.max_model_len} tokens and your request has "
                f"{token_num} input tokens ({max_tokens} > {self.max_model_len}"
1003
1004
                f" - {token_num})."
            )
1005
1006
1007

        return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)

1008
    async def _tokenize_prompt_input_async(
1009
1010
        self,
        request: AnyRequest,
1011
        tokenizer: TokenizerLike,
1012
        prompt_input: str | list[int],
1013
1014
1015
        add_special_tokens: bool = True,
    ) -> TextTokensPrompt:
        """
1016
        A simpler implementation that tokenizes a single prompt input.
1017
        """
1018
        async for result in self._tokenize_prompt_inputs_async(
1019
1020
            request,
            tokenizer,
1021
            [prompt_input],
1022
            add_special_tokens=add_special_tokens,
1023
1024
1025
        ):
            return result
        raise ValueError("No results yielded from tokenization")
1026

1027
    async def _tokenize_prompt_inputs_async(
1028
1029
        self,
        request: AnyRequest,
1030
        tokenizer: TokenizerLike,
1031
        prompt_inputs: Iterable[str | list[int]],
1032
        add_special_tokens: bool = True,
1033
    ) -> AsyncGenerator[TextTokensPrompt, None]:
1034
        """
1035
        A simpler implementation that tokenizes multiple prompt inputs.
1036
        """
1037
1038
        for prompt in prompt_inputs:
            if isinstance(prompt, str):
1039
                yield await self._normalize_prompt_text_to_input(
1040
                    request,
1041
1042
                    prompt=prompt,
                    tokenizer=tokenizer,
1043
1044
1045
                    add_special_tokens=add_special_tokens,
                )
            else:
1046
                yield await self._normalize_prompt_tokens_to_input(
1047
                    request,
1048
1049
                    prompt_ids=prompt,
                    tokenizer=tokenizer,
1050
1051
                )

1052
1053
    def _validate_chat_template(
        self,
1054
1055
        request_chat_template: str | None,
        chat_template_kwargs: dict[str, Any] | None,
1056
        trust_request_chat_template: bool,
1057
    ) -> ErrorResponse | None:
1058
        if not trust_request_chat_template and (
1059
1060
1061
1062
1063
1064
            request_chat_template is not None
            or (
                chat_template_kwargs
                and chat_template_kwargs.get("chat_template") is not None
            )
        ):
1065
1066
1067
            return self.create_error_response(
                "Chat template is passed with request, but "
                "--trust-request-chat-template is not set. "
1068
1069
                "Refused request with untrusted chat template."
            )
1070
1071
        return None

1072
1073
    async def _preprocess_chat(
        self,
1074
        request: ChatLikeRequest | ResponsesRequest,
1075
        tokenizer: TokenizerLike | None,
1076
        messages: list[ChatCompletionMessageParam],
1077
        chat_template: str | None,
1078
        chat_template_content_format: ChatTemplateContentFormatOption,
1079
1080
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
1081
1082
1083
        tool_dicts: list[dict[str, Any]] | None = None,
        documents: list[dict[str, str]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
1084
        tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
1085
        add_special_tokens: bool = False,
1086
    ) -> tuple[
1087
1088
1089
        list[ConversationMessage],
        Sequence[RequestPrompt],
        list[EngineTokensPrompt],
1090
    ]:
1091
1092
1093
1094
1095
        if tokenizer is None:
            raise ValueError(
                "Unable to get tokenizer because `skip_tokenizer_init=True`"
            )

1096
1097
        model_config = self.model_config

1098
1099
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
1100
            tool_dicts,
1101
1102
            chat_template_content_format,
            tokenizer,
1103
            model_config=model_config,
1104
        )
1105
        conversation, mm_data_future, mm_uuids = parse_chat_messages_futures(
1106
            messages,
1107
            model_config,
1108
            tokenizer,
1109
            content_format=resolved_content_format,
1110
1111
        )

1112
        _chat_template_kwargs: dict[str, Any] = dict(
1113
1114
1115
1116
1117
1118
1119
1120
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tool_dicts,
            documents=documents,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

1121
        request_prompt: str | list[int]
1122
1123
1124
1125

        if tokenizer is None:
            request_prompt = "placeholder"
        elif isinstance(tokenizer, MistralTokenizer):
1126
            request_prompt = await self._apply_mistral_chat_template_async(
1127
1128
                tokenizer,
                messages=messages,
1129
                **_chat_template_kwargs,
1130
1131
1132
            )
        else:
            request_prompt = apply_hf_chat_template(
1133
                tokenizer=tokenizer,
1134
                conversation=conversation,
1135
                model_config=model_config,
1136
                **_chat_template_kwargs,
1137
1138
1139
1140
            )

        mm_data = await mm_data_future

1141
1142
1143
        # tool parsing is done only if a tool_parser has been set and if
        # tool_choice is not "none" (if tool_choice is "none" but a tool_parser
        # is set, we want to prevent parsing a tool_call hallucinated by the LLM
1144
1145
1146
        should_parse_tools = tool_parser is not None and (
            hasattr(request, "tool_choice") and request.tool_choice != "none"
        )
1147
1148

        if should_parse_tools:
1149
1150
1151
1152
1153
            if not isinstance(request, ChatCompletionRequest | ResponsesRequest):
                msg = (
                    "Tool usage is only supported for Chat Completions API "
                    "or Responses API requests."
                )
1154
                raise NotImplementedError(msg)
1155
            request = tool_parser(tokenizer).adjust_request(request=request)  # type: ignore
1156

1157
1158
        if tokenizer is None:
            assert isinstance(request_prompt, str), (
1159
1160
                "Prompt has to be a string",
                "when the tokenizer is not initialised",
1161
            )
1162
1163
1164
            prompt_inputs = TextTokensPrompt(
                prompt=request_prompt, prompt_token_ids=[1]
            )
1165
        elif isinstance(request_prompt, str):
1166
            prompt_inputs = await self._tokenize_prompt_input_async(
1167
1168
1169
1170
1171
1172
1173
1174
                request,
                tokenizer,
                request_prompt,
                add_special_tokens=add_special_tokens,
            )
        else:
            # For MistralTokenizer
            assert is_list_of(request_prompt, int), (
1175
1176
                "Prompt has to be either a string or a list of token ids"
            )
1177
1178
            prompt_inputs = TextTokensPrompt(
                prompt=tokenizer.decode(request_prompt),
1179
1180
                prompt_token_ids=request_prompt,
            )
1181

1182
        engine_prompt = EngineTokensPrompt(
1183
1184
            prompt_token_ids=prompt_inputs["prompt_token_ids"]
        )
1185
1186
        if mm_data is not None:
            engine_prompt["multi_modal_data"] = mm_data
1187
1188
1189
1190

        if mm_uuids is not None:
            engine_prompt["multi_modal_uuids"] = mm_uuids

1191
1192
        if request.mm_processor_kwargs is not None:
            engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
1193

1194
1195
1196
        if hasattr(request, "cache_salt") and request.cache_salt is not None:
            engine_prompt["cache_salt"] = request.cache_salt

1197
1198
        return conversation, [request_prompt], [engine_prompt]

1199
1200
1201
1202
    async def _process_inputs(
        self,
        request_id: str,
        engine_prompt: PromptType,
1203
        params: SamplingParams | PoolingParams,
1204
        *,
1205
1206
        lora_request: LoRARequest | None,
        trace_headers: Mapping[str, str] | None,
1207
1208
        priority: int,
    ) -> tuple[EngineCoreRequest, dict[str, Any]]:
1209
        """Use the Processor to process inputs for AsyncLLM."""
1210
        tokenization_kwargs: dict[str, Any] = {}
1211
1212
1213
        _validate_truncation_size(
            self.max_model_len, params.truncate_prompt_tokens, tokenization_kwargs
        )
1214

1215
        engine_request = self.input_processor.process_inputs(
1216
1217
            request_id,
            engine_prompt,
1218
            params,
1219
1220
1221
1222
1223
1224
1225
            lora_request=lora_request,
            tokenization_kwargs=tokenization_kwargs,
            trace_headers=trace_headers,
            priority=priority,
        )
        return engine_request, tokenization_kwargs

1226
1227
1228
1229
1230
1231
1232
    async def _generate_with_builtin_tools(
        self,
        request_id: str,
        request_prompt: RequestPrompt,
        engine_prompt: EngineTokensPrompt,
        sampling_params: SamplingParams,
        context: ConversationContext,
1233
        lora_request: LoRARequest | None = None,
1234
1235
1236
        priority: int = 0,
        **kwargs,
    ):
1237
        prompt_text, _, _ = self._get_prompt_components(request_prompt)
1238
        orig_priority = priority
1239
        sub_request = 0
1240
        while True:
1241
1242
            # Ensure that each sub-request has a unique request id.
            sub_request_id = f"{request_id}_{sub_request}"
1243
            self._log_inputs(
1244
                sub_request_id,
1245
1246
1247
1248
                request_prompt,
                params=sampling_params,
                lora_request=lora_request,
            )
1249
            trace_headers = kwargs.get("trace_headers")
1250
            engine_request, tokenization_kwargs = await self._process_inputs(
1251
                sub_request_id,
1252
1253
                engine_prompt,
                sampling_params,
1254
1255
1256
                lora_request=lora_request,
                trace_headers=trace_headers,
                priority=priority,
1257
            )
1258
1259
1260
1261

            generator = self.engine_client.generate(
                engine_request,
                sampling_params,
1262
                sub_request_id,
1263
1264
                lora_request=lora_request,
                priority=priority,
1265
1266
                prompt_text=prompt_text,
                tokenization_kwargs=tokenization_kwargs,
1267
1268
                **kwargs,
            )
1269

1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
            async for res in generator:
                context.append_output(res)
                # NOTE(woosuk): The stop condition is handled by the engine.
                yield context

            if not context.need_builtin_tool_call():
                # The model did not ask for a tool call, so we're done.
                break

            # Call the tool and update the context with the result.
            tool_output = await context.call_tool()
1281
            context.append_tool_output(tool_output)
1282
1283
1284
1285
1286
1287
1288

            # TODO: uncomment this and enable tool output streaming
            # yield context

            # Create inputs for the next turn.
            # Render the next prompt token ids.
            prompt_token_ids = context.render_for_completion()
1289
            engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids)
1290
1291
            request_prompt = prompt_token_ids
            # Update the sampling params.
1292
            sampling_params.max_tokens = self.max_model_len - len(prompt_token_ids)
1293
1294
            # OPTIMIZATION
            priority = orig_priority - 1
1295
            sub_request += 1
1296

1297
1298
    def _get_prompt_components(
        self,
1299
        prompt: RequestPrompt | PromptType,
1300
    ) -> PromptComponents:
1301
1302
        if isinstance(prompt, list):
            return PromptComponents(token_ids=prompt)
1303

1304
        return get_prompt_components(prompt)  # type: ignore[arg-type]
1305

1306
1307
1308
    def _log_inputs(
        self,
        request_id: str,
1309
1310
1311
        inputs: RequestPrompt | PromptType,
        params: SamplingParams | PoolingParams | BeamSearchParams | None,
        lora_request: LoRARequest | None,
1312
1313
1314
    ) -> None:
        if self.request_logger is None:
            return
1315

1316
        prompt, prompt_token_ids, prompt_embeds = self._get_prompt_components(inputs)
1317
1318
1319
1320
1321

        self.request_logger.log_inputs(
            request_id,
            prompt,
            prompt_token_ids,
1322
            prompt_embeds,
1323
1324
1325
            params=params,
            lora_request=lora_request,
        )
1326

1327
1328
1329
    async def _get_trace_headers(
        self,
        headers: Headers,
1330
    ) -> Mapping[str, str] | None:
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
        is_tracing_enabled = await self.engine_client.is_tracing_enabled()

        if is_tracing_enabled:
            return extract_trace_headers(headers)

        if contains_trace_headers(headers):
            log_tracing_disabled_warning()

        return None

1341
    @staticmethod
1342
    def _base_request_id(
1343
1344
        raw_request: Request | None, default: str | None = None
    ) -> str | None:
1345
        """Pulls the request id to use from a header, if provided"""
1346
1347
1348
1349
        if raw_request is not None and (
            (req_id := raw_request.headers.get("X-Request-Id")) is not None
        ):
            return req_id
1350

1351
        return random_uuid() if default is None else default
1352

1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
    @staticmethod
    def _get_data_parallel_rank(raw_request: Request | None) -> int | None:
        """Pulls the data parallel rank from a header, if provided"""
        if raw_request is None:
            return None

        rank_str = raw_request.headers.get("X-data-parallel-rank")
        if rank_str is None:
            return None

        try:
            return int(rank_str)
        except ValueError:
            return None

1368
1369
1370
    @staticmethod
    def _parse_tool_calls_from_content(
        request: ResponsesRequest | ChatCompletionRequest,
1371
        tokenizer: TokenizerLike,
1372
        enable_auto_tools: bool,
1373
        tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
        content: str | None = None,
    ) -> tuple[list[FunctionCall] | None, str | None]:
        function_calls = list[FunctionCall]()
        if request.tool_choice and isinstance(request.tool_choice, ToolChoiceFunction):
            assert content is not None
            # Forced Function Call
            function_calls.append(
                FunctionCall(name=request.tool_choice.name, arguments=content)
            )
            content = None  # Clear content since tool is called.
        elif request.tool_choice and isinstance(
            request.tool_choice, ChatCompletionNamedToolChoiceParam
        ):
            assert content is not None
            # Forced Function Call
            function_calls.append(
                FunctionCall(name=request.tool_choice.function.name, arguments=content)
            )
            content = None  # Clear content since tool is called.
        elif request.tool_choice == "required":
            assert content is not None
            tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(content)
            function_calls.extend(
                [
                    FunctionCall(
                        name=tool_call.name,
                        arguments=json.dumps(tool_call.parameters, ensure_ascii=False),
                    )
                    for tool_call in tool_calls
                ]
            )
            content = None  # Clear content since tool is called.
        elif (
            tool_parser_cls
            and enable_auto_tools
            and (request.tool_choice == "auto" or request.tool_choice is None)
        ):
            # Automatic Tool Call Parsing
            try:
                tool_parser = tool_parser_cls(tokenizer)
            except RuntimeError as e:
                logger.exception("Error in tool parser creation.")
                raise e
            tool_call_info = tool_parser.extract_tool_calls(
                content if content is not None else "",
                request=request,  # type: ignore
            )
            if tool_call_info is not None and tool_call_info.tools_called:
                # extract_tool_calls() returns a list of tool calls.
                function_calls.extend(
                    FunctionCall(
                        name=tool_call.function.name,
                        arguments=tool_call.function.arguments,
                    )
                    for tool_call in tool_call_info.tool_calls
                )
                content = tool_call_info.content
1431
1432
                if content and content.strip() == "":
                    content = None
1433
1434
1435
1436
1437
1438
            else:
                # No tool calls.
                return None, content

        return function_calls, content

1439
    @staticmethod
1440
1441
1442
    def _get_decoded_token(
        logprob: Logprob,
        token_id: int,
1443
        tokenizer: TokenizerLike | None,
1444
1445
        return_as_token_id: bool = False,
    ) -> str:
1446
1447
1448
        if return_as_token_id:
            return f"token_id:{token_id}"

1449
1450
        if logprob.decoded_token is not None:
            return logprob.decoded_token
1451
1452
1453
1454
1455
1456

        if tokenizer is None:
            raise ValueError(
                "Unable to get tokenizer because `skip_tokenizer_init=True`"
            )

1457
        return tokenizer.decode(token_id)
1458

1459
    def _is_model_supported(self, model_name: str | None) -> bool:
1460
1461
        if not model_name:
            return True
1462
        return self.models.is_base_model(model_name)
1463

1464
1465

def clamp_prompt_logprobs(
1466
1467
    prompt_logprobs: PromptLogprobs | None,
) -> PromptLogprobs | None:
1468
1469
1470
1471
1472
1473
1474
    if prompt_logprobs is None:
        return prompt_logprobs

    for logprob_dict in prompt_logprobs:
        if logprob_dict is None:
            continue
        for logprob_values in logprob_dict.values():
1475
            if logprob_values.logprob == float("-inf"):
1476
1477
                logprob_values.logprob = -9999.0
    return prompt_logprobs