chat_utils.py 55 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
import json
6
import warnings
7
from abc import ABC, abstractmethod
8
from collections import Counter, defaultdict
9
from collections.abc import Awaitable, Callable, Iterable
10
from functools import cached_property, lru_cache, partial
11
from itertools import accumulate
12
from pathlib import Path
13
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, TypeVar, cast
14

15
from openai.types.chat import (
16
17
18
19
20
    ChatCompletionAssistantMessageParam,
    ChatCompletionContentPartImageParam,
    ChatCompletionContentPartInputAudioParam,
    ChatCompletionContentPartRefusalParam,
    ChatCompletionContentPartTextParam,
21
    ChatCompletionFunctionToolParam,
22
23
24
25
26
27
    ChatCompletionMessageToolCallParam,
    ChatCompletionToolMessageParam,
)
from openai.types.chat import (
    ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam,
)
28
from openai.types.chat import (
29
30
31
    ChatCompletionMessageParam as OpenAIChatCompletionMessageParam,
)
from openai.types.chat.chat_completion_content_part_input_audio_param import InputAudio
32
from openai.types.responses import ResponseInputImageParam
33
from openai_harmony import Message as OpenAIHarmonyMessage
34
35
from PIL import Image
from pydantic import BaseModel, ConfigDict, TypeAdapter
36

37
# pydantic needs the TypedDict from typing_extensions
38
from typing_extensions import Required, TypedDict
39

40
from vllm import envs
41
from vllm.config import ModelConfig
42
from vllm.logger import init_logger
43
from vllm.model_executor.models import SupportsMultiModal
44
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict
45
46
47
48
from vllm.multimodal.inputs import (
    MultiModalBatchedField,
    MultiModalFlatField,
    MultiModalSharedField,
Roger Wang's avatar
Roger Wang committed
49
50
51
    VisionChunk,
    VisionChunkImage,
    VisionChunkVideo,
52
)
53
from vllm.multimodal.media import MEDIA_CONNECTOR_REGISTRY, MediaConnector
54
from vllm.multimodal.processing import BaseMultiModalProcessor
55
from vllm.utils import random_uuid
56
57
58
59
60
from vllm.utils.collection_utils import is_list_of
from vllm.utils.import_utils import LazyLoader

if TYPE_CHECKING:
    import torch
61
    import transformers
62
else:
63
    transformers = LazyLoader("transformers", globals(), "transformers")
64
    torch = LazyLoader("torch", globals(), "torch")
65
66
67

logger = init_logger(__name__)

68

69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def __getattr__(name: str):
    if name == "resolve_hf_chat_template":
        from vllm.renderers.hf import resolve_chat_template

        warnings.warn(
            "`vllm.entrypoints.chat_utils.resolve_hf_chat_template` has been moved to "
            "`vllm.renderers.hf.resolve_chat_template`. "
            "The old name will be removed in v0.16.",
            DeprecationWarning,
            stacklevel=2,
        )

        return resolve_chat_template

    raise AttributeError(f"module {__name__!r} has no attribute {name!r}")


86
87
88
89
90
91
92
93
class ChatTemplateResolutionError(ValueError):
    """Raised when chat template resolution fails.

    This is a subclass of ValueError for backward compatibility with
    existing exception handlers.
    """


94
95
96
97
98
99
MODALITY_PLACEHOLDERS_MAP = {
    "image": "<##IMAGE##>",
    "audio": "<##AUDIO##>",
    "video": "<##VIDEO##>",
}

100

101
102
103
104
105
106
107
108
109
110
111
112
113
114
class AudioURL(TypedDict, total=False):
    url: Required[str]
    """
    Either a URL of the audio or a data URL with base64 encoded audio data.
    """


class ChatCompletionContentPartAudioParam(TypedDict, total=False):
    audio_url: Required[AudioURL]

    type: Required[Literal["audio_url"]]
    """The type of the content part."""


115
class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
116
    image_embeds: str | dict[str, str] | None
117
118
119
120
121
122
123
    """
    The image embeddings. It can be either:
    - A single base64 string.
    - A dictionary where each value is a base64 string.
    """
    type: Required[Literal["image_embeds"]]
    """The type of the content part."""
124
    uuid: str | None
125
126
127
128
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
129
130


131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
class ChatCompletionContentPartAudioEmbedsParam(TypedDict, total=False):
    audio_embeds: str | dict[str, str] | None
    """
    The audio embeddings. It can be either:
    - A single base64 string representing a serialized torch tensor.
    - A dictionary where each value is a base64 string.
    """
    type: Required[Literal["audio_embeds"]]
    """The type of the content part."""
    uuid: str | None
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """


147
148
149
150
151
152
153
154
155
156
157
158
159
160
class VideoURL(TypedDict, total=False):
    url: Required[str]
    """
    Either a URL of the video or a data URL with base64 encoded video data.
    """


class ChatCompletionContentPartVideoParam(TypedDict, total=False):
    video_url: Required[VideoURL]

    type: Required[Literal["video_url"]]
    """The type of the content part."""


161
162
163
164
class PILImage(BaseModel):
    """
    A PIL.Image.Image object.
    """
165

166
167
168
169
170
171
172
173
174
175
176
177
    image_pil: Image.Image
    model_config = ConfigDict(arbitrary_types_allowed=True)


class CustomChatCompletionContentPILImageParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a PIL image.

    Example:
    {
        "image_pil": ImageAsset('cherry_blossom').pil_image
    }
    """
178

179
180
    image_pil: PILImage | None
    uuid: str | None
181
182
183
184
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
185
186


187
188
189
class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain image_url.
    This is supported by OpenAI API, although it is not documented.
190

191
192
193
194
195
    Example:
    {
        "image_url": "https://example.com/image.jpg"
    }
    """
196

197
198
    image_url: str | None
    uuid: str | None
199
200
201
202
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
203
204
205
206


class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain audio_url.
207

208
209
210
211
212
    Example:
    {
        "audio_url": "https://example.com/audio.mp3"
    }
    """
213

214
    audio_url: str | None
215
216


217
218
219
220
221
222
223
224
class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain audio_url.

    Example:
    {
        "video_url": "https://example.com/video.mp4"
    }
    """
225

226
227
    video_url: str | None
    uuid: str | None
228
229
230
231
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
232
233


Julien Denize's avatar
Julien Denize committed
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
class CustomThinkCompletionContentParam(TypedDict, total=False):
    """A Think Completion Content Param that accepts a plain text and a boolean.

    Example:
    {
        "thinking": "I am thinking about the answer",
        "closed": True,
        "type": "thinking"
    }
    """

    thinking: Required[str]
    """The thinking content."""

    closed: bool
    """Whether the thinking is closed."""

    type: Required[Literal["thinking"]]
    """The thinking type."""


255
256
257
258
259
260
261
262
263
ChatCompletionContentPartParam: TypeAlias = (
    OpenAIChatCompletionContentPartParam
    | ChatCompletionContentPartAudioParam
    | ChatCompletionContentPartInputAudioParam
    | ChatCompletionContentPartVideoParam
    | ChatCompletionContentPartRefusalParam
    | CustomChatCompletionContentPILImageParam
    | CustomChatCompletionContentSimpleImageParam
    | ChatCompletionContentPartImageEmbedsParam
264
    | ChatCompletionContentPartAudioEmbedsParam
265
266
267
268
269
    | CustomChatCompletionContentSimpleAudioParam
    | CustomChatCompletionContentSimpleVideoParam
    | str
    | CustomThinkCompletionContentParam
)
270
271
272
273


class CustomChatCompletionMessageParam(TypedDict, total=False):
    """Enables custom roles in the Chat Completion API."""
274

275
276
277
    role: Required[str]
    """The role of the message's author."""

278
    content: str | list[ChatCompletionContentPartParam]
279
280
281
282
283
284
285
286
287
    """The contents of the message."""

    name: str
    """An optional name for the participant.

    Provides the model information to differentiate between participants of the
    same role.
    """

288
    tool_call_id: str | None
289
290
    """Tool call that this message is responding to."""

291
    tool_calls: Iterable[ChatCompletionMessageToolCallParam] | None
292
293
    """The tool calls generated by the model, such as function calls."""

294
295
296
    reasoning: str | None
    """The reasoning content for interleaved thinking."""

297
298
299
    tools: list[ChatCompletionFunctionToolParam] | None
    """The tools for developer role."""

300

301
302
303
304
305
ChatCompletionMessageParam: TypeAlias = (
    OpenAIChatCompletionMessageParam
    | CustomChatCompletionMessageParam
    | OpenAIHarmonyMessage
)
306
307


308
# TODO: Make fields ReadOnly once mypy supports it
309
310
311
312
class ConversationMessage(TypedDict, total=False):
    role: Required[str]
    """The role of the message's author."""

313
    content: str | None | list[dict[str, str]]
314
315
    """The contents of the message"""

316
    tool_call_id: str | None
317
318
    """Tool call that this message is responding to."""

319
    name: str | None
320
321
    """The name of the function to call"""

322
    tool_calls: Iterable[ChatCompletionMessageToolCallParam] | None
323
    """The tool calls generated by the model, such as function calls."""
324

325
326
327
328
329
330
    reasoning: str | None
    """The reasoning content for interleaved thinking."""

    reasoning_content: str | None
    """Deprecated: The reasoning content for interleaved thinking."""

331
332
333
    tools: list[ChatCompletionFunctionToolParam] | None
    """The tools for developer role."""

334

335
336
337
# Passed in by user
ChatTemplateContentFormatOption = Literal["auto", "string", "openai"]

338
339
# After resolving "auto"
ChatTemplateContentFormat = Literal["string", "openai"]
340

341

Roger Wang's avatar
Roger Wang committed
342
343
344
ModalityStr = Literal[
    "image", "audio", "video", "image_embeds", "audio_embeds", "vision_chunk"
]
345
346
347
_T = TypeVar("_T")


348
349
350
# Backward compatibility for single item input
class _BatchedSingleItemField(MultiModalSharedField):
    pass
351
352


353
354
355
356
357
358
def _detect_field(
    tensors: list[torch.Tensor],
    mm_processor: BaseMultiModalProcessor,
):
    first_item = tensors[0]
    hidden_size = mm_processor.info.ctx.model_config.get_inputs_embeds_size()
359

360
361
362
363
364
365
366
367
368
369
370
    if (
        len(tensors) == 1
        and first_item.ndim == 3
        and first_item.shape[0] == 1
        and first_item.shape[-1] == hidden_size
    ):
        logger.warning(
            "Batched multi-modal embedding inputs are deprecated for Chat API. "
            "Please pass a separate content part for each multi-modal item."
        )
        return _BatchedSingleItemField(batch_size=1)
371

372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
    first_shape = first_item.shape
    if all(t.shape == first_shape for t in tensors):
        return MultiModalBatchedField()

    size_per_item = [len(tensor) for tensor in tensors]
    slice_idxs = [0, *accumulate(size_per_item)]
    slices = [
        (slice(slice_idxs[i], slice_idxs[i + 1]),) for i in range(len(size_per_item))
    ]
    return MultiModalFlatField(slices=slices)


def _merge_embeds(
    data_items: list[dict[str, "torch.Tensor"]],
    mm_processor: BaseMultiModalProcessor,
):
    if not data_items:
        return {}

    first_keys = set(data_items[0].keys())
    if any(set(item.keys()) != first_keys for item in data_items[1:]):
        raise ValueError(
            "All dictionaries in the list of embeddings must have the same keys."
        )
396

397
398
399
400
401
402
403
404
    fields = {
        key: _detect_field([item[key] for item in data_items], mm_processor)
        for key in first_keys
    }
    data_merged = {
        key: field._reduce_data([item[key] for item in data_items], pin_memory=False)
        for key, field in fields.items()
    }
405

406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
    try:
        # TODO: Support per-request mm_processor_kwargs
        parsed_configs = mm_processor._get_mm_fields_config(
            transformers.BatchFeature(data_merged),
            {},
        )
        parsed_fields = {key: parsed_configs[key].field for key in first_keys}
        keys_to_update = [
            key
            for key in first_keys
            if (
                fields[key] != parsed_fields[key]
                and not isinstance(fields[key], _BatchedSingleItemField)
            )
        ]
421

422
423
424
        for key in keys_to_update:
            data_merged[key] = parsed_fields[key]._reduce_data(
                [item[key] for item in data_items], pin_memory=False
425
            )
426
427
428
429
430
    except Exception:
        logger.exception(
            "Error when parsing merged embeddings. "
            "Falling back to auto-detected fields."
        )
431

432
    return data_merged
433

434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454

def _get_embeds_data(
    modality: str,
    data_items: list[Any],
    mm_processor: BaseMultiModalProcessor,
):
    if len(data_items) == 0:
        return data_items

    if all(item is None for item in data_items):
        return data_items

    if is_list_of(data_items, torch.Tensor):
        embeds_key = f"{modality}_embeds"
        dict_items = [{embeds_key: item} for item in data_items]
        return _merge_embeds(dict_items, mm_processor)[embeds_key]

    if is_list_of(data_items, dict):
        return _merge_embeds(data_items, mm_processor)

    raise NotImplementedError(type(data_items))
455
456


457
class BaseMultiModalItemTracker(ABC, Generic[_T]):
458
459
460
461
462
463
    """
    Tracks multi-modal items in a given request and ensures that the number
    of multi-modal items in a given request does not exceed the configured
    maximum per prompt.
    """

464
    def __init__(self, model_config: ModelConfig):
465
466
        super().__init__()

467
        self._model_config = model_config
468

469
        self._items_by_modality = defaultdict[str, list[_T]](list)
Roger Wang's avatar
Roger Wang committed
470
471
472
473
474
475
476
        # Track original modality for each vision_chunk item (image or video)
        self._modality_order = defaultdict[str, list[str]](list)

    @cached_property
    def use_unified_vision_chunk_modality(self) -> bool:
        """Check if model uses unified vision_chunk modality for images/videos."""
        return getattr(self._model_config.hf_config, "use_unified_vision_chunk", False)
477

478
    @property
479
480
    def model_config(self) -> ModelConfig:
        return self._model_config
481

482
    @cached_property
483
    def model_cls(self) -> type[SupportsMultiModal]:
484
        from vllm.model_executor.model_loader import get_model_cls
485

486
        model_cls = get_model_cls(self.model_config)
487
        return cast(type[SupportsMultiModal], model_cls)
488

489
490
    @property
    def allowed_local_media_path(self):
491
        return self._model_config.allowed_local_media_path
492

493
494
    @property
    def allowed_media_domains(self):
495
        return self._model_config.allowed_media_domains
496

497
498
499
500
    @property
    def mm_registry(self):
        return MULTIMODAL_REGISTRY

501
502
    @cached_property
    def mm_processor(self):
503
        return self.mm_registry.create_processor(self.model_config)
504

505
    def add(self, modality: ModalityStr, item: _T) -> str | None:
506
507
508
        """
        Add a multi-modal item to the current prompt and returns the
        placeholder string to use, if any.
509
510

        An optional uuid can be added which serves as a unique identifier of the
511
        media.
512
        """
513
        input_modality = modality.replace("_embeds", "")
Roger Wang's avatar
Roger Wang committed
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
        original_modality = modality
        use_vision_chunk = (
            self.use_unified_vision_chunk_modality
            and original_modality in ["video", "image"]
        )

        # If use_unified_vision_chunk_modality is enabled,
        # map image/video to vision_chunk
        if use_vision_chunk:
            # To avoid validation fail
            # because models with use_unified_vision_chunk_modality=True
            # will only accept vision_chunk modality.
            input_modality = "vision_chunk"
            num_items = len(self._items_by_modality[input_modality]) + 1
        else:
            num_items = len(self._items_by_modality[original_modality]) + 1
530

531
        self.mm_processor.info.validate_num_items(input_modality, num_items)
532

Roger Wang's avatar
Roger Wang committed
533
534
535
536
537
538
        # Track original modality for vision_chunk items
        if use_vision_chunk:
            self._items_by_modality[input_modality].append(item)  # type: ignore
            self._modality_order["vision_chunk"].append(original_modality)
        else:
            self._items_by_modality[original_modality].append(item)
539

540
        return self.model_cls.get_placeholder_str(modality, num_items)
541
542
543
544
545
546

    @abstractmethod
    def create_parser(self) -> "BaseMultiModalContentParser":
        raise NotImplementedError


547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
def _resolve_vision_chunk_items(
    vision_chunk_items: list[tuple[object, str | None]],
    mm_processor: BaseMultiModalProcessor,
    vision_chunks_modality_order: list[str],
):
    # Process vision_chunk items - extract from (data, modality) tuples
    # and convert to VisionChunk types with proper UUID handling
    vision_chunks_uuids = [uuid for data, uuid in vision_chunk_items]

    assert len(vision_chunk_items) == len(vision_chunks_modality_order), (
        f"vision_chunk items ({len(vision_chunk_items)}) and "
        f"modality_order ({len(vision_chunks_modality_order)}) must have same length"
    )

    processed_chunks: list[VisionChunk] = []
    video_idx = 0
    for inner_modality, (data, uuid) in zip(
        vision_chunks_modality_order, vision_chunk_items
    ):
        if inner_modality == "image":
            # Cast data to proper type for image
            # Use .media (PIL.Image) directly to avoid redundant
            # bytes→PIL conversion in media_processor
            if hasattr(data, "media"):
                image_data = data.media  # type: ignore[union-attr]
                processed_chunks.append(
                    VisionChunkImage(type="image", image=image_data, uuid=uuid)
                )
            else:
                processed_chunks.append(data)  # type: ignore[arg-type]
        elif inner_modality == "video":
            # For video, we may need to split into chunks
            # if processor supports it
            # For now, just wrap as a video chunk placeholder
            if hasattr(mm_processor, "split_video_chunks") and data is not None:
                try:
                    video_uuid = uuid or random_uuid()
                    # video await result is (video_data, video_meta) tuple
                    if isinstance(data, tuple) and len(data) >= 1:
                        video_data = data[0]
                    else:
                        video_data = data
                    video_chunks = mm_processor.split_video_chunks(video_data)
                    for i, vc in enumerate(video_chunks):
                        processed_chunks.append(
                            VisionChunkVideo(
                                type="video_chunk",
                                video_chunk=vc["video_chunk"],
                                uuid=f"{video_uuid}-{i}",
                                video_idx=video_idx,
                                prompt=vc["prompt"],
                            )
                        )
                    video_idx += 1
                except Exception as e:
                    logger.warning("Failed to split video chunks: %s", e)
                    processed_chunks.append(data)  # type: ignore[arg-type]
            else:
                processed_chunks.append(data)  # type: ignore[arg-type]
    return processed_chunks, vision_chunks_uuids


609
610
611
def _resolve_items(
    items_by_modality: dict[str, list[tuple[object, str | None]]],
    mm_processor: BaseMultiModalProcessor,
612
    modality_order: dict[str, list[str]],
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
) -> tuple[MultiModalDataDict, MultiModalUUIDDict]:
    if "image" in items_by_modality and "image_embeds" in items_by_modality:
        raise ValueError("Mixing raw image and embedding inputs is not allowed")
    if "audio" in items_by_modality and "audio_embeds" in items_by_modality:
        raise ValueError("Mixing raw audio and embedding inputs is not allowed")

    mm_data = {}
    mm_uuids = {}
    if "image_embeds" in items_by_modality:
        mm_data["image"] = _get_embeds_data(
            "image",
            [data for data, uuid in items_by_modality["image_embeds"]],
            mm_processor,
        )
        mm_uuids["image"] = [uuid for data, uuid in items_by_modality["image_embeds"]]
    if "image" in items_by_modality:
        mm_data["image"] = [data for data, uuid in items_by_modality["image"]]
        mm_uuids["image"] = [uuid for data, uuid in items_by_modality["image"]]
    if "audio_embeds" in items_by_modality:
        mm_data["audio"] = _get_embeds_data(
            "audio",
            [data for data, uuid in items_by_modality["audio_embeds"]],
            mm_processor,
        )
        mm_uuids["audio"] = [uuid for data, uuid in items_by_modality["audio_embeds"]]
    if "audio" in items_by_modality:
        mm_data["audio"] = [data for data, uuid in items_by_modality["audio"]]
        mm_uuids["audio"] = [uuid for data, uuid in items_by_modality["audio"]]
    if "video" in items_by_modality:
        mm_data["video"] = [data for data, uuid in items_by_modality["video"]]
        mm_uuids["video"] = [uuid for data, uuid in items_by_modality["video"]]
Roger Wang's avatar
Roger Wang committed
644
645
646
    if "vision_chunk" in items_by_modality:
        # Process vision_chunk items - extract from (data, modality) tuples
        # and convert to VisionChunk types with proper UUID handling
647
648
649
650
        processed_chunks, vision_chunk_uuids = _resolve_vision_chunk_items(
            items_by_modality["vision_chunk"],
            mm_processor,
            modality_order.get("vision_chunk", []),
Roger Wang's avatar
Roger Wang committed
651
652
        )
        mm_data["vision_chunk"] = processed_chunks
653
        mm_uuids["vision_chunk"] = vision_chunk_uuids
654
655
656
657
658
659
660
661

    return mm_data, mm_uuids


class MultiModalItemTracker(BaseMultiModalItemTracker[tuple[object, str | None]]):
    def resolve_items(
        self,
    ) -> tuple[MultiModalDataDict | None, MultiModalUUIDDict | None]:
662
        if not self._items_by_modality:
663
664
            return None, None

Roger Wang's avatar
Roger Wang committed
665
666
667
        return _resolve_items(
            dict(self._items_by_modality), self.mm_processor, self._modality_order
        )
668
669
670
671
672

    def create_parser(self) -> "BaseMultiModalContentParser":
        return MultiModalContentParser(self)


673
674
675
676
677
678
class AsyncMultiModalItemTracker(
    BaseMultiModalItemTracker[Awaitable[tuple[object, str | None]]]
):
    async def resolve_items(
        self,
    ) -> tuple[MultiModalDataDict | None, MultiModalUUIDDict | None]:
679
        if not self._items_by_modality:
680
            return None, None
681

682
        resolved_items_by_modality = {
683
            modality: await asyncio.gather(*coros)
684
            for modality, coros in self._items_by_modality.items()
685
        }
686

Roger Wang's avatar
Roger Wang committed
687
688
689
        return _resolve_items(
            resolved_items_by_modality, self.mm_processor, self._modality_order
        )
690
691
692
693
694
695
696
697
698

    def create_parser(self) -> "BaseMultiModalContentParser":
        return AsyncMultiModalContentParser(self)


class BaseMultiModalContentParser(ABC):
    def __init__(self) -> None:
        super().__init__()

699
        # stores model placeholders list with corresponding
700
701
702
703
704
705
706
        # general MM placeholder:
        # {
        #   "<##IMAGE##>": ["<image>", "<image>", "<image>"],
        #   "<##AUDIO##>": ["<audio>", "<audio>"]
        # }
        self._placeholder_storage: dict[str, list] = defaultdict(list)

707
    def _add_placeholder(self, modality: ModalityStr, placeholder: str | None):
708
        mod_placeholder = MODALITY_PLACEHOLDERS_MAP[modality]
709
        if placeholder:
710
            self._placeholder_storage[mod_placeholder].append(placeholder)
711

712
713
    def mm_placeholder_storage(self) -> dict[str, list]:
        return dict(self._placeholder_storage)
714
715

    @abstractmethod
716
    def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
717
718
        raise NotImplementedError

719
    @abstractmethod
720
    def parse_image_embeds(
721
        self,
722
723
        image_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
724
    ) -> None:
725
726
        raise NotImplementedError

727
    @abstractmethod
728
    def parse_image_pil(
729
        self, image_pil: Image.Image | None, uuid: str | None = None
730
    ) -> None:
731
732
        raise NotImplementedError

733
    @abstractmethod
734
    def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None:
735
736
        raise NotImplementedError

737
    @abstractmethod
738
    def parse_input_audio(
739
        self, input_audio: InputAudio | None, uuid: str | None = None
740
    ) -> None:
741
742
        raise NotImplementedError

743
744
745
746
747
748
749
750
    @abstractmethod
    def parse_audio_embeds(
        self,
        audio_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
    ) -> None:
        raise NotImplementedError

751
    @abstractmethod
752
    def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
753
754
        raise NotImplementedError

755
756
757
758
759
760

class MultiModalContentParser(BaseMultiModalContentParser):
    def __init__(self, tracker: MultiModalItemTracker) -> None:
        super().__init__()

        self._tracker = tracker
761
762
763
        multimodal_config = self._tracker.model_config.multimodal_config
        media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None)

764
765
        self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load(
            envs.VLLM_MEDIA_CONNECTOR,
766
            media_io_kwargs=media_io_kwargs,
767
            allowed_local_media_path=tracker.allowed_local_media_path,
768
            allowed_media_domains=tracker.allowed_media_domains,
769
770
        )

771
772
    @property
    def model_config(self) -> ModelConfig:
773
        return self._tracker.model_config
774

775
    def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
776
        image = self._connector.fetch_image(image_url) if image_url else None
777

778
        placeholder = self._tracker.add("image", (image, uuid))
779
        self._add_placeholder("image", placeholder)
780

781
    def parse_image_embeds(
782
        self,
783
784
        image_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
785
    ) -> None:
786
787
788
789
790
791
        mm_config = self.model_config.get_multimodal_config()
        if not mm_config.enable_mm_embeds:
            raise ValueError(
                "You must set `--enable-mm-embeds` to input `image_embeds`"
            )

792
793
794
795
796
        if isinstance(image_embeds, dict):
            embeds = {
                k: self._connector.fetch_image_embedding(v)
                for k, v in image_embeds.items()
            }
797
            placeholder = self._tracker.add("image_embeds", (embeds, uuid))
798
799
800

        if isinstance(image_embeds, str):
            embedding = self._connector.fetch_image_embedding(image_embeds)
801
            placeholder = self._tracker.add("image_embeds", (embedding, uuid))
802

803
        if image_embeds is None:
804
            placeholder = self._tracker.add("image_embeds", (None, uuid))
805

806
        self._add_placeholder("image", placeholder)
807

808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
    def parse_audio_embeds(
        self,
        audio_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
    ) -> None:
        mm_config = self.model_config.get_multimodal_config()
        if not mm_config.enable_mm_embeds:
            raise ValueError(
                "You must set `--enable-mm-embeds` to input `audio_embeds`"
            )

        if isinstance(audio_embeds, dict):
            embeds = {
                k: self._connector.fetch_audio_embedding(v)
                for k, v in audio_embeds.items()
            }
824
            placeholder = self._tracker.add("audio_embeds", (embeds, uuid))
825
826
        elif isinstance(audio_embeds, str):
            embedding = self._connector.fetch_audio_embedding(audio_embeds)
827
            placeholder = self._tracker.add("audio_embeds", (embedding, uuid))
828
        else:
829
            placeholder = self._tracker.add("audio_embeds", (None, uuid))
830
831
832

        self._add_placeholder("audio", placeholder)

833
    def parse_image_pil(
834
        self, image_pil: Image.Image | None, uuid: str | None = None
835
    ) -> None:
836
        placeholder = self._tracker.add("image", (image_pil, uuid))
837
        self._add_placeholder("image", placeholder)
838

839
    def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None:
840
        audio = self._connector.fetch_audio(audio_url) if audio_url else None
841

842
        placeholder = self._tracker.add("audio", (audio, uuid))
843
        self._add_placeholder("audio", placeholder)
844

845
    def parse_input_audio(
846
        self, input_audio: InputAudio | None, uuid: str | None = None
847
    ) -> None:
848
849
850
851
852
853
854
855
856
857
        if input_audio:
            audio_data = input_audio.get("data", "")
            audio_format = input_audio.get("format", "")
            if audio_data:
                audio_url = f"data:audio/{audio_format};base64,{audio_data}"
            else:
                # If a UUID is provided, audio data may be empty.
                audio_url = None
        else:
            audio_url = None
858

859
        return self.parse_audio(audio_url, uuid)
860

861
    def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
862
        video = self._connector.fetch_video(video_url=video_url) if video_url else None
863

864
        placeholder = self._tracker.add("video", (video, uuid))
865
        self._add_placeholder("video", placeholder)
866

867
868
869
870
871
872

class AsyncMultiModalContentParser(BaseMultiModalContentParser):
    def __init__(self, tracker: AsyncMultiModalItemTracker) -> None:
        super().__init__()

        self._tracker = tracker
873
874
        multimodal_config = self._tracker.model_config.multimodal_config
        media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None)
875
876
        self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load(
            envs.VLLM_MEDIA_CONNECTOR,
877
            media_io_kwargs=media_io_kwargs,
878
            allowed_local_media_path=tracker.allowed_local_media_path,
879
            allowed_media_domains=tracker.allowed_media_domains,
880
        )
881

882
883
    @property
    def model_config(self) -> ModelConfig:
884
        return self._tracker.model_config
885

886
887
888
889
890
891
    async def _image_with_uuid_async(self, image_url: str | None, uuid: str | None):
        image = (
            await self._connector.fetch_image_async(image_url) if image_url else None
        )
        return image, uuid

892
    def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
893
        coro = self._image_with_uuid_async(image_url, uuid)
894

895
        placeholder = self._tracker.add("image", coro)
896
        self._add_placeholder("image", placeholder)
897

898
    def parse_image_embeds(
899
        self,
900
901
        image_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
902
    ) -> None:
903
904
905
906
907
908
        mm_config = self.model_config.get_multimodal_config()
        if not mm_config.enable_mm_embeds:
            raise ValueError(
                "You must set `--enable-mm-embeds` to input `image_embeds`"
            )

909
910
911
        future = asyncio.Future[
            tuple[torch.Tensor | dict[str, torch.Tensor] | None, str | None]
        ]()
912
913
914
915
916
917

        if isinstance(image_embeds, dict):
            embeds = {
                k: self._connector.fetch_image_embedding(v)
                for k, v in image_embeds.items()
            }
918
            future.set_result((embeds, uuid))
919
920

        if isinstance(image_embeds, str):
921
            embedding = self._connector.fetch_image_embedding(image_embeds)
922
            future.set_result((embedding, uuid))
923

924
        if image_embeds is None:
925
            future.set_result((None, uuid))
926

927
        placeholder = self._tracker.add("image_embeds", future)
928
        self._add_placeholder("image", placeholder)
929

930
931
932
933
934
935
936
937
938
939
940
    def parse_audio_embeds(
        self,
        audio_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
    ) -> None:
        mm_config = self.model_config.get_multimodal_config()
        if not mm_config.enable_mm_embeds:
            raise ValueError(
                "You must set `--enable-mm-embeds` to input `audio_embeds`"
            )

941
942
943
        future = asyncio.Future[
            tuple[torch.Tensor | dict[str, torch.Tensor] | None, str | None]
        ]()
944
945
946
947
948
949

        if isinstance(audio_embeds, dict):
            embeds = {
                k: self._connector.fetch_audio_embedding(v)
                for k, v in audio_embeds.items()
            }
950
            future.set_result((embeds, uuid))
951
952
953

        if isinstance(audio_embeds, str):
            embedding = self._connector.fetch_audio_embedding(audio_embeds)
954
            future.set_result((embedding, uuid))
955
956

        if audio_embeds is None:
957
            future.set_result((None, uuid))
958

959
        placeholder = self._tracker.add("audio_embeds", future)
960
961
        self._add_placeholder("audio", placeholder)

962
    def parse_image_pil(
963
964
965
        self,
        image_pil: Image.Image | None,
        uuid: str | None = None,
966
    ) -> None:
967
        future = asyncio.Future[tuple[Image.Image | None, str | None]]()
968
        if image_pil:
969
            future.set_result((image_pil, uuid))
970
        else:
971
            future.set_result((None, uuid))
972

973
        placeholder = self._tracker.add("image", future)
974
        self._add_placeholder("image", placeholder)
975

976
977
978
979
980
981
    async def _audio_with_uuid_async(self, audio_url: str | None, uuid: str | None):
        audio = (
            await self._connector.fetch_audio_async(audio_url) if audio_url else None
        )
        return audio, uuid

982
    def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None:
983
        coro = self._audio_with_uuid_async(audio_url, uuid)
984

985
        placeholder = self._tracker.add("audio", coro)
986
        self._add_placeholder("audio", placeholder)
987

988
    def parse_input_audio(
989
        self, input_audio: InputAudio | None, uuid: str | None = None
990
    ) -> None:
991
992
993
994
995
996
997
998
999
1000
        if input_audio:
            audio_data = input_audio.get("data", "")
            audio_format = input_audio.get("format", "")
            if audio_data:
                audio_url = f"data:audio/{audio_format};base64,{audio_data}"
            else:
                # If a UUID is provided, audio data may be empty.
                audio_url = None
        else:
            audio_url = None
1001

1002
        return self.parse_audio(audio_url, uuid)
1003

1004
    async def _video_with_uuid_async(self, video_url: str | None, uuid: str | None):
1005
        video = (
1006
            await self._connector.fetch_video_async(video_url) if video_url else None
1007
        )
1008
1009
1010
1011
        return video, uuid

    def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
        coro = self._video_with_uuid_async(video_url, uuid)
1012

1013
        placeholder = self._tracker.add("video", coro)
1014
        self._add_placeholder("video", placeholder)
1015

1016

1017
def validate_chat_template(chat_template: Path | str | None):
1018
1019
1020
1021
1022
    """Raises if the provided chat template appears invalid."""
    if chat_template is None:
        return

    elif isinstance(chat_template, Path) and not chat_template.exists():
1023
        raise FileNotFoundError("the supplied chat template path doesn't exist")
1024
1025
1026

    elif isinstance(chat_template, str):
        JINJA_CHARS = "{}\n"
1027
1028
1029
1030
        if (
            not any(c in chat_template for c in JINJA_CHARS)
            and not Path(chat_template).exists()
        ):
1031
1032
1033
            # Try to find the template in the built-in templates directory
            from vllm.transformers_utils.chat_templates.registry import (
                CHAT_TEMPLATES_DIR,
1034
            )
1035

1036
1037
1038
1039
1040
1041
1042
1043
            builtin_template_path = CHAT_TEMPLATES_DIR / chat_template
            if not builtin_template_path.exists():
                raise ValueError(
                    f"The supplied chat template string ({chat_template}) "
                    f"appears path-like, but doesn't exist! "
                    f"Tried: {chat_template} and {builtin_template_path}"
                )

1044
    else:
1045
        raise TypeError(f"{type(chat_template)} is not a valid chat template type")
1046
1047


1048
def _load_chat_template(
1049
    chat_template: Path | str | None,
1050
1051
    *,
    is_literal: bool = False,
1052
) -> str | None:
1053
1054
    if chat_template is None:
        return None
1055
1056
1057

    if is_literal:
        if isinstance(chat_template, Path):
1058
1059
1060
            raise TypeError(
                "chat_template is expected to be read directly from its value"
            )
1061

1062
        return chat_template
1063

1064
    try:
1065
        with open(chat_template) as f:
1066
            return f.read()
1067
    except OSError as e:
1068
1069
1070
        if isinstance(chat_template, Path):
            raise

1071
1072
        JINJA_CHARS = "{}\n"
        if not any(c in chat_template for c in JINJA_CHARS):
1073
1074
1075
            # Try to load from the built-in templates directory
            from vllm.transformers_utils.chat_templates.registry import (
                CHAT_TEMPLATES_DIR,
1076
            )
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089

            builtin_template_path = CHAT_TEMPLATES_DIR / chat_template
            try:
                with open(builtin_template_path) as f:
                    return f.read()
            except OSError:
                msg = (
                    f"The supplied chat template ({chat_template}) "
                    f"looks like a file path, but it failed to be opened. "
                    f"Tried: {chat_template} and {builtin_template_path}. "
                    f"Reason: {e}"
                )
                raise ValueError(msg) from e
1090

1091
1092
        # If opening a file fails, set chat template to be args to
        # ensure we decode so our escape are interpreted correctly
1093
1094
1095
1096
1097
1098
1099
        return _load_chat_template(chat_template, is_literal=True)


_cached_load_chat_template = lru_cache(_load_chat_template)


def load_chat_template(
1100
    chat_template: Path | str | None,
1101
1102
    *,
    is_literal: bool = False,
1103
) -> str | None:
1104
    return _cached_load_chat_template(chat_template, is_literal=is_literal)
1105
1106


1107
1108
1109
def _get_interleaved_text_prompt(
    placeholder_storage: dict[str, list], texts: list[str]
) -> str:
1110
1111
1112
1113
1114
1115
1116
    for idx, elem in enumerate(texts):
        if elem in placeholder_storage:
            texts[idx] = placeholder_storage[elem].pop(0)

    return "\n".join(texts)


1117
# TODO: Let user specify how to insert multimodal tokens into prompt
1118
# (similar to chat template)
1119
1120
1121
1122
1123
def _get_full_multimodal_text_prompt(
    placeholder_storage: dict[str, list],
    texts: list[str],
    interleave_strings: bool,
) -> str:
1124
    """Combine multimodal prompts for a multimodal language model."""
1125

1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
    # flatten storage to make it looks like
    # {
    #   "<|image|>": 2,
    #   "<|audio|>": 1
    # }
    placeholder_counts = Counter(
        [v for elem in placeholder_storage.values() for v in elem]
    )

    if interleave_strings:
        text_prompt = _get_interleaved_text_prompt(placeholder_storage, texts)
    else:
        text_prompt = "\n".join(texts)

    # Pass interleaved text further in case the user used image placeholders
    # himself, but forgot to disable the 'interleave_strings' flag

1143
    # Look through the text prompt to check for missing placeholders
1144
    missing_placeholders: list[str] = []
1145
1146
1147
1148
1149
    for placeholder in placeholder_counts:
        # For any existing placeholder in the text prompt, we leave it as is
        placeholder_counts[placeholder] -= text_prompt.count(placeholder)

        if placeholder_counts[placeholder] < 0:
1150
1151
1152
1153
            logger.error(
                "Placeholder count is negative! "
                "Ensure that the 'interleave_strings' flag is disabled "
                "(current value: %s) "
1154
1155
                "when manually placing image placeholders.",
                interleave_strings,
1156
1157
            )
            logger.debug("Input prompt: %s", text_prompt)
1158
1159
            raise ValueError(
                f"Found more '{placeholder}' placeholders in input prompt than "
1160
1161
                "actual multimodal data items."
            )
1162

1163
        missing_placeholders.extend([placeholder] * placeholder_counts[placeholder])
1164

1165
1166
    # NOTE: Default behaviour: we always add missing placeholders
    # at the front of the prompt, if interleave_strings=False
1167
1168
1169
1170
    if text_prompt:
        return "\n".join(missing_placeholders + [text_prompt])
    else:
        return "\n".join(missing_placeholders)
1171
1172


1173
1174
# No need to validate using Pydantic again
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
1175
_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
1176
_AudioEmbedsParser = partial(cast, ChatCompletionContentPartAudioEmbedsParam)
1177
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
1178
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
1179
_PILImageParser = partial(cast, CustomChatCompletionContentPILImageParam)
Julien Denize's avatar
Julien Denize committed
1180
_ThinkParser = partial(cast, CustomThinkCompletionContentParam)
1181
1182
1183
1184
# Need to validate url objects
_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam).validate_python
_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python
_VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python
1185

1186
_ResponsesInputImageParser = TypeAdapter(ResponseInputImageParam).validate_python
1187
_ContentPart: TypeAlias = str | dict[str, str] | InputAudio | PILImage
1188

1189
# Define a mapping from part types to their corresponding parsing functions.
1190
MM_PARSER_MAP: dict[
1191
1192
1193
    str,
    Callable[[ChatCompletionContentPartParam], _ContentPart],
] = {
1194
1195
1196
    "text": lambda part: _TextParser(part).get("text", None),
    "thinking": lambda part: _ThinkParser(part).get("thinking", None),
    "input_text": lambda part: _TextParser(part).get("text", None),
1197
    "output_text": lambda part: _TextParser(part).get("text", None),
1198
1199
1200
    "input_image": lambda part: _ResponsesInputImageParser(part).get("image_url", None),
    "image_url": lambda part: _ImageParser(part).get("image_url", {}).get("url", None),
    "image_embeds": lambda part: _ImageEmbedsParser(part).get("image_embeds", None),
1201
    "audio_embeds": lambda part: _AudioEmbedsParser(part).get("audio_embeds", None),
1202
    "image_pil": lambda part: _PILImageParser(part).get("image_pil", None),
1203
1204
    "audio_url": lambda part: _AudioParser(part).get("audio_url", {}).get("url", None),
    "input_audio": lambda part: _InputAudioParser(part).get("input_audio", None),
1205
    "refusal": lambda part: _RefusalParser(part).get("refusal", None),
1206
    "video_url": lambda part: _VideoParser(part).get("video_url", {}).get("url", None),
1207
1208
1209
1210
}


def _parse_chat_message_content_mm_part(
1211
1212
    part: ChatCompletionContentPartParam,
) -> tuple[str, _ContentPart]:
1213
    """
1214
    Parses a given multi-modal content part based on its type.
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227

    Args:
        part: A dict containing the content part, with a potential 'type' field.

    Returns:
        A tuple (part_type, content) where:
        - part_type: Type of the part (e.g., 'text', 'image_url').
        - content: Parsed content (e.g., text, image URL).

    Raises:
        ValueError: If the 'type' field is missing and no direct URL is found.
    """
    assert isinstance(
1228
1229
        part, dict
    )  # This is needed to avoid mypy errors: part.get() from str
1230
    part_type = part.get("type", None)
1231
    uuid = part.get("uuid", None)
1232

1233
    if isinstance(part_type, str) and part_type in MM_PARSER_MAP and uuid is None:  # noqa: E501
1234
1235
1236
        content = MM_PARSER_MAP[part_type](part)

        # Special case for 'image_url.detail'
1237
1238
        # We only support 'auto', which is the default
        if part_type == "image_url" and part.get("detail", "auto") != "auto":
1239
            logger.warning(
1240
                "'image_url.detail' is currently not supported and will be ignored."
1241
            )
1242
1243
1244
1245

        return part_type, content

    # Handle missing 'type' but provided direct URL fields.
1246
    # 'type' is required field by pydantic
1247
1248
    if part_type is None or uuid is not None:
        if "image_url" in part:
1249
            image_params = cast(CustomChatCompletionContentSimpleImageParam, part)
1250
1251
1252
1253
1254
1255
1256
1257
            image_url = image_params.get("image_url", None)
            if isinstance(image_url, dict):
                # Can potentially happen if user provides a uuid
                # with url as a dict of {"url": url}
                image_url = image_url.get("url", None)
            return "image_url", image_url
        if "image_pil" in part:
            # "image_pil" could be None if UUID is provided.
1258
            image_params = cast(  # type: ignore
1259
1260
1261
1262
1263
1264
                CustomChatCompletionContentPILImageParam, part
            )
            image_pil = image_params.get("image_pil", None)
            return "image_pil", image_pil
        if "image_embeds" in part:
            # "image_embeds" could be None if UUID is provided.
1265
            image_params = cast(  # type: ignore
1266
1267
1268
1269
                ChatCompletionContentPartImageEmbedsParam, part
            )
            image_embeds = image_params.get("image_embeds", None)
            return "image_embeds", image_embeds
1270
1271
1272
1273
1274
1275
1276
        if "audio_embeds" in part:
            # "audio_embeds" could be None if UUID is provided.
            audio_params = cast(  # type: ignore[assignment]
                ChatCompletionContentPartAudioEmbedsParam, part
            )
            audio_embeds = audio_params.get("audio_embeds", None)
            return "audio_embeds", audio_embeds
1277
        if "audio_url" in part:
1278
1279
1280
            audio_params = cast(  # type: ignore[assignment]
                CustomChatCompletionContentSimpleAudioParam, part
            )
1281
1282
1283
1284
1285
1286
            audio_url = audio_params.get("audio_url", None)
            if isinstance(audio_url, dict):
                # Can potentially happen if user provides a uuid
                # with url as a dict of {"url": url}
                audio_url = audio_url.get("url", None)
            return "audio_url", audio_url
1287
        if part.get("input_audio") is not None:
1288
            input_audio_params = cast(dict[str, str], part)
1289
            return "input_audio", input_audio_params
1290
        if "video_url" in part:
1291
            video_params = cast(CustomChatCompletionContentSimpleVideoParam, part)
1292
1293
1294
1295
1296
1297
            video_url = video_params.get("video_url", None)
            if isinstance(video_url, dict):
                # Can potentially happen if user provides a uuid
                # with url as a dict of {"url": url}
                video_url = video_url.get("url", None)
            return "video_url", video_url
1298
1299
1300
1301
1302
1303
1304
1305
        # Raise an error if no 'type' or direct URL is found.
        raise ValueError("Missing 'type' field in multimodal part.")

    if not isinstance(part_type, str):
        raise ValueError("Invalid 'type' field in multimodal part.")
    return part_type, "unknown part_type content"


1306
PART_TYPES_TO_SKIP_NONE_CONTENT = (
1307
1308
1309
    "text",
    "refusal",
)
1310

1311

1312
1313
1314
def _parse_chat_message_content_parts(
    role: str,
    parts: Iterable[ChatCompletionContentPartParam],
1315
    mm_tracker: BaseMultiModalItemTracker,
1316
1317
    *,
    wrap_dicts: bool,
1318
    interleave_strings: bool,
1319
) -> list[ConversationMessage]:
1320
    content = list[_ContentPart]()
1321

1322
    mm_parser = mm_tracker.create_parser()
1323
1324

    for part in parts:
1325
        parse_res = _parse_chat_message_content_part(
1326
1327
1328
            part,
            mm_parser,
            wrap_dicts=wrap_dicts,
1329
            interleave_strings=interleave_strings,
1330
        )
1331
1332
        if parse_res:
            content.append(parse_res)
1333

1334
    if wrap_dicts:
1335
        # Parsing wraps images and texts as interleaved dictionaries
1336
        return [ConversationMessage(role=role, content=content)]  # type: ignore
1337
    texts = cast(list[str], content)
1338
1339
    mm_placeholder_storage = mm_parser.mm_placeholder_storage()
    if mm_placeholder_storage:
1340
1341
1342
        text_prompt = _get_full_multimodal_text_prompt(
            mm_placeholder_storage, texts, interleave_strings
        )
1343
1344
1345
    else:
        text_prompt = "\n".join(texts)

1346
1347
1348
1349
    return [ConversationMessage(role=role, content=text_prompt)]


def _parse_chat_message_content_part(
1350
1351
1352
1353
    part: ChatCompletionContentPartParam,
    mm_parser: BaseMultiModalContentParser,
    *,
    wrap_dicts: bool,
1354
    interleave_strings: bool,
1355
) -> _ContentPart | None:
1356
1357
1358
1359
1360
1361
1362
1363
    """Parses a single part of a conversation. If wrap_dicts is True,
    structured dictionary pieces for texts and images will be
    wrapped in dictionaries, i.e., {"type": "text", "text", ...} and
    {"type": "image"}, respectively. Otherwise multimodal data will be
    handled by mm_parser, and texts will be returned as strings to be joined
    with multimodal placeholders.
    """
    if isinstance(part, str):  # Handle plain text parts
1364
        return part
1365
1366
    # Handle structured dictionary parts
    part_type, content = _parse_chat_message_content_mm_part(part)
1367
    # if part_type is text/refusal/image_url/audio_url/video_url/input_audio but
1368
    # content is None, log a warning and skip
1369
    if part_type in PART_TYPES_TO_SKIP_NONE_CONTENT and content is None:
1370
        logger.warning(
1371
            "Skipping multimodal part '%s' (type: '%s') "
1372
1373
1374
1375
            "with empty / unparsable content.",
            part,
            part_type,
        )
1376
1377
        return None

1378
    if part_type in ("text", "input_text", "output_text", "refusal", "thinking"):
1379
1380
        str_content = cast(str, content)
        if wrap_dicts:
1381
            return {"type": "text", "text": str_content}
1382
1383
        else:
            return str_content
1384

1385
1386
1387
1388
1389
1390
    # For media items, if a user has provided one, use it. Otherwise, insert
    # a placeholder empty uuid.
    uuid = part.get("uuid", None)
    if uuid is not None:
        uuid = str(uuid)

1391
    modality = None
1392
    if part_type == "image_pil":
1393
        image_content = cast(Image.Image, content) if content is not None else None
1394
        mm_parser.parse_image_pil(image_content, uuid)
1395
        modality = "image"
1396
    elif part_type in ("image_url", "input_image"):
1397
        str_content = cast(str, content)
1398
        mm_parser.parse_image(str_content, uuid)
1399
1400
        modality = "image"
    elif part_type == "image_embeds":
1401
        content = cast(str | dict[str, str], content) if content is not None else None
1402
        mm_parser.parse_image_embeds(content, uuid)
1403
        modality = "image"
1404
1405
1406
1407
    elif part_type == "audio_embeds":
        content = cast(str | dict[str, str], content) if content is not None else None
        mm_parser.parse_audio_embeds(content, uuid)
        modality = "audio"
1408
    elif part_type == "audio_url":
1409
        str_content = cast(str, content)
1410
        mm_parser.parse_audio(str_content, uuid)
1411
1412
        modality = "audio"
    elif part_type == "input_audio":
1413
        dict_content = cast(InputAudio, content)
1414
        mm_parser.parse_input_audio(dict_content, uuid)
1415
1416
        modality = "audio"
    elif part_type == "video_url":
1417
        str_content = cast(str, content)
1418
        mm_parser.parse_video(str_content, uuid)
1419
1420
1421
        modality = "video"
    else:
        raise NotImplementedError(f"Unknown part type: {part_type}")
1422

1423
1424
1425
    return (
        {"type": modality}
        if wrap_dicts
1426
        else (MODALITY_PLACEHOLDERS_MAP[modality] if interleave_strings else None)
1427
    )
1428
1429


1430
1431
1432
1433
1434
# No need to validate using Pydantic again
_AssistantParser = partial(cast, ChatCompletionAssistantMessageParam)
_ToolParser = partial(cast, ChatCompletionToolMessageParam)


1435
def _parse_chat_message_content(
1436
1437
    message: ChatCompletionMessageParam,
    mm_tracker: BaseMultiModalItemTracker,
1438
    content_format: ChatTemplateContentFormat,
1439
    interleave_strings: bool,
1440
) -> list[ConversationMessage]:
1441
1442
    role = message["role"]
    content = message.get("content")
1443
    reasoning = message.get("reasoning")
1444

1445
    if content is None:
1446
1447
        content = []
    elif isinstance(content, str):
1448
        content = [ChatCompletionContentPartTextParam(type="text", text=content)]
1449
    result = _parse_chat_message_content_parts(
1450
1451
        role,
        content,  # type: ignore
1452
        mm_tracker,
1453
        wrap_dicts=(content_format == "openai"),
1454
        interleave_strings=interleave_strings,
1455
    )
1456

1457
    for result_msg in result:
1458
        if role == "assistant":
1459
1460
            parsed_msg = _AssistantParser(message)

1461
1462
1463
            # The 'tool_calls' is not None check ensures compatibility.
            # It's needed only if downstream code doesn't strictly
            # follow the OpenAI spec.
1464
            if "tool_calls" in parsed_msg and parsed_msg["tool_calls"] is not None:
1465
                result_msg["tool_calls"] = list(parsed_msg["tool_calls"])
1466
1467
1468
            # Include reasoning if present for interleaved thinking.
            if reasoning is not None:
                result_msg["reasoning"] = cast(str, reasoning)
1469
1470
1471
                result_msg["reasoning_content"] = cast(
                    str, reasoning
                )  # keep compatibility
1472
1473
1474
1475
1476
1477
1478
1479
        elif role == "tool":
            parsed_msg = _ToolParser(message)
            if "tool_call_id" in parsed_msg:
                result_msg["tool_call_id"] = parsed_msg["tool_call_id"]

        if "name" in message and isinstance(message["name"], str):
            result_msg["name"] = message["name"]

1480
1481
        if role == "developer":
            result_msg["tools"] = message.get("tools", None)
1482
1483
    return result

1484

1485
def _postprocess_messages(messages: list[ConversationMessage]) -> None:
1486
1487
1488
1489
1490
1491
    # per the Transformers docs & maintainers, tool call arguments in
    # assistant-role messages with tool_calls need to be dicts not JSON str -
    # this is how tool-use chat templates will expect them moving forwards
    # so, for messages that have tool_calls, parse the string (which we get
    # from openAI format) to dict
    for message in messages:
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
        if message["role"] == "assistant" and "tool_calls" in message:
            tool_calls = message.get("tool_calls")
            if not isinstance(tool_calls, list):
                continue

            if len(tool_calls) == 0:
                # Drop empty tool_calls to keep templates on the normal assistant path.
                message.pop("tool_calls", None)
                continue

            for item in tool_calls:
1503
1504
                # if arguments is None or empty string, set to {}
                if content := item["function"].get("arguments"):
1505
1506
                    if not isinstance(content, (dict, list)):
                        item["function"]["arguments"] = json.loads(content)
1507
1508
                else:
                    item["function"]["arguments"] = {}
1509
1510


1511
def parse_chat_messages(
1512
    messages: list[ChatCompletionMessageParam],
1513
    model_config: ModelConfig,
1514
    content_format: ChatTemplateContentFormat,
1515
1516
) -> tuple[
    list[ConversationMessage],
1517
1518
    MultiModalDataDict | None,
    MultiModalUUIDDict | None,
1519
]:
1520
    conversation: list[ConversationMessage] = []
1521
    mm_tracker = MultiModalItemTracker(model_config)
1522
1523

    for msg in messages:
1524
1525
1526
        sub_messages = _parse_chat_message_content(
            msg,
            mm_tracker,
1527
            content_format,
1528
1529
1530
1531
            interleave_strings=(
                content_format == "string"
                and model_config.multimodal_config is not None
                and model_config.multimodal_config.interleave_mm_strings
1532
            ),
1533
        )
1534

1535
        conversation.extend(sub_messages)
1536

1537
1538
    _postprocess_messages(conversation)

1539
1540
1541
    mm_data, mm_uuids = mm_tracker.resolve_items()

    return conversation, mm_data, mm_uuids
1542
1543


1544
async def parse_chat_messages_async(
1545
    messages: list[ChatCompletionMessageParam],
1546
    model_config: ModelConfig,
1547
    content_format: ChatTemplateContentFormat,
1548
1549
) -> tuple[
    list[ConversationMessage],
1550
    MultiModalDataDict | None,
1551
    MultiModalUUIDDict | None,
1552
]:
1553
    conversation: list[ConversationMessage] = []
1554
    mm_tracker = AsyncMultiModalItemTracker(model_config)
1555
1556

    for msg in messages:
1557
1558
1559
        sub_messages = _parse_chat_message_content(
            msg,
            mm_tracker,
1560
            content_format,
1561
1562
1563
1564
            interleave_strings=(
                content_format == "string"
                and model_config.multimodal_config is not None
                and model_config.multimodal_config.interleave_mm_strings
1565
            ),
1566
        )
1567
1568
1569

        conversation.extend(sub_messages)

1570
1571
    _postprocess_messages(conversation)

1572
1573
1574
    mm_data, mm_uuids = await mm_tracker.resolve_items()

    return conversation, mm_data, mm_uuids
1575

1576

1577
1578
1579
def get_history_tool_calls_cnt(conversation: list[ConversationMessage]):
    idx = 0
    for msg in conversation:
1580
1581
1582
        if msg["role"] == "assistant":
            tool_calls = msg.get("tool_calls")
            idx += len(list(tool_calls)) if tool_calls is not None else 0  # noqa
1583
1584
1585
    return idx


1586
1587
1588
def make_tool_call_id(id_type: str = "random", func_name=None, idx=None):
    if id_type == "kimi_k2":
        return f"functions.{func_name}:{idx}"
1589
1590
1591
    else:
        # by default return random
        return f"chatcmpl-tool-{random_uuid()}"