"vscode:/vscode.git/clone" did not exist on "cb64c6bce1332416e621efe28751ed46ecb6c166"
chat_utils.py 65 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
import inspect
6
import json
7
from abc import ABC, abstractmethod
8
from collections import Counter, defaultdict, deque
9
from collections.abc import Awaitable, Callable, Iterable
10
from functools import cached_property, lru_cache, partial
11
from pathlib import Path
12
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, TypeVar, cast
13

14
15
16
import jinja2
import jinja2.ext
import jinja2.meta
17
import jinja2.nodes
18
19
import jinja2.parser
import jinja2.sandbox
20
import transformers.utils.chat_template_utils as hf_chat_utils
21
from openai.types.chat import (
22
23
24
25
26
    ChatCompletionAssistantMessageParam,
    ChatCompletionContentPartImageParam,
    ChatCompletionContentPartInputAudioParam,
    ChatCompletionContentPartRefusalParam,
    ChatCompletionContentPartTextParam,
27
    ChatCompletionFunctionToolParam,
28
29
30
31
32
33
    ChatCompletionMessageToolCallParam,
    ChatCompletionToolMessageParam,
)
from openai.types.chat import (
    ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam,
)
34
from openai.types.chat import (
35
36
37
    ChatCompletionMessageParam as OpenAIChatCompletionMessageParam,
)
from openai.types.chat.chat_completion_content_part_input_audio_param import InputAudio
38
from openai.types.responses import ResponseInputImageParam
39
from openai_harmony import Message as OpenAIHarmonyMessage
40
41
from PIL import Image
from pydantic import BaseModel, ConfigDict, TypeAdapter
42
43
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin

44
# pydantic needs the TypedDict from typing_extensions
45
from typing_extensions import Required, TypedDict
46

47
from vllm import envs
48
from vllm.config import ModelConfig
49
from vllm.logger import init_logger
50
from vllm.model_executor.models import SupportsMultiModal
51
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict
52
from vllm.multimodal.utils import MEDIA_CONNECTOR_REGISTRY, MediaConnector
53
from vllm.tokenizers import TokenizerLike
54
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path
55
from vllm.transformers_utils.processor import cached_get_processor
56
from vllm.utils import random_uuid
57
from vllm.utils.collection_utils import is_list_of
58
from vllm.utils.func_utils import supports_kw
59
60
61
62
from vllm.utils.import_utils import LazyLoader

if TYPE_CHECKING:
    import torch
63
64

    from vllm.tokenizers.mistral import MistralTokenizer
65
66
else:
    torch = LazyLoader("torch", globals(), "torch")
67
68
69

logger = init_logger(__name__)

70
71
72
73
74
75
MODALITY_PLACEHOLDERS_MAP = {
    "image": "<##IMAGE##>",
    "audio": "<##AUDIO##>",
    "video": "<##VIDEO##>",
}

76

77
78
79
80
81
82
83
84
85
86
87
88
89
90
class AudioURL(TypedDict, total=False):
    url: Required[str]
    """
    Either a URL of the audio or a data URL with base64 encoded audio data.
    """


class ChatCompletionContentPartAudioParam(TypedDict, total=False):
    audio_url: Required[AudioURL]

    type: Required[Literal["audio_url"]]
    """The type of the content part."""


91
class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
92
    image_embeds: str | dict[str, str] | None
93
94
95
96
97
98
99
    """
    The image embeddings. It can be either:
    - A single base64 string.
    - A dictionary where each value is a base64 string.
    """
    type: Required[Literal["image_embeds"]]
    """The type of the content part."""
100
    uuid: str | None
101
102
103
104
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
105
106


107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
class ChatCompletionContentPartAudioEmbedsParam(TypedDict, total=False):
    audio_embeds: str | dict[str, str] | None
    """
    The audio embeddings. It can be either:
    - A single base64 string representing a serialized torch tensor.
    - A dictionary where each value is a base64 string.
    """
    type: Required[Literal["audio_embeds"]]
    """The type of the content part."""
    uuid: str | None
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """


123
124
125
126
127
128
129
130
131
132
133
134
135
136
class VideoURL(TypedDict, total=False):
    url: Required[str]
    """
    Either a URL of the video or a data URL with base64 encoded video data.
    """


class ChatCompletionContentPartVideoParam(TypedDict, total=False):
    video_url: Required[VideoURL]

    type: Required[Literal["video_url"]]
    """The type of the content part."""


137
138
139
140
class PILImage(BaseModel):
    """
    A PIL.Image.Image object.
    """
141

142
143
144
145
146
147
148
149
150
151
152
153
    image_pil: Image.Image
    model_config = ConfigDict(arbitrary_types_allowed=True)


class CustomChatCompletionContentPILImageParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a PIL image.

    Example:
    {
        "image_pil": ImageAsset('cherry_blossom').pil_image
    }
    """
154

155
156
    image_pil: PILImage | None
    uuid: str | None
157
158
159
160
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
161
162


163
164
165
class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain image_url.
    This is supported by OpenAI API, although it is not documented.
166

167
168
169
170
171
    Example:
    {
        "image_url": "https://example.com/image.jpg"
    }
    """
172

173
174
    image_url: str | None
    uuid: str | None
175
176
177
178
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
179
180
181
182


class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain audio_url.
183

184
185
186
187
188
    Example:
    {
        "audio_url": "https://example.com/audio.mp3"
    }
    """
189

190
    audio_url: str | None
191
192


193
194
195
196
197
198
199
200
class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain audio_url.

    Example:
    {
        "video_url": "https://example.com/video.mp4"
    }
    """
201

202
203
    video_url: str | None
    uuid: str | None
204
205
206
207
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
208
209


Julien Denize's avatar
Julien Denize committed
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
class CustomThinkCompletionContentParam(TypedDict, total=False):
    """A Think Completion Content Param that accepts a plain text and a boolean.

    Example:
    {
        "thinking": "I am thinking about the answer",
        "closed": True,
        "type": "thinking"
    }
    """

    thinking: Required[str]
    """The thinking content."""

    closed: bool
    """Whether the thinking is closed."""

    type: Required[Literal["thinking"]]
    """The thinking type."""


231
232
233
234
235
236
237
238
239
ChatCompletionContentPartParam: TypeAlias = (
    OpenAIChatCompletionContentPartParam
    | ChatCompletionContentPartAudioParam
    | ChatCompletionContentPartInputAudioParam
    | ChatCompletionContentPartVideoParam
    | ChatCompletionContentPartRefusalParam
    | CustomChatCompletionContentPILImageParam
    | CustomChatCompletionContentSimpleImageParam
    | ChatCompletionContentPartImageEmbedsParam
240
    | ChatCompletionContentPartAudioEmbedsParam
241
242
243
244
245
    | CustomChatCompletionContentSimpleAudioParam
    | CustomChatCompletionContentSimpleVideoParam
    | str
    | CustomThinkCompletionContentParam
)
246
247
248
249


class CustomChatCompletionMessageParam(TypedDict, total=False):
    """Enables custom roles in the Chat Completion API."""
250

251
252
253
    role: Required[str]
    """The role of the message's author."""

254
    content: str | list[ChatCompletionContentPartParam]
255
256
257
258
259
260
261
262
263
    """The contents of the message."""

    name: str
    """An optional name for the participant.

    Provides the model information to differentiate between participants of the
    same role.
    """

264
    tool_call_id: str | None
265
266
    """Tool call that this message is responding to."""

267
    tool_calls: Iterable[ChatCompletionMessageToolCallParam] | None
268
269
    """The tool calls generated by the model, such as function calls."""

270
271
272
    reasoning: str | None
    """The reasoning content for interleaved thinking."""

273
274
275
    tools: list[ChatCompletionFunctionToolParam] | None
    """The tools for developer role."""

276

277
278
279
280
281
ChatCompletionMessageParam: TypeAlias = (
    OpenAIChatCompletionMessageParam
    | CustomChatCompletionMessageParam
    | OpenAIHarmonyMessage
)
282
283


284
# TODO: Make fields ReadOnly once mypy supports it
285
286
287
288
class ConversationMessage(TypedDict, total=False):
    role: Required[str]
    """The role of the message's author."""

289
    content: str | None | list[dict[str, str]]
290
291
    """The contents of the message"""

292
    tool_call_id: str | None
293
294
    """Tool call that this message is responding to."""

295
    name: str | None
296
297
    """The name of the function to call"""

298
    tool_calls: Iterable[ChatCompletionMessageToolCallParam] | None
299
    """The tool calls generated by the model, such as function calls."""
300

301
302
303
304
305
306
    reasoning: str | None
    """The reasoning content for interleaved thinking."""

    reasoning_content: str | None
    """Deprecated: The reasoning content for interleaved thinking."""

307
308
309
    tools: list[ChatCompletionFunctionToolParam] | None
    """The tools for developer role."""

310

311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
# Passed in by user
ChatTemplateContentFormatOption = Literal["auto", "string", "openai"]

# Used internally
_ChatTemplateContentFormat = Literal["string", "openai"]


def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool:
    if isinstance(node, jinja2.nodes.Name):
        return node.ctx == "load" and node.name == varname

    return False


def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool:
    if isinstance(node, jinja2.nodes.Getitem):
327
328
329
330
331
        return (
            _is_var_access(node.node, varname)
            and isinstance(node.arg, jinja2.nodes.Const)
            and node.arg.value == key
        )
332
333
334
335
336
337
338
339
340
341

    if isinstance(node, jinja2.nodes.Getattr):
        return _is_var_access(node.node, varname) and node.attr == key

    return False


def _is_var_or_elems_access(
    node: jinja2.nodes.Node,
    varname: str,
342
    key: str | None = None,
343
344
) -> bool:
    if isinstance(node, jinja2.nodes.Filter):
345
        return node.node is not None and _is_var_or_elems_access(
346
347
            node.node, varname, key
        )
348
349
350
    if isinstance(node, jinja2.nodes.Test):
        return _is_var_or_elems_access(node.node, varname, key)

351
    if isinstance(node, jinja2.nodes.Getitem) and isinstance(
352
353
        node.arg, jinja2.nodes.Slice
    ):
354
355
        return _is_var_or_elems_access(node.node, varname, key)

356
    return _is_attr_access(node, varname, key) if key else _is_var_access(node, varname)
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384


def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str):
    # Global variable that is implicitly defined at the root
    yield root, varname

    # Iterative BFS
    related_varnames = deque([varname])
    while related_varnames:
        related_varname = related_varnames.popleft()

        for assign_ast in root.find_all(jinja2.nodes.Assign):
            lhs = assign_ast.target
            rhs = assign_ast.node

            if _is_var_or_elems_access(rhs, related_varname):
                assert isinstance(lhs, jinja2.nodes.Name)
                yield assign_ast, lhs.name

                # Avoid infinite looping for self-assignment
                if lhs.name != related_varname:
                    related_varnames.append(lhs.name)


# NOTE: The proper way to handle this is to build a CFG so that we can handle
# the scope in which each variable is defined, but that is too complicated
def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node):
    messages_varnames = [
385
        varname for _, varname in _iter_nodes_assign_var_or_elems(root, "messages")
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
    ]

    # Search for {%- for message in messages -%} loops
    for loop_ast in root.find_all(jinja2.nodes.For):
        loop_iter = loop_ast.iter
        loop_target = loop_ast.target

        for varname in messages_varnames:
            if _is_var_or_elems_access(loop_iter, varname):
                assert isinstance(loop_target, jinja2.nodes.Name)
                yield loop_ast, loop_target.name
                break


def _iter_nodes_assign_content_item(root: jinja2.nodes.Node):
    message_varnames = [
        varname for _, varname in _iter_nodes_assign_messages_item(root)
    ]

    # Search for {%- for content in message['content'] -%} loops
    for loop_ast in root.find_all(jinja2.nodes.For):
        loop_iter = loop_ast.iter
        loop_target = loop_ast.target

        for varname in message_varnames:
            if _is_var_or_elems_access(loop_iter, varname, "content"):
                assert isinstance(loop_target, jinja2.nodes.Name)
                yield loop_ast, loop_target.name
                break


417
def _try_extract_ast(chat_template: str) -> jinja2.nodes.Template | None:
418
419
420
421
422
423
424
425
    try:
        jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template)
        return jinja_compiled.environment.parse(chat_template)
    except Exception:
        logger.exception("Error when compiling Jinja template")
        return None


426
@lru_cache(maxsize=32)
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
def _detect_content_format(
    chat_template: str,
    *,
    default: _ChatTemplateContentFormat,
) -> _ChatTemplateContentFormat:
    jinja_ast = _try_extract_ast(chat_template)
    if jinja_ast is None:
        return default

    try:
        next(_iter_nodes_assign_content_item(jinja_ast))
    except StopIteration:
        return "string"
    except Exception:
        logger.exception("Error when parsing AST of Jinja template")
        return default
    else:
        return "openai"


447
def resolve_mistral_chat_template(
448
    chat_template: str | None,
449
    **kwargs: Any,
450
) -> str | None:
451
452
453
454
    if chat_template is not None or kwargs.get("chat_template_kwargs") is not None:
        raise ValueError(
            "'chat_template' or 'chat_template_kwargs' cannot be overridden "
            "for mistral tokenizer."
455
        )
456

457
458
    return None

459

460
_PROCESSOR_CHAT_TEMPLATES = dict[tuple[str, bool], str | None]()
461
462
463
464
465
466
467
468
469
"""
Used in `_try_get_processor_chat_template` to avoid calling
`cached_get_processor` again if the processor fails to be loaded.

This is needed because `lru_cache` does not cache when an exception happens.
"""


def _try_get_processor_chat_template(
470
    tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
471
    model_config: ModelConfig,
472
) -> str | None:
473
    cache_key = (tokenizer.name_or_path, model_config.trust_remote_code)
474
475
476
477
478
479
480
481
482
483
484
    if cache_key in _PROCESSOR_CHAT_TEMPLATES:
        return _PROCESSOR_CHAT_TEMPLATES[cache_key]

    try:
        processor = cached_get_processor(
            tokenizer.name_or_path,
            processor_cls=(
                PreTrainedTokenizer,
                PreTrainedTokenizerFast,
                ProcessorMixin,
            ),
485
            trust_remote_code=model_config.trust_remote_code,
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
        )
        if (
            isinstance(processor, ProcessorMixin)
            and hasattr(processor, "chat_template")
            and (chat_template := processor.chat_template) is not None
        ):
            _PROCESSOR_CHAT_TEMPLATES[cache_key] = chat_template
            return chat_template
    except Exception:
        logger.debug(
            "Failed to load AutoProcessor chat template for %s",
            tokenizer.name_or_path,
            exc_info=True,
        )

    _PROCESSOR_CHAT_TEMPLATES[cache_key] = None
    return None


505
def resolve_hf_chat_template(
506
507
508
    tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
    chat_template: str | None,
    tools: list[dict[str, Any]] | None,
509
510
    *,
    model_config: ModelConfig,
511
) -> str | None:
512
513
514
515
516
517
    # 1st priority: The given chat template
    if chat_template is not None:
        return chat_template

    # 2nd priority: AutoProcessor chat template, unless tool calling is enabled
    if tools is None:
518
        chat_template = _try_get_processor_chat_template(tokenizer, model_config)
519
520
        if chat_template is not None:
            return chat_template
521
522
523
524
525

    # 3rd priority: AutoTokenizer chat template
    try:
        return tokenizer.get_chat_template(chat_template, tools=tools)
    except Exception:
526
527
528
529
530
        logger.debug(
            "Failed to load AutoTokenizer chat template for %s",
            tokenizer.name_or_path,
            exc_info=True,
        )
531

532
    # 4th priority: Predefined fallbacks
533
534
    path = get_chat_template_fallback_path(
        model_type=model_config.hf_config.model_type,
535
        tokenizer_name_or_path=model_config.tokenizer,
536
537
    )
    if path is not None:
538
        logger.info_once(
539
540
541
542
            "Loading chat template fallback for %s as there isn't one "
            "defined on HF Hub.",
            tokenizer.name_or_path,
        )
543
544
        chat_template = load_chat_template(path)
    else:
545
        logger.debug_once(
546
547
            "There is no chat template fallback for %s", tokenizer.name_or_path
        )
548
549

    return chat_template
550
551


552
def _resolve_chat_template_content_format(
553
554
    chat_template: str | None,
    tools: list[dict[str, Any]] | None,
555
    tokenizer: TokenizerLike | None,
556
    *,
557
    model_config: ModelConfig,
558
559
) -> _ChatTemplateContentFormat:
    if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
560
        hf_chat_template = resolve_hf_chat_template(
561
562
563
            tokenizer,
            chat_template=chat_template,
            tools=tools,
564
            model_config=model_config,
565
        )
566
    else:
567
568
        hf_chat_template = None

569
570
571
572
573
    jinja_text = (
        hf_chat_template
        if isinstance(hf_chat_template, str)
        else load_chat_template(chat_template, is_literal=True)
    )
574

575
576
577
578
579
    detected_format = (
        "string"
        if jinja_text is None
        else _detect_content_format(jinja_text, default="string")
    )
580

581
    return detected_format
582
583
584


@lru_cache
585
def _log_chat_template_content_format(
586
    chat_template: str | None,
587
    given_format: ChatTemplateContentFormatOption,
588
589
    detected_format: ChatTemplateContentFormatOption,
):
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
    logger.info(
        "Detected the chat template content format to be '%s'. "
        "You can set `--chat-template-content-format` to override this.",
        detected_format,
    )

    if given_format != "auto" and given_format != detected_format:
        logger.warning(
            "You specified `--chat-template-content-format %s` "
            "which is different from the detected format '%s'. "
            "If our automatic detection is incorrect, please consider "
            "opening a GitHub issue so that we can improve it: "
            "https://github.com/vllm-project/vllm/issues/new/choose",
            given_format,
            detected_format,
        )

607
608

def resolve_chat_template_content_format(
609
610
    chat_template: str | None,
    tools: list[dict[str, Any]] | None,
611
    given_format: ChatTemplateContentFormatOption,
612
    tokenizer: TokenizerLike | None,
613
    *,
614
    model_config: ModelConfig,
615
) -> _ChatTemplateContentFormat:
616
617
618
    if given_format != "auto":
        return given_format

619
620
621
622
    detected_format = _resolve_chat_template_content_format(
        chat_template,
        tools,
        tokenizer,
623
        model_config=model_config,
624
625
626
627
628
629
630
631
    )

    _log_chat_template_content_format(
        chat_template,
        given_format=given_format,
        detected_format=detected_format,
    )

632
    return detected_format
633

634

635
ModalityStr = Literal["image", "audio", "video", "image_embeds", "audio_embeds"]
636
637
638
_T = TypeVar("_T")


639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
def _extract_embeds(tensors: list[torch.Tensor]):
    if len(tensors) == 0:
        return tensors

    if len(tensors) == 1:
        tensors[0]._is_single_item = True  # type: ignore
        return tensors[0]  # To keep backwards compatibility for single item input

    first_shape = tensors[0].shape
    if all(t.shape == first_shape for t in tensors):
        return torch.stack(tensors)

    return tensors


def _get_embeds_data(items_by_modality: dict[str, list[Any]], modality: str):
    embeds_key = f"{modality}_embeds"
    embeds = items_by_modality[embeds_key]

    if len(embeds) == 0:
        return embeds
    if is_list_of(embeds, torch.Tensor):
        return _extract_embeds(embeds)
    if is_list_of(embeds, dict):
        if not embeds:
            return {}

        first_keys = set(embeds[0].keys())
        if any(set(item.keys()) != first_keys for item in embeds[1:]):
            raise ValueError(
                "All dictionaries in the list of embeddings must have the same keys."
            )

        return {k: _extract_embeds([item[k] for item in embeds]) for k in first_keys}

    return embeds


677
class BaseMultiModalItemTracker(ABC, Generic[_T]):
678
679
680
681
682
683
    """
    Tracks multi-modal items in a given request and ensures that the number
    of multi-modal items in a given request does not exceed the configured
    maximum per prompt.
    """

684
    def __init__(self, model_config: ModelConfig):
685
686
        super().__init__()

687
        self._model_config = model_config
688

689
690
        self._items_by_modality = defaultdict[str, list[_T | None]](list)
        self._uuids_by_modality = defaultdict[str, list[str | None]](list)
691

692
    @property
693
694
    def model_config(self) -> ModelConfig:
        return self._model_config
695

696
    @cached_property
697
    def model_cls(self) -> type[SupportsMultiModal]:
698
        from vllm.model_executor.model_loader import get_model_cls
699

700
        model_cls = get_model_cls(self.model_config)
701
        return cast(type[SupportsMultiModal], model_cls)
702

703
704
    @property
    def allowed_local_media_path(self):
705
        return self._model_config.allowed_local_media_path
706

707
708
    @property
    def allowed_media_domains(self):
709
        return self._model_config.allowed_media_domains
710

711
712
713
714
    @property
    def mm_registry(self):
        return MULTIMODAL_REGISTRY

715
716
    @cached_property
    def mm_processor(self):
717
        return self.mm_registry.create_processor(self.model_config)
718

719
    def add(
720
721
        self,
        modality: ModalityStr,
722
723
724
        item: _T | None,
        uuid: str | None = None,
    ) -> str | None:
725
726
727
        """
        Add a multi-modal item to the current prompt and returns the
        placeholder string to use, if any.
728
729

        An optional uuid can be added which serves as a unique identifier of the
730
        media.
731
        """
732
        input_modality = modality.replace("_embeds", "")
733
        num_items = len(self._items_by_modality[modality]) + 1
734

735
        self.mm_processor.validate_num_items(input_modality, num_items)
736

737
        self._items_by_modality[modality].append(item)
738
        self._uuids_by_modality[modality].append(uuid)
739

740
        return self.model_cls.get_placeholder_str(modality, num_items)
741

742
    def all_mm_uuids(self) -> MultiModalUUIDDict | None:
743
744
        if not self._items_by_modality:
            return None
745

746
747
        uuids_by_modality = dict(self._uuids_by_modality)
        if "image" in uuids_by_modality and "image_embeds" in uuids_by_modality:
748
            raise ValueError("Mixing raw image and embedding inputs is not allowed")
749
750
        if "audio" in uuids_by_modality and "audio_embeds" in uuids_by_modality:
            raise ValueError("Mixing raw audio and embedding inputs is not allowed")
751

752
        mm_uuids = {}
753
754
755
756
        if "image_embeds" in uuids_by_modality:
            mm_uuids["image"] = uuids_by_modality["image_embeds"]
        if "image" in uuids_by_modality:
            mm_uuids["image"] = uuids_by_modality["image"]  # UUIDs of images
757
758
        if "audio_embeds" in uuids_by_modality:
            mm_uuids["audio"] = uuids_by_modality["audio_embeds"]
759
760
761
762
        if "audio" in uuids_by_modality:
            mm_uuids["audio"] = uuids_by_modality["audio"]  # UUIDs of audios
        if "video" in uuids_by_modality:
            mm_uuids["video"] = uuids_by_modality["video"]  # UUIDs of videos
763

764
765
        return mm_uuids

766
767
768
769
770
    @abstractmethod
    def create_parser(self) -> "BaseMultiModalContentParser":
        raise NotImplementedError


771
class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
772
    def all_mm_data(self) -> MultiModalDataDict | None:
773
774
        if not self._items_by_modality:
            return None
775

776
777
        items_by_modality = dict(self._items_by_modality)
        if "image" in items_by_modality and "image_embeds" in items_by_modality:
778
            raise ValueError("Mixing raw image and embedding inputs is not allowed")
779
780
        if "audio" in items_by_modality and "audio_embeds" in items_by_modality:
            raise ValueError("Mixing raw audio and embedding inputs is not allowed")
781

782
        mm_inputs = {}
783
        if "image_embeds" in items_by_modality:
784
            mm_inputs["image"] = _get_embeds_data(items_by_modality, "image")
785
        if "image" in items_by_modality:
786
            mm_inputs["image"] = items_by_modality["image"]  # A list of images
787
        if "audio_embeds" in items_by_modality:
788
            mm_inputs["audio"] = _get_embeds_data(items_by_modality, "audio")
789
        if "audio" in items_by_modality:
790
            mm_inputs["audio"] = items_by_modality["audio"]  # A list of audios
791
        if "video" in items_by_modality:
792
            mm_inputs["video"] = items_by_modality["video"]  # A list of videos
793

794
        return mm_inputs
795
796
797
798
799

    def create_parser(self) -> "BaseMultiModalContentParser":
        return MultiModalContentParser(self)


800
class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
801
    async def all_mm_data(self) -> MultiModalDataDict | None:
802
803
        if not self._items_by_modality:
            return None
804

805
806
807
808
809
810
811
812
        coros_by_modality = {
            modality: [item or asyncio.sleep(0) for item in items]
            for modality, items in self._items_by_modality.items()
        }
        items_by_modality: dict[str, list[object | None]] = {
            modality: await asyncio.gather(*coros)
            for modality, coros in coros_by_modality.items()
        }
813
        if "image" in items_by_modality and "image_embeds" in items_by_modality:
814
            raise ValueError("Mixing raw image and embedding inputs is not allowed")
815
816
        if "audio" in items_by_modality and "audio_embeds" in items_by_modality:
            raise ValueError("Mixing raw audio and embedding inputs is not allowed")
817

818
        mm_inputs = {}
819
        if "image_embeds" in items_by_modality:
820
            mm_inputs["image"] = _get_embeds_data(items_by_modality, "image")
821
        if "image" in items_by_modality:
822
            mm_inputs["image"] = items_by_modality["image"]  # A list of images
823
        if "audio_embeds" in items_by_modality:
824
            mm_inputs["audio"] = _get_embeds_data(items_by_modality, "audio")
825
        if "audio" in items_by_modality:
826
            mm_inputs["audio"] = items_by_modality["audio"]  # A list of audios
827
        if "video" in items_by_modality:
828
            mm_inputs["video"] = items_by_modality["video"]  # A list of videos
829

830
        return mm_inputs
831
832
833
834
835
836
837
838
839

    def create_parser(self) -> "BaseMultiModalContentParser":
        return AsyncMultiModalContentParser(self)


class BaseMultiModalContentParser(ABC):
    def __init__(self) -> None:
        super().__init__()

840
        # stores model placeholders list with corresponding
841
842
843
844
845
846
847
        # general MM placeholder:
        # {
        #   "<##IMAGE##>": ["<image>", "<image>", "<image>"],
        #   "<##AUDIO##>": ["<audio>", "<audio>"]
        # }
        self._placeholder_storage: dict[str, list] = defaultdict(list)

848
    def _add_placeholder(self, modality: ModalityStr, placeholder: str | None):
849
        mod_placeholder = MODALITY_PLACEHOLDERS_MAP[modality]
850
        if placeholder:
851
            self._placeholder_storage[mod_placeholder].append(placeholder)
852

853
854
    def mm_placeholder_storage(self) -> dict[str, list]:
        return dict(self._placeholder_storage)
855
856

    @abstractmethod
857
    def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
858
859
        raise NotImplementedError

860
    @abstractmethod
861
    def parse_image_embeds(
862
        self,
863
864
        image_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
865
    ) -> None:
866
867
        raise NotImplementedError

868
    @abstractmethod
869
    def parse_image_pil(
870
        self, image_pil: Image.Image | None, uuid: str | None = None
871
    ) -> None:
872
873
        raise NotImplementedError

874
    @abstractmethod
875
    def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None:
876
877
        raise NotImplementedError

878
    @abstractmethod
879
    def parse_input_audio(
880
        self, input_audio: InputAudio | None, uuid: str | None = None
881
    ) -> None:
882
883
        raise NotImplementedError

884
885
886
887
888
889
890
891
    @abstractmethod
    def parse_audio_embeds(
        self,
        audio_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
    ) -> None:
        raise NotImplementedError

892
    @abstractmethod
893
    def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
894
895
        raise NotImplementedError

896
897
898
899
900
901

class MultiModalContentParser(BaseMultiModalContentParser):
    def __init__(self, tracker: MultiModalItemTracker) -> None:
        super().__init__()

        self._tracker = tracker
902
903
904
        multimodal_config = self._tracker.model_config.multimodal_config
        media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None)

905
906
        self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load(
            envs.VLLM_MEDIA_CONNECTOR,
907
            media_io_kwargs=media_io_kwargs,
908
            allowed_local_media_path=tracker.allowed_local_media_path,
909
            allowed_media_domains=tracker.allowed_media_domains,
910
911
        )

912
913
    @property
    def model_config(self) -> ModelConfig:
914
        return self._tracker.model_config
915

916
    def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
917
        image = self._connector.fetch_image(image_url) if image_url else None
918

919
        placeholder = self._tracker.add("image", image, uuid)
920
        self._add_placeholder("image", placeholder)
921

922
    def parse_image_embeds(
923
        self,
924
925
        image_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
926
    ) -> None:
927
928
929
930
931
932
        mm_config = self.model_config.get_multimodal_config()
        if not mm_config.enable_mm_embeds:
            raise ValueError(
                "You must set `--enable-mm-embeds` to input `image_embeds`"
            )

933
934
935
936
937
        if isinstance(image_embeds, dict):
            embeds = {
                k: self._connector.fetch_image_embedding(v)
                for k, v in image_embeds.items()
            }
938
            placeholder = self._tracker.add("image_embeds", embeds, uuid)
939
940
941

        if isinstance(image_embeds, str):
            embedding = self._connector.fetch_image_embedding(image_embeds)
942
            placeholder = self._tracker.add("image_embeds", embedding, uuid)
943

944
945
946
        if image_embeds is None:
            placeholder = self._tracker.add("image_embeds", None, uuid)

947
        self._add_placeholder("image", placeholder)
948

949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
    def parse_audio_embeds(
        self,
        audio_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
    ) -> None:
        mm_config = self.model_config.get_multimodal_config()
        if not mm_config.enable_mm_embeds:
            raise ValueError(
                "You must set `--enable-mm-embeds` to input `audio_embeds`"
            )

        if isinstance(audio_embeds, dict):
            embeds = {
                k: self._connector.fetch_audio_embedding(v)
                for k, v in audio_embeds.items()
            }
            placeholder = self._tracker.add("audio_embeds", embeds, uuid)
        elif isinstance(audio_embeds, str):
            embedding = self._connector.fetch_audio_embedding(audio_embeds)
            placeholder = self._tracker.add("audio_embeds", embedding, uuid)
        else:
            placeholder = self._tracker.add("audio_embeds", None, uuid)

        self._add_placeholder("audio", placeholder)

974
    def parse_image_pil(
975
        self, image_pil: Image.Image | None, uuid: str | None = None
976
977
    ) -> None:
        placeholder = self._tracker.add("image", image_pil, uuid)
978
        self._add_placeholder("image", placeholder)
979

980
    def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None:
981
        audio = self._connector.fetch_audio(audio_url) if audio_url else None
982

983
        placeholder = self._tracker.add("audio", audio, uuid)
984
        self._add_placeholder("audio", placeholder)
985

986
    def parse_input_audio(
987
        self, input_audio: InputAudio | None, uuid: str | None = None
988
    ) -> None:
989
990
991
992
993
994
995
996
997
998
        if input_audio:
            audio_data = input_audio.get("data", "")
            audio_format = input_audio.get("format", "")
            if audio_data:
                audio_url = f"data:audio/{audio_format};base64,{audio_data}"
            else:
                # If a UUID is provided, audio data may be empty.
                audio_url = None
        else:
            audio_url = None
999

1000
        return self.parse_audio(audio_url, uuid)
1001

1002
    def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
1003
        video = self._connector.fetch_video(video_url=video_url) if video_url else None
1004

1005
        placeholder = self._tracker.add("video", video, uuid)
1006
        self._add_placeholder("video", placeholder)
1007

1008
1009
1010
1011
1012
1013

class AsyncMultiModalContentParser(BaseMultiModalContentParser):
    def __init__(self, tracker: AsyncMultiModalItemTracker) -> None:
        super().__init__()

        self._tracker = tracker
1014
1015
        multimodal_config = self._tracker.model_config.multimodal_config
        media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None)
1016
1017
        self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load(
            envs.VLLM_MEDIA_CONNECTOR,
1018
            media_io_kwargs=media_io_kwargs,
1019
            allowed_local_media_path=tracker.allowed_local_media_path,
1020
            allowed_media_domains=tracker.allowed_media_domains,
1021
        )
1022

1023
1024
    @property
    def model_config(self) -> ModelConfig:
1025
        return self._tracker.model_config
1026

1027
    def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
1028
        image_coro = self._connector.fetch_image_async(image_url) if image_url else None
1029

1030
        placeholder = self._tracker.add("image", image_coro, uuid)
1031
        self._add_placeholder("image", placeholder)
1032

1033
    def parse_image_embeds(
1034
        self,
1035
1036
        image_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
1037
    ) -> None:
1038
1039
1040
1041
1042
1043
        mm_config = self.model_config.get_multimodal_config()
        if not mm_config.enable_mm_embeds:
            raise ValueError(
                "You must set `--enable-mm-embeds` to input `image_embeds`"
            )

1044
        future: asyncio.Future[str | dict[str, str] | None] = asyncio.Future()
1045
1046
1047
1048
1049
1050
1051
1052
1053

        if isinstance(image_embeds, dict):
            embeds = {
                k: self._connector.fetch_image_embedding(v)
                for k, v in image_embeds.items()
            }
            future.set_result(embeds)

        if isinstance(image_embeds, str):
1054
            embedding = self._connector.fetch_image_embedding(image_embeds)
1055
1056
            future.set_result(embedding)

1057
1058
1059
        if image_embeds is None:
            future.set_result(None)

1060
        placeholder = self._tracker.add("image_embeds", future, uuid)
1061
        self._add_placeholder("image", placeholder)
1062

1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
    def parse_audio_embeds(
        self,
        audio_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
    ) -> None:
        mm_config = self.model_config.get_multimodal_config()
        if not mm_config.enable_mm_embeds:
            raise ValueError(
                "You must set `--enable-mm-embeds` to input `audio_embeds`"
            )

        logger.info(
            "🎵 Parsing audio_embeds: type=%s, uuid=%s, is_dict=%s, "
            "is_str=%s, is_none=%s",
            type(audio_embeds).__name__,
            uuid,
            isinstance(audio_embeds, dict),
            isinstance(audio_embeds, str),
            audio_embeds is None,
        )

        future: asyncio.Future[str | dict[str, str] | None] = asyncio.Future()

        if isinstance(audio_embeds, dict):
            logger.info(
                "🎵 Processing dict audio_embeds with %d entries",
                len(audio_embeds),
            )
            embeds = {
                k: self._connector.fetch_audio_embedding(v)
                for k, v in audio_embeds.items()
            }
            future.set_result(embeds)
            logger.info(
                "🎵 Successfully loaded %d audio embeddings from dict",
                len(embeds),
            )

        if isinstance(audio_embeds, str):
            base64_size = len(audio_embeds)
            logger.info(
                "🎵 Processing base64 audio_embeds: %d chars (%.2f KB)",
                base64_size,
                base64_size / 1024,
            )
            embedding = self._connector.fetch_audio_embedding(audio_embeds)
            future.set_result(embedding)
            logger.info(
                "🎵 Successfully loaded audio embedding tensor: shape=%s, dtype=%s",
                embedding.shape,
                embedding.dtype,
            )

        if audio_embeds is None:
            logger.info("🎵 Audio embeds is None (UUID-only reference)")
            future.set_result(None)

        placeholder = self._tracker.add("audio_embeds", future, uuid)
        self._add_placeholder("audio", placeholder)
        logger.info("🎵 Added audio_embeds placeholder with uuid=%s", uuid)

1124
    def parse_image_pil(
1125
        self, image_pil: Image.Image | None, uuid: str | None = None
1126
    ) -> None:
1127
        future: asyncio.Future[Image.Image | None] = asyncio.Future()
1128
1129
1130
1131
        if image_pil:
            future.set_result(image_pil)
        else:
            future.set_result(None)
1132

1133
        placeholder = self._tracker.add("image", future, uuid)
1134
        self._add_placeholder("image", placeholder)
1135

1136
    def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None:
1137
        audio_coro = self._connector.fetch_audio_async(audio_url) if audio_url else None
1138

1139
        placeholder = self._tracker.add("audio", audio_coro, uuid)
1140
        self._add_placeholder("audio", placeholder)
1141

1142
    def parse_input_audio(
1143
        self, input_audio: InputAudio | None, uuid: str | None = None
1144
    ) -> None:
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
        if input_audio:
            audio_data = input_audio.get("data", "")
            audio_format = input_audio.get("format", "")
            if audio_data:
                audio_url = f"data:audio/{audio_format};base64,{audio_data}"
            else:
                # If a UUID is provided, audio data may be empty.
                audio_url = None
        else:
            audio_url = None
1155

1156
        return self.parse_audio(audio_url, uuid)
1157

1158
    def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
1159
1160
1161
1162
1163
        video = (
            self._connector.fetch_video_async(video_url=video_url)
            if video_url
            else None
        )
1164

1165
        placeholder = self._tracker.add("video", video, uuid)
1166
        self._add_placeholder("video", placeholder)
1167

1168

1169
def validate_chat_template(chat_template: Path | str | None):
1170
1171
1172
1173
1174
    """Raises if the provided chat template appears invalid."""
    if chat_template is None:
        return

    elif isinstance(chat_template, Path) and not chat_template.exists():
1175
        raise FileNotFoundError("the supplied chat template path doesn't exist")
1176
1177
1178

    elif isinstance(chat_template, str):
        JINJA_CHARS = "{}\n"
1179
1180
1181
1182
        if (
            not any(c in chat_template for c in JINJA_CHARS)
            and not Path(chat_template).exists()
        ):
1183
1184
1185
            # Try to find the template in the built-in templates directory
            from vllm.transformers_utils.chat_templates.registry import (
                CHAT_TEMPLATES_DIR,
1186
            )
1187

1188
1189
1190
1191
1192
1193
1194
1195
            builtin_template_path = CHAT_TEMPLATES_DIR / chat_template
            if not builtin_template_path.exists():
                raise ValueError(
                    f"The supplied chat template string ({chat_template}) "
                    f"appears path-like, but doesn't exist! "
                    f"Tried: {chat_template} and {builtin_template_path}"
                )

1196
    else:
1197
        raise TypeError(f"{type(chat_template)} is not a valid chat template type")
1198
1199


1200
def _load_chat_template(
1201
    chat_template: Path | str | None,
1202
1203
    *,
    is_literal: bool = False,
1204
) -> str | None:
1205
1206
    if chat_template is None:
        return None
1207
1208
1209

    if is_literal:
        if isinstance(chat_template, Path):
1210
1211
1212
            raise TypeError(
                "chat_template is expected to be read directly from its value"
            )
1213

1214
        return chat_template
1215

1216
    try:
1217
        with open(chat_template) as f:
1218
            return f.read()
1219
    except OSError as e:
1220
1221
1222
        if isinstance(chat_template, Path):
            raise

1223
1224
        JINJA_CHARS = "{}\n"
        if not any(c in chat_template for c in JINJA_CHARS):
1225
1226
1227
            # Try to load from the built-in templates directory
            from vllm.transformers_utils.chat_templates.registry import (
                CHAT_TEMPLATES_DIR,
1228
            )
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241

            builtin_template_path = CHAT_TEMPLATES_DIR / chat_template
            try:
                with open(builtin_template_path) as f:
                    return f.read()
            except OSError:
                msg = (
                    f"The supplied chat template ({chat_template}) "
                    f"looks like a file path, but it failed to be opened. "
                    f"Tried: {chat_template} and {builtin_template_path}. "
                    f"Reason: {e}"
                )
                raise ValueError(msg) from e
1242

1243
1244
        # If opening a file fails, set chat template to be args to
        # ensure we decode so our escape are interpreted correctly
1245
1246
1247
1248
1249
1250
1251
        return _load_chat_template(chat_template, is_literal=True)


_cached_load_chat_template = lru_cache(_load_chat_template)


def load_chat_template(
1252
    chat_template: Path | str | None,
1253
1254
    *,
    is_literal: bool = False,
1255
) -> str | None:
1256
    return _cached_load_chat_template(chat_template, is_literal=is_literal)
1257
1258


1259
1260
1261
def _get_interleaved_text_prompt(
    placeholder_storage: dict[str, list], texts: list[str]
) -> str:
1262
1263
1264
1265
1266
1267
1268
    for idx, elem in enumerate(texts):
        if elem in placeholder_storage:
            texts[idx] = placeholder_storage[elem].pop(0)

    return "\n".join(texts)


1269
# TODO: Let user specify how to insert multimodal tokens into prompt
1270
# (similar to chat template)
1271
1272
1273
1274
1275
def _get_full_multimodal_text_prompt(
    placeholder_storage: dict[str, list],
    texts: list[str],
    interleave_strings: bool,
) -> str:
1276
    """Combine multimodal prompts for a multimodal language model."""
1277

1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
    # flatten storage to make it looks like
    # {
    #   "<|image|>": 2,
    #   "<|audio|>": 1
    # }
    placeholder_counts = Counter(
        [v for elem in placeholder_storage.values() for v in elem]
    )

    if interleave_strings:
        text_prompt = _get_interleaved_text_prompt(placeholder_storage, texts)
    else:
        text_prompt = "\n".join(texts)

    # Pass interleaved text further in case the user used image placeholders
    # himself, but forgot to disable the 'interleave_strings' flag

1295
    # Look through the text prompt to check for missing placeholders
1296
    missing_placeholders: list[str] = []
1297
1298
1299
1300
1301
    for placeholder in placeholder_counts:
        # For any existing placeholder in the text prompt, we leave it as is
        placeholder_counts[placeholder] -= text_prompt.count(placeholder)

        if placeholder_counts[placeholder] < 0:
1302
1303
1304
1305
            logger.error(
                "Placeholder count is negative! "
                "Ensure that the 'interleave_strings' flag is disabled "
                "(current value: %s) "
1306
1307
                "when manually placing image placeholders.",
                interleave_strings,
1308
1309
            )
            logger.debug("Input prompt: %s", text_prompt)
1310
1311
            raise ValueError(
                f"Found more '{placeholder}' placeholders in input prompt than "
1312
1313
                "actual multimodal data items."
            )
1314

1315
        missing_placeholders.extend([placeholder] * placeholder_counts[placeholder])
1316

1317
1318
    # NOTE: Default behaviour: we always add missing placeholders
    # at the front of the prompt, if interleave_strings=False
1319
    return "\n".join(missing_placeholders + [text_prompt])
1320
1321


1322
1323
# No need to validate using Pydantic again
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
1324
_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
1325
_AudioEmbedsParser = partial(cast, ChatCompletionContentPartAudioEmbedsParam)
1326
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
1327
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
1328
_PILImageParser = partial(cast, CustomChatCompletionContentPILImageParam)
Julien Denize's avatar
Julien Denize committed
1329
_ThinkParser = partial(cast, CustomThinkCompletionContentParam)
1330
1331
1332
1333
# Need to validate url objects
_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam).validate_python
_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python
_VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python
1334

1335
_ResponsesInputImageParser = TypeAdapter(ResponseInputImageParam).validate_python
1336
_ContentPart: TypeAlias = str | dict[str, str] | InputAudio | PILImage
1337

1338
# Define a mapping from part types to their corresponding parsing functions.
1339
MM_PARSER_MAP: dict[
1340
1341
1342
    str,
    Callable[[ChatCompletionContentPartParam], _ContentPart],
] = {
1343
1344
1345
    "text": lambda part: _TextParser(part).get("text", None),
    "thinking": lambda part: _ThinkParser(part).get("thinking", None),
    "input_text": lambda part: _TextParser(part).get("text", None),
1346
    "output_text": lambda part: _TextParser(part).get("text", None),
1347
1348
1349
    "input_image": lambda part: _ResponsesInputImageParser(part).get("image_url", None),
    "image_url": lambda part: _ImageParser(part).get("image_url", {}).get("url", None),
    "image_embeds": lambda part: _ImageEmbedsParser(part).get("image_embeds", None),
1350
    "audio_embeds": lambda part: _AudioEmbedsParser(part).get("audio_embeds", None),
1351
    "image_pil": lambda part: _PILImageParser(part).get("image_pil", None),
1352
1353
    "audio_url": lambda part: _AudioParser(part).get("audio_url", {}).get("url", None),
    "input_audio": lambda part: _InputAudioParser(part).get("input_audio", None),
1354
    "refusal": lambda part: _RefusalParser(part).get("refusal", None),
1355
    "video_url": lambda part: _VideoParser(part).get("video_url", {}).get("url", None),
1356
1357
1358
1359
}


def _parse_chat_message_content_mm_part(
1360
1361
    part: ChatCompletionContentPartParam,
) -> tuple[str, _ContentPart]:
1362
    """
1363
    Parses a given multi-modal content part based on its type.
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376

    Args:
        part: A dict containing the content part, with a potential 'type' field.

    Returns:
        A tuple (part_type, content) where:
        - part_type: Type of the part (e.g., 'text', 'image_url').
        - content: Parsed content (e.g., text, image URL).

    Raises:
        ValueError: If the 'type' field is missing and no direct URL is found.
    """
    assert isinstance(
1377
1378
        part, dict
    )  # This is needed to avoid mypy errors: part.get() from str
1379
    part_type = part.get("type", None)
1380
    uuid = part.get("uuid", None)
1381

1382
    if isinstance(part_type, str) and part_type in MM_PARSER_MAP and uuid is None:  # noqa: E501
1383
1384
1385
        content = MM_PARSER_MAP[part_type](part)

        # Special case for 'image_url.detail'
1386
1387
        # We only support 'auto', which is the default
        if part_type == "image_url" and part.get("detail", "auto") != "auto":
1388
            logger.warning(
1389
                "'image_url.detail' is currently not supported and will be ignored."
1390
            )
1391
1392
1393
1394

        return part_type, content

    # Handle missing 'type' but provided direct URL fields.
1395
    # 'type' is required field by pydantic
1396
1397
    if part_type is None or uuid is not None:
        if "image_url" in part:
1398
            image_params = cast(CustomChatCompletionContentSimpleImageParam, part)
1399
1400
1401
1402
1403
1404
1405
1406
            image_url = image_params.get("image_url", None)
            if isinstance(image_url, dict):
                # Can potentially happen if user provides a uuid
                # with url as a dict of {"url": url}
                image_url = image_url.get("url", None)
            return "image_url", image_url
        if "image_pil" in part:
            # "image_pil" could be None if UUID is provided.
1407
            image_params = cast(  # type: ignore
1408
1409
1410
1411
1412
1413
                CustomChatCompletionContentPILImageParam, part
            )
            image_pil = image_params.get("image_pil", None)
            return "image_pil", image_pil
        if "image_embeds" in part:
            # "image_embeds" could be None if UUID is provided.
1414
            image_params = cast(  # type: ignore
1415
1416
1417
1418
                ChatCompletionContentPartImageEmbedsParam, part
            )
            image_embeds = image_params.get("image_embeds", None)
            return "image_embeds", image_embeds
1419
1420
1421
1422
1423
1424
1425
        if "audio_embeds" in part:
            # "audio_embeds" could be None if UUID is provided.
            audio_params = cast(  # type: ignore[assignment]
                ChatCompletionContentPartAudioEmbedsParam, part
            )
            audio_embeds = audio_params.get("audio_embeds", None)
            return "audio_embeds", audio_embeds
1426
        if "audio_url" in part:
1427
1428
1429
            audio_params = cast(  # type: ignore[assignment]
                CustomChatCompletionContentSimpleAudioParam, part
            )
1430
1431
1432
1433
1434
1435
            audio_url = audio_params.get("audio_url", None)
            if isinstance(audio_url, dict):
                # Can potentially happen if user provides a uuid
                # with url as a dict of {"url": url}
                audio_url = audio_url.get("url", None)
            return "audio_url", audio_url
1436
        if part.get("input_audio") is not None:
1437
            input_audio_params = cast(dict[str, str], part)
1438
            return "input_audio", input_audio_params
1439
        if "video_url" in part:
1440
            video_params = cast(CustomChatCompletionContentSimpleVideoParam, part)
1441
1442
1443
1444
1445
1446
            video_url = video_params.get("video_url", None)
            if isinstance(video_url, dict):
                # Can potentially happen if user provides a uuid
                # with url as a dict of {"url": url}
                video_url = video_url.get("url", None)
            return "video_url", video_url
1447
1448
1449
1450
1451
1452
1453
1454
        # Raise an error if no 'type' or direct URL is found.
        raise ValueError("Missing 'type' field in multimodal part.")

    if not isinstance(part_type, str):
        raise ValueError("Invalid 'type' field in multimodal part.")
    return part_type, "unknown part_type content"


1455
PART_TYPES_TO_SKIP_NONE_CONTENT = (
1456
1457
1458
    "text",
    "refusal",
)
1459

1460

1461
1462
1463
def _parse_chat_message_content_parts(
    role: str,
    parts: Iterable[ChatCompletionContentPartParam],
1464
    mm_tracker: BaseMultiModalItemTracker,
1465
1466
    *,
    wrap_dicts: bool,
1467
    interleave_strings: bool,
1468
) -> list[ConversationMessage]:
1469
    content = list[_ContentPart]()
1470

1471
    mm_parser = mm_tracker.create_parser()
1472
1473

    for part in parts:
1474
        parse_res = _parse_chat_message_content_part(
1475
1476
1477
            part,
            mm_parser,
            wrap_dicts=wrap_dicts,
1478
            interleave_strings=interleave_strings,
1479
        )
1480
1481
        if parse_res:
            content.append(parse_res)
1482

1483
    if wrap_dicts:
1484
        # Parsing wraps images and texts as interleaved dictionaries
1485
        return [ConversationMessage(role=role, content=content)]  # type: ignore
1486
    texts = cast(list[str], content)
1487
1488
    mm_placeholder_storage = mm_parser.mm_placeholder_storage()
    if mm_placeholder_storage:
1489
1490
1491
        text_prompt = _get_full_multimodal_text_prompt(
            mm_placeholder_storage, texts, interleave_strings
        )
1492
1493
1494
    else:
        text_prompt = "\n".join(texts)

1495
1496
1497
1498
    return [ConversationMessage(role=role, content=text_prompt)]


def _parse_chat_message_content_part(
1499
1500
1501
1502
    part: ChatCompletionContentPartParam,
    mm_parser: BaseMultiModalContentParser,
    *,
    wrap_dicts: bool,
1503
    interleave_strings: bool,
1504
) -> _ContentPart | None:
1505
1506
1507
1508
1509
1510
1511
1512
    """Parses a single part of a conversation. If wrap_dicts is True,
    structured dictionary pieces for texts and images will be
    wrapped in dictionaries, i.e., {"type": "text", "text", ...} and
    {"type": "image"}, respectively. Otherwise multimodal data will be
    handled by mm_parser, and texts will be returned as strings to be joined
    with multimodal placeholders.
    """
    if isinstance(part, str):  # Handle plain text parts
1513
        return part
1514
1515
    # Handle structured dictionary parts
    part_type, content = _parse_chat_message_content_mm_part(part)
1516
    # if part_type is text/refusal/image_url/audio_url/video_url/input_audio but
1517
    # content is None, log a warning and skip
1518
    if part_type in PART_TYPES_TO_SKIP_NONE_CONTENT and content is None:
1519
        logger.warning(
1520
            "Skipping multimodal part '%s' (type: '%s') "
1521
1522
1523
1524
            "with empty / unparsable content.",
            part,
            part_type,
        )
1525
1526
        return None

1527
    if part_type in ("text", "input_text", "output_text", "refusal", "thinking"):
1528
1529
        str_content = cast(str, content)
        if wrap_dicts:
1530
            return {"type": "text", "text": str_content}
1531
1532
        else:
            return str_content
1533

1534
1535
1536
1537
1538
1539
    # For media items, if a user has provided one, use it. Otherwise, insert
    # a placeholder empty uuid.
    uuid = part.get("uuid", None)
    if uuid is not None:
        uuid = str(uuid)

1540
    modality = None
1541
    if part_type == "image_pil":
1542
        image_content = cast(Image.Image, content) if content is not None else None
1543
        mm_parser.parse_image_pil(image_content, uuid)
1544
        modality = "image"
1545
    elif part_type in ("image_url", "input_image"):
1546
        str_content = cast(str, content)
1547
        mm_parser.parse_image(str_content, uuid)
1548
1549
        modality = "image"
    elif part_type == "image_embeds":
1550
        content = cast(str | dict[str, str], content) if content is not None else None
1551
        mm_parser.parse_image_embeds(content, uuid)
1552
        modality = "image"
1553
1554
1555
1556
    elif part_type == "audio_embeds":
        content = cast(str | dict[str, str], content) if content is not None else None
        mm_parser.parse_audio_embeds(content, uuid)
        modality = "audio"
1557
    elif part_type == "audio_url":
1558
        str_content = cast(str, content)
1559
        mm_parser.parse_audio(str_content, uuid)
1560
1561
        modality = "audio"
    elif part_type == "input_audio":
1562
        dict_content = cast(InputAudio, content)
1563
        mm_parser.parse_input_audio(dict_content, uuid)
1564
1565
        modality = "audio"
    elif part_type == "video_url":
1566
        str_content = cast(str, content)
1567
        mm_parser.parse_video(str_content, uuid)
1568
1569
1570
        modality = "video"
    else:
        raise NotImplementedError(f"Unknown part type: {part_type}")
1571

1572
1573
1574
    return (
        {"type": modality}
        if wrap_dicts
1575
        else (MODALITY_PLACEHOLDERS_MAP[modality] if interleave_strings else None)
1576
    )
1577
1578


1579
1580
1581
1582
1583
# No need to validate using Pydantic again
_AssistantParser = partial(cast, ChatCompletionAssistantMessageParam)
_ToolParser = partial(cast, ChatCompletionToolMessageParam)


1584
def _parse_chat_message_content(
1585
1586
    message: ChatCompletionMessageParam,
    mm_tracker: BaseMultiModalItemTracker,
1587
    content_format: _ChatTemplateContentFormat,
1588
    interleave_strings: bool,
1589
) -> list[ConversationMessage]:
1590
1591
    role = message["role"]
    content = message.get("content")
1592
    reasoning = message.get("reasoning") or message.get("reasoning_content")
1593

1594
    if content is None:
1595
1596
        content = []
    elif isinstance(content, str):
1597
        content = [ChatCompletionContentPartTextParam(type="text", text=content)]
1598
    result = _parse_chat_message_content_parts(
1599
1600
        role,
        content,  # type: ignore
1601
        mm_tracker,
1602
        wrap_dicts=(content_format == "openai"),
1603
        interleave_strings=interleave_strings,
1604
    )
1605

1606
    for result_msg in result:
1607
        if role == "assistant":
1608
1609
            parsed_msg = _AssistantParser(message)

1610
1611
1612
            # The 'tool_calls' is not None check ensures compatibility.
            # It's needed only if downstream code doesn't strictly
            # follow the OpenAI spec.
1613
            if "tool_calls" in parsed_msg and parsed_msg["tool_calls"] is not None:
1614
                result_msg["tool_calls"] = list(parsed_msg["tool_calls"])
1615
1616
1617
1618
1619
1620
            # Include reasoning if present for interleaved thinking.
            if reasoning is not None:
                result_msg["reasoning"] = cast(str, reasoning)
                result_msg["reasoning_content"] = cast(
                    str, reasoning
                )  # keep compatibility
1621
1622
1623
1624
1625
1626
1627
1628
        elif role == "tool":
            parsed_msg = _ToolParser(message)
            if "tool_call_id" in parsed_msg:
                result_msg["tool_call_id"] = parsed_msg["tool_call_id"]

        if "name" in message and isinstance(message["name"], str):
            result_msg["name"] = message["name"]

1629
1630
        if role == "developer":
            result_msg["tools"] = message.get("tools", None)
1631
1632
    return result

1633

1634
def _postprocess_messages(messages: list[ConversationMessage]) -> None:
1635
1636
1637
1638
1639
1640
    # per the Transformers docs & maintainers, tool call arguments in
    # assistant-role messages with tool_calls need to be dicts not JSON str -
    # this is how tool-use chat templates will expect them moving forwards
    # so, for messages that have tool_calls, parse the string (which we get
    # from openAI format) to dict
    for message in messages:
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
        if message["role"] == "assistant" and "tool_calls" in message:
            tool_calls = message.get("tool_calls")
            if not isinstance(tool_calls, list):
                continue

            if len(tool_calls) == 0:
                # Drop empty tool_calls to keep templates on the normal assistant path.
                message.pop("tool_calls", None)
                continue

            for item in tool_calls:
1652
1653
                # if arguments is None or empty string, set to {}
                if content := item["function"].get("arguments"):
1654
1655
                    if not isinstance(content, (dict, list)):
                        item["function"]["arguments"] = json.loads(content)
1656
1657
                else:
                    item["function"]["arguments"] = {}
1658
1659


1660
def parse_chat_messages(
1661
    messages: list[ChatCompletionMessageParam],
1662
    model_config: ModelConfig,
1663
    content_format: _ChatTemplateContentFormat,
1664
1665
) -> tuple[
    list[ConversationMessage],
1666
1667
    MultiModalDataDict | None,
    MultiModalUUIDDict | None,
1668
]:
1669
    conversation: list[ConversationMessage] = []
1670
    mm_tracker = MultiModalItemTracker(model_config)
1671
1672

    for msg in messages:
1673
1674
1675
        sub_messages = _parse_chat_message_content(
            msg,
            mm_tracker,
1676
            content_format,
1677
1678
1679
1680
            interleave_strings=(
                content_format == "string"
                and model_config.multimodal_config is not None
                and model_config.multimodal_config.interleave_mm_strings
1681
            ),
1682
        )
1683

1684
        conversation.extend(sub_messages)
1685

1686
1687
    _postprocess_messages(conversation)

1688
    return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
1689
1690


1691
def parse_chat_messages_futures(
1692
    messages: list[ChatCompletionMessageParam],
1693
    model_config: ModelConfig,
1694
    content_format: _ChatTemplateContentFormat,
1695
1696
) -> tuple[
    list[ConversationMessage],
1697
1698
    Awaitable[MultiModalDataDict | None],
    MultiModalUUIDDict | None,
1699
]:
1700
    conversation: list[ConversationMessage] = []
1701
    mm_tracker = AsyncMultiModalItemTracker(model_config)
1702
1703

    for msg in messages:
1704
1705
1706
        sub_messages = _parse_chat_message_content(
            msg,
            mm_tracker,
1707
            content_format,
1708
1709
1710
1711
            interleave_strings=(
                content_format == "string"
                and model_config.multimodal_config is not None
                and model_config.multimodal_config.interleave_mm_strings
1712
            ),
1713
        )
1714
1715
1716

        conversation.extend(sub_messages)

1717
1718
    _postprocess_messages(conversation)

1719
    return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
1720
1721


1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
# adapted from https://github.com/huggingface/transformers/blob/v4.56.2/src/transformers/utils/chat_template_utils.py#L398-L412
# only preserve the parse function used to resolve chat template kwargs
class AssistantTracker(jinja2.ext.Extension):
    tags = {"generation"}

    def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock:
        lineno = next(parser.stream).lineno
        body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
        call = self.call_method("_generation_support")
        call_block = jinja2.nodes.CallBlock(call, [], [], body)
        return call_block.set_lineno(lineno)


1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
def _resolve_chat_template_kwargs(
    chat_template: str,
):
    env = jinja2.sandbox.ImmutableSandboxedEnvironment(
        trim_blocks=True,
        lstrip_blocks=True,
        extensions=[AssistantTracker, jinja2.ext.loopcontrols],
    )
    parsed_content = env.parse(chat_template)
    template_vars = jinja2.meta.find_undeclared_variables(parsed_content)
    return template_vars


_cached_resolve_chat_template_kwargs = lru_cache(_resolve_chat_template_kwargs)


1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
@lru_cache
def _get_hf_base_chat_template_params() -> frozenset[str]:
    # Get standard parameters from HuggingFace's base tokenizer class.
    # This dynamically extracts parameters from PreTrainedTokenizer's
    # apply_chat_template method, ensuring compatibility with tokenizers
    # that use **kwargs to receive standard parameters.

    # Read signature from HF's base class - the single source of truth
    base_sig = inspect.signature(PreTrainedTokenizer.apply_chat_template)
    # Exclude VAR_KEYWORD (**kwargs) and VAR_POSITIONAL (*args) placeholders
    return frozenset(
        p.name
        for p in base_sig.parameters.values()
        if p.kind
        not in (inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL)
    )


1769
def resolve_chat_template_kwargs(
1770
    tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
1771
1772
    chat_template: str,
    chat_template_kwargs: dict[str, Any],
1773
    raise_on_unexpected: bool = True,
1774
) -> dict[str, Any]:
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
    # We exclude chat_template from kwargs here, because
    # chat template has been already resolved at this stage
    unexpected_vars = {"chat_template", "tokenize"}
    if raise_on_unexpected and (
        unexpected_in_kwargs := unexpected_vars & chat_template_kwargs.keys()
    ):
        raise ValueError(
            "Found unexpected chat template kwargs from request: "
            f"{unexpected_in_kwargs}"
        )

1786
    fn_kw = {
1787
1788
        k
        for k in chat_template_kwargs
1789
1790
        if supports_kw(tokenizer.apply_chat_template, k, allow_var_kwargs=False)
    }
1791
    template_vars = _cached_resolve_chat_template_kwargs(chat_template)
1792
1793
1794
1795
1796

    # Allow standard HF parameters even if tokenizer uses **kwargs to receive them
    hf_base_params = _get_hf_base_chat_template_params()

    accept_vars = (fn_kw | template_vars | hf_base_params) - unexpected_vars
1797
    return {k: v for k, v in chat_template_kwargs.items() if k in accept_vars}
1798
1799


1800
def apply_hf_chat_template(
1801
    tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
1802
    conversation: list[ConversationMessage],
1803
1804
    chat_template: str | None,
    tools: list[dict[str, Any]] | None,
1805
    *,
1806
    model_config: ModelConfig,
1807
    **kwargs: Any,
1808
) -> str:
1809
    hf_chat_template = resolve_hf_chat_template(
1810
1811
1812
        tokenizer,
        chat_template=chat_template,
        tools=tools,
1813
        model_config=model_config,
1814
    )
1815

1816
    if hf_chat_template is None:
1817
1818
1819
        raise ValueError(
            "As of transformers v4.44, default chat template is no longer "
            "allowed, so you must provide a chat template if the tokenizer "
1820
1821
            "does not define one."
        )
1822

1823
1824
1825
1826
1827
1828
    resolved_kwargs = resolve_chat_template_kwargs(
        tokenizer=tokenizer,
        chat_template=hf_chat_template,
        chat_template_kwargs=kwargs,
    )

1829
1830
1831
1832
1833
    try:
        return tokenizer.apply_chat_template(
            conversation=conversation,  # type: ignore[arg-type]
            tools=tools,  # type: ignore[arg-type]
            chat_template=hf_chat_template,
1834
            tokenize=False,
1835
            **resolved_kwargs,
1836
        )
1837

1838
1839
1840
1841
1842
1843
    # External library exceptions can sometimes occur despite the framework's
    # internal exception management capabilities.
    except Exception as e:
        # Log and report any library-related exceptions for further
        # investigation.
        logger.exception(
1844
1845
            "An error occurred in `transformers` while applying chat template"
        )
1846
        raise ValueError(str(e)) from e
1847

1848

1849
def apply_mistral_chat_template(
1850
    tokenizer: "MistralTokenizer",
1851
    messages: list[ChatCompletionMessageParam],
1852
1853
    chat_template: str | None,
    tools: list[dict[str, Any]] | None,
1854
    **kwargs: Any,
1855
) -> list[int]:
1856
1857
    from mistral_common.exceptions import MistralCommonException

1858
1859
1860
1861
1862
1863
    # The return value of resolve_mistral_chat_template is always None,
    # and we won't use it.
    resolve_mistral_chat_template(
        chat_template=chat_template,
        **kwargs,
    )
1864

1865
1866
1867
1868
1869
1870
1871
1872
1873
    try:
        return tokenizer.apply_chat_template(
            messages=messages,
            tools=tools,
            **kwargs,
        )
    # mistral-common uses assert statements to stop processing of input
    # if input does not comply with the expected format.
    # We convert those assertion errors to ValueErrors so they can be
1874
    # properly caught in the preprocessing_input step
1875
    except (AssertionError, MistralCommonException) as e:
1876
        raise ValueError(str(e)) from e
1877
1878
1879
1880
1881
1882
1883

    # External library exceptions can sometimes occur despite the framework's
    # internal exception management capabilities.
    except Exception as e:
        # Log and report any library-related exceptions for further
        # investigation.
        logger.exception(
1884
1885
            "An error occurred in `mistral_common` while applying chat template"
        )
1886
        raise ValueError(str(e)) from e
1887

1888

1889
1890
1891
def get_history_tool_calls_cnt(conversation: list[ConversationMessage]):
    idx = 0
    for msg in conversation:
1892
1893
1894
        if msg["role"] == "assistant":
            tool_calls = msg.get("tool_calls")
            idx += len(list(tool_calls)) if tool_calls is not None else 0  # noqa
1895
1896
1897
    return idx


1898
1899
1900
def make_tool_call_id(id_type: str = "random", func_name=None, idx=None):
    if id_type == "kimi_k2":
        return f"functions.{func_name}:{idx}"
1901
1902
1903
    else:
        # by default return random
        return f"chatcmpl-tool-{random_uuid()}"