parse.py 21.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
from abc import ABC, abstractmethod
from collections import UserDict
from collections.abc import Callable, Iterator, Mapping, Sequence
7
8
9
10
11
12
from typing import (
    TYPE_CHECKING,
    Any,
    Generic,
    Literal,
    NamedTuple,
13
14
    TypeAlias,
    TypeGuard,
15
16
    TypeVar,
)
17
18
19

import numpy as np
import torch
20
from typing_extensions import assert_never
21

22
from vllm.utils.collection_utils import is_list_of
23
from vllm.utils.import_utils import LazyLoader
24

25
from .audio import AudioResampler
26
from .base import MediaWithBytes
27
28
29
30
31
32
33
34
35
36
37
38
from .inputs import (
    AudioItem,
    HfAudioItem,
    HfImageItem,
    HfVideoItem,
    ImageItem,
    ModalityData,
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
    VideoItem,
)
39
40
41
42

_T = TypeVar("_T")
_I = TypeVar("_I")

43
44
45
46
47
if TYPE_CHECKING:
    import PIL.Image as PILImage
else:
    PILImage = LazyLoader("PILImage", globals(), "PIL.Image")

48
49

class ModalityDataItems(ABC, Generic[_T, _I]):
50
    """
51
52
    Represents data items for a modality in
    [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
53
    """
54

55
    def __init__(self, data: _T, modality: str) -> None:
56
57
        super().__init__()

58
        self.data: _T = data
59
60
61
        self.modality = modality

    def __repr__(self) -> str:
62
        return f"{type(self).__name__}(modality={self.modality!r}, len={len(self)})"
63
64
65
66
67
68
69
70
71

    def __len__(self) -> int:
        return self.get_count()

    def __getitem__(self, index: int) -> _I:
        return self.get(index)

    if TYPE_CHECKING:
        # Auto-generated
72
        def __iter__(self) -> Iterator[_I]: ...
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87

    @abstractmethod
    def get_count(self) -> int:
        """Get the number of data items."""
        raise NotImplementedError

    @abstractmethod
    def get(self, index: int) -> _I:
        """Get a data item by its index."""
        raise NotImplementedError

    def get_all(self) -> list[_I]:
        """Get all data items."""
        return [self.get(idx) for idx in range(self.get_count())]

88
89
90
91
92
93
    def get_item_for_hash(self, index: int) -> object:
        return self.get(index)

    def get_all_items_for_hash(self) -> list[object]:
        return [self.get_item_for_hash(idx) for idx in range(self.get_count())]

94
95
96
97
98
99
100
101
102
103
104
105
    @abstractmethod
    def get_processor_data(self) -> Mapping[str, object]:
        """Get the data to pass to the HF processor."""
        raise NotImplementedError

    @abstractmethod
    def get_passthrough_data(self) -> Mapping[str, object]:
        """Get the data to pass directly to the model."""
        raise NotImplementedError


class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]):
106
    """Base class for data items that are arranged in a list."""
107

108
109
110
111
    def _unwrap(self, item: _T | MediaWithBytes[_T]) -> _T:
        """Extract media from wrapper if present."""
        return item.media if isinstance(item, MediaWithBytes) else item

112
113
114
115
    def get_count(self) -> int:
        return len(self.data)

    def get(self, index: int) -> _T:
116
117
118
119
        return self._unwrap(self.data[index])

    def get_item_for_hash(self, index: int) -> _T | MediaWithBytes[_T]:
        # Return raw item for hashing (preserves original_bytes if present)
120
121
122
        return self.data[index]

    def get_processor_data(self) -> Mapping[str, object]:
123
        return {f"{self.modality}s": self.get_all()}
124
125
126
127
128

    def get_passthrough_data(self) -> Mapping[str, object]:
        return {}


129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def validate_embedding_ndim(
    tensor: torch.Tensor,
    modality: str,
    index: int | None = None,
) -> None:
    """Validate tensor ndim for multimodal embeddings.

    Single embeddings should be 2D (seq_len, hidden_size).
    Batched embeddings should be 3D (batch, seq_len, hidden_size).

    Args:
        tensor: The tensor to validate.
        modality: The modality name for error messages (e.g., "image", "audio").
        index: Optional index for list items, included in error messages.
    """
    if tensor.ndim < 2 or tensor.ndim > 3:
        idx_str = f" [{index}]" if index is not None else ""
        raise ValueError(
            f"{modality.capitalize()} embedding{idx_str} must be 2D "
            f"(seq_len, hidden_size) or 3D (batch, seq_len, hidden_size), "
            f"got {tensor.ndim}D tensor with shape {tuple(tensor.shape)}"
        )


153
class EmbeddingItems(
154
    ModalityDataItems[torch.Tensor | list[torch.Tensor], torch.Tensor]
155
):
156
157
158
159
    """
    Base class for data items that are expressed as a batched embedding tensor,
    or a list of embedding tensors (one per item).
    """
160

161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
    def __init__(
        self,
        data: torch.Tensor | list[torch.Tensor],
        modality: str,
        expected_hidden_size: int | None = None,
    ) -> None:
        super().__init__(data, modality)

        # Validate ndim first (before hidden_size which depends on correct ndim)
        self._validate_ndim()

        # Validate hidden dimension if expected size is provided
        if expected_hidden_size is not None:
            self._validate_hidden_size(expected_hidden_size)

    def _validate_ndim(self) -> None:
        """Validate that embedding tensors have correct ndim (2D or 3D)."""
        if isinstance(self.data, torch.Tensor):
            validate_embedding_ndim(self.data, self.modality)
        else:
            # List of tensors: each should be 2D (seq_len, hidden_size)
            for idx, tensor in enumerate(self.data):
                if tensor.ndim != 2:
                    raise ValueError(
                        f"{self.modality.capitalize()} embedding [{idx}] must be "
                        f"2D (seq_len, hidden_size), got {tensor.ndim}D tensor "
                        f"with shape {tuple(tensor.shape)}"
                    )

    def _validate_hidden_size(self, expected_hidden_size: int) -> None:
        """Validate that embedding hidden dimension matches expected size.

        This validates hidden dimensions to prevent vulnerabilities: Embeddings
        with correct ndim but wrong hidden dimension could bypass initial
        checks and cause crashes during model inference when dimensions don't match.
        """
        if isinstance(self.data, torch.Tensor):
            # Batched tensor: shape is (batch, seq_len, hidden_size)
            actual_hidden_size = self.data.shape[-1]
            if actual_hidden_size != expected_hidden_size:
                raise ValueError(
                    f"{self.modality.capitalize()} embedding hidden dimension "
                    f"mismatch: got {actual_hidden_size}, but model expects "
                    f"{expected_hidden_size}. Embedding shape: {tuple(self.data.shape)}"
                )
        else:
            # List of tensors: each has shape (seq_len, hidden_size)
            for idx, tensor in enumerate(self.data):
                actual_hidden_size = tensor.shape[-1]
                if actual_hidden_size != expected_hidden_size:
                    raise ValueError(
                        f"{self.modality.capitalize()} embedding [{idx}] hidden "
                        f"dimension mismatch: got {actual_hidden_size}, but model "
                        f"expects {expected_hidden_size}. "
                        f"Embedding shape: {tuple(tensor.shape)}"
                    )

218
219
220
221
222
223
    def _unwrap(
        self, item: torch.Tensor | MediaWithBytes[torch.Tensor]
    ) -> torch.Tensor:
        """Extract media from wrapper if present."""
        return item.media if isinstance(item, MediaWithBytes) else item

224
225
226
    def get_count(self) -> int:
        return len(self.data)

227
    def get(self, index: int) -> torch.Tensor:
228
        return self._unwrap(self.data[index])
229
230
231
232
233
234
235

    def get_processor_data(self) -> Mapping[str, object]:
        return {}

    def get_passthrough_data(self) -> Mapping[str, object]:
        return {f"{self.modality}_embeds": self.data}

236
237
238
    def get_feature_size(self, item_idx: int) -> int:
        return len(self.get(item_idx))

239

240
241
242
class DictEmbeddingItems(
    ModalityDataItems[Mapping[str, torch.Tensor], Mapping[str, torch.Tensor]]
):
243
244
245
246
247
248
249
250
251
252
253
    """
    Base class for data items that are expressed as a dictionary of tensors.

    Usually, the dictionary keys correspond to the outputs of HF processor.
    """

    def __init__(
        self,
        data: Mapping[str, torch.Tensor],
        modality: str,
        required_fields: set[str],
254
255
256
257
        fields_factory: Callable[
            [Mapping[str, torch.Tensor]],
            Mapping[str, MultiModalFieldConfig],
        ],
258
    ) -> None:
259
260
        from transformers.feature_extraction_utils import BatchFeature

261
262
263
264
265
        super().__init__(data, modality)

        missing_required_data_keys = required_fields - data.keys()
        if missing_required_data_keys:
            data_keys = set(data.keys())
266
267
268
269
            msg = (
                f"The data should contain the fields: {required_fields}, "
                f"but only found the following keys: {data_keys}"
            )
270
271
            raise ValueError(msg)

272
273
274
275
276
277
278
        fields_config = fields_factory(data)
        missing_required_fields = required_fields - fields_config.keys()
        if missing_required_fields:
            fields = set(fields_config.keys())
            msg = f"{required_fields=} should be a subset of {fields=}"
            raise ValueError(msg)

279
280
281
        self.fields_config = fields_config
        self.required_fields = required_fields

282
        self._kwargs = MultiModalKwargsItems.from_hf_inputs(
283
284
285
286
287
            BatchFeature(dict(data)),
            fields_config,
        )

    def get_count(self) -> int:
288
        return len(self._kwargs[self.modality])
289
290

    def get(self, index: int) -> Mapping[str, torch.Tensor]:
291
        return self._kwargs[self.modality][index].get_data()
292
293
294
295
296
297
298
299

    def get_processor_data(self) -> Mapping[str, object]:
        return {}

    def get_passthrough_data(self) -> Mapping[str, object]:
        return self.data


300
class AudioProcessorItems(ProcessorBatchItems[HfAudioItem]):
301
    def __init__(self, data: Sequence[HfAudioItem] | None) -> None:
302
303
        if data is None:
            data = [None]
304
305
        super().__init__(data, "audio")

306
307
308
309
    def get_audio_length(self, item_idx: int) -> int:
        audio = self.get(item_idx)
        return len(audio)

310
311

class AudioEmbeddingItems(EmbeddingItems):
312
313
314
315
316
317
    def __init__(
        self,
        data: torch.Tensor | list[torch.Tensor],
        expected_hidden_size: int | None = None,
    ) -> None:
        super().__init__(data, "audio", expected_hidden_size)
318
319
320
321
322
323
324
325


class ImageSize(NamedTuple):
    width: int
    height: int


class ImageProcessorItems(ProcessorBatchItems[HfImageItem]):
326
    def __init__(self, data: Sequence[HfImageItem] | None) -> None:
327
328
        if data is None:
            data = [None]
329
330
331
332
333
        super().__init__(data, "image")

    def get_image_size(self, item_idx: int) -> ImageSize:
        image = self.get(item_idx)

334
        if isinstance(image, PILImage.Image):
335
336
337
338
339
340
341
342
343
            return ImageSize(*image.size)
        if isinstance(image, (np.ndarray, torch.Tensor)):
            _, h, w = image.shape
            return ImageSize(w, h)

        assert_never(image)


class ImageEmbeddingItems(EmbeddingItems):
344
345
346
347
348
349
    def __init__(
        self,
        data: torch.Tensor | list[torch.Tensor],
        expected_hidden_size: int | None = None,
    ) -> None:
        super().__init__(data, "image", expected_hidden_size)
350
351
352


class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]):
353
354
    def __init__(
        self,
355
356
        data: Sequence[HfVideoItem] | None,
        metadata: dict[str, Any] | list[dict[str, Any] | None] | None = None,
357
    ) -> None:
358
359
        if data is None:
            data = [None]
360
        super().__init__(data, "video")
361
        self.metadata = metadata
362

363
364
365
366
367
368
    def get_num_frames(self, item_idx: int) -> int:
        return len(self.get(item_idx))

    def get_frame_size(self, item_idx: int) -> ImageSize:
        image = self.get(item_idx)[0]  # Assume that the video isn't empty

369
        if isinstance(image, PILImage.Image):
370
371
372
373
374
375
376
            return ImageSize(*image.size)
        if isinstance(image, (np.ndarray, torch.Tensor)):
            _, h, w = image.shape
            return ImageSize(w, h)

        assert_never(image)

377
378

class VideoEmbeddingItems(EmbeddingItems):
379
380
381
382
383
384
    def __init__(
        self,
        data: torch.Tensor | list[torch.Tensor],
        expected_hidden_size: int | None = None,
    ) -> None:
        super().__init__(data, "video", expected_hidden_size)
385
386
387
388
389
390
391


_D = TypeVar("_D", bound=ModalityDataItems[Any, Any])


class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]):
    """
392
393
    As [`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict], but
    normalized such that each entry corresponds to a list.
394
395
396
397
398
    """

    def get_count(self, modality: str, *, strict: bool = True) -> int:
        """
        Get the number of data items belonging to a modality.
399

400
        If `strict=False`, return `0` instead of raising [`KeyError`][]
401
402
403
404
405
        even if the modality is not found.
        """
        if modality not in self:
            if strict:
                available_modalities = set(self.keys())
406
407
408
409
                raise KeyError(
                    f"Modality {modality!r} not found. "
                    f"Available modalities: {available_modalities}"
                )
410
411
412
413
414
415
416
417
418
419
420
421

            return 0

        return self[modality].get_count()

    def get_all_counts(self) -> Mapping[str, int]:
        """Get the number of items belonging to each modality."""
        return {m: items.get_count() for m, items in self.items()}

    def get_items(
        self,
        modality: str,
422
        typ: type[_D] | tuple[type[_D], ...],
423
424
425
426
427
428
429
    ) -> _D:
        """
        Get the data items belonging to a modality,
        requiring that they belong to a certain type.
        """
        if modality not in self:
            available_modalities = set(self.keys())
430
431
432
433
            raise KeyError(
                f"Modality {modality!r} not found. "
                f"Available modalities: {available_modalities}"
            )
434
435
436

        items = self[modality]
        if not isinstance(items, typ):
437
438
439
440
441
            raise TypeError(
                f"Invalid type of data items for {modality=}. "
                f"Expected type: {typ}, but "
                f"found type: {type(items)}"
            )
442

443
        return items  # type: ignore[return-value]
444
445


446
ModalityDataParser: TypeAlias = Callable[
447
    [ModalityData[Any]], ModalityDataItems[Any, Any] | None
448
]
449
450
451
452


class MultiModalDataParser:
    """
453
454
    Parses [`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict]
    into [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
455
456
457
458

    Args:
        target_sr (float, optional): Enables automatic resampling of audio
            items to the model's expected sampling rate.
459
460
461
462
        expected_hidden_size (int, optional): Expected hidden dimension for
            embedding inputs. If provided, validates that user-supplied
            embeddings have the correct hidden size to prevent crashes
            during model inference.
463
464
    """

465
466
467
    def __init__(
        self,
        *,
468
        target_sr: float | None = None,
469
        audio_resample_method: Literal["librosa", "scipy"] = "librosa",
470
        video_needs_metadata: bool = False,
471
        expected_hidden_size: int | None = None,
472
    ) -> None:
473
474
        super().__init__()

475
476
477
478
        self.audio_resampler = AudioResampler(
            target_sr=target_sr,
            method=audio_resample_method,
        )
479
        self.video_needs_metadata = video_needs_metadata
480
        self.expected_hidden_size = expected_hidden_size
481

482
483
484
    @classmethod
    def is_embeddings(
        cls, data: object
485
    ) -> TypeGuard[torch.Tensor | list[torch.Tensor]]:
486
487
488
        if isinstance(data, torch.Tensor):
            return data.ndim == 3
        if is_list_of(data, torch.Tensor):
489
            return data[0].ndim == 2  # type: ignore[index]
490
491
492
493
494
495
496
497

        return False

    def _is_empty(self, data: object) -> TypeGuard[None]:
        if isinstance(data, list):
            return len(data) == 0
        if isinstance(data, (np.ndarray, torch.Tensor)):
            return data.size == 0
498
499
500
501
502
503

        return False

    def _get_audio_with_sr(
        self,
        audio: AudioItem,
504
    ) -> tuple[np.ndarray, float | None]:
505
506
507
508
509
510
511
512
513
514
515
        if isinstance(audio, tuple):
            return audio
        if isinstance(audio, list):
            return np.array(audio), None
        if isinstance(audio, np.ndarray):
            return audio, None
        if isinstance(audio, torch.Tensor):
            return audio.numpy(), None

        assert_never(audio)

516
517
518
    def _get_video_with_metadata(
        self,
        video: VideoItem,
519
    ) -> tuple[np.ndarray, dict[str, Any] | None]:
520
521
522
523
524
525
526
527
528
529
530
        if isinstance(video, tuple):
            return video
        if isinstance(video, list):
            return np.array(video), None
        if isinstance(video, np.ndarray):
            return video, None
        if isinstance(video, torch.Tensor):
            return video.numpy(), None

        assert_never(video)

531
532
533
    def _parse_audio_data(
        self,
        data: ModalityData[AudioItem],
534
    ) -> ModalityDataItems[Any, Any] | None:
535
536
537
        if data is None:
            return AudioProcessorItems(None)

538
        # also check single audio item with sampling rate
539
540
541
        if self._is_empty(data) or (
            isinstance(data, tuple) and self._is_empty(data[0])
        ):
542
543
            return None

544
        if self.is_embeddings(data):
545
            return AudioEmbeddingItems(data, self.expected_hidden_size)
546

547
        data_items: list[AudioItem]
548
549
550
551
552
553
        if (
            is_list_of(data, float)
            or isinstance(data, (np.ndarray, torch.Tensor))
            and data.ndim == 1
            or isinstance(data, tuple)
        ):
554
555
556
557
            data_items = [data]
        elif isinstance(data, (np.ndarray, torch.Tensor)):
            data_items = [elem for elem in data]
        else:
558
            data_items = data  # type: ignore[assignment]
559
560
561
562
563
564
565

        new_audios = list[np.ndarray]()
        for data_item in data_items:
            audio, orig_sr = self._get_audio_with_sr(data_item)
            if orig_sr is None:
                new_audio = audio
            else:
566
                new_audio = self.audio_resampler.resample(audio, orig_sr=orig_sr)
567
568
569
570
571
572
573
574

            new_audios.append(new_audio)

        return AudioProcessorItems(new_audios)

    def _parse_image_data(
        self,
        data: ModalityData[ImageItem],
575
    ) -> ModalityDataItems[Any, Any] | None:
576
577
578
        if data is None:
            return ImageProcessorItems(None)

579
580
581
        if self._is_empty(data):
            return None

582
        if self.is_embeddings(data):
583
            return ImageEmbeddingItems(data, self.expected_hidden_size)
584

585
        if (
586
            isinstance(data, (PILImage.Image, MediaWithBytes))
587
588
589
            or isinstance(data, (np.ndarray, torch.Tensor))
            and data.ndim == 3
        ):
590
591
592
593
594
595
596
597
598
599
600
            data_items = [data]
        elif isinstance(data, (np.ndarray, torch.Tensor)):
            data_items = [elem for elem in data]
        else:
            data_items = data

        return ImageProcessorItems(data_items)

    def _parse_video_data(
        self,
        data: ModalityData[VideoItem],
601
    ) -> ModalityDataItems[Any, Any] | None:
602
603
604
        if data is None:
            return VideoProcessorItems(None)

605
606
607
        if self._is_empty(data):
            return None

608
        if self.is_embeddings(data):
609
            return VideoEmbeddingItems(data, self.expected_hidden_size)
610

611
        data_items: list[VideoItem]
612
613
614
615
616
        if (
            is_list_of(data, PILImage.Image)
            or isinstance(data, (np.ndarray, torch.Tensor))
            and data.ndim == 4
        ):
617
618
619
            data_items = [data]
        elif isinstance(data, (np.ndarray, torch.Tensor)):
            data_items = [elem for elem in data]
620
621
        elif isinstance(data, tuple) and len(data) == 2:
            data_items = [data]
622
        else:
623
            data_items = data  # type: ignore[assignment]
624

625
626
        new_videos = list[tuple[np.ndarray, dict[str, Any] | None]]()
        metadata_lst: list[dict[str, Any] | None] = []
627
628
629
        for data_item in data_items:
            video, metadata = self._get_video_with_metadata(data_item)
            if self.video_needs_metadata:
630
631
632
633
634
                if metadata is None:
                    raise ValueError(
                        "Video metadata is required but not found in mm input. "
                        "Please check your video input in `multi_modal_data`"
                    )
635
636
637
638
639
640
641
642
643
                new_videos.append((video, metadata))
                metadata_lst.append(metadata)
            else:
                new_videos.append(video)

        if not self.video_needs_metadata:
            metadata = None

        return VideoProcessorItems(new_videos, metadata=metadata_lst)
644
645
646
647
648
649
650
651

    def _get_subparsers(self) -> Mapping[str, ModalityDataParser]:
        return {
            "audio": self._parse_audio_data,
            "image": self._parse_image_data,
            "video": self._parse_video_data,
        }

652
    def parse_mm_data(self, mm_data: MultiModalDataDict) -> MultiModalDataItems:
653
654
655
656
657
658
659
        subparsers = self._get_subparsers()

        mm_items = MultiModalDataItems()
        for k, v in mm_data.items():
            if k not in subparsers:
                raise ValueError(f"Unsupported modality: {k}")

660
661
662
            # ignore empty embedding data
            if (parsed_data := subparsers[k](v)) is not None:
                mm_items[k] = parsed_data
663
664

        return mm_items