"vscode:/vscode.git/clone" did not exist on "f1bc989074efe707d858a16f055a693ccbf75726"
chat_utils.py 55.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
import json
6
import warnings
7
from abc import ABC, abstractmethod
8
from collections import Counter, defaultdict
9
from collections.abc import Awaitable, Callable, Iterable
10
from dataclasses import dataclass
11
from functools import cached_property, lru_cache, partial
12
from itertools import accumulate
13
from pathlib import Path
14
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, TypeVar, cast
15

16
from openai.types.chat import (
17
18
19
20
21
    ChatCompletionAssistantMessageParam,
    ChatCompletionContentPartImageParam,
    ChatCompletionContentPartInputAudioParam,
    ChatCompletionContentPartRefusalParam,
    ChatCompletionContentPartTextParam,
22
    ChatCompletionFunctionToolParam,
23
24
25
26
27
28
    ChatCompletionMessageToolCallParam,
    ChatCompletionToolMessageParam,
)
from openai.types.chat import (
    ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam,
)
29
from openai.types.chat import (
30
31
32
    ChatCompletionMessageParam as OpenAIChatCompletionMessageParam,
)
from openai.types.chat.chat_completion_content_part_input_audio_param import InputAudio
33
from openai.types.responses import ResponseInputImageParam
34
from openai_harmony import Message as OpenAIHarmonyMessage
35
36
from PIL import Image
from pydantic import BaseModel, ConfigDict, TypeAdapter
37

38
# pydantic needs the TypedDict from typing_extensions
39
from typing_extensions import Required, TypedDict
40

41
from vllm import envs
42
from vllm.config import ModelConfig
43
from vllm.logger import init_logger
44
from vllm.model_executor.models import SupportsMultiModal
45
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict
46
47
48
49
from vllm.multimodal.inputs import (
    MultiModalBatchedField,
    MultiModalFlatField,
    MultiModalSharedField,
Roger Wang's avatar
Roger Wang committed
50
51
52
    VisionChunk,
    VisionChunkImage,
    VisionChunkVideo,
53
)
54
from vllm.multimodal.media import MEDIA_CONNECTOR_REGISTRY, MediaConnector
55
from vllm.multimodal.processing import BaseMultiModalProcessor
56
from vllm.utils import random_uuid
57
58
59
60
61
from vllm.utils.collection_utils import is_list_of
from vllm.utils.import_utils import LazyLoader

if TYPE_CHECKING:
    import torch
62
    import transformers
63
else:
64
    transformers = LazyLoader("transformers", globals(), "transformers")
65
    torch = LazyLoader("torch", globals(), "torch")
66
67
68

logger = init_logger(__name__)

69

70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def __getattr__(name: str):
    if name == "resolve_hf_chat_template":
        from vllm.renderers.hf import resolve_chat_template

        warnings.warn(
            "`vllm.entrypoints.chat_utils.resolve_hf_chat_template` has been moved to "
            "`vllm.renderers.hf.resolve_chat_template`. "
            "The old name will be removed in v0.16.",
            DeprecationWarning,
            stacklevel=2,
        )

        return resolve_chat_template

    raise AttributeError(f"module {__name__!r} has no attribute {name!r}")


87
88
89
90
91
92
93
94
class ChatTemplateResolutionError(ValueError):
    """Raised when chat template resolution fails.

    This is a subclass of ValueError for backward compatibility with
    existing exception handlers.
    """


95
96
97
98
99
100
MODALITY_PLACEHOLDERS_MAP = {
    "image": "<##IMAGE##>",
    "audio": "<##AUDIO##>",
    "video": "<##VIDEO##>",
}

101

102
103
104
105
106
107
108
109
110
111
112
113
114
115
class AudioURL(TypedDict, total=False):
    url: Required[str]
    """
    Either a URL of the audio or a data URL with base64 encoded audio data.
    """


class ChatCompletionContentPartAudioParam(TypedDict, total=False):
    audio_url: Required[AudioURL]

    type: Required[Literal["audio_url"]]
    """The type of the content part."""


116
class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
117
    image_embeds: str | dict[str, str] | None
118
119
120
121
122
123
124
    """
    The image embeddings. It can be either:
    - A single base64 string.
    - A dictionary where each value is a base64 string.
    """
    type: Required[Literal["image_embeds"]]
    """The type of the content part."""
125
    uuid: str | None
126
127
128
129
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
130
131


132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
class ChatCompletionContentPartAudioEmbedsParam(TypedDict, total=False):
    audio_embeds: str | dict[str, str] | None
    """
    The audio embeddings. It can be either:
    - A single base64 string representing a serialized torch tensor.
    - A dictionary where each value is a base64 string.
    """
    type: Required[Literal["audio_embeds"]]
    """The type of the content part."""
    uuid: str | None
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """


148
149
150
151
152
153
154
155
156
157
158
159
160
161
class VideoURL(TypedDict, total=False):
    url: Required[str]
    """
    Either a URL of the video or a data URL with base64 encoded video data.
    """


class ChatCompletionContentPartVideoParam(TypedDict, total=False):
    video_url: Required[VideoURL]

    type: Required[Literal["video_url"]]
    """The type of the content part."""


162
163
164
165
class PILImage(BaseModel):
    """
    A PIL.Image.Image object.
    """
166

167
168
169
170
171
172
173
174
175
176
177
178
    image_pil: Image.Image
    model_config = ConfigDict(arbitrary_types_allowed=True)


class CustomChatCompletionContentPILImageParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a PIL image.

    Example:
    {
        "image_pil": ImageAsset('cherry_blossom').pil_image
    }
    """
179

180
181
    image_pil: PILImage | None
    uuid: str | None
182
183
184
185
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
186
187


188
189
190
class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain image_url.
    This is supported by OpenAI API, although it is not documented.
191

192
193
194
195
196
    Example:
    {
        "image_url": "https://example.com/image.jpg"
    }
    """
197

198
199
    image_url: str | None
    uuid: str | None
200
201
202
203
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
204
205
206
207


class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain audio_url.
208

209
210
211
212
213
    Example:
    {
        "audio_url": "https://example.com/audio.mp3"
    }
    """
214

215
    audio_url: str | None
216
217


218
219
220
221
222
223
224
225
class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
    """A simpler version of the param that only accepts a plain audio_url.

    Example:
    {
        "video_url": "https://example.com/video.mp4"
    }
    """
226

227
228
    video_url: str | None
    uuid: str | None
229
230
231
232
    """
    User-provided UUID of a media. User must guarantee that it is properly
    generated and unique for different medias.
    """
233
234


Julien Denize's avatar
Julien Denize committed
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
class CustomThinkCompletionContentParam(TypedDict, total=False):
    """A Think Completion Content Param that accepts a plain text and a boolean.

    Example:
    {
        "thinking": "I am thinking about the answer",
        "closed": True,
        "type": "thinking"
    }
    """

    thinking: Required[str]
    """The thinking content."""

    closed: bool
    """Whether the thinking is closed."""

    type: Required[Literal["thinking"]]
    """The thinking type."""


256
257
258
259
260
261
262
263
264
ChatCompletionContentPartParam: TypeAlias = (
    OpenAIChatCompletionContentPartParam
    | ChatCompletionContentPartAudioParam
    | ChatCompletionContentPartInputAudioParam
    | ChatCompletionContentPartVideoParam
    | ChatCompletionContentPartRefusalParam
    | CustomChatCompletionContentPILImageParam
    | CustomChatCompletionContentSimpleImageParam
    | ChatCompletionContentPartImageEmbedsParam
265
    | ChatCompletionContentPartAudioEmbedsParam
266
267
268
269
270
    | CustomChatCompletionContentSimpleAudioParam
    | CustomChatCompletionContentSimpleVideoParam
    | str
    | CustomThinkCompletionContentParam
)
271
272
273
274


class CustomChatCompletionMessageParam(TypedDict, total=False):
    """Enables custom roles in the Chat Completion API."""
275

276
277
278
    role: Required[str]
    """The role of the message's author."""

279
    content: str | list[ChatCompletionContentPartParam]
280
281
282
283
284
285
286
287
288
    """The contents of the message."""

    name: str
    """An optional name for the participant.

    Provides the model information to differentiate between participants of the
    same role.
    """

289
    tool_call_id: str | None
290
291
    """Tool call that this message is responding to."""

292
    tool_calls: Iterable[ChatCompletionMessageToolCallParam] | None
293
294
    """The tool calls generated by the model, such as function calls."""

295
296
297
    reasoning: str | None
    """The reasoning content for interleaved thinking."""

298
299
300
    tools: list[ChatCompletionFunctionToolParam] | None
    """The tools for developer role."""

301

302
303
304
305
306
ChatCompletionMessageParam: TypeAlias = (
    OpenAIChatCompletionMessageParam
    | CustomChatCompletionMessageParam
    | OpenAIHarmonyMessage
)
307
308


309
# TODO: Make fields ReadOnly once mypy supports it
310
311
312
313
class ConversationMessage(TypedDict, total=False):
    role: Required[str]
    """The role of the message's author."""

314
    content: str | None | list[dict[str, str]]
315
316
    """The contents of the message"""

317
    tool_call_id: str | None
318
319
    """Tool call that this message is responding to."""

320
    name: str | None
321
322
    """The name of the function to call"""

323
    tool_calls: Iterable[ChatCompletionMessageToolCallParam] | None
324
    """The tool calls generated by the model, such as function calls."""
325

326
327
328
329
330
331
    reasoning: str | None
    """The reasoning content for interleaved thinking."""

    reasoning_content: str | None
    """Deprecated: The reasoning content for interleaved thinking."""

332
333
334
    tools: list[ChatCompletionFunctionToolParam] | None
    """The tools for developer role."""

335

336
337
338
# Passed in by user
ChatTemplateContentFormatOption = Literal["auto", "string", "openai"]

339
340
# After resolving "auto"
ChatTemplateContentFormat = Literal["string", "openai"]
341

342

Roger Wang's avatar
Roger Wang committed
343
344
345
ModalityStr = Literal[
    "image", "audio", "video", "image_embeds", "audio_embeds", "vision_chunk"
]
346
347
348
_T = TypeVar("_T")


349
350
351
# Backward compatibility for single item input
class _BatchedSingleItemField(MultiModalSharedField):
    pass
352
353


354
355
356
357
358
359
def _detect_field(
    tensors: list[torch.Tensor],
    mm_processor: BaseMultiModalProcessor,
):
    first_item = tensors[0]
    hidden_size = mm_processor.info.ctx.model_config.get_inputs_embeds_size()
360

361
362
363
364
365
366
367
368
369
370
371
    if (
        len(tensors) == 1
        and first_item.ndim == 3
        and first_item.shape[0] == 1
        and first_item.shape[-1] == hidden_size
    ):
        logger.warning(
            "Batched multi-modal embedding inputs are deprecated for Chat API. "
            "Please pass a separate content part for each multi-modal item."
        )
        return _BatchedSingleItemField(batch_size=1)
372

373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
    first_shape = first_item.shape
    if all(t.shape == first_shape for t in tensors):
        return MultiModalBatchedField()

    size_per_item = [len(tensor) for tensor in tensors]
    slice_idxs = [0, *accumulate(size_per_item)]
    slices = [
        (slice(slice_idxs[i], slice_idxs[i + 1]),) for i in range(len(size_per_item))
    ]
    return MultiModalFlatField(slices=slices)


def _merge_embeds(
    data_items: list[dict[str, "torch.Tensor"]],
    mm_processor: BaseMultiModalProcessor,
):
    if not data_items:
        return {}

    first_keys = set(data_items[0].keys())
    if any(set(item.keys()) != first_keys for item in data_items[1:]):
        raise ValueError(
            "All dictionaries in the list of embeddings must have the same keys."
        )
397

398
399
400
401
402
403
404
405
    fields = {
        key: _detect_field([item[key] for item in data_items], mm_processor)
        for key in first_keys
    }
    data_merged = {
        key: field._reduce_data([item[key] for item in data_items], pin_memory=False)
        for key, field in fields.items()
    }
406

407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
    try:
        # TODO: Support per-request mm_processor_kwargs
        parsed_configs = mm_processor._get_mm_fields_config(
            transformers.BatchFeature(data_merged),
            {},
        )
        parsed_fields = {key: parsed_configs[key].field for key in first_keys}
        keys_to_update = [
            key
            for key in first_keys
            if (
                fields[key] != parsed_fields[key]
                and not isinstance(fields[key], _BatchedSingleItemField)
            )
        ]
422

423
424
425
        for key in keys_to_update:
            data_merged[key] = parsed_fields[key]._reduce_data(
                [item[key] for item in data_items], pin_memory=False
426
            )
427
428
429
430
431
    except Exception:
        logger.exception(
            "Error when parsing merged embeddings. "
            "Falling back to auto-detected fields."
        )
432

433
    return data_merged
434

435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455

def _get_embeds_data(
    modality: str,
    data_items: list[Any],
    mm_processor: BaseMultiModalProcessor,
):
    if len(data_items) == 0:
        return data_items

    if all(item is None for item in data_items):
        return data_items

    if is_list_of(data_items, torch.Tensor):
        embeds_key = f"{modality}_embeds"
        dict_items = [{embeds_key: item} for item in data_items]
        return _merge_embeds(dict_items, mm_processor)[embeds_key]

    if is_list_of(data_items, dict):
        return _merge_embeds(data_items, mm_processor)

    raise NotImplementedError(type(data_items))
456
457


458
class BaseMultiModalItemTracker(ABC, Generic[_T]):
459
460
461
462
463
464
    """
    Tracks multi-modal items in a given request and ensures that the number
    of multi-modal items in a given request does not exceed the configured
    maximum per prompt.
    """

465
    def __init__(self, model_config: ModelConfig):
466
467
        super().__init__()

468
        self._model_config = model_config
469

470
        self._items_by_modality = defaultdict[str, list[_T]](list)
Roger Wang's avatar
Roger Wang committed
471
472
473
474
475
476
477
        # Track original modality for each vision_chunk item (image or video)
        self._modality_order = defaultdict[str, list[str]](list)

    @cached_property
    def use_unified_vision_chunk_modality(self) -> bool:
        """Check if model uses unified vision_chunk modality for images/videos."""
        return getattr(self._model_config.hf_config, "use_unified_vision_chunk", False)
478

479
    @property
480
481
    def model_config(self) -> ModelConfig:
        return self._model_config
482

483
    @cached_property
484
    def model_cls(self) -> type[SupportsMultiModal]:
485
        from vllm.model_executor.model_loader import get_model_cls
486

487
        model_cls = get_model_cls(self.model_config)
488
        return cast(type[SupportsMultiModal], model_cls)
489

490
491
    @property
    def allowed_local_media_path(self):
492
        return self._model_config.allowed_local_media_path
493

494
495
    @property
    def allowed_media_domains(self):
496
        return self._model_config.allowed_media_domains
497

498
499
500
501
    @property
    def mm_registry(self):
        return MULTIMODAL_REGISTRY

502
503
    @cached_property
    def mm_processor(self):
504
        return self.mm_registry.create_processor(self.model_config)
505

506
    def add(self, modality: ModalityStr, item: _T) -> str | None:
507
508
509
        """
        Add a multi-modal item to the current prompt and returns the
        placeholder string to use, if any.
510
511

        An optional uuid can be added which serves as a unique identifier of the
512
        media.
513
        """
514
        input_modality = modality.replace("_embeds", "")
Roger Wang's avatar
Roger Wang committed
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
        original_modality = modality
        use_vision_chunk = (
            self.use_unified_vision_chunk_modality
            and original_modality in ["video", "image"]
        )

        # If use_unified_vision_chunk_modality is enabled,
        # map image/video to vision_chunk
        if use_vision_chunk:
            # To avoid validation fail
            # because models with use_unified_vision_chunk_modality=True
            # will only accept vision_chunk modality.
            input_modality = "vision_chunk"
            num_items = len(self._items_by_modality[input_modality]) + 1
        else:
            num_items = len(self._items_by_modality[original_modality]) + 1
531

532
533
534
535
536
537
538
539
540
541
542
        mm_config = self.model_config.multimodal_config
        if (
            mm_config is not None
            and mm_config.enable_mm_embeds
            and mm_config.get_limit_per_prompt(input_modality) == 0
            and original_modality.endswith("_embeds")
        ):
            # Skip validation: embeddings bypass limit when enable_mm_embeds=True
            pass
        else:
            self.mm_processor.info.validate_num_items(input_modality, num_items)
543

Roger Wang's avatar
Roger Wang committed
544
545
546
547
548
549
        # Track original modality for vision_chunk items
        if use_vision_chunk:
            self._items_by_modality[input_modality].append(item)  # type: ignore
            self._modality_order["vision_chunk"].append(original_modality)
        else:
            self._items_by_modality[original_modality].append(item)
550

551
        return self.model_cls.get_placeholder_str(modality, num_items)
552
553
554
555
556
557

    @abstractmethod
    def create_parser(self) -> "BaseMultiModalContentParser":
        raise NotImplementedError


558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
def _resolve_vision_chunk_items(
    vision_chunk_items: list[tuple[object, str | None]],
    mm_processor: BaseMultiModalProcessor,
    vision_chunks_modality_order: list[str],
):
    # Process vision_chunk items - extract from (data, modality) tuples
    # and convert to VisionChunk types with proper UUID handling
    vision_chunks_uuids = [uuid for data, uuid in vision_chunk_items]

    assert len(vision_chunk_items) == len(vision_chunks_modality_order), (
        f"vision_chunk items ({len(vision_chunk_items)}) and "
        f"modality_order ({len(vision_chunks_modality_order)}) must have same length"
    )

    processed_chunks: list[VisionChunk] = []
    video_idx = 0
    for inner_modality, (data, uuid) in zip(
        vision_chunks_modality_order, vision_chunk_items
    ):
        if inner_modality == "image":
            # Cast data to proper type for image
            # Use .media (PIL.Image) directly to avoid redundant
            # bytes→PIL conversion in media_processor
            if hasattr(data, "media"):
                image_data = data.media  # type: ignore[union-attr]
                processed_chunks.append(
                    VisionChunkImage(type="image", image=image_data, uuid=uuid)
                )
            else:
                processed_chunks.append(data)  # type: ignore[arg-type]
        elif inner_modality == "video":
            # For video, we may need to split into chunks
            # if processor supports it
            # For now, just wrap as a video chunk placeholder
            if hasattr(mm_processor, "split_video_chunks") and data is not None:
                try:
                    video_uuid = uuid or random_uuid()
                    # video await result is (video_data, video_meta) tuple
                    if isinstance(data, tuple) and len(data) >= 1:
                        video_data = data[0]
                    else:
                        video_data = data
                    video_chunks = mm_processor.split_video_chunks(video_data)
                    for i, vc in enumerate(video_chunks):
                        processed_chunks.append(
                            VisionChunkVideo(
                                type="video_chunk",
                                video_chunk=vc["video_chunk"],
                                uuid=f"{video_uuid}-{i}",
                                video_idx=video_idx,
                                prompt=vc["prompt"],
                            )
                        )
                    video_idx += 1
                except Exception as e:
                    logger.warning("Failed to split video chunks: %s", e)
                    processed_chunks.append(data)  # type: ignore[arg-type]
            else:
                processed_chunks.append(data)  # type: ignore[arg-type]
    return processed_chunks, vision_chunks_uuids


620
621
622
def _resolve_items(
    items_by_modality: dict[str, list[tuple[object, str | None]]],
    mm_processor: BaseMultiModalProcessor,
623
    modality_order: dict[str, list[str]],
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
) -> tuple[MultiModalDataDict, MultiModalUUIDDict]:
    if "image" in items_by_modality and "image_embeds" in items_by_modality:
        raise ValueError("Mixing raw image and embedding inputs is not allowed")
    if "audio" in items_by_modality and "audio_embeds" in items_by_modality:
        raise ValueError("Mixing raw audio and embedding inputs is not allowed")

    mm_data = {}
    mm_uuids = {}
    if "image_embeds" in items_by_modality:
        mm_data["image"] = _get_embeds_data(
            "image",
            [data for data, uuid in items_by_modality["image_embeds"]],
            mm_processor,
        )
        mm_uuids["image"] = [uuid for data, uuid in items_by_modality["image_embeds"]]
    if "image" in items_by_modality:
        mm_data["image"] = [data for data, uuid in items_by_modality["image"]]
        mm_uuids["image"] = [uuid for data, uuid in items_by_modality["image"]]
    if "audio_embeds" in items_by_modality:
        mm_data["audio"] = _get_embeds_data(
            "audio",
            [data for data, uuid in items_by_modality["audio_embeds"]],
            mm_processor,
        )
        mm_uuids["audio"] = [uuid for data, uuid in items_by_modality["audio_embeds"]]
    if "audio" in items_by_modality:
        mm_data["audio"] = [data for data, uuid in items_by_modality["audio"]]
        mm_uuids["audio"] = [uuid for data, uuid in items_by_modality["audio"]]
    if "video" in items_by_modality:
        mm_data["video"] = [data for data, uuid in items_by_modality["video"]]
        mm_uuids["video"] = [uuid for data, uuid in items_by_modality["video"]]
Roger Wang's avatar
Roger Wang committed
655
656
657
    if "vision_chunk" in items_by_modality:
        # Process vision_chunk items - extract from (data, modality) tuples
        # and convert to VisionChunk types with proper UUID handling
658
659
660
661
        processed_chunks, vision_chunk_uuids = _resolve_vision_chunk_items(
            items_by_modality["vision_chunk"],
            mm_processor,
            modality_order.get("vision_chunk", []),
Roger Wang's avatar
Roger Wang committed
662
663
        )
        mm_data["vision_chunk"] = processed_chunks
664
        mm_uuids["vision_chunk"] = vision_chunk_uuids
665
666
667
668
669
670
671
672

    return mm_data, mm_uuids


class MultiModalItemTracker(BaseMultiModalItemTracker[tuple[object, str | None]]):
    def resolve_items(
        self,
    ) -> tuple[MultiModalDataDict | None, MultiModalUUIDDict | None]:
673
        if not self._items_by_modality:
674
675
            return None, None

Roger Wang's avatar
Roger Wang committed
676
677
678
        return _resolve_items(
            dict(self._items_by_modality), self.mm_processor, self._modality_order
        )
679
680
681
682
683

    def create_parser(self) -> "BaseMultiModalContentParser":
        return MultiModalContentParser(self)


684
685
686
687
688
689
class AsyncMultiModalItemTracker(
    BaseMultiModalItemTracker[Awaitable[tuple[object, str | None]]]
):
    async def resolve_items(
        self,
    ) -> tuple[MultiModalDataDict | None, MultiModalUUIDDict | None]:
690
        if not self._items_by_modality:
691
            return None, None
692

693
        resolved_items_by_modality = {
694
            modality: await asyncio.gather(*coros)
695
            for modality, coros in self._items_by_modality.items()
696
        }
697

Roger Wang's avatar
Roger Wang committed
698
699
700
        return _resolve_items(
            resolved_items_by_modality, self.mm_processor, self._modality_order
        )
701
702
703
704
705
706
707
708
709

    def create_parser(self) -> "BaseMultiModalContentParser":
        return AsyncMultiModalContentParser(self)


class BaseMultiModalContentParser(ABC):
    def __init__(self) -> None:
        super().__init__()

710
        # stores model placeholders list with corresponding
711
712
713
714
715
716
717
        # general MM placeholder:
        # {
        #   "<##IMAGE##>": ["<image>", "<image>", "<image>"],
        #   "<##AUDIO##>": ["<audio>", "<audio>"]
        # }
        self._placeholder_storage: dict[str, list] = defaultdict(list)

718
    def _add_placeholder(self, modality: ModalityStr, placeholder: str | None):
719
        mod_placeholder = MODALITY_PLACEHOLDERS_MAP[modality]
720
        if placeholder:
721
            self._placeholder_storage[mod_placeholder].append(placeholder)
722

723
724
    def mm_placeholder_storage(self) -> dict[str, list]:
        return dict(self._placeholder_storage)
725
726

    @abstractmethod
727
    def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
728
729
        raise NotImplementedError

730
    @abstractmethod
731
    def parse_image_embeds(
732
        self,
733
734
        image_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
735
    ) -> None:
736
737
        raise NotImplementedError

738
    @abstractmethod
739
    def parse_image_pil(
740
        self, image_pil: Image.Image | None, uuid: str | None = None
741
    ) -> None:
742
743
        raise NotImplementedError

744
    @abstractmethod
745
    def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None:
746
747
        raise NotImplementedError

748
    @abstractmethod
749
    def parse_input_audio(
750
        self, input_audio: InputAudio | None, uuid: str | None = None
751
    ) -> None:
752
753
        raise NotImplementedError

754
755
756
757
758
759
760
761
    @abstractmethod
    def parse_audio_embeds(
        self,
        audio_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
    ) -> None:
        raise NotImplementedError

762
    @abstractmethod
763
    def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
764
765
        raise NotImplementedError

766
767
768
769
770
771

class MultiModalContentParser(BaseMultiModalContentParser):
    def __init__(self, tracker: MultiModalItemTracker) -> None:
        super().__init__()

        self._tracker = tracker
772
773
774
        multimodal_config = self._tracker.model_config.multimodal_config
        media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None)

775
776
        self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load(
            envs.VLLM_MEDIA_CONNECTOR,
777
            media_io_kwargs=media_io_kwargs,
778
            allowed_local_media_path=tracker.allowed_local_media_path,
779
            allowed_media_domains=tracker.allowed_media_domains,
780
781
        )

782
783
    @property
    def model_config(self) -> ModelConfig:
784
        return self._tracker.model_config
785

786
    def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
787
        image = self._connector.fetch_image(image_url) if image_url else None
788

789
        placeholder = self._tracker.add("image", (image, uuid))
790
        self._add_placeholder("image", placeholder)
791

792
    def parse_image_embeds(
793
        self,
794
795
        image_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
796
    ) -> None:
797
798
799
800
801
802
        mm_config = self.model_config.get_multimodal_config()
        if not mm_config.enable_mm_embeds:
            raise ValueError(
                "You must set `--enable-mm-embeds` to input `image_embeds`"
            )

803
804
805
806
807
        if isinstance(image_embeds, dict):
            embeds = {
                k: self._connector.fetch_image_embedding(v)
                for k, v in image_embeds.items()
            }
808
            placeholder = self._tracker.add("image_embeds", (embeds, uuid))
809
810
811

        if isinstance(image_embeds, str):
            embedding = self._connector.fetch_image_embedding(image_embeds)
812
            placeholder = self._tracker.add("image_embeds", (embedding, uuid))
813

814
        if image_embeds is None:
815
            placeholder = self._tracker.add("image_embeds", (None, uuid))
816

817
        self._add_placeholder("image", placeholder)
818

819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
    def parse_audio_embeds(
        self,
        audio_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
    ) -> None:
        mm_config = self.model_config.get_multimodal_config()
        if not mm_config.enable_mm_embeds:
            raise ValueError(
                "You must set `--enable-mm-embeds` to input `audio_embeds`"
            )

        if isinstance(audio_embeds, dict):
            embeds = {
                k: self._connector.fetch_audio_embedding(v)
                for k, v in audio_embeds.items()
            }
835
            placeholder = self._tracker.add("audio_embeds", (embeds, uuid))
836
837
        elif isinstance(audio_embeds, str):
            embedding = self._connector.fetch_audio_embedding(audio_embeds)
838
            placeholder = self._tracker.add("audio_embeds", (embedding, uuid))
839
        else:
840
            placeholder = self._tracker.add("audio_embeds", (None, uuid))
841
842
843

        self._add_placeholder("audio", placeholder)

844
    def parse_image_pil(
845
        self, image_pil: Image.Image | None, uuid: str | None = None
846
    ) -> None:
847
        placeholder = self._tracker.add("image", (image_pil, uuid))
848
        self._add_placeholder("image", placeholder)
849

850
    def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None:
851
        audio = self._connector.fetch_audio(audio_url) if audio_url else None
852

853
        placeholder = self._tracker.add("audio", (audio, uuid))
854
        self._add_placeholder("audio", placeholder)
855

856
    def parse_input_audio(
857
        self, input_audio: InputAudio | None, uuid: str | None = None
858
    ) -> None:
859
860
861
862
863
864
865
866
867
868
        if input_audio:
            audio_data = input_audio.get("data", "")
            audio_format = input_audio.get("format", "")
            if audio_data:
                audio_url = f"data:audio/{audio_format};base64,{audio_data}"
            else:
                # If a UUID is provided, audio data may be empty.
                audio_url = None
        else:
            audio_url = None
869

870
        return self.parse_audio(audio_url, uuid)
871

872
    def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
873
        video = self._connector.fetch_video(video_url=video_url) if video_url else None
874

875
        placeholder = self._tracker.add("video", (video, uuid))
876
        self._add_placeholder("video", placeholder)
877

878
879
880
881
882
883

class AsyncMultiModalContentParser(BaseMultiModalContentParser):
    def __init__(self, tracker: AsyncMultiModalItemTracker) -> None:
        super().__init__()

        self._tracker = tracker
884
885
        multimodal_config = self._tracker.model_config.multimodal_config
        media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None)
886
887
        self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load(
            envs.VLLM_MEDIA_CONNECTOR,
888
            media_io_kwargs=media_io_kwargs,
889
            allowed_local_media_path=tracker.allowed_local_media_path,
890
            allowed_media_domains=tracker.allowed_media_domains,
891
        )
892

893
894
    @property
    def model_config(self) -> ModelConfig:
895
        return self._tracker.model_config
896

897
898
899
900
901
902
    async def _image_with_uuid_async(self, image_url: str | None, uuid: str | None):
        image = (
            await self._connector.fetch_image_async(image_url) if image_url else None
        )
        return image, uuid

903
    def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
904
        coro = self._image_with_uuid_async(image_url, uuid)
905

906
        placeholder = self._tracker.add("image", coro)
907
        self._add_placeholder("image", placeholder)
908

909
    def parse_image_embeds(
910
        self,
911
912
        image_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
913
    ) -> None:
914
915
916
917
918
919
        mm_config = self.model_config.get_multimodal_config()
        if not mm_config.enable_mm_embeds:
            raise ValueError(
                "You must set `--enable-mm-embeds` to input `image_embeds`"
            )

920
921
922
        future = asyncio.Future[
            tuple[torch.Tensor | dict[str, torch.Tensor] | None, str | None]
        ]()
923
924
925
926
927
928

        if isinstance(image_embeds, dict):
            embeds = {
                k: self._connector.fetch_image_embedding(v)
                for k, v in image_embeds.items()
            }
929
            future.set_result((embeds, uuid))
930
931

        if isinstance(image_embeds, str):
932
            embedding = self._connector.fetch_image_embedding(image_embeds)
933
            future.set_result((embedding, uuid))
934

935
        if image_embeds is None:
936
            future.set_result((None, uuid))
937

938
        placeholder = self._tracker.add("image_embeds", future)
939
        self._add_placeholder("image", placeholder)
940

941
942
943
944
945
946
947
948
949
950
951
    def parse_audio_embeds(
        self,
        audio_embeds: str | dict[str, str] | None,
        uuid: str | None = None,
    ) -> None:
        mm_config = self.model_config.get_multimodal_config()
        if not mm_config.enable_mm_embeds:
            raise ValueError(
                "You must set `--enable-mm-embeds` to input `audio_embeds`"
            )

952
953
954
        future = asyncio.Future[
            tuple[torch.Tensor | dict[str, torch.Tensor] | None, str | None]
        ]()
955
956
957
958
959
960

        if isinstance(audio_embeds, dict):
            embeds = {
                k: self._connector.fetch_audio_embedding(v)
                for k, v in audio_embeds.items()
            }
961
            future.set_result((embeds, uuid))
962
963
964

        if isinstance(audio_embeds, str):
            embedding = self._connector.fetch_audio_embedding(audio_embeds)
965
            future.set_result((embedding, uuid))
966
967

        if audio_embeds is None:
968
            future.set_result((None, uuid))
969

970
        placeholder = self._tracker.add("audio_embeds", future)
971
972
        self._add_placeholder("audio", placeholder)

973
    def parse_image_pil(
974
975
976
        self,
        image_pil: Image.Image | None,
        uuid: str | None = None,
977
    ) -> None:
978
        future = asyncio.Future[tuple[Image.Image | None, str | None]]()
979
        if image_pil:
980
            future.set_result((image_pil, uuid))
981
        else:
982
            future.set_result((None, uuid))
983

984
        placeholder = self._tracker.add("image", future)
985
        self._add_placeholder("image", placeholder)
986

987
988
989
990
991
992
    async def _audio_with_uuid_async(self, audio_url: str | None, uuid: str | None):
        audio = (
            await self._connector.fetch_audio_async(audio_url) if audio_url else None
        )
        return audio, uuid

993
    def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None:
994
        coro = self._audio_with_uuid_async(audio_url, uuid)
995

996
        placeholder = self._tracker.add("audio", coro)
997
        self._add_placeholder("audio", placeholder)
998

999
    def parse_input_audio(
1000
        self, input_audio: InputAudio | None, uuid: str | None = None
1001
    ) -> None:
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
        if input_audio:
            audio_data = input_audio.get("data", "")
            audio_format = input_audio.get("format", "")
            if audio_data:
                audio_url = f"data:audio/{audio_format};base64,{audio_data}"
            else:
                # If a UUID is provided, audio data may be empty.
                audio_url = None
        else:
            audio_url = None
1012

1013
        return self.parse_audio(audio_url, uuid)
1014

1015
    async def _video_with_uuid_async(self, video_url: str | None, uuid: str | None):
1016
        video = (
1017
            await self._connector.fetch_video_async(video_url) if video_url else None
1018
        )
1019
1020
1021
1022
        return video, uuid

    def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
        coro = self._video_with_uuid_async(video_url, uuid)
1023

1024
        placeholder = self._tracker.add("video", coro)
1025
        self._add_placeholder("video", placeholder)
1026

1027

1028
1029
1030
1031
1032
1033
1034
@dataclass
class ChatTemplateConfig:
    chat_template: str | None = None
    chat_template_content_format: ChatTemplateContentFormatOption = "auto"
    trust_request_chat_template: bool = False


1035
def validate_chat_template(chat_template: Path | str | None):
1036
1037
1038
1039
1040
    """Raises if the provided chat template appears invalid."""
    if chat_template is None:
        return

    elif isinstance(chat_template, Path) and not chat_template.exists():
1041
        raise FileNotFoundError("the supplied chat template path doesn't exist")
1042
1043
1044

    elif isinstance(chat_template, str):
        JINJA_CHARS = "{}\n"
1045
1046
1047
1048
        if (
            not any(c in chat_template for c in JINJA_CHARS)
            and not Path(chat_template).exists()
        ):
1049
1050
1051
            # Try to find the template in the built-in templates directory
            from vllm.transformers_utils.chat_templates.registry import (
                CHAT_TEMPLATES_DIR,
1052
            )
1053

1054
1055
1056
1057
1058
1059
1060
1061
            builtin_template_path = CHAT_TEMPLATES_DIR / chat_template
            if not builtin_template_path.exists():
                raise ValueError(
                    f"The supplied chat template string ({chat_template}) "
                    f"appears path-like, but doesn't exist! "
                    f"Tried: {chat_template} and {builtin_template_path}"
                )

1062
    else:
1063
        raise TypeError(f"{type(chat_template)} is not a valid chat template type")
1064
1065


1066
def _load_chat_template(
1067
    chat_template: Path | str | None,
1068
1069
    *,
    is_literal: bool = False,
1070
) -> str | None:
1071
1072
    if chat_template is None:
        return None
1073
1074
1075

    if is_literal:
        if isinstance(chat_template, Path):
1076
1077
1078
            raise TypeError(
                "chat_template is expected to be read directly from its value"
            )
1079

1080
        return chat_template
1081

1082
    try:
1083
        with open(chat_template) as f:
1084
            return f.read()
1085
    except OSError as e:
1086
1087
1088
        if isinstance(chat_template, Path):
            raise

1089
1090
        JINJA_CHARS = "{}\n"
        if not any(c in chat_template for c in JINJA_CHARS):
1091
1092
1093
            # Try to load from the built-in templates directory
            from vllm.transformers_utils.chat_templates.registry import (
                CHAT_TEMPLATES_DIR,
1094
            )
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107

            builtin_template_path = CHAT_TEMPLATES_DIR / chat_template
            try:
                with open(builtin_template_path) as f:
                    return f.read()
            except OSError:
                msg = (
                    f"The supplied chat template ({chat_template}) "
                    f"looks like a file path, but it failed to be opened. "
                    f"Tried: {chat_template} and {builtin_template_path}. "
                    f"Reason: {e}"
                )
                raise ValueError(msg) from e
1108

1109
1110
        # If opening a file fails, set chat template to be args to
        # ensure we decode so our escape are interpreted correctly
1111
1112
1113
1114
1115
1116
1117
        return _load_chat_template(chat_template, is_literal=True)


_cached_load_chat_template = lru_cache(_load_chat_template)


def load_chat_template(
1118
    chat_template: Path | str | None,
1119
1120
    *,
    is_literal: bool = False,
1121
) -> str | None:
1122
    return _cached_load_chat_template(chat_template, is_literal=is_literal)
1123
1124


1125
1126
1127
def _get_interleaved_text_prompt(
    placeholder_storage: dict[str, list], texts: list[str]
) -> str:
1128
1129
1130
1131
1132
1133
1134
    for idx, elem in enumerate(texts):
        if elem in placeholder_storage:
            texts[idx] = placeholder_storage[elem].pop(0)

    return "\n".join(texts)


1135
# TODO: Let user specify how to insert multimodal tokens into prompt
1136
# (similar to chat template)
1137
1138
1139
1140
1141
def _get_full_multimodal_text_prompt(
    placeholder_storage: dict[str, list],
    texts: list[str],
    interleave_strings: bool,
) -> str:
1142
    """Combine multimodal prompts for a multimodal language model."""
1143

1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
    # flatten storage to make it looks like
    # {
    #   "<|image|>": 2,
    #   "<|audio|>": 1
    # }
    placeholder_counts = Counter(
        [v for elem in placeholder_storage.values() for v in elem]
    )

    if interleave_strings:
        text_prompt = _get_interleaved_text_prompt(placeholder_storage, texts)
    else:
        text_prompt = "\n".join(texts)

    # Pass interleaved text further in case the user used image placeholders
    # himself, but forgot to disable the 'interleave_strings' flag

1161
    # Look through the text prompt to check for missing placeholders
1162
    missing_placeholders: list[str] = []
1163
1164
1165
1166
1167
    for placeholder in placeholder_counts:
        # For any existing placeholder in the text prompt, we leave it as is
        placeholder_counts[placeholder] -= text_prompt.count(placeholder)

        if placeholder_counts[placeholder] < 0:
1168
1169
1170
1171
            logger.error(
                "Placeholder count is negative! "
                "Ensure that the 'interleave_strings' flag is disabled "
                "(current value: %s) "
1172
1173
                "when manually placing image placeholders.",
                interleave_strings,
1174
1175
            )
            logger.debug("Input prompt: %s", text_prompt)
1176
1177
            raise ValueError(
                f"Found more '{placeholder}' placeholders in input prompt than "
1178
1179
                "actual multimodal data items."
            )
1180

1181
        missing_placeholders.extend([placeholder] * placeholder_counts[placeholder])
1182

1183
1184
    # NOTE: Default behaviour: we always add missing placeholders
    # at the front of the prompt, if interleave_strings=False
1185
1186
1187
1188
    if text_prompt:
        return "\n".join(missing_placeholders + [text_prompt])
    else:
        return "\n".join(missing_placeholders)
1189
1190


1191
1192
# No need to validate using Pydantic again
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
1193
_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
1194
_AudioEmbedsParser = partial(cast, ChatCompletionContentPartAudioEmbedsParam)
1195
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
1196
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
1197
_PILImageParser = partial(cast, CustomChatCompletionContentPILImageParam)
Julien Denize's avatar
Julien Denize committed
1198
_ThinkParser = partial(cast, CustomThinkCompletionContentParam)
1199
1200
1201
1202
# Need to validate url objects
_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam).validate_python
_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python
_VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python
1203

1204
_ResponsesInputImageParser = TypeAdapter(ResponseInputImageParam).validate_python
1205
_ContentPart: TypeAlias = str | dict[str, str] | InputAudio | PILImage
1206

1207
# Define a mapping from part types to their corresponding parsing functions.
1208
MM_PARSER_MAP: dict[
1209
1210
1211
    str,
    Callable[[ChatCompletionContentPartParam], _ContentPart],
] = {
1212
1213
1214
    "text": lambda part: _TextParser(part).get("text", None),
    "thinking": lambda part: _ThinkParser(part).get("thinking", None),
    "input_text": lambda part: _TextParser(part).get("text", None),
1215
    "output_text": lambda part: _TextParser(part).get("text", None),
1216
1217
1218
    "input_image": lambda part: _ResponsesInputImageParser(part).get("image_url", None),
    "image_url": lambda part: _ImageParser(part).get("image_url", {}).get("url", None),
    "image_embeds": lambda part: _ImageEmbedsParser(part).get("image_embeds", None),
1219
    "audio_embeds": lambda part: _AudioEmbedsParser(part).get("audio_embeds", None),
1220
    "image_pil": lambda part: _PILImageParser(part).get("image_pil", None),
1221
1222
    "audio_url": lambda part: _AudioParser(part).get("audio_url", {}).get("url", None),
    "input_audio": lambda part: _InputAudioParser(part).get("input_audio", None),
1223
    "refusal": lambda part: _RefusalParser(part).get("refusal", None),
1224
    "video_url": lambda part: _VideoParser(part).get("video_url", {}).get("url", None),
1225
1226
1227
1228
}


def _parse_chat_message_content_mm_part(
1229
1230
    part: ChatCompletionContentPartParam,
) -> tuple[str, _ContentPart]:
1231
    """
1232
    Parses a given multi-modal content part based on its type.
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245

    Args:
        part: A dict containing the content part, with a potential 'type' field.

    Returns:
        A tuple (part_type, content) where:
        - part_type: Type of the part (e.g., 'text', 'image_url').
        - content: Parsed content (e.g., text, image URL).

    Raises:
        ValueError: If the 'type' field is missing and no direct URL is found.
    """
    assert isinstance(
1246
1247
        part, dict
    )  # This is needed to avoid mypy errors: part.get() from str
1248
    part_type = part.get("type", None)
1249
    uuid = part.get("uuid", None)
1250

1251
    if isinstance(part_type, str) and part_type in MM_PARSER_MAP and uuid is None:  # noqa: E501
1252
1253
1254
        content = MM_PARSER_MAP[part_type](part)

        # Special case for 'image_url.detail'
1255
1256
        # We only support 'auto', which is the default
        if part_type == "image_url" and part.get("detail", "auto") != "auto":
1257
            logger.warning(
1258
                "'image_url.detail' is currently not supported and will be ignored."
1259
            )
1260
1261
1262
1263

        return part_type, content

    # Handle missing 'type' but provided direct URL fields.
1264
    # 'type' is required field by pydantic
1265
1266
    if part_type is None or uuid is not None:
        if "image_url" in part:
1267
            image_params = cast(CustomChatCompletionContentSimpleImageParam, part)
1268
1269
1270
1271
1272
1273
1274
1275
            image_url = image_params.get("image_url", None)
            if isinstance(image_url, dict):
                # Can potentially happen if user provides a uuid
                # with url as a dict of {"url": url}
                image_url = image_url.get("url", None)
            return "image_url", image_url
        if "image_pil" in part:
            # "image_pil" could be None if UUID is provided.
1276
            image_params = cast(  # type: ignore
1277
1278
1279
1280
1281
1282
                CustomChatCompletionContentPILImageParam, part
            )
            image_pil = image_params.get("image_pil", None)
            return "image_pil", image_pil
        if "image_embeds" in part:
            # "image_embeds" could be None if UUID is provided.
1283
            image_params = cast(  # type: ignore
1284
1285
1286
1287
                ChatCompletionContentPartImageEmbedsParam, part
            )
            image_embeds = image_params.get("image_embeds", None)
            return "image_embeds", image_embeds
1288
1289
1290
1291
1292
1293
1294
        if "audio_embeds" in part:
            # "audio_embeds" could be None if UUID is provided.
            audio_params = cast(  # type: ignore[assignment]
                ChatCompletionContentPartAudioEmbedsParam, part
            )
            audio_embeds = audio_params.get("audio_embeds", None)
            return "audio_embeds", audio_embeds
1295
        if "audio_url" in part:
1296
1297
1298
            audio_params = cast(  # type: ignore[assignment]
                CustomChatCompletionContentSimpleAudioParam, part
            )
1299
1300
1301
1302
1303
1304
            audio_url = audio_params.get("audio_url", None)
            if isinstance(audio_url, dict):
                # Can potentially happen if user provides a uuid
                # with url as a dict of {"url": url}
                audio_url = audio_url.get("url", None)
            return "audio_url", audio_url
1305
        if part.get("input_audio") is not None:
1306
            input_audio_params = cast(dict[str, str], part)
1307
            return "input_audio", input_audio_params
1308
        if "video_url" in part:
1309
            video_params = cast(CustomChatCompletionContentSimpleVideoParam, part)
1310
1311
1312
1313
1314
1315
            video_url = video_params.get("video_url", None)
            if isinstance(video_url, dict):
                # Can potentially happen if user provides a uuid
                # with url as a dict of {"url": url}
                video_url = video_url.get("url", None)
            return "video_url", video_url
1316
1317
1318
1319
1320
1321
1322
1323
        # Raise an error if no 'type' or direct URL is found.
        raise ValueError("Missing 'type' field in multimodal part.")

    if not isinstance(part_type, str):
        raise ValueError("Invalid 'type' field in multimodal part.")
    return part_type, "unknown part_type content"


1324
PART_TYPES_TO_SKIP_NONE_CONTENT = (
1325
1326
1327
    "text",
    "refusal",
)
1328

1329

1330
1331
1332
def _parse_chat_message_content_parts(
    role: str,
    parts: Iterable[ChatCompletionContentPartParam],
1333
    mm_tracker: BaseMultiModalItemTracker,
1334
1335
    *,
    wrap_dicts: bool,
1336
    interleave_strings: bool,
1337
) -> list[ConversationMessage]:
1338
    content = list[_ContentPart]()
1339

1340
    mm_parser = mm_tracker.create_parser()
1341
1342

    for part in parts:
1343
        parse_res = _parse_chat_message_content_part(
1344
1345
1346
            part,
            mm_parser,
            wrap_dicts=wrap_dicts,
1347
            interleave_strings=interleave_strings,
1348
        )
1349
1350
        if parse_res:
            content.append(parse_res)
1351

1352
    if wrap_dicts:
1353
        # Parsing wraps images and texts as interleaved dictionaries
1354
        return [ConversationMessage(role=role, content=content)]  # type: ignore
1355
    texts = cast(list[str], content)
1356
1357
    mm_placeholder_storage = mm_parser.mm_placeholder_storage()
    if mm_placeholder_storage:
1358
1359
1360
        text_prompt = _get_full_multimodal_text_prompt(
            mm_placeholder_storage, texts, interleave_strings
        )
1361
1362
1363
    else:
        text_prompt = "\n".join(texts)

1364
1365
1366
1367
    return [ConversationMessage(role=role, content=text_prompt)]


def _parse_chat_message_content_part(
1368
1369
1370
1371
    part: ChatCompletionContentPartParam,
    mm_parser: BaseMultiModalContentParser,
    *,
    wrap_dicts: bool,
1372
    interleave_strings: bool,
1373
) -> _ContentPart | None:
1374
1375
1376
1377
1378
1379
1380
1381
    """Parses a single part of a conversation. If wrap_dicts is True,
    structured dictionary pieces for texts and images will be
    wrapped in dictionaries, i.e., {"type": "text", "text", ...} and
    {"type": "image"}, respectively. Otherwise multimodal data will be
    handled by mm_parser, and texts will be returned as strings to be joined
    with multimodal placeholders.
    """
    if isinstance(part, str):  # Handle plain text parts
1382
        return part
1383
1384
    # Handle structured dictionary parts
    part_type, content = _parse_chat_message_content_mm_part(part)
1385
    # if part_type is text/refusal/image_url/audio_url/video_url/input_audio but
1386
    # content is None, log a warning and skip
1387
    if part_type in PART_TYPES_TO_SKIP_NONE_CONTENT and content is None:
1388
        logger.warning(
1389
            "Skipping multimodal part '%s' (type: '%s') "
1390
1391
1392
1393
            "with empty / unparsable content.",
            part,
            part_type,
        )
1394
1395
        return None

1396
    if part_type in ("text", "input_text", "output_text", "refusal", "thinking"):
1397
1398
        str_content = cast(str, content)
        if wrap_dicts:
1399
            return {"type": "text", "text": str_content}
1400
1401
        else:
            return str_content
1402

1403
1404
1405
1406
1407
1408
    # For media items, if a user has provided one, use it. Otherwise, insert
    # a placeholder empty uuid.
    uuid = part.get("uuid", None)
    if uuid is not None:
        uuid = str(uuid)

1409
    modality = None
1410
    if part_type == "image_pil":
1411
        image_content = cast(Image.Image, content) if content is not None else None
1412
        mm_parser.parse_image_pil(image_content, uuid)
1413
        modality = "image"
1414
    elif part_type in ("image_url", "input_image"):
1415
        str_content = cast(str, content)
1416
        mm_parser.parse_image(str_content, uuid)
1417
1418
        modality = "image"
    elif part_type == "image_embeds":
1419
        content = cast(str | dict[str, str], content) if content is not None else None
1420
        mm_parser.parse_image_embeds(content, uuid)
1421
        modality = "image"
1422
1423
1424
1425
    elif part_type == "audio_embeds":
        content = cast(str | dict[str, str], content) if content is not None else None
        mm_parser.parse_audio_embeds(content, uuid)
        modality = "audio"
1426
    elif part_type == "audio_url":
1427
        str_content = cast(str, content)
1428
        mm_parser.parse_audio(str_content, uuid)
1429
1430
        modality = "audio"
    elif part_type == "input_audio":
1431
        dict_content = cast(InputAudio, content)
1432
        mm_parser.parse_input_audio(dict_content, uuid)
1433
1434
        modality = "audio"
    elif part_type == "video_url":
1435
        str_content = cast(str, content)
1436
        mm_parser.parse_video(str_content, uuid)
1437
1438
1439
        modality = "video"
    else:
        raise NotImplementedError(f"Unknown part type: {part_type}")
1440

1441
1442
1443
    return (
        {"type": modality}
        if wrap_dicts
1444
        else (MODALITY_PLACEHOLDERS_MAP[modality] if interleave_strings else None)
1445
    )
1446
1447


1448
1449
1450
1451
1452
# No need to validate using Pydantic again
_AssistantParser = partial(cast, ChatCompletionAssistantMessageParam)
_ToolParser = partial(cast, ChatCompletionToolMessageParam)


1453
def _parse_chat_message_content(
1454
1455
    message: ChatCompletionMessageParam,
    mm_tracker: BaseMultiModalItemTracker,
1456
    content_format: ChatTemplateContentFormat,
1457
    interleave_strings: bool,
1458
) -> list[ConversationMessage]:
1459
1460
    role = message["role"]
    content = message.get("content")
1461
    reasoning = message.get("reasoning")
1462

1463
    if content is None:
1464
1465
        content = []
    elif isinstance(content, str):
1466
        content = [ChatCompletionContentPartTextParam(type="text", text=content)]
1467
    result = _parse_chat_message_content_parts(
1468
1469
        role,
        content,  # type: ignore
1470
        mm_tracker,
1471
        wrap_dicts=(content_format == "openai"),
1472
        interleave_strings=interleave_strings,
1473
    )
1474

1475
    for result_msg in result:
1476
        if role == "assistant":
1477
1478
            parsed_msg = _AssistantParser(message)

1479
1480
1481
            # The 'tool_calls' is not None check ensures compatibility.
            # It's needed only if downstream code doesn't strictly
            # follow the OpenAI spec.
1482
            if "tool_calls" in parsed_msg and parsed_msg["tool_calls"] is not None:
1483
                result_msg["tool_calls"] = list(parsed_msg["tool_calls"])
1484
1485
1486
            # Include reasoning if present for interleaved thinking.
            if reasoning is not None:
                result_msg["reasoning"] = cast(str, reasoning)
1487
1488
1489
                result_msg["reasoning_content"] = cast(
                    str, reasoning
                )  # keep compatibility
1490
1491
1492
1493
1494
1495
1496
1497
        elif role == "tool":
            parsed_msg = _ToolParser(message)
            if "tool_call_id" in parsed_msg:
                result_msg["tool_call_id"] = parsed_msg["tool_call_id"]

        if "name" in message and isinstance(message["name"], str):
            result_msg["name"] = message["name"]

1498
1499
        if role == "developer":
            result_msg["tools"] = message.get("tools", None)
1500
1501
    return result

1502

1503
def _postprocess_messages(messages: list[ConversationMessage]) -> None:
1504
1505
1506
1507
1508
1509
    # per the Transformers docs & maintainers, tool call arguments in
    # assistant-role messages with tool_calls need to be dicts not JSON str -
    # this is how tool-use chat templates will expect them moving forwards
    # so, for messages that have tool_calls, parse the string (which we get
    # from openAI format) to dict
    for message in messages:
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
        if message["role"] == "assistant" and "tool_calls" in message:
            tool_calls = message.get("tool_calls")
            if not isinstance(tool_calls, list):
                continue

            if len(tool_calls) == 0:
                # Drop empty tool_calls to keep templates on the normal assistant path.
                message.pop("tool_calls", None)
                continue

            for item in tool_calls:
1521
1522
                # if arguments is None or empty string, set to {}
                if content := item["function"].get("arguments"):
1523
1524
                    if not isinstance(content, (dict, list)):
                        item["function"]["arguments"] = json.loads(content)
1525
1526
                else:
                    item["function"]["arguments"] = {}
1527
1528


1529
def parse_chat_messages(
1530
    messages: list[ChatCompletionMessageParam],
1531
    model_config: ModelConfig,
1532
    content_format: ChatTemplateContentFormat,
1533
1534
) -> tuple[
    list[ConversationMessage],
1535
1536
    MultiModalDataDict | None,
    MultiModalUUIDDict | None,
1537
]:
1538
    conversation: list[ConversationMessage] = []
1539
    mm_tracker = MultiModalItemTracker(model_config)
1540
1541

    for msg in messages:
1542
1543
1544
        sub_messages = _parse_chat_message_content(
            msg,
            mm_tracker,
1545
            content_format,
1546
1547
1548
1549
            interleave_strings=(
                content_format == "string"
                and model_config.multimodal_config is not None
                and model_config.multimodal_config.interleave_mm_strings
1550
            ),
1551
        )
1552

1553
        conversation.extend(sub_messages)
1554

1555
1556
    _postprocess_messages(conversation)

1557
1558
1559
    mm_data, mm_uuids = mm_tracker.resolve_items()

    return conversation, mm_data, mm_uuids
1560
1561


1562
async def parse_chat_messages_async(
1563
    messages: list[ChatCompletionMessageParam],
1564
    model_config: ModelConfig,
1565
    content_format: ChatTemplateContentFormat,
1566
1567
) -> tuple[
    list[ConversationMessage],
1568
    MultiModalDataDict | None,
1569
    MultiModalUUIDDict | None,
1570
]:
1571
    conversation: list[ConversationMessage] = []
1572
    mm_tracker = AsyncMultiModalItemTracker(model_config)
1573
1574

    for msg in messages:
1575
1576
1577
        sub_messages = _parse_chat_message_content(
            msg,
            mm_tracker,
1578
            content_format,
1579
1580
1581
1582
            interleave_strings=(
                content_format == "string"
                and model_config.multimodal_config is not None
                and model_config.multimodal_config.interleave_mm_strings
1583
            ),
1584
        )
1585
1586
1587

        conversation.extend(sub_messages)

1588
1589
    _postprocess_messages(conversation)

1590
1591
1592
    mm_data, mm_uuids = await mm_tracker.resolve_items()

    return conversation, mm_data, mm_uuids
1593

1594

1595
1596
1597
def get_history_tool_calls_cnt(conversation: list[ConversationMessage]):
    idx = 0
    for msg in conversation:
1598
1599
1600
        if msg["role"] == "assistant":
            tool_calls = msg.get("tool_calls")
            idx += len(list(tool_calls)) if tool_calls is not None else 0  # noqa
1601
1602
1603
    return idx


1604
1605
1606
def make_tool_call_id(id_type: str = "random", func_name=None, idx=None):
    if id_type == "kimi_k2":
        return f"functions.{func_name}:{idx}"
1607
1608
1609
    else:
        # by default return random
        return f"chatcmpl-tool-{random_uuid()}"