serving_engine.py 38.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
import base64
import io
5
import json
6
import sys
7
8
9
import time
from collections.abc import (AsyncGenerator, Iterable, Iterator, Mapping,
                             Sequence)
10
from concurrent.futures.thread import ThreadPoolExecutor
11
from http import HTTPStatus
12
from typing import (Annotated, Any, Callable, ClassVar, Generic, Optional,
13
                    TypeVar, Union, cast, overload)
14

15
import torch
16
from fastapi import Request
17
from pydantic import BaseModel, ConfigDict, Field
18
from starlette.datastructures import Headers
19
20
21
22
23
24
from typing_extensions import TypeIs

if sys.version_info >= (3, 12):
    from typing import TypedDict
else:
    from typing_extensions import TypedDict
25

26
27
28
29
if sys.version_info >= (3, 12):
    from typing import TypedDict
else:
    from typing_extensions import TypedDict
30

31
import vllm.envs as envs
32
from vllm.config import ModelConfig
33
from vllm.engine.protocol import EngineClient
34
35
# yapf conflicts with isort for this block
# yapf: disable
36
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
37
                                         ChatTemplateContentFormatOption,
38
39
40
                                         ConversationMessage,
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
41
42
                                         parse_chat_messages_futures,
                                         resolve_chat_template_content_format)
43
from vllm.entrypoints.logger import RequestLogger
44
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
45
46
47
                                              ChatCompletionResponse,
                                              ClassificationRequest,
                                              ClassificationResponse,
48
                                              CompletionRequest,
49
                                              CompletionResponse,
50
                                              DetokenizeRequest,
51
52
                                              EmbeddingChatRequest,
                                              EmbeddingCompletionRequest,
53
54
55
56
                                              EmbeddingRequest,
                                              EmbeddingResponse, ErrorResponse,
                                              PoolingResponse, RerankRequest,
                                              ScoreRequest, ScoreResponse,
57
                                              TokenizeChatRequest,
58
                                              TokenizeCompletionRequest,
59
60
                                              TokenizeResponse,
                                              TranscriptionRequest,
61
62
                                              TranscriptionResponse,
                                              TranslationRequest)
63
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
64
from vllm.entrypoints.openai.tool_parsers import ToolParser
65
# yapf: enable
66
67
from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
68
from vllm.inputs.parse import parse_and_batch_prompt
69
from vllm.logger import init_logger
70
from vllm.lora.request import LoRARequest
71
72
73
from vllm.multimodal import (  # noqa: F401 - Required to resolve Pydantic error in RequestProcessingMixin
    MultiModalDataDict)
from vllm.outputs import PoolingRequestOutput, RequestOutput
74
from vllm.pooling_params import PoolingParams
75
from vllm.prompt_adapter.request import PromptAdapterRequest
76
from vllm.sampling_params import BeamSearchParams, SamplingParams
77
from vllm.sequence import Logprob, PromptLogprobs
78
79
80
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
                          log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
zhuwenwen's avatar
zhuwenwen committed
81

82
from vllm.transformers_utils.tokenizers import CPM9GTokenizer
83
84
from vllm.utils import (is_list_of, make_async, merge_async_iterators,
                        random_uuid)
85
86
87

logger = init_logger(__name__)

88
CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest,
89
                              EmbeddingCompletionRequest, RerankRequest,
90
91
                              ClassificationRequest, ScoreRequest,
                              TokenizeCompletionRequest]
92
93
94

ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
                        TokenizeChatRequest]
95
96
SpeechToTextRequest = Union[TranscriptionRequest, TranslationRequest]
AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest, SpeechToTextRequest]
97

98
99
100
101
102
103
104
105
106
107
108
AnyResponse = Union[
    CompletionResponse,
    ChatCompletionResponse,
    EmbeddingResponse,
    TranscriptionResponse,
    TokenizeResponse,
    PoolingResponse,
    ClassificationResponse,
    ScoreResponse,
]

109
110
111

class TextTokensPrompt(TypedDict):
    prompt: str
112
    prompt_token_ids: list[int]
113
114


115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
class EmbedsPrompt(TypedDict):
    prompt_embeds: torch.Tensor


RequestPrompt = Union[list[int], str, TextTokensPrompt, EmbedsPrompt]


def is_text_tokens_prompt(prompt: RequestPrompt) -> TypeIs[TextTokensPrompt]:
    return (isinstance(prompt, dict) and "prompt_token_ids" in prompt
            and "prompt_embeds" not in prompt)


def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]:
    return (isinstance(prompt, dict) and "prompt_token_ids" not in prompt
            and "prompt_embeds" in prompt)

131

132
133
134
135
136
RequestT = TypeVar("RequestT", bound=AnyRequest)


class RequestProcessingMixin(BaseModel):
    """
137
    Mixin for request processing,
138
139
    handling prompt preparation and engine input.
    """
140
    request_prompts: Optional[Sequence[RequestPrompt]] = []
141
    engine_prompts: Optional[Union[list[EngineTokensPrompt],
142
                                   list[EngineEmbedsPrompt]]] = []
143
144
145
146
147
148

    model_config = ConfigDict(arbitrary_types_allowed=True)


class ResponseGenerationMixin(BaseModel):
    """
149
    Mixin for response generation,
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    managing result generators and final batch results.
    """
    result_generator: Optional[AsyncGenerator[tuple[int, Union[
        RequestOutput, PoolingRequestOutput]], None]] = None
    final_res_batch: list[Union[RequestOutput, PoolingRequestOutput]] = Field(
        default_factory=list)

    model_config = ConfigDict(arbitrary_types_allowed=True)


class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, BaseModel,
                   Generic[RequestT]):
    # Shared across all requests
    request: RequestT
    raw_request: Optional[Request] = None
    model_name: str
    request_id: str
    created_time: int = Field(default_factory=lambda: int(time.time()))
    lora_request: Optional[LoRARequest] = None
    prompt_adapter_request: Optional[PromptAdapterRequest] = None

    # Shared across most requests
    tokenizer: Optional[AnyTokenizer] = None
    truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None

    # `protected_namespaces` resolves Pydantic v2's warning
    # on conflict with protected namespace "model_"
    model_config = ConfigDict(
        protected_namespaces=(),
        arbitrary_types_allowed=True,
    )


ClassificationServeContext = ServeContext[ClassificationRequest]


class EmbeddingServeContext(ServeContext[EmbeddingRequest]):
    chat_template: Optional[str] = None
    chat_template_content_format: ChatTemplateContentFormatOption


# Used to resolve the Pydantic error related to
# forward reference of MultiModalDataDict in TokensPrompt
RequestProcessingMixin.model_rebuild()
ServeContext.model_rebuild()
ClassificationServeContext.model_rebuild()
EmbeddingServeContext.model_rebuild()
197
198


199
class OpenAIServing:
200
201
202
203
    request_id_prefix: ClassVar[str] = """
    A short string prepended to every request’s ID (e.g. "embd", "classify")
    so you can easily tell “this ID came from Embedding vs Classification.”
    """
204

205
206
    def __init__(
        self,
207
        engine_client: EngineClient,
208
        model_config: ModelConfig,
209
        models: OpenAIServingModels,
210
211
        *,
        request_logger: Optional[RequestLogger],
212
        return_tokens_as_token_ids: bool = False,
213
        enable_force_include_usage: bool = False,
214
    ):
215
216
        super().__init__()

217
        self.engine_client = engine_client
218
        self.model_config = model_config
219
        self.max_model_len = model_config.max_model_len
220
221
222
223
        self.tokenizer_mode = model_config.tokenizer_mode
        
        if model_config.tokenizer_mode == "cpm":
            self.tokenizer = CPM9GTokenizer(model_config.model, trust_remote_code=True) 
224

225
        self.models = models
226

227
        self.request_logger = request_logger
228
        self.return_tokens_as_token_ids = return_tokens_as_token_ids
229
        self.enable_force_include_usage = enable_force_include_usage
230

231
232
233
234
235
236
237
238
        self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)

        self._tokenize_prompt_input_async = make_async(
            self._tokenize_prompt_input, executor=self._tokenizer_executor)
        self._tokenize_prompt_input_or_inputs_async = make_async(
            self._tokenize_prompt_input_or_inputs,
            executor=self._tokenizer_executor)

239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
    async def _preprocess(
        self,
        ctx: ServeContext,
    ) -> Optional[ErrorResponse]:
        """
        Default preprocessing hook. Subclasses may override
        to prepare `ctx` (classification, embedding, etc.).
        """
        return None

    def _build_response(
        self,
        ctx: ServeContext,
    ) -> Union[AnyResponse, ErrorResponse]:
        """
        Default response builder. Subclass may override this method
        to return the appropriate response object.
        """
        return self.create_error_response("unimplemented endpoint")

    async def handle(
        self,
        ctx: ServeContext,
    ) -> Union[AnyResponse, ErrorResponse]:
        generation: AsyncGenerator[Union[AnyResponse, ErrorResponse], None]
        generation = self._pipeline(ctx)

        async for response in generation:
            return response

        return self.create_error_response("No response yielded from pipeline")

    async def _pipeline(
        self,
        ctx: ServeContext,
    ) -> AsyncGenerator[Union[AnyResponse, ErrorResponse], None]:
        """Execute the request processing pipeline yielding responses."""
        if error := await self._check_model(ctx.request):
            yield error
        if error := self._validate_request(ctx):
            yield error

        preprocess_ret = await self._preprocess(ctx)
        if isinstance(preprocess_ret, ErrorResponse):
            yield preprocess_ret

        generators_ret = await self._prepare_generators(ctx)
        if isinstance(generators_ret, ErrorResponse):
            yield generators_ret

        collect_ret = await self._collect_batch(ctx)
        if isinstance(collect_ret, ErrorResponse):
            yield collect_ret

        yield self._build_response(ctx)

    def _validate_request(self, ctx: ServeContext) -> Optional[ErrorResponse]:
        truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens",
                                         None)

        if truncate_prompt_tokens is not None:
            if truncate_prompt_tokens <= self.max_model_len:
                ctx.truncate_prompt_tokens = truncate_prompt_tokens
            else:
                return self.create_error_response(
                    "truncate_prompt_tokens value is "
                    "greater than max_model_len."
                    " Please, select a smaller truncation size.")
        return None

    async def _prepare_generators(
        self,
        ctx: ServeContext,
    ) -> Optional[ErrorResponse]:
        """Schedule the request and get the result generator."""
        generators: list[AsyncGenerator[Union[RequestOutput,
                                              PoolingRequestOutput],
                                        None]] = []

        try:
            trace_headers = (None if ctx.raw_request is None else await
                             self._get_trace_headers(ctx.raw_request.headers))

            if not hasattr(ctx.request, "to_pooling_params"):
                return self.create_error_response(
                    "Request type does not support pooling parameters")

            pooling_params = ctx.request.to_pooling_params()

            if ctx.engine_prompts is None:
                return self.create_error_response(
                    "Engine prompts not available")

            for i, engine_prompt in enumerate(ctx.engine_prompts):
                request_id_item = f"{ctx.request_id}-{i}"

                if ctx.request_prompts is None:
                    return self.create_error_response(
                        "Request prompts not available")

                self._log_inputs(
                    request_id_item,
                    ctx.request_prompts[i],
                    params=pooling_params,
                    lora_request=ctx.lora_request,
                    prompt_adapter_request=ctx.prompt_adapter_request)

346
347
348
349
350
351
                # Mypy has an existing bug related to inferring the variance of
                # TypedDicts with `builtins.enumerate`:
                # https://github.com/python/mypy/issues/8586#issuecomment-2867698435
                engine_prompt = cast(
                    Union[EngineTokensPrompt, EngineEmbedsPrompt],
                    engine_prompt)
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
                generator = self.engine_client.encode(
                    engine_prompt,
                    pooling_params,
                    request_id_item,
                    lora_request=ctx.lora_request,
                    trace_headers=trace_headers,
                    priority=getattr(ctx.request, "priority", 0),
                )

                generators.append(generator)

            ctx.result_generator = merge_async_iterators(*generators)

            return None

        except Exception as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))

    async def _collect_batch(
        self,
        ctx: ServeContext,
    ) -> Optional[ErrorResponse]:
        """Collect batch results from the result generator."""
        try:
            if ctx.engine_prompts is None:
                return self.create_error_response(
                    "Engine prompts not available")

            num_prompts = len(ctx.engine_prompts)
            final_res_batch: list[Optional[Union[RequestOutput,
                                                 PoolingRequestOutput]]]
            final_res_batch = [None] * num_prompts

            if ctx.result_generator is None:
                return self.create_error_response(
                    "Result generator not available")

            async for i, res in ctx.result_generator:
                final_res_batch[i] = res

            if None in final_res_batch:
                return self.create_error_response(
                    "Failed to generate results for all prompts")

            ctx.final_res_batch = [
                res for res in final_res_batch if res is not None
            ]

            return None

        except Exception as e:
            return self.create_error_response(str(e))

406
407
408
409
410
411
412
413
414
    def create_error_response(
            self,
            message: str,
            err_type: str = "BadRequestError",
            status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
        return ErrorResponse(message=message,
                             type=err_type,
                             code=status_code.value)

415
416
417
418
419
420
421
422
423
424
425
426
427
    def create_streaming_error_response(
            self,
            message: str,
            err_type: str = "BadRequestError",
            status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
        json_str = json.dumps({
            "error":
            self.create_error_response(message=message,
                                       err_type=err_type,
                                       status_code=status_code).model_dump()
        })
        return json_str

428
    async def _check_model(
429
430
        self,
        request: AnyRequest,
431
    ) -> Optional[ErrorResponse]:
432
433
434

        error_response = None

435
        if self._is_model_supported(request.model):
436
            return None
437
438
439
        if request.model in [
                lora.lora_name for lora in self.models.lora_requests
        ]:
440
            return None
441
442
443
444
445
446
447
        if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING and request.model and (
                load_result := await self.models.resolve_lora(request.model)):
            if isinstance(load_result, LoRARequest):
                return None
            if isinstance(load_result, ErrorResponse) and \
                load_result.code == HTTPStatus.BAD_REQUEST.value:
                error_response = load_result
448
449
        if request.model in [
                prompt_adapter.prompt_adapter_name
450
                for prompt_adapter in self.models.prompt_adapter_requests
451
452
        ]:
            return None
453
454

        return error_response or self.create_error_response(
455
456
457
458
            message=f"The model `{request.model}` does not exist.",
            err_type="NotFoundError",
            status_code=HTTPStatus.NOT_FOUND)

459
460
    def _maybe_get_adapters(
        self, request: AnyRequest
461
    ) -> Union[tuple[None, None], tuple[LoRARequest, None], tuple[
462
            None, PromptAdapterRequest]]:
463
        if self._is_model_supported(request.model):
464
            return None, None
465
        for lora in self.models.lora_requests:
466
            if request.model == lora.lora_name:
467
                return lora, None
468
        for prompt_adapter in self.models.prompt_adapter_requests:
469
            if request.model == prompt_adapter.prompt_adapter_name:
470
                return None, prompt_adapter
471
        # if _check_model has been called earlier, this will be unreachable
472
        raise ValueError(f"The model `{request.model}` does not exist.")
473

474
475
476
477
478
    def _normalize_prompt_text_to_input(
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
        prompt: str,
479
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]],
480
481
        add_special_tokens: bool,
    ) -> TextTokensPrompt:
482
483
484
485
486
        if (self.model_config.encoder_config is not None
                and self.model_config.encoder_config.get(
                    "do_lower_case", False)):
            prompt = prompt.lower()

487
488
        if truncate_prompt_tokens is None:
            encoded = tokenizer(prompt, add_special_tokens=add_special_tokens)
489
490
491
492
493
494
        elif truncate_prompt_tokens < 0:
            # Negative means we cap at the model's max length
            encoded = tokenizer(prompt,
                                add_special_tokens=add_special_tokens,
                                truncation=True,
                                max_length=self.max_model_len)
495
        else:
496
497
498
499
500
            encoded = tokenizer(prompt,
                                add_special_tokens=add_special_tokens,
                                truncation=True,
                                max_length=truncate_prompt_tokens)

501
502
503
504
        if self.tokenizer_mode == "cpm":
            input_ids = [self.tokenizer.bos_id] + self.tokenizer.encode(prompt)
        else:
            input_ids = encoded.input_ids
505
506
507
508
509
510
511
512
513

        input_text = prompt

        return self._validate_input(request, input_ids, input_text)

    def _normalize_prompt_tokens_to_input(
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
514
        prompt_ids: list[int],
515
516
517
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
    ) -> TextTokensPrompt:
        if truncate_prompt_tokens is None:
518
            input_ids = prompt_ids
519
520
        elif truncate_prompt_tokens < 0:
            input_ids = prompt_ids[-self.max_model_len:]
521
522
523
        else:
            input_ids = prompt_ids[-truncate_prompt_tokens:]

524
        input_text = tokenizer.decode(input_ids) if self.tokenizer_mode != "cpm" else self.tokenizer.decode_all(input_ids) 
525

526
527
528
529
530
        return self._validate_input(request, input_ids, input_text)

    def _validate_input(
        self,
        request: AnyRequest,
531
        input_ids: list[int],
532
533
        input_text: str,
    ) -> TextTokensPrompt:
534
535
        token_num = len(input_ids)

536
537
        # Note: EmbeddingRequest, ClassificationRequest,
        # and ScoreRequest doesn't have max_tokens
538
539
        if isinstance(request,
                      (EmbeddingChatRequest, EmbeddingCompletionRequest,
540
                       ScoreRequest, RerankRequest, ClassificationRequest)):
541

542
            if token_num > self.max_model_len:
543
544
545
546
547
548
                operations: dict[type[AnyRequest], str] = {
                    ScoreRequest: "score",
                    ClassificationRequest: "classification"
                }
                operation = operations.get(type(request),
                                           "embedding generation")
549
550
551
                raise ValueError(
                    f"This model's maximum context length is "
                    f"{self.max_model_len} tokens. However, you requested "
552
553
                    f"{token_num} tokens in the input for {operation}. "
                    f"Please reduce the length of the input.")
554
555
            return TextTokensPrompt(prompt=input_text,
                                    prompt_token_ids=input_ids)
556

557
558
        # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
        # and does not require model context length validation
559
560
561
562
        if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest,
                                DetokenizeRequest)):
            return TextTokensPrompt(prompt=input_text,
                                    prompt_token_ids=input_ids)
563

564
565
566
567
568
        # chat completion endpoint supports max_completion_tokens
        if isinstance(request, ChatCompletionRequest):
            # TODO(#9845): remove max_tokens when field dropped from OpenAI API
            max_tokens = request.max_completion_tokens or request.max_tokens
        else:
569
            max_tokens = getattr(request, "max_tokens", None)
570
        if max_tokens is None:
571
572
573
574
575
            if token_num >= self.max_model_len:
                raise ValueError(
                    f"This model's maximum context length is "
                    f"{self.max_model_len} tokens. However, you requested "
                    f"{token_num} tokens in the messages, "
576
                    f"Please reduce the length of the messages.")
577
        elif token_num + max_tokens > self.max_model_len:
578
            raise ValueError(
579
580
                f"This model's maximum context length is "
                f"{self.max_model_len} tokens. However, you requested "
581
                f"{max_tokens + token_num} tokens "
582
                f"({token_num} in the messages, "
583
                f"{max_tokens} in the completion). "
584
585
586
587
588
589
590
591
                f"Please reduce the length of the messages or completion.")

        return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)

    def _tokenize_prompt_input(
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
592
        prompt_input: Union[str, list[int]],
593
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
594
595
596
        add_special_tokens: bool = True,
    ) -> TextTokensPrompt:
        """
597
598
        A simpler implementation of
        [`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs]
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
        that assumes single input.
        """
        return next(
            self._tokenize_prompt_inputs(
                request,
                tokenizer,
                [prompt_input],
                truncate_prompt_tokens=truncate_prompt_tokens,
                add_special_tokens=add_special_tokens,
            ))

    def _tokenize_prompt_inputs(
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
614
        prompt_inputs: Iterable[Union[str, list[int]]],
615
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
616
617
618
        add_special_tokens: bool = True,
    ) -> Iterator[TextTokensPrompt]:
        """
619
620
        A simpler implementation of
        [`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs]
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
        that assumes multiple inputs.
        """
        for text in prompt_inputs:
            if isinstance(text, str):
                yield self._normalize_prompt_text_to_input(
                    request,
                    tokenizer,
                    prompt=text,
                    truncate_prompt_tokens=truncate_prompt_tokens,
                    add_special_tokens=add_special_tokens,
                )
            else:
                yield self._normalize_prompt_tokens_to_input(
                    request,
                    tokenizer,
                    prompt_ids=text,
                    truncate_prompt_tokens=truncate_prompt_tokens,
                )

    def _tokenize_prompt_input_or_inputs(
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
644
645
        input_or_inputs: Optional[Union[str, list[str], list[int],
                                        list[list[int]]]],
646
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
647
        add_special_tokens: bool = True,
648
    ) -> tuple[list[TextTokensPrompt], list[EmbedsPrompt]]:
649
650
651
652
653
654
655
        """
        Tokenize/detokenize depending on the input format.

        According to `OpenAI API <https://platform.openai.com/docs/api-reference/embeddings/create>`_
        , each input can be a string or array of tokens. Note that each request
        can pass one or more inputs.
        """
656
657
658
659
660
661
662
663
664
665
666
667
668
669
        inputs_embeds = list[EmbedsPrompt]()
        inputs_text = list[TextTokensPrompt]()

        if (isinstance(request, CompletionRequest)
                and request.prompt_embeds is not None):
            inputs_embeds.extend(
                self._load_prompt_embeds(request.prompt_embeds,
                                         truncate_prompt_tokens))

        # Empty prompts are okay as long as there are prompt embeddings
        if input_or_inputs is None or (inputs_embeds
                                       and input_or_inputs == ""):
            return [], inputs_embeds

670
671
        # Although our type checking is based on mypy,
        # VSCode Pyright extension should still work properly
672
        # "is False" is required for Pyright to perform type narrowing
673
        # See: https://github.com/microsoft/pyright/issues/7672
674
        inputs_text.extend([
675
676
677
678
679
680
681
682
683
684
685
686
687
            self._normalize_prompt_text_to_input(
                request,
                tokenizer,
                prompt=prompt_input["content"],
                truncate_prompt_tokens=truncate_prompt_tokens,
                add_special_tokens=add_special_tokens)
            if prompt_input["is_tokens"] is False else
            self._normalize_prompt_tokens_to_input(
                request,
                tokenizer,
                prompt_ids=prompt_input["content"],
                truncate_prompt_tokens=truncate_prompt_tokens)
            for prompt_input in parse_and_batch_prompt(input_or_inputs)
688
        ])
689

690
        return inputs_text, inputs_embeds
691

692
    @overload
693
    async def _preprocess_completion(
694
        self,
695
696
697
        request: Union[DetokenizeRequest, EmbeddingCompletionRequest,
                       RerankRequest, ClassificationRequest, ScoreRequest,
                       TokenizeCompletionRequest],
698
        tokenizer: AnyTokenizer,
699
        input_or_inputs: Union[str, list[str], list[int], list[list[int]]],
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = ...,
        add_special_tokens: bool = ...,
    ) -> tuple[list[TextTokensPrompt], list[EngineTokensPrompt]]:
        ...

    @overload
    async def _preprocess_completion(
        self,
        request: CompletionRequest,
        tokenizer: AnyTokenizer,
        input_or_inputs: Optional[Union[str, list[str], list[int],
                                        list[list[int]]]],
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = ...,
        add_special_tokens: bool = ...,
    ) -> tuple[list[Union[TextTokensPrompt, EmbedsPrompt]], list[Union[
            EngineTokensPrompt, EngineEmbedsPrompt]]]:
        ...

    async def _preprocess_completion(
        self,
        request: CompletionLikeRequest,
        tokenizer: AnyTokenizer,
        input_or_inputs: Optional[Union[str, list[str], list[int],
                                        list[list[int]]]],
724
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
725
        add_special_tokens: bool = True,
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
    ) -> tuple[Union[list[TextTokensPrompt], list[Union[
            TextTokensPrompt, EmbedsPrompt]]], Union[
                list[EngineTokensPrompt], list[Union[EngineTokensPrompt,
                                                     EngineEmbedsPrompt]]]]:
        if not isinstance(request,
                          CompletionRequest) and input_or_inputs is None:
            raise ValueError(
                "Prompt embeds with non-completion requests is not"
                " currently supported.")

        (request_prompts_text, request_prompts_embeds
         ) = await self._tokenize_prompt_input_or_inputs_async(
             request,
             tokenizer,
             input_or_inputs,
             truncate_prompt_tokens=truncate_prompt_tokens,
             add_special_tokens=add_special_tokens,
         )

        engine_prompts_text = [
            EngineTokensPrompt(
                prompt_token_ids=request_prompt_text["prompt_token_ids"])
            for request_prompt_text in request_prompts_text
        ]
750

751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
        # This check is equivalent to simply checking if
        # `request_prompts_embeds` is empty, but it's difficult to propagate
        # overloads to the private helper functions to enable this check.
        # This overload is needed because only TextPrompts are allowed for
        # non-completion requests and if we don't add the overload here,
        # everywhere this function is used outside of serving_completion will
        # need logic asserting that only text prompts are in the request.
        if not isinstance(request,
                          CompletionRequest) and input_or_inputs is not None:
            return request_prompts_text, engine_prompts_text

        engine_prompts_embeds = [
            EngineEmbedsPrompt(
                prompt_embeds=request_prompt_embeds["prompt_embeds"])
            for request_prompt_embeds in request_prompts_embeds
766
767
        ]

768
769
        request_prompts = request_prompts_embeds + request_prompts_text
        engine_prompts = engine_prompts_embeds + engine_prompts_text
770
771
772
773
774
775
        return request_prompts, engine_prompts

    async def _preprocess_chat(
        self,
        request: ChatLikeRequest,
        tokenizer: AnyTokenizer,
776
        messages: list[ChatCompletionMessageParam],
777
778
        chat_template: Optional[str],
        chat_template_content_format: ChatTemplateContentFormatOption,
779
780
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
781
782
783
        tool_dicts: Optional[list[dict[str, Any]]] = None,
        documents: Optional[list[dict[str, str]]] = None,
        chat_template_kwargs: Optional[dict[str, Any]] = None,
784
785
786
        tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None,
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
        add_special_tokens: bool = False,
787
    ) -> tuple[list[ConversationMessage], Sequence[RequestPrompt],
788
               list[EngineTokensPrompt]]:
789
790
        model_config = self.model_config

791
792
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
793
            tool_dicts,
794
795
            chat_template_content_format,
            tokenizer,
796
            model_config=model_config,
797
        )
798
799
        conversation, mm_data_future = parse_chat_messages_futures(
            messages,
800
            model_config,
801
            tokenizer,
802
            content_format=resolved_content_format,
803
804
        )

805
        _chat_template_kwargs: dict[str, Any] = dict(
806
807
808
809
810
811
812
813
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tool_dicts,
            documents=documents,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

814
        request_prompt: Union[str, list[int]]
815
        if isinstance(tokenizer, MistralTokenizer):
816
817
818
            request_prompt = apply_mistral_chat_template(
                tokenizer,
                messages=messages,
819
                **_chat_template_kwargs,
820
821
822
            )
        else:
            request_prompt = apply_hf_chat_template(
823
                tokenizer=tokenizer,
824
                conversation=conversation,
825
                model_config=model_config,
826
                **_chat_template_kwargs,
827
828
829
830
            )

        mm_data = await mm_data_future

831
832
833
834
835
836
837
        # tool parsing is done only if a tool_parser has been set and if
        # tool_choice is not "none" (if tool_choice is "none" but a tool_parser
        # is set, we want to prevent parsing a tool_call hallucinated by the LLM
        should_parse_tools = tool_parser is not None and (hasattr(
            request, "tool_choice") and request.tool_choice != "none")

        if should_parse_tools:
838
839
840
841
            if not isinstance(request, ChatCompletionRequest):
                msg = "Tool usage is only supported for Chat Completions API"
                raise NotImplementedError(msg)

842
843
            request = tool_parser(tokenizer).adjust_request(  # type: ignore
                request=request)
844
845

        if isinstance(request_prompt, str):
846
            prompt_inputs = await self._tokenize_prompt_input_async(
847
848
849
850
851
852
853
854
855
856
857
858
859
860
                request,
                tokenizer,
                request_prompt,
                truncate_prompt_tokens=truncate_prompt_tokens,
                add_special_tokens=add_special_tokens,
            )
        else:
            # For MistralTokenizer
            assert is_list_of(request_prompt, int), (
                "Prompt has to be either a string or a list of token ids")
            prompt_inputs = TextTokensPrompt(
                prompt=tokenizer.decode(request_prompt),
                prompt_token_ids=request_prompt)

861
        engine_prompt = EngineTokensPrompt(
862
863
864
            prompt_token_ids=prompt_inputs["prompt_token_ids"])
        if mm_data is not None:
            engine_prompt["multi_modal_data"] = mm_data
865
866
        if request.mm_processor_kwargs is not None:
            engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
867

868
869
870
        if hasattr(request, "cache_salt") and request.cache_salt is not None:
            engine_prompt["cache_salt"] = request.cache_salt

871
872
        return conversation, [request_prompt], [engine_prompt]

873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
    def _load_prompt_embeds(
        self,
        prompt_embeds: Optional[Union[bytes, list[bytes]]],
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
    ) -> list[EmbedsPrompt]:

        def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt:
            tensor = torch.load(io.BytesIO(base64.b64decode(embed)),
                                weights_only=True)
            assert isinstance(
                tensor,
                (torch.FloatTensor, torch.BFloat16Tensor, torch.HalfTensor))
            if tensor.dim() > 2:
                tensor = tensor.squeeze(0)
                assert tensor.dim() == 2
            if truncate_prompt_tokens is not None:
                tensor = tensor[-truncate_prompt_tokens:]
            return {"prompt_embeds": tensor}

        if prompt_embeds:
            if isinstance(prompt_embeds, list):
                return [
                    _load_and_validate_embed(embed) for embed in prompt_embeds
                ]
            else:
                return [_load_and_validate_embed(prompt_embeds)]
        else:
            return []

902
903
904
    def _log_inputs(
        self,
        request_id: str,
905
        inputs: RequestPrompt,
906
907
        params: Optional[Union[SamplingParams, PoolingParams,
                               BeamSearchParams]],
908
909
910
911
912
        lora_request: Optional[LoRARequest],
        prompt_adapter_request: Optional[PromptAdapterRequest],
    ) -> None:
        if self.request_logger is None:
            return
913
        prompt, prompt_token_ids, prompt_embeds = None, None, None
914
915
916
917
        if isinstance(inputs, str):
            prompt = inputs
        elif isinstance(inputs, list):
            prompt_token_ids = inputs
918
919
        elif 'prompt_embeds' in inputs:
            prompt_embeds = inputs.get("prompt_embeds")
920
        else:
921
922
923
924
925
926
927
            prompt = inputs["prompt"]
            prompt_token_ids = inputs["prompt_token_ids"]

        self.request_logger.log_inputs(
            request_id,
            prompt,
            prompt_token_ids,
928
            prompt_embeds,
929
930
931
932
            params=params,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
        )
933

934
935
936
937
938
939
940
941
942
943
944
945
946
947
    async def _get_trace_headers(
        self,
        headers: Headers,
    ) -> Optional[Mapping[str, str]]:
        is_tracing_enabled = await self.engine_client.is_tracing_enabled()

        if is_tracing_enabled:
            return extract_trace_headers(headers)

        if contains_trace_headers(headers):
            log_tracing_disabled_warning()

        return None

948
    @staticmethod
949
    def _base_request_id(raw_request: Optional[Request],
950
951
952
                         default: Optional[str] = None) -> Optional[str]:
        """Pulls the request id to use from a header, if provided"""
        default = default or random_uuid()
953
954
955
956
        if raw_request is None:
            return default

        return raw_request.headers.get("X-Request-Id", default)
957

958
    @staticmethod
959
960
961
962
963
964
965
    def _get_decoded_token(logprob: Logprob,
                           token_id: int,
                           tokenizer: AnyTokenizer,
                           return_as_token_id: bool = False) -> str:
        if return_as_token_id:
            return f"token_id:{token_id}"

966
967
        if logprob.decoded_token is not None:
            return logprob.decoded_token
968
        return tokenizer.decode(token_id)
969

970
    def _is_model_supported(self, model_name: Optional[str]) -> bool:
971
972
        if not model_name:
            return True
973
        return self.models.is_base_model(model_name)
974
975
976
977
978
979

    def _get_model_name(self,
                        model_name: Optional[str] = None,
                        lora_request: Optional[LoRARequest] = None) -> str:
        if lora_request:
            return lora_request.lora_name
980
        if not model_name:
981
982
            return self.models.base_model_paths[0].name
        return model_name
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997


def clamp_prompt_logprobs(
    prompt_logprobs: Union[PromptLogprobs,
                           None]) -> Union[PromptLogprobs, None]:
    if prompt_logprobs is None:
        return prompt_logprobs

    for logprob_dict in prompt_logprobs:
        if logprob_dict is None:
            continue
        for logprob_values in logprob_dict.values():
            if logprob_values.logprob == float('-inf'):
                logprob_values.logprob = -9999.0
    return prompt_logprobs