minicpmv.py 57.1 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Alphi's avatar
Alphi committed
24
"""Inference-only MiniCPM-V model compatible with HuggingFace weights."""
25
26
import math
import re
27
from collections import Counter
28
from functools import cached_property, partial
29
30
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
                    Optional, Set, Tuple, TypedDict, Union)
31

32
import numpy as np
33
import torch
Alphi's avatar
Alphi committed
34
import torch.types
35
36
from PIL import Image
from torch import nn
37
from transformers import BatchFeature, PretrainedConfig
38
from typing_extensions import TypeVar
39
40

from vllm.attention import AttentionMetadata
41
from vllm.config import VllmConfig
42
from vllm.model_executor.layers.quantization import QuantizationConfig
43
from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2,
44
                                                  get_2d_sincos_pos_embed)
Joe Runde's avatar
Joe Runde committed
45
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
Jee Jee Li's avatar
Jee Jee Li committed
46
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
47
48
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.models.minicpm import MiniCPMForCausalLM
49
from vllm.model_executor.models.module_mapping import MultiModelKeys
50
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
51
from vllm.model_executor.sampling_metadata import SamplingMetadata
52
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
53
54
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
                                    MultiModalInputs, PlaceholderRange)
55
56
57
58
from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem, ImageSize,
                                   ModalityData, ModalityDataItems,
                                   MultiModalDataItems, MultiModalDataParser,
                                   VideoItem)
59
60
61
62
from vllm.multimodal.processing import (BaseMultiModalProcessor,
                                        BaseProcessingInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
63

Jee Jee Li's avatar
Jee Jee Li committed
64
from .idefics2_vision_model import Idefics2VisionTransformer
65
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
66
from .utils import AutoWeightsLoader, maybe_prefix
67

68
CPU_DEVICE = torch.device("cpu")
69

70
RawImageType = Union[Image.Image, torch.Tensor]
71
72


Jee Jee Li's avatar
Jee Jee Li committed
73
class MiniCPMVImagePixelInputs(TypedDict):
74
75
    type: Literal["pixel_values"]
    data: List[torch.Tensor]
Jee Jee Li's avatar
Jee Jee Li committed
76
    """
77
    Shape: `(batch_size * num_images * num_slices, num_channels, height, width)`
Jee Jee Li's avatar
Jee Jee Li committed
78
79
80
81
82
83
84

    Note that the image size may vary, so we pass it as a list
    instead of a batched tensor.
    """

    image_bounds: torch.Tensor
    """
85
    Shape: `(batch_size * num_images * num_slices, 2)`
Jee Jee Li's avatar
Jee Jee Li committed
86
87
88
89
90
91

    This should be in `(start, stop)` format.
    """

    tgt_sizes: torch.Tensor
    """
92
    Shape: `(batch_size * num_images * num_slices, 2)`
Jee Jee Li's avatar
Jee Jee Li committed
93
94
95
96
97

    This should be in `(height, width)` format.
    """


98
99
100
101
class MiniCPMVImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
    """
102
103
    Shape: `(batch_size * num_images * num_slices, 
             image_feature_size, hidden_size)`
104
105
106
107
108
109
110

    `hidden_size` must match the hidden size of language model backbone.
    instead of a batched tensor.
    """

    image_bounds: torch.Tensor
    """
111
    Shape: `(batch_size * num_images * num_slices, 2)`
112
113
114
115
116
117
118
119

    This should be in `(start, stop)` format.
    """


MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs,
                            MiniCPMVImageEmbeddingInputs]

Jee Jee Li's avatar
Jee Jee Li committed
120
121
122
123
124
DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)


class Resampler2_5(BaseResampler):

125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
    def __init__(self,
                 num_queries: int,
                 embed_dim: int,
                 num_heads: int,
                 kv_dim: Optional[int] = None,
                 norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
                 max_size: Tuple[int, int] = (70, 70),
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = "") -> None:
        super().__init__(num_queries,
                         embed_dim,
                         num_heads,
                         kv_dim,
                         norm_layer,
                         quant_config=quant_config,
                         prefix=prefix)
Jee Jee Li's avatar
Jee Jee Li committed
141
142
143

        self.max_size = max_size
        self._set_2d_pos_cache(self.max_size)
144

Alphi's avatar
Alphi committed
145
146
    def _set_2d_pos_cache(self,
                          max_size: Tuple[int, int],
Jee Jee Li's avatar
Jee Jee Li committed
147
148
149
150
151
                          device: torch.types.Device = "cpu") -> None:
        pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim,
                                                max_size,
                                                version=(2, 5))
        pos_embed = torch.from_numpy(pos_embed_arr).float().to(device)
152
153
        self.register_buffer("pos_embed", pos_embed, persistent=False)

Alphi's avatar
Alphi committed
154
    def _adjust_pos_cache(self, tgt_sizes: torch.Tensor,
Jee Jee Li's avatar
Jee Jee Li committed
155
156
157
158
159
                          device: torch.types.Device) -> None:
        max_h = tgt_sizes[:, 0].max().item()
        max_w = tgt_sizes[:, 1].max().item()
        assert isinstance(max_h, int) and isinstance(max_w, int)

160
        if max_h > self.max_size[0] or max_w > self.max_size[1]:
Jee Jee Li's avatar
Jee Jee Li committed
161
            self.max_size = (
162
                max(max_h, self.max_size[0]),
Jee Jee Li's avatar
Jee Jee Li committed
163
164
                max(max_w, self.max_size[1]),
            )
165
166
            self._set_2d_pos_cache(self.max_size, device)

Jee Jee Li's avatar
Jee Jee Li committed
167
168
    def forward(self, x: torch.Tensor,
                tgt_sizes: torch.Tensor) -> torch.Tensor:
169
170
171
172
173
174
175
176
177
178
        assert x.shape[0] == tgt_sizes.shape[0]
        bs = x.shape[0]

        device = x.device
        dtype = x.dtype

        patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1]

        self._adjust_pos_cache(tgt_sizes, device=device)

Jee Jee Li's avatar
Jee Jee Li committed
179
180
181
        max_patch_len = patch_len.max().item()
        assert isinstance(max_patch_len, int)

182
183
184
185
186
187
        key_padding_mask = torch.zeros((bs, max_patch_len),
                                       dtype=torch.bool,
                                       device=device)

        pos_embed = []
        for i in range(bs):
Jee Jee Li's avatar
Jee Jee Li committed
188
            tgt_h, tgt_w = tgt_sizes[i].tolist()
189
190
191
192
193
194
195
196
            pos_embed.append(self.pos_embed[:tgt_h, :tgt_w, :].reshape(
                (tgt_h * tgt_w, -1)).to(dtype))  # patches * D
            key_padding_mask[i, patch_len[i]:] = True
        pos_embed = torch.nn.utils.rnn.pad_sequence(pos_embed,
                                                    batch_first=True,
                                                    padding_value=0.0).permute(
                                                        1, 0,
                                                        2)  # BLD => L * B * D
Jee Jee Li's avatar
Jee Jee Li committed
197
        x, _ = self.kv_proj(x)  # B * L * D
198
199
200
201
202
203
204
205
        x = self.ln_kv(x).permute(1, 0, 2)  # L * B * D

        q = self.ln_q(self.query)  # Q * D

        out = self.attn(
            self._repeat(q, bs),  # Q * B * D
            x + pos_embed,  # L * B * D +  L * B * D
            x,
Jee Jee Li's avatar
Jee Jee Li committed
206
207
            key_padding_mask=key_padding_mask,
        )[0]
208
209
210
211
212
213
214
215
        #  out: Q * B * D
        x = out.permute(1, 0, 2)  # B * Q * D

        x = self.ln_post(x)
        x = x @ self.proj
        return x


216
217
218
219
220
221
222
223
224
225
226
227
228
def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
    version_float = getattr(config, "version", None)

    # The old configs do not include version number
    # TODO: Remove this after the HF repos are updated
    if version_float is None:
        if config.hidden_size == 2304 and config.query_num == 64:
            return (2, 0)
        return (2, 5)
    version_str = str(version_float)
    return tuple(int(x) for x in version_str.split("."))


229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]):
    image_num_slices = hf_inputs.get("image_num_slices", torch.empty(0))
    video_num_slices = hf_inputs.get("video_num_slices", torch.empty(0))

    return dict(
        pixel_values=MultiModalFieldConfig.flat_from_sizes(
            "image", image_num_slices),
        image_sizes=MultiModalFieldConfig.batched("image"),
        tgt_sizes=MultiModalFieldConfig.flat_from_sizes(
            "image", image_num_slices),
        image_num_slices=MultiModalFieldConfig.batched("image"),
        image_embeds=MultiModalFieldConfig.flat_from_sizes(
            "image", image_num_slices),
        video_pixel_values=MultiModalFieldConfig.flat_from_sizes(
            "video", video_num_slices),
        video_image_sizes=MultiModalFieldConfig.batched("video"),
        video_tgt_sizes=MultiModalFieldConfig.flat_from_sizes(
            "video", video_num_slices),
        video_embeds=MultiModalFieldConfig.flat_from_sizes(
            "video", video_num_slices),
        video_num_slices=MultiModalFieldConfig.batched("video"),
    )


class MiniCPMVImageEmbeddingItems(DictEmbeddingItems):

    def __init__(
        self,
        data: Mapping[str, torch.Tensor],
        fields_config: Mapping[str, MultiModalFieldConfig],
    ) -> None:
        super().__init__(
            data,
            modality="image",
            fields_config=fields_config,
            required_fields={"image_embeds", "image_sizes"},
        )

    def get_image_size(self, index: int) -> ImageSize:
        image_size = self.get(index)["image_sizes"].tolist()
        return ImageSize(width=image_size[0], height=image_size[1])


class MiniCPMVVideoEmbeddingItems(DictEmbeddingItems):

    def __init__(
        self,
        data: Mapping[str, torch.Tensor],
        fields_config: Mapping[str, MultiModalFieldConfig],
    ) -> None:
        super().__init__(
            data,
            modality="video",
            fields_config=fields_config,
            required_fields={"video_embeds", "video_image_sizes"},
        )

    def get_frame_size(self, index: int) -> ImageSize:
        frame_size = self.get(index)["video_image_sizes"].tolist()
        return ImageSize(width=frame_size[0], height=frame_size[1])

    def get_num_frames(self, index: int) -> int:
        return len(self.get(index)["video_image_sizes"])


294
295
296
297
298
299
300
class MiniCPMVMultiModalDataParser(MultiModalDataParser):

    def _parse_image_data(
        self,
        data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
    ) -> ModalityDataItems[Any, Any]:
        if isinstance(data, dict):
301
302
303
304
305
            return MiniCPMVImageEmbeddingItems(
                data,
                fields_config=_minicpmv_field_config(data),
            )

306
307
308
309
310
311
312
        return super()._parse_image_data(data)

    def _parse_video_data(
        self,
        data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]],
    ) -> ModalityDataItems[Any, Any]:
        if isinstance(data, dict):
313
314
315
316
317
            return MiniCPMVVideoEmbeddingItems(
                data,
                fields_config=_minicpmv_field_config(data),
            )

318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
        return super()._parse_video_data(data)


class MiniCPMVProcessingInfo(BaseProcessingInfo):
    image_pattern = "(<image>./</image>)"
    video_pattern = "(<video>./</video>)"

    def get_hf_config(self):
        return self.ctx.get_hf_config()

    def get_hf_processor(
        self,
        **kwargs: object,
    ):
        hf_processor = self.ctx.get_hf_processor()
333
334
335
336
337
338
339
340
341

        # NumPy arrays are considered as Iterable but not Sequence in
        # https://github.com/huggingface/transformers/blob/main/src/transformers/image_transforms.py#L428
        image_processor = hf_processor.image_processor  # type: ignore
        for attr in ("mean", "std"):
            val = getattr(image_processor, attr)
            if isinstance(val, np.ndarray):
                setattr(image_processor, attr, val.tolist())

342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
        return hf_processor

    def get_image_processor(self):
        hf_processor = self.get_hf_processor()
        image_processor = hf_processor.image_processor  # type: ignore
        return image_processor

    def get_model_version(self):
        return get_version_by_config(self.get_hf_config())

    def get_supported_mm_modalities(self) -> List[str]:
        if self.get_model_version() == (2, 6):
            return ["image", "video"]
        else:
            return ["image"]

    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        if self.get_model_version() == (2, 6):
            return {"image": None, "video": None}
        else:
            return {"image": None}

364
365
366
367
368
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
        mm_max_tokens = {"image": self.get_max_image_tokens()}
        if self.get_model_version() == (2, 6):
            mm_max_tokens["video"] = self.get_max_video_tokens(seq_len)
        return mm_max_tokens

    def get_max_video_frame_tokens(self) -> int:
        frame_size = self.get_video_frame_size_with_most_features()
        return self.get_num_image_tokens(frame_size,
                                         self.get_video_max_slice_num())

    def get_max_video_tokens(self, seq_len: int) -> int:
        return self.get_max_video_frame_tokens(
        ) * self.get_num_frames_with_most_features(seq_len)

    def get_slice_query_num(self) -> int:
        hf_config = self.get_hf_config()
        query_num = getattr(hf_config, "query_num", 64)
        return query_num

    def get_max_slice_num(self) -> int:
        hf_config = self.get_hf_config()
        max_slice_num = getattr(hf_config, "max_slice_num", 9)
        return max_slice_num

    def get_sliced_grid(self, image_size: ImageSize,
                        max_slice_num: int) -> Tuple[int, int]:
        if self.get_model_version() == (2, 6):
            slice_grid = self.get_image_processor().get_sliced_grid(
                image_size, max_slice_num)
        else:
            slice_grid = self.get_image_processor().get_sliced_grid(image_size)
        return slice_grid

    def get_num_image_tokens(self, image_size: ImageSize,
                             max_slice_num: int) -> int:
        slice_grid = self.get_sliced_grid(image_size, max_slice_num)
        num_tokens = self.get_slice_query_num(
        ) + 2  # <image>(<unk> * query_num)</image>
        if slice_grid is not None:
            if self.get_model_version() == (2, 6):
                num_additional_tokens = 0
            else:
                # <slice><image>(<unk> * query_num)</image></slice>
                num_additional_tokens = 2
            num_tokens += ((self.get_slice_query_num() + 2) \
                            * slice_grid[0] * slice_grid[1]) \
                            + slice_grid[1] - 1 + num_additional_tokens
        return num_tokens

    def get_image_slice_nums(self, image_size: torch.Tensor,
                             max_slice_nums: int) -> int:
        grid = self.get_sliced_grid(image_size, max_slice_nums)
        return 1 if grid is None else grid[0] * grid[1] + 1

    def get_max_image_tokens(self) -> int:
        image_size = self.get_image_size_with_most_features()
        return self.get_num_image_tokens(image_size, self.get_max_slice_num())

    def get_image_size_with_most_features(self) -> ImageSize:
        # Result in the max possible feature size (h:w = 9:1)
        return self.get_default_image_sizes(self.get_max_slice_num())

    def get_video_max_slice_num(self) -> int:
        return 1
433

434
435
    def get_video_frame_size_with_most_features(self) -> ImageSize:
        return self.get_default_image_sizes(self.get_video_max_slice_num())
436

437
438
439
440
    def get_max_video_frames(self, max_tokens: int) -> int:
        num_frame_tokens = self.get_max_video_frame_tokens()
        num_frames = max_tokens // num_frame_tokens
        return num_frames
441

442
443
444
445
    def get_num_frames_with_most_features(self, seq_len: int) -> int:
        mm_config = self.ctx.get_mm_config()
        max_images = mm_config.limit_per_prompt.get("image", 1)
        max_videos = mm_config.limit_per_prompt.get("video", 1)
446

447
448
449
450
451
452
        # count <image_idx></image_idx> tokens
        # which are not in get_max_image_tokens
        max_image_tokens = self.get_max_image_tokens(
        ) * max_images + 4 * max_images
        max_total_frames = self.get_max_video_frames(seq_len -
                                                     max_image_tokens)
453

454
        num_frames = max(max_total_frames // max(max_videos, 1), 1)
455

456
        return num_frames
457

458
459
460
    def get_default_image_sizes(self, num_slices: int) -> ImageSize:
        image_size = getattr(self.get_hf_config(), "image_size", 448)
        return ImageSize(width=image_size, height=image_size * num_slices)
461
462


463
464
465
466
467
468
_I = TypeVar("_I",
             bound=MiniCPMVProcessingInfo,
             default=MiniCPMVProcessingInfo)


class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
469

470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
    def get_dummy_processor_inputs(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> ProcessorInputs:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

        image_width, image_height = \
            self.info.get_image_size_with_most_features()
        video_width, video_height = \
            self.info.get_video_frame_size_with_most_features()
        num_video_frames = \
            self.info.get_num_frames_with_most_features(seq_len)

        mm_data = {
            "image":
            self._get_dummy_images(width=image_width,
                                   height=image_height,
                                   num_images=num_images),
            "video": [
                self._get_dummy_images(width=video_width,
                                       height=video_height,
                                       num_images=num_video_frames)
            ] * num_videos,
        }

        image_prompt_texts = self.info.image_pattern * num_images
        video_prompt_texts = self.info.video_pattern * num_videos

        return ProcessorInputs(prompt_text=image_prompt_texts +
                               video_prompt_texts,
                               mm_data=mm_data)
503

504

505
class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
506
507
508
509
510
511
512
513

    def _get_data_parser(self) -> MultiModalDataParser:
        return MiniCPMVMultiModalDataParser()

    def get_slice_image_placeholder(self, image_size: ImageSize,
                                    **kwargs) -> str:
        image_processor = self.info.get_image_processor()
        version = self.info.get_model_version()
514
        if version == (2, 0) or version == (2, 5):
515
516
            return image_processor.get_slice_image_placeholder(image_size)
        return image_processor.get_slice_image_placeholder(
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
            image_size, **kwargs)

    def get_image_prompt_texts(self,
                               image_size: ImageSize,
                               image_idx: int = 0) -> str:
        prompt_texts = self.get_slice_image_placeholder(image_size,
                                                        image_idx=image_idx)
        return prompt_texts

    def get_video_prompt_texts(self, image_size: ImageSize,
                               num_frames: int) -> str:
        prompt_texts = "".join(
            self.get_slice_image_placeholder(
                image_size=image_size,
                image_idx=0,
                max_slice_nums=self.info.get_video_max_slice_num(),
                use_image_id=False) for image_idx in range(num_frames))
        return prompt_texts

    def get_special_tokens(self) -> Dict[str, torch.Tensor]:
        tokenizer = self.info.get_tokenizer()
        special_tokens = {
            "im_start_id": torch.tensor(tokenizer.im_start_id),
            "im_end_id": torch.tensor(tokenizer.im_end_id)
        }
        if hasattr(tokenizer, "slice_start_id"):
            special_tokens["slice_start_id"] = torch.tensor(
                tokenizer.slice_start_id)
            special_tokens["slice_end_id"] = torch.tensor(
                tokenizer.slice_end_id)
        return special_tokens

    @staticmethod
    def repack_processor_outputs(outputs: Any) -> BatchFeature:
        valid_keys = ["pixel_values", "image_sizes", "tgt_sizes"]
        outputs = {key: outputs[key][0] for key in valid_keys}
        return outputs

    def process_images(self, mm_data: Mapping[str, object],
                       mm_kwargs: Mapping[str, object]) -> Dict[str, object]:
        images = mm_data.pop("images", [])
        image_embeds = mm_data.pop("image_embeds", [])
        if isinstance(images, Image.Image):
            images = [images]
        if isinstance(images, (list, torch.Tensor)) and len(images) > 0:
            image_outputs = super()._call_hf_processor(
                prompt=self.info.image_pattern * len(images),
                mm_data={"images": images},
                mm_kwargs=mm_kwargs)
            image_outputs = MiniCPMVMultiModalProcessor.\
                repack_processor_outputs(image_outputs)
        elif len(image_embeds) > 0:
            image_sizes = mm_data.pop("image_sizes", None)
            image_outputs = {
                "image_embeds": torch.cat(image_embeds),
                "image_sizes": image_sizes
            }
574
        else:
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
            image_outputs = {}
        return image_outputs

    def process_videos(self, mm_data: Mapping[str, object],
                       mm_kwargs: Mapping[str, object]) -> Dict[str, object]:
        videos = mm_data.pop("videos", [])
        video_embeds = mm_data.pop("video_embeds", [])
        if len(videos) > 0 and isinstance(videos[0], Image.Image):
            videos = [videos]
        if isinstance(videos, list) and len(videos) > 0:
            video_outputs = {
                "video_pixel_values": [],
                "video_image_sizes": [],
                "video_tgt_sizes": [],
                "num_frames": []
            }
            for video in videos:
                parsed_video = []
                for frame in video:
                    if isinstance(frame, np.ndarray):
                        parsed_video.append(Image.fromarray(frame))
                    else:
                        parsed_video.append(frame)
                video = parsed_video
                single_video_outputs = super()._call_hf_processor(
                    prompt=self.info.image_pattern * len(video),
                    mm_data={"images": video},
                    mm_kwargs={
                        **mm_kwargs, "max_slice_nums":
                        self.info.get_video_max_slice_num()
                    })
                video_outputs["num_frames"].append(len(video))
                for key in single_video_outputs:
                    if "video_" + key in video_outputs:
                        if key == "image_sizes":
                            video_outputs["video_" + key].append(
                                single_video_outputs[key][0][0])
                        else:
                            video_outputs["video_" +
                                          key] += single_video_outputs[key][0]
        elif len(video_embeds):
            image_sizes = mm_data.pop("image_sizes", None)
            num_frames = mm_data.pop("num_frames", None)
            video_outputs = {
                "video_embeds": torch.cat(video_embeds),
                "video_image_sizes": image_sizes,
                "num_frames": num_frames
            }
        else:
            video_outputs = {}
        return video_outputs
626

627
628
    def get_placeholder_match_pattern(self) -> str:
        return r"\(<(image|video)>./</\1>\)"
629

630
631
    def get_placeholder_split_pattern(self) -> str:
        return r"\(<(?:image|video)>./</(?:image|video)>\)"
632

633
634
635
636
637
    def process_mm_inputs(self, mm_data, mm_kwargs) -> object:
        return {
            "image": self.process_images(mm_data, mm_kwargs),
            "video": self.process_videos(mm_data, mm_kwargs)
        }
638

639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
    def get_input_modalities(self, mm_data) -> List[str]:
        supported_mm_modalities = self.info.get_supported_mm_modalities()
        input_modalities = []
        for modality in supported_mm_modalities:
            if modality in mm_data and mm_data[modality] != {}:
                input_modalities.append(modality)
        return input_modalities

    def get_modality_num_counter(self, modality: str) -> str:
        if modality == "image":
            return "image_sizes"
        elif modality == "video":
            return "video_image_sizes"

    def get_num_slices_by_modality(self, inputs: Dict[str, object],
                                   modality: str, index: int) -> int:
        if modality == "image":
            return self.info.get_image_slice_nums(
                inputs[modality]["image_sizes"][index],
                self.info.get_max_slice_num())
        elif modality == "video":
            return self.info.get_image_slice_nums(
                inputs[modality]["video_image_sizes"][index],
                self.info.get_video_max_slice_num()
            ) * inputs[modality]["num_frames"][index]
        else:
665
            raise ValueError(f"Unexpected modality: {modality}")
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689

    def check_mm_inputs(self, inputs: Dict[str, object],
                        matches: List[str]) -> None:
        counts = Counter(matches)
        for modality, count in counts.items():
            if modality not in inputs or not inputs[modality]:
                raise ValueError(f"None input data of {modality}."
                                 "But prompt requires.")
            counter_key = self.get_modality_num_counter(modality)
            if len(inputs[modality][counter_key]) != count:
                raise ValueError(f"The prompt requires {count} "
                                 f"{modality} inputs while you pass "
                                 f"{len(inputs[modality][counter_key])}")

    def get_prompt_texts_by_modality(self, inputs: Dict[str, object],
                                     modality: str, index: int) -> str:
        if modality == "image":
            return self.get_image_prompt_texts(
                inputs["image"]["image_sizes"][index], index)
        elif modality == "video":
            return self.get_video_prompt_texts(
                inputs["video"]["video_image_sizes"][index],
                inputs["video"]["num_frames"][index])
        else:
690
            raise ValueError(f"Unexpected modality: {modality}")
691

692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
    def call_base_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        return super()._call_hf_processor(prompt=prompt,
                                          mm_data=mm_data,
                                          mm_kwargs=mm_kwargs)

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        # Do not support combination inputs of images and videos for now
        # Try to handle interleaved multimodal data
        tokenizer = self.info.get_tokenizer()
        inputs = self.process_mm_inputs(mm_data, mm_kwargs)
        mm_input_modalities = self.get_input_modalities(inputs)
        num_mm_slices = {modality: [] for modality in mm_input_modalities}
        for modality in mm_input_modalities:
            num_counter_key = self.get_modality_num_counter(modality)
            for index in range(len(inputs[modality][num_counter_key])):
                num_mm_slices[modality].append(
                    self.get_num_slices_by_modality(inputs, modality, index))
        return {
            "input_ids": np.array([tokenizer.encode(prompt)]),
            **{
                key: value
                for modality in inputs
                for key, value in inputs[modality].items()
            },
            **{
                f"{modality}_num_slices": num_mm_slices[modality]
                for modality in mm_input_modalities
            }
        }
731

732
733
734
735
736
737
738
739
    def _hf_processor_applies_repl(
        self,
        prompt_text: str,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> bool:
        return False

740
741
742
743
744
745
746
    def _get_prompt_replacements(
            self, mm_items: MultiModalDataItems,
            hf_processor_mm_kwargs: Mapping[str, Any],
            out_mm_kwargs: MultiModalKwargs) -> List[PromptReplacement]:
        placeholder = {
            "image": self.info.image_pattern,
            "video": self.info.video_pattern,
747
        }
748

749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
        def get_replacement_minicpmv(item_idx: int, modality: str):
            if modality == "image":
                return self.get_image_prompt_texts(
                    mm_items["image"].get_image_size(item_idx), item_idx)
            else:  # video
                return self.get_video_prompt_texts(
                    mm_items["video"].get_frame_size(item_idx),
                    mm_items["video"].get_num_frames(item_idx))

        return [
            PromptReplacement(modality=modality,
                              target=placeholder[modality],
                              replacement=partial(get_replacement_minicpmv,
                                                  modality=modality))
            for modality in ("image", "video")
        ]
765

766
767
    def _get_mm_fields_config(
        self,
768
        hf_inputs: BatchFeature,
769
770
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
771
        return _minicpmv_field_config(hf_inputs)
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800

    def apply(
        self,
        prompt: Union[str, List[int]],
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> MultiModalInputs:
        supported_mm_modalities = self.info.get_supported_mm_modalities()
        if isinstance(prompt, list):
            prompt = self.info.get_tokenizer().decode(prompt)
        matches = re.findall(self.get_placeholder_match_pattern(), prompt)
        mm_orders = {
            f"{modality}_orders":
            torch.tensor(
                [index for index, m in enumerate(matches) if m == modality])
            for modality in supported_mm_modalities
        }
        result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
        # Exclude <image_id>x</image_id> from placeholders
        if "image" in result["mm_placeholders"] and \
            self.info.get_model_version() == (2, 6):
            result["mm_placeholders"]["image"] = [
                PlaceholderRange(offset=p["offset"] + 3 + idx // 10,
                                 length=p["length"] - 3 - idx // 10)
                for idx, p in enumerate(result["mm_placeholders"]["image"])
            ]
        result["mm_kwargs"].update(**mm_orders)
        result["mm_kwargs"].update(**self.get_special_tokens())
        return result
801
802


803
class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
Jee Jee Li's avatar
Jee Jee Li committed
804
805
806
807
    """
    The abstract class of MiniCPMV can only be inherited, but cannot be
    instantiated.
    """
808

809
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
810
811
812
        config = vllm_config.model_config.hf_config
        multimodal_config = vllm_config.model_config.multimodal_config
        quant_config = vllm_config.quant_config
813
        super().__init__()
814
815
816
817
        # All MiniCPM-V models disable `tie_word_embeddings` but
        # `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot
        # check `tie_word_embeddings` until vLLM integrate MiniCPM-V model
        # and config class
818
819
820
        self.config = config
        self.multimodal_config = multimodal_config

821
        self.version = get_version_by_config(self.config)
822
823
824
825
826
        self.llm = self.init_llm(vllm_config=vllm_config,
                                 prefix=maybe_prefix(prefix, "llm"))
        self.vpm = self.init_vision_module(config,
                                           quant_config,
                                           prefix=maybe_prefix(prefix, "vpm"))
Jee Jee Li's avatar
Jee Jee Li committed
827
828
        self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else
                           self.vpm.embeddings.embed_dim)
Alphi's avatar
Alphi committed
829
        self.embed_dim = self.config.hidden_size
830

831
832
833
        self.resampler = self.init_resampler(self.embed_dim,
                                             self.vision_dim,
                                             quant_config=quant_config,
834
835
                                             prefix=maybe_prefix(
                                                 prefix, "resampler"))
836

837
838
839
        self.make_empty_intermediate_tensors = (
            self.llm.make_empty_intermediate_tensors)

840
841
842
843
844
845
846
    @cached_property
    def sampler(self):
        if hasattr(self.llm, "sampler"):
            return self.llm.sampler

        return get_sampler()

847
    def get_embedding_with_vision(
Jee Jee Li's avatar
Jee Jee Li committed
848
849
        self,
        input_ids: torch.Tensor,
850
        image_inputs: Optional[MiniCPMVImageInputs],
Jee Jee Li's avatar
Jee Jee Li committed
851
    ) -> Tuple[torch.Tensor, torch.Tensor]:
852
        vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids)
Jee Jee Li's avatar
Jee Jee Li committed
853
854
855

        if image_inputs is None:  # No image
            vision_hidden_states = torch.tensor([], device=input_ids.device)
856
        else:
857
858
859
860
861
862
            if image_inputs["type"] == "image_embeds":
                vision_hidden_states = (image_inputs["data"].type(
                    vlm_embedding.dtype).to(vlm_embedding.device))
            else:
                vision_hidden_states = self.get_vision_hidden_states(
                    image_inputs)
Jee Jee Li's avatar
Jee Jee Li committed
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877

            # See NOTE in _parse_and_validate_inputs
            image_bounds = image_inputs["image_bounds"]
            if len(image_bounds) > 0:
                image_indices = torch.stack([
                    torch.arange(start, end, dtype=torch.long)
                    for start, end in image_bounds.tolist()
                ]).to(vlm_embedding.device)
                vlm_embedding.scatter_(
                    0,
                    image_indices.view(-1, 1).repeat(1,
                                                     vlm_embedding.shape[-1]),
                    vision_hidden_states.view(-1,
                                              vision_hidden_states.shape[-1]),
                )
878

Jee Jee Li's avatar
Jee Jee Li committed
879
        return vlm_embedding, vision_hidden_states
880

881
882
883
884
885
886
887
888
889
890
891
892
893
894
    def _get_image_bounds(
            self,
            input_ids: torch.Tensor,
            im_start_id: torch.Tensor,
            im_end_id: torch.Tensor,
            slice_start_id: Optional[torch.Tensor] = None,
            slice_end_id: Optional[torch.Tensor] = None) -> torch.Tensor:
        # All the images in the batch should share the same special image
        # bound token ids.
        start_cond = input_ids == im_start_id[0]
        end_cond = input_ids == im_end_id[0]
        if slice_start_id is not None:
            start_cond |= (input_ids == slice_start_id[0])
            end_cond |= (input_ids == slice_end_id[0])
Alphi's avatar
Alphi committed
895

Jee Jee Li's avatar
Jee Jee Li committed
896
        image_start_tokens, = torch.where(start_cond)
897
        image_start_tokens += 1
Jee Jee Li's avatar
Jee Jee Li committed
898
        image_end_tokens, = torch.where(end_cond)
Alphi's avatar
Alphi committed
899
        valid_image_nums = max(len(image_start_tokens), len(image_end_tokens))
Jee Jee Li's avatar
Jee Jee Li committed
900

901
        if valid_image_nums == 0:
Jee Jee Li's avatar
Jee Jee Li committed
902
903
904
            return torch.zeros((0, 2), device=input_ids.device)

        return torch.hstack([
905
906
907
908
            image_start_tokens[:valid_image_nums].unsqueeze(-1),
            image_end_tokens[:valid_image_nums].unsqueeze(-1),
        ])

909
    def _parse_and_validate_image_inputs(
Jee Jee Li's avatar
Jee Jee Li committed
910
911
912
        self,
        input_ids: torch.Tensor,
        **kwargs: object,
913
    ) -> Optional[MiniCPMVImageInputs]:
914
915
916
917
918
919
920
921
922
923
924
        mm_data = {
            "image": {
                key: kwargs.pop(key, [])
                for key in ["pixel_values", "tgt_sizes", "image_num_slices"]
            },
            "video": {
                "pixel_values": kwargs.pop("video_pixel_values", []),
                "tgt_sizes": kwargs.pop("video_tgt_sizes", []),
                "video_num_slices": kwargs.pop("video_num_slices", [])
            }
        }
925
926
927
928
        im_start_id = kwargs.pop("im_start_id", None)
        im_end_id = kwargs.pop("im_end_id", None)
        slice_start_id = kwargs.pop("slice_start_id", None)
        slice_end_id = kwargs.pop("slice_end_id", None)
929
930
931
932
933
934
        mm_orders = {
            f"{modality}": kwargs.pop(f"{modality}_orders", None)
            for modality in ["image", "video", "audio"]
        }
        batch_size = max(len(mm_data["image"]["pixel_values"]),
                         len(mm_data["video"]["pixel_values"]))
935
        image_embeds = kwargs.pop("image_embeds", None)
936
937
938
939
940
941
942
        video_embeds = kwargs.pop("video_embeds", None)
        if image_embeds is not None and video_embeds is not None:
            raise ValueError(
                "Incorrect inputs for vision embeddings. "
                "Image embeds and video embeds can not exist simultaneously.")
        if video_embeds is not None:
            image_embeds = video_embeds
943
        if image_embeds is not None:
944
945
946
            if not isinstance(image_embeds, (torch.Tensor, list)):
                raise ValueError(f"Incorrect type of image embeds. "
                                 f"Got type: {type(image_embeds)}")
947
948
            image_embeds = torch.concat(
                [image_embeds[i] for i in range(len(image_embeds))])
949

950
951
952
953
954
955
956
            return MiniCPMVImageEmbeddingInputs(
                image_bounds=self._get_image_bounds(input_ids, im_start_id,
                                                    im_end_id, slice_start_id,
                                                    slice_end_id),
                data=image_embeds,
                type="image_embeds",
            )
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
        for modality, modality_mm_data in mm_data.items():
            if not isinstance(modality_mm_data["pixel_values"],
                              (torch.Tensor, list)):
                raise ValueError(
                    "Incorrect type of pixel values. "
                    f"Got type: {type(modality_mm_data['pixel_values'])}")

            if not isinstance(modality_mm_data["tgt_sizes"],
                              (torch.Tensor, list)):
                raise ValueError(
                    "Incorrect type of target sizes. "
                    f"Got type: {type(modality_mm_data['tgt_sizes'])}")

            if len(modality_mm_data["pixel_values"]) != len(
                    modality_mm_data["tgt_sizes"]):
                raise ValueError(
                    "Inconsistent batch lengths, found: "
                    f"{len(modality_mm_data['pixel_values'])} vs. "
                    f"{len(modality_mm_data['tgt_sizes'])}")
Jee Jee Li's avatar
Jee Jee Li committed
976
977
978

        pixel_values_flat: List[torch.Tensor] = []
        tgt_sizes_flat: List[torch.Tensor] = []
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
        for b in range(batch_size):
            mm_counts = {"image": 0, "video": 0} if self.version == (2, 6) \
                        else {"image": 0}
            mm_slice_counts = {"image": 0, "video": 0} \
                               if self.version == (2, 6) else {"image": 0}
            mm_orders_b = [(index, modality) for modality in mm_counts
                           for index in mm_orders[modality][b]]
            for _, modality in sorted(mm_orders_b, key=lambda x: x[0]):
                pos = mm_counts[modality]
                num_slices = mm_data[modality][f"{modality}_num_slices"][b][
                    pos]
                slice_start_idx = mm_slice_counts[modality]
                slice_end_idx = slice_start_idx + num_slices
                pixel_values_flat += mm_data[modality]["pixel_values"][b][
                    slice_start_idx:slice_end_idx]
                tgt_sizes_flat += mm_data[modality]["tgt_sizes"][b][
                    slice_start_idx:slice_end_idx]
                mm_counts[modality] += 1
                mm_slice_counts[modality] += num_slices
Jee Jee Li's avatar
Jee Jee Li committed
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008

        # NOTE: Input IDs does not contain image tokens during memory profiling,
        # so we allow it to be empty
        if len(pixel_values_flat) != len(tgt_sizes_flat):
            raise ValueError("Inconsistent flattened lengths, found: "
                             f"{len(pixel_values_flat)} vs. "
                             f"{len(tgt_sizes_flat)}")

        if len(pixel_values_flat) == 0:
            return None

1009
1010
1011
1012
1013
1014
1015
        if im_start_id is None:
            return None

        return MiniCPMVImagePixelInputs(
            image_bounds=self._get_image_bounds(input_ids, im_start_id,
                                                im_end_id, slice_start_id,
                                                slice_end_id),
1016
            data=pixel_values_flat,
Jee Jee Li's avatar
Jee Jee Li committed
1017
            tgt_sizes=torch.stack(tgt_sizes_flat),
1018
            type="pixel_values",
Jee Jee Li's avatar
Jee Jee Li committed
1019
        )
1020

1021
1022
1023
1024
    def _parse_and_validate_inputs(self, input_ids: torch.Tensor,
                                   **kwargs: object):
        return self._parse_and_validate_image_inputs(input_ids, **kwargs)

1025
1026
1027
1028
1029
1030
1031
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
        intermediate_tensors: Optional[IntermediateTensors] = None,
Jee Jee Li's avatar
Jee Jee Li committed
1032
1033
        **kwargs: Any,
    ) -> torch.Tensor:
1034
1035
1036
        if intermediate_tensors is not None:
            vlm_embeddings = None
        else:
1037
1038
1039
1040
            image_inputs = \
                self._parse_and_validate_inputs(input_ids, **kwargs)
            vlm_embeddings, _ = self.get_embedding_with_vision(
                input_ids, image_inputs)
Jee Jee Li's avatar
Jee Jee Li committed
1041

1042
1043
1044
1045
1046
        # always pass the input via `inputs_embeds`
        # to make sure the computation graph is consistent
        # for `torch.compile` integration
        input_ids = None

1047
        output = self.llm.model(
1048
            input_ids=input_ids,
Jee Jee Li's avatar
Jee Jee Li committed
1049
1050
1051
1052
1053
1054
            positions=positions,
            kv_caches=kv_caches,
            attn_metadata=attn_metadata,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=vlm_embeddings,
        )
1055
1056
        return output

1057
1058
1059
1060
1061
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
1062
        return self.llm.compute_logits(hidden_states, sampling_metadata)
1063
1064
1065
1066
1067
1068

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
Alphi's avatar
Alphi committed
1069
        next_tokens = self.sampler(logits, sampling_metadata)
1070
1071
        return next_tokens

1072
1073
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
1074
1075
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)
Jee Jee Li's avatar
Jee Jee Li committed
1076

1077
1078
1079
1080
1081
1082
1083
1084
    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(language_model="llm",
                                                connector="resampler",
                                                tower_model="vpm")

Jee Jee Li's avatar
Jee Jee Li committed
1085
1086
    def init_llm(
        self,
1087
        vllm_config: VllmConfig,
1088
        prefix: str = "",
Jee Jee Li's avatar
Jee Jee Li committed
1089
1090
1091
    ) -> nn.Module:
        raise NotImplementedError

1092
1093
1094
1095
    def init_vision_module(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig],
1096
        prefix: str = "",
1097
    ) -> nn.Module:
Jee Jee Li's avatar
Jee Jee Li committed
1098
1099
        raise NotImplementedError

1100
1101
1102
1103
1104
    def init_resampler(self,
                       embed_dim: int,
                       vision_dim: int,
                       quant_config: Optional[QuantizationConfig] = None,
                       prefix: str = "") -> nn.Module:
Jee Jee Li's avatar
Jee Jee Li committed
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
        raise NotImplementedError

    def get_vision_embedding(
        self,
        pixel_values: List[torch.Tensor],
        patch_attn_mask: Optional[torch.Tensor] = None,
        tgt_sizes: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        raise NotImplementedError

1115
1116
    def get_vision_hidden_states(self,
                                 data: MiniCPMVImageInputs) -> torch.Tensor:
Jee Jee Li's avatar
Jee Jee Li committed
1117
1118
1119
        raise NotImplementedError


1120
class MiniCPMV2_0(MiniCPMVBaseModel):
Jee Jee Li's avatar
Jee Jee Li committed
1121

1122
1123
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)
Jee Jee Li's avatar
Jee Jee Li committed
1124
1125
1126
1127
        assert self.version == (2, 0)

    def init_llm(
        self,
1128
        vllm_config: VllmConfig,
1129
        prefix: str = "",
Jee Jee Li's avatar
Jee Jee Li committed
1130
    ) -> nn.Module:
1131
        return MiniCPMForCausalLM(vllm_config=vllm_config, prefix=prefix)
Jee Jee Li's avatar
Jee Jee Li committed
1132

1133
1134
1135
1136
    def init_vision_module(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig],
1137
        prefix: str = "",
1138
    ) -> nn.Module:
1139
        # TODO: refactor vision model through timm wrapper from transformers
Jee Jee Li's avatar
Jee Jee Li committed
1140
1141
1142
1143
        try:
            import timm
        except ImportError:
            raise ImportError("Please install timm==0.9.10") from ImportError
1144

Jee Jee Li's avatar
Jee Jee Li committed
1145
1146
1147
1148
1149
1150
1151
1152
1153
        with set_default_torch_dtype(torch.float16):
            model = timm.create_model(
                "vit_so400m_patch14_siglip_384.webli",
                pretrained=False,
                num_classes=0,
                dynamic_img_size=True,
                dynamic_img_pad=True,
            )

1154
1155
        model = model.to(dtype=torch.get_default_dtype())

Jee Jee Li's avatar
Jee Jee Li committed
1156
1157
1158
1159
1160
1161
1162
1163
1164
        if (isinstance(model, timm.models.VisionTransformer)
                and model.attn_pool is not None):
            model.attn_pool = torch.nn.Identity()

        if self.config.drop_vision_last_layer:
            model.blocks = model.blocks[:-1]

        return model

1165
1166
1167
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_tokens(input_ids)

1168
1169
1170
1171
1172
    def init_resampler(self,
                       embed_dim: int,
                       vision_dim: int,
                       quant_config: Optional[QuantizationConfig] = None,
                       prefix: str = "") -> nn.Module:
Jee Jee Li's avatar
Jee Jee Li committed
1173
        with set_default_torch_dtype(torch.float16):
1174
1175
1176
1177
1178
1179
1180
1181
1182
            resampler = Resampler2(embed_dim=embed_dim,
                                   num_heads=embed_dim // 128,
                                   grid_size=int(
                                       math.sqrt(self.config.query_num)),
                                   kv_dim=vision_dim,
                                   adaptive=False,
                                   do_post_projection=True,
                                   quant_config=quant_config,
                                   prefix=prefix)
Jee Jee Li's avatar
Jee Jee Li committed
1183

1184
        return resampler.to(device="cuda", dtype=torch.get_default_dtype())
Jee Jee Li's avatar
Jee Jee Li committed
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208

    def get_vision_embedding(
        self,
        pixel_values: List[torch.Tensor],
        patch_attn_mask: Optional[torch.Tensor] = None,
        tgt_sizes: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        res = []
        dtype = self.vpm.pos_embed.data.dtype
        for pixel_value in pixel_values:
            H, W = pixel_value[0].shape[-2:]
            tgt_size = (
                math.ceil(H / self.vpm.patch_embed.patch_size[0]),
                math.ceil(W / self.vpm.patch_embed.patch_size[0]),
            )
            vision_embedding = self.vpm.forward_features(
                pixel_value.unsqueeze(0).type(dtype))
            if (hasattr(self.vpm, "num_prefix_tokens")
                    and self.vpm.num_prefix_tokens > 0):
                vision_embedding = vision_embedding[:, self.vpm.
                                                    num_prefix_tokens:]
            res.append(self.resampler(vision_embedding, tgt_size))
        return torch.vstack(res)

1209
1210
1211
    def get_vision_hidden_states(self,
                                 data: MiniCPMVImageInputs) -> torch.Tensor:
        pixel_values = data["data"]
Jee Jee Li's avatar
Jee Jee Li committed
1212
1213
1214
1215

        return self.get_vision_embedding(pixel_values)


1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
    # LoRA specific attributes
    supported_lora_modules = [
        # vision encoder
        "fc1",
        "fc2",
        "out_proj",
        # language model
        "qkv_proj",  # same name with vision encoder
        "o_proj",
        "gate_up_proj",
        "down_proj",
        # resampler
        "kv_proj",
    ]
1242

1243
1244
    embedding_modules = {}
    embedding_padding_modules = []
Jee Jee Li's avatar
Jee Jee Li committed
1245

1246
1247
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)
Jee Jee Li's avatar
Jee Jee Li committed
1248
1249
1250
1251
        assert self.version == (2, 5)

    def init_llm(
        self,
1252
        vllm_config: VllmConfig,
1253
        prefix: str = "",
Jee Jee Li's avatar
Jee Jee Li committed
1254
    ) -> nn.Module:
1255
        return LlamaForCausalLM(vllm_config=vllm_config, prefix=prefix)
Jee Jee Li's avatar
Jee Jee Li committed
1256

1257
1258
1259
1260
    def init_vision_module(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig],
1261
        prefix: str = "",
1262
1263
    ) -> nn.Module:
        model = Idefics2VisionTransformer(config.vision_config,
1264
1265
                                          quant_config=quant_config,
                                          prefix=prefix)
Jee Jee Li's avatar
Jee Jee Li committed
1266
1267
1268
1269
        if self.config.drop_vision_last_layer:
            model.encoder.layers = model.encoder.layers[:-1]
        return model

1270
1271
1272
1273
1274
    def init_resampler(self,
                       embed_dim: int,
                       vision_dim: int,
                       quant_config: Optional[QuantizationConfig] = None,
                       prefix: str = "") -> nn.Module:
Jee Jee Li's avatar
Jee Jee Li committed
1275
        with set_default_torch_dtype(torch.float16):
1276
1277
1278
1279
1280
1281
            resampler = Resampler2_5(num_queries=self.config.query_num,
                                     embed_dim=embed_dim,
                                     num_heads=embed_dim // 128,
                                     kv_dim=vision_dim,
                                     quant_config=quant_config,
                                     prefix=prefix)
1282
1283

        return resampler.to(device="cuda", dtype=torch.get_default_dtype())
Jee Jee Li's avatar
Jee Jee Li committed
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295

    def get_vision_embedding(
        self,
        pixel_values: List[torch.Tensor],
        patch_attn_mask: Optional[torch.Tensor] = None,
        tgt_sizes: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        vision_embedding = self.vpm(pixel_values,
                                    patch_attention_mask=patch_attn_mask)
        vision_embedding = self.resampler(vision_embedding, tgt_sizes)
        return vision_embedding

1296
1297
1298
    def get_vision_hidden_states(self,
                                 data: MiniCPMVImageInputs) -> torch.Tensor:
        pixel_values = data["data"]
Jee Jee Li's avatar
Jee Jee Li committed
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
        tgt_sizes = data["tgt_sizes"]

        device = self.vpm.embeddings.position_embedding.weight.device
        dtype = self.vpm.embeddings.position_embedding.weight.dtype
        all_pixel_values_lst = [
            i.flatten(end_dim=1).permute(1, 0) for i in pixel_values
        ]

        max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
        assert isinstance(max_patches, int)

        all_pixel_values = torch.nn.utils.rnn.pad_sequence(
            all_pixel_values_lst, batch_first=True, padding_value=0.0)
        B, L, _ = all_pixel_values.shape
        all_pixel_values = all_pixel_values.permute(0, 2,
                                                    1).reshape(B, 3, -1, L)

        patch_attn_mask = torch.zeros((B, 1, max_patches),
                                      dtype=torch.bool,
                                      device=device)
        for i in range(B):
            patch_attn_mask[i, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True

        return self.get_vision_embedding(all_pixel_values.type(dtype),
                                         patch_attn_mask, tgt_sizes)


1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
    # LoRA specific attributes
    supported_lora_modules = [
        # vision encoder
        "fc1",
        "fc2",
        "out_proj",
        # language model
        "qkv_proj",  # same name with vision encoder
        "o_proj",
        "gate_up_proj",
        "down_proj",
        # resampler
        "kv_proj",
    ]

    embedding_modules = {}
    embedding_padding_modules = []
Jee Jee Li's avatar
Jee Jee Li committed
1355

1356
1357
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)
1358
        assert self.version == (2, 6)
Jee Jee Li's avatar
Jee Jee Li committed
1359
1360
1361

    def init_llm(
        self,
1362
        vllm_config: VllmConfig,
1363
        prefix: str = "",
Jee Jee Li's avatar
Jee Jee Li committed
1364
    ) -> nn.Module:
1365
        return Qwen2ForCausalLM(vllm_config=vllm_config, prefix=prefix)
Jee Jee Li's avatar
Jee Jee Li committed
1366

1367
1368
1369
1370
    def init_vision_module(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig],
1371
        prefix: str = "",
1372
1373
    ) -> nn.Module:
        model = Idefics2VisionTransformer(config.vision_config,
1374
1375
                                          quant_config=quant_config,
                                          prefix=prefix)
Jee Jee Li's avatar
Jee Jee Li committed
1376
1377
1378
1379
        if self.config.drop_vision_last_layer:
            model.encoder.layers = model.encoder.layers[:-1]
        return model

1380
1381
1382
1383
1384
    def init_resampler(self,
                       embed_dim: int,
                       vision_dim: int,
                       quant_config: Optional[QuantizationConfig] = None,
                       prefix: str = "") -> nn.Module:
Jee Jee Li's avatar
Jee Jee Li committed
1385
        with set_default_torch_dtype(torch.float16):
1386
            # The resampler in 2.6 remains consistent with the one in 2.5.
1387
1388
1389
1390
1391
1392
            resampler = Resampler2_5(num_queries=self.config.query_num,
                                     embed_dim=embed_dim,
                                     num_heads=embed_dim // 128,
                                     kv_dim=vision_dim,
                                     quant_config=quant_config,
                                     prefix=prefix)
1393
1394

        return resampler.to(device="cuda", dtype=torch.get_default_dtype())
Jee Jee Li's avatar
Jee Jee Li committed
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405

    def get_vision_embedding(
        self,
        pixel_values: List[torch.Tensor],
        patch_attn_mask: Optional[torch.Tensor] = None,
        tgt_sizes: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        vision_embedding = self.vpm(
            pixel_values,
            patch_attention_mask=patch_attn_mask,
            tgt_sizes=tgt_sizes,
1406
        )
Jee Jee Li's avatar
Jee Jee Li committed
1407
1408
        return vision_embedding

1409
1410
1411
    def get_vision_hidden_states(self,
                                 data: MiniCPMVImageInputs) -> torch.Tensor:
        pixel_values = data["data"]
Jee Jee Li's avatar
Jee Jee Li committed
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
        tgt_sizes = data["tgt_sizes"]

        device = self.vpm.embeddings.position_embedding.weight.device
        dtype = self.vpm.embeddings.position_embedding.weight.dtype
        all_pixel_values_lst = [
            i.flatten(end_dim=1).permute(1, 0) for i in pixel_values
        ]

        max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
        assert isinstance(max_patches, int)

        all_pixel_values = torch.nn.utils.rnn.pad_sequence(
            all_pixel_values_lst, batch_first=True, padding_value=0.0)
        B, L, _ = all_pixel_values.shape
        all_pixel_values = all_pixel_values.permute(0, 2,
                                                    1).reshape(B, 3, -1, L)

        patch_attn_mask = torch.zeros((B, 1, max_patches),
                                      dtype=torch.bool,
                                      device=device)
        for i in range(B):
            patch_attn_mask[i, 0, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True
        vision_embedding = self.vpm(
            all_pixel_values.type(dtype),
            patch_attention_mask=patch_attn_mask,
            tgt_sizes=tgt_sizes,
1438
        )
Jee Jee Li's avatar
Jee Jee Li committed
1439
1440
1441
1442

        return self.resampler(vision_embedding, tgt_sizes)


1443
1444
1445
_SUPPORT_VERSION = {
    (2, 0): MiniCPMV2_0,
    (2, 5): MiniCPMV2_5,
1446
    (2, 6): MiniCPMV2_6,
1447
1448
1449
}


1450
1451
1452
1453
1454
@MULTIMODAL_REGISTRY.register_processor(
    MiniCPMVMultiModalProcessor,
    info=MiniCPMVProcessingInfo,
    dummy_inputs=MiniCPMVDummyInputsBuilder)
class MiniCPMV(MiniCPMVBaseModel, SupportsMultiModal, SupportsLoRA):
Jee Jee Li's avatar
Jee Jee Li committed
1455
1456
1457
1458
1459
    """
    Different versions of MiniCPMV use different visual encoders and LLMs,
    which is not conducive to the current integration logic of LoRA and
    bitsandbytes in vLLM. Therefore, it is necessary to separate them.
    """
1460
1461
    # Ensure that the LoRA support check passes when the class is not
    # initialized, but set all these attributes to empty.
1462
    # These will be updated when an instance class is selected
1463
1464
1465
1466
1467
    packed_modules_mapping = {}
    supported_lora_modules = []
    embedding_modules = {}
    embedding_padding_modules = []

1468
    def __new__(cls, *, vllm_config: VllmConfig, prefix: str = ""):
1469
        config = vllm_config.model_config.hf_config
Jee Jee Li's avatar
Jee Jee Li committed
1470
1471
1472
1473
1474
1475
1476
1477
1478
        if not hasattr(config, "version"):
            if config.hidden_size == 2304 and config.query_num == 64:
                version = (2, 0)
            else:
                version = (2, 5)
        else:
            version = str(config.version).split(".")
            version = tuple([int(x) for x in version])
        # Dispatch class based on version
1479
1480
        instance_cls = _SUPPORT_VERSION.get(version)
        if instance_cls is None:
1481
1482
            raise ValueError(
                "Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6")
1483
1484
1485
1486
1487
1488
1489
1490

        # quant_config references base class members,
        # so update values before init is called
        cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping)
        cls.supported_lora_modules += instance_cls.supported_lora_modules
        cls.embedding_modules.update(instance_cls.embedding_modules)
        cls.embedding_padding_modules += instance_cls.embedding_padding_modules
        return instance_cls(vllm_config=vllm_config, prefix=prefix)