chat_utils.py 57.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
import json
6
import warnings
7
from abc import ABC, abstractmethod
8
from collections import Counter, defaultdict
9
from collections.abc import Awaitable, Callable, Iterable
10
from dataclasses import dataclass
11
from functools import cached_property, lru_cache, partial
12
from itertools import accumulate
13
from pathlib import Path
14
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, TypeVar, cast
15

16
from openai.types.chat import (
17
18
19
20
21
    ChatCompletionAssistantMessageParam,
    ChatCompletionContentPartImageParam,
    ChatCompletionContentPartInputAudioParam,
    ChatCompletionContentPartRefusalParam,
    ChatCompletionContentPartTextParam,
22
    ChatCompletionFunctionToolParam,
23
24
25
26
27
28
    ChatCompletionMessageToolCallParam,
    ChatCompletionToolMessageParam,
)
from openai.types.chat import (
    ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam,
)
29
from openai.types.chat import (
30
31
32
    ChatCompletionMessageParam as OpenAIChatCompletionMessageParam,
)
from openai.types.chat.chat_completion_content_part_input_audio_param import InputAudio
33
from openai.types.responses import ResponseInputImageParam
34
from openai_harmony import Message as OpenAIHarmonyMessage
35
36
from PIL import Image
from pydantic import BaseModel, ConfigDict, TypeAdapter
37

38
# pydantic needs the TypedDict from typing_extensions
39
from typing_extensions import Required, TypedDict
40

41
from vllm import envs
42
from vllm.config import ModelConfig
43
from vllm.logger import init_logger
44
from vllm.model_executor.models import SupportsMultiModal
45
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict
46
47
48
49
from vllm.multimodal.inputs import (
    MultiModalBatchedField,
    MultiModalFlatField,
    MultiModalSharedField,
Roger Wang's avatar
Roger Wang committed
50
51
52
    VisionChunk,
    VisionChunkImage,
    VisionChunkVideo,
53
)
54
from vllm.multimodal.media import MEDIA_CONNECTOR_REGISTRY, MediaConnector
55
from vllm.multimodal.processing import BaseMultiModalProcessor
56
from vllm.utils import random_uuid
57
58
59
60
61
from vllm.utils.collection_utils import is_list_of
from vllm.utils.import_utils import LazyLoader

if TYPE_CHECKING:
    import torch
62
    import transformers
63
else:
64
    transformers = LazyLoader("transformers", globals(), "transformers")
65
    torch = LazyLoader("torch", globals(), "torch")
66
67
68

logger = init_logger(__name__)

69

70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def __getattr__(name: str):
    if name == "resolve_hf_chat_template":
        from vllm.renderers.hf import resolve_chat_template

        warnings.warn(
            "`vllm.entrypoints.chat_utils.resolve_hf_chat_template` has been moved to "
            "`vllm.renderers.hf.resolve_chat_template`. "
            "The old name will be removed in v0.16.",
            DeprecationWarning,
            stacklevel=2,
        )

        return resolve_chat_template

    raise AttributeError(f"module {__name__!r} has no attribute {name!r}")


87
88
89
90
91
92
93
94
class ChatTemplateResolutionError(ValueError):
    """Raised when chat template resolution fails.

    This is a subclass of ValueError for backward compatibility with
    existing exception handlers.
    """


95
96
97
98
99
100
MODALITY_PLACEHOLDERS_MAP = {
    "image": "<##IMAGE##>",
    "audio": "<##AUDIO##>",
    "video": "<##VIDEO##>",
}

101

102
103
104
105
106
107
108
109
110
111
112
113
114
115
class AudioURL(TypedDict, total=False):
    url: Required[str]
    """
    Either a URL of the audio or a data URL with base64 encoded audio data.
    """


class ChatCompletionContentPartAudioParam(TypedDict, total=False):
    audio_url: Required[AudioURL]

    type: Required[Literal["audio_url"]]
    """The type of the content part."""


116
class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
117
    image_embeds: str | dict[str, str] | None
118
119
120
121
122
123
124
    """
    The image embeddings. It can be either:
    - A single base64 string.
    - A dictionary where each value is a base64 string.
    """
    type: Required[Literal["image_embeds"]]
    """The type of the content part."""
125
    uuid: str | None
126
127
128
129
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
130
131


132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
class ChatCompletionContentPartAudioEmbedsParam(TypedDict, total=False):
    audio_embeds: str | dict[str, str] | None
    """
    The audio embeddings. It can be either:
    - A single base64 string representing a serialized torch tensor.
    - A dictionary where each value is a base64 string.
    """
    type: Required[Literal["audio_embeds"]]
    """The type of the content part."""
    uuid: str | None
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """


148
149
150
151
152
153
154
155
156
157
158
159
160
161
class VideoURL(TypedDict, total=False):
    url: Required[str]
    """
    Either a URL of the video or a data URL with base64 encoded video data.
    """


class ChatCompletionContentPartVideoParam(TypedDict, total=False):
    video_url: Required[VideoURL]

    type: Required[Literal["video_url"]]
    """The type of the content part."""


162
163
164
165
class PILImage(BaseModel):
    """
    A PIL.Image.Image object.
    """
166

167
168
169
170
171
172
173
174
175
176
177
178
    image_pil: Image.Image
    model_config = ConfigDict(arbitrary_types_allowed=True)


class CustomChatCompletionContentPILImageParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a PIL image.

    Example:
    {
        "image_pil": ImageAsset('cherry_blossom').pil_image
    }
    """
179

180
181
    image_pil: PILImage | None
    uuid: str | None
182
183
184
185
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
186
187


188
189
190
class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain image_url.
    This is supported by OpenAI API, although it is not documented.
191

192
193
194
195
196
    Example:
    {
        "image_url": "https://example.com/image.jpg"
    }
    """
197

198
199
    image_url: str | None
    uuid: str | None
200
201
202
203
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
204
205
206
207


class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain audio_url.
208

209
210
211
212
213
    Example:
    {
        "audio_url": "https://example.com/audio.mp3"
    }
    """
214

215
    audio_url: str | None
216
217


218
219
220
221
222
223
224
225
class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain audio_url.

    Example:
    {
        "video_url": "https://example.com/video.mp4"
    }
    """
226

227
228
    video_url: str | None
    uuid: str | None
229
230
231
232
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
233
234


Julien Denize's avatar
Julien Denize committed
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
class CustomThinkCompletionContentParam(TypedDict, total=False):
    """A Think Completion Content Param that accepts a plain text and a boolean.

    Example:
    {
        "thinking": "I am thinking about the answer",
        "closed": True,
        "type": "thinking"
    }
    """

    thinking: Required[str]
    """The thinking content."""

    closed: bool
    """Whether the thinking is closed."""

    type: Required[Literal["thinking"]]
    """The thinking type."""


256
257
258
259
260
261
262
263
264
ChatCompletionContentPartParam: TypeAlias = (
    OpenAIChatCompletionContentPartParam
    | ChatCompletionContentPartAudioParam
    | ChatCompletionContentPartInputAudioParam
    | ChatCompletionContentPartVideoParam
    | ChatCompletionContentPartRefusalParam
    | CustomChatCompletionContentPILImageParam
    | CustomChatCompletionContentSimpleImageParam
    | ChatCompletionContentPartImageEmbedsParam
265
    | ChatCompletionContentPartAudioEmbedsParam
266
267
268
269
270
    | CustomChatCompletionContentSimpleAudioParam
    | CustomChatCompletionContentSimpleVideoParam
    | str
    | CustomThinkCompletionContentParam
)
271
272
273
274


class CustomChatCompletionMessageParam(TypedDict, total=False):
    """Enables custom roles in the Chat Completion API."""
275

276
277
278
    role: Required[str]
    """The role of the message's author."""

279
    content: str | list[ChatCompletionContentPartParam]
280
281
282
283
284
285
286
287
288
    """The contents of the message."""

    name: str
    """An optional name for the participant.

    Provides the model information to differentiate between participants of the
    same role.
    """

289
    tool_call_id: str | None
290
291
    """Tool call that this message is responding to."""

292
    tool_calls: Iterable[ChatCompletionMessageToolCallParam] | None
293
294
    """The tool calls generated by the model, such as function calls."""

295
296
297
    reasoning: str | None
    """The reasoning content for interleaved thinking."""

298
299
300
    tools: list[ChatCompletionFunctionToolParam] | None
    """The tools for developer role."""

301

302
303
304
305
306
ChatCompletionMessageParam: TypeAlias = (
    OpenAIChatCompletionMessageParam
    | CustomChatCompletionMessageParam
    | OpenAIHarmonyMessage
)
307
308


309
# TODO: Make fields ReadOnly once mypy supports it
310
311
312
313
class ConversationMessage(TypedDict, total=False):
    role: Required[str]
    """The role of the message's author."""

314
    content: str | None | list[dict[str, str]]
315
316
    """The contents of the message"""

317
    tool_call_id: str | None
318
319
    """Tool call that this message is responding to."""

320
    name: str | None
321
322
    """The name of the function to call"""

323
    tool_calls: Iterable[ChatCompletionMessageToolCallParam] | None
324
    """The tool calls generated by the model, such as function calls."""
325

326
327
328
329
330
331
    reasoning: str | None
    """The reasoning content for interleaved thinking."""

    reasoning_content: str | None
    """Deprecated: The reasoning content for interleaved thinking."""

332
333
334
    tools: list[ChatCompletionFunctionToolParam] | None
    """The tools for developer role."""

335

336
337
338
# Passed in by user
ChatTemplateContentFormatOption = Literal["auto", "string", "openai"]

339
340
# After resolving "auto"
ChatTemplateContentFormat = Literal["string", "openai"]
341

342

Roger Wang's avatar
Roger Wang committed
343
344
345
ModalityStr = Literal[
    "image", "audio", "video", "image_embeds", "audio_embeds", "vision_chunk"
]
346
347
348
_T = TypeVar("_T")


349
350
351
# Backward compatibility for single item input
class _BatchedSingleItemField(MultiModalSharedField):
    pass
352
353


354
355
356
357
358
359
def _detect_field(
    tensors: list[torch.Tensor],
    mm_processor: BaseMultiModalProcessor,
):
    first_item = tensors[0]
    hidden_size = mm_processor.info.ctx.model_config.get_inputs_embeds_size()
360

361
362
363
364
365
366
367
368
369
370
371
    if (
        len(tensors) == 1
        and first_item.ndim == 3
        and first_item.shape[0] == 1
        and first_item.shape[-1] == hidden_size
    ):
        logger.warning(
            "Batched multi-modal embedding inputs are deprecated for Chat API. "
            "Please pass a separate content part for each multi-modal item."
        )
        return _BatchedSingleItemField(batch_size=1)
372

373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
    first_shape = first_item.shape
    if all(t.shape == first_shape for t in tensors):
        return MultiModalBatchedField()

    size_per_item = [len(tensor) for tensor in tensors]
    slice_idxs = [0, *accumulate(size_per_item)]
    slices = [
        (slice(slice_idxs[i], slice_idxs[i + 1]),) for i in range(len(size_per_item))
    ]
    return MultiModalFlatField(slices=slices)


def _merge_embeds(
    data_items: list[dict[str, "torch.Tensor"]],
    mm_processor: BaseMultiModalProcessor,
):
    if not data_items:
        return {}

    first_keys = set(data_items[0].keys())
    if any(set(item.keys()) != first_keys for item in data_items[1:]):
        raise ValueError(
            "All dictionaries in the list of embeddings must have the same keys."
        )
397

398
399
400
401
402
403
404
405
    fields = {
        key: _detect_field([item[key] for item in data_items], mm_processor)
        for key in first_keys
    }
    data_merged = {
        key: field._reduce_data([item[key] for item in data_items], pin_memory=False)
        for key, field in fields.items()
    }
406

407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
    try:
        # TODO: Support per-request mm_processor_kwargs
        parsed_configs = mm_processor._get_mm_fields_config(
            transformers.BatchFeature(data_merged),
            {},
        )
        parsed_fields = {key: parsed_configs[key].field for key in first_keys}
        keys_to_update = [
            key
            for key in first_keys
            if (
                fields[key] != parsed_fields[key]
                and not isinstance(fields[key], _BatchedSingleItemField)
            )
        ]
422

423
424
425
        for key in keys_to_update:
            data_merged[key] = parsed_fields[key]._reduce_data(
                [item[key] for item in data_items], pin_memory=False
426
            )
427
428
429
430
431
    except Exception:
        logger.exception(
            "Error when parsing merged embeddings. "
            "Falling back to auto-detected fields."
        )
432

433
    return data_merged
434

435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455

def _get_embeds_data(
    modality: str,
    data_items: list[Any],
    mm_processor: BaseMultiModalProcessor,
):
    if len(data_items) == 0:
        return data_items

    if all(item is None for item in data_items):
        return data_items

    if is_list_of(data_items, torch.Tensor):
        embeds_key = f"{modality}_embeds"
        dict_items = [{embeds_key: item} for item in data_items]
        return _merge_embeds(dict_items, mm_processor)[embeds_key]

    if is_list_of(data_items, dict):
        return _merge_embeds(data_items, mm_processor)

    raise NotImplementedError(type(data_items))
456
457


458
class BaseMultiModalItemTracker(ABC, Generic[_T]):
459
460
461
462
463
464
    """
    Tracks multi-modal items in a given request and ensures that the number
    of multi-modal items in a given request does not exceed the configured
    maximum per prompt.
    """

465
466
467
468
469
    def __init__(
        self,
        model_config: ModelConfig,
        media_io_kwargs: dict[str, dict[str, Any]] | None = None,
    ):
470
471
        super().__init__()

472
        self._model_config = model_config
473
        self._media_io_kwargs = media_io_kwargs
474

475
        self._items_by_modality = defaultdict[str, list[_T]](list)
Roger Wang's avatar
Roger Wang committed
476
477
478
479
480
481
482
        # Track original modality for each vision_chunk item (image or video)
        self._modality_order = defaultdict[str, list[str]](list)

    @cached_property
    def use_unified_vision_chunk_modality(self) -> bool:
        """Check if model uses unified vision_chunk modality for images/videos."""
        return getattr(self._model_config.hf_config, "use_unified_vision_chunk", False)
483

484
    @property
485
486
    def model_config(self) -> ModelConfig:
        return self._model_config
487

488
    @cached_property
489
    def model_cls(self) -> type[SupportsMultiModal]:
490
        from vllm.model_executor.model_loader import get_model_cls
491

492
        model_cls = get_model_cls(self.model_config)
493
        return cast(type[SupportsMultiModal], model_cls)
494

495
496
497
498
499
500
501
502
    @property
    def media_io_kwargs(self) -> dict[str, dict[str, Any]] | None:
        return self._media_io_kwargs or (
            self._model_config.multimodal_config.media_io_kwargs
            if self._model_config.multimodal_config
            else None
        )

503
504
    @property
    def allowed_local_media_path(self):
505
        return self._model_config.allowed_local_media_path
506

507
508
    @property
    def allowed_media_domains(self):
509
        return self._model_config.allowed_media_domains
510

511
512
513
514
    @property
    def mm_registry(self):
        return MULTIMODAL_REGISTRY

515
516
    @cached_property
    def mm_processor(self):
517
        return self.mm_registry.create_processor(self.model_config)
518

519
    def add(self, modality: ModalityStr, item: _T) -> str | None:
520
521
522
        """
        Add a multi-modal item to the current prompt and returns the
        placeholder string to use, if any.
523
524

        An optional uuid can be added which serves as a unique identifier of the
525
        media.
526
        """
527
        input_modality = modality.replace("_embeds", "")
Roger Wang's avatar
Roger Wang committed
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
        original_modality = modality
        use_vision_chunk = (
            self.use_unified_vision_chunk_modality
            and original_modality in ["video", "image"]
        )

        # If use_unified_vision_chunk_modality is enabled,
        # map image/video to vision_chunk
        if use_vision_chunk:
            # To avoid validation fail
            # because models with use_unified_vision_chunk_modality=True
            # will only accept vision_chunk modality.
            input_modality = "vision_chunk"
            num_items = len(self._items_by_modality[input_modality]) + 1
        else:
            num_items = len(self._items_by_modality[original_modality]) + 1
544

545
546
547
548
549
550
551
552
553
554
555
        mm_config = self.model_config.multimodal_config
        if (
            mm_config is not None
            and mm_config.enable_mm_embeds
            and mm_config.get_limit_per_prompt(input_modality) == 0
            and original_modality.endswith("_embeds")
        ):
            # Skip validation: embeddings bypass limit when enable_mm_embeds=True
            pass
        else:
            self.mm_processor.info.validate_num_items(input_modality, num_items)
556

Roger Wang's avatar
Roger Wang committed
557
558
559
560
561
562
        # Track original modality for vision_chunk items
        if use_vision_chunk:
            self._items_by_modality[input_modality].append(item)  # type: ignore
            self._modality_order["vision_chunk"].append(original_modality)
        else:
            self._items_by_modality[original_modality].append(item)
563

564
        return self.model_cls.get_placeholder_str(modality, num_items)
565
566

    @abstractmethod
567
568
569
    def create_parser(
        self, mm_processor_kwargs: dict[str, Any] | None = None
    ) -> "BaseMultiModalContentParser":
570
571
572
        raise NotImplementedError


573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
def _resolve_vision_chunk_items(
    vision_chunk_items: list[tuple[object, str | None]],
    mm_processor: BaseMultiModalProcessor,
    vision_chunks_modality_order: list[str],
):
    # Process vision_chunk items - extract from (data, modality) tuples
    # and convert to VisionChunk types with proper UUID handling
    vision_chunks_uuids = [uuid for data, uuid in vision_chunk_items]

    assert len(vision_chunk_items) == len(vision_chunks_modality_order), (
        f"vision_chunk items ({len(vision_chunk_items)}) and "
        f"modality_order ({len(vision_chunks_modality_order)}) must have same length"
    )

    processed_chunks: list[VisionChunk] = []
    video_idx = 0
    for inner_modality, (data, uuid) in zip(
        vision_chunks_modality_order, vision_chunk_items
    ):
        if inner_modality == "image":
            # Cast data to proper type for image
            # Use .media (PIL.Image) directly to avoid redundant
            # bytes→PIL conversion in media_processor
            if hasattr(data, "media"):
                image_data = data.media  # type: ignore[union-attr]
                processed_chunks.append(
                    VisionChunkImage(type="image", image=image_data, uuid=uuid)
                )
            else:
                processed_chunks.append(data)  # type: ignore[arg-type]
        elif inner_modality == "video":
            # For video, we may need to split into chunks
            # if processor supports it
            # For now, just wrap as a video chunk placeholder
            if hasattr(mm_processor, "split_video_chunks") and data is not None:
                try:
                    video_uuid = uuid or random_uuid()
                    # video await result is (video_data, video_meta) tuple
                    if isinstance(data, tuple) and len(data) >= 1:
                        video_data = data[0]
                    else:
                        video_data = data
                    video_chunks = mm_processor.split_video_chunks(video_data)
                    for i, vc in enumerate(video_chunks):
                        processed_chunks.append(
                            VisionChunkVideo(
                                type="video_chunk",
                                video_chunk=vc["video_chunk"],
                                uuid=f"{video_uuid}-{i}",
                                video_idx=video_idx,
                                prompt=vc["prompt"],
                            )
                        )
                    video_idx += 1
                except Exception as e:
                    logger.warning("Failed to split video chunks: %s", e)
                    processed_chunks.append(data)  # type: ignore[arg-type]
            else:
                processed_chunks.append(data)  # type: ignore[arg-type]
    return processed_chunks, vision_chunks_uuids


635
636
637
def _resolve_items(
    items_by_modality: dict[str, list[tuple[object, str | None]]],
    mm_processor: BaseMultiModalProcessor,
638
    modality_order: dict[str, list[str]],
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
) -> tuple[MultiModalDataDict, MultiModalUUIDDict]:
    if "image" in items_by_modality and "image_embeds" in items_by_modality:
        raise ValueError("Mixing raw image and embedding inputs is not allowed")
    if "audio" in items_by_modality and "audio_embeds" in items_by_modality:
        raise ValueError("Mixing raw audio and embedding inputs is not allowed")

    mm_data = {}
    mm_uuids = {}
    if "image_embeds" in items_by_modality:
        mm_data["image"] = _get_embeds_data(
            "image",
            [data for data, uuid in items_by_modality["image_embeds"]],
            mm_processor,
        )
        mm_uuids["image"] = [uuid for data, uuid in items_by_modality["image_embeds"]]
    if "image" in items_by_modality:
        mm_data["image"] = [data for data, uuid in items_by_modality["image"]]
        mm_uuids["image"] = [uuid for data, uuid in items_by_modality["image"]]
    if "audio_embeds" in items_by_modality:
        mm_data["audio"] = _get_embeds_data(
            "audio",
            [data for data, uuid in items_by_modality["audio_embeds"]],
            mm_processor,
        )
        mm_uuids["audio"] = [uuid for data, uuid in items_by_modality["audio_embeds"]]
    if "audio" in items_by_modality:
        mm_data["audio"] = [data for data, uuid in items_by_modality["audio"]]
        mm_uuids["audio"] = [uuid for data, uuid in items_by_modality["audio"]]
    if "video" in items_by_modality:
        mm_data["video"] = [data for data, uuid in items_by_modality["video"]]
        mm_uuids["video"] = [uuid for data, uuid in items_by_modality["video"]]
Roger Wang's avatar
Roger Wang committed
670
671
672
    if "vision_chunk" in items_by_modality:
        # Process vision_chunk items - extract from (data, modality) tuples
        # and convert to VisionChunk types with proper UUID handling
673
674
675
676
        processed_chunks, vision_chunk_uuids = _resolve_vision_chunk_items(
            items_by_modality["vision_chunk"],
            mm_processor,
            modality_order.get("vision_chunk", []),
Roger Wang's avatar
Roger Wang committed
677
678
        )
        mm_data["vision_chunk"] = processed_chunks
679
        mm_uuids["vision_chunk"] = vision_chunk_uuids
680
681
682
683
684
685
686
687

    return mm_data, mm_uuids


class MultiModalItemTracker(BaseMultiModalItemTracker[tuple[object, str | None]]):
    def resolve_items(
        self,
    ) -> tuple[MultiModalDataDict | None, MultiModalUUIDDict | None]:
688
        if not self._items_by_modality:
689
690
            return None, None

Roger Wang's avatar
Roger Wang committed
691
692
693
        return _resolve_items(
            dict(self._items_by_modality), self.mm_processor, self._modality_order
        )
694

695
696
697
698
    def create_parser(
        self, mm_processor_kwargs: dict[str, Any] | None = None
    ) -> "BaseMultiModalContentParser":
        return MultiModalContentParser(self, mm_processor_kwargs=mm_processor_kwargs)
699
700


701
702
703
704
705
706
class AsyncMultiModalItemTracker(
    BaseMultiModalItemTracker[Awaitable[tuple[object, str | None]]]
):
    async def resolve_items(
        self,
    ) -> tuple[MultiModalDataDict | None, MultiModalUUIDDict | None]:
707
        if not self._items_by_modality:
708
            return None, None
709

710
        resolved_items_by_modality = {
711
            modality: await asyncio.gather(*coros)
712
            for modality, coros in self._items_by_modality.items()
713
        }
714

Roger Wang's avatar
Roger Wang committed
715
716
717
        return _resolve_items(
            resolved_items_by_modality, self.mm_processor, self._modality_order
        )
718

719
720
721
722
723
724
    def create_parser(
        self, mm_processor_kwargs: dict[str, Any] | None = None
    ) -> "BaseMultiModalContentParser":
        return AsyncMultiModalContentParser(
            self, mm_processor_kwargs=mm_processor_kwargs
        )
725
726
727
728
729
730


class BaseMultiModalContentParser(ABC):
    def __init__(self) -> None:
        super().__init__()

731
        # stores model placeholders list with corresponding
732
733
734
735
736
737
738
        # general MM placeholder:
        # {
        #   "<##IMAGE##>": ["<image>", "<image>", "<image>"],
        #   "<##AUDIO##>": ["<audio>", "<audio>"]
        # }
        self._placeholder_storage: dict[str, list] = defaultdict(list)

739
    def _add_placeholder(self, modality: ModalityStr, placeholder: str | None):
740
        mod_placeholder = MODALITY_PLACEHOLDERS_MAP[modality]
741
        if placeholder:
742
            self._placeholder_storage[mod_placeholder].append(placeholder)
743

744
745
    def mm_placeholder_storage(self) -> dict[str, list]:
        return dict(self._placeholder_storage)
746
747

    @abstractmethod
748
    def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
749
750
        raise NotImplementedError

751
    @abstractmethod
752
    def parse_image_embeds(
753
        self,
754
755
        image_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
756
    ) -> None:
757
758
        raise NotImplementedError

759
    @abstractmethod
760
    def parse_image_pil(
761
        self, image_pil: Image.Image | None, uuid: str | None = None
762
    ) -> None:
763
764
        raise NotImplementedError

765
    @abstractmethod
766
    def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None:
767
768
        raise NotImplementedError

769
    @abstractmethod
770
    def parse_input_audio(
771
        self, input_audio: InputAudio | None, uuid: str | None = None
772
    ) -> None:
773
774
        raise NotImplementedError

775
776
777
778
779
780
781
782
    @abstractmethod
    def parse_audio_embeds(
        self,
        audio_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
    ) -> None:
        raise NotImplementedError

783
    @abstractmethod
784
    def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
785
786
        raise NotImplementedError

787
788

class MultiModalContentParser(BaseMultiModalContentParser):
789
790
791
792
793
    def __init__(
        self,
        tracker: MultiModalItemTracker,
        mm_processor_kwargs: dict[str, Any] | None = None,
    ) -> None:
794
795
796
        super().__init__()

        self._tracker = tracker
797

798
799
        self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load(
            envs.VLLM_MEDIA_CONNECTOR,
800
            media_io_kwargs=tracker.media_io_kwargs,
801
            allowed_local_media_path=tracker.allowed_local_media_path,
802
            allowed_media_domains=tracker.allowed_media_domains,
803
804
        )

805
806
        self._mm_processor_kwargs = mm_processor_kwargs

807
808
    @property
    def model_config(self) -> ModelConfig:
809
        return self._tracker.model_config
810

811
    def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
812
        image = self._connector.fetch_image(image_url) if image_url else None
813

814
        placeholder = self._tracker.add("image", (image, uuid))
815
        self._add_placeholder("image", placeholder)
816

817
    def parse_image_embeds(
818
        self,
819
820
        image_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
821
    ) -> None:
822
823
824
825
826
827
        mm_config = self.model_config.get_multimodal_config()
        if not mm_config.enable_mm_embeds:
            raise ValueError(
                "You must set `--enable-mm-embeds` to input `image_embeds`"
            )

828
829
830
831
832
        if isinstance(image_embeds, dict):
            embeds = {
                k: self._connector.fetch_image_embedding(v)
                for k, v in image_embeds.items()
            }
833
            placeholder = self._tracker.add("image_embeds", (embeds, uuid))
834
835
836

        if isinstance(image_embeds, str):
            embedding = self._connector.fetch_image_embedding(image_embeds)
837
            placeholder = self._tracker.add("image_embeds", (embedding, uuid))
838

839
        if image_embeds is None:
840
            placeholder = self._tracker.add("image_embeds", (None, uuid))
841

842
        self._add_placeholder("image", placeholder)
843

844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
    def parse_audio_embeds(
        self,
        audio_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
    ) -> None:
        mm_config = self.model_config.get_multimodal_config()
        if not mm_config.enable_mm_embeds:
            raise ValueError(
                "You must set `--enable-mm-embeds` to input `audio_embeds`"
            )

        if isinstance(audio_embeds, dict):
            embeds = {
                k: self._connector.fetch_audio_embedding(v)
                for k, v in audio_embeds.items()
            }
860
            placeholder = self._tracker.add("audio_embeds", (embeds, uuid))
861
862
        elif isinstance(audio_embeds, str):
            embedding = self._connector.fetch_audio_embedding(audio_embeds)
863
            placeholder = self._tracker.add("audio_embeds", (embedding, uuid))
864
        else:
865
            placeholder = self._tracker.add("audio_embeds", (None, uuid))
866
867
868

        self._add_placeholder("audio", placeholder)

869
    def parse_image_pil(
870
        self, image_pil: Image.Image | None, uuid: str | None = None
871
    ) -> None:
872
        placeholder = self._tracker.add("image", (image_pil, uuid))
873
        self._add_placeholder("image", placeholder)
874

875
    def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None:
876
        audio = self._connector.fetch_audio(audio_url) if audio_url else None
877

878
        placeholder = self._tracker.add("audio", (audio, uuid))
879
        self._add_placeholder("audio", placeholder)
880

881
    def parse_input_audio(
882
        self, input_audio: InputAudio | None, uuid: str | None = None
883
    ) -> None:
884
885
886
887
888
889
890
891
892
893
        if input_audio:
            audio_data = input_audio.get("data", "")
            audio_format = input_audio.get("format", "")
            if audio_data:
                audio_url = f"data:audio/{audio_format};base64,{audio_data}"
            else:
                # If a UUID is provided, audio data may be empty.
                audio_url = None
        else:
            audio_url = None
894

895
        return self.parse_audio(audio_url, uuid)
896

897
    def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
898
        video = self._connector.fetch_video(video_url=video_url) if video_url else None
899

900
        placeholder = self._tracker.add("video", (video, uuid))
901
        self._add_placeholder("video", placeholder)
902

903
904
905
906
907
908
909
910
911
912
        # Extract audio from video if use_audio_in_video is True
        if (
            video_url
            and self._mm_processor_kwargs
            and self._mm_processor_kwargs.get("use_audio_in_video", False)
        ):
            audio = self._connector.fetch_audio(video_url) if video_url else None
            audio_placeholder = self._tracker.add("audio", (audio, uuid))
            self._add_placeholder("audio", audio_placeholder)

913
914

class AsyncMultiModalContentParser(BaseMultiModalContentParser):
915
916
917
918
919
    def __init__(
        self,
        tracker: AsyncMultiModalItemTracker,
        mm_processor_kwargs: dict[str, Any] | None = None,
    ) -> None:
920
921
922
        super().__init__()

        self._tracker = tracker
923
924
        self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load(
            envs.VLLM_MEDIA_CONNECTOR,
925
            media_io_kwargs=tracker.media_io_kwargs,
926
            allowed_local_media_path=tracker.allowed_local_media_path,
927
            allowed_media_domains=tracker.allowed_media_domains,
928
        )
929
        self._mm_processor_kwargs: dict[str, Any] | None = mm_processor_kwargs
930

931
932
    @property
    def model_config(self) -> ModelConfig:
933
        return self._tracker.model_config
934

935
936
937
938
939
940
    async def _image_with_uuid_async(self, image_url: str | None, uuid: str | None):
        image = (
            await self._connector.fetch_image_async(image_url) if image_url else None
        )
        return image, uuid

941
    def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
942
        coro = self._image_with_uuid_async(image_url, uuid)
943

944
        placeholder = self._tracker.add("image", coro)
945
        self._add_placeholder("image", placeholder)
946

947
    def parse_image_embeds(
948
        self,
949
950
        image_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
951
    ) -> None:
952
953
954
955
956
957
        mm_config = self.model_config.get_multimodal_config()
        if not mm_config.enable_mm_embeds:
            raise ValueError(
                "You must set `--enable-mm-embeds` to input `image_embeds`"
            )

958
959
960
        future = asyncio.Future[
            tuple[torch.Tensor | dict[str, torch.Tensor] | None, str | None]
        ]()
961
962
963
964
965
966

        if isinstance(image_embeds, dict):
            embeds = {
                k: self._connector.fetch_image_embedding(v)
                for k, v in image_embeds.items()
            }
967
            future.set_result((embeds, uuid))
968
969

        if isinstance(image_embeds, str):
970
            embedding = self._connector.fetch_image_embedding(image_embeds)
971
            future.set_result((embedding, uuid))
972

973
        if image_embeds is None:
974
            future.set_result((None, uuid))
975

976
        placeholder = self._tracker.add("image_embeds", future)
977
        self._add_placeholder("image", placeholder)
978

979
980
981
982
983
984
985
986
987
988
989
    def parse_audio_embeds(
        self,
        audio_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
    ) -> None:
        mm_config = self.model_config.get_multimodal_config()
        if not mm_config.enable_mm_embeds:
            raise ValueError(
                "You must set `--enable-mm-embeds` to input `audio_embeds`"
            )

990
991
992
        future = asyncio.Future[
            tuple[torch.Tensor | dict[str, torch.Tensor] | None, str | None]
        ]()
993
994
995
996
997
998

        if isinstance(audio_embeds, dict):
            embeds = {
                k: self._connector.fetch_audio_embedding(v)
                for k, v in audio_embeds.items()
            }
999
            future.set_result((embeds, uuid))
1000
1001
1002

        if isinstance(audio_embeds, str):
            embedding = self._connector.fetch_audio_embedding(audio_embeds)
1003
            future.set_result((embedding, uuid))
1004
1005

        if audio_embeds is None:
1006
            future.set_result((None, uuid))
1007

1008
        placeholder = self._tracker.add("audio_embeds", future)
1009
1010
        self._add_placeholder("audio", placeholder)

1011
    def parse_image_pil(
1012
1013
1014
        self,
        image_pil: Image.Image | None,
        uuid: str | None = None,
1015
    ) -> None:
1016
        future = asyncio.Future[tuple[Image.Image | None, str | None]]()
1017
        if image_pil:
1018
            future.set_result((image_pil, uuid))
1019
        else:
1020
            future.set_result((None, uuid))
1021

1022
        placeholder = self._tracker.add("image", future)
1023
        self._add_placeholder("image", placeholder)
1024

1025
1026
1027
1028
1029
1030
    async def _audio_with_uuid_async(self, audio_url: str | None, uuid: str | None):
        audio = (
            await self._connector.fetch_audio_async(audio_url) if audio_url else None
        )
        return audio, uuid

1031
    def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None:
1032
        coro = self._audio_with_uuid_async(audio_url, uuid)
1033

1034
        placeholder = self._tracker.add("audio", coro)
1035
        self._add_placeholder("audio", placeholder)
1036

1037
    def parse_input_audio(
1038
        self, input_audio: InputAudio | None, uuid: str | None = None
1039
    ) -> None:
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
        if input_audio:
            audio_data = input_audio.get("data", "")
            audio_format = input_audio.get("format", "")
            if audio_data:
                audio_url = f"data:audio/{audio_format};base64,{audio_data}"
            else:
                # If a UUID is provided, audio data may be empty.
                audio_url = None
        else:
            audio_url = None
1050

1051
        return self.parse_audio(audio_url, uuid)
1052

1053
    async def _video_with_uuid_async(self, video_url: str | None, uuid: str | None):
1054
        video = (
1055
            await self._connector.fetch_video_async(video_url) if video_url else None
1056
        )
1057
1058
1059
1060
        return video, uuid

    def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
        coro = self._video_with_uuid_async(video_url, uuid)
1061

1062
        placeholder = self._tracker.add("video", coro)
1063
        self._add_placeholder("video", placeholder)
1064

1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
        # Extract audio from video if use_audio_in_video is True
        if (
            video_url
            and self._mm_processor_kwargs
            and self._mm_processor_kwargs.get("use_audio_in_video", False)
        ):
            audio_coro = self._audio_with_uuid_async(video_url, uuid)
            audio_placeholder = self._tracker.add("audio", audio_coro)
            self._add_placeholder("audio", audio_placeholder)

1075

1076
1077
1078
1079
1080
1081
1082
@dataclass
class ChatTemplateConfig:
    chat_template: str | None = None
    chat_template_content_format: ChatTemplateContentFormatOption = "auto"
    trust_request_chat_template: bool = False


1083
def validate_chat_template(chat_template: Path | str | None):
1084
1085
1086
1087
1088
    """Raises if the provided chat template appears invalid."""
    if chat_template is None:
        return

    elif isinstance(chat_template, Path) and not chat_template.exists():
1089
        raise FileNotFoundError("the supplied chat template path doesn't exist")
1090
1091
1092

    elif isinstance(chat_template, str):
        JINJA_CHARS = "{}\n"
1093
1094
1095
1096
        if (
            not any(c in chat_template for c in JINJA_CHARS)
            and not Path(chat_template).exists()
        ):
1097
1098
1099
            # Try to find the template in the built-in templates directory
            from vllm.transformers_utils.chat_templates.registry import (
                CHAT_TEMPLATES_DIR,
1100
            )
1101

1102
1103
1104
1105
1106
1107
1108
1109
            builtin_template_path = CHAT_TEMPLATES_DIR / chat_template
            if not builtin_template_path.exists():
                raise ValueError(
                    f"The supplied chat template string ({chat_template}) "
                    f"appears path-like, but doesn't exist! "
                    f"Tried: {chat_template} and {builtin_template_path}"
                )

1110
    else:
1111
        raise TypeError(f"{type(chat_template)} is not a valid chat template type")
1112
1113


1114
def _load_chat_template(
1115
    chat_template: Path | str | None,
1116
1117
    *,
    is_literal: bool = False,
1118
) -> str | None:
1119
1120
    if chat_template is None:
        return None
1121
1122
1123

    if is_literal:
        if isinstance(chat_template, Path):
1124
1125
1126
            raise TypeError(
                "chat_template is expected to be read directly from its value"
            )
1127

1128
        return chat_template
1129

1130
    try:
1131
        with open(chat_template) as f:
1132
            return f.read()
1133
    except OSError as e:
1134
1135
1136
        if isinstance(chat_template, Path):
            raise

1137
1138
        JINJA_CHARS = "{}\n"
        if not any(c in chat_template for c in JINJA_CHARS):
1139
1140
1141
            # Try to load from the built-in templates directory
            from vllm.transformers_utils.chat_templates.registry import (
                CHAT_TEMPLATES_DIR,
1142
            )
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155

            builtin_template_path = CHAT_TEMPLATES_DIR / chat_template
            try:
                with open(builtin_template_path) as f:
                    return f.read()
            except OSError:
                msg = (
                    f"The supplied chat template ({chat_template}) "
                    f"looks like a file path, but it failed to be opened. "
                    f"Tried: {chat_template} and {builtin_template_path}. "
                    f"Reason: {e}"
                )
                raise ValueError(msg) from e
1156

1157
1158
        # If opening a file fails, set chat template to be args to
        # ensure we decode so our escape are interpreted correctly
1159
1160
1161
1162
1163
1164
1165
        return _load_chat_template(chat_template, is_literal=True)


_cached_load_chat_template = lru_cache(_load_chat_template)


def load_chat_template(
1166
    chat_template: Path | str | None,
1167
1168
    *,
    is_literal: bool = False,
1169
) -> str | None:
1170
    return _cached_load_chat_template(chat_template, is_literal=is_literal)
1171
1172


1173
1174
1175
def _get_interleaved_text_prompt(
    placeholder_storage: dict[str, list], texts: list[str]
) -> str:
1176
1177
1178
1179
1180
1181
1182
    for idx, elem in enumerate(texts):
        if elem in placeholder_storage:
            texts[idx] = placeholder_storage[elem].pop(0)

    return "\n".join(texts)


1183
# TODO: Let user specify how to insert multimodal tokens into prompt
1184
# (similar to chat template)
1185
1186
1187
1188
1189
def _get_full_multimodal_text_prompt(
    placeholder_storage: dict[str, list],
    texts: list[str],
    interleave_strings: bool,
) -> str:
1190
    """Combine multimodal prompts for a multimodal language model."""
1191

1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
    # flatten storage to make it looks like
    # {
    #   "<|image|>": 2,
    #   "<|audio|>": 1
    # }
    placeholder_counts = Counter(
        [v for elem in placeholder_storage.values() for v in elem]
    )

    if interleave_strings:
        text_prompt = _get_interleaved_text_prompt(placeholder_storage, texts)
    else:
        text_prompt = "\n".join(texts)

    # Pass interleaved text further in case the user used image placeholders
    # himself, but forgot to disable the 'interleave_strings' flag

1209
    # Look through the text prompt to check for missing placeholders
1210
    missing_placeholders: list[str] = []
1211
1212
1213
1214
1215
    for placeholder in placeholder_counts:
        # For any existing placeholder in the text prompt, we leave it as is
        placeholder_counts[placeholder] -= text_prompt.count(placeholder)

        if placeholder_counts[placeholder] < 0:
1216
1217
1218
1219
            logger.error(
                "Placeholder count is negative! "
                "Ensure that the 'interleave_strings' flag is disabled "
                "(current value: %s) "
1220
1221
                "when manually placing image placeholders.",
                interleave_strings,
1222
1223
            )
            logger.debug("Input prompt: %s", text_prompt)
1224
1225
            raise ValueError(
                f"Found more '{placeholder}' placeholders in input prompt than "
1226
1227
                "actual multimodal data items."
            )
1228

1229
        missing_placeholders.extend([placeholder] * placeholder_counts[placeholder])
1230

1231
1232
    # NOTE: Default behaviour: we always add missing placeholders
    # at the front of the prompt, if interleave_strings=False
1233
1234
1235
1236
    if text_prompt:
        return "\n".join(missing_placeholders + [text_prompt])
    else:
        return "\n".join(missing_placeholders)
1237
1238


1239
1240
# No need to validate using Pydantic again
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
1241
_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
1242
_AudioEmbedsParser = partial(cast, ChatCompletionContentPartAudioEmbedsParam)
1243
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
1244
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
1245
_PILImageParser = partial(cast, CustomChatCompletionContentPILImageParam)
Julien Denize's avatar
Julien Denize committed
1246
_ThinkParser = partial(cast, CustomThinkCompletionContentParam)
1247
1248
1249
1250
# Need to validate url objects
_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam).validate_python
_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python
_VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python
1251

1252
_ResponsesInputImageParser = TypeAdapter(ResponseInputImageParam).validate_python
1253
_ContentPart: TypeAlias = str | dict[str, str] | InputAudio | PILImage
1254

1255
# Define a mapping from part types to their corresponding parsing functions.
1256
MM_PARSER_MAP: dict[
1257
1258
1259
    str,
    Callable[[ChatCompletionContentPartParam], _ContentPart],
] = {
1260
1261
1262
    "text": lambda part: _TextParser(part).get("text", None),
    "thinking": lambda part: _ThinkParser(part).get("thinking", None),
    "input_text": lambda part: _TextParser(part).get("text", None),
1263
    "output_text": lambda part: _TextParser(part).get("text", None),
1264
1265
1266
    "input_image": lambda part: _ResponsesInputImageParser(part).get("image_url", None),
    "image_url": lambda part: _ImageParser(part).get("image_url", {}).get("url", None),
    "image_embeds": lambda part: _ImageEmbedsParser(part).get("image_embeds", None),
1267
    "audio_embeds": lambda part: _AudioEmbedsParser(part).get("audio_embeds", None),
1268
    "image_pil": lambda part: _PILImageParser(part).get("image_pil", None),
1269
1270
    "audio_url": lambda part: _AudioParser(part).get("audio_url", {}).get("url", None),
    "input_audio": lambda part: _InputAudioParser(part).get("input_audio", None),
1271
    "refusal": lambda part: _RefusalParser(part).get("refusal", None),
1272
    "video_url": lambda part: _VideoParser(part).get("video_url", {}).get("url", None),
1273
1274
1275
1276
}


def _parse_chat_message_content_mm_part(
1277
1278
    part: ChatCompletionContentPartParam,
) -> tuple[str, _ContentPart]:
1279
    """
1280
    Parses a given multi-modal content part based on its type.
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293

    Args:
        part: A dict containing the content part, with a potential 'type' field.

    Returns:
        A tuple (part_type, content) where:
        - part_type: Type of the part (e.g., 'text', 'image_url').
        - content: Parsed content (e.g., text, image URL).

    Raises:
        ValueError: If the 'type' field is missing and no direct URL is found.
    """
    assert isinstance(
1294
1295
        part, dict
    )  # This is needed to avoid mypy errors: part.get() from str
1296
    part_type = part.get("type", None)
1297
    uuid = part.get("uuid", None)
1298

1299
    if isinstance(part_type, str) and part_type in MM_PARSER_MAP and uuid is None:  # noqa: E501
1300
1301
1302
        content = MM_PARSER_MAP[part_type](part)

        # Special case for 'image_url.detail'
1303
1304
        # We only support 'auto', which is the default
        if part_type == "image_url" and part.get("detail", "auto") != "auto":
1305
            logger.warning(
1306
                "'image_url.detail' is currently not supported and will be ignored."
1307
            )
1308
1309
1310
1311

        return part_type, content

    # Handle missing 'type' but provided direct URL fields.
1312
    # 'type' is required field by pydantic
1313
1314
    if part_type is None or uuid is not None:
        if "image_url" in part:
1315
            image_params = cast(CustomChatCompletionContentSimpleImageParam, part)
1316
1317
1318
1319
1320
1321
1322
1323
            image_url = image_params.get("image_url", None)
            if isinstance(image_url, dict):
                # Can potentially happen if user provides a uuid
                # with url as a dict of {"url": url}
                image_url = image_url.get("url", None)
            return "image_url", image_url
        if "image_pil" in part:
            # "image_pil" could be None if UUID is provided.
1324
            image_params = cast(  # type: ignore
1325
1326
1327
1328
1329
1330
                CustomChatCompletionContentPILImageParam, part
            )
            image_pil = image_params.get("image_pil", None)
            return "image_pil", image_pil
        if "image_embeds" in part:
            # "image_embeds" could be None if UUID is provided.
1331
            image_params = cast(  # type: ignore
1332
1333
1334
1335
                ChatCompletionContentPartImageEmbedsParam, part
            )
            image_embeds = image_params.get("image_embeds", None)
            return "image_embeds", image_embeds
1336
1337
1338
1339
1340
1341
1342
        if "audio_embeds" in part:
            # "audio_embeds" could be None if UUID is provided.
            audio_params = cast(  # type: ignore[assignment]
                ChatCompletionContentPartAudioEmbedsParam, part
            )
            audio_embeds = audio_params.get("audio_embeds", None)
            return "audio_embeds", audio_embeds
1343
        if "audio_url" in part:
1344
1345
1346
            audio_params = cast(  # type: ignore[assignment]
                CustomChatCompletionContentSimpleAudioParam, part
            )
1347
1348
1349
1350
1351
1352
            audio_url = audio_params.get("audio_url", None)
            if isinstance(audio_url, dict):
                # Can potentially happen if user provides a uuid
                # with url as a dict of {"url": url}
                audio_url = audio_url.get("url", None)
            return "audio_url", audio_url
1353
        if part.get("input_audio") is not None:
1354
            input_audio_params = cast(dict[str, str], part)
1355
            return "input_audio", input_audio_params
1356
        if "video_url" in part:
1357
            video_params = cast(CustomChatCompletionContentSimpleVideoParam, part)
1358
1359
1360
1361
1362
1363
            video_url = video_params.get("video_url", None)
            if isinstance(video_url, dict):
                # Can potentially happen if user provides a uuid
                # with url as a dict of {"url": url}
                video_url = video_url.get("url", None)
            return "video_url", video_url
1364
1365
1366
1367
1368
1369
1370
1371
        # Raise an error if no 'type' or direct URL is found.
        raise ValueError("Missing 'type' field in multimodal part.")

    if not isinstance(part_type, str):
        raise ValueError("Invalid 'type' field in multimodal part.")
    return part_type, "unknown part_type content"


1372
PART_TYPES_TO_SKIP_NONE_CONTENT = (
1373
1374
1375
    "text",
    "refusal",
)
1376

1377

1378
1379
1380
def _parse_chat_message_content_parts(
    role: str,
    parts: Iterable[ChatCompletionContentPartParam],
1381
    mm_tracker: BaseMultiModalItemTracker,
1382
1383
    *,
    wrap_dicts: bool,
1384
    interleave_strings: bool,
1385
    mm_processor_kwargs: dict[str, Any] | None = None,
1386
) -> list[ConversationMessage]:
1387
    content = list[_ContentPart]()
1388

1389
    mm_parser = mm_tracker.create_parser(mm_processor_kwargs=mm_processor_kwargs)
1390
1391

    for part in parts:
1392
        parse_res = _parse_chat_message_content_part(
1393
1394
1395
            part,
            mm_parser,
            wrap_dicts=wrap_dicts,
1396
            interleave_strings=interleave_strings,
1397
        )
1398
1399
        if parse_res:
            content.append(parse_res)
1400

1401
    if wrap_dicts:
1402
        # Parsing wraps images and texts as interleaved dictionaries
1403
        return [ConversationMessage(role=role, content=content)]  # type: ignore
1404
    texts = cast(list[str], content)
1405
1406
    mm_placeholder_storage = mm_parser.mm_placeholder_storage()
    if mm_placeholder_storage:
1407
1408
1409
        text_prompt = _get_full_multimodal_text_prompt(
            mm_placeholder_storage, texts, interleave_strings
        )
1410
1411
1412
    else:
        text_prompt = "\n".join(texts)

1413
1414
1415
1416
    return [ConversationMessage(role=role, content=text_prompt)]


def _parse_chat_message_content_part(
1417
1418
1419
1420
    part: ChatCompletionContentPartParam,
    mm_parser: BaseMultiModalContentParser,
    *,
    wrap_dicts: bool,
1421
    interleave_strings: bool,
1422
) -> _ContentPart | None:
1423
1424
1425
1426
1427
1428
1429
1430
    """Parses a single part of a conversation. If wrap_dicts is True,
    structured dictionary pieces for texts and images will be
    wrapped in dictionaries, i.e., {"type": "text", "text", ...} and
    {"type": "image"}, respectively. Otherwise multimodal data will be
    handled by mm_parser, and texts will be returned as strings to be joined
    with multimodal placeholders.
    """
    if isinstance(part, str):  # Handle plain text parts
1431
        return part
1432
1433
    # Handle structured dictionary parts
    part_type, content = _parse_chat_message_content_mm_part(part)
1434
    # if part_type is text/refusal/image_url/audio_url/video_url/input_audio but
1435
    # content is None, log a warning and skip
1436
    if part_type in PART_TYPES_TO_SKIP_NONE_CONTENT and content is None:
1437
        logger.warning(
1438
            "Skipping multimodal part '%s' (type: '%s') "
1439
1440
1441
1442
            "with empty / unparsable content.",
            part,
            part_type,
        )
1443
1444
        return None

1445
    if part_type in ("text", "input_text", "output_text", "refusal", "thinking"):
1446
1447
        str_content = cast(str, content)
        if wrap_dicts:
1448
            return {"type": "text", "text": str_content}
1449
1450
        else:
            return str_content
1451

1452
1453
1454
1455
1456
1457
    # For media items, if a user has provided one, use it. Otherwise, insert
    # a placeholder empty uuid.
    uuid = part.get("uuid", None)
    if uuid is not None:
        uuid = str(uuid)

1458
    modality = None
1459
    if part_type == "image_pil":
1460
        image_content = cast(Image.Image, content) if content is not None else None
1461
        mm_parser.parse_image_pil(image_content, uuid)
1462
        modality = "image"
1463
    elif part_type in ("image_url", "input_image"):
1464
        str_content = cast(str, content)
1465
        mm_parser.parse_image(str_content, uuid)
1466
1467
        modality = "image"
    elif part_type == "image_embeds":
1468
        content = cast(str | dict[str, str], content) if content is not None else None
1469
        mm_parser.parse_image_embeds(content, uuid)
1470
        modality = "image"
1471
1472
1473
1474
    elif part_type == "audio_embeds":
        content = cast(str | dict[str, str], content) if content is not None else None
        mm_parser.parse_audio_embeds(content, uuid)
        modality = "audio"
1475
    elif part_type == "audio_url":
1476
        str_content = cast(str, content)
1477
        mm_parser.parse_audio(str_content, uuid)
1478
1479
        modality = "audio"
    elif part_type == "input_audio":
1480
        dict_content = cast(InputAudio, content)
1481
        mm_parser.parse_input_audio(dict_content, uuid)
1482
1483
        modality = "audio"
    elif part_type == "video_url":
1484
        str_content = cast(str, content)
1485
        mm_parser.parse_video(str_content, uuid)
1486
1487
1488
        modality = "video"
    else:
        raise NotImplementedError(f"Unknown part type: {part_type}")
1489

1490
1491
1492
    return (
        {"type": modality}
        if wrap_dicts
1493
        else (MODALITY_PLACEHOLDERS_MAP[modality] if interleave_strings else None)
1494
    )
1495
1496


1497
1498
1499
1500
1501
# No need to validate using Pydantic again
_AssistantParser = partial(cast, ChatCompletionAssistantMessageParam)
_ToolParser = partial(cast, ChatCompletionToolMessageParam)


1502
def _parse_chat_message_content(
1503
1504
    message: ChatCompletionMessageParam,
    mm_tracker: BaseMultiModalItemTracker,
1505
    content_format: ChatTemplateContentFormat,
1506
    interleave_strings: bool,
1507
    mm_processor_kwargs: dict[str, Any] | None = None,
1508
) -> list[ConversationMessage]:
1509
1510
    role = message["role"]
    content = message.get("content")
1511
    reasoning = message.get("reasoning")
1512

1513
    if content is None:
1514
1515
        content = []
    elif isinstance(content, str):
1516
        content = [ChatCompletionContentPartTextParam(type="text", text=content)]
1517
    result = _parse_chat_message_content_parts(
1518
1519
        role,
        content,  # type: ignore
1520
        mm_tracker,
1521
        wrap_dicts=(content_format == "openai"),
1522
        interleave_strings=interleave_strings,
1523
        mm_processor_kwargs=mm_processor_kwargs,
1524
    )
1525

1526
    for result_msg in result:
1527
        if role == "assistant":
1528
1529
            parsed_msg = _AssistantParser(message)

1530
1531
1532
            # The 'tool_calls' is not None check ensures compatibility.
            # It's needed only if downstream code doesn't strictly
            # follow the OpenAI spec.
1533
            if "tool_calls" in parsed_msg and parsed_msg["tool_calls"] is not None:
1534
                result_msg["tool_calls"] = list(parsed_msg["tool_calls"])
1535
1536
1537
            # Include reasoning if present for interleaved thinking.
            if reasoning is not None:
                result_msg["reasoning"] = cast(str, reasoning)
1538
1539
1540
                result_msg["reasoning_content"] = cast(
                    str, reasoning
                )  # keep compatibility
1541
1542
1543
1544
1545
1546
1547
1548
        elif role == "tool":
            parsed_msg = _ToolParser(message)
            if "tool_call_id" in parsed_msg:
                result_msg["tool_call_id"] = parsed_msg["tool_call_id"]

        if "name" in message and isinstance(message["name"], str):
            result_msg["name"] = message["name"]

1549
1550
        if role == "developer":
            result_msg["tools"] = message.get("tools", None)
1551
1552
    return result

1553

1554
def _postprocess_messages(messages: list[ConversationMessage]) -> None:
1555
1556
1557
1558
1559
1560
    # per the Transformers docs & maintainers, tool call arguments in
    # assistant-role messages with tool_calls need to be dicts not JSON str -
    # this is how tool-use chat templates will expect them moving forwards
    # so, for messages that have tool_calls, parse the string (which we get
    # from openAI format) to dict
    for message in messages:
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
        if message["role"] == "assistant" and "tool_calls" in message:
            tool_calls = message.get("tool_calls")
            if not isinstance(tool_calls, list):
                continue

            if len(tool_calls) == 0:
                # Drop empty tool_calls to keep templates on the normal assistant path.
                message.pop("tool_calls", None)
                continue

            for item in tool_calls:
1572
1573
                # if arguments is None or empty string, set to {}
                if content := item["function"].get("arguments"):
1574
1575
                    if not isinstance(content, (dict, list)):
                        item["function"]["arguments"] = json.loads(content)
1576
1577
                else:
                    item["function"]["arguments"] = {}
1578
1579


1580
def parse_chat_messages(
1581
    messages: list[ChatCompletionMessageParam],
1582
    model_config: ModelConfig,
1583
    content_format: ChatTemplateContentFormat,
1584
    media_io_kwargs: dict[str, dict[str, Any]] | None = None,
1585
    mm_processor_kwargs: dict[str, Any] | None = None,
1586
1587
) -> tuple[
    list[ConversationMessage],
1588
1589
    MultiModalDataDict | None,
    MultiModalUUIDDict | None,
1590
]:
1591
    conversation: list[ConversationMessage] = []
1592
    mm_tracker = MultiModalItemTracker(model_config, media_io_kwargs=media_io_kwargs)
1593
1594

    for msg in messages:
1595
1596
1597
        sub_messages = _parse_chat_message_content(
            msg,
            mm_tracker,
1598
            content_format,
1599
1600
1601
1602
            interleave_strings=(
                content_format == "string"
                and model_config.multimodal_config is not None
                and model_config.multimodal_config.interleave_mm_strings
1603
            ),
1604
            mm_processor_kwargs=mm_processor_kwargs,
1605
        )
1606

1607
        conversation.extend(sub_messages)
1608

1609
1610
    _postprocess_messages(conversation)

1611
1612
1613
    mm_data, mm_uuids = mm_tracker.resolve_items()

    return conversation, mm_data, mm_uuids
1614
1615


1616
async def parse_chat_messages_async(
1617
    messages: list[ChatCompletionMessageParam],
1618
    model_config: ModelConfig,
1619
    content_format: ChatTemplateContentFormat,
1620
    media_io_kwargs: dict[str, dict[str, Any]] | None = None,
1621
    mm_processor_kwargs: dict[str, Any] | None = None,
1622
1623
) -> tuple[
    list[ConversationMessage],
1624
    MultiModalDataDict | None,
1625
    MultiModalUUIDDict | None,
1626
]:
1627
    conversation: list[ConversationMessage] = []
1628
1629
1630
    mm_tracker = AsyncMultiModalItemTracker(
        model_config, media_io_kwargs=media_io_kwargs
    )
1631
1632

    for msg in messages:
1633
1634
1635
        sub_messages = _parse_chat_message_content(
            msg,
            mm_tracker,
1636
            content_format,
1637
1638
1639
1640
            interleave_strings=(
                content_format == "string"
                and model_config.multimodal_config is not None
                and model_config.multimodal_config.interleave_mm_strings
1641
            ),
1642
            mm_processor_kwargs=mm_processor_kwargs,
1643
        )
1644
1645
1646

        conversation.extend(sub_messages)

1647
1648
    _postprocess_messages(conversation)

1649
1650
1651
    mm_data, mm_uuids = await mm_tracker.resolve_items()

    return conversation, mm_data, mm_uuids
1652

1653

1654
1655
1656
def get_history_tool_calls_cnt(conversation: list[ConversationMessage]):
    idx = 0
    for msg in conversation:
1657
1658
1659
        if msg["role"] == "assistant":
            tool_calls = msg.get("tool_calls")
            idx += len(list(tool_calls)) if tool_calls is not None else 0  # noqa
1660
1661
1662
    return idx


1663
1664
1665
def make_tool_call_id(id_type: str = "random", func_name=None, idx=None):
    if id_type == "kimi_k2":
        return f"functions.{func_name}:{idx}"
1666
1667
1668
    else:
        # by default return random
        return f"chatcmpl-tool-{random_uuid()}"