minicpmv.py 46.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Alphi's avatar
Alphi committed
25
"""Inference-only MiniCPM-V model compatible with HuggingFace weights."""
26
import math
27
from collections import defaultdict
28
from collections.abc import Iterable, Mapping, Sequence
29
from functools import partial
30
from typing import Any, Callable, Literal, Optional, TypedDict, Union
31

32
import numpy as np
33
import torch
Alphi's avatar
Alphi committed
34
import torch.types
35
from torch import nn
36
from transformers import BatchFeature, PretrainedConfig
37
from typing_extensions import TypeVar
38

39
from vllm.config import VllmConfig
40
from vllm.model_executor.layers.quantization import QuantizationConfig
41
from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2,
42
                                                  get_2d_sincos_pos_embed)
Jee Jee Li's avatar
Jee Jee Li committed
43
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
44
45
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.models.minicpm import MiniCPMForCausalLM
46
from vllm.model_executor.models.module_mapping import MultiModelKeys
47
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
48
from vllm.model_executor.sampling_metadata import SamplingMetadata
49
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
50
51
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
                                    NestedTensors)
52
53
from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem,
                                   ImageProcessorItems, ImageSize,
54
55
                                   ModalityData, ModalityDataItems,
                                   MultiModalDataItems, MultiModalDataParser,
56
                                   VideoItem, VideoProcessorItems)
57
from vllm.multimodal.processing import (BaseMultiModalProcessor,
58
                                        BaseProcessingInfo, PromptReplacement,
59
                                        PromptUpdate, PromptUpdateDetails)
60
from vllm.multimodal.profiling import BaseDummyInputsBuilder
61
from vllm.platforms import current_platform
62
from vllm.sequence import IntermediateTensors
63
from vllm.utils import flatten_2d_lists
64

Jee Jee Li's avatar
Jee Jee Li committed
65
from .idefics2_vision_model import Idefics2VisionTransformer
66
67
68
69
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
                         SupportsMultiModal, SupportsPP)
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
                    merge_multimodal_embeddings)
70

71
72
73
# For profile run
_MAX_FRAMES_PER_VIDEO = 16

74

Jee Jee Li's avatar
Jee Jee Li committed
75
class MiniCPMVImagePixelInputs(TypedDict):
76
    type: Literal["pixel_values"]
77
    pixel_values: list[torch.Tensor]
Jee Jee Li's avatar
Jee Jee Li committed
78
    """
79
    Shape: `(batch_size * num_images * num_slices, num_channels, height, width)`
Jee Jee Li's avatar
Jee Jee Li committed
80
81
82
83
84

    Note that the image size may vary, so we pass it as a list
    instead of a batched tensor.
    """

85
    tgt_sizes: torch.Tensor
Jee Jee Li's avatar
Jee Jee Li committed
86
    """
87
    Shape: `(batch_size * num_images * num_slices, 2)`
Jee Jee Li's avatar
Jee Jee Li committed
88

89
    This should be in `(height, width)` format.
Jee Jee Li's avatar
Jee Jee Li committed
90
91
    """

92
93
94
    num_slices: torch.Tensor
    """Shape: `(batch_size * num_images)`"""

Jee Jee Li's avatar
Jee Jee Li committed
95

96
97
class MiniCPMVImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
98
    image_embeds: Union[torch.Tensor, list[torch.Tensor]]
99
    """
100
    Shape: `(batch_size * num_images, num_slices, hidden_size)`
101
102
103
104
105
106
107
108
109

    `hidden_size` must match the hidden size of language model backbone.
    instead of a batched tensor.
    """


MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs,
                            MiniCPMVImageEmbeddingInputs]

Jee Jee Li's avatar
Jee Jee Li committed
110
111
112
113
114
DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)


class Resampler2_5(BaseResampler):

115
116
117
118
119
120
    def __init__(self,
                 num_queries: int,
                 embed_dim: int,
                 num_heads: int,
                 kv_dim: Optional[int] = None,
                 norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
121
                 max_size: tuple[int, int] = (70, 70),
122
123
124
125
126
127
128
129
130
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = "") -> None:
        super().__init__(num_queries,
                         embed_dim,
                         num_heads,
                         kv_dim,
                         norm_layer,
                         quant_config=quant_config,
                         prefix=prefix)
Jee Jee Li's avatar
Jee Jee Li committed
131
132
133

        self.max_size = max_size
        self._set_2d_pos_cache(self.max_size)
134

Alphi's avatar
Alphi committed
135
    def _set_2d_pos_cache(self,
136
                          max_size: tuple[int, int],
Jee Jee Li's avatar
Jee Jee Li committed
137
138
139
140
141
                          device: torch.types.Device = "cpu") -> None:
        pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim,
                                                max_size,
                                                version=(2, 5))
        pos_embed = torch.from_numpy(pos_embed_arr).float().to(device)
142
143
        self.register_buffer("pos_embed", pos_embed, persistent=False)

Alphi's avatar
Alphi committed
144
    def _adjust_pos_cache(self, tgt_sizes: torch.Tensor,
Jee Jee Li's avatar
Jee Jee Li committed
145
146
147
148
149
                          device: torch.types.Device) -> None:
        max_h = tgt_sizes[:, 0].max().item()
        max_w = tgt_sizes[:, 1].max().item()
        assert isinstance(max_h, int) and isinstance(max_w, int)

150
        if max_h > self.max_size[0] or max_w > self.max_size[1]:
Jee Jee Li's avatar
Jee Jee Li committed
151
            self.max_size = (
152
                max(max_h, self.max_size[0]),
Jee Jee Li's avatar
Jee Jee Li committed
153
154
                max(max_w, self.max_size[1]),
            )
155
156
            self._set_2d_pos_cache(self.max_size, device)

Jee Jee Li's avatar
Jee Jee Li committed
157
158
    def forward(self, x: torch.Tensor,
                tgt_sizes: torch.Tensor) -> torch.Tensor:
159
160
161
162
163
164
165
166
167
168
        assert x.shape[0] == tgt_sizes.shape[0]
        bs = x.shape[0]

        device = x.device
        dtype = x.dtype

        patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1]

        self._adjust_pos_cache(tgt_sizes, device=device)

Jee Jee Li's avatar
Jee Jee Li committed
169
170
171
        max_patch_len = patch_len.max().item()
        assert isinstance(max_patch_len, int)

172
173
174
175
176
177
        key_padding_mask = torch.zeros((bs, max_patch_len),
                                       dtype=torch.bool,
                                       device=device)

        pos_embed = []
        for i in range(bs):
Jee Jee Li's avatar
Jee Jee Li committed
178
            tgt_h, tgt_w = tgt_sizes[i].tolist()
179
180
181
182
183
184
185
186
            pos_embed.append(self.pos_embed[:tgt_h, :tgt_w, :].reshape(
                (tgt_h * tgt_w, -1)).to(dtype))  # patches * D
            key_padding_mask[i, patch_len[i]:] = True
        pos_embed = torch.nn.utils.rnn.pad_sequence(pos_embed,
                                                    batch_first=True,
                                                    padding_value=0.0).permute(
                                                        1, 0,
                                                        2)  # BLD => L * B * D
Jee Jee Li's avatar
Jee Jee Li committed
187
        x, _ = self.kv_proj(x)  # B * L * D
188
189
190
191
192
193
194
195
        x = self.ln_kv(x).permute(1, 0, 2)  # L * B * D

        q = self.ln_q(self.query)  # Q * D

        out = self.attn(
            self._repeat(q, bs),  # Q * B * D
            x + pos_embed,  # L * B * D +  L * B * D
            x,
Jee Jee Li's avatar
Jee Jee Li committed
196
197
            key_padding_mask=key_padding_mask,
        )[0]
198
199
200
201
202
203
204
205
        #  out: Q * B * D
        x = out.permute(1, 0, 2)  # B * Q * D

        x = self.ln_post(x)
        x = x @ self.proj
        return x


206
def get_version_by_config(config: PretrainedConfig) -> tuple[int, ...]:
207
208
209
210
211
212
213
214
215
216
217
218
    version_float = getattr(config, "version", None)

    # The old configs do not include version number
    # TODO: Remove this after the HF repos are updated
    if version_float is None:
        if config.hidden_size == 2304 and config.query_num == 64:
            return (2, 0)
        return (2, 5)
    version_str = str(version_float)
    return tuple(int(x) for x in version_str.split("."))


219
def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]):
220
221
222
223
224
225
    pixel_values = hf_inputs.get("pixel_values", torch.empty(0))
    num_images = len(pixel_values)

    video_pixel_values = hf_inputs.get("video_pixel_values", torch.empty(0))
    num_videos = len(video_pixel_values)

226
    return dict(
227
        pixel_values=MultiModalFieldConfig.batched("image"),
228
        image_sizes=MultiModalFieldConfig.batched("image"),
229
230
231
        tgt_sizes=MultiModalFieldConfig.batched("image"),
        image_embeds=MultiModalFieldConfig.batched("image"),
        video_pixel_values=MultiModalFieldConfig.batched("video"),
232
        video_image_sizes=MultiModalFieldConfig.batched("video"),
233
234
        video_tgt_sizes=MultiModalFieldConfig.batched("video"),
        video_embeds=MultiModalFieldConfig.batched("video"),
235
236
        image_token_id=MultiModalFieldConfig.shared("image", num_images),
        video_token_id=MultiModalFieldConfig.shared("video", num_videos),
237
238
239
240
241
242
243
244
    )


class MiniCPMVImageEmbeddingItems(DictEmbeddingItems):

    def __init__(
        self,
        data: Mapping[str, torch.Tensor],
245
246
247
248
        fields_factory: Callable[
            [Mapping[str, torch.Tensor]],
            Mapping[str, MultiModalFieldConfig],
        ],
249
250
251
252
253
    ) -> None:
        super().__init__(
            data,
            modality="image",
            required_fields={"image_embeds", "image_sizes"},
254
            fields_factory=fields_factory,
255
256
257
258
259
260
261
262
263
264
265
266
        )

    def get_image_size(self, index: int) -> ImageSize:
        image_size = self.get(index)["image_sizes"].tolist()
        return ImageSize(width=image_size[0], height=image_size[1])


class MiniCPMVVideoEmbeddingItems(DictEmbeddingItems):

    def __init__(
        self,
        data: Mapping[str, torch.Tensor],
267
268
269
270
        fields_factory: Callable[
            [Mapping[str, torch.Tensor]],
            Mapping[str, MultiModalFieldConfig],
        ],
271
272
273
274
275
    ) -> None:
        super().__init__(
            data,
            modality="video",
            required_fields={"video_embeds", "video_image_sizes"},
276
            fields_factory=fields_factory,
277
278
279
280
281
282
283
284
285
286
        )

    def get_frame_size(self, index: int) -> ImageSize:
        frame_size = self.get(index)["video_image_sizes"].tolist()
        return ImageSize(width=frame_size[0], height=frame_size[1])

    def get_num_frames(self, index: int) -> int:
        return len(self.get(index)["video_image_sizes"])


287
288
289
290
291
class MiniCPMVMultiModalDataParser(MultiModalDataParser):

    def _parse_image_data(
        self,
        data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
292
    ) -> Optional[ModalityDataItems[Any, Any]]:
293
        if isinstance(data, dict):
294
295
            return MiniCPMVImageEmbeddingItems(
                data,
296
                fields_factory=_minicpmv_field_config,
297
298
            )

299
300
301
302
303
        return super()._parse_image_data(data)

    def _parse_video_data(
        self,
        data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]],
304
    ) -> Optional[ModalityDataItems[Any, Any]]:
305
        if isinstance(data, dict):
306
307
            return MiniCPMVVideoEmbeddingItems(
                data,
308
                fields_factory=_minicpmv_field_config,
309
310
            )

311
312
313
314
315
316
317
318
319
320
        return super()._parse_video_data(data)


class MiniCPMVProcessingInfo(BaseProcessingInfo):
    image_pattern = "(<image>./</image>)"
    video_pattern = "(<video>./</video>)"

    def get_hf_config(self):
        return self.ctx.get_hf_config()

321
322
    def get_hf_processor(self, **kwargs: object):
        hf_processor = self.ctx.get_hf_processor(**kwargs)
323
324
325
326
327
328
329
330
331

        # NumPy arrays are considered as Iterable but not Sequence in
        # https://github.com/huggingface/transformers/blob/main/src/transformers/image_transforms.py#L428
        image_processor = hf_processor.image_processor  # type: ignore
        for attr in ("mean", "std"):
            val = getattr(image_processor, attr)
            if isinstance(val, np.ndarray):
                setattr(image_processor, attr, val.tolist())

332
333
334
335
336
337
338
339
340
341
342
        return hf_processor

    def get_image_processor(self):
        hf_processor = self.get_hf_processor()
        image_processor = hf_processor.image_processor  # type: ignore
        return image_processor

    def get_model_version(self):
        return get_version_by_config(self.get_hf_config())

    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
343
        mm_limits = {"image": None}
344
        if self.get_model_version() == (2, 6):
345
346
347
            mm_limits["video"] = None

        return mm_limits
348

349
350
351
352
353
354
355
356
357
358
    def get_slice_image_placeholder(
        self,
        image_size: ImageSize,
        # For MiniCPM V/O 2.6
        image_idx: int = 0,
        max_slice_nums: Optional[int] = None,
        use_image_id: bool = True,
    ) -> str:
        image_processor = self.get_image_processor()
        version = self.get_model_version()
359

360
361
        if version == (2, 0) or version == (2, 5):
            return image_processor.get_slice_image_placeholder(image_size)
362

363
364
365
366
367
368
        return image_processor.get_slice_image_placeholder(
            image_size,
            image_idx=image_idx,
            max_slice_nums=max_slice_nums,
            use_image_id=use_image_id,
        )
369

370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
    def get_sliced_grid(
        self,
        image_size: ImageSize,
        # For MiniCPM V/O 2.6
        max_slice_nums: Optional[int] = None,
    ) -> Optional[tuple[int, int]]:
        image_processor = self.get_image_processor()
        version = self.get_model_version()

        if version == (2, 0) or version == (2, 5):
            return image_processor.get_sliced_grid(image_size)

        if max_slice_nums is None:
            max_slice_nums = image_processor.max_slice_nums

        return image_processor.get_sliced_grid(
            image_size,
            max_slice_nums=max_slice_nums,
        )

390
391
392
393
394
    def get_num_image_tokens(
        self,
        image_size: ImageSize,
        max_slice_nums: Optional[int] = None,
    ) -> int:
395
396
397
        image_processor = self.get_image_processor()

        grid = self.get_sliced_grid(
398
399
400
            image_size,
            max_slice_nums=max_slice_nums,
        )
401
402
403
404
        if grid is None:
            ncols = nrows = 0
        else:
            ncols, nrows = grid
405

406
        return (ncols * nrows + 1) * image_processor.image_feature_size
407
408
409

    def get_max_image_tokens(self) -> int:
        image_size = self.get_image_size_with_most_features()
410
411
412
413
        return self.get_num_image_tokens(image_size)

    def get_image_max_slice_num(self) -> int:
        return getattr(self.get_hf_config(), "max_slice_num", 9)
414
415

    def get_image_size_with_most_features(self) -> ImageSize:
416
417
418
419
420
421
422
423
424
425
426
427
        image_size = getattr(self.get_hf_config(), "image_size", 448)
        max_slice_num = self.get_image_max_slice_num()
        return ImageSize(width=image_size, height=image_size * max_slice_num)

    def get_max_video_frame_tokens(self) -> int:
        frame_size = self.get_video_frame_size_with_most_features()

        return self.get_num_image_tokens(
            frame_size,
            max_slice_nums=self.get_video_max_slice_num(),
        )

428
429
430
431
432
433
434
435
    def get_max_video_tokens(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
        num_frames = self.get_num_frames_with_most_features(seq_len, mm_counts)
        num_video_tokens_total = self.get_max_video_frame_tokens() * num_frames
        return num_video_tokens_total
436
437
438

    def get_video_max_slice_num(self) -> int:
        return 1
439

440
    def get_video_frame_size_with_most_features(self) -> ImageSize:
441
442
443
        image_size = getattr(self.get_hf_config(), "image_size", 448)
        max_slice_num = self.get_video_max_slice_num()
        return ImageSize(width=image_size, height=image_size * max_slice_num)
444

445
446
447
448
    def get_max_video_frames(self, max_tokens: int) -> int:
        num_frame_tokens = self.get_max_video_frame_tokens()
        num_frames = max_tokens // num_frame_tokens
        return num_frames
449

450
451
452
453
454
455
456
    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
        max_images = mm_counts.get("image", 0)
        max_videos = mm_counts.get("video", 0)
457

458
        max_image_tokens = self.get_max_image_tokens() * max_images
459
460
        max_total_frames = self.get_max_video_frames(seq_len -
                                                     max_image_tokens)
461
462
        max_frames_per_video = min(max_total_frames // max(max_videos, 1),
                                   _MAX_FRAMES_PER_VIDEO)
463

464
        return max(max_frames_per_video, 1)
465
466


467
468
469
470
471
472
_I = TypeVar("_I",
             bound=MiniCPMVProcessingInfo,
             default=MiniCPMVProcessingInfo)


class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
473

474
475
476
477
478
479
480
481
482
483
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

        image_prompt_texts = self.info.image_pattern * num_images
        video_prompt_texts = self.info.video_pattern * num_videos

        return image_prompt_texts + video_prompt_texts

    def get_dummy_mm_data(
484
485
486
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
487
    ) -> MultiModalDataDict:
488
489
490
491
492
493
494
495
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

        image_width, image_height = \
            self.info.get_image_size_with_most_features()
        video_width, video_height = \
            self.info.get_video_frame_size_with_most_features()
        num_video_frames = \
496
            self.info.get_num_frames_with_most_features(seq_len, mm_counts)
497

498
        return {
499
500
501
502
503
504
505
506
507
508
509
510
            "image":
            self._get_dummy_images(width=image_width,
                                   height=image_height,
                                   num_images=num_images),
            "video": [
                self._get_dummy_images(width=video_width,
                                       height=video_height,
                                       num_images=num_video_frames)
            ] * num_videos,
        }


511
class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
512
513
514
515
516
517
518

    def _get_data_parser(self) -> MultiModalDataParser:
        return MiniCPMVMultiModalDataParser()

    def get_image_prompt_texts(self,
                               image_size: ImageSize,
                               image_idx: int = 0) -> str:
519
520
521
522
        return self.info.get_slice_image_placeholder(
            image_size,
            image_idx=image_idx,
        )
523
524
525

    def get_video_prompt_texts(self, image_size: ImageSize,
                               num_frames: int) -> str:
526
        return self.info.get_slice_image_placeholder(
527
528
529
530
531
            image_size=image_size,
            image_idx=0,
            max_slice_nums=self.info.get_video_max_slice_num(),
            use_image_id=False,
        ) * num_frames
532

533
534
535
536
    def process_images(
        self,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
537
        tok_kwargs: Mapping[str, object],
538
    ) -> Mapping[str, NestedTensors]:
539
540
541
542
543
        if (images := mm_data.get("images")) is None:
            return {}

        parsed_images = (self._get_data_parser().parse_mm_data({
            "image": images
544
545
        }).get_items("image",
                     (MiniCPMVImageEmbeddingItems, ImageProcessorItems)))
546

547
548
549
550
551
552
553
        if isinstance(parsed_images, MiniCPMVImageEmbeddingItems):
            image_inputs = {}
        else:
            image_inputs = self._base_call_hf_processor(
                prompts=[self.info.image_pattern] * len(parsed_images),
                mm_data={"images": [[image] for image in parsed_images]},
                mm_kwargs=mm_kwargs,
554
                tok_kwargs=tok_kwargs,
555
556
557
558
559
560
561
562
                out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
            )

        tokenizer = self.info.get_tokenizer()
        unk_token_id = tokenizer.get_vocab()["<unk>"]
        image_inputs["image_token_id"] = torch.tensor(unk_token_id)

        return image_inputs
563

564
565
566
567
    def process_videos(
        self,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
568
        tok_kwargs: Mapping[str, object],
569
    ) -> Mapping[str, NestedTensors]:
570
571
572
573
574
        if (videos := mm_data.get("videos")) is None:
            return {}

        parsed_videos = (self._get_data_parser().parse_mm_data({
            "video": videos
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
        }).get_items("video",
                     (MiniCPMVVideoEmbeddingItems, VideoProcessorItems)))

        if isinstance(parsed_videos, MiniCPMVVideoEmbeddingItems):
            video_inputs = {}
        else:
            video_inputs = self._base_call_hf_processor(
                prompts=[
                    self.info.image_pattern * len(video)
                    for video in parsed_videos
                ],
                mm_data={"images": list(parsed_videos)},
                mm_kwargs={
                    **mm_kwargs,
                    "max_slice_nums":
                    self.info.get_video_max_slice_num(),
                },
592
                tok_kwargs=tok_kwargs,
593
594
595
                out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
            )

596
597
        video_inputs = {f"video_{k}": v for k, v in video_inputs.items()}

598
        tokenizer = self.info.get_tokenizer()
599
600
        unk_token_id = tokenizer.get_vocab()["<unk>"]
        video_inputs["video_token_id"] = torch.tensor(unk_token_id)
601

602
        return video_inputs
603

604
605
606
607
    def process_mm_inputs(
        self,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
608
        tok_kwargs: Mapping[str, object],
609
    ) -> Mapping[str, NestedTensors]:
610
        return {
611
612
            **self.process_images(mm_data, mm_kwargs, tok_kwargs),
            **self.process_videos(mm_data, mm_kwargs, tok_kwargs),
613
        }
614

615
    def _base_call_hf_processor(
616
        self,
617
618
        prompts: list[str],
        mm_data: Mapping[str, Sequence[object]],
619
        mm_kwargs: Mapping[str, object],
620
        tok_kwargs: Mapping[str, object],
621
622
        *,
        out_keys: set[str],
623
    ) -> dict[str, NestedTensors]:
624
625
626
627
628
629
        # This processor supports zipping prompt and mm_data together
        if self.info.get_model_version() == (2, 6):
            inputs = super()._call_hf_processor(
                prompt=prompts,  # type: ignore
                mm_data=mm_data,
                mm_kwargs=mm_kwargs,
630
                tok_kwargs=tok_kwargs,
631
632
633
634
635
636
637
638
639
640
641
642
            )
        else:
            inputs = defaultdict[str, list[torch.Tensor]](list)

            for i, prompt in enumerate(prompts):
                inputs_one = super()._call_hf_processor(
                    prompt=prompt,
                    mm_data={
                        k: v[i]
                        for k, v in mm_data.items()
                    },
                    mm_kwargs=mm_kwargs,
643
                    tok_kwargs=tok_kwargs,
644
645
646
647
648
649
650
                )

                for k, v in inputs_one.items():
                    assert len(v) == 1, (k, len(v))
                    inputs[k].append(v[0])

        return {k: inputs[k] for k in out_keys}
651
652
653
654
655
656

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
657
        tok_kwargs: Mapping[str, object],
658
659
    ) -> BatchFeature:
        tokenizer = self.info.get_tokenizer()
660

661
662
        input_ids = torch.tensor([tokenizer.encode(prompt, **tok_kwargs)])
        mm_inputs = self.process_mm_inputs(mm_data, mm_kwargs, tok_kwargs)
663
664

        return BatchFeature({
665
            "input_ids": input_ids,
666
            **mm_inputs,
667
        })
668

669
    def _hf_processor_applies_updates(
670
671
672
673
        self,
        prompt_text: str,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
674
        tokenization_kwargs: Mapping[str, object],
675
676
677
    ) -> bool:
        return False

678
    def _get_prompt_updates(
679
680
681
682
683
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
    ) -> Sequence[PromptUpdate]:
684
685
686
        placeholder = {
            "image": self.info.image_pattern,
            "video": self.info.video_pattern,
687
        }
688

689
690
691
692
693
694
        def get_image_replacement(item_idx: int):
            images = mm_items.get_items(
                "image", (MiniCPMVImageEmbeddingItems, ImageProcessorItems))

            image_size = images.get_image_size(item_idx)

695
696
697
698
            return PromptUpdateDetails.select_text(
                self.get_image_prompt_texts(image_size, item_idx),
                "<unk>",
            )
699
700
701
702
703
704
705
706

        def get_video_replacement(item_idx: int):
            videos = mm_items.get_items(
                "video", (MiniCPMVVideoEmbeddingItems, VideoProcessorItems))

            frame_size = videos.get_frame_size(item_idx)
            num_frames = videos.get_num_frames(item_idx)

707
708
709
710
            return PromptUpdateDetails.select_text(
                self.get_video_prompt_texts(frame_size, num_frames),
                "<unk>",
            )
711
712
713
714
715

        get_replacement = {
            "image": get_image_replacement,
            "video": get_video_replacement,
        }
716
717
718
719

        return [
            PromptReplacement(modality=modality,
                              target=placeholder[modality],
720
                              replacement=get_replacement[modality])
721
722
            for modality in ("image", "video")
        ]
723

724
725
    def _get_mm_fields_config(
        self,
726
        hf_inputs: BatchFeature,
727
728
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
729
        return _minicpmv_field_config(hf_inputs)
730

731
732

class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
Jee Jee Li's avatar
Jee Jee Li committed
733
734
735
736
    """
    The abstract class of MiniCPMV can only be inherited, but cannot be
    instantiated.
    """
737

738
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
739
740
741
        config = vllm_config.model_config.hf_config
        multimodal_config = vllm_config.model_config.multimodal_config
        quant_config = vllm_config.quant_config
742
        super().__init__()
743
744
745
746
        # All MiniCPM-V models disable `tie_word_embeddings` but
        # `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot
        # check `tie_word_embeddings` until vLLM integrate MiniCPM-V model
        # and config class
747
748
749
        self.config = config
        self.multimodal_config = multimodal_config

750
        self.version = get_version_by_config(self.config)
751
752
753
754
755
        self.llm = self.init_llm(vllm_config=vllm_config,
                                 prefix=maybe_prefix(prefix, "llm"))
        self.vpm = self.init_vision_module(config,
                                           quant_config,
                                           prefix=maybe_prefix(prefix, "vpm"))
Jee Jee Li's avatar
Jee Jee Li committed
756
757
        self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else
                           self.vpm.embeddings.embed_dim)
Alphi's avatar
Alphi committed
758
        self.embed_dim = self.config.hidden_size
759

760
761
762
        self.resampler = self.init_resampler(self.embed_dim,
                                             self.vision_dim,
                                             quant_config=quant_config,
763
764
                                             prefix=maybe_prefix(
                                                 prefix, "resampler"))
765

766
        self.mm_token_ids = set[int]()
767
768
769
        self.make_empty_intermediate_tensors = (
            self.llm.make_empty_intermediate_tensors)

770
    def _parse_and_validate_vision_input(
Jee Jee Li's avatar
Jee Jee Li committed
771
        self,
772
        modality: str,
Jee Jee Li's avatar
Jee Jee Li committed
773
        **kwargs: object,
774
    ) -> Optional[MiniCPMVImageInputs]:
775
776
        pixel_values = kwargs.pop("pixel_values", None)
        image_embeds = kwargs.pop("image_embeds", None)
777

778
        if pixel_values is None and image_embeds is None:
779
780
            return None

781
782
783
784
785
786
787
788
789
790
791
792
        image_token_id = kwargs.pop("image_token_id")
        if image_token_id is not None:
            assert isinstance(image_token_id, torch.Tensor)
            self.mm_token_ids.add(image_token_id.flatten().unique().item())

        if image_embeds is not None:
            if not isinstance(image_embeds, (torch.Tensor, list)):
                raise ValueError(
                    f"Incorrect type of image_embeds for {modality=}. "
                    f"Got type: {type(image_embeds)}")

            image_embeds_flat = flatten_bn(image_embeds)
793

794
            return MiniCPMVImageEmbeddingInputs(
795
                type="image_embeds",
796
                image_embeds=image_embeds_flat,
797
            )
798

799
800
801
802
        if not isinstance(pixel_values, (torch.Tensor, list)):
            raise ValueError(
                f"Incorrect type of pixel_values for {modality=}. "
                f"Got type: {type(pixel_values)}")
803

804
805
806
807
808
809
810
811
812
813
        tgt_sizes = kwargs.pop("tgt_sizes")
        if not isinstance(tgt_sizes, (torch.Tensor, list)):
            raise ValueError(f"Incorrect type of tgt_sizes for {modality=}. "
                             f"Got type: {type(tgt_sizes)}")

        num_slices = [[len(p) for p in ps] for ps in pixel_values]
        num_slices_flat = flatten_bn(torch.tensor(num_slices))

        pixel_values_flat = flatten_bn(flatten_2d_lists(pixel_values))
        tgt_sizes_flat = flatten_bn(flatten_2d_lists(tgt_sizes), concat=True)
814

Jee Jee Li's avatar
Jee Jee Li committed
815
816
817
818
819
        if len(pixel_values_flat) != len(tgt_sizes_flat):
            raise ValueError("Inconsistent flattened lengths, found: "
                             f"{len(pixel_values_flat)} vs. "
                             f"{len(tgt_sizes_flat)}")

820
        return MiniCPMVImagePixelInputs(
821
822
            type="pixel_values",
            pixel_values=pixel_values_flat,
823
824
            tgt_sizes=tgt_sizes_flat,
            num_slices=num_slices_flat,
Jee Jee Li's avatar
Jee Jee Li committed
825
        )
826

827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        modalities = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
            if input_key in ("pixel_values",
                             "image_embeds") and "images" not in modalities:
                modalities["images"] = self._parse_and_validate_vision_input(
                    "images", **kwargs)
            if input_key in ("video_pixel_values",
                             "video_embeds") and "videos" not in modalities:

                def _image_key(video_key: str):
                    if video_key == "video_token_id":
                        return "image_token_id"

                    return video_key.removeprefix("video_")

                modalities["videos"] = self._parse_and_validate_vision_input(
                    "videos", **{
                        _image_key(k): v
                        for k, v in kwargs.items()
                    })

        return modalities

    def _process_vision_input(
        self,
        image_input: MiniCPMVImageInputs,
    ) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]]:
        if image_input["type"] == "image_embeds":
            return image_input["image_embeds"]

        image_features_flat = self.get_vision_hidden_states(image_input)

863
864
865
866
867
        num_slices = image_input["num_slices"]
        return [
            e.flatten(0, 1)
            for e in image_features_flat.split(num_slices.tolist())
        ]
868
869
870
871
872
873
874
875
876
877
878
879

    def _process_multimodal_inputs(self, modalities: dict):
        # The result multimodal_embeddings is tuple of tensors, with each
        # tensor correspoending to a multimodal data item (image or video).
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in modalities:
            if modality == "images":
                image_input = modalities["images"]
                image_features = self._process_vision_input(image_input)
880
                multimodal_embeddings += tuple(image_features)
881
882
883
            if modality == "videos":
                video_input = modalities["videos"]
                video_features = self._process_vision_input(video_input)
884
                multimodal_embeddings += tuple(video_features)
885
886
887

        return multimodal_embeddings

888
889
890
    def get_language_model(self) -> torch.nn.Module:
        return self.llm

891
892
    def get_multimodal_embeddings(self,
                                  **kwargs: object) -> MultiModalEmbeddings:
893
894
        modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not modalities:
895
            return []
896
897
898
899
900
901
902
903
904

        return self._process_multimodal_inputs(modalities)

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
    ) -> torch.Tensor:
        inputs_embeds = self.llm.get_input_embeddings(input_ids)
905
906
        if multimodal_embeddings is not None \
            and len(multimodal_embeddings) != 0:
907
908
909
910
            assert len(self.mm_token_ids) > 0
            inputs_embeds = merge_multimodal_embeddings(
                input_ids,
                inputs_embeds,
911
                multimodal_embeddings,
912
913
914
                list(self.mm_token_ids),
            )
        return inputs_embeds
915

916
917
918
919
920
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
921
        inputs_embeds: Optional[torch.Tensor] = None,
Jee Jee Li's avatar
Jee Jee Li committed
922
923
        **kwargs: Any,
    ) -> torch.Tensor:
924
        if intermediate_tensors is not None:
925
926
927
928
929
930
931
            inputs_embeds = None

        # NOTE: In v1, inputs_embeds is always generated at model runner from
        # `get_multimodal_embeddings` and `get_input_embeddings`, this
        # condition is only for v0 compatibility.
        elif inputs_embeds is None:
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
Jee Jee Li's avatar
Jee Jee Li committed
932

933
934
935
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
936

937
        hidden_states = self.llm.model(
938
            input_ids=input_ids,
Jee Jee Li's avatar
Jee Jee Li committed
939
940
            positions=positions,
            intermediate_tensors=intermediate_tensors,
941
            inputs_embeds=inputs_embeds,
Jee Jee Li's avatar
Jee Jee Li committed
942
        )
943
        return hidden_states
944

945
946
947
948
949
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
950
        return self.llm.compute_logits(hidden_states, sampling_metadata)
951

952
953
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
954
955
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)
Jee Jee Li's avatar
Jee Jee Li committed
956

957
958
959
960
961
962
963
964
    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(language_model="llm",
                                                connector="resampler",
                                                tower_model="vpm")

Jee Jee Li's avatar
Jee Jee Li committed
965
966
    def init_llm(
        self,
967
        vllm_config: VllmConfig,
968
        prefix: str = "",
Jee Jee Li's avatar
Jee Jee Li committed
969
970
971
    ) -> nn.Module:
        raise NotImplementedError

972
973
974
975
    def init_vision_module(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig],
976
        prefix: str = "",
977
    ) -> nn.Module:
Jee Jee Li's avatar
Jee Jee Li committed
978
979
        raise NotImplementedError

980
981
982
983
984
    def init_resampler(self,
                       embed_dim: int,
                       vision_dim: int,
                       quant_config: Optional[QuantizationConfig] = None,
                       prefix: str = "") -> nn.Module:
Jee Jee Li's avatar
Jee Jee Li committed
985
986
        raise NotImplementedError

987
988
    def get_vision_hidden_states(
            self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
Jee Jee Li's avatar
Jee Jee Li committed
989
990
991
        raise NotImplementedError


992
class MiniCPMV2_0(MiniCPMVBaseModel):
Jee Jee Li's avatar
Jee Jee Li committed
993

994
995
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)
Jee Jee Li's avatar
Jee Jee Li committed
996
997
998
999
        assert self.version == (2, 0)

    def init_llm(
        self,
1000
        vllm_config: VllmConfig,
1001
        prefix: str = "",
Jee Jee Li's avatar
Jee Jee Li committed
1002
    ) -> nn.Module:
1003
        return MiniCPMForCausalLM(vllm_config=vllm_config, prefix=prefix)
Jee Jee Li's avatar
Jee Jee Li committed
1004

1005
1006
1007
1008
    def init_vision_module(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig],
1009
        prefix: str = "",
1010
    ) -> nn.Module:
1011
        # TODO: refactor vision model through timm wrapper from transformers
Jee Jee Li's avatar
Jee Jee Li committed
1012
1013
1014
1015
        try:
            import timm
        except ImportError:
            raise ImportError("Please install timm==0.9.10") from ImportError
1016

Jee Jee Li's avatar
Jee Jee Li committed
1017
1018
1019
1020
1021
1022
1023
1024
1025
        with set_default_torch_dtype(torch.float16):
            model = timm.create_model(
                "vit_so400m_patch14_siglip_384.webli",
                pretrained=False,
                num_classes=0,
                dynamic_img_size=True,
                dynamic_img_pad=True,
            )

1026
1027
        model = model.to(dtype=torch.get_default_dtype())

Jee Jee Li's avatar
Jee Jee Li committed
1028
1029
1030
1031
1032
1033
1034
1035
1036
        if (isinstance(model, timm.models.VisionTransformer)
                and model.attn_pool is not None):
            model.attn_pool = torch.nn.Identity()

        if self.config.drop_vision_last_layer:
            model.blocks = model.blocks[:-1]

        return model

1037
1038
1039
1040
1041
    def init_resampler(self,
                       embed_dim: int,
                       vision_dim: int,
                       quant_config: Optional[QuantizationConfig] = None,
                       prefix: str = "") -> nn.Module:
Jee Jee Li's avatar
Jee Jee Li committed
1042
        with set_default_torch_dtype(torch.float16):
1043
1044
1045
1046
1047
1048
1049
1050
1051
            resampler = Resampler2(embed_dim=embed_dim,
                                   num_heads=embed_dim // 128,
                                   grid_size=int(
                                       math.sqrt(self.config.query_num)),
                                   kv_dim=vision_dim,
                                   adaptive=False,
                                   do_post_projection=True,
                                   quant_config=quant_config,
                                   prefix=prefix)
Jee Jee Li's avatar
Jee Jee Li committed
1052

1053
1054
        return resampler.to(device=current_platform.device_type,
                            dtype=torch.get_default_dtype())
Jee Jee Li's avatar
Jee Jee Li committed
1055

1056
1057
1058
1059
1060
1061
1062
1063
1064
    def get_vision_hidden_states(
            self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
        pixel_values = data["pixel_values"]

        P_h, P_w = self.vpm.patch_embed.patch_size
        dtype: torch.dtype = self.vpm.pos_embed.data.dtype
        num_prefix_tokens = getattr(self.vpm, "num_prefix_tokens", 0)

        res = list[torch.Tensor]()
Jee Jee Li's avatar
Jee Jee Li committed
1065
1066
        for pixel_value in pixel_values:
            H, W = pixel_value[0].shape[-2:]
1067
            tgt_size = (math.ceil(H / P_h), math.ceil(W / P_w))
Jee Jee Li's avatar
Jee Jee Li committed
1068
1069
1070
            vision_embedding = self.vpm.forward_features(
                pixel_value.unsqueeze(0).type(dtype))

1071
1072
1073
            if num_prefix_tokens > 0:
                vision_embedding = vision_embedding[:, num_prefix_tokens:]
            res.append(self.resampler(vision_embedding, tgt_size))
Jee Jee Li's avatar
Jee Jee Li committed
1074

1075
        return torch.vstack(res)
Jee Jee Li's avatar
Jee Jee Li committed
1076
1077


1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
Jee Jee Li's avatar
Jee Jee Li committed
1090

1091
1092
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)
Jee Jee Li's avatar
Jee Jee Li committed
1093
1094
1095
1096
        assert self.version == (2, 5)

    def init_llm(
        self,
1097
        vllm_config: VllmConfig,
1098
        prefix: str = "",
Jee Jee Li's avatar
Jee Jee Li committed
1099
    ) -> nn.Module:
1100
        return LlamaForCausalLM(vllm_config=vllm_config, prefix=prefix)
Jee Jee Li's avatar
Jee Jee Li committed
1101

1102
1103
1104
1105
    def init_vision_module(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig],
1106
        prefix: str = "",
1107
1108
    ) -> nn.Module:
        model = Idefics2VisionTransformer(config.vision_config,
1109
1110
                                          quant_config=quant_config,
                                          prefix=prefix)
Jee Jee Li's avatar
Jee Jee Li committed
1111
1112
1113
1114
        if self.config.drop_vision_last_layer:
            model.encoder.layers = model.encoder.layers[:-1]
        return model

1115
1116
1117
1118
1119
    def init_resampler(self,
                       embed_dim: int,
                       vision_dim: int,
                       quant_config: Optional[QuantizationConfig] = None,
                       prefix: str = "") -> nn.Module:
Jee Jee Li's avatar
Jee Jee Li committed
1120
        with set_default_torch_dtype(torch.float16):
1121
1122
1123
1124
1125
1126
            resampler = Resampler2_5(num_queries=self.config.query_num,
                                     embed_dim=embed_dim,
                                     num_heads=embed_dim // 128,
                                     kv_dim=vision_dim,
                                     quant_config=quant_config,
                                     prefix=prefix)
1127

1128
1129
        return resampler.to(device=current_platform.device_type,
                            dtype=torch.get_default_dtype())
Jee Jee Li's avatar
Jee Jee Li committed
1130

1131
1132
1133
    def get_vision_hidden_states(
            self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
        pixel_values = data["pixel_values"]
Jee Jee Li's avatar
Jee Jee Li committed
1134
1135
        tgt_sizes = data["tgt_sizes"]

1136
1137
1138
1139
1140
        B = len(pixel_values)
        P = pixel_values[0].shape[-2]
        L = max(item.shape[-1] for item in pixel_values)
        device = pixel_values[0].device
        dtype = pixel_values[0].dtype
Jee Jee Li's avatar
Jee Jee Li committed
1141

1142
1143
1144
1145
1146
1147
        all_pixel_values = torch.zeros((B, 3, P, L),
                                       dtype=dtype,
                                       device=device)
        for i, pixel_values_item in enumerate(pixel_values):
            L_item = pixel_values_item.shape[-1]
            all_pixel_values[i, ..., :L_item] = pixel_values_item
Jee Jee Li's avatar
Jee Jee Li committed
1148

1149
1150
1151
        num_patches = tgt_sizes.prod(-1)
        max_patches = num_patches.max().item()
        assert isinstance(max_patches, int)
Jee Jee Li's avatar
Jee Jee Li committed
1152

1153
        patch_attn_mask = torch.zeros((B, max_patches),
Jee Jee Li's avatar
Jee Jee Li committed
1154
1155
                                      dtype=torch.bool,
                                      device=device)
1156
1157
        for i, num_patches_item in enumerate(num_patches):
            patch_attn_mask[i, :num_patches_item] = True
Jee Jee Li's avatar
Jee Jee Li committed
1158

1159
1160
1161
1162
1163
1164
1165
        vision_embedding = self.vpm(
            all_pixel_values,
            patch_attention_mask=patch_attn_mask.unsqueeze(1),
            tgt_sizes=None,
        )

        return self.resampler(vision_embedding, tgt_sizes)
Jee Jee Li's avatar
Jee Jee Li committed
1166
1167


1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
Jee Jee Li's avatar
Jee Jee Li committed
1180

1181
1182
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)
1183
        assert self.version == (2, 6)
Jee Jee Li's avatar
Jee Jee Li committed
1184
1185
1186

    def init_llm(
        self,
1187
        vllm_config: VllmConfig,
1188
        prefix: str = "",
Jee Jee Li's avatar
Jee Jee Li committed
1189
    ) -> nn.Module:
1190
        return Qwen2ForCausalLM(vllm_config=vllm_config, prefix=prefix)
Jee Jee Li's avatar
Jee Jee Li committed
1191

1192
1193
1194
    def init_vision_module(
        self,
        config: PretrainedConfig,
1195
        quant_config: Optional[QuantizationConfig] = None,
1196
        prefix: str = "",
1197
1198
    ) -> nn.Module:
        model = Idefics2VisionTransformer(config.vision_config,
1199
1200
                                          quant_config=quant_config,
                                          prefix=prefix)
Jee Jee Li's avatar
Jee Jee Li committed
1201
1202
1203
1204
        if self.config.drop_vision_last_layer:
            model.encoder.layers = model.encoder.layers[:-1]
        return model

1205
1206
1207
1208
1209
    def init_resampler(self,
                       embed_dim: int,
                       vision_dim: int,
                       quant_config: Optional[QuantizationConfig] = None,
                       prefix: str = "") -> nn.Module:
Jee Jee Li's avatar
Jee Jee Li committed
1210
        with set_default_torch_dtype(torch.float16):
1211
            # The resampler in 2.6 remains consistent with the one in 2.5.
1212
1213
1214
1215
1216
1217
            resampler = Resampler2_5(num_queries=self.config.query_num,
                                     embed_dim=embed_dim,
                                     num_heads=embed_dim // 128,
                                     kv_dim=vision_dim,
                                     quant_config=quant_config,
                                     prefix=prefix)
1218

1219
1220
        return resampler.to(device=current_platform.device_type,
                            dtype=torch.get_default_dtype())
Jee Jee Li's avatar
Jee Jee Li committed
1221

1222
1223
1224
    def get_vision_hidden_states(
            self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
        pixel_values = data["pixel_values"]
Jee Jee Li's avatar
Jee Jee Li committed
1225
1226
        tgt_sizes = data["tgt_sizes"]

1227
1228
1229
1230
1231
        B = len(pixel_values)
        P = pixel_values[0].shape[-2]
        L = max(item.shape[-1] for item in pixel_values)
        device = pixel_values[0].device
        dtype = pixel_values[0].dtype
Jee Jee Li's avatar
Jee Jee Li committed
1232

1233
1234
1235
1236
1237
1238
        all_pixel_values = torch.zeros((B, 3, P, L),
                                       dtype=dtype,
                                       device=device)
        for i, pixel_values_item in enumerate(pixel_values):
            L_item = pixel_values_item.shape[-1]
            all_pixel_values[i, ..., :L_item] = pixel_values_item
Jee Jee Li's avatar
Jee Jee Li committed
1239

1240
1241
1242
        num_patches = tgt_sizes.prod(-1)
        max_patches = num_patches.max().item()
        assert isinstance(max_patches, int)
Jee Jee Li's avatar
Jee Jee Li committed
1243

1244
        patch_attn_mask = torch.zeros((B, max_patches),
Jee Jee Li's avatar
Jee Jee Li committed
1245
1246
                                      dtype=torch.bool,
                                      device=device)
1247
1248
1249
        for i, num_patches_item in enumerate(num_patches):
            patch_attn_mask[i, :num_patches_item] = True

Jee Jee Li's avatar
Jee Jee Li committed
1250
        vision_embedding = self.vpm(
1251
1252
            all_pixel_values,
            patch_attention_mask=patch_attn_mask.unsqueeze(1),
Jee Jee Li's avatar
Jee Jee Li committed
1253
            tgt_sizes=tgt_sizes,
1254
        )
Jee Jee Li's avatar
Jee Jee Li committed
1255
1256
1257
1258

        return self.resampler(vision_embedding, tgt_sizes)


1259
1260
1261
_SUPPORT_VERSION = {
    (2, 0): MiniCPMV2_0,
    (2, 5): MiniCPMV2_5,
1262
    (2, 6): MiniCPMV2_6,
1263
1264
1265
}


1266
1267
1268
1269
1270
@MULTIMODAL_REGISTRY.register_processor(
    MiniCPMVMultiModalProcessor,
    info=MiniCPMVProcessingInfo,
    dummy_inputs=MiniCPMVDummyInputsBuilder)
class MiniCPMV(MiniCPMVBaseModel, SupportsMultiModal, SupportsLoRA):
Jee Jee Li's avatar
Jee Jee Li committed
1271
1272
1273
1274
1275
    """
    Different versions of MiniCPMV use different visual encoders and LLMs,
    which is not conducive to the current integration logic of LoRA and
    bitsandbytes in vLLM. Therefore, it is necessary to separate them.
    """
1276

1277
    def __new__(cls, *, vllm_config: VllmConfig, prefix: str = ""):
1278
        config = vllm_config.model_config.hf_config
Jee Jee Li's avatar
Jee Jee Li committed
1279
1280
1281
1282
1283
1284
1285
1286
1287
        if not hasattr(config, "version"):
            if config.hidden_size == 2304 and config.query_num == 64:
                version = (2, 0)
            else:
                version = (2, 5)
        else:
            version = str(config.version).split(".")
            version = tuple([int(x) for x in version])
        # Dispatch class based on version
1288
1289
        instance_cls = _SUPPORT_VERSION.get(version)
        if instance_cls is None:
1290
1291
            raise ValueError(
                "Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6")
1292
1293
1294
1295
1296
1297
1298

        # quant_config references base class members,
        # so update values before init is called
        cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping)
        cls.embedding_modules.update(instance_cls.embedding_modules)
        cls.embedding_padding_modules += instance_cls.embedding_padding_modules
        return instance_cls(vllm_config=vllm_config, prefix=prefix)