serving.py 35 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
import contextlib
5
import json
6
import time
7
from collections.abc import AsyncGenerator, Mapping
8
from dataclasses import dataclass, field
9
from http import HTTPStatus
10
from typing import Any, ClassVar, Generic, Protocol, TypeAlias, TypeVar
11

12
import numpy as np
13
from fastapi import Request
14
15
16
from openai.types.responses import (
    ToolChoiceFunction,
)
17
from pydantic import ConfigDict, TypeAdapter, ValidationError
18
from starlette.datastructures import Headers
19

20
import vllm.envs as envs
21
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
22
from vllm.config import ModelConfig
23
from vllm.engine.protocol import EngineClient
24
25
26
from vllm.entrypoints.chat_utils import (
    ChatTemplateContentFormatOption,
)
27
from vllm.entrypoints.logger import RequestLogger
28
from vllm.entrypoints.openai.chat_completion.protocol import (
29
    ChatCompletionNamedToolChoiceParam,
30
31
    ChatCompletionRequest,
    ChatCompletionResponse,
32
)
33
from vllm.entrypoints.openai.completion.protocol import (
34
35
    CompletionRequest,
    CompletionResponse,
36
37
)
from vllm.entrypoints.openai.engine.protocol import (
38
    ErrorResponse,
39
    FunctionCall,
40
    FunctionDefinition,
41
    GenerationError,
42
)
43
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
44
45
46
from vllm.entrypoints.openai.responses.protocol import (
    ResponsesRequest,
)
47
from vllm.entrypoints.openai.speech_to_text.protocol import (
48
49
50
51
    TranscriptionRequest,
    TranscriptionResponse,
    TranslationRequest,
)
52
53
from vllm.entrypoints.pooling.pooling.protocol import (
    IOProcessorRequest,
54
55
    PoolingChatRequest,
    PoolingCompletionRequest,
56
57
58
59
    PoolingResponse,
)
from vllm.entrypoints.pooling.score.protocol import (
    RerankRequest,
60
61
    ScoreDataRequest,
    ScoreQueriesDocumentsRequest,
62
63
    ScoreRequest,
    ScoreResponse,
64
    ScoreTextRequest,
65
)
66
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
67
68
69
70
71
72
from vllm.entrypoints.serve.tokenize.protocol import (
    DetokenizeRequest,
    TokenizeChatRequest,
    TokenizeCompletionRequest,
    TokenizeResponse,
)
73
from vllm.entrypoints.utils import create_error_response
74
from vllm.exceptions import VLLMValidationError
75
from vllm.inputs import EngineInput, PromptType, TokensPrompt
76
from vllm.logger import init_logger
77
from vllm.logprobs import Logprob, PromptLogprobs
78
from vllm.lora.request import LoRARequest
79
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
80
from vllm.pooling_params import PoolingParams
81
from vllm.renderers import ChatParams, TokenizeParams
82
83
84
85
from vllm.renderers.inputs.preprocess import (
    extract_prompt_components,
    extract_prompt_len,
)
86
from vllm.sampling_params import BeamSearchParams, SamplingParams
87
from vllm.tokenizers import TokenizerLike
88
from vllm.tool_parsers import ToolParser
89
90
91
92
93
from vllm.tracing import (
    contains_trace_headers,
    extract_trace_headers,
    log_tracing_disabled_warning,
)
94
from vllm.utils import random_uuid
95
from vllm.utils.async_utils import (
96
    collect_from_async_generator,
97
98
    merge_async_iterators,
)
99
100
101

logger = init_logger(__name__)

102
103
104
105
106
107
108
109
110
111
112
113
114
115
116

class RendererRequest(Protocol):
    def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
        raise NotImplementedError


class RendererChatRequest(RendererRequest, Protocol):
    def build_chat_params(
        self,
        default_template: str | None,
        default_template_content_format: ChatTemplateContentFormatOption,
    ) -> ChatParams:
        raise NotImplementedError


117
118
CompletionLikeRequest: TypeAlias = (
    CompletionRequest
119
    | TokenizeCompletionRequest
120
    | DetokenizeRequest
121
    | RerankRequest
122
    | ScoreRequest
123
    | PoolingCompletionRequest
124
)
125

126
ChatLikeRequest: TypeAlias = (
127
    ChatCompletionRequest | TokenizeChatRequest | PoolingChatRequest
128
)
129

130
SpeechToTextRequest: TypeAlias = TranscriptionRequest | TranslationRequest
131

132
133
134
135
136
137
AnyRequest: TypeAlias = (
    CompletionLikeRequest
    | ChatLikeRequest
    | SpeechToTextRequest
    | ResponsesRequest
    | IOProcessorRequest
138
    | GenerateRequest
139
140
141
142
143
144
145
146
147
)

AnyResponse: TypeAlias = (
    CompletionResponse
    | ChatCompletionResponse
    | TranscriptionResponse
    | TokenizeResponse
    | PoolingResponse
    | ScoreResponse
148
    | GenerateResponse
149
)
150
151
152
153

RequestT = TypeVar("RequestT", bound=AnyRequest)


154
@dataclass(kw_only=True)
155
class ServeContext(Generic[RequestT]):
156
    request: RequestT
157
    raw_request: Request | None = None
158
159
    model_name: str
    request_id: str
160
    created_time: int = field(default_factory=lambda: int(time.time()))
161
    lora_request: LoRARequest | None = None
162
    engine_inputs: list[EngineInput] | None = None
163

164
165
166
167
    result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
        None
    )
    final_res_batch: list[PoolingRequestOutput] = field(default_factory=list)
168

169
    model_config = ConfigDict(arbitrary_types_allowed=True)
170
171


172
class OpenAIServing:
173
    request_id_prefix: ClassVar[str] = """
174
    A short string prepended to every request’s ID.
175
    """
176

177
178
    def __init__(
        self,
179
        engine_client: EngineClient,
180
        models: OpenAIServingModels,
181
        *,
182
        request_logger: RequestLogger | None,
183
        return_tokens_as_token_ids: bool = False,
184
    ):
185
186
        super().__init__()

187
        self.engine_client = engine_client
188

189
        self.models = models
190

191
        self.request_logger = request_logger
192
        self.return_tokens_as_token_ids = return_tokens_as_token_ids
193

194
195
196
197
        self.model_config = engine_client.model_config
        self.renderer = engine_client.renderer
        self.io_processor = engine_client.io_processor
        self.input_processor = engine_client.input_processor
198
199
200

    async def beam_search(
        self,
201
        prompt: EngineInput,
202
203
        request_id: str,
        params: BeamSearchParams,
204
        lora_request: LoRARequest | None = None,
205
        trace_headers: Mapping[str, str] | None = None,
206
207
208
209
210
211
212
213
    ) -> AsyncGenerator[RequestOutput, None]:
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        ignore_eos = params.ignore_eos
        temperature = params.temperature
        length_penalty = params.length_penalty
        include_stop_str_in_output = params.include_stop_str_in_output

214
215
216
        tokenizer = self.renderer.get_tokenizer()
        eos_token_id = tokenizer.eos_token_id
        sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
217

218
219
        if prompt["type"] == "embeds":
            raise NotImplementedError("Embedding prompt not supported for beam search")
220

221
222
223
224
225
226
227
        # Extract prompt tokens and text based on model type
        decoder_prompt = (
            prompt if prompt["type"] != "enc_dec" else prompt["decoder_prompt"]
        )
        prompt_text = decoder_prompt.get("prompt")
        prompt_token_ids = decoder_prompt["prompt_token_ids"]

228
229
        tokenized_length = len(prompt_token_ids)

230
        logprobs_num = 2 * beam_width
231
        sampling_params = SamplingParams(
232
            logprobs=logprobs_num,
233
234
235
236
237
            max_tokens=1,
            temperature=temperature,
        )
        all_beams = [
            BeamSearchSequence(
238
                orig_prompt=prompt,
239
240
241
242
243
244
245
246
247
248
249
250
                tokens=prompt_token_ids,
                cum_logprob=0,
                logprobs=[],
                lora_request=lora_request,
            )
        ]
        completed = []

        for _ in range(max_tokens):
            tasks = []
            request_id_batch = f"{request_id}-{random_uuid()}"

251
252
253
            for i, beam in enumerate(all_beams):
                prompt_item = beam.get_prompt()
                lora_request_item = beam.lora_request
254
255
256
257
                request_id_item = f"{request_id_batch}-beam-{i}"
                task = asyncio.create_task(
                    collect_from_async_generator(
                        self.engine_client.generate(
258
259
                            prompt_item,
                            sampling_params,
260
                            request_id_item,
261
                            lora_request=lora_request_item,
262
                            trace_headers=trace_headers,
263
264
265
266
267
268
269
270
                        )
                    )
                )
                tasks.append(task)

            output = [x[0] for x in await asyncio.gather(*tasks)]

            new_beams = []
271
272
273
274
275
276
277
278
            # Store all new tokens generated by beam
            all_beams_token_id = []
            # Store the cumulative probability of all tokens
            # generated by beam search
            all_beams_logprob = []
            # Iterate through all beam inference results
            for i, result in enumerate(output):
                current_beam = all_beams[i]
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301

                # check for error finish reason and abort beam search
                if result.outputs[0].finish_reason == "error":
                    # yield error output and terminate beam search
                    yield RequestOutput(
                        request_id=request_id,
                        prompt=prompt_text,
                        outputs=[
                            CompletionOutput(
                                index=0,
                                text="",
                                token_ids=[],
                                cumulative_logprob=None,
                                logprobs=None,
                                finish_reason="error",
                            )
                        ],
                        finished=True,
                        prompt_token_ids=prompt_token_ids,
                        prompt_logprobs=None,
                    )
                    return

302
303
                if result.outputs[0].logprobs is not None:
                    logprobs = result.outputs[0].logprobs[0]
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
                    all_beams_token_id.extend(list(logprobs.keys()))
                    all_beams_logprob.extend(
                        [
                            current_beam.cum_logprob + obj.logprob
                            for obj in logprobs.values()
                        ]
                    )

            # Handle the token for the end of sentence (EOS)
            all_beams_token_id = np.array(all_beams_token_id)
            all_beams_logprob = np.array(all_beams_logprob)

            if not ignore_eos:
                # Get the index position of eos token in all generated results
                eos_idx = np.where(all_beams_token_id == eos_token_id)[0]
                for idx in eos_idx:
                    current_beam = all_beams[idx // logprobs_num]
                    result = output[idx // logprobs_num]
                    assert result.outputs[0].logprobs is not None
                    logprobs_entry = result.outputs[0].logprobs[0]
                    completed.append(
                        BeamSearchSequence(
326
                            orig_prompt=prompt,
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
                            tokens=current_beam.tokens + [eos_token_id]
                            if include_stop_str_in_output
                            else current_beam.tokens,
                            logprobs=current_beam.logprobs + [logprobs_entry],
                            cum_logprob=float(all_beams_logprob[idx]),
                            finish_reason="stop",
                            stop_reason=eos_token_id,
                        )
                    )
                # After processing, set the log probability of the eos condition
                # to negative infinity.
                all_beams_logprob[eos_idx] = -np.inf

            # Processing non-EOS tokens
            # Get indices of the top beam_width probabilities
            topn_idx = np.argpartition(np.negative(all_beams_logprob), beam_width)[
                :beam_width
            ]

            for idx in topn_idx:
                current_beam = all_beams[idx // logprobs_num]
                result = output[idx // logprobs_num]
                token_id = int(all_beams_token_id[idx])
                assert result.outputs[0].logprobs is not None
                logprobs_entry = result.outputs[0].logprobs[0]
                new_beams.append(
                    BeamSearchSequence(
354
                        orig_prompt=prompt,
355
356
357
358
359
360
361
362
                        tokens=current_beam.tokens + [token_id],
                        logprobs=current_beam.logprobs + [logprobs_entry],
                        lora_request=current_beam.lora_request,
                        cum_logprob=float(all_beams_logprob[idx]),
                    )
                )

            all_beams = new_beams
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396

        completed.extend(all_beams)
        sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
        best_beams = sorted_completed[:beam_width]

        for beam in best_beams:
            if beam.tokens[-1] == eos_token_id and not ignore_eos:
                # Skip the eos token in the text.
                tokens = beam.tokens[tokenized_length:-1]
            else:
                tokens = beam.tokens[tokenized_length:]
            beam.text = tokenizer.decode(tokens)

        yield RequestOutput(
            request_id=request_id,
            prompt=prompt_text,
            outputs=[
                CompletionOutput(
                    text=beam.text,  # type: ignore
                    cumulative_logprob=beam.cum_logprob,
                    token_ids=beam.tokens[tokenized_length:],
                    index=i,
                    logprobs=beam.logprobs,
                    finish_reason=beam.finish_reason
                    if beam.finish_reason is not None
                    else "length",
                    stop_reason=beam.stop_reason,
                )
                for (i, beam) in enumerate(best_beams)
            ],
            finished=True,
            prompt_token_ids=prompt_token_ids,
            prompt_logprobs=None,
        )
397

398
399
400
    async def _preprocess(
        self,
        ctx: ServeContext,
401
    ) -> ErrorResponse | None:
402
        """
403
        Default preprocessing hook. Subclasses may override to prepare `ctx`.
404
405
406
407
408
409
        """
        return None

    def _build_response(
        self,
        ctx: ServeContext,
410
    ) -> AnyResponse | ErrorResponse:
411
412
413
414
415
416
417
418
419
        """
        Default response builder. Subclass may override this method
        to return the appropriate response object.
        """
        return self.create_error_response("unimplemented endpoint")

    async def handle(
        self,
        ctx: ServeContext,
420
    ) -> AnyResponse | ErrorResponse:
421
        async for response in self._pipeline(ctx):
422
423
424
425
426
427
428
            return response

        return self.create_error_response("No response yielded from pipeline")

    async def _pipeline(
        self,
        ctx: ServeContext,
429
    ) -> AsyncGenerator[AnyResponse | ErrorResponse, None]:
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
        """Execute the request processing pipeline yielding responses."""
        if error := await self._check_model(ctx.request):
            yield error
        if error := self._validate_request(ctx):
            yield error

        preprocess_ret = await self._preprocess(ctx)
        if isinstance(preprocess_ret, ErrorResponse):
            yield preprocess_ret

        generators_ret = await self._prepare_generators(ctx)
        if isinstance(generators_ret, ErrorResponse):
            yield generators_ret

        collect_ret = await self._collect_batch(ctx)
        if isinstance(collect_ret, ErrorResponse):
            yield collect_ret

        yield self._build_response(ctx)

450
    def _validate_request(self, ctx: ServeContext) -> ErrorResponse | None:
451
        truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None)
452

453
454
        if (
            truncate_prompt_tokens is not None
455
            and truncate_prompt_tokens > self.model_config.max_model_len
456
        ):
457
458
459
            return self.create_error_response(
                "truncate_prompt_tokens value is "
                "greater than max_model_len."
460
461
                " Please, select a smaller truncation size."
            )
462
463
        return None

464
465
466
    def _create_pooling_params(
        self,
        ctx: ServeContext,
467
    ) -> PoolingParams | ErrorResponse:
468
469
        if not hasattr(ctx.request, "to_pooling_params"):
            return self.create_error_response(
470
471
                "Request type does not support pooling parameters"
            )
472
473
474

        return ctx.request.to_pooling_params()

475
476
477
    async def _prepare_generators(
        self,
        ctx: ServeContext,
478
    ) -> ErrorResponse | None:
479
        """Schedule the request and get the result generator."""
480
        generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
481

482
483
484
485
486
        trace_headers = (
            None
            if ctx.raw_request is None
            else await self._get_trace_headers(ctx.raw_request.headers)
        )
487

488
489
490
        pooling_params = self._create_pooling_params(ctx)
        if isinstance(pooling_params, ErrorResponse):
            return pooling_params
491

492
        if ctx.engine_inputs is None:
493
            return self.create_error_response("Engine prompts not available")
494

495
        for i, engine_input in enumerate(ctx.engine_inputs):
496
            request_id_item = f"{ctx.request_id}-{i}"
497

498
499
            self._log_inputs(
                request_id_item,
500
                engine_input,
501
502
503
                params=pooling_params,
                lora_request=ctx.lora_request,
            )
504

505
            generator = self.engine_client.encode(
506
                engine_input,
507
508
509
510
511
512
                pooling_params,
                request_id_item,
                lora_request=ctx.lora_request,
                trace_headers=trace_headers,
                priority=getattr(ctx.request, "priority", 0),
            )
513

514
            generators.append(generator)
515

516
        ctx.result_generator = merge_async_iterators(*generators)
517

518
        return None
519
520
521
522

    async def _collect_batch(
        self,
        ctx: ServeContext,
523
    ) -> ErrorResponse | None:
524
        """Collect batch results from the result generator."""
525
        if ctx.engine_inputs is None:
526
            return self.create_error_response("Engine prompts not available")
527

528
        num_prompts = len(ctx.engine_inputs)
529
530
        final_res_batch: list[PoolingRequestOutput | None]
        final_res_batch = [None] * num_prompts
531

532
533
        if ctx.result_generator is None:
            return self.create_error_response("Result generator not available")
534

535
536
        async for i, res in ctx.result_generator:
            final_res_batch[i] = res
537

538
539
540
541
        if None in final_res_batch:
            return self.create_error_response(
                "Failed to generate results for all prompts"
            )
542

543
        ctx.final_res_batch = [res for res in final_res_batch if res is not None]
544

545
        return None
546

547
    @staticmethod
548
    def create_error_response(
549
        message: str | Exception,
550
551
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
552
        param: str | None = None,
553
    ) -> ErrorResponse:
554
        return create_error_response(message, err_type, status_code, param)
555

556
    def create_streaming_error_response(
557
        self,
558
        message: str | Exception,
559
560
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
561
        param: str | None = None,
562
    ) -> str:
563
        json_str = json.dumps(
564
            self.create_error_response(
565
566
567
568
                message=message,
                err_type=err_type,
                status_code=status_code,
                param=param,
569
570
            ).model_dump()
        )
571
572
        return json_str

573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
    def _raise_if_error(self, finish_reason: str | None, request_id: str) -> None:
        """Raise GenerationError if finish_reason indicates an error."""
        if finish_reason == "error":
            logger.error(
                "Request %s failed with an internal error during generation",
                request_id,
            )
            raise GenerationError("Internal server error")

    def _convert_generation_error_to_streaming_response(
        self, e: GenerationError
    ) -> str:
        """Convert GenerationError to streaming error response."""
        return self.create_streaming_error_response(
            str(e),
            err_type="InternalServerError",
            status_code=e.status_code,
        )

592
    async def _check_model(
593
594
        self,
        request: AnyRequest,
595
    ) -> ErrorResponse | None:
596
597
        error_response = None

598
        if self._is_model_supported(request.model):
599
            return None
600
        if request.model in self.models.lora_requests:
601
            return None
602
603
604
605
606
        if (
            envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING
            and request.model
            and (load_result := await self.models.resolve_lora(request.model))
        ):
607
608
            if isinstance(load_result, LoRARequest):
                return None
609
610
611
612
            if (
                isinstance(load_result, ErrorResponse)
                and load_result.error.code == HTTPStatus.BAD_REQUEST.value
            ):
613
614
615
                error_response = load_result

        return error_response or self.create_error_response(
616
617
            message=f"The model `{request.model}` does not exist.",
            err_type="NotFoundError",
618
            status_code=HTTPStatus.NOT_FOUND,
619
            param="model",
620
        )
621

622
    def _get_active_default_mm_loras(self, request: AnyRequest) -> LoRARequest | None:
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
        """Determine if there are any active default multimodal loras."""
        # TODO: Currently this is only enabled for chat completions
        # to be better aligned with only being enabled for .generate
        # when run offline. It would be nice to support additional
        # tasks types in the future.
        message_types = self._get_message_types(request)
        default_mm_loras = set()

        for lora in self.models.lora_requests.values():
            # Best effort match for default multimodal lora adapters;
            # There is probably a better way to do this, but currently
            # this matches against the set of 'types' in any content lists
            # up until '_', e.g., to match audio_url -> audio
            if lora.lora_name in message_types:
                default_mm_loras.add(lora)

        # Currently only support default modality specific loras if
        # we have exactly one lora matched on the request.
        if len(default_mm_loras) == 1:
            return default_mm_loras.pop()
        return None

645
    def _maybe_get_adapters(
646
647
648
        self,
        request: AnyRequest,
        supports_default_mm_loras: bool = False,
649
    ) -> LoRARequest | None:
650
        if request.model in self.models.lora_requests:
651
            return self.models.lora_requests[request.model]
652
653
654
655
656
657

        # Currently only support default modality specific loras
        # if we have exactly one lora matched on the request.
        if supports_default_mm_loras:
            default_mm_lora = self._get_active_default_mm_loras(request)
            if default_mm_lora is not None:
658
                return default_mm_lora
659
660

        if self._is_model_supported(request.model):
661
            return None
662

663
        # if _check_model has been called earlier, this will be unreachable
664
        raise ValueError(f"The model `{request.model}` does not exist.")
665

666
667
668
669
670
671
672
673
674
675
    def _get_message_types(self, request: AnyRequest) -> set[str]:
        """Retrieve the set of types from message content dicts up
        until `_`; we use this to match potential multimodal data
        with default per modality loras.
        """
        message_types: set[str] = set()

        if not hasattr(request, "messages"):
            return message_types

676
677
678
679
680
        messages = request.messages
        if messages is None or isinstance(messages, (str, bytes)):
            return message_types

        for message in messages:
681
682
683
684
685
            if (
                isinstance(message, dict)
                and "content" in message
                and isinstance(message["content"], list)
            ):
686
687
688
689
690
                for content_dict in message["content"]:
                    if "type" in content_dict:
                        message_types.add(content_dict["type"].split("_")[0])
        return message_types

691
692
    def _validate_input(
        self,
693
        request: object,
694
        input_ids: list[int],
695
        input_text: str,
696
    ) -> TokensPrompt:
697
        token_num = len(input_ids)
698
        max_model_len = self.model_config.max_model_len
699

700
        # Note: ScoreRequest doesn't have max_tokens
701
        if isinstance(
702
            request,
703
            (
704
705
706
                ScoreDataRequest,
                ScoreTextRequest,
                ScoreQueriesDocumentsRequest,
707
708
709
                RerankRequest,
            ),
        ):
710
711
            # Note: input length can be up to the entire model context length
            # since these requests don't generate tokens.
712
            if token_num > max_model_len:
713
                operations: dict[type[AnyRequest], str] = {
714
715
716
                    ScoreDataRequest: "score",
                    ScoreTextRequest: "score",
                    ScoreQueriesDocumentsRequest: "score",
717
                }
718
                operation = operations.get(type(request), "embedding generation")
719
                raise VLLMValidationError(
720
                    f"This model's maximum context length is "
721
                    f"{max_model_len} tokens. However, you requested "
722
                    f"{token_num} tokens in the input for {operation}. "
723
724
725
                    f"Please reduce the length of the input.",
                    parameter="input_tokens",
                    value=token_num,
726
                )
727
            return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
728

729
730
        # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
        # and does not require model context length validation
731
        if isinstance(
732
733
            request,
            (TokenizeCompletionRequest, TokenizeChatRequest, DetokenizeRequest),
734
        ):
735
            return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
736

737
738
739
740
741
        # chat completion endpoint supports max_completion_tokens
        if isinstance(request, ChatCompletionRequest):
            # TODO(#9845): remove max_tokens when field dropped from OpenAI API
            max_tokens = request.max_completion_tokens or request.max_tokens
        else:
742
            max_tokens = getattr(request, "max_tokens", None)
743
744
745

        # Note: input length can be up to model context length - 1 for
        # completion-like requests.
746
        if token_num >= max_model_len:
747
            raise VLLMValidationError(
748
                f"This model's maximum context length is "
749
                f"{max_model_len} tokens. However, your request has "
750
                f"{token_num} input tokens. Please reduce the length of "
751
752
753
                "the input messages.",
                parameter="input_tokens",
                value=token_num,
754
            )
755

756
        if max_tokens is not None and token_num + max_tokens > max_model_len:
757
            raise VLLMValidationError(
758
759
760
761
762
763
764
765
766
                f"This model's maximum context length is "
                f"{max_model_len} tokens. However, you requested "
                f"{max_tokens} output tokens and your prompt contains "
                f"{token_num} input tokens, for a total of "
                f"{token_num + max_tokens} tokens "
                f"({token_num} + {max_tokens} = "
                f"{token_num + max_tokens} > {max_model_len}). "
                f"Please reduce the length of the input prompt or the "
                f"number of requested output tokens.",
767
768
                parameter="max_tokens",
                value=max_tokens,
769
            )
770

771
        return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
772

773
774
    def _validate_chat_template(
        self,
775
776
        request_chat_template: str | None,
        chat_template_kwargs: dict[str, Any] | None,
777
        trust_request_chat_template: bool,
778
    ) -> ErrorResponse | None:
779
        if not trust_request_chat_template and (
780
781
782
783
784
785
            request_chat_template is not None
            or (
                chat_template_kwargs
                and chat_template_kwargs.get("chat_template") is not None
            )
        ):
786
787
788
            return self.create_error_response(
                "Chat template is passed with request, but "
                "--trust-request-chat-template is not set. "
789
790
                "Refused request with untrusted chat template."
            )
791
792
        return None

793
794
795
796
797
798
799
800
801
802
803
804
    @staticmethod
    def _prepare_extra_chat_template_kwargs(
        request_chat_template_kwargs: dict[str, Any] | None = None,
        default_chat_template_kwargs: dict[str, Any] | None = None,
    ) -> dict[str, Any]:
        """Helper to merge server-default and request-specific chat template kwargs."""
        request_chat_template_kwargs = request_chat_template_kwargs or {}
        if default_chat_template_kwargs is None:
            return request_chat_template_kwargs
        # Apply server defaults first, then request kwargs override.
        return default_chat_template_kwargs | request_chat_template_kwargs

805
    def _extract_prompt_components(self, prompt: PromptType | EngineInput):
806
807
        return extract_prompt_components(self.model_config, prompt)

808
    def _extract_prompt_text(self, prompt: PromptType | EngineInput):
809
810
        return self._extract_prompt_components(prompt).text

811
    def _extract_prompt_len(self, prompt: EngineInput):
812
813
        return extract_prompt_len(self.model_config, prompt)

814
815
816
    def _log_inputs(
        self,
        request_id: str,
817
        inputs: PromptType | EngineInput,
818
819
        params: SamplingParams | PoolingParams | BeamSearchParams | None,
        lora_request: LoRARequest | None,
820
821
822
    ) -> None:
        if self.request_logger is None:
            return
823

824
        components = self._extract_prompt_components(inputs)
825
826
827

        self.request_logger.log_inputs(
            request_id,
828
829
830
            components.text,
            components.token_ids,
            components.embeds,
831
832
833
            params=params,
            lora_request=lora_request,
        )
834

835
836
837
    async def _get_trace_headers(
        self,
        headers: Headers,
838
    ) -> Mapping[str, str] | None:
839
840
841
842
843
844
845
846
847
848
        is_tracing_enabled = await self.engine_client.is_tracing_enabled()

        if is_tracing_enabled:
            return extract_trace_headers(headers)

        if contains_trace_headers(headers):
            log_tracing_disabled_warning()

        return None

849
    @staticmethod
850
    def _base_request_id(
851
852
        raw_request: Request | None, default: str | None = None
    ) -> str | None:
853
        """Pulls the request id to use from a header, if provided"""
854
855
856
857
        if raw_request is not None and (
            (req_id := raw_request.headers.get("X-Request-Id")) is not None
        ):
            return req_id
858

859
        return random_uuid() if default is None else default
860

861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
    @staticmethod
    def _get_data_parallel_rank(raw_request: Request | None) -> int | None:
        """Pulls the data parallel rank from a header, if provided"""
        if raw_request is None:
            return None

        rank_str = raw_request.headers.get("X-data-parallel-rank")
        if rank_str is None:
            return None

        try:
            return int(rank_str)
        except ValueError:
            return None

876
877
878
    @staticmethod
    def _parse_tool_calls_from_content(
        request: ResponsesRequest | ChatCompletionRequest,
879
        tokenizer: TokenizerLike | None,
880
        enable_auto_tools: bool,
881
        tool_parser_cls: type[ToolParser] | None,
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
        content: str | None = None,
    ) -> tuple[list[FunctionCall] | None, str | None]:
        function_calls = list[FunctionCall]()
        if request.tool_choice and isinstance(request.tool_choice, ToolChoiceFunction):
            assert content is not None
            # Forced Function Call
            function_calls.append(
                FunctionCall(name=request.tool_choice.name, arguments=content)
            )
            content = None  # Clear content since tool is called.
        elif request.tool_choice and isinstance(
            request.tool_choice, ChatCompletionNamedToolChoiceParam
        ):
            assert content is not None
            # Forced Function Call
            function_calls.append(
                FunctionCall(name=request.tool_choice.function.name, arguments=content)
            )
            content = None  # Clear content since tool is called.
        elif request.tool_choice == "required":
902
903
904
905
906
907
908
909
            tool_calls = []
            with contextlib.suppress(ValidationError):
                content = content or ""
                tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(
                    content
                )
            for tool_call in tool_calls:
                function_calls.append(
910
911
912
913
                    FunctionCall(
                        name=tool_call.name,
                        arguments=json.dumps(tool_call.parameters, ensure_ascii=False),
                    )
914
                )
915
916
917
918
919
920
            content = None  # Clear content since tool is called.
        elif (
            tool_parser_cls
            and enable_auto_tools
            and (request.tool_choice == "auto" or request.tool_choice is None)
        ):
921
922
923
924
925
            if tokenizer is None:
                raise ValueError(
                    "Tokenizer not available when `skip_tokenizer_init=True`"
                )

926
927
928
929
930
931
932
933
934
935
936
937
938
939
            # Automatic Tool Call Parsing
            try:
                tool_parser = tool_parser_cls(tokenizer)
            except RuntimeError as e:
                logger.exception("Error in tool parser creation.")
                raise e
            tool_call_info = tool_parser.extract_tool_calls(
                content if content is not None else "",
                request=request,  # type: ignore
            )
            if tool_call_info is not None and tool_call_info.tools_called:
                # extract_tool_calls() returns a list of tool calls.
                function_calls.extend(
                    FunctionCall(
940
                        id=tool_call.id,
941
942
943
944
945
946
                        name=tool_call.function.name,
                        arguments=tool_call.function.arguments,
                    )
                    for tool_call in tool_call_info.tool_calls
                )
                content = tool_call_info.content
947
948
                if content and content.strip() == "":
                    content = None
949
950
951
952
953
954
            else:
                # No tool calls.
                return None, content

        return function_calls, content

955
    @staticmethod
956
957
958
    def _get_decoded_token(
        logprob: Logprob,
        token_id: int,
959
        tokenizer: TokenizerLike | None,
960
961
        return_as_token_id: bool = False,
    ) -> str:
962
963
964
        if return_as_token_id:
            return f"token_id:{token_id}"

965
966
        if logprob.decoded_token is not None:
            return logprob.decoded_token
967
968
969
970
971
972

        if tokenizer is None:
            raise ValueError(
                "Unable to get tokenizer because `skip_tokenizer_init=True`"
            )

973
        return tokenizer.decode([token_id])
974

975
    def _is_model_supported(self, model_name: str | None) -> bool:
976
977
        if not model_name:
            return True
978
        return self.models.is_base_model(model_name)
979

980
981

def clamp_prompt_logprobs(
982
983
    prompt_logprobs: PromptLogprobs | None,
) -> PromptLogprobs | None:
984
985
986
987
988
989
990
    if prompt_logprobs is None:
        return prompt_logprobs

    for logprob_dict in prompt_logprobs:
        if logprob_dict is None:
            continue
        for logprob_values in logprob_dict.values():
991
            if logprob_values.logprob == float("-inf"):
992
993
                logprob_values.logprob = -9999.0
    return prompt_logprobs