minicpmv.py 46.6 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Alphi's avatar
Alphi committed
24
"""Inference-only MiniCPM-V model compatible with HuggingFace weights."""
25
import math
26
from collections import defaultdict
27
from collections.abc import Iterable, Mapping, Sequence
28
from functools import cached_property, partial
29
30
from typing import (Any, Callable, Literal, Optional, Set, Tuple, TypedDict,
                    Union)
31

32
import numpy as np
33
import torch
Alphi's avatar
Alphi committed
34
import torch.types
35
from torch import nn
36
from transformers import BatchFeature, PretrainedConfig
37
from typing_extensions import TypeVar
38

39
from vllm.config import VllmConfig
40
from vllm.model_executor.layers.quantization import QuantizationConfig
41
from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2,
42
                                                  get_2d_sincos_pos_embed)
Joe Runde's avatar
Joe Runde committed
43
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
Jee Jee Li's avatar
Jee Jee Li committed
44
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
45
46
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.models.minicpm import MiniCPMForCausalLM
47
from vllm.model_executor.models.module_mapping import MultiModelKeys
48
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
49
from vllm.model_executor.sampling_metadata import SamplingMetadata
50
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
51
52
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
                                    NestedTensors)
53
54
from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem,
                                   ImageProcessorItems, ImageSize,
55
56
                                   ModalityData, ModalityDataItems,
                                   MultiModalDataItems, MultiModalDataParser,
57
                                   VideoItem, VideoProcessorItems)
58
from vllm.multimodal.processing import (BaseMultiModalProcessor,
59
                                        BaseProcessingInfo, PromptReplacement,
60
                                        PromptUpdate, PromptUpdateDetails)
61
from vllm.multimodal.profiling import BaseDummyInputsBuilder
62
from vllm.platforms import current_platform
63
from vllm.sequence import IntermediateTensors
64
from vllm.utils import flatten_2d_lists
65

Jee Jee Li's avatar
Jee Jee Li committed
66
from .idefics2_vision_model import Idefics2VisionTransformer
67
68
69
70
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
                         SupportsMultiModal, SupportsPP)
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
                    merge_multimodal_embeddings)
71

72
73
74
# For profile run
_MAX_FRAMES_PER_VIDEO = 16

75

Jee Jee Li's avatar
Jee Jee Li committed
76
class MiniCPMVImagePixelInputs(TypedDict):
77
    type: Literal["pixel_values"]
78
    pixel_values: list[torch.Tensor]
Jee Jee Li's avatar
Jee Jee Li committed
79
    """
80
    Shape: `(batch_size * num_images * num_slices, num_channels, height, width)`
Jee Jee Li's avatar
Jee Jee Li committed
81
82
83
84
85

    Note that the image size may vary, so we pass it as a list
    instead of a batched tensor.
    """

86
    tgt_sizes: torch.Tensor
Jee Jee Li's avatar
Jee Jee Li committed
87
    """
88
    Shape: `(batch_size * num_images * num_slices, 2)`
Jee Jee Li's avatar
Jee Jee Li committed
89

90
    This should be in `(height, width)` format.
Jee Jee Li's avatar
Jee Jee Li committed
91
92
    """

93
94
95
    num_slices: torch.Tensor
    """Shape: `(batch_size * num_images)`"""

Jee Jee Li's avatar
Jee Jee Li committed
96

97
98
class MiniCPMVImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
99
    image_embeds: Union[torch.Tensor, list[torch.Tensor]]
100
    """
101
    Shape: `(batch_size * num_images, num_slices, hidden_size)`
102
103
104
105
106
107
108
109
110

    `hidden_size` must match the hidden size of language model backbone.
    instead of a batched tensor.
    """


MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs,
                            MiniCPMVImageEmbeddingInputs]

Jee Jee Li's avatar
Jee Jee Li committed
111
112
113
114
115
DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)


class Resampler2_5(BaseResampler):

116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
    def __init__(self,
                 num_queries: int,
                 embed_dim: int,
                 num_heads: int,
                 kv_dim: Optional[int] = None,
                 norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
                 max_size: Tuple[int, int] = (70, 70),
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = "") -> None:
        super().__init__(num_queries,
                         embed_dim,
                         num_heads,
                         kv_dim,
                         norm_layer,
                         quant_config=quant_config,
                         prefix=prefix)
Jee Jee Li's avatar
Jee Jee Li committed
132
133
134

        self.max_size = max_size
        self._set_2d_pos_cache(self.max_size)
135

Alphi's avatar
Alphi committed
136
137
    def _set_2d_pos_cache(self,
                          max_size: Tuple[int, int],
Jee Jee Li's avatar
Jee Jee Li committed
138
139
140
141
142
                          device: torch.types.Device = "cpu") -> None:
        pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim,
                                                max_size,
                                                version=(2, 5))
        pos_embed = torch.from_numpy(pos_embed_arr).float().to(device)
143
144
        self.register_buffer("pos_embed", pos_embed, persistent=False)

Alphi's avatar
Alphi committed
145
    def _adjust_pos_cache(self, tgt_sizes: torch.Tensor,
Jee Jee Li's avatar
Jee Jee Li committed
146
147
148
149
150
                          device: torch.types.Device) -> None:
        max_h = tgt_sizes[:, 0].max().item()
        max_w = tgt_sizes[:, 1].max().item()
        assert isinstance(max_h, int) and isinstance(max_w, int)

151
        if max_h > self.max_size[0] or max_w > self.max_size[1]:
Jee Jee Li's avatar
Jee Jee Li committed
152
            self.max_size = (
153
                max(max_h, self.max_size[0]),
Jee Jee Li's avatar
Jee Jee Li committed
154
155
                max(max_w, self.max_size[1]),
            )
156
157
            self._set_2d_pos_cache(self.max_size, device)

Jee Jee Li's avatar
Jee Jee Li committed
158
159
    def forward(self, x: torch.Tensor,
                tgt_sizes: torch.Tensor) -> torch.Tensor:
160
161
162
163
164
165
166
167
168
169
        assert x.shape[0] == tgt_sizes.shape[0]
        bs = x.shape[0]

        device = x.device
        dtype = x.dtype

        patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1]

        self._adjust_pos_cache(tgt_sizes, device=device)

Jee Jee Li's avatar
Jee Jee Li committed
170
171
172
        max_patch_len = patch_len.max().item()
        assert isinstance(max_patch_len, int)

173
174
175
176
177
178
        key_padding_mask = torch.zeros((bs, max_patch_len),
                                       dtype=torch.bool,
                                       device=device)

        pos_embed = []
        for i in range(bs):
Jee Jee Li's avatar
Jee Jee Li committed
179
            tgt_h, tgt_w = tgt_sizes[i].tolist()
180
181
182
183
184
185
186
187
            pos_embed.append(self.pos_embed[:tgt_h, :tgt_w, :].reshape(
                (tgt_h * tgt_w, -1)).to(dtype))  # patches * D
            key_padding_mask[i, patch_len[i]:] = True
        pos_embed = torch.nn.utils.rnn.pad_sequence(pos_embed,
                                                    batch_first=True,
                                                    padding_value=0.0).permute(
                                                        1, 0,
                                                        2)  # BLD => L * B * D
Jee Jee Li's avatar
Jee Jee Li committed
188
        x, _ = self.kv_proj(x)  # B * L * D
189
190
191
192
193
194
195
196
        x = self.ln_kv(x).permute(1, 0, 2)  # L * B * D

        q = self.ln_q(self.query)  # Q * D

        out = self.attn(
            self._repeat(q, bs),  # Q * B * D
            x + pos_embed,  # L * B * D +  L * B * D
            x,
Jee Jee Li's avatar
Jee Jee Li committed
197
198
            key_padding_mask=key_padding_mask,
        )[0]
199
200
201
202
203
204
205
206
        #  out: Q * B * D
        x = out.permute(1, 0, 2)  # B * Q * D

        x = self.ln_post(x)
        x = x @ self.proj
        return x


207
208
209
210
211
212
213
214
215
216
217
218
219
def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
    version_float = getattr(config, "version", None)

    # The old configs do not include version number
    # TODO: Remove this after the HF repos are updated
    if version_float is None:
        if config.hidden_size == 2304 and config.query_num == 64:
            return (2, 0)
        return (2, 5)
    version_str = str(version_float)
    return tuple(int(x) for x in version_str.split("."))


220
def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]):
221
222
223
224
225
226
    pixel_values = hf_inputs.get("pixel_values", torch.empty(0))
    num_images = len(pixel_values)

    video_pixel_values = hf_inputs.get("video_pixel_values", torch.empty(0))
    num_videos = len(video_pixel_values)

227
    return dict(
228
        pixel_values=MultiModalFieldConfig.batched("image"),
229
        image_sizes=MultiModalFieldConfig.batched("image"),
230
231
232
        tgt_sizes=MultiModalFieldConfig.batched("image"),
        image_embeds=MultiModalFieldConfig.batched("image"),
        video_pixel_values=MultiModalFieldConfig.batched("video"),
233
        video_image_sizes=MultiModalFieldConfig.batched("video"),
234
235
        video_tgt_sizes=MultiModalFieldConfig.batched("video"),
        video_embeds=MultiModalFieldConfig.batched("video"),
236
237
        image_token_id=MultiModalFieldConfig.shared("image", num_images),
        video_token_id=MultiModalFieldConfig.shared("video", num_videos),
238
239
240
241
242
243
244
245
    )


class MiniCPMVImageEmbeddingItems(DictEmbeddingItems):

    def __init__(
        self,
        data: Mapping[str, torch.Tensor],
246
247
248
249
        fields_factory: Callable[
            [Mapping[str, torch.Tensor]],
            Mapping[str, MultiModalFieldConfig],
        ],
250
251
252
253
254
    ) -> None:
        super().__init__(
            data,
            modality="image",
            required_fields={"image_embeds", "image_sizes"},
255
            fields_factory=fields_factory,
256
257
258
259
260
261
262
263
264
265
266
267
        )

    def get_image_size(self, index: int) -> ImageSize:
        image_size = self.get(index)["image_sizes"].tolist()
        return ImageSize(width=image_size[0], height=image_size[1])


class MiniCPMVVideoEmbeddingItems(DictEmbeddingItems):

    def __init__(
        self,
        data: Mapping[str, torch.Tensor],
268
269
270
271
        fields_factory: Callable[
            [Mapping[str, torch.Tensor]],
            Mapping[str, MultiModalFieldConfig],
        ],
272
273
274
275
276
    ) -> None:
        super().__init__(
            data,
            modality="video",
            required_fields={"video_embeds", "video_image_sizes"},
277
            fields_factory=fields_factory,
278
279
280
281
282
283
284
285
286
287
        )

    def get_frame_size(self, index: int) -> ImageSize:
        frame_size = self.get(index)["video_image_sizes"].tolist()
        return ImageSize(width=frame_size[0], height=frame_size[1])

    def get_num_frames(self, index: int) -> int:
        return len(self.get(index)["video_image_sizes"])


288
289
290
291
292
class MiniCPMVMultiModalDataParser(MultiModalDataParser):

    def _parse_image_data(
        self,
        data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
293
    ) -> Optional[ModalityDataItems[Any, Any]]:
294
        if isinstance(data, dict):
295
296
            return MiniCPMVImageEmbeddingItems(
                data,
297
                fields_factory=_minicpmv_field_config,
298
299
            )

300
301
302
303
304
        return super()._parse_image_data(data)

    def _parse_video_data(
        self,
        data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]],
305
    ) -> Optional[ModalityDataItems[Any, Any]]:
306
        if isinstance(data, dict):
307
308
            return MiniCPMVVideoEmbeddingItems(
                data,
309
                fields_factory=_minicpmv_field_config,
310
311
            )

312
313
314
315
316
317
318
319
320
321
        return super()._parse_video_data(data)


class MiniCPMVProcessingInfo(BaseProcessingInfo):
    image_pattern = "(<image>./</image>)"
    video_pattern = "(<video>./</video>)"

    def get_hf_config(self):
        return self.ctx.get_hf_config()

322
323
    def get_hf_processor(self, **kwargs: object):
        hf_processor = self.ctx.get_hf_processor(**kwargs)
324
325
326
327
328
329
330
331
332

        # NumPy arrays are considered as Iterable but not Sequence in
        # https://github.com/huggingface/transformers/blob/main/src/transformers/image_transforms.py#L428
        image_processor = hf_processor.image_processor  # type: ignore
        for attr in ("mean", "std"):
            val = getattr(image_processor, attr)
            if isinstance(val, np.ndarray):
                setattr(image_processor, attr, val.tolist())

333
334
335
336
337
338
339
340
341
342
343
        return hf_processor

    def get_image_processor(self):
        hf_processor = self.get_hf_processor()
        image_processor = hf_processor.image_processor  # type: ignore
        return image_processor

    def get_model_version(self):
        return get_version_by_config(self.get_hf_config())

    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
344
        mm_limits = {"image": None}
345
        if self.get_model_version() == (2, 6):
346
347
348
            mm_limits["video"] = None

        return mm_limits
349

350
351
352
353
354
355
356
357
358
359
    def get_slice_image_placeholder(
        self,
        image_size: ImageSize,
        # For MiniCPM V/O 2.6
        image_idx: int = 0,
        max_slice_nums: Optional[int] = None,
        use_image_id: bool = True,
    ) -> str:
        image_processor = self.get_image_processor()
        version = self.get_model_version()
360

361
362
        if version == (2, 0) or version == (2, 5):
            return image_processor.get_slice_image_placeholder(image_size)
363

364
365
366
367
368
369
        return image_processor.get_slice_image_placeholder(
            image_size,
            image_idx=image_idx,
            max_slice_nums=max_slice_nums,
            use_image_id=use_image_id,
        )
370

371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
    def get_sliced_grid(
        self,
        image_size: ImageSize,
        # For MiniCPM V/O 2.6
        max_slice_nums: Optional[int] = None,
    ) -> Optional[tuple[int, int]]:
        image_processor = self.get_image_processor()
        version = self.get_model_version()

        if version == (2, 0) or version == (2, 5):
            return image_processor.get_sliced_grid(image_size)

        if max_slice_nums is None:
            max_slice_nums = image_processor.max_slice_nums

        return image_processor.get_sliced_grid(
            image_size,
            max_slice_nums=max_slice_nums,
        )

391
392
393
394
395
    def get_num_image_tokens(
        self,
        image_size: ImageSize,
        max_slice_nums: Optional[int] = None,
    ) -> int:
396
397
398
        image_processor = self.get_image_processor()

        grid = self.get_sliced_grid(
399
400
401
            image_size,
            max_slice_nums=max_slice_nums,
        )
402
403
404
405
        if grid is None:
            ncols = nrows = 0
        else:
            ncols, nrows = grid
406

407
        return (ncols * nrows + 1) * image_processor.image_feature_size
408
409
410

    def get_max_image_tokens(self) -> int:
        image_size = self.get_image_size_with_most_features()
411
412
413
414
        return self.get_num_image_tokens(image_size)

    def get_image_max_slice_num(self) -> int:
        return getattr(self.get_hf_config(), "max_slice_num", 9)
415
416

    def get_image_size_with_most_features(self) -> ImageSize:
417
418
419
420
421
422
423
424
425
426
427
428
        image_size = getattr(self.get_hf_config(), "image_size", 448)
        max_slice_num = self.get_image_max_slice_num()
        return ImageSize(width=image_size, height=image_size * max_slice_num)

    def get_max_video_frame_tokens(self) -> int:
        frame_size = self.get_video_frame_size_with_most_features()

        return self.get_num_image_tokens(
            frame_size,
            max_slice_nums=self.get_video_max_slice_num(),
        )

429
430
431
432
433
434
435
436
    def get_max_video_tokens(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
        num_frames = self.get_num_frames_with_most_features(seq_len, mm_counts)
        num_video_tokens_total = self.get_max_video_frame_tokens() * num_frames
        return num_video_tokens_total
437
438
439

    def get_video_max_slice_num(self) -> int:
        return 1
440

441
    def get_video_frame_size_with_most_features(self) -> ImageSize:
442
443
444
        image_size = getattr(self.get_hf_config(), "image_size", 448)
        max_slice_num = self.get_video_max_slice_num()
        return ImageSize(width=image_size, height=image_size * max_slice_num)
445

446
447
448
449
    def get_max_video_frames(self, max_tokens: int) -> int:
        num_frame_tokens = self.get_max_video_frame_tokens()
        num_frames = max_tokens // num_frame_tokens
        return num_frames
450

451
452
453
454
455
456
457
    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
        max_images = mm_counts.get("image", 0)
        max_videos = mm_counts.get("video", 0)
458

459
        max_image_tokens = self.get_max_image_tokens() * max_images
460
461
        max_total_frames = self.get_max_video_frames(seq_len -
                                                     max_image_tokens)
462
463
        max_frames_per_video = min(max_total_frames // max(max_videos, 1),
                                   _MAX_FRAMES_PER_VIDEO)
464

465
        return max(max_frames_per_video, 1)
466
467


468
469
470
471
472
473
_I = TypeVar("_I",
             bound=MiniCPMVProcessingInfo,
             default=MiniCPMVProcessingInfo)


class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
474

475
476
477
478
479
480
481
482
483
484
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

        image_prompt_texts = self.info.image_pattern * num_images
        video_prompt_texts = self.info.video_pattern * num_videos

        return image_prompt_texts + video_prompt_texts

    def get_dummy_mm_data(
485
486
487
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
488
    ) -> MultiModalDataDict:
489
490
491
492
493
494
495
496
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

        image_width, image_height = \
            self.info.get_image_size_with_most_features()
        video_width, video_height = \
            self.info.get_video_frame_size_with_most_features()
        num_video_frames = \
497
            self.info.get_num_frames_with_most_features(seq_len, mm_counts)
498

499
        return {
500
501
502
503
504
505
506
507
508
509
510
511
            "image":
            self._get_dummy_images(width=image_width,
                                   height=image_height,
                                   num_images=num_images),
            "video": [
                self._get_dummy_images(width=video_width,
                                       height=video_height,
                                       num_images=num_video_frames)
            ] * num_videos,
        }


512
class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
513
514
515
516
517
518
519

    def _get_data_parser(self) -> MultiModalDataParser:
        return MiniCPMVMultiModalDataParser()

    def get_image_prompt_texts(self,
                               image_size: ImageSize,
                               image_idx: int = 0) -> str:
520
521
522
523
        return self.info.get_slice_image_placeholder(
            image_size,
            image_idx=image_idx,
        )
524
525
526

    def get_video_prompt_texts(self, image_size: ImageSize,
                               num_frames: int) -> str:
527
        return self.info.get_slice_image_placeholder(
528
529
530
531
532
            image_size=image_size,
            image_idx=0,
            max_slice_nums=self.info.get_video_max_slice_num(),
            use_image_id=False,
        ) * num_frames
533

534
535
536
537
538
    def process_images(
        self,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, NestedTensors]:
539
540
541
542
543
        if (images := mm_data.get("images")) is None:
            return {}

        parsed_images = (self._get_data_parser().parse_mm_data({
            "image": images
544
545
        }).get_items("image",
                     (MiniCPMVImageEmbeddingItems, ImageProcessorItems)))
546

547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
        if isinstance(parsed_images, MiniCPMVImageEmbeddingItems):
            image_inputs = {}
        else:
            image_inputs = self._base_call_hf_processor(
                prompts=[self.info.image_pattern] * len(parsed_images),
                mm_data={"images": [[image] for image in parsed_images]},
                mm_kwargs=mm_kwargs,
                out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
            )

        tokenizer = self.info.get_tokenizer()
        unk_token_id = tokenizer.get_vocab()["<unk>"]
        image_inputs["image_token_id"] = torch.tensor(unk_token_id)

        return image_inputs
562

563
564
565
566
567
    def process_videos(
        self,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, NestedTensors]:
568
569
570
571
572
        if (videos := mm_data.get("videos")) is None:
            return {}

        parsed_videos = (self._get_data_parser().parse_mm_data({
            "video": videos
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
        }).get_items("video",
                     (MiniCPMVVideoEmbeddingItems, VideoProcessorItems)))

        if isinstance(parsed_videos, MiniCPMVVideoEmbeddingItems):
            video_inputs = {}
        else:
            video_inputs = self._base_call_hf_processor(
                prompts=[
                    self.info.image_pattern * len(video)
                    for video in parsed_videos
                ],
                mm_data={"images": list(parsed_videos)},
                mm_kwargs={
                    **mm_kwargs,
                    "max_slice_nums":
                    self.info.get_video_max_slice_num(),
                },
                out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
            )

593
594
        video_inputs = {f"video_{k}": v for k, v in video_inputs.items()}

595
        tokenizer = self.info.get_tokenizer()
596
597
        unk_token_id = tokenizer.get_vocab()["<unk>"]
        video_inputs["video_token_id"] = torch.tensor(unk_token_id)
598

599
        return video_inputs
600

601
602
603
604
    def process_mm_inputs(
        self,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
605
    ) -> Mapping[str, NestedTensors]:
606
        return {
607
608
            **self.process_images(mm_data, mm_kwargs),
            **self.process_videos(mm_data, mm_kwargs),
609
        }
610

611
    def _base_call_hf_processor(
612
        self,
613
614
        prompts: list[str],
        mm_data: Mapping[str, Sequence[object]],
615
        mm_kwargs: Mapping[str, object],
616
617
        *,
        out_keys: set[str],
618
    ) -> dict[str, NestedTensors]:
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
        # This processor supports zipping prompt and mm_data together
        if self.info.get_model_version() == (2, 6):
            inputs = super()._call_hf_processor(
                prompt=prompts,  # type: ignore
                mm_data=mm_data,
                mm_kwargs=mm_kwargs,
            )
        else:
            inputs = defaultdict[str, list[torch.Tensor]](list)

            for i, prompt in enumerate(prompts):
                inputs_one = super()._call_hf_processor(
                    prompt=prompt,
                    mm_data={
                        k: v[i]
                        for k, v in mm_data.items()
                    },
                    mm_kwargs=mm_kwargs,
                )

                for k, v in inputs_one.items():
                    assert len(v) == 1, (k, len(v))
                    inputs[k].append(v[0])

        return {k: inputs[k] for k in out_keys}
644
645
646
647
648
649
650
651

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        tokenizer = self.info.get_tokenizer()
652
653

        input_ids = torch.tensor([tokenizer.encode(prompt)])
654
        mm_inputs = self.process_mm_inputs(mm_data, mm_kwargs)
655
656

        return BatchFeature({
657
            "input_ids": input_ids,
658
            **mm_inputs,
659
        })
660

661
    def _hf_processor_applies_updates(
662
663
664
665
666
667
668
        self,
        prompt_text: str,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> bool:
        return False

669
    def _get_prompt_updates(
670
671
672
673
674
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
    ) -> Sequence[PromptUpdate]:
675
676
677
        placeholder = {
            "image": self.info.image_pattern,
            "video": self.info.video_pattern,
678
        }
679

680
681
682
683
684
685
        def get_image_replacement(item_idx: int):
            images = mm_items.get_items(
                "image", (MiniCPMVImageEmbeddingItems, ImageProcessorItems))

            image_size = images.get_image_size(item_idx)

686
687
688
689
            return PromptUpdateDetails.select_text(
                self.get_image_prompt_texts(image_size, item_idx),
                "<unk>",
            )
690
691
692
693
694
695
696
697

        def get_video_replacement(item_idx: int):
            videos = mm_items.get_items(
                "video", (MiniCPMVVideoEmbeddingItems, VideoProcessorItems))

            frame_size = videos.get_frame_size(item_idx)
            num_frames = videos.get_num_frames(item_idx)

698
699
700
701
            return PromptUpdateDetails.select_text(
                self.get_video_prompt_texts(frame_size, num_frames),
                "<unk>",
            )
702
703
704
705
706

        get_replacement = {
            "image": get_image_replacement,
            "video": get_video_replacement,
        }
707
708
709
710

        return [
            PromptReplacement(modality=modality,
                              target=placeholder[modality],
711
                              replacement=get_replacement[modality])
712
713
            for modality in ("image", "video")
        ]
714

715
716
    def _get_mm_fields_config(
        self,
717
        hf_inputs: BatchFeature,
718
719
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
720
        return _minicpmv_field_config(hf_inputs)
721

722
723

class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
Jee Jee Li's avatar
Jee Jee Li committed
724
725
726
727
    """
    The abstract class of MiniCPMV can only be inherited, but cannot be
    instantiated.
    """
728

729
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
730
731
732
        config = vllm_config.model_config.hf_config
        multimodal_config = vllm_config.model_config.multimodal_config
        quant_config = vllm_config.quant_config
733
        super().__init__()
734
735
736
737
        # All MiniCPM-V models disable `tie_word_embeddings` but
        # `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot
        # check `tie_word_embeddings` until vLLM integrate MiniCPM-V model
        # and config class
738
739
740
        self.config = config
        self.multimodal_config = multimodal_config

741
        self.version = get_version_by_config(self.config)
742
743
744
745
746
        self.llm = self.init_llm(vllm_config=vllm_config,
                                 prefix=maybe_prefix(prefix, "llm"))
        self.vpm = self.init_vision_module(config,
                                           quant_config,
                                           prefix=maybe_prefix(prefix, "vpm"))
Jee Jee Li's avatar
Jee Jee Li committed
747
748
        self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else
                           self.vpm.embeddings.embed_dim)
Alphi's avatar
Alphi committed
749
        self.embed_dim = self.config.hidden_size
750

751
752
753
        self.resampler = self.init_resampler(self.embed_dim,
                                             self.vision_dim,
                                             quant_config=quant_config,
754
755
                                             prefix=maybe_prefix(
                                                 prefix, "resampler"))
756

757
        self.mm_token_ids = set[int]()
758
759
760
        self.make_empty_intermediate_tensors = (
            self.llm.make_empty_intermediate_tensors)

761
762
763
764
765
766
767
    @cached_property
    def sampler(self):
        if hasattr(self.llm, "sampler"):
            return self.llm.sampler

        return get_sampler()

768
    def _parse_and_validate_vision_input(
Jee Jee Li's avatar
Jee Jee Li committed
769
        self,
770
        modality: str,
Jee Jee Li's avatar
Jee Jee Li committed
771
        **kwargs: object,
772
    ) -> Optional[MiniCPMVImageInputs]:
773
774
        pixel_values = kwargs.pop("pixel_values", None)
        image_embeds = kwargs.pop("image_embeds", None)
775

776
        if pixel_values is None and image_embeds is None:
777
778
            return None

779
780
781
782
783
784
785
786
787
788
789
790
        image_token_id = kwargs.pop("image_token_id")
        if image_token_id is not None:
            assert isinstance(image_token_id, torch.Tensor)
            self.mm_token_ids.add(image_token_id.flatten().unique().item())

        if image_embeds is not None:
            if not isinstance(image_embeds, (torch.Tensor, list)):
                raise ValueError(
                    f"Incorrect type of image_embeds for {modality=}. "
                    f"Got type: {type(image_embeds)}")

            image_embeds_flat = flatten_bn(image_embeds)
791

792
            return MiniCPMVImageEmbeddingInputs(
793
                type="image_embeds",
794
                image_embeds=image_embeds_flat,
795
            )
796

797
798
799
800
        if not isinstance(pixel_values, (torch.Tensor, list)):
            raise ValueError(
                f"Incorrect type of pixel_values for {modality=}. "
                f"Got type: {type(pixel_values)}")
801

802
803
804
805
806
807
808
809
810
811
        tgt_sizes = kwargs.pop("tgt_sizes")
        if not isinstance(tgt_sizes, (torch.Tensor, list)):
            raise ValueError(f"Incorrect type of tgt_sizes for {modality=}. "
                             f"Got type: {type(tgt_sizes)}")

        num_slices = [[len(p) for p in ps] for ps in pixel_values]
        num_slices_flat = flatten_bn(torch.tensor(num_slices))

        pixel_values_flat = flatten_bn(flatten_2d_lists(pixel_values))
        tgt_sizes_flat = flatten_bn(flatten_2d_lists(tgt_sizes), concat=True)
812

Jee Jee Li's avatar
Jee Jee Li committed
813
814
815
816
817
        if len(pixel_values_flat) != len(tgt_sizes_flat):
            raise ValueError("Inconsistent flattened lengths, found: "
                             f"{len(pixel_values_flat)} vs. "
                             f"{len(tgt_sizes_flat)}")

818
        return MiniCPMVImagePixelInputs(
819
820
            type="pixel_values",
            pixel_values=pixel_values_flat,
821
822
            tgt_sizes=tgt_sizes_flat,
            num_slices=num_slices_flat,
Jee Jee Li's avatar
Jee Jee Li committed
823
        )
824

825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        modalities = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
            if input_key in ("pixel_values",
                             "image_embeds") and "images" not in modalities:
                modalities["images"] = self._parse_and_validate_vision_input(
                    "images", **kwargs)
            if input_key in ("video_pixel_values",
                             "video_embeds") and "videos" not in modalities:

                def _image_key(video_key: str):
                    if video_key == "video_token_id":
                        return "image_token_id"

                    return video_key.removeprefix("video_")

                modalities["videos"] = self._parse_and_validate_vision_input(
                    "videos", **{
                        _image_key(k): v
                        for k, v in kwargs.items()
                    })

        return modalities

    def _process_vision_input(
        self,
        image_input: MiniCPMVImageInputs,
    ) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]]:
        if image_input["type"] == "image_embeds":
            return image_input["image_embeds"]

        image_features_flat = self.get_vision_hidden_states(image_input)

861
862
863
864
865
        num_slices = image_input["num_slices"]
        return [
            e.flatten(0, 1)
            for e in image_features_flat.split(num_slices.tolist())
        ]
866
867
868
869
870
871
872
873
874
875
876
877

    def _process_multimodal_inputs(self, modalities: dict):
        # The result multimodal_embeddings is tuple of tensors, with each
        # tensor correspoending to a multimodal data item (image or video).
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in modalities:
            if modality == "images":
                image_input = modalities["images"]
                image_features = self._process_vision_input(image_input)
878
                multimodal_embeddings += tuple(image_features)
879
880
881
            if modality == "videos":
                video_input = modalities["videos"]
                video_features = self._process_vision_input(video_input)
882
                multimodal_embeddings += tuple(video_features)
883
884
885

        return multimodal_embeddings

886
887
888
    def get_language_model(self) -> torch.nn.Module:
        return self.llm

889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
    def get_multimodal_embeddings(
            self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
        modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not modalities:
            return None

        return self._process_multimodal_inputs(modalities)

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
    ) -> torch.Tensor:
        inputs_embeds = self.llm.get_input_embeddings(input_ids)
        if multimodal_embeddings is not None:
            assert len(self.mm_token_ids) > 0
            inputs_embeds = merge_multimodal_embeddings(
                input_ids,
                inputs_embeds,
908
                multimodal_embeddings,
909
910
911
                list(self.mm_token_ids),
            )
        return inputs_embeds
912

913
914
915
916
917
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
918
        inputs_embeds: Optional[torch.Tensor] = None,
Jee Jee Li's avatar
Jee Jee Li committed
919
920
        **kwargs: Any,
    ) -> torch.Tensor:
921
        if intermediate_tensors is not None:
922
923
924
925
926
927
928
            inputs_embeds = None

        # NOTE: In v1, inputs_embeds is always generated at model runner from
        # `get_multimodal_embeddings` and `get_input_embeddings`, this
        # condition is only for v0 compatibility.
        elif inputs_embeds is None:
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
Jee Jee Li's avatar
Jee Jee Li committed
929

930
931
932
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
933

934
        hidden_states = self.llm.model(
935
            input_ids=input_ids,
Jee Jee Li's avatar
Jee Jee Li committed
936
937
            positions=positions,
            intermediate_tensors=intermediate_tensors,
938
            inputs_embeds=inputs_embeds,
Jee Jee Li's avatar
Jee Jee Li committed
939
        )
940
        return hidden_states
941

942
943
944
945
946
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
947
        return self.llm.compute_logits(hidden_states, sampling_metadata)
948
949
950
951
952
953

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
Alphi's avatar
Alphi committed
954
        next_tokens = self.sampler(logits, sampling_metadata)
955
956
        return next_tokens

957
958
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
959
960
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)
Jee Jee Li's avatar
Jee Jee Li committed
961

962
963
964
965
966
967
968
969
    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(language_model="llm",
                                                connector="resampler",
                                                tower_model="vpm")

Jee Jee Li's avatar
Jee Jee Li committed
970
971
    def init_llm(
        self,
972
        vllm_config: VllmConfig,
973
        prefix: str = "",
Jee Jee Li's avatar
Jee Jee Li committed
974
975
976
    ) -> nn.Module:
        raise NotImplementedError

977
978
979
980
    def init_vision_module(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig],
981
        prefix: str = "",
982
    ) -> nn.Module:
Jee Jee Li's avatar
Jee Jee Li committed
983
984
        raise NotImplementedError

985
986
987
988
989
    def init_resampler(self,
                       embed_dim: int,
                       vision_dim: int,
                       quant_config: Optional[QuantizationConfig] = None,
                       prefix: str = "") -> nn.Module:
Jee Jee Li's avatar
Jee Jee Li committed
990
991
        raise NotImplementedError

992
993
    def get_vision_hidden_states(
            self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
Jee Jee Li's avatar
Jee Jee Li committed
994
995
996
        raise NotImplementedError


997
class MiniCPMV2_0(MiniCPMVBaseModel):
Jee Jee Li's avatar
Jee Jee Li committed
998

999
1000
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)
Jee Jee Li's avatar
Jee Jee Li committed
1001
1002
1003
1004
        assert self.version == (2, 0)

    def init_llm(
        self,
1005
        vllm_config: VllmConfig,
1006
        prefix: str = "",
Jee Jee Li's avatar
Jee Jee Li committed
1007
    ) -> nn.Module:
1008
        return MiniCPMForCausalLM(vllm_config=vllm_config, prefix=prefix)
Jee Jee Li's avatar
Jee Jee Li committed
1009

1010
1011
1012
1013
    def init_vision_module(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig],
1014
        prefix: str = "",
1015
    ) -> nn.Module:
1016
        # TODO: refactor vision model through timm wrapper from transformers
Jee Jee Li's avatar
Jee Jee Li committed
1017
1018
1019
1020
        try:
            import timm
        except ImportError:
            raise ImportError("Please install timm==0.9.10") from ImportError
1021

Jee Jee Li's avatar
Jee Jee Li committed
1022
1023
1024
1025
1026
1027
1028
1029
1030
        with set_default_torch_dtype(torch.float16):
            model = timm.create_model(
                "vit_so400m_patch14_siglip_384.webli",
                pretrained=False,
                num_classes=0,
                dynamic_img_size=True,
                dynamic_img_pad=True,
            )

1031
1032
        model = model.to(dtype=torch.get_default_dtype())

Jee Jee Li's avatar
Jee Jee Li committed
1033
1034
1035
1036
1037
1038
1039
1040
1041
        if (isinstance(model, timm.models.VisionTransformer)
                and model.attn_pool is not None):
            model.attn_pool = torch.nn.Identity()

        if self.config.drop_vision_last_layer:
            model.blocks = model.blocks[:-1]

        return model

1042
1043
1044
1045
1046
    def init_resampler(self,
                       embed_dim: int,
                       vision_dim: int,
                       quant_config: Optional[QuantizationConfig] = None,
                       prefix: str = "") -> nn.Module:
Jee Jee Li's avatar
Jee Jee Li committed
1047
        with set_default_torch_dtype(torch.float16):
1048
1049
1050
1051
1052
1053
1054
1055
1056
            resampler = Resampler2(embed_dim=embed_dim,
                                   num_heads=embed_dim // 128,
                                   grid_size=int(
                                       math.sqrt(self.config.query_num)),
                                   kv_dim=vision_dim,
                                   adaptive=False,
                                   do_post_projection=True,
                                   quant_config=quant_config,
                                   prefix=prefix)
Jee Jee Li's avatar
Jee Jee Li committed
1057

1058
1059
        return resampler.to(device=current_platform.device_type,
                            dtype=torch.get_default_dtype())
Jee Jee Li's avatar
Jee Jee Li committed
1060

1061
1062
1063
1064
1065
1066
1067
1068
1069
    def get_vision_hidden_states(
            self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
        pixel_values = data["pixel_values"]

        P_h, P_w = self.vpm.patch_embed.patch_size
        dtype: torch.dtype = self.vpm.pos_embed.data.dtype
        num_prefix_tokens = getattr(self.vpm, "num_prefix_tokens", 0)

        res = list[torch.Tensor]()
Jee Jee Li's avatar
Jee Jee Li committed
1070
1071
        for pixel_value in pixel_values:
            H, W = pixel_value[0].shape[-2:]
1072
            tgt_size = (math.ceil(H / P_h), math.ceil(W / P_w))
Jee Jee Li's avatar
Jee Jee Li committed
1073
1074
1075
            vision_embedding = self.vpm.forward_features(
                pixel_value.unsqueeze(0).type(dtype))

1076
1077
1078
            if num_prefix_tokens > 0:
                vision_embedding = vision_embedding[:, num_prefix_tokens:]
            res.append(self.resampler(vision_embedding, tgt_size))
Jee Jee Li's avatar
Jee Jee Li committed
1079

1080
        return torch.vstack(res)
Jee Jee Li's avatar
Jee Jee Li committed
1081
1082


1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
Jee Jee Li's avatar
Jee Jee Li committed
1095

1096
1097
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)
Jee Jee Li's avatar
Jee Jee Li committed
1098
1099
1100
1101
        assert self.version == (2, 5)

    def init_llm(
        self,
1102
        vllm_config: VllmConfig,
1103
        prefix: str = "",
Jee Jee Li's avatar
Jee Jee Li committed
1104
    ) -> nn.Module:
1105
        return LlamaForCausalLM(vllm_config=vllm_config, prefix=prefix)
Jee Jee Li's avatar
Jee Jee Li committed
1106

1107
1108
1109
1110
    def init_vision_module(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig],
1111
        prefix: str = "",
1112
1113
    ) -> nn.Module:
        model = Idefics2VisionTransformer(config.vision_config,
1114
1115
                                          quant_config=quant_config,
                                          prefix=prefix)
Jee Jee Li's avatar
Jee Jee Li committed
1116
1117
1118
1119
        if self.config.drop_vision_last_layer:
            model.encoder.layers = model.encoder.layers[:-1]
        return model

1120
1121
1122
1123
1124
    def init_resampler(self,
                       embed_dim: int,
                       vision_dim: int,
                       quant_config: Optional[QuantizationConfig] = None,
                       prefix: str = "") -> nn.Module:
Jee Jee Li's avatar
Jee Jee Li committed
1125
        with set_default_torch_dtype(torch.float16):
1126
1127
1128
1129
1130
1131
            resampler = Resampler2_5(num_queries=self.config.query_num,
                                     embed_dim=embed_dim,
                                     num_heads=embed_dim // 128,
                                     kv_dim=vision_dim,
                                     quant_config=quant_config,
                                     prefix=prefix)
1132

1133
1134
        return resampler.to(device=current_platform.device_type,
                            dtype=torch.get_default_dtype())
Jee Jee Li's avatar
Jee Jee Li committed
1135

1136
1137
1138
    def get_vision_hidden_states(
            self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
        pixel_values = data["pixel_values"]
Jee Jee Li's avatar
Jee Jee Li committed
1139
1140
        tgt_sizes = data["tgt_sizes"]

1141
1142
1143
1144
1145
        B = len(pixel_values)
        P = pixel_values[0].shape[-2]
        L = max(item.shape[-1] for item in pixel_values)
        device = pixel_values[0].device
        dtype = pixel_values[0].dtype
Jee Jee Li's avatar
Jee Jee Li committed
1146

1147
1148
1149
1150
1151
1152
        all_pixel_values = torch.zeros((B, 3, P, L),
                                       dtype=dtype,
                                       device=device)
        for i, pixel_values_item in enumerate(pixel_values):
            L_item = pixel_values_item.shape[-1]
            all_pixel_values[i, ..., :L_item] = pixel_values_item
Jee Jee Li's avatar
Jee Jee Li committed
1153

1154
1155
1156
        num_patches = tgt_sizes.prod(-1)
        max_patches = num_patches.max().item()
        assert isinstance(max_patches, int)
Jee Jee Li's avatar
Jee Jee Li committed
1157

1158
        patch_attn_mask = torch.zeros((B, max_patches),
Jee Jee Li's avatar
Jee Jee Li committed
1159
1160
                                      dtype=torch.bool,
                                      device=device)
1161
1162
        for i, num_patches_item in enumerate(num_patches):
            patch_attn_mask[i, :num_patches_item] = True
Jee Jee Li's avatar
Jee Jee Li committed
1163

1164
1165
1166
1167
1168
1169
1170
        vision_embedding = self.vpm(
            all_pixel_values,
            patch_attention_mask=patch_attn_mask.unsqueeze(1),
            tgt_sizes=None,
        )

        return self.resampler(vision_embedding, tgt_sizes)
Jee Jee Li's avatar
Jee Jee Li committed
1171
1172


1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
Jee Jee Li's avatar
Jee Jee Li committed
1185

1186
1187
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)
1188
        assert self.version == (2, 6)
Jee Jee Li's avatar
Jee Jee Li committed
1189
1190
1191

    def init_llm(
        self,
1192
        vllm_config: VllmConfig,
1193
        prefix: str = "",
Jee Jee Li's avatar
Jee Jee Li committed
1194
    ) -> nn.Module:
1195
        return Qwen2ForCausalLM(vllm_config=vllm_config, prefix=prefix)
Jee Jee Li's avatar
Jee Jee Li committed
1196

1197
1198
1199
1200
    def init_vision_module(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig],
1201
        prefix: str = "",
1202
1203
    ) -> nn.Module:
        model = Idefics2VisionTransformer(config.vision_config,
1204
1205
                                          quant_config=quant_config,
                                          prefix=prefix)
Jee Jee Li's avatar
Jee Jee Li committed
1206
1207
1208
1209
        if self.config.drop_vision_last_layer:
            model.encoder.layers = model.encoder.layers[:-1]
        return model

1210
1211
1212
1213
1214
    def init_resampler(self,
                       embed_dim: int,
                       vision_dim: int,
                       quant_config: Optional[QuantizationConfig] = None,
                       prefix: str = "") -> nn.Module:
Jee Jee Li's avatar
Jee Jee Li committed
1215
        with set_default_torch_dtype(torch.float16):
1216
            # The resampler in 2.6 remains consistent with the one in 2.5.
1217
1218
1219
1220
1221
1222
            resampler = Resampler2_5(num_queries=self.config.query_num,
                                     embed_dim=embed_dim,
                                     num_heads=embed_dim // 128,
                                     kv_dim=vision_dim,
                                     quant_config=quant_config,
                                     prefix=prefix)
1223

1224
1225
        return resampler.to(device=current_platform.device_type,
                            dtype=torch.get_default_dtype())
Jee Jee Li's avatar
Jee Jee Li committed
1226

1227
1228
1229
    def get_vision_hidden_states(
            self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
        pixel_values = data["pixel_values"]
Jee Jee Li's avatar
Jee Jee Li committed
1230
1231
        tgt_sizes = data["tgt_sizes"]

1232
1233
1234
1235
1236
        B = len(pixel_values)
        P = pixel_values[0].shape[-2]
        L = max(item.shape[-1] for item in pixel_values)
        device = pixel_values[0].device
        dtype = pixel_values[0].dtype
Jee Jee Li's avatar
Jee Jee Li committed
1237

1238
1239
1240
1241
1242
1243
        all_pixel_values = torch.zeros((B, 3, P, L),
                                       dtype=dtype,
                                       device=device)
        for i, pixel_values_item in enumerate(pixel_values):
            L_item = pixel_values_item.shape[-1]
            all_pixel_values[i, ..., :L_item] = pixel_values_item
Jee Jee Li's avatar
Jee Jee Li committed
1244

1245
1246
1247
        num_patches = tgt_sizes.prod(-1)
        max_patches = num_patches.max().item()
        assert isinstance(max_patches, int)
Jee Jee Li's avatar
Jee Jee Li committed
1248

1249
        patch_attn_mask = torch.zeros((B, max_patches),
Jee Jee Li's avatar
Jee Jee Li committed
1250
1251
                                      dtype=torch.bool,
                                      device=device)
1252
1253
1254
        for i, num_patches_item in enumerate(num_patches):
            patch_attn_mask[i, :num_patches_item] = True

Jee Jee Li's avatar
Jee Jee Li committed
1255
        vision_embedding = self.vpm(
1256
1257
            all_pixel_values,
            patch_attention_mask=patch_attn_mask.unsqueeze(1),
Jee Jee Li's avatar
Jee Jee Li committed
1258
            tgt_sizes=tgt_sizes,
1259
        )
Jee Jee Li's avatar
Jee Jee Li committed
1260
1261
1262
1263

        return self.resampler(vision_embedding, tgt_sizes)


1264
1265
1266
_SUPPORT_VERSION = {
    (2, 0): MiniCPMV2_0,
    (2, 5): MiniCPMV2_5,
1267
    (2, 6): MiniCPMV2_6,
1268
1269
1270
}


1271
1272
1273
1274
1275
@MULTIMODAL_REGISTRY.register_processor(
    MiniCPMVMultiModalProcessor,
    info=MiniCPMVProcessingInfo,
    dummy_inputs=MiniCPMVDummyInputsBuilder)
class MiniCPMV(MiniCPMVBaseModel, SupportsMultiModal, SupportsLoRA):
Jee Jee Li's avatar
Jee Jee Li committed
1276
1277
1278
1279
1280
    """
    Different versions of MiniCPMV use different visual encoders and LLMs,
    which is not conducive to the current integration logic of LoRA and
    bitsandbytes in vLLM. Therefore, it is necessary to separate them.
    """
1281

1282
    def __new__(cls, *, vllm_config: VllmConfig, prefix: str = ""):
1283
        config = vllm_config.model_config.hf_config
Jee Jee Li's avatar
Jee Jee Li committed
1284
1285
1286
1287
1288
1289
1290
1291
1292
        if not hasattr(config, "version"):
            if config.hidden_size == 2304 and config.query_num == 64:
                version = (2, 0)
            else:
                version = (2, 5)
        else:
            version = str(config.version).split(".")
            version = tuple([int(x) for x in version])
        # Dispatch class based on version
1293
1294
        instance_cls = _SUPPORT_VERSION.get(version)
        if instance_cls is None:
1295
1296
            raise ValueError(
                "Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6")
1297
1298
1299
1300
1301
1302
1303

        # quant_config references base class members,
        # so update values before init is called
        cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping)
        cls.embedding_modules.update(instance_cls.embedding_modules)
        cls.embedding_padding_modules += instance_cls.embedding_padding_modules
        return instance_cls(vllm_config=vllm_config, prefix=prefix)