serving_engine.py 38 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import json
4
import sys
5
import time
6
import traceback
7
8
from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor
9
from http import HTTPStatus
10
from typing import Any, Callable, ClassVar, Generic, Optional, TypeVar, Union
11

12
import torch
13
from fastapi import Request
14
from pydantic import BaseModel, ConfigDict, Field
15
from starlette.datastructures import Headers
16
17
from typing_extensions import TypeIs

18
19
20
21
22
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.processor import Processor

23
24
25
26
27
if sys.version_info >= (3, 12):
    from typing import TypedDict
else:
    from typing_extensions import TypedDict

28
import vllm.envs as envs
29
from vllm.config import ModelConfig
30
from vllm.engine.protocol import EngineClient
31
32
# yapf conflicts with isort for this block
# yapf: disable
33
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
34
                                         ChatTemplateContentFormatOption,
35
36
37
                                         ConversationMessage,
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
38
39
                                         parse_chat_messages_futures,
                                         resolve_chat_template_content_format)
40
from vllm.entrypoints.context import ConversationContext
41
from vllm.entrypoints.logger import RequestLogger
42
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
43
44
45
                                              ChatCompletionResponse,
                                              ClassificationRequest,
                                              ClassificationResponse,
46
                                              CompletionRequest,
47
                                              CompletionResponse,
48
                                              DetokenizeRequest,
49
50
                                              EmbeddingChatRequest,
                                              EmbeddingCompletionRequest,
51
                                              EmbeddingRequest,
52
                                              EmbeddingResponse, ErrorInfo,
53
54
55
56
57
                                              ErrorResponse,
                                              IOProcessorRequest,
                                              PoolingResponse, RerankRequest,
                                              ResponsesRequest, ScoreRequest,
                                              ScoreResponse,
58
                                              TokenizeChatRequest,
59
                                              TokenizeCompletionRequest,
60
61
                                              TokenizeResponse,
                                              TranscriptionRequest,
62
63
                                              TranscriptionResponse,
                                              TranslationRequest)
64
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
65
from vllm.entrypoints.openai.tool_parsers import ToolParser
66
67
from vllm.entrypoints.renderer import (BaseRenderer, CompletionRenderer,
                                       RenderConfig)
68
# yapf: enable
69
from vllm.inputs.data import PromptType
70
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
71
from vllm.inputs.parse import PromptComponents, get_prompt_components
72
from vllm.logger import init_logger
73
from vllm.logprobs import Logprob, PromptLogprobs
74
from vllm.lora.request import LoRARequest
75
from vllm.multimodal import (  # noqa: F401 - Required to resolve Pydantic error in RequestProcessingMixin
76
    MultiModalDataDict, MultiModalUUIDDict)
77
from vllm.outputs import PoolingRequestOutput, RequestOutput
78
from vllm.pooling_params import PoolingParams
79
from vllm.sampling_params import BeamSearchParams, SamplingParams
80
81
82
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
                          log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
83
84
from vllm.utils import (AsyncMicrobatchTokenizer, is_list_of,
                        merge_async_iterators, random_uuid)
85
86
87

logger = init_logger(__name__)

88
89
90
91
92
93
94
95
96
CompletionLikeRequest = Union[
    CompletionRequest,
    DetokenizeRequest,
    EmbeddingCompletionRequest,
    RerankRequest,
    ClassificationRequest,
    ScoreRequest,
    TokenizeCompletionRequest,
]
97
98
99

ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
                        TokenizeChatRequest]
100
SpeechToTextRequest = Union[TranscriptionRequest, TranslationRequest]
101
102
103
104
105
106
107
AnyRequest = Union[
    CompletionLikeRequest,
    ChatLikeRequest,
    SpeechToTextRequest,
    ResponsesRequest,
    IOProcessorRequest,
]
108

109
110
111
112
113
114
115
116
117
118
119
AnyResponse = Union[
    CompletionResponse,
    ChatCompletionResponse,
    EmbeddingResponse,
    TranscriptionResponse,
    TokenizeResponse,
    PoolingResponse,
    ClassificationResponse,
    ScoreResponse,
]

120
121
122

class TextTokensPrompt(TypedDict):
    prompt: str
123
    prompt_token_ids: list[int]
124
125


126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
class EmbedsPrompt(TypedDict):
    prompt_embeds: torch.Tensor


RequestPrompt = Union[list[int], str, TextTokensPrompt, EmbedsPrompt]


def is_text_tokens_prompt(prompt: RequestPrompt) -> TypeIs[TextTokensPrompt]:
    return (isinstance(prompt, dict) and "prompt_token_ids" in prompt
            and "prompt_embeds" not in prompt)


def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]:
    return (isinstance(prompt, dict) and "prompt_token_ids" not in prompt
            and "prompt_embeds" in prompt)

142

143
144
145
146
147
RequestT = TypeVar("RequestT", bound=AnyRequest)


class RequestProcessingMixin(BaseModel):
    """
148
    Mixin for request processing,
149
150
    handling prompt preparation and engine input.
    """
151

152
    request_prompts: Optional[Sequence[RequestPrompt]] = []
153
    engine_prompts: Optional[list[EngineTokensPrompt]] = []
154
155
156
157
158
159

    model_config = ConfigDict(arbitrary_types_allowed=True)


class ResponseGenerationMixin(BaseModel):
    """
160
    Mixin for response generation,
161
162
    managing result generators and final batch results.
    """
163

164
165
166
167
168
169
170
171
    result_generator: Optional[AsyncGenerator[tuple[int, Union[
        RequestOutput, PoolingRequestOutput]], None]] = None
    final_res_batch: list[Union[RequestOutput, PoolingRequestOutput]] = Field(
        default_factory=list)

    model_config = ConfigDict(arbitrary_types_allowed=True)


172
173
174
175
176
177
class ServeContext(
        RequestProcessingMixin,
        ResponseGenerationMixin,
        BaseModel,
        Generic[RequestT],
):
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
    # Shared across all requests
    request: RequestT
    raw_request: Optional[Request] = None
    model_name: str
    request_id: str
    created_time: int = Field(default_factory=lambda: int(time.time()))
    lora_request: Optional[LoRARequest] = None

    # Shared across most requests
    tokenizer: Optional[AnyTokenizer] = None

    # `protected_namespaces` resolves Pydantic v2's warning
    # on conflict with protected namespace "model_"
    model_config = ConfigDict(
        protected_namespaces=(),
        arbitrary_types_allowed=True,
    )


ClassificationServeContext = ServeContext[ClassificationRequest]


class EmbeddingServeContext(ServeContext[EmbeddingRequest]):
    chat_template: Optional[str] = None
    chat_template_content_format: ChatTemplateContentFormatOption


# Used to resolve the Pydantic error related to
# forward reference of MultiModalDataDict in TokensPrompt
RequestProcessingMixin.model_rebuild()
ServeContext.model_rebuild()
ClassificationServeContext.model_rebuild()
EmbeddingServeContext.model_rebuild()

212

213
class OpenAIServing:
214
215
216
217
    request_id_prefix: ClassVar[str] = """
    A short string prepended to every request’s ID (e.g. "embd", "classify")
    so you can easily tell “this ID came from Embedding vs Classification.”
    """
218

219
220
    def __init__(
        self,
221
        engine_client: EngineClient,
222
        model_config: ModelConfig,
223
        models: OpenAIServingModels,
224
225
        *,
        request_logger: Optional[RequestLogger],
226
        return_tokens_as_token_ids: bool = False,
227
        enable_force_include_usage: bool = False,
228
        log_error_stack: bool = False,
229
    ):
230
231
        super().__init__()

232
        self.engine_client = engine_client
233
        self.model_config = model_config
234
235
        self.max_model_len = model_config.max_model_len

236
        self.models = models
237

238
        self.request_logger = request_logger
239
        self.return_tokens_as_token_ids = return_tokens_as_token_ids
240
        self.enable_force_include_usage = enable_force_include_usage
241

242
243
        self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)

244
245
        self._async_tokenizer_pool: dict[AnyTokenizer,
                                         AsyncMicrobatchTokenizer] = {}
246
        self.log_error_stack = log_error_stack
247

248
249
250
251
252
253
254
255
256
257
    async def _get_processor(self) -> Processor:
        if not hasattr(self, "_processor"):
            vllm_config = await self.engine_client.get_vllm_config()
            if self.model_config.skip_tokenizer_init:
                tokenizer = None
            else:
                tokenizer = init_tokenizer_from_configs(self.model_config)
            self._processor = Processor(vllm_config, tokenizer)
        return self._processor

258
259
260
261
262
263
264
265
266
267
    def _get_renderer(self, tokenizer: Optional[AnyTokenizer]) -> BaseRenderer:
        """
        Get a Renderer instance with the provided tokenizer.
        Uses shared async tokenizer pool for efficiency.
        """
        return CompletionRenderer(
            model_config=self.model_config,
            tokenizer=tokenizer,
            async_tokenizer_pool=self._async_tokenizer_pool)

268
269
270
271
272
273
274
275
276
277
278
279
280
    def _build_render_config(
        self,
        request: Any,
    ) -> RenderConfig:
        """
        Build and return a `RenderConfig` for an endpoint.

        Used by the renderer to control how prompts are prepared
        (e.g., tokenization and length handling). Endpoints should
        implement this with logic appropriate to their request type.
        """
        raise NotImplementedError

281
282
    def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer:
        """
283
        Return (and cache) an `AsyncMicrobatchTokenizer` bound to the
284
285
286
287
288
289
290
        given tokenizer.
        """
        async_tokenizer = self._async_tokenizer_pool.get(tokenizer)
        if async_tokenizer is None:
            async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
            self._async_tokenizer_pool[tokenizer] = async_tokenizer
        return async_tokenizer
291

292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
    async def _preprocess(
        self,
        ctx: ServeContext,
    ) -> Optional[ErrorResponse]:
        """
        Default preprocessing hook. Subclasses may override
        to prepare `ctx` (classification, embedding, etc.).
        """
        return None

    def _build_response(
        self,
        ctx: ServeContext,
    ) -> Union[AnyResponse, ErrorResponse]:
        """
        Default response builder. Subclass may override this method
        to return the appropriate response object.
        """
        return self.create_error_response("unimplemented endpoint")

    async def handle(
        self,
        ctx: ServeContext,
    ) -> Union[AnyResponse, ErrorResponse]:
        generation: AsyncGenerator[Union[AnyResponse, ErrorResponse], None]
        generation = self._pipeline(ctx)

        async for response in generation:
            return response

        return self.create_error_response("No response yielded from pipeline")

    async def _pipeline(
        self,
        ctx: ServeContext,
    ) -> AsyncGenerator[Union[AnyResponse, ErrorResponse], None]:
        """Execute the request processing pipeline yielding responses."""
        if error := await self._check_model(ctx.request):
            yield error
        if error := self._validate_request(ctx):
            yield error

        preprocess_ret = await self._preprocess(ctx)
        if isinstance(preprocess_ret, ErrorResponse):
            yield preprocess_ret

        generators_ret = await self._prepare_generators(ctx)
        if isinstance(generators_ret, ErrorResponse):
            yield generators_ret

        collect_ret = await self._collect_batch(ctx)
        if isinstance(collect_ret, ErrorResponse):
            yield collect_ret

        yield self._build_response(ctx)

    def _validate_request(self, ctx: ServeContext) -> Optional[ErrorResponse]:
        truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens",
                                         None)

352
353
        if (truncate_prompt_tokens is not None
                and truncate_prompt_tokens > self.max_model_len):
354
355
356
357
            return self.create_error_response(
                "truncate_prompt_tokens value is "
                "greater than max_model_len."
                " Please, select a smaller truncation size.")
358
359
        return None

360
361
362
363
364
365
366
367
368
369
    def _create_pooling_params(
        self,
        ctx: ServeContext,
    ) -> Union[PoolingParams, ErrorResponse]:
        if not hasattr(ctx.request, "to_pooling_params"):
            return self.create_error_response(
                "Request type does not support pooling parameters")

        return ctx.request.to_pooling_params()

370
371
372
373
374
375
376
377
378
379
380
381
382
    async def _prepare_generators(
        self,
        ctx: ServeContext,
    ) -> Optional[ErrorResponse]:
        """Schedule the request and get the result generator."""
        generators: list[AsyncGenerator[Union[RequestOutput,
                                              PoolingRequestOutput],
                                        None]] = []

        try:
            trace_headers = (None if ctx.raw_request is None else await
                             self._get_trace_headers(ctx.raw_request.headers))

383
384
385
            pooling_params = self._create_pooling_params(ctx)
            if isinstance(pooling_params, ErrorResponse):
                return pooling_params
386
387
388
389
390
391
392
393

            if ctx.engine_prompts is None:
                return self.create_error_response(
                    "Engine prompts not available")

            for i, engine_prompt in enumerate(ctx.engine_prompts):
                request_id_item = f"{ctx.request_id}-{i}"

394
395
                self._log_inputs(
                    request_id_item,
396
                    engine_prompt,
397
398
399
                    params=pooling_params,
                    lora_request=ctx.lora_request,
                )
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454

                generator = self.engine_client.encode(
                    engine_prompt,
                    pooling_params,
                    request_id_item,
                    lora_request=ctx.lora_request,
                    trace_headers=trace_headers,
                    priority=getattr(ctx.request, "priority", 0),
                )

                generators.append(generator)

            ctx.result_generator = merge_async_iterators(*generators)

            return None

        except Exception as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))

    async def _collect_batch(
        self,
        ctx: ServeContext,
    ) -> Optional[ErrorResponse]:
        """Collect batch results from the result generator."""
        try:
            if ctx.engine_prompts is None:
                return self.create_error_response(
                    "Engine prompts not available")

            num_prompts = len(ctx.engine_prompts)
            final_res_batch: list[Optional[Union[RequestOutput,
                                                 PoolingRequestOutput]]]
            final_res_batch = [None] * num_prompts

            if ctx.result_generator is None:
                return self.create_error_response(
                    "Result generator not available")

            async for i, res in ctx.result_generator:
                final_res_batch[i] = res

            if None in final_res_batch:
                return self.create_error_response(
                    "Failed to generate results for all prompts")

            ctx.final_res_batch = [
                res for res in final_res_batch if res is not None
            ]

            return None

        except Exception as e:
            return self.create_error_response(str(e))

455
    def create_error_response(
456
457
458
459
460
        self,
        message: str,
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
    ) -> ErrorResponse:
461
462
463
464
465
466
        if self.log_error_stack:
            exc_type, _, _ = sys.exc_info()
            if exc_type is not None:
                traceback.print_exc()
            else:
                traceback.print_stack()
467
468
        return ErrorResponse(error=ErrorInfo(
            message=message, type=err_type, code=status_code.value))
469

470
    def create_streaming_error_response(
471
472
473
474
475
        self,
        message: str,
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
    ) -> str:
476
        json_str = json.dumps(
477
478
            self.create_error_response(message=message,
                                       err_type=err_type,
479
                                       status_code=status_code).model_dump())
480
481
        return json_str

482
    async def _check_model(
483
484
        self,
        request: AnyRequest,
485
    ) -> Optional[ErrorResponse]:
486
487
        error_response = None

488
        if self._is_model_supported(request.model):
489
            return None
490
        if request.model in self.models.lora_requests:
491
            return None
492
493
        if (envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING and request.model and
            (load_result := await self.models.resolve_lora(request.model))):
494
495
            if isinstance(load_result, LoRARequest):
                return None
496
497
            if (isinstance(load_result, ErrorResponse) and
                    load_result.error.code == HTTPStatus.BAD_REQUEST.value):
498
499
500
                error_response = load_result

        return error_response or self.create_error_response(
501
502
            message=f"The model `{request.model}` does not exist.",
            err_type="NotFoundError",
503
504
            status_code=HTTPStatus.NOT_FOUND,
        )
505

506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
    def _get_active_default_mm_loras(
            self, request: AnyRequest) -> Optional[LoRARequest]:
        """Determine if there are any active default multimodal loras."""
        # TODO: Currently this is only enabled for chat completions
        # to be better aligned with only being enabled for .generate
        # when run offline. It would be nice to support additional
        # tasks types in the future.
        message_types = self._get_message_types(request)
        default_mm_loras = set()

        for lora in self.models.lora_requests.values():
            # Best effort match for default multimodal lora adapters;
            # There is probably a better way to do this, but currently
            # this matches against the set of 'types' in any content lists
            # up until '_', e.g., to match audio_url -> audio
            if lora.lora_name in message_types:
                default_mm_loras.add(lora)

        # Currently only support default modality specific loras if
        # we have exactly one lora matched on the request.
        if len(default_mm_loras) == 1:
            return default_mm_loras.pop()
        return None

530
    def _maybe_get_adapters(
531
532
533
        self,
        request: AnyRequest,
        supports_default_mm_loras: bool = False,
534
    ) -> Optional[LoRARequest]:
535
        if request.model in self.models.lora_requests:
536
            return self.models.lora_requests[request.model]
537
538
539
540
541
542

        # Currently only support default modality specific loras
        # if we have exactly one lora matched on the request.
        if supports_default_mm_loras:
            default_mm_lora = self._get_active_default_mm_loras(request)
            if default_mm_lora is not None:
543
                return default_mm_lora
544
545

        if self._is_model_supported(request.model):
546
            return None
547

548
        # if _check_model has been called earlier, this will be unreachable
549
        raise ValueError(f"The model `{request.model}` does not exist.")
550

551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
    def _get_message_types(self, request: AnyRequest) -> set[str]:
        """Retrieve the set of types from message content dicts up
        until `_`; we use this to match potential multimodal data
        with default per modality loras.
        """
        message_types: set[str] = set()

        if not hasattr(request, "messages"):
            return message_types

        for message in request.messages:
            if (isinstance(message, dict) and "content" in message
                    and isinstance(message["content"], list)):
                for content_dict in message["content"]:
                    if "type" in content_dict:
                        message_types.add(content_dict["type"].split("_")[0])
        return message_types

569
    async def _normalize_prompt_text_to_input(
570
571
572
        self,
        request: AnyRequest,
        prompt: str,
573
        tokenizer: AnyTokenizer,
574
575
        add_special_tokens: bool,
    ) -> TextTokensPrompt:
576
577
        async_tokenizer = self._get_async_tokenizer(tokenizer)

578
579
580
581
582
        if (self.model_config.encoder_config is not None
                and self.model_config.encoder_config.get(
                    "do_lower_case", False)):
            prompt = prompt.lower()

583
584
585
        truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
                                         None)

586
        if truncate_prompt_tokens is None:
587
588
            encoded = await async_tokenizer(
                prompt, add_special_tokens=add_special_tokens)
589
590
        elif truncate_prompt_tokens < 0:
            # Negative means we cap at the model's max length
591
592
593
594
            encoded = await async_tokenizer(
                prompt,
                add_special_tokens=add_special_tokens,
                truncation=True,
595
596
                max_length=self.max_model_len,
            )
597
        else:
598
599
600
601
            encoded = await async_tokenizer(
                prompt,
                add_special_tokens=add_special_tokens,
                truncation=True,
602
603
                max_length=truncate_prompt_tokens,
            )
604
605
606
607
608
609

        input_ids = encoded.input_ids
        input_text = prompt

        return self._validate_input(request, input_ids, input_text)

610
    async def _normalize_prompt_tokens_to_input(
611
612
        self,
        request: AnyRequest,
613
        prompt_ids: list[int],
614
        tokenizer: Optional[AnyTokenizer],
615
    ) -> TextTokensPrompt:
616
617
618
        truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
                                         None)

619
        if truncate_prompt_tokens is None:
620
            input_ids = prompt_ids
621
622
        elif truncate_prompt_tokens < 0:
            input_ids = prompt_ids[-self.max_model_len:]
623
624
625
        else:
            input_ids = prompt_ids[-truncate_prompt_tokens:]

626
627
628
629
630
        if tokenizer is None:
            input_text = ""
        else:
            async_tokenizer = self._get_async_tokenizer(tokenizer)
            input_text = await async_tokenizer.decode(input_ids)
631

632
633
634
635
636
        return self._validate_input(request, input_ids, input_text)

    def _validate_input(
        self,
        request: AnyRequest,
637
        input_ids: list[int],
638
639
        input_text: str,
    ) -> TextTokensPrompt:
640
641
        token_num = len(input_ids)

642
643
        # Note: EmbeddingRequest, ClassificationRequest,
        # and ScoreRequest doesn't have max_tokens
644
645
646
647
648
649
650
651
652
653
        if isinstance(
                request,
            (
                EmbeddingChatRequest,
                EmbeddingCompletionRequest,
                ScoreRequest,
                RerankRequest,
                ClassificationRequest,
            ),
        ):
654
655
            # Note: input length can be up to the entire model context length
            # since these requests don't generate tokens.
656
            if token_num > self.max_model_len:
657
658
                operations: dict[type[AnyRequest], str] = {
                    ScoreRequest: "score",
659
                    ClassificationRequest: "classification",
660
661
662
                }
                operation = operations.get(type(request),
                                           "embedding generation")
663
664
665
                raise ValueError(
                    f"This model's maximum context length is "
                    f"{self.max_model_len} tokens. However, you requested "
666
667
                    f"{token_num} tokens in the input for {operation}. "
                    f"Please reduce the length of the input.")
668
669
            return TextTokensPrompt(prompt=input_text,
                                    prompt_token_ids=input_ids)
670

671
672
        # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
        # and does not require model context length validation
673
674
675
676
677
        if isinstance(
                request,
            (TokenizeCompletionRequest, TokenizeChatRequest,
             DetokenizeRequest),
        ):
678
679
            return TextTokensPrompt(prompt=input_text,
                                    prompt_token_ids=input_ids)
680

681
682
683
684
685
        # chat completion endpoint supports max_completion_tokens
        if isinstance(request, ChatCompletionRequest):
            # TODO(#9845): remove max_tokens when field dropped from OpenAI API
            max_tokens = request.max_completion_tokens or request.max_tokens
        else:
686
            max_tokens = getattr(request, "max_tokens", None)
687
688
689
690

        # Note: input length can be up to model context length - 1 for
        # completion-like requests.
        if token_num >= self.max_model_len:
691
            raise ValueError(
692
                f"This model's maximum context length is "
693
694
695
696
                f"{self.max_model_len} tokens. However, your request has "
                f"{token_num} input tokens. Please reduce the length of "
                "the input messages.")

697
698
        if (max_tokens is not None
                and token_num + max_tokens > self.max_model_len):
699
700
701
702
703
704
            raise ValueError(
                "'max_tokens' or 'max_completion_tokens' is too large: "
                f"{max_tokens}. This model's maximum context length is "
                f"{self.max_model_len} tokens and your request has "
                f"{token_num} input tokens ({max_tokens} > {self.max_model_len}"
                f" - {token_num}).")
705
706
707

        return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)

708
    async def _tokenize_prompt_input_async(
709
710
711
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
712
        prompt_input: Union[str, list[int]],
713
714
715
        add_special_tokens: bool = True,
    ) -> TextTokensPrompt:
        """
716
        A simpler implementation that tokenizes a single prompt input.
717
        """
718
        async for result in self._tokenize_prompt_inputs_async(
719
720
                request,
                tokenizer,
721
            [prompt_input],
722
                add_special_tokens=add_special_tokens,
723
724
725
        ):
            return result
        raise ValueError("No results yielded from tokenization")
726

727
    async def _tokenize_prompt_inputs_async(
728
729
730
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
731
        prompt_inputs: Iterable[Union[str, list[int]]],
732
        add_special_tokens: bool = True,
733
    ) -> AsyncGenerator[TextTokensPrompt, None]:
734
        """
735
        A simpler implementation that tokenizes multiple prompt inputs.
736
        """
737
738
        for prompt in prompt_inputs:
            if isinstance(prompt, str):
739
                yield await self._normalize_prompt_text_to_input(
740
                    request,
741
742
                    prompt=prompt,
                    tokenizer=tokenizer,
743
744
745
                    add_special_tokens=add_special_tokens,
                )
            else:
746
                yield await self._normalize_prompt_tokens_to_input(
747
                    request,
748
749
                    prompt_ids=prompt,
                    tokenizer=tokenizer,
750
751
                )

752
753
    async def _preprocess_chat(
        self,
754
        request: Union[ChatLikeRequest, ResponsesRequest],
755
        tokenizer: AnyTokenizer,
756
        messages: list[ChatCompletionMessageParam],
757
758
        chat_template: Optional[str],
        chat_template_content_format: ChatTemplateContentFormatOption,
759
760
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
761
762
763
        tool_dicts: Optional[list[dict[str, Any]]] = None,
        documents: Optional[list[dict[str, str]]] = None,
        chat_template_kwargs: Optional[dict[str, Any]] = None,
764
765
        tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None,
        add_special_tokens: bool = False,
766
767
768
769
770
    ) -> tuple[
            list[ConversationMessage],
            Sequence[RequestPrompt],
            list[EngineTokensPrompt],
    ]:
771
772
        model_config = self.model_config

773
774
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
775
            tool_dicts,
776
777
            chat_template_content_format,
            tokenizer,
778
            model_config=model_config,
779
        )
780
        conversation, mm_data_future, mm_uuids = parse_chat_messages_futures(
781
            messages,
782
            model_config,
783
            tokenizer,
784
            content_format=resolved_content_format,
785
786
        )

787
        _chat_template_kwargs: dict[str, Any] = dict(
788
789
790
791
792
793
794
795
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tool_dicts,
            documents=documents,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

796
        request_prompt: Union[str, list[int]]
797
798
799
800

        if tokenizer is None:
            request_prompt = "placeholder"
        elif isinstance(tokenizer, MistralTokenizer):
801
802
803
            request_prompt = apply_mistral_chat_template(
                tokenizer,
                messages=messages,
804
                **_chat_template_kwargs,
805
806
807
            )
        else:
            request_prompt = apply_hf_chat_template(
808
                tokenizer=tokenizer,
809
                conversation=conversation,
810
                model_config=model_config,
811
                **_chat_template_kwargs,
812
813
814
815
            )

        mm_data = await mm_data_future

816
817
818
819
820
821
822
        # tool parsing is done only if a tool_parser has been set and if
        # tool_choice is not "none" (if tool_choice is "none" but a tool_parser
        # is set, we want to prevent parsing a tool_call hallucinated by the LLM
        should_parse_tools = tool_parser is not None and (hasattr(
            request, "tool_choice") and request.tool_choice != "none")

        if should_parse_tools:
823
824
825
826
            if not isinstance(request, ChatCompletionRequest):
                msg = "Tool usage is only supported for Chat Completions API"
                raise NotImplementedError(msg)

827
828
            request = tool_parser(tokenizer).adjust_request(  # type: ignore
                request=request)
829

830
831
        if tokenizer is None:
            assert isinstance(request_prompt, str), (
832
833
                "Prompt has to be a string",
                "when the tokenizer is not initialised",
834
835
836
837
            )
            prompt_inputs = TextTokensPrompt(prompt=request_prompt,
                                             prompt_token_ids=[1])
        elif isinstance(request_prompt, str):
838
            prompt_inputs = await self._tokenize_prompt_input_async(
839
840
841
842
843
844
845
846
847
848
849
                request,
                tokenizer,
                request_prompt,
                add_special_tokens=add_special_tokens,
            )
        else:
            # For MistralTokenizer
            assert is_list_of(request_prompt, int), (
                "Prompt has to be either a string or a list of token ids")
            prompt_inputs = TextTokensPrompt(
                prompt=tokenizer.decode(request_prompt),
850
851
                prompt_token_ids=request_prompt,
            )
852

853
        engine_prompt = EngineTokensPrompt(
854
855
856
            prompt_token_ids=prompt_inputs["prompt_token_ids"])
        if mm_data is not None:
            engine_prompt["multi_modal_data"] = mm_data
857
858
859
860

        if mm_uuids is not None:
            engine_prompt["multi_modal_uuids"] = mm_uuids

861
862
        if request.mm_processor_kwargs is not None:
            engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
863

864
865
866
        if hasattr(request, "cache_salt") and request.cache_salt is not None:
            engine_prompt["cache_salt"] = request.cache_salt

867
868
        return conversation, [request_prompt], [engine_prompt]

869
870
871
872
    async def _process_inputs(
        self,
        request_id: str,
        engine_prompt: PromptType,
873
        params: Union[SamplingParams, PoolingParams],
874
875
876
877
878
        *,
        lora_request: Optional[LoRARequest],
        trace_headers: Optional[Mapping[str, str]],
        priority: int,
    ) -> tuple[EngineCoreRequest, dict[str, Any]]:
879
        """Use the Processor to process inputs for AsyncLLM."""
880
881
        tokenization_kwargs: dict[str, Any] = {}
        _validate_truncation_size(self.max_model_len,
882
                                  params.truncate_prompt_tokens,
883
884
885
886
887
888
                                  tokenization_kwargs)

        processor = await self._get_processor()
        engine_request = processor.process_inputs(
            request_id,
            engine_prompt,
889
            params,
890
891
892
893
894
895
896
            lora_request=lora_request,
            tokenization_kwargs=tokenization_kwargs,
            trace_headers=trace_headers,
            priority=priority,
        )
        return engine_request, tokenization_kwargs

897
898
899
900
901
902
903
904
905
906
907
    async def _generate_with_builtin_tools(
        self,
        request_id: str,
        request_prompt: RequestPrompt,
        engine_prompt: EngineTokensPrompt,
        sampling_params: SamplingParams,
        context: ConversationContext,
        lora_request: Optional[LoRARequest] = None,
        priority: int = 0,
        **kwargs,
    ):
908
        prompt_text, _, _ = self._get_prompt_components(request_prompt)
909
910
911
912
913
914
915
916
        orig_priority = priority
        while True:
            self._log_inputs(
                request_id,
                request_prompt,
                params=sampling_params,
                lora_request=lora_request,
            )
917
918
919
            trace_headers = kwargs.get("trace_headers")
            engine_request, tokenization_kwargs = (await self._process_inputs(
                request_id,
920
921
                engine_prompt,
                sampling_params,
922
923
924
925
926
927
928
929
                lora_request=lora_request,
                trace_headers=trace_headers,
                priority=priority,
            ))

            generator = self.engine_client.generate(
                engine_request,
                sampling_params,
930
931
932
                request_id,
                lora_request=lora_request,
                priority=priority,
933
934
                prompt_text=prompt_text,
                tokenization_kwargs=tokenization_kwargs,
935
936
                **kwargs,
            )
937

938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
            async for res in generator:
                context.append_output(res)
                # NOTE(woosuk): The stop condition is handled by the engine.
                yield context

            if not context.need_builtin_tool_call():
                # The model did not ask for a tool call, so we're done.
                break

            # Call the tool and update the context with the result.
            tool_output = await context.call_tool()
            context.append_output(tool_output)

            # TODO: uncomment this and enable tool output streaming
            # yield context

            # Create inputs for the next turn.
            # Render the next prompt token ids.
            prompt_token_ids = context.render_for_completion()
            engine_prompt = EngineTokensPrompt(
                prompt_token_ids=prompt_token_ids)
            request_prompt = prompt_token_ids
            # Update the sampling params.
961
962
            sampling_params.max_tokens = self.max_model_len - len(
                prompt_token_ids)
963
964
965
            # OPTIMIZATION
            priority = orig_priority - 1

966
967
    def _get_prompt_components(
        self,
968
        prompt: Union[RequestPrompt, PromptType],
969
    ) -> PromptComponents:
970
971
        if isinstance(prompt, list):
            return PromptComponents(token_ids=prompt)
972

973
        return get_prompt_components(prompt)  # type: ignore[arg-type]
974

975
976
977
    def _log_inputs(
        self,
        request_id: str,
978
        inputs: Union[RequestPrompt, PromptType],
979
980
        params: Optional[Union[SamplingParams, PoolingParams,
                               BeamSearchParams]],
981
982
983
984
        lora_request: Optional[LoRARequest],
    ) -> None:
        if self.request_logger is None:
            return
985
986
987

        prompt, prompt_token_ids, prompt_embeds = (
            self._get_prompt_components(inputs))
988
989
990
991
992

        self.request_logger.log_inputs(
            request_id,
            prompt,
            prompt_token_ids,
993
            prompt_embeds,
994
995
996
            params=params,
            lora_request=lora_request,
        )
997

998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
    async def _get_trace_headers(
        self,
        headers: Headers,
    ) -> Optional[Mapping[str, str]]:
        is_tracing_enabled = await self.engine_client.is_tracing_enabled()

        if is_tracing_enabled:
            return extract_trace_headers(headers)

        if contains_trace_headers(headers):
            log_tracing_disabled_warning()

        return None

1012
    @staticmethod
1013
    def _base_request_id(raw_request: Optional[Request],
1014
1015
1016
                         default: Optional[str] = None) -> Optional[str]:
        """Pulls the request id to use from a header, if provided"""
        default = default or random_uuid()
1017
1018
1019
1020
        if raw_request is None:
            return default

        return raw_request.headers.get("X-Request-Id", default)
1021

1022
    @staticmethod
1023
1024
1025
1026
1027
1028
    def _get_decoded_token(
        logprob: Logprob,
        token_id: int,
        tokenizer: AnyTokenizer,
        return_as_token_id: bool = False,
    ) -> str:
1029
1030
1031
        if return_as_token_id:
            return f"token_id:{token_id}"

1032
1033
        if logprob.decoded_token is not None:
            return logprob.decoded_token
1034
        return tokenizer.decode(token_id)
1035

1036
    def _is_model_supported(self, model_name: Optional[str]) -> bool:
1037
1038
        if not model_name:
            return True
1039
        return self.models.is_base_model(model_name)
1040

1041
1042
1043

def clamp_prompt_logprobs(
    prompt_logprobs: Union[PromptLogprobs,
1044
                           None], ) -> Union[PromptLogprobs, None]:
1045
1046
1047
1048
1049
1050
1051
    if prompt_logprobs is None:
        return prompt_logprobs

    for logprob_dict in prompt_logprobs:
        if logprob_dict is None:
            continue
        for logprob_values in logprob_dict.values():
1052
            if logprob_values.logprob == float("-inf"):
1053
1054
                logprob_values.logprob = -9999.0
    return prompt_logprobs