chat_utils.py 55.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
import json
6
import warnings
7
from abc import ABC, abstractmethod
8
from collections import Counter, defaultdict
9
from collections.abc import Awaitable, Callable, Iterable
10
from functools import cached_property, lru_cache, partial
11
from itertools import accumulate
12
from pathlib import Path
13
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, TypeVar, cast
14

15
from openai.types.chat import (
16
17
18
19
20
    ChatCompletionAssistantMessageParam,
    ChatCompletionContentPartImageParam,
    ChatCompletionContentPartInputAudioParam,
    ChatCompletionContentPartRefusalParam,
    ChatCompletionContentPartTextParam,
21
    ChatCompletionFunctionToolParam,
22
23
24
25
26
27
    ChatCompletionMessageToolCallParam,
    ChatCompletionToolMessageParam,
)
from openai.types.chat import (
    ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam,
)
28
from openai.types.chat import (
29
30
31
    ChatCompletionMessageParam as OpenAIChatCompletionMessageParam,
)
from openai.types.chat.chat_completion_content_part_input_audio_param import InputAudio
32
from openai.types.responses import ResponseInputImageParam
33
from openai_harmony import Message as OpenAIHarmonyMessage
34
35
from PIL import Image
from pydantic import BaseModel, ConfigDict, TypeAdapter
36

37
# pydantic needs the TypedDict from typing_extensions
38
from typing_extensions import Required, TypedDict
39

40
from vllm import envs
41
from vllm.config import ModelConfig
42
from vllm.logger import init_logger
43
from vllm.model_executor.models import SupportsMultiModal
44
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict
45
46
47
48
from vllm.multimodal.inputs import (
    MultiModalBatchedField,
    MultiModalFlatField,
    MultiModalSharedField,
Roger Wang's avatar
Roger Wang committed
49
50
51
    VisionChunk,
    VisionChunkImage,
    VisionChunkVideo,
52
)
53
from vllm.multimodal.media import MEDIA_CONNECTOR_REGISTRY, MediaConnector
54
from vllm.multimodal.processing import BaseMultiModalProcessor
55
from vllm.utils import random_uuid
56
57
58
59
60
from vllm.utils.collection_utils import is_list_of
from vllm.utils.import_utils import LazyLoader

if TYPE_CHECKING:
    import torch
61
    import transformers
62
else:
63
    transformers = LazyLoader("transformers", globals(), "transformers")
64
    torch = LazyLoader("torch", globals(), "torch")
65
66
67

logger = init_logger(__name__)

68

69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def __getattr__(name: str):
    if name == "resolve_hf_chat_template":
        from vllm.renderers.hf import resolve_chat_template

        warnings.warn(
            "`vllm.entrypoints.chat_utils.resolve_hf_chat_template` has been moved to "
            "`vllm.renderers.hf.resolve_chat_template`. "
            "The old name will be removed in v0.16.",
            DeprecationWarning,
            stacklevel=2,
        )

        return resolve_chat_template

    raise AttributeError(f"module {__name__!r} has no attribute {name!r}")


86
87
88
89
90
91
92
93
class ChatTemplateResolutionError(ValueError):
    """Raised when chat template resolution fails.

    This is a subclass of ValueError for backward compatibility with
    existing exception handlers.
    """


94
95
96
97
98
99
MODALITY_PLACEHOLDERS_MAP = {
    "image": "<##IMAGE##>",
    "audio": "<##AUDIO##>",
    "video": "<##VIDEO##>",
}

100

101
102
103
104
105
106
107
108
109
110
111
112
113
114
class AudioURL(TypedDict, total=False):
    url: Required[str]
    """
    Either a URL of the audio or a data URL with base64 encoded audio data.
    """


class ChatCompletionContentPartAudioParam(TypedDict, total=False):
    audio_url: Required[AudioURL]

    type: Required[Literal["audio_url"]]
    """The type of the content part."""


115
class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
116
    image_embeds: str | dict[str, str] | None
117
118
119
120
121
122
123
    """
    The image embeddings. It can be either:
    - A single base64 string.
    - A dictionary where each value is a base64 string.
    """
    type: Required[Literal["image_embeds"]]
    """The type of the content part."""
124
    uuid: str | None
125
126
127
128
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
129
130


131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
class ChatCompletionContentPartAudioEmbedsParam(TypedDict, total=False):
    audio_embeds: str | dict[str, str] | None
    """
    The audio embeddings. It can be either:
    - A single base64 string representing a serialized torch tensor.
    - A dictionary where each value is a base64 string.
    """
    type: Required[Literal["audio_embeds"]]
    """The type of the content part."""
    uuid: str | None
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """


147
148
149
150
151
152
153
154
155
156
157
158
159
160
class VideoURL(TypedDict, total=False):
    url: Required[str]
    """
    Either a URL of the video or a data URL with base64 encoded video data.
    """


class ChatCompletionContentPartVideoParam(TypedDict, total=False):
    video_url: Required[VideoURL]

    type: Required[Literal["video_url"]]
    """The type of the content part."""


161
162
163
164
class PILImage(BaseModel):
    """
    A PIL.Image.Image object.
    """
165

166
167
168
169
170
171
172
173
174
175
176
177
    image_pil: Image.Image
    model_config = ConfigDict(arbitrary_types_allowed=True)


class CustomChatCompletionContentPILImageParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a PIL image.

    Example:
    {
        "image_pil": ImageAsset('cherry_blossom').pil_image
    }
    """
178

179
180
    image_pil: PILImage | None
    uuid: str | None
181
182
183
184
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
185
186


187
188
189
class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain image_url.
    This is supported by OpenAI API, although it is not documented.
190

191
192
193
194
195
    Example:
    {
        "image_url": "https://example.com/image.jpg"
    }
    """
196

197
198
    image_url: str | None
    uuid: str | None
199
200
201
202
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
203
204
205
206


class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain audio_url.
207

208
209
210
211
212
    Example:
    {
        "audio_url": "https://example.com/audio.mp3"
    }
    """
213

214
    audio_url: str | None
215
216


217
218
219
220
221
222
223
224
class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain audio_url.

    Example:
    {
        "video_url": "https://example.com/video.mp4"
    }
    """
225

226
227
    video_url: str | None
    uuid: str | None
228
229
230
231
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
232
233


Julien Denize's avatar
Julien Denize committed
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
class CustomThinkCompletionContentParam(TypedDict, total=False):
    """A Think Completion Content Param that accepts a plain text and a boolean.

    Example:
    {
        "thinking": "I am thinking about the answer",
        "closed": True,
        "type": "thinking"
    }
    """

    thinking: Required[str]
    """The thinking content."""

    closed: bool
    """Whether the thinking is closed."""

    type: Required[Literal["thinking"]]
    """The thinking type."""


255
256
257
258
259
260
261
262
263
ChatCompletionContentPartParam: TypeAlias = (
    OpenAIChatCompletionContentPartParam
    | ChatCompletionContentPartAudioParam
    | ChatCompletionContentPartInputAudioParam
    | ChatCompletionContentPartVideoParam
    | ChatCompletionContentPartRefusalParam
    | CustomChatCompletionContentPILImageParam
    | CustomChatCompletionContentSimpleImageParam
    | ChatCompletionContentPartImageEmbedsParam
264
    | ChatCompletionContentPartAudioEmbedsParam
265
266
267
268
269
    | CustomChatCompletionContentSimpleAudioParam
    | CustomChatCompletionContentSimpleVideoParam
    | str
    | CustomThinkCompletionContentParam
)
270
271
272
273


class CustomChatCompletionMessageParam(TypedDict, total=False):
    """Enables custom roles in the Chat Completion API."""
274

275
276
277
    role: Required[str]
    """The role of the message's author."""

278
    content: str | list[ChatCompletionContentPartParam]
279
280
281
282
283
284
285
286
287
    """The contents of the message."""

    name: str
    """An optional name for the participant.

    Provides the model information to differentiate between participants of the
    same role.
    """

288
    tool_call_id: str | None
289
290
    """Tool call that this message is responding to."""

291
    tool_calls: Iterable[ChatCompletionMessageToolCallParam] | None
292
293
    """The tool calls generated by the model, such as function calls."""

294
295
296
    reasoning: str | None
    """The reasoning content for interleaved thinking."""

297
298
299
    tools: list[ChatCompletionFunctionToolParam] | None
    """The tools for developer role."""

300

301
302
303
304
305
ChatCompletionMessageParam: TypeAlias = (
    OpenAIChatCompletionMessageParam
    | CustomChatCompletionMessageParam
    | OpenAIHarmonyMessage
)
306
307


308
# TODO: Make fields ReadOnly once mypy supports it
309
310
311
312
class ConversationMessage(TypedDict, total=False):
    role: Required[str]
    """The role of the message's author."""

313
    content: str | None | list[dict[str, str]]
314
315
    """The contents of the message"""

316
    tool_call_id: str | None
317
318
    """Tool call that this message is responding to."""

319
    name: str | None
320
321
    """The name of the function to call"""

322
    tool_calls: Iterable[ChatCompletionMessageToolCallParam] | None
323
    """The tool calls generated by the model, such as function calls."""
324

325
326
327
328
329
330
    reasoning: str | None
    """The reasoning content for interleaved thinking."""

    reasoning_content: str | None
    """Deprecated: The reasoning content for interleaved thinking."""

331
332
333
    tools: list[ChatCompletionFunctionToolParam] | None
    """The tools for developer role."""

334

335
336
337
# Passed in by user
ChatTemplateContentFormatOption = Literal["auto", "string", "openai"]

338
339
# After resolving "auto"
ChatTemplateContentFormat = Literal["string", "openai"]
340

341

Roger Wang's avatar
Roger Wang committed
342
343
344
ModalityStr = Literal[
    "image", "audio", "video", "image_embeds", "audio_embeds", "vision_chunk"
]
345
346
347
_T = TypeVar("_T")


348
349
350
# Backward compatibility for single item input
class _BatchedSingleItemField(MultiModalSharedField):
    pass
351
352


353
354
355
356
357
358
def _detect_field(
    tensors: list[torch.Tensor],
    mm_processor: BaseMultiModalProcessor,
):
    first_item = tensors[0]
    hidden_size = mm_processor.info.ctx.model_config.get_inputs_embeds_size()
359

360
361
362
363
364
365
366
367
368
369
370
    if (
        len(tensors) == 1
        and first_item.ndim == 3
        and first_item.shape[0] == 1
        and first_item.shape[-1] == hidden_size
    ):
        logger.warning(
            "Batched multi-modal embedding inputs are deprecated for Chat API. "
            "Please pass a separate content part for each multi-modal item."
        )
        return _BatchedSingleItemField(batch_size=1)
371

372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
    first_shape = first_item.shape
    if all(t.shape == first_shape for t in tensors):
        return MultiModalBatchedField()

    size_per_item = [len(tensor) for tensor in tensors]
    slice_idxs = [0, *accumulate(size_per_item)]
    slices = [
        (slice(slice_idxs[i], slice_idxs[i + 1]),) for i in range(len(size_per_item))
    ]
    return MultiModalFlatField(slices=slices)


def _merge_embeds(
    data_items: list[dict[str, "torch.Tensor"]],
    mm_processor: BaseMultiModalProcessor,
):
    if not data_items:
        return {}

    first_keys = set(data_items[0].keys())
    if any(set(item.keys()) != first_keys for item in data_items[1:]):
        raise ValueError(
            "All dictionaries in the list of embeddings must have the same keys."
        )
396

397
398
399
400
401
402
403
404
    fields = {
        key: _detect_field([item[key] for item in data_items], mm_processor)
        for key in first_keys
    }
    data_merged = {
        key: field._reduce_data([item[key] for item in data_items], pin_memory=False)
        for key, field in fields.items()
    }
405

406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
    try:
        # TODO: Support per-request mm_processor_kwargs
        parsed_configs = mm_processor._get_mm_fields_config(
            transformers.BatchFeature(data_merged),
            {},
        )
        parsed_fields = {key: parsed_configs[key].field for key in first_keys}
        keys_to_update = [
            key
            for key in first_keys
            if (
                fields[key] != parsed_fields[key]
                and not isinstance(fields[key], _BatchedSingleItemField)
            )
        ]
421

422
423
424
        for key in keys_to_update:
            data_merged[key] = parsed_fields[key]._reduce_data(
                [item[key] for item in data_items], pin_memory=False
425
            )
426
427
428
429
430
    except Exception:
        logger.exception(
            "Error when parsing merged embeddings. "
            "Falling back to auto-detected fields."
        )
431

432
    return data_merged
433

434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454

def _get_embeds_data(
    modality: str,
    data_items: list[Any],
    mm_processor: BaseMultiModalProcessor,
):
    if len(data_items) == 0:
        return data_items

    if all(item is None for item in data_items):
        return data_items

    if is_list_of(data_items, torch.Tensor):
        embeds_key = f"{modality}_embeds"
        dict_items = [{embeds_key: item} for item in data_items]
        return _merge_embeds(dict_items, mm_processor)[embeds_key]

    if is_list_of(data_items, dict):
        return _merge_embeds(data_items, mm_processor)

    raise NotImplementedError(type(data_items))
455
456


457
class BaseMultiModalItemTracker(ABC, Generic[_T]):
458
459
460
461
462
463
    """
    Tracks multi-modal items in a given request and ensures that the number
    of multi-modal items in a given request does not exceed the configured
    maximum per prompt.
    """

464
    def __init__(self, model_config: ModelConfig):
465
466
        super().__init__()

467
        self._model_config = model_config
468

469
        self._items_by_modality = defaultdict[str, list[_T]](list)
Roger Wang's avatar
Roger Wang committed
470
471
472
473
474
475
476
        # Track original modality for each vision_chunk item (image or video)
        self._modality_order = defaultdict[str, list[str]](list)

    @cached_property
    def use_unified_vision_chunk_modality(self) -> bool:
        """Check if model uses unified vision_chunk modality for images/videos."""
        return getattr(self._model_config.hf_config, "use_unified_vision_chunk", False)
477

478
    @property
479
480
    def model_config(self) -> ModelConfig:
        return self._model_config
481

482
    @cached_property
483
    def model_cls(self) -> type[SupportsMultiModal]:
484
        from vllm.model_executor.model_loader import get_model_cls
485

486
        model_cls = get_model_cls(self.model_config)
487
        return cast(type[SupportsMultiModal], model_cls)
488

489
490
    @property
    def allowed_local_media_path(self):
491
        return self._model_config.allowed_local_media_path
492

493
494
    @property
    def allowed_media_domains(self):
495
        return self._model_config.allowed_media_domains
496

497
498
499
500
    @property
    def mm_registry(self):
        return MULTIMODAL_REGISTRY

501
502
    @cached_property
    def mm_processor(self):
503
        return self.mm_registry.create_processor(self.model_config)
504

505
    def add(self, modality: ModalityStr, item: _T) -> str | None:
506
507
508
        """
        Add a multi-modal item to the current prompt and returns the
        placeholder string to use, if any.
509
510

        An optional uuid can be added which serves as a unique identifier of the
511
        media.
512
        """
513
        input_modality = modality.replace("_embeds", "")
Roger Wang's avatar
Roger Wang committed
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
        original_modality = modality
        use_vision_chunk = (
            self.use_unified_vision_chunk_modality
            and original_modality in ["video", "image"]
        )

        # If use_unified_vision_chunk_modality is enabled,
        # map image/video to vision_chunk
        if use_vision_chunk:
            # To avoid validation fail
            # because models with use_unified_vision_chunk_modality=True
            # will only accept vision_chunk modality.
            input_modality = "vision_chunk"
            num_items = len(self._items_by_modality[input_modality]) + 1
        else:
            num_items = len(self._items_by_modality[original_modality]) + 1
530

531
532
533
534
535
536
537
538
539
540
541
        mm_config = self.model_config.multimodal_config
        if (
            mm_config is not None
            and mm_config.enable_mm_embeds
            and mm_config.get_limit_per_prompt(input_modality) == 0
            and original_modality.endswith("_embeds")
        ):
            # Skip validation: embeddings bypass limit when enable_mm_embeds=True
            pass
        else:
            self.mm_processor.info.validate_num_items(input_modality, num_items)
542

Roger Wang's avatar
Roger Wang committed
543
544
545
546
547
548
        # Track original modality for vision_chunk items
        if use_vision_chunk:
            self._items_by_modality[input_modality].append(item)  # type: ignore
            self._modality_order["vision_chunk"].append(original_modality)
        else:
            self._items_by_modality[original_modality].append(item)
549

550
        return self.model_cls.get_placeholder_str(modality, num_items)
551
552
553
554
555
556

    @abstractmethod
    def create_parser(self) -> "BaseMultiModalContentParser":
        raise NotImplementedError


557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
def _resolve_vision_chunk_items(
    vision_chunk_items: list[tuple[object, str | None]],
    mm_processor: BaseMultiModalProcessor,
    vision_chunks_modality_order: list[str],
):
    # Process vision_chunk items - extract from (data, modality) tuples
    # and convert to VisionChunk types with proper UUID handling
    vision_chunks_uuids = [uuid for data, uuid in vision_chunk_items]

    assert len(vision_chunk_items) == len(vision_chunks_modality_order), (
        f"vision_chunk items ({len(vision_chunk_items)}) and "
        f"modality_order ({len(vision_chunks_modality_order)}) must have same length"
    )

    processed_chunks: list[VisionChunk] = []
    video_idx = 0
    for inner_modality, (data, uuid) in zip(
        vision_chunks_modality_order, vision_chunk_items
    ):
        if inner_modality == "image":
            # Cast data to proper type for image
            # Use .media (PIL.Image) directly to avoid redundant
            # bytes→PIL conversion in media_processor
            if hasattr(data, "media"):
                image_data = data.media  # type: ignore[union-attr]
                processed_chunks.append(
                    VisionChunkImage(type="image", image=image_data, uuid=uuid)
                )
            else:
                processed_chunks.append(data)  # type: ignore[arg-type]
        elif inner_modality == "video":
            # For video, we may need to split into chunks
            # if processor supports it
            # For now, just wrap as a video chunk placeholder
            if hasattr(mm_processor, "split_video_chunks") and data is not None:
                try:
                    video_uuid = uuid or random_uuid()
                    # video await result is (video_data, video_meta) tuple
                    if isinstance(data, tuple) and len(data) >= 1:
                        video_data = data[0]
                    else:
                        video_data = data
                    video_chunks = mm_processor.split_video_chunks(video_data)
                    for i, vc in enumerate(video_chunks):
                        processed_chunks.append(
                            VisionChunkVideo(
                                type="video_chunk",
                                video_chunk=vc["video_chunk"],
                                uuid=f"{video_uuid}-{i}",
                                video_idx=video_idx,
                                prompt=vc["prompt"],
                            )
                        )
                    video_idx += 1
                except Exception as e:
                    logger.warning("Failed to split video chunks: %s", e)
                    processed_chunks.append(data)  # type: ignore[arg-type]
            else:
                processed_chunks.append(data)  # type: ignore[arg-type]
    return processed_chunks, vision_chunks_uuids


619
620
621
def _resolve_items(
    items_by_modality: dict[str, list[tuple[object, str | None]]],
    mm_processor: BaseMultiModalProcessor,
622
    modality_order: dict[str, list[str]],
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
) -> tuple[MultiModalDataDict, MultiModalUUIDDict]:
    if "image" in items_by_modality and "image_embeds" in items_by_modality:
        raise ValueError("Mixing raw image and embedding inputs is not allowed")
    if "audio" in items_by_modality and "audio_embeds" in items_by_modality:
        raise ValueError("Mixing raw audio and embedding inputs is not allowed")

    mm_data = {}
    mm_uuids = {}
    if "image_embeds" in items_by_modality:
        mm_data["image"] = _get_embeds_data(
            "image",
            [data for data, uuid in items_by_modality["image_embeds"]],
            mm_processor,
        )
        mm_uuids["image"] = [uuid for data, uuid in items_by_modality["image_embeds"]]
    if "image" in items_by_modality:
        mm_data["image"] = [data for data, uuid in items_by_modality["image"]]
        mm_uuids["image"] = [uuid for data, uuid in items_by_modality["image"]]
    if "audio_embeds" in items_by_modality:
        mm_data["audio"] = _get_embeds_data(
            "audio",
            [data for data, uuid in items_by_modality["audio_embeds"]],
            mm_processor,
        )
        mm_uuids["audio"] = [uuid for data, uuid in items_by_modality["audio_embeds"]]
    if "audio" in items_by_modality:
        mm_data["audio"] = [data for data, uuid in items_by_modality["audio"]]
        mm_uuids["audio"] = [uuid for data, uuid in items_by_modality["audio"]]
    if "video" in items_by_modality:
        mm_data["video"] = [data for data, uuid in items_by_modality["video"]]
        mm_uuids["video"] = [uuid for data, uuid in items_by_modality["video"]]
Roger Wang's avatar
Roger Wang committed
654
655
656
    if "vision_chunk" in items_by_modality:
        # Process vision_chunk items - extract from (data, modality) tuples
        # and convert to VisionChunk types with proper UUID handling
657
658
659
660
        processed_chunks, vision_chunk_uuids = _resolve_vision_chunk_items(
            items_by_modality["vision_chunk"],
            mm_processor,
            modality_order.get("vision_chunk", []),
Roger Wang's avatar
Roger Wang committed
661
662
        )
        mm_data["vision_chunk"] = processed_chunks
663
        mm_uuids["vision_chunk"] = vision_chunk_uuids
664
665
666
667
668
669
670
671

    return mm_data, mm_uuids


class MultiModalItemTracker(BaseMultiModalItemTracker[tuple[object, str | None]]):
    def resolve_items(
        self,
    ) -> tuple[MultiModalDataDict | None, MultiModalUUIDDict | None]:
672
        if not self._items_by_modality:
673
674
            return None, None

Roger Wang's avatar
Roger Wang committed
675
676
677
        return _resolve_items(
            dict(self._items_by_modality), self.mm_processor, self._modality_order
        )
678
679
680
681
682

    def create_parser(self) -> "BaseMultiModalContentParser":
        return MultiModalContentParser(self)


683
684
685
686
687
688
class AsyncMultiModalItemTracker(
    BaseMultiModalItemTracker[Awaitable[tuple[object, str | None]]]
):
    async def resolve_items(
        self,
    ) -> tuple[MultiModalDataDict | None, MultiModalUUIDDict | None]:
689
        if not self._items_by_modality:
690
            return None, None
691

692
        resolved_items_by_modality = {
693
            modality: await asyncio.gather(*coros)
694
            for modality, coros in self._items_by_modality.items()
695
        }
696

Roger Wang's avatar
Roger Wang committed
697
698
699
        return _resolve_items(
            resolved_items_by_modality, self.mm_processor, self._modality_order
        )
700
701
702
703
704
705
706
707
708

    def create_parser(self) -> "BaseMultiModalContentParser":
        return AsyncMultiModalContentParser(self)


class BaseMultiModalContentParser(ABC):
    def __init__(self) -> None:
        super().__init__()

709
        # stores model placeholders list with corresponding
710
711
712
713
714
715
716
        # general MM placeholder:
        # {
        #   "<##IMAGE##>": ["<image>", "<image>", "<image>"],
        #   "<##AUDIO##>": ["<audio>", "<audio>"]
        # }
        self._placeholder_storage: dict[str, list] = defaultdict(list)

717
    def _add_placeholder(self, modality: ModalityStr, placeholder: str | None):
718
        mod_placeholder = MODALITY_PLACEHOLDERS_MAP[modality]
719
        if placeholder:
720
            self._placeholder_storage[mod_placeholder].append(placeholder)
721

722
723
    def mm_placeholder_storage(self) -> dict[str, list]:
        return dict(self._placeholder_storage)
724
725

    @abstractmethod
726
    def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
727
728
        raise NotImplementedError

729
    @abstractmethod
730
    def parse_image_embeds(
731
        self,
732
733
        image_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
734
    ) -> None:
735
736
        raise NotImplementedError

737
    @abstractmethod
738
    def parse_image_pil(
739
        self, image_pil: Image.Image | None, uuid: str | None = None
740
    ) -> None:
741
742
        raise NotImplementedError

743
    @abstractmethod
744
    def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None:
745
746
        raise NotImplementedError

747
    @abstractmethod
748
    def parse_input_audio(
749
        self, input_audio: InputAudio | None, uuid: str | None = None
750
    ) -> None:
751
752
        raise NotImplementedError

753
754
755
756
757
758
759
760
    @abstractmethod
    def parse_audio_embeds(
        self,
        audio_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
    ) -> None:
        raise NotImplementedError

761
    @abstractmethod
762
    def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
763
764
        raise NotImplementedError

765
766
767
768
769
770

class MultiModalContentParser(BaseMultiModalContentParser):
    def __init__(self, tracker: MultiModalItemTracker) -> None:
        super().__init__()

        self._tracker = tracker
771
772
773
        multimodal_config = self._tracker.model_config.multimodal_config
        media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None)

774
775
        self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load(
            envs.VLLM_MEDIA_CONNECTOR,
776
            media_io_kwargs=media_io_kwargs,
777
            allowed_local_media_path=tracker.allowed_local_media_path,
778
            allowed_media_domains=tracker.allowed_media_domains,
779
780
        )

781
782
    @property
    def model_config(self) -> ModelConfig:
783
        return self._tracker.model_config
784

785
    def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
786
        image = self._connector.fetch_image(image_url) if image_url else None
787

788
        placeholder = self._tracker.add("image", (image, uuid))
789
        self._add_placeholder("image", placeholder)
790

791
    def parse_image_embeds(
792
        self,
793
794
        image_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
795
    ) -> None:
796
797
798
799
800
801
        mm_config = self.model_config.get_multimodal_config()
        if not mm_config.enable_mm_embeds:
            raise ValueError(
                "You must set `--enable-mm-embeds` to input `image_embeds`"
            )

802
803
804
805
806
        if isinstance(image_embeds, dict):
            embeds = {
                k: self._connector.fetch_image_embedding(v)
                for k, v in image_embeds.items()
            }
807
            placeholder = self._tracker.add("image_embeds", (embeds, uuid))
808
809
810

        if isinstance(image_embeds, str):
            embedding = self._connector.fetch_image_embedding(image_embeds)
811
            placeholder = self._tracker.add("image_embeds", (embedding, uuid))
812

813
        if image_embeds is None:
814
            placeholder = self._tracker.add("image_embeds", (None, uuid))
815

816
        self._add_placeholder("image", placeholder)
817

818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
    def parse_audio_embeds(
        self,
        audio_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
    ) -> None:
        mm_config = self.model_config.get_multimodal_config()
        if not mm_config.enable_mm_embeds:
            raise ValueError(
                "You must set `--enable-mm-embeds` to input `audio_embeds`"
            )

        if isinstance(audio_embeds, dict):
            embeds = {
                k: self._connector.fetch_audio_embedding(v)
                for k, v in audio_embeds.items()
            }
834
            placeholder = self._tracker.add("audio_embeds", (embeds, uuid))
835
836
        elif isinstance(audio_embeds, str):
            embedding = self._connector.fetch_audio_embedding(audio_embeds)
837
            placeholder = self._tracker.add("audio_embeds", (embedding, uuid))
838
        else:
839
            placeholder = self._tracker.add("audio_embeds", (None, uuid))
840
841
842

        self._add_placeholder("audio", placeholder)

843
    def parse_image_pil(
844
        self, image_pil: Image.Image | None, uuid: str | None = None
845
    ) -> None:
846
        placeholder = self._tracker.add("image", (image_pil, uuid))
847
        self._add_placeholder("image", placeholder)
848

849
    def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None:
850
        audio = self._connector.fetch_audio(audio_url) if audio_url else None
851

852
        placeholder = self._tracker.add("audio", (audio, uuid))
853
        self._add_placeholder("audio", placeholder)
854

855
    def parse_input_audio(
856
        self, input_audio: InputAudio | None, uuid: str | None = None
857
    ) -> None:
858
859
860
861
862
863
864
865
866
867
        if input_audio:
            audio_data = input_audio.get("data", "")
            audio_format = input_audio.get("format", "")
            if audio_data:
                audio_url = f"data:audio/{audio_format};base64,{audio_data}"
            else:
                # If a UUID is provided, audio data may be empty.
                audio_url = None
        else:
            audio_url = None
868

869
        return self.parse_audio(audio_url, uuid)
870

871
    def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
872
        video = self._connector.fetch_video(video_url=video_url) if video_url else None
873

874
        placeholder = self._tracker.add("video", (video, uuid))
875
        self._add_placeholder("video", placeholder)
876

877
878
879
880
881
882

class AsyncMultiModalContentParser(BaseMultiModalContentParser):
    def __init__(self, tracker: AsyncMultiModalItemTracker) -> None:
        super().__init__()

        self._tracker = tracker
883
884
        multimodal_config = self._tracker.model_config.multimodal_config
        media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None)
885
886
        self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load(
            envs.VLLM_MEDIA_CONNECTOR,
887
            media_io_kwargs=media_io_kwargs,
888
            allowed_local_media_path=tracker.allowed_local_media_path,
889
            allowed_media_domains=tracker.allowed_media_domains,
890
        )
891

892
893
    @property
    def model_config(self) -> ModelConfig:
894
        return self._tracker.model_config
895

896
897
898
899
900
901
    async def _image_with_uuid_async(self, image_url: str | None, uuid: str | None):
        image = (
            await self._connector.fetch_image_async(image_url) if image_url else None
        )
        return image, uuid

902
    def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
903
        coro = self._image_with_uuid_async(image_url, uuid)
904

905
        placeholder = self._tracker.add("image", coro)
906
        self._add_placeholder("image", placeholder)
907

908
    def parse_image_embeds(
909
        self,
910
911
        image_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
912
    ) -> None:
913
914
915
916
917
918
        mm_config = self.model_config.get_multimodal_config()
        if not mm_config.enable_mm_embeds:
            raise ValueError(
                "You must set `--enable-mm-embeds` to input `image_embeds`"
            )

919
920
921
        future = asyncio.Future[
            tuple[torch.Tensor | dict[str, torch.Tensor] | None, str | None]
        ]()
922
923
924
925
926
927

        if isinstance(image_embeds, dict):
            embeds = {
                k: self._connector.fetch_image_embedding(v)
                for k, v in image_embeds.items()
            }
928
            future.set_result((embeds, uuid))
929
930

        if isinstance(image_embeds, str):
931
            embedding = self._connector.fetch_image_embedding(image_embeds)
932
            future.set_result((embedding, uuid))
933

934
        if image_embeds is None:
935
            future.set_result((None, uuid))
936

937
        placeholder = self._tracker.add("image_embeds", future)
938
        self._add_placeholder("image", placeholder)
939

940
941
942
943
944
945
946
947
948
949
950
    def parse_audio_embeds(
        self,
        audio_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
    ) -> None:
        mm_config = self.model_config.get_multimodal_config()
        if not mm_config.enable_mm_embeds:
            raise ValueError(
                "You must set `--enable-mm-embeds` to input `audio_embeds`"
            )

951
952
953
        future = asyncio.Future[
            tuple[torch.Tensor | dict[str, torch.Tensor] | None, str | None]
        ]()
954
955
956
957
958
959

        if isinstance(audio_embeds, dict):
            embeds = {
                k: self._connector.fetch_audio_embedding(v)
                for k, v in audio_embeds.items()
            }
960
            future.set_result((embeds, uuid))
961
962
963

        if isinstance(audio_embeds, str):
            embedding = self._connector.fetch_audio_embedding(audio_embeds)
964
            future.set_result((embedding, uuid))
965
966

        if audio_embeds is None:
967
            future.set_result((None, uuid))
968

969
        placeholder = self._tracker.add("audio_embeds", future)
970
971
        self._add_placeholder("audio", placeholder)

972
    def parse_image_pil(
973
974
975
        self,
        image_pil: Image.Image | None,
        uuid: str | None = None,
976
    ) -> None:
977
        future = asyncio.Future[tuple[Image.Image | None, str | None]]()
978
        if image_pil:
979
            future.set_result((image_pil, uuid))
980
        else:
981
            future.set_result((None, uuid))
982

983
        placeholder = self._tracker.add("image", future)
984
        self._add_placeholder("image", placeholder)
985

986
987
988
989
990
991
    async def _audio_with_uuid_async(self, audio_url: str | None, uuid: str | None):
        audio = (
            await self._connector.fetch_audio_async(audio_url) if audio_url else None
        )
        return audio, uuid

992
    def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None:
993
        coro = self._audio_with_uuid_async(audio_url, uuid)
994

995
        placeholder = self._tracker.add("audio", coro)
996
        self._add_placeholder("audio", placeholder)
997

998
    def parse_input_audio(
999
        self, input_audio: InputAudio | None, uuid: str | None = None
1000
    ) -> None:
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
        if input_audio:
            audio_data = input_audio.get("data", "")
            audio_format = input_audio.get("format", "")
            if audio_data:
                audio_url = f"data:audio/{audio_format};base64,{audio_data}"
            else:
                # If a UUID is provided, audio data may be empty.
                audio_url = None
        else:
            audio_url = None
1011

1012
        return self.parse_audio(audio_url, uuid)
1013

1014
    async def _video_with_uuid_async(self, video_url: str | None, uuid: str | None):
1015
        video = (
1016
            await self._connector.fetch_video_async(video_url) if video_url else None
1017
        )
1018
1019
1020
1021
        return video, uuid

    def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
        coro = self._video_with_uuid_async(video_url, uuid)
1022

1023
        placeholder = self._tracker.add("video", coro)
1024
        self._add_placeholder("video", placeholder)
1025

1026

1027
def validate_chat_template(chat_template: Path | str | None):
1028
1029
1030
1031
1032
    """Raises if the provided chat template appears invalid."""
    if chat_template is None:
        return

    elif isinstance(chat_template, Path) and not chat_template.exists():
1033
        raise FileNotFoundError("the supplied chat template path doesn't exist")
1034
1035
1036

    elif isinstance(chat_template, str):
        JINJA_CHARS = "{}\n"
1037
1038
1039
1040
        if (
            not any(c in chat_template for c in JINJA_CHARS)
            and not Path(chat_template).exists()
        ):
1041
1042
1043
            # Try to find the template in the built-in templates directory
            from vllm.transformers_utils.chat_templates.registry import (
                CHAT_TEMPLATES_DIR,
1044
            )
1045

1046
1047
1048
1049
1050
1051
1052
1053
            builtin_template_path = CHAT_TEMPLATES_DIR / chat_template
            if not builtin_template_path.exists():
                raise ValueError(
                    f"The supplied chat template string ({chat_template}) "
                    f"appears path-like, but doesn't exist! "
                    f"Tried: {chat_template} and {builtin_template_path}"
                )

1054
    else:
1055
        raise TypeError(f"{type(chat_template)} is not a valid chat template type")
1056
1057


1058
def _load_chat_template(
1059
    chat_template: Path | str | None,
1060
1061
    *,
    is_literal: bool = False,
1062
) -> str | None:
1063
1064
    if chat_template is None:
        return None
1065
1066
1067

    if is_literal:
        if isinstance(chat_template, Path):
1068
1069
1070
            raise TypeError(
                "chat_template is expected to be read directly from its value"
            )
1071

1072
        return chat_template
1073

1074
    try:
1075
        with open(chat_template) as f:
1076
            return f.read()
1077
    except OSError as e:
1078
1079
1080
        if isinstance(chat_template, Path):
            raise

1081
1082
        JINJA_CHARS = "{}\n"
        if not any(c in chat_template for c in JINJA_CHARS):
1083
1084
1085
            # Try to load from the built-in templates directory
            from vllm.transformers_utils.chat_templates.registry import (
                CHAT_TEMPLATES_DIR,
1086
            )
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099

            builtin_template_path = CHAT_TEMPLATES_DIR / chat_template
            try:
                with open(builtin_template_path) as f:
                    return f.read()
            except OSError:
                msg = (
                    f"The supplied chat template ({chat_template}) "
                    f"looks like a file path, but it failed to be opened. "
                    f"Tried: {chat_template} and {builtin_template_path}. "
                    f"Reason: {e}"
                )
                raise ValueError(msg) from e
1100

1101
1102
        # If opening a file fails, set chat template to be args to
        # ensure we decode so our escape are interpreted correctly
1103
1104
1105
1106
1107
1108
1109
        return _load_chat_template(chat_template, is_literal=True)


_cached_load_chat_template = lru_cache(_load_chat_template)


def load_chat_template(
1110
    chat_template: Path | str | None,
1111
1112
    *,
    is_literal: bool = False,
1113
) -> str | None:
1114
    return _cached_load_chat_template(chat_template, is_literal=is_literal)
1115
1116


1117
1118
1119
def _get_interleaved_text_prompt(
    placeholder_storage: dict[str, list], texts: list[str]
) -> str:
1120
1121
1122
1123
1124
1125
1126
    for idx, elem in enumerate(texts):
        if elem in placeholder_storage:
            texts[idx] = placeholder_storage[elem].pop(0)

    return "\n".join(texts)


1127
# TODO: Let user specify how to insert multimodal tokens into prompt
1128
# (similar to chat template)
1129
1130
1131
1132
1133
def _get_full_multimodal_text_prompt(
    placeholder_storage: dict[str, list],
    texts: list[str],
    interleave_strings: bool,
) -> str:
1134
    """Combine multimodal prompts for a multimodal language model."""
1135

1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
    # flatten storage to make it looks like
    # {
    #   "<|image|>": 2,
    #   "<|audio|>": 1
    # }
    placeholder_counts = Counter(
        [v for elem in placeholder_storage.values() for v in elem]
    )

    if interleave_strings:
        text_prompt = _get_interleaved_text_prompt(placeholder_storage, texts)
    else:
        text_prompt = "\n".join(texts)

    # Pass interleaved text further in case the user used image placeholders
    # himself, but forgot to disable the 'interleave_strings' flag

1153
    # Look through the text prompt to check for missing placeholders
1154
    missing_placeholders: list[str] = []
1155
1156
1157
1158
1159
    for placeholder in placeholder_counts:
        # For any existing placeholder in the text prompt, we leave it as is
        placeholder_counts[placeholder] -= text_prompt.count(placeholder)

        if placeholder_counts[placeholder] < 0:
1160
1161
1162
1163
            logger.error(
                "Placeholder count is negative! "
                "Ensure that the 'interleave_strings' flag is disabled "
                "(current value: %s) "
1164
1165
                "when manually placing image placeholders.",
                interleave_strings,
1166
1167
            )
            logger.debug("Input prompt: %s", text_prompt)
1168
1169
            raise ValueError(
                f"Found more '{placeholder}' placeholders in input prompt than "
1170
1171
                "actual multimodal data items."
            )
1172

1173
        missing_placeholders.extend([placeholder] * placeholder_counts[placeholder])
1174

1175
1176
    # NOTE: Default behaviour: we always add missing placeholders
    # at the front of the prompt, if interleave_strings=False
1177
1178
1179
1180
    if text_prompt:
        return "\n".join(missing_placeholders + [text_prompt])
    else:
        return "\n".join(missing_placeholders)
1181
1182


1183
1184
# No need to validate using Pydantic again
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
1185
_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
1186
_AudioEmbedsParser = partial(cast, ChatCompletionContentPartAudioEmbedsParam)
1187
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
1188
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
1189
_PILImageParser = partial(cast, CustomChatCompletionContentPILImageParam)
Julien Denize's avatar
Julien Denize committed
1190
_ThinkParser = partial(cast, CustomThinkCompletionContentParam)
1191
1192
1193
1194
# Need to validate url objects
_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam).validate_python
_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python
_VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python
1195

1196
_ResponsesInputImageParser = TypeAdapter(ResponseInputImageParam).validate_python
1197
_ContentPart: TypeAlias = str | dict[str, str] | InputAudio | PILImage
1198

1199
# Define a mapping from part types to their corresponding parsing functions.
1200
MM_PARSER_MAP: dict[
1201
1202
1203
    str,
    Callable[[ChatCompletionContentPartParam], _ContentPart],
] = {
1204
1205
1206
    "text": lambda part: _TextParser(part).get("text", None),
    "thinking": lambda part: _ThinkParser(part).get("thinking", None),
    "input_text": lambda part: _TextParser(part).get("text", None),
1207
    "output_text": lambda part: _TextParser(part).get("text", None),
1208
1209
1210
    "input_image": lambda part: _ResponsesInputImageParser(part).get("image_url", None),
    "image_url": lambda part: _ImageParser(part).get("image_url", {}).get("url", None),
    "image_embeds": lambda part: _ImageEmbedsParser(part).get("image_embeds", None),
1211
    "audio_embeds": lambda part: _AudioEmbedsParser(part).get("audio_embeds", None),
1212
    "image_pil": lambda part: _PILImageParser(part).get("image_pil", None),
1213
1214
    "audio_url": lambda part: _AudioParser(part).get("audio_url", {}).get("url", None),
    "input_audio": lambda part: _InputAudioParser(part).get("input_audio", None),
1215
    "refusal": lambda part: _RefusalParser(part).get("refusal", None),
1216
    "video_url": lambda part: _VideoParser(part).get("video_url", {}).get("url", None),
1217
1218
1219
1220
}


def _parse_chat_message_content_mm_part(
1221
1222
    part: ChatCompletionContentPartParam,
) -> tuple[str, _ContentPart]:
1223
    """
1224
    Parses a given multi-modal content part based on its type.
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237

    Args:
        part: A dict containing the content part, with a potential 'type' field.

    Returns:
        A tuple (part_type, content) where:
        - part_type: Type of the part (e.g., 'text', 'image_url').
        - content: Parsed content (e.g., text, image URL).

    Raises:
        ValueError: If the 'type' field is missing and no direct URL is found.
    """
    assert isinstance(
1238
1239
        part, dict
    )  # This is needed to avoid mypy errors: part.get() from str
1240
    part_type = part.get("type", None)
1241
    uuid = part.get("uuid", None)
1242

1243
    if isinstance(part_type, str) and part_type in MM_PARSER_MAP and uuid is None:  # noqa: E501
1244
1245
1246
        content = MM_PARSER_MAP[part_type](part)

        # Special case for 'image_url.detail'
1247
1248
        # We only support 'auto', which is the default
        if part_type == "image_url" and part.get("detail", "auto") != "auto":
1249
            logger.warning(
1250
                "'image_url.detail' is currently not supported and will be ignored."
1251
            )
1252
1253
1254
1255

        return part_type, content

    # Handle missing 'type' but provided direct URL fields.
1256
    # 'type' is required field by pydantic
1257
1258
    if part_type is None or uuid is not None:
        if "image_url" in part:
1259
            image_params = cast(CustomChatCompletionContentSimpleImageParam, part)
1260
1261
1262
1263
1264
1265
1266
1267
            image_url = image_params.get("image_url", None)
            if isinstance(image_url, dict):
                # Can potentially happen if user provides a uuid
                # with url as a dict of {"url": url}
                image_url = image_url.get("url", None)
            return "image_url", image_url
        if "image_pil" in part:
            # "image_pil" could be None if UUID is provided.
1268
            image_params = cast(  # type: ignore
1269
1270
1271
1272
1273
1274
                CustomChatCompletionContentPILImageParam, part
            )
            image_pil = image_params.get("image_pil", None)
            return "image_pil", image_pil
        if "image_embeds" in part:
            # "image_embeds" could be None if UUID is provided.
1275
            image_params = cast(  # type: ignore
1276
1277
1278
1279
                ChatCompletionContentPartImageEmbedsParam, part
            )
            image_embeds = image_params.get("image_embeds", None)
            return "image_embeds", image_embeds
1280
1281
1282
1283
1284
1285
1286
        if "audio_embeds" in part:
            # "audio_embeds" could be None if UUID is provided.
            audio_params = cast(  # type: ignore[assignment]
                ChatCompletionContentPartAudioEmbedsParam, part
            )
            audio_embeds = audio_params.get("audio_embeds", None)
            return "audio_embeds", audio_embeds
1287
        if "audio_url" in part:
1288
1289
1290
            audio_params = cast(  # type: ignore[assignment]
                CustomChatCompletionContentSimpleAudioParam, part
            )
1291
1292
1293
1294
1295
1296
            audio_url = audio_params.get("audio_url", None)
            if isinstance(audio_url, dict):
                # Can potentially happen if user provides a uuid
                # with url as a dict of {"url": url}
                audio_url = audio_url.get("url", None)
            return "audio_url", audio_url
1297
        if part.get("input_audio") is not None:
1298
            input_audio_params = cast(dict[str, str], part)
1299
            return "input_audio", input_audio_params
1300
        if "video_url" in part:
1301
            video_params = cast(CustomChatCompletionContentSimpleVideoParam, part)
1302
1303
1304
1305
1306
1307
            video_url = video_params.get("video_url", None)
            if isinstance(video_url, dict):
                # Can potentially happen if user provides a uuid
                # with url as a dict of {"url": url}
                video_url = video_url.get("url", None)
            return "video_url", video_url
1308
1309
1310
1311
1312
1313
1314
1315
        # Raise an error if no 'type' or direct URL is found.
        raise ValueError("Missing 'type' field in multimodal part.")

    if not isinstance(part_type, str):
        raise ValueError("Invalid 'type' field in multimodal part.")
    return part_type, "unknown part_type content"


1316
PART_TYPES_TO_SKIP_NONE_CONTENT = (
1317
1318
1319
    "text",
    "refusal",
)
1320

1321

1322
1323
1324
def _parse_chat_message_content_parts(
    role: str,
    parts: Iterable[ChatCompletionContentPartParam],
1325
    mm_tracker: BaseMultiModalItemTracker,
1326
1327
    *,
    wrap_dicts: bool,
1328
    interleave_strings: bool,
1329
) -> list[ConversationMessage]:
1330
    content = list[_ContentPart]()
1331

1332
    mm_parser = mm_tracker.create_parser()
1333
1334

    for part in parts:
1335
        parse_res = _parse_chat_message_content_part(
1336
1337
1338
            part,
            mm_parser,
            wrap_dicts=wrap_dicts,
1339
            interleave_strings=interleave_strings,
1340
        )
1341
1342
        if parse_res:
            content.append(parse_res)
1343

1344
    if wrap_dicts:
1345
        # Parsing wraps images and texts as interleaved dictionaries
1346
        return [ConversationMessage(role=role, content=content)]  # type: ignore
1347
    texts = cast(list[str], content)
1348
1349
    mm_placeholder_storage = mm_parser.mm_placeholder_storage()
    if mm_placeholder_storage:
1350
1351
1352
        text_prompt = _get_full_multimodal_text_prompt(
            mm_placeholder_storage, texts, interleave_strings
        )
1353
1354
1355
    else:
        text_prompt = "\n".join(texts)

1356
1357
1358
1359
    return [ConversationMessage(role=role, content=text_prompt)]


def _parse_chat_message_content_part(
1360
1361
1362
1363
    part: ChatCompletionContentPartParam,
    mm_parser: BaseMultiModalContentParser,
    *,
    wrap_dicts: bool,
1364
    interleave_strings: bool,
1365
) -> _ContentPart | None:
1366
1367
1368
1369
1370
1371
1372
1373
    """Parses a single part of a conversation. If wrap_dicts is True,
    structured dictionary pieces for texts and images will be
    wrapped in dictionaries, i.e., {"type": "text", "text", ...} and
    {"type": "image"}, respectively. Otherwise multimodal data will be
    handled by mm_parser, and texts will be returned as strings to be joined
    with multimodal placeholders.
    """
    if isinstance(part, str):  # Handle plain text parts
1374
        return part
1375
1376
    # Handle structured dictionary parts
    part_type, content = _parse_chat_message_content_mm_part(part)
1377
    # if part_type is text/refusal/image_url/audio_url/video_url/input_audio but
1378
    # content is None, log a warning and skip
1379
    if part_type in PART_TYPES_TO_SKIP_NONE_CONTENT and content is None:
1380
        logger.warning(
1381
            "Skipping multimodal part '%s' (type: '%s') "
1382
1383
1384
1385
            "with empty / unparsable content.",
            part,
            part_type,
        )
1386
1387
        return None

1388
    if part_type in ("text", "input_text", "output_text", "refusal", "thinking"):
1389
1390
        str_content = cast(str, content)
        if wrap_dicts:
1391
            return {"type": "text", "text": str_content}
1392
1393
        else:
            return str_content
1394

1395
1396
1397
1398
1399
1400
    # For media items, if a user has provided one, use it. Otherwise, insert
    # a placeholder empty uuid.
    uuid = part.get("uuid", None)
    if uuid is not None:
        uuid = str(uuid)

1401
    modality = None
1402
    if part_type == "image_pil":
1403
        image_content = cast(Image.Image, content) if content is not None else None
1404
        mm_parser.parse_image_pil(image_content, uuid)
1405
        modality = "image"
1406
    elif part_type in ("image_url", "input_image"):
1407
        str_content = cast(str, content)
1408
        mm_parser.parse_image(str_content, uuid)
1409
1410
        modality = "image"
    elif part_type == "image_embeds":
1411
        content = cast(str | dict[str, str], content) if content is not None else None
1412
        mm_parser.parse_image_embeds(content, uuid)
1413
        modality = "image"
1414
1415
1416
1417
    elif part_type == "audio_embeds":
        content = cast(str | dict[str, str], content) if content is not None else None
        mm_parser.parse_audio_embeds(content, uuid)
        modality = "audio"
1418
    elif part_type == "audio_url":
1419
        str_content = cast(str, content)
1420
        mm_parser.parse_audio(str_content, uuid)
1421
1422
        modality = "audio"
    elif part_type == "input_audio":
1423
        dict_content = cast(InputAudio, content)
1424
        mm_parser.parse_input_audio(dict_content, uuid)
1425
1426
        modality = "audio"
    elif part_type == "video_url":
1427
        str_content = cast(str, content)
1428
        mm_parser.parse_video(str_content, uuid)
1429
1430
1431
        modality = "video"
    else:
        raise NotImplementedError(f"Unknown part type: {part_type}")
1432

1433
1434
1435
    return (
        {"type": modality}
        if wrap_dicts
1436
        else (MODALITY_PLACEHOLDERS_MAP[modality] if interleave_strings else None)
1437
    )
1438
1439


1440
1441
1442
1443
1444
# No need to validate using Pydantic again
_AssistantParser = partial(cast, ChatCompletionAssistantMessageParam)
_ToolParser = partial(cast, ChatCompletionToolMessageParam)


1445
def _parse_chat_message_content(
1446
1447
    message: ChatCompletionMessageParam,
    mm_tracker: BaseMultiModalItemTracker,
1448
    content_format: ChatTemplateContentFormat,
1449
    interleave_strings: bool,
1450
) -> list[ConversationMessage]:
1451
1452
    role = message["role"]
    content = message.get("content")
1453
    reasoning = message.get("reasoning")
1454

1455
    if content is None:
1456
1457
        content = []
    elif isinstance(content, str):
1458
        content = [ChatCompletionContentPartTextParam(type="text", text=content)]
1459
    result = _parse_chat_message_content_parts(
1460
1461
        role,
        content,  # type: ignore
1462
        mm_tracker,
1463
        wrap_dicts=(content_format == "openai"),
1464
        interleave_strings=interleave_strings,
1465
    )
1466

1467
    for result_msg in result:
1468
        if role == "assistant":
1469
1470
            parsed_msg = _AssistantParser(message)

1471
1472
1473
            # The 'tool_calls' is not None check ensures compatibility.
            # It's needed only if downstream code doesn't strictly
            # follow the OpenAI spec.
1474
            if "tool_calls" in parsed_msg and parsed_msg["tool_calls"] is not None:
1475
                result_msg["tool_calls"] = list(parsed_msg["tool_calls"])
1476
1477
1478
            # Include reasoning if present for interleaved thinking.
            if reasoning is not None:
                result_msg["reasoning"] = cast(str, reasoning)
1479
1480
1481
                result_msg["reasoning_content"] = cast(
                    str, reasoning
                )  # keep compatibility
1482
1483
1484
1485
1486
1487
1488
1489
        elif role == "tool":
            parsed_msg = _ToolParser(message)
            if "tool_call_id" in parsed_msg:
                result_msg["tool_call_id"] = parsed_msg["tool_call_id"]

        if "name" in message and isinstance(message["name"], str):
            result_msg["name"] = message["name"]

1490
1491
        if role == "developer":
            result_msg["tools"] = message.get("tools", None)
1492
1493
    return result

1494

1495
def _postprocess_messages(messages: list[ConversationMessage]) -> None:
1496
1497
1498
1499
1500
1501
    # per the Transformers docs & maintainers, tool call arguments in
    # assistant-role messages with tool_calls need to be dicts not JSON str -
    # this is how tool-use chat templates will expect them moving forwards
    # so, for messages that have tool_calls, parse the string (which we get
    # from openAI format) to dict
    for message in messages:
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
        if message["role"] == "assistant" and "tool_calls" in message:
            tool_calls = message.get("tool_calls")
            if not isinstance(tool_calls, list):
                continue

            if len(tool_calls) == 0:
                # Drop empty tool_calls to keep templates on the normal assistant path.
                message.pop("tool_calls", None)
                continue

            for item in tool_calls:
1513
1514
                # if arguments is None or empty string, set to {}
                if content := item["function"].get("arguments"):
1515
1516
                    if not isinstance(content, (dict, list)):
                        item["function"]["arguments"] = json.loads(content)
1517
1518
                else:
                    item["function"]["arguments"] = {}
1519
1520


1521
def parse_chat_messages(
1522
    messages: list[ChatCompletionMessageParam],
1523
    model_config: ModelConfig,
1524
    content_format: ChatTemplateContentFormat,
1525
1526
) -> tuple[
    list[ConversationMessage],
1527
1528
    MultiModalDataDict | None,
    MultiModalUUIDDict | None,
1529
]:
1530
    conversation: list[ConversationMessage] = []
1531
    mm_tracker = MultiModalItemTracker(model_config)
1532
1533

    for msg in messages:
1534
1535
1536
        sub_messages = _parse_chat_message_content(
            msg,
            mm_tracker,
1537
            content_format,
1538
1539
1540
1541
            interleave_strings=(
                content_format == "string"
                and model_config.multimodal_config is not None
                and model_config.multimodal_config.interleave_mm_strings
1542
            ),
1543
        )
1544

1545
        conversation.extend(sub_messages)
1546

1547
1548
    _postprocess_messages(conversation)

1549
1550
1551
    mm_data, mm_uuids = mm_tracker.resolve_items()

    return conversation, mm_data, mm_uuids
1552
1553


1554
async def parse_chat_messages_async(
1555
    messages: list[ChatCompletionMessageParam],
1556
    model_config: ModelConfig,
1557
    content_format: ChatTemplateContentFormat,
1558
1559
) -> tuple[
    list[ConversationMessage],
1560
    MultiModalDataDict | None,
1561
    MultiModalUUIDDict | None,
1562
]:
1563
    conversation: list[ConversationMessage] = []
1564
    mm_tracker = AsyncMultiModalItemTracker(model_config)
1565
1566

    for msg in messages:
1567
1568
1569
        sub_messages = _parse_chat_message_content(
            msg,
            mm_tracker,
1570
            content_format,
1571
1572
1573
1574
            interleave_strings=(
                content_format == "string"
                and model_config.multimodal_config is not None
                and model_config.multimodal_config.interleave_mm_strings
1575
            ),
1576
        )
1577
1578
1579

        conversation.extend(sub_messages)

1580
1581
    _postprocess_messages(conversation)

1582
1583
1584
    mm_data, mm_uuids = await mm_tracker.resolve_items()

    return conversation, mm_data, mm_uuids
1585

1586

1587
1588
1589
def get_history_tool_calls_cnt(conversation: list[ConversationMessage]):
    idx = 0
    for msg in conversation:
1590
1591
1592
        if msg["role"] == "assistant":
            tool_calls = msg.get("tool_calls")
            idx += len(list(tool_calls)) if tool_calls is not None else 0  # noqa
1593
1594
1595
    return idx


1596
1597
1598
def make_tool_call_id(id_type: str = "random", func_name=None, idx=None):
    if id_type == "kimi_k2":
        return f"functions.{func_name}:{idx}"
1599
1600
1601
    else:
        # by default return random
        return f"chatcmpl-tool-{random_uuid()}"