chat_utils.py 59.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
import json
6
import warnings
7
from abc import ABC, abstractmethod
8
from collections import Counter, defaultdict
9
from collections.abc import Awaitable, Callable, Iterable
10
from dataclasses import dataclass
11
from functools import cached_property, lru_cache, partial
12
from itertools import accumulate
13
from pathlib import Path
14
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, TypeVar, cast
15

16
from openai.types.chat import (
17
18
19
20
21
    ChatCompletionAssistantMessageParam,
    ChatCompletionContentPartImageParam,
    ChatCompletionContentPartInputAudioParam,
    ChatCompletionContentPartRefusalParam,
    ChatCompletionContentPartTextParam,
22
    ChatCompletionFunctionToolParam,
23
24
25
26
27
28
    ChatCompletionMessageToolCallParam,
    ChatCompletionToolMessageParam,
)
from openai.types.chat import (
    ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam,
)
29
from openai.types.chat import (
30
31
32
    ChatCompletionMessageParam as OpenAIChatCompletionMessageParam,
)
from openai.types.chat.chat_completion_content_part_input_audio_param import InputAudio
33
from openai.types.responses import ResponseInputImageParam
34
from openai_harmony import Message as OpenAIHarmonyMessage
35
36
from PIL import Image
from pydantic import BaseModel, ConfigDict, TypeAdapter
37

38
# pydantic needs the TypedDict from typing_extensions
39
from typing_extensions import Required, TypedDict
40

41
from vllm import envs
42
from vllm.config import ModelConfig
43
from vllm.inputs import MultiModalDataDict, MultiModalUUIDDict
44
from vllm.logger import init_logger
45
from vllm.model_executor.models import SupportsMultiModal
46
from vllm.multimodal import MULTIMODAL_REGISTRY
47
48
49
50
from vllm.multimodal.inputs import (
    MultiModalBatchedField,
    MultiModalFlatField,
    MultiModalSharedField,
Roger Wang's avatar
Roger Wang committed
51
52
53
    VisionChunk,
    VisionChunkImage,
    VisionChunkVideo,
54
)
55
from vllm.multimodal.media import MEDIA_CONNECTOR_REGISTRY, MediaConnector
56
from vllm.multimodal.processing import BaseMultiModalProcessor
57
from vllm.utils import random_uuid
58
59
60
61
62
from vllm.utils.collection_utils import is_list_of
from vllm.utils.import_utils import LazyLoader

if TYPE_CHECKING:
    import torch
63
    import transformers
64
else:
65
    transformers = LazyLoader("transformers", globals(), "transformers")
66
    torch = LazyLoader("torch", globals(), "torch")
67
68
69

logger = init_logger(__name__)

70

71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def __getattr__(name: str):
    if name == "resolve_hf_chat_template":
        from vllm.renderers.hf import resolve_chat_template

        warnings.warn(
            "`vllm.entrypoints.chat_utils.resolve_hf_chat_template` has been moved to "
            "`vllm.renderers.hf.resolve_chat_template`. "
            "The old name will be removed in v0.16.",
            DeprecationWarning,
            stacklevel=2,
        )

        return resolve_chat_template

    raise AttributeError(f"module {__name__!r} has no attribute {name!r}")


88
89
90
91
92
93
94
95
class ChatTemplateResolutionError(ValueError):
    """Raised when chat template resolution fails.

    This is a subclass of ValueError for backward compatibility with
    existing exception handlers.
    """


96
97
98
99
100
101
MODALITY_PLACEHOLDERS_MAP = {
    "image": "<##IMAGE##>",
    "audio": "<##AUDIO##>",
    "video": "<##VIDEO##>",
}

102

103
104
105
106
107
108
109
110
111
112
113
114
115
116
class AudioURL(TypedDict, total=False):
    url: Required[str]
    """
    Either a URL of the audio or a data URL with base64 encoded audio data.
    """


class ChatCompletionContentPartAudioParam(TypedDict, total=False):
    audio_url: Required[AudioURL]

    type: Required[Literal["audio_url"]]
    """The type of the content part."""


117
class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
118
    image_embeds: str | dict[str, str] | None
119
120
121
122
123
124
125
    """
    The image embeddings. It can be either:
    - A single base64 string.
    - A dictionary where each value is a base64 string.
    """
    type: Required[Literal["image_embeds"]]
    """The type of the content part."""
126
    uuid: str | None
127
128
129
130
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
131
132


133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
class ChatCompletionContentPartAudioEmbedsParam(TypedDict, total=False):
    audio_embeds: str | dict[str, str] | None
    """
    The audio embeddings. It can be either:
    - A single base64 string representing a serialized torch tensor.
    - A dictionary where each value is a base64 string.
    """
    type: Required[Literal["audio_embeds"]]
    """The type of the content part."""
    uuid: str | None
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """


149
150
151
152
153
154
155
156
157
158
159
160
161
162
class VideoURL(TypedDict, total=False):
    url: Required[str]
    """
    Either a URL of the video or a data URL with base64 encoded video data.
    """


class ChatCompletionContentPartVideoParam(TypedDict, total=False):
    video_url: Required[VideoURL]

    type: Required[Literal["video_url"]]
    """The type of the content part."""


163
164
165
166
class PILImage(BaseModel):
    """
    A PIL.Image.Image object.
    """
167

168
169
170
171
172
173
174
175
176
177
178
179
    image_pil: Image.Image
    model_config = ConfigDict(arbitrary_types_allowed=True)


class CustomChatCompletionContentPILImageParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a PIL image.

    Example:
    {
        "image_pil": ImageAsset('cherry_blossom').pil_image
    }
    """
180

181
182
    image_pil: PILImage | None
    uuid: str | None
183
184
185
186
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
187
188


189
190
191
class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain image_url.
    This is supported by OpenAI API, although it is not documented.
192

193
194
195
196
197
    Example:
    {
        "image_url": "https://example.com/image.jpg"
    }
    """
198

199
200
    image_url: str | None
    uuid: str | None
201
202
203
204
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
205
206
207
208


class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain audio_url.
209

210
211
212
213
214
    Example:
    {
        "audio_url": "https://example.com/audio.mp3"
    }
    """
215

216
    audio_url: str | None
217
218


219
220
221
222
223
224
225
226
class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain audio_url.

    Example:
    {
        "video_url": "https://example.com/video.mp4"
    }
    """
227

228
229
    video_url: str | None
    uuid: str | None
230
231
232
233
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
234
235


Julien Denize's avatar
Julien Denize committed
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
class CustomThinkCompletionContentParam(TypedDict, total=False):
    """A Think Completion Content Param that accepts a plain text and a boolean.

    Example:
    {
        "thinking": "I am thinking about the answer",
        "closed": True,
        "type": "thinking"
    }
    """

    thinking: Required[str]
    """The thinking content."""

    closed: bool
    """Whether the thinking is closed."""

    type: Required[Literal["thinking"]]
    """The thinking type."""


257
258
259
260
261
262
263
264
265
ChatCompletionContentPartParam: TypeAlias = (
    OpenAIChatCompletionContentPartParam
    | ChatCompletionContentPartAudioParam
    | ChatCompletionContentPartInputAudioParam
    | ChatCompletionContentPartVideoParam
    | ChatCompletionContentPartRefusalParam
    | CustomChatCompletionContentPILImageParam
    | CustomChatCompletionContentSimpleImageParam
    | ChatCompletionContentPartImageEmbedsParam
266
    | ChatCompletionContentPartAudioEmbedsParam
267
268
269
270
271
    | CustomChatCompletionContentSimpleAudioParam
    | CustomChatCompletionContentSimpleVideoParam
    | str
    | CustomThinkCompletionContentParam
)
272
273
274
275


class CustomChatCompletionMessageParam(TypedDict, total=False):
    """Enables custom roles in the Chat Completion API."""
276

277
278
279
    role: Required[str]
    """The role of the message's author."""

280
    content: str | list[ChatCompletionContentPartParam]
281
282
283
284
285
286
287
288
289
    """The contents of the message."""

    name: str
    """An optional name for the participant.

    Provides the model information to differentiate between participants of the
    same role.
    """

290
    tool_call_id: str | None
291
292
    """Tool call that this message is responding to."""

293
    tool_calls: list[ChatCompletionMessageToolCallParam] | None
294
295
    """The tool calls generated by the model, such as function calls."""

296
297
298
    reasoning: str | None
    """The reasoning content for interleaved thinking."""

299
300
301
    tools: list[ChatCompletionFunctionToolParam] | None
    """The tools for developer role."""

302
303
304
    task: str | None
    """Model-specific task marker. Currently passed through for DeepSeek V4."""

305

306
307
308
309
310
ChatCompletionMessageParam: TypeAlias = (
    OpenAIChatCompletionMessageParam
    | CustomChatCompletionMessageParam
    | OpenAIHarmonyMessage
)
311
312


313
# TODO: Make fields ReadOnly once mypy supports it
314
315
316
317
class ConversationMessage(TypedDict, total=False):
    role: Required[str]
    """The role of the message's author."""

318
    content: str | None | list[dict[str, str]]
319
320
    """The contents of the message"""

321
    tool_call_id: str | None
322
323
    """Tool call that this message is responding to."""

324
    name: str | None
325
326
    """The name of the function to call"""

327
    tool_calls: list[ChatCompletionMessageToolCallParam] | None
328
    """The tool calls generated by the model, such as function calls."""
329

330
331
332
333
334
335
    reasoning: str | None
    """The reasoning content for interleaved thinking."""

    reasoning_content: str | None
    """Deprecated: The reasoning content for interleaved thinking."""

336
337
338
    tools: list[ChatCompletionFunctionToolParam] | None
    """The tools for developer role."""

339
340
341
    task: str | None
    """Model-specific task marker. Currently passed through for DeepSeek V4."""

342

343
344
345
# Passed in by user
ChatTemplateContentFormatOption = Literal["auto", "string", "openai"]

346
347
# After resolving "auto"
ChatTemplateContentFormat = Literal["string", "openai"]
348

349

Roger Wang's avatar
Roger Wang committed
350
351
352
ModalityStr = Literal[
    "image", "audio", "video", "image_embeds", "audio_embeds", "vision_chunk"
]
353
354
355
_T = TypeVar("_T")


356
357
358
# Backward compatibility for single item input
class _BatchedSingleItemField(MultiModalSharedField):
    pass
359
360


361
362
363
364
365
366
def _detect_field(
    tensors: list[torch.Tensor],
    mm_processor: BaseMultiModalProcessor,
):
    first_item = tensors[0]
    hidden_size = mm_processor.info.ctx.model_config.get_inputs_embeds_size()
367

368
369
370
371
372
373
374
375
376
377
378
    if (
        len(tensors) == 1
        and first_item.ndim == 3
        and first_item.shape[0] == 1
        and first_item.shape[-1] == hidden_size
    ):
        logger.warning(
            "Batched multi-modal embedding inputs are deprecated for Chat API. "
            "Please pass a separate content part for each multi-modal item."
        )
        return _BatchedSingleItemField(batch_size=1)
379

380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
    first_shape = first_item.shape
    if all(t.shape == first_shape for t in tensors):
        return MultiModalBatchedField()

    size_per_item = [len(tensor) for tensor in tensors]
    slice_idxs = [0, *accumulate(size_per_item)]
    slices = [
        (slice(slice_idxs[i], slice_idxs[i + 1]),) for i in range(len(size_per_item))
    ]
    return MultiModalFlatField(slices=slices)


def _merge_embeds(
    data_items: list[dict[str, "torch.Tensor"]],
    mm_processor: BaseMultiModalProcessor,
):
    if not data_items:
        return {}

    first_keys = set(data_items[0].keys())
    if any(set(item.keys()) != first_keys for item in data_items[1:]):
        raise ValueError(
            "All dictionaries in the list of embeddings must have the same keys."
        )
404

405
406
407
408
409
410
411
412
    fields = {
        key: _detect_field([item[key] for item in data_items], mm_processor)
        for key in first_keys
    }
    data_merged = {
        key: field._reduce_data([item[key] for item in data_items], pin_memory=False)
        for key, field in fields.items()
    }
413

414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
    try:
        # TODO: Support per-request mm_processor_kwargs
        parsed_configs = mm_processor._get_mm_fields_config(
            transformers.BatchFeature(data_merged),
            {},
        )
        parsed_fields = {key: parsed_configs[key].field for key in first_keys}
        keys_to_update = [
            key
            for key in first_keys
            if (
                fields[key] != parsed_fields[key]
                and not isinstance(fields[key], _BatchedSingleItemField)
            )
        ]
429

430
431
432
        for key in keys_to_update:
            data_merged[key] = parsed_fields[key]._reduce_data(
                [item[key] for item in data_items], pin_memory=False
433
            )
434
435
436
437
438
    except Exception:
        logger.exception(
            "Error when parsing merged embeddings. "
            "Falling back to auto-detected fields."
        )
439

440
    return data_merged
441

442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462

def _get_embeds_data(
    modality: str,
    data_items: list[Any],
    mm_processor: BaseMultiModalProcessor,
):
    if len(data_items) == 0:
        return data_items

    if all(item is None for item in data_items):
        return data_items

    if is_list_of(data_items, torch.Tensor):
        embeds_key = f"{modality}_embeds"
        dict_items = [{embeds_key: item} for item in data_items]
        return _merge_embeds(dict_items, mm_processor)[embeds_key]

    if is_list_of(data_items, dict):
        return _merge_embeds(data_items, mm_processor)

    raise NotImplementedError(type(data_items))
463
464


465
class BaseMultiModalItemTracker(ABC, Generic[_T]):
466
467
468
469
470
471
    """
    Tracks multi-modal items in a given request and ensures that the number
    of multi-modal items in a given request does not exceed the configured
    maximum per prompt.
    """

472
473
474
475
476
    def __init__(
        self,
        model_config: ModelConfig,
        media_io_kwargs: dict[str, dict[str, Any]] | None = None,
    ):
477
478
        super().__init__()

479
        self._model_config = model_config
480
        self._media_io_kwargs = media_io_kwargs
481

482
        self._items_by_modality = defaultdict[str, list[_T]](list)
Roger Wang's avatar
Roger Wang committed
483
484
485
486
487
488
489
        # Track original modality for each vision_chunk item (image or video)
        self._modality_order = defaultdict[str, list[str]](list)

    @cached_property
    def use_unified_vision_chunk_modality(self) -> bool:
        """Check if model uses unified vision_chunk modality for images/videos."""
        return getattr(self._model_config.hf_config, "use_unified_vision_chunk", False)
490

491
    @property
492
493
    def model_config(self) -> ModelConfig:
        return self._model_config
494

495
    @cached_property
496
    def model_cls(self) -> type[SupportsMultiModal]:
497
        from vllm.model_executor.model_loader import get_model_cls
498

499
        model_cls = get_model_cls(self.model_config)
500
        return cast(type[SupportsMultiModal], model_cls)
501

502
503
504
505
506
507
508
509
    @property
    def media_io_kwargs(self) -> dict[str, dict[str, Any]] | None:
        return self._media_io_kwargs or (
            self._model_config.multimodal_config.media_io_kwargs
            if self._model_config.multimodal_config
            else None
        )

510
511
    @property
    def allowed_local_media_path(self):
512
        return self._model_config.allowed_local_media_path
513

514
515
    @property
    def allowed_media_domains(self):
516
        return self._model_config.allowed_media_domains
517

518
519
520
521
    @property
    def mm_registry(self):
        return MULTIMODAL_REGISTRY

522
523
    @cached_property
    def mm_processor(self):
524
        return self.mm_registry.create_processor(self.model_config)
525

526
    def add(self, modality: ModalityStr, item: _T) -> str | None:
527
528
529
        """
        Add a multi-modal item to the current prompt and returns the
        placeholder string to use, if any.
530
531

        An optional uuid can be added which serves as a unique identifier of the
532
        media.
533
        """
534
        input_modality = modality.replace("_embeds", "")
Roger Wang's avatar
Roger Wang committed
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
        original_modality = modality
        use_vision_chunk = (
            self.use_unified_vision_chunk_modality
            and original_modality in ["video", "image"]
        )

        # If use_unified_vision_chunk_modality is enabled,
        # map image/video to vision_chunk
        if use_vision_chunk:
            # To avoid validation fail
            # because models with use_unified_vision_chunk_modality=True
            # will only accept vision_chunk modality.
            input_modality = "vision_chunk"
            num_items = len(self._items_by_modality[input_modality]) + 1
        else:
            num_items = len(self._items_by_modality[original_modality]) + 1
551

552
553
554
555
556
557
558
559
560
561
562
        mm_config = self.model_config.multimodal_config
        if (
            mm_config is not None
            and mm_config.enable_mm_embeds
            and mm_config.get_limit_per_prompt(input_modality) == 0
            and original_modality.endswith("_embeds")
        ):
            # Skip validation: embeddings bypass limit when enable_mm_embeds=True
            pass
        else:
            self.mm_processor.info.validate_num_items(input_modality, num_items)
563

Roger Wang's avatar
Roger Wang committed
564
565
566
567
568
569
        # Track original modality for vision_chunk items
        if use_vision_chunk:
            self._items_by_modality[input_modality].append(item)  # type: ignore
            self._modality_order["vision_chunk"].append(original_modality)
        else:
            self._items_by_modality[original_modality].append(item)
570

571
        return self.model_cls.get_placeholder_str(modality, num_items)
572
573

    @abstractmethod
574
575
576
    def create_parser(
        self, mm_processor_kwargs: dict[str, Any] | None = None
    ) -> "BaseMultiModalContentParser":
577
578
579
        raise NotImplementedError


580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
def _resolve_vision_chunk_items(
    vision_chunk_items: list[tuple[object, str | None]],
    mm_processor: BaseMultiModalProcessor,
    vision_chunks_modality_order: list[str],
):
    # Process vision_chunk items - extract from (data, modality) tuples
    # and convert to VisionChunk types with proper UUID handling
    vision_chunks_uuids = [uuid for data, uuid in vision_chunk_items]

    assert len(vision_chunk_items) == len(vision_chunks_modality_order), (
        f"vision_chunk items ({len(vision_chunk_items)}) and "
        f"modality_order ({len(vision_chunks_modality_order)}) must have same length"
    )

    processed_chunks: list[VisionChunk] = []
    video_idx = 0
    for inner_modality, (data, uuid) in zip(
        vision_chunks_modality_order, vision_chunk_items
    ):
        if inner_modality == "image":
            # Cast data to proper type for image
            # Use .media (PIL.Image) directly to avoid redundant
            # bytes→PIL conversion in media_processor
            if hasattr(data, "media"):
                image_data = data.media  # type: ignore[union-attr]
                processed_chunks.append(
                    VisionChunkImage(type="image", image=image_data, uuid=uuid)
                )
            else:
                processed_chunks.append(data)  # type: ignore[arg-type]
        elif inner_modality == "video":
            # For video, we may need to split into chunks
            # if processor supports it
            # For now, just wrap as a video chunk placeholder
            if hasattr(mm_processor, "split_video_chunks") and data is not None:
                try:
                    video_uuid = uuid or random_uuid()
                    # video await result is (video_data, video_meta) tuple
                    if isinstance(data, tuple) and len(data) >= 1:
                        video_data = data[0]
                    else:
                        video_data = data
                    video_chunks = mm_processor.split_video_chunks(video_data)
                    for i, vc in enumerate(video_chunks):
                        processed_chunks.append(
                            VisionChunkVideo(
                                type="video_chunk",
                                video_chunk=vc["video_chunk"],
                                uuid=f"{video_uuid}-{i}",
                                video_idx=video_idx,
                                prompt=vc["prompt"],
                            )
                        )
                    video_idx += 1
                except Exception as e:
                    logger.warning("Failed to split video chunks: %s", e)
                    processed_chunks.append(data)  # type: ignore[arg-type]
            else:
                processed_chunks.append(data)  # type: ignore[arg-type]
    return processed_chunks, vision_chunks_uuids


642
643
644
def _resolve_items(
    items_by_modality: dict[str, list[tuple[object, str | None]]],
    mm_processor: BaseMultiModalProcessor,
645
    modality_order: dict[str, list[str]],
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
) -> tuple[MultiModalDataDict, MultiModalUUIDDict]:
    if "image" in items_by_modality and "image_embeds" in items_by_modality:
        raise ValueError("Mixing raw image and embedding inputs is not allowed")
    if "audio" in items_by_modality and "audio_embeds" in items_by_modality:
        raise ValueError("Mixing raw audio and embedding inputs is not allowed")

    mm_data = {}
    mm_uuids = {}
    if "image_embeds" in items_by_modality:
        mm_data["image"] = _get_embeds_data(
            "image",
            [data for data, uuid in items_by_modality["image_embeds"]],
            mm_processor,
        )
        mm_uuids["image"] = [uuid for data, uuid in items_by_modality["image_embeds"]]
    if "image" in items_by_modality:
        mm_data["image"] = [data for data, uuid in items_by_modality["image"]]
        mm_uuids["image"] = [uuid for data, uuid in items_by_modality["image"]]
    if "audio_embeds" in items_by_modality:
        mm_data["audio"] = _get_embeds_data(
            "audio",
            [data for data, uuid in items_by_modality["audio_embeds"]],
            mm_processor,
        )
        mm_uuids["audio"] = [uuid for data, uuid in items_by_modality["audio_embeds"]]
    if "audio" in items_by_modality:
        mm_data["audio"] = [data for data, uuid in items_by_modality["audio"]]
        mm_uuids["audio"] = [uuid for data, uuid in items_by_modality["audio"]]
    if "video" in items_by_modality:
        mm_data["video"] = [data for data, uuid in items_by_modality["video"]]
        mm_uuids["video"] = [uuid for data, uuid in items_by_modality["video"]]
Roger Wang's avatar
Roger Wang committed
677
678
679
    if "vision_chunk" in items_by_modality:
        # Process vision_chunk items - extract from (data, modality) tuples
        # and convert to VisionChunk types with proper UUID handling
680
681
682
683
        processed_chunks, vision_chunk_uuids = _resolve_vision_chunk_items(
            items_by_modality["vision_chunk"],
            mm_processor,
            modality_order.get("vision_chunk", []),
Roger Wang's avatar
Roger Wang committed
684
685
        )
        mm_data["vision_chunk"] = processed_chunks
686
        mm_uuids["vision_chunk"] = vision_chunk_uuids
687
688
689
690
691
692
693
694

    return mm_data, mm_uuids


class MultiModalItemTracker(BaseMultiModalItemTracker[tuple[object, str | None]]):
    def resolve_items(
        self,
    ) -> tuple[MultiModalDataDict | None, MultiModalUUIDDict | None]:
695
        if not self._items_by_modality:
696
697
            return None, None

Roger Wang's avatar
Roger Wang committed
698
699
700
        return _resolve_items(
            dict(self._items_by_modality), self.mm_processor, self._modality_order
        )
701

702
703
704
705
    def create_parser(
        self, mm_processor_kwargs: dict[str, Any] | None = None
    ) -> "BaseMultiModalContentParser":
        return MultiModalContentParser(self, mm_processor_kwargs=mm_processor_kwargs)
706
707


708
709
710
711
712
713
class AsyncMultiModalItemTracker(
    BaseMultiModalItemTracker[Awaitable[tuple[object, str | None]]]
):
    async def resolve_items(
        self,
    ) -> tuple[MultiModalDataDict | None, MultiModalUUIDDict | None]:
714
        if not self._items_by_modality:
715
            return None, None
716

717
        resolved_items_by_modality = {
718
            modality: await asyncio.gather(*coros)
719
            for modality, coros in self._items_by_modality.items()
720
        }
721

Roger Wang's avatar
Roger Wang committed
722
723
724
        return _resolve_items(
            resolved_items_by_modality, self.mm_processor, self._modality_order
        )
725

726
727
728
729
730
731
    def create_parser(
        self, mm_processor_kwargs: dict[str, Any] | None = None
    ) -> "BaseMultiModalContentParser":
        return AsyncMultiModalContentParser(
            self, mm_processor_kwargs=mm_processor_kwargs
        )
732
733
734
735
736
737


class BaseMultiModalContentParser(ABC):
    def __init__(self) -> None:
        super().__init__()

738
        # stores model placeholders list with corresponding
739
740
741
742
743
744
745
        # general MM placeholder:
        # {
        #   "<##IMAGE##>": ["<image>", "<image>", "<image>"],
        #   "<##AUDIO##>": ["<audio>", "<audio>"]
        # }
        self._placeholder_storage: dict[str, list] = defaultdict(list)

746
    def _add_placeholder(self, modality: ModalityStr, placeholder: str | None):
747
        mod_placeholder = MODALITY_PLACEHOLDERS_MAP[modality]
748
        if placeholder:
749
            self._placeholder_storage[mod_placeholder].append(placeholder)
750

751
752
    def mm_placeholder_storage(self) -> dict[str, list]:
        return dict(self._placeholder_storage)
753
754

    @abstractmethod
755
    def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
756
757
        raise NotImplementedError

758
    @abstractmethod
759
    def parse_image_embeds(
760
        self,
761
762
        image_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
763
    ) -> None:
764
765
        raise NotImplementedError

766
    @abstractmethod
767
    def parse_image_pil(
768
        self, image_pil: Image.Image | None, uuid: str | None = None
769
    ) -> None:
770
771
        raise NotImplementedError

772
    @abstractmethod
773
    def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None:
774
775
        raise NotImplementedError

776
    @abstractmethod
777
    def parse_input_audio(
778
        self, input_audio: InputAudio | None, uuid: str | None = None
779
    ) -> None:
780
781
        raise NotImplementedError

782
783
784
785
786
787
788
789
    @abstractmethod
    def parse_audio_embeds(
        self,
        audio_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
    ) -> None:
        raise NotImplementedError

790
    @abstractmethod
791
    def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
792
793
        raise NotImplementedError

794
795

class MultiModalContentParser(BaseMultiModalContentParser):
796
797
798
799
800
    def __init__(
        self,
        tracker: MultiModalItemTracker,
        mm_processor_kwargs: dict[str, Any] | None = None,
    ) -> None:
801
802
803
        super().__init__()

        self._tracker = tracker
804

805
806
        self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load(
            envs.VLLM_MEDIA_CONNECTOR,
807
            media_io_kwargs=tracker.media_io_kwargs,
808
            allowed_local_media_path=tracker.allowed_local_media_path,
809
            allowed_media_domains=tracker.allowed_media_domains,
810
811
        )

812
813
        self._mm_processor_kwargs = mm_processor_kwargs

814
815
    @property
    def model_config(self) -> ModelConfig:
816
        return self._tracker.model_config
817

818
    def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
819
        image = self._connector.fetch_image(image_url) if image_url else None
820

821
        placeholder = self._tracker.add("image", (image, uuid))
822
        self._add_placeholder("image", placeholder)
823

824
    def parse_image_embeds(
825
        self,
826
827
        image_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
828
    ) -> None:
829
830
831
832
833
834
        mm_config = self.model_config.get_multimodal_config()
        if not mm_config.enable_mm_embeds:
            raise ValueError(
                "You must set `--enable-mm-embeds` to input `image_embeds`"
            )

835
836
837
838
839
        if isinstance(image_embeds, dict):
            embeds = {
                k: self._connector.fetch_image_embedding(v)
                for k, v in image_embeds.items()
            }
840
            placeholder = self._tracker.add("image_embeds", (embeds, uuid))
841
842
843

        if isinstance(image_embeds, str):
            embedding = self._connector.fetch_image_embedding(image_embeds)
844
            placeholder = self._tracker.add("image_embeds", (embedding, uuid))
845

846
        if image_embeds is None:
847
            placeholder = self._tracker.add("image_embeds", (None, uuid))
848

849
        self._add_placeholder("image", placeholder)
850

851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
    def parse_audio_embeds(
        self,
        audio_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
    ) -> None:
        mm_config = self.model_config.get_multimodal_config()
        if not mm_config.enable_mm_embeds:
            raise ValueError(
                "You must set `--enable-mm-embeds` to input `audio_embeds`"
            )

        if isinstance(audio_embeds, dict):
            embeds = {
                k: self._connector.fetch_audio_embedding(v)
                for k, v in audio_embeds.items()
            }
867
            placeholder = self._tracker.add("audio_embeds", (embeds, uuid))
868
869
        elif isinstance(audio_embeds, str):
            embedding = self._connector.fetch_audio_embedding(audio_embeds)
870
            placeholder = self._tracker.add("audio_embeds", (embedding, uuid))
871
        else:
872
            placeholder = self._tracker.add("audio_embeds", (None, uuid))
873
874
875

        self._add_placeholder("audio", placeholder)

876
    def parse_image_pil(
877
        self, image_pil: Image.Image | None, uuid: str | None = None
878
    ) -> None:
879
        placeholder = self._tracker.add("image", (image_pil, uuid))
880
        self._add_placeholder("image", placeholder)
881

882
    def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None:
883
        audio = self._connector.fetch_audio(audio_url) if audio_url else None
884

885
        placeholder = self._tracker.add("audio", (audio, uuid))
886
        self._add_placeholder("audio", placeholder)
887

888
    def parse_input_audio(
889
        self, input_audio: InputAudio | None, uuid: str | None = None
890
    ) -> None:
891
892
893
894
895
896
897
898
899
900
        if input_audio:
            audio_data = input_audio.get("data", "")
            audio_format = input_audio.get("format", "")
            if audio_data:
                audio_url = f"data:audio/{audio_format};base64,{audio_data}"
            else:
                # If a UUID is provided, audio data may be empty.
                audio_url = None
        else:
            audio_url = None
901

902
        return self.parse_audio(audio_url, uuid)
903

904
    def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
905
        video = self._connector.fetch_video(video_url=video_url) if video_url else None
906

907
        placeholder = self._tracker.add("video", (video, uuid))
908
        self._add_placeholder("video", placeholder)
909

910
911
912
913
914
915
916
917
918
919
        # Extract audio from video if use_audio_in_video is True
        if (
            video_url
            and self._mm_processor_kwargs
            and self._mm_processor_kwargs.get("use_audio_in_video", False)
        ):
            audio = self._connector.fetch_audio(video_url) if video_url else None
            audio_placeholder = self._tracker.add("audio", (audio, uuid))
            self._add_placeholder("audio", audio_placeholder)

920
921

class AsyncMultiModalContentParser(BaseMultiModalContentParser):
922
923
924
925
926
    def __init__(
        self,
        tracker: AsyncMultiModalItemTracker,
        mm_processor_kwargs: dict[str, Any] | None = None,
    ) -> None:
927
928
929
        super().__init__()

        self._tracker = tracker
930
931
        self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load(
            envs.VLLM_MEDIA_CONNECTOR,
932
            media_io_kwargs=tracker.media_io_kwargs,
933
            allowed_local_media_path=tracker.allowed_local_media_path,
934
            allowed_media_domains=tracker.allowed_media_domains,
935
        )
936
        self._mm_processor_kwargs: dict[str, Any] | None = mm_processor_kwargs
937

938
939
    @property
    def model_config(self) -> ModelConfig:
940
        return self._tracker.model_config
941

942
943
944
945
946
947
    async def _image_with_uuid_async(self, image_url: str | None, uuid: str | None):
        image = (
            await self._connector.fetch_image_async(image_url) if image_url else None
        )
        return image, uuid

948
    def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
949
        coro = self._image_with_uuid_async(image_url, uuid)
950

951
        placeholder = self._tracker.add("image", coro)
952
        self._add_placeholder("image", placeholder)
953

954
    def parse_image_embeds(
955
        self,
956
957
        image_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
958
    ) -> None:
959
960
961
962
963
964
        mm_config = self.model_config.get_multimodal_config()
        if not mm_config.enable_mm_embeds:
            raise ValueError(
                "You must set `--enable-mm-embeds` to input `image_embeds`"
            )

965
966
967
        future = asyncio.Future[
            tuple[torch.Tensor | dict[str, torch.Tensor] | None, str | None]
        ]()
968
969
970
971
972
973

        if isinstance(image_embeds, dict):
            embeds = {
                k: self._connector.fetch_image_embedding(v)
                for k, v in image_embeds.items()
            }
974
            future.set_result((embeds, uuid))
975
976

        if isinstance(image_embeds, str):
977
            embedding = self._connector.fetch_image_embedding(image_embeds)
978
            future.set_result((embedding, uuid))
979

980
        if image_embeds is None:
981
            future.set_result((None, uuid))
982

983
        placeholder = self._tracker.add("image_embeds", future)
984
        self._add_placeholder("image", placeholder)
985

986
987
988
989
990
991
992
993
994
995
996
    def parse_audio_embeds(
        self,
        audio_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
    ) -> None:
        mm_config = self.model_config.get_multimodal_config()
        if not mm_config.enable_mm_embeds:
            raise ValueError(
                "You must set `--enable-mm-embeds` to input `audio_embeds`"
            )

997
998
999
        future = asyncio.Future[
            tuple[torch.Tensor | dict[str, torch.Tensor] | None, str | None]
        ]()
1000
1001
1002
1003
1004
1005

        if isinstance(audio_embeds, dict):
            embeds = {
                k: self._connector.fetch_audio_embedding(v)
                for k, v in audio_embeds.items()
            }
1006
            future.set_result((embeds, uuid))
1007
1008
1009

        if isinstance(audio_embeds, str):
            embedding = self._connector.fetch_audio_embedding(audio_embeds)
1010
            future.set_result((embedding, uuid))
1011
1012

        if audio_embeds is None:
1013
            future.set_result((None, uuid))
1014

1015
        placeholder = self._tracker.add("audio_embeds", future)
1016
1017
        self._add_placeholder("audio", placeholder)

1018
    def parse_image_pil(
1019
1020
1021
        self,
        image_pil: Image.Image | None,
        uuid: str | None = None,
1022
    ) -> None:
1023
        future = asyncio.Future[tuple[Image.Image | None, str | None]]()
1024
        if image_pil:
1025
            future.set_result((image_pil, uuid))
1026
        else:
1027
            future.set_result((None, uuid))
1028

1029
        placeholder = self._tracker.add("image", future)
1030
        self._add_placeholder("image", placeholder)
1031

1032
1033
1034
1035
1036
1037
    async def _audio_with_uuid_async(self, audio_url: str | None, uuid: str | None):
        audio = (
            await self._connector.fetch_audio_async(audio_url) if audio_url else None
        )
        return audio, uuid

1038
    def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None:
1039
        coro = self._audio_with_uuid_async(audio_url, uuid)
1040

1041
        placeholder = self._tracker.add("audio", coro)
1042
        self._add_placeholder("audio", placeholder)
1043

1044
    def parse_input_audio(
1045
        self, input_audio: InputAudio | None, uuid: str | None = None
1046
    ) -> None:
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
        if input_audio:
            audio_data = input_audio.get("data", "")
            audio_format = input_audio.get("format", "")
            if audio_data:
                audio_url = f"data:audio/{audio_format};base64,{audio_data}"
            else:
                # If a UUID is provided, audio data may be empty.
                audio_url = None
        else:
            audio_url = None
1057

1058
        return self.parse_audio(audio_url, uuid)
1059

1060
    async def _video_with_uuid_async(self, video_url: str | None, uuid: str | None):
1061
        video = (
1062
            await self._connector.fetch_video_async(video_url) if video_url else None
1063
        )
1064
1065
1066
1067
        return video, uuid

    def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
        coro = self._video_with_uuid_async(video_url, uuid)
1068

1069
        placeholder = self._tracker.add("video", coro)
1070
        self._add_placeholder("video", placeholder)
1071

1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
        # Extract audio from video if use_audio_in_video is True
        if (
            video_url
            and self._mm_processor_kwargs
            and self._mm_processor_kwargs.get("use_audio_in_video", False)
        ):
            audio_coro = self._audio_with_uuid_async(video_url, uuid)
            audio_placeholder = self._tracker.add("audio", audio_coro)
            self._add_placeholder("audio", audio_placeholder)

1082

1083
1084
1085
1086
1087
1088
1089
@dataclass
class ChatTemplateConfig:
    chat_template: str | None = None
    chat_template_content_format: ChatTemplateContentFormatOption = "auto"
    trust_request_chat_template: bool = False


1090
def validate_chat_template(chat_template: Path | str | None):
1091
1092
1093
1094
1095
    """Raises if the provided chat template appears invalid."""
    if chat_template is None:
        return

    elif isinstance(chat_template, Path) and not chat_template.exists():
1096
        raise FileNotFoundError("the supplied chat template path doesn't exist")
1097
1098
1099

    elif isinstance(chat_template, str):
        JINJA_CHARS = "{}\n"
1100
1101
1102
1103
        if (
            not any(c in chat_template for c in JINJA_CHARS)
            and not Path(chat_template).exists()
        ):
1104
1105
1106
            # Try to find the template in the built-in templates directory
            from vllm.transformers_utils.chat_templates.registry import (
                CHAT_TEMPLATES_DIR,
1107
            )
1108

1109
1110
1111
1112
1113
1114
1115
1116
            builtin_template_path = CHAT_TEMPLATES_DIR / chat_template
            if not builtin_template_path.exists():
                raise ValueError(
                    f"The supplied chat template string ({chat_template}) "
                    f"appears path-like, but doesn't exist! "
                    f"Tried: {chat_template} and {builtin_template_path}"
                )

1117
    else:
1118
        raise TypeError(f"{type(chat_template)} is not a valid chat template type")
1119
1120


1121
def _load_chat_template(
1122
    chat_template: Path | str | None,
1123
1124
    *,
    is_literal: bool = False,
1125
) -> str | None:
1126
1127
    if chat_template is None:
        return None
1128
1129
1130

    if is_literal:
        if isinstance(chat_template, Path):
1131
1132
1133
            raise TypeError(
                "chat_template is expected to be read directly from its value"
            )
1134

1135
        return chat_template
1136

1137
    try:
1138
        with open(chat_template) as f:
1139
            return f.read()
1140
    except OSError as e:
1141
1142
1143
        if isinstance(chat_template, Path):
            raise

1144
1145
        JINJA_CHARS = "{}\n"
        if not any(c in chat_template for c in JINJA_CHARS):
1146
1147
1148
            # Try to load from the built-in templates directory
            from vllm.transformers_utils.chat_templates.registry import (
                CHAT_TEMPLATES_DIR,
1149
            )
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162

            builtin_template_path = CHAT_TEMPLATES_DIR / chat_template
            try:
                with open(builtin_template_path) as f:
                    return f.read()
            except OSError:
                msg = (
                    f"The supplied chat template ({chat_template}) "
                    f"looks like a file path, but it failed to be opened. "
                    f"Tried: {chat_template} and {builtin_template_path}. "
                    f"Reason: {e}"
                )
                raise ValueError(msg) from e
1163

1164
1165
        # If opening a file fails, set chat template to be args to
        # ensure we decode so our escape are interpreted correctly
1166
1167
1168
1169
1170
1171
1172
        return _load_chat_template(chat_template, is_literal=True)


_cached_load_chat_template = lru_cache(_load_chat_template)


def load_chat_template(
1173
    chat_template: Path | str | None,
1174
1175
    *,
    is_literal: bool = False,
1176
) -> str | None:
1177
    return _cached_load_chat_template(chat_template, is_literal=is_literal)
1178
1179


1180
1181
1182
def _get_interleaved_text_prompt(
    placeholder_storage: dict[str, list], texts: list[str]
) -> str:
1183
1184
1185
1186
1187
1188
1189
    for idx, elem in enumerate(texts):
        if elem in placeholder_storage:
            texts[idx] = placeholder_storage[elem].pop(0)

    return "\n".join(texts)


1190
# TODO: Let user specify how to insert multimodal tokens into prompt
1191
# (similar to chat template)
1192
1193
1194
1195
def _get_full_multimodal_text_prompt(
    placeholder_storage: dict[str, list],
    texts: list[str],
    interleave_strings: bool,
1196
    multimodal_content_part_separator: str = "\n",
1197
) -> str:
1198
    """Combine multimodal prompts for a multimodal language model."""
1199

1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
    # flatten storage to make it looks like
    # {
    #   "<|image|>": 2,
    #   "<|audio|>": 1
    # }
    placeholder_counts = Counter(
        [v for elem in placeholder_storage.values() for v in elem]
    )

    if interleave_strings:
        text_prompt = _get_interleaved_text_prompt(placeholder_storage, texts)
    else:
        text_prompt = "\n".join(texts)

    # Pass interleaved text further in case the user used image placeholders
    # himself, but forgot to disable the 'interleave_strings' flag

1217
    # Look through the text prompt to check for missing placeholders
1218
    missing_placeholders: list[str] = []
1219
1220
1221
1222
1223
    for placeholder in placeholder_counts:
        # For any existing placeholder in the text prompt, we leave it as is
        placeholder_counts[placeholder] -= text_prompt.count(placeholder)

        if placeholder_counts[placeholder] < 0:
1224
1225
1226
1227
            logger.error(
                "Placeholder count is negative! "
                "Ensure that the 'interleave_strings' flag is disabled "
                "(current value: %s) "
1228
1229
                "when manually placing image placeholders.",
                interleave_strings,
1230
1231
            )
            logger.debug("Input prompt: %s", text_prompt)
1232
1233
            raise ValueError(
                f"Found more '{placeholder}' placeholders in input prompt than "
1234
1235
                "actual multimodal data items."
            )
1236

1237
        missing_placeholders.extend([placeholder] * placeholder_counts[placeholder])
1238

1239
1240
    # NOTE: Default behaviour: we always add missing placeholders
    # at the front of the prompt, if interleave_strings=False
1241
    if text_prompt:
1242
1243
1244
        return multimodal_content_part_separator.join(
            missing_placeholders + [text_prompt]
        )
1245
    else:
1246
        return multimodal_content_part_separator.join(missing_placeholders)
1247
1248


1249
1250
# No need to validate using Pydantic again
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
1251
_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
1252
_AudioEmbedsParser = partial(cast, ChatCompletionContentPartAudioEmbedsParam)
1253
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
1254
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
1255
_PILImageParser = partial(cast, CustomChatCompletionContentPILImageParam)
Julien Denize's avatar
Julien Denize committed
1256
_ThinkParser = partial(cast, CustomThinkCompletionContentParam)
1257
1258
1259
1260
# Need to validate url objects
_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam).validate_python
_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python
_VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python
1261

1262
_ResponsesInputImageParser = TypeAdapter(ResponseInputImageParam).validate_python
1263
_ContentPart: TypeAlias = str | dict[str, str] | InputAudio | PILImage
1264

1265
# Define a mapping from part types to their corresponding parsing functions.
1266
MM_PARSER_MAP: dict[
1267
1268
1269
    str,
    Callable[[ChatCompletionContentPartParam], _ContentPart],
] = {
1270
1271
1272
    "text": lambda part: _TextParser(part).get("text", None),
    "thinking": lambda part: _ThinkParser(part).get("thinking", None),
    "input_text": lambda part: _TextParser(part).get("text", None),
1273
    "output_text": lambda part: _TextParser(part).get("text", None),
1274
1275
1276
    "input_image": lambda part: _ResponsesInputImageParser(part).get("image_url", None),
    "image_url": lambda part: _ImageParser(part).get("image_url", {}).get("url", None),
    "image_embeds": lambda part: _ImageEmbedsParser(part).get("image_embeds", None),
1277
    "audio_embeds": lambda part: _AudioEmbedsParser(part).get("audio_embeds", None),
1278
    "image_pil": lambda part: _PILImageParser(part).get("image_pil", None),
1279
1280
    "audio_url": lambda part: _AudioParser(part).get("audio_url", {}).get("url", None),
    "input_audio": lambda part: _InputAudioParser(part).get("input_audio", None),
1281
    "refusal": lambda part: _RefusalParser(part).get("refusal", None),
1282
    "video_url": lambda part: _VideoParser(part).get("video_url", {}).get("url", None),
1283
1284
1285
1286
}


def _parse_chat_message_content_mm_part(
1287
1288
    part: ChatCompletionContentPartParam,
) -> tuple[str, _ContentPart]:
1289
    """
1290
    Parses a given multi-modal content part based on its type.
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303

    Args:
        part: A dict containing the content part, with a potential 'type' field.

    Returns:
        A tuple (part_type, content) where:
        - part_type: Type of the part (e.g., 'text', 'image_url').
        - content: Parsed content (e.g., text, image URL).

    Raises:
        ValueError: If the 'type' field is missing and no direct URL is found.
    """
    assert isinstance(
1304
1305
        part, dict
    )  # This is needed to avoid mypy errors: part.get() from str
1306
    part_type = part.get("type", None)
1307
    uuid = part.get("uuid", None)
1308

1309
    if isinstance(part_type, str) and part_type in MM_PARSER_MAP and uuid is None:  # noqa: E501
1310
1311
1312
        content = MM_PARSER_MAP[part_type](part)

        # Special case for 'image_url.detail'
1313
1314
        # We only support 'auto', which is the default
        if part_type == "image_url" and part.get("detail", "auto") != "auto":
1315
            logger.warning(
1316
                "'image_url.detail' is currently not supported and will be ignored."
1317
            )
1318
1319
1320
1321

        return part_type, content

    # Handle missing 'type' but provided direct URL fields.
1322
    # 'type' is required field by pydantic
1323
1324
    if part_type is None or uuid is not None:
        if "image_url" in part:
1325
            image_params = cast(CustomChatCompletionContentSimpleImageParam, part)
1326
1327
1328
1329
1330
1331
1332
1333
            image_url = image_params.get("image_url", None)
            if isinstance(image_url, dict):
                # Can potentially happen if user provides a uuid
                # with url as a dict of {"url": url}
                image_url = image_url.get("url", None)
            return "image_url", image_url
        if "image_pil" in part:
            # "image_pil" could be None if UUID is provided.
1334
            image_params = cast(  # type: ignore
1335
1336
1337
1338
1339
1340
                CustomChatCompletionContentPILImageParam, part
            )
            image_pil = image_params.get("image_pil", None)
            return "image_pil", image_pil
        if "image_embeds" in part:
            # "image_embeds" could be None if UUID is provided.
1341
            image_params = cast(  # type: ignore
1342
1343
1344
1345
                ChatCompletionContentPartImageEmbedsParam, part
            )
            image_embeds = image_params.get("image_embeds", None)
            return "image_embeds", image_embeds
1346
1347
1348
1349
1350
1351
1352
        if "audio_embeds" in part:
            # "audio_embeds" could be None if UUID is provided.
            audio_params = cast(  # type: ignore[assignment]
                ChatCompletionContentPartAudioEmbedsParam, part
            )
            audio_embeds = audio_params.get("audio_embeds", None)
            return "audio_embeds", audio_embeds
1353
        if "audio_url" in part:
1354
1355
1356
            audio_params = cast(  # type: ignore[assignment]
                CustomChatCompletionContentSimpleAudioParam, part
            )
1357
1358
1359
1360
1361
1362
            audio_url = audio_params.get("audio_url", None)
            if isinstance(audio_url, dict):
                # Can potentially happen if user provides a uuid
                # with url as a dict of {"url": url}
                audio_url = audio_url.get("url", None)
            return "audio_url", audio_url
1363
        if part.get("input_audio") is not None:
1364
            input_audio_params = cast(dict[str, str], part)
1365
            return "input_audio", input_audio_params
1366
        if "video_url" in part:
1367
            video_params = cast(CustomChatCompletionContentSimpleVideoParam, part)
1368
1369
1370
1371
1372
1373
            video_url = video_params.get("video_url", None)
            if isinstance(video_url, dict):
                # Can potentially happen if user provides a uuid
                # with url as a dict of {"url": url}
                video_url = video_url.get("url", None)
            return "video_url", video_url
1374
1375
1376
1377
1378
1379
1380
1381
        # Raise an error if no 'type' or direct URL is found.
        raise ValueError("Missing 'type' field in multimodal part.")

    if not isinstance(part_type, str):
        raise ValueError("Invalid 'type' field in multimodal part.")
    return part_type, "unknown part_type content"


1382
PART_TYPES_TO_SKIP_NONE_CONTENT = (
1383
1384
1385
    "text",
    "refusal",
)
1386

1387

1388
1389
1390
def _parse_chat_message_content_parts(
    role: str,
    parts: Iterable[ChatCompletionContentPartParam],
1391
    mm_tracker: BaseMultiModalItemTracker,
1392
1393
    *,
    wrap_dicts: bool,
1394
    interleave_strings: bool,
1395
    mm_processor_kwargs: dict[str, Any] | None = None,
1396
    multimodal_content_part_separator="\n",
1397
) -> list[ConversationMessage]:
1398
    content = list[_ContentPart]()
1399

1400
    mm_parser = mm_tracker.create_parser(mm_processor_kwargs=mm_processor_kwargs)
1401
1402

    for part in parts:
1403
        parse_res = _parse_chat_message_content_part(
1404
1405
1406
            part,
            mm_parser,
            wrap_dicts=wrap_dicts,
1407
            interleave_strings=interleave_strings,
1408
        )
1409
1410
        if parse_res:
            content.append(parse_res)
1411

1412
    if wrap_dicts:
1413
        # Parsing wraps images and texts as interleaved dictionaries
1414
        return [ConversationMessage(role=role, content=content)]  # type: ignore
1415
    texts = cast(list[str], content)
1416
1417
    mm_placeholder_storage = mm_parser.mm_placeholder_storage()
    if mm_placeholder_storage:
1418
        text_prompt = _get_full_multimodal_text_prompt(
1419
1420
1421
1422
            mm_placeholder_storage,
            texts,
            interleave_strings,
            multimodal_content_part_separator=multimodal_content_part_separator,
1423
        )
1424
1425
1426
    else:
        text_prompt = "\n".join(texts)

1427
1428
1429
1430
    return [ConversationMessage(role=role, content=text_prompt)]


def _parse_chat_message_content_part(
1431
1432
1433
1434
    part: ChatCompletionContentPartParam,
    mm_parser: BaseMultiModalContentParser,
    *,
    wrap_dicts: bool,
1435
    interleave_strings: bool,
1436
) -> _ContentPart | None:
1437
1438
1439
1440
1441
1442
1443
1444
    """Parses a single part of a conversation. If wrap_dicts is True,
    structured dictionary pieces for texts and images will be
    wrapped in dictionaries, i.e., {"type": "text", "text", ...} and
    {"type": "image"}, respectively. Otherwise multimodal data will be
    handled by mm_parser, and texts will be returned as strings to be joined
    with multimodal placeholders.
    """
    if isinstance(part, str):  # Handle plain text parts
1445
1446
        if wrap_dicts:
            return {"type": "text", "text": part}
1447
        return part
1448
1449
    # Handle structured dictionary parts
    part_type, content = _parse_chat_message_content_mm_part(part)
1450
    # if part_type is text/refusal/image_url/audio_url/video_url/input_audio but
1451
    # content is None, log a warning and skip
1452
    if part_type in PART_TYPES_TO_SKIP_NONE_CONTENT and content is None:
1453
        logger.warning(
1454
            "Skipping multimodal part '%s' (type: '%s') "
1455
1456
1457
1458
            "with empty / unparsable content.",
            part,
            part_type,
        )
1459
1460
        return None

1461
    if part_type in ("text", "input_text", "output_text", "refusal", "thinking"):
1462
1463
        str_content = cast(str, content)
        if wrap_dicts:
1464
            return {"type": "text", "text": str_content}
1465
1466
        else:
            return str_content
1467

1468
1469
1470
1471
1472
1473
    # For media items, if a user has provided one, use it. Otherwise, insert
    # a placeholder empty uuid.
    uuid = part.get("uuid", None)
    if uuid is not None:
        uuid = str(uuid)

1474
    modality = None
1475
    if part_type == "image_pil":
1476
        image_content = cast(Image.Image, content) if content is not None else None
1477
        mm_parser.parse_image_pil(image_content, uuid)
1478
        modality = "image"
1479
    elif part_type in ("image_url", "input_image"):
1480
        str_content = cast(str, content)
1481
        mm_parser.parse_image(str_content, uuid)
1482
1483
        modality = "image"
    elif part_type == "image_embeds":
1484
        content = cast(str | dict[str, str], content) if content is not None else None
1485
        mm_parser.parse_image_embeds(content, uuid)
1486
        modality = "image"
1487
1488
1489
1490
    elif part_type == "audio_embeds":
        content = cast(str | dict[str, str], content) if content is not None else None
        mm_parser.parse_audio_embeds(content, uuid)
        modality = "audio"
1491
    elif part_type == "audio_url":
1492
        str_content = cast(str, content)
1493
        mm_parser.parse_audio(str_content, uuid)
1494
1495
        modality = "audio"
    elif part_type == "input_audio":
1496
        dict_content = cast(InputAudio, content)
1497
        mm_parser.parse_input_audio(dict_content, uuid)
1498
1499
        modality = "audio"
    elif part_type == "video_url":
1500
        str_content = cast(str, content)
1501
        mm_parser.parse_video(str_content, uuid)
1502
1503
1504
        modality = "video"
    else:
        raise NotImplementedError(f"Unknown part type: {part_type}")
1505

1506
1507
1508
    if wrap_dicts:
        return {"type": modality}
    return MODALITY_PLACEHOLDERS_MAP[modality] if interleave_strings else None
1509
1510


1511
1512
1513
1514
1515
# No need to validate using Pydantic again
_AssistantParser = partial(cast, ChatCompletionAssistantMessageParam)
_ToolParser = partial(cast, ChatCompletionToolMessageParam)


1516
def _parse_chat_message_content(
1517
1518
    message: ChatCompletionMessageParam,
    mm_tracker: BaseMultiModalItemTracker,
1519
    content_format: ChatTemplateContentFormat,
1520
    interleave_strings: bool,
1521
    mm_processor_kwargs: dict[str, Any] | None = None,
1522
) -> list[ConversationMessage]:
1523
1524
    role = message["role"]
    content = message.get("content")
1525
    reasoning = message.get("reasoning")
1526

1527
    if content is None:
1528
1529
        content = []
    elif isinstance(content, str):
1530
        content = [ChatCompletionContentPartTextParam(type="text", text=content)]
1531
    result = _parse_chat_message_content_parts(
1532
1533
        role,
        content,  # type: ignore
1534
        mm_tracker,
1535
        wrap_dicts=(content_format == "openai"),
1536
        interleave_strings=interleave_strings,
1537
        mm_processor_kwargs=mm_processor_kwargs,
1538
    )
1539

1540
    for result_msg in result:
1541
        if role == "assistant":
1542
1543
            parsed_msg = _AssistantParser(message)

1544
1545
1546
            # The 'tool_calls' is not None check ensures compatibility.
            # It's needed only if downstream code doesn't strictly
            # follow the OpenAI spec.
1547
            if "tool_calls" in parsed_msg and parsed_msg["tool_calls"] is not None:
1548
                result_msg["tool_calls"] = list(parsed_msg["tool_calls"])
1549
1550
1551
            # Include reasoning if present for interleaved thinking.
            if reasoning is not None:
                result_msg["reasoning"] = cast(str, reasoning)
1552
1553
1554
                result_msg["reasoning_content"] = cast(
                    str, reasoning
                )  # keep compatibility
1555
1556
1557
1558
        elif role == "tool":
            parsed_msg = _ToolParser(message)
            if "tool_call_id" in parsed_msg:
                result_msg["tool_call_id"] = parsed_msg["tool_call_id"]
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
            # Normalize tool message content from OpenAI array format to plain
            # string. Clients like Claude Code / Cursor send tool results as
            # [{"type": "text", "text": "..."}], but most chat templates only
            # handle string content for tool messages.
            msg_content = result_msg.get("content")
            if isinstance(msg_content, list):
                texts = [
                    item.get("text", "")
                    for item in msg_content
                    if isinstance(item, dict) and item.get("type") == "text"
                ]
                result_msg["content"] = "\n".join(texts) if texts else ""
1571
1572
1573
1574

        if "name" in message and isinstance(message["name"], str):
            result_msg["name"] = message["name"]

1575
1576
1577
        if "task" in message and isinstance(message["task"], str):
            result_msg["task"] = message["task"]

1578
1579
        if role == "developer":
            result_msg["tools"] = message.get("tools", None)
1580
1581
    return result

1582

1583
def _postprocess_messages(messages: list[ConversationMessage]) -> None:
1584
1585
1586
1587
1588
1589
    # per the Transformers docs & maintainers, tool call arguments in
    # assistant-role messages with tool_calls need to be dicts not JSON str -
    # this is how tool-use chat templates will expect them moving forwards
    # so, for messages that have tool_calls, parse the string (which we get
    # from openAI format) to dict
    for message in messages:
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
        if message["role"] == "assistant" and "tool_calls" in message:
            tool_calls = message.get("tool_calls")
            if not isinstance(tool_calls, list):
                continue

            if len(tool_calls) == 0:
                # Drop empty tool_calls to keep templates on the normal assistant path.
                message.pop("tool_calls", None)
                continue

            for item in tool_calls:
1601
1602
                # if arguments is None or empty string, set to {}
                if content := item["function"].get("arguments"):
1603
1604
                    if not isinstance(content, (dict, list)):
                        item["function"]["arguments"] = json.loads(content)
1605
1606
                else:
                    item["function"]["arguments"] = {}
1607
1608


1609
def parse_chat_messages(
1610
    messages: list[ChatCompletionMessageParam],
1611
    model_config: ModelConfig,
1612
    content_format: ChatTemplateContentFormat,
1613
    media_io_kwargs: dict[str, dict[str, Any]] | None = None,
1614
    mm_processor_kwargs: dict[str, Any] | None = None,
1615
1616
) -> tuple[
    list[ConversationMessage],
1617
1618
    MultiModalDataDict | None,
    MultiModalUUIDDict | None,
1619
]:
1620
    conversation: list[ConversationMessage] = []
1621
    mm_tracker = MultiModalItemTracker(model_config, media_io_kwargs=media_io_kwargs)
1622
1623

    for msg in messages:
1624
1625
1626
        sub_messages = _parse_chat_message_content(
            msg,
            mm_tracker,
1627
            content_format,
1628
1629
1630
1631
            interleave_strings=(
                content_format == "string"
                and model_config.multimodal_config is not None
                and model_config.multimodal_config.interleave_mm_strings
1632
            ),
1633
            mm_processor_kwargs=mm_processor_kwargs,
1634
        )
1635

1636
        conversation.extend(sub_messages)
1637

1638
1639
    _postprocess_messages(conversation)

1640
1641
1642
    mm_data, mm_uuids = mm_tracker.resolve_items()

    return conversation, mm_data, mm_uuids
1643
1644


1645
async def parse_chat_messages_async(
1646
    messages: list[ChatCompletionMessageParam],
1647
    model_config: ModelConfig,
1648
    content_format: ChatTemplateContentFormat,
1649
    media_io_kwargs: dict[str, dict[str, Any]] | None = None,
1650
    mm_processor_kwargs: dict[str, Any] | None = None,
1651
1652
) -> tuple[
    list[ConversationMessage],
1653
    MultiModalDataDict | None,
1654
    MultiModalUUIDDict | None,
1655
]:
1656
    conversation: list[ConversationMessage] = []
1657
1658
1659
    mm_tracker = AsyncMultiModalItemTracker(
        model_config, media_io_kwargs=media_io_kwargs
    )
1660
1661

    for msg in messages:
1662
1663
1664
        sub_messages = _parse_chat_message_content(
            msg,
            mm_tracker,
1665
            content_format,
1666
1667
1668
1669
            interleave_strings=(
                content_format == "string"
                and model_config.multimodal_config is not None
                and model_config.multimodal_config.interleave_mm_strings
1670
            ),
1671
            mm_processor_kwargs=mm_processor_kwargs,
1672
        )
1673
1674
1675

        conversation.extend(sub_messages)

1676
1677
    _postprocess_messages(conversation)

1678
1679
1680
    mm_data, mm_uuids = await mm_tracker.resolve_items()

    return conversation, mm_data, mm_uuids
1681

1682

1683
1684
1685
def get_history_tool_calls_cnt(conversation: list[ConversationMessage]):
    idx = 0
    for msg in conversation:
1686
1687
1688
        if msg["role"] == "assistant":
            tool_calls = msg.get("tool_calls")
            idx += len(list(tool_calls)) if tool_calls is not None else 0  # noqa
1689
1690
1691
    return idx


1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
_KIMI_MODEL_TYPES = ("kimi_k2", "kimi_k25")


def get_tool_call_id_type(model_config: ModelConfig) -> str:
    """Return the tool-call ID type for a given model configuration."""
    hf_overrides = getattr(model_config, "hf_overrides", None)
    if model_config.hf_text_config.model_type in _KIMI_MODEL_TYPES or (
        isinstance(hf_overrides, dict)
        and hf_overrides.get("model_type") in _KIMI_MODEL_TYPES
    ):
        return "kimi_k2"
    return "random"


1706
1707
1708
def make_tool_call_id(id_type: str = "random", func_name=None, idx=None):
    if id_type == "kimi_k2":
        return f"functions.{func_name}:{idx}"
1709
1710
1711
    else:
        # by default return random
        return f"chatcmpl-tool-{random_uuid()}"