serving_engine.py 37.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import json
4
import sys
5
import time
6
import traceback
7
8
from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor
9
from http import HTTPStatus
10
from typing import Any, Callable, ClassVar, Generic, Optional, TypeVar, Union
11

12
import torch
13
from fastapi import Request
14
from pydantic import BaseModel, ConfigDict, Field
15
from starlette.datastructures import Headers
16
17
from typing_extensions import TypeIs

18
19
20
21
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.processor import Processor

22
23
24
25
26
if sys.version_info >= (3, 12):
    from typing import TypedDict
else:
    from typing_extensions import TypedDict

27
import vllm.envs as envs
28
from vllm.config import ModelConfig
29
from vllm.engine.protocol import EngineClient
30
31
32
33
34
35
36
37
38
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ChatTemplateContentFormatOption,
    ConversationMessage,
    apply_hf_chat_template,
    apply_mistral_chat_template,
    parse_chat_messages_futures,
    resolve_chat_template_content_format,
)
39
from vllm.entrypoints.context import ConversationContext
40
from vllm.entrypoints.logger import RequestLogger
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from vllm.entrypoints.openai.protocol import (
    ChatCompletionRequest,
    ChatCompletionResponse,
    ClassificationRequest,
    ClassificationResponse,
    CompletionRequest,
    CompletionResponse,
    DetokenizeRequest,
    EmbeddingChatRequest,
    EmbeddingCompletionRequest,
    EmbeddingRequest,
    EmbeddingResponse,
    ErrorInfo,
    ErrorResponse,
    IOProcessorRequest,
    PoolingResponse,
    RerankRequest,
    ResponsesRequest,
    ScoreRequest,
    ScoreResponse,
    TokenizeChatRequest,
    TokenizeCompletionRequest,
    TokenizeResponse,
    TranscriptionRequest,
    TranscriptionResponse,
    TranslationRequest,
)
68
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
69
from vllm.entrypoints.openai.tool_parsers import ToolParser
70
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig
71
from vllm.inputs.data import PromptType
72
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
73
from vllm.inputs.parse import PromptComponents, get_prompt_components
74
from vllm.logger import init_logger
75
from vllm.logprobs import Logprob, PromptLogprobs
76
from vllm.lora.request import LoRARequest
77
from vllm.multimodal import (  # noqa: F401 - Required to resolve Pydantic error in RequestProcessingMixin
78
79
80
    MultiModalDataDict,
    MultiModalUUIDDict,
)
81
from vllm.outputs import PoolingRequestOutput, RequestOutput
82
from vllm.pooling_params import PoolingParams
83
from vllm.sampling_params import BeamSearchParams, SamplingParams
84
85
86
87
88
from vllm.tracing import (
    contains_trace_headers,
    extract_trace_headers,
    log_tracing_disabled_warning,
)
89
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
90
91
92
93
94
95
96
from vllm.utils import (
    AsyncMicrobatchTokenizer,
    is_list_of,
    make_async,
    merge_async_iterators,
    random_uuid,
)
97
98
99

logger = init_logger(__name__)

100
101
102
103
104
105
106
107
108
CompletionLikeRequest = Union[
    CompletionRequest,
    DetokenizeRequest,
    EmbeddingCompletionRequest,
    RerankRequest,
    ClassificationRequest,
    ScoreRequest,
    TokenizeCompletionRequest,
]
109

110
111
112
ChatLikeRequest = Union[
    ChatCompletionRequest, EmbeddingChatRequest, TokenizeChatRequest
]
113
SpeechToTextRequest = Union[TranscriptionRequest, TranslationRequest]
114
115
116
117
118
119
120
AnyRequest = Union[
    CompletionLikeRequest,
    ChatLikeRequest,
    SpeechToTextRequest,
    ResponsesRequest,
    IOProcessorRequest,
]
121

122
123
124
125
126
127
128
129
130
131
132
AnyResponse = Union[
    CompletionResponse,
    ChatCompletionResponse,
    EmbeddingResponse,
    TranscriptionResponse,
    TokenizeResponse,
    PoolingResponse,
    ClassificationResponse,
    ScoreResponse,
]

133
134
135

class TextTokensPrompt(TypedDict):
    prompt: str
136
    prompt_token_ids: list[int]
137
138


139
140
141
142
143
144
145
146
class EmbedsPrompt(TypedDict):
    prompt_embeds: torch.Tensor


RequestPrompt = Union[list[int], str, TextTokensPrompt, EmbedsPrompt]


def is_text_tokens_prompt(prompt: RequestPrompt) -> TypeIs[TextTokensPrompt]:
147
148
149
150
151
    return (
        isinstance(prompt, dict)
        and "prompt_token_ids" in prompt
        and "prompt_embeds" not in prompt
    )
152
153
154


def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]:
155
156
157
158
159
    return (
        isinstance(prompt, dict)
        and "prompt_token_ids" not in prompt
        and "prompt_embeds" in prompt
    )
160

161

162
163
164
165
166
RequestT = TypeVar("RequestT", bound=AnyRequest)


class RequestProcessingMixin(BaseModel):
    """
167
    Mixin for request processing,
168
169
    handling prompt preparation and engine input.
    """
170

171
    request_prompts: Optional[Sequence[RequestPrompt]] = []
172
    engine_prompts: Optional[list[EngineTokensPrompt]] = []
173
174
175
176
177
178

    model_config = ConfigDict(arbitrary_types_allowed=True)


class ResponseGenerationMixin(BaseModel):
    """
179
    Mixin for response generation,
180
181
    managing result generators and final batch results.
    """
182

183
184
185
    result_generator: Optional[
        AsyncGenerator[tuple[int, Union[RequestOutput, PoolingRequestOutput]], None]
    ] = None
186
    final_res_batch: list[Union[RequestOutput, PoolingRequestOutput]] = Field(
187
188
        default_factory=list
    )
189
190
191
192

    model_config = ConfigDict(arbitrary_types_allowed=True)


193
class ServeContext(
194
195
196
197
    RequestProcessingMixin,
    ResponseGenerationMixin,
    BaseModel,
    Generic[RequestT],
198
):
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
    # Shared across all requests
    request: RequestT
    raw_request: Optional[Request] = None
    model_name: str
    request_id: str
    created_time: int = Field(default_factory=lambda: int(time.time()))
    lora_request: Optional[LoRARequest] = None

    # Shared across most requests
    tokenizer: Optional[AnyTokenizer] = None

    # `protected_namespaces` resolves Pydantic v2's warning
    # on conflict with protected namespace "model_"
    model_config = ConfigDict(
        protected_namespaces=(),
        arbitrary_types_allowed=True,
    )


ClassificationServeContext = ServeContext[ClassificationRequest]


class EmbeddingServeContext(ServeContext[EmbeddingRequest]):
    chat_template: Optional[str] = None
    chat_template_content_format: ChatTemplateContentFormatOption


# Used to resolve the Pydantic error related to
# forward reference of MultiModalDataDict in TokensPrompt
RequestProcessingMixin.model_rebuild()
ServeContext.model_rebuild()
ClassificationServeContext.model_rebuild()
EmbeddingServeContext.model_rebuild()

233

234
class OpenAIServing:
235
236
237
238
    request_id_prefix: ClassVar[str] = """
    A short string prepended to every request’s ID (e.g. "embd", "classify")
    so you can easily tell “this ID came from Embedding vs Classification.”
    """
239

240
241
    def __init__(
        self,
242
        engine_client: EngineClient,
243
        model_config: ModelConfig,
244
        models: OpenAIServingModels,
245
246
        *,
        request_logger: Optional[RequestLogger],
247
        return_tokens_as_token_ids: bool = False,
248
        enable_force_include_usage: bool = False,
249
        log_error_stack: bool = False,
250
    ):
251
252
        super().__init__()

253
        self.engine_client = engine_client
254
        self.model_config = model_config
255
256
        self.max_model_len = model_config.max_model_len

257
        self.models = models
258

259
        self.request_logger = request_logger
260
        self.return_tokens_as_token_ids = return_tokens_as_token_ids
261
        self.enable_force_include_usage = enable_force_include_usage
262

263
        self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
264
        self._apply_mistral_chat_template_async = make_async(
265
266
            apply_mistral_chat_template, executor=self._tokenizer_executor
        )
267

268
        self._async_tokenizer_pool: dict[AnyTokenizer, AsyncMicrobatchTokenizer] = {}
269
        self.log_error_stack = log_error_stack
270

271
272
273
    async def _get_processor(self) -> Processor:
        if not hasattr(self, "_processor"):
            vllm_config = await self.engine_client.get_vllm_config()
274
275
            self._processor = Processor(vllm_config)

276
277
        return self._processor

278
279
280
281
282
283
284
285
    def _get_renderer(self, tokenizer: Optional[AnyTokenizer]) -> BaseRenderer:
        """
        Get a Renderer instance with the provided tokenizer.
        Uses shared async tokenizer pool for efficiency.
        """
        return CompletionRenderer(
            model_config=self.model_config,
            tokenizer=tokenizer,
286
287
            async_tokenizer_pool=self._async_tokenizer_pool,
        )
288

289
290
291
292
293
294
295
296
297
298
299
300
301
    def _build_render_config(
        self,
        request: Any,
    ) -> RenderConfig:
        """
        Build and return a `RenderConfig` for an endpoint.

        Used by the renderer to control how prompts are prepared
        (e.g., tokenization and length handling). Endpoints should
        implement this with logic appropriate to their request type.
        """
        raise NotImplementedError

302
303
    def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer:
        """
304
        Return (and cache) an `AsyncMicrobatchTokenizer` bound to the
305
306
307
308
309
310
311
        given tokenizer.
        """
        async_tokenizer = self._async_tokenizer_pool.get(tokenizer)
        if async_tokenizer is None:
            async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
            self._async_tokenizer_pool[tokenizer] = async_tokenizer
        return async_tokenizer
312

313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
    async def _preprocess(
        self,
        ctx: ServeContext,
    ) -> Optional[ErrorResponse]:
        """
        Default preprocessing hook. Subclasses may override
        to prepare `ctx` (classification, embedding, etc.).
        """
        return None

    def _build_response(
        self,
        ctx: ServeContext,
    ) -> Union[AnyResponse, ErrorResponse]:
        """
        Default response builder. Subclass may override this method
        to return the appropriate response object.
        """
        return self.create_error_response("unimplemented endpoint")

    async def handle(
        self,
        ctx: ServeContext,
    ) -> Union[AnyResponse, ErrorResponse]:
        generation: AsyncGenerator[Union[AnyResponse, ErrorResponse], None]
        generation = self._pipeline(ctx)

        async for response in generation:
            return response

        return self.create_error_response("No response yielded from pipeline")

    async def _pipeline(
        self,
        ctx: ServeContext,
    ) -> AsyncGenerator[Union[AnyResponse, ErrorResponse], None]:
        """Execute the request processing pipeline yielding responses."""
        if error := await self._check_model(ctx.request):
            yield error
        if error := self._validate_request(ctx):
            yield error

        preprocess_ret = await self._preprocess(ctx)
        if isinstance(preprocess_ret, ErrorResponse):
            yield preprocess_ret

        generators_ret = await self._prepare_generators(ctx)
        if isinstance(generators_ret, ErrorResponse):
            yield generators_ret

        collect_ret = await self._collect_batch(ctx)
        if isinstance(collect_ret, ErrorResponse):
            yield collect_ret

        yield self._build_response(ctx)

    def _validate_request(self, ctx: ServeContext) -> Optional[ErrorResponse]:
370
        truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None)
371

372
373
374
375
        if (
            truncate_prompt_tokens is not None
            and truncate_prompt_tokens > self.max_model_len
        ):
376
377
378
            return self.create_error_response(
                "truncate_prompt_tokens value is "
                "greater than max_model_len."
379
380
                " Please, select a smaller truncation size."
            )
381
382
        return None

383
384
385
386
387
388
    def _create_pooling_params(
        self,
        ctx: ServeContext,
    ) -> Union[PoolingParams, ErrorResponse]:
        if not hasattr(ctx.request, "to_pooling_params"):
            return self.create_error_response(
389
390
                "Request type does not support pooling parameters"
            )
391
392
393

        return ctx.request.to_pooling_params()

394
395
396
397
398
    async def _prepare_generators(
        self,
        ctx: ServeContext,
    ) -> Optional[ErrorResponse]:
        """Schedule the request and get the result generator."""
399
400
401
        generators: list[
            AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]
        ] = []
402
403

        try:
404
405
406
407
408
            trace_headers = (
                None
                if ctx.raw_request is None
                else await self._get_trace_headers(ctx.raw_request.headers)
            )
409

410
411
412
            pooling_params = self._create_pooling_params(ctx)
            if isinstance(pooling_params, ErrorResponse):
                return pooling_params
413
414

            if ctx.engine_prompts is None:
415
                return self.create_error_response("Engine prompts not available")
416
417
418
419

            for i, engine_prompt in enumerate(ctx.engine_prompts):
                request_id_item = f"{ctx.request_id}-{i}"

420
421
                self._log_inputs(
                    request_id_item,
422
                    engine_prompt,
423
424
425
                    params=pooling_params,
                    lora_request=ctx.lora_request,
                )
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452

                generator = self.engine_client.encode(
                    engine_prompt,
                    pooling_params,
                    request_id_item,
                    lora_request=ctx.lora_request,
                    trace_headers=trace_headers,
                    priority=getattr(ctx.request, "priority", 0),
                )

                generators.append(generator)

            ctx.result_generator = merge_async_iterators(*generators)

            return None

        except Exception as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))

    async def _collect_batch(
        self,
        ctx: ServeContext,
    ) -> Optional[ErrorResponse]:
        """Collect batch results from the result generator."""
        try:
            if ctx.engine_prompts is None:
453
                return self.create_error_response("Engine prompts not available")
454
455

            num_prompts = len(ctx.engine_prompts)
456
            final_res_batch: list[Optional[Union[RequestOutput, PoolingRequestOutput]]]
457
458
459
            final_res_batch = [None] * num_prompts

            if ctx.result_generator is None:
460
                return self.create_error_response("Result generator not available")
461
462
463
464
465
466

            async for i, res in ctx.result_generator:
                final_res_batch[i] = res

            if None in final_res_batch:
                return self.create_error_response(
467
468
                    "Failed to generate results for all prompts"
                )
469

470
            ctx.final_res_batch = [res for res in final_res_batch if res is not None]
471
472
473
474
475
476

            return None

        except Exception as e:
            return self.create_error_response(str(e))

477
    def create_error_response(
478
479
480
481
482
        self,
        message: str,
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
    ) -> ErrorResponse:
483
484
485
486
487
488
        if self.log_error_stack:
            exc_type, _, _ = sys.exc_info()
            if exc_type is not None:
                traceback.print_exc()
            else:
                traceback.print_stack()
489
490
491
        return ErrorResponse(
            error=ErrorInfo(message=message, type=err_type, code=status_code.value)
        )
492

493
    def create_streaming_error_response(
494
495
496
497
498
        self,
        message: str,
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
    ) -> str:
499
        json_str = json.dumps(
500
501
502
503
            self.create_error_response(
                message=message, err_type=err_type, status_code=status_code
            ).model_dump()
        )
504
505
        return json_str

506
    async def _check_model(
507
508
        self,
        request: AnyRequest,
509
    ) -> Optional[ErrorResponse]:
510
511
        error_response = None

512
        if self._is_model_supported(request.model):
513
            return None
514
        if request.model in self.models.lora_requests:
515
            return None
516
517
518
519
520
        if (
            envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING
            and request.model
            and (load_result := await self.models.resolve_lora(request.model))
        ):
521
522
            if isinstance(load_result, LoRARequest):
                return None
523
524
525
526
            if (
                isinstance(load_result, ErrorResponse)
                and load_result.error.code == HTTPStatus.BAD_REQUEST.value
            ):
527
528
529
                error_response = load_result

        return error_response or self.create_error_response(
530
531
            message=f"The model `{request.model}` does not exist.",
            err_type="NotFoundError",
532
533
            status_code=HTTPStatus.NOT_FOUND,
        )
534

535
    def _get_active_default_mm_loras(
536
537
        self, request: AnyRequest
    ) -> Optional[LoRARequest]:
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
        """Determine if there are any active default multimodal loras."""
        # TODO: Currently this is only enabled for chat completions
        # to be better aligned with only being enabled for .generate
        # when run offline. It would be nice to support additional
        # tasks types in the future.
        message_types = self._get_message_types(request)
        default_mm_loras = set()

        for lora in self.models.lora_requests.values():
            # Best effort match for default multimodal lora adapters;
            # There is probably a better way to do this, but currently
            # this matches against the set of 'types' in any content lists
            # up until '_', e.g., to match audio_url -> audio
            if lora.lora_name in message_types:
                default_mm_loras.add(lora)

        # Currently only support default modality specific loras if
        # we have exactly one lora matched on the request.
        if len(default_mm_loras) == 1:
            return default_mm_loras.pop()
        return None

560
    def _maybe_get_adapters(
561
562
563
        self,
        request: AnyRequest,
        supports_default_mm_loras: bool = False,
564
    ) -> Optional[LoRARequest]:
565
        if request.model in self.models.lora_requests:
566
            return self.models.lora_requests[request.model]
567
568
569
570
571
572

        # Currently only support default modality specific loras
        # if we have exactly one lora matched on the request.
        if supports_default_mm_loras:
            default_mm_lora = self._get_active_default_mm_loras(request)
            if default_mm_lora is not None:
573
                return default_mm_lora
574
575

        if self._is_model_supported(request.model):
576
            return None
577

578
        # if _check_model has been called earlier, this will be unreachable
579
        raise ValueError(f"The model `{request.model}` does not exist.")
580

581
582
583
584
585
586
587
588
589
590
591
    def _get_message_types(self, request: AnyRequest) -> set[str]:
        """Retrieve the set of types from message content dicts up
        until `_`; we use this to match potential multimodal data
        with default per modality loras.
        """
        message_types: set[str] = set()

        if not hasattr(request, "messages"):
            return message_types

        for message in request.messages:
592
593
594
595
596
            if (
                isinstance(message, dict)
                and "content" in message
                and isinstance(message["content"], list)
            ):
597
598
599
600
601
                for content_dict in message["content"]:
                    if "type" in content_dict:
                        message_types.add(content_dict["type"].split("_")[0])
        return message_types

602
    async def _normalize_prompt_text_to_input(
603
604
605
        self,
        request: AnyRequest,
        prompt: str,
606
        tokenizer: AnyTokenizer,
607
608
        add_special_tokens: bool,
    ) -> TextTokensPrompt:
609
610
        async_tokenizer = self._get_async_tokenizer(tokenizer)

611
612
613
614
        if (
            self.model_config.encoder_config is not None
            and self.model_config.encoder_config.get("do_lower_case", False)
        ):
615
616
            prompt = prompt.lower()

617
        truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
618

619
        if truncate_prompt_tokens is None:
620
            encoded = await async_tokenizer(
621
622
                prompt, add_special_tokens=add_special_tokens
            )
623
624
        elif truncate_prompt_tokens < 0:
            # Negative means we cap at the model's max length
625
626
627
628
            encoded = await async_tokenizer(
                prompt,
                add_special_tokens=add_special_tokens,
                truncation=True,
629
630
                max_length=self.max_model_len,
            )
631
        else:
632
633
634
635
            encoded = await async_tokenizer(
                prompt,
                add_special_tokens=add_special_tokens,
                truncation=True,
636
637
                max_length=truncate_prompt_tokens,
            )
638
639
640
641
642
643

        input_ids = encoded.input_ids
        input_text = prompt

        return self._validate_input(request, input_ids, input_text)

644
    async def _normalize_prompt_tokens_to_input(
645
646
        self,
        request: AnyRequest,
647
        prompt_ids: list[int],
648
        tokenizer: Optional[AnyTokenizer],
649
    ) -> TextTokensPrompt:
650
        truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
651

652
        if truncate_prompt_tokens is None:
653
            input_ids = prompt_ids
654
        elif truncate_prompt_tokens < 0:
655
            input_ids = prompt_ids[-self.max_model_len :]
656
657
658
        else:
            input_ids = prompt_ids[-truncate_prompt_tokens:]

659
660
661
662
663
        if tokenizer is None:
            input_text = ""
        else:
            async_tokenizer = self._get_async_tokenizer(tokenizer)
            input_text = await async_tokenizer.decode(input_ids)
664

665
666
667
668
669
        return self._validate_input(request, input_ids, input_text)

    def _validate_input(
        self,
        request: AnyRequest,
670
        input_ids: list[int],
671
672
        input_text: str,
    ) -> TextTokensPrompt:
673
674
        token_num = len(input_ids)

675
676
        # Note: EmbeddingRequest, ClassificationRequest,
        # and ScoreRequest doesn't have max_tokens
677
        if isinstance(
678
            request,
679
680
681
682
683
684
685
686
            (
                EmbeddingChatRequest,
                EmbeddingCompletionRequest,
                ScoreRequest,
                RerankRequest,
                ClassificationRequest,
            ),
        ):
687
688
            # Note: input length can be up to the entire model context length
            # since these requests don't generate tokens.
689
            if token_num > self.max_model_len:
690
691
                operations: dict[type[AnyRequest], str] = {
                    ScoreRequest: "score",
692
                    ClassificationRequest: "classification",
693
                }
694
                operation = operations.get(type(request), "embedding generation")
695
696
697
                raise ValueError(
                    f"This model's maximum context length is "
                    f"{self.max_model_len} tokens. However, you requested "
698
                    f"{token_num} tokens in the input for {operation}. "
699
700
701
                    f"Please reduce the length of the input."
                )
            return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
702

703
704
        # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
        # and does not require model context length validation
705
        if isinstance(
706
707
            request,
            (TokenizeCompletionRequest, TokenizeChatRequest, DetokenizeRequest),
708
        ):
709
            return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
710

711
712
713
714
715
        # chat completion endpoint supports max_completion_tokens
        if isinstance(request, ChatCompletionRequest):
            # TODO(#9845): remove max_tokens when field dropped from OpenAI API
            max_tokens = request.max_completion_tokens or request.max_tokens
        else:
716
            max_tokens = getattr(request, "max_tokens", None)
717
718
719
720

        # Note: input length can be up to model context length - 1 for
        # completion-like requests.
        if token_num >= self.max_model_len:
721
            raise ValueError(
722
                f"This model's maximum context length is "
723
724
                f"{self.max_model_len} tokens. However, your request has "
                f"{token_num} input tokens. Please reduce the length of "
725
726
                "the input messages."
            )
727

728
        if max_tokens is not None and token_num + max_tokens > self.max_model_len:
729
730
731
732
733
            raise ValueError(
                "'max_tokens' or 'max_completion_tokens' is too large: "
                f"{max_tokens}. This model's maximum context length is "
                f"{self.max_model_len} tokens and your request has "
                f"{token_num} input tokens ({max_tokens} > {self.max_model_len}"
734
735
                f" - {token_num})."
            )
736
737
738

        return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)

739
    async def _tokenize_prompt_input_async(
740
741
742
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
743
        prompt_input: Union[str, list[int]],
744
745
746
        add_special_tokens: bool = True,
    ) -> TextTokensPrompt:
        """
747
        A simpler implementation that tokenizes a single prompt input.
748
        """
749
        async for result in self._tokenize_prompt_inputs_async(
750
751
            request,
            tokenizer,
752
            [prompt_input],
753
            add_special_tokens=add_special_tokens,
754
755
756
        ):
            return result
        raise ValueError("No results yielded from tokenization")
757

758
    async def _tokenize_prompt_inputs_async(
759
760
761
        self,
        request: AnyRequest,
        tokenizer: AnyTokenizer,
762
        prompt_inputs: Iterable[Union[str, list[int]]],
763
        add_special_tokens: bool = True,
764
    ) -> AsyncGenerator[TextTokensPrompt, None]:
765
        """
766
        A simpler implementation that tokenizes multiple prompt inputs.
767
        """
768
769
        for prompt in prompt_inputs:
            if isinstance(prompt, str):
770
                yield await self._normalize_prompt_text_to_input(
771
                    request,
772
773
                    prompt=prompt,
                    tokenizer=tokenizer,
774
775
776
                    add_special_tokens=add_special_tokens,
                )
            else:
777
                yield await self._normalize_prompt_tokens_to_input(
778
                    request,
779
780
                    prompt_ids=prompt,
                    tokenizer=tokenizer,
781
782
                )

783
784
785
786
787
788
789
    def _validate_chat_template(
        self,
        request_chat_template: Optional[str],
        chat_template_kwargs: Optional[dict[str, Any]],
        trust_request_chat_template: bool,
    ) -> Optional[ErrorResponse]:
        if not trust_request_chat_template and (
790
791
792
793
794
795
            request_chat_template is not None
            or (
                chat_template_kwargs
                and chat_template_kwargs.get("chat_template") is not None
            )
        ):
796
797
798
            return self.create_error_response(
                "Chat template is passed with request, but "
                "--trust-request-chat-template is not set. "
799
800
                "Refused request with untrusted chat template."
            )
801
802
        return None

803
804
    async def _preprocess_chat(
        self,
805
        request: Union[ChatLikeRequest, ResponsesRequest],
806
        tokenizer: AnyTokenizer,
807
        messages: list[ChatCompletionMessageParam],
808
809
        chat_template: Optional[str],
        chat_template_content_format: ChatTemplateContentFormatOption,
810
811
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
812
813
814
        tool_dicts: Optional[list[dict[str, Any]]] = None,
        documents: Optional[list[dict[str, str]]] = None,
        chat_template_kwargs: Optional[dict[str, Any]] = None,
815
816
        tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None,
        add_special_tokens: bool = False,
817
    ) -> tuple[
818
819
820
        list[ConversationMessage],
        Sequence[RequestPrompt],
        list[EngineTokensPrompt],
821
    ]:
822
823
        model_config = self.model_config

824
825
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
826
            tool_dicts,
827
828
            chat_template_content_format,
            tokenizer,
829
            model_config=model_config,
830
        )
831
        conversation, mm_data_future, mm_uuids = parse_chat_messages_futures(
832
            messages,
833
            model_config,
834
            tokenizer,
835
            content_format=resolved_content_format,
836
837
        )

838
        _chat_template_kwargs: dict[str, Any] = dict(
839
840
841
842
843
844
845
846
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tool_dicts,
            documents=documents,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

847
        request_prompt: Union[str, list[int]]
848
849
850
851

        if tokenizer is None:
            request_prompt = "placeholder"
        elif isinstance(tokenizer, MistralTokenizer):
852
            request_prompt = await self._apply_mistral_chat_template_async(
853
854
                tokenizer,
                messages=messages,
855
                **_chat_template_kwargs,
856
857
858
            )
        else:
            request_prompt = apply_hf_chat_template(
859
                tokenizer=tokenizer,
860
                conversation=conversation,
861
                model_config=model_config,
862
                **_chat_template_kwargs,
863
864
865
866
            )

        mm_data = await mm_data_future

867
868
869
        # tool parsing is done only if a tool_parser has been set and if
        # tool_choice is not "none" (if tool_choice is "none" but a tool_parser
        # is set, we want to prevent parsing a tool_call hallucinated by the LLM
870
871
872
        should_parse_tools = tool_parser is not None and (
            hasattr(request, "tool_choice") and request.tool_choice != "none"
        )
873
874

        if should_parse_tools:
875
876
877
878
            if not isinstance(request, ChatCompletionRequest):
                msg = "Tool usage is only supported for Chat Completions API"
                raise NotImplementedError(msg)

879
            request = tool_parser(tokenizer).adjust_request(  # type: ignore
880
881
                request=request
            )
882

883
884
        if tokenizer is None:
            assert isinstance(request_prompt, str), (
885
886
                "Prompt has to be a string",
                "when the tokenizer is not initialised",
887
            )
888
889
890
            prompt_inputs = TextTokensPrompt(
                prompt=request_prompt, prompt_token_ids=[1]
            )
891
        elif isinstance(request_prompt, str):
892
            prompt_inputs = await self._tokenize_prompt_input_async(
893
894
895
896
897
898
899
900
                request,
                tokenizer,
                request_prompt,
                add_special_tokens=add_special_tokens,
            )
        else:
            # For MistralTokenizer
            assert is_list_of(request_prompt, int), (
901
902
                "Prompt has to be either a string or a list of token ids"
            )
903
904
            prompt_inputs = TextTokensPrompt(
                prompt=tokenizer.decode(request_prompt),
905
906
                prompt_token_ids=request_prompt,
            )
907

908
        engine_prompt = EngineTokensPrompt(
909
910
            prompt_token_ids=prompt_inputs["prompt_token_ids"]
        )
911
912
        if mm_data is not None:
            engine_prompt["multi_modal_data"] = mm_data
913
914
915
916

        if mm_uuids is not None:
            engine_prompt["multi_modal_uuids"] = mm_uuids

917
918
        if request.mm_processor_kwargs is not None:
            engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
919

920
921
922
        if hasattr(request, "cache_salt") and request.cache_salt is not None:
            engine_prompt["cache_salt"] = request.cache_salt

923
924
        return conversation, [request_prompt], [engine_prompt]

925
926
927
928
    async def _process_inputs(
        self,
        request_id: str,
        engine_prompt: PromptType,
929
        params: Union[SamplingParams, PoolingParams],
930
931
932
933
934
        *,
        lora_request: Optional[LoRARequest],
        trace_headers: Optional[Mapping[str, str]],
        priority: int,
    ) -> tuple[EngineCoreRequest, dict[str, Any]]:
935
        """Use the Processor to process inputs for AsyncLLM."""
936
        tokenization_kwargs: dict[str, Any] = {}
937
938
939
        _validate_truncation_size(
            self.max_model_len, params.truncate_prompt_tokens, tokenization_kwargs
        )
940
941
942
943
944

        processor = await self._get_processor()
        engine_request = processor.process_inputs(
            request_id,
            engine_prompt,
945
            params,
946
947
948
949
950
951
952
            lora_request=lora_request,
            tokenization_kwargs=tokenization_kwargs,
            trace_headers=trace_headers,
            priority=priority,
        )
        return engine_request, tokenization_kwargs

953
954
955
956
957
958
959
960
961
962
963
    async def _generate_with_builtin_tools(
        self,
        request_id: str,
        request_prompt: RequestPrompt,
        engine_prompt: EngineTokensPrompt,
        sampling_params: SamplingParams,
        context: ConversationContext,
        lora_request: Optional[LoRARequest] = None,
        priority: int = 0,
        **kwargs,
    ):
964
        prompt_text, _, _ = self._get_prompt_components(request_prompt)
965
966
967
968
969
970
971
972
        orig_priority = priority
        while True:
            self._log_inputs(
                request_id,
                request_prompt,
                params=sampling_params,
                lora_request=lora_request,
            )
973
            trace_headers = kwargs.get("trace_headers")
974
            engine_request, tokenization_kwargs = await self._process_inputs(
975
                request_id,
976
977
                engine_prompt,
                sampling_params,
978
979
980
                lora_request=lora_request,
                trace_headers=trace_headers,
                priority=priority,
981
            )
982
983
984
985

            generator = self.engine_client.generate(
                engine_request,
                sampling_params,
986
987
988
                request_id,
                lora_request=lora_request,
                priority=priority,
989
990
                prompt_text=prompt_text,
                tokenization_kwargs=tokenization_kwargs,
991
992
                **kwargs,
            )
993

994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
            async for res in generator:
                context.append_output(res)
                # NOTE(woosuk): The stop condition is handled by the engine.
                yield context

            if not context.need_builtin_tool_call():
                # The model did not ask for a tool call, so we're done.
                break

            # Call the tool and update the context with the result.
            tool_output = await context.call_tool()
            context.append_output(tool_output)

            # TODO: uncomment this and enable tool output streaming
            # yield context

            # Create inputs for the next turn.
            # Render the next prompt token ids.
            prompt_token_ids = context.render_for_completion()
1013
            engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids)
1014
1015
            request_prompt = prompt_token_ids
            # Update the sampling params.
1016
            sampling_params.max_tokens = self.max_model_len - len(prompt_token_ids)
1017
1018
1019
            # OPTIMIZATION
            priority = orig_priority - 1

1020
1021
    def _get_prompt_components(
        self,
1022
        prompt: Union[RequestPrompt, PromptType],
1023
    ) -> PromptComponents:
1024
1025
        if isinstance(prompt, list):
            return PromptComponents(token_ids=prompt)
1026

1027
        return get_prompt_components(prompt)  # type: ignore[arg-type]
1028

1029
1030
1031
    def _log_inputs(
        self,
        request_id: str,
1032
        inputs: Union[RequestPrompt, PromptType],
1033
        params: Optional[Union[SamplingParams, PoolingParams, BeamSearchParams]],
1034
1035
1036
1037
        lora_request: Optional[LoRARequest],
    ) -> None:
        if self.request_logger is None:
            return
1038

1039
        prompt, prompt_token_ids, prompt_embeds = self._get_prompt_components(inputs)
1040
1041
1042
1043
1044

        self.request_logger.log_inputs(
            request_id,
            prompt,
            prompt_token_ids,
1045
            prompt_embeds,
1046
1047
1048
            params=params,
            lora_request=lora_request,
        )
1049

1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
    async def _get_trace_headers(
        self,
        headers: Headers,
    ) -> Optional[Mapping[str, str]]:
        is_tracing_enabled = await self.engine_client.is_tracing_enabled()

        if is_tracing_enabled:
            return extract_trace_headers(headers)

        if contains_trace_headers(headers):
            log_tracing_disabled_warning()

        return None

1064
    @staticmethod
1065
1066
1067
    def _base_request_id(
        raw_request: Optional[Request], default: Optional[str] = None
    ) -> Optional[str]:
1068
1069
        """Pulls the request id to use from a header, if provided"""
        default = default or random_uuid()
1070
1071
1072
1073
        if raw_request is None:
            return default

        return raw_request.headers.get("X-Request-Id", default)
1074

1075
    @staticmethod
1076
1077
1078
1079
1080
1081
    def _get_decoded_token(
        logprob: Logprob,
        token_id: int,
        tokenizer: AnyTokenizer,
        return_as_token_id: bool = False,
    ) -> str:
1082
1083
1084
        if return_as_token_id:
            return f"token_id:{token_id}"

1085
1086
        if logprob.decoded_token is not None:
            return logprob.decoded_token
1087
        return tokenizer.decode(token_id)
1088

1089
    def _is_model_supported(self, model_name: Optional[str]) -> bool:
1090
1091
        if not model_name:
            return True
1092
        return self.models.is_base_model(model_name)
1093

1094
1095

def clamp_prompt_logprobs(
1096
1097
    prompt_logprobs: Union[PromptLogprobs, None],
) -> Union[PromptLogprobs, None]:
1098
1099
1100
1101
1102
1103
1104
    if prompt_logprobs is None:
        return prompt_logprobs

    for logprob_dict in prompt_logprobs:
        if logprob_dict is None:
            continue
        for logprob_values in logprob_dict.values():
1105
            if logprob_values.logprob == float("-inf"):
1106
1107
                logprob_values.logprob = -9999.0
    return prompt_logprobs