serving.py 56.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
import json
5
import sys
6
import time
7
import traceback
8
from collections.abc import AsyncGenerator, Callable, Iterable, Mapping
9
from dataclasses import dataclass, field
10
from http import HTTPStatus
11
from typing import Any, ClassVar, Generic, TypeAlias, TypeVar, cast
12

13
import numpy as np
14
from fastapi import Request
15
16
17
from openai.types.responses import (
    ToolChoiceFunction,
)
18
19
from pydantic import ConfigDict, TypeAdapter
from starlette.datastructures import Headers
20

21
import vllm.envs as envs
22
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
23
from vllm.engine.protocol import EngineClient
24
25
26
27
28
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ChatTemplateContentFormatOption,
    ConversationMessage,
)
29
from vllm.entrypoints.logger import RequestLogger
30
from vllm.entrypoints.openai.chat_completion.protocol import (
31
    ChatCompletionNamedToolChoiceParam,
32
33
    ChatCompletionRequest,
    ChatCompletionResponse,
34
)
35
from vllm.entrypoints.openai.completion.protocol import (
36
37
    CompletionRequest,
    CompletionResponse,
38
39
)
from vllm.entrypoints.openai.engine.protocol import (
40
41
    ErrorInfo,
    ErrorResponse,
42
    FunctionCall,
43
    FunctionDefinition,
44
)
45
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
46
47
48
49
50
51
from vllm.entrypoints.openai.responses.context import (
    ConversationContext,
    HarmonyContext,
    ParsableContext,
    StreamingHarmonyContext,
)
52
53
54
55
from vllm.entrypoints.openai.responses.protocol import (
    ResponseInputOutputItem,
    ResponsesRequest,
)
56
57
58
from vllm.entrypoints.openai.responses.utils import (
    construct_input_messages,
)
59
60
61
62
63
from vllm.entrypoints.openai.translations.protocol import (
    TranscriptionRequest,
    TranscriptionResponse,
    TranslationRequest,
)
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from vllm.entrypoints.pooling.classify.protocol import (
    ClassificationChatRequest,
    ClassificationCompletionRequest,
    ClassificationRequest,
    ClassificationResponse,
)
from vllm.entrypoints.pooling.embed.protocol import (
    EmbeddingChatRequest,
    EmbeddingCompletionRequest,
    EmbeddingRequest,
    EmbeddingResponse,
)
from vllm.entrypoints.pooling.pooling.protocol import (
    IOProcessorRequest,
78
79
    PoolingChatRequest,
    PoolingCompletionRequest,
80
81
82
83
    PoolingResponse,
)
from vllm.entrypoints.pooling.score.protocol import (
    RerankRequest,
84
85
    ScoreDataRequest,
    ScoreQueriesDocumentsRequest,
86
87
    ScoreRequest,
    ScoreResponse,
88
    ScoreTextRequest,
89
)
90
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig
91
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
92
93
94
95
96
97
from vllm.entrypoints.serve.tokenize.protocol import (
    DetokenizeRequest,
    TokenizeChatRequest,
    TokenizeCompletionRequest,
    TokenizeResponse,
)
98
from vllm.entrypoints.utils import _validate_truncation_size, sanitize_message
99
from vllm.exceptions import VLLMValidationError
100
from vllm.inputs.data import PromptType, TokensPrompt
101
102
103
104
105
from vllm.inputs.parse import (
    PromptComponents,
    get_prompt_components,
    is_explicit_encoder_decoder_prompt,
)
106
from vllm.logger import init_logger
107
from vllm.logprobs import Logprob, PromptLogprobs
108
from vllm.lora.request import LoRARequest
109
from vllm.multimodal import MultiModalDataDict
110
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
111
from vllm.pooling_params import PoolingParams
112
from vllm.reasoning import ReasoningParser, ReasoningParserManager
113
from vllm.renderers import RendererLike
114
from vllm.sampling_params import BeamSearchParams, SamplingParams
115
from vllm.tokenizers import TokenizerLike
116
from vllm.tool_parsers import ToolParser, ToolParserManager
117
118
119
120
121
from vllm.tracing import (
    contains_trace_headers,
    extract_trace_headers,
    log_tracing_disabled_warning,
)
122
from vllm.utils import random_uuid
123
from vllm.utils.async_utils import (
124
    AsyncMicrobatchTokenizer,
125
    collect_from_async_generator,
126
127
    merge_async_iterators,
)
128
from vllm.v1.engine import EngineCoreRequest
129

130
131
132
133
134
135
136
137
138

class GenerationError(Exception):
    """raised when finish_reason indicates internal server error (500)"""

    def __init__(self, message: str = "Internal server error"):
        super().__init__(message)
        self.status_code = HTTPStatus.INTERNAL_SERVER_ERROR


139
140
logger = init_logger(__name__)

141
142
CompletionLikeRequest: TypeAlias = (
    CompletionRequest
143
    | TokenizeCompletionRequest
144
145
    | DetokenizeRequest
    | EmbeddingCompletionRequest
146
    | ClassificationCompletionRequest
147
    | RerankRequest
148
    | ScoreRequest
149
    | PoolingCompletionRequest
150
)
151

152
ChatLikeRequest: TypeAlias = (
153
154
    ChatCompletionRequest
    | TokenizeChatRequest
155
    | EmbeddingChatRequest
156
    | ClassificationChatRequest
157
    | PoolingChatRequest
158
159
160
161
162
163
164
165
)
SpeechToTextRequest: TypeAlias = TranscriptionRequest | TranslationRequest
AnyRequest: TypeAlias = (
    CompletionLikeRequest
    | ChatLikeRequest
    | SpeechToTextRequest
    | ResponsesRequest
    | IOProcessorRequest
166
    | GenerateRequest
167
168
169
170
171
172
173
174
175
176
177
)

AnyResponse: TypeAlias = (
    CompletionResponse
    | ChatCompletionResponse
    | EmbeddingResponse
    | TranscriptionResponse
    | TokenizeResponse
    | PoolingResponse
    | ClassificationResponse
    | ScoreResponse
178
    | GenerateResponse
179
)
180

181

182
183
184
RequestT = TypeVar("RequestT", bound=AnyRequest)


185
186
@dataclass(kw_only=True)
class RequestProcessingMixin:
187
    """
188
    Mixin for request processing,
189
190
    handling prompt preparation and engine input.
    """
191

192
    engine_prompts: list[TokensPrompt] | None = field(default_factory=list)
193
194


195
196
@dataclass(kw_only=True)
class ResponseGenerationMixin:
197
    """
198
    Mixin for response generation,
199
200
    managing result generators and final batch results.
    """
201

202
203
204
    result_generator: (
        AsyncGenerator[tuple[int, RequestOutput | PoolingRequestOutput], None] | None
    ) = None
205
    final_res_batch: list[RequestOutput | PoolingRequestOutput] = field(
206
207
        default_factory=list
    )
208
209
210
211

    model_config = ConfigDict(arbitrary_types_allowed=True)


212
213
@dataclass(kw_only=True)
class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, Generic[RequestT]):
214
    request: RequestT
215
    raw_request: Request | None = None
216
217
    model_name: str
    request_id: str
218
    created_time: int = field(default_factory=lambda: int(time.time()))
219
    lora_request: LoRARequest | None = None
220
221


222
223
224
@dataclass(kw_only=True)
class ClassificationServeContext(ServeContext[ClassificationRequest]):
    pass
225
226


227
@dataclass(kw_only=True)
228
class EmbeddingServeContext(ServeContext[EmbeddingRequest]):
229
    chat_template: str | None = None
230
231
232
    chat_template_content_format: ChatTemplateContentFormatOption


233
class OpenAIServing:
234
235
236
237
    request_id_prefix: ClassVar[str] = """
    A short string prepended to every request’s ID (e.g. "embd", "classify")
    so you can easily tell “this ID came from Embedding vs Classification.”
    """
238

239
240
    def __init__(
        self,
241
        engine_client: EngineClient,
242
        models: OpenAIServingModels,
243
        *,
244
        request_logger: RequestLogger | None,
245
        return_tokens_as_token_ids: bool = False,
246
        log_error_stack: bool = False,
247
    ):
248
249
        super().__init__()

250
        self.engine_client = engine_client
251

252
        self.models = models
253

254
        self.request_logger = request_logger
255
        self.return_tokens_as_token_ids = return_tokens_as_token_ids
256

257
        self._async_tokenizer_pool: dict[TokenizerLike, AsyncMicrobatchTokenizer] = {}
258
        self.log_error_stack = log_error_stack
259

260
        self.input_processor = self.models.input_processor
261
        self.io_processor = self.models.io_processor
262
        self.renderer = self.models.renderer
263
264
265
        self.model_config = self.models.model_config
        self.max_model_len = self.model_config.max_model_len

266
    def _get_tool_parser(
267
        self, tool_parser_name: str | None = None, enable_auto_tools: bool = False
268
    ) -> Callable[[TokenizerLike], ToolParser] | None:
269
270
271
272
        """Get the tool parser based on the name."""
        parser = None
        if not enable_auto_tools or tool_parser_name is None:
            return parser
273
        logger.info('"auto" tool choice has been enabled.')
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293

        try:
            if tool_parser_name == "pythonic" and self.model_config.model.startswith(
                "meta-llama/Llama-3.2"
            ):
                logger.warning(
                    "Llama3.2 models may struggle to emit valid pythonic tool calls"
                )
            parser = ToolParserManager.get_tool_parser(tool_parser_name)
        except Exception as e:
            raise TypeError(
                "Error: --enable-auto-tool-choice requires "
                f"tool_parser:'{tool_parser_name}' which has not "
                "been registered"
            ) from e
        return parser

    def _get_reasoning_parser(
        self,
        reasoning_parser_name: str,
294
    ) -> Callable[[TokenizerLike], ReasoningParser] | None:
295
296
297
298
299
300
301
302
303
304
305
        """Get the reasoning parser based on the name."""
        parser = None
        if not reasoning_parser_name:
            return None
        try:
            parser = ReasoningParserManager.get_reasoning_parser(reasoning_parser_name)
            assert parser is not None
        except Exception as e:
            raise TypeError(f"{reasoning_parser_name=} has not been registered") from e
        return parser

306
    async def reset_mm_cache(self) -> None:
307
        self.input_processor.clear_mm_cache()
308
309
        await self.engine_client.reset_mm_cache()

310
311
312
313
314
    async def beam_search(
        self,
        prompt: PromptType,
        request_id: str,
        params: BeamSearchParams,
315
        lora_request: LoRARequest | None = None,
316
        trace_headers: Mapping[str, str] | None = None,
317
318
319
320
321
322
323
324
    ) -> AsyncGenerator[RequestOutput, None]:
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        ignore_eos = params.ignore_eos
        temperature = params.temperature
        length_penalty = params.length_penalty
        include_stop_str_in_output = params.include_stop_str_in_output

325
326
        input_processor = self.input_processor
        tokenizer = input_processor.tokenizer
327
        if tokenizer is None:
328
329
330
331
            raise VLLMValidationError(
                "You cannot use beam search when `skip_tokenizer_init=True`",
                parameter="skip_tokenizer_init",
                value=True,
332
333
334
335
336
337
338
            )

        eos_token_id: int = tokenizer.eos_token_id  # type: ignore

        if is_explicit_encoder_decoder_prompt(prompt):
            raise NotImplementedError

339
        prompt_text: str | None
340
        prompt_token_ids: list[int]
341
        multi_modal_data: MultiModalDataDict | None
342
343
344
345
346
347
348
349
350
        if isinstance(prompt, str):
            prompt_text = prompt
            prompt_token_ids = []
            multi_modal_data = None
        else:
            prompt_text = prompt.get("prompt")  # type: ignore
            prompt_token_ids = prompt.get("prompt_token_ids", [])  # type: ignore
            multi_modal_data = prompt.get("multi_modal_data")  # type: ignore

351
352
353
354
355
356
357
358
359
360
        mm_processor_kwargs: dict[str, Any] | None = None

        # This is a workaround to fix multimodal beam search; this is a
        # bandaid fix for 2 small problems:
        # 1. Multi_modal_data on the processed_inputs currently resolves to
        #    `None`.
        # 2. preprocessing above expands the multimodal placeholders. However,
        #    this happens again in generation, so the double expansion causes
        #    a mismatch.
        # TODO - would be ideal to handle this more gracefully.
361
362
363
364
365

        tokenized_length = len(prompt_token_ids)

        sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)

366
        logprobs_num = 2 * beam_width
367
        beam_search_params = SamplingParams(
368
            logprobs=logprobs_num,
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
            max_tokens=1,
            temperature=temperature,
        )
        all_beams = [
            BeamSearchSequence(
                tokens=prompt_token_ids,
                cum_logprob=0,
                logprobs=[],
                multi_modal_data=multi_modal_data,
                mm_processor_kwargs=mm_processor_kwargs,
                lora_request=lora_request,
            )
        ]
        completed = []

        for _ in range(max_tokens):
            prompts_batch, lora_req_batch = zip(
                *[
                    (
388
                        TokensPrompt(
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
                            prompt_token_ids=beam.tokens,
                            multi_modal_data=beam.multi_modal_data,
                            mm_processor_kwargs=beam.mm_processor_kwargs,
                        ),
                        beam.lora_request,
                    )
                    for beam in all_beams
                ]
            )

            tasks = []
            request_id_batch = f"{request_id}-{random_uuid()}"

            for i, (individual_prompt, lora_req) in enumerate(
                zip(prompts_batch, lora_req_batch)
            ):
                request_id_item = f"{request_id_batch}-beam-{i}"
                task = asyncio.create_task(
                    collect_from_async_generator(
                        self.engine_client.generate(
                            individual_prompt,
                            beam_search_params,
                            request_id_item,
                            lora_request=lora_req,
413
                            trace_headers=trace_headers,
414
415
416
417
418
419
420
421
                        )
                    )
                )
                tasks.append(task)

            output = [x[0] for x in await asyncio.gather(*tasks)]

            new_beams = []
422
423
424
425
426
427
428
429
            # Store all new tokens generated by beam
            all_beams_token_id = []
            # Store the cumulative probability of all tokens
            # generated by beam search
            all_beams_logprob = []
            # Iterate through all beam inference results
            for i, result in enumerate(output):
                current_beam = all_beams[i]
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452

                # check for error finish reason and abort beam search
                if result.outputs[0].finish_reason == "error":
                    # yield error output and terminate beam search
                    yield RequestOutput(
                        request_id=request_id,
                        prompt=prompt_text,
                        outputs=[
                            CompletionOutput(
                                index=0,
                                text="",
                                token_ids=[],
                                cumulative_logprob=None,
                                logprobs=None,
                                finish_reason="error",
                            )
                        ],
                        finished=True,
                        prompt_token_ids=prompt_token_ids,
                        prompt_logprobs=None,
                    )
                    return

453
454
                if result.outputs[0].logprobs is not None:
                    logprobs = result.outputs[0].logprobs[0]
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
                    all_beams_token_id.extend(list(logprobs.keys()))
                    all_beams_logprob.extend(
                        [
                            current_beam.cum_logprob + obj.logprob
                            for obj in logprobs.values()
                        ]
                    )

            # Handle the token for the end of sentence (EOS)
            all_beams_token_id = np.array(all_beams_token_id)
            all_beams_logprob = np.array(all_beams_logprob)

            if not ignore_eos:
                # Get the index position of eos token in all generated results
                eos_idx = np.where(all_beams_token_id == eos_token_id)[0]
                for idx in eos_idx:
                    current_beam = all_beams[idx // logprobs_num]
                    result = output[idx // logprobs_num]
                    assert result.outputs[0].logprobs is not None
                    logprobs_entry = result.outputs[0].logprobs[0]
                    completed.append(
                        BeamSearchSequence(
                            tokens=current_beam.tokens + [eos_token_id]
                            if include_stop_str_in_output
                            else current_beam.tokens,
                            logprobs=current_beam.logprobs + [logprobs_entry],
                            cum_logprob=float(all_beams_logprob[idx]),
                            finish_reason="stop",
                            stop_reason=eos_token_id,
                        )
                    )
                # After processing, set the log probability of the eos condition
                # to negative infinity.
                all_beams_logprob[eos_idx] = -np.inf

            # Processing non-EOS tokens
            # Get indices of the top beam_width probabilities
            topn_idx = np.argpartition(np.negative(all_beams_logprob), beam_width)[
                :beam_width
            ]

            for idx in topn_idx:
                current_beam = all_beams[idx // logprobs_num]
                result = output[idx // logprobs_num]
                token_id = int(all_beams_token_id[idx])
                assert result.outputs[0].logprobs is not None
                logprobs_entry = result.outputs[0].logprobs[0]
                new_beams.append(
                    BeamSearchSequence(
                        tokens=current_beam.tokens + [token_id],
                        logprobs=current_beam.logprobs + [logprobs_entry],
                        lora_request=current_beam.lora_request,
                        cum_logprob=float(all_beams_logprob[idx]),
                        multi_modal_data=current_beam.multi_modal_data,
                        mm_processor_kwargs=current_beam.mm_processor_kwargs,
                    )
                )

            all_beams = new_beams
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547

        completed.extend(all_beams)
        sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
        best_beams = sorted_completed[:beam_width]

        for beam in best_beams:
            if beam.tokens[-1] == eos_token_id and not ignore_eos:
                # Skip the eos token in the text.
                tokens = beam.tokens[tokenized_length:-1]
            else:
                tokens = beam.tokens[tokenized_length:]
            beam.text = tokenizer.decode(tokens)

        yield RequestOutput(
            request_id=request_id,
            prompt=prompt_text,
            outputs=[
                CompletionOutput(
                    text=beam.text,  # type: ignore
                    cumulative_logprob=beam.cum_logprob,
                    token_ids=beam.tokens[tokenized_length:],
                    index=i,
                    logprobs=beam.logprobs,
                    finish_reason=beam.finish_reason
                    if beam.finish_reason is not None
                    else "length",
                    stop_reason=beam.stop_reason,
                )
                for (i, beam) in enumerate(best_beams)
            ],
            finished=True,
            prompt_token_ids=prompt_token_ids,
            prompt_logprobs=None,
        )
548

549
    def _get_completion_renderer(self) -> BaseRenderer:
550
551
552
553
554
555
        """
        Get a Renderer instance with the provided tokenizer.
        Uses shared async tokenizer pool for efficiency.
        """
        return CompletionRenderer(
            model_config=self.model_config,
556
            tokenizer=self.renderer.tokenizer,
557
558
            async_tokenizer_pool=self._async_tokenizer_pool,
        )
559

560
561
562
563
564
565
566
567
568
569
570
571
572
    def _build_render_config(
        self,
        request: Any,
    ) -> RenderConfig:
        """
        Build and return a `RenderConfig` for an endpoint.

        Used by the renderer to control how prompts are prepared
        (e.g., tokenization and length handling). Endpoints should
        implement this with logic appropriate to their request type.
        """
        raise NotImplementedError

573
574
    def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer:
        """
575
        Return (and cache) an `AsyncMicrobatchTokenizer` bound to the
576
577
578
579
580
581
582
        given tokenizer.
        """
        async_tokenizer = self._async_tokenizer_pool.get(tokenizer)
        if async_tokenizer is None:
            async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
            self._async_tokenizer_pool[tokenizer] = async_tokenizer
        return async_tokenizer
583

584
585
586
    async def _preprocess(
        self,
        ctx: ServeContext,
587
    ) -> ErrorResponse | None:
588
589
590
591
592
593
594
595
596
        """
        Default preprocessing hook. Subclasses may override
        to prepare `ctx` (classification, embedding, etc.).
        """
        return None

    def _build_response(
        self,
        ctx: ServeContext,
597
    ) -> AnyResponse | ErrorResponse:
598
599
600
601
602
603
604
605
606
        """
        Default response builder. Subclass may override this method
        to return the appropriate response object.
        """
        return self.create_error_response("unimplemented endpoint")

    async def handle(
        self,
        ctx: ServeContext,
607
608
    ) -> AnyResponse | ErrorResponse:
        generation: AsyncGenerator[AnyResponse | ErrorResponse, None]
609
610
611
612
613
614
615
616
617
618
        generation = self._pipeline(ctx)

        async for response in generation:
            return response

        return self.create_error_response("No response yielded from pipeline")

    async def _pipeline(
        self,
        ctx: ServeContext,
619
    ) -> AsyncGenerator[AnyResponse | ErrorResponse, None]:
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
        """Execute the request processing pipeline yielding responses."""
        if error := await self._check_model(ctx.request):
            yield error
        if error := self._validate_request(ctx):
            yield error

        preprocess_ret = await self._preprocess(ctx)
        if isinstance(preprocess_ret, ErrorResponse):
            yield preprocess_ret

        generators_ret = await self._prepare_generators(ctx)
        if isinstance(generators_ret, ErrorResponse):
            yield generators_ret

        collect_ret = await self._collect_batch(ctx)
        if isinstance(collect_ret, ErrorResponse):
            yield collect_ret

        yield self._build_response(ctx)

640
    def _validate_request(self, ctx: ServeContext) -> ErrorResponse | None:
641
        truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None)
642

643
644
645
646
        if (
            truncate_prompt_tokens is not None
            and truncate_prompt_tokens > self.max_model_len
        ):
647
648
649
            return self.create_error_response(
                "truncate_prompt_tokens value is "
                "greater than max_model_len."
650
651
                " Please, select a smaller truncation size."
            )
652
653
        return None

654
655
656
    def _create_pooling_params(
        self,
        ctx: ServeContext,
657
    ) -> PoolingParams | ErrorResponse:
658
659
        if not hasattr(ctx.request, "to_pooling_params"):
            return self.create_error_response(
660
661
                "Request type does not support pooling parameters"
            )
662
663
664

        return ctx.request.to_pooling_params()

665
666
667
    async def _prepare_generators(
        self,
        ctx: ServeContext,
668
    ) -> ErrorResponse | None:
669
        """Schedule the request and get the result generator."""
670
        generators: list[
671
            AsyncGenerator[RequestOutput | PoolingRequestOutput, None]
672
        ] = []
673
674

        try:
675
676
677
678
679
            trace_headers = (
                None
                if ctx.raw_request is None
                else await self._get_trace_headers(ctx.raw_request.headers)
            )
680

681
682
683
            pooling_params = self._create_pooling_params(ctx)
            if isinstance(pooling_params, ErrorResponse):
                return pooling_params
684
685

            if ctx.engine_prompts is None:
686
                return self.create_error_response("Engine prompts not available")
687
688
689
690

            for i, engine_prompt in enumerate(ctx.engine_prompts):
                request_id_item = f"{ctx.request_id}-{i}"

691
692
                self._log_inputs(
                    request_id_item,
693
                    engine_prompt,
694
695
696
                    params=pooling_params,
                    lora_request=ctx.lora_request,
                )
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713

                generator = self.engine_client.encode(
                    engine_prompt,
                    pooling_params,
                    request_id_item,
                    lora_request=ctx.lora_request,
                    trace_headers=trace_headers,
                    priority=getattr(ctx.request, "priority", 0),
                )

                generators.append(generator)

            ctx.result_generator = merge_async_iterators(*generators)

            return None

        except Exception as e:
714
            return self.create_error_response(e)
715
716
717
718

    async def _collect_batch(
        self,
        ctx: ServeContext,
719
    ) -> ErrorResponse | None:
720
721
722
        """Collect batch results from the result generator."""
        try:
            if ctx.engine_prompts is None:
723
                return self.create_error_response("Engine prompts not available")
724
725

            num_prompts = len(ctx.engine_prompts)
726
            final_res_batch: list[RequestOutput | PoolingRequestOutput | None]
727
728
729
            final_res_batch = [None] * num_prompts

            if ctx.result_generator is None:
730
                return self.create_error_response("Result generator not available")
731
732
733
734
735
736

            async for i, res in ctx.result_generator:
                final_res_batch[i] = res

            if None in final_res_batch:
                return self.create_error_response(
737
738
                    "Failed to generate results for all prompts"
                )
739

740
            ctx.final_res_batch = [res for res in final_res_batch if res is not None]
741
742
743
744

            return None

        except Exception as e:
745
            return self.create_error_response(e)
746

747
    def create_error_response(
748
        self,
749
        message: str | Exception,
750
751
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
752
        param: str | None = None,
753
    ) -> ErrorResponse:
754
755
756
757
758
        exc: Exception | None = None

        if isinstance(message, Exception):
            exc = message

759
            from vllm.exceptions import VLLMValidationError
760
761
762
763
764

            if isinstance(exc, VLLMValidationError):
                err_type = "BadRequestError"
                status_code = HTTPStatus.BAD_REQUEST
                param = exc.parameter
765
            elif isinstance(exc, (ValueError, TypeError, RuntimeError, OverflowError)):
766
767
768
769
                # Common validation errors from user input
                err_type = "BadRequestError"
                status_code = HTTPStatus.BAD_REQUEST
                param = None
770
771
772
773
            elif isinstance(exc, NotImplementedError):
                err_type = "NotImplementedError"
                status_code = HTTPStatus.NOT_IMPLEMENTED
                param = None
774
775
776
777
778
779
780
781
782
783
784
785
            elif exc.__class__.__name__ == "TemplateError":
                # jinja2.TemplateError (avoid importing jinja2)
                err_type = "BadRequestError"
                status_code = HTTPStatus.BAD_REQUEST
                param = None
            else:
                err_type = "InternalServerError"
                status_code = HTTPStatus.INTERNAL_SERVER_ERROR
                param = None

            message = str(exc)

786
787
788
789
790
791
        if self.log_error_stack:
            exc_type, _, _ = sys.exc_info()
            if exc_type is not None:
                traceback.print_exc()
            else:
                traceback.print_stack()
792

793
        return ErrorResponse(
794
            error=ErrorInfo(
795
                message=sanitize_message(message),
796
797
798
799
                type=err_type,
                code=status_code.value,
                param=param,
            )
800
        )
801

802
    def create_streaming_error_response(
803
        self,
804
        message: str | Exception,
805
806
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
807
        param: str | None = None,
808
    ) -> str:
809
        json_str = json.dumps(
810
            self.create_error_response(
811
812
813
814
                message=message,
                err_type=err_type,
                status_code=status_code,
                param=param,
815
816
            ).model_dump()
        )
817
818
        return json_str

819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
    def _raise_if_error(self, finish_reason: str | None, request_id: str) -> None:
        """Raise GenerationError if finish_reason indicates an error."""
        if finish_reason == "error":
            logger.error(
                "Request %s failed with an internal error during generation",
                request_id,
            )
            raise GenerationError("Internal server error")

    def _convert_generation_error_to_response(
        self, e: GenerationError
    ) -> ErrorResponse:
        """Convert GenerationError to ErrorResponse."""
        return self.create_error_response(
            str(e),
            err_type="InternalServerError",
            status_code=e.status_code,
        )

    def _convert_generation_error_to_streaming_response(
        self, e: GenerationError
    ) -> str:
        """Convert GenerationError to streaming error response."""
        return self.create_streaming_error_response(
            str(e),
            err_type="InternalServerError",
            status_code=e.status_code,
        )

848
    async def _check_model(
849
850
        self,
        request: AnyRequest,
851
    ) -> ErrorResponse | None:
852
853
        error_response = None

854
        if self._is_model_supported(request.model):
855
            return None
856
        if request.model in self.models.lora_requests:
857
            return None
858
859
860
861
862
        if (
            envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING
            and request.model
            and (load_result := await self.models.resolve_lora(request.model))
        ):
863
864
            if isinstance(load_result, LoRARequest):
                return None
865
866
867
868
            if (
                isinstance(load_result, ErrorResponse)
                and load_result.error.code == HTTPStatus.BAD_REQUEST.value
            ):
869
870
871
                error_response = load_result

        return error_response or self.create_error_response(
872
873
            message=f"The model `{request.model}` does not exist.",
            err_type="NotFoundError",
874
            status_code=HTTPStatus.NOT_FOUND,
875
            param="model",
876
        )
877

878
    def _get_active_default_mm_loras(self, request: AnyRequest) -> LoRARequest | None:
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
        """Determine if there are any active default multimodal loras."""
        # TODO: Currently this is only enabled for chat completions
        # to be better aligned with only being enabled for .generate
        # when run offline. It would be nice to support additional
        # tasks types in the future.
        message_types = self._get_message_types(request)
        default_mm_loras = set()

        for lora in self.models.lora_requests.values():
            # Best effort match for default multimodal lora adapters;
            # There is probably a better way to do this, but currently
            # this matches against the set of 'types' in any content lists
            # up until '_', e.g., to match audio_url -> audio
            if lora.lora_name in message_types:
                default_mm_loras.add(lora)

        # Currently only support default modality specific loras if
        # we have exactly one lora matched on the request.
        if len(default_mm_loras) == 1:
            return default_mm_loras.pop()
        return None

901
    def _maybe_get_adapters(
902
903
904
        self,
        request: AnyRequest,
        supports_default_mm_loras: bool = False,
905
    ) -> LoRARequest | None:
906
        if request.model in self.models.lora_requests:
907
            return self.models.lora_requests[request.model]
908
909
910
911
912
913

        # Currently only support default modality specific loras
        # if we have exactly one lora matched on the request.
        if supports_default_mm_loras:
            default_mm_lora = self._get_active_default_mm_loras(request)
            if default_mm_lora is not None:
914
                return default_mm_lora
915
916

        if self._is_model_supported(request.model):
917
            return None
918

919
        # if _check_model has been called earlier, this will be unreachable
920
        raise ValueError(f"The model `{request.model}` does not exist.")
921

922
923
924
925
926
927
928
929
930
931
    def _get_message_types(self, request: AnyRequest) -> set[str]:
        """Retrieve the set of types from message content dicts up
        until `_`; we use this to match potential multimodal data
        with default per modality loras.
        """
        message_types: set[str] = set()

        if not hasattr(request, "messages"):
            return message_types

932
933
934
935
936
        messages = request.messages
        if messages is None or isinstance(messages, (str, bytes)):
            return message_types

        for message in messages:
937
938
939
940
941
            if (
                isinstance(message, dict)
                and "content" in message
                and isinstance(message["content"], list)
            ):
942
943
944
945
946
                for content_dict in message["content"]:
                    if "type" in content_dict:
                        message_types.add(content_dict["type"].split("_")[0])
        return message_types

947
    async def _normalize_prompt_text_to_input(
948
949
950
        self,
        request: AnyRequest,
        prompt: str,
951
        tokenizer: TokenizerLike,
952
        add_special_tokens: bool,
953
    ) -> TokensPrompt:
954
955
        async_tokenizer = self._get_async_tokenizer(tokenizer)

956
957
958
959
        if (
            self.model_config.encoder_config is not None
            and self.model_config.encoder_config.get("do_lower_case", False)
        ):
960
961
            prompt = prompt.lower()

962
        truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
963

964
        if truncate_prompt_tokens is None:
965
            encoded = await async_tokenizer(
966
967
                prompt, add_special_tokens=add_special_tokens
            )
968
969
        elif truncate_prompt_tokens < 0:
            # Negative means we cap at the model's max length
970
971
972
973
            encoded = await async_tokenizer(
                prompt,
                add_special_tokens=add_special_tokens,
                truncation=True,
974
975
                max_length=self.max_model_len,
            )
976
        else:
977
978
979
980
            encoded = await async_tokenizer(
                prompt,
                add_special_tokens=add_special_tokens,
                truncation=True,
981
982
                max_length=truncate_prompt_tokens,
            )
983
984
985
986
987
988

        input_ids = encoded.input_ids
        input_text = prompt

        return self._validate_input(request, input_ids, input_text)

989
    async def _normalize_prompt_tokens_to_input(
990
991
        self,
        request: AnyRequest,
992
        prompt_ids: list[int],
993
        tokenizer: TokenizerLike | None,
994
    ) -> TokensPrompt:
995
        truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
996

997
        if truncate_prompt_tokens is None:
998
            input_ids = prompt_ids
999
        elif truncate_prompt_tokens < 0:
1000
            input_ids = prompt_ids[-self.max_model_len :]
1001
1002
1003
        else:
            input_ids = prompt_ids[-truncate_prompt_tokens:]

1004
1005
1006
1007
1008
        if tokenizer is None:
            input_text = ""
        else:
            async_tokenizer = self._get_async_tokenizer(tokenizer)
            input_text = await async_tokenizer.decode(input_ids)
1009

1010
1011
1012
1013
1014
        return self._validate_input(request, input_ids, input_text)

    def _validate_input(
        self,
        request: AnyRequest,
1015
        input_ids: list[int],
1016
        input_text: str,
1017
    ) -> TokensPrompt:
1018
1019
        token_num = len(input_ids)

1020
1021
        # Note: EmbeddingRequest, ClassificationRequest,
        # and ScoreRequest doesn't have max_tokens
1022
        if isinstance(
1023
            request,
1024
1025
1026
            (
                EmbeddingChatRequest,
                EmbeddingCompletionRequest,
1027
1028
1029
                ScoreDataRequest,
                ScoreTextRequest,
                ScoreQueriesDocumentsRequest,
1030
                RerankRequest,
1031
1032
                ClassificationCompletionRequest,
                ClassificationChatRequest,
1033
1034
            ),
        ):
1035
1036
            # Note: input length can be up to the entire model context length
            # since these requests don't generate tokens.
1037
            if token_num > self.max_model_len:
1038
                operations: dict[type[AnyRequest], str] = {
1039
1040
1041
                    ScoreDataRequest: "score",
                    ScoreTextRequest: "score",
                    ScoreQueriesDocumentsRequest: "score",
1042
1043
                    ClassificationCompletionRequest: "classification",
                    ClassificationChatRequest: "classification",
1044
                }
1045
                operation = operations.get(type(request), "embedding generation")
1046
                raise VLLMValidationError(
1047
1048
                    f"This model's maximum context length is "
                    f"{self.max_model_len} tokens. However, you requested "
1049
                    f"{token_num} tokens in the input for {operation}. "
1050
1051
1052
                    f"Please reduce the length of the input.",
                    parameter="input_tokens",
                    value=token_num,
1053
                )
1054
            return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
1055

1056
1057
        # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
        # and does not require model context length validation
1058
        if isinstance(
1059
1060
            request,
            (TokenizeCompletionRequest, TokenizeChatRequest, DetokenizeRequest),
1061
        ):
1062
            return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
1063

1064
1065
1066
1067
1068
        # chat completion endpoint supports max_completion_tokens
        if isinstance(request, ChatCompletionRequest):
            # TODO(#9845): remove max_tokens when field dropped from OpenAI API
            max_tokens = request.max_completion_tokens or request.max_tokens
        else:
1069
            max_tokens = getattr(request, "max_tokens", None)
1070
1071
1072
1073

        # Note: input length can be up to model context length - 1 for
        # completion-like requests.
        if token_num >= self.max_model_len:
1074
            raise VLLMValidationError(
1075
                f"This model's maximum context length is "
1076
1077
                f"{self.max_model_len} tokens. However, your request has "
                f"{token_num} input tokens. Please reduce the length of "
1078
1079
1080
                "the input messages.",
                parameter="input_tokens",
                value=token_num,
1081
            )
1082

1083
        if max_tokens is not None and token_num + max_tokens > self.max_model_len:
1084
            raise VLLMValidationError(
1085
1086
1087
1088
                "'max_tokens' or 'max_completion_tokens' is too large: "
                f"{max_tokens}. This model's maximum context length is "
                f"{self.max_model_len} tokens and your request has "
                f"{token_num} input tokens ({max_tokens} > {self.max_model_len}"
1089
1090
1091
                f" - {token_num}).",
                parameter="max_tokens",
                value=max_tokens,
1092
            )
1093

1094
        return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
1095

1096
    async def _tokenize_prompt_input_async(
1097
1098
        self,
        request: AnyRequest,
1099
        tokenizer: TokenizerLike,
1100
        prompt_input: str | list[int],
1101
        add_special_tokens: bool = True,
1102
    ) -> TokensPrompt:
1103
        """
1104
        A simpler implementation that tokenizes a single prompt input.
1105
        """
1106
        async for result in self._tokenize_prompt_inputs_async(
1107
1108
            request,
            tokenizer,
1109
            [prompt_input],
1110
            add_special_tokens=add_special_tokens,
1111
1112
1113
        ):
            return result
        raise ValueError("No results yielded from tokenization")
1114

1115
    async def _tokenize_prompt_inputs_async(
1116
1117
        self,
        request: AnyRequest,
1118
        tokenizer: TokenizerLike,
1119
        prompt_inputs: Iterable[str | list[int]],
1120
        add_special_tokens: bool = True,
1121
    ) -> AsyncGenerator[TokensPrompt, None]:
1122
        """
1123
        A simpler implementation that tokenizes multiple prompt inputs.
1124
        """
1125
1126
        for prompt in prompt_inputs:
            if isinstance(prompt, str):
1127
                yield await self._normalize_prompt_text_to_input(
1128
                    request,
1129
1130
                    prompt=prompt,
                    tokenizer=tokenizer,
1131
1132
1133
                    add_special_tokens=add_special_tokens,
                )
            else:
1134
                yield await self._normalize_prompt_tokens_to_input(
1135
                    request,
1136
1137
                    prompt_ids=prompt,
                    tokenizer=tokenizer,
1138
1139
                )

1140
1141
    def _validate_chat_template(
        self,
1142
1143
        request_chat_template: str | None,
        chat_template_kwargs: dict[str, Any] | None,
1144
        trust_request_chat_template: bool,
1145
    ) -> ErrorResponse | None:
1146
        if not trust_request_chat_template and (
1147
1148
1149
1150
1151
1152
            request_chat_template is not None
            or (
                chat_template_kwargs
                and chat_template_kwargs.get("chat_template") is not None
            )
        ):
1153
1154
1155
            return self.create_error_response(
                "Chat template is passed with request, but "
                "--trust-request-chat-template is not set. "
1156
1157
                "Refused request with untrusted chat template."
            )
1158
1159
        return None

1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
    @staticmethod
    def _prepare_extra_chat_template_kwargs(
        request_chat_template_kwargs: dict[str, Any] | None = None,
        default_chat_template_kwargs: dict[str, Any] | None = None,
    ) -> dict[str, Any]:
        """Helper to merge server-default and request-specific chat template kwargs."""
        request_chat_template_kwargs = request_chat_template_kwargs or {}
        if default_chat_template_kwargs is None:
            return request_chat_template_kwargs
        # Apply server defaults first, then request kwargs override.
        return default_chat_template_kwargs | request_chat_template_kwargs

1172
1173
    async def _preprocess_chat(
        self,
1174
        request: ChatLikeRequest | ResponsesRequest,
1175
        renderer: RendererLike,
1176
        messages: list[ChatCompletionMessageParam],
1177
        chat_template: str | None,
1178
        chat_template_content_format: ChatTemplateContentFormatOption,
1179
1180
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
1181
1182
1183
        tool_dicts: list[dict[str, Any]] | None = None,
        documents: list[dict[str, str]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
1184
        default_chat_template_kwargs: dict[str, Any] | None = None,
1185
        tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
1186
        add_special_tokens: bool = False,
1187
    ) -> tuple[list[ConversationMessage], list[TokensPrompt]]:
1188
1189
1190
1191
1192
1193
1194
1195
1196
        chat_template_kwargs = {
            "chat_template": chat_template,
            "add_generation_prompt": add_generation_prompt,
            "continue_final_message": continue_final_message,
            "tools": tool_dicts,
            "documents": documents,
            **(chat_template_kwargs or {}),
        }
        chat_template_kwargs = self._prepare_extra_chat_template_kwargs(
1197
1198
1199
            chat_template_kwargs,
            default_chat_template_kwargs,
        )
1200

1201
1202
1203
1204
1205
        # Use the async tokenizer in `OpenAIServing` if possible.
        # Later we can move it into the renderer so that we can return both
        # text and token IDs in the same prompt from `render_messages_async`
        # which is used for logging and `enable_response_messages`.
        from vllm.tokenizers.mistral import MistralTokenizer
1206

1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
        conversation, engine_prompt = await renderer.render_messages_async(
            messages,
            chat_template_content_format=chat_template_content_format,
            tokenize=(
                chat_template_kwargs.pop("tokenize", False)
                or isinstance(renderer.tokenizer, MistralTokenizer)
            ),
            **chat_template_kwargs,
        )

        if "prompt_token_ids" not in engine_prompt:
            extra_data = engine_prompt
            engine_prompt = await self._tokenize_prompt_input_async(
                request,
                renderer.get_tokenizer(),
                engine_prompt["prompt"],
                add_special_tokens=add_special_tokens,
1224
            )
1225
1226
            # Fill in other keys like MM data
            engine_prompt.update(extra_data)  # type: ignore
1227
        else:
1228
1229
1230
1231
            self._validate_input(
                request=request,
                input_ids=engine_prompt["prompt_token_ids"],  # type: ignore
                input_text="",
1232
1233
            )

1234
1235
1236
1237
1238
1239
        engine_prompt = cast(TokensPrompt, engine_prompt)

        if request.mm_processor_kwargs is not None:
            engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
        if (cache_salt := getattr(request, "cache_salt", None)) is not None:
            engine_prompt["cache_salt"] = cache_salt
1240

1241
1242
1243
        # tool parsing is done only if a tool_parser has been set and if
        # tool_choice is not "none" (if tool_choice is "none" but a tool_parser
        # is set, we want to prevent parsing a tool_call hallucinated by the LLM
1244
1245
1246
        should_parse_tools = tool_parser is not None and (
            hasattr(request, "tool_choice") and request.tool_choice != "none"
        )
1247
1248

        if should_parse_tools:
1249
1250
1251
1252
1253
            if not isinstance(request, ChatCompletionRequest | ResponsesRequest):
                msg = (
                    "Tool usage is only supported for Chat Completions API "
                    "or Responses API requests."
                )
1254
                raise NotImplementedError(msg)
1255

1256
1257
            tokenizer = renderer.get_tokenizer()
            request = tool_parser(tokenizer).adjust_request(request=request)  # type: ignore
1258

1259
        return conversation, [engine_prompt]
1260

1261
1262
1263
1264
    async def _process_inputs(
        self,
        request_id: str,
        engine_prompt: PromptType,
1265
        params: SamplingParams | PoolingParams,
1266
        *,
1267
1268
        lora_request: LoRARequest | None,
        trace_headers: Mapping[str, str] | None,
1269
        priority: int,
1270
        data_parallel_rank: int | None = None,
1271
    ) -> tuple[EngineCoreRequest, dict[str, Any]]:
1272
        """Use the Processor to process inputs for AsyncLLM."""
1273
        tokenization_kwargs: dict[str, Any] = {}
1274
1275
1276
        _validate_truncation_size(
            self.max_model_len, params.truncate_prompt_tokens, tokenization_kwargs
        )
1277

1278
        engine_request = self.input_processor.process_inputs(
1279
1280
            request_id,
            engine_prompt,
1281
            params,
1282
1283
1284
1285
            lora_request=lora_request,
            tokenization_kwargs=tokenization_kwargs,
            trace_headers=trace_headers,
            priority=priority,
1286
            data_parallel_rank=data_parallel_rank,
1287
1288
1289
        )
        return engine_request, tokenization_kwargs

1290
1291
1292
    async def _render_next_turn(
        self,
        request: ResponsesRequest,
1293
        renderer: RendererLike,
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
        messages: list[ResponseInputOutputItem],
        tool_dicts: list[dict[str, Any]] | None,
        tool_parser,
        chat_template: str | None,
        chat_template_content_format: ChatTemplateContentFormatOption,
    ):
        new_messages = construct_input_messages(
            request_input=messages,
        )

1304
        _, engine_prompts = await self._preprocess_chat(
1305
            request,
1306
            renderer,
1307
1308
1309
1310
1311
1312
            new_messages,
            tool_dicts=tool_dicts,
            tool_parser=tool_parser,
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
        )
1313
        return engine_prompts
1314

1315
1316
1317
    async def _generate_with_builtin_tools(
        self,
        request_id: str,
1318
        engine_prompt: TokensPrompt,
1319
1320
        sampling_params: SamplingParams,
        context: ConversationContext,
1321
        lora_request: LoRARequest | None = None,
1322
1323
1324
        priority: int = 0,
        **kwargs,
    ):
1325
1326
        prompt_text, _, _ = self._get_prompt_components(engine_prompt)

1327
        orig_priority = priority
1328
        sub_request = 0
1329
        while True:
1330
1331
            # Ensure that each sub-request has a unique request id.
            sub_request_id = f"{request_id}_{sub_request}"
1332
            self._log_inputs(
1333
                sub_request_id,
1334
                engine_prompt,
1335
1336
1337
                params=sampling_params,
                lora_request=lora_request,
            )
1338
            trace_headers = kwargs.get("trace_headers")
1339
            engine_request, tokenization_kwargs = await self._process_inputs(
1340
                sub_request_id,
1341
1342
                engine_prompt,
                sampling_params,
1343
1344
1345
                lora_request=lora_request,
                trace_headers=trace_headers,
                priority=priority,
1346
            )
1347
1348
1349
1350

            generator = self.engine_client.generate(
                engine_request,
                sampling_params,
1351
                sub_request_id,
1352
1353
                lora_request=lora_request,
                priority=priority,
1354
1355
                prompt_text=prompt_text,
                tokenization_kwargs=tokenization_kwargs,
1356
1357
                **kwargs,
            )
1358

1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
            async for res in generator:
                context.append_output(res)
                # NOTE(woosuk): The stop condition is handled by the engine.
                yield context

            if not context.need_builtin_tool_call():
                # The model did not ask for a tool call, so we're done.
                break

            # Call the tool and update the context with the result.
            tool_output = await context.call_tool()
1370
            context.append_tool_output(tool_output)
1371
1372
1373
1374
1375
1376

            # TODO: uncomment this and enable tool output streaming
            # yield context

            # Create inputs for the next turn.
            # Render the next prompt token ids.
1377
1378
            if isinstance(context, (HarmonyContext, StreamingHarmonyContext)):
                prompt_token_ids = context.render_for_completion()
1379
                engine_prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
1380
            elif isinstance(context, ParsableContext):
1381
                engine_prompts = await self._render_next_turn(
1382
                    context.request,
1383
                    context.renderer,
1384
1385
1386
1387
1388
1389
1390
                    context.parser.response_messages,
                    context.tool_dicts,
                    context.tool_parser_cls,
                    context.chat_template,
                    context.chat_template_content_format,
                )
                engine_prompt = engine_prompts[0]
1391
                prompt_text, _, _ = self._get_prompt_components(engine_prompt)
1392

1393
            # Update the sampling params.
1394
1395
1396
            sampling_params.max_tokens = self.max_model_len - len(
                engine_prompt["prompt_token_ids"]
            )
1397
1398
            # OPTIMIZATION
            priority = orig_priority - 1
1399
            sub_request += 1
1400

1401
1402
    def _get_prompt_components(self, prompt: PromptType) -> PromptComponents:
        return get_prompt_components(prompt)
1403

1404
1405
1406
    def _log_inputs(
        self,
        request_id: str,
1407
        inputs: PromptType,
1408
1409
        params: SamplingParams | PoolingParams | BeamSearchParams | None,
        lora_request: LoRARequest | None,
1410
1411
1412
    ) -> None:
        if self.request_logger is None:
            return
1413

1414
        prompt, prompt_token_ids, prompt_embeds = self._get_prompt_components(inputs)
1415
1416
1417
1418
1419

        self.request_logger.log_inputs(
            request_id,
            prompt,
            prompt_token_ids,
1420
            prompt_embeds,
1421
1422
1423
            params=params,
            lora_request=lora_request,
        )
1424

1425
1426
1427
    async def _get_trace_headers(
        self,
        headers: Headers,
1428
    ) -> Mapping[str, str] | None:
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
        is_tracing_enabled = await self.engine_client.is_tracing_enabled()

        if is_tracing_enabled:
            return extract_trace_headers(headers)

        if contains_trace_headers(headers):
            log_tracing_disabled_warning()

        return None

1439
    @staticmethod
1440
    def _base_request_id(
1441
1442
        raw_request: Request | None, default: str | None = None
    ) -> str | None:
1443
        """Pulls the request id to use from a header, if provided"""
1444
1445
1446
1447
        if raw_request is not None and (
            (req_id := raw_request.headers.get("X-Request-Id")) is not None
        ):
            return req_id
1448

1449
        return random_uuid() if default is None else default
1450

1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
    @staticmethod
    def _get_data_parallel_rank(raw_request: Request | None) -> int | None:
        """Pulls the data parallel rank from a header, if provided"""
        if raw_request is None:
            return None

        rank_str = raw_request.headers.get("X-data-parallel-rank")
        if rank_str is None:
            return None

        try:
            return int(rank_str)
        except ValueError:
            return None

1466
1467
1468
    @staticmethod
    def _parse_tool_calls_from_content(
        request: ResponsesRequest | ChatCompletionRequest,
1469
        tokenizer: TokenizerLike | None,
1470
        enable_auto_tools: bool,
1471
        tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
        content: str | None = None,
    ) -> tuple[list[FunctionCall] | None, str | None]:
        function_calls = list[FunctionCall]()
        if request.tool_choice and isinstance(request.tool_choice, ToolChoiceFunction):
            assert content is not None
            # Forced Function Call
            function_calls.append(
                FunctionCall(name=request.tool_choice.name, arguments=content)
            )
            content = None  # Clear content since tool is called.
        elif request.tool_choice and isinstance(
            request.tool_choice, ChatCompletionNamedToolChoiceParam
        ):
            assert content is not None
            # Forced Function Call
            function_calls.append(
                FunctionCall(name=request.tool_choice.function.name, arguments=content)
            )
            content = None  # Clear content since tool is called.
        elif request.tool_choice == "required":
            assert content is not None
            tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(content)
            function_calls.extend(
                [
                    FunctionCall(
                        name=tool_call.name,
                        arguments=json.dumps(tool_call.parameters, ensure_ascii=False),
                    )
                    for tool_call in tool_calls
                ]
            )
            content = None  # Clear content since tool is called.
        elif (
            tool_parser_cls
            and enable_auto_tools
            and (request.tool_choice == "auto" or request.tool_choice is None)
        ):
1509
1510
1511
1512
1513
            if tokenizer is None:
                raise ValueError(
                    "Tokenizer not available when `skip_tokenizer_init=True`"
                )

1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
            # Automatic Tool Call Parsing
            try:
                tool_parser = tool_parser_cls(tokenizer)
            except RuntimeError as e:
                logger.exception("Error in tool parser creation.")
                raise e
            tool_call_info = tool_parser.extract_tool_calls(
                content if content is not None else "",
                request=request,  # type: ignore
            )
            if tool_call_info is not None and tool_call_info.tools_called:
                # extract_tool_calls() returns a list of tool calls.
                function_calls.extend(
                    FunctionCall(
1528
                        id=tool_call.id,
1529
1530
1531
1532
1533
1534
                        name=tool_call.function.name,
                        arguments=tool_call.function.arguments,
                    )
                    for tool_call in tool_call_info.tool_calls
                )
                content = tool_call_info.content
1535
1536
                if content and content.strip() == "":
                    content = None
1537
1538
1539
1540
1541
1542
            else:
                # No tool calls.
                return None, content

        return function_calls, content

1543
    @staticmethod
1544
1545
1546
    def _get_decoded_token(
        logprob: Logprob,
        token_id: int,
1547
        tokenizer: TokenizerLike | None,
1548
1549
        return_as_token_id: bool = False,
    ) -> str:
1550
1551
1552
        if return_as_token_id:
            return f"token_id:{token_id}"

1553
1554
        if logprob.decoded_token is not None:
            return logprob.decoded_token
1555
1556
1557
1558
1559
1560

        if tokenizer is None:
            raise ValueError(
                "Unable to get tokenizer because `skip_tokenizer_init=True`"
            )

1561
        return tokenizer.decode(token_id)
1562

1563
    def _is_model_supported(self, model_name: str | None) -> bool:
1564
1565
        if not model_name:
            return True
1566
        return self.models.is_base_model(model_name)
1567

1568
1569

def clamp_prompt_logprobs(
1570
1571
    prompt_logprobs: PromptLogprobs | None,
) -> PromptLogprobs | None:
1572
1573
1574
1575
1576
1577
1578
    if prompt_logprobs is None:
        return prompt_logprobs

    for logprob_dict in prompt_logprobs:
        if logprob_dict is None:
            continue
        for logprob_values in logprob_dict.values():
1579
            if logprob_values.logprob == float("-inf"):
1580
1581
                logprob_values.logprob = -9999.0
    return prompt_logprobs