"vllm/vscode:/vscode.git/clone" did not exist on "a263aa614060f8e6be52ed3de9995450d6c02892"
serving.py 35.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
import contextlib
5
import json
6
import time
7
from collections.abc import AsyncGenerator, Mapping
8
from dataclasses import dataclass, field
9
from http import HTTPStatus
10
from typing import Any, ClassVar, Generic, Protocol, TypeAlias, TypeVar
11

12
import numpy as np
13
from fastapi import Request
14
15
16
from openai.types.responses import (
    ToolChoiceFunction,
)
17
from pydantic import ConfigDict, TypeAdapter, ValidationError
18
from starlette.datastructures import Headers
19

20
import vllm.envs as envs
21
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
22
from vllm.config import ModelConfig
23
from vllm.engine.protocol import EngineClient
24
25
26
from vllm.entrypoints.chat_utils import (
    ChatTemplateContentFormatOption,
)
27
from vllm.entrypoints.logger import RequestLogger
28
from vllm.entrypoints.openai.chat_completion.protocol import (
29
    BatchChatCompletionRequest,
30
    ChatCompletionNamedToolChoiceParam,
31
32
    ChatCompletionRequest,
    ChatCompletionResponse,
33
)
34
from vllm.entrypoints.openai.completion.protocol import (
35
36
    CompletionRequest,
    CompletionResponse,
37
38
)
from vllm.entrypoints.openai.engine.protocol import (
39
    ErrorResponse,
40
    FunctionCall,
41
    FunctionDefinition,
42
    GenerationError,
43
)
44
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
45
46
47
from vllm.entrypoints.openai.responses.protocol import (
    ResponsesRequest,
)
48
from vllm.entrypoints.openai.speech_to_text.protocol import (
49
50
51
52
    TranscriptionRequest,
    TranscriptionResponse,
    TranslationRequest,
)
53
54
from vllm.entrypoints.pooling.pooling.protocol import (
    IOProcessorRequest,
55
56
    PoolingChatRequest,
    PoolingCompletionRequest,
57
58
59
60
    PoolingResponse,
)
from vllm.entrypoints.pooling.score.protocol import (
    RerankRequest,
61
62
    ScoreDataRequest,
    ScoreQueriesDocumentsRequest,
63
64
    ScoreRequest,
    ScoreResponse,
65
    ScoreTextRequest,
66
)
67
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
68
69
70
71
72
73
from vllm.entrypoints.serve.tokenize.protocol import (
    DetokenizeRequest,
    TokenizeChatRequest,
    TokenizeCompletionRequest,
    TokenizeResponse,
)
74
from vllm.entrypoints.utils import create_error_response
75
from vllm.exceptions import VLLMValidationError
76
from vllm.inputs import EngineInput, PromptType, TokensPrompt
77
from vllm.logger import init_logger
78
from vllm.logprobs import Logprob, PromptLogprobs
79
from vllm.lora.request import LoRARequest
80
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
81
from vllm.pooling_params import PoolingParams
82
from vllm.renderers import ChatParams, TokenizeParams
83
84
85
86
from vllm.renderers.inputs.preprocess import (
    extract_prompt_components,
    extract_prompt_len,
)
87
from vllm.sampling_params import BeamSearchParams, SamplingParams
88
from vllm.tokenizers import TokenizerLike
89
from vllm.tool_parsers import ToolParser
90
91
92
93
94
from vllm.tracing import (
    contains_trace_headers,
    extract_trace_headers,
    log_tracing_disabled_warning,
)
95
from vllm.utils import random_uuid
96
from vllm.utils.async_utils import (
97
    collect_from_async_generator,
98
99
    merge_async_iterators,
)
100
101
102

logger = init_logger(__name__)

103
104
105
106
107
108
109
110
111
112
113
114
115
116
117

class RendererRequest(Protocol):
    def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
        raise NotImplementedError


class RendererChatRequest(RendererRequest, Protocol):
    def build_chat_params(
        self,
        default_template: str | None,
        default_template_content_format: ChatTemplateContentFormatOption,
    ) -> ChatParams:
        raise NotImplementedError


118
119
CompletionLikeRequest: TypeAlias = (
    CompletionRequest
120
    | TokenizeCompletionRequest
121
    | DetokenizeRequest
122
    | RerankRequest
123
    | ScoreRequest
124
    | PoolingCompletionRequest
125
)
126

127
ChatLikeRequest: TypeAlias = (
128
129
130
131
    ChatCompletionRequest
    | BatchChatCompletionRequest
    | TokenizeChatRequest
    | PoolingChatRequest
132
)
133

134
SpeechToTextRequest: TypeAlias = TranscriptionRequest | TranslationRequest
135

136
137
138
139
140
141
AnyRequest: TypeAlias = (
    CompletionLikeRequest
    | ChatLikeRequest
    | SpeechToTextRequest
    | ResponsesRequest
    | IOProcessorRequest
142
    | GenerateRequest
143
144
145
146
147
148
149
150
151
)

AnyResponse: TypeAlias = (
    CompletionResponse
    | ChatCompletionResponse
    | TranscriptionResponse
    | TokenizeResponse
    | PoolingResponse
    | ScoreResponse
152
    | GenerateResponse
153
)
154
155
156
157

RequestT = TypeVar("RequestT", bound=AnyRequest)


158
@dataclass(kw_only=True)
159
class ServeContext(Generic[RequestT]):
160
    request: RequestT
161
    raw_request: Request | None = None
162
163
    model_name: str
    request_id: str
164
    created_time: int = field(default_factory=lambda: int(time.time()))
165
    lora_request: LoRARequest | None = None
166
    engine_inputs: list[EngineInput] | None = None
167

168
169
170
171
    result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
        None
    )
    final_res_batch: list[PoolingRequestOutput] = field(default_factory=list)
172

173
    model_config = ConfigDict(arbitrary_types_allowed=True)
174
175


176
class OpenAIServing:
177
    request_id_prefix: ClassVar[str] = """
178
    A short string prepended to every request’s ID.
179
    """
180

181
182
    def __init__(
        self,
183
        engine_client: EngineClient,
184
        models: OpenAIServingModels,
185
        *,
186
        request_logger: RequestLogger | None,
187
        return_tokens_as_token_ids: bool = False,
188
    ):
189
190
        super().__init__()

191
        self.engine_client = engine_client
192

193
        self.models = models
194

195
        self.request_logger = request_logger
196
        self.return_tokens_as_token_ids = return_tokens_as_token_ids
197

198
199
200
201
        self.model_config = engine_client.model_config
        self.renderer = engine_client.renderer
        self.io_processor = engine_client.io_processor
        self.input_processor = engine_client.input_processor
202
203
204

    async def beam_search(
        self,
205
        prompt: EngineInput,
206
207
        request_id: str,
        params: BeamSearchParams,
208
        lora_request: LoRARequest | None = None,
209
        trace_headers: Mapping[str, str] | None = None,
210
211
212
213
214
215
216
217
    ) -> AsyncGenerator[RequestOutput, None]:
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        ignore_eos = params.ignore_eos
        temperature = params.temperature
        length_penalty = params.length_penalty
        include_stop_str_in_output = params.include_stop_str_in_output

218
219
220
        tokenizer = self.renderer.get_tokenizer()
        eos_token_id = tokenizer.eos_token_id
        sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
221

222
223
        if prompt["type"] == "embeds":
            raise NotImplementedError("Embedding prompt not supported for beam search")
224

225
226
227
228
229
230
231
        # Extract prompt tokens and text based on model type
        decoder_prompt = (
            prompt if prompt["type"] != "enc_dec" else prompt["decoder_prompt"]
        )
        prompt_text = decoder_prompt.get("prompt")
        prompt_token_ids = decoder_prompt["prompt_token_ids"]

232
233
        tokenized_length = len(prompt_token_ids)

234
        logprobs_num = 2 * beam_width
235
        sampling_params = SamplingParams(
236
            logprobs=logprobs_num,
237
238
239
240
241
            max_tokens=1,
            temperature=temperature,
        )
        all_beams = [
            BeamSearchSequence(
242
                orig_prompt=prompt,
243
244
245
246
247
248
249
250
251
252
253
254
                tokens=prompt_token_ids,
                cum_logprob=0,
                logprobs=[],
                lora_request=lora_request,
            )
        ]
        completed = []

        for _ in range(max_tokens):
            tasks = []
            request_id_batch = f"{request_id}-{random_uuid()}"

255
256
257
            for i, beam in enumerate(all_beams):
                prompt_item = beam.get_prompt()
                lora_request_item = beam.lora_request
258
259
260
261
                request_id_item = f"{request_id_batch}-beam-{i}"
                task = asyncio.create_task(
                    collect_from_async_generator(
                        self.engine_client.generate(
262
263
                            prompt_item,
                            sampling_params,
264
                            request_id_item,
265
                            lora_request=lora_request_item,
266
                            trace_headers=trace_headers,
267
268
269
270
271
272
273
274
                        )
                    )
                )
                tasks.append(task)

            output = [x[0] for x in await asyncio.gather(*tasks)]

            new_beams = []
275
276
277
278
279
280
281
282
            # Store all new tokens generated by beam
            all_beams_token_id = []
            # Store the cumulative probability of all tokens
            # generated by beam search
            all_beams_logprob = []
            # Iterate through all beam inference results
            for i, result in enumerate(output):
                current_beam = all_beams[i]
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305

                # check for error finish reason and abort beam search
                if result.outputs[0].finish_reason == "error":
                    # yield error output and terminate beam search
                    yield RequestOutput(
                        request_id=request_id,
                        prompt=prompt_text,
                        outputs=[
                            CompletionOutput(
                                index=0,
                                text="",
                                token_ids=[],
                                cumulative_logprob=None,
                                logprobs=None,
                                finish_reason="error",
                            )
                        ],
                        finished=True,
                        prompt_token_ids=prompt_token_ids,
                        prompt_logprobs=None,
                    )
                    return

306
307
                if result.outputs[0].logprobs is not None:
                    logprobs = result.outputs[0].logprobs[0]
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
                    all_beams_token_id.extend(list(logprobs.keys()))
                    all_beams_logprob.extend(
                        [
                            current_beam.cum_logprob + obj.logprob
                            for obj in logprobs.values()
                        ]
                    )

            # Handle the token for the end of sentence (EOS)
            all_beams_token_id = np.array(all_beams_token_id)
            all_beams_logprob = np.array(all_beams_logprob)

            if not ignore_eos:
                # Get the index position of eos token in all generated results
                eos_idx = np.where(all_beams_token_id == eos_token_id)[0]
                for idx in eos_idx:
                    current_beam = all_beams[idx // logprobs_num]
                    result = output[idx // logprobs_num]
                    assert result.outputs[0].logprobs is not None
                    logprobs_entry = result.outputs[0].logprobs[0]
                    completed.append(
                        BeamSearchSequence(
330
                            orig_prompt=prompt,
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
                            tokens=current_beam.tokens + [eos_token_id]
                            if include_stop_str_in_output
                            else current_beam.tokens,
                            logprobs=current_beam.logprobs + [logprobs_entry],
                            cum_logprob=float(all_beams_logprob[idx]),
                            finish_reason="stop",
                            stop_reason=eos_token_id,
                        )
                    )
                # After processing, set the log probability of the eos condition
                # to negative infinity.
                all_beams_logprob[eos_idx] = -np.inf

            # Processing non-EOS tokens
            # Get indices of the top beam_width probabilities
            topn_idx = np.argpartition(np.negative(all_beams_logprob), beam_width)[
                :beam_width
            ]

            for idx in topn_idx:
                current_beam = all_beams[idx // logprobs_num]
                result = output[idx // logprobs_num]
                token_id = int(all_beams_token_id[idx])
                assert result.outputs[0].logprobs is not None
                logprobs_entry = result.outputs[0].logprobs[0]
                new_beams.append(
                    BeamSearchSequence(
358
                        orig_prompt=prompt,
359
360
361
362
363
364
365
366
                        tokens=current_beam.tokens + [token_id],
                        logprobs=current_beam.logprobs + [logprobs_entry],
                        lora_request=current_beam.lora_request,
                        cum_logprob=float(all_beams_logprob[idx]),
                    )
                )

            all_beams = new_beams
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400

        completed.extend(all_beams)
        sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
        best_beams = sorted_completed[:beam_width]

        for beam in best_beams:
            if beam.tokens[-1] == eos_token_id and not ignore_eos:
                # Skip the eos token in the text.
                tokens = beam.tokens[tokenized_length:-1]
            else:
                tokens = beam.tokens[tokenized_length:]
            beam.text = tokenizer.decode(tokens)

        yield RequestOutput(
            request_id=request_id,
            prompt=prompt_text,
            outputs=[
                CompletionOutput(
                    text=beam.text,  # type: ignore
                    cumulative_logprob=beam.cum_logprob,
                    token_ids=beam.tokens[tokenized_length:],
                    index=i,
                    logprobs=beam.logprobs,
                    finish_reason=beam.finish_reason
                    if beam.finish_reason is not None
                    else "length",
                    stop_reason=beam.stop_reason,
                )
                for (i, beam) in enumerate(best_beams)
            ],
            finished=True,
            prompt_token_ids=prompt_token_ids,
            prompt_logprobs=None,
        )
401

402
403
404
    async def _preprocess(
        self,
        ctx: ServeContext,
405
    ) -> ErrorResponse | None:
406
        """
407
        Default preprocessing hook. Subclasses may override to prepare `ctx`.
408
409
410
411
412
413
        """
        return None

    def _build_response(
        self,
        ctx: ServeContext,
414
    ) -> AnyResponse | ErrorResponse:
415
416
417
418
419
420
421
422
423
        """
        Default response builder. Subclass may override this method
        to return the appropriate response object.
        """
        return self.create_error_response("unimplemented endpoint")

    async def handle(
        self,
        ctx: ServeContext,
424
    ) -> AnyResponse | ErrorResponse:
425
        async for response in self._pipeline(ctx):
426
427
428
429
430
431
432
            return response

        return self.create_error_response("No response yielded from pipeline")

    async def _pipeline(
        self,
        ctx: ServeContext,
433
    ) -> AsyncGenerator[AnyResponse | ErrorResponse, None]:
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
        """Execute the request processing pipeline yielding responses."""
        if error := await self._check_model(ctx.request):
            yield error
        if error := self._validate_request(ctx):
            yield error

        preprocess_ret = await self._preprocess(ctx)
        if isinstance(preprocess_ret, ErrorResponse):
            yield preprocess_ret

        generators_ret = await self._prepare_generators(ctx)
        if isinstance(generators_ret, ErrorResponse):
            yield generators_ret

        collect_ret = await self._collect_batch(ctx)
        if isinstance(collect_ret, ErrorResponse):
            yield collect_ret

        yield self._build_response(ctx)

454
    def _validate_request(self, ctx: ServeContext) -> ErrorResponse | None:
455
        truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None)
456

457
458
        if (
            truncate_prompt_tokens is not None
459
            and truncate_prompt_tokens > self.model_config.max_model_len
460
        ):
461
462
463
            return self.create_error_response(
                "truncate_prompt_tokens value is "
                "greater than max_model_len."
464
                " Please request a smaller truncation size."
465
            )
466
467
        return None

468
469
470
    def _create_pooling_params(
        self,
        ctx: ServeContext,
471
    ) -> PoolingParams | ErrorResponse:
472
473
        if not hasattr(ctx.request, "to_pooling_params"):
            return self.create_error_response(
474
475
                "Request type does not support pooling parameters"
            )
476
477
478

        return ctx.request.to_pooling_params()

479
480
481
    async def _prepare_generators(
        self,
        ctx: ServeContext,
482
    ) -> ErrorResponse | None:
483
        """Schedule the request and get the result generator."""
484
        generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
485

486
487
488
489
490
        trace_headers = (
            None
            if ctx.raw_request is None
            else await self._get_trace_headers(ctx.raw_request.headers)
        )
491

492
493
494
        pooling_params = self._create_pooling_params(ctx)
        if isinstance(pooling_params, ErrorResponse):
            return pooling_params
495

496
        if ctx.engine_inputs is None:
497
            return self.create_error_response("Engine prompts not available")
498

499
        for i, engine_input in enumerate(ctx.engine_inputs):
500
            request_id_item = f"{ctx.request_id}-{i}"
501

502
503
            self._log_inputs(
                request_id_item,
504
                engine_input,
505
506
507
                params=pooling_params,
                lora_request=ctx.lora_request,
            )
508

509
            generator = self.engine_client.encode(
510
                engine_input,
511
512
513
514
515
516
                pooling_params,
                request_id_item,
                lora_request=ctx.lora_request,
                trace_headers=trace_headers,
                priority=getattr(ctx.request, "priority", 0),
            )
517

518
            generators.append(generator)
519

520
        ctx.result_generator = merge_async_iterators(*generators)
521

522
        return None
523
524
525
526

    async def _collect_batch(
        self,
        ctx: ServeContext,
527
    ) -> ErrorResponse | None:
528
        """Collect batch results from the result generator."""
529
        if ctx.engine_inputs is None:
530
            return self.create_error_response("Engine prompts not available")
531

532
        num_prompts = len(ctx.engine_inputs)
533
534
        final_res_batch: list[PoolingRequestOutput | None]
        final_res_batch = [None] * num_prompts
535

536
537
        if ctx.result_generator is None:
            return self.create_error_response("Result generator not available")
538

539
540
        async for i, res in ctx.result_generator:
            final_res_batch[i] = res
541

542
543
544
545
        if None in final_res_batch:
            return self.create_error_response(
                "Failed to generate results for all prompts"
            )
546

547
        ctx.final_res_batch = [res for res in final_res_batch if res is not None]
548

549
        return None
550

551
    @staticmethod
552
    def create_error_response(
553
        message: str | Exception,
554
555
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
556
        param: str | None = None,
557
    ) -> ErrorResponse:
558
        return create_error_response(message, err_type, status_code, param)
559

560
    def create_streaming_error_response(
561
        self,
562
        message: str | Exception,
563
564
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
565
        param: str | None = None,
566
    ) -> str:
567
        json_str = json.dumps(
568
            self.create_error_response(
569
570
571
572
                message=message,
                err_type=err_type,
                status_code=status_code,
                param=param,
573
574
            ).model_dump()
        )
575
576
        return json_str

577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
    def _raise_if_error(self, finish_reason: str | None, request_id: str) -> None:
        """Raise GenerationError if finish_reason indicates an error."""
        if finish_reason == "error":
            logger.error(
                "Request %s failed with an internal error during generation",
                request_id,
            )
            raise GenerationError("Internal server error")

    def _convert_generation_error_to_streaming_response(
        self, e: GenerationError
    ) -> str:
        """Convert GenerationError to streaming error response."""
        return self.create_streaming_error_response(
            str(e),
            err_type="InternalServerError",
            status_code=e.status_code,
        )

596
    async def _check_model(
597
598
        self,
        request: AnyRequest,
599
    ) -> ErrorResponse | None:
600
601
        error_response = None

602
        if self._is_model_supported(request.model):
603
            return None
604
        if request.model in self.models.lora_requests:
605
            return None
606
607
608
609
610
        if (
            envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING
            and request.model
            and (load_result := await self.models.resolve_lora(request.model))
        ):
611
612
            if isinstance(load_result, LoRARequest):
                return None
613
614
615
616
            if (
                isinstance(load_result, ErrorResponse)
                and load_result.error.code == HTTPStatus.BAD_REQUEST.value
            ):
617
618
619
                error_response = load_result

        return error_response or self.create_error_response(
620
621
            message=f"The model `{request.model}` does not exist.",
            err_type="NotFoundError",
622
            status_code=HTTPStatus.NOT_FOUND,
623
            param="model",
624
        )
625

626
    def _get_active_default_mm_loras(self, request: AnyRequest) -> LoRARequest | None:
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
        """Determine if there are any active default multimodal loras."""
        # TODO: Currently this is only enabled for chat completions
        # to be better aligned with only being enabled for .generate
        # when run offline. It would be nice to support additional
        # tasks types in the future.
        message_types = self._get_message_types(request)
        default_mm_loras = set()

        for lora in self.models.lora_requests.values():
            # Best effort match for default multimodal lora adapters;
            # There is probably a better way to do this, but currently
            # this matches against the set of 'types' in any content lists
            # up until '_', e.g., to match audio_url -> audio
            if lora.lora_name in message_types:
                default_mm_loras.add(lora)

        # Currently only support default modality specific loras if
        # we have exactly one lora matched on the request.
        if len(default_mm_loras) == 1:
            return default_mm_loras.pop()
        return None

649
    def _maybe_get_adapters(
650
651
652
        self,
        request: AnyRequest,
        supports_default_mm_loras: bool = False,
653
    ) -> LoRARequest | None:
654
        if request.model in self.models.lora_requests:
655
            return self.models.lora_requests[request.model]
656
657
658
659
660
661

        # Currently only support default modality specific loras
        # if we have exactly one lora matched on the request.
        if supports_default_mm_loras:
            default_mm_lora = self._get_active_default_mm_loras(request)
            if default_mm_lora is not None:
662
                return default_mm_lora
663
664

        if self._is_model_supported(request.model):
665
            return None
666

667
        # if _check_model has been called earlier, this will be unreachable
668
        raise ValueError(f"The model `{request.model}` does not exist.")
669

670
671
672
673
674
675
676
677
678
679
    def _get_message_types(self, request: AnyRequest) -> set[str]:
        """Retrieve the set of types from message content dicts up
        until `_`; we use this to match potential multimodal data
        with default per modality loras.
        """
        message_types: set[str] = set()

        if not hasattr(request, "messages"):
            return message_types

680
681
682
683
684
        messages = request.messages
        if messages is None or isinstance(messages, (str, bytes)):
            return message_types

        for message in messages:
685
686
687
688
689
            if (
                isinstance(message, dict)
                and "content" in message
                and isinstance(message["content"], list)
            ):
690
691
692
693
694
                for content_dict in message["content"]:
                    if "type" in content_dict:
                        message_types.add(content_dict["type"].split("_")[0])
        return message_types

695
696
    def _validate_input(
        self,
697
        request: object,
698
        input_ids: list[int],
699
        input_text: str,
700
    ) -> TokensPrompt:
701
        token_num = len(input_ids)
702
        max_model_len = self.model_config.max_model_len
703

704
        # Note: ScoreRequest doesn't have max_tokens
705
        if isinstance(
706
            request,
707
            (
708
709
710
                ScoreDataRequest,
                ScoreTextRequest,
                ScoreQueriesDocumentsRequest,
711
712
713
                RerankRequest,
            ),
        ):
714
715
            # Note: input length can be up to the entire model context length
            # since these requests don't generate tokens.
716
            if token_num > max_model_len:
717
                operations: dict[type[AnyRequest], str] = {
718
719
720
                    ScoreDataRequest: "score",
                    ScoreTextRequest: "score",
                    ScoreQueriesDocumentsRequest: "score",
721
                }
722
                operation = operations.get(type(request), "embedding generation")
723
                raise VLLMValidationError(
724
                    f"This model's maximum context length is "
725
                    f"{max_model_len} tokens. However, you requested "
726
                    f"{token_num} tokens in the input for {operation}. "
727
                    f"Please reduce the length of the input prompt.",
728
729
                    parameter="input_tokens",
                    value=token_num,
730
                )
731
            return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
732

733
734
        # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
        # and does not require model context length validation
735
        if isinstance(
736
737
            request,
            (TokenizeCompletionRequest, TokenizeChatRequest, DetokenizeRequest),
738
        ):
739
            return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
740

741
742
743
744
745
        # chat completion endpoint supports max_completion_tokens
        if isinstance(request, ChatCompletionRequest):
            # TODO(#9845): remove max_tokens when field dropped from OpenAI API
            max_tokens = request.max_completion_tokens or request.max_tokens
        else:
746
            max_tokens = getattr(request, "max_tokens", None)
747
748
749

        # Note: input length can be up to model context length - 1 for
        # completion-like requests.
750
        if token_num >= max_model_len:
751
            raise VLLMValidationError(
752
                f"This model's maximum context length is "
753
                f"{max_model_len} tokens. However, your request has "
754
                f"{token_num} input tokens. Please reduce the length of "
755
756
757
                "the input messages.",
                parameter="input_tokens",
                value=token_num,
758
            )
759

760
        if max_tokens is not None and token_num + max_tokens > max_model_len:
761
            raise VLLMValidationError(
762
763
764
765
766
767
768
769
770
                f"This model's maximum context length is "
                f"{max_model_len} tokens. However, you requested "
                f"{max_tokens} output tokens and your prompt contains "
                f"{token_num} input tokens, for a total of "
                f"{token_num + max_tokens} tokens "
                f"({token_num} + {max_tokens} = "
                f"{token_num + max_tokens} > {max_model_len}). "
                f"Please reduce the length of the input prompt or the "
                f"number of requested output tokens.",
771
772
                parameter="max_tokens",
                value=max_tokens,
773
            )
774

775
        return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
776

777
778
    def _validate_chat_template(
        self,
779
780
        request_chat_template: str | None,
        chat_template_kwargs: dict[str, Any] | None,
781
        trust_request_chat_template: bool,
782
    ) -> ErrorResponse | None:
783
        if not trust_request_chat_template and (
784
785
786
787
788
789
            request_chat_template is not None
            or (
                chat_template_kwargs
                and chat_template_kwargs.get("chat_template") is not None
            )
        ):
790
791
792
            return self.create_error_response(
                "Chat template is passed with request, but "
                "--trust-request-chat-template is not set. "
793
794
                "Refused request with untrusted chat template."
            )
795
796
        return None

797
798
799
800
801
802
803
804
805
806
807
808
    @staticmethod
    def _prepare_extra_chat_template_kwargs(
        request_chat_template_kwargs: dict[str, Any] | None = None,
        default_chat_template_kwargs: dict[str, Any] | None = None,
    ) -> dict[str, Any]:
        """Helper to merge server-default and request-specific chat template kwargs."""
        request_chat_template_kwargs = request_chat_template_kwargs or {}
        if default_chat_template_kwargs is None:
            return request_chat_template_kwargs
        # Apply server defaults first, then request kwargs override.
        return default_chat_template_kwargs | request_chat_template_kwargs

809
    def _extract_prompt_components(self, prompt: PromptType | EngineInput):
810
811
        return extract_prompt_components(self.model_config, prompt)

812
    def _extract_prompt_text(self, prompt: PromptType | EngineInput):
813
814
        return self._extract_prompt_components(prompt).text

815
    def _extract_prompt_len(self, prompt: EngineInput):
816
817
        return extract_prompt_len(self.model_config, prompt)

818
819
820
    def _log_inputs(
        self,
        request_id: str,
821
        inputs: PromptType | EngineInput,
822
823
        params: SamplingParams | PoolingParams | BeamSearchParams | None,
        lora_request: LoRARequest | None,
824
825
826
    ) -> None:
        if self.request_logger is None:
            return
827

828
        components = self._extract_prompt_components(inputs)
829
830
831

        self.request_logger.log_inputs(
            request_id,
832
833
834
            components.text,
            components.token_ids,
            components.embeds,
835
836
837
            params=params,
            lora_request=lora_request,
        )
838

839
840
841
    async def _get_trace_headers(
        self,
        headers: Headers,
842
    ) -> Mapping[str, str] | None:
843
844
845
846
847
848
849
850
851
852
        is_tracing_enabled = await self.engine_client.is_tracing_enabled()

        if is_tracing_enabled:
            return extract_trace_headers(headers)

        if contains_trace_headers(headers):
            log_tracing_disabled_warning()

        return None

853
    @staticmethod
854
    def _base_request_id(
855
856
        raw_request: Request | None, default: str | None = None
    ) -> str | None:
857
        """Pulls the request id to use from a header, if provided"""
858
859
860
861
        if raw_request is not None and (
            (req_id := raw_request.headers.get("X-Request-Id")) is not None
        ):
            return req_id
862

863
        return random_uuid() if default is None else default
864

865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
    @staticmethod
    def _get_data_parallel_rank(raw_request: Request | None) -> int | None:
        """Pulls the data parallel rank from a header, if provided"""
        if raw_request is None:
            return None

        rank_str = raw_request.headers.get("X-data-parallel-rank")
        if rank_str is None:
            return None

        try:
            return int(rank_str)
        except ValueError:
            return None

880
881
882
    @staticmethod
    def _parse_tool_calls_from_content(
        request: ResponsesRequest | ChatCompletionRequest,
883
        tokenizer: TokenizerLike | None,
884
        enable_auto_tools: bool,
885
        tool_parser_cls: type[ToolParser] | None,
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
        content: str | None = None,
    ) -> tuple[list[FunctionCall] | None, str | None]:
        function_calls = list[FunctionCall]()
        if request.tool_choice and isinstance(request.tool_choice, ToolChoiceFunction):
            assert content is not None
            # Forced Function Call
            function_calls.append(
                FunctionCall(name=request.tool_choice.name, arguments=content)
            )
            content = None  # Clear content since tool is called.
        elif request.tool_choice and isinstance(
            request.tool_choice, ChatCompletionNamedToolChoiceParam
        ):
            assert content is not None
            # Forced Function Call
            function_calls.append(
                FunctionCall(name=request.tool_choice.function.name, arguments=content)
            )
            content = None  # Clear content since tool is called.
        elif request.tool_choice == "required":
906
907
908
909
910
911
912
913
            tool_calls = []
            with contextlib.suppress(ValidationError):
                content = content or ""
                tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(
                    content
                )
            for tool_call in tool_calls:
                function_calls.append(
914
915
916
917
                    FunctionCall(
                        name=tool_call.name,
                        arguments=json.dumps(tool_call.parameters, ensure_ascii=False),
                    )
918
                )
919
920
921
922
923
924
            content = None  # Clear content since tool is called.
        elif (
            tool_parser_cls
            and enable_auto_tools
            and (request.tool_choice == "auto" or request.tool_choice is None)
        ):
925
926
927
928
929
            if tokenizer is None:
                raise ValueError(
                    "Tokenizer not available when `skip_tokenizer_init=True`"
                )

930
931
            # Automatic Tool Call Parsing
            try:
932
                tool_parser = tool_parser_cls(tokenizer, request.tools)
933
934
935
936
937
938
939
940
941
942
943
            except RuntimeError as e:
                logger.exception("Error in tool parser creation.")
                raise e
            tool_call_info = tool_parser.extract_tool_calls(
                content if content is not None else "",
                request=request,  # type: ignore
            )
            if tool_call_info is not None and tool_call_info.tools_called:
                # extract_tool_calls() returns a list of tool calls.
                function_calls.extend(
                    FunctionCall(
944
                        id=tool_call.id,
945
946
947
948
949
950
                        name=tool_call.function.name,
                        arguments=tool_call.function.arguments,
                    )
                    for tool_call in tool_call_info.tool_calls
                )
                content = tool_call_info.content
951
952
                if content and content.strip() == "":
                    content = None
953
954
955
956
957
958
            else:
                # No tool calls.
                return None, content

        return function_calls, content

959
    @staticmethod
960
961
962
    def _get_decoded_token(
        logprob: Logprob,
        token_id: int,
963
        tokenizer: TokenizerLike | None,
964
965
        return_as_token_id: bool = False,
    ) -> str:
966
967
968
        if return_as_token_id:
            return f"token_id:{token_id}"

969
970
        if logprob.decoded_token is not None:
            return logprob.decoded_token
971
972
973
974
975
976

        if tokenizer is None:
            raise ValueError(
                "Unable to get tokenizer because `skip_tokenizer_init=True`"
            )

977
        return tokenizer.decode([token_id])
978

979
    def _is_model_supported(self, model_name: str | None) -> bool:
980
981
        if not model_name:
            return True
982
        return self.models.is_base_model(model_name)
983

984
985

def clamp_prompt_logprobs(
986
987
    prompt_logprobs: PromptLogprobs | None,
) -> PromptLogprobs | None:
988
989
990
991
992
993
994
    if prompt_logprobs is None:
        return prompt_logprobs

    for logprob_dict in prompt_logprobs:
        if logprob_dict is None:
            continue
        for logprob_values in logprob_dict.values():
995
            if logprob_values.logprob == float("-inf"):
996
997
                logprob_values.logprob = -9999.0
    return prompt_logprobs