serving_engine.py 43.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
import io
5
import json
6
import sys
7
import time
8
9
from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor
10
from http import HTTPStatus
11
from typing import (Annotated, Any, Callable, ClassVar, Generic, Optional,
12
                    TypeVar, Union, cast, overload)
13

14
import pybase64
15
import torch
16
from fastapi import Request
17
from pydantic import BaseModel, ConfigDict, Field
18
from starlette.datastructures import Headers
19
20
from typing_extensions import TypeIs

21
22
23
24
25
if sys.version_info >= (3, 12):
    from typing import TypedDict
else:
    from typing_extensions import TypedDict

26
import vllm.envs as envs
27
from vllm.config import ModelConfig
28
from vllm.engine.protocol import EngineClient
29
30
# yapf conflicts with isort for this block
# yapf: disable
31
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
32
                                         ChatTemplateContentFormatOption,
33
34
35
                                         ConversationMessage,
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
36
37
                                         parse_chat_messages_futures,
                                         resolve_chat_template_content_format)
38
from vllm.entrypoints.context import ConversationContext
39
from vllm.entrypoints.logger import RequestLogger
40
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
41
42
43
                                              ChatCompletionResponse,
                                              ClassificationRequest,
                                              ClassificationResponse,
44
                                              CompletionRequest,
45
                                              CompletionResponse,
46
                                              DetokenizeRequest,
47
48
                                              EmbeddingChatRequest,
                                              EmbeddingCompletionRequest,
49
                                              EmbeddingRequest,
50
51
52
53
                                              EmbeddingResponse, ErrorInfo,
                                              ErrorResponse, PoolingResponse,
                                              RerankRequest, ResponsesRequest,
                                              ScoreRequest, ScoreResponse,
54
                                              TokenizeChatRequest,
55
                                              TokenizeCompletionRequest,
56
57
                                              TokenizeResponse,
                                              TranscriptionRequest,
58
59
                                              TranscriptionResponse,
                                              TranslationRequest)
60
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
61
from vllm.entrypoints.openai.tool_parsers import ToolParser
62
# yapf: enable
63
64
from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
65
from vllm.inputs.parse import parse_and_batch_prompt
66
from vllm.logger import init_logger
67
from vllm.lora.request import LoRARequest
68
69
70
from vllm.multimodal import (  # noqa: F401 - Required to resolve Pydantic error in RequestProcessingMixin
    MultiModalDataDict)
from vllm.outputs import PoolingRequestOutput, RequestOutput
71
from vllm.pooling_params import PoolingParams
72
from vllm.sampling_params import BeamSearchParams, SamplingParams
73
from vllm.sequence import Logprob, PromptLogprobs
74
75
76
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
                          log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
77
78
from vllm.utils import (AsyncMicrobatchTokenizer, is_list_of,
                        merge_async_iterators, random_uuid)
79
80
81

logger = init_logger(__name__)

82
CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest,
83
                              EmbeddingCompletionRequest, RerankRequest,
84
85
                              ClassificationRequest, ScoreRequest,
                              TokenizeCompletionRequest]
86
87
88

ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
                        TokenizeChatRequest]
89
SpeechToTextRequest = Union[TranscriptionRequest, TranslationRequest]
90
91
AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest, SpeechToTextRequest,
                   ResponsesRequest]
92

93
94
95
96
97
98
99
100
101
102
103
AnyResponse = Union[
    CompletionResponse,
    ChatCompletionResponse,
    EmbeddingResponse,
    TranscriptionResponse,
    TokenizeResponse,
    PoolingResponse,
    ClassificationResponse,
    ScoreResponse,
]

104
105
106

class TextTokensPrompt(TypedDict):
    prompt: str
107
    prompt_token_ids: list[int]
108
109


110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
class EmbedsPrompt(TypedDict):
    prompt_embeds: torch.Tensor


RequestPrompt = Union[list[int], str, TextTokensPrompt, EmbedsPrompt]


def is_text_tokens_prompt(prompt: RequestPrompt) -> TypeIs[TextTokensPrompt]:
    return (isinstance(prompt, dict) and "prompt_token_ids" in prompt
            and "prompt_embeds" not in prompt)


def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]:
    return (isinstance(prompt, dict) and "prompt_token_ids" not in prompt
            and "prompt_embeds" in prompt)

126

127
128
129
130
131
RequestT = TypeVar("RequestT", bound=AnyRequest)


class RequestProcessingMixin(BaseModel):
    """
132
    Mixin for request processing,
133
134
    handling prompt preparation and engine input.
    """
135
    request_prompts: Optional[Sequence[RequestPrompt]] = []
136
    engine_prompts: Optional[Union[list[EngineTokensPrompt],
137
                                   list[EngineEmbedsPrompt]]] = []
138
139
140
141
142
143

    model_config = ConfigDict(arbitrary_types_allowed=True)


class ResponseGenerationMixin(BaseModel):
    """
144
    Mixin for response generation,
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
    managing result generators and final batch results.
    """
    result_generator: Optional[AsyncGenerator[tuple[int, Union[
        RequestOutput, PoolingRequestOutput]], None]] = None
    final_res_batch: list[Union[RequestOutput, PoolingRequestOutput]] = Field(
        default_factory=list)

    model_config = ConfigDict(arbitrary_types_allowed=True)


class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, BaseModel,
                   Generic[RequestT]):
    # Shared across all requests
    request: RequestT
    raw_request: Optional[Request] = None
    model_name: str
    request_id: str
    created_time: int = Field(default_factory=lambda: int(time.time()))
    lora_request: Optional[LoRARequest] = None

    # Shared across most requests
    tokenizer: Optional[AnyTokenizer] = None
    truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None

    # `protected_namespaces` resolves Pydantic v2's warning
    # on conflict with protected namespace "model_"
    model_config = ConfigDict(
        protected_namespaces=(),
        arbitrary_types_allowed=True,
    )


ClassificationServeContext = ServeContext[ClassificationRequest]


class EmbeddingServeContext(ServeContext[EmbeddingRequest]):
    chat_template: Optional[str] = None
    chat_template_content_format: ChatTemplateContentFormatOption


# Used to resolve the Pydantic error related to
# forward reference of MultiModalDataDict in TokensPrompt
RequestProcessingMixin.model_rebuild()
ServeContext.model_rebuild()
ClassificationServeContext.model_rebuild()
EmbeddingServeContext.model_rebuild()

192

193
class OpenAIServing:
194
195
196
197
    request_id_prefix: ClassVar[str] = """
    A short string prepended to every request’s ID (e.g. "embd", "classify")
    so you can easily tell “this ID came from Embedding vs Classification.”
    """
198

199
200
    def __init__(
        self,
201
        engine_client: EngineClient,
202
        model_config: ModelConfig,
203
        models: OpenAIServingModels,
204
205
        *,
        request_logger: Optional[RequestLogger],
206
        return_tokens_as_token_ids: bool = False,
207
        enable_force_include_usage: bool = False,
208
    ):
209
210
        super().__init__()

211
        self.engine_client = engine_client
212
        self.model_config = model_config
213
214
        self.max_model_len = model_config.max_model_len

215
        self.models = models
216

217
        self.request_logger = request_logger
218
        self.return_tokens_as_token_ids = return_tokens_as_token_ids
219
        self.enable_force_include_usage = enable_force_include_usage
220

221
222
        self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)

223
224
225
226
227
        self._async_tokenizer_pool: dict[AnyTokenizer,
                                         AsyncMicrobatchTokenizer] = {}

    def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer:
        """
228
        Return (and cache) an `AsyncMicrobatchTokenizer` bound to the
229
230
231
232
233
234
235
        given tokenizer.
        """
        async_tokenizer = self._async_tokenizer_pool.get(tokenizer)
        if async_tokenizer is None:
            async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
            self._async_tokenizer_pool[tokenizer] = async_tokenizer
        return async_tokenizer
236

237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
    async def _preprocess(
        self,
        ctx: ServeContext,
    ) -> Optional[ErrorResponse]:
        """
        Default preprocessing hook. Subclasses may override
        to prepare `ctx` (classification, embedding, etc.).
        """
        return None

    def _build_response(
        self,
        ctx: ServeContext,
    ) -> Union[AnyResponse, ErrorResponse]:
        """
        Default response builder. Subclass may override this method
        to return the appropriate response object.
        """
        return self.create_error_response("unimplemented endpoint")

    async def handle(
        self,
        ctx: ServeContext,
    ) -> Union[AnyResponse, ErrorResponse]:
        generation: AsyncGenerator[Union[AnyResponse, ErrorResponse], None]
        generation = self._pipeline(ctx)

        async for response in generation:
            return response

        return self.create_error_response("No response yielded from pipeline")

    async def _pipeline(
        self,
        ctx: ServeContext,
    ) -> AsyncGenerator[Union[AnyResponse, ErrorResponse], None]:
        """Execute the request processing pipeline yielding responses."""
        if error := await self._check_model(ctx.request):
            yield error
        if error := self._validate_request(ctx):
            yield error

        preprocess_ret = await self._preprocess(ctx)
        if isinstance(preprocess_ret, ErrorResponse):
            yield preprocess_ret

        generators_ret = await self._prepare_generators(ctx)
        if isinstance(generators_ret, ErrorResponse):
            yield generators_ret

        collect_ret = await self._collect_batch(ctx)
        if isinstance(collect_ret, ErrorResponse):
            yield collect_ret

        yield self._build_response(ctx)

    def _validate_request(self, ctx: ServeContext) -> Optional[ErrorResponse]:
        truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens",
                                         None)

        if truncate_prompt_tokens is not None:
            if truncate_prompt_tokens <= self.max_model_len:
                ctx.truncate_prompt_tokens = truncate_prompt_tokens
            else:
                return self.create_error_response(
                    "truncate_prompt_tokens value is "
                    "greater than max_model_len."
                    " Please, select a smaller truncation size.")
        return None

307
308
309
310
311
312
313
314
315
316
    def _create_pooling_params(
        self,
        ctx: ServeContext,
    ) -> Union[PoolingParams, ErrorResponse]:
        if not hasattr(ctx.request, "to_pooling_params"):
            return self.create_error_response(
                "Request type does not support pooling parameters")

        return ctx.request.to_pooling_params()

317
318
319
320
321
322
323
324
325
326
327
328
329
    async def _prepare_generators(
        self,
        ctx: ServeContext,
    ) -> Optional[ErrorResponse]:
        """Schedule the request and get the result generator."""
        generators: list[AsyncGenerator[Union[RequestOutput,
                                              PoolingRequestOutput],
                                        None]] = []

        try:
            trace_headers = (None if ctx.raw_request is None else await
                             self._get_trace_headers(ctx.raw_request.headers))

330
331
332
            pooling_params = self._create_pooling_params(ctx)
            if isinstance(pooling_params, ErrorResponse):
                return pooling_params
333
334
335
336
337
338
339
340
341
342
343
344

            if ctx.engine_prompts is None:
                return self.create_error_response(
                    "Engine prompts not available")

            for i, engine_prompt in enumerate(ctx.engine_prompts):
                request_id_item = f"{ctx.request_id}-{i}"

                if ctx.request_prompts is None:
                    return self.create_error_response(
                        "Request prompts not available")

345
346
347
348
                self._log_inputs(request_id_item,
                                 ctx.request_prompts[i],
                                 params=pooling_params,
                                 lora_request=ctx.lora_request)
349

350
351
352
353
354
355
                # Mypy has an existing bug related to inferring the variance of
                # TypedDicts with `builtins.enumerate`:
                # https://github.com/python/mypy/issues/8586#issuecomment-2867698435
                engine_prompt = cast(
                    Union[EngineTokensPrompt, EngineEmbedsPrompt],
                    engine_prompt)
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
                generator = self.engine_client.encode(
                    engine_prompt,
                    pooling_params,
                    request_id_item,
                    lora_request=ctx.lora_request,
                    trace_headers=trace_headers,
                    priority=getattr(ctx.request, "priority", 0),
                )

                generators.append(generator)

            ctx.result_generator = merge_async_iterators(*generators)

            return None

        except Exception as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))

    async def _collect_batch(
        self,
        ctx: ServeContext,
    ) -> Optional[ErrorResponse]:
        """Collect batch results from the result generator."""
        try:
            if ctx.engine_prompts is None:
                return self.create_error_response(
                    "Engine prompts not available")

            num_prompts = len(ctx.engine_prompts)
            final_res_batch: list[Optional[Union[RequestOutput,
                                                 PoolingRequestOutput]]]
            final_res_batch = [None] * num_prompts

            if ctx.result_generator is None:
                return self.create_error_response(
                    "Result generator not available")

            async for i, res in ctx.result_generator:
                final_res_batch[i] = res

            if None in final_res_batch:
                return self.create_error_response(
                    "Failed to generate results for all prompts")

            ctx.final_res_batch = [
                res for res in final_res_batch if res is not None
            ]

            return None

        except Exception as e:
            return self.create_error_response(str(e))

410
411
412
413
414
    def create_error_response(
            self,
            message: str,
            err_type: str = "BadRequestError",
            status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
415
416
        return ErrorResponse(error=ErrorInfo(
            message=message, type=err_type, code=status_code.value))
417

418
419
420
421
422
    def create_streaming_error_response(
            self,
            message: str,
            err_type: str = "BadRequestError",
            status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
423
        json_str = json.dumps(
424
425
            self.create_error_response(message=message,
                                       err_type=err_type,
426
                                       status_code=status_code).model_dump())
427
428
        return json_str

429
    async def _check_model(
430
431
        self,
        request: AnyRequest,
432
    ) -> Optional[ErrorResponse]:
433
434
435

        error_response = None

436
        if self._is_model_supported(request.model):
437
            return None
438
        if request.model in self.models.lora_requests:
439
            return None
440
441
442
443
444
        if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING and request.model and (
                load_result := await self.models.resolve_lora(request.model)):
            if isinstance(load_result, LoRARequest):
                return None
            if isinstance(load_result, ErrorResponse) and \
445
                load_result.error.code == HTTPStatus.BAD_REQUEST.value:
446
447
448
                error_response = load_result

        return error_response or self.create_error_response(
449
450
451
452
            message=f"The model `{request.model}` does not exist.",
            err_type="NotFoundError",
            status_code=HTTPStatus.NOT_FOUND)

453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
    def _get_active_default_mm_loras(
            self, request: AnyRequest) -> Optional[LoRARequest]:
        """Determine if there are any active default multimodal loras."""
        # TODO: Currently this is only enabled for chat completions
        # to be better aligned with only being enabled for .generate
        # when run offline. It would be nice to support additional
        # tasks types in the future.
        message_types = self._get_message_types(request)
        default_mm_loras = set()

        for lora in self.models.lora_requests.values():
            # Best effort match for default multimodal lora adapters;
            # There is probably a better way to do this, but currently
            # this matches against the set of 'types' in any content lists
            # up until '_', e.g., to match audio_url -> audio
            if lora.lora_name in message_types:
                default_mm_loras.add(lora)

        # Currently only support default modality specific loras if
        # we have exactly one lora matched on the request.
        if len(default_mm_loras) == 1:
            return default_mm_loras.pop()
        return None

477
    def _maybe_get_adapters(
478
479
480
        self,
        request: AnyRequest,
        supports_default_mm_loras: bool = False,
481
    ) -> Optional[LoRARequest]:
482

483
        if request.model in self.models.lora_requests:
484
            return self.models.lora_requests[request.model]
485
486
487
488
489
490

        # Currently only support default modality specific loras
        # if we have exactly one lora matched on the request.
        if supports_default_mm_loras:
            default_mm_lora = self._get_active_default_mm_loras(request)
            if default_mm_lora is not None:
491
                return default_mm_lora
492
493

        if self._is_model_supported(request.model):
494
            return None
495

496
        # if _check_model has been called earlier, this will be unreachable
497
        raise ValueError(f"The model `{request.model}` does not exist.")
498

499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
    def _get_message_types(self, request: AnyRequest) -> set[str]:
        """Retrieve the set of types from message content dicts up
        until `_`; we use this to match potential multimodal data
        with default per modality loras.
        """
        message_types: set[str] = set()

        if not hasattr(request, "messages"):
            return message_types

        for message in request.messages:
            if (isinstance(message, dict) and "content" in message
                    and isinstance(message["content"], list)):
                for content_dict in message["content"]:
                    if "type" in content_dict:
                        message_types.add(content_dict["type"].split("_")[0])
        return message_types

517
    async def _normalize_prompt_text_to_input(
518
519
520
521
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
        prompt: str,
522
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]],
523
524
        add_special_tokens: bool,
    ) -> TextTokensPrompt:
525
526
        async_tokenizer = self._get_async_tokenizer(tokenizer)

527
528
529
530
531
        if (self.model_config.encoder_config is not None
                and self.model_config.encoder_config.get(
                    "do_lower_case", False)):
            prompt = prompt.lower()

532
        if truncate_prompt_tokens is None:
533
534
            encoded = await async_tokenizer(
                prompt, add_special_tokens=add_special_tokens)
535
536
        elif truncate_prompt_tokens < 0:
            # Negative means we cap at the model's max length
537
538
539
540
541
            encoded = await async_tokenizer(
                prompt,
                add_special_tokens=add_special_tokens,
                truncation=True,
                max_length=self.max_model_len)
542
        else:
543
544
545
546
547
            encoded = await async_tokenizer(
                prompt,
                add_special_tokens=add_special_tokens,
                truncation=True,
                max_length=truncate_prompt_tokens)
548
549
550
551
552
553

        input_ids = encoded.input_ids
        input_text = prompt

        return self._validate_input(request, input_ids, input_text)

554
    async def _normalize_prompt_tokens_to_input(
555
556
557
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
558
        prompt_ids: list[int],
559
560
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
    ) -> TextTokensPrompt:
561
562
        async_tokenizer = self._get_async_tokenizer(tokenizer)

563
        if truncate_prompt_tokens is None:
564
            input_ids = prompt_ids
565
566
        elif truncate_prompt_tokens < 0:
            input_ids = prompt_ids[-self.max_model_len:]
567
568
569
        else:
            input_ids = prompt_ids[-truncate_prompt_tokens:]

570
        input_text = await async_tokenizer.decode(input_ids)
571

572
573
574
575
576
        return self._validate_input(request, input_ids, input_text)

    def _validate_input(
        self,
        request: AnyRequest,
577
        input_ids: list[int],
578
579
        input_text: str,
    ) -> TextTokensPrompt:
580
581
        token_num = len(input_ids)

582
583
        # Note: EmbeddingRequest, ClassificationRequest,
        # and ScoreRequest doesn't have max_tokens
584
585
        if isinstance(request,
                      (EmbeddingChatRequest, EmbeddingCompletionRequest,
586
                       ScoreRequest, RerankRequest, ClassificationRequest)):
587

588
            if token_num > self.max_model_len:
589
590
591
592
593
594
                operations: dict[type[AnyRequest], str] = {
                    ScoreRequest: "score",
                    ClassificationRequest: "classification"
                }
                operation = operations.get(type(request),
                                           "embedding generation")
595
596
597
                raise ValueError(
                    f"This model's maximum context length is "
                    f"{self.max_model_len} tokens. However, you requested "
598
599
                    f"{token_num} tokens in the input for {operation}. "
                    f"Please reduce the length of the input.")
600
601
            return TextTokensPrompt(prompt=input_text,
                                    prompt_token_ids=input_ids)
602

603
604
        # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
        # and does not require model context length validation
605
606
607
608
        if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest,
                                DetokenizeRequest)):
            return TextTokensPrompt(prompt=input_text,
                                    prompt_token_ids=input_ids)
609

610
611
612
613
614
        # chat completion endpoint supports max_completion_tokens
        if isinstance(request, ChatCompletionRequest):
            # TODO(#9845): remove max_tokens when field dropped from OpenAI API
            max_tokens = request.max_completion_tokens or request.max_tokens
        else:
615
            max_tokens = getattr(request, "max_tokens", None)
616
        if max_tokens is None:
617
618
619
620
621
            if token_num >= self.max_model_len:
                raise ValueError(
                    f"This model's maximum context length is "
                    f"{self.max_model_len} tokens. However, you requested "
                    f"{token_num} tokens in the messages, "
622
                    f"Please reduce the length of the messages.")
623
        elif token_num + max_tokens > self.max_model_len:
624
            raise ValueError(
625
626
                f"This model's maximum context length is "
                f"{self.max_model_len} tokens. However, you requested "
627
                f"{max_tokens + token_num} tokens "
628
                f"({token_num} in the messages, "
629
                f"{max_tokens} in the completion). "
630
631
632
633
                f"Please reduce the length of the messages or completion.")

        return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)

634
    async def _tokenize_prompt_input_async(
635
636
637
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
638
        prompt_input: Union[str, list[int]],
639
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
640
641
642
        add_special_tokens: bool = True,
    ) -> TextTokensPrompt:
        """
643
644
        A simpler implementation of
        [`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs]
645
646
        that assumes single input.
        """
647
        async for result in self._tokenize_prompt_inputs_async(
648
649
                request,
                tokenizer,
650
            [prompt_input],
651
652
                truncate_prompt_tokens=truncate_prompt_tokens,
                add_special_tokens=add_special_tokens,
653
654
655
        ):
            return result
        raise ValueError("No results yielded from tokenization")
656

657
    async def _tokenize_prompt_inputs_async(
658
659
660
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
661
        prompt_inputs: Iterable[Union[str, list[int]]],
662
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
663
        add_special_tokens: bool = True,
664
    ) -> AsyncGenerator[TextTokensPrompt, None]:
665
        """
666
667
        A simpler implementation of
        [`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs]
668
669
670
671
        that assumes multiple inputs.
        """
        for text in prompt_inputs:
            if isinstance(text, str):
672
                yield await self._normalize_prompt_text_to_input(
673
674
675
676
677
678
679
                    request,
                    tokenizer,
                    prompt=text,
                    truncate_prompt_tokens=truncate_prompt_tokens,
                    add_special_tokens=add_special_tokens,
                )
            else:
680
                yield await self._normalize_prompt_tokens_to_input(
681
682
683
684
685
686
                    request,
                    tokenizer,
                    prompt_ids=text,
                    truncate_prompt_tokens=truncate_prompt_tokens,
                )

687
    async def _tokenize_prompt_input_or_inputs_async(
688
689
690
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
691
692
        input_or_inputs: Optional[Union[str, list[str], list[int],
                                        list[list[int]]]],
693
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
694
        add_special_tokens: bool = True,
695
    ) -> tuple[list[TextTokensPrompt], list[EmbedsPrompt]]:
696
697
698
699
700
701
702
        """
        Tokenize/detokenize depending on the input format.

        According to `OpenAI API <https://platform.openai.com/docs/api-reference/embeddings/create>`_
        , each input can be a string or array of tokens. Note that each request
        can pass one or more inputs.
        """
703
704
705
706
707
708
709
710
711
712
713
714
715
716
        inputs_embeds = list[EmbedsPrompt]()
        inputs_text = list[TextTokensPrompt]()

        if (isinstance(request, CompletionRequest)
                and request.prompt_embeds is not None):
            inputs_embeds.extend(
                self._load_prompt_embeds(request.prompt_embeds,
                                         truncate_prompt_tokens))

        # Empty prompts are okay as long as there are prompt embeddings
        if input_or_inputs is None or (inputs_embeds
                                       and input_or_inputs == ""):
            return [], inputs_embeds

717
718
        # Although our type checking is based on mypy,
        # VSCode Pyright extension should still work properly
719
        # "is False" is required for Pyright to perform type narrowing
720
        # See: https://github.com/microsoft/pyright/issues/7672
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745

        # Parse and batch the input prompts
        batch_inputs = parse_and_batch_prompt(input_or_inputs)

        # Process each input in the batch concurrently
        tasks = []
        for prompt_input in batch_inputs:
            if prompt_input["is_tokens"] is False:
                task = self._normalize_prompt_text_to_input(
                    request,
                    tokenizer,
                    prompt_input["content"],
                    truncate_prompt_tokens=truncate_prompt_tokens,
                    add_special_tokens=add_special_tokens)
            else:
                task = self._normalize_prompt_tokens_to_input(
                    request,
                    tokenizer,
                    prompt_input["content"],
                    truncate_prompt_tokens=truncate_prompt_tokens)
            tasks.append(task)

        # Wait for all tokenization tasks to complete
        results = await asyncio.gather(*tasks)
        inputs_text.extend(results)
746
747

        return inputs_text, inputs_embeds
748

749
    @overload
750
    async def _preprocess_completion(
751
        self,
752
753
754
        request: Union[DetokenizeRequest, EmbeddingCompletionRequest,
                       RerankRequest, ClassificationRequest, ScoreRequest,
                       TokenizeCompletionRequest],
755
        tokenizer: AnyTokenizer,
756
        input_or_inputs: Union[str, list[str], list[int], list[list[int]]],
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = ...,
        add_special_tokens: bool = ...,
    ) -> tuple[list[TextTokensPrompt], list[EngineTokensPrompt]]:
        ...

    @overload
    async def _preprocess_completion(
        self,
        request: CompletionRequest,
        tokenizer: AnyTokenizer,
        input_or_inputs: Optional[Union[str, list[str], list[int],
                                        list[list[int]]]],
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = ...,
        add_special_tokens: bool = ...,
    ) -> tuple[list[Union[TextTokensPrompt, EmbedsPrompt]], list[Union[
            EngineTokensPrompt, EngineEmbedsPrompt]]]:
        ...

    async def _preprocess_completion(
        self,
        request: CompletionLikeRequest,
        tokenizer: AnyTokenizer,
        input_or_inputs: Optional[Union[str, list[str], list[int],
                                        list[list[int]]]],
781
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
782
        add_special_tokens: bool = True,
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
    ) -> tuple[Union[list[TextTokensPrompt], list[Union[
            TextTokensPrompt, EmbedsPrompt]]], Union[
                list[EngineTokensPrompt], list[Union[EngineTokensPrompt,
                                                     EngineEmbedsPrompt]]]]:
        if not isinstance(request,
                          CompletionRequest) and input_or_inputs is None:
            raise ValueError(
                "Prompt embeds with non-completion requests is not"
                " currently supported.")

        (request_prompts_text, request_prompts_embeds
         ) = await self._tokenize_prompt_input_or_inputs_async(
             request,
             tokenizer,
             input_or_inputs,
             truncate_prompt_tokens=truncate_prompt_tokens,
             add_special_tokens=add_special_tokens,
         )

        engine_prompts_text = [
            EngineTokensPrompt(
                prompt_token_ids=request_prompt_text["prompt_token_ids"])
            for request_prompt_text in request_prompts_text
        ]
807
808
809
810
811
812
        cache_salt = request.cache_salt if (
            hasattr(request, "cache_salt")
            and request.cache_salt is not None) else None
        if cache_salt:
            for prompt_text in engine_prompts_text:
                prompt_text["cache_salt"] = cache_salt
813

814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
        # This check is equivalent to simply checking if
        # `request_prompts_embeds` is empty, but it's difficult to propagate
        # overloads to the private helper functions to enable this check.
        # This overload is needed because only TextPrompts are allowed for
        # non-completion requests and if we don't add the overload here,
        # everywhere this function is used outside of serving_completion will
        # need logic asserting that only text prompts are in the request.
        if not isinstance(request,
                          CompletionRequest) and input_or_inputs is not None:
            return request_prompts_text, engine_prompts_text

        engine_prompts_embeds = [
            EngineEmbedsPrompt(
                prompt_embeds=request_prompt_embeds["prompt_embeds"])
            for request_prompt_embeds in request_prompts_embeds
829
        ]
830
831
832
        if cache_salt:
            for prompt_embed in engine_prompts_embeds:
                prompt_embed["cache_salt"] = cache_salt
833

834
835
        request_prompts = request_prompts_embeds + request_prompts_text
        engine_prompts = engine_prompts_embeds + engine_prompts_text
836
837
838
839
        return request_prompts, engine_prompts

    async def _preprocess_chat(
        self,
840
        request: Union[ChatLikeRequest, ResponsesRequest],
841
        tokenizer: AnyTokenizer,
842
        messages: list[ChatCompletionMessageParam],
843
844
        chat_template: Optional[str],
        chat_template_content_format: ChatTemplateContentFormatOption,
845
846
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
847
848
849
        tool_dicts: Optional[list[dict[str, Any]]] = None,
        documents: Optional[list[dict[str, str]]] = None,
        chat_template_kwargs: Optional[dict[str, Any]] = None,
850
851
852
        tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None,
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
        add_special_tokens: bool = False,
853
    ) -> tuple[list[ConversationMessage], Sequence[RequestPrompt],
854
               list[EngineTokensPrompt]]:
855
856
        model_config = self.model_config

857
858
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
859
            tool_dicts,
860
861
            chat_template_content_format,
            tokenizer,
862
            model_config=model_config,
863
        )
864
865
        conversation, mm_data_future = parse_chat_messages_futures(
            messages,
866
            model_config,
867
            tokenizer,
868
            content_format=resolved_content_format,
869
870
        )

871
        _chat_template_kwargs: dict[str, Any] = dict(
872
873
874
875
876
877
878
879
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tool_dicts,
            documents=documents,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

880
        request_prompt: Union[str, list[int]]
881
882
883
884

        if tokenizer is None:
            request_prompt = "placeholder"
        elif isinstance(tokenizer, MistralTokenizer):
885
886
887
            request_prompt = apply_mistral_chat_template(
                tokenizer,
                messages=messages,
888
                **_chat_template_kwargs,
889
890
891
            )
        else:
            request_prompt = apply_hf_chat_template(
892
                tokenizer=tokenizer,
893
                conversation=conversation,
894
                model_config=model_config,
895
                **_chat_template_kwargs,
896
897
898
899
            )

        mm_data = await mm_data_future

900
901
902
903
904
905
906
        # tool parsing is done only if a tool_parser has been set and if
        # tool_choice is not "none" (if tool_choice is "none" but a tool_parser
        # is set, we want to prevent parsing a tool_call hallucinated by the LLM
        should_parse_tools = tool_parser is not None and (hasattr(
            request, "tool_choice") and request.tool_choice != "none")

        if should_parse_tools:
907
908
909
910
            if not isinstance(request, ChatCompletionRequest):
                msg = "Tool usage is only supported for Chat Completions API"
                raise NotImplementedError(msg)

911
912
            request = tool_parser(tokenizer).adjust_request(  # type: ignore
                request=request)
913

914
915
916
917
918
919
920
921
        if tokenizer is None:
            assert isinstance(request_prompt, str), (
                "Prompt has to be a string", \
                "when the tokenizer is not initialised"
            )
            prompt_inputs = TextTokensPrompt(prompt=request_prompt,
                                             prompt_token_ids=[1])
        elif isinstance(request_prompt, str):
922
            prompt_inputs = await self._tokenize_prompt_input_async(
923
924
925
926
927
928
929
930
931
932
933
934
935
936
                request,
                tokenizer,
                request_prompt,
                truncate_prompt_tokens=truncate_prompt_tokens,
                add_special_tokens=add_special_tokens,
            )
        else:
            # For MistralTokenizer
            assert is_list_of(request_prompt, int), (
                "Prompt has to be either a string or a list of token ids")
            prompt_inputs = TextTokensPrompt(
                prompt=tokenizer.decode(request_prompt),
                prompt_token_ids=request_prompt)

937
        engine_prompt = EngineTokensPrompt(
938
939
940
            prompt_token_ids=prompt_inputs["prompt_token_ids"])
        if mm_data is not None:
            engine_prompt["multi_modal_data"] = mm_data
941
942
        if request.mm_processor_kwargs is not None:
            engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
943

944
945
946
        if hasattr(request, "cache_salt") and request.cache_salt is not None:
            engine_prompt["cache_salt"] = request.cache_salt

947
948
        return conversation, [request_prompt], [engine_prompt]

949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
    async def _generate_with_builtin_tools(
        self,
        request_id: str,
        request_prompt: RequestPrompt,
        engine_prompt: EngineTokensPrompt,
        sampling_params: SamplingParams,
        context: ConversationContext,
        lora_request: Optional[LoRARequest] = None,
        priority: int = 0,
        **kwargs,
    ):
        orig_priority = priority
        while True:
            self._log_inputs(
                request_id,
                request_prompt,
                params=sampling_params,
                lora_request=lora_request,
            )
            generator = self.engine_client.generate(
                engine_prompt,
                sampling_params,
                request_id,
                lora_request=lora_request,
                priority=priority,
                **kwargs,
            )
            async for res in generator:
                context.append_output(res)
                # NOTE(woosuk): The stop condition is handled by the engine.
                yield context

            if not context.need_builtin_tool_call():
                # The model did not ask for a tool call, so we're done.
                break

            # Call the tool and update the context with the result.
            tool_output = await context.call_tool()
            context.append_output(tool_output)

            # TODO: uncomment this and enable tool output streaming
            # yield context

            # Create inputs for the next turn.
            # Render the next prompt token ids.
            prompt_token_ids = context.render_for_completion()
            engine_prompt = EngineTokensPrompt(
                prompt_token_ids=prompt_token_ids)
            request_prompt = prompt_token_ids
            # Update the sampling params.
            sampling_params.max_tokens = (self.max_model_len -
                                          len(prompt_token_ids))
            # OPTIMIZATION
            priority = orig_priority - 1

1004
1005
1006
1007
1008
1009
1010
    def _load_prompt_embeds(
        self,
        prompt_embeds: Optional[Union[bytes, list[bytes]]],
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
    ) -> list[EmbedsPrompt]:

        def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt:
1011
1012
            tensor = torch.load(io.BytesIO(
                pybase64.b64decode(embed, validate=True)),
1013
                                weights_only=True)
1014
1015
1016
1017
1018
            assert isinstance(tensor, torch.Tensor) and tensor.dtype in (
                torch.float32,
                torch.bfloat16,
                torch.float16,
            )
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
            if tensor.dim() > 2:
                tensor = tensor.squeeze(0)
                assert tensor.dim() == 2
            if truncate_prompt_tokens is not None:
                tensor = tensor[-truncate_prompt_tokens:]
            return {"prompt_embeds": tensor}

        if prompt_embeds:
            if isinstance(prompt_embeds, list):
                return [
                    _load_and_validate_embed(embed) for embed in prompt_embeds
                ]
            else:
                return [_load_and_validate_embed(prompt_embeds)]
        else:
            return []

1036
1037
1038
    def _log_inputs(
        self,
        request_id: str,
1039
        inputs: RequestPrompt,
1040
1041
        params: Optional[Union[SamplingParams, PoolingParams,
                               BeamSearchParams]],
1042
1043
1044
1045
        lora_request: Optional[LoRARequest],
    ) -> None:
        if self.request_logger is None:
            return
1046
        prompt, prompt_token_ids, prompt_embeds = None, None, None
1047
1048
1049
1050
        if isinstance(inputs, str):
            prompt = inputs
        elif isinstance(inputs, list):
            prompt_token_ids = inputs
1051
1052
        elif 'prompt_embeds' in inputs:
            prompt_embeds = inputs.get("prompt_embeds")
1053
        else:
1054
1055
1056
1057
1058
1059
1060
            prompt = inputs["prompt"]
            prompt_token_ids = inputs["prompt_token_ids"]

        self.request_logger.log_inputs(
            request_id,
            prompt,
            prompt_token_ids,
1061
            prompt_embeds,
1062
1063
1064
            params=params,
            lora_request=lora_request,
        )
1065

1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
    async def _get_trace_headers(
        self,
        headers: Headers,
    ) -> Optional[Mapping[str, str]]:
        is_tracing_enabled = await self.engine_client.is_tracing_enabled()

        if is_tracing_enabled:
            return extract_trace_headers(headers)

        if contains_trace_headers(headers):
            log_tracing_disabled_warning()

        return None

1080
    @staticmethod
1081
    def _base_request_id(raw_request: Optional[Request],
1082
1083
1084
                         default: Optional[str] = None) -> Optional[str]:
        """Pulls the request id to use from a header, if provided"""
        default = default or random_uuid()
1085
1086
1087
1088
        if raw_request is None:
            return default

        return raw_request.headers.get("X-Request-Id", default)
1089

1090
    @staticmethod
1091
1092
1093
1094
1095
1096
1097
    def _get_decoded_token(logprob: Logprob,
                           token_id: int,
                           tokenizer: AnyTokenizer,
                           return_as_token_id: bool = False) -> str:
        if return_as_token_id:
            return f"token_id:{token_id}"

1098
1099
        if logprob.decoded_token is not None:
            return logprob.decoded_token
1100
        return tokenizer.decode(token_id)
1101

1102
    def _is_model_supported(self, model_name: Optional[str]) -> bool:
1103
1104
        if not model_name:
            return True
1105
        return self.models.is_base_model(model_name)
1106
1107
1108
1109
1110
1111

    def _get_model_name(self,
                        model_name: Optional[str] = None,
                        lora_request: Optional[LoRARequest] = None) -> str:
        if lora_request:
            return lora_request.lora_name
1112
        if not model_name:
1113
1114
            return self.models.base_model_paths[0].name
        return model_name
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129


def clamp_prompt_logprobs(
    prompt_logprobs: Union[PromptLogprobs,
                           None]) -> Union[PromptLogprobs, None]:
    if prompt_logprobs is None:
        return prompt_logprobs

    for logprob_dict in prompt_logprobs:
        if logprob_dict is None:
            continue
        for logprob_values in logprob_dict.values():
            if logprob_values.logprob == float('-inf'):
                logprob_values.logprob = -9999.0
    return prompt_logprobs