serving_engine.py 39 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
5
import base64
import io
6
import json
7
import sys
8
import time
9
10
from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor
11
from http import HTTPStatus
12
from typing import (Annotated, Any, Callable, ClassVar, Generic, Optional,
13
                    TypeVar, Union, cast, overload)
14

15
import torch
16
from fastapi import Request
17
from pydantic import BaseModel, ConfigDict, Field
18
from starlette.datastructures import Headers
19
20
21
22
23
24
from typing_extensions import TypeIs

if sys.version_info >= (3, 12):
    from typing import TypedDict
else:
    from typing_extensions import TypedDict
25

26
27
28
29
30
if sys.version_info >= (3, 12):
    from typing import TypedDict
else:
    from typing_extensions import TypedDict

31
import vllm.envs as envs
32
from vllm.config import ModelConfig
33
from vllm.engine.protocol import EngineClient
34
35
# yapf conflicts with isort for this block
# yapf: disable
36
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
37
                                         ChatTemplateContentFormatOption,
38
39
40
                                         ConversationMessage,
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
41
42
                                         parse_chat_messages_futures,
                                         resolve_chat_template_content_format)
43
from vllm.entrypoints.logger import RequestLogger
44
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
45
46
47
                                              ChatCompletionResponse,
                                              ClassificationRequest,
                                              ClassificationResponse,
48
                                              CompletionRequest,
49
                                              CompletionResponse,
50
                                              DetokenizeRequest,
51
52
                                              EmbeddingChatRequest,
                                              EmbeddingCompletionRequest,
53
54
55
                                              EmbeddingRequest,
                                              EmbeddingResponse, ErrorResponse,
                                              PoolingResponse, RerankRequest,
56
57
                                              ResponsesRequest, ScoreRequest,
                                              ScoreResponse,
58
                                              TokenizeChatRequest,
59
                                              TokenizeCompletionRequest,
60
61
                                              TokenizeResponse,
                                              TranscriptionRequest,
62
63
                                              TranscriptionResponse,
                                              TranslationRequest)
64
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
65
from vllm.entrypoints.openai.tool_parsers import ToolParser
66
# yapf: enable
67
68
from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
69
from vllm.inputs.parse import parse_and_batch_prompt
70
from vllm.logger import init_logger
71
from vllm.lora.request import LoRARequest
72
73
74
from vllm.multimodal import (  # noqa: F401 - Required to resolve Pydantic error in RequestProcessingMixin
    MultiModalDataDict)
from vllm.outputs import PoolingRequestOutput, RequestOutput
75
from vllm.pooling_params import PoolingParams
76
from vllm.prompt_adapter.request import PromptAdapterRequest
77
from vllm.sampling_params import BeamSearchParams, SamplingParams
78
from vllm.sequence import Logprob, PromptLogprobs
79
80
81
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
                          log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
82
83
from vllm.utils import (AsyncMicrobatchTokenizer, is_list_of,
                        merge_async_iterators, random_uuid)
84
85
86

logger = init_logger(__name__)

87
CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest,
88
                              EmbeddingCompletionRequest, RerankRequest,
89
90
                              ClassificationRequest, ScoreRequest,
                              TokenizeCompletionRequest]
91
92
93

ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
                        TokenizeChatRequest]
94
SpeechToTextRequest = Union[TranscriptionRequest, TranslationRequest]
95
96
AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest, SpeechToTextRequest,
                   ResponsesRequest]
97

98
99
100
101
102
103
104
105
106
107
108
AnyResponse = Union[
    CompletionResponse,
    ChatCompletionResponse,
    EmbeddingResponse,
    TranscriptionResponse,
    TokenizeResponse,
    PoolingResponse,
    ClassificationResponse,
    ScoreResponse,
]

109
110
111

class TextTokensPrompt(TypedDict):
    prompt: str
112
    prompt_token_ids: list[int]
113
114


115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
class EmbedsPrompt(TypedDict):
    prompt_embeds: torch.Tensor


RequestPrompt = Union[list[int], str, TextTokensPrompt, EmbedsPrompt]


def is_text_tokens_prompt(prompt: RequestPrompt) -> TypeIs[TextTokensPrompt]:
    return (isinstance(prompt, dict) and "prompt_token_ids" in prompt
            and "prompt_embeds" not in prompt)


def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]:
    return (isinstance(prompt, dict) and "prompt_token_ids" not in prompt
            and "prompt_embeds" in prompt)

131

132
133
134
135
136
RequestT = TypeVar("RequestT", bound=AnyRequest)


class RequestProcessingMixin(BaseModel):
    """
137
    Mixin for request processing,
138
139
    handling prompt preparation and engine input.
    """
140
    request_prompts: Optional[Sequence[RequestPrompt]] = []
141
    engine_prompts: Optional[Union[list[EngineTokensPrompt],
142
                                   list[EngineEmbedsPrompt]]] = []
143
144
145
146
147
148

    model_config = ConfigDict(arbitrary_types_allowed=True)


class ResponseGenerationMixin(BaseModel):
    """
149
    Mixin for response generation,
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
    managing result generators and final batch results.
    """
    result_generator: Optional[AsyncGenerator[tuple[int, Union[
        RequestOutput, PoolingRequestOutput]], None]] = None
    final_res_batch: list[Union[RequestOutput, PoolingRequestOutput]] = Field(
        default_factory=list)

    model_config = ConfigDict(arbitrary_types_allowed=True)


class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, BaseModel,
                   Generic[RequestT]):
    # Shared across all requests
    request: RequestT
    raw_request: Optional[Request] = None
    model_name: str
    request_id: str
    created_time: int = Field(default_factory=lambda: int(time.time()))
    lora_request: Optional[LoRARequest] = None
    prompt_adapter_request: Optional[PromptAdapterRequest] = None

    # Shared across most requests
    tokenizer: Optional[AnyTokenizer] = None
    truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None

    # `protected_namespaces` resolves Pydantic v2's warning
    # on conflict with protected namespace "model_"
    model_config = ConfigDict(
        protected_namespaces=(),
        arbitrary_types_allowed=True,
    )


ClassificationServeContext = ServeContext[ClassificationRequest]


class EmbeddingServeContext(ServeContext[EmbeddingRequest]):
    chat_template: Optional[str] = None
    chat_template_content_format: ChatTemplateContentFormatOption


# Used to resolve the Pydantic error related to
# forward reference of MultiModalDataDict in TokensPrompt
RequestProcessingMixin.model_rebuild()
ServeContext.model_rebuild()
ClassificationServeContext.model_rebuild()
EmbeddingServeContext.model_rebuild()

198

199
class OpenAIServing:
200
201
202
203
    request_id_prefix: ClassVar[str] = """
    A short string prepended to every request’s ID (e.g. "embd", "classify")
    so you can easily tell “this ID came from Embedding vs Classification.”
    """
204

205
206
    def __init__(
        self,
207
        engine_client: EngineClient,
208
        model_config: ModelConfig,
209
        models: OpenAIServingModels,
210
211
        *,
        request_logger: Optional[RequestLogger],
212
        return_tokens_as_token_ids: bool = False,
213
        enable_force_include_usage: bool = False,
214
    ):
215
216
        super().__init__()

217
        self.engine_client = engine_client
218
        self.model_config = model_config
219
220
        self.max_model_len = model_config.max_model_len

221
        self.models = models
222

223
        self.request_logger = request_logger
224
        self.return_tokens_as_token_ids = return_tokens_as_token_ids
225
        self.enable_force_include_usage = enable_force_include_usage
226

227
228
        self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)

229
230
231
232
233
234
235
236
237
238
239
240
241
        self._async_tokenizer_pool: dict[AnyTokenizer,
                                         AsyncMicrobatchTokenizer] = {}

    def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer:
        """
        Return (and cache) an `AsyncMicrobatchTokenizer` bound to the 
        given tokenizer.
        """
        async_tokenizer = self._async_tokenizer_pool.get(tokenizer)
        if async_tokenizer is None:
            async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
            self._async_tokenizer_pool[tokenizer] = async_tokenizer
        return async_tokenizer
242

243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
    async def _preprocess(
        self,
        ctx: ServeContext,
    ) -> Optional[ErrorResponse]:
        """
        Default preprocessing hook. Subclasses may override
        to prepare `ctx` (classification, embedding, etc.).
        """
        return None

    def _build_response(
        self,
        ctx: ServeContext,
    ) -> Union[AnyResponse, ErrorResponse]:
        """
        Default response builder. Subclass may override this method
        to return the appropriate response object.
        """
        return self.create_error_response("unimplemented endpoint")

    async def handle(
        self,
        ctx: ServeContext,
    ) -> Union[AnyResponse, ErrorResponse]:
        generation: AsyncGenerator[Union[AnyResponse, ErrorResponse], None]
        generation = self._pipeline(ctx)

        async for response in generation:
            return response

        return self.create_error_response("No response yielded from pipeline")

    async def _pipeline(
        self,
        ctx: ServeContext,
    ) -> AsyncGenerator[Union[AnyResponse, ErrorResponse], None]:
        """Execute the request processing pipeline yielding responses."""
        if error := await self._check_model(ctx.request):
            yield error
        if error := self._validate_request(ctx):
            yield error

        preprocess_ret = await self._preprocess(ctx)
        if isinstance(preprocess_ret, ErrorResponse):
            yield preprocess_ret

        generators_ret = await self._prepare_generators(ctx)
        if isinstance(generators_ret, ErrorResponse):
            yield generators_ret

        collect_ret = await self._collect_batch(ctx)
        if isinstance(collect_ret, ErrorResponse):
            yield collect_ret

        yield self._build_response(ctx)

    def _validate_request(self, ctx: ServeContext) -> Optional[ErrorResponse]:
        truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens",
                                         None)

        if truncate_prompt_tokens is not None:
            if truncate_prompt_tokens <= self.max_model_len:
                ctx.truncate_prompt_tokens = truncate_prompt_tokens
            else:
                return self.create_error_response(
                    "truncate_prompt_tokens value is "
                    "greater than max_model_len."
                    " Please, select a smaller truncation size.")
        return None

    async def _prepare_generators(
        self,
        ctx: ServeContext,
    ) -> Optional[ErrorResponse]:
        """Schedule the request and get the result generator."""
        generators: list[AsyncGenerator[Union[RequestOutput,
                                              PoolingRequestOutput],
                                        None]] = []

        try:
            trace_headers = (None if ctx.raw_request is None else await
                             self._get_trace_headers(ctx.raw_request.headers))

            if not hasattr(ctx.request, "to_pooling_params"):
                return self.create_error_response(
                    "Request type does not support pooling parameters")

            pooling_params = ctx.request.to_pooling_params()

            if ctx.engine_prompts is None:
                return self.create_error_response(
                    "Engine prompts not available")

            for i, engine_prompt in enumerate(ctx.engine_prompts):
                request_id_item = f"{ctx.request_id}-{i}"

                if ctx.request_prompts is None:
                    return self.create_error_response(
                        "Request prompts not available")

                self._log_inputs(
                    request_id_item,
                    ctx.request_prompts[i],
                    params=pooling_params,
                    lora_request=ctx.lora_request,
                    prompt_adapter_request=ctx.prompt_adapter_request)

350
351
352
353
354
355
                # Mypy has an existing bug related to inferring the variance of
                # TypedDicts with `builtins.enumerate`:
                # https://github.com/python/mypy/issues/8586#issuecomment-2867698435
                engine_prompt = cast(
                    Union[EngineTokensPrompt, EngineEmbedsPrompt],
                    engine_prompt)
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
                generator = self.engine_client.encode(
                    engine_prompt,
                    pooling_params,
                    request_id_item,
                    lora_request=ctx.lora_request,
                    trace_headers=trace_headers,
                    priority=getattr(ctx.request, "priority", 0),
                )

                generators.append(generator)

            ctx.result_generator = merge_async_iterators(*generators)

            return None

        except Exception as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))

    async def _collect_batch(
        self,
        ctx: ServeContext,
    ) -> Optional[ErrorResponse]:
        """Collect batch results from the result generator."""
        try:
            if ctx.engine_prompts is None:
                return self.create_error_response(
                    "Engine prompts not available")

            num_prompts = len(ctx.engine_prompts)
            final_res_batch: list[Optional[Union[RequestOutput,
                                                 PoolingRequestOutput]]]
            final_res_batch = [None] * num_prompts

            if ctx.result_generator is None:
                return self.create_error_response(
                    "Result generator not available")

            async for i, res in ctx.result_generator:
                final_res_batch[i] = res

            if None in final_res_batch:
                return self.create_error_response(
                    "Failed to generate results for all prompts")

            ctx.final_res_batch = [
                res for res in final_res_batch if res is not None
            ]

            return None

        except Exception as e:
            return self.create_error_response(str(e))

410
411
412
413
414
415
416
417
418
    def create_error_response(
            self,
            message: str,
            err_type: str = "BadRequestError",
            status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
        return ErrorResponse(message=message,
                             type=err_type,
                             code=status_code.value)

419
420
421
422
423
424
425
426
427
428
429
430
431
    def create_streaming_error_response(
            self,
            message: str,
            err_type: str = "BadRequestError",
            status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
        json_str = json.dumps({
            "error":
            self.create_error_response(message=message,
                                       err_type=err_type,
                                       status_code=status_code).model_dump()
        })
        return json_str

432
    async def _check_model(
433
434
        self,
        request: AnyRequest,
435
    ) -> Optional[ErrorResponse]:
436
437
438

        error_response = None

439
        if self._is_model_supported(request.model):
440
            return None
441
        if request.model in self.models.lora_requests:
442
            return None
443
444
445
446
447
448
449
        if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING and request.model and (
                load_result := await self.models.resolve_lora(request.model)):
            if isinstance(load_result, LoRARequest):
                return None
            if isinstance(load_result, ErrorResponse) and \
                load_result.code == HTTPStatus.BAD_REQUEST.value:
                error_response = load_result
450
451
        if request.model in [
                prompt_adapter.prompt_adapter_name
452
                for prompt_adapter in self.models.prompt_adapter_requests
453
454
        ]:
            return None
455
456

        return error_response or self.create_error_response(
457
458
459
460
            message=f"The model `{request.model}` does not exist.",
            err_type="NotFoundError",
            status_code=HTTPStatus.NOT_FOUND)

461
462
    def _maybe_get_adapters(
        self, request: AnyRequest
463
    ) -> Union[tuple[None, None], tuple[LoRARequest, None], tuple[
464
            None, PromptAdapterRequest]]:
465
        if self._is_model_supported(request.model):
466
            return None, None
467
468
        if request.model in self.models.lora_requests:
            return self.models.lora_requests[request.model], None
469
        for prompt_adapter in self.models.prompt_adapter_requests:
470
            if request.model == prompt_adapter.prompt_adapter_name:
471
                return None, prompt_adapter
472
        # if _check_model has been called earlier, this will be unreachable
473
        raise ValueError(f"The model `{request.model}` does not exist.")
474

475
    async def _normalize_prompt_text_to_input(
476
477
478
479
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
        prompt: str,
480
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]],
481
482
        add_special_tokens: bool,
    ) -> TextTokensPrompt:
483
484
        async_tokenizer = self._get_async_tokenizer(tokenizer)

485
486
487
488
489
        if (self.model_config.encoder_config is not None
                and self.model_config.encoder_config.get(
                    "do_lower_case", False)):
            prompt = prompt.lower()

490
        if truncate_prompt_tokens is None:
491
492
            encoded = await async_tokenizer(
                prompt, add_special_tokens=add_special_tokens)
493
494
        elif truncate_prompt_tokens < 0:
            # Negative means we cap at the model's max length
495
496
497
498
499
            encoded = await async_tokenizer(
                prompt,
                add_special_tokens=add_special_tokens,
                truncation=True,
                max_length=self.max_model_len)
500
        else:
501
502
503
504
505
            encoded = await async_tokenizer(
                prompt,
                add_special_tokens=add_special_tokens,
                truncation=True,
                max_length=truncate_prompt_tokens)
506
507
508
509
510
511

        input_ids = encoded.input_ids
        input_text = prompt

        return self._validate_input(request, input_ids, input_text)

512
    async def _normalize_prompt_tokens_to_input(
513
514
515
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
516
        prompt_ids: list[int],
517
518
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
    ) -> TextTokensPrompt:
519
520
        async_tokenizer = self._get_async_tokenizer(tokenizer)

521
        if truncate_prompt_tokens is None:
522
            input_ids = prompt_ids
523
524
        elif truncate_prompt_tokens < 0:
            input_ids = prompt_ids[-self.max_model_len:]
525
526
527
        else:
            input_ids = prompt_ids[-truncate_prompt_tokens:]

528
        input_text = await async_tokenizer.decode(input_ids)
529

530
531
532
533
534
        return self._validate_input(request, input_ids, input_text)

    def _validate_input(
        self,
        request: AnyRequest,
535
        input_ids: list[int],
536
537
        input_text: str,
    ) -> TextTokensPrompt:
538
539
        token_num = len(input_ids)

540
541
        # Note: EmbeddingRequest, ClassificationRequest,
        # and ScoreRequest doesn't have max_tokens
542
543
        if isinstance(request,
                      (EmbeddingChatRequest, EmbeddingCompletionRequest,
544
                       ScoreRequest, RerankRequest, ClassificationRequest)):
545

546
            if token_num > self.max_model_len:
547
548
549
550
551
552
                operations: dict[type[AnyRequest], str] = {
                    ScoreRequest: "score",
                    ClassificationRequest: "classification"
                }
                operation = operations.get(type(request),
                                           "embedding generation")
553
554
555
                raise ValueError(
                    f"This model's maximum context length is "
                    f"{self.max_model_len} tokens. However, you requested "
556
557
                    f"{token_num} tokens in the input for {operation}. "
                    f"Please reduce the length of the input.")
558
559
            return TextTokensPrompt(prompt=input_text,
                                    prompt_token_ids=input_ids)
560

561
562
        # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
        # and does not require model context length validation
563
564
565
566
        if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest,
                                DetokenizeRequest)):
            return TextTokensPrompt(prompt=input_text,
                                    prompt_token_ids=input_ids)
567

568
569
570
571
572
        # chat completion endpoint supports max_completion_tokens
        if isinstance(request, ChatCompletionRequest):
            # TODO(#9845): remove max_tokens when field dropped from OpenAI API
            max_tokens = request.max_completion_tokens or request.max_tokens
        else:
573
            max_tokens = getattr(request, "max_tokens", None)
574
        if max_tokens is None:
575
576
577
578
579
            if token_num >= self.max_model_len:
                raise ValueError(
                    f"This model's maximum context length is "
                    f"{self.max_model_len} tokens. However, you requested "
                    f"{token_num} tokens in the messages, "
580
                    f"Please reduce the length of the messages.")
581
        elif token_num + max_tokens > self.max_model_len:
582
            raise ValueError(
583
584
                f"This model's maximum context length is "
                f"{self.max_model_len} tokens. However, you requested "
585
                f"{max_tokens + token_num} tokens "
586
                f"({token_num} in the messages, "
587
                f"{max_tokens} in the completion). "
588
589
590
591
                f"Please reduce the length of the messages or completion.")

        return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)

592
    async def _tokenize_prompt_input_async(
593
594
595
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
596
        prompt_input: Union[str, list[int]],
597
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
598
599
600
        add_special_tokens: bool = True,
    ) -> TextTokensPrompt:
        """
601
602
        A simpler implementation of
        [`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs]
603
604
        that assumes single input.
        """
605
        async for result in self._tokenize_prompt_inputs_async(
606
607
                request,
                tokenizer,
608
            [prompt_input],
609
610
                truncate_prompt_tokens=truncate_prompt_tokens,
                add_special_tokens=add_special_tokens,
611
612
613
        ):
            return result
        raise ValueError("No results yielded from tokenization")
614

615
    async def _tokenize_prompt_inputs_async(
616
617
618
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
619
        prompt_inputs: Iterable[Union[str, list[int]]],
620
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
621
        add_special_tokens: bool = True,
622
    ) -> AsyncGenerator[TextTokensPrompt, None]:
623
        """
624
625
        A simpler implementation of
        [`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs]
626
627
628
629
        that assumes multiple inputs.
        """
        for text in prompt_inputs:
            if isinstance(text, str):
630
                yield await self._normalize_prompt_text_to_input(
631
632
633
634
635
636
637
                    request,
                    tokenizer,
                    prompt=text,
                    truncate_prompt_tokens=truncate_prompt_tokens,
                    add_special_tokens=add_special_tokens,
                )
            else:
638
                yield await self._normalize_prompt_tokens_to_input(
639
640
641
642
643
644
                    request,
                    tokenizer,
                    prompt_ids=text,
                    truncate_prompt_tokens=truncate_prompt_tokens,
                )

645
    async def _tokenize_prompt_input_or_inputs_async(
646
647
648
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
649
650
        input_or_inputs: Optional[Union[str, list[str], list[int],
                                        list[list[int]]]],
651
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
652
        add_special_tokens: bool = True,
653
    ) -> tuple[list[TextTokensPrompt], list[EmbedsPrompt]]:
654
655
656
657
658
659
660
        """
        Tokenize/detokenize depending on the input format.

        According to `OpenAI API <https://platform.openai.com/docs/api-reference/embeddings/create>`_
        , each input can be a string or array of tokens. Note that each request
        can pass one or more inputs.
        """
661
662
663
664
665
666
667
668
669
670
671
672
673
674
        inputs_embeds = list[EmbedsPrompt]()
        inputs_text = list[TextTokensPrompt]()

        if (isinstance(request, CompletionRequest)
                and request.prompt_embeds is not None):
            inputs_embeds.extend(
                self._load_prompt_embeds(request.prompt_embeds,
                                         truncate_prompt_tokens))

        # Empty prompts are okay as long as there are prompt embeddings
        if input_or_inputs is None or (inputs_embeds
                                       and input_or_inputs == ""):
            return [], inputs_embeds

675
676
        # Although our type checking is based on mypy,
        # VSCode Pyright extension should still work properly
677
        # "is False" is required for Pyright to perform type narrowing
678
        # See: https://github.com/microsoft/pyright/issues/7672
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703

        # Parse and batch the input prompts
        batch_inputs = parse_and_batch_prompt(input_or_inputs)

        # Process each input in the batch concurrently
        tasks = []
        for prompt_input in batch_inputs:
            if prompt_input["is_tokens"] is False:
                task = self._normalize_prompt_text_to_input(
                    request,
                    tokenizer,
                    prompt_input["content"],
                    truncate_prompt_tokens=truncate_prompt_tokens,
                    add_special_tokens=add_special_tokens)
            else:
                task = self._normalize_prompt_tokens_to_input(
                    request,
                    tokenizer,
                    prompt_input["content"],
                    truncate_prompt_tokens=truncate_prompt_tokens)
            tasks.append(task)

        # Wait for all tokenization tasks to complete
        results = await asyncio.gather(*tasks)
        inputs_text.extend(results)
704
705

        return inputs_text, inputs_embeds
706

707
    @overload
708
    async def _preprocess_completion(
709
        self,
710
711
712
        request: Union[DetokenizeRequest, EmbeddingCompletionRequest,
                       RerankRequest, ClassificationRequest, ScoreRequest,
                       TokenizeCompletionRequest],
713
        tokenizer: AnyTokenizer,
714
        input_or_inputs: Union[str, list[str], list[int], list[list[int]]],
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = ...,
        add_special_tokens: bool = ...,
    ) -> tuple[list[TextTokensPrompt], list[EngineTokensPrompt]]:
        ...

    @overload
    async def _preprocess_completion(
        self,
        request: CompletionRequest,
        tokenizer: AnyTokenizer,
        input_or_inputs: Optional[Union[str, list[str], list[int],
                                        list[list[int]]]],
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = ...,
        add_special_tokens: bool = ...,
    ) -> tuple[list[Union[TextTokensPrompt, EmbedsPrompt]], list[Union[
            EngineTokensPrompt, EngineEmbedsPrompt]]]:
        ...

    async def _preprocess_completion(
        self,
        request: CompletionLikeRequest,
        tokenizer: AnyTokenizer,
        input_or_inputs: Optional[Union[str, list[str], list[int],
                                        list[list[int]]]],
739
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
740
        add_special_tokens: bool = True,
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
    ) -> tuple[Union[list[TextTokensPrompt], list[Union[
            TextTokensPrompt, EmbedsPrompt]]], Union[
                list[EngineTokensPrompt], list[Union[EngineTokensPrompt,
                                                     EngineEmbedsPrompt]]]]:
        if not isinstance(request,
                          CompletionRequest) and input_or_inputs is None:
            raise ValueError(
                "Prompt embeds with non-completion requests is not"
                " currently supported.")

        (request_prompts_text, request_prompts_embeds
         ) = await self._tokenize_prompt_input_or_inputs_async(
             request,
             tokenizer,
             input_or_inputs,
             truncate_prompt_tokens=truncate_prompt_tokens,
             add_special_tokens=add_special_tokens,
         )

        engine_prompts_text = [
            EngineTokensPrompt(
                prompt_token_ids=request_prompt_text["prompt_token_ids"])
            for request_prompt_text in request_prompts_text
        ]
765

766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
        # This check is equivalent to simply checking if
        # `request_prompts_embeds` is empty, but it's difficult to propagate
        # overloads to the private helper functions to enable this check.
        # This overload is needed because only TextPrompts are allowed for
        # non-completion requests and if we don't add the overload here,
        # everywhere this function is used outside of serving_completion will
        # need logic asserting that only text prompts are in the request.
        if not isinstance(request,
                          CompletionRequest) and input_or_inputs is not None:
            return request_prompts_text, engine_prompts_text

        engine_prompts_embeds = [
            EngineEmbedsPrompt(
                prompt_embeds=request_prompt_embeds["prompt_embeds"])
            for request_prompt_embeds in request_prompts_embeds
781
782
        ]

783
784
        request_prompts = request_prompts_embeds + request_prompts_text
        engine_prompts = engine_prompts_embeds + engine_prompts_text
785
786
787
788
        return request_prompts, engine_prompts

    async def _preprocess_chat(
        self,
789
        request: Union[ChatLikeRequest, ResponsesRequest],
790
        tokenizer: AnyTokenizer,
791
        messages: list[ChatCompletionMessageParam],
792
793
        chat_template: Optional[str],
        chat_template_content_format: ChatTemplateContentFormatOption,
794
795
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
796
797
798
        tool_dicts: Optional[list[dict[str, Any]]] = None,
        documents: Optional[list[dict[str, str]]] = None,
        chat_template_kwargs: Optional[dict[str, Any]] = None,
799
800
801
        tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None,
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
        add_special_tokens: bool = False,
802
    ) -> tuple[list[ConversationMessage], Sequence[RequestPrompt],
803
               list[EngineTokensPrompt]]:
804
805
        model_config = self.model_config

806
807
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
808
            tool_dicts,
809
810
            chat_template_content_format,
            tokenizer,
811
            model_config=model_config,
812
        )
813
814
        conversation, mm_data_future = parse_chat_messages_futures(
            messages,
815
            model_config,
816
            tokenizer,
817
            content_format=resolved_content_format,
818
819
        )

820
        _chat_template_kwargs: dict[str, Any] = dict(
821
822
823
824
825
826
827
828
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tool_dicts,
            documents=documents,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

829
        request_prompt: Union[str, list[int]]
830
        if isinstance(tokenizer, MistralTokenizer):
831
832
833
            request_prompt = apply_mistral_chat_template(
                tokenizer,
                messages=messages,
834
                **_chat_template_kwargs,
835
836
837
            )
        else:
            request_prompt = apply_hf_chat_template(
838
                tokenizer=tokenizer,
839
                conversation=conversation,
840
                model_config=model_config,
841
                **_chat_template_kwargs,
842
843
844
845
            )

        mm_data = await mm_data_future

846
847
848
849
850
851
852
        # tool parsing is done only if a tool_parser has been set and if
        # tool_choice is not "none" (if tool_choice is "none" but a tool_parser
        # is set, we want to prevent parsing a tool_call hallucinated by the LLM
        should_parse_tools = tool_parser is not None and (hasattr(
            request, "tool_choice") and request.tool_choice != "none")

        if should_parse_tools:
853
854
855
856
            if not isinstance(request, ChatCompletionRequest):
                msg = "Tool usage is only supported for Chat Completions API"
                raise NotImplementedError(msg)

857
858
            request = tool_parser(tokenizer).adjust_request(  # type: ignore
                request=request)
859
860

        if isinstance(request_prompt, str):
861
            prompt_inputs = await self._tokenize_prompt_input_async(
862
863
864
865
866
867
868
869
870
871
872
873
874
875
                request,
                tokenizer,
                request_prompt,
                truncate_prompt_tokens=truncate_prompt_tokens,
                add_special_tokens=add_special_tokens,
            )
        else:
            # For MistralTokenizer
            assert is_list_of(request_prompt, int), (
                "Prompt has to be either a string or a list of token ids")
            prompt_inputs = TextTokensPrompt(
                prompt=tokenizer.decode(request_prompt),
                prompt_token_ids=request_prompt)

876
        engine_prompt = EngineTokensPrompt(
877
878
879
            prompt_token_ids=prompt_inputs["prompt_token_ids"])
        if mm_data is not None:
            engine_prompt["multi_modal_data"] = mm_data
880
881
        if request.mm_processor_kwargs is not None:
            engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
882

883
884
885
        if hasattr(request, "cache_salt") and request.cache_salt is not None:
            engine_prompt["cache_salt"] = request.cache_salt

886
887
        return conversation, [request_prompt], [engine_prompt]

888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
    def _load_prompt_embeds(
        self,
        prompt_embeds: Optional[Union[bytes, list[bytes]]],
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
    ) -> list[EmbedsPrompt]:

        def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt:
            tensor = torch.load(io.BytesIO(base64.b64decode(embed)),
                                weights_only=True)
            assert isinstance(
                tensor,
                (torch.FloatTensor, torch.BFloat16Tensor, torch.HalfTensor))
            if tensor.dim() > 2:
                tensor = tensor.squeeze(0)
                assert tensor.dim() == 2
            if truncate_prompt_tokens is not None:
                tensor = tensor[-truncate_prompt_tokens:]
            return {"prompt_embeds": tensor}

        if prompt_embeds:
            if isinstance(prompt_embeds, list):
                return [
                    _load_and_validate_embed(embed) for embed in prompt_embeds
                ]
            else:
                return [_load_and_validate_embed(prompt_embeds)]
        else:
            return []

917
918
919
    def _log_inputs(
        self,
        request_id: str,
920
        inputs: RequestPrompt,
921
922
        params: Optional[Union[SamplingParams, PoolingParams,
                               BeamSearchParams]],
923
924
925
926
927
        lora_request: Optional[LoRARequest],
        prompt_adapter_request: Optional[PromptAdapterRequest],
    ) -> None:
        if self.request_logger is None:
            return
928
        prompt, prompt_token_ids, prompt_embeds = None, None, None
929
930
931
932
        if isinstance(inputs, str):
            prompt = inputs
        elif isinstance(inputs, list):
            prompt_token_ids = inputs
933
934
        elif 'prompt_embeds' in inputs:
            prompt_embeds = inputs.get("prompt_embeds")
935
        else:
936
937
938
939
940
941
942
            prompt = inputs["prompt"]
            prompt_token_ids = inputs["prompt_token_ids"]

        self.request_logger.log_inputs(
            request_id,
            prompt,
            prompt_token_ids,
943
            prompt_embeds,
944
945
946
947
            params=params,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
        )
948

949
950
951
952
953
954
955
956
957
958
959
960
961
962
    async def _get_trace_headers(
        self,
        headers: Headers,
    ) -> Optional[Mapping[str, str]]:
        is_tracing_enabled = await self.engine_client.is_tracing_enabled()

        if is_tracing_enabled:
            return extract_trace_headers(headers)

        if contains_trace_headers(headers):
            log_tracing_disabled_warning()

        return None

963
    @staticmethod
964
    def _base_request_id(raw_request: Optional[Request],
965
966
967
                         default: Optional[str] = None) -> Optional[str]:
        """Pulls the request id to use from a header, if provided"""
        default = default or random_uuid()
968
969
970
971
        if raw_request is None:
            return default

        return raw_request.headers.get("X-Request-Id", default)
972

973
    @staticmethod
974
975
976
977
978
979
980
    def _get_decoded_token(logprob: Logprob,
                           token_id: int,
                           tokenizer: AnyTokenizer,
                           return_as_token_id: bool = False) -> str:
        if return_as_token_id:
            return f"token_id:{token_id}"

981
982
        if logprob.decoded_token is not None:
            return logprob.decoded_token
983
        return tokenizer.decode(token_id)
984

985
    def _is_model_supported(self, model_name: Optional[str]) -> bool:
986
987
        if not model_name:
            return True
988
        return self.models.is_base_model(model_name)
989
990
991
992
993
994

    def _get_model_name(self,
                        model_name: Optional[str] = None,
                        lora_request: Optional[LoRARequest] = None) -> str:
        if lora_request:
            return lora_request.lora_name
995
        if not model_name:
996
997
            return self.models.base_model_paths[0].name
        return model_name
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012


def clamp_prompt_logprobs(
    prompt_logprobs: Union[PromptLogprobs,
                           None]) -> Union[PromptLogprobs, None]:
    if prompt_logprobs is None:
        return prompt_logprobs

    for logprob_dict in prompt_logprobs:
        if logprob_dict is None:
            continue
        for logprob_values in logprob_dict.values():
            if logprob_values.logprob == float('-inf'):
                logprob_values.logprob = -9999.0
    return prompt_logprobs