parse.py 22.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
from abc import ABC, abstractmethod
from collections import UserDict
from collections.abc import Callable, Iterator, Mapping, Sequence
7
8
9
10
11
12
from typing import (
    TYPE_CHECKING,
    Any,
    Generic,
    Literal,
    NamedTuple,
13
14
    TypeAlias,
    TypeGuard,
15
16
    TypeVar,
)
17
18
19

import numpy as np
import torch
20
from typing_extensions import assert_never
21

22
from vllm.utils.collection_utils import is_list_of
23
from vllm.utils.import_utils import LazyLoader
24

25
from .audio import AudioResampler, AudioSpec, normalize_audio
26
27
28
29
30
31
32
33
34
35
36
37
from .inputs import (
    AudioItem,
    HfAudioItem,
    HfImageItem,
    HfVideoItem,
    ImageItem,
    ModalityData,
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
    VideoItem,
)
38
from .media import MediaWithBytes
39
40
41
42

_T = TypeVar("_T")
_I = TypeVar("_I")

43
44
45
46
47
if TYPE_CHECKING:
    import PIL.Image as PILImage
else:
    PILImage = LazyLoader("PILImage", globals(), "PIL.Image")

48
49

class ModalityDataItems(ABC, Generic[_T, _I]):
50
    """
51
52
    Represents data items for a modality in
    [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
53
    """
54

55
    def __init__(self, data: _T, modality: str) -> None:
56
57
        super().__init__()

58
        self.data: _T = data
59
60
61
        self.modality = modality

    def __repr__(self) -> str:
62
        return f"{type(self).__name__}(modality={self.modality!r}, len={len(self)})"
63
64
65
66
67
68
69
70
71

    def __len__(self) -> int:
        return self.get_count()

    def __getitem__(self, index: int) -> _I:
        return self.get(index)

    if TYPE_CHECKING:
        # Auto-generated
72
        def __iter__(self) -> Iterator[_I]: ...
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87

    @abstractmethod
    def get_count(self) -> int:
        """Get the number of data items."""
        raise NotImplementedError

    @abstractmethod
    def get(self, index: int) -> _I:
        """Get a data item by its index."""
        raise NotImplementedError

    def get_all(self) -> list[_I]:
        """Get all data items."""
        return [self.get(idx) for idx in range(self.get_count())]

88
89
90
91
92
93
    def get_item_for_hash(self, index: int) -> object:
        return self.get(index)

    def get_all_items_for_hash(self) -> list[object]:
        return [self.get_item_for_hash(idx) for idx in range(self.get_count())]

94
95
96
97
98
99
100
101
102
103
104
105
    @abstractmethod
    def get_processor_data(self) -> Mapping[str, object]:
        """Get the data to pass to the HF processor."""
        raise NotImplementedError

    @abstractmethod
    def get_passthrough_data(self) -> Mapping[str, object]:
        """Get the data to pass directly to the model."""
        raise NotImplementedError


class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]):
106
    """Base class for data items that are arranged in a list."""
107

108
109
110
111
    def _unwrap(self, item: _T | MediaWithBytes[_T]) -> _T:
        """Extract media from wrapper if present."""
        return item.media if isinstance(item, MediaWithBytes) else item

112
113
114
115
    def get_count(self) -> int:
        return len(self.data)

    def get(self, index: int) -> _T:
116
117
118
119
        return self._unwrap(self.data[index])

    def get_item_for_hash(self, index: int) -> _T | MediaWithBytes[_T]:
        # Return raw item for hashing (preserves original_bytes if present)
120
121
122
        return self.data[index]

    def get_processor_data(self) -> Mapping[str, object]:
123
        return {f"{self.modality}s": self.get_all()}
124
125
126
127
128

    def get_passthrough_data(self) -> Mapping[str, object]:
        return {}


129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def validate_embedding_ndim(
    tensor: torch.Tensor,
    modality: str,
    index: int | None = None,
) -> None:
    """Validate tensor ndim for multimodal embeddings.

    Single embeddings should be 2D (seq_len, hidden_size).
    Batched embeddings should be 3D (batch, seq_len, hidden_size).

    Args:
        tensor: The tensor to validate.
        modality: The modality name for error messages (e.g., "image", "audio").
        index: Optional index for list items, included in error messages.
    """
    if tensor.ndim < 2 or tensor.ndim > 3:
        idx_str = f" [{index}]" if index is not None else ""
        raise ValueError(
            f"{modality.capitalize()} embedding{idx_str} must be 2D "
            f"(seq_len, hidden_size) or 3D (batch, seq_len, hidden_size), "
            f"got {tensor.ndim}D tensor with shape {tuple(tensor.shape)}"
        )


153
class EmbeddingItems(
154
    ModalityDataItems[torch.Tensor | list[torch.Tensor], torch.Tensor]
155
):
156
157
158
159
    """
    Base class for data items that are expressed as a batched embedding tensor,
    or a list of embedding tensors (one per item).
    """
160

161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
    def __init__(
        self,
        data: torch.Tensor | list[torch.Tensor],
        modality: str,
        expected_hidden_size: int | None = None,
    ) -> None:
        super().__init__(data, modality)

        # Validate ndim first (before hidden_size which depends on correct ndim)
        self._validate_ndim()

        # Validate hidden dimension if expected size is provided
        if expected_hidden_size is not None:
            self._validate_hidden_size(expected_hidden_size)

    def _validate_ndim(self) -> None:
        """Validate that embedding tensors have correct ndim (2D or 3D)."""
        if isinstance(self.data, torch.Tensor):
            validate_embedding_ndim(self.data, self.modality)
        else:
            # List of tensors: each should be 2D (seq_len, hidden_size)
            for idx, tensor in enumerate(self.data):
                if tensor.ndim != 2:
                    raise ValueError(
                        f"{self.modality.capitalize()} embedding [{idx}] must be "
                        f"2D (seq_len, hidden_size), got {tensor.ndim}D tensor "
                        f"with shape {tuple(tensor.shape)}"
                    )

    def _validate_hidden_size(self, expected_hidden_size: int) -> None:
        """Validate that embedding hidden dimension matches expected size.

        This validates hidden dimensions to prevent vulnerabilities: Embeddings
        with correct ndim but wrong hidden dimension could bypass initial
        checks and cause crashes during model inference when dimensions don't match.
        """
        if isinstance(self.data, torch.Tensor):
            # Batched tensor: shape is (batch, seq_len, hidden_size)
            actual_hidden_size = self.data.shape[-1]
            if actual_hidden_size != expected_hidden_size:
                raise ValueError(
                    f"{self.modality.capitalize()} embedding hidden dimension "
                    f"mismatch: got {actual_hidden_size}, but model expects "
                    f"{expected_hidden_size}. Embedding shape: {tuple(self.data.shape)}"
                )
        else:
            # List of tensors: each has shape (seq_len, hidden_size)
            for idx, tensor in enumerate(self.data):
                actual_hidden_size = tensor.shape[-1]
                if actual_hidden_size != expected_hidden_size:
                    raise ValueError(
                        f"{self.modality.capitalize()} embedding [{idx}] hidden "
                        f"dimension mismatch: got {actual_hidden_size}, but model "
                        f"expects {expected_hidden_size}. "
                        f"Embedding shape: {tuple(tensor.shape)}"
                    )

218
219
220
221
222
223
    def _unwrap(
        self, item: torch.Tensor | MediaWithBytes[torch.Tensor]
    ) -> torch.Tensor:
        """Extract media from wrapper if present."""
        return item.media if isinstance(item, MediaWithBytes) else item

224
225
226
    def get_count(self) -> int:
        return len(self.data)

227
    def get(self, index: int) -> torch.Tensor:
228
        return self._unwrap(self.data[index])
229
230
231
232
233
234
235

    def get_processor_data(self) -> Mapping[str, object]:
        return {}

    def get_passthrough_data(self) -> Mapping[str, object]:
        return {f"{self.modality}_embeds": self.data}

236
237
238
    def get_feature_size(self, item_idx: int) -> int:
        return len(self.get(item_idx))

239

240
241
242
class DictEmbeddingItems(
    ModalityDataItems[Mapping[str, torch.Tensor], Mapping[str, torch.Tensor]]
):
243
244
245
246
247
248
249
250
251
252
253
    """
    Base class for data items that are expressed as a dictionary of tensors.

    Usually, the dictionary keys correspond to the outputs of HF processor.
    """

    def __init__(
        self,
        data: Mapping[str, torch.Tensor],
        modality: str,
        required_fields: set[str],
254
255
256
257
        fields_factory: Callable[
            [Mapping[str, torch.Tensor]],
            Mapping[str, MultiModalFieldConfig],
        ],
258
    ) -> None:
259
260
        from transformers.feature_extraction_utils import BatchFeature

261
262
263
264
265
        super().__init__(data, modality)

        missing_required_data_keys = required_fields - data.keys()
        if missing_required_data_keys:
            data_keys = set(data.keys())
266
267
268
269
            msg = (
                f"The data should contain the fields: {required_fields}, "
                f"but only found the following keys: {data_keys}"
            )
270
271
            raise ValueError(msg)

272
273
274
275
276
277
278
        fields_config = fields_factory(data)
        missing_required_fields = required_fields - fields_config.keys()
        if missing_required_fields:
            fields = set(fields_config.keys())
            msg = f"{required_fields=} should be a subset of {fields=}"
            raise ValueError(msg)

279
280
281
        self.fields_config = fields_config
        self.required_fields = required_fields

282
        self._kwargs = MultiModalKwargsItems.from_hf_inputs(
283
284
285
286
287
            BatchFeature(dict(data)),
            fields_config,
        )

    def get_count(self) -> int:
288
        return len(self._kwargs[self.modality])
289
290

    def get(self, index: int) -> Mapping[str, torch.Tensor]:
291
        return self._kwargs[self.modality][index].get_data()
292
293
294
295
296
297
298
299

    def get_processor_data(self) -> Mapping[str, object]:
        return {}

    def get_passthrough_data(self) -> Mapping[str, object]:
        return self.data


300
class AudioProcessorItems(ProcessorBatchItems[HfAudioItem]):
301
    def __init__(self, data: Sequence[HfAudioItem] | None) -> None:
302
303
        if data is None:
            data = [None]
304
305
        super().__init__(data, "audio")

306
307
308
309
    def get_audio_length(self, item_idx: int) -> int:
        audio = self.get(item_idx)
        return len(audio)

310
311

class AudioEmbeddingItems(EmbeddingItems):
312
313
314
315
316
317
    def __init__(
        self,
        data: torch.Tensor | list[torch.Tensor],
        expected_hidden_size: int | None = None,
    ) -> None:
        super().__init__(data, "audio", expected_hidden_size)
318
319
320
321
322
323
324
325


class ImageSize(NamedTuple):
    width: int
    height: int


class ImageProcessorItems(ProcessorBatchItems[HfImageItem]):
326
    def __init__(self, data: Sequence[HfImageItem] | None) -> None:
327
328
        if data is None:
            data = [None]
329
330
331
332
333
        super().__init__(data, "image")

    def get_image_size(self, item_idx: int) -> ImageSize:
        image = self.get(item_idx)

334
        if isinstance(image, PILImage.Image):
335
336
337
338
339
340
341
342
343
            return ImageSize(*image.size)
        if isinstance(image, (np.ndarray, torch.Tensor)):
            _, h, w = image.shape
            return ImageSize(w, h)

        assert_never(image)


class ImageEmbeddingItems(EmbeddingItems):
344
345
346
347
348
349
    def __init__(
        self,
        data: torch.Tensor | list[torch.Tensor],
        expected_hidden_size: int | None = None,
    ) -> None:
        super().__init__(data, "image", expected_hidden_size)
350
351
352


class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]):
353
354
    def __init__(
        self,
355
356
        data: Sequence[HfVideoItem] | None,
        metadata: dict[str, Any] | list[dict[str, Any] | None] | None = None,
357
    ) -> None:
358
359
        if data is None:
            data = [None]
360
        super().__init__(data, "video")
361
        self.metadata = metadata
362

363
364
365
366
367
368
    def get_num_frames(self, item_idx: int) -> int:
        return len(self.get(item_idx))

    def get_frame_size(self, item_idx: int) -> ImageSize:
        image = self.get(item_idx)[0]  # Assume that the video isn't empty

369
        if isinstance(image, PILImage.Image):
370
371
372
373
374
375
376
            return ImageSize(*image.size)
        if isinstance(image, (np.ndarray, torch.Tensor)):
            _, h, w = image.shape
            return ImageSize(w, h)

        assert_never(image)

377
378

class VideoEmbeddingItems(EmbeddingItems):
379
380
381
382
383
384
    def __init__(
        self,
        data: torch.Tensor | list[torch.Tensor],
        expected_hidden_size: int | None = None,
    ) -> None:
        super().__init__(data, "video", expected_hidden_size)
385
386


Roger Wang's avatar
Roger Wang committed
387
388
389
390
391
392
393
class VisionChunkProcessorItems(ProcessorBatchItems[Any]):
    """Processor items for vision chunks (unified image and video chunks)."""

    def __init__(self, data: Sequence[Any]) -> None:
        super().__init__(data, "vision_chunk")


394
395
396
397
398
_D = TypeVar("_D", bound=ModalityDataItems[Any, Any])


class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]):
    """
399
400
    As [`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict], but
    normalized such that each entry corresponds to a list.
401
402
403
404
405
    """

    def get_count(self, modality: str, *, strict: bool = True) -> int:
        """
        Get the number of data items belonging to a modality.
406

407
        If `strict=False`, return `0` instead of raising [`KeyError`][]
408
409
410
411
412
        even if the modality is not found.
        """
        if modality not in self:
            if strict:
                available_modalities = set(self.keys())
413
414
415
416
                raise KeyError(
                    f"Modality {modality!r} not found. "
                    f"Available modalities: {available_modalities}"
                )
417
418
419
420
421
422
423
424
425
426
427
428

            return 0

        return self[modality].get_count()

    def get_all_counts(self) -> Mapping[str, int]:
        """Get the number of items belonging to each modality."""
        return {m: items.get_count() for m, items in self.items()}

    def get_items(
        self,
        modality: str,
429
        typ: type[_D] | tuple[type[_D], ...],
430
431
432
433
434
435
436
    ) -> _D:
        """
        Get the data items belonging to a modality,
        requiring that they belong to a certain type.
        """
        if modality not in self:
            available_modalities = set(self.keys())
437
438
439
440
            raise KeyError(
                f"Modality {modality!r} not found. "
                f"Available modalities: {available_modalities}"
            )
441
442
443

        items = self[modality]
        if not isinstance(items, typ):
444
445
446
447
448
            raise TypeError(
                f"Invalid type of data items for {modality=}. "
                f"Expected type: {typ}, but "
                f"found type: {type(items)}"
            )
449

450
        return items  # type: ignore[return-value]
451
452


453
ModalityDataParser: TypeAlias = Callable[
454
    [ModalityData[Any]], ModalityDataItems[Any, Any] | None
455
]
456
457
458
459


class MultiModalDataParser:
    """
460
461
    Parses [`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict]
    into [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
462
463
464
465

    Args:
        target_sr (float, optional): Enables automatic resampling of audio
            items to the model's expected sampling rate.
466
467
468
        target_channels (int, optional): Target number of audio channels.
            If provided, normalizes audio to this many channels (e.g., 1 for mono).
            If None, audio channels are passed through unchanged.
469
470
471
472
        expected_hidden_size (int, optional): Expected hidden dimension for
            embedding inputs. If provided, validates that user-supplied
            embeddings have the correct hidden size to prevent crashes
            during model inference.
473
474
    """

475
476
477
    def __init__(
        self,
        *,
478
        target_sr: float | None = None,
479
        target_channels: int | None = None,
480
        audio_resample_method: Literal["librosa", "scipy"] = "librosa",
481
        video_needs_metadata: bool = False,
482
        expected_hidden_size: int | None = None,
483
    ) -> None:
484
485
        super().__init__()

486
487
488
489
        self.audio_resampler = AudioResampler(
            target_sr=target_sr,
            method=audio_resample_method,
        )
490
        self.target_channels = target_channels
491
        self.video_needs_metadata = video_needs_metadata
492
        self.expected_hidden_size = expected_hidden_size
493

494
495
496
    @classmethod
    def is_embeddings(
        cls, data: object
497
    ) -> TypeGuard[torch.Tensor | list[torch.Tensor]]:
498
499
500
        if isinstance(data, torch.Tensor):
            return data.ndim == 3
        if is_list_of(data, torch.Tensor):
501
            return data[0].ndim == 2  # type: ignore[index]
502
503
504
505
506
507
508
509

        return False

    def _is_empty(self, data: object) -> TypeGuard[None]:
        if isinstance(data, list):
            return len(data) == 0
        if isinstance(data, (np.ndarray, torch.Tensor)):
            return data.size == 0
510
511
512
513
514
515

        return False

    def _get_audio_with_sr(
        self,
        audio: AudioItem,
516
    ) -> tuple[np.ndarray, float | None]:
517
518
519
520
521
522
523
524
525
526
527
        if isinstance(audio, tuple):
            return audio
        if isinstance(audio, list):
            return np.array(audio), None
        if isinstance(audio, np.ndarray):
            return audio, None
        if isinstance(audio, torch.Tensor):
            return audio.numpy(), None

        assert_never(audio)

528
529
530
    def _get_video_with_metadata(
        self,
        video: VideoItem,
531
    ) -> tuple[np.ndarray, dict[str, Any] | None]:
532
533
534
535
536
537
538
539
540
541
542
        if isinstance(video, tuple):
            return video
        if isinstance(video, list):
            return np.array(video), None
        if isinstance(video, np.ndarray):
            return video, None
        if isinstance(video, torch.Tensor):
            return video.numpy(), None

        assert_never(video)

543
544
545
    def _parse_audio_data(
        self,
        data: ModalityData[AudioItem],
546
    ) -> ModalityDataItems[Any, Any] | None:
547
548
549
        if data is None:
            return AudioProcessorItems(None)

550
        # also check single audio item with sampling rate
551
552
553
        if self._is_empty(data) or (
            isinstance(data, tuple) and self._is_empty(data[0])
        ):
554
555
            return None

556
        if self.is_embeddings(data):
557
            return AudioEmbeddingItems(data, self.expected_hidden_size)
558

559
        data_items: list[AudioItem]
560
561
562
563
564
565
        if (
            is_list_of(data, float)
            or isinstance(data, (np.ndarray, torch.Tensor))
            and data.ndim == 1
            or isinstance(data, tuple)
        ):
566
567
568
569
            data_items = [data]
        elif isinstance(data, (np.ndarray, torch.Tensor)):
            data_items = [elem for elem in data]
        else:
570
            data_items = data  # type: ignore[assignment]
571
572
573
574
575
576
577

        new_audios = list[np.ndarray]()
        for data_item in data_items:
            audio, orig_sr = self._get_audio_with_sr(data_item)
            if orig_sr is None:
                new_audio = audio
            else:
578
                new_audio = self.audio_resampler.resample(audio, orig_sr=orig_sr)
579

580
581
582
583
584
            # Apply channel normalization if target_channels is set
            if self.target_channels is not None:
                spec = AudioSpec(target_channels=self.target_channels)
                new_audio = normalize_audio(new_audio, spec)

585
586
587
588
589
590
591
            new_audios.append(new_audio)

        return AudioProcessorItems(new_audios)

    def _parse_image_data(
        self,
        data: ModalityData[ImageItem],
592
    ) -> ModalityDataItems[Any, Any] | None:
593
594
595
        if data is None:
            return ImageProcessorItems(None)

596
597
598
        if self._is_empty(data):
            return None

599
        if self.is_embeddings(data):
600
            return ImageEmbeddingItems(data, self.expected_hidden_size)
601

602
        if (
603
            isinstance(data, (PILImage.Image, MediaWithBytes))
604
605
606
            or isinstance(data, (np.ndarray, torch.Tensor))
            and data.ndim == 3
        ):
607
608
609
610
611
612
613
614
615
616
617
            data_items = [data]
        elif isinstance(data, (np.ndarray, torch.Tensor)):
            data_items = [elem for elem in data]
        else:
            data_items = data

        return ImageProcessorItems(data_items)

    def _parse_video_data(
        self,
        data: ModalityData[VideoItem],
618
    ) -> ModalityDataItems[Any, Any] | None:
619
620
621
        if data is None:
            return VideoProcessorItems(None)

622
623
624
        if self._is_empty(data):
            return None

625
        if self.is_embeddings(data):
626
            return VideoEmbeddingItems(data, self.expected_hidden_size)
627

628
        data_items: list[VideoItem]
629
630
631
632
633
        if (
            is_list_of(data, PILImage.Image)
            or isinstance(data, (np.ndarray, torch.Tensor))
            and data.ndim == 4
        ):
634
635
636
            data_items = [data]
        elif isinstance(data, (np.ndarray, torch.Tensor)):
            data_items = [elem for elem in data]
637
638
        elif isinstance(data, tuple) and len(data) == 2:
            data_items = [data]
639
        else:
640
            data_items = data  # type: ignore[assignment]
641

642
643
        new_videos = list[tuple[np.ndarray, dict[str, Any] | None]]()
        metadata_lst: list[dict[str, Any] | None] = []
644
645
646
        for data_item in data_items:
            video, metadata = self._get_video_with_metadata(data_item)
            if self.video_needs_metadata:
647
648
649
650
651
                if metadata is None:
                    raise ValueError(
                        "Video metadata is required but not found in mm input. "
                        "Please check your video input in `multi_modal_data`"
                    )
652
653
654
655
656
657
658
659
660
                new_videos.append((video, metadata))
                metadata_lst.append(metadata)
            else:
                new_videos.append(video)

        if not self.video_needs_metadata:
            metadata = None

        return VideoProcessorItems(new_videos, metadata=metadata_lst)
661

Roger Wang's avatar
Roger Wang committed
662
663
664
665
666
667
668
669
670
671
672
    def _parse_vision_chunk_data(
        self,
        data: ModalityData[Any],
    ) -> ModalityDataItems[Any, Any] | None:
        """Parse vision chunk data (unified image and video chunks)."""
        if data is None or self._is_empty(data):
            return None
        if self.is_embeddings(data):
            raise ValueError("Do not support embedding data for vision_chunk right now")
        return VisionChunkProcessorItems(data)

673
674
675
676
677
    def _get_subparsers(self) -> Mapping[str, ModalityDataParser]:
        return {
            "audio": self._parse_audio_data,
            "image": self._parse_image_data,
            "video": self._parse_video_data,
Roger Wang's avatar
Roger Wang committed
678
            "vision_chunk": self._parse_vision_chunk_data,
679
680
        }

681
    def parse_mm_data(self, mm_data: MultiModalDataDict) -> MultiModalDataItems:
682
683
684
685
686
687
688
        subparsers = self._get_subparsers()

        mm_items = MultiModalDataItems()
        for k, v in mm_data.items():
            if k not in subparsers:
                raise ValueError(f"Unsupported modality: {k}")

689
690
691
            # ignore empty embedding data
            if (parsed_data := subparsers[k](v)) is not None:
                mm_items[k] = parsed_data
692
693

        return mm_items