serving_engine.py 41.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
5
import base64
import io
6
import json
7
import sys
8
import time
9
10
from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor
11
from http import HTTPStatus
12
from typing import (Annotated, Any, Callable, ClassVar, Generic, Optional,
13
                    TypeVar, Union, cast, overload)
14

15
import torch
16
from fastapi import Request
17
from pydantic import BaseModel, ConfigDict, Field
18
from starlette.datastructures import Headers
19
20
from typing_extensions import TypeIs

21
22
23
24
25
if sys.version_info >= (3, 12):
    from typing import TypedDict
else:
    from typing_extensions import TypedDict

26
import vllm.envs as envs
27
from vllm.config import ModelConfig
28
from vllm.engine.protocol import EngineClient
29
30
# yapf conflicts with isort for this block
# yapf: disable
31
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
32
                                         ChatTemplateContentFormatOption,
33
34
35
                                         ConversationMessage,
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
36
37
                                         parse_chat_messages_futures,
                                         resolve_chat_template_content_format)
38
from vllm.entrypoints.logger import RequestLogger
39
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
40
41
42
                                              ChatCompletionResponse,
                                              ClassificationRequest,
                                              ClassificationResponse,
43
                                              CompletionRequest,
44
                                              CompletionResponse,
45
                                              DetokenizeRequest,
46
47
                                              EmbeddingChatRequest,
                                              EmbeddingCompletionRequest,
48
49
50
                                              EmbeddingRequest,
                                              EmbeddingResponse, ErrorResponse,
                                              PoolingResponse, RerankRequest,
51
52
                                              ResponsesRequest, ScoreRequest,
                                              ScoreResponse,
53
                                              TokenizeChatRequest,
54
                                              TokenizeCompletionRequest,
55
56
                                              TokenizeResponse,
                                              TranscriptionRequest,
57
58
                                              TranscriptionResponse,
                                              TranslationRequest)
59
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
60
from vllm.entrypoints.openai.tool_parsers import ToolParser
61
# yapf: enable
62
63
from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
64
from vllm.inputs.parse import parse_and_batch_prompt
65
from vllm.logger import init_logger
66
from vllm.lora.request import LoRARequest
67
68
69
from vllm.multimodal import (  # noqa: F401 - Required to resolve Pydantic error in RequestProcessingMixin
    MultiModalDataDict)
from vllm.outputs import PoolingRequestOutput, RequestOutput
70
from vllm.pooling_params import PoolingParams
71
from vllm.prompt_adapter.request import PromptAdapterRequest
72
from vllm.sampling_params import BeamSearchParams, SamplingParams
73
from vllm.sequence import Logprob, PromptLogprobs
74
75
76
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
                          log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
77
78
from vllm.utils import (AsyncMicrobatchTokenizer, is_list_of,
                        merge_async_iterators, random_uuid)
79
80
81

logger = init_logger(__name__)

82
CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest,
83
                              EmbeddingCompletionRequest, RerankRequest,
84
85
                              ClassificationRequest, ScoreRequest,
                              TokenizeCompletionRequest]
86
87
88

ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
                        TokenizeChatRequest]
89
SpeechToTextRequest = Union[TranscriptionRequest, TranslationRequest]
90
91
AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest, SpeechToTextRequest,
                   ResponsesRequest]
92

93
94
95
96
97
98
99
100
101
102
103
AnyResponse = Union[
    CompletionResponse,
    ChatCompletionResponse,
    EmbeddingResponse,
    TranscriptionResponse,
    TokenizeResponse,
    PoolingResponse,
    ClassificationResponse,
    ScoreResponse,
]

104
105
106

class TextTokensPrompt(TypedDict):
    prompt: str
107
    prompt_token_ids: list[int]
108
109


110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
class EmbedsPrompt(TypedDict):
    prompt_embeds: torch.Tensor


RequestPrompt = Union[list[int], str, TextTokensPrompt, EmbedsPrompt]


def is_text_tokens_prompt(prompt: RequestPrompt) -> TypeIs[TextTokensPrompt]:
    return (isinstance(prompt, dict) and "prompt_token_ids" in prompt
            and "prompt_embeds" not in prompt)


def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]:
    return (isinstance(prompt, dict) and "prompt_token_ids" not in prompt
            and "prompt_embeds" in prompt)

126

127
128
129
130
131
RequestT = TypeVar("RequestT", bound=AnyRequest)


class RequestProcessingMixin(BaseModel):
    """
132
    Mixin for request processing,
133
134
    handling prompt preparation and engine input.
    """
135
    request_prompts: Optional[Sequence[RequestPrompt]] = []
136
    engine_prompts: Optional[Union[list[EngineTokensPrompt],
137
                                   list[EngineEmbedsPrompt]]] = []
138
139
140
141
142
143

    model_config = ConfigDict(arbitrary_types_allowed=True)


class ResponseGenerationMixin(BaseModel):
    """
144
    Mixin for response generation,
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
    managing result generators and final batch results.
    """
    result_generator: Optional[AsyncGenerator[tuple[int, Union[
        RequestOutput, PoolingRequestOutput]], None]] = None
    final_res_batch: list[Union[RequestOutput, PoolingRequestOutput]] = Field(
        default_factory=list)

    model_config = ConfigDict(arbitrary_types_allowed=True)


class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, BaseModel,
                   Generic[RequestT]):
    # Shared across all requests
    request: RequestT
    raw_request: Optional[Request] = None
    model_name: str
    request_id: str
    created_time: int = Field(default_factory=lambda: int(time.time()))
    lora_request: Optional[LoRARequest] = None
    prompt_adapter_request: Optional[PromptAdapterRequest] = None

    # Shared across most requests
    tokenizer: Optional[AnyTokenizer] = None
    truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None

    # `protected_namespaces` resolves Pydantic v2's warning
    # on conflict with protected namespace "model_"
    model_config = ConfigDict(
        protected_namespaces=(),
        arbitrary_types_allowed=True,
    )


ClassificationServeContext = ServeContext[ClassificationRequest]


class EmbeddingServeContext(ServeContext[EmbeddingRequest]):
    chat_template: Optional[str] = None
    chat_template_content_format: ChatTemplateContentFormatOption


# Used to resolve the Pydantic error related to
# forward reference of MultiModalDataDict in TokensPrompt
RequestProcessingMixin.model_rebuild()
ServeContext.model_rebuild()
ClassificationServeContext.model_rebuild()
EmbeddingServeContext.model_rebuild()

193

194
class OpenAIServing:
195
196
197
198
    request_id_prefix: ClassVar[str] = """
    A short string prepended to every request’s ID (e.g. "embd", "classify")
    so you can easily tell “this ID came from Embedding vs Classification.”
    """
199

200
201
    def __init__(
        self,
202
        engine_client: EngineClient,
203
        model_config: ModelConfig,
204
        models: OpenAIServingModels,
205
206
        *,
        request_logger: Optional[RequestLogger],
207
        return_tokens_as_token_ids: bool = False,
208
        enable_force_include_usage: bool = False,
209
    ):
210
211
        super().__init__()

212
        self.engine_client = engine_client
213
        self.model_config = model_config
214
215
        self.max_model_len = model_config.max_model_len

216
        self.models = models
217

218
        self.request_logger = request_logger
219
        self.return_tokens_as_token_ids = return_tokens_as_token_ids
220
        self.enable_force_include_usage = enable_force_include_usage
221

222
223
        self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)

224
225
226
227
228
        self._async_tokenizer_pool: dict[AnyTokenizer,
                                         AsyncMicrobatchTokenizer] = {}

    def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer:
        """
229
        Return (and cache) an `AsyncMicrobatchTokenizer` bound to the
230
231
232
233
234
235
236
        given tokenizer.
        """
        async_tokenizer = self._async_tokenizer_pool.get(tokenizer)
        if async_tokenizer is None:
            async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
            self._async_tokenizer_pool[tokenizer] = async_tokenizer
        return async_tokenizer
237

238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
    async def _preprocess(
        self,
        ctx: ServeContext,
    ) -> Optional[ErrorResponse]:
        """
        Default preprocessing hook. Subclasses may override
        to prepare `ctx` (classification, embedding, etc.).
        """
        return None

    def _build_response(
        self,
        ctx: ServeContext,
    ) -> Union[AnyResponse, ErrorResponse]:
        """
        Default response builder. Subclass may override this method
        to return the appropriate response object.
        """
        return self.create_error_response("unimplemented endpoint")

    async def handle(
        self,
        ctx: ServeContext,
    ) -> Union[AnyResponse, ErrorResponse]:
        generation: AsyncGenerator[Union[AnyResponse, ErrorResponse], None]
        generation = self._pipeline(ctx)

        async for response in generation:
            return response

        return self.create_error_response("No response yielded from pipeline")

    async def _pipeline(
        self,
        ctx: ServeContext,
    ) -> AsyncGenerator[Union[AnyResponse, ErrorResponse], None]:
        """Execute the request processing pipeline yielding responses."""
        if error := await self._check_model(ctx.request):
            yield error
        if error := self._validate_request(ctx):
            yield error

        preprocess_ret = await self._preprocess(ctx)
        if isinstance(preprocess_ret, ErrorResponse):
            yield preprocess_ret

        generators_ret = await self._prepare_generators(ctx)
        if isinstance(generators_ret, ErrorResponse):
            yield generators_ret

        collect_ret = await self._collect_batch(ctx)
        if isinstance(collect_ret, ErrorResponse):
            yield collect_ret

        yield self._build_response(ctx)

    def _validate_request(self, ctx: ServeContext) -> Optional[ErrorResponse]:
        truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens",
                                         None)

        if truncate_prompt_tokens is not None:
            if truncate_prompt_tokens <= self.max_model_len:
                ctx.truncate_prompt_tokens = truncate_prompt_tokens
            else:
                return self.create_error_response(
                    "truncate_prompt_tokens value is "
                    "greater than max_model_len."
                    " Please, select a smaller truncation size.")
        return None

    async def _prepare_generators(
        self,
        ctx: ServeContext,
    ) -> Optional[ErrorResponse]:
        """Schedule the request and get the result generator."""
        generators: list[AsyncGenerator[Union[RequestOutput,
                                              PoolingRequestOutput],
                                        None]] = []

        try:
            trace_headers = (None if ctx.raw_request is None else await
                             self._get_trace_headers(ctx.raw_request.headers))

            if not hasattr(ctx.request, "to_pooling_params"):
                return self.create_error_response(
                    "Request type does not support pooling parameters")

            pooling_params = ctx.request.to_pooling_params()

            if ctx.engine_prompts is None:
                return self.create_error_response(
                    "Engine prompts not available")

            for i, engine_prompt in enumerate(ctx.engine_prompts):
                request_id_item = f"{ctx.request_id}-{i}"

                if ctx.request_prompts is None:
                    return self.create_error_response(
                        "Request prompts not available")

                self._log_inputs(
                    request_id_item,
                    ctx.request_prompts[i],
                    params=pooling_params,
                    lora_request=ctx.lora_request,
                    prompt_adapter_request=ctx.prompt_adapter_request)

345
346
347
348
349
350
                # Mypy has an existing bug related to inferring the variance of
                # TypedDicts with `builtins.enumerate`:
                # https://github.com/python/mypy/issues/8586#issuecomment-2867698435
                engine_prompt = cast(
                    Union[EngineTokensPrompt, EngineEmbedsPrompt],
                    engine_prompt)
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
                generator = self.engine_client.encode(
                    engine_prompt,
                    pooling_params,
                    request_id_item,
                    lora_request=ctx.lora_request,
                    trace_headers=trace_headers,
                    priority=getattr(ctx.request, "priority", 0),
                )

                generators.append(generator)

            ctx.result_generator = merge_async_iterators(*generators)

            return None

        except Exception as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))

    async def _collect_batch(
        self,
        ctx: ServeContext,
    ) -> Optional[ErrorResponse]:
        """Collect batch results from the result generator."""
        try:
            if ctx.engine_prompts is None:
                return self.create_error_response(
                    "Engine prompts not available")

            num_prompts = len(ctx.engine_prompts)
            final_res_batch: list[Optional[Union[RequestOutput,
                                                 PoolingRequestOutput]]]
            final_res_batch = [None] * num_prompts

            if ctx.result_generator is None:
                return self.create_error_response(
                    "Result generator not available")

            async for i, res in ctx.result_generator:
                final_res_batch[i] = res

            if None in final_res_batch:
                return self.create_error_response(
                    "Failed to generate results for all prompts")

            ctx.final_res_batch = [
                res for res in final_res_batch if res is not None
            ]

            return None

        except Exception as e:
            return self.create_error_response(str(e))

405
406
407
408
409
410
411
412
413
    def create_error_response(
            self,
            message: str,
            err_type: str = "BadRequestError",
            status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
        return ErrorResponse(message=message,
                             type=err_type,
                             code=status_code.value)

414
415
416
417
418
419
420
421
422
423
424
425
426
    def create_streaming_error_response(
            self,
            message: str,
            err_type: str = "BadRequestError",
            status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
        json_str = json.dumps({
            "error":
            self.create_error_response(message=message,
                                       err_type=err_type,
                                       status_code=status_code).model_dump()
        })
        return json_str

427
    async def _check_model(
428
429
        self,
        request: AnyRequest,
430
    ) -> Optional[ErrorResponse]:
431
432
433

        error_response = None

434
        if self._is_model_supported(request.model):
435
            return None
436
        if request.model in self.models.lora_requests:
437
            return None
438
439
440
441
442
443
444
        if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING and request.model and (
                load_result := await self.models.resolve_lora(request.model)):
            if isinstance(load_result, LoRARequest):
                return None
            if isinstance(load_result, ErrorResponse) and \
                load_result.code == HTTPStatus.BAD_REQUEST.value:
                error_response = load_result
445
446
        if request.model in [
                prompt_adapter.prompt_adapter_name
447
                for prompt_adapter in self.models.prompt_adapter_requests
448
449
        ]:
            return None
450
451

        return error_response or self.create_error_response(
452
453
454
455
            message=f"The model `{request.model}` does not exist.",
            err_type="NotFoundError",
            status_code=HTTPStatus.NOT_FOUND)

456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
    def _get_active_default_mm_loras(
            self, request: AnyRequest) -> Optional[LoRARequest]:
        """Determine if there are any active default multimodal loras."""
        # TODO: Currently this is only enabled for chat completions
        # to be better aligned with only being enabled for .generate
        # when run offline. It would be nice to support additional
        # tasks types in the future.
        message_types = self._get_message_types(request)
        default_mm_loras = set()

        for lora in self.models.lora_requests.values():
            # Best effort match for default multimodal lora adapters;
            # There is probably a better way to do this, but currently
            # this matches against the set of 'types' in any content lists
            # up until '_', e.g., to match audio_url -> audio
            if lora.lora_name in message_types:
                default_mm_loras.add(lora)

        # Currently only support default modality specific loras if
        # we have exactly one lora matched on the request.
        if len(default_mm_loras) == 1:
            return default_mm_loras.pop()
        return None

480
    def _maybe_get_adapters(
481
482
483
        self,
        request: AnyRequest,
        supports_default_mm_loras: bool = False,
484
    ) -> Union[tuple[None, None], tuple[LoRARequest, None], tuple[
485
            None, PromptAdapterRequest]]:
486

487
488
        if request.model in self.models.lora_requests:
            return self.models.lora_requests[request.model], None
489
490
491
492
493
494
495
496
497
498
499

        # Currently only support default modality specific loras
        # if we have exactly one lora matched on the request.
        if supports_default_mm_loras:
            default_mm_lora = self._get_active_default_mm_loras(request)
            if default_mm_lora is not None:
                return default_mm_lora, None

        if self._is_model_supported(request.model):
            return None, None

500
        for prompt_adapter in self.models.prompt_adapter_requests:
501
            if request.model == prompt_adapter.prompt_adapter_name:
502
                return None, prompt_adapter
503
        # if _check_model has been called earlier, this will be unreachable
504
        raise ValueError(f"The model `{request.model}` does not exist.")
505

506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
    def _get_message_types(self, request: AnyRequest) -> set[str]:
        """Retrieve the set of types from message content dicts up
        until `_`; we use this to match potential multimodal data
        with default per modality loras.
        """
        message_types: set[str] = set()

        if not hasattr(request, "messages"):
            return message_types

        for message in request.messages:
            if (isinstance(message, dict) and "content" in message
                    and isinstance(message["content"], list)):
                for content_dict in message["content"]:
                    if "type" in content_dict:
                        message_types.add(content_dict["type"].split("_")[0])
        return message_types

524
    async def _normalize_prompt_text_to_input(
525
526
527
528
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
        prompt: str,
529
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]],
530
531
        add_special_tokens: bool,
    ) -> TextTokensPrompt:
532
533
        async_tokenizer = self._get_async_tokenizer(tokenizer)

534
535
536
537
538
        if (self.model_config.encoder_config is not None
                and self.model_config.encoder_config.get(
                    "do_lower_case", False)):
            prompt = prompt.lower()

539
        if truncate_prompt_tokens is None:
540
541
            encoded = await async_tokenizer(
                prompt, add_special_tokens=add_special_tokens)
542
543
        elif truncate_prompt_tokens < 0:
            # Negative means we cap at the model's max length
544
545
546
547
548
            encoded = await async_tokenizer(
                prompt,
                add_special_tokens=add_special_tokens,
                truncation=True,
                max_length=self.max_model_len)
549
        else:
550
551
552
553
554
            encoded = await async_tokenizer(
                prompt,
                add_special_tokens=add_special_tokens,
                truncation=True,
                max_length=truncate_prompt_tokens)
555
556
557
558
559
560

        input_ids = encoded.input_ids
        input_text = prompt

        return self._validate_input(request, input_ids, input_text)

561
    async def _normalize_prompt_tokens_to_input(
562
563
564
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
565
        prompt_ids: list[int],
566
567
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
    ) -> TextTokensPrompt:
568
569
        async_tokenizer = self._get_async_tokenizer(tokenizer)

570
        if truncate_prompt_tokens is None:
571
            input_ids = prompt_ids
572
573
        elif truncate_prompt_tokens < 0:
            input_ids = prompt_ids[-self.max_model_len:]
574
575
576
        else:
            input_ids = prompt_ids[-truncate_prompt_tokens:]

577
        input_text = await async_tokenizer.decode(input_ids)
578

579
580
581
582
583
        return self._validate_input(request, input_ids, input_text)

    def _validate_input(
        self,
        request: AnyRequest,
584
        input_ids: list[int],
585
586
        input_text: str,
    ) -> TextTokensPrompt:
587
588
        token_num = len(input_ids)

589
590
        # Note: EmbeddingRequest, ClassificationRequest,
        # and ScoreRequest doesn't have max_tokens
591
592
        if isinstance(request,
                      (EmbeddingChatRequest, EmbeddingCompletionRequest,
593
                       ScoreRequest, RerankRequest, ClassificationRequest)):
594

595
            if token_num > self.max_model_len:
596
597
598
599
600
601
                operations: dict[type[AnyRequest], str] = {
                    ScoreRequest: "score",
                    ClassificationRequest: "classification"
                }
                operation = operations.get(type(request),
                                           "embedding generation")
602
603
604
                raise ValueError(
                    f"This model's maximum context length is "
                    f"{self.max_model_len} tokens. However, you requested "
605
606
                    f"{token_num} tokens in the input for {operation}. "
                    f"Please reduce the length of the input.")
607
608
            return TextTokensPrompt(prompt=input_text,
                                    prompt_token_ids=input_ids)
609

610
611
        # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
        # and does not require model context length validation
612
613
614
615
        if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest,
                                DetokenizeRequest)):
            return TextTokensPrompt(prompt=input_text,
                                    prompt_token_ids=input_ids)
616

617
618
619
620
621
        # chat completion endpoint supports max_completion_tokens
        if isinstance(request, ChatCompletionRequest):
            # TODO(#9845): remove max_tokens when field dropped from OpenAI API
            max_tokens = request.max_completion_tokens or request.max_tokens
        else:
622
            max_tokens = getattr(request, "max_tokens", None)
623
        if max_tokens is None:
624
625
626
627
628
            if token_num >= self.max_model_len:
                raise ValueError(
                    f"This model's maximum context length is "
                    f"{self.max_model_len} tokens. However, you requested "
                    f"{token_num} tokens in the messages, "
629
                    f"Please reduce the length of the messages.")
630
        elif token_num + max_tokens > self.max_model_len:
631
            raise ValueError(
632
633
                f"This model's maximum context length is "
                f"{self.max_model_len} tokens. However, you requested "
634
                f"{max_tokens + token_num} tokens "
635
                f"({token_num} in the messages, "
636
                f"{max_tokens} in the completion). "
637
638
639
640
                f"Please reduce the length of the messages or completion.")

        return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)

641
    async def _tokenize_prompt_input_async(
642
643
644
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
645
        prompt_input: Union[str, list[int]],
646
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
647
648
649
        add_special_tokens: bool = True,
    ) -> TextTokensPrompt:
        """
650
651
        A simpler implementation of
        [`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs]
652
653
        that assumes single input.
        """
654
        async for result in self._tokenize_prompt_inputs_async(
655
656
                request,
                tokenizer,
657
            [prompt_input],
658
659
                truncate_prompt_tokens=truncate_prompt_tokens,
                add_special_tokens=add_special_tokens,
660
661
662
        ):
            return result
        raise ValueError("No results yielded from tokenization")
663

664
    async def _tokenize_prompt_inputs_async(
665
666
667
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
668
        prompt_inputs: Iterable[Union[str, list[int]]],
669
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
670
        add_special_tokens: bool = True,
671
    ) -> AsyncGenerator[TextTokensPrompt, None]:
672
        """
673
674
        A simpler implementation of
        [`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs]
675
676
677
678
        that assumes multiple inputs.
        """
        for text in prompt_inputs:
            if isinstance(text, str):
679
                yield await self._normalize_prompt_text_to_input(
680
681
682
683
684
685
686
                    request,
                    tokenizer,
                    prompt=text,
                    truncate_prompt_tokens=truncate_prompt_tokens,
                    add_special_tokens=add_special_tokens,
                )
            else:
687
                yield await self._normalize_prompt_tokens_to_input(
688
689
690
691
692
693
                    request,
                    tokenizer,
                    prompt_ids=text,
                    truncate_prompt_tokens=truncate_prompt_tokens,
                )

694
    async def _tokenize_prompt_input_or_inputs_async(
695
696
697
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
698
699
        input_or_inputs: Optional[Union[str, list[str], list[int],
                                        list[list[int]]]],
700
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
701
        add_special_tokens: bool = True,
702
    ) -> tuple[list[TextTokensPrompt], list[EmbedsPrompt]]:
703
704
705
706
707
708
709
        """
        Tokenize/detokenize depending on the input format.

        According to `OpenAI API <https://platform.openai.com/docs/api-reference/embeddings/create>`_
        , each input can be a string or array of tokens. Note that each request
        can pass one or more inputs.
        """
710
711
712
713
714
715
716
717
718
719
720
721
722
723
        inputs_embeds = list[EmbedsPrompt]()
        inputs_text = list[TextTokensPrompt]()

        if (isinstance(request, CompletionRequest)
                and request.prompt_embeds is not None):
            inputs_embeds.extend(
                self._load_prompt_embeds(request.prompt_embeds,
                                         truncate_prompt_tokens))

        # Empty prompts are okay as long as there are prompt embeddings
        if input_or_inputs is None or (inputs_embeds
                                       and input_or_inputs == ""):
            return [], inputs_embeds

724
725
        # Although our type checking is based on mypy,
        # VSCode Pyright extension should still work properly
726
        # "is False" is required for Pyright to perform type narrowing
727
        # See: https://github.com/microsoft/pyright/issues/7672
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752

        # Parse and batch the input prompts
        batch_inputs = parse_and_batch_prompt(input_or_inputs)

        # Process each input in the batch concurrently
        tasks = []
        for prompt_input in batch_inputs:
            if prompt_input["is_tokens"] is False:
                task = self._normalize_prompt_text_to_input(
                    request,
                    tokenizer,
                    prompt_input["content"],
                    truncate_prompt_tokens=truncate_prompt_tokens,
                    add_special_tokens=add_special_tokens)
            else:
                task = self._normalize_prompt_tokens_to_input(
                    request,
                    tokenizer,
                    prompt_input["content"],
                    truncate_prompt_tokens=truncate_prompt_tokens)
            tasks.append(task)

        # Wait for all tokenization tasks to complete
        results = await asyncio.gather(*tasks)
        inputs_text.extend(results)
753
754

        return inputs_text, inputs_embeds
755

756
    @overload
757
    async def _preprocess_completion(
758
        self,
759
760
761
        request: Union[DetokenizeRequest, EmbeddingCompletionRequest,
                       RerankRequest, ClassificationRequest, ScoreRequest,
                       TokenizeCompletionRequest],
762
        tokenizer: AnyTokenizer,
763
        input_or_inputs: Union[str, list[str], list[int], list[list[int]]],
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = ...,
        add_special_tokens: bool = ...,
    ) -> tuple[list[TextTokensPrompt], list[EngineTokensPrompt]]:
        ...

    @overload
    async def _preprocess_completion(
        self,
        request: CompletionRequest,
        tokenizer: AnyTokenizer,
        input_or_inputs: Optional[Union[str, list[str], list[int],
                                        list[list[int]]]],
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = ...,
        add_special_tokens: bool = ...,
    ) -> tuple[list[Union[TextTokensPrompt, EmbedsPrompt]], list[Union[
            EngineTokensPrompt, EngineEmbedsPrompt]]]:
        ...

    async def _preprocess_completion(
        self,
        request: CompletionLikeRequest,
        tokenizer: AnyTokenizer,
        input_or_inputs: Optional[Union[str, list[str], list[int],
                                        list[list[int]]]],
788
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
789
        add_special_tokens: bool = True,
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
    ) -> tuple[Union[list[TextTokensPrompt], list[Union[
            TextTokensPrompt, EmbedsPrompt]]], Union[
                list[EngineTokensPrompt], list[Union[EngineTokensPrompt,
                                                     EngineEmbedsPrompt]]]]:
        if not isinstance(request,
                          CompletionRequest) and input_or_inputs is None:
            raise ValueError(
                "Prompt embeds with non-completion requests is not"
                " currently supported.")

        (request_prompts_text, request_prompts_embeds
         ) = await self._tokenize_prompt_input_or_inputs_async(
             request,
             tokenizer,
             input_or_inputs,
             truncate_prompt_tokens=truncate_prompt_tokens,
             add_special_tokens=add_special_tokens,
         )

        engine_prompts_text = [
            EngineTokensPrompt(
                prompt_token_ids=request_prompt_text["prompt_token_ids"])
            for request_prompt_text in request_prompts_text
        ]
814
815
816
817
818
819
        cache_salt = request.cache_salt if (
            hasattr(request, "cache_salt")
            and request.cache_salt is not None) else None
        if cache_salt:
            for prompt_text in engine_prompts_text:
                prompt_text["cache_salt"] = cache_salt
820

821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
        # This check is equivalent to simply checking if
        # `request_prompts_embeds` is empty, but it's difficult to propagate
        # overloads to the private helper functions to enable this check.
        # This overload is needed because only TextPrompts are allowed for
        # non-completion requests and if we don't add the overload here,
        # everywhere this function is used outside of serving_completion will
        # need logic asserting that only text prompts are in the request.
        if not isinstance(request,
                          CompletionRequest) and input_or_inputs is not None:
            return request_prompts_text, engine_prompts_text

        engine_prompts_embeds = [
            EngineEmbedsPrompt(
                prompt_embeds=request_prompt_embeds["prompt_embeds"])
            for request_prompt_embeds in request_prompts_embeds
836
        ]
837
838
839
        if cache_salt:
            for prompt_embed in engine_prompts_embeds:
                prompt_embed["cache_salt"] = cache_salt
840

841
842
        request_prompts = request_prompts_embeds + request_prompts_text
        engine_prompts = engine_prompts_embeds + engine_prompts_text
843
844
845
846
        return request_prompts, engine_prompts

    async def _preprocess_chat(
        self,
847
        request: Union[ChatLikeRequest, ResponsesRequest],
848
        tokenizer: AnyTokenizer,
849
        messages: list[ChatCompletionMessageParam],
850
851
        chat_template: Optional[str],
        chat_template_content_format: ChatTemplateContentFormatOption,
852
853
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
854
855
856
        tool_dicts: Optional[list[dict[str, Any]]] = None,
        documents: Optional[list[dict[str, str]]] = None,
        chat_template_kwargs: Optional[dict[str, Any]] = None,
857
858
859
        tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None,
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
        add_special_tokens: bool = False,
860
    ) -> tuple[list[ConversationMessage], Sequence[RequestPrompt],
861
               list[EngineTokensPrompt]]:
862
863
        model_config = self.model_config

864
865
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
866
            tool_dicts,
867
868
            chat_template_content_format,
            tokenizer,
869
            model_config=model_config,
870
        )
871
872
        conversation, mm_data_future = parse_chat_messages_futures(
            messages,
873
            model_config,
874
            tokenizer,
875
            content_format=resolved_content_format,
876
877
        )

878
        _chat_template_kwargs: dict[str, Any] = dict(
879
880
881
882
883
884
885
886
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tool_dicts,
            documents=documents,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

887
        request_prompt: Union[str, list[int]]
888
        if isinstance(tokenizer, MistralTokenizer):
889
890
891
            request_prompt = apply_mistral_chat_template(
                tokenizer,
                messages=messages,
892
                **_chat_template_kwargs,
893
894
895
            )
        else:
            request_prompt = apply_hf_chat_template(
896
                tokenizer=tokenizer,
897
                conversation=conversation,
898
                model_config=model_config,
899
                **_chat_template_kwargs,
900
901
902
903
            )

        mm_data = await mm_data_future

904
905
906
907
908
909
910
        # tool parsing is done only if a tool_parser has been set and if
        # tool_choice is not "none" (if tool_choice is "none" but a tool_parser
        # is set, we want to prevent parsing a tool_call hallucinated by the LLM
        should_parse_tools = tool_parser is not None and (hasattr(
            request, "tool_choice") and request.tool_choice != "none")

        if should_parse_tools:
911
912
913
914
            if not isinstance(request, ChatCompletionRequest):
                msg = "Tool usage is only supported for Chat Completions API"
                raise NotImplementedError(msg)

915
916
            request = tool_parser(tokenizer).adjust_request(  # type: ignore
                request=request)
917
918

        if isinstance(request_prompt, str):
919
            prompt_inputs = await self._tokenize_prompt_input_async(
920
921
922
923
924
925
926
927
928
929
930
931
932
933
                request,
                tokenizer,
                request_prompt,
                truncate_prompt_tokens=truncate_prompt_tokens,
                add_special_tokens=add_special_tokens,
            )
        else:
            # For MistralTokenizer
            assert is_list_of(request_prompt, int), (
                "Prompt has to be either a string or a list of token ids")
            prompt_inputs = TextTokensPrompt(
                prompt=tokenizer.decode(request_prompt),
                prompt_token_ids=request_prompt)

934
        engine_prompt = EngineTokensPrompt(
935
936
937
            prompt_token_ids=prompt_inputs["prompt_token_ids"])
        if mm_data is not None:
            engine_prompt["multi_modal_data"] = mm_data
938
939
        if request.mm_processor_kwargs is not None:
            engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
940

941
942
943
        if hasattr(request, "cache_salt") and request.cache_salt is not None:
            engine_prompt["cache_salt"] = request.cache_salt

944
945
        return conversation, [request_prompt], [engine_prompt]

946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
    def _load_prompt_embeds(
        self,
        prompt_embeds: Optional[Union[bytes, list[bytes]]],
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
    ) -> list[EmbedsPrompt]:

        def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt:
            tensor = torch.load(io.BytesIO(base64.b64decode(embed)),
                                weights_only=True)
            assert isinstance(
                tensor,
                (torch.FloatTensor, torch.BFloat16Tensor, torch.HalfTensor))
            if tensor.dim() > 2:
                tensor = tensor.squeeze(0)
                assert tensor.dim() == 2
            if truncate_prompt_tokens is not None:
                tensor = tensor[-truncate_prompt_tokens:]
            return {"prompt_embeds": tensor}

        if prompt_embeds:
            if isinstance(prompt_embeds, list):
                return [
                    _load_and_validate_embed(embed) for embed in prompt_embeds
                ]
            else:
                return [_load_and_validate_embed(prompt_embeds)]
        else:
            return []

975
976
977
    def _log_inputs(
        self,
        request_id: str,
978
        inputs: RequestPrompt,
979
980
        params: Optional[Union[SamplingParams, PoolingParams,
                               BeamSearchParams]],
981
982
983
984
985
        lora_request: Optional[LoRARequest],
        prompt_adapter_request: Optional[PromptAdapterRequest],
    ) -> None:
        if self.request_logger is None:
            return
986
        prompt, prompt_token_ids, prompt_embeds = None, None, None
987
988
989
990
        if isinstance(inputs, str):
            prompt = inputs
        elif isinstance(inputs, list):
            prompt_token_ids = inputs
991
992
        elif 'prompt_embeds' in inputs:
            prompt_embeds = inputs.get("prompt_embeds")
993
        else:
994
995
996
997
998
999
1000
            prompt = inputs["prompt"]
            prompt_token_ids = inputs["prompt_token_ids"]

        self.request_logger.log_inputs(
            request_id,
            prompt,
            prompt_token_ids,
1001
            prompt_embeds,
1002
1003
1004
1005
            params=params,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
        )
1006

1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
    async def _get_trace_headers(
        self,
        headers: Headers,
    ) -> Optional[Mapping[str, str]]:
        is_tracing_enabled = await self.engine_client.is_tracing_enabled()

        if is_tracing_enabled:
            return extract_trace_headers(headers)

        if contains_trace_headers(headers):
            log_tracing_disabled_warning()

        return None

1021
    @staticmethod
1022
    def _base_request_id(raw_request: Optional[Request],
1023
1024
1025
                         default: Optional[str] = None) -> Optional[str]:
        """Pulls the request id to use from a header, if provided"""
        default = default or random_uuid()
1026
1027
1028
1029
        if raw_request is None:
            return default

        return raw_request.headers.get("X-Request-Id", default)
1030

1031
    @staticmethod
1032
1033
1034
1035
1036
1037
1038
    def _get_decoded_token(logprob: Logprob,
                           token_id: int,
                           tokenizer: AnyTokenizer,
                           return_as_token_id: bool = False) -> str:
        if return_as_token_id:
            return f"token_id:{token_id}"

1039
1040
        if logprob.decoded_token is not None:
            return logprob.decoded_token
1041
        return tokenizer.decode(token_id)
1042

1043
    def _is_model_supported(self, model_name: Optional[str]) -> bool:
1044
1045
        if not model_name:
            return True
1046
        return self.models.is_base_model(model_name)
1047
1048
1049
1050
1051
1052

    def _get_model_name(self,
                        model_name: Optional[str] = None,
                        lora_request: Optional[LoRARequest] = None) -> str:
        if lora_request:
            return lora_request.lora_name
1053
        if not model_name:
1054
1055
            return self.models.base_model_paths[0].name
        return model_name
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070


def clamp_prompt_logprobs(
    prompt_logprobs: Union[PromptLogprobs,
                           None]) -> Union[PromptLogprobs, None]:
    if prompt_logprobs is None:
        return prompt_logprobs

    for logprob_dict in prompt_logprobs:
        if logprob_dict is None:
            continue
        for logprob_values in logprob_dict.values():
            if logprob_values.logprob == float('-inf'):
                logprob_values.logprob = -9999.0
    return prompt_logprobs