parse.py 22.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from abc import ABC, abstractmethod
from collections import UserDict
6
from collections.abc import Callable, Iterator, Mapping, Sequence, Set
7
8
9
10
11
12
from typing import (
    TYPE_CHECKING,
    Any,
    Generic,
    Literal,
    NamedTuple,
13
14
    TypeAlias,
    TypeGuard,
15
16
    TypeVar,
)
17
18
19

import numpy as np
import torch
20
from typing_extensions import assert_never
21

22
from vllm.inputs import ModalityData, MultiModalDataDict, MultiModalUUIDDict
23
from vllm.utils.collection_utils import is_list_of
24
from vllm.utils.import_utils import LazyLoader
25

26
from .audio import AudioResampler, AudioSpec, normalize_audio
27
28
29
30
31
32
33
34
35
36
from .inputs import (
    AudioItem,
    HfAudioItem,
    HfImageItem,
    HfVideoItem,
    ImageItem,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
    VideoItem,
)
37
from .media import MediaWithBytes
38
39
40
41

_T = TypeVar("_T")
_I = TypeVar("_I")

42
43
44
45
46
if TYPE_CHECKING:
    import PIL.Image as PILImage
else:
    PILImage = LazyLoader("PILImage", globals(), "PIL.Image")

47
48

class ModalityDataItems(ABC, Generic[_T, _I]):
49
    """
50
51
    Represents data items for a modality in
    [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
52
    """
53

54
    def __init__(self, data: _T, modality: str) -> None:
55
56
        super().__init__()

57
        self.data: _T = data
58
59
60
        self.modality = modality

    def __repr__(self) -> str:
61
        return f"{type(self).__name__}(modality={self.modality!r}, len={len(self)})"
62
63
64
65
66
67
68
69
70

    def __len__(self) -> int:
        return self.get_count()

    def __getitem__(self, index: int) -> _I:
        return self.get(index)

    if TYPE_CHECKING:
        # Auto-generated
71
        def __iter__(self) -> Iterator[_I]: ...
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86

    @abstractmethod
    def get_count(self) -> int:
        """Get the number of data items."""
        raise NotImplementedError

    @abstractmethod
    def get(self, index: int) -> _I:
        """Get a data item by its index."""
        raise NotImplementedError

    def get_all(self) -> list[_I]:
        """Get all data items."""
        return [self.get(idx) for idx in range(self.get_count())]

87
88
89
90
91
92
    def get_item_for_hash(self, index: int) -> object:
        return self.get(index)

    def get_all_items_for_hash(self) -> list[object]:
        return [self.get_item_for_hash(idx) for idx in range(self.get_count())]

93
94
95
96
97
98
99
100
101
102
103
104
    @abstractmethod
    def get_processor_data(self) -> Mapping[str, object]:
        """Get the data to pass to the HF processor."""
        raise NotImplementedError

    @abstractmethod
    def get_passthrough_data(self) -> Mapping[str, object]:
        """Get the data to pass directly to the model."""
        raise NotImplementedError


class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]):
105
    """Base class for data items that are arranged in a list."""
106

107
108
109
110
    def _unwrap(self, item: _T | MediaWithBytes[_T]) -> _T:
        """Extract media from wrapper if present."""
        return item.media if isinstance(item, MediaWithBytes) else item

111
112
113
114
    def get_count(self) -> int:
        return len(self.data)

    def get(self, index: int) -> _T:
115
116
117
118
        return self._unwrap(self.data[index])

    def get_item_for_hash(self, index: int) -> _T | MediaWithBytes[_T]:
        # Return raw item for hashing (preserves original_bytes if present)
119
120
121
        return self.data[index]

    def get_processor_data(self) -> Mapping[str, object]:
122
        return {f"{self.modality}s": self.get_all()}
123
124
125
126
127

    def get_passthrough_data(self) -> Mapping[str, object]:
        return {}


128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def validate_embedding_ndim(
    tensor: torch.Tensor,
    modality: str,
    index: int | None = None,
) -> None:
    """Validate tensor ndim for multimodal embeddings.

    Single embeddings should be 2D (seq_len, hidden_size).
    Batched embeddings should be 3D (batch, seq_len, hidden_size).

    Args:
        tensor: The tensor to validate.
        modality: The modality name for error messages (e.g., "image", "audio").
        index: Optional index for list items, included in error messages.
    """
    if tensor.ndim < 2 or tensor.ndim > 3:
        idx_str = f" [{index}]" if index is not None else ""
        raise ValueError(
            f"{modality.capitalize()} embedding{idx_str} must be 2D "
            f"(seq_len, hidden_size) or 3D (batch, seq_len, hidden_size), "
            f"got {tensor.ndim}D tensor with shape {tuple(tensor.shape)}"
        )


152
class EmbeddingItems(
153
    ModalityDataItems[torch.Tensor | list[torch.Tensor], torch.Tensor]
154
):
155
156
157
158
    """
    Base class for data items that are expressed as a batched embedding tensor,
    or a list of embedding tensors (one per item).
    """
159

160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
    def __init__(
        self,
        data: torch.Tensor | list[torch.Tensor],
        modality: str,
        expected_hidden_size: int | None = None,
    ) -> None:
        super().__init__(data, modality)

        # Validate ndim first (before hidden_size which depends on correct ndim)
        self._validate_ndim()

        # Validate hidden dimension if expected size is provided
        if expected_hidden_size is not None:
            self._validate_hidden_size(expected_hidden_size)

    def _validate_ndim(self) -> None:
        """Validate that embedding tensors have correct ndim (2D or 3D)."""
        if isinstance(self.data, torch.Tensor):
            validate_embedding_ndim(self.data, self.modality)
        else:
            # List of tensors: each should be 2D (seq_len, hidden_size)
            for idx, tensor in enumerate(self.data):
                if tensor.ndim != 2:
                    raise ValueError(
                        f"{self.modality.capitalize()} embedding [{idx}] must be "
                        f"2D (seq_len, hidden_size), got {tensor.ndim}D tensor "
                        f"with shape {tuple(tensor.shape)}"
                    )

    def _validate_hidden_size(self, expected_hidden_size: int) -> None:
        """Validate that embedding hidden dimension matches expected size.

        This validates hidden dimensions to prevent vulnerabilities: Embeddings
        with correct ndim but wrong hidden dimension could bypass initial
        checks and cause crashes during model inference when dimensions don't match.
        """
        if isinstance(self.data, torch.Tensor):
            # Batched tensor: shape is (batch, seq_len, hidden_size)
            actual_hidden_size = self.data.shape[-1]
            if actual_hidden_size != expected_hidden_size:
                raise ValueError(
                    f"{self.modality.capitalize()} embedding hidden dimension "
                    f"mismatch: got {actual_hidden_size}, but model expects "
                    f"{expected_hidden_size}. Embedding shape: {tuple(self.data.shape)}"
                )
        else:
            # List of tensors: each has shape (seq_len, hidden_size)
            for idx, tensor in enumerate(self.data):
                actual_hidden_size = tensor.shape[-1]
                if actual_hidden_size != expected_hidden_size:
                    raise ValueError(
                        f"{self.modality.capitalize()} embedding [{idx}] hidden "
                        f"dimension mismatch: got {actual_hidden_size}, but model "
                        f"expects {expected_hidden_size}. "
                        f"Embedding shape: {tuple(tensor.shape)}"
                    )

217
218
219
220
221
222
    def _unwrap(
        self, item: torch.Tensor | MediaWithBytes[torch.Tensor]
    ) -> torch.Tensor:
        """Extract media from wrapper if present."""
        return item.media if isinstance(item, MediaWithBytes) else item

223
224
225
    def get_count(self) -> int:
        return len(self.data)

226
    def get(self, index: int) -> torch.Tensor:
227
        return self._unwrap(self.data[index])
228
229
230
231
232
233
234

    def get_processor_data(self) -> Mapping[str, object]:
        return {}

    def get_passthrough_data(self) -> Mapping[str, object]:
        return {f"{self.modality}_embeds": self.data}

235
236
237
    def get_feature_size(self, item_idx: int) -> int:
        return len(self.get(item_idx))

238

239
240
241
class DictEmbeddingItems(
    ModalityDataItems[Mapping[str, torch.Tensor], Mapping[str, torch.Tensor]]
):
242
243
244
245
246
247
248
249
250
251
252
    """
    Base class for data items that are expressed as a dictionary of tensors.

    Usually, the dictionary keys correspond to the outputs of HF processor.
    """

    def __init__(
        self,
        data: Mapping[str, torch.Tensor],
        modality: str,
        required_fields: set[str],
253
254
255
256
        fields_factory: Callable[
            [Mapping[str, torch.Tensor]],
            Mapping[str, MultiModalFieldConfig],
        ],
257
    ) -> None:
258
259
        from transformers.feature_extraction_utils import BatchFeature

260
261
262
263
264
        super().__init__(data, modality)

        missing_required_data_keys = required_fields - data.keys()
        if missing_required_data_keys:
            data_keys = set(data.keys())
265
266
267
268
            msg = (
                f"The data should contain the fields: {required_fields}, "
                f"but only found the following keys: {data_keys}"
            )
269
270
            raise ValueError(msg)

271
272
273
274
275
276
277
        fields_config = fields_factory(data)
        missing_required_fields = required_fields - fields_config.keys()
        if missing_required_fields:
            fields = set(fields_config.keys())
            msg = f"{required_fields=} should be a subset of {fields=}"
            raise ValueError(msg)

278
279
280
        self.fields_config = fields_config
        self.required_fields = required_fields

281
        self._kwargs = MultiModalKwargsItems.from_hf_inputs(
282
283
284
285
286
            BatchFeature(dict(data)),
            fields_config,
        )

    def get_count(self) -> int:
287
        return len(self._kwargs[self.modality])
288
289

    def get(self, index: int) -> Mapping[str, torch.Tensor]:
290
        return self._kwargs[self.modality][index].get_data()
291
292
293
294
295
296
297
298

    def get_processor_data(self) -> Mapping[str, object]:
        return {}

    def get_passthrough_data(self) -> Mapping[str, object]:
        return self.data


299
300
class AudioProcessorItems(ProcessorBatchItems[HfAudioItem | None]):
    def __init__(self, data: Sequence[HfAudioItem | None]) -> None:
301
302
        super().__init__(data, "audio")

303
304
    def get_audio_length(self, item_idx: int) -> int:
        audio = self.get(item_idx)
305
306
307
        if audio is None:
            raise ValueError(f"Cannot get length of cached audio at {item_idx}")

308
309
        return len(audio)

310
311

class AudioEmbeddingItems(EmbeddingItems):
312
313
314
315
316
317
    def __init__(
        self,
        data: torch.Tensor | list[torch.Tensor],
        expected_hidden_size: int | None = None,
    ) -> None:
        super().__init__(data, "audio", expected_hidden_size)
318
319
320
321
322
323
324


class ImageSize(NamedTuple):
    width: int
    height: int


325
326
class ImageProcessorItems(ProcessorBatchItems[HfImageItem | None]):
    def __init__(self, data: Sequence[HfImageItem | None]) -> None:
327
328
329
330
        super().__init__(data, "image")

    def get_image_size(self, item_idx: int) -> ImageSize:
        image = self.get(item_idx)
331
332
        if image is None:
            raise ValueError(f"Cannot get size of cached image at {item_idx}")
333

334
        if isinstance(image, PILImage.Image):
335
336
337
338
339
340
341
342
343
            return ImageSize(*image.size)
        if isinstance(image, (np.ndarray, torch.Tensor)):
            _, h, w = image.shape
            return ImageSize(w, h)

        assert_never(image)


class ImageEmbeddingItems(EmbeddingItems):
344
345
346
347
348
349
    def __init__(
        self,
        data: torch.Tensor | list[torch.Tensor],
        expected_hidden_size: int | None = None,
    ) -> None:
        super().__init__(data, "image", expected_hidden_size)
350
351


352
class VideoProcessorItems(ProcessorBatchItems[HfVideoItem | None]):
353
354
    def __init__(
        self,
355
        data: Sequence[HfVideoItem | None],
356
        metadata: dict[str, Any] | list[dict[str, Any] | None] | None = None,
357
    ) -> None:
358
        super().__init__(data, "video")
359

360
        self.metadata = metadata
361

362
    def get_num_frames(self, item_idx: int) -> int:
363
364
365
366
367
        video = self.get(item_idx)
        if video is None:
            raise ValueError(f"Cannot get length of cached video at {item_idx}")

        return len(video)
368
369

    def get_frame_size(self, item_idx: int) -> ImageSize:
370
371
372
373
374
375
376
        video = self.get(item_idx)
        if video is None:
            raise ValueError(f"Cannot get size of cached video at {item_idx}")
        if len(video) == 0:
            raise ValueError(f"Cannot get size of empty video at {item_idx}")

        image = video[0]
377

378
        if isinstance(image, PILImage.Image):
379
380
381
382
383
384
385
            return ImageSize(*image.size)
        if isinstance(image, (np.ndarray, torch.Tensor)):
            _, h, w = image.shape
            return ImageSize(w, h)

        assert_never(image)

386
387

class VideoEmbeddingItems(EmbeddingItems):
388
389
390
391
392
393
    def __init__(
        self,
        data: torch.Tensor | list[torch.Tensor],
        expected_hidden_size: int | None = None,
    ) -> None:
        super().__init__(data, "video", expected_hidden_size)
394
395


Roger Wang's avatar
Roger Wang committed
396
397
398
399
400
401
402
class VisionChunkProcessorItems(ProcessorBatchItems[Any]):
    """Processor items for vision chunks (unified image and video chunks)."""

    def __init__(self, data: Sequence[Any]) -> None:
        super().__init__(data, "vision_chunk")


403
404
405
406
407
_D = TypeVar("_D", bound=ModalityDataItems[Any, Any])


class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]):
    """
408
409
    A normalized [`MultiModalDataDict`][vllm.inputs.MultiModalDataDict]
    such that each entry corresponds to a list.
410
411
    """

412
413
414
415
416
417
418
419
420
    def select(self, modalities: Set[str]):
        """
        Construct a new `MultiModalDataItems` instance containing only the
        selected modalities.
        """
        return MultiModalDataItems(
            {modality: self[modality] for modality in modalities}
        )

421
422
423
    def get_count(self, modality: str, *, strict: bool = True) -> int:
        """
        Get the number of data items belonging to a modality.
424

425
        If `strict=False`, return `0` instead of raising [`KeyError`][]
426
427
428
429
430
        even if the modality is not found.
        """
        if modality not in self:
            if strict:
                available_modalities = set(self.keys())
431
432
433
434
                raise KeyError(
                    f"Modality {modality!r} not found. "
                    f"Available modalities: {available_modalities}"
                )
435
436
437
438
439
440
441
442
443
444
445
446

            return 0

        return self[modality].get_count()

    def get_all_counts(self) -> Mapping[str, int]:
        """Get the number of items belonging to each modality."""
        return {m: items.get_count() for m, items in self.items()}

    def get_items(
        self,
        modality: str,
447
        typ: type[_D] | tuple[type[_D], ...],
448
449
450
451
452
453
454
    ) -> _D:
        """
        Get the data items belonging to a modality,
        requiring that they belong to a certain type.
        """
        if modality not in self:
            available_modalities = set(self.keys())
455
456
457
458
            raise KeyError(
                f"Modality {modality!r} not found. "
                f"Available modalities: {available_modalities}"
            )
459
460
461

        items = self[modality]
        if not isinstance(items, typ):
462
463
464
465
466
            raise TypeError(
                f"Invalid type of data items for {modality=}. "
                f"Expected type: {typ}, but "
                f"found type: {type(items)}"
            )
467

468
        return items  # type: ignore[return-value]
469
470


471
ModalityDataParser: TypeAlias = Callable[
472
    [ModalityData[Any]], ModalityDataItems[Any, Any] | None
473
]
474
475
476
477


class MultiModalDataParser:
    """
478
    Parses [`MultiModalDataDict`][vllm.inputs.MultiModalDataDict]
479
    into [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
480
481
482
483

    Args:
        target_sr (float, optional): Enables automatic resampling of audio
            items to the model's expected sampling rate.
484
485
486
        target_channels (int, optional): Target number of audio channels.
            If provided, normalizes audio to this many channels (e.g., 1 for mono).
            If None, audio channels are passed through unchanged.
487
488
489
490
        expected_hidden_size (int, optional): Expected hidden dimension for
            embedding inputs. If provided, validates that user-supplied
            embeddings have the correct hidden size to prevent crashes
            during model inference.
491
492
    """

493
494
495
    def __init__(
        self,
        *,
496
        target_sr: float | None = None,
497
        target_channels: int | None = None,
498
        audio_resample_method: Literal["pyav", "scipy"] = "pyav",
499
        video_needs_metadata: bool = False,
500
        expected_hidden_size: int | None = None,
501
    ) -> None:
502
503
        super().__init__()

504
505
506
507
        self.audio_resampler = AudioResampler(
            target_sr=target_sr,
            method=audio_resample_method,
        )
508
        self.target_channels = target_channels
509
        self.video_needs_metadata = video_needs_metadata
510
        self.expected_hidden_size = expected_hidden_size
511

512
513
514
    @classmethod
    def is_embeddings(
        cls, data: object
515
    ) -> TypeGuard[torch.Tensor | list[torch.Tensor]]:
516
517
        if isinstance(data, torch.Tensor):
            return data.ndim == 3
518
        if is_list_of(data, torch.Tensor) and len(data) > 0:
519
            return data[0].ndim == 2  # type: ignore[index]
520
521
522

        return False

523
524
525
    def _get_audio_with_sr(
        self,
        audio: AudioItem,
526
    ) -> tuple[np.ndarray, float | None]:
527
528
529
530
531
532
533
534
535
536
537
        if isinstance(audio, tuple):
            return audio
        if isinstance(audio, list):
            return np.array(audio), None
        if isinstance(audio, np.ndarray):
            return audio, None
        if isinstance(audio, torch.Tensor):
            return audio.numpy(), None

        assert_never(audio)

538
539
540
    def _get_video_with_metadata(
        self,
        video: VideoItem,
541
    ) -> tuple[np.ndarray, dict[str, Any] | None]:
542
543
544
545
546
547
548
549
550
551
552
        if isinstance(video, tuple):
            return video
        if isinstance(video, list):
            return np.array(video), None
        if isinstance(video, np.ndarray):
            return video, None
        if isinstance(video, torch.Tensor):
            return video.numpy(), None

        assert_never(video)

553
554
555
    def _parse_audio_data(
        self,
        data: ModalityData[AudioItem],
556
    ) -> ModalityDataItems[Any, Any] | None:
557
        if data is None:
558
559
            return None

560
        if self.is_embeddings(data):
561
            return AudioEmbeddingItems(data, self.expected_hidden_size)
562

563
        data_items: list[AudioItem]
564
        if (
565
566
            (is_list_of(data, float) and len(data) > 0)
            or (isinstance(data, (np.ndarray, torch.Tensor)) and data.ndim == 1)
567
568
            or isinstance(data, tuple)
        ):
569
570
571
572
            data_items = [data]
        elif isinstance(data, (np.ndarray, torch.Tensor)):
            data_items = [elem for elem in data]
        else:
573
            data_items = data  # type: ignore[assignment]
574
575
576
577
578
579
580

        new_audios = list[np.ndarray]()
        for data_item in data_items:
            audio, orig_sr = self._get_audio_with_sr(data_item)
            if orig_sr is None:
                new_audio = audio
            else:
581
                new_audio = self.audio_resampler.resample(audio, orig_sr=orig_sr)
582

583
584
585
586
587
            # Apply channel normalization if target_channels is set
            if self.target_channels is not None:
                spec = AudioSpec(target_channels=self.target_channels)
                new_audio = normalize_audio(new_audio, spec)

588
589
590
591
592
593
594
            new_audios.append(new_audio)

        return AudioProcessorItems(new_audios)

    def _parse_image_data(
        self,
        data: ModalityData[ImageItem],
595
    ) -> ModalityDataItems[Any, Any] | None:
596
        if data is None:
597
598
            return None

599
        if self.is_embeddings(data):
600
            return ImageEmbeddingItems(data, self.expected_hidden_size)
601

602
603
        if isinstance(data, (PILImage.Image, MediaWithBytes)) or (
            isinstance(data, (np.ndarray, torch.Tensor)) and data.ndim == 3
604
        ):
605
606
607
608
609
610
611
612
613
614
615
            data_items = [data]
        elif isinstance(data, (np.ndarray, torch.Tensor)):
            data_items = [elem for elem in data]
        else:
            data_items = data

        return ImageProcessorItems(data_items)

    def _parse_video_data(
        self,
        data: ModalityData[VideoItem],
616
    ) -> ModalityDataItems[Any, Any] | None:
617
        if data is None:
618
619
            return None

620
        if self.is_embeddings(data):
621
            return VideoEmbeddingItems(data, self.expected_hidden_size)
622

623
        data_items: list[VideoItem]
624
625
        if (is_list_of(data, PILImage.Image) and len(data) > 0) or (
            isinstance(data, (np.ndarray, torch.Tensor)) and data.ndim == 4
626
        ):
627
628
629
            data_items = [data]
        elif isinstance(data, (np.ndarray, torch.Tensor)):
            data_items = [elem for elem in data]
630
631
        elif isinstance(data, tuple) and len(data) == 2:
            data_items = [data]
632
        else:
633
            data_items = data  # type: ignore[assignment]
634

635
636
        new_videos = list[tuple[np.ndarray, dict[str, Any] | None]]()
        metadata_lst: list[dict[str, Any] | None] = []
637
638
639
        for data_item in data_items:
            video, metadata = self._get_video_with_metadata(data_item)
            if self.video_needs_metadata:
640
641
642
643
644
                if metadata is None:
                    raise ValueError(
                        "Video metadata is required but not found in mm input. "
                        "Please check your video input in `multi_modal_data`"
                    )
645
646
647
648
649
650
651
652
653
                new_videos.append((video, metadata))
                metadata_lst.append(metadata)
            else:
                new_videos.append(video)

        if not self.video_needs_metadata:
            metadata = None

        return VideoProcessorItems(new_videos, metadata=metadata_lst)
654

Roger Wang's avatar
Roger Wang committed
655
656
657
658
659
    def _parse_vision_chunk_data(
        self,
        data: ModalityData[Any],
    ) -> ModalityDataItems[Any, Any] | None:
        """Parse vision chunk data (unified image and video chunks)."""
660
        if data is None:
Roger Wang's avatar
Roger Wang committed
661
            return None
662

Roger Wang's avatar
Roger Wang committed
663
664
        if self.is_embeddings(data):
            raise ValueError("Do not support embedding data for vision_chunk right now")
665

666
667
        if isinstance(data, dict):
            data = [data]
668

Roger Wang's avatar
Roger Wang committed
669
670
        return VisionChunkProcessorItems(data)

671
672
673
674
675
    def _get_subparsers(self) -> Mapping[str, ModalityDataParser]:
        return {
            "audio": self._parse_audio_data,
            "image": self._parse_image_data,
            "video": self._parse_video_data,
Roger Wang's avatar
Roger Wang committed
676
            "vision_chunk": self._parse_vision_chunk_data,
677
678
        }

679
    def parse_mm_data(self, mm_data: MultiModalDataDict) -> MultiModalDataItems:
680
681
682
683
684
685
686
        subparsers = self._get_subparsers()

        mm_items = MultiModalDataItems()
        for k, v in mm_data.items():
            if k not in subparsers:
                raise ValueError(f"Unsupported modality: {k}")

687
688
689
            # ignore empty embedding data
            if (parsed_data := subparsers[k](v)) is not None:
                mm_items[k] = parsed_data
690
691

        return mm_items
692
693
694
695


MultiModalUUIDItems: TypeAlias = dict[str, Sequence[str | None]]
"""
696
697
A normalized [`MultiModalUUIDDict`][vllm.inputs.MultiModalUUIDDict]
such that each entry corresponds to a list.
698
699
700
701
702
703
704
705
706
707
708
"""


def parse_mm_uuids(mm_uuids: MultiModalUUIDDict | None) -> MultiModalUUIDItems:
    if mm_uuids is None:
        return {}

    return {
        modality: [uuids] if isinstance(uuids, str) else uuids
        for modality, uuids in mm_uuids.items()
    }