serving.py 46.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
import json
5
import sys
6
import time
7
import traceback
8
from collections.abc import AsyncGenerator, Callable, Mapping, Sequence
9
from dataclasses import dataclass, field
10
from http import HTTPStatus
11
from typing import Any, ClassVar, Generic, Protocol, TypeAlias, TypeVar
12

13
import numpy as np
14
from fastapi import Request
15
16
17
from openai.types.responses import (
    ToolChoiceFunction,
)
18
19
from pydantic import ConfigDict, TypeAdapter
from starlette.datastructures import Headers
20

21
import vllm.envs as envs
22
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
23
from vllm.config import ModelConfig
24
from vllm.engine.protocol import EngineClient
25
26
27
28
29
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ChatTemplateContentFormatOption,
    ConversationMessage,
)
30
from vllm.entrypoints.logger import RequestLogger
31
from vllm.entrypoints.openai.chat_completion.protocol import (
32
    ChatCompletionNamedToolChoiceParam,
33
34
    ChatCompletionRequest,
    ChatCompletionResponse,
35
)
36
from vllm.entrypoints.openai.completion.protocol import (
37
38
    CompletionRequest,
    CompletionResponse,
39
40
)
from vllm.entrypoints.openai.engine.protocol import (
41
42
    ErrorInfo,
    ErrorResponse,
43
    FunctionCall,
44
    FunctionDefinition,
45
)
46
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
47
48
49
50
51
52
from vllm.entrypoints.openai.responses.context import (
    ConversationContext,
    HarmonyContext,
    ParsableContext,
    StreamingHarmonyContext,
)
53
54
55
56
from vllm.entrypoints.openai.responses.protocol import (
    ResponseInputOutputItem,
    ResponsesRequest,
)
57
58
59
from vllm.entrypoints.openai.responses.utils import (
    construct_input_messages,
)
60
from vllm.entrypoints.openai.speech_to_text.protocol import (
61
62
63
64
    TranscriptionRequest,
    TranscriptionResponse,
    TranslationRequest,
)
65
66
67
68
69
70
from vllm.entrypoints.pooling.classify.protocol import (
    ClassificationChatRequest,
    ClassificationCompletionRequest,
    ClassificationResponse,
)
from vllm.entrypoints.pooling.embed.protocol import (
71
    EmbeddingBytesResponse,
72
73
74
75
76
77
    EmbeddingChatRequest,
    EmbeddingCompletionRequest,
    EmbeddingResponse,
)
from vllm.entrypoints.pooling.pooling.protocol import (
    IOProcessorRequest,
78
79
    PoolingChatRequest,
    PoolingCompletionRequest,
80
81
82
83
    PoolingResponse,
)
from vllm.entrypoints.pooling.score.protocol import (
    RerankRequest,
84
85
    ScoreDataRequest,
    ScoreQueriesDocumentsRequest,
86
87
    ScoreRequest,
    ScoreResponse,
88
    ScoreTextRequest,
89
)
90
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
91
92
93
94
95
96
from vllm.entrypoints.serve.tokenize.protocol import (
    DetokenizeRequest,
    TokenizeChatRequest,
    TokenizeCompletionRequest,
    TokenizeResponse,
)
97
from vllm.entrypoints.utils import get_max_tokens, sanitize_message
98
from vllm.exceptions import VLLMValidationError
99
100
101
102
103
104
105
from vllm.inputs.data import (
    ProcessorInputs,
    PromptType,
    SingletonPrompt,
    TokensPrompt,
    token_inputs,
)
106
from vllm.logger import init_logger
107
from vllm.logprobs import Logprob, PromptLogprobs
108
from vllm.lora.request import LoRARequest
109
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
110
from vllm.pooling_params import PoolingParams
111
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
112
113
114
115
116
117
from vllm.renderers.inputs.preprocess import (
    extract_prompt_components,
    extract_prompt_len,
    parse_model_prompt,
    prompt_to_seq,
)
118
from vllm.sampling_params import BeamSearchParams, SamplingParams
119
from vllm.tokenizers import TokenizerLike
120
from vllm.tool_parsers import ToolParser
121
122
123
124
125
from vllm.tracing import (
    contains_trace_headers,
    extract_trace_headers,
    log_tracing_disabled_warning,
)
126
from vllm.utils import random_uuid
127
from vllm.utils.async_utils import (
128
    collect_from_async_generator,
129
130
    merge_async_iterators,
)
131
from vllm.utils.mistral import is_mistral_tokenizer
132

133
134
135
136
137
138
139
140
141

class GenerationError(Exception):
    """raised when finish_reason indicates internal server error (500)"""

    def __init__(self, message: str = "Internal server error"):
        super().__init__(message)
        self.status_code = HTTPStatus.INTERNAL_SERVER_ERROR


142
143
logger = init_logger(__name__)

144
145
146
147
148
149
150
151
152
153
154
155
156
157
158

class RendererRequest(Protocol):
    def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
        raise NotImplementedError


class RendererChatRequest(RendererRequest, Protocol):
    def build_chat_params(
        self,
        default_template: str | None,
        default_template_content_format: ChatTemplateContentFormatOption,
    ) -> ChatParams:
        raise NotImplementedError


159
160
CompletionLikeRequest: TypeAlias = (
    CompletionRequest
161
    | TokenizeCompletionRequest
162
163
    | DetokenizeRequest
    | EmbeddingCompletionRequest
164
    | ClassificationCompletionRequest
165
    | RerankRequest
166
    | ScoreRequest
167
    | PoolingCompletionRequest
168
)
169

170
ChatLikeRequest: TypeAlias = (
171
172
    ChatCompletionRequest
    | TokenizeChatRequest
173
    | EmbeddingChatRequest
174
    | ClassificationChatRequest
175
    | PoolingChatRequest
176
)
177

178
SpeechToTextRequest: TypeAlias = TranscriptionRequest | TranslationRequest
179

180
181
182
183
184
185
AnyRequest: TypeAlias = (
    CompletionLikeRequest
    | ChatLikeRequest
    | SpeechToTextRequest
    | ResponsesRequest
    | IOProcessorRequest
186
    | GenerateRequest
187
188
189
190
191
192
)

AnyResponse: TypeAlias = (
    CompletionResponse
    | ChatCompletionResponse
    | EmbeddingResponse
193
    | EmbeddingBytesResponse
194
195
196
197
198
    | TranscriptionResponse
    | TokenizeResponse
    | PoolingResponse
    | ClassificationResponse
    | ScoreResponse
199
    | GenerateResponse
200
)
201

202

203
204
205
RequestT = TypeVar("RequestT", bound=AnyRequest)


206
@dataclass(kw_only=True)
207
class ServeContext(Generic[RequestT]):
208
    request: RequestT
209
    raw_request: Request | None = None
210
211
    model_name: str
    request_id: str
212
    created_time: int = field(default_factory=lambda: int(time.time()))
213
    lora_request: LoRARequest | None = None
214
    engine_prompts: list[ProcessorInputs] | None = None
215

216
217
218
219
    result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
        None
    )
    final_res_batch: list[PoolingRequestOutput] = field(default_factory=list)
220

221
    model_config = ConfigDict(arbitrary_types_allowed=True)
222
223


224
class OpenAIServing:
225
226
227
228
    request_id_prefix: ClassVar[str] = """
    A short string prepended to every request’s ID (e.g. "embd", "classify")
    so you can easily tell “this ID came from Embedding vs Classification.”
    """
229

230
231
    def __init__(
        self,
232
        engine_client: EngineClient,
233
        models: OpenAIServingModels,
234
        *,
235
        request_logger: RequestLogger | None,
236
        return_tokens_as_token_ids: bool = False,
237
        log_error_stack: bool = False,
238
    ):
239
240
        super().__init__()

241
        self.engine_client = engine_client
242

243
        self.models = models
244

245
        self.request_logger = request_logger
246
        self.return_tokens_as_token_ids = return_tokens_as_token_ids
247

248
        self.log_error_stack = log_error_stack
249

250
251
252
253
        self.model_config = engine_client.model_config
        self.renderer = engine_client.renderer
        self.io_processor = engine_client.io_processor
        self.input_processor = engine_client.input_processor
254
255
256

    async def beam_search(
        self,
257
        prompt: ProcessorInputs,
258
259
        request_id: str,
        params: BeamSearchParams,
260
        lora_request: LoRARequest | None = None,
261
        trace_headers: Mapping[str, str] | None = None,
262
263
264
265
266
267
268
269
    ) -> AsyncGenerator[RequestOutput, None]:
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        ignore_eos = params.ignore_eos
        temperature = params.temperature
        length_penalty = params.length_penalty
        include_stop_str_in_output = params.include_stop_str_in_output

270
271
272
        tokenizer = self.renderer.get_tokenizer()
        eos_token_id = tokenizer.eos_token_id
        sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
273

274
275
276
277
278
279
        if prompt["type"] == "embeds":
            raise NotImplementedError("Embedding prompt not supported for beam search")
        if prompt["type"] == "enc_dec":
            raise NotImplementedError(
                "Encoder-decoder prompt not supported for beam search"
            )
280

281
282
        prompt_text = prompt.get("prompt")
        prompt_token_ids = prompt["prompt_token_ids"]
283
284
        tokenized_length = len(prompt_token_ids)

285
        logprobs_num = 2 * beam_width
286
        sampling_params = SamplingParams(
287
            logprobs=logprobs_num,
288
289
290
291
292
            max_tokens=1,
            temperature=temperature,
        )
        all_beams = [
            BeamSearchSequence(
293
                orig_prompt=prompt,
294
295
296
297
298
299
300
301
302
303
304
305
                tokens=prompt_token_ids,
                cum_logprob=0,
                logprobs=[],
                lora_request=lora_request,
            )
        ]
        completed = []

        for _ in range(max_tokens):
            tasks = []
            request_id_batch = f"{request_id}-{random_uuid()}"

306
307
308
            for i, beam in enumerate(all_beams):
                prompt_item = beam.get_prompt()
                lora_request_item = beam.lora_request
309
310
311
312
                request_id_item = f"{request_id_batch}-beam-{i}"
                task = asyncio.create_task(
                    collect_from_async_generator(
                        self.engine_client.generate(
313
314
                            prompt_item,
                            sampling_params,
315
                            request_id_item,
316
                            lora_request=lora_request_item,
317
                            trace_headers=trace_headers,
318
319
320
321
322
323
324
325
                        )
                    )
                )
                tasks.append(task)

            output = [x[0] for x in await asyncio.gather(*tasks)]

            new_beams = []
326
327
328
329
330
331
332
333
            # Store all new tokens generated by beam
            all_beams_token_id = []
            # Store the cumulative probability of all tokens
            # generated by beam search
            all_beams_logprob = []
            # Iterate through all beam inference results
            for i, result in enumerate(output):
                current_beam = all_beams[i]
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356

                # check for error finish reason and abort beam search
                if result.outputs[0].finish_reason == "error":
                    # yield error output and terminate beam search
                    yield RequestOutput(
                        request_id=request_id,
                        prompt=prompt_text,
                        outputs=[
                            CompletionOutput(
                                index=0,
                                text="",
                                token_ids=[],
                                cumulative_logprob=None,
                                logprobs=None,
                                finish_reason="error",
                            )
                        ],
                        finished=True,
                        prompt_token_ids=prompt_token_ids,
                        prompt_logprobs=None,
                    )
                    return

357
358
                if result.outputs[0].logprobs is not None:
                    logprobs = result.outputs[0].logprobs[0]
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
                    all_beams_token_id.extend(list(logprobs.keys()))
                    all_beams_logprob.extend(
                        [
                            current_beam.cum_logprob + obj.logprob
                            for obj in logprobs.values()
                        ]
                    )

            # Handle the token for the end of sentence (EOS)
            all_beams_token_id = np.array(all_beams_token_id)
            all_beams_logprob = np.array(all_beams_logprob)

            if not ignore_eos:
                # Get the index position of eos token in all generated results
                eos_idx = np.where(all_beams_token_id == eos_token_id)[0]
                for idx in eos_idx:
                    current_beam = all_beams[idx // logprobs_num]
                    result = output[idx // logprobs_num]
                    assert result.outputs[0].logprobs is not None
                    logprobs_entry = result.outputs[0].logprobs[0]
                    completed.append(
                        BeamSearchSequence(
381
                            orig_prompt=prompt,
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
                            tokens=current_beam.tokens + [eos_token_id]
                            if include_stop_str_in_output
                            else current_beam.tokens,
                            logprobs=current_beam.logprobs + [logprobs_entry],
                            cum_logprob=float(all_beams_logprob[idx]),
                            finish_reason="stop",
                            stop_reason=eos_token_id,
                        )
                    )
                # After processing, set the log probability of the eos condition
                # to negative infinity.
                all_beams_logprob[eos_idx] = -np.inf

            # Processing non-EOS tokens
            # Get indices of the top beam_width probabilities
            topn_idx = np.argpartition(np.negative(all_beams_logprob), beam_width)[
                :beam_width
            ]

            for idx in topn_idx:
                current_beam = all_beams[idx // logprobs_num]
                result = output[idx // logprobs_num]
                token_id = int(all_beams_token_id[idx])
                assert result.outputs[0].logprobs is not None
                logprobs_entry = result.outputs[0].logprobs[0]
                new_beams.append(
                    BeamSearchSequence(
409
                        orig_prompt=prompt,
410
411
412
413
414
415
416
417
                        tokens=current_beam.tokens + [token_id],
                        logprobs=current_beam.logprobs + [logprobs_entry],
                        lora_request=current_beam.lora_request,
                        cum_logprob=float(all_beams_logprob[idx]),
                    )
                )

            all_beams = new_beams
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451

        completed.extend(all_beams)
        sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
        best_beams = sorted_completed[:beam_width]

        for beam in best_beams:
            if beam.tokens[-1] == eos_token_id and not ignore_eos:
                # Skip the eos token in the text.
                tokens = beam.tokens[tokenized_length:-1]
            else:
                tokens = beam.tokens[tokenized_length:]
            beam.text = tokenizer.decode(tokens)

        yield RequestOutput(
            request_id=request_id,
            prompt=prompt_text,
            outputs=[
                CompletionOutput(
                    text=beam.text,  # type: ignore
                    cumulative_logprob=beam.cum_logprob,
                    token_ids=beam.tokens[tokenized_length:],
                    index=i,
                    logprobs=beam.logprobs,
                    finish_reason=beam.finish_reason
                    if beam.finish_reason is not None
                    else "length",
                    stop_reason=beam.stop_reason,
                )
                for (i, beam) in enumerate(best_beams)
            ],
            finished=True,
            prompt_token_ids=prompt_token_ids,
            prompt_logprobs=None,
        )
452

453
454
455
    async def _preprocess(
        self,
        ctx: ServeContext,
456
    ) -> ErrorResponse | None:
457
458
459
460
461
462
463
464
465
        """
        Default preprocessing hook. Subclasses may override
        to prepare `ctx` (classification, embedding, etc.).
        """
        return None

    def _build_response(
        self,
        ctx: ServeContext,
466
    ) -> AnyResponse | ErrorResponse:
467
468
469
470
471
472
473
474
475
        """
        Default response builder. Subclass may override this method
        to return the appropriate response object.
        """
        return self.create_error_response("unimplemented endpoint")

    async def handle(
        self,
        ctx: ServeContext,
476
    ) -> AnyResponse | ErrorResponse:
477
        async for response in self._pipeline(ctx):
478
479
480
481
482
483
484
            return response

        return self.create_error_response("No response yielded from pipeline")

    async def _pipeline(
        self,
        ctx: ServeContext,
485
    ) -> AsyncGenerator[AnyResponse | ErrorResponse, None]:
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
        """Execute the request processing pipeline yielding responses."""
        if error := await self._check_model(ctx.request):
            yield error
        if error := self._validate_request(ctx):
            yield error

        preprocess_ret = await self._preprocess(ctx)
        if isinstance(preprocess_ret, ErrorResponse):
            yield preprocess_ret

        generators_ret = await self._prepare_generators(ctx)
        if isinstance(generators_ret, ErrorResponse):
            yield generators_ret

        collect_ret = await self._collect_batch(ctx)
        if isinstance(collect_ret, ErrorResponse):
            yield collect_ret

        yield self._build_response(ctx)

506
    def _validate_request(self, ctx: ServeContext) -> ErrorResponse | None:
507
        truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None)
508

509
510
        if (
            truncate_prompt_tokens is not None
511
            and truncate_prompt_tokens > self.model_config.max_model_len
512
        ):
513
514
515
            return self.create_error_response(
                "truncate_prompt_tokens value is "
                "greater than max_model_len."
516
517
                " Please, select a smaller truncation size."
            )
518
519
        return None

520
521
522
    def _create_pooling_params(
        self,
        ctx: ServeContext,
523
    ) -> PoolingParams | ErrorResponse:
524
525
        if not hasattr(ctx.request, "to_pooling_params"):
            return self.create_error_response(
526
527
                "Request type does not support pooling parameters"
            )
528
529
530

        return ctx.request.to_pooling_params()

531
532
533
    async def _prepare_generators(
        self,
        ctx: ServeContext,
534
    ) -> ErrorResponse | None:
535
        """Schedule the request and get the result generator."""
536
        generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
537
538

        try:
539
540
541
542
543
            trace_headers = (
                None
                if ctx.raw_request is None
                else await self._get_trace_headers(ctx.raw_request.headers)
            )
544

545
546
547
            pooling_params = self._create_pooling_params(ctx)
            if isinstance(pooling_params, ErrorResponse):
                return pooling_params
548
549

            if ctx.engine_prompts is None:
550
                return self.create_error_response("Engine prompts not available")
551
552
553
554

            for i, engine_prompt in enumerate(ctx.engine_prompts):
                request_id_item = f"{ctx.request_id}-{i}"

555
556
                self._log_inputs(
                    request_id_item,
557
                    engine_prompt,
558
559
560
                    params=pooling_params,
                    lora_request=ctx.lora_request,
                )
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577

                generator = self.engine_client.encode(
                    engine_prompt,
                    pooling_params,
                    request_id_item,
                    lora_request=ctx.lora_request,
                    trace_headers=trace_headers,
                    priority=getattr(ctx.request, "priority", 0),
                )

                generators.append(generator)

            ctx.result_generator = merge_async_iterators(*generators)

            return None

        except Exception as e:
578
            return self.create_error_response(e)
579
580
581
582

    async def _collect_batch(
        self,
        ctx: ServeContext,
583
    ) -> ErrorResponse | None:
584
585
586
        """Collect batch results from the result generator."""
        try:
            if ctx.engine_prompts is None:
587
                return self.create_error_response("Engine prompts not available")
588
589

            num_prompts = len(ctx.engine_prompts)
590
            final_res_batch: list[PoolingRequestOutput | None]
591
592
593
            final_res_batch = [None] * num_prompts

            if ctx.result_generator is None:
594
                return self.create_error_response("Result generator not available")
595
596
597
598
599
600

            async for i, res in ctx.result_generator:
                final_res_batch[i] = res

            if None in final_res_batch:
                return self.create_error_response(
601
602
                    "Failed to generate results for all prompts"
                )
603

604
            ctx.final_res_batch = [res for res in final_res_batch if res is not None]
605
606
607
608

            return None

        except Exception as e:
609
            return self.create_error_response(e)
610

611
    def create_error_response(
612
        self,
613
        message: str | Exception,
614
615
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
616
        param: str | None = None,
617
    ) -> ErrorResponse:
618
619
620
621
622
        exc: Exception | None = None

        if isinstance(message, Exception):
            exc = message

623
            from vllm.exceptions import VLLMValidationError
624
625
626
627
628

            if isinstance(exc, VLLMValidationError):
                err_type = "BadRequestError"
                status_code = HTTPStatus.BAD_REQUEST
                param = exc.parameter
629
            elif isinstance(exc, (ValueError, TypeError, RuntimeError, OverflowError)):
630
631
632
633
                # Common validation errors from user input
                err_type = "BadRequestError"
                status_code = HTTPStatus.BAD_REQUEST
                param = None
634
635
636
637
            elif isinstance(exc, NotImplementedError):
                err_type = "NotImplementedError"
                status_code = HTTPStatus.NOT_IMPLEMENTED
                param = None
638
639
640
641
642
643
644
645
646
647
648
649
            elif exc.__class__.__name__ == "TemplateError":
                # jinja2.TemplateError (avoid importing jinja2)
                err_type = "BadRequestError"
                status_code = HTTPStatus.BAD_REQUEST
                param = None
            else:
                err_type = "InternalServerError"
                status_code = HTTPStatus.INTERNAL_SERVER_ERROR
                param = None

            message = str(exc)

650
651
652
653
654
655
        if self.log_error_stack:
            exc_type, _, _ = sys.exc_info()
            if exc_type is not None:
                traceback.print_exc()
            else:
                traceback.print_stack()
656

657
        return ErrorResponse(
658
            error=ErrorInfo(
659
                message=sanitize_message(message),
660
661
662
663
                type=err_type,
                code=status_code.value,
                param=param,
            )
664
        )
665

666
    def create_streaming_error_response(
667
        self,
668
        message: str | Exception,
669
670
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
671
        param: str | None = None,
672
    ) -> str:
673
        json_str = json.dumps(
674
            self.create_error_response(
675
676
677
678
                message=message,
                err_type=err_type,
                status_code=status_code,
                param=param,
679
680
            ).model_dump()
        )
681
682
        return json_str

683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
    def _raise_if_error(self, finish_reason: str | None, request_id: str) -> None:
        """Raise GenerationError if finish_reason indicates an error."""
        if finish_reason == "error":
            logger.error(
                "Request %s failed with an internal error during generation",
                request_id,
            )
            raise GenerationError("Internal server error")

    def _convert_generation_error_to_response(
        self, e: GenerationError
    ) -> ErrorResponse:
        """Convert GenerationError to ErrorResponse."""
        return self.create_error_response(
            str(e),
            err_type="InternalServerError",
            status_code=e.status_code,
        )

    def _convert_generation_error_to_streaming_response(
        self, e: GenerationError
    ) -> str:
        """Convert GenerationError to streaming error response."""
        return self.create_streaming_error_response(
            str(e),
            err_type="InternalServerError",
            status_code=e.status_code,
        )

712
    async def _check_model(
713
714
        self,
        request: AnyRequest,
715
    ) -> ErrorResponse | None:
716
717
        error_response = None

718
        if self._is_model_supported(request.model):
719
            return None
720
        if request.model in self.models.lora_requests:
721
            return None
722
723
724
725
726
        if (
            envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING
            and request.model
            and (load_result := await self.models.resolve_lora(request.model))
        ):
727
728
            if isinstance(load_result, LoRARequest):
                return None
729
730
731
732
            if (
                isinstance(load_result, ErrorResponse)
                and load_result.error.code == HTTPStatus.BAD_REQUEST.value
            ):
733
734
735
                error_response = load_result

        return error_response or self.create_error_response(
736
737
            message=f"The model `{request.model}` does not exist.",
            err_type="NotFoundError",
738
            status_code=HTTPStatus.NOT_FOUND,
739
            param="model",
740
        )
741

742
    def _get_active_default_mm_loras(self, request: AnyRequest) -> LoRARequest | None:
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
        """Determine if there are any active default multimodal loras."""
        # TODO: Currently this is only enabled for chat completions
        # to be better aligned with only being enabled for .generate
        # when run offline. It would be nice to support additional
        # tasks types in the future.
        message_types = self._get_message_types(request)
        default_mm_loras = set()

        for lora in self.models.lora_requests.values():
            # Best effort match for default multimodal lora adapters;
            # There is probably a better way to do this, but currently
            # this matches against the set of 'types' in any content lists
            # up until '_', e.g., to match audio_url -> audio
            if lora.lora_name in message_types:
                default_mm_loras.add(lora)

        # Currently only support default modality specific loras if
        # we have exactly one lora matched on the request.
        if len(default_mm_loras) == 1:
            return default_mm_loras.pop()
        return None

765
    def _maybe_get_adapters(
766
767
768
        self,
        request: AnyRequest,
        supports_default_mm_loras: bool = False,
769
    ) -> LoRARequest | None:
770
        if request.model in self.models.lora_requests:
771
            return self.models.lora_requests[request.model]
772
773
774
775
776
777

        # Currently only support default modality specific loras
        # if we have exactly one lora matched on the request.
        if supports_default_mm_loras:
            default_mm_lora = self._get_active_default_mm_loras(request)
            if default_mm_lora is not None:
778
                return default_mm_lora
779
780

        if self._is_model_supported(request.model):
781
            return None
782

783
        # if _check_model has been called earlier, this will be unreachable
784
        raise ValueError(f"The model `{request.model}` does not exist.")
785

786
787
788
789
790
791
792
793
794
795
    def _get_message_types(self, request: AnyRequest) -> set[str]:
        """Retrieve the set of types from message content dicts up
        until `_`; we use this to match potential multimodal data
        with default per modality loras.
        """
        message_types: set[str] = set()

        if not hasattr(request, "messages"):
            return message_types

796
797
798
799
800
        messages = request.messages
        if messages is None or isinstance(messages, (str, bytes)):
            return message_types

        for message in messages:
801
802
803
804
805
            if (
                isinstance(message, dict)
                and "content" in message
                and isinstance(message["content"], list)
            ):
806
807
808
809
810
                for content_dict in message["content"]:
                    if "type" in content_dict:
                        message_types.add(content_dict["type"].split("_")[0])
        return message_types

811
812
    def _validate_input(
        self,
813
        request: object,
814
        input_ids: list[int],
815
        input_text: str,
816
    ) -> TokensPrompt:
817
        token_num = len(input_ids)
818
        max_model_len = self.model_config.max_model_len
819

820
821
        # Note: EmbeddingRequest, ClassificationRequest,
        # and ScoreRequest doesn't have max_tokens
822
        if isinstance(
823
            request,
824
825
826
            (
                EmbeddingChatRequest,
                EmbeddingCompletionRequest,
827
828
829
                ScoreDataRequest,
                ScoreTextRequest,
                ScoreQueriesDocumentsRequest,
830
                RerankRequest,
831
832
                ClassificationCompletionRequest,
                ClassificationChatRequest,
833
834
            ),
        ):
835
836
            # Note: input length can be up to the entire model context length
            # since these requests don't generate tokens.
837
            if token_num > max_model_len:
838
                operations: dict[type[AnyRequest], str] = {
839
840
841
                    ScoreDataRequest: "score",
                    ScoreTextRequest: "score",
                    ScoreQueriesDocumentsRequest: "score",
842
843
                    ClassificationCompletionRequest: "classification",
                    ClassificationChatRequest: "classification",
844
                }
845
                operation = operations.get(type(request), "embedding generation")
846
                raise VLLMValidationError(
847
                    f"This model's maximum context length is "
848
                    f"{max_model_len} tokens. However, you requested "
849
                    f"{token_num} tokens in the input for {operation}. "
850
851
852
                    f"Please reduce the length of the input.",
                    parameter="input_tokens",
                    value=token_num,
853
                )
854
            return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
855

856
857
        # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
        # and does not require model context length validation
858
        if isinstance(
859
860
            request,
            (TokenizeCompletionRequest, TokenizeChatRequest, DetokenizeRequest),
861
        ):
862
            return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
863

864
865
866
867
868
        # chat completion endpoint supports max_completion_tokens
        if isinstance(request, ChatCompletionRequest):
            # TODO(#9845): remove max_tokens when field dropped from OpenAI API
            max_tokens = request.max_completion_tokens or request.max_tokens
        else:
869
            max_tokens = getattr(request, "max_tokens", None)
870
871
872

        # Note: input length can be up to model context length - 1 for
        # completion-like requests.
873
        if token_num >= max_model_len:
874
            raise VLLMValidationError(
875
                f"This model's maximum context length is "
876
                f"{max_model_len} tokens. However, your request has "
877
                f"{token_num} input tokens. Please reduce the length of "
878
879
880
                "the input messages.",
                parameter="input_tokens",
                value=token_num,
881
            )
882

883
        if max_tokens is not None and token_num + max_tokens > max_model_len:
884
            raise VLLMValidationError(
885
886
                "'max_tokens' or 'max_completion_tokens' is too large: "
                f"{max_tokens}. This model's maximum context length is "
887
888
                f"{max_model_len} tokens and your request has "
                f"{token_num} input tokens ({max_tokens} > {max_model_len}"
889
890
891
                f" - {token_num}).",
                parameter="max_tokens",
                value=max_tokens,
892
            )
893

894
        return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
895

896
897
    def _validate_chat_template(
        self,
898
899
        request_chat_template: str | None,
        chat_template_kwargs: dict[str, Any] | None,
900
        trust_request_chat_template: bool,
901
    ) -> ErrorResponse | None:
902
        if not trust_request_chat_template and (
903
904
905
906
907
908
            request_chat_template is not None
            or (
                chat_template_kwargs
                and chat_template_kwargs.get("chat_template") is not None
            )
        ):
909
910
911
            return self.create_error_response(
                "Chat template is passed with request, but "
                "--trust-request-chat-template is not set. "
912
913
                "Refused request with untrusted chat template."
            )
914
915
        return None

916
917
918
919
920
921
922
923
924
925
926
927
    @staticmethod
    def _prepare_extra_chat_template_kwargs(
        request_chat_template_kwargs: dict[str, Any] | None = None,
        default_chat_template_kwargs: dict[str, Any] | None = None,
    ) -> dict[str, Any]:
        """Helper to merge server-default and request-specific chat template kwargs."""
        request_chat_template_kwargs = request_chat_template_kwargs or {}
        if default_chat_template_kwargs is None:
            return request_chat_template_kwargs
        # Apply server defaults first, then request kwargs override.
        return default_chat_template_kwargs | request_chat_template_kwargs

928
929
930
931
932
    async def _preprocess_completion(
        self,
        request: RendererRequest,
        prompt_input: str | list[str] | list[int] | list[list[int]] | None,
        prompt_embeds: bytes | list[bytes] | None,
933
    ) -> list[ProcessorInputs]:
934
935
936
937
938
939
        prompts = list[SingletonPrompt | bytes]()
        if prompt_embeds is not None:  # embeds take higher priority
            prompts.extend(prompt_to_seq(prompt_embeds))
        if prompt_input is not None:
            prompts.extend(prompt_to_seq(prompt_input))

940
941
942
943
944
945
        return await self._preprocess_cmpl(request, prompts)

    async def _preprocess_cmpl(
        self,
        request: RendererRequest,
        prompts: Sequence[PromptType | bytes],
946
    ) -> list[ProcessorInputs]:
947
948
949
        renderer = self.renderer
        model_config = self.model_config

950
951
952
953
954
955
956
957
        parsed_prompts = [
            (
                prompt
                if isinstance(prompt, bytes)
                else parse_model_prompt(model_config, prompt)
            )
            for prompt in prompts
        ]
958
        tok_params = request.build_tok_params(model_config)
959

960
961
962
963
964
965
966
967
968
        return await renderer.render_cmpl_async(
            parsed_prompts,
            tok_params,
            prompt_extras={
                k: v
                for k in ("mm_processor_kwargs", "cache_salt")
                if (v := getattr(request, k, None)) is not None
            },
        )
969

970
971
    async def _preprocess_chat(
        self,
972
        request: RendererChatRequest,
973
        messages: list[ChatCompletionMessageParam],
974
975
976
        default_template: str | None,
        default_template_content_format: ChatTemplateContentFormatOption,
        default_template_kwargs: dict[str, Any] | None,
977
        tool_dicts: list[dict[str, Any]] | None = None,
978
        tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
979
    ) -> tuple[list[ConversationMessage], list[ProcessorInputs]]:
980
981
982
983
984
985
        renderer = self.renderer

        default_template_kwargs = merge_kwargs(
            default_template_kwargs,
            dict(
                tools=tool_dicts,
986
                tokenize=is_mistral_tokenizer(renderer.tokenizer),
987
988
989
            ),
        )

990
991
992
993
        tok_params = request.build_tok_params(self.model_config)
        chat_params = request.build_chat_params(
            default_template, default_template_content_format
        ).with_defaults(default_template_kwargs)
994

995
996
997
998
999
1000
1001
1002
1003
        (conversation,), (engine_prompt,) = await renderer.render_chat_async(
            [messages],
            chat_params,
            tok_params,
            prompt_extras={
                k: v
                for k in ("mm_processor_kwargs", "cache_salt")
                if (v := getattr(request, k, None)) is not None
            },
1004
        )
1005

1006
1007
1008
        # tool parsing is done only if a tool_parser has been set and if
        # tool_choice is not "none" (if tool_choice is "none" but a tool_parser
        # is set, we want to prevent parsing a tool_call hallucinated by the LLM
1009
1010
1011
1012
1013
1014
1015
1016
1017
        if tool_parser is not None:
            tool_choice = getattr(request, "tool_choice", "none")
            if tool_choice != "none":
                if not isinstance(request, ChatCompletionRequest | ResponsesRequest):
                    msg = (
                        "Tool usage is only supported for Chat Completions API "
                        "or Responses API requests."
                    )
                    raise NotImplementedError(msg)
1018

1019
1020
1021
                # TODO: Update adjust_request to accept ResponsesRequest
                tokenizer = renderer.get_tokenizer()
                request = tool_parser(tokenizer).adjust_request(request=request)  # type: ignore[arg-type]
1022

1023
        return conversation, [engine_prompt]
1024

1025
    def _extract_prompt_components(self, prompt: PromptType | ProcessorInputs):
1026
1027
        return extract_prompt_components(self.model_config, prompt)

1028
    def _extract_prompt_text(self, prompt: ProcessorInputs):
1029
1030
        return self._extract_prompt_components(prompt).text

1031
    def _extract_prompt_len(self, prompt: ProcessorInputs):
1032
1033
        return extract_prompt_len(self.model_config, prompt)

1034
1035
1036
1037
1038
    async def _render_next_turn(
        self,
        request: ResponsesRequest,
        messages: list[ResponseInputOutputItem],
        tool_dicts: list[dict[str, Any]] | None,
1039
        tool_parser: Callable[[TokenizerLike], ToolParser] | None,
1040
1041
1042
1043
1044
1045
1046
        chat_template: str | None,
        chat_template_content_format: ChatTemplateContentFormatOption,
    ):
        new_messages = construct_input_messages(
            request_input=messages,
        )

1047
        _, engine_prompts = await self._preprocess_chat(
1048
1049
            request,
            new_messages,
1050
1051
1052
            default_template=chat_template,
            default_template_content_format=chat_template_content_format,
            default_template_kwargs=None,
1053
1054
1055
            tool_dicts=tool_dicts,
            tool_parser=tool_parser,
        )
1056
        return engine_prompts
1057

1058
1059
1060
    async def _generate_with_builtin_tools(
        self,
        request_id: str,
1061
        engine_prompt: ProcessorInputs,
1062
1063
        sampling_params: SamplingParams,
        context: ConversationContext,
1064
        lora_request: LoRARequest | None = None,
1065
        priority: int = 0,
1066
        trace_headers: Mapping[str, str] | None = None,
1067
    ):
1068
        max_model_len = self.model_config.max_model_len
1069

1070
        orig_priority = priority
1071
        sub_request = 0
1072
        while True:
1073
1074
            # Ensure that each sub-request has a unique request id.
            sub_request_id = f"{request_id}_{sub_request}"
1075

1076
            self._log_inputs(
1077
                sub_request_id,
1078
                engine_prompt,
1079
1080
1081
                params=sampling_params,
                lora_request=lora_request,
            )
1082

1083
            generator = self.engine_client.generate(
1084
                engine_prompt,
1085
                sampling_params,
1086
                sub_request_id,
1087
                lora_request=lora_request,
1088
                trace_headers=trace_headers,
1089
1090
                priority=priority,
            )
1091

1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
            async for res in generator:
                context.append_output(res)
                # NOTE(woosuk): The stop condition is handled by the engine.
                yield context

            if not context.need_builtin_tool_call():
                # The model did not ask for a tool call, so we're done.
                break

            # Call the tool and update the context with the result.
            tool_output = await context.call_tool()
1103
            context.append_tool_output(tool_output)
1104
1105
1106
1107
1108

            # TODO: uncomment this and enable tool output streaming
            # yield context

            # Create inputs for the next turn.
1109
            # Render the next prompt token ids and update sampling_params.
1110
            if isinstance(context, (HarmonyContext, StreamingHarmonyContext)):
1111
                token_ids = context.render_for_completion()
1112
                engine_prompt = token_inputs(token_ids)
1113

1114
                sampling_params.max_tokens = max_model_len - len(token_ids)
1115
            elif isinstance(context, ParsableContext):
1116
                (engine_prompt,) = await self._render_next_turn(
1117
1118
1119
1120
1121
1122
1123
                    context.request,
                    context.parser.response_messages,
                    context.tool_dicts,
                    context.tool_parser_cls,
                    context.chat_template,
                    context.chat_template_content_format,
                )
1124
1125

                sampling_params.max_tokens = get_max_tokens(
1126
                    max_model_len,
1127
                    context.request.max_output_tokens,
1128
                    self._extract_prompt_len(engine_prompt),
1129
                    self.default_sampling_params,  # type: ignore
1130
                    self.override_max_tokens,  # type: ignore
1131
                )
1132

1133
1134
            # OPTIMIZATION
            priority = orig_priority - 1
1135
            sub_request += 1
1136

1137
1138
1139
    def _log_inputs(
        self,
        request_id: str,
1140
        inputs: PromptType | ProcessorInputs,
1141
1142
        params: SamplingParams | PoolingParams | BeamSearchParams | None,
        lora_request: LoRARequest | None,
1143
1144
1145
    ) -> None:
        if self.request_logger is None:
            return
1146

1147
        components = self._extract_prompt_components(inputs)
1148
1149
1150

        self.request_logger.log_inputs(
            request_id,
1151
1152
1153
            components.text,
            components.token_ids,
            components.embeds,
1154
1155
1156
            params=params,
            lora_request=lora_request,
        )
1157

1158
1159
1160
    async def _get_trace_headers(
        self,
        headers: Headers,
1161
    ) -> Mapping[str, str] | None:
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
        is_tracing_enabled = await self.engine_client.is_tracing_enabled()

        if is_tracing_enabled:
            return extract_trace_headers(headers)

        if contains_trace_headers(headers):
            log_tracing_disabled_warning()

        return None

1172
    @staticmethod
1173
    def _base_request_id(
1174
1175
        raw_request: Request | None, default: str | None = None
    ) -> str | None:
1176
        """Pulls the request id to use from a header, if provided"""
1177
1178
1179
1180
        if raw_request is not None and (
            (req_id := raw_request.headers.get("X-Request-Id")) is not None
        ):
            return req_id
1181

1182
        return random_uuid() if default is None else default
1183

1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
    @staticmethod
    def _get_data_parallel_rank(raw_request: Request | None) -> int | None:
        """Pulls the data parallel rank from a header, if provided"""
        if raw_request is None:
            return None

        rank_str = raw_request.headers.get("X-data-parallel-rank")
        if rank_str is None:
            return None

        try:
            return int(rank_str)
        except ValueError:
            return None

1199
1200
1201
    @staticmethod
    def _parse_tool_calls_from_content(
        request: ResponsesRequest | ChatCompletionRequest,
1202
        tokenizer: TokenizerLike | None,
1203
        enable_auto_tools: bool,
1204
        tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
        content: str | None = None,
    ) -> tuple[list[FunctionCall] | None, str | None]:
        function_calls = list[FunctionCall]()
        if request.tool_choice and isinstance(request.tool_choice, ToolChoiceFunction):
            assert content is not None
            # Forced Function Call
            function_calls.append(
                FunctionCall(name=request.tool_choice.name, arguments=content)
            )
            content = None  # Clear content since tool is called.
        elif request.tool_choice and isinstance(
            request.tool_choice, ChatCompletionNamedToolChoiceParam
        ):
            assert content is not None
            # Forced Function Call
            function_calls.append(
                FunctionCall(name=request.tool_choice.function.name, arguments=content)
            )
            content = None  # Clear content since tool is called.
        elif request.tool_choice == "required":
            assert content is not None
            tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(content)
            function_calls.extend(
                [
                    FunctionCall(
                        name=tool_call.name,
                        arguments=json.dumps(tool_call.parameters, ensure_ascii=False),
                    )
                    for tool_call in tool_calls
                ]
            )
            content = None  # Clear content since tool is called.
        elif (
            tool_parser_cls
            and enable_auto_tools
            and (request.tool_choice == "auto" or request.tool_choice is None)
        ):
1242
1243
1244
1245
1246
            if tokenizer is None:
                raise ValueError(
                    "Tokenizer not available when `skip_tokenizer_init=True`"
                )

1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
            # Automatic Tool Call Parsing
            try:
                tool_parser = tool_parser_cls(tokenizer)
            except RuntimeError as e:
                logger.exception("Error in tool parser creation.")
                raise e
            tool_call_info = tool_parser.extract_tool_calls(
                content if content is not None else "",
                request=request,  # type: ignore
            )
            if tool_call_info is not None and tool_call_info.tools_called:
                # extract_tool_calls() returns a list of tool calls.
                function_calls.extend(
                    FunctionCall(
1261
                        id=tool_call.id,
1262
1263
1264
1265
1266
1267
                        name=tool_call.function.name,
                        arguments=tool_call.function.arguments,
                    )
                    for tool_call in tool_call_info.tool_calls
                )
                content = tool_call_info.content
1268
1269
                if content and content.strip() == "":
                    content = None
1270
1271
1272
1273
1274
1275
            else:
                # No tool calls.
                return None, content

        return function_calls, content

1276
    @staticmethod
1277
1278
1279
    def _get_decoded_token(
        logprob: Logprob,
        token_id: int,
1280
        tokenizer: TokenizerLike | None,
1281
1282
        return_as_token_id: bool = False,
    ) -> str:
1283
1284
1285
        if return_as_token_id:
            return f"token_id:{token_id}"

1286
1287
        if logprob.decoded_token is not None:
            return logprob.decoded_token
1288
1289
1290
1291
1292
1293

        if tokenizer is None:
            raise ValueError(
                "Unable to get tokenizer because `skip_tokenizer_init=True`"
            )

1294
        return tokenizer.decode([token_id])
1295

1296
    def _is_model_supported(self, model_name: str | None) -> bool:
1297
1298
        if not model_name:
            return True
1299
        return self.models.is_base_model(model_name)
1300

1301
1302

def clamp_prompt_logprobs(
1303
1304
    prompt_logprobs: PromptLogprobs | None,
) -> PromptLogprobs | None:
1305
1306
1307
1308
1309
1310
1311
    if prompt_logprobs is None:
        return prompt_logprobs

    for logprob_dict in prompt_logprobs:
        if logprob_dict is None:
            continue
        for logprob_values in logprob_dict.values():
1312
            if logprob_values.logprob == float("-inf"):
1313
1314
                logprob_values.logprob = -9999.0
    return prompt_logprobs