serving.py 49.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
import json
5
import sys
6
import time
7
import traceback
8
from collections.abc import AsyncGenerator, Callable, Mapping
9
from dataclasses import dataclass, field
10
from http import HTTPStatus
11
from typing import Any, ClassVar, Generic, Protocol, TypeAlias, TypeVar
12

13
import numpy as np
14
from fastapi import Request
15
16
17
from openai.types.responses import (
    ToolChoiceFunction,
)
18
19
from pydantic import ConfigDict, TypeAdapter
from starlette.datastructures import Headers
20

21
import vllm.envs as envs
22
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
23
from vllm.config import ModelConfig
24
from vllm.engine.protocol import EngineClient
25
26
27
28
29
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ChatTemplateContentFormatOption,
    ConversationMessage,
)
30
from vllm.entrypoints.logger import RequestLogger
31
from vllm.entrypoints.openai.chat_completion.protocol import (
32
    ChatCompletionNamedToolChoiceParam,
33
34
    ChatCompletionRequest,
    ChatCompletionResponse,
35
)
36
from vllm.entrypoints.openai.completion.protocol import (
37
38
    CompletionRequest,
    CompletionResponse,
39
40
)
from vllm.entrypoints.openai.engine.protocol import (
41
42
    ErrorInfo,
    ErrorResponse,
43
    FunctionCall,
44
    FunctionDefinition,
45
)
46
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
47
48
49
50
51
52
from vllm.entrypoints.openai.responses.context import (
    ConversationContext,
    HarmonyContext,
    ParsableContext,
    StreamingHarmonyContext,
)
53
54
55
56
from vllm.entrypoints.openai.responses.protocol import (
    ResponseInputOutputItem,
    ResponsesRequest,
)
57
58
59
from vllm.entrypoints.openai.responses.utils import (
    construct_input_messages,
)
60
61
62
63
64
from vllm.entrypoints.openai.translations.protocol import (
    TranscriptionRequest,
    TranscriptionResponse,
    TranslationRequest,
)
65
66
67
68
69
70
from vllm.entrypoints.pooling.classify.protocol import (
    ClassificationChatRequest,
    ClassificationCompletionRequest,
    ClassificationResponse,
)
from vllm.entrypoints.pooling.embed.protocol import (
71
    EmbeddingBytesResponse,
72
73
74
75
76
77
    EmbeddingChatRequest,
    EmbeddingCompletionRequest,
    EmbeddingResponse,
)
from vllm.entrypoints.pooling.pooling.protocol import (
    IOProcessorRequest,
78
79
    PoolingChatRequest,
    PoolingCompletionRequest,
80
81
82
83
    PoolingResponse,
)
from vllm.entrypoints.pooling.score.protocol import (
    RerankRequest,
84
85
    ScoreDataRequest,
    ScoreQueriesDocumentsRequest,
86
87
    ScoreRequest,
    ScoreResponse,
88
    ScoreTextRequest,
89
)
90
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
91
92
93
94
95
96
from vllm.entrypoints.serve.tokenize.protocol import (
    DetokenizeRequest,
    TokenizeChatRequest,
    TokenizeCompletionRequest,
    TokenizeResponse,
)
97
from vllm.entrypoints.utils import get_max_tokens, sanitize_message
98
from vllm.exceptions import VLLMValidationError
99
from vllm.inputs.data import EmbedsPrompt, PromptType, TokensPrompt
100
101
102
103
from vllm.inputs.parse import (
    get_prompt_components,
    is_explicit_encoder_decoder_prompt,
)
104
from vllm.logger import init_logger
105
from vllm.logprobs import Logprob, PromptLogprobs
106
from vllm.lora.request import LoRARequest
107
from vllm.multimodal import MultiModalDataDict
108
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
109
from vllm.pooling_params import PoolingParams
110
from vllm.reasoning import ReasoningParser, ReasoningParserManager
111
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
112
from vllm.sampling_params import BeamSearchParams, SamplingParams
113
from vllm.tokenizers import TokenizerLike
114
from vllm.tool_parsers import ToolParser, ToolParserManager
115
116
117
118
119
from vllm.tracing import (
    contains_trace_headers,
    extract_trace_headers,
    log_tracing_disabled_warning,
)
120
from vllm.utils import random_uuid
121
from vllm.utils.async_utils import (
122
    collect_from_async_generator,
123
124
    merge_async_iterators,
)
125

126
127
128
129
130
131
132
133
134

class GenerationError(Exception):
    """raised when finish_reason indicates internal server error (500)"""

    def __init__(self, message: str = "Internal server error"):
        super().__init__(message)
        self.status_code = HTTPStatus.INTERNAL_SERVER_ERROR


135
136
logger = init_logger(__name__)

137
138
139
140
141
142
143
144
145
146
147
148
149
150
151

class RendererRequest(Protocol):
    def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
        raise NotImplementedError


class RendererChatRequest(RendererRequest, Protocol):
    def build_chat_params(
        self,
        default_template: str | None,
        default_template_content_format: ChatTemplateContentFormatOption,
    ) -> ChatParams:
        raise NotImplementedError


152
153
CompletionLikeRequest: TypeAlias = (
    CompletionRequest
154
    | TokenizeCompletionRequest
155
156
    | DetokenizeRequest
    | EmbeddingCompletionRequest
157
    | ClassificationCompletionRequest
158
    | RerankRequest
159
    | ScoreRequest
160
    | PoolingCompletionRequest
161
)
162

163
ChatLikeRequest: TypeAlias = (
164
165
    ChatCompletionRequest
    | TokenizeChatRequest
166
    | EmbeddingChatRequest
167
    | ClassificationChatRequest
168
    | PoolingChatRequest
169
)
170

171
SpeechToTextRequest: TypeAlias = TranscriptionRequest | TranslationRequest
172

173
174
175
176
177
178
AnyRequest: TypeAlias = (
    CompletionLikeRequest
    | ChatLikeRequest
    | SpeechToTextRequest
    | ResponsesRequest
    | IOProcessorRequest
179
    | GenerateRequest
180
181
182
183
184
185
)

AnyResponse: TypeAlias = (
    CompletionResponse
    | ChatCompletionResponse
    | EmbeddingResponse
186
    | EmbeddingBytesResponse
187
188
189
190
191
    | TranscriptionResponse
    | TokenizeResponse
    | PoolingResponse
    | ClassificationResponse
    | ScoreResponse
192
    | GenerateResponse
193
)
194

195

196
197
198
RequestT = TypeVar("RequestT", bound=AnyRequest)


199
@dataclass(kw_only=True)
200
class ServeContext(Generic[RequestT]):
201
    request: RequestT
202
    raw_request: Request | None = None
203
204
    model_name: str
    request_id: str
205
    created_time: int = field(default_factory=lambda: int(time.time()))
206
    lora_request: LoRARequest | None = None
207
    engine_prompts: list[TokensPrompt | EmbedsPrompt] | None = None
208

209
210
211
212
    result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
        None
    )
    final_res_batch: list[PoolingRequestOutput] = field(default_factory=list)
213

214
    model_config = ConfigDict(arbitrary_types_allowed=True)
215
216


217
class OpenAIServing:
218
219
220
221
    request_id_prefix: ClassVar[str] = """
    A short string prepended to every request’s ID (e.g. "embd", "classify")
    so you can easily tell “this ID came from Embedding vs Classification.”
    """
222

223
224
    def __init__(
        self,
225
        engine_client: EngineClient,
226
        models: OpenAIServingModels,
227
        *,
228
        request_logger: RequestLogger | None,
229
        return_tokens_as_token_ids: bool = False,
230
        log_error_stack: bool = False,
231
    ):
232
233
        super().__init__()

234
        self.engine_client = engine_client
235

236
        self.models = models
237

238
        self.request_logger = request_logger
239
        self.return_tokens_as_token_ids = return_tokens_as_token_ids
240

241
        self.log_error_stack = log_error_stack
242

243
        self.input_processor = self.models.input_processor
244
        self.io_processor = self.models.io_processor
245
        self.renderer = self.models.renderer
246
247
248
        self.model_config = self.models.model_config
        self.max_model_len = self.model_config.max_model_len

249
    def _get_tool_parser(
250
        self, tool_parser_name: str | None = None, enable_auto_tools: bool = False
251
    ) -> Callable[[TokenizerLike], ToolParser] | None:
252
253
254
255
        """Get the tool parser based on the name."""
        parser = None
        if not enable_auto_tools or tool_parser_name is None:
            return parser
256
        logger.info('"auto" tool choice has been enabled.')
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276

        try:
            if tool_parser_name == "pythonic" and self.model_config.model.startswith(
                "meta-llama/Llama-3.2"
            ):
                logger.warning(
                    "Llama3.2 models may struggle to emit valid pythonic tool calls"
                )
            parser = ToolParserManager.get_tool_parser(tool_parser_name)
        except Exception as e:
            raise TypeError(
                "Error: --enable-auto-tool-choice requires "
                f"tool_parser:'{tool_parser_name}' which has not "
                "been registered"
            ) from e
        return parser

    def _get_reasoning_parser(
        self,
        reasoning_parser_name: str,
277
    ) -> Callable[[TokenizerLike], ReasoningParser] | None:
278
279
280
281
282
283
284
285
286
287
288
        """Get the reasoning parser based on the name."""
        parser = None
        if not reasoning_parser_name:
            return None
        try:
            parser = ReasoningParserManager.get_reasoning_parser(reasoning_parser_name)
            assert parser is not None
        except Exception as e:
            raise TypeError(f"{reasoning_parser_name=} has not been registered") from e
        return parser

289
    async def reset_mm_cache(self) -> None:
290
        self.input_processor.clear_mm_cache()
291
292
        await self.engine_client.reset_mm_cache()

293
294
295
296
297
    async def beam_search(
        self,
        prompt: PromptType,
        request_id: str,
        params: BeamSearchParams,
298
        lora_request: LoRARequest | None = None,
299
        trace_headers: Mapping[str, str] | None = None,
300
301
302
303
304
305
306
307
    ) -> AsyncGenerator[RequestOutput, None]:
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        ignore_eos = params.ignore_eos
        temperature = params.temperature
        length_penalty = params.length_penalty
        include_stop_str_in_output = params.include_stop_str_in_output

308
309
        input_processor = self.input_processor
        tokenizer = input_processor.tokenizer
310
        if tokenizer is None:
311
312
313
314
            raise VLLMValidationError(
                "You cannot use beam search when `skip_tokenizer_init=True`",
                parameter="skip_tokenizer_init",
                value=True,
315
316
317
318
319
320
321
            )

        eos_token_id: int = tokenizer.eos_token_id  # type: ignore

        if is_explicit_encoder_decoder_prompt(prompt):
            raise NotImplementedError

322
        prompt_text: str | None
323
        prompt_token_ids: list[int]
324
        multi_modal_data: MultiModalDataDict | None
325
326
327
328
329
330
331
332
333
        if isinstance(prompt, str):
            prompt_text = prompt
            prompt_token_ids = []
            multi_modal_data = None
        else:
            prompt_text = prompt.get("prompt")  # type: ignore
            prompt_token_ids = prompt.get("prompt_token_ids", [])  # type: ignore
            multi_modal_data = prompt.get("multi_modal_data")  # type: ignore

334
335
336
337
338
339
340
341
342
343
        mm_processor_kwargs: dict[str, Any] | None = None

        # This is a workaround to fix multimodal beam search; this is a
        # bandaid fix for 2 small problems:
        # 1. Multi_modal_data on the processed_inputs currently resolves to
        #    `None`.
        # 2. preprocessing above expands the multimodal placeholders. However,
        #    this happens again in generation, so the double expansion causes
        #    a mismatch.
        # TODO - would be ideal to handle this more gracefully.
344
345
346
347
348

        tokenized_length = len(prompt_token_ids)

        sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)

349
        logprobs_num = 2 * beam_width
350
        beam_search_params = SamplingParams(
351
            logprobs=logprobs_num,
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
            max_tokens=1,
            temperature=temperature,
        )
        all_beams = [
            BeamSearchSequence(
                tokens=prompt_token_ids,
                cum_logprob=0,
                logprobs=[],
                multi_modal_data=multi_modal_data,
                mm_processor_kwargs=mm_processor_kwargs,
                lora_request=lora_request,
            )
        ]
        completed = []

        for _ in range(max_tokens):
            prompts_batch, lora_req_batch = zip(
                *[
                    (
371
                        TokensPrompt(
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
                            prompt_token_ids=beam.tokens,
                            multi_modal_data=beam.multi_modal_data,
                            mm_processor_kwargs=beam.mm_processor_kwargs,
                        ),
                        beam.lora_request,
                    )
                    for beam in all_beams
                ]
            )

            tasks = []
            request_id_batch = f"{request_id}-{random_uuid()}"

            for i, (individual_prompt, lora_req) in enumerate(
                zip(prompts_batch, lora_req_batch)
            ):
                request_id_item = f"{request_id_batch}-beam-{i}"
                task = asyncio.create_task(
                    collect_from_async_generator(
                        self.engine_client.generate(
                            individual_prompt,
                            beam_search_params,
                            request_id_item,
                            lora_request=lora_req,
396
                            trace_headers=trace_headers,
397
398
399
400
401
402
403
404
                        )
                    )
                )
                tasks.append(task)

            output = [x[0] for x in await asyncio.gather(*tasks)]

            new_beams = []
405
406
407
408
409
410
411
412
            # Store all new tokens generated by beam
            all_beams_token_id = []
            # Store the cumulative probability of all tokens
            # generated by beam search
            all_beams_logprob = []
            # Iterate through all beam inference results
            for i, result in enumerate(output):
                current_beam = all_beams[i]
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435

                # check for error finish reason and abort beam search
                if result.outputs[0].finish_reason == "error":
                    # yield error output and terminate beam search
                    yield RequestOutput(
                        request_id=request_id,
                        prompt=prompt_text,
                        outputs=[
                            CompletionOutput(
                                index=0,
                                text="",
                                token_ids=[],
                                cumulative_logprob=None,
                                logprobs=None,
                                finish_reason="error",
                            )
                        ],
                        finished=True,
                        prompt_token_ids=prompt_token_ids,
                        prompt_logprobs=None,
                    )
                    return

436
437
                if result.outputs[0].logprobs is not None:
                    logprobs = result.outputs[0].logprobs[0]
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
                    all_beams_token_id.extend(list(logprobs.keys()))
                    all_beams_logprob.extend(
                        [
                            current_beam.cum_logprob + obj.logprob
                            for obj in logprobs.values()
                        ]
                    )

            # Handle the token for the end of sentence (EOS)
            all_beams_token_id = np.array(all_beams_token_id)
            all_beams_logprob = np.array(all_beams_logprob)

            if not ignore_eos:
                # Get the index position of eos token in all generated results
                eos_idx = np.where(all_beams_token_id == eos_token_id)[0]
                for idx in eos_idx:
                    current_beam = all_beams[idx // logprobs_num]
                    result = output[idx // logprobs_num]
                    assert result.outputs[0].logprobs is not None
                    logprobs_entry = result.outputs[0].logprobs[0]
                    completed.append(
                        BeamSearchSequence(
                            tokens=current_beam.tokens + [eos_token_id]
                            if include_stop_str_in_output
                            else current_beam.tokens,
                            logprobs=current_beam.logprobs + [logprobs_entry],
                            cum_logprob=float(all_beams_logprob[idx]),
                            finish_reason="stop",
                            stop_reason=eos_token_id,
                        )
                    )
                # After processing, set the log probability of the eos condition
                # to negative infinity.
                all_beams_logprob[eos_idx] = -np.inf

            # Processing non-EOS tokens
            # Get indices of the top beam_width probabilities
            topn_idx = np.argpartition(np.negative(all_beams_logprob), beam_width)[
                :beam_width
            ]

            for idx in topn_idx:
                current_beam = all_beams[idx // logprobs_num]
                result = output[idx // logprobs_num]
                token_id = int(all_beams_token_id[idx])
                assert result.outputs[0].logprobs is not None
                logprobs_entry = result.outputs[0].logprobs[0]
                new_beams.append(
                    BeamSearchSequence(
                        tokens=current_beam.tokens + [token_id],
                        logprobs=current_beam.logprobs + [logprobs_entry],
                        lora_request=current_beam.lora_request,
                        cum_logprob=float(all_beams_logprob[idx]),
                        multi_modal_data=current_beam.multi_modal_data,
                        mm_processor_kwargs=current_beam.mm_processor_kwargs,
                    )
                )

            all_beams = new_beams
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530

        completed.extend(all_beams)
        sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
        best_beams = sorted_completed[:beam_width]

        for beam in best_beams:
            if beam.tokens[-1] == eos_token_id and not ignore_eos:
                # Skip the eos token in the text.
                tokens = beam.tokens[tokenized_length:-1]
            else:
                tokens = beam.tokens[tokenized_length:]
            beam.text = tokenizer.decode(tokens)

        yield RequestOutput(
            request_id=request_id,
            prompt=prompt_text,
            outputs=[
                CompletionOutput(
                    text=beam.text,  # type: ignore
                    cumulative_logprob=beam.cum_logprob,
                    token_ids=beam.tokens[tokenized_length:],
                    index=i,
                    logprobs=beam.logprobs,
                    finish_reason=beam.finish_reason
                    if beam.finish_reason is not None
                    else "length",
                    stop_reason=beam.stop_reason,
                )
                for (i, beam) in enumerate(best_beams)
            ],
            finished=True,
            prompt_token_ids=prompt_token_ids,
            prompt_logprobs=None,
        )
531

532
533
534
    async def _preprocess(
        self,
        ctx: ServeContext,
535
    ) -> ErrorResponse | None:
536
537
538
539
540
541
542
543
544
        """
        Default preprocessing hook. Subclasses may override
        to prepare `ctx` (classification, embedding, etc.).
        """
        return None

    def _build_response(
        self,
        ctx: ServeContext,
545
    ) -> AnyResponse | ErrorResponse:
546
547
548
549
550
551
552
553
554
        """
        Default response builder. Subclass may override this method
        to return the appropriate response object.
        """
        return self.create_error_response("unimplemented endpoint")

    async def handle(
        self,
        ctx: ServeContext,
555
    ) -> AnyResponse | ErrorResponse:
556
        async for response in self._pipeline(ctx):
557
558
559
560
561
562
563
            return response

        return self.create_error_response("No response yielded from pipeline")

    async def _pipeline(
        self,
        ctx: ServeContext,
564
    ) -> AsyncGenerator[AnyResponse | ErrorResponse, None]:
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
        """Execute the request processing pipeline yielding responses."""
        if error := await self._check_model(ctx.request):
            yield error
        if error := self._validate_request(ctx):
            yield error

        preprocess_ret = await self._preprocess(ctx)
        if isinstance(preprocess_ret, ErrorResponse):
            yield preprocess_ret

        generators_ret = await self._prepare_generators(ctx)
        if isinstance(generators_ret, ErrorResponse):
            yield generators_ret

        collect_ret = await self._collect_batch(ctx)
        if isinstance(collect_ret, ErrorResponse):
            yield collect_ret

        yield self._build_response(ctx)

585
    def _validate_request(self, ctx: ServeContext) -> ErrorResponse | None:
586
        truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None)
587

588
589
590
591
        if (
            truncate_prompt_tokens is not None
            and truncate_prompt_tokens > self.max_model_len
        ):
592
593
594
            return self.create_error_response(
                "truncate_prompt_tokens value is "
                "greater than max_model_len."
595
596
                " Please, select a smaller truncation size."
            )
597
598
        return None

599
600
601
    def _create_pooling_params(
        self,
        ctx: ServeContext,
602
    ) -> PoolingParams | ErrorResponse:
603
604
        if not hasattr(ctx.request, "to_pooling_params"):
            return self.create_error_response(
605
606
                "Request type does not support pooling parameters"
            )
607
608
609

        return ctx.request.to_pooling_params()

610
611
612
    async def _prepare_generators(
        self,
        ctx: ServeContext,
613
    ) -> ErrorResponse | None:
614
        """Schedule the request and get the result generator."""
615
        generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
616
617

        try:
618
619
620
621
622
            trace_headers = (
                None
                if ctx.raw_request is None
                else await self._get_trace_headers(ctx.raw_request.headers)
            )
623

624
625
626
            pooling_params = self._create_pooling_params(ctx)
            if isinstance(pooling_params, ErrorResponse):
                return pooling_params
627
628

            if ctx.engine_prompts is None:
629
                return self.create_error_response("Engine prompts not available")
630
631
632
633

            for i, engine_prompt in enumerate(ctx.engine_prompts):
                request_id_item = f"{ctx.request_id}-{i}"

634
635
                self._log_inputs(
                    request_id_item,
636
                    engine_prompt,
637
638
639
                    params=pooling_params,
                    lora_request=ctx.lora_request,
                )
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656

                generator = self.engine_client.encode(
                    engine_prompt,
                    pooling_params,
                    request_id_item,
                    lora_request=ctx.lora_request,
                    trace_headers=trace_headers,
                    priority=getattr(ctx.request, "priority", 0),
                )

                generators.append(generator)

            ctx.result_generator = merge_async_iterators(*generators)

            return None

        except Exception as e:
657
            return self.create_error_response(e)
658
659
660
661

    async def _collect_batch(
        self,
        ctx: ServeContext,
662
    ) -> ErrorResponse | None:
663
664
665
        """Collect batch results from the result generator."""
        try:
            if ctx.engine_prompts is None:
666
                return self.create_error_response("Engine prompts not available")
667
668

            num_prompts = len(ctx.engine_prompts)
669
            final_res_batch: list[PoolingRequestOutput | None]
670
671
672
            final_res_batch = [None] * num_prompts

            if ctx.result_generator is None:
673
                return self.create_error_response("Result generator not available")
674
675
676
677
678
679

            async for i, res in ctx.result_generator:
                final_res_batch[i] = res

            if None in final_res_batch:
                return self.create_error_response(
680
681
                    "Failed to generate results for all prompts"
                )
682

683
            ctx.final_res_batch = [res for res in final_res_batch if res is not None]
684
685
686
687

            return None

        except Exception as e:
688
            return self.create_error_response(e)
689

690
    def create_error_response(
691
        self,
692
        message: str | Exception,
693
694
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
695
        param: str | None = None,
696
    ) -> ErrorResponse:
697
698
699
700
701
        exc: Exception | None = None

        if isinstance(message, Exception):
            exc = message

702
            from vllm.exceptions import VLLMValidationError
703
704
705
706
707

            if isinstance(exc, VLLMValidationError):
                err_type = "BadRequestError"
                status_code = HTTPStatus.BAD_REQUEST
                param = exc.parameter
708
            elif isinstance(exc, (ValueError, TypeError, RuntimeError, OverflowError)):
709
710
711
712
                # Common validation errors from user input
                err_type = "BadRequestError"
                status_code = HTTPStatus.BAD_REQUEST
                param = None
713
714
715
716
            elif isinstance(exc, NotImplementedError):
                err_type = "NotImplementedError"
                status_code = HTTPStatus.NOT_IMPLEMENTED
                param = None
717
718
719
720
721
722
723
724
725
726
727
728
            elif exc.__class__.__name__ == "TemplateError":
                # jinja2.TemplateError (avoid importing jinja2)
                err_type = "BadRequestError"
                status_code = HTTPStatus.BAD_REQUEST
                param = None
            else:
                err_type = "InternalServerError"
                status_code = HTTPStatus.INTERNAL_SERVER_ERROR
                param = None

            message = str(exc)

729
730
731
732
733
734
        if self.log_error_stack:
            exc_type, _, _ = sys.exc_info()
            if exc_type is not None:
                traceback.print_exc()
            else:
                traceback.print_stack()
735

736
        return ErrorResponse(
737
            error=ErrorInfo(
738
                message=sanitize_message(message),
739
740
741
742
                type=err_type,
                code=status_code.value,
                param=param,
            )
743
        )
744

745
    def create_streaming_error_response(
746
        self,
747
        message: str | Exception,
748
749
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
750
        param: str | None = None,
751
    ) -> str:
752
        json_str = json.dumps(
753
            self.create_error_response(
754
755
756
757
                message=message,
                err_type=err_type,
                status_code=status_code,
                param=param,
758
759
            ).model_dump()
        )
760
761
        return json_str

762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
    def _raise_if_error(self, finish_reason: str | None, request_id: str) -> None:
        """Raise GenerationError if finish_reason indicates an error."""
        if finish_reason == "error":
            logger.error(
                "Request %s failed with an internal error during generation",
                request_id,
            )
            raise GenerationError("Internal server error")

    def _convert_generation_error_to_response(
        self, e: GenerationError
    ) -> ErrorResponse:
        """Convert GenerationError to ErrorResponse."""
        return self.create_error_response(
            str(e),
            err_type="InternalServerError",
            status_code=e.status_code,
        )

    def _convert_generation_error_to_streaming_response(
        self, e: GenerationError
    ) -> str:
        """Convert GenerationError to streaming error response."""
        return self.create_streaming_error_response(
            str(e),
            err_type="InternalServerError",
            status_code=e.status_code,
        )

791
    async def _check_model(
792
793
        self,
        request: AnyRequest,
794
    ) -> ErrorResponse | None:
795
796
        error_response = None

797
        if self._is_model_supported(request.model):
798
            return None
799
        if request.model in self.models.lora_requests:
800
            return None
801
802
803
804
805
        if (
            envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING
            and request.model
            and (load_result := await self.models.resolve_lora(request.model))
        ):
806
807
            if isinstance(load_result, LoRARequest):
                return None
808
809
810
811
            if (
                isinstance(load_result, ErrorResponse)
                and load_result.error.code == HTTPStatus.BAD_REQUEST.value
            ):
812
813
814
                error_response = load_result

        return error_response or self.create_error_response(
815
816
            message=f"The model `{request.model}` does not exist.",
            err_type="NotFoundError",
817
            status_code=HTTPStatus.NOT_FOUND,
818
            param="model",
819
        )
820

821
    def _get_active_default_mm_loras(self, request: AnyRequest) -> LoRARequest | None:
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
        """Determine if there are any active default multimodal loras."""
        # TODO: Currently this is only enabled for chat completions
        # to be better aligned with only being enabled for .generate
        # when run offline. It would be nice to support additional
        # tasks types in the future.
        message_types = self._get_message_types(request)
        default_mm_loras = set()

        for lora in self.models.lora_requests.values():
            # Best effort match for default multimodal lora adapters;
            # There is probably a better way to do this, but currently
            # this matches against the set of 'types' in any content lists
            # up until '_', e.g., to match audio_url -> audio
            if lora.lora_name in message_types:
                default_mm_loras.add(lora)

        # Currently only support default modality specific loras if
        # we have exactly one lora matched on the request.
        if len(default_mm_loras) == 1:
            return default_mm_loras.pop()
        return None

844
    def _maybe_get_adapters(
845
846
847
        self,
        request: AnyRequest,
        supports_default_mm_loras: bool = False,
848
    ) -> LoRARequest | None:
849
        if request.model in self.models.lora_requests:
850
            return self.models.lora_requests[request.model]
851
852
853
854
855
856

        # Currently only support default modality specific loras
        # if we have exactly one lora matched on the request.
        if supports_default_mm_loras:
            default_mm_lora = self._get_active_default_mm_loras(request)
            if default_mm_lora is not None:
857
                return default_mm_lora
858
859

        if self._is_model_supported(request.model):
860
            return None
861

862
        # if _check_model has been called earlier, this will be unreachable
863
        raise ValueError(f"The model `{request.model}` does not exist.")
864

865
866
867
868
869
870
871
872
873
874
    def _get_message_types(self, request: AnyRequest) -> set[str]:
        """Retrieve the set of types from message content dicts up
        until `_`; we use this to match potential multimodal data
        with default per modality loras.
        """
        message_types: set[str] = set()

        if not hasattr(request, "messages"):
            return message_types

875
876
877
878
879
        messages = request.messages
        if messages is None or isinstance(messages, (str, bytes)):
            return message_types

        for message in messages:
880
881
882
883
884
            if (
                isinstance(message, dict)
                and "content" in message
                and isinstance(message["content"], list)
            ):
885
886
887
888
889
                for content_dict in message["content"]:
                    if "type" in content_dict:
                        message_types.add(content_dict["type"].split("_")[0])
        return message_types

890
891
    def _validate_input(
        self,
892
        request: object,
893
        input_ids: list[int],
894
        input_text: str,
895
    ) -> TokensPrompt:
896
897
        token_num = len(input_ids)

898
899
        # Note: EmbeddingRequest, ClassificationRequest,
        # and ScoreRequest doesn't have max_tokens
900
        if isinstance(
901
            request,
902
903
904
            (
                EmbeddingChatRequest,
                EmbeddingCompletionRequest,
905
906
907
                ScoreDataRequest,
                ScoreTextRequest,
                ScoreQueriesDocumentsRequest,
908
                RerankRequest,
909
910
                ClassificationCompletionRequest,
                ClassificationChatRequest,
911
912
            ),
        ):
913
914
            # Note: input length can be up to the entire model context length
            # since these requests don't generate tokens.
915
            if token_num > self.max_model_len:
916
                operations: dict[type[AnyRequest], str] = {
917
918
919
                    ScoreDataRequest: "score",
                    ScoreTextRequest: "score",
                    ScoreQueriesDocumentsRequest: "score",
920
921
                    ClassificationCompletionRequest: "classification",
                    ClassificationChatRequest: "classification",
922
                }
923
                operation = operations.get(type(request), "embedding generation")
924
                raise VLLMValidationError(
925
926
                    f"This model's maximum context length is "
                    f"{self.max_model_len} tokens. However, you requested "
927
                    f"{token_num} tokens in the input for {operation}. "
928
929
930
                    f"Please reduce the length of the input.",
                    parameter="input_tokens",
                    value=token_num,
931
                )
932
            return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
933

934
935
        # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
        # and does not require model context length validation
936
        if isinstance(
937
938
            request,
            (TokenizeCompletionRequest, TokenizeChatRequest, DetokenizeRequest),
939
        ):
940
            return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
941

942
943
944
945
946
        # chat completion endpoint supports max_completion_tokens
        if isinstance(request, ChatCompletionRequest):
            # TODO(#9845): remove max_tokens when field dropped from OpenAI API
            max_tokens = request.max_completion_tokens or request.max_tokens
        else:
947
            max_tokens = getattr(request, "max_tokens", None)
948
949
950
951

        # Note: input length can be up to model context length - 1 for
        # completion-like requests.
        if token_num >= self.max_model_len:
952
            raise VLLMValidationError(
953
                f"This model's maximum context length is "
954
955
                f"{self.max_model_len} tokens. However, your request has "
                f"{token_num} input tokens. Please reduce the length of "
956
957
958
                "the input messages.",
                parameter="input_tokens",
                value=token_num,
959
            )
960

961
        if max_tokens is not None and token_num + max_tokens > self.max_model_len:
962
            raise VLLMValidationError(
963
964
965
966
                "'max_tokens' or 'max_completion_tokens' is too large: "
                f"{max_tokens}. This model's maximum context length is "
                f"{self.max_model_len} tokens and your request has "
                f"{token_num} input tokens ({max_tokens} > {self.max_model_len}"
967
968
969
                f" - {token_num}).",
                parameter="max_tokens",
                value=max_tokens,
970
            )
971

972
        return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
973

974
975
    def _validate_chat_template(
        self,
976
977
        request_chat_template: str | None,
        chat_template_kwargs: dict[str, Any] | None,
978
        trust_request_chat_template: bool,
979
    ) -> ErrorResponse | None:
980
        if not trust_request_chat_template and (
981
982
983
984
985
986
            request_chat_template is not None
            or (
                chat_template_kwargs
                and chat_template_kwargs.get("chat_template") is not None
            )
        ):
987
988
989
            return self.create_error_response(
                "Chat template is passed with request, but "
                "--trust-request-chat-template is not set. "
990
991
                "Refused request with untrusted chat template."
            )
992
993
        return None

994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
    @staticmethod
    def _prepare_extra_chat_template_kwargs(
        request_chat_template_kwargs: dict[str, Any] | None = None,
        default_chat_template_kwargs: dict[str, Any] | None = None,
    ) -> dict[str, Any]:
        """Helper to merge server-default and request-specific chat template kwargs."""
        request_chat_template_kwargs = request_chat_template_kwargs or {}
        if default_chat_template_kwargs is None:
            return request_chat_template_kwargs
        # Apply server defaults first, then request kwargs override.
        return default_chat_template_kwargs | request_chat_template_kwargs

1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
    async def _preprocess_completion(
        self,
        request: RendererRequest,
        prompt_input: str | list[str] | list[int] | list[list[int]] | None,
        prompt_embeds: bytes | list[bytes] | None,
    ) -> list[TokensPrompt | EmbedsPrompt]:
        renderer = self.renderer
        tok_params = request.build_tok_params(self.model_config)

        in_prompts = await renderer.render_completions_async(
            prompt_input, prompt_embeds
        )
        engine_prompts = await renderer.tokenize_prompts_async(in_prompts, tok_params)

        extra_items = {
            k: v
            for k in ("mm_processor_kwargs", "cache_salt")
            if (v := getattr(request, k, None)) is not None
        }
        for prompt in engine_prompts:
            prompt.update(extra_items)  # type: ignore

        return engine_prompts

1030
1031
    async def _preprocess_chat(
        self,
1032
        request: RendererChatRequest,
1033
        messages: list[ChatCompletionMessageParam],
1034
1035
1036
        default_template: str | None,
        default_template_content_format: ChatTemplateContentFormatOption,
        default_template_kwargs: dict[str, Any] | None,
1037
        tool_dicts: list[dict[str, Any]] | None = None,
1038
        tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
1039
    ) -> tuple[list[ConversationMessage], list[TokensPrompt | EmbedsPrompt]]:
1040
        from vllm.tokenizers.mistral import MistralTokenizer
1041

1042
1043
1044
1045
1046
1047
1048
        renderer = self.renderer

        default_template_kwargs = merge_kwargs(
            default_template_kwargs,
            dict(
                tools=tool_dicts,
                tokenize=isinstance(renderer.tokenizer, MistralTokenizer),
1049
1050
1051
            ),
        )

1052
1053
1054
1055
        tok_params = request.build_tok_params(self.model_config)
        chat_params = request.build_chat_params(
            default_template, default_template_content_format
        ).with_defaults(default_template_kwargs)
1056

1057
1058
1059
1060
        conversation, prompt = await renderer.render_messages_async(
            messages, chat_params
        )
        engine_prompt = await renderer.tokenize_prompt_async(prompt, tok_params)
1061

1062
1063
1064
1065
1066
1067
        extra_items = {
            k: v
            for k in ("mm_processor_kwargs", "cache_salt")
            if (v := getattr(request, k, None)) is not None
        }
        engine_prompt.update(extra_items)  # type: ignore
1068

1069
1070
1071
        # tool parsing is done only if a tool_parser has been set and if
        # tool_choice is not "none" (if tool_choice is "none" but a tool_parser
        # is set, we want to prevent parsing a tool_call hallucinated by the LLM
1072
1073
1074
1075
1076
1077
1078
1079
1080
        if tool_parser is not None:
            tool_choice = getattr(request, "tool_choice", "none")
            if tool_choice != "none":
                if not isinstance(request, ChatCompletionRequest | ResponsesRequest):
                    msg = (
                        "Tool usage is only supported for Chat Completions API "
                        "or Responses API requests."
                    )
                    raise NotImplementedError(msg)
1081

1082
1083
1084
                # TODO: Update adjust_request to accept ResponsesRequest
                tokenizer = renderer.get_tokenizer()
                request = tool_parser(tokenizer).adjust_request(request=request)  # type: ignore[arg-type]
1085

1086
        return conversation, [engine_prompt]
1087

1088
1089
1090
1091
1092
    async def _render_next_turn(
        self,
        request: ResponsesRequest,
        messages: list[ResponseInputOutputItem],
        tool_dicts: list[dict[str, Any]] | None,
1093
        tool_parser: Callable[[TokenizerLike], ToolParser] | None,
1094
1095
1096
1097
1098
1099
1100
        chat_template: str | None,
        chat_template_content_format: ChatTemplateContentFormatOption,
    ):
        new_messages = construct_input_messages(
            request_input=messages,
        )

1101
        _, engine_prompts = await self._preprocess_chat(
1102
1103
            request,
            new_messages,
1104
1105
1106
            default_template=chat_template,
            default_template_content_format=chat_template_content_format,
            default_template_kwargs=None,
1107
1108
1109
            tool_dicts=tool_dicts,
            tool_parser=tool_parser,
        )
1110
        return engine_prompts
1111

1112
1113
1114
    async def _generate_with_builtin_tools(
        self,
        request_id: str,
1115
        engine_prompt: TokensPrompt | EmbedsPrompt,
1116
        sampling_params: SamplingParams,
1117
        tok_params: TokenizeParams,
1118
        context: ConversationContext,
1119
        lora_request: LoRARequest | None = None,
1120
        priority: int = 0,
1121
        trace_headers: Mapping[str, str] | None = None,
1122
    ):
1123
        prompt_text, _, _ = get_prompt_components(engine_prompt)
1124

1125
        orig_priority = priority
1126
        sub_request = 0
1127
        while True:
1128
1129
            # Ensure that each sub-request has a unique request id.
            sub_request_id = f"{request_id}_{sub_request}"
1130

1131
            self._log_inputs(
1132
                sub_request_id,
1133
                engine_prompt,
1134
1135
1136
                params=sampling_params,
                lora_request=lora_request,
            )
1137
1138
1139

            tokenization_kwargs = tok_params.get_encode_kwargs()
            engine_request = self.input_processor.process_inputs(
1140
                sub_request_id,
1141
1142
                engine_prompt,
                sampling_params,
1143
                lora_request=lora_request,
1144
                tokenization_kwargs=tokenization_kwargs,
1145
1146
                trace_headers=trace_headers,
                priority=priority,
1147
            )
1148
1149
1150
1151

            generator = self.engine_client.generate(
                engine_request,
                sampling_params,
1152
                sub_request_id,
1153
                lora_request=lora_request,
1154
                trace_headers=trace_headers,
1155
                priority=priority,
1156
1157
                prompt_text=prompt_text,
                tokenization_kwargs=tokenization_kwargs,
1158
            )
1159

1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
            async for res in generator:
                context.append_output(res)
                # NOTE(woosuk): The stop condition is handled by the engine.
                yield context

            if not context.need_builtin_tool_call():
                # The model did not ask for a tool call, so we're done.
                break

            # Call the tool and update the context with the result.
            tool_output = await context.call_tool()
1171
            context.append_tool_output(tool_output)
1172
1173
1174
1175
1176

            # TODO: uncomment this and enable tool output streaming
            # yield context

            # Create inputs for the next turn.
1177
            # Render the next prompt token ids and update sampling_params.
1178
            if isinstance(context, (HarmonyContext, StreamingHarmonyContext)):
1179
1180
1181
1182
                token_ids = context.render_for_completion()
                engine_prompt = TokensPrompt(prompt_token_ids=token_ids)

                sampling_params.max_tokens = self.max_model_len - len(token_ids)
1183
            elif isinstance(context, ParsableContext):
1184
                engine_prompts = await self._render_next_turn(
1185
1186
1187
1188
1189
1190
1191
1192
                    context.request,
                    context.parser.response_messages,
                    context.tool_dicts,
                    context.tool_parser_cls,
                    context.chat_template,
                    context.chat_template_content_format,
                )
                engine_prompt = engine_prompts[0]
1193
1194
1195
1196
1197
1198
1199
1200
                prompt_text, _, _ = get_prompt_components(engine_prompt)

                sampling_params.max_tokens = get_max_tokens(
                    self.max_model_len,
                    context.request,
                    engine_prompt,
                    self.default_sampling_params,  # type: ignore
                )
1201

1202
1203
            # OPTIMIZATION
            priority = orig_priority - 1
1204
            sub_request += 1
1205

1206
1207
1208
    def _log_inputs(
        self,
        request_id: str,
1209
        inputs: PromptType,
1210
1211
        params: SamplingParams | PoolingParams | BeamSearchParams | None,
        lora_request: LoRARequest | None,
1212
1213
1214
    ) -> None:
        if self.request_logger is None:
            return
1215

1216
        prompt, prompt_token_ids, prompt_embeds = get_prompt_components(inputs)
1217
1218
1219
1220
1221

        self.request_logger.log_inputs(
            request_id,
            prompt,
            prompt_token_ids,
1222
            prompt_embeds,
1223
1224
1225
            params=params,
            lora_request=lora_request,
        )
1226

1227
1228
1229
    async def _get_trace_headers(
        self,
        headers: Headers,
1230
    ) -> Mapping[str, str] | None:
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
        is_tracing_enabled = await self.engine_client.is_tracing_enabled()

        if is_tracing_enabled:
            return extract_trace_headers(headers)

        if contains_trace_headers(headers):
            log_tracing_disabled_warning()

        return None

1241
    @staticmethod
1242
    def _base_request_id(
1243
1244
        raw_request: Request | None, default: str | None = None
    ) -> str | None:
1245
        """Pulls the request id to use from a header, if provided"""
1246
1247
1248
1249
        if raw_request is not None and (
            (req_id := raw_request.headers.get("X-Request-Id")) is not None
        ):
            return req_id
1250

1251
        return random_uuid() if default is None else default
1252

1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
    @staticmethod
    def _get_data_parallel_rank(raw_request: Request | None) -> int | None:
        """Pulls the data parallel rank from a header, if provided"""
        if raw_request is None:
            return None

        rank_str = raw_request.headers.get("X-data-parallel-rank")
        if rank_str is None:
            return None

        try:
            return int(rank_str)
        except ValueError:
            return None

1268
1269
1270
    @staticmethod
    def _parse_tool_calls_from_content(
        request: ResponsesRequest | ChatCompletionRequest,
1271
        tokenizer: TokenizerLike | None,
1272
        enable_auto_tools: bool,
1273
        tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
        content: str | None = None,
    ) -> tuple[list[FunctionCall] | None, str | None]:
        function_calls = list[FunctionCall]()
        if request.tool_choice and isinstance(request.tool_choice, ToolChoiceFunction):
            assert content is not None
            # Forced Function Call
            function_calls.append(
                FunctionCall(name=request.tool_choice.name, arguments=content)
            )
            content = None  # Clear content since tool is called.
        elif request.tool_choice and isinstance(
            request.tool_choice, ChatCompletionNamedToolChoiceParam
        ):
            assert content is not None
            # Forced Function Call
            function_calls.append(
                FunctionCall(name=request.tool_choice.function.name, arguments=content)
            )
            content = None  # Clear content since tool is called.
        elif request.tool_choice == "required":
            assert content is not None
            tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(content)
            function_calls.extend(
                [
                    FunctionCall(
                        name=tool_call.name,
                        arguments=json.dumps(tool_call.parameters, ensure_ascii=False),
                    )
                    for tool_call in tool_calls
                ]
            )
            content = None  # Clear content since tool is called.
        elif (
            tool_parser_cls
            and enable_auto_tools
            and (request.tool_choice == "auto" or request.tool_choice is None)
        ):
1311
1312
1313
1314
1315
            if tokenizer is None:
                raise ValueError(
                    "Tokenizer not available when `skip_tokenizer_init=True`"
                )

1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
            # Automatic Tool Call Parsing
            try:
                tool_parser = tool_parser_cls(tokenizer)
            except RuntimeError as e:
                logger.exception("Error in tool parser creation.")
                raise e
            tool_call_info = tool_parser.extract_tool_calls(
                content if content is not None else "",
                request=request,  # type: ignore
            )
            if tool_call_info is not None and tool_call_info.tools_called:
                # extract_tool_calls() returns a list of tool calls.
                function_calls.extend(
                    FunctionCall(
1330
                        id=tool_call.id,
1331
1332
1333
1334
1335
1336
                        name=tool_call.function.name,
                        arguments=tool_call.function.arguments,
                    )
                    for tool_call in tool_call_info.tool_calls
                )
                content = tool_call_info.content
1337
1338
                if content and content.strip() == "":
                    content = None
1339
1340
1341
1342
1343
1344
            else:
                # No tool calls.
                return None, content

        return function_calls, content

1345
    @staticmethod
1346
1347
1348
    def _get_decoded_token(
        logprob: Logprob,
        token_id: int,
1349
        tokenizer: TokenizerLike | None,
1350
1351
        return_as_token_id: bool = False,
    ) -> str:
1352
1353
1354
        if return_as_token_id:
            return f"token_id:{token_id}"

1355
1356
        if logprob.decoded_token is not None:
            return logprob.decoded_token
1357
1358
1359
1360
1361
1362

        if tokenizer is None:
            raise ValueError(
                "Unable to get tokenizer because `skip_tokenizer_init=True`"
            )

1363
        return tokenizer.decode([token_id])
1364

1365
    def _is_model_supported(self, model_name: str | None) -> bool:
1366
1367
        if not model_name:
            return True
1368
        return self.models.is_base_model(model_name)
1369

1370
1371

def clamp_prompt_logprobs(
1372
1373
    prompt_logprobs: PromptLogprobs | None,
) -> PromptLogprobs | None:
1374
1375
1376
1377
1378
1379
1380
    if prompt_logprobs is None:
        return prompt_logprobs

    for logprob_dict in prompt_logprobs:
        if logprob_dict is None:
            continue
        for logprob_values in logprob_dict.values():
1381
            if logprob_values.logprob == float("-inf"):
1382
1383
                logprob_values.logprob = -9999.0
    return prompt_logprobs