serving_engine.py 44.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
import io
5
import json
6
import sys
7
import time
8
import traceback
9
10
from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor
11
from http import HTTPStatus
12
from typing import (Annotated, Any, Callable, ClassVar, Generic, Optional,
13
                    TypeVar, Union, cast, overload)
14

15
import pybase64
16
import torch
17
from fastapi import Request
18
from pydantic import BaseModel, ConfigDict, Field
19
from starlette.datastructures import Headers
20
21
from typing_extensions import TypeIs

22
23
24
25
26
if sys.version_info >= (3, 12):
    from typing import TypedDict
else:
    from typing_extensions import TypedDict

27
import vllm.envs as envs
28
from vllm.config import ModelConfig
29
from vllm.engine.protocol import EngineClient
30
31
# yapf conflicts with isort for this block
# yapf: disable
32
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
33
                                         ChatTemplateContentFormatOption,
34
35
36
                                         ConversationMessage,
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
37
38
                                         parse_chat_messages_futures,
                                         resolve_chat_template_content_format)
39
from vllm.entrypoints.context import ConversationContext
40
from vllm.entrypoints.logger import RequestLogger
41
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
42
43
44
                                              ChatCompletionResponse,
                                              ClassificationRequest,
                                              ClassificationResponse,
45
                                              CompletionRequest,
46
                                              CompletionResponse,
47
                                              DetokenizeRequest,
48
49
                                              EmbeddingChatRequest,
                                              EmbeddingCompletionRequest,
50
                                              EmbeddingRequest,
51
                                              EmbeddingResponse, ErrorInfo,
52
53
54
55
56
                                              ErrorResponse,
                                              IOProcessorRequest,
                                              PoolingResponse, RerankRequest,
                                              ResponsesRequest, ScoreRequest,
                                              ScoreResponse,
57
                                              TokenizeChatRequest,
58
                                              TokenizeCompletionRequest,
59
60
                                              TokenizeResponse,
                                              TranscriptionRequest,
61
62
                                              TranscriptionResponse,
                                              TranslationRequest)
63
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
64
from vllm.entrypoints.openai.tool_parsers import ToolParser
65
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer
66
# yapf: enable
67
from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt
68
from vllm.inputs.data import PromptType
69
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
70
from vllm.inputs.parse import parse_and_batch_prompt
71
from vllm.logger import init_logger
72
from vllm.logprobs import Logprob, PromptLogprobs
73
from vllm.lora.request import LoRARequest
74
from vllm.multimodal import (  # noqa: F401 - Required to resolve Pydantic error in RequestProcessingMixin
75
    MultiModalDataDict, MultiModalUUIDDict)
76
from vllm.outputs import PoolingRequestOutput, RequestOutput
77
from vllm.pooling_params import PoolingParams
78
from vllm.sampling_params import BeamSearchParams, SamplingParams
79
80
81
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
                          log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
82
83
from vllm.utils import (AsyncMicrobatchTokenizer, is_list_of,
                        merge_async_iterators, random_uuid)
84
85
86

logger = init_logger(__name__)

87
88
89
90
91
92
93
94
95
CompletionLikeRequest = Union[
    CompletionRequest,
    DetokenizeRequest,
    EmbeddingCompletionRequest,
    RerankRequest,
    ClassificationRequest,
    ScoreRequest,
    TokenizeCompletionRequest,
]
96
97
98

ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
                        TokenizeChatRequest]
99
SpeechToTextRequest = Union[TranscriptionRequest, TranslationRequest]
100
101
102
103
104
105
106
AnyRequest = Union[
    CompletionLikeRequest,
    ChatLikeRequest,
    SpeechToTextRequest,
    ResponsesRequest,
    IOProcessorRequest,
]
107

108
109
110
111
112
113
114
115
116
117
118
AnyResponse = Union[
    CompletionResponse,
    ChatCompletionResponse,
    EmbeddingResponse,
    TranscriptionResponse,
    TokenizeResponse,
    PoolingResponse,
    ClassificationResponse,
    ScoreResponse,
]

119
120
121

class TextTokensPrompt(TypedDict):
    prompt: str
122
    prompt_token_ids: list[int]
123
124


125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
class EmbedsPrompt(TypedDict):
    prompt_embeds: torch.Tensor


RequestPrompt = Union[list[int], str, TextTokensPrompt, EmbedsPrompt]


def is_text_tokens_prompt(prompt: RequestPrompt) -> TypeIs[TextTokensPrompt]:
    return (isinstance(prompt, dict) and "prompt_token_ids" in prompt
            and "prompt_embeds" not in prompt)


def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]:
    return (isinstance(prompt, dict) and "prompt_token_ids" not in prompt
            and "prompt_embeds" in prompt)

141

142
143
144
145
146
RequestT = TypeVar("RequestT", bound=AnyRequest)


class RequestProcessingMixin(BaseModel):
    """
147
    Mixin for request processing,
148
149
    handling prompt preparation and engine input.
    """
150

151
    request_prompts: Optional[Sequence[RequestPrompt]] = []
152
    engine_prompts: Optional[Union[list[EngineTokensPrompt],
153
                                   list[EngineEmbedsPrompt]]] = []
154
155
156
157
158
159

    model_config = ConfigDict(arbitrary_types_allowed=True)


class ResponseGenerationMixin(BaseModel):
    """
160
    Mixin for response generation,
161
162
    managing result generators and final batch results.
    """
163

164
165
166
167
168
169
170
171
    result_generator: Optional[AsyncGenerator[tuple[int, Union[
        RequestOutput, PoolingRequestOutput]], None]] = None
    final_res_batch: list[Union[RequestOutput, PoolingRequestOutput]] = Field(
        default_factory=list)

    model_config = ConfigDict(arbitrary_types_allowed=True)


172
173
174
175
176
177
class ServeContext(
        RequestProcessingMixin,
        ResponseGenerationMixin,
        BaseModel,
        Generic[RequestT],
):
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
    # Shared across all requests
    request: RequestT
    raw_request: Optional[Request] = None
    model_name: str
    request_id: str
    created_time: int = Field(default_factory=lambda: int(time.time()))
    lora_request: Optional[LoRARequest] = None

    # Shared across most requests
    tokenizer: Optional[AnyTokenizer] = None

    # `protected_namespaces` resolves Pydantic v2's warning
    # on conflict with protected namespace "model_"
    model_config = ConfigDict(
        protected_namespaces=(),
        arbitrary_types_allowed=True,
    )


ClassificationServeContext = ServeContext[ClassificationRequest]


class EmbeddingServeContext(ServeContext[EmbeddingRequest]):
    chat_template: Optional[str] = None
    chat_template_content_format: ChatTemplateContentFormatOption


# Used to resolve the Pydantic error related to
# forward reference of MultiModalDataDict in TokensPrompt
RequestProcessingMixin.model_rebuild()
ServeContext.model_rebuild()
ClassificationServeContext.model_rebuild()
EmbeddingServeContext.model_rebuild()

212

213
class OpenAIServing:
214
215
216
217
    request_id_prefix: ClassVar[str] = """
    A short string prepended to every request’s ID (e.g. "embd", "classify")
    so you can easily tell “this ID came from Embedding vs Classification.”
    """
218

219
220
    def __init__(
        self,
221
        engine_client: EngineClient,
222
        model_config: ModelConfig,
223
        models: OpenAIServingModels,
224
225
        *,
        request_logger: Optional[RequestLogger],
226
        return_tokens_as_token_ids: bool = False,
227
        enable_force_include_usage: bool = False,
228
        log_error_stack: bool = False,
229
    ):
230
231
        super().__init__()

232
        self.engine_client = engine_client
233
        self.model_config = model_config
234
235
        self.max_model_len = model_config.max_model_len

236
        self.models = models
237

238
        self.request_logger = request_logger
239
        self.return_tokens_as_token_ids = return_tokens_as_token_ids
240
        self.enable_force_include_usage = enable_force_include_usage
241

242
243
        self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)

244
245
        self._async_tokenizer_pool: dict[AnyTokenizer,
                                         AsyncMicrobatchTokenizer] = {}
246
        self.log_error_stack = log_error_stack
247

248
249
250
251
252
253
254
255
256
257
    def _get_renderer(self, tokenizer: Optional[AnyTokenizer]) -> BaseRenderer:
        """
        Get a Renderer instance with the provided tokenizer.
        Uses shared async tokenizer pool for efficiency.
        """
        return CompletionRenderer(
            model_config=self.model_config,
            tokenizer=tokenizer,
            async_tokenizer_pool=self._async_tokenizer_pool)

258
259
    def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer:
        """
260
        Return (and cache) an `AsyncMicrobatchTokenizer` bound to the
261
262
263
264
265
266
267
        given tokenizer.
        """
        async_tokenizer = self._async_tokenizer_pool.get(tokenizer)
        if async_tokenizer is None:
            async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
            self._async_tokenizer_pool[tokenizer] = async_tokenizer
        return async_tokenizer
268

269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
    async def _preprocess(
        self,
        ctx: ServeContext,
    ) -> Optional[ErrorResponse]:
        """
        Default preprocessing hook. Subclasses may override
        to prepare `ctx` (classification, embedding, etc.).
        """
        return None

    def _build_response(
        self,
        ctx: ServeContext,
    ) -> Union[AnyResponse, ErrorResponse]:
        """
        Default response builder. Subclass may override this method
        to return the appropriate response object.
        """
        return self.create_error_response("unimplemented endpoint")

    async def handle(
        self,
        ctx: ServeContext,
    ) -> Union[AnyResponse, ErrorResponse]:
        generation: AsyncGenerator[Union[AnyResponse, ErrorResponse], None]
        generation = self._pipeline(ctx)

        async for response in generation:
            return response

        return self.create_error_response("No response yielded from pipeline")

    async def _pipeline(
        self,
        ctx: ServeContext,
    ) -> AsyncGenerator[Union[AnyResponse, ErrorResponse], None]:
        """Execute the request processing pipeline yielding responses."""
        if error := await self._check_model(ctx.request):
            yield error
        if error := self._validate_request(ctx):
            yield error

        preprocess_ret = await self._preprocess(ctx)
        if isinstance(preprocess_ret, ErrorResponse):
            yield preprocess_ret

        generators_ret = await self._prepare_generators(ctx)
        if isinstance(generators_ret, ErrorResponse):
            yield generators_ret

        collect_ret = await self._collect_batch(ctx)
        if isinstance(collect_ret, ErrorResponse):
            yield collect_ret

        yield self._build_response(ctx)

    def _validate_request(self, ctx: ServeContext) -> Optional[ErrorResponse]:
        truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens",
                                         None)

329
330
        if (truncate_prompt_tokens is not None
                and truncate_prompt_tokens > self.max_model_len):
331
332
333
334
            return self.create_error_response(
                "truncate_prompt_tokens value is "
                "greater than max_model_len."
                " Please, select a smaller truncation size.")
335
336
        return None

337
338
339
340
341
342
343
344
345
346
    def _create_pooling_params(
        self,
        ctx: ServeContext,
    ) -> Union[PoolingParams, ErrorResponse]:
        if not hasattr(ctx.request, "to_pooling_params"):
            return self.create_error_response(
                "Request type does not support pooling parameters")

        return ctx.request.to_pooling_params()

347
348
349
350
351
352
353
354
355
356
357
358
359
    async def _prepare_generators(
        self,
        ctx: ServeContext,
    ) -> Optional[ErrorResponse]:
        """Schedule the request and get the result generator."""
        generators: list[AsyncGenerator[Union[RequestOutput,
                                              PoolingRequestOutput],
                                        None]] = []

        try:
            trace_headers = (None if ctx.raw_request is None else await
                             self._get_trace_headers(ctx.raw_request.headers))

360
361
362
            pooling_params = self._create_pooling_params(ctx)
            if isinstance(pooling_params, ErrorResponse):
                return pooling_params
363
364
365
366
367
368
369
370

            if ctx.engine_prompts is None:
                return self.create_error_response(
                    "Engine prompts not available")

            for i, engine_prompt in enumerate(ctx.engine_prompts):
                request_id_item = f"{ctx.request_id}-{i}"

371
372
373
374
375
376
                # Mypy has an existing bug related to inferring the variance of
                # TypedDicts with `builtins.enumerate`:
                # https://github.com/python/mypy/issues/8586#issuecomment-2867698435
                engine_prompt = cast(
                    Union[EngineTokensPrompt, EngineEmbedsPrompt],
                    engine_prompt)
377

378
379
                self._log_inputs(
                    request_id_item,
380
                    engine_prompt,
381
382
383
                    params=pooling_params,
                    lora_request=ctx.lora_request,
                )
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438

                generator = self.engine_client.encode(
                    engine_prompt,
                    pooling_params,
                    request_id_item,
                    lora_request=ctx.lora_request,
                    trace_headers=trace_headers,
                    priority=getattr(ctx.request, "priority", 0),
                )

                generators.append(generator)

            ctx.result_generator = merge_async_iterators(*generators)

            return None

        except Exception as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))

    async def _collect_batch(
        self,
        ctx: ServeContext,
    ) -> Optional[ErrorResponse]:
        """Collect batch results from the result generator."""
        try:
            if ctx.engine_prompts is None:
                return self.create_error_response(
                    "Engine prompts not available")

            num_prompts = len(ctx.engine_prompts)
            final_res_batch: list[Optional[Union[RequestOutput,
                                                 PoolingRequestOutput]]]
            final_res_batch = [None] * num_prompts

            if ctx.result_generator is None:
                return self.create_error_response(
                    "Result generator not available")

            async for i, res in ctx.result_generator:
                final_res_batch[i] = res

            if None in final_res_batch:
                return self.create_error_response(
                    "Failed to generate results for all prompts")

            ctx.final_res_batch = [
                res for res in final_res_batch if res is not None
            ]

            return None

        except Exception as e:
            return self.create_error_response(str(e))

439
    def create_error_response(
440
441
442
443
444
        self,
        message: str,
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
    ) -> ErrorResponse:
445
446
447
448
449
450
        if self.log_error_stack:
            exc_type, _, _ = sys.exc_info()
            if exc_type is not None:
                traceback.print_exc()
            else:
                traceback.print_stack()
451
452
        return ErrorResponse(error=ErrorInfo(
            message=message, type=err_type, code=status_code.value))
453

454
    def create_streaming_error_response(
455
456
457
458
459
        self,
        message: str,
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
    ) -> str:
460
        json_str = json.dumps(
461
462
            self.create_error_response(message=message,
                                       err_type=err_type,
463
                                       status_code=status_code).model_dump())
464
465
        return json_str

466
    async def _check_model(
467
468
        self,
        request: AnyRequest,
469
    ) -> Optional[ErrorResponse]:
470
471
        error_response = None

472
        if self._is_model_supported(request.model):
473
            return None
474
        if request.model in self.models.lora_requests:
475
            return None
476
477
        if (envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING and request.model and
            (load_result := await self.models.resolve_lora(request.model))):
478
479
            if isinstance(load_result, LoRARequest):
                return None
480
481
            if (isinstance(load_result, ErrorResponse) and
                    load_result.error.code == HTTPStatus.BAD_REQUEST.value):
482
483
484
                error_response = load_result

        return error_response or self.create_error_response(
485
486
            message=f"The model `{request.model}` does not exist.",
            err_type="NotFoundError",
487
488
            status_code=HTTPStatus.NOT_FOUND,
        )
489

490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
    def _get_active_default_mm_loras(
            self, request: AnyRequest) -> Optional[LoRARequest]:
        """Determine if there are any active default multimodal loras."""
        # TODO: Currently this is only enabled for chat completions
        # to be better aligned with only being enabled for .generate
        # when run offline. It would be nice to support additional
        # tasks types in the future.
        message_types = self._get_message_types(request)
        default_mm_loras = set()

        for lora in self.models.lora_requests.values():
            # Best effort match for default multimodal lora adapters;
            # There is probably a better way to do this, but currently
            # this matches against the set of 'types' in any content lists
            # up until '_', e.g., to match audio_url -> audio
            if lora.lora_name in message_types:
                default_mm_loras.add(lora)

        # Currently only support default modality specific loras if
        # we have exactly one lora matched on the request.
        if len(default_mm_loras) == 1:
            return default_mm_loras.pop()
        return None

514
    def _maybe_get_adapters(
515
516
517
        self,
        request: AnyRequest,
        supports_default_mm_loras: bool = False,
518
    ) -> Optional[LoRARequest]:
519
        if request.model in self.models.lora_requests:
520
            return self.models.lora_requests[request.model]
521
522
523
524
525
526

        # Currently only support default modality specific loras
        # if we have exactly one lora matched on the request.
        if supports_default_mm_loras:
            default_mm_lora = self._get_active_default_mm_loras(request)
            if default_mm_lora is not None:
527
                return default_mm_lora
528
529

        if self._is_model_supported(request.model):
530
            return None
531

532
        # if _check_model has been called earlier, this will be unreachable
533
        raise ValueError(f"The model `{request.model}` does not exist.")
534

535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
    def _get_message_types(self, request: AnyRequest) -> set[str]:
        """Retrieve the set of types from message content dicts up
        until `_`; we use this to match potential multimodal data
        with default per modality loras.
        """
        message_types: set[str] = set()

        if not hasattr(request, "messages"):
            return message_types

        for message in request.messages:
            if (isinstance(message, dict) and "content" in message
                    and isinstance(message["content"], list)):
                for content_dict in message["content"]:
                    if "type" in content_dict:
                        message_types.add(content_dict["type"].split("_")[0])
        return message_types

553
    async def _normalize_prompt_text_to_input(
554
555
556
        self,
        request: AnyRequest,
        prompt: str,
557
        tokenizer: AnyTokenizer,
558
559
        add_special_tokens: bool,
    ) -> TextTokensPrompt:
560
561
        async_tokenizer = self._get_async_tokenizer(tokenizer)

562
563
564
565
566
        if (self.model_config.encoder_config is not None
                and self.model_config.encoder_config.get(
                    "do_lower_case", False)):
            prompt = prompt.lower()

567
568
569
        truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
                                         None)

570
        if truncate_prompt_tokens is None:
571
572
            encoded = await async_tokenizer(
                prompt, add_special_tokens=add_special_tokens)
573
574
        elif truncate_prompt_tokens < 0:
            # Negative means we cap at the model's max length
575
576
577
578
            encoded = await async_tokenizer(
                prompt,
                add_special_tokens=add_special_tokens,
                truncation=True,
579
580
                max_length=self.max_model_len,
            )
581
        else:
582
583
584
585
            encoded = await async_tokenizer(
                prompt,
                add_special_tokens=add_special_tokens,
                truncation=True,
586
587
                max_length=truncate_prompt_tokens,
            )
588
589
590
591
592
593

        input_ids = encoded.input_ids
        input_text = prompt

        return self._validate_input(request, input_ids, input_text)

594
    async def _normalize_prompt_tokens_to_input(
595
596
        self,
        request: AnyRequest,
597
        prompt_ids: list[int],
598
        tokenizer: Optional[AnyTokenizer],
599
    ) -> TextTokensPrompt:
600
601
602
        truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
                                         None)

603
        if truncate_prompt_tokens is None:
604
            input_ids = prompt_ids
605
606
        elif truncate_prompt_tokens < 0:
            input_ids = prompt_ids[-self.max_model_len:]
607
608
609
        else:
            input_ids = prompt_ids[-truncate_prompt_tokens:]

610
611
612
613
614
        if tokenizer is None:
            input_text = ""
        else:
            async_tokenizer = self._get_async_tokenizer(tokenizer)
            input_text = await async_tokenizer.decode(input_ids)
615

616
617
618
619
620
        return self._validate_input(request, input_ids, input_text)

    def _validate_input(
        self,
        request: AnyRequest,
621
        input_ids: list[int],
622
623
        input_text: str,
    ) -> TextTokensPrompt:
624
625
        token_num = len(input_ids)

626
627
        # Note: EmbeddingRequest, ClassificationRequest,
        # and ScoreRequest doesn't have max_tokens
628
629
630
631
632
633
634
635
636
637
        if isinstance(
                request,
            (
                EmbeddingChatRequest,
                EmbeddingCompletionRequest,
                ScoreRequest,
                RerankRequest,
                ClassificationRequest,
            ),
        ):
638
639
            # Note: input length can be up to the entire model context length
            # since these requests don't generate tokens.
640
            if token_num > self.max_model_len:
641
642
                operations: dict[type[AnyRequest], str] = {
                    ScoreRequest: "score",
643
                    ClassificationRequest: "classification",
644
645
646
                }
                operation = operations.get(type(request),
                                           "embedding generation")
647
648
649
                raise ValueError(
                    f"This model's maximum context length is "
                    f"{self.max_model_len} tokens. However, you requested "
650
651
                    f"{token_num} tokens in the input for {operation}. "
                    f"Please reduce the length of the input.")
652
653
            return TextTokensPrompt(prompt=input_text,
                                    prompt_token_ids=input_ids)
654

655
656
        # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
        # and does not require model context length validation
657
658
659
660
661
        if isinstance(
                request,
            (TokenizeCompletionRequest, TokenizeChatRequest,
             DetokenizeRequest),
        ):
662
663
            return TextTokensPrompt(prompt=input_text,
                                    prompt_token_ids=input_ids)
664

665
666
667
668
669
        # chat completion endpoint supports max_completion_tokens
        if isinstance(request, ChatCompletionRequest):
            # TODO(#9845): remove max_tokens when field dropped from OpenAI API
            max_tokens = request.max_completion_tokens or request.max_tokens
        else:
670
            max_tokens = getattr(request, "max_tokens", None)
671
672
673
674

        # Note: input length can be up to model context length - 1 for
        # completion-like requests.
        if token_num >= self.max_model_len:
675
            raise ValueError(
676
                f"This model's maximum context length is "
677
678
679
680
                f"{self.max_model_len} tokens. However, your request has "
                f"{token_num} input tokens. Please reduce the length of "
                "the input messages.")

681
682
        if (max_tokens is not None
                and token_num + max_tokens > self.max_model_len):
683
684
685
686
687
688
            raise ValueError(
                "'max_tokens' or 'max_completion_tokens' is too large: "
                f"{max_tokens}. This model's maximum context length is "
                f"{self.max_model_len} tokens and your request has "
                f"{token_num} input tokens ({max_tokens} > {self.max_model_len}"
                f" - {token_num}).")
689
690
691

        return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)

692
    async def _tokenize_prompt_input_async(
693
694
695
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
696
        prompt_input: Union[str, list[int]],
697
698
699
        add_special_tokens: bool = True,
    ) -> TextTokensPrompt:
        """
700
701
        A simpler implementation of
        [`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs]
702
703
        that assumes single input.
        """
704
        async for result in self._tokenize_prompt_inputs_async(
705
706
                request,
                tokenizer,
707
            [prompt_input],
708
                add_special_tokens=add_special_tokens,
709
710
711
        ):
            return result
        raise ValueError("No results yielded from tokenization")
712

713
    async def _tokenize_prompt_inputs_async(
714
715
716
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
717
        prompt_inputs: Iterable[Union[str, list[int]]],
718
        add_special_tokens: bool = True,
719
    ) -> AsyncGenerator[TextTokensPrompt, None]:
720
        """
721
722
        A simpler implementation of
        [`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs]
723
724
        that assumes multiple inputs.
        """
725
726
        for prompt in prompt_inputs:
            if isinstance(prompt, str):
727
                yield await self._normalize_prompt_text_to_input(
728
                    request,
729
730
                    prompt=prompt,
                    tokenizer=tokenizer,
731
732
733
                    add_special_tokens=add_special_tokens,
                )
            else:
734
                yield await self._normalize_prompt_tokens_to_input(
735
                    request,
736
737
                    prompt_ids=prompt,
                    tokenizer=tokenizer,
738
739
                )

740
    async def _tokenize_prompt_input_or_inputs_async(
741
742
        self,
        request: AnyRequest,
743
        tokenizer: Optional[AnyTokenizer],
744
745
        input_or_inputs: Optional[Union[str, list[str], list[int],
                                        list[list[int]]]],
746
        add_special_tokens: bool = True,
747
    ) -> tuple[list[TextTokensPrompt], list[EmbedsPrompt]]:
748
749
750
751
752
753
754
        """
        Tokenize/detokenize depending on the input format.

        According to `OpenAI API <https://platform.openai.com/docs/api-reference/embeddings/create>`_
        , each input can be a string or array of tokens. Note that each request
        can pass one or more inputs.
        """
755
756
757
        inputs_embeds = list[EmbedsPrompt]()
        inputs_text = list[TextTokensPrompt]()

758
759
760
761
762
763
        truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
                                         None)

        if (truncate_prompt_tokens or 0) < 0:
            truncate_prompt_tokens = self.max_model_len

764
765
766
767
768
769
770
771
772
773
774
        if (isinstance(request, CompletionRequest)
                and request.prompt_embeds is not None):
            inputs_embeds.extend(
                self._load_prompt_embeds(request.prompt_embeds,
                                         truncate_prompt_tokens))

        # Empty prompts are okay as long as there are prompt embeddings
        if input_or_inputs is None or (inputs_embeds
                                       and input_or_inputs == ""):
            return [], inputs_embeds

775
776
        # Although our type checking is based on mypy,
        # VSCode Pyright extension should still work properly
777
        # "is False" is required for Pyright to perform type narrowing
778
        # See: https://github.com/microsoft/pyright/issues/7672
779
780
781
782
783
784
785
786

        # Parse and batch the input prompts
        batch_inputs = parse_and_batch_prompt(input_or_inputs)

        # Process each input in the batch concurrently
        tasks = []
        for prompt_input in batch_inputs:
            if prompt_input["is_tokens"] is False:
787
788
                assert tokenizer is not None, (
                    "Tokenizer is required for text prompts")
789
790
791
                task = self._normalize_prompt_text_to_input(
                    request,
                    prompt_input["content"],
792
                    tokenizer=tokenizer,
793
794
                    add_special_tokens=add_special_tokens,
                )
795
796
            else:
                task = self._normalize_prompt_tokens_to_input(
797
                    request, prompt_input["content"], tokenizer=tokenizer)
798
799
800
801
802
            tasks.append(task)

        # Wait for all tokenization tasks to complete
        results = await asyncio.gather(*tasks)
        inputs_text.extend(results)
803
804

        return inputs_text, inputs_embeds
805

806
    @overload
807
    async def _preprocess_completion(
808
        self,
809
810
811
812
813
814
815
816
        request: Union[
            DetokenizeRequest,
            EmbeddingCompletionRequest,
            RerankRequest,
            ClassificationRequest,
            ScoreRequest,
            TokenizeCompletionRequest,
        ],
817
        tokenizer: Optional[AnyTokenizer],
818
        input_or_inputs: Union[str, list[str], list[int], list[list[int]]],
819
820
821
822
823
824
825
826
        add_special_tokens: bool = ...,
    ) -> tuple[list[TextTokensPrompt], list[EngineTokensPrompt]]:
        ...

    @overload
    async def _preprocess_completion(
        self,
        request: CompletionRequest,
827
        tokenizer: Optional[AnyTokenizer],
828
829
830
        input_or_inputs: Optional[Union[str, list[str], list[int],
                                        list[list[int]]]],
        add_special_tokens: bool = ...,
831
832
833
834
    ) -> tuple[
            list[Union[TextTokensPrompt, EmbedsPrompt]],
            list[Union[EngineTokensPrompt, EngineEmbedsPrompt]],
    ]:
835
836
837
838
839
        ...

    async def _preprocess_completion(
        self,
        request: CompletionLikeRequest,
840
        tokenizer: Optional[AnyTokenizer],
841
842
        input_or_inputs: Optional[Union[str, list[str], list[int],
                                        list[list[int]]]],
843
        add_special_tokens: bool = True,
844
845
846
847
848
849
850
851
852
853
    ) -> tuple[
            Union[list[TextTokensPrompt], list[Union[TextTokensPrompt,
                                                     EmbedsPrompt]]],
            Union[
                list[EngineTokensPrompt],
                list[Union[EngineTokensPrompt, EngineEmbedsPrompt]],
            ],
    ]:
        if (not isinstance(request, CompletionRequest)
                and input_or_inputs is None):
854
855
856
857
            raise ValueError(
                "Prompt embeds with non-completion requests is not"
                " currently supported.")

858
859
860
861
862
863
864
865
866
        (
            request_prompts_text,
            request_prompts_embeds,
        ) = await self._tokenize_prompt_input_or_inputs_async(
            request,
            tokenizer,
            input_or_inputs,
            add_special_tokens=add_special_tokens,
        )
867
868
869
870
871
872

        engine_prompts_text = [
            EngineTokensPrompt(
                prompt_token_ids=request_prompt_text["prompt_token_ids"])
            for request_prompt_text in request_prompts_text
        ]
873
874
875
        cache_salt = (request.cache_salt if
                      (hasattr(request, "cache_salt")
                       and request.cache_salt is not None) else None)
876
877
878
        if cache_salt:
            for prompt_text in engine_prompts_text:
                prompt_text["cache_salt"] = cache_salt
879

880
881
882
883
884
885
886
        # This check is equivalent to simply checking if
        # `request_prompts_embeds` is empty, but it's difficult to propagate
        # overloads to the private helper functions to enable this check.
        # This overload is needed because only TextPrompts are allowed for
        # non-completion requests and if we don't add the overload here,
        # everywhere this function is used outside of serving_completion will
        # need logic asserting that only text prompts are in the request.
887
888
        if (not isinstance(request, CompletionRequest)
                and input_or_inputs is not None):
889
890
891
892
893
894
            return request_prompts_text, engine_prompts_text

        engine_prompts_embeds = [
            EngineEmbedsPrompt(
                prompt_embeds=request_prompt_embeds["prompt_embeds"])
            for request_prompt_embeds in request_prompts_embeds
895
        ]
896
897
898
        if cache_salt:
            for prompt_embed in engine_prompts_embeds:
                prompt_embed["cache_salt"] = cache_salt
899

900
901
        request_prompts = request_prompts_embeds + request_prompts_text
        engine_prompts = engine_prompts_embeds + engine_prompts_text
902
903
904
905
        return request_prompts, engine_prompts

    async def _preprocess_chat(
        self,
906
        request: Union[ChatLikeRequest, ResponsesRequest],
907
        tokenizer: AnyTokenizer,
908
        messages: list[ChatCompletionMessageParam],
909
910
        chat_template: Optional[str],
        chat_template_content_format: ChatTemplateContentFormatOption,
911
912
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
913
914
915
        tool_dicts: Optional[list[dict[str, Any]]] = None,
        documents: Optional[list[dict[str, str]]] = None,
        chat_template_kwargs: Optional[dict[str, Any]] = None,
916
917
        tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None,
        add_special_tokens: bool = False,
918
919
920
921
922
    ) -> tuple[
            list[ConversationMessage],
            Sequence[RequestPrompt],
            list[EngineTokensPrompt],
    ]:
923
924
        model_config = self.model_config

925
926
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
927
            tool_dicts,
928
929
            chat_template_content_format,
            tokenizer,
930
            model_config=model_config,
931
        )
932
933
        conversation, mm_data_future = parse_chat_messages_futures(
            messages,
934
            model_config,
935
            tokenizer,
936
            content_format=resolved_content_format,
937
938
        )

939
        _chat_template_kwargs: dict[str, Any] = dict(
940
941
942
943
944
945
946
947
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tool_dicts,
            documents=documents,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

948
        request_prompt: Union[str, list[int]]
949
950
951
952

        if tokenizer is None:
            request_prompt = "placeholder"
        elif isinstance(tokenizer, MistralTokenizer):
953
954
955
            request_prompt = apply_mistral_chat_template(
                tokenizer,
                messages=messages,
956
                **_chat_template_kwargs,
957
958
959
            )
        else:
            request_prompt = apply_hf_chat_template(
960
                tokenizer=tokenizer,
961
                conversation=conversation,
962
                model_config=model_config,
963
                **_chat_template_kwargs,
964
965
966
967
            )

        mm_data = await mm_data_future

968
969
970
971
972
973
974
        # tool parsing is done only if a tool_parser has been set and if
        # tool_choice is not "none" (if tool_choice is "none" but a tool_parser
        # is set, we want to prevent parsing a tool_call hallucinated by the LLM
        should_parse_tools = tool_parser is not None and (hasattr(
            request, "tool_choice") and request.tool_choice != "none")

        if should_parse_tools:
975
976
977
978
            if not isinstance(request, ChatCompletionRequest):
                msg = "Tool usage is only supported for Chat Completions API"
                raise NotImplementedError(msg)

979
980
            request = tool_parser(tokenizer).adjust_request(  # type: ignore
                request=request)
981

982
983
        if tokenizer is None:
            assert isinstance(request_prompt, str), (
984
985
                "Prompt has to be a string",
                "when the tokenizer is not initialised",
986
987
988
989
            )
            prompt_inputs = TextTokensPrompt(prompt=request_prompt,
                                             prompt_token_ids=[1])
        elif isinstance(request_prompt, str):
990
            prompt_inputs = await self._tokenize_prompt_input_async(
991
992
993
994
995
996
997
998
999
1000
1001
                request,
                tokenizer,
                request_prompt,
                add_special_tokens=add_special_tokens,
            )
        else:
            # For MistralTokenizer
            assert is_list_of(request_prompt, int), (
                "Prompt has to be either a string or a list of token ids")
            prompt_inputs = TextTokensPrompt(
                prompt=tokenizer.decode(request_prompt),
1002
1003
                prompt_token_ids=request_prompt,
            )
1004

1005
        engine_prompt = EngineTokensPrompt(
1006
1007
1008
            prompt_token_ids=prompt_inputs["prompt_token_ids"])
        if mm_data is not None:
            engine_prompt["multi_modal_data"] = mm_data
1009
1010
        if request.mm_processor_kwargs is not None:
            engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
1011

1012
1013
1014
        if hasattr(request, "cache_salt") and request.cache_salt is not None:
            engine_prompt["cache_salt"] = request.cache_salt

1015
1016
        return conversation, [request_prompt], [engine_prompt]

1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
    async def _generate_with_builtin_tools(
        self,
        request_id: str,
        request_prompt: RequestPrompt,
        engine_prompt: EngineTokensPrompt,
        sampling_params: SamplingParams,
        context: ConversationContext,
        lora_request: Optional[LoRARequest] = None,
        priority: int = 0,
        **kwargs,
    ):
        orig_priority = priority
        while True:
            self._log_inputs(
                request_id,
                request_prompt,
                params=sampling_params,
                lora_request=lora_request,
            )
            generator = self.engine_client.generate(
                engine_prompt,
                sampling_params,
                request_id,
                lora_request=lora_request,
                priority=priority,
                **kwargs,
            )
            async for res in generator:
                context.append_output(res)
                # NOTE(woosuk): The stop condition is handled by the engine.
                yield context

            if not context.need_builtin_tool_call():
                # The model did not ask for a tool call, so we're done.
                break

            # Call the tool and update the context with the result.
            tool_output = await context.call_tool()
            context.append_output(tool_output)

            # TODO: uncomment this and enable tool output streaming
            # yield context

            # Create inputs for the next turn.
            # Render the next prompt token ids.
            prompt_token_ids = context.render_for_completion()
            engine_prompt = EngineTokensPrompt(
                prompt_token_ids=prompt_token_ids)
            request_prompt = prompt_token_ids
            # Update the sampling params.
1067
1068
            sampling_params.max_tokens = self.max_model_len - len(
                prompt_token_ids)
1069
1070
1071
            # OPTIMIZATION
            priority = orig_priority - 1

1072
    @staticmethod
1073
1074
    def _load_prompt_embeds(
        prompt_embeds: Optional[Union[bytes, list[bytes]]],
1075
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
1076
1077
1078
    ) -> list[EmbedsPrompt]:

        def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt:
1079
1080
1081
1082
1083
            tensor = torch.load(
                io.BytesIO(pybase64.b64decode(embed, validate=True)),
                weights_only=True,
                map_location=torch.device("cpu"),
            )
1084
1085
1086
1087
1088
            assert isinstance(tensor, torch.Tensor) and tensor.dtype in (
                torch.float32,
                torch.bfloat16,
                torch.float16,
            )
1089
            tensor = tensor.to_dense()
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
            if tensor.dim() > 2:
                tensor = tensor.squeeze(0)
                assert tensor.dim() == 2
            if truncate_prompt_tokens is not None:
                tensor = tensor[-truncate_prompt_tokens:]
            return {"prompt_embeds": tensor}

        if prompt_embeds:
            if isinstance(prompt_embeds, list):
                return [
                    _load_and_validate_embed(embed) for embed in prompt_embeds
                ]
            else:
                return [_load_and_validate_embed(prompt_embeds)]
        else:
            return []

1107
1108
1109
    def _log_inputs(
        self,
        request_id: str,
1110
        inputs: Union[RequestPrompt, PromptType],
1111
1112
        params: Optional[Union[SamplingParams, PoolingParams,
                               BeamSearchParams]],
1113
1114
1115
1116
        lora_request: Optional[LoRARequest],
    ) -> None:
        if self.request_logger is None:
            return
1117
        prompt, prompt_token_ids, prompt_embeds = None, None, None
1118
1119
1120
1121
        if isinstance(inputs, str):
            prompt = inputs
        elif isinstance(inputs, list):
            prompt_token_ids = inputs
1122
        else:
1123
1124
            prompt = getattr(inputs, 'prompt', None)
            prompt_token_ids = getattr(inputs, 'prompt_token_ids', None)
1125
1126
1127
1128
1129

        self.request_logger.log_inputs(
            request_id,
            prompt,
            prompt_token_ids,
1130
            prompt_embeds,
1131
1132
1133
            params=params,
            lora_request=lora_request,
        )
1134

1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
    async def _get_trace_headers(
        self,
        headers: Headers,
    ) -> Optional[Mapping[str, str]]:
        is_tracing_enabled = await self.engine_client.is_tracing_enabled()

        if is_tracing_enabled:
            return extract_trace_headers(headers)

        if contains_trace_headers(headers):
            log_tracing_disabled_warning()

        return None

1149
    @staticmethod
1150
    def _base_request_id(raw_request: Optional[Request],
1151
1152
1153
                         default: Optional[str] = None) -> Optional[str]:
        """Pulls the request id to use from a header, if provided"""
        default = default or random_uuid()
1154
1155
1156
1157
        if raw_request is None:
            return default

        return raw_request.headers.get("X-Request-Id", default)
1158

1159
    @staticmethod
1160
1161
1162
1163
1164
1165
    def _get_decoded_token(
        logprob: Logprob,
        token_id: int,
        tokenizer: AnyTokenizer,
        return_as_token_id: bool = False,
    ) -> str:
1166
1167
1168
        if return_as_token_id:
            return f"token_id:{token_id}"

1169
1170
        if logprob.decoded_token is not None:
            return logprob.decoded_token
1171
        return tokenizer.decode(token_id)
1172

1173
    def _is_model_supported(self, model_name: Optional[str]) -> bool:
1174
1175
        if not model_name:
            return True
1176
        return self.models.is_base_model(model_name)
1177

1178
1179
1180
1181
1182
    def _get_model_name(
        self,
        model_name: Optional[str] = None,
        lora_request: Optional[LoRARequest] = None,
    ) -> str:
1183
1184
        if lora_request:
            return lora_request.lora_name
1185
        if not model_name:
1186
1187
            return self.models.base_model_paths[0].name
        return model_name
1188
1189
1190
1191


def clamp_prompt_logprobs(
    prompt_logprobs: Union[PromptLogprobs,
1192
                           None], ) -> Union[PromptLogprobs, None]:
1193
1194
1195
1196
1197
1198
1199
    if prompt_logprobs is None:
        return prompt_logprobs

    for logprob_dict in prompt_logprobs:
        if logprob_dict is None:
            continue
        for logprob_values in logprob_dict.values():
1200
            if logprob_values.logprob == float("-inf"):
1201
1202
                logprob_values.logprob = -9999.0
    return prompt_logprobs