serving_engine.py 37.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import json
4
import sys
5
import time
6
import traceback
7
8
from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor
9
from http import HTTPStatus
10
from typing import Any, Callable, ClassVar, Generic, Optional, TypeVar, Union
11

12
import torch
13
from fastapi import Request
14
from pydantic import BaseModel, ConfigDict, Field
15
from starlette.datastructures import Headers
16
17
from typing_extensions import TypeIs

18
19
20
21
22
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.processor import Processor

23
24
25
26
27
if sys.version_info >= (3, 12):
    from typing import TypedDict
else:
    from typing_extensions import TypedDict

28
import vllm.envs as envs
29
from vllm.config import ModelConfig
30
from vllm.engine.protocol import EngineClient
31

32
33
# yapf conflicts with isort for this block
# yapf: disable
34
35
36
37
38
39
40
41
42
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ChatTemplateContentFormatOption,
    ConversationMessage,
    apply_hf_chat_template,
    apply_mistral_chat_template,
    parse_chat_messages_futures,
    resolve_chat_template_content_format,
)
43
from vllm.entrypoints.context import ConversationContext
44
from vllm.entrypoints.logger import RequestLogger
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from vllm.entrypoints.openai.protocol import (
    ChatCompletionRequest,
    ChatCompletionResponse,
    ClassificationRequest,
    ClassificationResponse,
    CompletionRequest,
    CompletionResponse,
    DetokenizeRequest,
    EmbeddingChatRequest,
    EmbeddingCompletionRequest,
    EmbeddingRequest,
    EmbeddingResponse,
    ErrorInfo,
    ErrorResponse,
    IOProcessorRequest,
    PoolingResponse,
    RerankRequest,
    ResponsesRequest,
    ScoreRequest,
    ScoreResponse,
    TokenizeChatRequest,
    TokenizeCompletionRequest,
    TokenizeResponse,
    TranscriptionRequest,
    TranscriptionResponse,
    TranslationRequest,
)
72
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
73
from vllm.entrypoints.openai.tool_parsers import ToolParser
74
75
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig

76
# yapf: enable
77
from vllm.inputs.data import PromptType
78
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
79
from vllm.inputs.parse import PromptComponents, get_prompt_components
80
from vllm.logger import init_logger
81
from vllm.logprobs import Logprob, PromptLogprobs
82
from vllm.lora.request import LoRARequest
83
from vllm.multimodal import (  # noqa: F401 - Required to resolve Pydantic error in RequestProcessingMixin
84
85
86
    MultiModalDataDict,
    MultiModalUUIDDict,
)
87
from vllm.outputs import PoolingRequestOutput, RequestOutput
88
from vllm.pooling_params import PoolingParams
89
from vllm.sampling_params import BeamSearchParams, SamplingParams
90
91
92
93
94
from vllm.tracing import (
    contains_trace_headers,
    extract_trace_headers,
    log_tracing_disabled_warning,
)
95
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
96
97
98
99
100
101
102
from vllm.utils import (
    AsyncMicrobatchTokenizer,
    is_list_of,
    make_async,
    merge_async_iterators,
    random_uuid,
)
103
104
105

logger = init_logger(__name__)

106
107
108
109
110
111
112
113
114
CompletionLikeRequest = Union[
    CompletionRequest,
    DetokenizeRequest,
    EmbeddingCompletionRequest,
    RerankRequest,
    ClassificationRequest,
    ScoreRequest,
    TokenizeCompletionRequest,
]
115

116
117
118
ChatLikeRequest = Union[
    ChatCompletionRequest, EmbeddingChatRequest, TokenizeChatRequest
]
119
SpeechToTextRequest = Union[TranscriptionRequest, TranslationRequest]
120
121
122
123
124
125
126
AnyRequest = Union[
    CompletionLikeRequest,
    ChatLikeRequest,
    SpeechToTextRequest,
    ResponsesRequest,
    IOProcessorRequest,
]
127

128
129
130
131
132
133
134
135
136
137
138
AnyResponse = Union[
    CompletionResponse,
    ChatCompletionResponse,
    EmbeddingResponse,
    TranscriptionResponse,
    TokenizeResponse,
    PoolingResponse,
    ClassificationResponse,
    ScoreResponse,
]

139
140
141

class TextTokensPrompt(TypedDict):
    prompt: str
142
    prompt_token_ids: list[int]
143
144


145
146
147
148
149
150
151
152
class EmbedsPrompt(TypedDict):
    prompt_embeds: torch.Tensor


RequestPrompt = Union[list[int], str, TextTokensPrompt, EmbedsPrompt]


def is_text_tokens_prompt(prompt: RequestPrompt) -> TypeIs[TextTokensPrompt]:
153
154
155
156
157
    return (
        isinstance(prompt, dict)
        and "prompt_token_ids" in prompt
        and "prompt_embeds" not in prompt
    )
158
159
160


def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]:
161
162
163
164
165
    return (
        isinstance(prompt, dict)
        and "prompt_token_ids" not in prompt
        and "prompt_embeds" in prompt
    )
166

167

168
169
170
171
172
RequestT = TypeVar("RequestT", bound=AnyRequest)


class RequestProcessingMixin(BaseModel):
    """
173
    Mixin for request processing,
174
175
    handling prompt preparation and engine input.
    """
176

177
    request_prompts: Optional[Sequence[RequestPrompt]] = []
178
    engine_prompts: Optional[list[EngineTokensPrompt]] = []
179
180
181
182
183
184

    model_config = ConfigDict(arbitrary_types_allowed=True)


class ResponseGenerationMixin(BaseModel):
    """
185
    Mixin for response generation,
186
187
    managing result generators and final batch results.
    """
188

189
190
191
    result_generator: Optional[
        AsyncGenerator[tuple[int, Union[RequestOutput, PoolingRequestOutput]], None]
    ] = None
192
    final_res_batch: list[Union[RequestOutput, PoolingRequestOutput]] = Field(
193
194
        default_factory=list
    )
195
196
197
198

    model_config = ConfigDict(arbitrary_types_allowed=True)


199
class ServeContext(
200
201
202
203
    RequestProcessingMixin,
    ResponseGenerationMixin,
    BaseModel,
    Generic[RequestT],
204
):
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
    # Shared across all requests
    request: RequestT
    raw_request: Optional[Request] = None
    model_name: str
    request_id: str
    created_time: int = Field(default_factory=lambda: int(time.time()))
    lora_request: Optional[LoRARequest] = None

    # Shared across most requests
    tokenizer: Optional[AnyTokenizer] = None

    # `protected_namespaces` resolves Pydantic v2's warning
    # on conflict with protected namespace "model_"
    model_config = ConfigDict(
        protected_namespaces=(),
        arbitrary_types_allowed=True,
    )


ClassificationServeContext = ServeContext[ClassificationRequest]


class EmbeddingServeContext(ServeContext[EmbeddingRequest]):
    chat_template: Optional[str] = None
    chat_template_content_format: ChatTemplateContentFormatOption


# Used to resolve the Pydantic error related to
# forward reference of MultiModalDataDict in TokensPrompt
RequestProcessingMixin.model_rebuild()
ServeContext.model_rebuild()
ClassificationServeContext.model_rebuild()
EmbeddingServeContext.model_rebuild()

239

240
class OpenAIServing:
241
242
243
244
    request_id_prefix: ClassVar[str] = """
    A short string prepended to every request’s ID (e.g. "embd", "classify")
    so you can easily tell “this ID came from Embedding vs Classification.”
    """
245

246
247
    def __init__(
        self,
248
        engine_client: EngineClient,
249
        model_config: ModelConfig,
250
        models: OpenAIServingModels,
251
252
        *,
        request_logger: Optional[RequestLogger],
253
        return_tokens_as_token_ids: bool = False,
254
        enable_force_include_usage: bool = False,
255
        log_error_stack: bool = False,
256
    ):
257
258
        super().__init__()

259
        self.engine_client = engine_client
260
        self.model_config = model_config
261
262
        self.max_model_len = model_config.max_model_len

263
        self.models = models
264

265
        self.request_logger = request_logger
266
        self.return_tokens_as_token_ids = return_tokens_as_token_ids
267
        self.enable_force_include_usage = enable_force_include_usage
268

269
        self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
270
        self._apply_mistral_chat_template_async = make_async(
271
272
            apply_mistral_chat_template, executor=self._tokenizer_executor
        )
273

274
        self._async_tokenizer_pool: dict[AnyTokenizer, AsyncMicrobatchTokenizer] = {}
275
        self.log_error_stack = log_error_stack
276

277
278
279
280
281
282
283
284
285
286
    async def _get_processor(self) -> Processor:
        if not hasattr(self, "_processor"):
            vllm_config = await self.engine_client.get_vllm_config()
            if self.model_config.skip_tokenizer_init:
                tokenizer = None
            else:
                tokenizer = init_tokenizer_from_configs(self.model_config)
            self._processor = Processor(vllm_config, tokenizer)
        return self._processor

287
288
289
290
291
292
293
294
    def _get_renderer(self, tokenizer: Optional[AnyTokenizer]) -> BaseRenderer:
        """
        Get a Renderer instance with the provided tokenizer.
        Uses shared async tokenizer pool for efficiency.
        """
        return CompletionRenderer(
            model_config=self.model_config,
            tokenizer=tokenizer,
295
296
            async_tokenizer_pool=self._async_tokenizer_pool,
        )
297

298
299
300
301
302
303
304
305
306
307
308
309
310
    def _build_render_config(
        self,
        request: Any,
    ) -> RenderConfig:
        """
        Build and return a `RenderConfig` for an endpoint.

        Used by the renderer to control how prompts are prepared
        (e.g., tokenization and length handling). Endpoints should
        implement this with logic appropriate to their request type.
        """
        raise NotImplementedError

311
312
    def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer:
        """
313
        Return (and cache) an `AsyncMicrobatchTokenizer` bound to the
314
315
316
317
318
319
320
        given tokenizer.
        """
        async_tokenizer = self._async_tokenizer_pool.get(tokenizer)
        if async_tokenizer is None:
            async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
            self._async_tokenizer_pool[tokenizer] = async_tokenizer
        return async_tokenizer
321

322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
    async def _preprocess(
        self,
        ctx: ServeContext,
    ) -> Optional[ErrorResponse]:
        """
        Default preprocessing hook. Subclasses may override
        to prepare `ctx` (classification, embedding, etc.).
        """
        return None

    def _build_response(
        self,
        ctx: ServeContext,
    ) -> Union[AnyResponse, ErrorResponse]:
        """
        Default response builder. Subclass may override this method
        to return the appropriate response object.
        """
        return self.create_error_response("unimplemented endpoint")

    async def handle(
        self,
        ctx: ServeContext,
    ) -> Union[AnyResponse, ErrorResponse]:
        generation: AsyncGenerator[Union[AnyResponse, ErrorResponse], None]
        generation = self._pipeline(ctx)

        async for response in generation:
            return response

        return self.create_error_response("No response yielded from pipeline")

    async def _pipeline(
        self,
        ctx: ServeContext,
    ) -> AsyncGenerator[Union[AnyResponse, ErrorResponse], None]:
        """Execute the request processing pipeline yielding responses."""
        if error := await self._check_model(ctx.request):
            yield error
        if error := self._validate_request(ctx):
            yield error

        preprocess_ret = await self._preprocess(ctx)
        if isinstance(preprocess_ret, ErrorResponse):
            yield preprocess_ret

        generators_ret = await self._prepare_generators(ctx)
        if isinstance(generators_ret, ErrorResponse):
            yield generators_ret

        collect_ret = await self._collect_batch(ctx)
        if isinstance(collect_ret, ErrorResponse):
            yield collect_ret

        yield self._build_response(ctx)

    def _validate_request(self, ctx: ServeContext) -> Optional[ErrorResponse]:
379
        truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None)
380

381
382
383
384
        if (
            truncate_prompt_tokens is not None
            and truncate_prompt_tokens > self.max_model_len
        ):
385
386
387
            return self.create_error_response(
                "truncate_prompt_tokens value is "
                "greater than max_model_len."
388
389
                " Please, select a smaller truncation size."
            )
390
391
        return None

392
393
394
395
396
397
    def _create_pooling_params(
        self,
        ctx: ServeContext,
    ) -> Union[PoolingParams, ErrorResponse]:
        if not hasattr(ctx.request, "to_pooling_params"):
            return self.create_error_response(
398
399
                "Request type does not support pooling parameters"
            )
400
401
402

        return ctx.request.to_pooling_params()

403
404
405
406
407
    async def _prepare_generators(
        self,
        ctx: ServeContext,
    ) -> Optional[ErrorResponse]:
        """Schedule the request and get the result generator."""
408
409
410
        generators: list[
            AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]
        ] = []
411
412

        try:
413
414
415
416
417
            trace_headers = (
                None
                if ctx.raw_request is None
                else await self._get_trace_headers(ctx.raw_request.headers)
            )
418

419
420
421
            pooling_params = self._create_pooling_params(ctx)
            if isinstance(pooling_params, ErrorResponse):
                return pooling_params
422
423

            if ctx.engine_prompts is None:
424
                return self.create_error_response("Engine prompts not available")
425
426
427
428

            for i, engine_prompt in enumerate(ctx.engine_prompts):
                request_id_item = f"{ctx.request_id}-{i}"

429
430
                self._log_inputs(
                    request_id_item,
431
                    engine_prompt,
432
433
434
                    params=pooling_params,
                    lora_request=ctx.lora_request,
                )
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461

                generator = self.engine_client.encode(
                    engine_prompt,
                    pooling_params,
                    request_id_item,
                    lora_request=ctx.lora_request,
                    trace_headers=trace_headers,
                    priority=getattr(ctx.request, "priority", 0),
                )

                generators.append(generator)

            ctx.result_generator = merge_async_iterators(*generators)

            return None

        except Exception as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))

    async def _collect_batch(
        self,
        ctx: ServeContext,
    ) -> Optional[ErrorResponse]:
        """Collect batch results from the result generator."""
        try:
            if ctx.engine_prompts is None:
462
                return self.create_error_response("Engine prompts not available")
463
464

            num_prompts = len(ctx.engine_prompts)
465
            final_res_batch: list[Optional[Union[RequestOutput, PoolingRequestOutput]]]
466
467
468
            final_res_batch = [None] * num_prompts

            if ctx.result_generator is None:
469
                return self.create_error_response("Result generator not available")
470
471
472
473
474
475

            async for i, res in ctx.result_generator:
                final_res_batch[i] = res

            if None in final_res_batch:
                return self.create_error_response(
476
477
                    "Failed to generate results for all prompts"
                )
478

479
            ctx.final_res_batch = [res for res in final_res_batch if res is not None]
480
481
482
483
484
485

            return None

        except Exception as e:
            return self.create_error_response(str(e))

486
    def create_error_response(
487
488
489
490
491
        self,
        message: str,
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
    ) -> ErrorResponse:
492
493
494
495
496
497
        if self.log_error_stack:
            exc_type, _, _ = sys.exc_info()
            if exc_type is not None:
                traceback.print_exc()
            else:
                traceback.print_stack()
498
499
500
        return ErrorResponse(
            error=ErrorInfo(message=message, type=err_type, code=status_code.value)
        )
501

502
    def create_streaming_error_response(
503
504
505
506
507
        self,
        message: str,
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
    ) -> str:
508
        json_str = json.dumps(
509
510
511
512
            self.create_error_response(
                message=message, err_type=err_type, status_code=status_code
            ).model_dump()
        )
513
514
        return json_str

515
    async def _check_model(
516
517
        self,
        request: AnyRequest,
518
    ) -> Optional[ErrorResponse]:
519
520
        error_response = None

521
        if self._is_model_supported(request.model):
522
            return None
523
        if request.model in self.models.lora_requests:
524
            return None
525
526
527
528
529
        if (
            envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING
            and request.model
            and (load_result := await self.models.resolve_lora(request.model))
        ):
530
531
            if isinstance(load_result, LoRARequest):
                return None
532
533
534
535
            if (
                isinstance(load_result, ErrorResponse)
                and load_result.error.code == HTTPStatus.BAD_REQUEST.value
            ):
536
537
538
                error_response = load_result

        return error_response or self.create_error_response(
539
540
            message=f"The model `{request.model}` does not exist.",
            err_type="NotFoundError",
541
542
            status_code=HTTPStatus.NOT_FOUND,
        )
543

544
    def _get_active_default_mm_loras(
545
546
        self, request: AnyRequest
    ) -> Optional[LoRARequest]:
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
        """Determine if there are any active default multimodal loras."""
        # TODO: Currently this is only enabled for chat completions
        # to be better aligned with only being enabled for .generate
        # when run offline. It would be nice to support additional
        # tasks types in the future.
        message_types = self._get_message_types(request)
        default_mm_loras = set()

        for lora in self.models.lora_requests.values():
            # Best effort match for default multimodal lora adapters;
            # There is probably a better way to do this, but currently
            # this matches against the set of 'types' in any content lists
            # up until '_', e.g., to match audio_url -> audio
            if lora.lora_name in message_types:
                default_mm_loras.add(lora)

        # Currently only support default modality specific loras if
        # we have exactly one lora matched on the request.
        if len(default_mm_loras) == 1:
            return default_mm_loras.pop()
        return None

569
    def _maybe_get_adapters(
570
571
572
        self,
        request: AnyRequest,
        supports_default_mm_loras: bool = False,
573
    ) -> Optional[LoRARequest]:
574
        if request.model in self.models.lora_requests:
575
            return self.models.lora_requests[request.model]
576
577
578
579
580
581

        # Currently only support default modality specific loras
        # if we have exactly one lora matched on the request.
        if supports_default_mm_loras:
            default_mm_lora = self._get_active_default_mm_loras(request)
            if default_mm_lora is not None:
582
                return default_mm_lora
583
584

        if self._is_model_supported(request.model):
585
            return None
586

587
        # if _check_model has been called earlier, this will be unreachable
588
        raise ValueError(f"The model `{request.model}` does not exist.")
589

590
591
592
593
594
595
596
597
598
599
600
    def _get_message_types(self, request: AnyRequest) -> set[str]:
        """Retrieve the set of types from message content dicts up
        until `_`; we use this to match potential multimodal data
        with default per modality loras.
        """
        message_types: set[str] = set()

        if not hasattr(request, "messages"):
            return message_types

        for message in request.messages:
601
602
603
604
605
            if (
                isinstance(message, dict)
                and "content" in message
                and isinstance(message["content"], list)
            ):
606
607
608
609
610
                for content_dict in message["content"]:
                    if "type" in content_dict:
                        message_types.add(content_dict["type"].split("_")[0])
        return message_types

611
    async def _normalize_prompt_text_to_input(
612
613
614
        self,
        request: AnyRequest,
        prompt: str,
615
        tokenizer: AnyTokenizer,
616
617
        add_special_tokens: bool,
    ) -> TextTokensPrompt:
618
619
        async_tokenizer = self._get_async_tokenizer(tokenizer)

620
621
622
623
        if (
            self.model_config.encoder_config is not None
            and self.model_config.encoder_config.get("do_lower_case", False)
        ):
624
625
            prompt = prompt.lower()

626
        truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
627

628
        if truncate_prompt_tokens is None:
629
            encoded = await async_tokenizer(
630
631
                prompt, add_special_tokens=add_special_tokens
            )
632
633
        elif truncate_prompt_tokens < 0:
            # Negative means we cap at the model's max length
634
635
636
637
            encoded = await async_tokenizer(
                prompt,
                add_special_tokens=add_special_tokens,
                truncation=True,
638
639
                max_length=self.max_model_len,
            )
640
        else:
641
642
643
644
            encoded = await async_tokenizer(
                prompt,
                add_special_tokens=add_special_tokens,
                truncation=True,
645
646
                max_length=truncate_prompt_tokens,
            )
647
648
649
650
651
652

        input_ids = encoded.input_ids
        input_text = prompt

        return self._validate_input(request, input_ids, input_text)

653
    async def _normalize_prompt_tokens_to_input(
654
655
        self,
        request: AnyRequest,
656
        prompt_ids: list[int],
657
        tokenizer: Optional[AnyTokenizer],
658
    ) -> TextTokensPrompt:
659
        truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
660

661
        if truncate_prompt_tokens is None:
662
            input_ids = prompt_ids
663
        elif truncate_prompt_tokens < 0:
664
            input_ids = prompt_ids[-self.max_model_len :]
665
666
667
        else:
            input_ids = prompt_ids[-truncate_prompt_tokens:]

668
669
670
671
672
        if tokenizer is None:
            input_text = ""
        else:
            async_tokenizer = self._get_async_tokenizer(tokenizer)
            input_text = await async_tokenizer.decode(input_ids)
673

674
675
676
677
678
        return self._validate_input(request, input_ids, input_text)

    def _validate_input(
        self,
        request: AnyRequest,
679
        input_ids: list[int],
680
681
        input_text: str,
    ) -> TextTokensPrompt:
682
683
        token_num = len(input_ids)

684
685
        # Note: EmbeddingRequest, ClassificationRequest,
        # and ScoreRequest doesn't have max_tokens
686
        if isinstance(
687
            request,
688
689
690
691
692
693
694
695
            (
                EmbeddingChatRequest,
                EmbeddingCompletionRequest,
                ScoreRequest,
                RerankRequest,
                ClassificationRequest,
            ),
        ):
696
697
            # Note: input length can be up to the entire model context length
            # since these requests don't generate tokens.
698
            if token_num > self.max_model_len:
699
700
                operations: dict[type[AnyRequest], str] = {
                    ScoreRequest: "score",
701
                    ClassificationRequest: "classification",
702
                }
703
                operation = operations.get(type(request), "embedding generation")
704
705
706
                raise ValueError(
                    f"This model's maximum context length is "
                    f"{self.max_model_len} tokens. However, you requested "
707
                    f"{token_num} tokens in the input for {operation}. "
708
709
710
                    f"Please reduce the length of the input."
                )
            return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
711

712
713
        # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
        # and does not require model context length validation
714
        if isinstance(
715
716
            request,
            (TokenizeCompletionRequest, TokenizeChatRequest, DetokenizeRequest),
717
        ):
718
            return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
719

720
721
722
723
724
        # chat completion endpoint supports max_completion_tokens
        if isinstance(request, ChatCompletionRequest):
            # TODO(#9845): remove max_tokens when field dropped from OpenAI API
            max_tokens = request.max_completion_tokens or request.max_tokens
        else:
725
            max_tokens = getattr(request, "max_tokens", None)
726
727
728
729

        # Note: input length can be up to model context length - 1 for
        # completion-like requests.
        if token_num >= self.max_model_len:
730
            raise ValueError(
731
                f"This model's maximum context length is "
732
733
                f"{self.max_model_len} tokens. However, your request has "
                f"{token_num} input tokens. Please reduce the length of "
734
735
                "the input messages."
            )
736

737
        if max_tokens is not None and token_num + max_tokens > self.max_model_len:
738
739
740
741
742
            raise ValueError(
                "'max_tokens' or 'max_completion_tokens' is too large: "
                f"{max_tokens}. This model's maximum context length is "
                f"{self.max_model_len} tokens and your request has "
                f"{token_num} input tokens ({max_tokens} > {self.max_model_len}"
743
744
                f" - {token_num})."
            )
745
746
747

        return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)

748
    async def _tokenize_prompt_input_async(
749
750
751
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
752
        prompt_input: Union[str, list[int]],
753
754
755
        add_special_tokens: bool = True,
    ) -> TextTokensPrompt:
        """
756
        A simpler implementation that tokenizes a single prompt input.
757
        """
758
        async for result in self._tokenize_prompt_inputs_async(
759
760
            request,
            tokenizer,
761
            [prompt_input],
762
            add_special_tokens=add_special_tokens,
763
764
765
        ):
            return result
        raise ValueError("No results yielded from tokenization")
766

767
    async def _tokenize_prompt_inputs_async(
768
769
770
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
771
        prompt_inputs: Iterable[Union[str, list[int]]],
772
        add_special_tokens: bool = True,
773
    ) -> AsyncGenerator[TextTokensPrompt, None]:
774
        """
775
        A simpler implementation that tokenizes multiple prompt inputs.
776
        """
777
778
        for prompt in prompt_inputs:
            if isinstance(prompt, str):
779
                yield await self._normalize_prompt_text_to_input(
780
                    request,
781
782
                    prompt=prompt,
                    tokenizer=tokenizer,
783
784
785
                    add_special_tokens=add_special_tokens,
                )
            else:
786
                yield await self._normalize_prompt_tokens_to_input(
787
                    request,
788
789
                    prompt_ids=prompt,
                    tokenizer=tokenizer,
790
791
                )

792
793
794
795
796
797
798
    def _validate_chat_template(
        self,
        request_chat_template: Optional[str],
        chat_template_kwargs: Optional[dict[str, Any]],
        trust_request_chat_template: bool,
    ) -> Optional[ErrorResponse]:
        if not trust_request_chat_template and (
799
800
801
802
803
804
            request_chat_template is not None
            or (
                chat_template_kwargs
                and chat_template_kwargs.get("chat_template") is not None
            )
        ):
805
806
807
            return self.create_error_response(
                "Chat template is passed with request, but "
                "--trust-request-chat-template is not set. "
808
809
                "Refused request with untrusted chat template."
            )
810
811
        return None

812
813
    async def _preprocess_chat(
        self,
814
        request: Union[ChatLikeRequest, ResponsesRequest],
815
        tokenizer: AnyTokenizer,
816
        messages: list[ChatCompletionMessageParam],
817
818
        chat_template: Optional[str],
        chat_template_content_format: ChatTemplateContentFormatOption,
819
820
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
821
822
823
        tool_dicts: Optional[list[dict[str, Any]]] = None,
        documents: Optional[list[dict[str, str]]] = None,
        chat_template_kwargs: Optional[dict[str, Any]] = None,
824
825
        tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None,
        add_special_tokens: bool = False,
826
    ) -> tuple[
827
828
829
        list[ConversationMessage],
        Sequence[RequestPrompt],
        list[EngineTokensPrompt],
830
    ]:
831
832
        model_config = self.model_config

833
834
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
835
            tool_dicts,
836
837
            chat_template_content_format,
            tokenizer,
838
            model_config=model_config,
839
        )
840
        conversation, mm_data_future, mm_uuids = parse_chat_messages_futures(
841
            messages,
842
            model_config,
843
            tokenizer,
844
            content_format=resolved_content_format,
845
846
        )

847
        _chat_template_kwargs: dict[str, Any] = dict(
848
849
850
851
852
853
854
855
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tool_dicts,
            documents=documents,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

856
        request_prompt: Union[str, list[int]]
857
858
859
860

        if tokenizer is None:
            request_prompt = "placeholder"
        elif isinstance(tokenizer, MistralTokenizer):
861
            request_prompt = await self._apply_mistral_chat_template_async(
862
863
                tokenizer,
                messages=messages,
864
                **_chat_template_kwargs,
865
866
867
            )
        else:
            request_prompt = apply_hf_chat_template(
868
                tokenizer=tokenizer,
869
                conversation=conversation,
870
                model_config=model_config,
871
                **_chat_template_kwargs,
872
873
874
875
            )

        mm_data = await mm_data_future

876
877
878
        # tool parsing is done only if a tool_parser has been set and if
        # tool_choice is not "none" (if tool_choice is "none" but a tool_parser
        # is set, we want to prevent parsing a tool_call hallucinated by the LLM
879
880
881
        should_parse_tools = tool_parser is not None and (
            hasattr(request, "tool_choice") and request.tool_choice != "none"
        )
882
883

        if should_parse_tools:
884
885
886
887
            if not isinstance(request, ChatCompletionRequest):
                msg = "Tool usage is only supported for Chat Completions API"
                raise NotImplementedError(msg)

888
            request = tool_parser(tokenizer).adjust_request(  # type: ignore
889
890
                request=request
            )
891

892
893
        if tokenizer is None:
            assert isinstance(request_prompt, str), (
894
895
                "Prompt has to be a string",
                "when the tokenizer is not initialised",
896
            )
897
898
899
            prompt_inputs = TextTokensPrompt(
                prompt=request_prompt, prompt_token_ids=[1]
            )
900
        elif isinstance(request_prompt, str):
901
            prompt_inputs = await self._tokenize_prompt_input_async(
902
903
904
905
906
907
908
909
                request,
                tokenizer,
                request_prompt,
                add_special_tokens=add_special_tokens,
            )
        else:
            # For MistralTokenizer
            assert is_list_of(request_prompt, int), (
910
911
                "Prompt has to be either a string or a list of token ids"
            )
912
913
            prompt_inputs = TextTokensPrompt(
                prompt=tokenizer.decode(request_prompt),
914
915
                prompt_token_ids=request_prompt,
            )
916

917
        engine_prompt = EngineTokensPrompt(
918
919
            prompt_token_ids=prompt_inputs["prompt_token_ids"]
        )
920
921
        if mm_data is not None:
            engine_prompt["multi_modal_data"] = mm_data
922
923
924
925

        if mm_uuids is not None:
            engine_prompt["multi_modal_uuids"] = mm_uuids

926
927
        if request.mm_processor_kwargs is not None:
            engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
928

929
930
931
        if hasattr(request, "cache_salt") and request.cache_salt is not None:
            engine_prompt["cache_salt"] = request.cache_salt

932
933
        return conversation, [request_prompt], [engine_prompt]

934
935
936
937
    async def _process_inputs(
        self,
        request_id: str,
        engine_prompt: PromptType,
938
        params: Union[SamplingParams, PoolingParams],
939
940
941
942
943
        *,
        lora_request: Optional[LoRARequest],
        trace_headers: Optional[Mapping[str, str]],
        priority: int,
    ) -> tuple[EngineCoreRequest, dict[str, Any]]:
944
        """Use the Processor to process inputs for AsyncLLM."""
945
        tokenization_kwargs: dict[str, Any] = {}
946
947
948
        _validate_truncation_size(
            self.max_model_len, params.truncate_prompt_tokens, tokenization_kwargs
        )
949
950
951
952
953

        processor = await self._get_processor()
        engine_request = processor.process_inputs(
            request_id,
            engine_prompt,
954
            params,
955
956
957
958
959
960
961
            lora_request=lora_request,
            tokenization_kwargs=tokenization_kwargs,
            trace_headers=trace_headers,
            priority=priority,
        )
        return engine_request, tokenization_kwargs

962
963
964
965
966
967
968
969
970
971
972
    async def _generate_with_builtin_tools(
        self,
        request_id: str,
        request_prompt: RequestPrompt,
        engine_prompt: EngineTokensPrompt,
        sampling_params: SamplingParams,
        context: ConversationContext,
        lora_request: Optional[LoRARequest] = None,
        priority: int = 0,
        **kwargs,
    ):
973
        prompt_text, _, _ = self._get_prompt_components(request_prompt)
974
975
976
977
978
979
980
981
        orig_priority = priority
        while True:
            self._log_inputs(
                request_id,
                request_prompt,
                params=sampling_params,
                lora_request=lora_request,
            )
982
            trace_headers = kwargs.get("trace_headers")
983
            engine_request, tokenization_kwargs = await self._process_inputs(
984
                request_id,
985
986
                engine_prompt,
                sampling_params,
987
988
989
                lora_request=lora_request,
                trace_headers=trace_headers,
                priority=priority,
990
            )
991
992
993
994

            generator = self.engine_client.generate(
                engine_request,
                sampling_params,
995
996
997
                request_id,
                lora_request=lora_request,
                priority=priority,
998
999
                prompt_text=prompt_text,
                tokenization_kwargs=tokenization_kwargs,
1000
1001
                **kwargs,
            )
1002

1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
            async for res in generator:
                context.append_output(res)
                # NOTE(woosuk): The stop condition is handled by the engine.
                yield context

            if not context.need_builtin_tool_call():
                # The model did not ask for a tool call, so we're done.
                break

            # Call the tool and update the context with the result.
            tool_output = await context.call_tool()
            context.append_output(tool_output)

            # TODO: uncomment this and enable tool output streaming
            # yield context

            # Create inputs for the next turn.
            # Render the next prompt token ids.
            prompt_token_ids = context.render_for_completion()
1022
            engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids)
1023
1024
            request_prompt = prompt_token_ids
            # Update the sampling params.
1025
            sampling_params.max_tokens = self.max_model_len - len(prompt_token_ids)
1026
1027
1028
            # OPTIMIZATION
            priority = orig_priority - 1

1029
1030
    def _get_prompt_components(
        self,
1031
        prompt: Union[RequestPrompt, PromptType],
1032
    ) -> PromptComponents:
1033
1034
        if isinstance(prompt, list):
            return PromptComponents(token_ids=prompt)
1035

1036
        return get_prompt_components(prompt)  # type: ignore[arg-type]
1037

1038
1039
1040
    def _log_inputs(
        self,
        request_id: str,
1041
        inputs: Union[RequestPrompt, PromptType],
1042
        params: Optional[Union[SamplingParams, PoolingParams, BeamSearchParams]],
1043
1044
1045
1046
        lora_request: Optional[LoRARequest],
    ) -> None:
        if self.request_logger is None:
            return
1047

1048
        prompt, prompt_token_ids, prompt_embeds = self._get_prompt_components(inputs)
1049
1050
1051
1052
1053

        self.request_logger.log_inputs(
            request_id,
            prompt,
            prompt_token_ids,
1054
            prompt_embeds,
1055
1056
1057
            params=params,
            lora_request=lora_request,
        )
1058

1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
    async def _get_trace_headers(
        self,
        headers: Headers,
    ) -> Optional[Mapping[str, str]]:
        is_tracing_enabled = await self.engine_client.is_tracing_enabled()

        if is_tracing_enabled:
            return extract_trace_headers(headers)

        if contains_trace_headers(headers):
            log_tracing_disabled_warning()

        return None

1073
    @staticmethod
1074
1075
1076
    def _base_request_id(
        raw_request: Optional[Request], default: Optional[str] = None
    ) -> Optional[str]:
1077
1078
        """Pulls the request id to use from a header, if provided"""
        default = default or random_uuid()
1079
1080
1081
1082
        if raw_request is None:
            return default

        return raw_request.headers.get("X-Request-Id", default)
1083

1084
    @staticmethod
1085
1086
1087
1088
1089
1090
    def _get_decoded_token(
        logprob: Logprob,
        token_id: int,
        tokenizer: AnyTokenizer,
        return_as_token_id: bool = False,
    ) -> str:
1091
1092
1093
        if return_as_token_id:
            return f"token_id:{token_id}"

1094
1095
        if logprob.decoded_token is not None:
            return logprob.decoded_token
1096
        return tokenizer.decode(token_id)
1097

1098
    def _is_model_supported(self, model_name: Optional[str]) -> bool:
1099
1100
        if not model_name:
            return True
1101
        return self.models.is_base_model(model_name)
1102

1103
1104

def clamp_prompt_logprobs(
1105
1106
    prompt_logprobs: Union[PromptLogprobs, None],
) -> Union[PromptLogprobs, None]:
1107
1108
1109
1110
1111
1112
1113
    if prompt_logprobs is None:
        return prompt_logprobs

    for logprob_dict in prompt_logprobs:
        if logprob_dict is None:
            continue
        for logprob_values in logprob_dict.values():
1114
            if logprob_values.logprob == float("-inf"):
1115
1116
                logprob_values.logprob = -9999.0
    return prompt_logprobs