chat_utils.py 57.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
import json
6
import warnings
7
from abc import ABC, abstractmethod
8
from collections import Counter, defaultdict
9
from collections.abc import Awaitable, Callable, Iterable
10
from functools import cached_property, lru_cache, partial
11
from itertools import accumulate
12
from pathlib import Path
13
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, TypeVar, cast
14

15
from openai.types.chat import (
16
17
18
19
20
    ChatCompletionAssistantMessageParam,
    ChatCompletionContentPartImageParam,
    ChatCompletionContentPartInputAudioParam,
    ChatCompletionContentPartRefusalParam,
    ChatCompletionContentPartTextParam,
21
    ChatCompletionFunctionToolParam,
22
23
24
25
26
27
    ChatCompletionMessageToolCallParam,
    ChatCompletionToolMessageParam,
)
from openai.types.chat import (
    ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam,
)
28
from openai.types.chat import (
29
30
31
    ChatCompletionMessageParam as OpenAIChatCompletionMessageParam,
)
from openai.types.chat.chat_completion_content_part_input_audio_param import InputAudio
32
from openai.types.responses import ResponseInputImageParam
33
from openai_harmony import Message as OpenAIHarmonyMessage
34
35
from PIL import Image
from pydantic import BaseModel, ConfigDict, TypeAdapter
36

37
# pydantic needs the TypedDict from typing_extensions
38
from typing_extensions import Required, TypedDict
39

40
from vllm import envs
41
from vllm.config import ModelConfig
42
from vllm.logger import init_logger
43
from vllm.model_executor.models import SupportsMultiModal
44
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict
45
46
47
48
from vllm.multimodal.inputs import (
    MultiModalBatchedField,
    MultiModalFlatField,
    MultiModalSharedField,
Roger Wang's avatar
Roger Wang committed
49
50
51
    VisionChunk,
    VisionChunkImage,
    VisionChunkVideo,
52
53
)
from vllm.multimodal.processing import BaseMultiModalProcessor
54
from vllm.multimodal.utils import MEDIA_CONNECTOR_REGISTRY, MediaConnector
55
from vllm.utils import random_uuid
56
57
58
59
60
from vllm.utils.collection_utils import is_list_of
from vllm.utils.import_utils import LazyLoader

if TYPE_CHECKING:
    import torch
61
    import transformers
62
else:
63
    transformers = LazyLoader("transformers", globals(), "transformers")
64
    torch = LazyLoader("torch", globals(), "torch")
65
66
67

logger = init_logger(__name__)

68

69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def __getattr__(name: str):
    if name == "resolve_hf_chat_template":
        from vllm.renderers.hf import resolve_chat_template

        warnings.warn(
            "`vllm.entrypoints.chat_utils.resolve_hf_chat_template` has been moved to "
            "`vllm.renderers.hf.resolve_chat_template`. "
            "The old name will be removed in v0.16.",
            DeprecationWarning,
            stacklevel=2,
        )

        return resolve_chat_template

    raise AttributeError(f"module {__name__!r} has no attribute {name!r}")


86
87
88
89
90
91
92
93
class ChatTemplateResolutionError(ValueError):
    """Raised when chat template resolution fails.

    This is a subclass of ValueError for backward compatibility with
    existing exception handlers.
    """


94
95
96
97
98
99
MODALITY_PLACEHOLDERS_MAP = {
    "image": "<##IMAGE##>",
    "audio": "<##AUDIO##>",
    "video": "<##VIDEO##>",
}

100

101
102
103
104
105
106
107
108
109
110
111
112
113
114
class AudioURL(TypedDict, total=False):
    url: Required[str]
    """
    Either a URL of the audio or a data URL with base64 encoded audio data.
    """


class ChatCompletionContentPartAudioParam(TypedDict, total=False):
    audio_url: Required[AudioURL]

    type: Required[Literal["audio_url"]]
    """The type of the content part."""


115
class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
116
    image_embeds: str | dict[str, str] | None
117
118
119
120
121
122
123
    """
    The image embeddings. It can be either:
    - A single base64 string.
    - A dictionary where each value is a base64 string.
    """
    type: Required[Literal["image_embeds"]]
    """The type of the content part."""
124
    uuid: str | None
125
126
127
128
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
129
130


131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
class ChatCompletionContentPartAudioEmbedsParam(TypedDict, total=False):
    audio_embeds: str | dict[str, str] | None
    """
    The audio embeddings. It can be either:
    - A single base64 string representing a serialized torch tensor.
    - A dictionary where each value is a base64 string.
    """
    type: Required[Literal["audio_embeds"]]
    """The type of the content part."""
    uuid: str | None
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """


147
148
149
150
151
152
153
154
155
156
157
158
159
160
class VideoURL(TypedDict, total=False):
    url: Required[str]
    """
    Either a URL of the video or a data URL with base64 encoded video data.
    """


class ChatCompletionContentPartVideoParam(TypedDict, total=False):
    video_url: Required[VideoURL]

    type: Required[Literal["video_url"]]
    """The type of the content part."""


161
162
163
164
class PILImage(BaseModel):
    """
    A PIL.Image.Image object.
    """
165

166
167
168
169
170
171
172
173
174
175
176
177
    image_pil: Image.Image
    model_config = ConfigDict(arbitrary_types_allowed=True)


class CustomChatCompletionContentPILImageParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a PIL image.

    Example:
    {
        "image_pil": ImageAsset('cherry_blossom').pil_image
    }
    """
178

179
180
    image_pil: PILImage | None
    uuid: str | None
181
182
183
184
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
185
186


187
188
189
class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain image_url.
    This is supported by OpenAI API, although it is not documented.
190

191
192
193
194
195
    Example:
    {
        "image_url": "https://example.com/image.jpg"
    }
    """
196

197
198
    image_url: str | None
    uuid: str | None
199
200
201
202
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
203
204
205
206


class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain audio_url.
207

208
209
210
211
212
    Example:
    {
        "audio_url": "https://example.com/audio.mp3"
    }
    """
213

214
    audio_url: str | None
215
216


217
218
219
220
221
222
223
224
class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain audio_url.

    Example:
    {
        "video_url": "https://example.com/video.mp4"
    }
    """
225

226
227
    video_url: str | None
    uuid: str | None
228
229
230
231
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
232
233


Julien Denize's avatar
Julien Denize committed
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
class CustomThinkCompletionContentParam(TypedDict, total=False):
    """A Think Completion Content Param that accepts a plain text and a boolean.

    Example:
    {
        "thinking": "I am thinking about the answer",
        "closed": True,
        "type": "thinking"
    }
    """

    thinking: Required[str]
    """The thinking content."""

    closed: bool
    """Whether the thinking is closed."""

    type: Required[Literal["thinking"]]
    """The thinking type."""


255
256
257
258
259
260
261
262
263
ChatCompletionContentPartParam: TypeAlias = (
    OpenAIChatCompletionContentPartParam
    | ChatCompletionContentPartAudioParam
    | ChatCompletionContentPartInputAudioParam
    | ChatCompletionContentPartVideoParam
    | ChatCompletionContentPartRefusalParam
    | CustomChatCompletionContentPILImageParam
    | CustomChatCompletionContentSimpleImageParam
    | ChatCompletionContentPartImageEmbedsParam
264
    | ChatCompletionContentPartAudioEmbedsParam
265
266
267
268
269
    | CustomChatCompletionContentSimpleAudioParam
    | CustomChatCompletionContentSimpleVideoParam
    | str
    | CustomThinkCompletionContentParam
)
270
271
272
273


class CustomChatCompletionMessageParam(TypedDict, total=False):
    """Enables custom roles in the Chat Completion API."""
274

275
276
277
    role: Required[str]
    """The role of the message's author."""

278
    content: str | list[ChatCompletionContentPartParam]
279
280
281
282
283
284
285
286
287
    """The contents of the message."""

    name: str
    """An optional name for the participant.

    Provides the model information to differentiate between participants of the
    same role.
    """

288
    tool_call_id: str | None
289
290
    """Tool call that this message is responding to."""

291
    tool_calls: Iterable[ChatCompletionMessageToolCallParam] | None
292
293
    """The tool calls generated by the model, such as function calls."""

294
295
296
    reasoning: str | None
    """The reasoning content for interleaved thinking."""

297
298
299
    tools: list[ChatCompletionFunctionToolParam] | None
    """The tools for developer role."""

300

301
302
303
304
305
ChatCompletionMessageParam: TypeAlias = (
    OpenAIChatCompletionMessageParam
    | CustomChatCompletionMessageParam
    | OpenAIHarmonyMessage
)
306
307


308
# TODO: Make fields ReadOnly once mypy supports it
309
310
311
312
class ConversationMessage(TypedDict, total=False):
    role: Required[str]
    """The role of the message's author."""

313
    content: str | None | list[dict[str, str]]
314
315
    """The contents of the message"""

316
    tool_call_id: str | None
317
318
    """Tool call that this message is responding to."""

319
    name: str | None
320
321
    """The name of the function to call"""

322
    tool_calls: Iterable[ChatCompletionMessageToolCallParam] | None
323
    """The tool calls generated by the model, such as function calls."""
324

325
326
327
328
329
330
    reasoning: str | None
    """The reasoning content for interleaved thinking."""

    reasoning_content: str | None
    """Deprecated: The reasoning content for interleaved thinking."""

331
332
333
    tools: list[ChatCompletionFunctionToolParam] | None
    """The tools for developer role."""

334

335
336
337
# Passed in by user
ChatTemplateContentFormatOption = Literal["auto", "string", "openai"]

338
339
# After resolving "auto"
ChatTemplateContentFormat = Literal["string", "openai"]
340

341

Roger Wang's avatar
Roger Wang committed
342
343
344
ModalityStr = Literal[
    "image", "audio", "video", "image_embeds", "audio_embeds", "vision_chunk"
]
345
346
347
_T = TypeVar("_T")


348
349
350
# Backward compatibility for single item input
class _BatchedSingleItemField(MultiModalSharedField):
    pass
351
352


353
354
355
356
357
358
def _detect_field(
    tensors: list[torch.Tensor],
    mm_processor: BaseMultiModalProcessor,
):
    first_item = tensors[0]
    hidden_size = mm_processor.info.ctx.model_config.get_inputs_embeds_size()
359

360
361
362
363
364
365
366
367
368
369
370
    if (
        len(tensors) == 1
        and first_item.ndim == 3
        and first_item.shape[0] == 1
        and first_item.shape[-1] == hidden_size
    ):
        logger.warning(
            "Batched multi-modal embedding inputs are deprecated for Chat API. "
            "Please pass a separate content part for each multi-modal item."
        )
        return _BatchedSingleItemField(batch_size=1)
371

372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
    first_shape = first_item.shape
    if all(t.shape == first_shape for t in tensors):
        return MultiModalBatchedField()

    size_per_item = [len(tensor) for tensor in tensors]
    slice_idxs = [0, *accumulate(size_per_item)]
    slices = [
        (slice(slice_idxs[i], slice_idxs[i + 1]),) for i in range(len(size_per_item))
    ]
    return MultiModalFlatField(slices=slices)


def _merge_embeds(
    data_items: list[dict[str, "torch.Tensor"]],
    mm_processor: BaseMultiModalProcessor,
):
    if not data_items:
        return {}

    first_keys = set(data_items[0].keys())
    if any(set(item.keys()) != first_keys for item in data_items[1:]):
        raise ValueError(
            "All dictionaries in the list of embeddings must have the same keys."
        )
396

397
398
399
400
401
402
403
404
    fields = {
        key: _detect_field([item[key] for item in data_items], mm_processor)
        for key in first_keys
    }
    data_merged = {
        key: field._reduce_data([item[key] for item in data_items], pin_memory=False)
        for key, field in fields.items()
    }
405

406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
    try:
        # TODO: Support per-request mm_processor_kwargs
        parsed_configs = mm_processor._get_mm_fields_config(
            transformers.BatchFeature(data_merged),
            {},
        )
        parsed_fields = {key: parsed_configs[key].field for key in first_keys}
        keys_to_update = [
            key
            for key in first_keys
            if (
                fields[key] != parsed_fields[key]
                and not isinstance(fields[key], _BatchedSingleItemField)
            )
        ]
421

422
423
424
        for key in keys_to_update:
            data_merged[key] = parsed_fields[key]._reduce_data(
                [item[key] for item in data_items], pin_memory=False
425
            )
426
427
428
429
430
    except Exception:
        logger.exception(
            "Error when parsing merged embeddings. "
            "Falling back to auto-detected fields."
        )
431

432
    return data_merged
433

434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454

def _get_embeds_data(
    modality: str,
    data_items: list[Any],
    mm_processor: BaseMultiModalProcessor,
):
    if len(data_items) == 0:
        return data_items

    if all(item is None for item in data_items):
        return data_items

    if is_list_of(data_items, torch.Tensor):
        embeds_key = f"{modality}_embeds"
        dict_items = [{embeds_key: item} for item in data_items]
        return _merge_embeds(dict_items, mm_processor)[embeds_key]

    if is_list_of(data_items, dict):
        return _merge_embeds(data_items, mm_processor)

    raise NotImplementedError(type(data_items))
455
456


Roger Wang's avatar
Roger Wang committed
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
def rebuild_mm_uuids_from_mm_data(
    mm_uuids: MultiModalUUIDDict,
    mm_data: MultiModalDataDict,
) -> MultiModalUUIDDict:
    """Rebuild mm_uuids after vision_chunk processing.

    When videos are split into chunks, the original UUIDs need to be updated
    to reflect the new UUIDs generated for each chunk.

    Args:
        mm_uuids: Original UUIDs dictionary
        mm_data: Processed multimodal data with vision_chunk items

    Returns:
        Updated UUIDs dictionary with chunk UUIDs
    """
    vision_chunks = mm_data.get("vision_chunk")
    if vision_chunks is None:
        return mm_uuids

    new_uuids = dict(mm_uuids)
    vision_chunk_uuids = []

    for item in vision_chunks:
        # vision_chunk items are always dicts (VisionChunkImage/VisionChunkVideo)
        assert isinstance(item, dict)
        uuid_val = item.get("uuid")
        if uuid_val is not None:
            vision_chunk_uuids.append(uuid_val)

    if vision_chunk_uuids:
        new_uuids["vision_chunk"] = vision_chunk_uuids

    return new_uuids


def build_video_prompts_from_mm_data(
    mm_data: MultiModalDataDict,
) -> list[str]:
    """Build video prompts from vision_chunk data.

    Collects prompts from video chunks and groups them by video_idx.

    Args:
        mm_data: Processed multimodal data with vision_chunk items

    Returns:
        List of video prompts, one per video.
    """
    vision_chunks = mm_data.get("vision_chunk")
    if vision_chunks is None:
        return []

    # Group chunks by video_idx
    video_prompts_dict: dict[int, list[str]] = defaultdict(list)

    for item in vision_chunks:
        # vision_chunk items are always dicts (VisionChunkImage/VisionChunkVideo)
        assert isinstance(item, dict)
        if item.get("type") == "video_chunk":
            video_idx = item.get("video_idx", 0)
            prompt = item.get("prompt", "")
            video_prompts_dict[video_idx].append(prompt)

    # Build prompts in video order
    video_prompts = []
    for video_idx in sorted(video_prompts_dict.keys()):
        video_prompts.append("".join(video_prompts_dict[video_idx]))

    return video_prompts


529
class BaseMultiModalItemTracker(ABC, Generic[_T]):
530
531
532
533
534
535
    """
    Tracks multi-modal items in a given request and ensures that the number
    of multi-modal items in a given request does not exceed the configured
    maximum per prompt.
    """

536
    def __init__(self, model_config: ModelConfig):
537
538
        super().__init__()

539
        self._model_config = model_config
540

541
        self._items_by_modality = defaultdict[str, list[_T]](list)
Roger Wang's avatar
Roger Wang committed
542
543
544
545
546
547
548
        # Track original modality for each vision_chunk item (image or video)
        self._modality_order = defaultdict[str, list[str]](list)

    @cached_property
    def use_unified_vision_chunk_modality(self) -> bool:
        """Check if model uses unified vision_chunk modality for images/videos."""
        return getattr(self._model_config.hf_config, "use_unified_vision_chunk", False)
549

550
    @property
551
552
    def model_config(self) -> ModelConfig:
        return self._model_config
553

554
    @cached_property
555
    def model_cls(self) -> type[SupportsMultiModal]:
556
        from vllm.model_executor.model_loader import get_model_cls
557

558
        model_cls = get_model_cls(self.model_config)
559
        return cast(type[SupportsMultiModal], model_cls)
560

561
562
    @property
    def allowed_local_media_path(self):
563
        return self._model_config.allowed_local_media_path
564

565
566
    @property
    def allowed_media_domains(self):
567
        return self._model_config.allowed_media_domains
568

569
570
571
572
    @property
    def mm_registry(self):
        return MULTIMODAL_REGISTRY

573
574
    @cached_property
    def mm_processor(self):
575
        return self.mm_registry.create_processor(self.model_config)
576

577
    def add(self, modality: ModalityStr, item: _T) -> str | None:
578
579
580
        """
        Add a multi-modal item to the current prompt and returns the
        placeholder string to use, if any.
581
582

        An optional uuid can be added which serves as a unique identifier of the
583
        media.
584
        """
585
        input_modality = modality.replace("_embeds", "")
Roger Wang's avatar
Roger Wang committed
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
        original_modality = modality
        use_vision_chunk = (
            self.use_unified_vision_chunk_modality
            and original_modality in ["video", "image"]
        )

        # If use_unified_vision_chunk_modality is enabled,
        # map image/video to vision_chunk
        if use_vision_chunk:
            # To avoid validation fail
            # because models with use_unified_vision_chunk_modality=True
            # will only accept vision_chunk modality.
            input_modality = "vision_chunk"
            num_items = len(self._items_by_modality[input_modality]) + 1
        else:
            num_items = len(self._items_by_modality[original_modality]) + 1
602

603
        self.mm_processor.validate_num_items(input_modality, num_items)
604

Roger Wang's avatar
Roger Wang committed
605
606
607
608
609
610
        # Track original modality for vision_chunk items
        if use_vision_chunk:
            self._items_by_modality[input_modality].append(item)  # type: ignore
            self._modality_order["vision_chunk"].append(original_modality)
        else:
            self._items_by_modality[original_modality].append(item)
611

612
        return self.model_cls.get_placeholder_str(modality, num_items)
613
614
615
616
617
618

    @abstractmethod
    def create_parser(self) -> "BaseMultiModalContentParser":
        raise NotImplementedError


619
620
621
def _resolve_items(
    items_by_modality: dict[str, list[tuple[object, str | None]]],
    mm_processor: BaseMultiModalProcessor,
Roger Wang's avatar
Roger Wang committed
622
    vision_chunk_modality_order: dict[str, list[str]],
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
) -> tuple[MultiModalDataDict, MultiModalUUIDDict]:
    if "image" in items_by_modality and "image_embeds" in items_by_modality:
        raise ValueError("Mixing raw image and embedding inputs is not allowed")
    if "audio" in items_by_modality and "audio_embeds" in items_by_modality:
        raise ValueError("Mixing raw audio and embedding inputs is not allowed")

    mm_data = {}
    mm_uuids = {}
    if "image_embeds" in items_by_modality:
        mm_data["image"] = _get_embeds_data(
            "image",
            [data for data, uuid in items_by_modality["image_embeds"]],
            mm_processor,
        )
        mm_uuids["image"] = [uuid for data, uuid in items_by_modality["image_embeds"]]
    if "image" in items_by_modality:
        mm_data["image"] = [data for data, uuid in items_by_modality["image"]]
        mm_uuids["image"] = [uuid for data, uuid in items_by_modality["image"]]
    if "audio_embeds" in items_by_modality:
        mm_data["audio"] = _get_embeds_data(
            "audio",
            [data for data, uuid in items_by_modality["audio_embeds"]],
            mm_processor,
        )
        mm_uuids["audio"] = [uuid for data, uuid in items_by_modality["audio_embeds"]]
    if "audio" in items_by_modality:
        mm_data["audio"] = [data for data, uuid in items_by_modality["audio"]]
        mm_uuids["audio"] = [uuid for data, uuid in items_by_modality["audio"]]
    if "video" in items_by_modality:
        mm_data["video"] = [data for data, uuid in items_by_modality["video"]]
        mm_uuids["video"] = [uuid for data, uuid in items_by_modality["video"]]
Roger Wang's avatar
Roger Wang committed
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
    if "vision_chunk" in items_by_modality:
        # Process vision_chunk items - extract from (data, modality) tuples
        # and convert to VisionChunk types with proper UUID handling
        vision_chunk_items = items_by_modality["vision_chunk"]
        modality_order = vision_chunk_modality_order.get("vision_chunk", [])
        mm_uuids["vision_chunk"] = [
            uuid for data, uuid in items_by_modality["vision_chunk"]
        ]

        # Filter out None items (from asyncio.sleep(0) placeholders)
        filtered_items = [
            (idx, item)
            for idx, item in enumerate(vision_chunk_items)
            if item is not None
        ]

        assert len(filtered_items) == len(modality_order), (
            f"vision_chunk items ({len(filtered_items)}) and "
            f"modality_order ({len(modality_order)}) must have same length"
        )

        processed_chunks: list[VisionChunk] = []
        video_idx = 0
        for i, (idx, item) in enumerate(filtered_items):
            inner_modality = modality_order[i]
            data, uuid = item
            uuid_val = uuid if idx < len(mm_uuids["vision_chunk"]) else None
            if inner_modality == "image":
                # Cast data to proper type for image
                # Use .media (PIL.Image) directly to avoid redundant
                # bytes→PIL conversion in media_processor
                if hasattr(data, "media"):
                    image_data = data.media  # type: ignore[union-attr]
                    processed_chunks.append(
                        VisionChunkImage(type="image", image=image_data, uuid=uuid_val)
                    )
                else:
                    processed_chunks.append(data)  # type: ignore[arg-type]
            elif inner_modality == "video":
                # For video, we may need to split into chunks
                # if processor supports it
                # For now, just wrap as a video chunk placeholder
                if hasattr(mm_processor, "split_video_chunks") and data is not None:
                    try:
                        video_uuid = uuid_val or random_uuid()
                        # video await result is (video_data, video_meta) tuple
                        if isinstance(data, tuple) and len(data) >= 1:
                            video_data = data[0]
                        else:
                            video_data = data
                        video_chunks = mm_processor.split_video_chunks(video_data)
                        for i, vc in enumerate(video_chunks):
                            processed_chunks.append(
                                VisionChunkVideo(
                                    type="video_chunk",
                                    video_chunk=vc["video_chunk"],
                                    uuid=f"{video_uuid}-{i}",
                                    video_idx=video_idx,
                                    prompt=vc["prompt"],
                                )
                            )
                        video_idx += 1
                    except Exception as e:
                        logger.warning("Failed to split video chunks: %s", e)
                        processed_chunks.append(data)  # type: ignore[arg-type]
                else:
                    processed_chunks.append(data)  # type: ignore[arg-type]
        mm_data["vision_chunk"] = processed_chunks
722
723
724
725
726
727
728
729

    return mm_data, mm_uuids


class MultiModalItemTracker(BaseMultiModalItemTracker[tuple[object, str | None]]):
    def resolve_items(
        self,
    ) -> tuple[MultiModalDataDict | None, MultiModalUUIDDict | None]:
730
        if not self._items_by_modality:
731
732
            return None, None

Roger Wang's avatar
Roger Wang committed
733
734
735
        return _resolve_items(
            dict(self._items_by_modality), self.mm_processor, self._modality_order
        )
736
737
738
739
740

    def create_parser(self) -> "BaseMultiModalContentParser":
        return MultiModalContentParser(self)


741
742
743
744
745
746
class AsyncMultiModalItemTracker(
    BaseMultiModalItemTracker[Awaitable[tuple[object, str | None]]]
):
    async def resolve_items(
        self,
    ) -> tuple[MultiModalDataDict | None, MultiModalUUIDDict | None]:
747
        if not self._items_by_modality:
748
            return None, None
749

750
        resolved_items_by_modality = {
751
            modality: await asyncio.gather(*coros)
752
            for modality, coros in self._items_by_modality.items()
753
        }
754

Roger Wang's avatar
Roger Wang committed
755
756
757
        return _resolve_items(
            resolved_items_by_modality, self.mm_processor, self._modality_order
        )
758
759
760
761
762
763
764
765
766

    def create_parser(self) -> "BaseMultiModalContentParser":
        return AsyncMultiModalContentParser(self)


class BaseMultiModalContentParser(ABC):
    def __init__(self) -> None:
        super().__init__()

767
        # stores model placeholders list with corresponding
768
769
770
771
772
773
774
        # general MM placeholder:
        # {
        #   "<##IMAGE##>": ["<image>", "<image>", "<image>"],
        #   "<##AUDIO##>": ["<audio>", "<audio>"]
        # }
        self._placeholder_storage: dict[str, list] = defaultdict(list)

775
    def _add_placeholder(self, modality: ModalityStr, placeholder: str | None):
776
        mod_placeholder = MODALITY_PLACEHOLDERS_MAP[modality]
777
        if placeholder:
778
            self._placeholder_storage[mod_placeholder].append(placeholder)
779

780
781
    def mm_placeholder_storage(self) -> dict[str, list]:
        return dict(self._placeholder_storage)
782
783

    @abstractmethod
784
    def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
785
786
        raise NotImplementedError

787
    @abstractmethod
788
    def parse_image_embeds(
789
        self,
790
791
        image_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
792
    ) -> None:
793
794
        raise NotImplementedError

795
    @abstractmethod
796
    def parse_image_pil(
797
        self, image_pil: Image.Image | None, uuid: str | None = None
798
    ) -> None:
799
800
        raise NotImplementedError

801
    @abstractmethod
802
    def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None:
803
804
        raise NotImplementedError

805
    @abstractmethod
806
    def parse_input_audio(
807
        self, input_audio: InputAudio | None, uuid: str | None = None
808
    ) -> None:
809
810
        raise NotImplementedError

811
812
813
814
815
816
817
818
    @abstractmethod
    def parse_audio_embeds(
        self,
        audio_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
    ) -> None:
        raise NotImplementedError

819
    @abstractmethod
820
    def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
821
822
        raise NotImplementedError

823
824
825
826
827
828

class MultiModalContentParser(BaseMultiModalContentParser):
    def __init__(self, tracker: MultiModalItemTracker) -> None:
        super().__init__()

        self._tracker = tracker
829
830
831
        multimodal_config = self._tracker.model_config.multimodal_config
        media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None)

832
833
        self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load(
            envs.VLLM_MEDIA_CONNECTOR,
834
            media_io_kwargs=media_io_kwargs,
835
            allowed_local_media_path=tracker.allowed_local_media_path,
836
            allowed_media_domains=tracker.allowed_media_domains,
837
838
        )

839
840
    @property
    def model_config(self) -> ModelConfig:
841
        return self._tracker.model_config
842

843
    def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
844
        image = self._connector.fetch_image(image_url) if image_url else None
845

846
        placeholder = self._tracker.add("image", (image, uuid))
847
        self._add_placeholder("image", placeholder)
848

849
    def parse_image_embeds(
850
        self,
851
852
        image_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
853
    ) -> None:
854
855
856
857
858
859
        mm_config = self.model_config.get_multimodal_config()
        if not mm_config.enable_mm_embeds:
            raise ValueError(
                "You must set `--enable-mm-embeds` to input `image_embeds`"
            )

860
861
862
863
864
        if isinstance(image_embeds, dict):
            embeds = {
                k: self._connector.fetch_image_embedding(v)
                for k, v in image_embeds.items()
            }
865
            placeholder = self._tracker.add("image_embeds", (embeds, uuid))
866
867
868

        if isinstance(image_embeds, str):
            embedding = self._connector.fetch_image_embedding(image_embeds)
869
            placeholder = self._tracker.add("image_embeds", (embedding, uuid))
870

871
        if image_embeds is None:
872
            placeholder = self._tracker.add("image_embeds", (None, uuid))
873

874
        self._add_placeholder("image", placeholder)
875

876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
    def parse_audio_embeds(
        self,
        audio_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
    ) -> None:
        mm_config = self.model_config.get_multimodal_config()
        if not mm_config.enable_mm_embeds:
            raise ValueError(
                "You must set `--enable-mm-embeds` to input `audio_embeds`"
            )

        if isinstance(audio_embeds, dict):
            embeds = {
                k: self._connector.fetch_audio_embedding(v)
                for k, v in audio_embeds.items()
            }
892
            placeholder = self._tracker.add("audio_embeds", (embeds, uuid))
893
894
        elif isinstance(audio_embeds, str):
            embedding = self._connector.fetch_audio_embedding(audio_embeds)
895
            placeholder = self._tracker.add("audio_embeds", (embedding, uuid))
896
        else:
897
            placeholder = self._tracker.add("audio_embeds", (None, uuid))
898
899
900

        self._add_placeholder("audio", placeholder)

901
    def parse_image_pil(
902
        self, image_pil: Image.Image | None, uuid: str | None = None
903
    ) -> None:
904
        placeholder = self._tracker.add("image", (image_pil, uuid))
905
        self._add_placeholder("image", placeholder)
906

907
    def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None:
908
        audio = self._connector.fetch_audio(audio_url) if audio_url else None
909

910
        placeholder = self._tracker.add("audio", (audio, uuid))
911
        self._add_placeholder("audio", placeholder)
912

913
    def parse_input_audio(
914
        self, input_audio: InputAudio | None, uuid: str | None = None
915
    ) -> None:
916
917
918
919
920
921
922
923
924
925
        if input_audio:
            audio_data = input_audio.get("data", "")
            audio_format = input_audio.get("format", "")
            if audio_data:
                audio_url = f"data:audio/{audio_format};base64,{audio_data}"
            else:
                # If a UUID is provided, audio data may be empty.
                audio_url = None
        else:
            audio_url = None
926

927
        return self.parse_audio(audio_url, uuid)
928

929
    def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
930
        video = self._connector.fetch_video(video_url=video_url) if video_url else None
931

932
        placeholder = self._tracker.add("video", (video, uuid))
933
        self._add_placeholder("video", placeholder)
934

935
936
937
938
939
940

class AsyncMultiModalContentParser(BaseMultiModalContentParser):
    def __init__(self, tracker: AsyncMultiModalItemTracker) -> None:
        super().__init__()

        self._tracker = tracker
941
942
        multimodal_config = self._tracker.model_config.multimodal_config
        media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None)
943
944
        self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load(
            envs.VLLM_MEDIA_CONNECTOR,
945
            media_io_kwargs=media_io_kwargs,
946
            allowed_local_media_path=tracker.allowed_local_media_path,
947
            allowed_media_domains=tracker.allowed_media_domains,
948
        )
949

950
951
    @property
    def model_config(self) -> ModelConfig:
952
        return self._tracker.model_config
953

954
955
956
957
958
959
    async def _image_with_uuid_async(self, image_url: str | None, uuid: str | None):
        image = (
            await self._connector.fetch_image_async(image_url) if image_url else None
        )
        return image, uuid

960
    def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
961
        coro = self._image_with_uuid_async(image_url, uuid)
962

963
        placeholder = self._tracker.add("image", coro)
964
        self._add_placeholder("image", placeholder)
965

966
    def parse_image_embeds(
967
        self,
968
969
        image_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
970
    ) -> None:
971
972
973
974
975
976
        mm_config = self.model_config.get_multimodal_config()
        if not mm_config.enable_mm_embeds:
            raise ValueError(
                "You must set `--enable-mm-embeds` to input `image_embeds`"
            )

977
978
979
        future = asyncio.Future[
            tuple[torch.Tensor | dict[str, torch.Tensor] | None, str | None]
        ]()
980
981
982
983
984
985

        if isinstance(image_embeds, dict):
            embeds = {
                k: self._connector.fetch_image_embedding(v)
                for k, v in image_embeds.items()
            }
986
            future.set_result((embeds, uuid))
987
988

        if isinstance(image_embeds, str):
989
            embedding = self._connector.fetch_image_embedding(image_embeds)
990
            future.set_result((embedding, uuid))
991

992
        if image_embeds is None:
993
            future.set_result((None, uuid))
994

995
        placeholder = self._tracker.add("image_embeds", future)
996
        self._add_placeholder("image", placeholder)
997

998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
    def parse_audio_embeds(
        self,
        audio_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
    ) -> None:
        mm_config = self.model_config.get_multimodal_config()
        if not mm_config.enable_mm_embeds:
            raise ValueError(
                "You must set `--enable-mm-embeds` to input `audio_embeds`"
            )

1009
1010
1011
        future = asyncio.Future[
            tuple[torch.Tensor | dict[str, torch.Tensor] | None, str | None]
        ]()
1012
1013
1014
1015
1016
1017

        if isinstance(audio_embeds, dict):
            embeds = {
                k: self._connector.fetch_audio_embedding(v)
                for k, v in audio_embeds.items()
            }
1018
            future.set_result((embeds, uuid))
1019
1020
1021

        if isinstance(audio_embeds, str):
            embedding = self._connector.fetch_audio_embedding(audio_embeds)
1022
            future.set_result((embedding, uuid))
1023
1024

        if audio_embeds is None:
1025
            future.set_result((None, uuid))
1026

1027
        placeholder = self._tracker.add("audio_embeds", future)
1028
1029
        self._add_placeholder("audio", placeholder)

1030
    def parse_image_pil(
1031
1032
1033
        self,
        image_pil: Image.Image | None,
        uuid: str | None = None,
1034
    ) -> None:
1035
        future = asyncio.Future[tuple[Image.Image | None, str | None]]()
1036
        if image_pil:
1037
            future.set_result((image_pil, uuid))
1038
        else:
1039
            future.set_result((None, uuid))
1040

1041
        placeholder = self._tracker.add("image", future)
1042
        self._add_placeholder("image", placeholder)
1043

1044
1045
1046
1047
1048
1049
    async def _audio_with_uuid_async(self, audio_url: str | None, uuid: str | None):
        audio = (
            await self._connector.fetch_audio_async(audio_url) if audio_url else None
        )
        return audio, uuid

1050
    def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None:
1051
        coro = self._audio_with_uuid_async(audio_url, uuid)
1052

1053
        placeholder = self._tracker.add("audio", coro)
1054
        self._add_placeholder("audio", placeholder)
1055

1056
    def parse_input_audio(
1057
        self, input_audio: InputAudio | None, uuid: str | None = None
1058
    ) -> None:
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
        if input_audio:
            audio_data = input_audio.get("data", "")
            audio_format = input_audio.get("format", "")
            if audio_data:
                audio_url = f"data:audio/{audio_format};base64,{audio_data}"
            else:
                # If a UUID is provided, audio data may be empty.
                audio_url = None
        else:
            audio_url = None
1069

1070
        return self.parse_audio(audio_url, uuid)
1071

1072
    async def _video_with_uuid_async(self, video_url: str | None, uuid: str | None):
1073
        video = (
1074
            await self._connector.fetch_video_async(video_url) if video_url else None
1075
        )
1076
1077
1078
1079
        return video, uuid

    def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
        coro = self._video_with_uuid_async(video_url, uuid)
1080

1081
        placeholder = self._tracker.add("video", coro)
1082
        self._add_placeholder("video", placeholder)
1083

1084

1085
def validate_chat_template(chat_template: Path | str | None):
1086
1087
1088
1089
1090
    """Raises if the provided chat template appears invalid."""
    if chat_template is None:
        return

    elif isinstance(chat_template, Path) and not chat_template.exists():
1091
        raise FileNotFoundError("the supplied chat template path doesn't exist")
1092
1093
1094

    elif isinstance(chat_template, str):
        JINJA_CHARS = "{}\n"
1095
1096
1097
1098
        if (
            not any(c in chat_template for c in JINJA_CHARS)
            and not Path(chat_template).exists()
        ):
1099
1100
1101
            # Try to find the template in the built-in templates directory
            from vllm.transformers_utils.chat_templates.registry import (
                CHAT_TEMPLATES_DIR,
1102
            )
1103

1104
1105
1106
1107
1108
1109
1110
1111
            builtin_template_path = CHAT_TEMPLATES_DIR / chat_template
            if not builtin_template_path.exists():
                raise ValueError(
                    f"The supplied chat template string ({chat_template}) "
                    f"appears path-like, but doesn't exist! "
                    f"Tried: {chat_template} and {builtin_template_path}"
                )

1112
    else:
1113
        raise TypeError(f"{type(chat_template)} is not a valid chat template type")
1114
1115


1116
def _load_chat_template(
1117
    chat_template: Path | str | None,
1118
1119
    *,
    is_literal: bool = False,
1120
) -> str | None:
1121
1122
    if chat_template is None:
        return None
1123
1124
1125

    if is_literal:
        if isinstance(chat_template, Path):
1126
1127
1128
            raise TypeError(
                "chat_template is expected to be read directly from its value"
            )
1129

1130
        return chat_template
1131

1132
    try:
1133
        with open(chat_template) as f:
1134
            return f.read()
1135
    except OSError as e:
1136
1137
1138
        if isinstance(chat_template, Path):
            raise

1139
1140
        JINJA_CHARS = "{}\n"
        if not any(c in chat_template for c in JINJA_CHARS):
1141
1142
1143
            # Try to load from the built-in templates directory
            from vllm.transformers_utils.chat_templates.registry import (
                CHAT_TEMPLATES_DIR,
1144
            )
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157

            builtin_template_path = CHAT_TEMPLATES_DIR / chat_template
            try:
                with open(builtin_template_path) as f:
                    return f.read()
            except OSError:
                msg = (
                    f"The supplied chat template ({chat_template}) "
                    f"looks like a file path, but it failed to be opened. "
                    f"Tried: {chat_template} and {builtin_template_path}. "
                    f"Reason: {e}"
                )
                raise ValueError(msg) from e
1158

1159
1160
        # If opening a file fails, set chat template to be args to
        # ensure we decode so our escape are interpreted correctly
1161
1162
1163
1164
1165
1166
1167
        return _load_chat_template(chat_template, is_literal=True)


_cached_load_chat_template = lru_cache(_load_chat_template)


def load_chat_template(
1168
    chat_template: Path | str | None,
1169
1170
    *,
    is_literal: bool = False,
1171
) -> str | None:
1172
    return _cached_load_chat_template(chat_template, is_literal=is_literal)
1173
1174


1175
1176
1177
def _get_interleaved_text_prompt(
    placeholder_storage: dict[str, list], texts: list[str]
) -> str:
1178
1179
1180
1181
1182
1183
1184
    for idx, elem in enumerate(texts):
        if elem in placeholder_storage:
            texts[idx] = placeholder_storage[elem].pop(0)

    return "\n".join(texts)


1185
# TODO: Let user specify how to insert multimodal tokens into prompt
1186
# (similar to chat template)
1187
1188
1189
1190
1191
def _get_full_multimodal_text_prompt(
    placeholder_storage: dict[str, list],
    texts: list[str],
    interleave_strings: bool,
) -> str:
1192
    """Combine multimodal prompts for a multimodal language model."""
1193

1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
    # flatten storage to make it looks like
    # {
    #   "<|image|>": 2,
    #   "<|audio|>": 1
    # }
    placeholder_counts = Counter(
        [v for elem in placeholder_storage.values() for v in elem]
    )

    if interleave_strings:
        text_prompt = _get_interleaved_text_prompt(placeholder_storage, texts)
    else:
        text_prompt = "\n".join(texts)

    # Pass interleaved text further in case the user used image placeholders
    # himself, but forgot to disable the 'interleave_strings' flag

1211
    # Look through the text prompt to check for missing placeholders
1212
    missing_placeholders: list[str] = []
1213
1214
1215
1216
1217
    for placeholder in placeholder_counts:
        # For any existing placeholder in the text prompt, we leave it as is
        placeholder_counts[placeholder] -= text_prompt.count(placeholder)

        if placeholder_counts[placeholder] < 0:
1218
1219
1220
1221
            logger.error(
                "Placeholder count is negative! "
                "Ensure that the 'interleave_strings' flag is disabled "
                "(current value: %s) "
1222
1223
                "when manually placing image placeholders.",
                interleave_strings,
1224
1225
            )
            logger.debug("Input prompt: %s", text_prompt)
1226
1227
            raise ValueError(
                f"Found more '{placeholder}' placeholders in input prompt than "
1228
1229
                "actual multimodal data items."
            )
1230

1231
        missing_placeholders.extend([placeholder] * placeholder_counts[placeholder])
1232

1233
1234
    # NOTE: Default behaviour: we always add missing placeholders
    # at the front of the prompt, if interleave_strings=False
1235
    return "\n".join(missing_placeholders + [text_prompt])
1236
1237


1238
1239
# No need to validate using Pydantic again
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
1240
_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
1241
_AudioEmbedsParser = partial(cast, ChatCompletionContentPartAudioEmbedsParam)
1242
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
1243
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
1244
_PILImageParser = partial(cast, CustomChatCompletionContentPILImageParam)
Julien Denize's avatar
Julien Denize committed
1245
_ThinkParser = partial(cast, CustomThinkCompletionContentParam)
1246
1247
1248
1249
# Need to validate url objects
_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam).validate_python
_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python
_VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python
1250

1251
_ResponsesInputImageParser = TypeAdapter(ResponseInputImageParam).validate_python
1252
_ContentPart: TypeAlias = str | dict[str, str] | InputAudio | PILImage
1253

1254
# Define a mapping from part types to their corresponding parsing functions.
1255
MM_PARSER_MAP: dict[
1256
1257
1258
    str,
    Callable[[ChatCompletionContentPartParam], _ContentPart],
] = {
1259
1260
1261
    "text": lambda part: _TextParser(part).get("text", None),
    "thinking": lambda part: _ThinkParser(part).get("thinking", None),
    "input_text": lambda part: _TextParser(part).get("text", None),
1262
    "output_text": lambda part: _TextParser(part).get("text", None),
1263
1264
1265
    "input_image": lambda part: _ResponsesInputImageParser(part).get("image_url", None),
    "image_url": lambda part: _ImageParser(part).get("image_url", {}).get("url", None),
    "image_embeds": lambda part: _ImageEmbedsParser(part).get("image_embeds", None),
1266
    "audio_embeds": lambda part: _AudioEmbedsParser(part).get("audio_embeds", None),
1267
    "image_pil": lambda part: _PILImageParser(part).get("image_pil", None),
1268
1269
    "audio_url": lambda part: _AudioParser(part).get("audio_url", {}).get("url", None),
    "input_audio": lambda part: _InputAudioParser(part).get("input_audio", None),
1270
    "refusal": lambda part: _RefusalParser(part).get("refusal", None),
1271
    "video_url": lambda part: _VideoParser(part).get("video_url", {}).get("url", None),
1272
1273
1274
1275
}


def _parse_chat_message_content_mm_part(
1276
1277
    part: ChatCompletionContentPartParam,
) -> tuple[str, _ContentPart]:
1278
    """
1279
    Parses a given multi-modal content part based on its type.
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292

    Args:
        part: A dict containing the content part, with a potential 'type' field.

    Returns:
        A tuple (part_type, content) where:
        - part_type: Type of the part (e.g., 'text', 'image_url').
        - content: Parsed content (e.g., text, image URL).

    Raises:
        ValueError: If the 'type' field is missing and no direct URL is found.
    """
    assert isinstance(
1293
1294
        part, dict
    )  # This is needed to avoid mypy errors: part.get() from str
1295
    part_type = part.get("type", None)
1296
    uuid = part.get("uuid", None)
1297

1298
    if isinstance(part_type, str) and part_type in MM_PARSER_MAP and uuid is None:  # noqa: E501
1299
1300
1301
        content = MM_PARSER_MAP[part_type](part)

        # Special case for 'image_url.detail'
1302
1303
        # We only support 'auto', which is the default
        if part_type == "image_url" and part.get("detail", "auto") != "auto":
1304
            logger.warning(
1305
                "'image_url.detail' is currently not supported and will be ignored."
1306
            )
1307
1308
1309
1310

        return part_type, content

    # Handle missing 'type' but provided direct URL fields.
1311
    # 'type' is required field by pydantic
1312
1313
    if part_type is None or uuid is not None:
        if "image_url" in part:
1314
            image_params = cast(CustomChatCompletionContentSimpleImageParam, part)
1315
1316
1317
1318
1319
1320
1321
1322
            image_url = image_params.get("image_url", None)
            if isinstance(image_url, dict):
                # Can potentially happen if user provides a uuid
                # with url as a dict of {"url": url}
                image_url = image_url.get("url", None)
            return "image_url", image_url
        if "image_pil" in part:
            # "image_pil" could be None if UUID is provided.
1323
            image_params = cast(  # type: ignore
1324
1325
1326
1327
1328
1329
                CustomChatCompletionContentPILImageParam, part
            )
            image_pil = image_params.get("image_pil", None)
            return "image_pil", image_pil
        if "image_embeds" in part:
            # "image_embeds" could be None if UUID is provided.
1330
            image_params = cast(  # type: ignore
1331
1332
1333
1334
                ChatCompletionContentPartImageEmbedsParam, part
            )
            image_embeds = image_params.get("image_embeds", None)
            return "image_embeds", image_embeds
1335
1336
1337
1338
1339
1340
1341
        if "audio_embeds" in part:
            # "audio_embeds" could be None if UUID is provided.
            audio_params = cast(  # type: ignore[assignment]
                ChatCompletionContentPartAudioEmbedsParam, part
            )
            audio_embeds = audio_params.get("audio_embeds", None)
            return "audio_embeds", audio_embeds
1342
        if "audio_url" in part:
1343
1344
1345
            audio_params = cast(  # type: ignore[assignment]
                CustomChatCompletionContentSimpleAudioParam, part
            )
1346
1347
1348
1349
1350
1351
            audio_url = audio_params.get("audio_url", None)
            if isinstance(audio_url, dict):
                # Can potentially happen if user provides a uuid
                # with url as a dict of {"url": url}
                audio_url = audio_url.get("url", None)
            return "audio_url", audio_url
1352
        if part.get("input_audio") is not None:
1353
            input_audio_params = cast(dict[str, str], part)
1354
            return "input_audio", input_audio_params
1355
        if "video_url" in part:
1356
            video_params = cast(CustomChatCompletionContentSimpleVideoParam, part)
1357
1358
1359
1360
1361
1362
            video_url = video_params.get("video_url", None)
            if isinstance(video_url, dict):
                # Can potentially happen if user provides a uuid
                # with url as a dict of {"url": url}
                video_url = video_url.get("url", None)
            return "video_url", video_url
1363
1364
1365
1366
1367
1368
1369
1370
        # Raise an error if no 'type' or direct URL is found.
        raise ValueError("Missing 'type' field in multimodal part.")

    if not isinstance(part_type, str):
        raise ValueError("Invalid 'type' field in multimodal part.")
    return part_type, "unknown part_type content"


1371
PART_TYPES_TO_SKIP_NONE_CONTENT = (
1372
1373
1374
    "text",
    "refusal",
)
1375

1376

1377
1378
1379
def _parse_chat_message_content_parts(
    role: str,
    parts: Iterable[ChatCompletionContentPartParam],
1380
    mm_tracker: BaseMultiModalItemTracker,
1381
1382
    *,
    wrap_dicts: bool,
1383
    interleave_strings: bool,
1384
) -> list[ConversationMessage]:
1385
    content = list[_ContentPart]()
1386

1387
    mm_parser = mm_tracker.create_parser()
1388
1389

    for part in parts:
1390
        parse_res = _parse_chat_message_content_part(
1391
1392
1393
            part,
            mm_parser,
            wrap_dicts=wrap_dicts,
1394
            interleave_strings=interleave_strings,
1395
        )
1396
1397
        if parse_res:
            content.append(parse_res)
1398

1399
    if wrap_dicts:
1400
        # Parsing wraps images and texts as interleaved dictionaries
1401
        return [ConversationMessage(role=role, content=content)]  # type: ignore
1402
    texts = cast(list[str], content)
1403
1404
    mm_placeholder_storage = mm_parser.mm_placeholder_storage()
    if mm_placeholder_storage:
1405
1406
1407
        text_prompt = _get_full_multimodal_text_prompt(
            mm_placeholder_storage, texts, interleave_strings
        )
1408
1409
1410
    else:
        text_prompt = "\n".join(texts)

1411
1412
1413
1414
    return [ConversationMessage(role=role, content=text_prompt)]


def _parse_chat_message_content_part(
1415
1416
1417
1418
    part: ChatCompletionContentPartParam,
    mm_parser: BaseMultiModalContentParser,
    *,
    wrap_dicts: bool,
1419
    interleave_strings: bool,
1420
) -> _ContentPart | None:
1421
1422
1423
1424
1425
1426
1427
1428
    """Parses a single part of a conversation. If wrap_dicts is True,
    structured dictionary pieces for texts and images will be
    wrapped in dictionaries, i.e., {"type": "text", "text", ...} and
    {"type": "image"}, respectively. Otherwise multimodal data will be
    handled by mm_parser, and texts will be returned as strings to be joined
    with multimodal placeholders.
    """
    if isinstance(part, str):  # Handle plain text parts
1429
        return part
1430
1431
    # Handle structured dictionary parts
    part_type, content = _parse_chat_message_content_mm_part(part)
1432
    # if part_type is text/refusal/image_url/audio_url/video_url/input_audio but
1433
    # content is None, log a warning and skip
1434
    if part_type in PART_TYPES_TO_SKIP_NONE_CONTENT and content is None:
1435
        logger.warning(
1436
            "Skipping multimodal part '%s' (type: '%s') "
1437
1438
1439
1440
            "with empty / unparsable content.",
            part,
            part_type,
        )
1441
1442
        return None

1443
    if part_type in ("text", "input_text", "output_text", "refusal", "thinking"):
1444
1445
        str_content = cast(str, content)
        if wrap_dicts:
1446
            return {"type": "text", "text": str_content}
1447
1448
        else:
            return str_content
1449

1450
1451
1452
1453
1454
1455
    # For media items, if a user has provided one, use it. Otherwise, insert
    # a placeholder empty uuid.
    uuid = part.get("uuid", None)
    if uuid is not None:
        uuid = str(uuid)

1456
    modality = None
1457
    if part_type == "image_pil":
1458
        image_content = cast(Image.Image, content) if content is not None else None
1459
        mm_parser.parse_image_pil(image_content, uuid)
1460
        modality = "image"
1461
    elif part_type in ("image_url", "input_image"):
1462
        str_content = cast(str, content)
1463
        mm_parser.parse_image(str_content, uuid)
1464
1465
        modality = "image"
    elif part_type == "image_embeds":
1466
        content = cast(str | dict[str, str], content) if content is not None else None
1467
        mm_parser.parse_image_embeds(content, uuid)
1468
        modality = "image"
1469
1470
1471
1472
    elif part_type == "audio_embeds":
        content = cast(str | dict[str, str], content) if content is not None else None
        mm_parser.parse_audio_embeds(content, uuid)
        modality = "audio"
1473
    elif part_type == "audio_url":
1474
        str_content = cast(str, content)
1475
        mm_parser.parse_audio(str_content, uuid)
1476
1477
        modality = "audio"
    elif part_type == "input_audio":
1478
        dict_content = cast(InputAudio, content)
1479
        mm_parser.parse_input_audio(dict_content, uuid)
1480
1481
        modality = "audio"
    elif part_type == "video_url":
1482
        str_content = cast(str, content)
1483
        mm_parser.parse_video(str_content, uuid)
1484
1485
1486
        modality = "video"
    else:
        raise NotImplementedError(f"Unknown part type: {part_type}")
1487

1488
1489
1490
    return (
        {"type": modality}
        if wrap_dicts
1491
        else (MODALITY_PLACEHOLDERS_MAP[modality] if interleave_strings else None)
1492
    )
1493
1494


1495
1496
1497
1498
1499
# No need to validate using Pydantic again
_AssistantParser = partial(cast, ChatCompletionAssistantMessageParam)
_ToolParser = partial(cast, ChatCompletionToolMessageParam)


1500
def _parse_chat_message_content(
1501
1502
    message: ChatCompletionMessageParam,
    mm_tracker: BaseMultiModalItemTracker,
1503
    content_format: ChatTemplateContentFormat,
1504
    interleave_strings: bool,
1505
) -> list[ConversationMessage]:
1506
1507
    role = message["role"]
    content = message.get("content")
1508
    reasoning = message.get("reasoning") or message.get("reasoning_content")
1509

1510
    if content is None:
1511
1512
        content = []
    elif isinstance(content, str):
1513
        content = [ChatCompletionContentPartTextParam(type="text", text=content)]
1514
    result = _parse_chat_message_content_parts(
1515
1516
        role,
        content,  # type: ignore
1517
        mm_tracker,
1518
        wrap_dicts=(content_format == "openai"),
1519
        interleave_strings=interleave_strings,
1520
    )
1521

1522
    for result_msg in result:
1523
        if role == "assistant":
1524
1525
            parsed_msg = _AssistantParser(message)

1526
1527
1528
            # The 'tool_calls' is not None check ensures compatibility.
            # It's needed only if downstream code doesn't strictly
            # follow the OpenAI spec.
1529
            if "tool_calls" in parsed_msg and parsed_msg["tool_calls"] is not None:
1530
                result_msg["tool_calls"] = list(parsed_msg["tool_calls"])
1531
1532
1533
1534
1535
1536
            # Include reasoning if present for interleaved thinking.
            if reasoning is not None:
                result_msg["reasoning"] = cast(str, reasoning)
                result_msg["reasoning_content"] = cast(
                    str, reasoning
                )  # keep compatibility
1537
1538
1539
1540
1541
1542
1543
1544
        elif role == "tool":
            parsed_msg = _ToolParser(message)
            if "tool_call_id" in parsed_msg:
                result_msg["tool_call_id"] = parsed_msg["tool_call_id"]

        if "name" in message and isinstance(message["name"], str):
            result_msg["name"] = message["name"]

1545
1546
        if role == "developer":
            result_msg["tools"] = message.get("tools", None)
1547
1548
    return result

1549

1550
def _postprocess_messages(messages: list[ConversationMessage]) -> None:
1551
1552
1553
1554
1555
1556
    # per the Transformers docs & maintainers, tool call arguments in
    # assistant-role messages with tool_calls need to be dicts not JSON str -
    # this is how tool-use chat templates will expect them moving forwards
    # so, for messages that have tool_calls, parse the string (which we get
    # from openAI format) to dict
    for message in messages:
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
        if message["role"] == "assistant" and "tool_calls" in message:
            tool_calls = message.get("tool_calls")
            if not isinstance(tool_calls, list):
                continue

            if len(tool_calls) == 0:
                # Drop empty tool_calls to keep templates on the normal assistant path.
                message.pop("tool_calls", None)
                continue

            for item in tool_calls:
1568
1569
                # if arguments is None or empty string, set to {}
                if content := item["function"].get("arguments"):
1570
1571
                    if not isinstance(content, (dict, list)):
                        item["function"]["arguments"] = json.loads(content)
1572
1573
                else:
                    item["function"]["arguments"] = {}
1574
1575


1576
def parse_chat_messages(
1577
    messages: list[ChatCompletionMessageParam],
1578
    model_config: ModelConfig,
1579
    content_format: ChatTemplateContentFormat,
1580
1581
) -> tuple[
    list[ConversationMessage],
1582
1583
    MultiModalDataDict | None,
    MultiModalUUIDDict | None,
1584
]:
1585
    conversation: list[ConversationMessage] = []
1586
    mm_tracker = MultiModalItemTracker(model_config)
1587
1588

    for msg in messages:
1589
1590
1591
        sub_messages = _parse_chat_message_content(
            msg,
            mm_tracker,
1592
            content_format,
1593
1594
1595
1596
            interleave_strings=(
                content_format == "string"
                and model_config.multimodal_config is not None
                and model_config.multimodal_config.interleave_mm_strings
1597
            ),
1598
        )
1599

1600
        conversation.extend(sub_messages)
1601

1602
1603
    _postprocess_messages(conversation)

1604
1605
1606
    mm_data, mm_uuids = mm_tracker.resolve_items()

    return conversation, mm_data, mm_uuids
1607
1608


1609
async def parse_chat_messages_async(
1610
    messages: list[ChatCompletionMessageParam],
1611
    model_config: ModelConfig,
1612
    content_format: ChatTemplateContentFormat,
1613
1614
) -> tuple[
    list[ConversationMessage],
1615
    MultiModalDataDict | None,
1616
    MultiModalUUIDDict | None,
1617
]:
1618
    conversation: list[ConversationMessage] = []
1619
    mm_tracker = AsyncMultiModalItemTracker(model_config)
1620
1621

    for msg in messages:
1622
1623
1624
        sub_messages = _parse_chat_message_content(
            msg,
            mm_tracker,
1625
            content_format,
1626
1627
1628
1629
            interleave_strings=(
                content_format == "string"
                and model_config.multimodal_config is not None
                and model_config.multimodal_config.interleave_mm_strings
1630
            ),
1631
        )
1632
1633
1634

        conversation.extend(sub_messages)

1635
1636
    _postprocess_messages(conversation)

1637
1638
1639
    mm_data, mm_uuids = await mm_tracker.resolve_items()

    return conversation, mm_data, mm_uuids
1640

1641

1642
1643
1644
def get_history_tool_calls_cnt(conversation: list[ConversationMessage]):
    idx = 0
    for msg in conversation:
1645
1646
1647
        if msg["role"] == "assistant":
            tool_calls = msg.get("tool_calls")
            idx += len(list(tool_calls)) if tool_calls is not None else 0  # noqa
1648
1649
1650
    return idx


1651
1652
1653
def make_tool_call_id(id_type: str = "random", func_name=None, idx=None):
    if id_type == "kimi_k2":
        return f"functions.{func_name}:{idx}"
1654
1655
1656
    else:
        # by default return random
        return f"chatcmpl-tool-{random_uuid()}"