serving_engine.py 45.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
import json
5
import sys
6
import time
7
import traceback
8
from collections.abc import AsyncGenerator, Callable, Iterable, Mapping, Sequence
9
from concurrent.futures import ThreadPoolExecutor
10
from http import HTTPStatus
11
from typing import Any, ClassVar, Generic, TypeAlias, TypeVar
12

13
import torch
14
from fastapi import Request
15
from pydantic import BaseModel, ConfigDict, Field
16
from starlette.datastructures import Headers
17
18
from typing_extensions import TypeIs

19
20
21
22
23
if sys.version_info >= (3, 12):
    from typing import TypedDict
else:
    from typing_extensions import TypedDict

24
import vllm.envs as envs
25
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
26
from vllm.engine.protocol import EngineClient
27
28
29
30
31
32
33
34
35
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ChatTemplateContentFormatOption,
    ConversationMessage,
    apply_hf_chat_template,
    apply_mistral_chat_template,
    parse_chat_messages_futures,
    resolve_chat_template_content_format,
)
36
from vllm.entrypoints.context import ConversationContext
37
from vllm.entrypoints.logger import RequestLogger
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from vllm.entrypoints.openai.protocol import (
    ChatCompletionRequest,
    ChatCompletionResponse,
    ClassificationRequest,
    ClassificationResponse,
    CompletionRequest,
    CompletionResponse,
    DetokenizeRequest,
    EmbeddingChatRequest,
    EmbeddingCompletionRequest,
    EmbeddingRequest,
    EmbeddingResponse,
    ErrorInfo,
    ErrorResponse,
    IOProcessorRequest,
    PoolingResponse,
    RerankRequest,
    ResponsesRequest,
    ScoreRequest,
    ScoreResponse,
    TokenizeChatRequest,
    TokenizeCompletionRequest,
    TokenizeResponse,
    TranscriptionRequest,
    TranscriptionResponse,
    TranslationRequest,
)
65
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
66
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
67
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig
68
from vllm.entrypoints.utils import _validate_truncation_size
69
from vllm.inputs.data import PromptType
70
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
71
72
73
74
75
from vllm.inputs.parse import (
    PromptComponents,
    get_prompt_components,
    is_explicit_encoder_decoder_prompt,
)
76
from vllm.logger import init_logger
77
from vllm.logprobs import Logprob, PromptLogprobs
78
from vllm.lora.request import LoRARequest
79
from vllm.multimodal import (  # noqa: F401 - Required to resolve Pydantic error in RequestProcessingMixin
80
81
82
    MultiModalDataDict,
    MultiModalUUIDDict,
)
83
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
84
from vllm.pooling_params import PoolingParams
85
from vllm.reasoning import ReasoningParser, ReasoningParserManager
86
from vllm.sampling_params import BeamSearchParams, SamplingParams
87
88
89
90
91
from vllm.tracing import (
    contains_trace_headers,
    extract_trace_headers,
    log_tracing_disabled_warning,
)
92
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
93
94
from vllm.utils import (
    AsyncMicrobatchTokenizer,
95
    collect_from_async_generator,
96
97
98
99
100
    is_list_of,
    make_async,
    merge_async_iterators,
    random_uuid,
)
101
from vllm.v1.engine import EngineCoreRequest
102
103
104

logger = init_logger(__name__)

105
106
107
108
109
110
111
112
113
CompletionLikeRequest: TypeAlias = (
    CompletionRequest
    | DetokenizeRequest
    | EmbeddingCompletionRequest
    | RerankRequest
    | ClassificationRequest
    | ScoreRequest
    | TokenizeCompletionRequest
)
114

115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
ChatLikeRequest: TypeAlias = (
    ChatCompletionRequest | EmbeddingChatRequest | TokenizeChatRequest
)
SpeechToTextRequest: TypeAlias = TranscriptionRequest | TranslationRequest
AnyRequest: TypeAlias = (
    CompletionLikeRequest
    | ChatLikeRequest
    | SpeechToTextRequest
    | ResponsesRequest
    | IOProcessorRequest
)

AnyResponse: TypeAlias = (
    CompletionResponse
    | ChatCompletionResponse
    | EmbeddingResponse
    | TranscriptionResponse
    | TokenizeResponse
    | PoolingResponse
    | ClassificationResponse
    | ScoreResponse
)
137

138
139
140

class TextTokensPrompt(TypedDict):
    prompt: str
141
    prompt_token_ids: list[int]
142
143


144
145
146
147
class EmbedsPrompt(TypedDict):
    prompt_embeds: torch.Tensor


148
RequestPrompt: TypeAlias = list[int] | str | TextTokensPrompt | EmbedsPrompt
149
150
151


def is_text_tokens_prompt(prompt: RequestPrompt) -> TypeIs[TextTokensPrompt]:
152
153
154
155
156
    return (
        isinstance(prompt, dict)
        and "prompt_token_ids" in prompt
        and "prompt_embeds" not in prompt
    )
157
158
159


def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]:
160
161
162
163
164
    return (
        isinstance(prompt, dict)
        and "prompt_token_ids" not in prompt
        and "prompt_embeds" in prompt
    )
165

166

167
168
169
170
171
RequestT = TypeVar("RequestT", bound=AnyRequest)


class RequestProcessingMixin(BaseModel):
    """
172
    Mixin for request processing,
173
174
    handling prompt preparation and engine input.
    """
175

176
177
    request_prompts: Sequence[RequestPrompt] | None = []
    engine_prompts: list[EngineTokensPrompt] | None = []
178
179
180
181
182
183

    model_config = ConfigDict(arbitrary_types_allowed=True)


class ResponseGenerationMixin(BaseModel):
    """
184
    Mixin for response generation,
185
186
    managing result generators and final batch results.
    """
187

188
189
190
191
    result_generator: (
        AsyncGenerator[tuple[int, RequestOutput | PoolingRequestOutput], None] | None
    ) = None
    final_res_batch: list[RequestOutput | PoolingRequestOutput] = Field(
192
193
        default_factory=list
    )
194
195
196
197

    model_config = ConfigDict(arbitrary_types_allowed=True)


198
class ServeContext(
199
200
201
202
    RequestProcessingMixin,
    ResponseGenerationMixin,
    BaseModel,
    Generic[RequestT],
203
):
204
205
    # Shared across all requests
    request: RequestT
206
    raw_request: Request | None = None
207
208
209
    model_name: str
    request_id: str
    created_time: int = Field(default_factory=lambda: int(time.time()))
210
    lora_request: LoRARequest | None = None
211
212

    # Shared across most requests
213
    tokenizer: AnyTokenizer | None = None
214
215
216
217
218
219
220
221
222
223
224
225
226

    # `protected_namespaces` resolves Pydantic v2's warning
    # on conflict with protected namespace "model_"
    model_config = ConfigDict(
        protected_namespaces=(),
        arbitrary_types_allowed=True,
    )


ClassificationServeContext = ServeContext[ClassificationRequest]


class EmbeddingServeContext(ServeContext[EmbeddingRequest]):
227
    chat_template: str | None = None
228
229
230
231
232
233
234
235
236
237
    chat_template_content_format: ChatTemplateContentFormatOption


# Used to resolve the Pydantic error related to
# forward reference of MultiModalDataDict in TokensPrompt
RequestProcessingMixin.model_rebuild()
ServeContext.model_rebuild()
ClassificationServeContext.model_rebuild()
EmbeddingServeContext.model_rebuild()

238

239
class OpenAIServing:
240
241
242
243
    request_id_prefix: ClassVar[str] = """
    A short string prepended to every request’s ID (e.g. "embd", "classify")
    so you can easily tell “this ID came from Embedding vs Classification.”
    """
244

245
246
    def __init__(
        self,
247
        engine_client: EngineClient,
248
        models: OpenAIServingModels,
249
        *,
250
        request_logger: RequestLogger | None,
251
        return_tokens_as_token_ids: bool = False,
252
        enable_force_include_usage: bool = False,
253
        log_error_stack: bool = False,
254
    ):
255
256
        super().__init__()

257
        self.engine_client = engine_client
258

259
        self.models = models
260

261
        self.request_logger = request_logger
262
        self.return_tokens_as_token_ids = return_tokens_as_token_ids
263
        self.enable_force_include_usage = enable_force_include_usage
264

265
        self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
266
        self._apply_mistral_chat_template_async = make_async(
267
268
            apply_mistral_chat_template, executor=self._tokenizer_executor
        )
269

270
        self._async_tokenizer_pool: dict[AnyTokenizer, AsyncMicrobatchTokenizer] = {}
271
        self.log_error_stack = log_error_stack
272

273
274
275
276
277
        self.processor = self.models.processor
        self.io_processor = self.models.io_processor
        self.model_config = self.models.model_config
        self.max_model_len = self.model_config.max_model_len

278
    def _get_tool_parser(
279
280
        self, tool_parser_name: str | None = None, enable_auto_tools: bool = False
    ) -> Callable[[AnyTokenizer], ToolParser] | None:
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
        """Get the tool parser based on the name."""
        parser = None
        if not enable_auto_tools or tool_parser_name is None:
            return parser
        logger.info(
            '"auto" tool choice has been enabled please note that while'
            " the parallel_tool_calls client option is preset for "
            "compatibility reasons, it will be ignored."
        )

        try:
            if tool_parser_name == "pythonic" and self.model_config.model.startswith(
                "meta-llama/Llama-3.2"
            ):
                logger.warning(
                    "Llama3.2 models may struggle to emit valid pythonic tool calls"
                )
            parser = ToolParserManager.get_tool_parser(tool_parser_name)
        except Exception as e:
            raise TypeError(
                "Error: --enable-auto-tool-choice requires "
                f"tool_parser:'{tool_parser_name}' which has not "
                "been registered"
            ) from e
        return parser

    def _get_reasoning_parser(
        self,
        reasoning_parser_name: str,
310
    ) -> Callable[[AnyTokenizer], ReasoningParser] | None:
311
312
313
314
315
316
317
318
319
320
321
        """Get the reasoning parser based on the name."""
        parser = None
        if not reasoning_parser_name:
            return None
        try:
            parser = ReasoningParserManager.get_reasoning_parser(reasoning_parser_name)
            assert parser is not None
        except Exception as e:
            raise TypeError(f"{reasoning_parser_name=} has not been registered") from e
        return parser

322
323
324
325
    async def reset_mm_cache(self) -> None:
        self.processor.clear_mm_cache()
        await self.engine_client.reset_mm_cache()

326
327
328
329
330
    async def beam_search(
        self,
        prompt: PromptType,
        request_id: str,
        params: BeamSearchParams,
331
        lora_request: LoRARequest | None = None,
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
    ) -> AsyncGenerator[RequestOutput, None]:
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        ignore_eos = params.ignore_eos
        temperature = params.temperature
        length_penalty = params.length_penalty
        include_stop_str_in_output = params.include_stop_str_in_output

        processor = self.processor
        tokenizer = processor.tokenizer
        if tokenizer is None:
            raise ValueError(
                "You cannot use beam search when `skip_tokenizer_init` is True"
            )

        eos_token_id: int = tokenizer.eos_token_id  # type: ignore

        if is_explicit_encoder_decoder_prompt(prompt):
            raise NotImplementedError
        else:
            processed_inputs = processor.input_preprocessor._prompt_to_llm_inputs(
                prompt
            )
355

356
357
358
359
360
361
362
363
364
365
366
        if processed_inputs["type"] == "embeds":
            raise NotImplementedError

        # This is a workaround to fix multimodal beam search; this is a
        # bandaid fix for 2 small problems:
        # 1. Multi_modal_data on the processed_inputs currently resolves to
        #    `None`.
        # 2. preprocessing above expands the multimodal placeholders. However,
        #    this happens again in generation, so the double expansion causes
        #    a mismatch.
        # TODO - would be ideal to handle this more gracefully.
367
        prompt_text: str | None
368
        prompt_token_ids: list[int]
369
        multi_modal_data: MultiModalDataDict | None
370
371
372
373
374
375
376
377
378
        if isinstance(prompt, str):
            prompt_text = prompt
            prompt_token_ids = []
            multi_modal_data = None
        else:
            prompt_text = prompt.get("prompt")  # type: ignore
            prompt_token_ids = prompt.get("prompt_token_ids", [])  # type: ignore
            multi_modal_data = prompt.get("multi_modal_data")  # type: ignore

379
        mm_processor_kwargs: dict[str, Any] | None = processed_inputs.get(
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
            "mm_processor_kwargs"
        )  # type: ignore

        tokenized_length = len(prompt_token_ids)

        sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)

        beam_search_params = SamplingParams(
            logprobs=2 * beam_width,
            max_tokens=1,
            temperature=temperature,
        )
        all_beams = [
            BeamSearchSequence(
                tokens=prompt_token_ids,
                cum_logprob=0,
                logprobs=[],
                multi_modal_data=multi_modal_data,
                mm_processor_kwargs=mm_processor_kwargs,
                lora_request=lora_request,
            )
        ]
        completed = []

        for _ in range(max_tokens):
            prompts_batch, lora_req_batch = zip(
                *[
                    (
                        EngineTokensPrompt(
                            prompt_token_ids=beam.tokens,
                            multi_modal_data=beam.multi_modal_data,
                            mm_processor_kwargs=beam.mm_processor_kwargs,
                        ),
                        beam.lora_request,
                    )
                    for beam in all_beams
                ]
            )

            tasks = []
            request_id_batch = f"{request_id}-{random_uuid()}"

            for i, (individual_prompt, lora_req) in enumerate(
                zip(prompts_batch, lora_req_batch)
            ):
                request_id_item = f"{request_id_batch}-beam-{i}"
                task = asyncio.create_task(
                    collect_from_async_generator(
                        self.engine_client.generate(
                            individual_prompt,
                            beam_search_params,
                            request_id_item,
                            lora_request=lora_req,
                        )
                    )
                )
                tasks.append(task)

            output = [x[0] for x in await asyncio.gather(*tasks)]

            new_beams = []
            for i, current_beam in enumerate(all_beams):
                result = output[i]

                if result.outputs[0].logprobs is not None:
                    logprobs = result.outputs[0].logprobs[0]
                    for token_id, logprob_obj in logprobs.items():
                        if token_id == eos_token_id and not ignore_eos:
                            completed.append(
                                BeamSearchSequence(
                                    tokens=current_beam.tokens + [token_id]
                                    if include_stop_str_in_output
                                    else current_beam.tokens,
                                    logprobs=current_beam.logprobs + [logprobs],
                                    cum_logprob=current_beam.cum_logprob
                                    + logprob_obj.logprob,
                                    finish_reason="stop",
                                    stop_reason=eos_token_id,
                                )
                            )
                        else:
                            new_beams.append(
                                BeamSearchSequence(
                                    tokens=current_beam.tokens + [token_id],
                                    logprobs=current_beam.logprobs + [logprobs],
                                    lora_request=current_beam.lora_request,
                                    cum_logprob=current_beam.cum_logprob
                                    + logprob_obj.logprob,
                                    multi_modal_data=current_beam.multi_modal_data,
                                    mm_processor_kwargs=current_beam.mm_processor_kwargs,
                                )
                            )

            sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
            all_beams = sorted_beams[:beam_width]

        completed.extend(all_beams)
        sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
        best_beams = sorted_completed[:beam_width]

        for beam in best_beams:
            if beam.tokens[-1] == eos_token_id and not ignore_eos:
                # Skip the eos token in the text.
                tokens = beam.tokens[tokenized_length:-1]
            else:
                tokens = beam.tokens[tokenized_length:]
            beam.text = tokenizer.decode(tokens)

        yield RequestOutput(
            request_id=request_id,
            prompt=prompt_text,
            outputs=[
                CompletionOutput(
                    text=beam.text,  # type: ignore
                    cumulative_logprob=beam.cum_logprob,
                    token_ids=beam.tokens[tokenized_length:],
                    index=i,
                    logprobs=beam.logprobs,
                    finish_reason=beam.finish_reason
                    if beam.finish_reason is not None
                    else "length",
                    stop_reason=beam.stop_reason,
                )
                for (i, beam) in enumerate(best_beams)
            ],
            finished=True,
            prompt_token_ids=prompt_token_ids,
            prompt_logprobs=None,
        )
509

510
    def _get_renderer(self, tokenizer: AnyTokenizer | None) -> BaseRenderer:
511
512
513
514
515
516
517
        """
        Get a Renderer instance with the provided tokenizer.
        Uses shared async tokenizer pool for efficiency.
        """
        return CompletionRenderer(
            model_config=self.model_config,
            tokenizer=tokenizer,
518
519
            async_tokenizer_pool=self._async_tokenizer_pool,
        )
520

521
522
523
524
525
526
527
528
529
530
531
532
533
    def _build_render_config(
        self,
        request: Any,
    ) -> RenderConfig:
        """
        Build and return a `RenderConfig` for an endpoint.

        Used by the renderer to control how prompts are prepared
        (e.g., tokenization and length handling). Endpoints should
        implement this with logic appropriate to their request type.
        """
        raise NotImplementedError

534
535
    def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer:
        """
536
        Return (and cache) an `AsyncMicrobatchTokenizer` bound to the
537
538
539
540
541
542
543
        given tokenizer.
        """
        async_tokenizer = self._async_tokenizer_pool.get(tokenizer)
        if async_tokenizer is None:
            async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
            self._async_tokenizer_pool[tokenizer] = async_tokenizer
        return async_tokenizer
544

545
546
547
    async def _preprocess(
        self,
        ctx: ServeContext,
548
    ) -> ErrorResponse | None:
549
550
551
552
553
554
555
556
557
        """
        Default preprocessing hook. Subclasses may override
        to prepare `ctx` (classification, embedding, etc.).
        """
        return None

    def _build_response(
        self,
        ctx: ServeContext,
558
    ) -> AnyResponse | ErrorResponse:
559
560
561
562
563
564
565
566
567
        """
        Default response builder. Subclass may override this method
        to return the appropriate response object.
        """
        return self.create_error_response("unimplemented endpoint")

    async def handle(
        self,
        ctx: ServeContext,
568
569
    ) -> AnyResponse | ErrorResponse:
        generation: AsyncGenerator[AnyResponse | ErrorResponse, None]
570
571
572
573
574
575
576
577
578
579
        generation = self._pipeline(ctx)

        async for response in generation:
            return response

        return self.create_error_response("No response yielded from pipeline")

    async def _pipeline(
        self,
        ctx: ServeContext,
580
    ) -> AsyncGenerator[AnyResponse | ErrorResponse, None]:
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
        """Execute the request processing pipeline yielding responses."""
        if error := await self._check_model(ctx.request):
            yield error
        if error := self._validate_request(ctx):
            yield error

        preprocess_ret = await self._preprocess(ctx)
        if isinstance(preprocess_ret, ErrorResponse):
            yield preprocess_ret

        generators_ret = await self._prepare_generators(ctx)
        if isinstance(generators_ret, ErrorResponse):
            yield generators_ret

        collect_ret = await self._collect_batch(ctx)
        if isinstance(collect_ret, ErrorResponse):
            yield collect_ret

        yield self._build_response(ctx)

601
    def _validate_request(self, ctx: ServeContext) -> ErrorResponse | None:
602
        truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None)
603

604
605
606
607
        if (
            truncate_prompt_tokens is not None
            and truncate_prompt_tokens > self.max_model_len
        ):
608
609
610
            return self.create_error_response(
                "truncate_prompt_tokens value is "
                "greater than max_model_len."
611
612
                " Please, select a smaller truncation size."
            )
613
614
        return None

615
616
617
    def _create_pooling_params(
        self,
        ctx: ServeContext,
618
    ) -> PoolingParams | ErrorResponse:
619
620
        if not hasattr(ctx.request, "to_pooling_params"):
            return self.create_error_response(
621
622
                "Request type does not support pooling parameters"
            )
623
624
625

        return ctx.request.to_pooling_params()

626
627
628
    async def _prepare_generators(
        self,
        ctx: ServeContext,
629
    ) -> ErrorResponse | None:
630
        """Schedule the request and get the result generator."""
631
        generators: list[
632
            AsyncGenerator[RequestOutput | PoolingRequestOutput, None]
633
        ] = []
634
635

        try:
636
637
638
639
640
            trace_headers = (
                None
                if ctx.raw_request is None
                else await self._get_trace_headers(ctx.raw_request.headers)
            )
641

642
643
644
            pooling_params = self._create_pooling_params(ctx)
            if isinstance(pooling_params, ErrorResponse):
                return pooling_params
645
646

            if ctx.engine_prompts is None:
647
                return self.create_error_response("Engine prompts not available")
648
649
650
651

            for i, engine_prompt in enumerate(ctx.engine_prompts):
                request_id_item = f"{ctx.request_id}-{i}"

652
653
                self._log_inputs(
                    request_id_item,
654
                    engine_prompt,
655
656
657
                    params=pooling_params,
                    lora_request=ctx.lora_request,
                )
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680

                generator = self.engine_client.encode(
                    engine_prompt,
                    pooling_params,
                    request_id_item,
                    lora_request=ctx.lora_request,
                    trace_headers=trace_headers,
                    priority=getattr(ctx.request, "priority", 0),
                )

                generators.append(generator)

            ctx.result_generator = merge_async_iterators(*generators)

            return None

        except Exception as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))

    async def _collect_batch(
        self,
        ctx: ServeContext,
681
    ) -> ErrorResponse | None:
682
683
684
        """Collect batch results from the result generator."""
        try:
            if ctx.engine_prompts is None:
685
                return self.create_error_response("Engine prompts not available")
686
687

            num_prompts = len(ctx.engine_prompts)
688
            final_res_batch: list[RequestOutput | PoolingRequestOutput | None]
689
690
691
            final_res_batch = [None] * num_prompts

            if ctx.result_generator is None:
692
                return self.create_error_response("Result generator not available")
693
694
695
696
697
698

            async for i, res in ctx.result_generator:
                final_res_batch[i] = res

            if None in final_res_batch:
                return self.create_error_response(
699
700
                    "Failed to generate results for all prompts"
                )
701

702
            ctx.final_res_batch = [res for res in final_res_batch if res is not None]
703
704
705
706
707
708

            return None

        except Exception as e:
            return self.create_error_response(str(e))

709
    def create_error_response(
710
711
712
713
714
        self,
        message: str,
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
    ) -> ErrorResponse:
715
716
717
718
719
720
        if self.log_error_stack:
            exc_type, _, _ = sys.exc_info()
            if exc_type is not None:
                traceback.print_exc()
            else:
                traceback.print_stack()
721
722
723
        return ErrorResponse(
            error=ErrorInfo(message=message, type=err_type, code=status_code.value)
        )
724

725
    def create_streaming_error_response(
726
727
728
729
730
        self,
        message: str,
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
    ) -> str:
731
        json_str = json.dumps(
732
733
734
735
            self.create_error_response(
                message=message, err_type=err_type, status_code=status_code
            ).model_dump()
        )
736
737
        return json_str

738
    async def _check_model(
739
740
        self,
        request: AnyRequest,
741
    ) -> ErrorResponse | None:
742
743
        error_response = None

744
        if self._is_model_supported(request.model):
745
            return None
746
        if request.model in self.models.lora_requests:
747
            return None
748
749
750
751
752
        if (
            envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING
            and request.model
            and (load_result := await self.models.resolve_lora(request.model))
        ):
753
754
            if isinstance(load_result, LoRARequest):
                return None
755
756
757
758
            if (
                isinstance(load_result, ErrorResponse)
                and load_result.error.code == HTTPStatus.BAD_REQUEST.value
            ):
759
760
761
                error_response = load_result

        return error_response or self.create_error_response(
762
763
            message=f"The model `{request.model}` does not exist.",
            err_type="NotFoundError",
764
765
            status_code=HTTPStatus.NOT_FOUND,
        )
766

767
    def _get_active_default_mm_loras(self, request: AnyRequest) -> LoRARequest | None:
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
        """Determine if there are any active default multimodal loras."""
        # TODO: Currently this is only enabled for chat completions
        # to be better aligned with only being enabled for .generate
        # when run offline. It would be nice to support additional
        # tasks types in the future.
        message_types = self._get_message_types(request)
        default_mm_loras = set()

        for lora in self.models.lora_requests.values():
            # Best effort match for default multimodal lora adapters;
            # There is probably a better way to do this, but currently
            # this matches against the set of 'types' in any content lists
            # up until '_', e.g., to match audio_url -> audio
            if lora.lora_name in message_types:
                default_mm_loras.add(lora)

        # Currently only support default modality specific loras if
        # we have exactly one lora matched on the request.
        if len(default_mm_loras) == 1:
            return default_mm_loras.pop()
        return None

790
    def _maybe_get_adapters(
791
792
793
        self,
        request: AnyRequest,
        supports_default_mm_loras: bool = False,
794
    ) -> LoRARequest | None:
795
        if request.model in self.models.lora_requests:
796
            return self.models.lora_requests[request.model]
797
798
799
800
801
802

        # Currently only support default modality specific loras
        # if we have exactly one lora matched on the request.
        if supports_default_mm_loras:
            default_mm_lora = self._get_active_default_mm_loras(request)
            if default_mm_lora is not None:
803
                return default_mm_lora
804
805

        if self._is_model_supported(request.model):
806
            return None
807

808
        # if _check_model has been called earlier, this will be unreachable
809
        raise ValueError(f"The model `{request.model}` does not exist.")
810

811
812
813
814
815
816
817
818
819
820
821
    def _get_message_types(self, request: AnyRequest) -> set[str]:
        """Retrieve the set of types from message content dicts up
        until `_`; we use this to match potential multimodal data
        with default per modality loras.
        """
        message_types: set[str] = set()

        if not hasattr(request, "messages"):
            return message_types

        for message in request.messages:
822
823
824
825
826
            if (
                isinstance(message, dict)
                and "content" in message
                and isinstance(message["content"], list)
            ):
827
828
829
830
831
                for content_dict in message["content"]:
                    if "type" in content_dict:
                        message_types.add(content_dict["type"].split("_")[0])
        return message_types

832
    async def _normalize_prompt_text_to_input(
833
834
835
        self,
        request: AnyRequest,
        prompt: str,
836
        tokenizer: AnyTokenizer,
837
838
        add_special_tokens: bool,
    ) -> TextTokensPrompt:
839
840
        async_tokenizer = self._get_async_tokenizer(tokenizer)

841
842
843
844
        if (
            self.model_config.encoder_config is not None
            and self.model_config.encoder_config.get("do_lower_case", False)
        ):
845
846
            prompt = prompt.lower()

847
        truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
848

849
        if truncate_prompt_tokens is None:
850
            encoded = await async_tokenizer(
851
852
                prompt, add_special_tokens=add_special_tokens
            )
853
854
        elif truncate_prompt_tokens < 0:
            # Negative means we cap at the model's max length
855
856
857
858
            encoded = await async_tokenizer(
                prompt,
                add_special_tokens=add_special_tokens,
                truncation=True,
859
860
                max_length=self.max_model_len,
            )
861
        else:
862
863
864
865
            encoded = await async_tokenizer(
                prompt,
                add_special_tokens=add_special_tokens,
                truncation=True,
866
867
                max_length=truncate_prompt_tokens,
            )
868
869
870
871
872
873

        input_ids = encoded.input_ids
        input_text = prompt

        return self._validate_input(request, input_ids, input_text)

874
    async def _normalize_prompt_tokens_to_input(
875
876
        self,
        request: AnyRequest,
877
        prompt_ids: list[int],
878
        tokenizer: AnyTokenizer | None,
879
    ) -> TextTokensPrompt:
880
        truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
881

882
        if truncate_prompt_tokens is None:
883
            input_ids = prompt_ids
884
        elif truncate_prompt_tokens < 0:
885
            input_ids = prompt_ids[-self.max_model_len :]
886
887
888
        else:
            input_ids = prompt_ids[-truncate_prompt_tokens:]

889
890
891
892
893
        if tokenizer is None:
            input_text = ""
        else:
            async_tokenizer = self._get_async_tokenizer(tokenizer)
            input_text = await async_tokenizer.decode(input_ids)
894

895
896
897
898
899
        return self._validate_input(request, input_ids, input_text)

    def _validate_input(
        self,
        request: AnyRequest,
900
        input_ids: list[int],
901
902
        input_text: str,
    ) -> TextTokensPrompt:
903
904
        token_num = len(input_ids)

905
906
        # Note: EmbeddingRequest, ClassificationRequest,
        # and ScoreRequest doesn't have max_tokens
907
        if isinstance(
908
            request,
909
910
911
912
913
914
915
916
            (
                EmbeddingChatRequest,
                EmbeddingCompletionRequest,
                ScoreRequest,
                RerankRequest,
                ClassificationRequest,
            ),
        ):
917
918
            # Note: input length can be up to the entire model context length
            # since these requests don't generate tokens.
919
            if token_num > self.max_model_len:
920
921
                operations: dict[type[AnyRequest], str] = {
                    ScoreRequest: "score",
922
                    ClassificationRequest: "classification",
923
                }
924
                operation = operations.get(type(request), "embedding generation")
925
926
927
                raise ValueError(
                    f"This model's maximum context length is "
                    f"{self.max_model_len} tokens. However, you requested "
928
                    f"{token_num} tokens in the input for {operation}. "
929
930
931
                    f"Please reduce the length of the input."
                )
            return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
932

933
934
        # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
        # and does not require model context length validation
935
        if isinstance(
936
937
            request,
            (TokenizeCompletionRequest, TokenizeChatRequest, DetokenizeRequest),
938
        ):
939
            return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
940

941
942
943
944
945
        # chat completion endpoint supports max_completion_tokens
        if isinstance(request, ChatCompletionRequest):
            # TODO(#9845): remove max_tokens when field dropped from OpenAI API
            max_tokens = request.max_completion_tokens or request.max_tokens
        else:
946
            max_tokens = getattr(request, "max_tokens", None)
947
948
949
950

        # Note: input length can be up to model context length - 1 for
        # completion-like requests.
        if token_num >= self.max_model_len:
951
            raise ValueError(
952
                f"This model's maximum context length is "
953
954
                f"{self.max_model_len} tokens. However, your request has "
                f"{token_num} input tokens. Please reduce the length of "
955
956
                "the input messages."
            )
957

958
        if max_tokens is not None and token_num + max_tokens > self.max_model_len:
959
960
961
962
963
            raise ValueError(
                "'max_tokens' or 'max_completion_tokens' is too large: "
                f"{max_tokens}. This model's maximum context length is "
                f"{self.max_model_len} tokens and your request has "
                f"{token_num} input tokens ({max_tokens} > {self.max_model_len}"
964
965
                f" - {token_num})."
            )
966
967
968

        return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)

969
    async def _tokenize_prompt_input_async(
970
971
972
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
973
        prompt_input: str | list[int],
974
975
976
        add_special_tokens: bool = True,
    ) -> TextTokensPrompt:
        """
977
        A simpler implementation that tokenizes a single prompt input.
978
        """
979
        async for result in self._tokenize_prompt_inputs_async(
980
981
            request,
            tokenizer,
982
            [prompt_input],
983
            add_special_tokens=add_special_tokens,
984
985
986
        ):
            return result
        raise ValueError("No results yielded from tokenization")
987

988
    async def _tokenize_prompt_inputs_async(
989
990
991
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
992
        prompt_inputs: Iterable[str | list[int]],
993
        add_special_tokens: bool = True,
994
    ) -> AsyncGenerator[TextTokensPrompt, None]:
995
        """
996
        A simpler implementation that tokenizes multiple prompt inputs.
997
        """
998
999
        for prompt in prompt_inputs:
            if isinstance(prompt, str):
1000
                yield await self._normalize_prompt_text_to_input(
1001
                    request,
1002
1003
                    prompt=prompt,
                    tokenizer=tokenizer,
1004
1005
1006
                    add_special_tokens=add_special_tokens,
                )
            else:
1007
                yield await self._normalize_prompt_tokens_to_input(
1008
                    request,
1009
1010
                    prompt_ids=prompt,
                    tokenizer=tokenizer,
1011
1012
                )

1013
1014
    def _validate_chat_template(
        self,
1015
1016
        request_chat_template: str | None,
        chat_template_kwargs: dict[str, Any] | None,
1017
        trust_request_chat_template: bool,
1018
    ) -> ErrorResponse | None:
1019
        if not trust_request_chat_template and (
1020
1021
1022
1023
1024
1025
            request_chat_template is not None
            or (
                chat_template_kwargs
                and chat_template_kwargs.get("chat_template") is not None
            )
        ):
1026
1027
1028
            return self.create_error_response(
                "Chat template is passed with request, but "
                "--trust-request-chat-template is not set. "
1029
1030
                "Refused request with untrusted chat template."
            )
1031
1032
        return None

1033
1034
    async def _preprocess_chat(
        self,
1035
        request: ChatLikeRequest | ResponsesRequest,
1036
        tokenizer: AnyTokenizer,
1037
        messages: list[ChatCompletionMessageParam],
1038
        chat_template: str | None,
1039
        chat_template_content_format: ChatTemplateContentFormatOption,
1040
1041
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
1042
1043
1044
1045
        tool_dicts: list[dict[str, Any]] | None = None,
        documents: list[dict[str, str]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
        tool_parser: Callable[[AnyTokenizer], ToolParser] | None = None,
1046
        add_special_tokens: bool = False,
1047
    ) -> tuple[
1048
1049
1050
        list[ConversationMessage],
        Sequence[RequestPrompt],
        list[EngineTokensPrompt],
1051
    ]:
1052
1053
        model_config = self.model_config

1054
1055
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
1056
            tool_dicts,
1057
1058
            chat_template_content_format,
            tokenizer,
1059
            model_config=model_config,
1060
        )
1061
        conversation, mm_data_future, mm_uuids = parse_chat_messages_futures(
1062
            messages,
1063
            model_config,
1064
            tokenizer,
1065
            content_format=resolved_content_format,
1066
1067
        )

1068
        _chat_template_kwargs: dict[str, Any] = dict(
1069
1070
1071
1072
1073
1074
1075
1076
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tool_dicts,
            documents=documents,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

1077
        request_prompt: str | list[int]
1078
1079
1080
1081

        if tokenizer is None:
            request_prompt = "placeholder"
        elif isinstance(tokenizer, MistralTokenizer):
1082
            request_prompt = await self._apply_mistral_chat_template_async(
1083
1084
                tokenizer,
                messages=messages,
1085
                **_chat_template_kwargs,
1086
1087
1088
            )
        else:
            request_prompt = apply_hf_chat_template(
1089
                tokenizer=tokenizer,
1090
                conversation=conversation,
1091
                model_config=model_config,
1092
                **_chat_template_kwargs,
1093
1094
1095
1096
            )

        mm_data = await mm_data_future

1097
1098
1099
        # tool parsing is done only if a tool_parser has been set and if
        # tool_choice is not "none" (if tool_choice is "none" but a tool_parser
        # is set, we want to prevent parsing a tool_call hallucinated by the LLM
1100
1101
1102
        should_parse_tools = tool_parser is not None and (
            hasattr(request, "tool_choice") and request.tool_choice != "none"
        )
1103
1104

        if should_parse_tools:
1105
1106
1107
1108
            if not isinstance(request, ChatCompletionRequest):
                msg = "Tool usage is only supported for Chat Completions API"
                raise NotImplementedError(msg)

1109
            request = tool_parser(tokenizer).adjust_request(  # type: ignore
1110
1111
                request=request
            )
1112

1113
1114
        if tokenizer is None:
            assert isinstance(request_prompt, str), (
1115
1116
                "Prompt has to be a string",
                "when the tokenizer is not initialised",
1117
            )
1118
1119
1120
            prompt_inputs = TextTokensPrompt(
                prompt=request_prompt, prompt_token_ids=[1]
            )
1121
        elif isinstance(request_prompt, str):
1122
            prompt_inputs = await self._tokenize_prompt_input_async(
1123
1124
1125
1126
1127
1128
1129
1130
                request,
                tokenizer,
                request_prompt,
                add_special_tokens=add_special_tokens,
            )
        else:
            # For MistralTokenizer
            assert is_list_of(request_prompt, int), (
1131
1132
                "Prompt has to be either a string or a list of token ids"
            )
1133
1134
            prompt_inputs = TextTokensPrompt(
                prompt=tokenizer.decode(request_prompt),
1135
1136
                prompt_token_ids=request_prompt,
            )
1137

1138
        engine_prompt = EngineTokensPrompt(
1139
1140
            prompt_token_ids=prompt_inputs["prompt_token_ids"]
        )
1141
1142
        if mm_data is not None:
            engine_prompt["multi_modal_data"] = mm_data
1143
1144
1145
1146

        if mm_uuids is not None:
            engine_prompt["multi_modal_uuids"] = mm_uuids

1147
1148
        if request.mm_processor_kwargs is not None:
            engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
1149

1150
1151
1152
        if hasattr(request, "cache_salt") and request.cache_salt is not None:
            engine_prompt["cache_salt"] = request.cache_salt

1153
1154
        return conversation, [request_prompt], [engine_prompt]

1155
1156
1157
1158
    async def _process_inputs(
        self,
        request_id: str,
        engine_prompt: PromptType,
1159
        params: SamplingParams | PoolingParams,
1160
        *,
1161
1162
        lora_request: LoRARequest | None,
        trace_headers: Mapping[str, str] | None,
1163
1164
        priority: int,
    ) -> tuple[EngineCoreRequest, dict[str, Any]]:
1165
        """Use the Processor to process inputs for AsyncLLM."""
1166
        tokenization_kwargs: dict[str, Any] = {}
1167
1168
1169
        _validate_truncation_size(
            self.max_model_len, params.truncate_prompt_tokens, tokenization_kwargs
        )
1170

1171
        engine_request = self.processor.process_inputs(
1172
1173
            request_id,
            engine_prompt,
1174
            params,
1175
1176
1177
1178
1179
1180
1181
            lora_request=lora_request,
            tokenization_kwargs=tokenization_kwargs,
            trace_headers=trace_headers,
            priority=priority,
        )
        return engine_request, tokenization_kwargs

1182
1183
1184
1185
1186
1187
1188
    async def _generate_with_builtin_tools(
        self,
        request_id: str,
        request_prompt: RequestPrompt,
        engine_prompt: EngineTokensPrompt,
        sampling_params: SamplingParams,
        context: ConversationContext,
1189
        lora_request: LoRARequest | None = None,
1190
1191
1192
        priority: int = 0,
        **kwargs,
    ):
1193
        prompt_text, _, _ = self._get_prompt_components(request_prompt)
1194
1195
1196
1197
1198
1199
1200
1201
        orig_priority = priority
        while True:
            self._log_inputs(
                request_id,
                request_prompt,
                params=sampling_params,
                lora_request=lora_request,
            )
1202
            trace_headers = kwargs.get("trace_headers")
1203
            engine_request, tokenization_kwargs = await self._process_inputs(
1204
                request_id,
1205
1206
                engine_prompt,
                sampling_params,
1207
1208
1209
                lora_request=lora_request,
                trace_headers=trace_headers,
                priority=priority,
1210
            )
1211
1212
1213
1214

            generator = self.engine_client.generate(
                engine_request,
                sampling_params,
1215
1216
1217
                request_id,
                lora_request=lora_request,
                priority=priority,
1218
1219
                prompt_text=prompt_text,
                tokenization_kwargs=tokenization_kwargs,
1220
1221
                **kwargs,
            )
1222

1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
            async for res in generator:
                context.append_output(res)
                # NOTE(woosuk): The stop condition is handled by the engine.
                yield context

            if not context.need_builtin_tool_call():
                # The model did not ask for a tool call, so we're done.
                break

            # Call the tool and update the context with the result.
            tool_output = await context.call_tool()
            context.append_output(tool_output)

            # TODO: uncomment this and enable tool output streaming
            # yield context

            # Create inputs for the next turn.
            # Render the next prompt token ids.
            prompt_token_ids = context.render_for_completion()
1242
            engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids)
1243
1244
            request_prompt = prompt_token_ids
            # Update the sampling params.
1245
            sampling_params.max_tokens = self.max_model_len - len(prompt_token_ids)
1246
1247
1248
            # OPTIMIZATION
            priority = orig_priority - 1

1249
1250
    def _get_prompt_components(
        self,
1251
        prompt: RequestPrompt | PromptType,
1252
    ) -> PromptComponents:
1253
1254
        if isinstance(prompt, list):
            return PromptComponents(token_ids=prompt)
1255

1256
        return get_prompt_components(prompt)  # type: ignore[arg-type]
1257

1258
1259
1260
    def _log_inputs(
        self,
        request_id: str,
1261
1262
1263
        inputs: RequestPrompt | PromptType,
        params: SamplingParams | PoolingParams | BeamSearchParams | None,
        lora_request: LoRARequest | None,
1264
1265
1266
    ) -> None:
        if self.request_logger is None:
            return
1267

1268
        prompt, prompt_token_ids, prompt_embeds = self._get_prompt_components(inputs)
1269
1270
1271
1272
1273

        self.request_logger.log_inputs(
            request_id,
            prompt,
            prompt_token_ids,
1274
            prompt_embeds,
1275
1276
1277
            params=params,
            lora_request=lora_request,
        )
1278

1279
1280
1281
    async def _get_trace_headers(
        self,
        headers: Headers,
1282
    ) -> Mapping[str, str] | None:
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
        is_tracing_enabled = await self.engine_client.is_tracing_enabled()

        if is_tracing_enabled:
            return extract_trace_headers(headers)

        if contains_trace_headers(headers):
            log_tracing_disabled_warning()

        return None

1293
    @staticmethod
1294
    def _base_request_id(
1295
1296
        raw_request: Request | None, default: str | None = None
    ) -> str | None:
1297
1298
        """Pulls the request id to use from a header, if provided"""
        default = default or random_uuid()
1299
1300
1301
1302
        if raw_request is None:
            return default

        return raw_request.headers.get("X-Request-Id", default)
1303

1304
    @staticmethod
1305
1306
1307
1308
1309
1310
    def _get_decoded_token(
        logprob: Logprob,
        token_id: int,
        tokenizer: AnyTokenizer,
        return_as_token_id: bool = False,
    ) -> str:
1311
1312
1313
        if return_as_token_id:
            return f"token_id:{token_id}"

1314
1315
        if logprob.decoded_token is not None:
            return logprob.decoded_token
1316
        return tokenizer.decode(token_id)
1317

1318
    def _is_model_supported(self, model_name: str | None) -> bool:
1319
1320
        if not model_name:
            return True
1321
        return self.models.is_base_model(model_name)
1322

1323
1324

def clamp_prompt_logprobs(
1325
1326
    prompt_logprobs: PromptLogprobs | None,
) -> PromptLogprobs | None:
1327
1328
1329
1330
1331
1332
1333
    if prompt_logprobs is None:
        return prompt_logprobs

    for logprob_dict in prompt_logprobs:
        if logprob_dict is None:
            continue
        for logprob_values in logprob_dict.values():
1334
            if logprob_values.logprob == float("-inf"):
1335
1336
                logprob_values.logprob = -9999.0
    return prompt_logprobs