mm_plugin.py 51.3 KB
Newer Older
chenych's avatar
chenych committed
1
import inspect
luopl's avatar
luopl committed
2
import math
luopl's avatar
luopl committed
3
import re
luopl's avatar
luopl committed
4
from copy import deepcopy
chenych's avatar
chenych committed
5
from dataclasses import dataclass
luopl's avatar
luopl committed
6
from io import BytesIO
chenych's avatar
chenych committed
7
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Type, TypedDict, Union
luopl's avatar
luopl committed
8
9

import numpy as np
luopl's avatar
luopl committed
10
import torch
luopl's avatar
luopl committed
11
12
13
from transformers.image_utils import get_image_size, to_numpy_array
from typing_extensions import override

chenych's avatar
chenych committed
14
15
16
17
18
19
20
21
22
23
24
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.packages import (
    is_librosa_available,
    is_pillow_available,
    is_pyav_available,
    is_transformers_version_greater_than,
)


if is_librosa_available():
    import librosa
luopl's avatar
luopl committed
25
26
27
28
29
30
31
32
33
34
35


if is_pillow_available():
    from PIL import Image
    from PIL.Image import Image as ImageObject


if is_pyav_available():
    import av


luopl's avatar
luopl committed
36
37
38
39
40
41
42
if is_transformers_version_greater_than("4.45.0"):
    from transformers.models.mllama.processing_mllama import (
        convert_sparse_cross_attention_mask_to_dense,
        get_cross_attention_token_mask,
    )


luopl's avatar
luopl committed
43
44
if TYPE_CHECKING:
    from av.stream import Stream
chenych's avatar
chenych committed
45
    from numpy.typing import NDArray
luopl's avatar
luopl committed
46
    from transformers import PreTrainedTokenizer, ProcessorMixin
chenych's avatar
chenych committed
47
    from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
luopl's avatar
luopl committed
48
49
50
51
52
53
    from transformers.image_processing_utils import BaseImageProcessor

    class EncodedImage(TypedDict):
        path: Optional[str]
        bytes: Optional[bytes]

luopl's avatar
luopl committed
54
    ImageInput = Union[str, bytes, EncodedImage, ImageObject]
luopl's avatar
luopl committed
55
    VideoInput = str
chenych's avatar
chenych committed
56
    AudioInput = Union[str, NDArray]
luopl's avatar
luopl committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75


def _get_paligemma_token_type_ids(
    imglens: Sequence[int], seqlens: Sequence[int], processor: "ProcessorMixin"
) -> List[List[int]]:
    r"""
    Gets paligemma token type ids for computing loss.

    Returns:
        batch_token_type_ids: shape (batch_size, sequence_length)
    """
    batch_token_type_ids = []
    for imglen, seqlen in zip(imglens, seqlens):
        image_seqlen = imglen * getattr(processor, "image_seqlen")
        batch_token_type_ids.append([0] * image_seqlen + [1] * (seqlen - image_seqlen))

    return batch_token_type_ids


chenych's avatar
chenych committed
76
77
78
79
80
81
@dataclass
class MMPluginMixin:
    image_token: Optional[str]
    video_token: Optional[str]
    audio_token: Optional[str]
    expand_mm_tokens: bool = True
luopl's avatar
luopl committed
82
83
84

    def _validate_input(
        self,
chenych's avatar
chenych committed
85
        processor: Optional["ProcessorMixin"],
luopl's avatar
luopl committed
86
87
        images: Sequence["ImageInput"],
        videos: Sequence["VideoInput"],
chenych's avatar
chenych committed
88
        audios: Sequence["AudioInput"],
luopl's avatar
luopl committed
89
90
91
92
    ) -> None:
        r"""
        Validates if this model accepts the input modalities.
        """
chenych's avatar
chenych committed
93
94
        image_processor: "BaseImageProcessor" = getattr(processor, "image_processor", None)
        feature_extractor: "SequenceFeatureExtractor" = getattr(processor, "feature_extractor", None)
luopl's avatar
luopl committed
95
        if len(images) != 0 and self.image_token is None:
luopl's avatar
luopl committed
96
97
98
            raise ValueError(
                "This model does not support image input. Please check whether the correct `template` is used."
            )
luopl's avatar
luopl committed
99
100

        if len(videos) != 0 and self.video_token is None:
luopl's avatar
luopl committed
101
102
103
            raise ValueError(
                "This model does not support video input. Please check whether the correct `template` is used."
            )
luopl's avatar
luopl committed
104

chenych's avatar
chenych committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
        if len(audios) != 0 and self.audio_token is None:
            raise ValueError(
                "This model does not support audio input. Please check whether the correct `template` is used."
            )

        if self.image_token is not None and processor is None:
            raise ValueError("Processor was not found, please check and update your processor config.")

        if self.image_token is not None and image_processor is None:
            raise ValueError("Image processor was not found, please check and update your processor config.")

        if self.audio_token is not None and feature_extractor is None:
            raise ValueError("Audio feature extractor was not found, please check and update your processor config.")

    def _preprocess_image(
        self, image: "ImageObject", image_max_pixels: int, image_min_pixels: int, **kwargs
    ) -> "ImageObject":
luopl's avatar
luopl committed
122
123
124
        r"""
        Pre-processes a single image.
        """
chenych's avatar
chenych committed
125
126
        if (image.width * image.height) > image_max_pixels:
            resize_factor = math.sqrt(image_max_pixels / (image.width * image.height))
luopl's avatar
luopl committed
127
            width, height = int(image.width * resize_factor), int(image.height * resize_factor)
chenych's avatar
chenych committed
128
129
130
131
132
133
            image = image.resize((width, height))

        if (image.width * image.height) < image_min_pixels:
            resize_factor = math.sqrt(image_min_pixels / (image.width * image.height))
            width, height = int(image.width * resize_factor), int(image.height * resize_factor)
            image = image.resize((width, height))
luopl's avatar
luopl committed
134
135
136
137
138
139

        if image.mode != "RGB":
            image = image.convert("RGB")

        return image

chenych's avatar
chenych committed
140
141
142
    def _get_video_sample_indices(
        self, video_stream: "Stream", video_fps: float, video_maxlen: int, **kwargs
    ) -> List[int]:
luopl's avatar
luopl committed
143
        r"""
chenych's avatar
chenych committed
144
        Computes video sample indices according to fps.
luopl's avatar
luopl committed
145
146
        """
        total_frames = video_stream.frames
chenych's avatar
chenych committed
147
148
149
150
        if total_frames == 0:  # infinite video
            return np.linspace(0, video_maxlen - 1, video_maxlen).astype(np.int32)

        sample_frames = math.floor(float(video_stream.duration * video_stream.time_base) * video_fps)
luopl's avatar
luopl committed
151
        sample_frames = min(total_frames, video_maxlen, sample_frames)
chenych's avatar
chenych committed
152
        return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
luopl's avatar
luopl committed
153
154
155
156
157
158
159
160
161

    def _regularize_images(self, images: Sequence["ImageInput"], **kwargs) -> List["ImageObject"]:
        r"""
        Regularizes images to avoid error. Including reading and pre-processing.
        """
        results = []
        for image in images:
            if isinstance(image, str):
                image = Image.open(image)
luopl's avatar
luopl committed
162
163
            elif isinstance(image, bytes):
                image = Image.open(BytesIO(image))
luopl's avatar
luopl committed
164
165
166
167
168
169
170
            elif isinstance(image, dict):
                if image["bytes"] is not None:
                    image = Image.open(BytesIO(image["bytes"]))
                else:
                    image = Image.open(image["path"])

            if not isinstance(image, ImageObject):
chenych's avatar
chenych committed
171
                raise ValueError(f"Expect input is a list of images, but got {type(image)}.")
luopl's avatar
luopl committed
172
173
174
175
176
177
178
179
180
181
182
183
184

            results.append(self._preprocess_image(image, **kwargs))

        return results

    def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]:
        r"""
        Regularizes videos to avoid error. Including reading, resizing and converting.
        """
        results = []
        for video in videos:
            container = av.open(video, "r")
            video_stream = next(stream for stream in container.streams if stream.type == "video")
chenych's avatar
chenych committed
185
            sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
luopl's avatar
luopl committed
186
187
188
189
190
191
192
193
194
195
196
            frames: List["ImageObject"] = []
            container.seek(0)
            for frame_idx, frame in enumerate(container.decode(video_stream)):
                if frame_idx in sample_indices:
                    frames.append(frame.to_image())

            frames = self._regularize_images(frames, **kwargs)
            results.append(frames)

        return results

chenych's avatar
chenych committed
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
    def _regularize_audios(self, audios: Sequence["AudioInput"], sampling_rate: float, **kwargs) -> List["NDArray"]:
        r"""
        Regularizes audios to avoid error. Including reading and resampling.
        """
        results = []
        for audio in audios:
            if isinstance(audio, str):
                audio = librosa.load(audio, sr=sampling_rate)[0]

            if not isinstance(audio, np.ndarray):
                raise ValueError(f"Expect input is a list of audios, but got {type(audio)}.")

            results.append(audio)

        return results

luopl's avatar
luopl committed
213
214
215
216
    def _get_mm_inputs(
        self,
        images: Sequence["ImageInput"],
        videos: Sequence["VideoInput"],
chenych's avatar
chenych committed
217
        audios: Sequence["AudioInput"],
luopl's avatar
luopl committed
218
219
220
221
222
223
224
225
226
227
228
229
230
231
        processor: "ProcessorMixin",
    ) -> Dict[str, "torch.Tensor"]:
        r"""
        Processes visual inputs.

        Returns: (llava and paligemma)
            pixel_values: tensor with shape (B, C, H, W)

        Returns: (qwen2-vl)
            pixel_values: tensor with shape (num_patches, patch_dim)
            image_grid_thw: tensor with shape (num_images, 3), where the three numbers are time, width, height

        It holds num_patches == torch.prod(image_grid_thw)
        """
chenych's avatar
chenych committed
232
        image_processor: "BaseImageProcessor" = getattr(processor, "image_processor", None)
luopl's avatar
luopl committed
233
        video_processor: "BaseImageProcessor" = getattr(processor, "video_processor", image_processor)
chenych's avatar
chenych committed
234
235
236
        feature_extractor: "SequenceFeatureExtractor" = getattr(processor, "feature_extractor", None)
        mm_inputs = {}

luopl's avatar
luopl committed
237
238
239
        if len(images) != 0:
            images = self._regularize_images(
                images,
chenych's avatar
chenych committed
240
241
                image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
                image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
luopl's avatar
luopl committed
242
            )
chenych's avatar
chenych committed
243
            mm_inputs.update(image_processor(images, return_tensors="pt"))
luopl's avatar
luopl committed
244
245
246
247

        if len(videos) != 0:
            videos = self._regularize_videos(
                videos,
chenych's avatar
chenych committed
248
249
                image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
                image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
luopl's avatar
luopl committed
250
                video_fps=getattr(processor, "video_fps", 2.0),
chenych's avatar
chenych committed
251
                video_maxlen=getattr(processor, "video_maxlen", 128),
luopl's avatar
luopl committed
252
            )
chenych's avatar
chenych committed
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
            if "videos" in inspect.signature(video_processor.preprocess).parameters:  # for qwen2_vl and video_llava
                mm_inputs.update(video_processor(images=None, videos=videos, return_tensors="pt"))
            else:  # for llava_next_video
                mm_inputs.update(video_processor(videos, return_tensors="pt"))

        if len(audios) != 0:
            audios = self._regularize_audios(
                audios,
                sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
            )
            mm_inputs.update(
                feature_extractor(
                    audios,
                    sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
                    return_attention_mask=True,
                    padding="max_length",
                    return_tensors="pt",
                )
            )
            mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask")  # prevent conflicts
luopl's avatar
luopl committed
273
274
275

        return mm_inputs

chenych's avatar
chenych committed
276
277
278

@dataclass
class BasePlugin(MMPluginMixin):
luopl's avatar
luopl committed
279
280
281
282
283
    def process_messages(
        self,
        messages: Sequence[Dict[str, str]],
        images: Sequence["ImageInput"],
        videos: Sequence["VideoInput"],
chenych's avatar
chenych committed
284
        audios: Sequence["AudioInput"],
luopl's avatar
luopl committed
285
286
287
288
289
        processor: Optional["ProcessorMixin"],
    ) -> List[Dict[str, str]]:
        r"""
        Pre-processes input messages before tokenization for VLMs.
        """
chenych's avatar
chenych committed
290
        self._validate_input(processor, images, videos, audios)
luopl's avatar
luopl committed
291
292
293
294
295
296
297
298
        return messages

    def process_token_ids(
        self,
        input_ids: List[int],
        labels: Optional[List[int]],
        images: Sequence["ImageInput"],
        videos: Sequence["VideoInput"],
chenych's avatar
chenych committed
299
        audios: Sequence["AudioInput"],
luopl's avatar
luopl committed
300
301
302
303
304
305
        tokenizer: "PreTrainedTokenizer",
        processor: Optional["ProcessorMixin"],
    ) -> Tuple[List[int], Optional[List[int]]]:
        r"""
        Pre-processes token ids after tokenization for VLMs.
        """
chenych's avatar
chenych committed
306
        self._validate_input(processor, images, videos, audios)
luopl's avatar
luopl committed
307
308
309
310
311
312
        return input_ids, labels

    def get_mm_inputs(
        self,
        images: Sequence["ImageInput"],
        videos: Sequence["VideoInput"],
chenych's avatar
chenych committed
313
        audios: Sequence["AudioInput"],
luopl's avatar
luopl committed
314
315
        imglens: Sequence[int],
        vidlens: Sequence[int],
chenych's avatar
chenych committed
316
        audlens: Sequence[int],
luopl's avatar
luopl committed
317
        batch_ids: Sequence[List[int]],
luopl's avatar
luopl committed
318
319
320
321
        processor: Optional["ProcessorMixin"],
    ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
        r"""
        Builds batched multimodal inputs for VLMs.
luopl's avatar
luopl committed
322
323
324
325
326
327

        Arguments:
            images: a list of image inputs, shape (num_images,)
            videos: a list of video inputs, shape (num_videos,)
            imglens: number of images in each sample, shape (batch_size,)
            vidlens: number of videos in each sample, shape (batch_size,)
chenych's avatar
chenych committed
328
            audlens: number of audios in each sample, shape (batch_size,)
luopl's avatar
luopl committed
329
            batch_ids: token ids of input samples, shape (batch_size, seq_len)
luopl's avatar
luopl committed
330
            processor: a processor for pre-processing images and videos
luopl's avatar
luopl committed
331
        """
chenych's avatar
chenych committed
332
        self._validate_input(processor, images, videos, audios)
luopl's avatar
luopl committed
333
334
335
        return {}


chenych's avatar
chenych committed
336
@dataclass
luopl's avatar
luopl committed
337
338
339
340
341
342
343
class LlavaPlugin(BasePlugin):
    @override
    def process_messages(
        self,
        messages: Sequence[Dict[str, str]],
        images: Sequence["ImageInput"],
        videos: Sequence["VideoInput"],
chenych's avatar
chenych committed
344
        audios: Sequence["AudioInput"],
luopl's avatar
luopl committed
345
346
        processor: Optional["ProcessorMixin"],
    ) -> List[Dict[str, str]]:
chenych's avatar
chenych committed
347
        self._validate_input(processor, images, videos, audios)
luopl's avatar
luopl committed
348
        num_image_tokens = 0
luopl's avatar
luopl committed
349
        image_seqlen = getattr(processor, "image_seqlen") if self.expand_mm_tokens else 1
luopl's avatar
luopl committed
350
351
352
353
        messages = deepcopy(messages)
        for message in messages:
            content = message["content"]
            while IMAGE_PLACEHOLDER in content:
luopl's avatar
luopl committed
354
                content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
luopl's avatar
luopl committed
355
                num_image_tokens += 1
luopl's avatar
luopl committed
356

luopl's avatar
luopl committed
357
            message["content"] = content.replace("{{image}}", self.image_token)
luopl's avatar
luopl committed
358
359

        if len(images) != num_image_tokens:
luopl's avatar
luopl committed
360
            raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
luopl's avatar
luopl committed
361
362
363
364
365
366
367
368

        return messages

    @override
    def get_mm_inputs(
        self,
        images: Sequence["ImageInput"],
        videos: Sequence["VideoInput"],
chenych's avatar
chenych committed
369
        audios: Sequence["AudioInput"],
luopl's avatar
luopl committed
370
371
        imglens: Sequence[int],
        vidlens: Sequence[int],
chenych's avatar
chenych committed
372
        audlens: Sequence[int],
luopl's avatar
luopl committed
373
        batch_ids: Sequence[List[int]],
luopl's avatar
luopl committed
374
375
        processor: Optional["ProcessorMixin"],
    ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
chenych's avatar
chenych committed
376
377
        self._validate_input(processor, images, videos, audios)
        return self._get_mm_inputs(images, videos, audios, processor)
luopl's avatar
luopl committed
378
379


chenych's avatar
chenych committed
380
@dataclass
luopl's avatar
luopl committed
381
382
383
384
385
386
387
class LlavaNextPlugin(BasePlugin):
    @override
    def process_messages(
        self,
        messages: Sequence[Dict[str, str]],
        images: Sequence["ImageInput"],
        videos: Sequence["VideoInput"],
chenych's avatar
chenych committed
388
        audios: Sequence["AudioInput"],
luopl's avatar
luopl committed
389
390
        processor: Optional["ProcessorMixin"],
    ) -> List[Dict[str, str]]:
chenych's avatar
chenych committed
391
        self._validate_input(processor, images, videos, audios)
luopl's avatar
luopl committed
392
393
        num_image_tokens = 0
        messages = deepcopy(messages)
chenych's avatar
chenych committed
394
        mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
luopl's avatar
luopl committed
395
        if "pixel_values" in mm_inputs:
chenych's avatar
chenych committed
396
            image_sizes = iter(mm_inputs["image_sizes"].tolist())
luopl's avatar
luopl committed
397
            height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
luopl's avatar
luopl committed
398

luopl's avatar
luopl committed
399
400
        for message in messages:
            content = message["content"]
luopl's avatar
luopl committed
401
            while IMAGE_PLACEHOLDER in content:
luopl's avatar
luopl committed
402
403
404
                if self.expand_mm_tokens:
                    orig_height, orig_width = next(image_sizes)
                    image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
chenych's avatar
chenych committed
405
                    if getattr(processor, "vision_feature_select_strategy", "default") == "default":
luopl's avatar
luopl committed
406
407
408
                        image_seqlen -= 1
                else:
                    image_seqlen = 1
luopl's avatar
luopl committed
409
410

                content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
luopl's avatar
luopl committed
411
                num_image_tokens += 1
luopl's avatar
luopl committed
412
413
414
415

            message["content"] = content.replace("{{image}}", self.image_token)

        if len(images) != num_image_tokens:
luopl's avatar
luopl committed
416
417
            raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")

luopl's avatar
luopl committed
418
419
420
421
422
423
424
        return messages

    @override
    def get_mm_inputs(
        self,
        images: Sequence["ImageInput"],
        videos: Sequence["VideoInput"],
chenych's avatar
chenych committed
425
        audios: Sequence["AudioInput"],
luopl's avatar
luopl committed
426
427
        imglens: Sequence[int],
        vidlens: Sequence[int],
chenych's avatar
chenych committed
428
        audlens: Sequence[int],
luopl's avatar
luopl committed
429
        batch_ids: Sequence[List[int]],
luopl's avatar
luopl committed
430
431
        processor: Optional["ProcessorMixin"],
    ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
chenych's avatar
chenych committed
432
433
        self._validate_input(processor, images, videos, audios)
        return self._get_mm_inputs(images, videos, audios, processor)
luopl's avatar
luopl committed
434
435


chenych's avatar
chenych committed
436
@dataclass
luopl's avatar
luopl committed
437
438
439
440
441
442
443
class LlavaNextVideoPlugin(BasePlugin):
    @override
    def process_messages(
        self,
        messages: Sequence[Dict[str, str]],
        images: Sequence["ImageInput"],
        videos: Sequence["VideoInput"],
chenych's avatar
chenych committed
444
        audios: Sequence["AudioInput"],
luopl's avatar
luopl committed
445
446
        processor: Optional["ProcessorMixin"],
    ) -> List[Dict[str, str]]:
chenych's avatar
chenych committed
447
        self._validate_input(processor, images, videos, audios)
luopl's avatar
luopl committed
448
        num_image_tokens, num_video_tokens = 0, 0
luopl's avatar
luopl committed
449
        messages = deepcopy(messages)
chenych's avatar
chenych committed
450
        mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
luopl's avatar
luopl committed
451
        if "pixel_values" in mm_inputs:
chenych's avatar
chenych committed
452
            image_sizes = iter(mm_inputs["image_sizes"].tolist())
luopl's avatar
luopl committed
453
454
455
            height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
            for message in messages:
                content = message["content"]
luopl's avatar
luopl committed
456
                while IMAGE_PLACEHOLDER in content:
luopl's avatar
luopl committed
457
458
459
                    if self.expand_mm_tokens:
                        orig_height, orig_width = next(image_sizes)
                        image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
chenych's avatar
chenych committed
460
                        if getattr(processor, "vision_feature_select_strategy", "default") == "default":
luopl's avatar
luopl committed
461
462
463
                            image_seqlen -= 1
                    else:
                        image_seqlen = 1
luopl's avatar
luopl committed
464
465

                    content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
luopl's avatar
luopl committed
466
                    num_image_tokens += 1
luopl's avatar
luopl committed
467
468
469
470

                message["content"] = content.replace("{{image}}", self.image_token)

        if "pixel_values_videos" in mm_inputs:
chenych's avatar
chenych committed
471
472
473
474
475
476
477
478
479
            if self.expand_mm_tokens:
                pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
                height, width = get_image_size(pixel_values_video[0])
                num_frames = pixel_values_video.shape[0]  # frame dim is always after batch dim
                image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
                video_seqlen = image_seqlen // 4 * num_frames  # divide by 4 needed for avg pooling layer
            else:
                video_seqlen = 1

luopl's avatar
luopl committed
480
481
            for message in messages:
                content = message["content"]
luopl's avatar
luopl committed
482
                while VIDEO_PLACEHOLDER in content:
luopl's avatar
luopl committed
483
                    num_video_tokens += 1
luopl's avatar
luopl committed
484
485
486
                    content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)

                message["content"] = content.replace("{{video}}", self.video_token)
luopl's avatar
luopl committed
487
488

        if len(images) != num_image_tokens:
luopl's avatar
luopl committed
489
            raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
luopl's avatar
luopl committed
490
491

        if len(videos) != num_video_tokens:
luopl's avatar
luopl committed
492
            raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
luopl's avatar
luopl committed
493
494
495
496
497
498
499
500

        return messages

    @override
    def get_mm_inputs(
        self,
        images: Sequence["ImageInput"],
        videos: Sequence["VideoInput"],
chenych's avatar
chenych committed
501
        audios: Sequence["AudioInput"],
luopl's avatar
luopl committed
502
503
        imglens: Sequence[int],
        vidlens: Sequence[int],
chenych's avatar
chenych committed
504
        audlens: Sequence[int],
luopl's avatar
luopl committed
505
        batch_ids: Sequence[List[int]],
luopl's avatar
luopl committed
506
507
        processor: Optional["ProcessorMixin"],
    ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
chenych's avatar
chenych committed
508
509
        self._validate_input(processor, images, videos, audios)
        return self._get_mm_inputs(images, videos, audios, processor)
luopl's avatar
luopl committed
510
511


chenych's avatar
chenych committed
512
@dataclass
luopl's avatar
luopl committed
513
class MiniCPMVPlugin(BasePlugin):
luopl's avatar
luopl committed
514
515
516
517
518
519
    @override
    def process_messages(
        self,
        messages: Sequence[Dict[str, str]],
        images: Sequence["ImageInput"],
        videos: Sequence["VideoInput"],
chenych's avatar
chenych committed
520
        audios: Sequence["AudioInput"],
luopl's avatar
luopl committed
521
522
        processor: Optional["ProcessorMixin"],
    ) -> List[Dict[str, str]]:
chenych's avatar
chenych committed
523
524
        self._validate_input(processor, images, videos, audios)
        num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
luopl's avatar
luopl committed
525
        messages = deepcopy(messages)
luopl's avatar
luopl committed
526
527
        image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
        mm_inputs = {}
chenych's avatar
chenych committed
528
        audio_inputs = {}
luopl's avatar
luopl committed
529
530
531
532
533
534
        if len(images) != 0 and len(videos) != 0:
            raise ValueError("MiniCPM-V model does not support input images and videos at the same time.")

        if len(videos) != 0:
            max_slice_nums = 2
            use_image_id = False
chenych's avatar
chenych committed
535
            mm_inputs = self._get_mm_inputs([], videos, [], processor)
luopl's avatar
luopl committed
536
537
538
539
        else:
            max_slice_nums = image_processor.max_slice_nums
            use_image_id = image_processor.use_image_id

chenych's avatar
chenych committed
540
        for i, message in enumerate(messages):
luopl's avatar
luopl committed
541
542
            content = message["content"]
            while IMAGE_PLACEHOLDER in content:
luopl's avatar
luopl committed
543
                content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
luopl's avatar
luopl committed
544
                num_image_tokens += 1
luopl's avatar
luopl committed
545
546
547
548
549
550

            while VIDEO_PLACEHOLDER in content:
                video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1
                content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1)
                num_video_tokens += 1

chenych's avatar
chenych committed
551
552
553
554
555
556
557
            while AUDIO_PLACEHOLDER in content:
                content = content.replace(AUDIO_PLACEHOLDER, "{{audio}}", 1)
                num_audio_tokens += 1

            message["content"] = content.replace("{{image}}", "(<image>./</image>)").replace(
                "{{audio}}", "(<audio>./</audio>)"
            )
luopl's avatar
luopl committed
558
559

        if num_image_tokens > 0:
chenych's avatar
chenych committed
560
561
562
563
            mm_inputs = self._get_mm_inputs(images, [], [], processor)

        if num_audio_tokens > 0:
            audio_inputs = self._get_mm_inputs([], [], audios, processor, ret_phs=True)
luopl's avatar
luopl committed
564
565
566
567

        if mm_inputs:
            pattern = "(<image>./</image>)"
            image_sizes = mm_inputs["image_sizes"]
chenych's avatar
chenych committed
568
            idx = 0
luopl's avatar
luopl committed
569
570
571
572
573
574
575
576
577
578
            for index, message in enumerate(messages):
                text = message["content"]
                image_tags = re.findall(pattern, text)
                text_chunks = text.split(pattern)
                final_text = ""
                for i in range(len(image_tags)):
                    final_text = (
                        final_text
                        + text_chunks[i]
                        + image_processor.get_slice_image_placeholder(
chenych's avatar
chenych committed
579
                            image_sizes[0][idx], idx, max_slice_nums, use_image_id
luopl's avatar
luopl committed
580
581
                        )
                    )
chenych's avatar
chenych committed
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
                    idx += 1

                final_text += text_chunks[-1]
                messages[index]["content"] = final_text

        if audio_inputs:
            pattern = "(<audio>./</audio>)"
            idx = 0
            for index, message in enumerate(messages):
                text = message["content"]
                audio_tags = re.findall(pattern, text)
                text_chunks = text.split(pattern)
                final_text = ""
                for i in range(len(audio_tags)):
                    audio_placeholder = audio_inputs["audio_phs"][0][idx]
                    final_text = final_text + text_chunks[i] + audio_placeholder
                    idx += 1
luopl's avatar
luopl committed
599
600
601
602
603
604
605
606
607
608

                final_text += text_chunks[-1]
                messages[index]["content"] = final_text

        if len(images) != num_image_tokens:
            raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")

        if len(videos) != num_video_tokens:
            raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")

chenych's avatar
chenych committed
609
610
611
        if len(audios) != num_audio_tokens:
            raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.")

luopl's avatar
luopl committed
612
613
614
615
616
617
618
        return messages

    @override
    def _get_mm_inputs(
        self,
        images: Sequence["ImageInput"],
        videos: Sequence["VideoInput"],
chenych's avatar
chenych committed
619
        audios: Sequence["AudioInput"],
luopl's avatar
luopl committed
620
621
622
623
        processor: "ProcessorMixin",
        **kwargs,
    ) -> Dict[str, "torch.Tensor"]:
        image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
chenych's avatar
chenych committed
624
        feature_extractor: "SequenceFeatureExtractor" = getattr(processor, "feature_extractor", None)
luopl's avatar
luopl committed
625
626
627
628
        mm_inputs = {}
        if len(images) != 0:
            images = self._regularize_images(
                images,
chenych's avatar
chenych committed
629
630
                image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
                image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
luopl's avatar
luopl committed
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
            )
            if "valid_image_nums_ls" in kwargs:
                valid_image_nums_ls = kwargs["valid_image_nums_ls"]
                new_images = []
                idx = 0
                for valid_image_nums in valid_image_nums_ls:
                    new_images.append(images[idx : idx + valid_image_nums])
                    idx += valid_image_nums

                images = new_images

            image_inputs = image_processor(
                images, do_pad=True, max_slice_nums=image_processor.max_slice_nums, return_tensors="pt"
            )
            mm_inputs.update(image_inputs)

        if len(videos) != 0:
            videos = self._regularize_videos(
                videos,
chenych's avatar
chenych committed
650
651
                image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
                image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
luopl's avatar
luopl committed
652
                video_fps=getattr(processor, "video_fps", 2.0),
chenych's avatar
chenych committed
653
                video_maxlen=getattr(processor, "video_maxlen", 128),
luopl's avatar
luopl committed
654
655
656
657
            )
            video_inputs = image_processor(videos, do_pad=True, max_slice_nums=2, return_tensors="pt")
            mm_inputs.update(video_inputs)

chenych's avatar
chenych committed
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
        if len(audios) != 0:
            audios = self._regularize_audios(
                audios,
                sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
            )
            if "valid_audio_nums_ls" in kwargs:
                valid_audio_nums_ls = kwargs["valid_audio_nums_ls"]
                audios_ls = []
                idx = 0
                for valid_audio_nums in valid_audio_nums_ls:
                    audios_ls.append(audios[idx : idx + valid_audio_nums])
                    idx += valid_audio_nums
            else:
                audios_ls = [audios]

            audio_features, audio_feature_lens, audio_phs = processor.audio_feature_extract(
                audios_ls,
                chunk_input=True,
                sampling_rate=16000,
            )
            audio_feature_lens = [torch.tensor(audio_feature_len) for audio_feature_len in audio_feature_lens]
            mm_inputs.update({"audio_features": audio_features, "audio_feature_lens": audio_feature_lens})
            if kwargs.get("ret_phs", False):
                mm_inputs.update({"audio_phs": audio_phs})

luopl's avatar
luopl committed
683
684
685
686
687
688
689
        return mm_inputs

    @override
    def get_mm_inputs(
        self,
        images: Sequence["ImageInput"],
        videos: Sequence["VideoInput"],
chenych's avatar
chenych committed
690
        audios: Sequence["AudioInput"],
luopl's avatar
luopl committed
691
692
        imglens: Sequence[int],
        vidlens: Sequence[int],
chenych's avatar
chenych committed
693
        audlens: Sequence[int],
luopl's avatar
luopl committed
694
695
696
        batch_ids: Sequence[List[int]],
        processor: Optional["ProcessorMixin"],
    ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
chenych's avatar
chenych committed
697
698
        self._validate_input(processor, images, videos, audios)
        # image bound
luopl's avatar
luopl committed
699
700
        image_bounds_list = []
        valid_image_nums_ls = []
chenych's avatar
chenych committed
701
        for i, input_ids in enumerate(batch_ids):
luopl's avatar
luopl committed
702
703
704
705
706
707
708
709
            input_ids_ = torch.tensor(input_ids)
            start_cond = (input_ids_ == processor.tokenizer.im_start_id) | (
                input_ids_ == processor.tokenizer.slice_start_id
            )
            end_cond = (input_ids_ == processor.tokenizer.im_end_id) | (input_ids_ == processor.tokenizer.slice_end_id)
            image_start_tokens = torch.where(start_cond)[0]
            image_start_tokens += 1
            image_end_tokens = torch.where(end_cond)[0]
chenych's avatar
chenych committed
710
            valid_image_nums_ls.append(imglens[i])
luopl's avatar
luopl committed
711
712
            image_bounds = torch.hstack(
                [
chenych's avatar
chenych committed
713
714
                    image_start_tokens.unsqueeze(-1),
                    image_end_tokens.unsqueeze(-1),
luopl's avatar
luopl committed
715
716
717
718
                ]
            )
            image_bounds_list.append(image_bounds)

chenych's avatar
chenych committed
719
720
721
722
723
        mm_inputs = self._get_mm_inputs(images, videos, [], processor, valid_image_nums_ls=valid_image_nums_ls)
        if "tgt_sizes" not in mm_inputs:
            dummy_data = [torch.empty(0) for _ in range(len(batch_ids))]
            mm_inputs.update({"tgt_sizes": dummy_data, "pixel_values": dummy_data, "image_sizes": dummy_data})

luopl's avatar
luopl committed
724
        mm_inputs.update({"image_bound": image_bounds_list})
chenych's avatar
chenych committed
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750

        if len(audios) > 0:
            # audio bound
            audio_bounds_ls = []
            spk_bounds_ls = []
            valid_audio_nums_ls = []

            for input_ids, audiolen in zip(batch_ids, audlens):
                input_ids_ = torch.tensor(input_ids)
                audio_start_idx = torch.where(input_ids_ == processor.tokenizer.audio_start_id)[0]
                audio_end_idx = torch.where(input_ids_ == processor.tokenizer.audio_end_id)[0]
                assert len(audio_start_idx) == len(audio_end_idx)
                audio_bounds = torch.hstack([(audio_start_idx + 1).unsqueeze(-1), audio_end_idx.unsqueeze(-1)])
                audio_bounds_ls.append(audio_bounds)
                valid_audio_nums_ls.append(audiolen)

                spk_start_idx = torch.where(input_ids_ == processor.tokenizer.spk_start_id)[0]
                spk_end_idx = torch.where(input_ids_ == processor.tokenizer.spk_end_id)[0]
                assert len(spk_start_idx) == len(spk_end_idx)
                spk_bounds = torch.hstack([(spk_start_idx + 1).unsqueeze(-1), spk_end_idx.unsqueeze(-1)])
                spk_bounds_ls.append(spk_bounds)

            audio_inputs = self._get_mm_inputs([], [], audios, processor, valid_audio_nums_ls=valid_audio_nums_ls)
            mm_inputs.update(audio_inputs)
            mm_inputs.update({"audio_bounds": audio_bounds_ls, "spk_bounds": spk_bounds_ls})

luopl's avatar
luopl committed
751
752
753
        return mm_inputs


chenych's avatar
chenych committed
754
@dataclass
luopl's avatar
luopl committed
755
756
757
758
759
760
761
class MllamaPlugin(BasePlugin):
    @override
    def process_messages(
        self,
        messages: Sequence[Dict[str, str]],
        images: Sequence["ImageInput"],
        videos: Sequence["VideoInput"],
chenych's avatar
chenych committed
762
        audios: Sequence["AudioInput"],
luopl's avatar
luopl committed
763
764
        processor: Optional["ProcessorMixin"],
    ) -> List[Dict[str, str]]:
chenych's avatar
chenych committed
765
        self._validate_input(processor, images, videos, audios)
luopl's avatar
luopl committed
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
        num_image_tokens = 0
        messages = deepcopy(messages)
        for message in messages:
            content = message["content"]
            num_image_tokens += content.count(IMAGE_PLACEHOLDER)
            message["content"] = content.replace(IMAGE_PLACEHOLDER, self.image_token)

        if len(images) != num_image_tokens:
            raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")

        return messages

    @override
    def _get_mm_inputs(
        self,
        images: Sequence["ImageInput"],
        videos: Sequence["VideoInput"],
chenych's avatar
chenych committed
783
        audios: Sequence["AudioInput"],
luopl's avatar
luopl committed
784
        processor: "ProcessorMixin",
chenych's avatar
chenych committed
785
        imglens: List[int],
luopl's avatar
luopl committed
786
787
788
789
790
791
792
793
794
795
796
797
798
    ) -> Dict[str, "torch.Tensor"]:
        r"""
        Processes visual inputs for mllama because its image processor only accepts List[List[ImageInput]].

        Returns:
            pixel_values: tensor with shape
                          (batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width)
                          For example, (2, 1, 4, 3, 560, 560).
            aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1).
            aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4).
            num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
        """
        image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
chenych's avatar
chenych committed
799
800
801
802
803
804
805
806
807
808
809
810
811
        mm_inputs = {}
        if len(images) > 0:
            images = self._regularize_images(
                images,
                image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
                image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
            )
            batch_images = []
            for image_length in imglens:
                batch_images.append(images[:image_length])
                images = images[image_length:]

            mm_inputs.update(image_processor(batch_images, return_tensors="pt"))
luopl's avatar
luopl committed
812

chenych's avatar
chenych committed
813
        return mm_inputs
luopl's avatar
luopl committed
814

chenych's avatar
chenych committed
815
    @override
luopl's avatar
luopl committed
816
817
818
819
    def get_mm_inputs(
        self,
        images: Sequence["ImageInput"],
        videos: Sequence["VideoInput"],
chenych's avatar
chenych committed
820
        audios: Sequence["AudioInput"],
luopl's avatar
luopl committed
821
822
        imglens: Sequence[int],
        vidlens: Sequence[int],
chenych's avatar
chenych committed
823
        audlens: Sequence[int],
luopl's avatar
luopl committed
824
825
826
        batch_ids: Sequence[List[int]],
        processor: Optional["ProcessorMixin"],
    ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
chenych's avatar
chenych committed
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
        self._validate_input(processor, images, videos, audios)
        mm_inputs = self._get_mm_inputs(images, videos, audios, processor, imglens)
        if mm_inputs:
            num_tiles = mm_inputs.pop("num_tiles")
            image_token_id = getattr(processor, "image_token_id")
            max_image_tiles = getattr(processor.image_processor, "max_image_tiles")
            cross_attention_token_mask = [
                get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids
            ]
            mm_inputs["cross_attention_mask"] = torch.from_numpy(
                convert_sparse_cross_attention_mask_to_dense(
                    cross_attention_token_mask,
                    num_tiles=num_tiles,
                    max_num_tiles=max_image_tiles,
                    length=max(len(input_ids) for input_ids in batch_ids),
                )
            )  # shape: (batch_size, length, max_num_images, max_num_tiles)

luopl's avatar
luopl committed
845
846
847
        return mm_inputs


chenych's avatar
chenych committed
848
@dataclass
luopl's avatar
luopl committed
849
850
851
852
853
854
855
class PaliGemmaPlugin(BasePlugin):
    @override
    def process_messages(
        self,
        messages: Sequence[Dict[str, str]],
        images: Sequence["ImageInput"],
        videos: Sequence["VideoInput"],
chenych's avatar
chenych committed
856
        audios: Sequence["AudioInput"],
luopl's avatar
luopl committed
857
858
        processor: Optional["ProcessorMixin"],
    ) -> List[Dict[str, str]]:
chenych's avatar
chenych committed
859
        self._validate_input(processor, images, videos, audios)
luopl's avatar
luopl committed
860
861
862
863
864
        num_image_tokens = 0
        messages = deepcopy(messages)
        for message in messages:
            content = message["content"]
            while IMAGE_PLACEHOLDER in content:
luopl's avatar
luopl committed
865
                content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
luopl's avatar
luopl committed
866
                num_image_tokens += 1
luopl's avatar
luopl committed
867
868
869
870

            message["content"] = content.replace("{{image}}", "")

        if len(images) != num_image_tokens:
luopl's avatar
luopl committed
871
            raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
luopl's avatar
luopl committed
872
873
874
875
876
877
878
879
880
881

        return messages

    @override
    def process_token_ids(
        self,
        input_ids: List[int],
        labels: Optional[List[int]],
        images: Sequence["ImageInput"],
        videos: Sequence["VideoInput"],
chenych's avatar
chenych committed
882
        audios: Sequence["AudioInput"],
luopl's avatar
luopl committed
883
884
885
        tokenizer: "PreTrainedTokenizer",
        processor: Optional["ProcessorMixin"],
    ) -> Tuple[List[int], Optional[List[int]]]:
chenych's avatar
chenych committed
886
        self._validate_input(processor, images, videos, audios)
luopl's avatar
luopl committed
887
        num_images = len(images)
luopl's avatar
luopl committed
888
        image_seqlen = num_images * getattr(processor, "image_seqlen") if self.expand_mm_tokens else 0  # skip mm token
luopl's avatar
luopl committed
889
890
891
892
893
894
895
896
897
898
899
900
        image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
        input_ids = [image_token_id] * image_seqlen + input_ids
        if labels is not None:
            labels = [IGNORE_INDEX] * image_seqlen + labels

        return input_ids, labels

    @override
    def get_mm_inputs(
        self,
        images: Sequence["ImageInput"],
        videos: Sequence["VideoInput"],
chenych's avatar
chenych committed
901
        audios: Sequence["AudioInput"],
luopl's avatar
luopl committed
902
903
        imglens: Sequence[int],
        vidlens: Sequence[int],
chenych's avatar
chenych committed
904
        audlens: Sequence[int],
luopl's avatar
luopl committed
905
        batch_ids: Sequence[List[int]],
luopl's avatar
luopl committed
906
907
        processor: Optional["ProcessorMixin"],
    ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
chenych's avatar
chenych committed
908
        self._validate_input(processor, images, videos, audios)
luopl's avatar
luopl committed
909
        seqlens = [len(input_ids) for input_ids in batch_ids]
chenych's avatar
chenych committed
910
        mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
luopl's avatar
luopl committed
911
912
913
914
        mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor)
        return mm_inputs


chenych's avatar
chenych committed
915
@dataclass
luopl's avatar
luopl committed
916
917
918
919
920
921
922
class PixtralPlugin(BasePlugin):
    @override
    def process_messages(
        self,
        messages: Sequence[Dict[str, str]],
        images: Sequence["ImageInput"],
        videos: Sequence["VideoInput"],
chenych's avatar
chenych committed
923
        audios: Sequence["AudioInput"],
luopl's avatar
luopl committed
924
925
        processor: Optional["ProcessorMixin"],
    ) -> List[Dict[str, str]]:
chenych's avatar
chenych committed
926
        self._validate_input(processor, images, videos, audios)
luopl's avatar
luopl committed
927
928
929
930
931
932
933
        patch_size = getattr(processor, "patch_size")
        image_token = getattr(processor, "image_token")
        image_break_token = getattr(processor, "image_break_token")
        image_end_token = getattr(processor, "image_end_token")

        num_image_tokens = 0
        messages = deepcopy(messages)
chenych's avatar
chenych committed
934
935
936
937
        mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
        if "pixel_values" in mm_inputs:
            image_sizes = iter(mm_inputs["image_sizes"].tolist())

luopl's avatar
luopl committed
938
939
940
        for message in messages:
            content = message["content"]
            while IMAGE_PLACEHOLDER in content:
luopl's avatar
luopl committed
941
                if self.expand_mm_tokens:
chenych's avatar
chenych committed
942
                    height, width = next(image_sizes)
luopl's avatar
luopl committed
943
944
945
946
947
948
949
950
951
                    num_height_tokens = height // patch_size
                    num_width_tokens = width // patch_size
                    replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
                    replace_tokens = [item for sublist in replace_tokens for item in sublist]  # flatten list
                    replace_tokens[-1] = image_end_token
                    replace_str = "".join(replace_tokens)
                else:
                    replace_str = image_token

luopl's avatar
luopl committed
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
                content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
                num_image_tokens += 1

            message["content"] = content

        if len(images) != num_image_tokens:
            raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")

        return messages

    @override
    def get_mm_inputs(
        self,
        images: Sequence["ImageInput"],
        videos: Sequence["VideoInput"],
chenych's avatar
chenych committed
967
        audios: Sequence["AudioInput"],
luopl's avatar
luopl committed
968
969
        imglens: Sequence[int],
        vidlens: Sequence[int],
chenych's avatar
chenych committed
970
        audlens: Sequence[int],
luopl's avatar
luopl committed
971
972
973
        batch_ids: Sequence[List[int]],
        processor: Optional["ProcessorMixin"],
    ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
chenych's avatar
chenych committed
974
975
        self._validate_input(processor, images, videos, audios)
        mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
luopl's avatar
luopl committed
976
977
978
979
        mm_inputs.pop("image_sizes", None)
        return mm_inputs


chenych's avatar
chenych committed
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
@dataclass
class Qwen2AudioPlugin(BasePlugin):
    @override
    def process_messages(
        self,
        messages: Sequence[Dict[str, str]],
        images: Sequence["ImageInput"],
        videos: Sequence["VideoInput"],
        audios: Sequence["AudioInput"],
        processor: Optional["ProcessorMixin"],
    ) -> List[Dict[str, str]]:
        self._validate_input(processor, images, videos, audios)
        bos_token: str = getattr(processor, "audio_bos_token")
        eos_token: str = getattr(processor, "audio_eos_token")
        messages = deepcopy(messages)
        mm_inputs = self._get_mm_inputs([], [], audios, processor)
        if "feature_attention_mask" in mm_inputs:
            audio_lengths = mm_inputs["feature_attention_mask"].sum(-1).tolist()

        num_audio_tokens = 0
        for message in messages:
            content = message["content"]
            while AUDIO_PLACEHOLDER in content:
                if self.expand_mm_tokens:
                    audio_length = audio_lengths.pop(0)
                    input_length = (audio_length - 1) // 2 + 1
                    audio_seqlen = (input_length - 2) // 2 + 1
                else:
                    audio_seqlen = 1

                content = content.replace(
                    AUDIO_PLACEHOLDER, f"{bos_token}{self.audio_token * audio_seqlen}{eos_token}", 1
                )
                num_audio_tokens += 1

            message["content"] = content

        if len(audios) != num_audio_tokens:
            raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.")

        return messages

    @override
    def get_mm_inputs(
        self,
        images: Sequence["ImageInput"],
        videos: Sequence["VideoInput"],
        audios: Sequence["AudioInput"],
        imglens: Sequence[int],
        vidlens: Sequence[int],
        audlens: Sequence[int],
        batch_ids: Sequence[List[int]],
        processor: Optional["ProcessorMixin"],
    ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
        self._validate_input(processor, images, videos, audios)
        return self._get_mm_inputs(images, videos, audios, processor)


@dataclass
class Qwen2VLPlugin(BasePlugin):
luopl's avatar
luopl committed
1040
1041
1042
1043
1044
    @override
    def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
        image = super()._preprocess_image(image, **kwargs)
        if min(image.width, image.height) < 28:
            width, height = max(image.width, 28), max(image.height, 28)
chenych's avatar
chenych committed
1045
            image = image.resize((width, height))
luopl's avatar
luopl committed
1046
1047
1048

        if image.width / image.height > 200:
            width, height = image.height * 180, image.height
chenych's avatar
chenych committed
1049
            image = image.resize((width, height))
luopl's avatar
luopl committed
1050
1051
1052

        if image.height / image.width > 200:
            width, height = image.width, image.width * 180
chenych's avatar
chenych committed
1053
            image = image.resize((width, height))
luopl's avatar
luopl committed
1054
1055
1056
1057

        return image

    @override
chenych's avatar
chenych committed
1058
1059
1060
1061
    def _regularize_videos(
        self, videos: Sequence["VideoInput"], **kwargs
    ) -> Tuple[List[List["ImageObject"]], List[float]]:
        results, fps_per_video = [], []
luopl's avatar
luopl committed
1062
1063
1064
        for video in videos:
            container = av.open(video, "r")
            video_stream = next(stream for stream in container.streams if stream.type == "video")
chenych's avatar
chenych committed
1065
            sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
luopl's avatar
luopl committed
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
            frames: List["ImageObject"] = []
            container.seek(0)
            for frame_idx, frame in enumerate(container.decode(video_stream)):
                if frame_idx in sample_indices:
                    frames.append(frame.to_image())

            if len(frames) % 2 != 0:  # qwen2-vl requires even number of frames
                frames.append(frames[-1])

            frames = self._regularize_images(frames, **kwargs)
            results.append(frames)
chenych's avatar
chenych committed
1077
1078
1079
1080
            if video_stream.duration is None:
                fps_per_video.append(2.0)
            else:
                fps_per_video.append(len(sample_indices) / float(video_stream.duration * video_stream.time_base))
luopl's avatar
luopl committed
1081

chenych's avatar
chenych committed
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
        return results, fps_per_video

    @override
    def _get_mm_inputs(
        self,
        images: Sequence["ImageInput"],
        videos: Sequence["VideoInput"],
        audios: Sequence["AudioInput"],
        processor: "ProcessorMixin",
    ) -> Dict[str, "torch.Tensor"]:
        image_processor: "BaseImageProcessor" = getattr(processor, "image_processor", None)
        mm_inputs = {}
        if len(images) != 0:
            images = self._regularize_images(
                images,
                image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
                image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
            )
            mm_inputs.update(image_processor(images, return_tensors="pt"))

        if len(videos) != 0:
            videos, fps_per_video = self._regularize_videos(
                videos,
                image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
                image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
                video_fps=getattr(processor, "video_fps", 2.0),
                video_maxlen=getattr(processor, "video_maxlen", 128),
            )
            mm_inputs.update(image_processor(images=None, videos=videos, return_tensors="pt"))
            mm_inputs["fps_per_video"] = fps_per_video

        return mm_inputs
luopl's avatar
luopl committed
1114
1115
1116
1117
1118
1119
1120

    @override
    def process_messages(
        self,
        messages: Sequence[Dict[str, str]],
        images: Sequence["ImageInput"],
        videos: Sequence["VideoInput"],
chenych's avatar
chenych committed
1121
        audios: Sequence["AudioInput"],
luopl's avatar
luopl committed
1122
1123
        processor: Optional["ProcessorMixin"],
    ) -> List[Dict[str, str]]:
chenych's avatar
chenych committed
1124
1125
1126
        self._validate_input(processor, images, videos, audios)
        num_image_tokens, num_video_tokens = 0, 0
        messages = deepcopy(messages)
luopl's avatar
luopl committed
1127
        image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
chenych's avatar
chenych committed
1128

luopl's avatar
luopl committed
1129
        merge_length: int = getattr(image_processor, "merge_size") ** 2
chenych's avatar
chenych committed
1130
1131
1132
1133
1134
1135
1136
        if self.expand_mm_tokens:
            mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
            image_grid_thw = mm_inputs.get("image_grid_thw", [])
            video_grid_thw = mm_inputs.get("video_grid_thw", [])
        else:
            image_grid_thw = [None] * len(images)
            video_grid_thw = [None] * len(videos)
luopl's avatar
luopl committed
1137
1138
1139
1140
1141

        for message in messages:
            content = message["content"]
            while IMAGE_PLACEHOLDER in content:
                if num_image_tokens >= len(image_grid_thw):
luopl's avatar
luopl committed
1142
                    raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
luopl's avatar
luopl committed
1143

luopl's avatar
luopl committed
1144
                image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
luopl's avatar
luopl committed
1145
                content = content.replace(
luopl's avatar
luopl committed
1146
                    IMAGE_PLACEHOLDER, f"<|vision_start|>{self.image_token * image_seqlen}<|vision_end|>", 1
luopl's avatar
luopl committed
1147
1148
1149
1150
1151
                )
                num_image_tokens += 1

            while VIDEO_PLACEHOLDER in content:
                if num_video_tokens >= len(video_grid_thw):
luopl's avatar
luopl committed
1152
                    raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
luopl's avatar
luopl committed
1153

luopl's avatar
luopl committed
1154
                video_seqlen = video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
luopl's avatar
luopl committed
1155
                content = content.replace(
luopl's avatar
luopl committed
1156
                    VIDEO_PLACEHOLDER, f"<|vision_start|>{self.video_token * video_seqlen}<|vision_end|>", 1
luopl's avatar
luopl committed
1157
1158
1159
1160
1161
1162
                )
                num_video_tokens += 1

            message["content"] = content

        if len(images) != num_image_tokens:
luopl's avatar
luopl committed
1163
            raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
luopl's avatar
luopl committed
1164
1165

        if len(videos) != num_video_tokens:
luopl's avatar
luopl committed
1166
            raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
luopl's avatar
luopl committed
1167
1168
1169
1170
1171
1172
1173
1174

        return messages

    @override
    def get_mm_inputs(
        self,
        images: Sequence["ImageInput"],
        videos: Sequence["VideoInput"],
chenych's avatar
chenych committed
1175
        audios: Sequence["AudioInput"],
luopl's avatar
luopl committed
1176
1177
        imglens: Sequence[int],
        vidlens: Sequence[int],
chenych's avatar
chenych committed
1178
        audlens: Sequence[int],
luopl's avatar
luopl committed
1179
        batch_ids: Sequence[List[int]],
luopl's avatar
luopl committed
1180
1181
        processor: Optional["ProcessorMixin"],
    ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
chenych's avatar
chenych committed
1182
1183
1184
1185
1186
1187
1188
1189
        self._validate_input(processor, images, videos, audios)
        mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
        fps_per_video = mm_inputs.pop("fps_per_video", [])
        image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
        if "second_per_grid_ts" in processor.model_input_names and fps_per_video:
            mm_inputs["second_per_grid_ts"] = [image_processor.temporal_patch_size / fps for fps in fps_per_video]

        return mm_inputs
luopl's avatar
luopl committed
1190
1191


chenych's avatar
chenych committed
1192
@dataclass
luopl's avatar
luopl committed
1193
1194
1195
1196
1197
1198
1199
class VideoLlavaPlugin(BasePlugin):
    @override
    def process_messages(
        self,
        messages: Sequence[Dict[str, str]],
        images: Sequence["ImageInput"],
        videos: Sequence["VideoInput"],
chenych's avatar
chenych committed
1200
        audios: Sequence["AudioInput"],
luopl's avatar
luopl committed
1201
1202
        processor: Optional["ProcessorMixin"],
    ) -> List[Dict[str, str]]:
chenych's avatar
chenych committed
1203
        self._validate_input(processor, images, videos, audios)
luopl's avatar
luopl committed
1204
        num_image_tokens, num_video_tokens = 0, 0
luopl's avatar
luopl committed
1205
        messages = deepcopy(messages)
chenych's avatar
chenych committed
1206
        mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
luopl's avatar
luopl committed
1207
        num_frames = 0
luopl's avatar
luopl committed
1208
1209
1210
        has_images = "pixel_values_images" in mm_inputs
        has_videos = "pixel_values_videos" in mm_inputs
        if has_images or has_videos:
luopl's avatar
luopl committed
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
            if self.expand_mm_tokens:
                if has_images:
                    height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0]))
                    num_frames = 1

                if has_videos:
                    pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
                    height, width = get_image_size(pixel_values_video[0])
                    num_frames = pixel_values_video.shape[0]  # frame dim is always after batch dim

                image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1
                video_seqlen = image_seqlen * num_frames
chenych's avatar
chenych committed
1223
                if getattr(processor, "vision_feature_select_strategy", "default") == "default":
luopl's avatar
luopl committed
1224
1225
1226
                    image_seqlen -= 1
            else:
                image_seqlen, video_seqlen = 1, 1
luopl's avatar
luopl committed
1227

luopl's avatar
luopl committed
1228
1229
            for message in messages:
                content = message["content"]
luopl's avatar
luopl committed
1230
1231
                while IMAGE_PLACEHOLDER in content:
                    content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
luopl's avatar
luopl committed
1232
                    num_image_tokens += 1
luopl's avatar
luopl committed
1233
1234
1235

                while VIDEO_PLACEHOLDER in content:
                    content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
luopl's avatar
luopl committed
1236
                    num_video_tokens += 1
luopl's avatar
luopl committed
1237

luopl's avatar
luopl committed
1238
1239
                content = content.replace("{{image}}", self.image_token)
                message["content"] = content.replace("{{video}}", self.video_token)
luopl's avatar
luopl committed
1240
1241

        if len(images) != num_image_tokens:
luopl's avatar
luopl committed
1242
            raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
luopl's avatar
luopl committed
1243
1244

        if len(videos) != num_video_tokens:
luopl's avatar
luopl committed
1245
            raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
luopl's avatar
luopl committed
1246
1247
1248
1249
1250
1251
1252
1253

        return messages

    @override
    def get_mm_inputs(
        self,
        images: Sequence["ImageInput"],
        videos: Sequence["VideoInput"],
chenych's avatar
chenych committed
1254
        audios: Sequence["AudioInput"],
luopl's avatar
luopl committed
1255
1256
        imglens: Sequence[int],
        vidlens: Sequence[int],
chenych's avatar
chenych committed
1257
        audlens: Sequence[int],
luopl's avatar
luopl committed
1258
        batch_ids: Sequence[List[int]],
luopl's avatar
luopl committed
1259
1260
        processor: Optional["ProcessorMixin"],
    ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
chenych's avatar
chenych committed
1261
1262
        self._validate_input(processor, images, videos, audios)
        return self._get_mm_inputs(images, videos, audios, processor)
luopl's avatar
luopl committed
1263
1264
1265
1266
1267
1268
1269


PLUGINS = {
    "base": BasePlugin,
    "llava": LlavaPlugin,
    "llava_next": LlavaNextPlugin,
    "llava_next_video": LlavaNextVideoPlugin,
luopl's avatar
luopl committed
1270
1271
    "minicpm_v": MiniCPMVPlugin,
    "mllama": MllamaPlugin,
luopl's avatar
luopl committed
1272
    "paligemma": PaliGemmaPlugin,
luopl's avatar
luopl committed
1273
    "pixtral": PixtralPlugin,
chenych's avatar
chenych committed
1274
1275
    "qwen2_audio": Qwen2AudioPlugin,
    "qwen2_vl": Qwen2VLPlugin,
luopl's avatar
luopl committed
1276
1277
1278
1279
    "video_llava": VideoLlavaPlugin,
}


chenych's avatar
chenych committed
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
def register_mm_plugin(name: str, plugin_class: Type["BasePlugin"]) -> None:
    r"""
    Registers a multimodal plugin.
    """
    if name in PLUGINS:
        raise ValueError(f"Multimodal plugin {name} already exists.")

    PLUGINS[name] = plugin_class


luopl's avatar
luopl committed
1290
1291
1292
1293
def get_mm_plugin(
    name: str,
    image_token: Optional[str] = None,
    video_token: Optional[str] = None,
chenych's avatar
chenych committed
1294
    audio_token: Optional[str] = None,
luopl's avatar
luopl committed
1295
) -> "BasePlugin":
chenych's avatar
chenych committed
1296
1297
1298
1299
    r"""
    Gets plugin for multimodal inputs.
    """
    if name not in PLUGINS:
luopl's avatar
luopl committed
1300
        raise ValueError(f"Multimodal plugin `{name}` not found.")
luopl's avatar
luopl committed
1301

chenych's avatar
chenych committed
1302
    return PLUGINS[name](image_token, video_token, audio_token)