parse.py 21.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
from abc import ABC, abstractmethod
from collections import UserDict
from collections.abc import Callable, Iterator, Mapping, Sequence
7
8
9
10
11
12
from typing import (
    TYPE_CHECKING,
    Any,
    Generic,
    Literal,
    NamedTuple,
13
14
    TypeAlias,
    TypeGuard,
15
16
    TypeVar,
)
17
18
19

import numpy as np
import torch
20
from typing_extensions import assert_never
21

22
from vllm.utils.collection_utils import is_list_of
23
from vllm.utils.import_utils import LazyLoader
24

25
from .audio import AudioResampler, AudioSpec, normalize_audio
26
27
28
29
30
31
32
33
34
35
36
37
from .inputs import (
    AudioItem,
    HfAudioItem,
    HfImageItem,
    HfVideoItem,
    ImageItem,
    ModalityData,
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
    VideoItem,
)
38
from .media import MediaWithBytes
39
40
41
42

_T = TypeVar("_T")
_I = TypeVar("_I")

43
44
45
46
47
if TYPE_CHECKING:
    import PIL.Image as PILImage
else:
    PILImage = LazyLoader("PILImage", globals(), "PIL.Image")

48
49

class ModalityDataItems(ABC, Generic[_T, _I]):
50
    """
51
52
    Represents data items for a modality in
    [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
53
    """
54

55
    def __init__(self, data: _T, modality: str) -> None:
56
57
        super().__init__()

58
        self.data: _T = data
59
60
61
        self.modality = modality

    def __repr__(self) -> str:
62
        return f"{type(self).__name__}(modality={self.modality!r}, len={len(self)})"
63
64
65
66
67
68
69
70
71

    def __len__(self) -> int:
        return self.get_count()

    def __getitem__(self, index: int) -> _I:
        return self.get(index)

    if TYPE_CHECKING:
        # Auto-generated
72
        def __iter__(self) -> Iterator[_I]: ...
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87

    @abstractmethod
    def get_count(self) -> int:
        """Get the number of data items."""
        raise NotImplementedError

    @abstractmethod
    def get(self, index: int) -> _I:
        """Get a data item by its index."""
        raise NotImplementedError

    def get_all(self) -> list[_I]:
        """Get all data items."""
        return [self.get(idx) for idx in range(self.get_count())]

88
89
90
91
92
93
    def get_item_for_hash(self, index: int) -> object:
        return self.get(index)

    def get_all_items_for_hash(self) -> list[object]:
        return [self.get_item_for_hash(idx) for idx in range(self.get_count())]

94
95
96
97
98
99
100
101
102
103
104
105
    @abstractmethod
    def get_processor_data(self) -> Mapping[str, object]:
        """Get the data to pass to the HF processor."""
        raise NotImplementedError

    @abstractmethod
    def get_passthrough_data(self) -> Mapping[str, object]:
        """Get the data to pass directly to the model."""
        raise NotImplementedError


class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]):
106
    """Base class for data items that are arranged in a list."""
107

108
109
110
111
    def _unwrap(self, item: _T | MediaWithBytes[_T]) -> _T:
        """Extract media from wrapper if present."""
        return item.media if isinstance(item, MediaWithBytes) else item

112
113
114
115
    def get_count(self) -> int:
        return len(self.data)

    def get(self, index: int) -> _T:
116
117
118
119
        return self._unwrap(self.data[index])

    def get_item_for_hash(self, index: int) -> _T | MediaWithBytes[_T]:
        # Return raw item for hashing (preserves original_bytes if present)
120
121
122
        return self.data[index]

    def get_processor_data(self) -> Mapping[str, object]:
123
        return {f"{self.modality}s": self.get_all()}
124
125
126
127
128

    def get_passthrough_data(self) -> Mapping[str, object]:
        return {}


129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def validate_embedding_ndim(
    tensor: torch.Tensor,
    modality: str,
    index: int | None = None,
) -> None:
    """Validate tensor ndim for multimodal embeddings.

    Single embeddings should be 2D (seq_len, hidden_size).
    Batched embeddings should be 3D (batch, seq_len, hidden_size).

    Args:
        tensor: The tensor to validate.
        modality: The modality name for error messages (e.g., "image", "audio").
        index: Optional index for list items, included in error messages.
    """
    if tensor.ndim < 2 or tensor.ndim > 3:
        idx_str = f" [{index}]" if index is not None else ""
        raise ValueError(
            f"{modality.capitalize()} embedding{idx_str} must be 2D "
            f"(seq_len, hidden_size) or 3D (batch, seq_len, hidden_size), "
            f"got {tensor.ndim}D tensor with shape {tuple(tensor.shape)}"
        )


153
class EmbeddingItems(
154
    ModalityDataItems[torch.Tensor | list[torch.Tensor], torch.Tensor]
155
):
156
157
158
159
    """
    Base class for data items that are expressed as a batched embedding tensor,
    or a list of embedding tensors (one per item).
    """
160

161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
    def __init__(
        self,
        data: torch.Tensor | list[torch.Tensor],
        modality: str,
        expected_hidden_size: int | None = None,
    ) -> None:
        super().__init__(data, modality)

        # Validate ndim first (before hidden_size which depends on correct ndim)
        self._validate_ndim()

        # Validate hidden dimension if expected size is provided
        if expected_hidden_size is not None:
            self._validate_hidden_size(expected_hidden_size)

    def _validate_ndim(self) -> None:
        """Validate that embedding tensors have correct ndim (2D or 3D)."""
        if isinstance(self.data, torch.Tensor):
            validate_embedding_ndim(self.data, self.modality)
        else:
            # List of tensors: each should be 2D (seq_len, hidden_size)
            for idx, tensor in enumerate(self.data):
                if tensor.ndim != 2:
                    raise ValueError(
                        f"{self.modality.capitalize()} embedding [{idx}] must be "
                        f"2D (seq_len, hidden_size), got {tensor.ndim}D tensor "
                        f"with shape {tuple(tensor.shape)}"
                    )

    def _validate_hidden_size(self, expected_hidden_size: int) -> None:
        """Validate that embedding hidden dimension matches expected size.

        This validates hidden dimensions to prevent vulnerabilities: Embeddings
        with correct ndim but wrong hidden dimension could bypass initial
        checks and cause crashes during model inference when dimensions don't match.
        """
        if isinstance(self.data, torch.Tensor):
            # Batched tensor: shape is (batch, seq_len, hidden_size)
            actual_hidden_size = self.data.shape[-1]
            if actual_hidden_size != expected_hidden_size:
                raise ValueError(
                    f"{self.modality.capitalize()} embedding hidden dimension "
                    f"mismatch: got {actual_hidden_size}, but model expects "
                    f"{expected_hidden_size}. Embedding shape: {tuple(self.data.shape)}"
                )
        else:
            # List of tensors: each has shape (seq_len, hidden_size)
            for idx, tensor in enumerate(self.data):
                actual_hidden_size = tensor.shape[-1]
                if actual_hidden_size != expected_hidden_size:
                    raise ValueError(
                        f"{self.modality.capitalize()} embedding [{idx}] hidden "
                        f"dimension mismatch: got {actual_hidden_size}, but model "
                        f"expects {expected_hidden_size}. "
                        f"Embedding shape: {tuple(tensor.shape)}"
                    )

218
219
220
221
222
223
    def _unwrap(
        self, item: torch.Tensor | MediaWithBytes[torch.Tensor]
    ) -> torch.Tensor:
        """Extract media from wrapper if present."""
        return item.media if isinstance(item, MediaWithBytes) else item

224
225
226
    def get_count(self) -> int:
        return len(self.data)

227
    def get(self, index: int) -> torch.Tensor:
228
        return self._unwrap(self.data[index])
229
230
231
232
233
234
235

    def get_processor_data(self) -> Mapping[str, object]:
        return {}

    def get_passthrough_data(self) -> Mapping[str, object]:
        return {f"{self.modality}_embeds": self.data}

236
237
238
    def get_feature_size(self, item_idx: int) -> int:
        return len(self.get(item_idx))

239

240
241
242
class DictEmbeddingItems(
    ModalityDataItems[Mapping[str, torch.Tensor], Mapping[str, torch.Tensor]]
):
243
244
245
246
247
248
249
250
251
252
253
    """
    Base class for data items that are expressed as a dictionary of tensors.

    Usually, the dictionary keys correspond to the outputs of HF processor.
    """

    def __init__(
        self,
        data: Mapping[str, torch.Tensor],
        modality: str,
        required_fields: set[str],
254
255
256
257
        fields_factory: Callable[
            [Mapping[str, torch.Tensor]],
            Mapping[str, MultiModalFieldConfig],
        ],
258
    ) -> None:
259
260
        from transformers.feature_extraction_utils import BatchFeature

261
262
263
264
265
        super().__init__(data, modality)

        missing_required_data_keys = required_fields - data.keys()
        if missing_required_data_keys:
            data_keys = set(data.keys())
266
267
268
269
            msg = (
                f"The data should contain the fields: {required_fields}, "
                f"but only found the following keys: {data_keys}"
            )
270
271
            raise ValueError(msg)

272
273
274
275
276
277
278
        fields_config = fields_factory(data)
        missing_required_fields = required_fields - fields_config.keys()
        if missing_required_fields:
            fields = set(fields_config.keys())
            msg = f"{required_fields=} should be a subset of {fields=}"
            raise ValueError(msg)

279
280
281
        self.fields_config = fields_config
        self.required_fields = required_fields

282
        self._kwargs = MultiModalKwargsItems.from_hf_inputs(
283
284
285
286
287
            BatchFeature(dict(data)),
            fields_config,
        )

    def get_count(self) -> int:
288
        return len(self._kwargs[self.modality])
289
290

    def get(self, index: int) -> Mapping[str, torch.Tensor]:
291
        return self._kwargs[self.modality][index].get_data()
292
293
294
295
296
297
298
299

    def get_processor_data(self) -> Mapping[str, object]:
        return {}

    def get_passthrough_data(self) -> Mapping[str, object]:
        return self.data


300
class AudioProcessorItems(ProcessorBatchItems[HfAudioItem]):
301
    def __init__(self, data: Sequence[HfAudioItem] | None) -> None:
302
303
        if data is None:
            data = [None]
304
305
        super().__init__(data, "audio")

306
307
308
309
    def get_audio_length(self, item_idx: int) -> int:
        audio = self.get(item_idx)
        return len(audio)

310
311

class AudioEmbeddingItems(EmbeddingItems):
312
313
314
315
316
317
    def __init__(
        self,
        data: torch.Tensor | list[torch.Tensor],
        expected_hidden_size: int | None = None,
    ) -> None:
        super().__init__(data, "audio", expected_hidden_size)
318
319
320
321
322
323
324
325


class ImageSize(NamedTuple):
    width: int
    height: int


class ImageProcessorItems(ProcessorBatchItems[HfImageItem]):
326
    def __init__(self, data: Sequence[HfImageItem] | None) -> None:
327
328
        if data is None:
            data = [None]
329
330
331
332
333
        super().__init__(data, "image")

    def get_image_size(self, item_idx: int) -> ImageSize:
        image = self.get(item_idx)

334
        if isinstance(image, PILImage.Image):
335
336
337
338
339
340
341
342
343
            return ImageSize(*image.size)
        if isinstance(image, (np.ndarray, torch.Tensor)):
            _, h, w = image.shape
            return ImageSize(w, h)

        assert_never(image)


class ImageEmbeddingItems(EmbeddingItems):
344
345
346
347
348
349
    def __init__(
        self,
        data: torch.Tensor | list[torch.Tensor],
        expected_hidden_size: int | None = None,
    ) -> None:
        super().__init__(data, "image", expected_hidden_size)
350
351
352


class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]):
353
354
    def __init__(
        self,
355
356
        data: Sequence[HfVideoItem] | None,
        metadata: dict[str, Any] | list[dict[str, Any] | None] | None = None,
357
    ) -> None:
358
359
        if data is None:
            data = [None]
360
        super().__init__(data, "video")
361
        self.metadata = metadata
362

363
364
365
366
367
368
    def get_num_frames(self, item_idx: int) -> int:
        return len(self.get(item_idx))

    def get_frame_size(self, item_idx: int) -> ImageSize:
        image = self.get(item_idx)[0]  # Assume that the video isn't empty

369
        if isinstance(image, PILImage.Image):
370
371
372
373
374
375
376
            return ImageSize(*image.size)
        if isinstance(image, (np.ndarray, torch.Tensor)):
            _, h, w = image.shape
            return ImageSize(w, h)

        assert_never(image)

377
378

class VideoEmbeddingItems(EmbeddingItems):
379
380
381
382
383
384
    def __init__(
        self,
        data: torch.Tensor | list[torch.Tensor],
        expected_hidden_size: int | None = None,
    ) -> None:
        super().__init__(data, "video", expected_hidden_size)
385
386
387
388
389
390
391


_D = TypeVar("_D", bound=ModalityDataItems[Any, Any])


class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]):
    """
392
393
    As [`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict], but
    normalized such that each entry corresponds to a list.
394
395
396
397
398
    """

    def get_count(self, modality: str, *, strict: bool = True) -> int:
        """
        Get the number of data items belonging to a modality.
399

400
        If `strict=False`, return `0` instead of raising [`KeyError`][]
401
402
403
404
405
        even if the modality is not found.
        """
        if modality not in self:
            if strict:
                available_modalities = set(self.keys())
406
407
408
409
                raise KeyError(
                    f"Modality {modality!r} not found. "
                    f"Available modalities: {available_modalities}"
                )
410
411
412
413
414
415
416
417
418
419
420
421

            return 0

        return self[modality].get_count()

    def get_all_counts(self) -> Mapping[str, int]:
        """Get the number of items belonging to each modality."""
        return {m: items.get_count() for m, items in self.items()}

    def get_items(
        self,
        modality: str,
422
        typ: type[_D] | tuple[type[_D], ...],
423
424
425
426
427
428
429
    ) -> _D:
        """
        Get the data items belonging to a modality,
        requiring that they belong to a certain type.
        """
        if modality not in self:
            available_modalities = set(self.keys())
430
431
432
433
            raise KeyError(
                f"Modality {modality!r} not found. "
                f"Available modalities: {available_modalities}"
            )
434
435
436

        items = self[modality]
        if not isinstance(items, typ):
437
438
439
440
441
            raise TypeError(
                f"Invalid type of data items for {modality=}. "
                f"Expected type: {typ}, but "
                f"found type: {type(items)}"
            )
442

443
        return items  # type: ignore[return-value]
444
445


446
ModalityDataParser: TypeAlias = Callable[
447
    [ModalityData[Any]], ModalityDataItems[Any, Any] | None
448
]
449
450
451
452


class MultiModalDataParser:
    """
453
454
    Parses [`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict]
    into [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
455
456
457
458

    Args:
        target_sr (float, optional): Enables automatic resampling of audio
            items to the model's expected sampling rate.
459
460
461
        target_channels (int, optional): Target number of audio channels.
            If provided, normalizes audio to this many channels (e.g., 1 for mono).
            If None, audio channels are passed through unchanged.
462
463
464
465
        expected_hidden_size (int, optional): Expected hidden dimension for
            embedding inputs. If provided, validates that user-supplied
            embeddings have the correct hidden size to prevent crashes
            during model inference.
466
467
    """

468
469
470
    def __init__(
        self,
        *,
471
        target_sr: float | None = None,
472
        target_channels: int | None = None,
473
        audio_resample_method: Literal["librosa", "scipy"] = "librosa",
474
        video_needs_metadata: bool = False,
475
        expected_hidden_size: int | None = None,
476
    ) -> None:
477
478
        super().__init__()

479
480
481
482
        self.audio_resampler = AudioResampler(
            target_sr=target_sr,
            method=audio_resample_method,
        )
483
        self.target_channels = target_channels
484
        self.video_needs_metadata = video_needs_metadata
485
        self.expected_hidden_size = expected_hidden_size
486

487
488
489
    @classmethod
    def is_embeddings(
        cls, data: object
490
    ) -> TypeGuard[torch.Tensor | list[torch.Tensor]]:
491
492
493
        if isinstance(data, torch.Tensor):
            return data.ndim == 3
        if is_list_of(data, torch.Tensor):
494
            return data[0].ndim == 2  # type: ignore[index]
495
496
497
498
499
500
501
502

        return False

    def _is_empty(self, data: object) -> TypeGuard[None]:
        if isinstance(data, list):
            return len(data) == 0
        if isinstance(data, (np.ndarray, torch.Tensor)):
            return data.size == 0
503
504
505
506
507
508

        return False

    def _get_audio_with_sr(
        self,
        audio: AudioItem,
509
    ) -> tuple[np.ndarray, float | None]:
510
511
512
513
514
515
516
517
518
519
520
        if isinstance(audio, tuple):
            return audio
        if isinstance(audio, list):
            return np.array(audio), None
        if isinstance(audio, np.ndarray):
            return audio, None
        if isinstance(audio, torch.Tensor):
            return audio.numpy(), None

        assert_never(audio)

521
522
523
    def _get_video_with_metadata(
        self,
        video: VideoItem,
524
    ) -> tuple[np.ndarray, dict[str, Any] | None]:
525
526
527
528
529
530
531
532
533
534
535
        if isinstance(video, tuple):
            return video
        if isinstance(video, list):
            return np.array(video), None
        if isinstance(video, np.ndarray):
            return video, None
        if isinstance(video, torch.Tensor):
            return video.numpy(), None

        assert_never(video)

536
537
538
    def _parse_audio_data(
        self,
        data: ModalityData[AudioItem],
539
    ) -> ModalityDataItems[Any, Any] | None:
540
541
542
        if data is None:
            return AudioProcessorItems(None)

543
        # also check single audio item with sampling rate
544
545
546
        if self._is_empty(data) or (
            isinstance(data, tuple) and self._is_empty(data[0])
        ):
547
548
            return None

549
        if self.is_embeddings(data):
550
            return AudioEmbeddingItems(data, self.expected_hidden_size)
551

552
        data_items: list[AudioItem]
553
554
555
556
557
558
        if (
            is_list_of(data, float)
            or isinstance(data, (np.ndarray, torch.Tensor))
            and data.ndim == 1
            or isinstance(data, tuple)
        ):
559
560
561
562
            data_items = [data]
        elif isinstance(data, (np.ndarray, torch.Tensor)):
            data_items = [elem for elem in data]
        else:
563
            data_items = data  # type: ignore[assignment]
564
565
566
567
568
569
570

        new_audios = list[np.ndarray]()
        for data_item in data_items:
            audio, orig_sr = self._get_audio_with_sr(data_item)
            if orig_sr is None:
                new_audio = audio
            else:
571
                new_audio = self.audio_resampler.resample(audio, orig_sr=orig_sr)
572

573
574
575
576
577
            # Apply channel normalization if target_channels is set
            if self.target_channels is not None:
                spec = AudioSpec(target_channels=self.target_channels)
                new_audio = normalize_audio(new_audio, spec)

578
579
580
581
582
583
584
            new_audios.append(new_audio)

        return AudioProcessorItems(new_audios)

    def _parse_image_data(
        self,
        data: ModalityData[ImageItem],
585
    ) -> ModalityDataItems[Any, Any] | None:
586
587
588
        if data is None:
            return ImageProcessorItems(None)

589
590
591
        if self._is_empty(data):
            return None

592
        if self.is_embeddings(data):
593
            return ImageEmbeddingItems(data, self.expected_hidden_size)
594

595
        if (
596
            isinstance(data, (PILImage.Image, MediaWithBytes))
597
598
599
            or isinstance(data, (np.ndarray, torch.Tensor))
            and data.ndim == 3
        ):
600
601
602
603
604
605
606
607
608
609
610
            data_items = [data]
        elif isinstance(data, (np.ndarray, torch.Tensor)):
            data_items = [elem for elem in data]
        else:
            data_items = data

        return ImageProcessorItems(data_items)

    def _parse_video_data(
        self,
        data: ModalityData[VideoItem],
611
    ) -> ModalityDataItems[Any, Any] | None:
612
613
614
        if data is None:
            return VideoProcessorItems(None)

615
616
617
        if self._is_empty(data):
            return None

618
        if self.is_embeddings(data):
619
            return VideoEmbeddingItems(data, self.expected_hidden_size)
620

621
        data_items: list[VideoItem]
622
623
624
625
626
        if (
            is_list_of(data, PILImage.Image)
            or isinstance(data, (np.ndarray, torch.Tensor))
            and data.ndim == 4
        ):
627
628
629
            data_items = [data]
        elif isinstance(data, (np.ndarray, torch.Tensor)):
            data_items = [elem for elem in data]
630
631
        elif isinstance(data, tuple) and len(data) == 2:
            data_items = [data]
632
        else:
633
            data_items = data  # type: ignore[assignment]
634

635
636
        new_videos = list[tuple[np.ndarray, dict[str, Any] | None]]()
        metadata_lst: list[dict[str, Any] | None] = []
637
638
639
        for data_item in data_items:
            video, metadata = self._get_video_with_metadata(data_item)
            if self.video_needs_metadata:
640
641
642
643
644
                if metadata is None:
                    raise ValueError(
                        "Video metadata is required but not found in mm input. "
                        "Please check your video input in `multi_modal_data`"
                    )
645
646
647
648
649
650
651
652
653
                new_videos.append((video, metadata))
                metadata_lst.append(metadata)
            else:
                new_videos.append(video)

        if not self.video_needs_metadata:
            metadata = None

        return VideoProcessorItems(new_videos, metadata=metadata_lst)
654
655
656
657
658
659
660
661

    def _get_subparsers(self) -> Mapping[str, ModalityDataParser]:
        return {
            "audio": self._parse_audio_data,
            "image": self._parse_image_data,
            "video": self._parse_video_data,
        }

662
    def parse_mm_data(self, mm_data: MultiModalDataDict) -> MultiModalDataItems:
663
664
665
666
667
668
669
        subparsers = self._get_subparsers()

        mm_items = MultiModalDataItems()
        for k, v in mm_data.items():
            if k not in subparsers:
                raise ValueError(f"Unsupported modality: {k}")

670
671
672
            # ignore empty embedding data
            if (parsed_data := subparsers[k](v)) is not None:
                mm_items[k] = parsed_data
673
674

        return mm_items