minicpmv.py 36.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Alphi's avatar
Alphi committed
22
"""Inference-only MiniCPM-V model compatible with HuggingFace weights."""
23
24
import math
import re
25
from functools import cached_property, partial
26
from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional,
27
                    Set, Tuple, TypedDict, Union)
28
29

import torch
Alphi's avatar
Alphi committed
30
import torch.types
31
32
from PIL import Image
from torch import nn
33
from transformers import PretrainedConfig
34
from typing_extensions import NotRequired
35
36

from vllm.attention import AttentionMetadata
37
from vllm.config import VllmConfig
38
39
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
                         InputContext, token_inputs)
40
from vllm.model_executor.layers.quantization import QuantizationConfig
41
from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2,
42
                                                  get_2d_sincos_pos_embed)
Joe Runde's avatar
Joe Runde committed
43
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
Jee Jee Li's avatar
Jee Jee Li committed
44
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
45
46
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.models.minicpm import MiniCPMForCausalLM
47
from vllm.model_executor.models.module_mapping import MultiModelKeys
48
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
49
from vllm.model_executor.sampling_metadata import SamplingMetadata
50
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
51
52
from vllm.multimodal.image import cached_get_image_processor
from vllm.multimodal.utils import cached_get_tokenizer
53
from vllm.sequence import IntermediateTensors, SequenceData
54

Jee Jee Li's avatar
Jee Jee Li committed
55
from .idefics2_vision_model import Idefics2VisionTransformer
56
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
57
from .utils import AutoWeightsLoader, maybe_prefix
58

59
RawImageType = Union[Image.Image, torch.Tensor]
60

61
62

class MiniCPMVRawImageInput(TypedDict):
63
    """Input mapper input with auxiliary data for computing image bounds."""
64
    image: RawImageType
65
66
67
68
69
70
71
72

    # Image bounds token ids in 0-dim scaler tensor.
    im_start_id: torch.Tensor
    im_end_id: torch.Tensor
    slice_start_id: NotRequired[torch.Tensor]
    slice_end_id: NotRequired[torch.Tensor]


Jee Jee Li's avatar
Jee Jee Li committed
73
class MiniCPMVImagePixelInputs(TypedDict):
74
75
    type: Literal["pixel_values"]
    data: List[torch.Tensor]
Jee Jee Li's avatar
Jee Jee Li committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    """
    Shape: `(batch_size * num_images, num_channels, height, width)`

    Note that the image size may vary, so we pass it as a list
    instead of a batched tensor.
    """

    image_bounds: torch.Tensor
    """
    Shape: `(batch_size * num_images, 2)`

    This should be in `(start, stop)` format.
    """

    tgt_sizes: torch.Tensor
    """
    Shape: `(batch_size * num_images, 2)`

    This should be in `(height, width)` format.
    """


98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
class MiniCPMVImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
    """
    Shape: `(batch_size * num_images, image_feature_size, hidden_size)`

    `hidden_size` must match the hidden size of language model backbone.
    instead of a batched tensor.
    """

    image_bounds: torch.Tensor
    """
    Shape: `(batch_size * num_images, 2)`

    This should be in `(start, stop)` format.
    """


MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs,
                            MiniCPMVImageEmbeddingInputs]

Jee Jee Li's avatar
Jee Jee Li committed
119
120
121
122
123
DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)


class Resampler2_5(BaseResampler):

124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
    def __init__(self,
                 num_queries: int,
                 embed_dim: int,
                 num_heads: int,
                 kv_dim: Optional[int] = None,
                 norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
                 max_size: Tuple[int, int] = (70, 70),
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = "") -> None:
        super().__init__(num_queries,
                         embed_dim,
                         num_heads,
                         kv_dim,
                         norm_layer,
                         quant_config=quant_config,
                         prefix=prefix)
Jee Jee Li's avatar
Jee Jee Li committed
140
141
142

        self.max_size = max_size
        self._set_2d_pos_cache(self.max_size)
143

Alphi's avatar
Alphi committed
144
145
    def _set_2d_pos_cache(self,
                          max_size: Tuple[int, int],
Jee Jee Li's avatar
Jee Jee Li committed
146
147
148
149
150
                          device: torch.types.Device = "cpu") -> None:
        pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim,
                                                max_size,
                                                version=(2, 5))
        pos_embed = torch.from_numpy(pos_embed_arr).float().to(device)
151
152
        self.register_buffer("pos_embed", pos_embed, persistent=False)

Alphi's avatar
Alphi committed
153
    def _adjust_pos_cache(self, tgt_sizes: torch.Tensor,
Jee Jee Li's avatar
Jee Jee Li committed
154
155
156
157
158
                          device: torch.types.Device) -> None:
        max_h = tgt_sizes[:, 0].max().item()
        max_w = tgt_sizes[:, 1].max().item()
        assert isinstance(max_h, int) and isinstance(max_w, int)

159
        if max_h > self.max_size[0] or max_w > self.max_size[1]:
Jee Jee Li's avatar
Jee Jee Li committed
160
            self.max_size = (
161
                max(max_h, self.max_size[0]),
Jee Jee Li's avatar
Jee Jee Li committed
162
163
                max(max_w, self.max_size[1]),
            )
164
165
            self._set_2d_pos_cache(self.max_size, device)

Jee Jee Li's avatar
Jee Jee Li committed
166
167
    def forward(self, x: torch.Tensor,
                tgt_sizes: torch.Tensor) -> torch.Tensor:
168
169
170
171
172
173
174
175
176
177
        assert x.shape[0] == tgt_sizes.shape[0]
        bs = x.shape[0]

        device = x.device
        dtype = x.dtype

        patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1]

        self._adjust_pos_cache(tgt_sizes, device=device)

Jee Jee Li's avatar
Jee Jee Li committed
178
179
180
        max_patch_len = patch_len.max().item()
        assert isinstance(max_patch_len, int)

181
182
183
184
185
186
        key_padding_mask = torch.zeros((bs, max_patch_len),
                                       dtype=torch.bool,
                                       device=device)

        pos_embed = []
        for i in range(bs):
Jee Jee Li's avatar
Jee Jee Li committed
187
            tgt_h, tgt_w = tgt_sizes[i].tolist()
188
189
190
191
192
193
194
195
            pos_embed.append(self.pos_embed[:tgt_h, :tgt_w, :].reshape(
                (tgt_h * tgt_w, -1)).to(dtype))  # patches * D
            key_padding_mask[i, patch_len[i]:] = True
        pos_embed = torch.nn.utils.rnn.pad_sequence(pos_embed,
                                                    batch_first=True,
                                                    padding_value=0.0).permute(
                                                        1, 0,
                                                        2)  # BLD => L * B * D
Jee Jee Li's avatar
Jee Jee Li committed
196
        x, _ = self.kv_proj(x)  # B * L * D
197
198
199
200
201
202
203
204
        x = self.ln_kv(x).permute(1, 0, 2)  # L * B * D

        q = self.ln_q(self.query)  # Q * D

        out = self.attn(
            self._repeat(q, bs),  # Q * B * D
            x + pos_embed,  # L * B * D +  L * B * D
            x,
Jee Jee Li's avatar
Jee Jee Li committed
205
206
            key_padding_mask=key_padding_mask,
        )[0]
207
208
209
210
211
212
213
214
        #  out: Q * B * D
        x = out.permute(1, 0, 2)  # B * Q * D

        x = self.ln_post(x)
        x = x @ self.proj
        return x


215
def _build_image_input(ctx: InputContext,
216
                       image: RawImageType) -> MiniCPMVRawImageInput:
217
218
219
220
    tokenizer = cached_get_tokenizer(
        ctx.model_config.tokenizer,
        trust_remote_code=ctx.model_config.trust_remote_code)
    if hasattr(tokenizer, "slice_start_id"):
221
        return MiniCPMVRawImageInput(
222
223
224
225
226
227
            image=image,
            im_start_id=torch.tensor(tokenizer.im_start_id),
            im_end_id=torch.tensor(tokenizer.im_end_id),
            slice_start_id=torch.tensor(tokenizer.slice_start_id),
            slice_end_id=torch.tensor(tokenizer.slice_end_id))
    else:
228
229
230
231
        return MiniCPMVRawImageInput(
            image=image,
            im_start_id=torch.tensor(tokenizer.im_start_id),
            im_end_id=torch.tensor(tokenizer.im_end_id))
232
233


234
235
236
237
238
239
240
241
242
243
244
245
246
247
def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
    version_float = getattr(config, "version", None)

    # The old configs do not include version number
    # TODO: Remove this after the HF repos are updated
    if version_float is None:
        if config.hidden_size == 2304 and config.query_num == 64:
            return (2, 0)
        return (2, 5)

    version_str = str(version_float)
    return tuple(int(x) for x in version_str.split("."))


248
def get_max_minicpmv_image_tokens(ctx: InputContext):
249
    hf_config = ctx.get_hf_config()
250
251
252
    return getattr(hf_config, "query_num", 64)


253
def dummy_seq_data_for_minicpmv(seq_len: int, num_images: int):
254
    return SequenceData.from_prompt_token_counts((0, seq_len))
255
256


257
258
def dummy_image_for_minicpmv(ctx: InputContext, hf_config: PretrainedConfig,
                             num_images: int):
259
    width = height = hf_config.image_size
260
261
262
263
    image = _build_image_input(ctx,
                               image=Image.new("RGB", (width, height),
                                               color=0))
    return {"image": [image] if num_images == 1 else [image] * num_images}
264
265


266
267
def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int,
                            mm_counts: Mapping[str, int]):
268
    hf_config = ctx.get_hf_config()
269
    num_images = mm_counts["image"]
270

271
    seq_data = dummy_seq_data_for_minicpmv(seq_len, num_images)
272
    mm_data = dummy_image_for_minicpmv(ctx, hf_config, num_images)
273

274
    return DummyData(seq_data, mm_data)
275
276


277
278
def input_processor_for_minicpmv(ctx: InputContext, inputs: DecoderOnlyInputs):
    multi_modal_data = inputs.get("multi_modal_data")
279
    if multi_modal_data is None or "image" not in multi_modal_data:
280
        return inputs
281
    model_config = ctx.model_config
282
    version = get_version_by_config(model_config.hf_config)
283
284
285
    tokenizer = cached_get_tokenizer(
        model_config.tokenizer,
        trust_remote_code=model_config.trust_remote_code)
286
287
288
289
    image_processor = cached_get_image_processor(model_config.tokenizer)

    def get_placeholder(image_size: Tuple[int, int], num_image: int):
        if version == (2, 0) or version == (2, 5):
290
291
292
            return image_processor.get_slice_image_placeholder(image_size)
        return image_processor.get_slice_image_placeholder(
            image_size, num_image)
293

294
295
    prompt = inputs.get("prompt")
    token_ids = inputs.get("prompt_token_ids")
296
297
298
299
    if prompt is None:
        prompt = tokenizer.decode(token_ids)

    pattern = "(<image>./</image>)"
300
    images = multi_modal_data["image"]
301
    image_tags = re.findall(pattern, prompt)
Jee Jee Li's avatar
Jee Jee Li committed
302
303
304
305
    if len(image_tags) == 0:
        new_token_ids = token_ids
        new_prompt = prompt
    else:
306
307
308
309
310
311
312
313
        if isinstance(images, dict):
            image_size_list = images.get("image_size_list")
            images = [images.get("image_embeds")]
        else:
            if isinstance(images, Image.Image):
                images = [images]
            image_size_list = [image.size for image in images]

Jee Jee Li's avatar
Jee Jee Li committed
314
        text_chunks = prompt.split(pattern)
315
        new_prompt_chunks: List[str] = []
316
        for i in range(len(image_size_list)):
317
318
            new_prompt_chunks += [
                text_chunks[i],
319
                get_placeholder(image_size_list[i], i)
320
321
322
            ]
        new_prompt_chunks.append(text_chunks[-1])
        new_prompt = "".join(new_prompt_chunks)
Jee Jee Li's avatar
Jee Jee Li committed
323
324
        new_token_ids = tokenizer.encode(new_prompt)

325
326
327
328
    multi_modal_data["image"] = [
        _build_image_input(ctx, image) for image in images
    ]

329
    return token_inputs(
Jee Jee Li's avatar
Jee Jee Li committed
330
331
332
333
        prompt_token_ids=new_token_ids,
        prompt=new_prompt,
        multi_modal_data=multi_modal_data,
    )
334
335


336
337
338
339
340
341
342
343
344
345
346
347
def input_mapper_for_minicpmv(ctx: InputContext, data: object):
    model_config = ctx.model_config

    image_processor = cached_get_image_processor(
        model_config.model, trust_remote_code=model_config.trust_remote_code)
    if image_processor is None:
        raise RuntimeError("No HuggingFace processor is available "
                           "to process the image object")

    if not isinstance(data, list):
        raise ValueError(
            "Image input must be list of MiniCPMVImageInput, got (%s)", data)
348
349
350
351
352
353
354
355
356

    if len(data) > 0 and isinstance(data[0]['image'], torch.Tensor):
        batch_data = {
            "image_embeds": data[0]['image'],
        }
    else:
        batch_data = image_processor \
            .preprocess([img["image"] for img in data], return_tensors="pt") \
            .data
357
358
359
360
361
362
363
364

    if len(data) > 0:
        batch_data["im_start_id"] = data[0]["im_start_id"]
        batch_data["im_end_id"] = data[0]["im_end_id"]
        if "slice_start_id" in data[0]:
            batch_data["slice_start_id"] = data[0]["slice_start_id"]
            batch_data["slice_end_id"] = data[0]["slice_end_id"]

365
    return MultiModalKwargs(batch_data)
366
367


368
class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
Jee Jee Li's avatar
Jee Jee Li committed
369
370
371
372
    """
    The abstract class of MiniCPMV can only be inherited, but cannot be
    instantiated.
    """
373

374
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
375
376
377
        config = vllm_config.model_config.hf_config
        multimodal_config = vllm_config.model_config.multimodal_config
        quant_config = vllm_config.quant_config
378
        super().__init__()
379
380
381
382
        # All MiniCPM-V models disable `tie_word_embeddings` but
        # `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot
        # check `tie_word_embeddings` until vLLM integrate MiniCPM-V model
        # and config class
383
384
385
        self.config = config
        self.multimodal_config = multimodal_config

386
        self.version = get_version_by_config(self.config)
387
388
389
390
391
        self.llm = self.init_llm(vllm_config=vllm_config,
                                 prefix=maybe_prefix(prefix, "llm"))
        self.vpm = self.init_vision_module(config,
                                           quant_config,
                                           prefix=maybe_prefix(prefix, "vpm"))
Jee Jee Li's avatar
Jee Jee Li committed
392
393
        self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else
                           self.vpm.embeddings.embed_dim)
Alphi's avatar
Alphi committed
394
        self.embed_dim = self.config.hidden_size
395

396
397
398
        self.resampler = self.init_resampler(self.embed_dim,
                                             self.vision_dim,
                                             quant_config=quant_config,
399
400
                                             prefix=maybe_prefix(
                                                 prefix, "resampler"))
401

402
403
404
        self.make_empty_intermediate_tensors = (
            self.llm.make_empty_intermediate_tensors)

405
406
407
408
409
410
411
    @cached_property
    def sampler(self):
        if hasattr(self.llm, "sampler"):
            return self.llm.sampler

        return get_sampler()

Jee Jee Li's avatar
Jee Jee Li committed
412
413
414
    def get_embedding(
        self,
        input_ids: torch.Tensor,
415
        image_inputs: Optional[MiniCPMVImageInputs],
Jee Jee Li's avatar
Jee Jee Li committed
416
    ) -> Tuple[torch.Tensor, torch.Tensor]:
417
        vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids)
Jee Jee Li's avatar
Jee Jee Li committed
418
419
420

        if image_inputs is None:  # No image
            vision_hidden_states = torch.tensor([], device=input_ids.device)
421
        else:
422
423
424
425
426
427
            if image_inputs["type"] == "image_embeds":
                vision_hidden_states = (image_inputs["data"].type(
                    vlm_embedding.dtype).to(vlm_embedding.device))
            else:
                vision_hidden_states = self.get_vision_hidden_states(
                    image_inputs)
Jee Jee Li's avatar
Jee Jee Li committed
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442

            # See NOTE in _parse_and_validate_inputs
            image_bounds = image_inputs["image_bounds"]
            if len(image_bounds) > 0:
                image_indices = torch.stack([
                    torch.arange(start, end, dtype=torch.long)
                    for start, end in image_bounds.tolist()
                ]).to(vlm_embedding.device)
                vlm_embedding.scatter_(
                    0,
                    image_indices.view(-1, 1).repeat(1,
                                                     vlm_embedding.shape[-1]),
                    vision_hidden_states.view(-1,
                                              vision_hidden_states.shape[-1]),
                )
443

Jee Jee Li's avatar
Jee Jee Li committed
444
        return vlm_embedding, vision_hidden_states
445

446
447
448
449
450
451
452
453
454
455
456
457
458
459
    def _get_image_bounds(
            self,
            input_ids: torch.Tensor,
            im_start_id: torch.Tensor,
            im_end_id: torch.Tensor,
            slice_start_id: Optional[torch.Tensor] = None,
            slice_end_id: Optional[torch.Tensor] = None) -> torch.Tensor:
        # All the images in the batch should share the same special image
        # bound token ids.
        start_cond = input_ids == im_start_id[0]
        end_cond = input_ids == im_end_id[0]
        if slice_start_id is not None:
            start_cond |= (input_ids == slice_start_id[0])
            end_cond |= (input_ids == slice_end_id[0])
Alphi's avatar
Alphi committed
460

Jee Jee Li's avatar
Jee Jee Li committed
461
        image_start_tokens, = torch.where(start_cond)
462
        image_start_tokens += 1
Jee Jee Li's avatar
Jee Jee Li committed
463
        image_end_tokens, = torch.where(end_cond)
Alphi's avatar
Alphi committed
464
        valid_image_nums = max(len(image_start_tokens), len(image_end_tokens))
Jee Jee Li's avatar
Jee Jee Li committed
465

466
        if valid_image_nums == 0:
Jee Jee Li's avatar
Jee Jee Li committed
467
468
469
            return torch.zeros((0, 2), device=input_ids.device)

        return torch.hstack([
470
471
472
473
            image_start_tokens[:valid_image_nums].unsqueeze(-1),
            image_end_tokens[:valid_image_nums].unsqueeze(-1),
        ])

Jee Jee Li's avatar
Jee Jee Li committed
474
475
476
477
    def _parse_and_validate_inputs(
        self,
        input_ids: torch.Tensor,
        **kwargs: object,
478
    ) -> Optional[MiniCPMVImageInputs]:
Jee Jee Li's avatar
Jee Jee Li committed
479
480
        pixel_values = kwargs.pop("pixel_values", [])
        tgt_sizes = kwargs.pop("tgt_sizes", [])
481
482
483
484
485
486
487
        im_start_id = kwargs.pop("im_start_id", None)
        im_end_id = kwargs.pop("im_end_id", None)
        slice_start_id = kwargs.pop("slice_start_id", None)
        slice_end_id = kwargs.pop("slice_end_id", None)
        image_embeds = kwargs.pop("image_embeds", None)

        if image_embeds is not None:
488
489
490
491
492
493
            if not isinstance(image_embeds, (torch.Tensor, list)):
                raise ValueError(f"Incorrect type of image embeds. "
                                 f"Got type: {type(image_embeds)}")
            if isinstance(image_embeds, list):
                image_embeds = torch.concat(image_embeds)

494
495
496
497
498
499
500
            return MiniCPMVImageEmbeddingInputs(
                image_bounds=self._get_image_bounds(input_ids, im_start_id,
                                                    im_end_id, slice_start_id,
                                                    slice_end_id),
                data=image_embeds,
                type="image_embeds",
            )
Jee Jee Li's avatar
Jee Jee Li committed
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515

        if not isinstance(pixel_values, (torch.Tensor, list)):
            raise ValueError("Incorrect type of pixel values. "
                             f"Got type: {type(pixel_values)}")

        if not isinstance(tgt_sizes, (torch.Tensor, list)):
            raise ValueError("Incorrect type of target sizes. "
                             f"Got type: {type(tgt_sizes)}")

        if len(pixel_values) != len(tgt_sizes):
            raise ValueError("Inconsistent batch lengths, found: "
                             f"{len(pixel_values)} vs. {len(tgt_sizes)}")

        pixel_values_flat: List[torch.Tensor] = []
        tgt_sizes_flat: List[torch.Tensor] = []
516
517
518
519
520
521
522
523
        for pixel_b, tgt_b in zip(pixel_values, tgt_sizes):
            if len(pixel_b) != len(tgt_b):
                raise ValueError("Inconsistent N lengths, found: "
                                 f"{len(pixel_b)} vs {len(tgt_b)}")

            for pixel_n, tgt_n in zip(pixel_b, tgt_b):
                pixel_values_flat += pixel_n
                tgt_sizes_flat += tgt_n
Jee Jee Li's avatar
Jee Jee Li committed
524
525
526
527
528
529
530
531
532
533
534

        # NOTE: Input IDs does not contain image tokens during memory profiling,
        # so we allow it to be empty
        if len(pixel_values_flat) != len(tgt_sizes_flat):
            raise ValueError("Inconsistent flattened lengths, found: "
                             f"{len(pixel_values_flat)} vs. "
                             f"{len(tgt_sizes_flat)}")

        if len(pixel_values_flat) == 0:
            return None

535
536
537
538
539
540
541
        if im_start_id is None:
            return None

        return MiniCPMVImagePixelInputs(
            image_bounds=self._get_image_bounds(input_ids, im_start_id,
                                                im_end_id, slice_start_id,
                                                slice_end_id),
542
            data=pixel_values_flat,
Jee Jee Li's avatar
Jee Jee Li committed
543
            tgt_sizes=torch.stack(tgt_sizes_flat),
544
            type="pixel_values",
Jee Jee Li's avatar
Jee Jee Li committed
545
        )
546
547
548
549
550
551
552
553

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
        intermediate_tensors: Optional[IntermediateTensors] = None,
Jee Jee Li's avatar
Jee Jee Li committed
554
555
        **kwargs: Any,
    ) -> torch.Tensor:
556
557
558
559
        if intermediate_tensors is not None:
            vlm_embeddings = None
        else:
            image_inputs = self._parse_and_validate_inputs(input_ids, **kwargs)
Jee Jee Li's avatar
Jee Jee Li committed
560

561
            vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs)
Jee Jee Li's avatar
Jee Jee Li committed
562

563
564
565
566
567
        # always pass the input via `inputs_embeds`
        # to make sure the computation graph is consistent
        # for `torch.compile` integration
        input_ids = None

568
        output = self.llm.model(
569
            input_ids=input_ids,
Jee Jee Li's avatar
Jee Jee Li committed
570
571
572
573
574
575
            positions=positions,
            kv_caches=kv_caches,
            attn_metadata=attn_metadata,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=vlm_embeddings,
        )
576
577
        return output

578
579
580
581
582
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
583
        return self.llm.compute_logits(hidden_states, sampling_metadata)
584
585
586
587
588
589

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
Alphi's avatar
Alphi committed
590
        next_tokens = self.sampler(logits, sampling_metadata)
591
592
        return next_tokens

593
594
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
595
596
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)
Jee Jee Li's avatar
Jee Jee Li committed
597

598
599
600
601
602
603
604
605
    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(language_model="llm",
                                                connector="resampler",
                                                tower_model="vpm")

Jee Jee Li's avatar
Jee Jee Li committed
606
607
    def init_llm(
        self,
608
        vllm_config: VllmConfig,
609
        prefix: str = "",
Jee Jee Li's avatar
Jee Jee Li committed
610
611
612
    ) -> nn.Module:
        raise NotImplementedError

613
614
615
616
    def init_vision_module(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig],
617
        prefix: str = "",
618
    ) -> nn.Module:
Jee Jee Li's avatar
Jee Jee Li committed
619
620
        raise NotImplementedError

621
622
623
624
625
    def init_resampler(self,
                       embed_dim: int,
                       vision_dim: int,
                       quant_config: Optional[QuantizationConfig] = None,
                       prefix: str = "") -> nn.Module:
Jee Jee Li's avatar
Jee Jee Li committed
626
627
628
629
630
631
632
633
634
635
        raise NotImplementedError

    def get_vision_embedding(
        self,
        pixel_values: List[torch.Tensor],
        patch_attn_mask: Optional[torch.Tensor] = None,
        tgt_sizes: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        raise NotImplementedError

636
637
    def get_vision_hidden_states(self,
                                 data: MiniCPMVImageInputs) -> torch.Tensor:
Jee Jee Li's avatar
Jee Jee Li committed
638
639
640
        raise NotImplementedError


641
class MiniCPMV2_0(MiniCPMVBaseModel):
Jee Jee Li's avatar
Jee Jee Li committed
642

643
644
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)
Jee Jee Li's avatar
Jee Jee Li committed
645
646
647
648
        assert self.version == (2, 0)

    def init_llm(
        self,
649
        vllm_config: VllmConfig,
650
        prefix: str = "",
Jee Jee Li's avatar
Jee Jee Li committed
651
    ) -> nn.Module:
652
        return MiniCPMForCausalLM(vllm_config=vllm_config, prefix=prefix)
Jee Jee Li's avatar
Jee Jee Li committed
653

654
655
656
657
    def init_vision_module(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig],
658
        prefix: str = "",
659
    ) -> nn.Module:
660
        # TODO: refactor vision model through timm wrapper from transformers
Jee Jee Li's avatar
Jee Jee Li committed
661
662
663
664
        try:
            import timm
        except ImportError:
            raise ImportError("Please install timm==0.9.10") from ImportError
665

Jee Jee Li's avatar
Jee Jee Li committed
666
667
668
669
670
671
672
673
674
        with set_default_torch_dtype(torch.float16):
            model = timm.create_model(
                "vit_so400m_patch14_siglip_384.webli",
                pretrained=False,
                num_classes=0,
                dynamic_img_size=True,
                dynamic_img_pad=True,
            )

675
676
        model = model.to(dtype=torch.get_default_dtype())

Jee Jee Li's avatar
Jee Jee Li committed
677
678
679
680
681
682
683
684
685
        if (isinstance(model, timm.models.VisionTransformer)
                and model.attn_pool is not None):
            model.attn_pool = torch.nn.Identity()

        if self.config.drop_vision_last_layer:
            model.blocks = model.blocks[:-1]

        return model

686
687
688
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_tokens(input_ids)

689
690
691
692
693
    def init_resampler(self,
                       embed_dim: int,
                       vision_dim: int,
                       quant_config: Optional[QuantizationConfig] = None,
                       prefix: str = "") -> nn.Module:
Jee Jee Li's avatar
Jee Jee Li committed
694
        with set_default_torch_dtype(torch.float16):
695
696
697
698
699
700
701
702
703
            resampler = Resampler2(embed_dim=embed_dim,
                                   num_heads=embed_dim // 128,
                                   grid_size=int(
                                       math.sqrt(self.config.query_num)),
                                   kv_dim=vision_dim,
                                   adaptive=False,
                                   do_post_projection=True,
                                   quant_config=quant_config,
                                   prefix=prefix)
Jee Jee Li's avatar
Jee Jee Li committed
704

705
        return resampler.to(device="cuda", dtype=torch.get_default_dtype())
Jee Jee Li's avatar
Jee Jee Li committed
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729

    def get_vision_embedding(
        self,
        pixel_values: List[torch.Tensor],
        patch_attn_mask: Optional[torch.Tensor] = None,
        tgt_sizes: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        res = []
        dtype = self.vpm.pos_embed.data.dtype
        for pixel_value in pixel_values:
            H, W = pixel_value[0].shape[-2:]
            tgt_size = (
                math.ceil(H / self.vpm.patch_embed.patch_size[0]),
                math.ceil(W / self.vpm.patch_embed.patch_size[0]),
            )
            vision_embedding = self.vpm.forward_features(
                pixel_value.unsqueeze(0).type(dtype))
            if (hasattr(self.vpm, "num_prefix_tokens")
                    and self.vpm.num_prefix_tokens > 0):
                vision_embedding = vision_embedding[:, self.vpm.
                                                    num_prefix_tokens:]
            res.append(self.resampler(vision_embedding, tgt_size))
        return torch.vstack(res)

730
731
732
    def get_vision_hidden_states(self,
                                 data: MiniCPMVImageInputs) -> torch.Tensor:
        pixel_values = data["data"]
Jee Jee Li's avatar
Jee Jee Li committed
733
734
735
736

        return self.get_vision_embedding(pixel_values)


737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
    # LoRA specific attributes
    supported_lora_modules = [
        # vision encoder
        "fc1",
        "fc2",
        "out_proj",
        # language model
        "qkv_proj",  # same name with vision encoder
        "o_proj",
        "gate_up_proj",
        "down_proj",
        # resampler
        "kv_proj",
    ]
763

764
765
    embedding_modules = {}
    embedding_padding_modules = []
Jee Jee Li's avatar
Jee Jee Li committed
766

767
768
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)
Jee Jee Li's avatar
Jee Jee Li committed
769
770
771
772
        assert self.version == (2, 5)

    def init_llm(
        self,
773
        vllm_config: VllmConfig,
774
        prefix: str = "",
Jee Jee Li's avatar
Jee Jee Li committed
775
    ) -> nn.Module:
776
        return LlamaForCausalLM(vllm_config=vllm_config, prefix=prefix)
Jee Jee Li's avatar
Jee Jee Li committed
777

778
779
780
781
    def init_vision_module(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig],
782
        prefix: str = "",
783
784
    ) -> nn.Module:
        model = Idefics2VisionTransformer(config.vision_config,
785
786
                                          quant_config=quant_config,
                                          prefix=prefix)
Jee Jee Li's avatar
Jee Jee Li committed
787
788
789
790
        if self.config.drop_vision_last_layer:
            model.encoder.layers = model.encoder.layers[:-1]
        return model

791
792
793
794
795
    def init_resampler(self,
                       embed_dim: int,
                       vision_dim: int,
                       quant_config: Optional[QuantizationConfig] = None,
                       prefix: str = "") -> nn.Module:
Jee Jee Li's avatar
Jee Jee Li committed
796
        with set_default_torch_dtype(torch.float16):
797
798
799
800
801
802
            resampler = Resampler2_5(num_queries=self.config.query_num,
                                     embed_dim=embed_dim,
                                     num_heads=embed_dim // 128,
                                     kv_dim=vision_dim,
                                     quant_config=quant_config,
                                     prefix=prefix)
803
804

        return resampler.to(device="cuda", dtype=torch.get_default_dtype())
Jee Jee Li's avatar
Jee Jee Li committed
805
806
807
808
809
810
811
812
813
814
815
816

    def get_vision_embedding(
        self,
        pixel_values: List[torch.Tensor],
        patch_attn_mask: Optional[torch.Tensor] = None,
        tgt_sizes: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        vision_embedding = self.vpm(pixel_values,
                                    patch_attention_mask=patch_attn_mask)
        vision_embedding = self.resampler(vision_embedding, tgt_sizes)
        return vision_embedding

817
818
819
    def get_vision_hidden_states(self,
                                 data: MiniCPMVImageInputs) -> torch.Tensor:
        pixel_values = data["data"]
Jee Jee Li's avatar
Jee Jee Li committed
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
        tgt_sizes = data["tgt_sizes"]

        device = self.vpm.embeddings.position_embedding.weight.device
        dtype = self.vpm.embeddings.position_embedding.weight.dtype
        all_pixel_values_lst = [
            i.flatten(end_dim=1).permute(1, 0) for i in pixel_values
        ]

        max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
        assert isinstance(max_patches, int)

        all_pixel_values = torch.nn.utils.rnn.pad_sequence(
            all_pixel_values_lst, batch_first=True, padding_value=0.0)
        B, L, _ = all_pixel_values.shape
        all_pixel_values = all_pixel_values.permute(0, 2,
                                                    1).reshape(B, 3, -1, L)

        patch_attn_mask = torch.zeros((B, 1, max_patches),
                                      dtype=torch.bool,
                                      device=device)
        for i in range(B):
            patch_attn_mask[i, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True

        return self.get_vision_embedding(all_pixel_values.type(dtype),
                                         patch_attn_mask, tgt_sizes)


847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
    # LoRA specific attributes
    supported_lora_modules = [
        # vision encoder
        "fc1",
        "fc2",
        "out_proj",
        # language model
        "qkv_proj",  # same name with vision encoder
        "o_proj",
        "gate_up_proj",
        "down_proj",
        # resampler
        "kv_proj",
    ]

    embedding_modules = {}
    embedding_padding_modules = []
Jee Jee Li's avatar
Jee Jee Li committed
876

877
878
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)
879
        assert self.version == (2, 6)
Jee Jee Li's avatar
Jee Jee Li committed
880
881
882

    def init_llm(
        self,
883
        vllm_config: VllmConfig,
884
        prefix: str = "",
Jee Jee Li's avatar
Jee Jee Li committed
885
    ) -> nn.Module:
886
        return Qwen2ForCausalLM(vllm_config=vllm_config, prefix=prefix)
Jee Jee Li's avatar
Jee Jee Li committed
887

888
889
890
891
    def init_vision_module(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig],
892
        prefix: str = "",
893
894
    ) -> nn.Module:
        model = Idefics2VisionTransformer(config.vision_config,
895
896
                                          quant_config=quant_config,
                                          prefix=prefix)
Jee Jee Li's avatar
Jee Jee Li committed
897
898
899
900
        if self.config.drop_vision_last_layer:
            model.encoder.layers = model.encoder.layers[:-1]
        return model

901
902
903
904
905
    def init_resampler(self,
                       embed_dim: int,
                       vision_dim: int,
                       quant_config: Optional[QuantizationConfig] = None,
                       prefix: str = "") -> nn.Module:
Jee Jee Li's avatar
Jee Jee Li committed
906
        with set_default_torch_dtype(torch.float16):
907
            # The resampler in 2.6 remains consistent with the one in 2.5.
908
909
910
911
912
913
            resampler = Resampler2_5(num_queries=self.config.query_num,
                                     embed_dim=embed_dim,
                                     num_heads=embed_dim // 128,
                                     kv_dim=vision_dim,
                                     quant_config=quant_config,
                                     prefix=prefix)
914
915

        return resampler.to(device="cuda", dtype=torch.get_default_dtype())
Jee Jee Li's avatar
Jee Jee Li committed
916
917
918
919
920
921
922
923
924
925
926

    def get_vision_embedding(
        self,
        pixel_values: List[torch.Tensor],
        patch_attn_mask: Optional[torch.Tensor] = None,
        tgt_sizes: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        vision_embedding = self.vpm(
            pixel_values,
            patch_attention_mask=patch_attn_mask,
            tgt_sizes=tgt_sizes,
927
        )
Jee Jee Li's avatar
Jee Jee Li committed
928
929
        return vision_embedding

930
931
932
    def get_vision_hidden_states(self,
                                 data: MiniCPMVImageInputs) -> torch.Tensor:
        pixel_values = data["data"]
Jee Jee Li's avatar
Jee Jee Li committed
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
        tgt_sizes = data["tgt_sizes"]

        device = self.vpm.embeddings.position_embedding.weight.device
        dtype = self.vpm.embeddings.position_embedding.weight.dtype
        all_pixel_values_lst = [
            i.flatten(end_dim=1).permute(1, 0) for i in pixel_values
        ]

        max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
        assert isinstance(max_patches, int)

        all_pixel_values = torch.nn.utils.rnn.pad_sequence(
            all_pixel_values_lst, batch_first=True, padding_value=0.0)
        B, L, _ = all_pixel_values.shape
        all_pixel_values = all_pixel_values.permute(0, 2,
                                                    1).reshape(B, 3, -1, L)

        patch_attn_mask = torch.zeros((B, 1, max_patches),
                                      dtype=torch.bool,
                                      device=device)
        for i in range(B):
            patch_attn_mask[i, 0, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True
        vision_embedding = self.vpm(
            all_pixel_values.type(dtype),
            patch_attention_mask=patch_attn_mask,
            tgt_sizes=tgt_sizes,
959
        )
Jee Jee Li's avatar
Jee Jee Li committed
960
961
962
963

        return self.resampler(vision_embedding, tgt_sizes)


964
965
966
967
968
969
970
_SUPPORT_VERSION = {
    (2, 0): MiniCPMV2_0,
    (2, 5): MiniCPMV2_5,
    (2, 6): MiniCPMV2_6
}


971
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_minicpmv)
Jee Jee Li's avatar
Jee Jee Li committed
972
973
974
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_minicpmv_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_minicpmv)
@INPUT_REGISTRY.register_input_processor(input_processor_for_minicpmv)
975
class MiniCPMV(MiniCPMVBaseModel, SupportsLoRA):
Jee Jee Li's avatar
Jee Jee Li committed
976
977
978
979
980
    """
    Different versions of MiniCPMV use different visual encoders and LLMs,
    which is not conducive to the current integration logic of LoRA and
    bitsandbytes in vLLM. Therefore, it is necessary to separate them.
    """
981
982
983
984
985
986
987
    # Ensure that the LoRA support check passes when the class is not
    # initialized, but set all these attributes to empty.
    packed_modules_mapping = {}
    supported_lora_modules = []
    embedding_modules = {}
    embedding_padding_modules = []

988
    def __new__(cls, *, vllm_config: VllmConfig, prefix: str = ""):
989
        config = vllm_config.model_config.hf_config
Jee Jee Li's avatar
Jee Jee Li committed
990
991
992
993
994
995
996
997
998
        if not hasattr(config, "version"):
            if config.hidden_size == 2304 and config.query_num == 64:
                version = (2, 0)
            else:
                version = (2, 5)
        else:
            version = str(config.version).split(".")
            version = tuple([int(x) for x in version])
        # Dispatch class based on version
999
        instance_class = _SUPPORT_VERSION.get(version)
1000
1001
1002
        if instance_class is None:
            raise ValueError(
                "Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6")
1003
        return instance_class(vllm_config=vllm_config, prefix=prefix)