serving_engine.py 37.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import json
4
import sys
5
import time
6
import traceback
7
8
from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor
9
from http import HTTPStatus
10
from typing import Any, Callable, ClassVar, Generic, Optional, TypeVar, Union
11

12
import torch
13
from fastapi import Request
14
from pydantic import BaseModel, ConfigDict, Field
15
from starlette.datastructures import Headers
16
17
from typing_extensions import TypeIs

18
19
20
21
22
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.processor import Processor

23
24
25
26
27
if sys.version_info >= (3, 12):
    from typing import TypedDict
else:
    from typing_extensions import TypedDict

28
import vllm.envs as envs
29
from vllm.config import ModelConfig
30
from vllm.engine.protocol import EngineClient
31
32
33
34
35
36
37
38
39
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ChatTemplateContentFormatOption,
    ConversationMessage,
    apply_hf_chat_template,
    apply_mistral_chat_template,
    parse_chat_messages_futures,
    resolve_chat_template_content_format,
)
40
from vllm.entrypoints.context import ConversationContext
41
from vllm.entrypoints.logger import RequestLogger
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from vllm.entrypoints.openai.protocol import (
    ChatCompletionRequest,
    ChatCompletionResponse,
    ClassificationRequest,
    ClassificationResponse,
    CompletionRequest,
    CompletionResponse,
    DetokenizeRequest,
    EmbeddingChatRequest,
    EmbeddingCompletionRequest,
    EmbeddingRequest,
    EmbeddingResponse,
    ErrorInfo,
    ErrorResponse,
    IOProcessorRequest,
    PoolingResponse,
    RerankRequest,
    ResponsesRequest,
    ScoreRequest,
    ScoreResponse,
    TokenizeChatRequest,
    TokenizeCompletionRequest,
    TokenizeResponse,
    TranscriptionRequest,
    TranscriptionResponse,
    TranslationRequest,
)
69
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
70
from vllm.entrypoints.openai.tool_parsers import ToolParser
71
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig
72
from vllm.inputs.data import PromptType
73
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
74
from vllm.inputs.parse import PromptComponents, get_prompt_components
75
from vllm.logger import init_logger
76
from vllm.logprobs import Logprob, PromptLogprobs
77
from vllm.lora.request import LoRARequest
78
from vllm.multimodal import (  # noqa: F401 - Required to resolve Pydantic error in RequestProcessingMixin
79
80
81
    MultiModalDataDict,
    MultiModalUUIDDict,
)
82
from vllm.outputs import PoolingRequestOutput, RequestOutput
83
from vllm.pooling_params import PoolingParams
84
from vllm.sampling_params import BeamSearchParams, SamplingParams
85
86
87
88
89
from vllm.tracing import (
    contains_trace_headers,
    extract_trace_headers,
    log_tracing_disabled_warning,
)
90
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
91
92
93
94
95
96
97
from vllm.utils import (
    AsyncMicrobatchTokenizer,
    is_list_of,
    make_async,
    merge_async_iterators,
    random_uuid,
)
98
99
100

logger = init_logger(__name__)

101
102
103
104
105
106
107
108
109
CompletionLikeRequest = Union[
    CompletionRequest,
    DetokenizeRequest,
    EmbeddingCompletionRequest,
    RerankRequest,
    ClassificationRequest,
    ScoreRequest,
    TokenizeCompletionRequest,
]
110

111
112
113
ChatLikeRequest = Union[
    ChatCompletionRequest, EmbeddingChatRequest, TokenizeChatRequest
]
114
SpeechToTextRequest = Union[TranscriptionRequest, TranslationRequest]
115
116
117
118
119
120
121
AnyRequest = Union[
    CompletionLikeRequest,
    ChatLikeRequest,
    SpeechToTextRequest,
    ResponsesRequest,
    IOProcessorRequest,
]
122

123
124
125
126
127
128
129
130
131
132
133
AnyResponse = Union[
    CompletionResponse,
    ChatCompletionResponse,
    EmbeddingResponse,
    TranscriptionResponse,
    TokenizeResponse,
    PoolingResponse,
    ClassificationResponse,
    ScoreResponse,
]

134
135
136

class TextTokensPrompt(TypedDict):
    prompt: str
137
    prompt_token_ids: list[int]
138
139


140
141
142
143
144
145
146
147
class EmbedsPrompt(TypedDict):
    prompt_embeds: torch.Tensor


RequestPrompt = Union[list[int], str, TextTokensPrompt, EmbedsPrompt]


def is_text_tokens_prompt(prompt: RequestPrompt) -> TypeIs[TextTokensPrompt]:
148
149
150
151
152
    return (
        isinstance(prompt, dict)
        and "prompt_token_ids" in prompt
        and "prompt_embeds" not in prompt
    )
153
154
155


def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]:
156
157
158
159
160
    return (
        isinstance(prompt, dict)
        and "prompt_token_ids" not in prompt
        and "prompt_embeds" in prompt
    )
161

162

163
164
165
166
167
RequestT = TypeVar("RequestT", bound=AnyRequest)


class RequestProcessingMixin(BaseModel):
    """
168
    Mixin for request processing,
169
170
    handling prompt preparation and engine input.
    """
171

172
    request_prompts: Optional[Sequence[RequestPrompt]] = []
173
    engine_prompts: Optional[list[EngineTokensPrompt]] = []
174
175
176
177
178
179

    model_config = ConfigDict(arbitrary_types_allowed=True)


class ResponseGenerationMixin(BaseModel):
    """
180
    Mixin for response generation,
181
182
    managing result generators and final batch results.
    """
183

184
185
186
    result_generator: Optional[
        AsyncGenerator[tuple[int, Union[RequestOutput, PoolingRequestOutput]], None]
    ] = None
187
    final_res_batch: list[Union[RequestOutput, PoolingRequestOutput]] = Field(
188
189
        default_factory=list
    )
190
191
192
193

    model_config = ConfigDict(arbitrary_types_allowed=True)


194
class ServeContext(
195
196
197
198
    RequestProcessingMixin,
    ResponseGenerationMixin,
    BaseModel,
    Generic[RequestT],
199
):
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
    # Shared across all requests
    request: RequestT
    raw_request: Optional[Request] = None
    model_name: str
    request_id: str
    created_time: int = Field(default_factory=lambda: int(time.time()))
    lora_request: Optional[LoRARequest] = None

    # Shared across most requests
    tokenizer: Optional[AnyTokenizer] = None

    # `protected_namespaces` resolves Pydantic v2's warning
    # on conflict with protected namespace "model_"
    model_config = ConfigDict(
        protected_namespaces=(),
        arbitrary_types_allowed=True,
    )


ClassificationServeContext = ServeContext[ClassificationRequest]


class EmbeddingServeContext(ServeContext[EmbeddingRequest]):
    chat_template: Optional[str] = None
    chat_template_content_format: ChatTemplateContentFormatOption


# Used to resolve the Pydantic error related to
# forward reference of MultiModalDataDict in TokensPrompt
RequestProcessingMixin.model_rebuild()
ServeContext.model_rebuild()
ClassificationServeContext.model_rebuild()
EmbeddingServeContext.model_rebuild()

234

235
class OpenAIServing:
236
237
238
239
    request_id_prefix: ClassVar[str] = """
    A short string prepended to every request’s ID (e.g. "embd", "classify")
    so you can easily tell “this ID came from Embedding vs Classification.”
    """
240

241
242
    def __init__(
        self,
243
        engine_client: EngineClient,
244
        model_config: ModelConfig,
245
        models: OpenAIServingModels,
246
247
        *,
        request_logger: Optional[RequestLogger],
248
        return_tokens_as_token_ids: bool = False,
249
        enable_force_include_usage: bool = False,
250
        log_error_stack: bool = False,
251
    ):
252
253
        super().__init__()

254
        self.engine_client = engine_client
255
        self.model_config = model_config
256
257
        self.max_model_len = model_config.max_model_len

258
        self.models = models
259

260
        self.request_logger = request_logger
261
        self.return_tokens_as_token_ids = return_tokens_as_token_ids
262
        self.enable_force_include_usage = enable_force_include_usage
263

264
        self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
265
        self._apply_mistral_chat_template_async = make_async(
266
267
            apply_mistral_chat_template, executor=self._tokenizer_executor
        )
268

269
        self._async_tokenizer_pool: dict[AnyTokenizer, AsyncMicrobatchTokenizer] = {}
270
        self.log_error_stack = log_error_stack
271

272
273
274
275
276
277
278
279
280
281
    async def _get_processor(self) -> Processor:
        if not hasattr(self, "_processor"):
            vllm_config = await self.engine_client.get_vllm_config()
            if self.model_config.skip_tokenizer_init:
                tokenizer = None
            else:
                tokenizer = init_tokenizer_from_configs(self.model_config)
            self._processor = Processor(vllm_config, tokenizer)
        return self._processor

282
283
284
285
286
287
288
289
    def _get_renderer(self, tokenizer: Optional[AnyTokenizer]) -> BaseRenderer:
        """
        Get a Renderer instance with the provided tokenizer.
        Uses shared async tokenizer pool for efficiency.
        """
        return CompletionRenderer(
            model_config=self.model_config,
            tokenizer=tokenizer,
290
291
            async_tokenizer_pool=self._async_tokenizer_pool,
        )
292

293
294
295
296
297
298
299
300
301
302
303
304
305
    def _build_render_config(
        self,
        request: Any,
    ) -> RenderConfig:
        """
        Build and return a `RenderConfig` for an endpoint.

        Used by the renderer to control how prompts are prepared
        (e.g., tokenization and length handling). Endpoints should
        implement this with logic appropriate to their request type.
        """
        raise NotImplementedError

306
307
    def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer:
        """
308
        Return (and cache) an `AsyncMicrobatchTokenizer` bound to the
309
310
311
312
313
314
315
        given tokenizer.
        """
        async_tokenizer = self._async_tokenizer_pool.get(tokenizer)
        if async_tokenizer is None:
            async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
            self._async_tokenizer_pool[tokenizer] = async_tokenizer
        return async_tokenizer
316

317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
    async def _preprocess(
        self,
        ctx: ServeContext,
    ) -> Optional[ErrorResponse]:
        """
        Default preprocessing hook. Subclasses may override
        to prepare `ctx` (classification, embedding, etc.).
        """
        return None

    def _build_response(
        self,
        ctx: ServeContext,
    ) -> Union[AnyResponse, ErrorResponse]:
        """
        Default response builder. Subclass may override this method
        to return the appropriate response object.
        """
        return self.create_error_response("unimplemented endpoint")

    async def handle(
        self,
        ctx: ServeContext,
    ) -> Union[AnyResponse, ErrorResponse]:
        generation: AsyncGenerator[Union[AnyResponse, ErrorResponse], None]
        generation = self._pipeline(ctx)

        async for response in generation:
            return response

        return self.create_error_response("No response yielded from pipeline")

    async def _pipeline(
        self,
        ctx: ServeContext,
    ) -> AsyncGenerator[Union[AnyResponse, ErrorResponse], None]:
        """Execute the request processing pipeline yielding responses."""
        if error := await self._check_model(ctx.request):
            yield error
        if error := self._validate_request(ctx):
            yield error

        preprocess_ret = await self._preprocess(ctx)
        if isinstance(preprocess_ret, ErrorResponse):
            yield preprocess_ret

        generators_ret = await self._prepare_generators(ctx)
        if isinstance(generators_ret, ErrorResponse):
            yield generators_ret

        collect_ret = await self._collect_batch(ctx)
        if isinstance(collect_ret, ErrorResponse):
            yield collect_ret

        yield self._build_response(ctx)

    def _validate_request(self, ctx: ServeContext) -> Optional[ErrorResponse]:
374
        truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None)
375

376
377
378
379
        if (
            truncate_prompt_tokens is not None
            and truncate_prompt_tokens > self.max_model_len
        ):
380
381
382
            return self.create_error_response(
                "truncate_prompt_tokens value is "
                "greater than max_model_len."
383
384
                " Please, select a smaller truncation size."
            )
385
386
        return None

387
388
389
390
391
392
    def _create_pooling_params(
        self,
        ctx: ServeContext,
    ) -> Union[PoolingParams, ErrorResponse]:
        if not hasattr(ctx.request, "to_pooling_params"):
            return self.create_error_response(
393
394
                "Request type does not support pooling parameters"
            )
395
396
397

        return ctx.request.to_pooling_params()

398
399
400
401
402
    async def _prepare_generators(
        self,
        ctx: ServeContext,
    ) -> Optional[ErrorResponse]:
        """Schedule the request and get the result generator."""
403
404
405
        generators: list[
            AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]
        ] = []
406
407

        try:
408
409
410
411
412
            trace_headers = (
                None
                if ctx.raw_request is None
                else await self._get_trace_headers(ctx.raw_request.headers)
            )
413

414
415
416
            pooling_params = self._create_pooling_params(ctx)
            if isinstance(pooling_params, ErrorResponse):
                return pooling_params
417
418

            if ctx.engine_prompts is None:
419
                return self.create_error_response("Engine prompts not available")
420
421
422
423

            for i, engine_prompt in enumerate(ctx.engine_prompts):
                request_id_item = f"{ctx.request_id}-{i}"

424
425
                self._log_inputs(
                    request_id_item,
426
                    engine_prompt,
427
428
429
                    params=pooling_params,
                    lora_request=ctx.lora_request,
                )
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456

                generator = self.engine_client.encode(
                    engine_prompt,
                    pooling_params,
                    request_id_item,
                    lora_request=ctx.lora_request,
                    trace_headers=trace_headers,
                    priority=getattr(ctx.request, "priority", 0),
                )

                generators.append(generator)

            ctx.result_generator = merge_async_iterators(*generators)

            return None

        except Exception as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))

    async def _collect_batch(
        self,
        ctx: ServeContext,
    ) -> Optional[ErrorResponse]:
        """Collect batch results from the result generator."""
        try:
            if ctx.engine_prompts is None:
457
                return self.create_error_response("Engine prompts not available")
458
459

            num_prompts = len(ctx.engine_prompts)
460
            final_res_batch: list[Optional[Union[RequestOutput, PoolingRequestOutput]]]
461
462
463
            final_res_batch = [None] * num_prompts

            if ctx.result_generator is None:
464
                return self.create_error_response("Result generator not available")
465
466
467
468
469
470

            async for i, res in ctx.result_generator:
                final_res_batch[i] = res

            if None in final_res_batch:
                return self.create_error_response(
471
472
                    "Failed to generate results for all prompts"
                )
473

474
            ctx.final_res_batch = [res for res in final_res_batch if res is not None]
475
476
477
478
479
480

            return None

        except Exception as e:
            return self.create_error_response(str(e))

481
    def create_error_response(
482
483
484
485
486
        self,
        message: str,
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
    ) -> ErrorResponse:
487
488
489
490
491
492
        if self.log_error_stack:
            exc_type, _, _ = sys.exc_info()
            if exc_type is not None:
                traceback.print_exc()
            else:
                traceback.print_stack()
493
494
495
        return ErrorResponse(
            error=ErrorInfo(message=message, type=err_type, code=status_code.value)
        )
496

497
    def create_streaming_error_response(
498
499
500
501
502
        self,
        message: str,
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
    ) -> str:
503
        json_str = json.dumps(
504
505
506
507
            self.create_error_response(
                message=message, err_type=err_type, status_code=status_code
            ).model_dump()
        )
508
509
        return json_str

510
    async def _check_model(
511
512
        self,
        request: AnyRequest,
513
    ) -> Optional[ErrorResponse]:
514
515
        error_response = None

516
        if self._is_model_supported(request.model):
517
            return None
518
        if request.model in self.models.lora_requests:
519
            return None
520
521
522
523
524
        if (
            envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING
            and request.model
            and (load_result := await self.models.resolve_lora(request.model))
        ):
525
526
            if isinstance(load_result, LoRARequest):
                return None
527
528
529
530
            if (
                isinstance(load_result, ErrorResponse)
                and load_result.error.code == HTTPStatus.BAD_REQUEST.value
            ):
531
532
533
                error_response = load_result

        return error_response or self.create_error_response(
534
535
            message=f"The model `{request.model}` does not exist.",
            err_type="NotFoundError",
536
537
            status_code=HTTPStatus.NOT_FOUND,
        )
538

539
    def _get_active_default_mm_loras(
540
541
        self, request: AnyRequest
    ) -> Optional[LoRARequest]:
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
        """Determine if there are any active default multimodal loras."""
        # TODO: Currently this is only enabled for chat completions
        # to be better aligned with only being enabled for .generate
        # when run offline. It would be nice to support additional
        # tasks types in the future.
        message_types = self._get_message_types(request)
        default_mm_loras = set()

        for lora in self.models.lora_requests.values():
            # Best effort match for default multimodal lora adapters;
            # There is probably a better way to do this, but currently
            # this matches against the set of 'types' in any content lists
            # up until '_', e.g., to match audio_url -> audio
            if lora.lora_name in message_types:
                default_mm_loras.add(lora)

        # Currently only support default modality specific loras if
        # we have exactly one lora matched on the request.
        if len(default_mm_loras) == 1:
            return default_mm_loras.pop()
        return None

564
    def _maybe_get_adapters(
565
566
567
        self,
        request: AnyRequest,
        supports_default_mm_loras: bool = False,
568
    ) -> Optional[LoRARequest]:
569
        if request.model in self.models.lora_requests:
570
            return self.models.lora_requests[request.model]
571
572
573
574
575
576

        # Currently only support default modality specific loras
        # if we have exactly one lora matched on the request.
        if supports_default_mm_loras:
            default_mm_lora = self._get_active_default_mm_loras(request)
            if default_mm_lora is not None:
577
                return default_mm_lora
578
579

        if self._is_model_supported(request.model):
580
            return None
581

582
        # if _check_model has been called earlier, this will be unreachable
583
        raise ValueError(f"The model `{request.model}` does not exist.")
584

585
586
587
588
589
590
591
592
593
594
595
    def _get_message_types(self, request: AnyRequest) -> set[str]:
        """Retrieve the set of types from message content dicts up
        until `_`; we use this to match potential multimodal data
        with default per modality loras.
        """
        message_types: set[str] = set()

        if not hasattr(request, "messages"):
            return message_types

        for message in request.messages:
596
597
598
599
600
            if (
                isinstance(message, dict)
                and "content" in message
                and isinstance(message["content"], list)
            ):
601
602
603
604
605
                for content_dict in message["content"]:
                    if "type" in content_dict:
                        message_types.add(content_dict["type"].split("_")[0])
        return message_types

606
    async def _normalize_prompt_text_to_input(
607
608
609
        self,
        request: AnyRequest,
        prompt: str,
610
        tokenizer: AnyTokenizer,
611
612
        add_special_tokens: bool,
    ) -> TextTokensPrompt:
613
614
        async_tokenizer = self._get_async_tokenizer(tokenizer)

615
616
617
618
        if (
            self.model_config.encoder_config is not None
            and self.model_config.encoder_config.get("do_lower_case", False)
        ):
619
620
            prompt = prompt.lower()

621
        truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
622

623
        if truncate_prompt_tokens is None:
624
            encoded = await async_tokenizer(
625
626
                prompt, add_special_tokens=add_special_tokens
            )
627
628
        elif truncate_prompt_tokens < 0:
            # Negative means we cap at the model's max length
629
630
631
632
            encoded = await async_tokenizer(
                prompt,
                add_special_tokens=add_special_tokens,
                truncation=True,
633
634
                max_length=self.max_model_len,
            )
635
        else:
636
637
638
639
            encoded = await async_tokenizer(
                prompt,
                add_special_tokens=add_special_tokens,
                truncation=True,
640
641
                max_length=truncate_prompt_tokens,
            )
642
643
644
645
646
647

        input_ids = encoded.input_ids
        input_text = prompt

        return self._validate_input(request, input_ids, input_text)

648
    async def _normalize_prompt_tokens_to_input(
649
650
        self,
        request: AnyRequest,
651
        prompt_ids: list[int],
652
        tokenizer: Optional[AnyTokenizer],
653
    ) -> TextTokensPrompt:
654
        truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
655

656
        if truncate_prompt_tokens is None:
657
            input_ids = prompt_ids
658
        elif truncate_prompt_tokens < 0:
659
            input_ids = prompt_ids[-self.max_model_len :]
660
661
662
        else:
            input_ids = prompt_ids[-truncate_prompt_tokens:]

663
664
665
666
667
        if tokenizer is None:
            input_text = ""
        else:
            async_tokenizer = self._get_async_tokenizer(tokenizer)
            input_text = await async_tokenizer.decode(input_ids)
668

669
670
671
672
673
        return self._validate_input(request, input_ids, input_text)

    def _validate_input(
        self,
        request: AnyRequest,
674
        input_ids: list[int],
675
676
        input_text: str,
    ) -> TextTokensPrompt:
677
678
        token_num = len(input_ids)

679
680
        # Note: EmbeddingRequest, ClassificationRequest,
        # and ScoreRequest doesn't have max_tokens
681
        if isinstance(
682
            request,
683
684
685
686
687
688
689
690
            (
                EmbeddingChatRequest,
                EmbeddingCompletionRequest,
                ScoreRequest,
                RerankRequest,
                ClassificationRequest,
            ),
        ):
691
692
            # Note: input length can be up to the entire model context length
            # since these requests don't generate tokens.
693
            if token_num > self.max_model_len:
694
695
                operations: dict[type[AnyRequest], str] = {
                    ScoreRequest: "score",
696
                    ClassificationRequest: "classification",
697
                }
698
                operation = operations.get(type(request), "embedding generation")
699
700
701
                raise ValueError(
                    f"This model's maximum context length is "
                    f"{self.max_model_len} tokens. However, you requested "
702
                    f"{token_num} tokens in the input for {operation}. "
703
704
705
                    f"Please reduce the length of the input."
                )
            return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
706

707
708
        # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
        # and does not require model context length validation
709
        if isinstance(
710
711
            request,
            (TokenizeCompletionRequest, TokenizeChatRequest, DetokenizeRequest),
712
        ):
713
            return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
714

715
716
717
718
719
        # chat completion endpoint supports max_completion_tokens
        if isinstance(request, ChatCompletionRequest):
            # TODO(#9845): remove max_tokens when field dropped from OpenAI API
            max_tokens = request.max_completion_tokens or request.max_tokens
        else:
720
            max_tokens = getattr(request, "max_tokens", None)
721
722
723
724

        # Note: input length can be up to model context length - 1 for
        # completion-like requests.
        if token_num >= self.max_model_len:
725
            raise ValueError(
726
                f"This model's maximum context length is "
727
728
                f"{self.max_model_len} tokens. However, your request has "
                f"{token_num} input tokens. Please reduce the length of "
729
730
                "the input messages."
            )
731

732
        if max_tokens is not None and token_num + max_tokens > self.max_model_len:
733
734
735
736
737
            raise ValueError(
                "'max_tokens' or 'max_completion_tokens' is too large: "
                f"{max_tokens}. This model's maximum context length is "
                f"{self.max_model_len} tokens and your request has "
                f"{token_num} input tokens ({max_tokens} > {self.max_model_len}"
738
739
                f" - {token_num})."
            )
740
741
742

        return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)

743
    async def _tokenize_prompt_input_async(
744
745
746
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
747
        prompt_input: Union[str, list[int]],
748
749
750
        add_special_tokens: bool = True,
    ) -> TextTokensPrompt:
        """
751
        A simpler implementation that tokenizes a single prompt input.
752
        """
753
        async for result in self._tokenize_prompt_inputs_async(
754
755
            request,
            tokenizer,
756
            [prompt_input],
757
            add_special_tokens=add_special_tokens,
758
759
760
        ):
            return result
        raise ValueError("No results yielded from tokenization")
761

762
    async def _tokenize_prompt_inputs_async(
763
764
765
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
766
        prompt_inputs: Iterable[Union[str, list[int]]],
767
        add_special_tokens: bool = True,
768
    ) -> AsyncGenerator[TextTokensPrompt, None]:
769
        """
770
        A simpler implementation that tokenizes multiple prompt inputs.
771
        """
772
773
        for prompt in prompt_inputs:
            if isinstance(prompt, str):
774
                yield await self._normalize_prompt_text_to_input(
775
                    request,
776
777
                    prompt=prompt,
                    tokenizer=tokenizer,
778
779
780
                    add_special_tokens=add_special_tokens,
                )
            else:
781
                yield await self._normalize_prompt_tokens_to_input(
782
                    request,
783
784
                    prompt_ids=prompt,
                    tokenizer=tokenizer,
785
786
                )

787
788
789
790
791
792
793
    def _validate_chat_template(
        self,
        request_chat_template: Optional[str],
        chat_template_kwargs: Optional[dict[str, Any]],
        trust_request_chat_template: bool,
    ) -> Optional[ErrorResponse]:
        if not trust_request_chat_template and (
794
795
796
797
798
799
            request_chat_template is not None
            or (
                chat_template_kwargs
                and chat_template_kwargs.get("chat_template") is not None
            )
        ):
800
801
802
            return self.create_error_response(
                "Chat template is passed with request, but "
                "--trust-request-chat-template is not set. "
803
804
                "Refused request with untrusted chat template."
            )
805
806
        return None

807
808
    async def _preprocess_chat(
        self,
809
        request: Union[ChatLikeRequest, ResponsesRequest],
810
        tokenizer: AnyTokenizer,
811
        messages: list[ChatCompletionMessageParam],
812
813
        chat_template: Optional[str],
        chat_template_content_format: ChatTemplateContentFormatOption,
814
815
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
816
817
818
        tool_dicts: Optional[list[dict[str, Any]]] = None,
        documents: Optional[list[dict[str, str]]] = None,
        chat_template_kwargs: Optional[dict[str, Any]] = None,
819
820
        tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None,
        add_special_tokens: bool = False,
821
    ) -> tuple[
822
823
824
        list[ConversationMessage],
        Sequence[RequestPrompt],
        list[EngineTokensPrompt],
825
    ]:
826
827
        model_config = self.model_config

828
829
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
830
            tool_dicts,
831
832
            chat_template_content_format,
            tokenizer,
833
            model_config=model_config,
834
        )
835
        conversation, mm_data_future, mm_uuids = parse_chat_messages_futures(
836
            messages,
837
            model_config,
838
            tokenizer,
839
            content_format=resolved_content_format,
840
841
        )

842
        _chat_template_kwargs: dict[str, Any] = dict(
843
844
845
846
847
848
849
850
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tool_dicts,
            documents=documents,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

851
        request_prompt: Union[str, list[int]]
852
853
854
855

        if tokenizer is None:
            request_prompt = "placeholder"
        elif isinstance(tokenizer, MistralTokenizer):
856
            request_prompt = await self._apply_mistral_chat_template_async(
857
858
                tokenizer,
                messages=messages,
859
                **_chat_template_kwargs,
860
861
862
            )
        else:
            request_prompt = apply_hf_chat_template(
863
                tokenizer=tokenizer,
864
                conversation=conversation,
865
                model_config=model_config,
866
                **_chat_template_kwargs,
867
868
869
870
            )

        mm_data = await mm_data_future

871
872
873
        # tool parsing is done only if a tool_parser has been set and if
        # tool_choice is not "none" (if tool_choice is "none" but a tool_parser
        # is set, we want to prevent parsing a tool_call hallucinated by the LLM
874
875
876
        should_parse_tools = tool_parser is not None and (
            hasattr(request, "tool_choice") and request.tool_choice != "none"
        )
877
878

        if should_parse_tools:
879
880
881
882
            if not isinstance(request, ChatCompletionRequest):
                msg = "Tool usage is only supported for Chat Completions API"
                raise NotImplementedError(msg)

883
            request = tool_parser(tokenizer).adjust_request(  # type: ignore
884
885
                request=request
            )
886

887
888
        if tokenizer is None:
            assert isinstance(request_prompt, str), (
889
890
                "Prompt has to be a string",
                "when the tokenizer is not initialised",
891
            )
892
893
894
            prompt_inputs = TextTokensPrompt(
                prompt=request_prompt, prompt_token_ids=[1]
            )
895
        elif isinstance(request_prompt, str):
896
            prompt_inputs = await self._tokenize_prompt_input_async(
897
898
899
900
901
902
903
904
                request,
                tokenizer,
                request_prompt,
                add_special_tokens=add_special_tokens,
            )
        else:
            # For MistralTokenizer
            assert is_list_of(request_prompt, int), (
905
906
                "Prompt has to be either a string or a list of token ids"
            )
907
908
            prompt_inputs = TextTokensPrompt(
                prompt=tokenizer.decode(request_prompt),
909
910
                prompt_token_ids=request_prompt,
            )
911

912
        engine_prompt = EngineTokensPrompt(
913
914
            prompt_token_ids=prompt_inputs["prompt_token_ids"]
        )
915
916
        if mm_data is not None:
            engine_prompt["multi_modal_data"] = mm_data
917
918
919
920

        if mm_uuids is not None:
            engine_prompt["multi_modal_uuids"] = mm_uuids

921
922
        if request.mm_processor_kwargs is not None:
            engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
923

924
925
926
        if hasattr(request, "cache_salt") and request.cache_salt is not None:
            engine_prompt["cache_salt"] = request.cache_salt

927
928
        return conversation, [request_prompt], [engine_prompt]

929
930
931
932
    async def _process_inputs(
        self,
        request_id: str,
        engine_prompt: PromptType,
933
        params: Union[SamplingParams, PoolingParams],
934
935
936
937
938
        *,
        lora_request: Optional[LoRARequest],
        trace_headers: Optional[Mapping[str, str]],
        priority: int,
    ) -> tuple[EngineCoreRequest, dict[str, Any]]:
939
        """Use the Processor to process inputs for AsyncLLM."""
940
        tokenization_kwargs: dict[str, Any] = {}
941
942
943
        _validate_truncation_size(
            self.max_model_len, params.truncate_prompt_tokens, tokenization_kwargs
        )
944
945
946
947
948

        processor = await self._get_processor()
        engine_request = processor.process_inputs(
            request_id,
            engine_prompt,
949
            params,
950
951
952
953
954
955
956
            lora_request=lora_request,
            tokenization_kwargs=tokenization_kwargs,
            trace_headers=trace_headers,
            priority=priority,
        )
        return engine_request, tokenization_kwargs

957
958
959
960
961
962
963
964
965
966
967
    async def _generate_with_builtin_tools(
        self,
        request_id: str,
        request_prompt: RequestPrompt,
        engine_prompt: EngineTokensPrompt,
        sampling_params: SamplingParams,
        context: ConversationContext,
        lora_request: Optional[LoRARequest] = None,
        priority: int = 0,
        **kwargs,
    ):
968
        prompt_text, _, _ = self._get_prompt_components(request_prompt)
969
970
971
972
973
974
975
976
        orig_priority = priority
        while True:
            self._log_inputs(
                request_id,
                request_prompt,
                params=sampling_params,
                lora_request=lora_request,
            )
977
            trace_headers = kwargs.get("trace_headers")
978
            engine_request, tokenization_kwargs = await self._process_inputs(
979
                request_id,
980
981
                engine_prompt,
                sampling_params,
982
983
984
                lora_request=lora_request,
                trace_headers=trace_headers,
                priority=priority,
985
            )
986
987
988
989

            generator = self.engine_client.generate(
                engine_request,
                sampling_params,
990
991
992
                request_id,
                lora_request=lora_request,
                priority=priority,
993
994
                prompt_text=prompt_text,
                tokenization_kwargs=tokenization_kwargs,
995
996
                **kwargs,
            )
997

998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
            async for res in generator:
                context.append_output(res)
                # NOTE(woosuk): The stop condition is handled by the engine.
                yield context

            if not context.need_builtin_tool_call():
                # The model did not ask for a tool call, so we're done.
                break

            # Call the tool and update the context with the result.
            tool_output = await context.call_tool()
            context.append_output(tool_output)

            # TODO: uncomment this and enable tool output streaming
            # yield context

            # Create inputs for the next turn.
            # Render the next prompt token ids.
            prompt_token_ids = context.render_for_completion()
1017
            engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids)
1018
1019
            request_prompt = prompt_token_ids
            # Update the sampling params.
1020
            sampling_params.max_tokens = self.max_model_len - len(prompt_token_ids)
1021
1022
1023
            # OPTIMIZATION
            priority = orig_priority - 1

1024
1025
    def _get_prompt_components(
        self,
1026
        prompt: Union[RequestPrompt, PromptType],
1027
    ) -> PromptComponents:
1028
1029
        if isinstance(prompt, list):
            return PromptComponents(token_ids=prompt)
1030

1031
        return get_prompt_components(prompt)  # type: ignore[arg-type]
1032

1033
1034
1035
    def _log_inputs(
        self,
        request_id: str,
1036
        inputs: Union[RequestPrompt, PromptType],
1037
        params: Optional[Union[SamplingParams, PoolingParams, BeamSearchParams]],
1038
1039
1040
1041
        lora_request: Optional[LoRARequest],
    ) -> None:
        if self.request_logger is None:
            return
1042

1043
        prompt, prompt_token_ids, prompt_embeds = self._get_prompt_components(inputs)
1044
1045
1046
1047
1048

        self.request_logger.log_inputs(
            request_id,
            prompt,
            prompt_token_ids,
1049
            prompt_embeds,
1050
1051
1052
            params=params,
            lora_request=lora_request,
        )
1053

1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
    async def _get_trace_headers(
        self,
        headers: Headers,
    ) -> Optional[Mapping[str, str]]:
        is_tracing_enabled = await self.engine_client.is_tracing_enabled()

        if is_tracing_enabled:
            return extract_trace_headers(headers)

        if contains_trace_headers(headers):
            log_tracing_disabled_warning()

        return None

1068
    @staticmethod
1069
1070
1071
    def _base_request_id(
        raw_request: Optional[Request], default: Optional[str] = None
    ) -> Optional[str]:
1072
1073
        """Pulls the request id to use from a header, if provided"""
        default = default or random_uuid()
1074
1075
1076
1077
        if raw_request is None:
            return default

        return raw_request.headers.get("X-Request-Id", default)
1078

1079
    @staticmethod
1080
1081
1082
1083
1084
1085
    def _get_decoded_token(
        logprob: Logprob,
        token_id: int,
        tokenizer: AnyTokenizer,
        return_as_token_id: bool = False,
    ) -> str:
1086
1087
1088
        if return_as_token_id:
            return f"token_id:{token_id}"

1089
1090
        if logprob.decoded_token is not None:
            return logprob.decoded_token
1091
        return tokenizer.decode(token_id)
1092

1093
    def _is_model_supported(self, model_name: Optional[str]) -> bool:
1094
1095
        if not model_name:
            return True
1096
        return self.models.is_base_model(model_name)
1097

1098
1099

def clamp_prompt_logprobs(
1100
1101
    prompt_logprobs: Union[PromptLogprobs, None],
) -> Union[PromptLogprobs, None]:
1102
1103
1104
1105
1106
1107
1108
    if prompt_logprobs is None:
        return prompt_logprobs

    for logprob_dict in prompt_logprobs:
        if logprob_dict is None:
            continue
        for logprob_values in logprob_dict.values():
1109
            if logprob_values.logprob == float("-inf"):
1110
1111
                logprob_values.logprob = -9999.0
    return prompt_logprobs