serving.py 49.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
import json
5
import sys
6
import time
7
import traceback
8
from collections.abc import AsyncGenerator, Callable, Mapping
9
from dataclasses import dataclass, field
10
from http import HTTPStatus
11
from typing import Any, ClassVar, Generic, Protocol, TypeAlias, TypeVar
12

13
import numpy as np
14
from fastapi import Request
15
16
17
from openai.types.responses import (
    ToolChoiceFunction,
)
18
19
from pydantic import ConfigDict, TypeAdapter
from starlette.datastructures import Headers
20

21
import vllm.envs as envs
22
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
23
from vllm.config import ModelConfig
24
from vllm.engine.protocol import EngineClient
25
26
27
28
29
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ChatTemplateContentFormatOption,
    ConversationMessage,
)
30
from vllm.entrypoints.logger import RequestLogger
31
from vllm.entrypoints.openai.chat_completion.protocol import (
32
    ChatCompletionNamedToolChoiceParam,
33
34
    ChatCompletionRequest,
    ChatCompletionResponse,
35
)
36
from vllm.entrypoints.openai.completion.protocol import (
37
38
    CompletionRequest,
    CompletionResponse,
39
40
)
from vllm.entrypoints.openai.engine.protocol import (
41
42
    ErrorInfo,
    ErrorResponse,
43
    FunctionCall,
44
    FunctionDefinition,
45
)
46
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
47
48
49
50
51
52
from vllm.entrypoints.openai.responses.context import (
    ConversationContext,
    HarmonyContext,
    ParsableContext,
    StreamingHarmonyContext,
)
53
54
55
56
from vllm.entrypoints.openai.responses.protocol import (
    ResponseInputOutputItem,
    ResponsesRequest,
)
57
58
59
from vllm.entrypoints.openai.responses.utils import (
    construct_input_messages,
)
60
61
62
63
64
from vllm.entrypoints.openai.translations.protocol import (
    TranscriptionRequest,
    TranscriptionResponse,
    TranslationRequest,
)
65
66
67
68
69
70
from vllm.entrypoints.pooling.classify.protocol import (
    ClassificationChatRequest,
    ClassificationCompletionRequest,
    ClassificationResponse,
)
from vllm.entrypoints.pooling.embed.protocol import (
71
    EmbeddingBytesResponse,
72
73
74
75
76
77
    EmbeddingChatRequest,
    EmbeddingCompletionRequest,
    EmbeddingResponse,
)
from vllm.entrypoints.pooling.pooling.protocol import (
    IOProcessorRequest,
78
79
    PoolingChatRequest,
    PoolingCompletionRequest,
80
81
82
83
    PoolingResponse,
)
from vllm.entrypoints.pooling.score.protocol import (
    RerankRequest,
84
85
    ScoreDataRequest,
    ScoreQueriesDocumentsRequest,
86
87
    ScoreRequest,
    ScoreResponse,
88
    ScoreTextRequest,
89
)
90
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
91
92
93
94
95
96
from vllm.entrypoints.serve.tokenize.protocol import (
    DetokenizeRequest,
    TokenizeChatRequest,
    TokenizeCompletionRequest,
    TokenizeResponse,
)
97
from vllm.entrypoints.utils import get_max_tokens, sanitize_message
98
from vllm.exceptions import VLLMValidationError
99
from vllm.inputs.data import EmbedsPrompt, PromptType, TokensPrompt
100
101
102
103
from vllm.inputs.parse import (
    get_prompt_components,
    is_explicit_encoder_decoder_prompt,
)
104
from vllm.logger import init_logger
105
from vllm.logprobs import Logprob, PromptLogprobs
106
from vllm.lora.request import LoRARequest
107
from vllm.multimodal import MultiModalDataDict
108
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
109
from vllm.pooling_params import PoolingParams
110
from vllm.reasoning import ReasoningParser, ReasoningParserManager
111
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
112
from vllm.sampling_params import BeamSearchParams, SamplingParams
113
from vllm.tokenizers import TokenizerLike
114
from vllm.tool_parsers import ToolParser, ToolParserManager
115
116
117
118
119
from vllm.tracing import (
    contains_trace_headers,
    extract_trace_headers,
    log_tracing_disabled_warning,
)
120
from vllm.utils import random_uuid
121
from vllm.utils.async_utils import (
122
    collect_from_async_generator,
123
124
    merge_async_iterators,
)
125

126
127
128
129
130
131
132
133
134

class GenerationError(Exception):
    """raised when finish_reason indicates internal server error (500)"""

    def __init__(self, message: str = "Internal server error"):
        super().__init__(message)
        self.status_code = HTTPStatus.INTERNAL_SERVER_ERROR


135
136
logger = init_logger(__name__)

137
138
139
140
141
142
143
144
145
146
147
148
149
150
151

class RendererRequest(Protocol):
    def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
        raise NotImplementedError


class RendererChatRequest(RendererRequest, Protocol):
    def build_chat_params(
        self,
        default_template: str | None,
        default_template_content_format: ChatTemplateContentFormatOption,
    ) -> ChatParams:
        raise NotImplementedError


152
153
CompletionLikeRequest: TypeAlias = (
    CompletionRequest
154
    | TokenizeCompletionRequest
155
156
    | DetokenizeRequest
    | EmbeddingCompletionRequest
157
    | ClassificationCompletionRequest
158
    | RerankRequest
159
    | ScoreRequest
160
    | PoolingCompletionRequest
161
)
162

163
ChatLikeRequest: TypeAlias = (
164
165
    ChatCompletionRequest
    | TokenizeChatRequest
166
    | EmbeddingChatRequest
167
    | ClassificationChatRequest
168
    | PoolingChatRequest
169
)
170

171
SpeechToTextRequest: TypeAlias = TranscriptionRequest | TranslationRequest
172

173
174
175
176
177
178
AnyRequest: TypeAlias = (
    CompletionLikeRequest
    | ChatLikeRequest
    | SpeechToTextRequest
    | ResponsesRequest
    | IOProcessorRequest
179
    | GenerateRequest
180
181
182
183
184
185
)

AnyResponse: TypeAlias = (
    CompletionResponse
    | ChatCompletionResponse
    | EmbeddingResponse
186
    | EmbeddingBytesResponse
187
188
189
190
191
    | TranscriptionResponse
    | TokenizeResponse
    | PoolingResponse
    | ClassificationResponse
    | ScoreResponse
192
    | GenerateResponse
193
)
194

195

196
197
198
RequestT = TypeVar("RequestT", bound=AnyRequest)


199
@dataclass(kw_only=True)
200
class ServeContext(Generic[RequestT]):
201
    request: RequestT
202
    raw_request: Request | None = None
203
204
    model_name: str
    request_id: str
205
    created_time: int = field(default_factory=lambda: int(time.time()))
206
    lora_request: LoRARequest | None = None
207
    engine_prompts: list[TokensPrompt | EmbedsPrompt] | None = None
208

209
210
211
212
    result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
        None
    )
    final_res_batch: list[PoolingRequestOutput] = field(default_factory=list)
213

214
    model_config = ConfigDict(arbitrary_types_allowed=True)
215
216


217
class OpenAIServing:
218
219
220
221
    request_id_prefix: ClassVar[str] = """
    A short string prepended to every request’s ID (e.g. "embd", "classify")
    so you can easily tell “this ID came from Embedding vs Classification.”
    """
222

223
224
    def __init__(
        self,
225
        engine_client: EngineClient,
226
        models: OpenAIServingModels,
227
        *,
228
        request_logger: RequestLogger | None,
229
        return_tokens_as_token_ids: bool = False,
230
        log_error_stack: bool = False,
231
    ):
232
233
        super().__init__()

234
        self.engine_client = engine_client
235

236
        self.models = models
237

238
        self.request_logger = request_logger
239
        self.return_tokens_as_token_ids = return_tokens_as_token_ids
240

241
        self.log_error_stack = log_error_stack
242

243
        self.input_processor = self.models.input_processor
244
        self.io_processor = self.models.io_processor
245
        self.renderer = self.models.renderer
246
247
248
        self.model_config = self.models.model_config
        self.max_model_len = self.model_config.max_model_len

249
    def _get_tool_parser(
250
        self, tool_parser_name: str | None = None, enable_auto_tools: bool = False
251
    ) -> Callable[[TokenizerLike], ToolParser] | None:
252
253
254
255
        """Get the tool parser based on the name."""
        parser = None
        if not enable_auto_tools or tool_parser_name is None:
            return parser
256
        logger.info('"auto" tool choice has been enabled.')
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276

        try:
            if tool_parser_name == "pythonic" and self.model_config.model.startswith(
                "meta-llama/Llama-3.2"
            ):
                logger.warning(
                    "Llama3.2 models may struggle to emit valid pythonic tool calls"
                )
            parser = ToolParserManager.get_tool_parser(tool_parser_name)
        except Exception as e:
            raise TypeError(
                "Error: --enable-auto-tool-choice requires "
                f"tool_parser:'{tool_parser_name}' which has not "
                "been registered"
            ) from e
        return parser

    def _get_reasoning_parser(
        self,
        reasoning_parser_name: str,
277
    ) -> Callable[[TokenizerLike], ReasoningParser] | None:
278
279
280
281
282
283
284
285
286
287
288
        """Get the reasoning parser based on the name."""
        parser = None
        if not reasoning_parser_name:
            return None
        try:
            parser = ReasoningParserManager.get_reasoning_parser(reasoning_parser_name)
            assert parser is not None
        except Exception as e:
            raise TypeError(f"{reasoning_parser_name=} has not been registered") from e
        return parser

289
290
291
292
293
    async def beam_search(
        self,
        prompt: PromptType,
        request_id: str,
        params: BeamSearchParams,
294
        lora_request: LoRARequest | None = None,
295
        trace_headers: Mapping[str, str] | None = None,
296
297
298
299
300
301
302
303
    ) -> AsyncGenerator[RequestOutput, None]:
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        ignore_eos = params.ignore_eos
        temperature = params.temperature
        length_penalty = params.length_penalty
        include_stop_str_in_output = params.include_stop_str_in_output

304
305
        input_processor = self.input_processor
        tokenizer = input_processor.tokenizer
306
        if tokenizer is None:
307
308
309
310
            raise VLLMValidationError(
                "You cannot use beam search when `skip_tokenizer_init=True`",
                parameter="skip_tokenizer_init",
                value=True,
311
312
313
314
315
316
317
            )

        eos_token_id: int = tokenizer.eos_token_id  # type: ignore

        if is_explicit_encoder_decoder_prompt(prompt):
            raise NotImplementedError

318
        prompt_text: str | None
319
        prompt_token_ids: list[int]
320
        multi_modal_data: MultiModalDataDict | None
321
322
323
324
325
326
327
328
329
        if isinstance(prompt, str):
            prompt_text = prompt
            prompt_token_ids = []
            multi_modal_data = None
        else:
            prompt_text = prompt.get("prompt")  # type: ignore
            prompt_token_ids = prompt.get("prompt_token_ids", [])  # type: ignore
            multi_modal_data = prompt.get("multi_modal_data")  # type: ignore

330
331
332
333
334
335
336
337
338
339
        mm_processor_kwargs: dict[str, Any] | None = None

        # This is a workaround to fix multimodal beam search; this is a
        # bandaid fix for 2 small problems:
        # 1. Multi_modal_data on the processed_inputs currently resolves to
        #    `None`.
        # 2. preprocessing above expands the multimodal placeholders. However,
        #    this happens again in generation, so the double expansion causes
        #    a mismatch.
        # TODO - would be ideal to handle this more gracefully.
340
341
342
343
344

        tokenized_length = len(prompt_token_ids)

        sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)

345
        logprobs_num = 2 * beam_width
346
        beam_search_params = SamplingParams(
347
            logprobs=logprobs_num,
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
            max_tokens=1,
            temperature=temperature,
        )
        all_beams = [
            BeamSearchSequence(
                tokens=prompt_token_ids,
                cum_logprob=0,
                logprobs=[],
                multi_modal_data=multi_modal_data,
                mm_processor_kwargs=mm_processor_kwargs,
                lora_request=lora_request,
            )
        ]
        completed = []

        for _ in range(max_tokens):
            prompts_batch, lora_req_batch = zip(
                *[
                    (
367
                        TokensPrompt(
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
                            prompt_token_ids=beam.tokens,
                            multi_modal_data=beam.multi_modal_data,
                            mm_processor_kwargs=beam.mm_processor_kwargs,
                        ),
                        beam.lora_request,
                    )
                    for beam in all_beams
                ]
            )

            tasks = []
            request_id_batch = f"{request_id}-{random_uuid()}"

            for i, (individual_prompt, lora_req) in enumerate(
                zip(prompts_batch, lora_req_batch)
            ):
                request_id_item = f"{request_id_batch}-beam-{i}"
                task = asyncio.create_task(
                    collect_from_async_generator(
                        self.engine_client.generate(
                            individual_prompt,
                            beam_search_params,
                            request_id_item,
                            lora_request=lora_req,
392
                            trace_headers=trace_headers,
393
394
395
396
397
398
399
400
                        )
                    )
                )
                tasks.append(task)

            output = [x[0] for x in await asyncio.gather(*tasks)]

            new_beams = []
401
402
403
404
405
406
407
408
            # Store all new tokens generated by beam
            all_beams_token_id = []
            # Store the cumulative probability of all tokens
            # generated by beam search
            all_beams_logprob = []
            # Iterate through all beam inference results
            for i, result in enumerate(output):
                current_beam = all_beams[i]
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431

                # check for error finish reason and abort beam search
                if result.outputs[0].finish_reason == "error":
                    # yield error output and terminate beam search
                    yield RequestOutput(
                        request_id=request_id,
                        prompt=prompt_text,
                        outputs=[
                            CompletionOutput(
                                index=0,
                                text="",
                                token_ids=[],
                                cumulative_logprob=None,
                                logprobs=None,
                                finish_reason="error",
                            )
                        ],
                        finished=True,
                        prompt_token_ids=prompt_token_ids,
                        prompt_logprobs=None,
                    )
                    return

432
433
                if result.outputs[0].logprobs is not None:
                    logprobs = result.outputs[0].logprobs[0]
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
                    all_beams_token_id.extend(list(logprobs.keys()))
                    all_beams_logprob.extend(
                        [
                            current_beam.cum_logprob + obj.logprob
                            for obj in logprobs.values()
                        ]
                    )

            # Handle the token for the end of sentence (EOS)
            all_beams_token_id = np.array(all_beams_token_id)
            all_beams_logprob = np.array(all_beams_logprob)

            if not ignore_eos:
                # Get the index position of eos token in all generated results
                eos_idx = np.where(all_beams_token_id == eos_token_id)[0]
                for idx in eos_idx:
                    current_beam = all_beams[idx // logprobs_num]
                    result = output[idx // logprobs_num]
                    assert result.outputs[0].logprobs is not None
                    logprobs_entry = result.outputs[0].logprobs[0]
                    completed.append(
                        BeamSearchSequence(
                            tokens=current_beam.tokens + [eos_token_id]
                            if include_stop_str_in_output
                            else current_beam.tokens,
                            logprobs=current_beam.logprobs + [logprobs_entry],
                            cum_logprob=float(all_beams_logprob[idx]),
                            finish_reason="stop",
                            stop_reason=eos_token_id,
                        )
                    )
                # After processing, set the log probability of the eos condition
                # to negative infinity.
                all_beams_logprob[eos_idx] = -np.inf

            # Processing non-EOS tokens
            # Get indices of the top beam_width probabilities
            topn_idx = np.argpartition(np.negative(all_beams_logprob), beam_width)[
                :beam_width
            ]

            for idx in topn_idx:
                current_beam = all_beams[idx // logprobs_num]
                result = output[idx // logprobs_num]
                token_id = int(all_beams_token_id[idx])
                assert result.outputs[0].logprobs is not None
                logprobs_entry = result.outputs[0].logprobs[0]
                new_beams.append(
                    BeamSearchSequence(
                        tokens=current_beam.tokens + [token_id],
                        logprobs=current_beam.logprobs + [logprobs_entry],
                        lora_request=current_beam.lora_request,
                        cum_logprob=float(all_beams_logprob[idx]),
                        multi_modal_data=current_beam.multi_modal_data,
                        mm_processor_kwargs=current_beam.mm_processor_kwargs,
                    )
                )

            all_beams = new_beams
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526

        completed.extend(all_beams)
        sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
        best_beams = sorted_completed[:beam_width]

        for beam in best_beams:
            if beam.tokens[-1] == eos_token_id and not ignore_eos:
                # Skip the eos token in the text.
                tokens = beam.tokens[tokenized_length:-1]
            else:
                tokens = beam.tokens[tokenized_length:]
            beam.text = tokenizer.decode(tokens)

        yield RequestOutput(
            request_id=request_id,
            prompt=prompt_text,
            outputs=[
                CompletionOutput(
                    text=beam.text,  # type: ignore
                    cumulative_logprob=beam.cum_logprob,
                    token_ids=beam.tokens[tokenized_length:],
                    index=i,
                    logprobs=beam.logprobs,
                    finish_reason=beam.finish_reason
                    if beam.finish_reason is not None
                    else "length",
                    stop_reason=beam.stop_reason,
                )
                for (i, beam) in enumerate(best_beams)
            ],
            finished=True,
            prompt_token_ids=prompt_token_ids,
            prompt_logprobs=None,
        )
527

528
529
530
    async def _preprocess(
        self,
        ctx: ServeContext,
531
    ) -> ErrorResponse | None:
532
533
534
535
536
537
538
539
540
        """
        Default preprocessing hook. Subclasses may override
        to prepare `ctx` (classification, embedding, etc.).
        """
        return None

    def _build_response(
        self,
        ctx: ServeContext,
541
    ) -> AnyResponse | ErrorResponse:
542
543
544
545
546
547
548
549
550
        """
        Default response builder. Subclass may override this method
        to return the appropriate response object.
        """
        return self.create_error_response("unimplemented endpoint")

    async def handle(
        self,
        ctx: ServeContext,
551
    ) -> AnyResponse | ErrorResponse:
552
        async for response in self._pipeline(ctx):
553
554
555
556
557
558
559
            return response

        return self.create_error_response("No response yielded from pipeline")

    async def _pipeline(
        self,
        ctx: ServeContext,
560
    ) -> AsyncGenerator[AnyResponse | ErrorResponse, None]:
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
        """Execute the request processing pipeline yielding responses."""
        if error := await self._check_model(ctx.request):
            yield error
        if error := self._validate_request(ctx):
            yield error

        preprocess_ret = await self._preprocess(ctx)
        if isinstance(preprocess_ret, ErrorResponse):
            yield preprocess_ret

        generators_ret = await self._prepare_generators(ctx)
        if isinstance(generators_ret, ErrorResponse):
            yield generators_ret

        collect_ret = await self._collect_batch(ctx)
        if isinstance(collect_ret, ErrorResponse):
            yield collect_ret

        yield self._build_response(ctx)

581
    def _validate_request(self, ctx: ServeContext) -> ErrorResponse | None:
582
        truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None)
583

584
585
586
587
        if (
            truncate_prompt_tokens is not None
            and truncate_prompt_tokens > self.max_model_len
        ):
588
589
590
            return self.create_error_response(
                "truncate_prompt_tokens value is "
                "greater than max_model_len."
591
592
                " Please, select a smaller truncation size."
            )
593
594
        return None

595
596
597
    def _create_pooling_params(
        self,
        ctx: ServeContext,
598
    ) -> PoolingParams | ErrorResponse:
599
600
        if not hasattr(ctx.request, "to_pooling_params"):
            return self.create_error_response(
601
602
                "Request type does not support pooling parameters"
            )
603
604
605

        return ctx.request.to_pooling_params()

606
607
608
    async def _prepare_generators(
        self,
        ctx: ServeContext,
609
    ) -> ErrorResponse | None:
610
        """Schedule the request and get the result generator."""
611
        generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
612
613

        try:
614
615
616
617
618
            trace_headers = (
                None
                if ctx.raw_request is None
                else await self._get_trace_headers(ctx.raw_request.headers)
            )
619

620
621
622
            pooling_params = self._create_pooling_params(ctx)
            if isinstance(pooling_params, ErrorResponse):
                return pooling_params
623
624

            if ctx.engine_prompts is None:
625
                return self.create_error_response("Engine prompts not available")
626
627
628
629

            for i, engine_prompt in enumerate(ctx.engine_prompts):
                request_id_item = f"{ctx.request_id}-{i}"

630
631
                self._log_inputs(
                    request_id_item,
632
                    engine_prompt,
633
634
635
                    params=pooling_params,
                    lora_request=ctx.lora_request,
                )
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652

                generator = self.engine_client.encode(
                    engine_prompt,
                    pooling_params,
                    request_id_item,
                    lora_request=ctx.lora_request,
                    trace_headers=trace_headers,
                    priority=getattr(ctx.request, "priority", 0),
                )

                generators.append(generator)

            ctx.result_generator = merge_async_iterators(*generators)

            return None

        except Exception as e:
653
            return self.create_error_response(e)
654
655
656
657

    async def _collect_batch(
        self,
        ctx: ServeContext,
658
    ) -> ErrorResponse | None:
659
660
661
        """Collect batch results from the result generator."""
        try:
            if ctx.engine_prompts is None:
662
                return self.create_error_response("Engine prompts not available")
663
664

            num_prompts = len(ctx.engine_prompts)
665
            final_res_batch: list[PoolingRequestOutput | None]
666
667
668
            final_res_batch = [None] * num_prompts

            if ctx.result_generator is None:
669
                return self.create_error_response("Result generator not available")
670
671
672
673
674
675

            async for i, res in ctx.result_generator:
                final_res_batch[i] = res

            if None in final_res_batch:
                return self.create_error_response(
676
677
                    "Failed to generate results for all prompts"
                )
678

679
            ctx.final_res_batch = [res for res in final_res_batch if res is not None]
680
681
682
683

            return None

        except Exception as e:
684
            return self.create_error_response(e)
685

686
    def create_error_response(
687
        self,
688
        message: str | Exception,
689
690
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
691
        param: str | None = None,
692
    ) -> ErrorResponse:
693
694
695
696
697
        exc: Exception | None = None

        if isinstance(message, Exception):
            exc = message

698
            from vllm.exceptions import VLLMValidationError
699
700
701
702
703

            if isinstance(exc, VLLMValidationError):
                err_type = "BadRequestError"
                status_code = HTTPStatus.BAD_REQUEST
                param = exc.parameter
704
            elif isinstance(exc, (ValueError, TypeError, RuntimeError, OverflowError)):
705
706
707
708
                # Common validation errors from user input
                err_type = "BadRequestError"
                status_code = HTTPStatus.BAD_REQUEST
                param = None
709
710
711
712
            elif isinstance(exc, NotImplementedError):
                err_type = "NotImplementedError"
                status_code = HTTPStatus.NOT_IMPLEMENTED
                param = None
713
714
715
716
717
718
719
720
721
722
723
724
            elif exc.__class__.__name__ == "TemplateError":
                # jinja2.TemplateError (avoid importing jinja2)
                err_type = "BadRequestError"
                status_code = HTTPStatus.BAD_REQUEST
                param = None
            else:
                err_type = "InternalServerError"
                status_code = HTTPStatus.INTERNAL_SERVER_ERROR
                param = None

            message = str(exc)

725
726
727
728
729
730
        if self.log_error_stack:
            exc_type, _, _ = sys.exc_info()
            if exc_type is not None:
                traceback.print_exc()
            else:
                traceback.print_stack()
731

732
        return ErrorResponse(
733
            error=ErrorInfo(
734
                message=sanitize_message(message),
735
736
737
738
                type=err_type,
                code=status_code.value,
                param=param,
            )
739
        )
740

741
    def create_streaming_error_response(
742
        self,
743
        message: str | Exception,
744
745
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
746
        param: str | None = None,
747
    ) -> str:
748
        json_str = json.dumps(
749
            self.create_error_response(
750
751
752
753
                message=message,
                err_type=err_type,
                status_code=status_code,
                param=param,
754
755
            ).model_dump()
        )
756
757
        return json_str

758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
    def _raise_if_error(self, finish_reason: str | None, request_id: str) -> None:
        """Raise GenerationError if finish_reason indicates an error."""
        if finish_reason == "error":
            logger.error(
                "Request %s failed with an internal error during generation",
                request_id,
            )
            raise GenerationError("Internal server error")

    def _convert_generation_error_to_response(
        self, e: GenerationError
    ) -> ErrorResponse:
        """Convert GenerationError to ErrorResponse."""
        return self.create_error_response(
            str(e),
            err_type="InternalServerError",
            status_code=e.status_code,
        )

    def _convert_generation_error_to_streaming_response(
        self, e: GenerationError
    ) -> str:
        """Convert GenerationError to streaming error response."""
        return self.create_streaming_error_response(
            str(e),
            err_type="InternalServerError",
            status_code=e.status_code,
        )

787
    async def _check_model(
788
789
        self,
        request: AnyRequest,
790
    ) -> ErrorResponse | None:
791
792
        error_response = None

793
        if self._is_model_supported(request.model):
794
            return None
795
        if request.model in self.models.lora_requests:
796
            return None
797
798
799
800
801
        if (
            envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING
            and request.model
            and (load_result := await self.models.resolve_lora(request.model))
        ):
802
803
            if isinstance(load_result, LoRARequest):
                return None
804
805
806
807
            if (
                isinstance(load_result, ErrorResponse)
                and load_result.error.code == HTTPStatus.BAD_REQUEST.value
            ):
808
809
810
                error_response = load_result

        return error_response or self.create_error_response(
811
812
            message=f"The model `{request.model}` does not exist.",
            err_type="NotFoundError",
813
            status_code=HTTPStatus.NOT_FOUND,
814
            param="model",
815
        )
816

817
    def _get_active_default_mm_loras(self, request: AnyRequest) -> LoRARequest | None:
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
        """Determine if there are any active default multimodal loras."""
        # TODO: Currently this is only enabled for chat completions
        # to be better aligned with only being enabled for .generate
        # when run offline. It would be nice to support additional
        # tasks types in the future.
        message_types = self._get_message_types(request)
        default_mm_loras = set()

        for lora in self.models.lora_requests.values():
            # Best effort match for default multimodal lora adapters;
            # There is probably a better way to do this, but currently
            # this matches against the set of 'types' in any content lists
            # up until '_', e.g., to match audio_url -> audio
            if lora.lora_name in message_types:
                default_mm_loras.add(lora)

        # Currently only support default modality specific loras if
        # we have exactly one lora matched on the request.
        if len(default_mm_loras) == 1:
            return default_mm_loras.pop()
        return None

840
    def _maybe_get_adapters(
841
842
843
        self,
        request: AnyRequest,
        supports_default_mm_loras: bool = False,
844
    ) -> LoRARequest | None:
845
        if request.model in self.models.lora_requests:
846
            return self.models.lora_requests[request.model]
847
848
849
850
851
852

        # Currently only support default modality specific loras
        # if we have exactly one lora matched on the request.
        if supports_default_mm_loras:
            default_mm_lora = self._get_active_default_mm_loras(request)
            if default_mm_lora is not None:
853
                return default_mm_lora
854
855

        if self._is_model_supported(request.model):
856
            return None
857

858
        # if _check_model has been called earlier, this will be unreachable
859
        raise ValueError(f"The model `{request.model}` does not exist.")
860

861
862
863
864
865
866
867
868
869
870
    def _get_message_types(self, request: AnyRequest) -> set[str]:
        """Retrieve the set of types from message content dicts up
        until `_`; we use this to match potential multimodal data
        with default per modality loras.
        """
        message_types: set[str] = set()

        if not hasattr(request, "messages"):
            return message_types

871
872
873
874
875
        messages = request.messages
        if messages is None or isinstance(messages, (str, bytes)):
            return message_types

        for message in messages:
876
877
878
879
880
            if (
                isinstance(message, dict)
                and "content" in message
                and isinstance(message["content"], list)
            ):
881
882
883
884
885
                for content_dict in message["content"]:
                    if "type" in content_dict:
                        message_types.add(content_dict["type"].split("_")[0])
        return message_types

886
887
    def _validate_input(
        self,
888
        request: object,
889
        input_ids: list[int],
890
        input_text: str,
891
    ) -> TokensPrompt:
892
893
        token_num = len(input_ids)

894
895
        # Note: EmbeddingRequest, ClassificationRequest,
        # and ScoreRequest doesn't have max_tokens
896
        if isinstance(
897
            request,
898
899
900
            (
                EmbeddingChatRequest,
                EmbeddingCompletionRequest,
901
902
903
                ScoreDataRequest,
                ScoreTextRequest,
                ScoreQueriesDocumentsRequest,
904
                RerankRequest,
905
906
                ClassificationCompletionRequest,
                ClassificationChatRequest,
907
908
            ),
        ):
909
910
            # Note: input length can be up to the entire model context length
            # since these requests don't generate tokens.
911
            if token_num > self.max_model_len:
912
                operations: dict[type[AnyRequest], str] = {
913
914
915
                    ScoreDataRequest: "score",
                    ScoreTextRequest: "score",
                    ScoreQueriesDocumentsRequest: "score",
916
917
                    ClassificationCompletionRequest: "classification",
                    ClassificationChatRequest: "classification",
918
                }
919
                operation = operations.get(type(request), "embedding generation")
920
                raise VLLMValidationError(
921
922
                    f"This model's maximum context length is "
                    f"{self.max_model_len} tokens. However, you requested "
923
                    f"{token_num} tokens in the input for {operation}. "
924
925
926
                    f"Please reduce the length of the input.",
                    parameter="input_tokens",
                    value=token_num,
927
                )
928
            return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
929

930
931
        # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
        # and does not require model context length validation
932
        if isinstance(
933
934
            request,
            (TokenizeCompletionRequest, TokenizeChatRequest, DetokenizeRequest),
935
        ):
936
            return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
937

938
939
940
941
942
        # chat completion endpoint supports max_completion_tokens
        if isinstance(request, ChatCompletionRequest):
            # TODO(#9845): remove max_tokens when field dropped from OpenAI API
            max_tokens = request.max_completion_tokens or request.max_tokens
        else:
943
            max_tokens = getattr(request, "max_tokens", None)
944
945
946
947

        # Note: input length can be up to model context length - 1 for
        # completion-like requests.
        if token_num >= self.max_model_len:
948
            raise VLLMValidationError(
949
                f"This model's maximum context length is "
950
951
                f"{self.max_model_len} tokens. However, your request has "
                f"{token_num} input tokens. Please reduce the length of "
952
953
954
                "the input messages.",
                parameter="input_tokens",
                value=token_num,
955
            )
956

957
        if max_tokens is not None and token_num + max_tokens > self.max_model_len:
958
            raise VLLMValidationError(
959
960
961
962
                "'max_tokens' or 'max_completion_tokens' is too large: "
                f"{max_tokens}. This model's maximum context length is "
                f"{self.max_model_len} tokens and your request has "
                f"{token_num} input tokens ({max_tokens} > {self.max_model_len}"
963
964
965
                f" - {token_num}).",
                parameter="max_tokens",
                value=max_tokens,
966
            )
967

968
        return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
969

970
971
    def _validate_chat_template(
        self,
972
973
        request_chat_template: str | None,
        chat_template_kwargs: dict[str, Any] | None,
974
        trust_request_chat_template: bool,
975
    ) -> ErrorResponse | None:
976
        if not trust_request_chat_template and (
977
978
979
980
981
982
            request_chat_template is not None
            or (
                chat_template_kwargs
                and chat_template_kwargs.get("chat_template") is not None
            )
        ):
983
984
985
            return self.create_error_response(
                "Chat template is passed with request, but "
                "--trust-request-chat-template is not set. "
986
987
                "Refused request with untrusted chat template."
            )
988
989
        return None

990
991
992
993
994
995
996
997
998
999
1000
1001
    @staticmethod
    def _prepare_extra_chat_template_kwargs(
        request_chat_template_kwargs: dict[str, Any] | None = None,
        default_chat_template_kwargs: dict[str, Any] | None = None,
    ) -> dict[str, Any]:
        """Helper to merge server-default and request-specific chat template kwargs."""
        request_chat_template_kwargs = request_chat_template_kwargs or {}
        if default_chat_template_kwargs is None:
            return request_chat_template_kwargs
        # Apply server defaults first, then request kwargs override.
        return default_chat_template_kwargs | request_chat_template_kwargs

1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
    async def _preprocess_completion(
        self,
        request: RendererRequest,
        prompt_input: str | list[str] | list[int] | list[list[int]] | None,
        prompt_embeds: bytes | list[bytes] | None,
    ) -> list[TokensPrompt | EmbedsPrompt]:
        renderer = self.renderer
        tok_params = request.build_tok_params(self.model_config)

        in_prompts = await renderer.render_completions_async(
            prompt_input, prompt_embeds
        )
        engine_prompts = await renderer.tokenize_prompts_async(in_prompts, tok_params)

        extra_items = {
            k: v
            for k in ("mm_processor_kwargs", "cache_salt")
            if (v := getattr(request, k, None)) is not None
        }
        for prompt in engine_prompts:
            prompt.update(extra_items)  # type: ignore

        return engine_prompts

1026
1027
    async def _preprocess_chat(
        self,
1028
        request: RendererChatRequest,
1029
        messages: list[ChatCompletionMessageParam],
1030
1031
1032
        default_template: str | None,
        default_template_content_format: ChatTemplateContentFormatOption,
        default_template_kwargs: dict[str, Any] | None,
1033
        tool_dicts: list[dict[str, Any]] | None = None,
1034
        tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
1035
    ) -> tuple[list[ConversationMessage], list[TokensPrompt | EmbedsPrompt]]:
1036
        from vllm.tokenizers.mistral import MistralTokenizer
1037

1038
1039
1040
1041
1042
1043
1044
        renderer = self.renderer

        default_template_kwargs = merge_kwargs(
            default_template_kwargs,
            dict(
                tools=tool_dicts,
                tokenize=isinstance(renderer.tokenizer, MistralTokenizer),
1045
1046
1047
            ),
        )

1048
1049
1050
1051
        tok_params = request.build_tok_params(self.model_config)
        chat_params = request.build_chat_params(
            default_template, default_template_content_format
        ).with_defaults(default_template_kwargs)
1052

1053
1054
1055
1056
        conversation, prompt = await renderer.render_messages_async(
            messages, chat_params
        )
        engine_prompt = await renderer.tokenize_prompt_async(prompt, tok_params)
1057

1058
1059
1060
1061
1062
1063
        extra_items = {
            k: v
            for k in ("mm_processor_kwargs", "cache_salt")
            if (v := getattr(request, k, None)) is not None
        }
        engine_prompt.update(extra_items)  # type: ignore
1064

1065
1066
1067
        # tool parsing is done only if a tool_parser has been set and if
        # tool_choice is not "none" (if tool_choice is "none" but a tool_parser
        # is set, we want to prevent parsing a tool_call hallucinated by the LLM
1068
1069
1070
1071
1072
1073
1074
1075
1076
        if tool_parser is not None:
            tool_choice = getattr(request, "tool_choice", "none")
            if tool_choice != "none":
                if not isinstance(request, ChatCompletionRequest | ResponsesRequest):
                    msg = (
                        "Tool usage is only supported for Chat Completions API "
                        "or Responses API requests."
                    )
                    raise NotImplementedError(msg)
1077

1078
1079
1080
                # TODO: Update adjust_request to accept ResponsesRequest
                tokenizer = renderer.get_tokenizer()
                request = tool_parser(tokenizer).adjust_request(request=request)  # type: ignore[arg-type]
1081

1082
        return conversation, [engine_prompt]
1083

1084
1085
1086
1087
1088
    async def _render_next_turn(
        self,
        request: ResponsesRequest,
        messages: list[ResponseInputOutputItem],
        tool_dicts: list[dict[str, Any]] | None,
1089
        tool_parser: Callable[[TokenizerLike], ToolParser] | None,
1090
1091
1092
1093
1094
1095
1096
        chat_template: str | None,
        chat_template_content_format: ChatTemplateContentFormatOption,
    ):
        new_messages = construct_input_messages(
            request_input=messages,
        )

1097
        _, engine_prompts = await self._preprocess_chat(
1098
1099
            request,
            new_messages,
1100
1101
1102
            default_template=chat_template,
            default_template_content_format=chat_template_content_format,
            default_template_kwargs=None,
1103
1104
1105
            tool_dicts=tool_dicts,
            tool_parser=tool_parser,
        )
1106
        return engine_prompts
1107

1108
1109
1110
    async def _generate_with_builtin_tools(
        self,
        request_id: str,
1111
        engine_prompt: TokensPrompt | EmbedsPrompt,
1112
        sampling_params: SamplingParams,
1113
        tok_params: TokenizeParams,
1114
        context: ConversationContext,
1115
        lora_request: LoRARequest | None = None,
1116
        priority: int = 0,
1117
        trace_headers: Mapping[str, str] | None = None,
1118
    ):
1119
        prompt_text, _, _ = get_prompt_components(engine_prompt)
1120

1121
        orig_priority = priority
1122
        sub_request = 0
1123
        while True:
1124
1125
            # Ensure that each sub-request has a unique request id.
            sub_request_id = f"{request_id}_{sub_request}"
1126

1127
            self._log_inputs(
1128
                sub_request_id,
1129
                engine_prompt,
1130
1131
1132
                params=sampling_params,
                lora_request=lora_request,
            )
1133
1134
1135

            tokenization_kwargs = tok_params.get_encode_kwargs()
            engine_request = self.input_processor.process_inputs(
1136
                sub_request_id,
1137
1138
                engine_prompt,
                sampling_params,
1139
                lora_request=lora_request,
1140
                tokenization_kwargs=tokenization_kwargs,
1141
1142
                trace_headers=trace_headers,
                priority=priority,
1143
            )
1144
1145
1146
1147

            generator = self.engine_client.generate(
                engine_request,
                sampling_params,
1148
                sub_request_id,
1149
                lora_request=lora_request,
1150
                trace_headers=trace_headers,
1151
                priority=priority,
1152
1153
                prompt_text=prompt_text,
                tokenization_kwargs=tokenization_kwargs,
1154
            )
1155

1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
            async for res in generator:
                context.append_output(res)
                # NOTE(woosuk): The stop condition is handled by the engine.
                yield context

            if not context.need_builtin_tool_call():
                # The model did not ask for a tool call, so we're done.
                break

            # Call the tool and update the context with the result.
            tool_output = await context.call_tool()
1167
            context.append_tool_output(tool_output)
1168
1169
1170
1171
1172

            # TODO: uncomment this and enable tool output streaming
            # yield context

            # Create inputs for the next turn.
1173
            # Render the next prompt token ids and update sampling_params.
1174
            if isinstance(context, (HarmonyContext, StreamingHarmonyContext)):
1175
1176
1177
1178
                token_ids = context.render_for_completion()
                engine_prompt = TokensPrompt(prompt_token_ids=token_ids)

                sampling_params.max_tokens = self.max_model_len - len(token_ids)
1179
            elif isinstance(context, ParsableContext):
1180
                engine_prompts = await self._render_next_turn(
1181
1182
1183
1184
1185
1186
1187
1188
                    context.request,
                    context.parser.response_messages,
                    context.tool_dicts,
                    context.tool_parser_cls,
                    context.chat_template,
                    context.chat_template_content_format,
                )
                engine_prompt = engine_prompts[0]
1189
1190
1191
1192
1193
1194
1195
1196
                prompt_text, _, _ = get_prompt_components(engine_prompt)

                sampling_params.max_tokens = get_max_tokens(
                    self.max_model_len,
                    context.request,
                    engine_prompt,
                    self.default_sampling_params,  # type: ignore
                )
1197

1198
1199
            # OPTIMIZATION
            priority = orig_priority - 1
1200
            sub_request += 1
1201

1202
1203
1204
    def _log_inputs(
        self,
        request_id: str,
1205
        inputs: PromptType,
1206
1207
        params: SamplingParams | PoolingParams | BeamSearchParams | None,
        lora_request: LoRARequest | None,
1208
1209
1210
    ) -> None:
        if self.request_logger is None:
            return
1211

1212
        prompt, prompt_token_ids, prompt_embeds = get_prompt_components(inputs)
1213
1214
1215
1216
1217

        self.request_logger.log_inputs(
            request_id,
            prompt,
            prompt_token_ids,
1218
            prompt_embeds,
1219
1220
1221
            params=params,
            lora_request=lora_request,
        )
1222

1223
1224
1225
    async def _get_trace_headers(
        self,
        headers: Headers,
1226
    ) -> Mapping[str, str] | None:
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
        is_tracing_enabled = await self.engine_client.is_tracing_enabled()

        if is_tracing_enabled:
            return extract_trace_headers(headers)

        if contains_trace_headers(headers):
            log_tracing_disabled_warning()

        return None

1237
    @staticmethod
1238
    def _base_request_id(
1239
1240
        raw_request: Request | None, default: str | None = None
    ) -> str | None:
1241
        """Pulls the request id to use from a header, if provided"""
1242
1243
1244
1245
        if raw_request is not None and (
            (req_id := raw_request.headers.get("X-Request-Id")) is not None
        ):
            return req_id
1246

1247
        return random_uuid() if default is None else default
1248

1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
    @staticmethod
    def _get_data_parallel_rank(raw_request: Request | None) -> int | None:
        """Pulls the data parallel rank from a header, if provided"""
        if raw_request is None:
            return None

        rank_str = raw_request.headers.get("X-data-parallel-rank")
        if rank_str is None:
            return None

        try:
            return int(rank_str)
        except ValueError:
            return None

1264
1265
1266
    @staticmethod
    def _parse_tool_calls_from_content(
        request: ResponsesRequest | ChatCompletionRequest,
1267
        tokenizer: TokenizerLike | None,
1268
        enable_auto_tools: bool,
1269
        tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
        content: str | None = None,
    ) -> tuple[list[FunctionCall] | None, str | None]:
        function_calls = list[FunctionCall]()
        if request.tool_choice and isinstance(request.tool_choice, ToolChoiceFunction):
            assert content is not None
            # Forced Function Call
            function_calls.append(
                FunctionCall(name=request.tool_choice.name, arguments=content)
            )
            content = None  # Clear content since tool is called.
        elif request.tool_choice and isinstance(
            request.tool_choice, ChatCompletionNamedToolChoiceParam
        ):
            assert content is not None
            # Forced Function Call
            function_calls.append(
                FunctionCall(name=request.tool_choice.function.name, arguments=content)
            )
            content = None  # Clear content since tool is called.
        elif request.tool_choice == "required":
            assert content is not None
            tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(content)
            function_calls.extend(
                [
                    FunctionCall(
                        name=tool_call.name,
                        arguments=json.dumps(tool_call.parameters, ensure_ascii=False),
                    )
                    for tool_call in tool_calls
                ]
            )
            content = None  # Clear content since tool is called.
        elif (
            tool_parser_cls
            and enable_auto_tools
            and (request.tool_choice == "auto" or request.tool_choice is None)
        ):
1307
1308
1309
1310
1311
            if tokenizer is None:
                raise ValueError(
                    "Tokenizer not available when `skip_tokenizer_init=True`"
                )

1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
            # Automatic Tool Call Parsing
            try:
                tool_parser = tool_parser_cls(tokenizer)
            except RuntimeError as e:
                logger.exception("Error in tool parser creation.")
                raise e
            tool_call_info = tool_parser.extract_tool_calls(
                content if content is not None else "",
                request=request,  # type: ignore
            )
            if tool_call_info is not None and tool_call_info.tools_called:
                # extract_tool_calls() returns a list of tool calls.
                function_calls.extend(
                    FunctionCall(
1326
                        id=tool_call.id,
1327
1328
1329
1330
1331
1332
                        name=tool_call.function.name,
                        arguments=tool_call.function.arguments,
                    )
                    for tool_call in tool_call_info.tool_calls
                )
                content = tool_call_info.content
1333
1334
                if content and content.strip() == "":
                    content = None
1335
1336
1337
1338
1339
1340
            else:
                # No tool calls.
                return None, content

        return function_calls, content

1341
    @staticmethod
1342
1343
1344
    def _get_decoded_token(
        logprob: Logprob,
        token_id: int,
1345
        tokenizer: TokenizerLike | None,
1346
1347
        return_as_token_id: bool = False,
    ) -> str:
1348
1349
1350
        if return_as_token_id:
            return f"token_id:{token_id}"

1351
1352
        if logprob.decoded_token is not None:
            return logprob.decoded_token
1353
1354
1355
1356
1357
1358

        if tokenizer is None:
            raise ValueError(
                "Unable to get tokenizer because `skip_tokenizer_init=True`"
            )

1359
        return tokenizer.decode([token_id])
1360

1361
    def _is_model_supported(self, model_name: str | None) -> bool:
1362
1363
        if not model_name:
            return True
1364
        return self.models.is_base_model(model_name)
1365

1366
1367

def clamp_prompt_logprobs(
1368
1369
    prompt_logprobs: PromptLogprobs | None,
) -> PromptLogprobs | None:
1370
1371
1372
1373
1374
1375
1376
    if prompt_logprobs is None:
        return prompt_logprobs

    for logprob_dict in prompt_logprobs:
        if logprob_dict is None:
            continue
        for logprob_values in logprob_dict.values():
1377
            if logprob_values.logprob == float("-inf"):
1378
1379
                logprob_values.logprob = -9999.0
    return prompt_logprobs