chat_utils.py 58.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
import json
6
import warnings
7
from abc import ABC, abstractmethod
8
from collections import Counter, defaultdict
9
from collections.abc import Awaitable, Callable, Iterable
10
from dataclasses import dataclass
11
from functools import cached_property, lru_cache, partial
12
from itertools import accumulate
13
from pathlib import Path
14
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, TypeVar, cast
15

16
from openai.types.chat import (
17
18
19
20
21
    ChatCompletionAssistantMessageParam,
    ChatCompletionContentPartImageParam,
    ChatCompletionContentPartInputAudioParam,
    ChatCompletionContentPartRefusalParam,
    ChatCompletionContentPartTextParam,
22
    ChatCompletionFunctionToolParam,
23
24
25
26
27
28
    ChatCompletionMessageToolCallParam,
    ChatCompletionToolMessageParam,
)
from openai.types.chat import (
    ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam,
)
29
from openai.types.chat import (
30
31
32
    ChatCompletionMessageParam as OpenAIChatCompletionMessageParam,
)
from openai.types.chat.chat_completion_content_part_input_audio_param import InputAudio
33
from openai.types.responses import ResponseInputImageParam
34
from openai_harmony import Message as OpenAIHarmonyMessage
35
36
from PIL import Image
from pydantic import BaseModel, ConfigDict, TypeAdapter
37

38
# pydantic needs the TypedDict from typing_extensions
39
from typing_extensions import Required, TypedDict
40

41
from vllm import envs
42
from vllm.config import ModelConfig
43
from vllm.inputs import MultiModalDataDict, MultiModalUUIDDict
44
from vllm.logger import init_logger
45
from vllm.model_executor.models import SupportsMultiModal
46
from vllm.multimodal import MULTIMODAL_REGISTRY
47
48
49
50
from vllm.multimodal.inputs import (
    MultiModalBatchedField,
    MultiModalFlatField,
    MultiModalSharedField,
Roger Wang's avatar
Roger Wang committed
51
52
53
    VisionChunk,
    VisionChunkImage,
    VisionChunkVideo,
54
)
55
from vllm.multimodal.media import MEDIA_CONNECTOR_REGISTRY, MediaConnector
56
from vllm.multimodal.processing import BaseMultiModalProcessor
57
from vllm.utils import random_uuid
58
59
60
61
62
from vllm.utils.collection_utils import is_list_of
from vllm.utils.import_utils import LazyLoader

if TYPE_CHECKING:
    import torch
63
    import transformers
64
else:
65
    transformers = LazyLoader("transformers", globals(), "transformers")
66
    torch = LazyLoader("torch", globals(), "torch")
67
68
69

logger = init_logger(__name__)

70

71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def __getattr__(name: str):
    if name == "resolve_hf_chat_template":
        from vllm.renderers.hf import resolve_chat_template

        warnings.warn(
            "`vllm.entrypoints.chat_utils.resolve_hf_chat_template` has been moved to "
            "`vllm.renderers.hf.resolve_chat_template`. "
            "The old name will be removed in v0.16.",
            DeprecationWarning,
            stacklevel=2,
        )

        return resolve_chat_template

    raise AttributeError(f"module {__name__!r} has no attribute {name!r}")


88
89
90
91
92
93
94
95
class ChatTemplateResolutionError(ValueError):
    """Raised when chat template resolution fails.

    This is a subclass of ValueError for backward compatibility with
    existing exception handlers.
    """


96
97
98
99
100
101
MODALITY_PLACEHOLDERS_MAP = {
    "image": "<##IMAGE##>",
    "audio": "<##AUDIO##>",
    "video": "<##VIDEO##>",
}

102

103
104
105
106
107
108
109
110
111
112
113
114
115
116
class AudioURL(TypedDict, total=False):
    url: Required[str]
    """
    Either a URL of the audio or a data URL with base64 encoded audio data.
    """


class ChatCompletionContentPartAudioParam(TypedDict, total=False):
    audio_url: Required[AudioURL]

    type: Required[Literal["audio_url"]]
    """The type of the content part."""


117
class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
118
    image_embeds: str | dict[str, str] | None
119
120
121
122
123
124
125
    """
    The image embeddings. It can be either:
    - A single base64 string.
    - A dictionary where each value is a base64 string.
    """
    type: Required[Literal["image_embeds"]]
    """The type of the content part."""
126
    uuid: str | None
127
128
129
130
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
131
132


133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
class ChatCompletionContentPartAudioEmbedsParam(TypedDict, total=False):
    audio_embeds: str | dict[str, str] | None
    """
    The audio embeddings. It can be either:
    - A single base64 string representing a serialized torch tensor.
    - A dictionary where each value is a base64 string.
    """
    type: Required[Literal["audio_embeds"]]
    """The type of the content part."""
    uuid: str | None
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """


149
150
151
152
153
154
155
156
157
158
159
160
161
162
class VideoURL(TypedDict, total=False):
    url: Required[str]
    """
    Either a URL of the video or a data URL with base64 encoded video data.
    """


class ChatCompletionContentPartVideoParam(TypedDict, total=False):
    video_url: Required[VideoURL]

    type: Required[Literal["video_url"]]
    """The type of the content part."""


163
164
165
166
class PILImage(BaseModel):
    """
    A PIL.Image.Image object.
    """
167

168
169
170
171
172
173
174
175
176
177
178
179
    image_pil: Image.Image
    model_config = ConfigDict(arbitrary_types_allowed=True)


class CustomChatCompletionContentPILImageParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a PIL image.

    Example:
    {
        "image_pil": ImageAsset('cherry_blossom').pil_image
    }
    """
180

181
182
    image_pil: PILImage | None
    uuid: str | None
183
184
185
186
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
187
188


189
190
191
class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain image_url.
    This is supported by OpenAI API, although it is not documented.
192

193
194
195
196
197
    Example:
    {
        "image_url": "https://example.com/image.jpg"
    }
    """
198

199
200
    image_url: str | None
    uuid: str | None
201
202
203
204
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
205
206
207
208


class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain audio_url.
209

210
211
212
213
214
    Example:
    {
        "audio_url": "https://example.com/audio.mp3"
    }
    """
215

216
    audio_url: str | None
217
218


219
220
221
222
223
224
225
226
class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain audio_url.

    Example:
    {
        "video_url": "https://example.com/video.mp4"
    }
    """
227

228
229
    video_url: str | None
    uuid: str | None
230
231
232
233
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
234
235


Julien Denize's avatar
Julien Denize committed
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
class CustomThinkCompletionContentParam(TypedDict, total=False):
    """A Think Completion Content Param that accepts a plain text and a boolean.

    Example:
    {
        "thinking": "I am thinking about the answer",
        "closed": True,
        "type": "thinking"
    }
    """

    thinking: Required[str]
    """The thinking content."""

    closed: bool
    """Whether the thinking is closed."""

    type: Required[Literal["thinking"]]
    """The thinking type."""


257
258
259
260
261
262
263
264
265
ChatCompletionContentPartParam: TypeAlias = (
    OpenAIChatCompletionContentPartParam
    | ChatCompletionContentPartAudioParam
    | ChatCompletionContentPartInputAudioParam
    | ChatCompletionContentPartVideoParam
    | ChatCompletionContentPartRefusalParam
    | CustomChatCompletionContentPILImageParam
    | CustomChatCompletionContentSimpleImageParam
    | ChatCompletionContentPartImageEmbedsParam
266
    | ChatCompletionContentPartAudioEmbedsParam
267
268
269
270
271
    | CustomChatCompletionContentSimpleAudioParam
    | CustomChatCompletionContentSimpleVideoParam
    | str
    | CustomThinkCompletionContentParam
)
272
273
274
275


class CustomChatCompletionMessageParam(TypedDict, total=False):
    """Enables custom roles in the Chat Completion API."""
276

277
278
279
    role: Required[str]
    """The role of the message's author."""

280
    content: str | list[ChatCompletionContentPartParam]
281
282
283
284
285
286
287
288
289
    """The contents of the message."""

    name: str
    """An optional name for the participant.

    Provides the model information to differentiate between participants of the
    same role.
    """

290
    tool_call_id: str | None
291
292
    """Tool call that this message is responding to."""

293
    tool_calls: list[ChatCompletionMessageToolCallParam] | None
294
295
    """The tool calls generated by the model, such as function calls."""

296
297
298
    reasoning: str | None
    """The reasoning content for interleaved thinking."""

299
300
301
    tools: list[ChatCompletionFunctionToolParam] | None
    """The tools for developer role."""

302

303
304
305
306
307
ChatCompletionMessageParam: TypeAlias = (
    OpenAIChatCompletionMessageParam
    | CustomChatCompletionMessageParam
    | OpenAIHarmonyMessage
)
308
309


310
# TODO: Make fields ReadOnly once mypy supports it
311
312
313
314
class ConversationMessage(TypedDict, total=False):
    role: Required[str]
    """The role of the message's author."""

315
    content: str | None | list[dict[str, str]]
316
317
    """The contents of the message"""

318
    tool_call_id: str | None
319
320
    """Tool call that this message is responding to."""

321
    name: str | None
322
323
    """The name of the function to call"""

324
    tool_calls: list[ChatCompletionMessageToolCallParam] | None
325
    """The tool calls generated by the model, such as function calls."""
326

327
328
329
330
331
332
    reasoning: str | None
    """The reasoning content for interleaved thinking."""

    reasoning_content: str | None
    """Deprecated: The reasoning content for interleaved thinking."""

333
334
335
    tools: list[ChatCompletionFunctionToolParam] | None
    """The tools for developer role."""

336

337
338
339
# Passed in by user
ChatTemplateContentFormatOption = Literal["auto", "string", "openai"]

340
341
# After resolving "auto"
ChatTemplateContentFormat = Literal["string", "openai"]
342

343

Roger Wang's avatar
Roger Wang committed
344
345
346
ModalityStr = Literal[
    "image", "audio", "video", "image_embeds", "audio_embeds", "vision_chunk"
]
347
348
349
_T = TypeVar("_T")


350
351
352
# Backward compatibility for single item input
class _BatchedSingleItemField(MultiModalSharedField):
    pass
353
354


355
356
357
358
359
360
def _detect_field(
    tensors: list[torch.Tensor],
    mm_processor: BaseMultiModalProcessor,
):
    first_item = tensors[0]
    hidden_size = mm_processor.info.ctx.model_config.get_inputs_embeds_size()
361

362
363
364
365
366
367
368
369
370
371
372
    if (
        len(tensors) == 1
        and first_item.ndim == 3
        and first_item.shape[0] == 1
        and first_item.shape[-1] == hidden_size
    ):
        logger.warning(
            "Batched multi-modal embedding inputs are deprecated for Chat API. "
            "Please pass a separate content part for each multi-modal item."
        )
        return _BatchedSingleItemField(batch_size=1)
373

374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
    first_shape = first_item.shape
    if all(t.shape == first_shape for t in tensors):
        return MultiModalBatchedField()

    size_per_item = [len(tensor) for tensor in tensors]
    slice_idxs = [0, *accumulate(size_per_item)]
    slices = [
        (slice(slice_idxs[i], slice_idxs[i + 1]),) for i in range(len(size_per_item))
    ]
    return MultiModalFlatField(slices=slices)


def _merge_embeds(
    data_items: list[dict[str, "torch.Tensor"]],
    mm_processor: BaseMultiModalProcessor,
):
    if not data_items:
        return {}

    first_keys = set(data_items[0].keys())
    if any(set(item.keys()) != first_keys for item in data_items[1:]):
        raise ValueError(
            "All dictionaries in the list of embeddings must have the same keys."
        )
398

399
400
401
402
403
404
405
406
    fields = {
        key: _detect_field([item[key] for item in data_items], mm_processor)
        for key in first_keys
    }
    data_merged = {
        key: field._reduce_data([item[key] for item in data_items], pin_memory=False)
        for key, field in fields.items()
    }
407

408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
    try:
        # TODO: Support per-request mm_processor_kwargs
        parsed_configs = mm_processor._get_mm_fields_config(
            transformers.BatchFeature(data_merged),
            {},
        )
        parsed_fields = {key: parsed_configs[key].field for key in first_keys}
        keys_to_update = [
            key
            for key in first_keys
            if (
                fields[key] != parsed_fields[key]
                and not isinstance(fields[key], _BatchedSingleItemField)
            )
        ]
423

424
425
426
        for key in keys_to_update:
            data_merged[key] = parsed_fields[key]._reduce_data(
                [item[key] for item in data_items], pin_memory=False
427
            )
428
429
430
431
432
    except Exception:
        logger.exception(
            "Error when parsing merged embeddings. "
            "Falling back to auto-detected fields."
        )
433

434
    return data_merged
435

436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456

def _get_embeds_data(
    modality: str,
    data_items: list[Any],
    mm_processor: BaseMultiModalProcessor,
):
    if len(data_items) == 0:
        return data_items

    if all(item is None for item in data_items):
        return data_items

    if is_list_of(data_items, torch.Tensor):
        embeds_key = f"{modality}_embeds"
        dict_items = [{embeds_key: item} for item in data_items]
        return _merge_embeds(dict_items, mm_processor)[embeds_key]

    if is_list_of(data_items, dict):
        return _merge_embeds(data_items, mm_processor)

    raise NotImplementedError(type(data_items))
457
458


459
class BaseMultiModalItemTracker(ABC, Generic[_T]):
460
461
462
463
464
465
    """
    Tracks multi-modal items in a given request and ensures that the number
    of multi-modal items in a given request does not exceed the configured
    maximum per prompt.
    """

466
467
468
469
470
    def __init__(
        self,
        model_config: ModelConfig,
        media_io_kwargs: dict[str, dict[str, Any]] | None = None,
    ):
471
472
        super().__init__()

473
        self._model_config = model_config
474
        self._media_io_kwargs = media_io_kwargs
475

476
        self._items_by_modality = defaultdict[str, list[_T]](list)
Roger Wang's avatar
Roger Wang committed
477
478
479
480
481
482
483
        # Track original modality for each vision_chunk item (image or video)
        self._modality_order = defaultdict[str, list[str]](list)

    @cached_property
    def use_unified_vision_chunk_modality(self) -> bool:
        """Check if model uses unified vision_chunk modality for images/videos."""
        return getattr(self._model_config.hf_config, "use_unified_vision_chunk", False)
484

485
    @property
486
487
    def model_config(self) -> ModelConfig:
        return self._model_config
488

489
    @cached_property
490
    def model_cls(self) -> type[SupportsMultiModal]:
491
        from vllm.model_executor.model_loader import get_model_cls
492

493
        model_cls = get_model_cls(self.model_config)
494
        return cast(type[SupportsMultiModal], model_cls)
495

496
497
498
499
500
501
502
503
    @property
    def media_io_kwargs(self) -> dict[str, dict[str, Any]] | None:
        return self._media_io_kwargs or (
            self._model_config.multimodal_config.media_io_kwargs
            if self._model_config.multimodal_config
            else None
        )

504
505
    @property
    def allowed_local_media_path(self):
506
        return self._model_config.allowed_local_media_path
507

508
509
    @property
    def allowed_media_domains(self):
510
        return self._model_config.allowed_media_domains
511

512
513
514
515
    @property
    def mm_registry(self):
        return MULTIMODAL_REGISTRY

516
517
    @cached_property
    def mm_processor(self):
518
        return self.mm_registry.create_processor(self.model_config)
519

520
    def add(self, modality: ModalityStr, item: _T) -> str | None:
521
522
523
        """
        Add a multi-modal item to the current prompt and returns the
        placeholder string to use, if any.
524
525

        An optional uuid can be added which serves as a unique identifier of the
526
        media.
527
        """
528
        input_modality = modality.replace("_embeds", "")
Roger Wang's avatar
Roger Wang committed
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
        original_modality = modality
        use_vision_chunk = (
            self.use_unified_vision_chunk_modality
            and original_modality in ["video", "image"]
        )

        # If use_unified_vision_chunk_modality is enabled,
        # map image/video to vision_chunk
        if use_vision_chunk:
            # To avoid validation fail
            # because models with use_unified_vision_chunk_modality=True
            # will only accept vision_chunk modality.
            input_modality = "vision_chunk"
            num_items = len(self._items_by_modality[input_modality]) + 1
        else:
            num_items = len(self._items_by_modality[original_modality]) + 1
545

546
547
548
549
550
551
552
553
554
555
556
        mm_config = self.model_config.multimodal_config
        if (
            mm_config is not None
            and mm_config.enable_mm_embeds
            and mm_config.get_limit_per_prompt(input_modality) == 0
            and original_modality.endswith("_embeds")
        ):
            # Skip validation: embeddings bypass limit when enable_mm_embeds=True
            pass
        else:
            self.mm_processor.info.validate_num_items(input_modality, num_items)
557

Roger Wang's avatar
Roger Wang committed
558
559
560
561
562
563
        # Track original modality for vision_chunk items
        if use_vision_chunk:
            self._items_by_modality[input_modality].append(item)  # type: ignore
            self._modality_order["vision_chunk"].append(original_modality)
        else:
            self._items_by_modality[original_modality].append(item)
564

565
        return self.model_cls.get_placeholder_str(modality, num_items)
566
567

    @abstractmethod
568
569
570
    def create_parser(
        self, mm_processor_kwargs: dict[str, Any] | None = None
    ) -> "BaseMultiModalContentParser":
571
572
573
        raise NotImplementedError


574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
def _resolve_vision_chunk_items(
    vision_chunk_items: list[tuple[object, str | None]],
    mm_processor: BaseMultiModalProcessor,
    vision_chunks_modality_order: list[str],
):
    # Process vision_chunk items - extract from (data, modality) tuples
    # and convert to VisionChunk types with proper UUID handling
    vision_chunks_uuids = [uuid for data, uuid in vision_chunk_items]

    assert len(vision_chunk_items) == len(vision_chunks_modality_order), (
        f"vision_chunk items ({len(vision_chunk_items)}) and "
        f"modality_order ({len(vision_chunks_modality_order)}) must have same length"
    )

    processed_chunks: list[VisionChunk] = []
    video_idx = 0
    for inner_modality, (data, uuid) in zip(
        vision_chunks_modality_order, vision_chunk_items
    ):
        if inner_modality == "image":
            # Cast data to proper type for image
            # Use .media (PIL.Image) directly to avoid redundant
            # bytes→PIL conversion in media_processor
            if hasattr(data, "media"):
                image_data = data.media  # type: ignore[union-attr]
                processed_chunks.append(
                    VisionChunkImage(type="image", image=image_data, uuid=uuid)
                )
            else:
                processed_chunks.append(data)  # type: ignore[arg-type]
        elif inner_modality == "video":
            # For video, we may need to split into chunks
            # if processor supports it
            # For now, just wrap as a video chunk placeholder
            if hasattr(mm_processor, "split_video_chunks") and data is not None:
                try:
                    video_uuid = uuid or random_uuid()
                    # video await result is (video_data, video_meta) tuple
                    if isinstance(data, tuple) and len(data) >= 1:
                        video_data = data[0]
                    else:
                        video_data = data
                    video_chunks = mm_processor.split_video_chunks(video_data)
                    for i, vc in enumerate(video_chunks):
                        processed_chunks.append(
                            VisionChunkVideo(
                                type="video_chunk",
                                video_chunk=vc["video_chunk"],
                                uuid=f"{video_uuid}-{i}",
                                video_idx=video_idx,
                                prompt=vc["prompt"],
                            )
                        )
                    video_idx += 1
                except Exception as e:
                    logger.warning("Failed to split video chunks: %s", e)
                    processed_chunks.append(data)  # type: ignore[arg-type]
            else:
                processed_chunks.append(data)  # type: ignore[arg-type]
    return processed_chunks, vision_chunks_uuids


636
637
638
def _resolve_items(
    items_by_modality: dict[str, list[tuple[object, str | None]]],
    mm_processor: BaseMultiModalProcessor,
639
    modality_order: dict[str, list[str]],
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
) -> tuple[MultiModalDataDict, MultiModalUUIDDict]:
    if "image" in items_by_modality and "image_embeds" in items_by_modality:
        raise ValueError("Mixing raw image and embedding inputs is not allowed")
    if "audio" in items_by_modality and "audio_embeds" in items_by_modality:
        raise ValueError("Mixing raw audio and embedding inputs is not allowed")

    mm_data = {}
    mm_uuids = {}
    if "image_embeds" in items_by_modality:
        mm_data["image"] = _get_embeds_data(
            "image",
            [data for data, uuid in items_by_modality["image_embeds"]],
            mm_processor,
        )
        mm_uuids["image"] = [uuid for data, uuid in items_by_modality["image_embeds"]]
    if "image" in items_by_modality:
        mm_data["image"] = [data for data, uuid in items_by_modality["image"]]
        mm_uuids["image"] = [uuid for data, uuid in items_by_modality["image"]]
    if "audio_embeds" in items_by_modality:
        mm_data["audio"] = _get_embeds_data(
            "audio",
            [data for data, uuid in items_by_modality["audio_embeds"]],
            mm_processor,
        )
        mm_uuids["audio"] = [uuid for data, uuid in items_by_modality["audio_embeds"]]
    if "audio" in items_by_modality:
        mm_data["audio"] = [data for data, uuid in items_by_modality["audio"]]
        mm_uuids["audio"] = [uuid for data, uuid in items_by_modality["audio"]]
    if "video" in items_by_modality:
        mm_data["video"] = [data for data, uuid in items_by_modality["video"]]
        mm_uuids["video"] = [uuid for data, uuid in items_by_modality["video"]]
Roger Wang's avatar
Roger Wang committed
671
672
673
    if "vision_chunk" in items_by_modality:
        # Process vision_chunk items - extract from (data, modality) tuples
        # and convert to VisionChunk types with proper UUID handling
674
675
676
677
        processed_chunks, vision_chunk_uuids = _resolve_vision_chunk_items(
            items_by_modality["vision_chunk"],
            mm_processor,
            modality_order.get("vision_chunk", []),
Roger Wang's avatar
Roger Wang committed
678
679
        )
        mm_data["vision_chunk"] = processed_chunks
680
        mm_uuids["vision_chunk"] = vision_chunk_uuids
681
682
683
684
685
686
687
688

    return mm_data, mm_uuids


class MultiModalItemTracker(BaseMultiModalItemTracker[tuple[object, str | None]]):
    def resolve_items(
        self,
    ) -> tuple[MultiModalDataDict | None, MultiModalUUIDDict | None]:
689
        if not self._items_by_modality:
690
691
            return None, None

Roger Wang's avatar
Roger Wang committed
692
693
694
        return _resolve_items(
            dict(self._items_by_modality), self.mm_processor, self._modality_order
        )
695

696
697
698
699
    def create_parser(
        self, mm_processor_kwargs: dict[str, Any] | None = None
    ) -> "BaseMultiModalContentParser":
        return MultiModalContentParser(self, mm_processor_kwargs=mm_processor_kwargs)
700
701


702
703
704
705
706
707
class AsyncMultiModalItemTracker(
    BaseMultiModalItemTracker[Awaitable[tuple[object, str | None]]]
):
    async def resolve_items(
        self,
    ) -> tuple[MultiModalDataDict | None, MultiModalUUIDDict | None]:
708
        if not self._items_by_modality:
709
            return None, None
710

711
        resolved_items_by_modality = {
712
            modality: await asyncio.gather(*coros)
713
            for modality, coros in self._items_by_modality.items()
714
        }
715

Roger Wang's avatar
Roger Wang committed
716
717
718
        return _resolve_items(
            resolved_items_by_modality, self.mm_processor, self._modality_order
        )
719

720
721
722
723
724
725
    def create_parser(
        self, mm_processor_kwargs: dict[str, Any] | None = None
    ) -> "BaseMultiModalContentParser":
        return AsyncMultiModalContentParser(
            self, mm_processor_kwargs=mm_processor_kwargs
        )
726
727
728
729
730
731


class BaseMultiModalContentParser(ABC):
    def __init__(self) -> None:
        super().__init__()

732
        # stores model placeholders list with corresponding
733
734
735
736
737
738
739
        # general MM placeholder:
        # {
        #   "<##IMAGE##>": ["<image>", "<image>", "<image>"],
        #   "<##AUDIO##>": ["<audio>", "<audio>"]
        # }
        self._placeholder_storage: dict[str, list] = defaultdict(list)

740
    def _add_placeholder(self, modality: ModalityStr, placeholder: str | None):
741
        mod_placeholder = MODALITY_PLACEHOLDERS_MAP[modality]
742
        if placeholder:
743
            self._placeholder_storage[mod_placeholder].append(placeholder)
744

745
746
    def mm_placeholder_storage(self) -> dict[str, list]:
        return dict(self._placeholder_storage)
747
748

    @abstractmethod
749
    def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
750
751
        raise NotImplementedError

752
    @abstractmethod
753
    def parse_image_embeds(
754
        self,
755
756
        image_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
757
    ) -> None:
758
759
        raise NotImplementedError

760
    @abstractmethod
761
    def parse_image_pil(
762
        self, image_pil: Image.Image | None, uuid: str | None = None
763
    ) -> None:
764
765
        raise NotImplementedError

766
    @abstractmethod
767
    def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None:
768
769
        raise NotImplementedError

770
    @abstractmethod
771
    def parse_input_audio(
772
        self, input_audio: InputAudio | None, uuid: str | None = None
773
    ) -> None:
774
775
        raise NotImplementedError

776
777
778
779
780
781
782
783
    @abstractmethod
    def parse_audio_embeds(
        self,
        audio_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
    ) -> None:
        raise NotImplementedError

784
    @abstractmethod
785
    def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
786
787
        raise NotImplementedError

788
789

class MultiModalContentParser(BaseMultiModalContentParser):
790
791
792
793
794
    def __init__(
        self,
        tracker: MultiModalItemTracker,
        mm_processor_kwargs: dict[str, Any] | None = None,
    ) -> None:
795
796
797
        super().__init__()

        self._tracker = tracker
798

799
800
        self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load(
            envs.VLLM_MEDIA_CONNECTOR,
801
            media_io_kwargs=tracker.media_io_kwargs,
802
            allowed_local_media_path=tracker.allowed_local_media_path,
803
            allowed_media_domains=tracker.allowed_media_domains,
804
805
        )

806
807
        self._mm_processor_kwargs = mm_processor_kwargs

808
809
    @property
    def model_config(self) -> ModelConfig:
810
        return self._tracker.model_config
811

812
    def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
813
        image = self._connector.fetch_image(image_url) if image_url else None
814

815
        placeholder = self._tracker.add("image", (image, uuid))
816
        self._add_placeholder("image", placeholder)
817

818
    def parse_image_embeds(
819
        self,
820
821
        image_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
822
    ) -> None:
823
824
825
826
827
828
        mm_config = self.model_config.get_multimodal_config()
        if not mm_config.enable_mm_embeds:
            raise ValueError(
                "You must set `--enable-mm-embeds` to input `image_embeds`"
            )

829
830
831
832
833
        if isinstance(image_embeds, dict):
            embeds = {
                k: self._connector.fetch_image_embedding(v)
                for k, v in image_embeds.items()
            }
834
            placeholder = self._tracker.add("image_embeds", (embeds, uuid))
835
836
837

        if isinstance(image_embeds, str):
            embedding = self._connector.fetch_image_embedding(image_embeds)
838
            placeholder = self._tracker.add("image_embeds", (embedding, uuid))
839

840
        if image_embeds is None:
841
            placeholder = self._tracker.add("image_embeds", (None, uuid))
842

843
        self._add_placeholder("image", placeholder)
844

845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
    def parse_audio_embeds(
        self,
        audio_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
    ) -> None:
        mm_config = self.model_config.get_multimodal_config()
        if not mm_config.enable_mm_embeds:
            raise ValueError(
                "You must set `--enable-mm-embeds` to input `audio_embeds`"
            )

        if isinstance(audio_embeds, dict):
            embeds = {
                k: self._connector.fetch_audio_embedding(v)
                for k, v in audio_embeds.items()
            }
861
            placeholder = self._tracker.add("audio_embeds", (embeds, uuid))
862
863
        elif isinstance(audio_embeds, str):
            embedding = self._connector.fetch_audio_embedding(audio_embeds)
864
            placeholder = self._tracker.add("audio_embeds", (embedding, uuid))
865
        else:
866
            placeholder = self._tracker.add("audio_embeds", (None, uuid))
867
868
869

        self._add_placeholder("audio", placeholder)

870
    def parse_image_pil(
871
        self, image_pil: Image.Image | None, uuid: str | None = None
872
    ) -> None:
873
        placeholder = self._tracker.add("image", (image_pil, uuid))
874
        self._add_placeholder("image", placeholder)
875

876
    def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None:
877
        audio = self._connector.fetch_audio(audio_url) if audio_url else None
878

879
        placeholder = self._tracker.add("audio", (audio, uuid))
880
        self._add_placeholder("audio", placeholder)
881

882
    def parse_input_audio(
883
        self, input_audio: InputAudio | None, uuid: str | None = None
884
    ) -> None:
885
886
887
888
889
890
891
892
893
894
        if input_audio:
            audio_data = input_audio.get("data", "")
            audio_format = input_audio.get("format", "")
            if audio_data:
                audio_url = f"data:audio/{audio_format};base64,{audio_data}"
            else:
                # If a UUID is provided, audio data may be empty.
                audio_url = None
        else:
            audio_url = None
895

896
        return self.parse_audio(audio_url, uuid)
897

898
    def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
899
        video = self._connector.fetch_video(video_url=video_url) if video_url else None
900

901
        placeholder = self._tracker.add("video", (video, uuid))
902
        self._add_placeholder("video", placeholder)
903

904
905
906
907
908
909
910
911
912
913
        # Extract audio from video if use_audio_in_video is True
        if (
            video_url
            and self._mm_processor_kwargs
            and self._mm_processor_kwargs.get("use_audio_in_video", False)
        ):
            audio = self._connector.fetch_audio(video_url) if video_url else None
            audio_placeholder = self._tracker.add("audio", (audio, uuid))
            self._add_placeholder("audio", audio_placeholder)

914
915

class AsyncMultiModalContentParser(BaseMultiModalContentParser):
916
917
918
919
920
    def __init__(
        self,
        tracker: AsyncMultiModalItemTracker,
        mm_processor_kwargs: dict[str, Any] | None = None,
    ) -> None:
921
922
923
        super().__init__()

        self._tracker = tracker
924
925
        self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load(
            envs.VLLM_MEDIA_CONNECTOR,
926
            media_io_kwargs=tracker.media_io_kwargs,
927
            allowed_local_media_path=tracker.allowed_local_media_path,
928
            allowed_media_domains=tracker.allowed_media_domains,
929
        )
930
        self._mm_processor_kwargs: dict[str, Any] | None = mm_processor_kwargs
931

932
933
    @property
    def model_config(self) -> ModelConfig:
934
        return self._tracker.model_config
935

936
937
938
939
940
941
    async def _image_with_uuid_async(self, image_url: str | None, uuid: str | None):
        image = (
            await self._connector.fetch_image_async(image_url) if image_url else None
        )
        return image, uuid

942
    def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
943
        coro = self._image_with_uuid_async(image_url, uuid)
944

945
        placeholder = self._tracker.add("image", coro)
946
        self._add_placeholder("image", placeholder)
947

948
    def parse_image_embeds(
949
        self,
950
951
        image_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
952
    ) -> None:
953
954
955
956
957
958
        mm_config = self.model_config.get_multimodal_config()
        if not mm_config.enable_mm_embeds:
            raise ValueError(
                "You must set `--enable-mm-embeds` to input `image_embeds`"
            )

959
960
961
        future = asyncio.Future[
            tuple[torch.Tensor | dict[str, torch.Tensor] | None, str | None]
        ]()
962
963
964
965
966
967

        if isinstance(image_embeds, dict):
            embeds = {
                k: self._connector.fetch_image_embedding(v)
                for k, v in image_embeds.items()
            }
968
            future.set_result((embeds, uuid))
969
970

        if isinstance(image_embeds, str):
971
            embedding = self._connector.fetch_image_embedding(image_embeds)
972
            future.set_result((embedding, uuid))
973

974
        if image_embeds is None:
975
            future.set_result((None, uuid))
976

977
        placeholder = self._tracker.add("image_embeds", future)
978
        self._add_placeholder("image", placeholder)
979

980
981
982
983
984
985
986
987
988
989
990
    def parse_audio_embeds(
        self,
        audio_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
    ) -> None:
        mm_config = self.model_config.get_multimodal_config()
        if not mm_config.enable_mm_embeds:
            raise ValueError(
                "You must set `--enable-mm-embeds` to input `audio_embeds`"
            )

991
992
993
        future = asyncio.Future[
            tuple[torch.Tensor | dict[str, torch.Tensor] | None, str | None]
        ]()
994
995
996
997
998
999

        if isinstance(audio_embeds, dict):
            embeds = {
                k: self._connector.fetch_audio_embedding(v)
                for k, v in audio_embeds.items()
            }
1000
            future.set_result((embeds, uuid))
1001
1002
1003

        if isinstance(audio_embeds, str):
            embedding = self._connector.fetch_audio_embedding(audio_embeds)
1004
            future.set_result((embedding, uuid))
1005
1006

        if audio_embeds is None:
1007
            future.set_result((None, uuid))
1008

1009
        placeholder = self._tracker.add("audio_embeds", future)
1010
1011
        self._add_placeholder("audio", placeholder)

1012
    def parse_image_pil(
1013
1014
1015
        self,
        image_pil: Image.Image | None,
        uuid: str | None = None,
1016
    ) -> None:
1017
        future = asyncio.Future[tuple[Image.Image | None, str | None]]()
1018
        if image_pil:
1019
            future.set_result((image_pil, uuid))
1020
        else:
1021
            future.set_result((None, uuid))
1022

1023
        placeholder = self._tracker.add("image", future)
1024
        self._add_placeholder("image", placeholder)
1025

1026
1027
1028
1029
1030
1031
    async def _audio_with_uuid_async(self, audio_url: str | None, uuid: str | None):
        audio = (
            await self._connector.fetch_audio_async(audio_url) if audio_url else None
        )
        return audio, uuid

1032
    def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None:
1033
        coro = self._audio_with_uuid_async(audio_url, uuid)
1034

1035
        placeholder = self._tracker.add("audio", coro)
1036
        self._add_placeholder("audio", placeholder)
1037

1038
    def parse_input_audio(
1039
        self, input_audio: InputAudio | None, uuid: str | None = None
1040
    ) -> None:
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
        if input_audio:
            audio_data = input_audio.get("data", "")
            audio_format = input_audio.get("format", "")
            if audio_data:
                audio_url = f"data:audio/{audio_format};base64,{audio_data}"
            else:
                # If a UUID is provided, audio data may be empty.
                audio_url = None
        else:
            audio_url = None
1051

1052
        return self.parse_audio(audio_url, uuid)
1053

1054
    async def _video_with_uuid_async(self, video_url: str | None, uuid: str | None):
1055
        video = (
1056
            await self._connector.fetch_video_async(video_url) if video_url else None
1057
        )
1058
1059
1060
1061
        return video, uuid

    def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
        coro = self._video_with_uuid_async(video_url, uuid)
1062

1063
        placeholder = self._tracker.add("video", coro)
1064
        self._add_placeholder("video", placeholder)
1065

1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
        # Extract audio from video if use_audio_in_video is True
        if (
            video_url
            and self._mm_processor_kwargs
            and self._mm_processor_kwargs.get("use_audio_in_video", False)
        ):
            audio_coro = self._audio_with_uuid_async(video_url, uuid)
            audio_placeholder = self._tracker.add("audio", audio_coro)
            self._add_placeholder("audio", audio_placeholder)

1076

1077
1078
1079
1080
1081
1082
1083
@dataclass
class ChatTemplateConfig:
    chat_template: str | None = None
    chat_template_content_format: ChatTemplateContentFormatOption = "auto"
    trust_request_chat_template: bool = False


1084
def validate_chat_template(chat_template: Path | str | None):
1085
1086
1087
1088
1089
    """Raises if the provided chat template appears invalid."""
    if chat_template is None:
        return

    elif isinstance(chat_template, Path) and not chat_template.exists():
1090
        raise FileNotFoundError("the supplied chat template path doesn't exist")
1091
1092
1093

    elif isinstance(chat_template, str):
        JINJA_CHARS = "{}\n"
1094
1095
1096
1097
        if (
            not any(c in chat_template for c in JINJA_CHARS)
            and not Path(chat_template).exists()
        ):
1098
1099
1100
            # Try to find the template in the built-in templates directory
            from vllm.transformers_utils.chat_templates.registry import (
                CHAT_TEMPLATES_DIR,
1101
            )
1102

1103
1104
1105
1106
1107
1108
1109
1110
            builtin_template_path = CHAT_TEMPLATES_DIR / chat_template
            if not builtin_template_path.exists():
                raise ValueError(
                    f"The supplied chat template string ({chat_template}) "
                    f"appears path-like, but doesn't exist! "
                    f"Tried: {chat_template} and {builtin_template_path}"
                )

1111
    else:
1112
        raise TypeError(f"{type(chat_template)} is not a valid chat template type")
1113
1114


1115
def _load_chat_template(
1116
    chat_template: Path | str | None,
1117
1118
    *,
    is_literal: bool = False,
1119
) -> str | None:
1120
1121
    if chat_template is None:
        return None
1122
1123
1124

    if is_literal:
        if isinstance(chat_template, Path):
1125
1126
1127
            raise TypeError(
                "chat_template is expected to be read directly from its value"
            )
1128

1129
        return chat_template
1130

1131
    try:
1132
        with open(chat_template) as f:
1133
            return f.read()
1134
    except OSError as e:
1135
1136
1137
        if isinstance(chat_template, Path):
            raise

1138
1139
        JINJA_CHARS = "{}\n"
        if not any(c in chat_template for c in JINJA_CHARS):
1140
1141
1142
            # Try to load from the built-in templates directory
            from vllm.transformers_utils.chat_templates.registry import (
                CHAT_TEMPLATES_DIR,
1143
            )
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156

            builtin_template_path = CHAT_TEMPLATES_DIR / chat_template
            try:
                with open(builtin_template_path) as f:
                    return f.read()
            except OSError:
                msg = (
                    f"The supplied chat template ({chat_template}) "
                    f"looks like a file path, but it failed to be opened. "
                    f"Tried: {chat_template} and {builtin_template_path}. "
                    f"Reason: {e}"
                )
                raise ValueError(msg) from e
1157

1158
1159
        # If opening a file fails, set chat template to be args to
        # ensure we decode so our escape are interpreted correctly
1160
1161
1162
1163
1164
1165
1166
        return _load_chat_template(chat_template, is_literal=True)


_cached_load_chat_template = lru_cache(_load_chat_template)


def load_chat_template(
1167
    chat_template: Path | str | None,
1168
1169
    *,
    is_literal: bool = False,
1170
) -> str | None:
1171
    return _cached_load_chat_template(chat_template, is_literal=is_literal)
1172
1173


1174
1175
1176
def _get_interleaved_text_prompt(
    placeholder_storage: dict[str, list], texts: list[str]
) -> str:
1177
1178
1179
1180
1181
1182
1183
    for idx, elem in enumerate(texts):
        if elem in placeholder_storage:
            texts[idx] = placeholder_storage[elem].pop(0)

    return "\n".join(texts)


1184
# TODO: Let user specify how to insert multimodal tokens into prompt
1185
# (similar to chat template)
1186
1187
1188
1189
def _get_full_multimodal_text_prompt(
    placeholder_storage: dict[str, list],
    texts: list[str],
    interleave_strings: bool,
1190
    multimodal_content_part_separator: str = "\n",
1191
) -> str:
1192
    """Combine multimodal prompts for a multimodal language model."""
1193

1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
    # flatten storage to make it looks like
    # {
    #   "<|image|>": 2,
    #   "<|audio|>": 1
    # }
    placeholder_counts = Counter(
        [v for elem in placeholder_storage.values() for v in elem]
    )

    if interleave_strings:
        text_prompt = _get_interleaved_text_prompt(placeholder_storage, texts)
    else:
        text_prompt = "\n".join(texts)

    # Pass interleaved text further in case the user used image placeholders
    # himself, but forgot to disable the 'interleave_strings' flag

1211
    # Look through the text prompt to check for missing placeholders
1212
    missing_placeholders: list[str] = []
1213
1214
1215
1216
1217
    for placeholder in placeholder_counts:
        # For any existing placeholder in the text prompt, we leave it as is
        placeholder_counts[placeholder] -= text_prompt.count(placeholder)

        if placeholder_counts[placeholder] < 0:
1218
1219
1220
1221
            logger.error(
                "Placeholder count is negative! "
                "Ensure that the 'interleave_strings' flag is disabled "
                "(current value: %s) "
1222
1223
                "when manually placing image placeholders.",
                interleave_strings,
1224
1225
            )
            logger.debug("Input prompt: %s", text_prompt)
1226
1227
            raise ValueError(
                f"Found more '{placeholder}' placeholders in input prompt than "
1228
1229
                "actual multimodal data items."
            )
1230

1231
        missing_placeholders.extend([placeholder] * placeholder_counts[placeholder])
1232

1233
1234
    # NOTE: Default behaviour: we always add missing placeholders
    # at the front of the prompt, if interleave_strings=False
1235
    if text_prompt:
1236
1237
1238
        return multimodal_content_part_separator.join(
            missing_placeholders + [text_prompt]
        )
1239
    else:
1240
        return multimodal_content_part_separator.join(missing_placeholders)
1241
1242


1243
1244
# No need to validate using Pydantic again
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
1245
_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
1246
_AudioEmbedsParser = partial(cast, ChatCompletionContentPartAudioEmbedsParam)
1247
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
1248
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
1249
_PILImageParser = partial(cast, CustomChatCompletionContentPILImageParam)
Julien Denize's avatar
Julien Denize committed
1250
_ThinkParser = partial(cast, CustomThinkCompletionContentParam)
1251
1252
1253
1254
# Need to validate url objects
_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam).validate_python
_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python
_VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python
1255

1256
_ResponsesInputImageParser = TypeAdapter(ResponseInputImageParam).validate_python
1257
_ContentPart: TypeAlias = str | dict[str, str] | InputAudio | PILImage
1258

1259
# Define a mapping from part types to their corresponding parsing functions.
1260
MM_PARSER_MAP: dict[
1261
1262
1263
    str,
    Callable[[ChatCompletionContentPartParam], _ContentPart],
] = {
1264
1265
1266
    "text": lambda part: _TextParser(part).get("text", None),
    "thinking": lambda part: _ThinkParser(part).get("thinking", None),
    "input_text": lambda part: _TextParser(part).get("text", None),
1267
    "output_text": lambda part: _TextParser(part).get("text", None),
1268
1269
1270
    "input_image": lambda part: _ResponsesInputImageParser(part).get("image_url", None),
    "image_url": lambda part: _ImageParser(part).get("image_url", {}).get("url", None),
    "image_embeds": lambda part: _ImageEmbedsParser(part).get("image_embeds", None),
1271
    "audio_embeds": lambda part: _AudioEmbedsParser(part).get("audio_embeds", None),
1272
    "image_pil": lambda part: _PILImageParser(part).get("image_pil", None),
1273
1274
    "audio_url": lambda part: _AudioParser(part).get("audio_url", {}).get("url", None),
    "input_audio": lambda part: _InputAudioParser(part).get("input_audio", None),
1275
    "refusal": lambda part: _RefusalParser(part).get("refusal", None),
1276
    "video_url": lambda part: _VideoParser(part).get("video_url", {}).get("url", None),
1277
1278
1279
1280
}


def _parse_chat_message_content_mm_part(
1281
1282
    part: ChatCompletionContentPartParam,
) -> tuple[str, _ContentPart]:
1283
    """
1284
    Parses a given multi-modal content part based on its type.
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297

    Args:
        part: A dict containing the content part, with a potential 'type' field.

    Returns:
        A tuple (part_type, content) where:
        - part_type: Type of the part (e.g., 'text', 'image_url').
        - content: Parsed content (e.g., text, image URL).

    Raises:
        ValueError: If the 'type' field is missing and no direct URL is found.
    """
    assert isinstance(
1298
1299
        part, dict
    )  # This is needed to avoid mypy errors: part.get() from str
1300
    part_type = part.get("type", None)
1301
    uuid = part.get("uuid", None)
1302

1303
    if isinstance(part_type, str) and part_type in MM_PARSER_MAP and uuid is None:  # noqa: E501
1304
1305
1306
        content = MM_PARSER_MAP[part_type](part)

        # Special case for 'image_url.detail'
1307
1308
        # We only support 'auto', which is the default
        if part_type == "image_url" and part.get("detail", "auto") != "auto":
1309
            logger.warning(
1310
                "'image_url.detail' is currently not supported and will be ignored."
1311
            )
1312
1313
1314
1315

        return part_type, content

    # Handle missing 'type' but provided direct URL fields.
1316
    # 'type' is required field by pydantic
1317
1318
    if part_type is None or uuid is not None:
        if "image_url" in part:
1319
            image_params = cast(CustomChatCompletionContentSimpleImageParam, part)
1320
1321
1322
1323
1324
1325
1326
1327
            image_url = image_params.get("image_url", None)
            if isinstance(image_url, dict):
                # Can potentially happen if user provides a uuid
                # with url as a dict of {"url": url}
                image_url = image_url.get("url", None)
            return "image_url", image_url
        if "image_pil" in part:
            # "image_pil" could be None if UUID is provided.
1328
            image_params = cast(  # type: ignore
1329
1330
1331
1332
1333
1334
                CustomChatCompletionContentPILImageParam, part
            )
            image_pil = image_params.get("image_pil", None)
            return "image_pil", image_pil
        if "image_embeds" in part:
            # "image_embeds" could be None if UUID is provided.
1335
            image_params = cast(  # type: ignore
1336
1337
1338
1339
                ChatCompletionContentPartImageEmbedsParam, part
            )
            image_embeds = image_params.get("image_embeds", None)
            return "image_embeds", image_embeds
1340
1341
1342
1343
1344
1345
1346
        if "audio_embeds" in part:
            # "audio_embeds" could be None if UUID is provided.
            audio_params = cast(  # type: ignore[assignment]
                ChatCompletionContentPartAudioEmbedsParam, part
            )
            audio_embeds = audio_params.get("audio_embeds", None)
            return "audio_embeds", audio_embeds
1347
        if "audio_url" in part:
1348
1349
1350
            audio_params = cast(  # type: ignore[assignment]
                CustomChatCompletionContentSimpleAudioParam, part
            )
1351
1352
1353
1354
1355
1356
            audio_url = audio_params.get("audio_url", None)
            if isinstance(audio_url, dict):
                # Can potentially happen if user provides a uuid
                # with url as a dict of {"url": url}
                audio_url = audio_url.get("url", None)
            return "audio_url", audio_url
1357
        if part.get("input_audio") is not None:
1358
            input_audio_params = cast(dict[str, str], part)
1359
            return "input_audio", input_audio_params
1360
        if "video_url" in part:
1361
            video_params = cast(CustomChatCompletionContentSimpleVideoParam, part)
1362
1363
1364
1365
1366
1367
            video_url = video_params.get("video_url", None)
            if isinstance(video_url, dict):
                # Can potentially happen if user provides a uuid
                # with url as a dict of {"url": url}
                video_url = video_url.get("url", None)
            return "video_url", video_url
1368
1369
1370
1371
1372
1373
1374
1375
        # Raise an error if no 'type' or direct URL is found.
        raise ValueError("Missing 'type' field in multimodal part.")

    if not isinstance(part_type, str):
        raise ValueError("Invalid 'type' field in multimodal part.")
    return part_type, "unknown part_type content"


1376
PART_TYPES_TO_SKIP_NONE_CONTENT = (
1377
1378
1379
    "text",
    "refusal",
)
1380

1381

1382
1383
1384
def _parse_chat_message_content_parts(
    role: str,
    parts: Iterable[ChatCompletionContentPartParam],
1385
    mm_tracker: BaseMultiModalItemTracker,
1386
1387
    *,
    wrap_dicts: bool,
1388
    interleave_strings: bool,
1389
    mm_processor_kwargs: dict[str, Any] | None = None,
1390
    multimodal_content_part_separator="\n",
1391
) -> list[ConversationMessage]:
1392
    content = list[_ContentPart]()
1393

1394
    mm_parser = mm_tracker.create_parser(mm_processor_kwargs=mm_processor_kwargs)
1395
1396

    for part in parts:
1397
        parse_res = _parse_chat_message_content_part(
1398
1399
1400
            part,
            mm_parser,
            wrap_dicts=wrap_dicts,
1401
            interleave_strings=interleave_strings,
1402
        )
1403
1404
        if parse_res:
            content.append(parse_res)
1405

1406
    if wrap_dicts:
1407
        # Parsing wraps images and texts as interleaved dictionaries
1408
        return [ConversationMessage(role=role, content=content)]  # type: ignore
1409
    texts = cast(list[str], content)
1410
1411
    mm_placeholder_storage = mm_parser.mm_placeholder_storage()
    if mm_placeholder_storage:
1412
        text_prompt = _get_full_multimodal_text_prompt(
1413
1414
1415
1416
            mm_placeholder_storage,
            texts,
            interleave_strings,
            multimodal_content_part_separator=multimodal_content_part_separator,
1417
        )
1418
1419
1420
    else:
        text_prompt = "\n".join(texts)

1421
1422
1423
1424
    return [ConversationMessage(role=role, content=text_prompt)]


def _parse_chat_message_content_part(
1425
1426
1427
1428
    part: ChatCompletionContentPartParam,
    mm_parser: BaseMultiModalContentParser,
    *,
    wrap_dicts: bool,
1429
    interleave_strings: bool,
1430
) -> _ContentPart | None:
1431
1432
1433
1434
1435
1436
1437
1438
    """Parses a single part of a conversation. If wrap_dicts is True,
    structured dictionary pieces for texts and images will be
    wrapped in dictionaries, i.e., {"type": "text", "text", ...} and
    {"type": "image"}, respectively. Otherwise multimodal data will be
    handled by mm_parser, and texts will be returned as strings to be joined
    with multimodal placeholders.
    """
    if isinstance(part, str):  # Handle plain text parts
1439
1440
        if wrap_dicts:
            return {"type": "text", "text": part}
1441
        return part
1442
1443
    # Handle structured dictionary parts
    part_type, content = _parse_chat_message_content_mm_part(part)
1444
    # if part_type is text/refusal/image_url/audio_url/video_url/input_audio but
1445
    # content is None, log a warning and skip
1446
    if part_type in PART_TYPES_TO_SKIP_NONE_CONTENT and content is None:
1447
        logger.warning(
1448
            "Skipping multimodal part '%s' (type: '%s') "
1449
1450
1451
1452
            "with empty / unparsable content.",
            part,
            part_type,
        )
1453
1454
        return None

1455
    if part_type in ("text", "input_text", "output_text", "refusal", "thinking"):
1456
1457
        str_content = cast(str, content)
        if wrap_dicts:
1458
            return {"type": "text", "text": str_content}
1459
1460
        else:
            return str_content
1461

1462
1463
1464
1465
1466
1467
    # For media items, if a user has provided one, use it. Otherwise, insert
    # a placeholder empty uuid.
    uuid = part.get("uuid", None)
    if uuid is not None:
        uuid = str(uuid)

1468
    modality = None
1469
    if part_type == "image_pil":
1470
        image_content = cast(Image.Image, content) if content is not None else None
1471
        mm_parser.parse_image_pil(image_content, uuid)
1472
        modality = "image"
1473
    elif part_type in ("image_url", "input_image"):
1474
        str_content = cast(str, content)
1475
        mm_parser.parse_image(str_content, uuid)
1476
1477
        modality = "image"
    elif part_type == "image_embeds":
1478
        content = cast(str | dict[str, str], content) if content is not None else None
1479
        mm_parser.parse_image_embeds(content, uuid)
1480
        modality = "image"
1481
1482
1483
1484
    elif part_type == "audio_embeds":
        content = cast(str | dict[str, str], content) if content is not None else None
        mm_parser.parse_audio_embeds(content, uuid)
        modality = "audio"
1485
    elif part_type == "audio_url":
1486
        str_content = cast(str, content)
1487
        mm_parser.parse_audio(str_content, uuid)
1488
1489
        modality = "audio"
    elif part_type == "input_audio":
1490
        dict_content = cast(InputAudio, content)
1491
        mm_parser.parse_input_audio(dict_content, uuid)
1492
1493
        modality = "audio"
    elif part_type == "video_url":
1494
        str_content = cast(str, content)
1495
        mm_parser.parse_video(str_content, uuid)
1496
1497
1498
        modality = "video"
    else:
        raise NotImplementedError(f"Unknown part type: {part_type}")
1499

1500
1501
1502
    if wrap_dicts:
        return {"type": modality}
    return MODALITY_PLACEHOLDERS_MAP[modality] if interleave_strings else None
1503
1504


1505
1506
1507
1508
1509
# No need to validate using Pydantic again
_AssistantParser = partial(cast, ChatCompletionAssistantMessageParam)
_ToolParser = partial(cast, ChatCompletionToolMessageParam)


1510
def _parse_chat_message_content(
1511
1512
    message: ChatCompletionMessageParam,
    mm_tracker: BaseMultiModalItemTracker,
1513
    content_format: ChatTemplateContentFormat,
1514
    interleave_strings: bool,
1515
    mm_processor_kwargs: dict[str, Any] | None = None,
1516
) -> list[ConversationMessage]:
1517
1518
    role = message["role"]
    content = message.get("content")
1519
    reasoning = message.get("reasoning")
1520

1521
    if content is None:
1522
1523
        content = []
    elif isinstance(content, str):
1524
        content = [ChatCompletionContentPartTextParam(type="text", text=content)]
1525
    result = _parse_chat_message_content_parts(
1526
1527
        role,
        content,  # type: ignore
1528
        mm_tracker,
1529
        wrap_dicts=(content_format == "openai"),
1530
        interleave_strings=interleave_strings,
1531
        mm_processor_kwargs=mm_processor_kwargs,
1532
    )
1533

1534
    for result_msg in result:
1535
        if role == "assistant":
1536
1537
            parsed_msg = _AssistantParser(message)

1538
1539
1540
            # The 'tool_calls' is not None check ensures compatibility.
            # It's needed only if downstream code doesn't strictly
            # follow the OpenAI spec.
1541
            if "tool_calls" in parsed_msg and parsed_msg["tool_calls"] is not None:
1542
                result_msg["tool_calls"] = list(parsed_msg["tool_calls"])
1543
1544
1545
            # Include reasoning if present for interleaved thinking.
            if reasoning is not None:
                result_msg["reasoning"] = cast(str, reasoning)
1546
1547
1548
                result_msg["reasoning_content"] = cast(
                    str, reasoning
                )  # keep compatibility
1549
1550
1551
1552
1553
1554
1555
1556
        elif role == "tool":
            parsed_msg = _ToolParser(message)
            if "tool_call_id" in parsed_msg:
                result_msg["tool_call_id"] = parsed_msg["tool_call_id"]

        if "name" in message and isinstance(message["name"], str):
            result_msg["name"] = message["name"]

1557
1558
        if role == "developer":
            result_msg["tools"] = message.get("tools", None)
1559
1560
    return result

1561

1562
def _postprocess_messages(messages: list[ConversationMessage]) -> None:
1563
1564
1565
1566
1567
1568
    # per the Transformers docs & maintainers, tool call arguments in
    # assistant-role messages with tool_calls need to be dicts not JSON str -
    # this is how tool-use chat templates will expect them moving forwards
    # so, for messages that have tool_calls, parse the string (which we get
    # from openAI format) to dict
    for message in messages:
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
        if message["role"] == "assistant" and "tool_calls" in message:
            tool_calls = message.get("tool_calls")
            if not isinstance(tool_calls, list):
                continue

            if len(tool_calls) == 0:
                # Drop empty tool_calls to keep templates on the normal assistant path.
                message.pop("tool_calls", None)
                continue

            for item in tool_calls:
1580
1581
                # if arguments is None or empty string, set to {}
                if content := item["function"].get("arguments"):
1582
1583
                    if not isinstance(content, (dict, list)):
                        item["function"]["arguments"] = json.loads(content)
1584
1585
                else:
                    item["function"]["arguments"] = {}
1586
1587


1588
def parse_chat_messages(
1589
    messages: list[ChatCompletionMessageParam],
1590
    model_config: ModelConfig,
1591
    content_format: ChatTemplateContentFormat,
1592
    media_io_kwargs: dict[str, dict[str, Any]] | None = None,
1593
    mm_processor_kwargs: dict[str, Any] | None = None,
1594
1595
) -> tuple[
    list[ConversationMessage],
1596
1597
    MultiModalDataDict | None,
    MultiModalUUIDDict | None,
1598
]:
1599
    conversation: list[ConversationMessage] = []
1600
    mm_tracker = MultiModalItemTracker(model_config, media_io_kwargs=media_io_kwargs)
1601
1602

    for msg in messages:
1603
1604
1605
        sub_messages = _parse_chat_message_content(
            msg,
            mm_tracker,
1606
            content_format,
1607
1608
1609
1610
            interleave_strings=(
                content_format == "string"
                and model_config.multimodal_config is not None
                and model_config.multimodal_config.interleave_mm_strings
1611
            ),
1612
            mm_processor_kwargs=mm_processor_kwargs,
1613
        )
1614

1615
        conversation.extend(sub_messages)
1616

1617
1618
    _postprocess_messages(conversation)

1619
1620
1621
    mm_data, mm_uuids = mm_tracker.resolve_items()

    return conversation, mm_data, mm_uuids
1622
1623


1624
async def parse_chat_messages_async(
1625
    messages: list[ChatCompletionMessageParam],
1626
    model_config: ModelConfig,
1627
    content_format: ChatTemplateContentFormat,
1628
    media_io_kwargs: dict[str, dict[str, Any]] | None = None,
1629
    mm_processor_kwargs: dict[str, Any] | None = None,
1630
1631
) -> tuple[
    list[ConversationMessage],
1632
    MultiModalDataDict | None,
1633
    MultiModalUUIDDict | None,
1634
]:
1635
    conversation: list[ConversationMessage] = []
1636
1637
1638
    mm_tracker = AsyncMultiModalItemTracker(
        model_config, media_io_kwargs=media_io_kwargs
    )
1639
1640

    for msg in messages:
1641
1642
1643
        sub_messages = _parse_chat_message_content(
            msg,
            mm_tracker,
1644
            content_format,
1645
1646
1647
1648
            interleave_strings=(
                content_format == "string"
                and model_config.multimodal_config is not None
                and model_config.multimodal_config.interleave_mm_strings
1649
            ),
1650
            mm_processor_kwargs=mm_processor_kwargs,
1651
        )
1652
1653
1654

        conversation.extend(sub_messages)

1655
1656
    _postprocess_messages(conversation)

1657
1658
1659
    mm_data, mm_uuids = await mm_tracker.resolve_items()

    return conversation, mm_data, mm_uuids
1660

1661

1662
1663
1664
def get_history_tool_calls_cnt(conversation: list[ConversationMessage]):
    idx = 0
    for msg in conversation:
1665
1666
1667
        if msg["role"] == "assistant":
            tool_calls = msg.get("tool_calls")
            idx += len(list(tool_calls)) if tool_calls is not None else 0  # noqa
1668
1669
1670
    return idx


1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
_KIMI_MODEL_TYPES = ("kimi_k2", "kimi_k25")


def get_tool_call_id_type(model_config: ModelConfig) -> str:
    """Return the tool-call ID type for a given model configuration."""
    hf_overrides = getattr(model_config, "hf_overrides", None)
    if model_config.hf_text_config.model_type in _KIMI_MODEL_TYPES or (
        isinstance(hf_overrides, dict)
        and hf_overrides.get("model_type") in _KIMI_MODEL_TYPES
    ):
        return "kimi_k2"
    return "random"


1685
1686
1687
def make_tool_call_id(id_type: str = "random", func_name=None, idx=None):
    if id_type == "kimi_k2":
        return f"functions.{func_name}:{idx}"
1688
1689
1690
    else:
        # by default return random
        return f"chatcmpl-tool-{random_uuid()}"