parse.py 16.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
from abc import ABC, abstractmethod
from collections import UserDict
from collections.abc import Callable, Iterator, Mapping, Sequence
7
8
9
10
11
12
from typing import (
    TYPE_CHECKING,
    Any,
    Generic,
    Literal,
    NamedTuple,
13
14
    TypeAlias,
    TypeGuard,
15
16
    TypeVar,
)
17
18
19

import numpy as np
import torch
20
from typing_extensions import assert_never
21

22
from vllm.utils.collection_utils import is_list_of
23
from vllm.utils.import_utils import LazyLoader
24

25
from .audio import AudioResampler
26
from .base import MediaWithBytes
27
28
29
30
31
32
33
34
35
36
37
38
from .inputs import (
    AudioItem,
    HfAudioItem,
    HfImageItem,
    HfVideoItem,
    ImageItem,
    ModalityData,
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
    VideoItem,
)
39
40
41
42

_T = TypeVar("_T")
_I = TypeVar("_I")

43
44
45
46
47
if TYPE_CHECKING:
    import PIL.Image as PILImage
else:
    PILImage = LazyLoader("PILImage", globals(), "PIL.Image")

48
49

class ModalityDataItems(ABC, Generic[_T, _I]):
50
    """
51
52
    Represents data items for a modality in
    [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
53
    """
54

55
    def __init__(self, data: _T, modality: str) -> None:
56
57
        super().__init__()

58
        self.data: _T = data
59
60
61
        self.modality = modality

    def __repr__(self) -> str:
62
        return f"{type(self).__name__}(modality={self.modality!r}, len={len(self)})"
63
64
65
66
67
68
69
70
71

    def __len__(self) -> int:
        return self.get_count()

    def __getitem__(self, index: int) -> _I:
        return self.get(index)

    if TYPE_CHECKING:
        # Auto-generated
72
        def __iter__(self) -> Iterator[_I]: ...
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87

    @abstractmethod
    def get_count(self) -> int:
        """Get the number of data items."""
        raise NotImplementedError

    @abstractmethod
    def get(self, index: int) -> _I:
        """Get a data item by its index."""
        raise NotImplementedError

    def get_all(self) -> list[_I]:
        """Get all data items."""
        return [self.get(idx) for idx in range(self.get_count())]

88
89
90
91
92
93
    def get_item_for_hash(self, index: int) -> object:
        return self.get(index)

    def get_all_items_for_hash(self) -> list[object]:
        return [self.get_item_for_hash(idx) for idx in range(self.get_count())]

94
95
96
97
98
99
100
101
102
103
104
105
    @abstractmethod
    def get_processor_data(self) -> Mapping[str, object]:
        """Get the data to pass to the HF processor."""
        raise NotImplementedError

    @abstractmethod
    def get_passthrough_data(self) -> Mapping[str, object]:
        """Get the data to pass directly to the model."""
        raise NotImplementedError


class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]):
106
    """Base class for data items that are arranged in a list."""
107

108
109
110
111
    def _unwrap(self, item: _T | MediaWithBytes[_T]) -> _T:
        """Extract media from wrapper if present."""
        return item.media if isinstance(item, MediaWithBytes) else item

112
113
114
115
    def get_count(self) -> int:
        return len(self.data)

    def get(self, index: int) -> _T:
116
117
118
119
        return self._unwrap(self.data[index])

    def get_item_for_hash(self, index: int) -> _T | MediaWithBytes[_T]:
        # Return raw item for hashing (preserves original_bytes if present)
120
121
122
123
124
125
126
127
128
        return self.data[index]

    def get_processor_data(self) -> Mapping[str, object]:
        return {f"{self.modality}s": self.data}

    def get_passthrough_data(self) -> Mapping[str, object]:
        return {}


129
class EmbeddingItems(
130
    ModalityDataItems[torch.Tensor | list[torch.Tensor], torch.Tensor]
131
):
132
133
134
135
    """
    Base class for data items that are expressed as a batched embedding tensor,
    or a list of embedding tensors (one per item).
    """
136

137
138
139
140
141
142
    def _unwrap(
        self, item: torch.Tensor | MediaWithBytes[torch.Tensor]
    ) -> torch.Tensor:
        """Extract media from wrapper if present."""
        return item.media if isinstance(item, MediaWithBytes) else item

143
144
145
    def get_count(self) -> int:
        return len(self.data)

146
    def get(self, index: int) -> torch.Tensor:
147
        return self._unwrap(self.data[index])
148
149
150
151
152
153
154

    def get_processor_data(self) -> Mapping[str, object]:
        return {}

    def get_passthrough_data(self) -> Mapping[str, object]:
        return {f"{self.modality}_embeds": self.data}

155
156
157
    def get_feature_size(self, item_idx: int) -> int:
        return len(self.get(item_idx))

158

159
160
161
class DictEmbeddingItems(
    ModalityDataItems[Mapping[str, torch.Tensor], Mapping[str, torch.Tensor]]
):
162
163
164
165
166
167
168
169
170
171
172
    """
    Base class for data items that are expressed as a dictionary of tensors.

    Usually, the dictionary keys correspond to the outputs of HF processor.
    """

    def __init__(
        self,
        data: Mapping[str, torch.Tensor],
        modality: str,
        required_fields: set[str],
173
174
175
176
        fields_factory: Callable[
            [Mapping[str, torch.Tensor]],
            Mapping[str, MultiModalFieldConfig],
        ],
177
    ) -> None:
178
179
        from transformers.feature_extraction_utils import BatchFeature

180
181
182
183
184
        super().__init__(data, modality)

        missing_required_data_keys = required_fields - data.keys()
        if missing_required_data_keys:
            data_keys = set(data.keys())
185
186
187
188
            msg = (
                f"The data should contain the fields: {required_fields}, "
                f"but only found the following keys: {data_keys}"
            )
189
190
            raise ValueError(msg)

191
192
193
194
195
196
197
        fields_config = fields_factory(data)
        missing_required_fields = required_fields - fields_config.keys()
        if missing_required_fields:
            fields = set(fields_config.keys())
            msg = f"{required_fields=} should be a subset of {fields=}"
            raise ValueError(msg)

198
199
200
        self.fields_config = fields_config
        self.required_fields = required_fields

201
        self._kwargs = MultiModalKwargsItems.from_hf_inputs(
202
203
204
205
206
            BatchFeature(dict(data)),
            fields_config,
        )

    def get_count(self) -> int:
207
        return len(self._kwargs[self.modality])
208
209

    def get(self, index: int) -> Mapping[str, torch.Tensor]:
210
        return self._kwargs[self.modality][index].get_data()
211
212
213
214
215
216
217
218

    def get_processor_data(self) -> Mapping[str, object]:
        return {}

    def get_passthrough_data(self) -> Mapping[str, object]:
        return self.data


219
class AudioProcessorItems(ProcessorBatchItems[HfAudioItem]):
220
    def __init__(self, data: Sequence[HfAudioItem] | None) -> None:
221
222
        if data is None:
            data = [None]
223
224
        super().__init__(data, "audio")

225
226
227
228
    def get_audio_length(self, item_idx: int) -> int:
        audio = self.get(item_idx)
        return len(audio)

229
230

class AudioEmbeddingItems(EmbeddingItems):
231
    def __init__(self, data: torch.Tensor | list[torch.Tensor]) -> None:
232
233
234
235
236
237
238
239
240
        super().__init__(data, "audio")


class ImageSize(NamedTuple):
    width: int
    height: int


class ImageProcessorItems(ProcessorBatchItems[HfImageItem]):
241
    def __init__(self, data: Sequence[HfImageItem] | None) -> None:
242
243
        if data is None:
            data = [None]
244
245
246
247
248
        super().__init__(data, "image")

    def get_image_size(self, item_idx: int) -> ImageSize:
        image = self.get(item_idx)

249
        if isinstance(image, PILImage.Image):
250
251
252
253
254
255
256
257
258
            return ImageSize(*image.size)
        if isinstance(image, (np.ndarray, torch.Tensor)):
            _, h, w = image.shape
            return ImageSize(w, h)

        assert_never(image)


class ImageEmbeddingItems(EmbeddingItems):
259
    def __init__(self, data: torch.Tensor | list[torch.Tensor]) -> None:
260
261
262
263
        super().__init__(data, "image")


class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]):
264
265
    def __init__(
        self,
266
267
        data: Sequence[HfVideoItem] | None,
        metadata: dict[str, Any] | list[dict[str, Any] | None] | None = None,
268
    ) -> None:
269
270
        if data is None:
            data = [None]
271
        super().__init__(data, "video")
272
        self.metadata = metadata
273

274
275
276
277
278
279
    def get_num_frames(self, item_idx: int) -> int:
        return len(self.get(item_idx))

    def get_frame_size(self, item_idx: int) -> ImageSize:
        image = self.get(item_idx)[0]  # Assume that the video isn't empty

280
        if isinstance(image, PILImage.Image):
281
282
283
284
285
286
287
            return ImageSize(*image.size)
        if isinstance(image, (np.ndarray, torch.Tensor)):
            _, h, w = image.shape
            return ImageSize(w, h)

        assert_never(image)

288
289

class VideoEmbeddingItems(EmbeddingItems):
290
    def __init__(self, data: torch.Tensor | list[torch.Tensor]) -> None:
291
292
293
294
295
296
297
298
        super().__init__(data, "video")


_D = TypeVar("_D", bound=ModalityDataItems[Any, Any])


class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]):
    """
299
300
    As [`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict], but
    normalized such that each entry corresponds to a list.
301
302
303
304
305
    """

    def get_count(self, modality: str, *, strict: bool = True) -> int:
        """
        Get the number of data items belonging to a modality.
306

307
        If `strict=False`, return `0` instead of raising [`KeyError`][]
308
309
310
311
312
        even if the modality is not found.
        """
        if modality not in self:
            if strict:
                available_modalities = set(self.keys())
313
314
315
316
                raise KeyError(
                    f"Modality {modality!r} not found. "
                    f"Available modalities: {available_modalities}"
                )
317
318
319
320
321
322
323
324
325
326
327
328

            return 0

        return self[modality].get_count()

    def get_all_counts(self) -> Mapping[str, int]:
        """Get the number of items belonging to each modality."""
        return {m: items.get_count() for m, items in self.items()}

    def get_items(
        self,
        modality: str,
329
        typ: type[_D] | tuple[type[_D], ...],
330
331
332
333
334
335
336
    ) -> _D:
        """
        Get the data items belonging to a modality,
        requiring that they belong to a certain type.
        """
        if modality not in self:
            available_modalities = set(self.keys())
337
338
339
340
            raise KeyError(
                f"Modality {modality!r} not found. "
                f"Available modalities: {available_modalities}"
            )
341
342
343

        items = self[modality]
        if not isinstance(items, typ):
344
345
346
347
348
            raise TypeError(
                f"Invalid type of data items for {modality=}. "
                f"Expected type: {typ}, but "
                f"found type: {type(items)}"
            )
349

350
        return items  # type: ignore[return-value]
351
352


353
ModalityDataParser: TypeAlias = Callable[
354
    [ModalityData[Any]], ModalityDataItems[Any, Any] | None
355
]
356
357
358
359


class MultiModalDataParser:
    """
360
361
    Parses [`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict]
    into [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
362
363
364
365

    Args:
        target_sr (float, optional): Enables automatic resampling of audio
            items to the model's expected sampling rate.
366
367
    """

368
369
370
    def __init__(
        self,
        *,
371
        target_sr: float | None = None,
372
        audio_resample_method: Literal["librosa", "scipy"] = "librosa",
373
        video_needs_metadata: bool = False,
374
    ) -> None:
375
376
        super().__init__()

377
378
379
380
        self.audio_resampler = AudioResampler(
            target_sr=target_sr,
            method=audio_resample_method,
        )
381
        self.video_needs_metadata = video_needs_metadata
382

383
384
385
    @classmethod
    def is_embeddings(
        cls, data: object
386
    ) -> TypeGuard[torch.Tensor | list[torch.Tensor]]:
387
388
389
        if isinstance(data, torch.Tensor):
            return data.ndim == 3
        if is_list_of(data, torch.Tensor):
390
            return data[0].ndim == 2  # type: ignore[index]
391
392
393
394
395
396
397
398

        return False

    def _is_empty(self, data: object) -> TypeGuard[None]:
        if isinstance(data, list):
            return len(data) == 0
        if isinstance(data, (np.ndarray, torch.Tensor)):
            return data.size == 0
399
400
401
402
403
404

        return False

    def _get_audio_with_sr(
        self,
        audio: AudioItem,
405
    ) -> tuple[np.ndarray, float | None]:
406
407
408
409
410
411
412
413
414
415
416
        if isinstance(audio, tuple):
            return audio
        if isinstance(audio, list):
            return np.array(audio), None
        if isinstance(audio, np.ndarray):
            return audio, None
        if isinstance(audio, torch.Tensor):
            return audio.numpy(), None

        assert_never(audio)

417
418
419
    def _get_video_with_metadata(
        self,
        video: VideoItem,
420
    ) -> tuple[np.ndarray, dict[str, Any] | None]:
421
422
423
424
425
426
427
428
429
430
431
        if isinstance(video, tuple):
            return video
        if isinstance(video, list):
            return np.array(video), None
        if isinstance(video, np.ndarray):
            return video, None
        if isinstance(video, torch.Tensor):
            return video.numpy(), None

        assert_never(video)

432
433
434
    def _parse_audio_data(
        self,
        data: ModalityData[AudioItem],
435
    ) -> ModalityDataItems[Any, Any] | None:
436
437
438
        if data is None:
            return AudioProcessorItems(None)

439
        # also check single audio item with sampling rate
440
441
442
        if self._is_empty(data) or (
            isinstance(data, tuple) and self._is_empty(data[0])
        ):
443
444
            return None

445
        if self.is_embeddings(data):
446
447
            return AudioEmbeddingItems(data)

448
        data_items: list[AudioItem]
449
450
451
452
453
454
        if (
            is_list_of(data, float)
            or isinstance(data, (np.ndarray, torch.Tensor))
            and data.ndim == 1
            or isinstance(data, tuple)
        ):
455
456
457
458
            data_items = [data]
        elif isinstance(data, (np.ndarray, torch.Tensor)):
            data_items = [elem for elem in data]
        else:
459
            data_items = data  # type: ignore[assignment]
460
461
462
463
464
465
466

        new_audios = list[np.ndarray]()
        for data_item in data_items:
            audio, orig_sr = self._get_audio_with_sr(data_item)
            if orig_sr is None:
                new_audio = audio
            else:
467
                new_audio = self.audio_resampler.resample(audio, orig_sr=orig_sr)
468
469
470
471
472
473
474
475

            new_audios.append(new_audio)

        return AudioProcessorItems(new_audios)

    def _parse_image_data(
        self,
        data: ModalityData[ImageItem],
476
    ) -> ModalityDataItems[Any, Any] | None:
477
478
479
        if data is None:
            return ImageProcessorItems(None)

480
481
482
        if self._is_empty(data):
            return None

483
        if self.is_embeddings(data):
484
485
            return ImageEmbeddingItems(data)

486
487
488
489
490
        if (
            isinstance(data, PILImage.Image)
            or isinstance(data, (np.ndarray, torch.Tensor))
            and data.ndim == 3
        ):
491
492
493
494
495
496
497
498
499
500
501
            data_items = [data]
        elif isinstance(data, (np.ndarray, torch.Tensor)):
            data_items = [elem for elem in data]
        else:
            data_items = data

        return ImageProcessorItems(data_items)

    def _parse_video_data(
        self,
        data: ModalityData[VideoItem],
502
    ) -> ModalityDataItems[Any, Any] | None:
503
504
505
        if data is None:
            return VideoProcessorItems(None)

506
507
508
        if self._is_empty(data):
            return None

509
        if self.is_embeddings(data):
510
511
            return VideoEmbeddingItems(data)

512
        data_items: list[VideoItem]
513
514
515
516
517
        if (
            is_list_of(data, PILImage.Image)
            or isinstance(data, (np.ndarray, torch.Tensor))
            and data.ndim == 4
        ):
518
519
520
            data_items = [data]
        elif isinstance(data, (np.ndarray, torch.Tensor)):
            data_items = [elem for elem in data]
521
522
        elif isinstance(data, tuple) and len(data) == 2:
            data_items = [data]
523
        else:
524
            data_items = data  # type: ignore[assignment]
525

526
527
        new_videos = list[tuple[np.ndarray, dict[str, Any] | None]]()
        metadata_lst: list[dict[str, Any] | None] = []
528
529
530
        for data_item in data_items:
            video, metadata = self._get_video_with_metadata(data_item)
            if self.video_needs_metadata:
531
532
533
534
535
                if metadata is None:
                    raise ValueError(
                        "Video metadata is required but not found in mm input. "
                        "Please check your video input in `multi_modal_data`"
                    )
536
537
538
539
540
541
542
543
544
                new_videos.append((video, metadata))
                metadata_lst.append(metadata)
            else:
                new_videos.append(video)

        if not self.video_needs_metadata:
            metadata = None

        return VideoProcessorItems(new_videos, metadata=metadata_lst)
545
546
547
548
549
550
551
552

    def _get_subparsers(self) -> Mapping[str, ModalityDataParser]:
        return {
            "audio": self._parse_audio_data,
            "image": self._parse_image_data,
            "video": self._parse_video_data,
        }

553
    def parse_mm_data(self, mm_data: MultiModalDataDict) -> MultiModalDataItems:
554
555
556
557
558
559
560
        subparsers = self._get_subparsers()

        mm_items = MultiModalDataItems()
        for k, v in mm_data.items():
            if k not in subparsers:
                raise ValueError(f"Unsupported modality: {k}")

561
562
563
            # ignore empty embedding data
            if (parsed_data := subparsers[k](v)) is not None:
                mm_items[k] = parsed_data
564
565

        return mm_items