Unverified Commit 179a6a36 authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[Model]Refactor MiniCPMV (#7020)


Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
parent 83c644fe
...@@ -220,7 +220,7 @@ Vision Language Models ...@@ -220,7 +220,7 @@ Vision Language Models
- Phi-3-Vision - Phi-3-Vision
- :code:`microsoft/Phi-3-vision-128k-instruct`, etc. - :code:`microsoft/Phi-3-vision-128k-instruct`, etc.
- -
* - :code:`MiniCPM-V` * - :code:`MiniCPMV`
- MiniCPM-V - MiniCPM-V
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, etc. - :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, etc.
- -
......
# coding=utf-8
# adapted from https://github.com/huggingface/transformers/blob/v4.43.2/src/transformers/models/idefics2/modeling_idefics2.py
# Copyright 2024 The vLLM team.
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Idefics2 model."""
from typing import Optional
import torch
from torch import nn
from transformers.models.idefics2.configuration_idefics2 import (
Idefics2Config, Idefics2VisionConfig)
from xformers import ops as xops
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
class Idefics2VisionEmbeddings(nn.Module):
"""
This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings
` to enable images of variable
resolution.
The modifications are adapted from [Patch n' Pack: NaViT, a Vision
Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304)
which allows treating images in their native aspect ratio and without the
need to resize them to the same fixed size. In particular, we start from the
original pre-trained SigLIP model(which uses images of fixed-size square
images) and adapt it by training on images of variable resolutions.
"""
def __init__(self, config: Idefics2VisionConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid",
)
self.num_patches_per_side = self.image_size // self.patch_size
self.num_patches = self.num_patches_per_side**2
self.num_positions = self.num_patches
self.position_embedding = nn.Embedding(self.num_positions,
self.embed_dim)
def forward(
self,
pixel_values: torch.FloatTensor,
patch_attention_mask: torch.BoolTensor,
) -> torch.Tensor:
batch_size, _, max_im_h, max_im_w = pixel_values.shape
patch_embeds = self.patch_embedding(pixel_values)
embeddings = patch_embeds.flatten(2).transpose(1, 2)
max_nb_patches_h, max_nb_patches_w = (
max_im_h // self.patch_size,
max_im_w // self.patch_size,
)
boundaries = torch.arange(1 / self.num_patches_per_side, 1.0,
1 / self.num_patches_per_side)
position_ids = torch.full(size=(batch_size,
max_nb_patches_h * max_nb_patches_w),
fill_value=0)
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
nb_patches_h = p_attn_mask[:, 0].sum()
nb_patches_w = p_attn_mask[0].sum()
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
bucket_coords_h = torch.bucketize(fractional_coords_h,
boundaries,
right=True)
bucket_coords_w = torch.bucketize(fractional_coords_w,
boundaries,
right=True)
pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side +
bucket_coords_w).flatten()
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
position_ids = position_ids.to(self.position_embedding.weight.device)
embeddings = embeddings + self.position_embedding(position_ids)
return embeddings
class Idefics2VisionAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
config: Idefics2Config,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" # noqa: E501
f" {self.num_heads}).")
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.qkv_proj = QKVParallelLinear(
self.embed_dim,
self.head_dim,
self.num_heads,
quant_config=quant_config,
)
self.out_proj = RowParallelLinear(
self.embed_dim,
self.embed_dim,
bias=True,
quant_config=quant_config,
)
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
self.is_causal = False
def forward(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
batch_size, q_len, _ = hidden_states.size()
qkv, _ = self.qkv_proj(
hidden_states
) # batch_size, q_len, 3 * num_heads_per_partition * head_dim
query_states, key_states, value_states = qkv.chunk(3, dim=-1)
query_states = query_states.view(batch_size, q_len,
self.num_heads_per_partition,
self.head_dim)
key_states = key_states.view(batch_size, q_len,
self.num_heads_per_partition,
self.head_dim)
value_states = value_states.view(batch_size, q_len,
self.num_heads_per_partition,
self.head_dim)
# see: https://facebookresearch.github.io/xformers/components/ops.html
out = xops.memory_efficient_attention_forward(
query_states,
key_states,
value_states,
p=self.dropout,
scale=self.scale,
)
out = out.view(batch_size, q_len, -1)
attn_output, _ = self.out_proj(out)
return attn_output
class Idefics2VisionMLP(nn.Module):
def __init__(
self,
config: Idefics2Config,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
self.fc1 = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,
bias=True,
quant_config=quant_config,
)
self.fc2 = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
bias=True,
quant_config=quant_config,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states, _ = self.fc2(hidden_states)
return hidden_states
class Idefics2EncoderLayer(nn.Module):
def __init__(self, config: Idefics2Config):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = Idefics2VisionAttention(config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps)
self.mlp = Idefics2VisionMLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
"""
Args:
hidden_states (`torch.FloatTensor`):
Input to the layer of shape `(batch, seq_len, embed_dim)`.
"""
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.self_attn(hidden_states)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class Idefics2Encoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention
layers. Each layer is a
[`Idefics2EncoderLayer`].
Args:
config: Idefics2Config
"""
def __init__(self, config: Idefics2Config):
super().__init__()
self.config = config
self.layers = nn.ModuleList([
Idefics2EncoderLayer(config)
for _ in range(config.num_hidden_layers)
])
def forward(
self,
inputs_embeds: torch.Tensor,
) -> torch.Tensor:
r"""
Args:
inputs_embeds (torch.Tensor):
Optionally, instead of passing `input_ids` you can choose to
directly pass an embedded representation.
This is useful if you want more control over how to convert
`input_ids` indices into associated vectorsthan the model's
internal embedding lookup matrix.
"""
hidden_states = inputs_embeds
for encoder_layer in self.layers:
layer_outputs = encoder_layer(hidden_states)
hidden_states = layer_outputs
return hidden_states
class Idefics2VisionTransformer(nn.Module):
def __init__(self, config: Idefics2VisionConfig):
super().__init__()
embed_dim = config.hidden_size
self.config = config
self.embeddings = Idefics2VisionEmbeddings(config)
self.encoder = Idefics2Encoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim,
eps=config.layer_norm_eps)
def get_input_embeddings(self):
return self.embeddings
def forward(
self,
pixel_values,
patch_attention_mask: Optional[torch.BoolTensor] = None,
) -> torch.tensor:
hidden_states = self.embeddings(
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask)
encoder_outputs = self.encoder(hidden_states)
last_hidden_state = self.post_layernorm(encoder_outputs)
return last_hidden_state
...@@ -24,7 +24,8 @@ ...@@ -24,7 +24,8 @@
import math import math
import re import re
from functools import partial from functools import partial
from typing import Dict, Iterable, List, Optional, Tuple, Union from typing import (Any, Callable, Iterable, List, Optional, Tuple, TypedDict,
Union)
import numpy as np import numpy as np
import torch import torch
...@@ -38,11 +39,14 @@ from transformers.configuration_utils import PretrainedConfig ...@@ -38,11 +39,14 @@ from transformers.configuration_utils import PretrainedConfig
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsVision from vllm.model_executor.models.interfaces import SupportsVision
from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.models.llama import LlamaModel
...@@ -54,12 +58,45 @@ from vllm.multimodal.image import (cached_get_image_processor, ...@@ -54,12 +58,45 @@ from vllm.multimodal.image import (cached_get_image_processor,
cached_get_tokenizer) cached_get_tokenizer)
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
from .idefics2_vision_model import Idefics2VisionTransformer
logger = init_logger(__name__)
_KEYS_TO_MODIFY_MAPPING = { _KEYS_TO_MODIFY_MAPPING = {
"llm.lm_head": "lm_head", "llm.lm_head": "lm_head",
"llm.model": "llm", "llm.model": "llm",
} }
class MiniCPMVImagePixelInputs(TypedDict):
pixel_values: List[torch.Tensor]
"""
Shape: `(batch_size * num_images, num_channels, height, width)`
Note that the image size may vary, so we pass it as a list
instead of a batched tensor.
"""
image_bounds: torch.Tensor
"""
Shape: `(batch_size * num_images, 2)`
This should be in `(start, stop)` format.
"""
tgt_sizes: torch.Tensor
"""
Shape: `(batch_size * num_images, 2)`
This should be in `(height, width)` format.
"""
MiniCPMVImageInputs = MiniCPMVImagePixelInputs
DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
def get_abs_pos(abs_pos: torch.Tensor, tgt_size: torch.Tensor): def get_abs_pos(abs_pos: torch.Tensor, tgt_size: torch.Tensor):
# abs_pos: L, C # abs_pos: L, C
# tgt_size: (H, W) # tgt_size: (H, W)
...@@ -68,19 +105,21 @@ def get_abs_pos(abs_pos: torch.Tensor, tgt_size: torch.Tensor): ...@@ -68,19 +105,21 @@ def get_abs_pos(abs_pos: torch.Tensor, tgt_size: torch.Tensor):
# tgt_size = int(math.sqrt(tgt_size)) # tgt_size = int(math.sqrt(tgt_size))
dtype = abs_pos.dtype dtype = abs_pos.dtype
return F.interpolate( return (F.interpolate(
abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2), abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
size=(tgt_size[0], tgt_size[1]), size=(tgt_size[0], tgt_size[1]),
mode="bicubic", mode="bicubic",
align_corners=False, align_corners=False,
).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype) ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype))
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
def get_2d_sincos_pos_embed(embed_dim: int, def get_2d_sincos_pos_embed(
embed_dim: int,
grid_size: Union[int, Tuple[int, int]], grid_size: Union[int, Tuple[int, int]],
cls_token: bool = False, cls_token: bool = False,
version: Tuple[int, int] = (2, 0)): version: Tuple[int, int] = (2, 0),
):
""" """
grid_size: int of the grid height and width grid_size: int of the grid height and width
return: return:
...@@ -109,7 +148,7 @@ def get_2d_sincos_pos_embed(embed_dim: int, ...@@ -109,7 +148,7 @@ def get_2d_sincos_pos_embed(embed_dim: int,
def get_2d_sincos_pos_embed_from_grid(embed_dim: int, def get_2d_sincos_pos_embed_from_grid(embed_dim: int,
grid: Union[int, Tuple[int, int]], grid: np.ndarray,
version: Tuple[int, int] = (2, 0)): version: Tuple[int, int] = (2, 0)):
assert embed_dim % 2 == 0 assert embed_dim % 2 == 0
...@@ -127,7 +166,7 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim: int, ...@@ -127,7 +166,7 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim: int,
def get_1d_sincos_pos_embed_from_grid(embed_dim: int, def get_1d_sincos_pos_embed_from_grid(embed_dim: int,
pos: int, pos: np.ndarray,
version: Tuple[int, int] = (2, 0)): version: Tuple[int, int] = (2, 0)):
""" """
embed_dim: output dimension for each position embed_dim: output dimension for each position
...@@ -136,24 +175,24 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim: int, ...@@ -136,24 +175,24 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim: int,
""" """
assert embed_dim % 2 == 0 assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float32) omega = np.arange(embed_dim // 2, dtype=np.float32)
omega /= embed_dim / 2. omega /= embed_dim / 2.0
omega = 1. / 10000**omega # (D/2,) omega = 1.0 / 10000**omega # (D/2,)
if version == (2, 0): if version == (2, 0):
pos = pos.reshape(-1) # (M,) pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2) emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
else: else:
out = np.einsum('hw,d->hwd', pos, omega) # (H, W, D/2), outer product out = np.einsum("hw,d->hwd", pos, omega) # (H, W, D/2), outer product
emb_sin = np.sin(out) # (H, W, D/2) emb_sin = np.sin(out) # (H, W, D/2)
emb_cos = np.cos(out) # (H, W, D/2) emb_cos = np.cos(out) # (H, W, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D) emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D)
return emb return emb
class Resampler(nn.Module): class BaseResampler(nn.Module):
""" """
A 2D perceiver-resampler network with one cross attention layers by A 2D perceiver-resampler network with one cross attention layers by
(grid_size**2) learnable queries and 2d sincos pos_emb (grid_size**2) learnable queries and 2d sincos pos_emb
...@@ -161,89 +200,151 @@ class Resampler(nn.Module): ...@@ -161,89 +200,151 @@ class Resampler(nn.Module):
A tensor with the shape of (grid_size**2, embed_dim) A tensor with the shape of (grid_size**2, embed_dim)
""" """
default_norm_layer = partial(nn.LayerNorm, eps=1e-6) def __init__(
self,
def __init__(self,
num_queries: int, num_queries: int,
grid_size: int,
embed_dim: int, embed_dim: int,
num_heads: int, num_heads: int,
kv_dim: Optional[int] = None, kv_dim: Optional[int] = None,
norm_layer: nn.Module = default_norm_layer, norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
adaptive: bool = False, ) -> None:
max_size: Tuple[int, int] = (70, 70),
version: Tuple[int, int] = (2, 0)):
super().__init__() super().__init__()
self.version = version
if self.version == (2, 0):
self.num_queries = grid_size**2
else:
self.num_queries = num_queries self.num_queries = num_queries
self.max_size = max_size
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.num_heads = num_heads self.num_heads = num_heads
self.adaptive = adaptive
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim)) self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
trunc_normal_(self.query, std=.02) trunc_normal_(self.query, std=0.02)
if kv_dim is not None and kv_dim != embed_dim: if kv_dim is not None and kv_dim != embed_dim:
self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False) self.kv_proj = ReplicatedLinear(kv_dim, embed_dim, bias=False)
else: else:
self.kv_proj = nn.Identity() # Maintain the same return value with ReplicatedLinear.forward
self.kv_proj = lambda *args, **kwargs: (
nn.Identity()(*args, **kwargs),
None,
)
self.attn = nn.MultiheadAttention(embed_dim, num_heads) self.attn = nn.MultiheadAttention(embed_dim, num_heads)
self.ln_q = norm_layer(embed_dim) self.ln_q = norm_layer(embed_dim)
self.ln_kv = norm_layer(embed_dim) self.ln_kv = norm_layer(embed_dim)
self.ln_post = norm_layer(embed_dim) self.ln_post = norm_layer(embed_dim)
self.proj = nn.Parameter( self.proj = nn.Parameter(
(embed_dim**-0.5) * torch.randn(embed_dim, embed_dim)) (embed_dim**-0.5) * torch.randn(embed_dim, embed_dim))
if self.version == (2, 0): def _init_weights(self, m: nn.Module) -> None:
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def _repeat(self, query, N: int):
return query.unsqueeze(1).repeat(1, N, 1)
class Resampler2(BaseResampler):
def __init__(
self,
grid_size: int,
embed_dim: int,
num_heads: int,
kv_dim: Optional[int] = None,
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
adaptive: bool = False,
) -> None:
super().__init__(grid_size**2, embed_dim, num_heads, kv_dim,
norm_layer)
self.adaptive = adaptive
pos_embed_arr = get_2d_sincos_pos_embed(embed_dim,
grid_size,
version=(2, 0))
self.pos_embed = nn.Parameter( self.pos_embed = nn.Parameter(
torch.from_numpy( torch.from_numpy(pos_embed_arr).float()).requires_grad_(False)
get_2d_sincos_pos_embed(
embed_dim, grid_size, self.apply(self._init_weights)
version=self.version)).float()).requires_grad_(False)
def forward(
self,
x: torch.Tensor,
tgt_sizes: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
):
if self.adaptive:
pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim,
tgt_sizes,
version=(2, 0))
pos_embed = torch.from_numpy(pos_embed_arr).to(device=x.device,
dtype=x.dtype)
else: else:
pos_embed = get_abs_pos(self.pos_embed, tgt_sizes)
x, _ = self.kv_proj(x)
x = self.ln_kv(x).permute(1, 0, 2)
N = x.shape[1]
q = self.ln_q(self.query)
out = self.attn(
self._repeat(q, N) + self.pos_embed.unsqueeze(1),
x + pos_embed.unsqueeze(1),
x,
attn_mask=attn_mask,
)[0]
x = out.permute(1, 0, 2)
x = self.ln_post(x)
x = x @ self.proj
return x
class Resampler2_5(BaseResampler):
def __init__(
self,
num_queries: int,
embed_dim: int,
num_heads: int,
kv_dim: Optional[int] = None,
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
max_size: Tuple[int, int] = (70, 70),
) -> None:
super().__init__(num_queries, embed_dim, num_heads, kv_dim, norm_layer)
self.max_size = max_size
self._set_2d_pos_cache(self.max_size) self._set_2d_pos_cache(self.max_size)
self.apply(self._init_weights) self.apply(self._init_weights)
def _set_2d_pos_cache(self, def _set_2d_pos_cache(self,
max_size: Tuple[int, int], max_size: Tuple[int, int],
device: torch.types.Device = 'cpu'): device: torch.types.Device = "cpu") -> None:
pos_embed = torch.from_numpy( pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim,
get_2d_sincos_pos_embed(self.embed_dim,
max_size, max_size,
version=self.version)).float().to(device) version=(2, 5))
pos_embed = torch.from_numpy(pos_embed_arr).float().to(device)
self.register_buffer("pos_embed", pos_embed, persistent=False) self.register_buffer("pos_embed", pos_embed, persistent=False)
def _adjust_pos_cache(self, tgt_sizes: torch.Tensor, def _adjust_pos_cache(self, tgt_sizes: torch.Tensor,
device: torch.types.Device): device: torch.types.Device) -> None:
max_h = torch.max(tgt_sizes[:, 0]) max_h = tgt_sizes[:, 0].max().item()
max_w = torch.max(tgt_sizes[:, 1]) max_w = tgt_sizes[:, 1].max().item()
assert isinstance(max_h, int) and isinstance(max_w, int)
if max_h > self.max_size[0] or max_w > self.max_size[1]: if max_h > self.max_size[0] or max_w > self.max_size[1]:
self.max_size = [ self.max_size = (
max(max_h, self.max_size[0]), max(max_h, self.max_size[0]),
max(max_w, self.max_size[1]) max(max_w, self.max_size[1]),
] )
self._set_2d_pos_cache(self.max_size, device) self._set_2d_pos_cache(self.max_size, device)
def _init_weights(self, m: nn.Module): def forward(self, x: torch.Tensor,
if isinstance(m, nn.Linear): tgt_sizes: torch.Tensor) -> torch.Tensor:
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward_2_5(self,
x: torch.Tensor,
tgt_sizes: Optional[torch.Tensor] = None):
assert x.shape[0] == tgt_sizes.shape[0] assert x.shape[0] == tgt_sizes.shape[0]
bs = x.shape[0] bs = x.shape[0]
...@@ -254,25 +355,25 @@ class Resampler(nn.Module): ...@@ -254,25 +355,25 @@ class Resampler(nn.Module):
self._adjust_pos_cache(tgt_sizes, device=device) self._adjust_pos_cache(tgt_sizes, device=device)
max_patch_len = torch.max(patch_len) max_patch_len = patch_len.max().item()
assert isinstance(max_patch_len, int)
key_padding_mask = torch.zeros((bs, max_patch_len), key_padding_mask = torch.zeros((bs, max_patch_len),
dtype=torch.bool, dtype=torch.bool,
device=device) device=device)
pos_embed = [] pos_embed = []
for i in range(bs): for i in range(bs):
tgt_h, tgt_w = tgt_sizes[i] tgt_h, tgt_w = tgt_sizes[i].tolist()
pos_embed.append(self.pos_embed[:tgt_h, :tgt_w, :].reshape( pos_embed.append(self.pos_embed[:tgt_h, :tgt_w, :].reshape(
(tgt_h * tgt_w, -1)).to(dtype)) # patches * D (tgt_h * tgt_w, -1)).to(dtype)) # patches * D
key_padding_mask[i, patch_len[i]:] = True key_padding_mask[i, patch_len[i]:] = True
pos_embed = torch.nn.utils.rnn.pad_sequence(pos_embed, pos_embed = torch.nn.utils.rnn.pad_sequence(pos_embed,
batch_first=True, batch_first=True,
padding_value=0.0).permute( padding_value=0.0).permute(
1, 0, 1, 0,
2) # BLD => L * B * D 2) # BLD => L * B * D
x, _ = self.kv_proj(x) # B * L * D
x = self.kv_proj(x) # B * L * D
x = self.ln_kv(x).permute(1, 0, 2) # L * B * D x = self.ln_kv(x).permute(1, 0, 2) # L * B * D
q = self.ln_q(self.query) # Q * D q = self.ln_q(self.query) # Q * D
...@@ -281,7 +382,8 @@ class Resampler(nn.Module): ...@@ -281,7 +382,8 @@ class Resampler(nn.Module):
self._repeat(q, bs), # Q * B * D self._repeat(q, bs), # Q * B * D
x + pos_embed, # L * B * D + L * B * D x + pos_embed, # L * B * D + L * B * D
x, x,
key_padding_mask=key_padding_mask)[0] key_padding_mask=key_padding_mask,
)[0]
# out: Q * B * D # out: Q * B * D
x = out.permute(1, 0, 2) # B * Q * D x = out.permute(1, 0, 2) # B * Q * D
...@@ -289,45 +391,6 @@ class Resampler(nn.Module): ...@@ -289,45 +391,6 @@ class Resampler(nn.Module):
x = x @ self.proj x = x @ self.proj
return x return x
def forward_2(self,
x: torch.Tensor,
tgt_sizes: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None):
if self.adaptive:
pos_embed = torch.Tensor(
get_2d_sincos_pos_embed(self.embed_dim,
tgt_sizes)).float().to(device=x.device,
dtype=x.dtype)
else:
pos_embed = get_abs_pos(self.pos_embed, tgt_sizes)
x = self.kv_proj(x)
x = self.ln_kv(x).permute(1, 0, 2)
N = x.shape[1]
q = self.ln_q(self.query)
out = self.attn(self._repeat(q, N) + self.pos_embed.unsqueeze(1),
x + pos_embed.unsqueeze(1),
x,
attn_mask=attn_mask)[0]
x = out.permute(1, 0, 2)
x = self.ln_post(x)
x = x @ self.proj
return x
def forward(self,
x: torch.Tensor,
tgt_sizes: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None):
if self.version == (2, 0):
return self.forward_2(x, tgt_sizes=tgt_sizes, attn_mask=attn_mask)
else:
return self.forward_2_5(x, tgt_sizes=tgt_sizes)
def _repeat(self, query, N: int):
return query.unsqueeze(1).repeat(1, N, 1)
def get_max_minicpmv_image_tokens(ctx: InputContext): def get_max_minicpmv_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(PretrainedConfig) hf_config = ctx.get_hf_config(PretrainedConfig)
...@@ -348,10 +411,7 @@ def dummy_image_for_minicpmv(hf_config: PretrainedConfig): ...@@ -348,10 +411,7 @@ def dummy_image_for_minicpmv(hf_config: PretrainedConfig):
def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int): def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int):
hf_config = ctx.get_hf_config(PretrainedConfig) hf_config = ctx.get_hf_config(PretrainedConfig)
# image_feature_size = get_max_minicpmv_image_tokens(ctx)
seq_data = dummy_seq_data_for_minicpmv(seq_len) seq_data = dummy_seq_data_for_minicpmv(seq_len)
mm_data = dummy_image_for_minicpmv(hf_config) mm_data = dummy_image_for_minicpmv(hf_config)
return seq_data, mm_data return seq_data, mm_data
...@@ -376,25 +436,36 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -376,25 +436,36 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
pattern = "(<image>./</image>)" pattern = "(<image>./</image>)"
image = multi_modal_data["image"] image = multi_modal_data["image"]
image_tags = re.findall(pattern, prompt) image_tags = re.findall(pattern, prompt)
assert len(image_tags) <= 1
if len(image_tags) == 0:
new_token_ids = token_ids
new_prompt = prompt
else:
if len(image_tags) > 1:
logger.warning("Multiple image input is not supported yet, "
"so any extra image tokens will be treated "
"as plain text.")
text_chunks = prompt.split(pattern) text_chunks = prompt.split(pattern)
new_prompt = text_chunks[0] \ new_prompt = (text_chunks[0] +
+ image_processor.get_slice_image_placeholder(image.size) \ image_processor.get_slice_image_placeholder(image.size) +
+ text_chunks[1] "".join(text_chunks[1:]))
new_token_ids = tokenizer.encode(new_prompt) new_token_ids = tokenizer.encode(new_prompt)
llm_inputs = LLMInputs(prompt_token_ids=new_token_ids, llm_inputs = LLMInputs(
prompt_token_ids=new_token_ids,
prompt=new_prompt, prompt=new_prompt,
multi_modal_data=multi_modal_data) multi_modal_data=multi_modal_data,
)
return llm_inputs return llm_inputs
@MULTIMODAL_REGISTRY.register_image_input_mapper() class MiniCPMVBaseModel(nn.Module, SupportsVision):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_minicpmv_image_tokens) """
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_minicpmv) The abstract class of MiniCPMV can only be inherited, but cannot be
@INPUT_REGISTRY.register_input_processor(input_processor_for_minicpmv) instantiated.
class MiniCPMV(nn.Module, SupportsVision): """
def __init__( def __init__(
self, self,
...@@ -419,8 +490,8 @@ class MiniCPMV(nn.Module, SupportsVision): ...@@ -419,8 +490,8 @@ class MiniCPMV(nn.Module, SupportsVision):
self.vpm = self.init_vision_module() self.vpm = self.init_vision_module()
param_dtype = torch.get_default_dtype() param_dtype = torch.get_default_dtype()
self.vpm.to(dtype=param_dtype) self.vpm.to(dtype=param_dtype)
self.vision_dim = self.vpm.embed_dim if self.version == (2, 0) \ self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else
else self.vpm.embeddings.embed_dim self.vpm.embeddings.embed_dim)
self.embed_dim = self.config.hidden_size self.embed_dim = self.config.hidden_size
self.resampler = self.init_resampler(self.embed_dim, self.vision_dim) self.resampler = self.init_resampler(self.embed_dim, self.vision_dim)
self.resampler.to(device="cuda", dtype=param_dtype) self.resampler.to(device="cuda", dtype=param_dtype)
...@@ -430,248 +501,100 @@ class MiniCPMV(nn.Module, SupportsVision): ...@@ -430,248 +501,100 @@ class MiniCPMV(nn.Module, SupportsVision):
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
def init_llm(self, def get_embedding(
config: PretrainedConfig, self,
cache_config: Optional[CacheConfig] = None, input_ids: torch.Tensor,
quant_config: Optional[QuantizationConfig] = None): image_inputs: Optional[MiniCPMVImageInputs],
if self.version == (2, 0): ) -> Tuple[torch.Tensor, torch.Tensor]:
return MiniCPMModel(config, vlm_embedding: torch.Tensor = self.llm.embed_tokens(input_ids)
cache_config=cache_config, if hasattr(self.config, "scale_emb"):
quant_config=quant_config) vlm_embedding *= self.config.scale_emb
elif self.version == (2, 5):
return LlamaModel(config, if image_inputs is None: # No image
cache_config=cache_config, vision_hidden_states = torch.tensor([], device=input_ids.device)
quant_config=quant_config)
else:
return Qwen2Model(config,
cache_config=cache_config,
quant_config=quant_config)
def init_vision_module(self):
if self.version == (2, 0):
try:
import timm
except ImportError:
raise ImportError(
'Please install timm==0.9.10') from ImportError
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.float16)
model = timm.create_model('vit_so400m_patch14_siglip_384.webli',
pretrained=False,
num_classes=0,
dynamic_img_size=True,
dynamic_img_pad=True)
torch.set_default_dtype(default_dtype)
if isinstance(model, timm.models.VisionTransformer
) and model.attn_pool is not None:
model.attn_pool = torch.nn.Identity()
if self.config.drop_vision_last_layer:
model.blocks = model.blocks[:-1]
elif self.version == (2, 5):
from transformers.models.idefics2.modeling_idefics2 import (
Idefics2VisionTransformer)
model = Idefics2VisionTransformer(self.config.vision_config)
if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1]
else:
from vllm.model_executor.models.na_vit import (
SiglipVisionTransformer)
if self.config._attn_implementation == 'flash_attention_2':
self.config.vision_config._attn_implementation \
= 'flash_attention_2'
else: else:
# not support sdpa vision_hidden_states = self.get_vision_hidden_states(image_inputs)
self.config.vision_config._attn_implementation = 'eager'
model = SiglipVisionTransformer(self.config.vision_config)
if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1]
return model
def init_resampler(self, embed_dim: int, vision_dim: int): # See NOTE in _parse_and_validate_inputs
default_dtype = torch.get_default_dtype() image_bounds = image_inputs["image_bounds"]
torch.set_default_dtype(torch.float16) if len(image_bounds) > 0:
if self.version == (2, 0): image_indices = torch.stack([
resampler = Resampler(grid_size=int( torch.arange(start, end, dtype=torch.long)
math.sqrt(self.config.query_num)), for start, end in image_bounds.tolist()
num_queries=None, ]).to(vlm_embedding.device)
embed_dim=embed_dim, vlm_embedding.scatter_(
num_heads=embed_dim // 128, 0,
kv_dim=vision_dim, image_indices.view(-1, 1).repeat(1,
adaptive=True, vlm_embedding.shape[-1]),
version=self.version) vision_hidden_states.view(-1,
else: vision_hidden_states.shape[-1]),
resampler = Resampler(num_queries=self.config.query_num, )
grid_size=None,
embed_dim=embed_dim,
num_heads=embed_dim // 128,
kv_dim=vision_dim,
adaptive=True,
version=self.version)
torch.set_default_dtype(default_dtype)
return resampler
def get_vision_embedding(self, return vlm_embedding, vision_hidden_states
pixel_values: List[List[torch.Tensor]],
patch_attn_mask: Optional[torch.Tensor] = None,
tgt_sizes: Optional[torch.Tensor] = None,
version: Tuple[int, int] = (2, 0)):
if version == (2, 0):
res = []
dtype = self.vpm.pos_embed.data.dtype
for pixel_value in pixel_values:
# V2.0 start
H, W = pixel_value[0].shape[-2:]
tgt_size = (math.ceil(H / self.vpm.patch_embed.patch_size[0]),
math.ceil(W / self.vpm.patch_embed.patch_size[0]))
# V2.0 end
vision_embedding = self.vpm.forward_features(
pixel_value.unsqueeze(0).type(dtype))
if hasattr(self.vpm, 'num_prefix_tokens'
) and self.vpm.num_prefix_tokens > 0:
vision_embedding = vision_embedding[:, self.vpm.
num_prefix_tokens:]
res.append(self.resampler(vision_embedding, tgt_size))
return torch.vstack(res)
elif version == (2, 5):
vision_embedding = self.vpm(
pixel_values.type(dtype),
patch_attention_mask=patch_attn_mask).last_hidden_state
vision_embedding = self.resampler(vision_embedding, tgt_sizes)
else:
vision_embedding = self.vpm(pixel_values.type(dtype),
patch_attention_mask=patch_attn_mask,
tgt_sizes=tgt_sizes).last_hidden_state
def get_image_bounds(self, input_ids: torch.Tensor): def _get_image_bounds(self, input_ids: torch.Tensor) -> torch.Tensor:
tokenizer = cached_get_tokenizer(self.config._name_or_path, tokenizer = cached_get_tokenizer(self.config._name_or_path,
trust_remote_code=True) trust_remote_code=True)
if not hasattr(tokenizer, "slice_start_id"):
start_cond = input_ids == tokenizer.im_start_id start_cond = input_ids == tokenizer.im_start_id
end_cond = input_ids == tokenizer.im_end_id end_cond = input_ids == tokenizer.im_end_id
else: if hasattr(tokenizer, "slice_start_id"):
start_cond = (input_ids == tokenizer.im_start_id) | ( start_cond |= (input_ids == tokenizer.slice_start_id)
input_ids == tokenizer.slice_start_id) end_cond |= (input_ids == tokenizer.slice_end_id)
end_cond = (input_ids == tokenizer.im_end_id) | (
input_ids == tokenizer.slice_end_id)
image_start_tokens = torch.where(start_cond)[0] image_start_tokens, = torch.where(start_cond)
image_start_tokens += 1 image_start_tokens += 1
image_end_tokens = torch.where(end_cond)[0] image_end_tokens, = torch.where(end_cond)
valid_image_nums = max(len(image_start_tokens), len(image_end_tokens)) valid_image_nums = max(len(image_start_tokens), len(image_end_tokens))
if valid_image_nums == 0: if valid_image_nums == 0:
return [] return torch.zeros((0, 2), device=input_ids.device)
image_bound = torch.hstack([
return torch.hstack([
image_start_tokens[:valid_image_nums].unsqueeze(-1), image_start_tokens[:valid_image_nums].unsqueeze(-1),
image_end_tokens[:valid_image_nums].unsqueeze(-1), image_end_tokens[:valid_image_nums].unsqueeze(-1),
]) ])
return image_bound def _parse_and_validate_inputs(
self,
def get_vision_hidden_states(self, data: Dict[str, input_ids: torch.Tensor,
Union[List[torch.Tensor], **kwargs: object,
torch.Tensor]]): ) -> Optional[MiniCPMVImageInputs]:
if "vision_hidden_states" not in data: pixel_values = kwargs.pop("pixel_values", [])
pixel_values = data["pixel_values"] tgt_sizes = kwargs.pop("tgt_sizes", [])
tgt_sizes = data["tgt_sizes"]
vision_hidden_states = [] if not isinstance(pixel_values, (torch.Tensor, list)):
if self.version == (2, 0): raise ValueError("Incorrect type of pixel values. "
if pixel_values is not None and len(pixel_values) > 0: f"Got type: {type(pixel_values)}")
vision_hidden_states = self.get_vision_embedding(
pixel_values) if not isinstance(tgt_sizes, (torch.Tensor, list)):
else: raise ValueError("Incorrect type of target sizes. "
vision_hidden_states = torch.tensor([]).to( f"Got type: {type(tgt_sizes)}")
data["input_ids"].device)
else: if len(pixel_values) != len(tgt_sizes):
device = self.vpm.embeddings.position_embedding.weight.device raise ValueError("Inconsistent batch lengths, found: "
dtype = self.vpm.embeddings.position_embedding.weight.dtype f"{len(pixel_values)} vs. {len(tgt_sizes)}")
all_pixel_values = [
i.flatten(end_dim=1).permute(1, 0) for i in pixel_values pixel_values_flat: List[torch.Tensor] = []
] tgt_sizes_flat: List[torch.Tensor] = []
if all_pixel_values: for b in range(len(pixel_values)):
tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32) pixel_values_flat += pixel_values[b]
max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1]) tgt_sizes_flat += tgt_sizes[b]
all_pixel_values = torch.nn.utils.rnn.pad_sequence(
all_pixel_values, batch_first=True, padding_value=0.0) # NOTE: Input IDs does not contain image tokens during memory profiling,
B, L, _ = all_pixel_values.shape # so we allow it to be empty
all_pixel_values = all_pixel_values.permute( if len(pixel_values_flat) != len(tgt_sizes_flat):
0, 2, 1).reshape(B, 3, -1, L) raise ValueError("Inconsistent flattened lengths, found: "
patch_attn_mask = torch.zeros((B, 1, max_patches), f"{len(pixel_values_flat)} vs. "
dtype=torch.bool, f"{len(tgt_sizes_flat)}")
device=device)
if self.version == (2, 5): if len(pixel_values_flat) == 0:
for i in range(B): return None
patch_attn_mask[i, :tgt_sizes[i][0] *
tgt_sizes[i][1]] = True return MiniCPMVImageInputs(
vision_embedding = self.vpm( image_bounds=self._get_image_bounds(input_ids),
all_pixel_values.type(dtype), pixel_values=pixel_values_flat,
patch_attention_mask=patch_attn_mask tgt_sizes=torch.stack(tgt_sizes_flat),
).last_hidden_state )
else:
for i in range(B):
patch_attn_mask[i, 0, :tgt_sizes[i][0] *
tgt_sizes[i][1]] = True
vision_embedding = self.vpm(
all_pixel_values.type(dtype),
patch_attention_mask=patch_attn_mask,
tgt_sizes=tgt_sizes).last_hidden_state
vision_hidden_states = self.resampler(
vision_embedding, tgt_sizes)
else: # no image
dummy_feature = []
vision_hidden_states = dummy_feature
else:
vision_hidden_states = data["vision_hidden_states"]
return vision_hidden_states
def get_embedding(self, data: Dict[str, Union[List[torch.Tensor],
torch.Tensor]]):
input_ids = data["input_ids"]
vision_hidden_states = self.get_vision_hidden_states(data)
if vision_hidden_states is not None and len(vision_hidden_states) > 0:
image_bounds = self.get_image_bounds(input_ids)
else:
image_bounds = []
if hasattr(self.config, 'scale_emb'):
vlm_embedding = self.llm.embed_tokens(
input_ids) * self.config.scale_emb
else:
vlm_embedding = self.llm.embed_tokens(input_ids)
vision_hidden_states = [
i.type(vlm_embedding.dtype) if isinstance(i, torch.Tensor) else i
for i in vision_hidden_states
]
if len(vision_hidden_states) > 0 and len(image_bounds) > 0:
vision_hidden_states = torch.cat(vision_hidden_states, dim=0)
image_indices = torch.stack([
torch.arange(r[0], r[1], dtype=torch.long)
for r in image_bounds
]).to(vlm_embedding.device)
vlm_embedding.scatter_(
0,
image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]),
vision_hidden_states.view(-1, vision_hidden_states.shape[-1]))
return vlm_embedding, vision_hidden_states
def process_multimodal_inputs(self, inputs: Dict[str,
Union[List[torch.Tensor],
torch.Tensor]]):
pixel_values = []
tgt_sizes = []
for b in range(len(inputs["pixel_values"])):
pixel_values += inputs["pixel_values"][b]
tgt_sizes += inputs["tgt_sizes"][b]
return {
"pixel_values": pixel_values,
"input_ids": inputs["input_ids"],
"tgt_sizes": tgt_sizes
}
def forward( def forward(
self, self,
...@@ -680,23 +603,20 @@ class MiniCPMV(nn.Module, SupportsVision): ...@@ -680,23 +603,20 @@ class MiniCPMV(nn.Module, SupportsVision):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object, **kwargs: Any,
): ) -> torch.Tensor:
inputs = { image_inputs = self._parse_and_validate_inputs(input_ids, **kwargs)
"pixel_values": kwargs.pop("pixel_values", []),
"input_ids": input_ids,
"tgt_sizes": kwargs.pop("tgt_sizes", None),
}
inputs = self.process_multimodal_inputs(inputs)
vlm_embeddings, vision_hidden_states = self.get_embedding(inputs) vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs)
output = self.llm(input_ids=None, output = self.llm(
input_ids=None,
positions=positions, positions=positions,
kv_caches=kv_caches, kv_caches=kv_caches,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
inputs_embeds=vlm_embeddings) inputs_embeds=vlm_embeddings,
)
return output return output
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
...@@ -735,13 +655,10 @@ class MiniCPMV(nn.Module, SupportsVision): ...@@ -735,13 +655,10 @@ class MiniCPMV(nn.Module, SupportsVision):
# the checkpoint. Skip them. # the checkpoint. Skip them.
continue continue
use_default_weight_loading = False use_default_weight_loading = False
if "vpm" in name or 'resampler' in name: if self.is_default_weight_loading(name):
# We only do sharding for language model and
# not vision model for now.
use_default_weight_loading = True use_default_weight_loading = True
else: else:
for (param_name, weight_name, for param_name, weight_name, shard_id in stacked_params_mapping:
shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
param = params_dict[name.replace(weight_name, param_name)] param = params_dict[name.replace(weight_name, param_name)]
...@@ -755,3 +672,341 @@ class MiniCPMV(nn.Module, SupportsVision): ...@@ -755,3 +672,341 @@ class MiniCPMV(nn.Module, SupportsVision):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
def init_llm(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> nn.Module:
raise NotImplementedError
def init_vision_module(self) -> nn.Module:
raise NotImplementedError
def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
raise NotImplementedError
def get_vision_embedding(
self,
pixel_values: List[torch.Tensor],
patch_attn_mask: Optional[torch.Tensor] = None,
tgt_sizes: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
def get_vision_hidden_states(self,
data: MiniCPMVImageInputs) -> torch.Tensor:
raise NotImplementedError
def is_default_weight_loading(self, name: str) -> bool:
raise NotImplementedError
class MiniCPMV2(MiniCPMVBaseModel):
def __init__(
self,
config: PretrainedConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__(config, multimodal_config, cache_config, quant_config)
assert self.version == (2, 0)
def init_llm(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> nn.Module:
return MiniCPMModel(config,
cache_config=cache_config,
quant_config=quant_config)
def init_vision_module(self) -> nn.Module:
# TODO :refactor this vision model
try:
import timm
except ImportError:
raise ImportError("Please install timm==0.9.10") from ImportError
with set_default_torch_dtype(torch.float16):
model = timm.create_model(
"vit_so400m_patch14_siglip_384.webli",
pretrained=False,
num_classes=0,
dynamic_img_size=True,
dynamic_img_pad=True,
)
if (isinstance(model, timm.models.VisionTransformer)
and model.attn_pool is not None):
model.attn_pool = torch.nn.Identity()
if self.config.drop_vision_last_layer:
model.blocks = model.blocks[:-1]
return model
def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
with set_default_torch_dtype(torch.float16):
resampler = Resampler2(
embed_dim=embed_dim,
num_heads=embed_dim // 128,
grid_size=int(math.sqrt(self.config.query_num)),
kv_dim=vision_dim,
adaptive=True,
)
return resampler
def get_vision_embedding(
self,
pixel_values: List[torch.Tensor],
patch_attn_mask: Optional[torch.Tensor] = None,
tgt_sizes: Optional[torch.Tensor] = None,
) -> torch.Tensor:
res = []
dtype = self.vpm.pos_embed.data.dtype
for pixel_value in pixel_values:
H, W = pixel_value[0].shape[-2:]
tgt_size = (
math.ceil(H / self.vpm.patch_embed.patch_size[0]),
math.ceil(W / self.vpm.patch_embed.patch_size[0]),
)
vision_embedding = self.vpm.forward_features(
pixel_value.unsqueeze(0).type(dtype))
if (hasattr(self.vpm, "num_prefix_tokens")
and self.vpm.num_prefix_tokens > 0):
vision_embedding = vision_embedding[:, self.vpm.
num_prefix_tokens:]
res.append(self.resampler(vision_embedding, tgt_size))
return torch.vstack(res)
def get_vision_hidden_states(self,
data: MiniCPMVImageInputs) -> torch.Tensor:
pixel_values = data["pixel_values"]
return self.get_vision_embedding(pixel_values)
def is_default_weight_loading(self, name: str) -> bool:
return "resampler" in name or "vpm" in name
class MiniCPMV2_5(MiniCPMVBaseModel):
def __init__(
self,
config: PretrainedConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__(config, multimodal_config, cache_config, quant_config)
assert self.version == (2, 5)
def init_llm(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> nn.Module:
return LlamaModel(config,
cache_config=cache_config,
quant_config=quant_config)
def init_vision_module(self) -> nn.Module:
model = Idefics2VisionTransformer(self.config.vision_config)
if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1]
return model
def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
with set_default_torch_dtype(torch.float16):
resampler = Resampler2_5(
num_queries=self.config.query_num,
embed_dim=embed_dim,
num_heads=embed_dim // 128,
kv_dim=vision_dim,
)
return resampler
def get_vision_embedding(
self,
pixel_values: List[torch.Tensor],
patch_attn_mask: Optional[torch.Tensor] = None,
tgt_sizes: Optional[torch.Tensor] = None,
) -> torch.Tensor:
vision_embedding = self.vpm(pixel_values,
patch_attention_mask=patch_attn_mask)
vision_embedding = self.resampler(vision_embedding, tgt_sizes)
return vision_embedding
def get_vision_hidden_states(self,
data: MiniCPMVImageInputs) -> torch.Tensor:
pixel_values = data["pixel_values"]
tgt_sizes = data["tgt_sizes"]
device = self.vpm.embeddings.position_embedding.weight.device
dtype = self.vpm.embeddings.position_embedding.weight.dtype
all_pixel_values_lst = [
i.flatten(end_dim=1).permute(1, 0) for i in pixel_values
]
max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
assert isinstance(max_patches, int)
all_pixel_values = torch.nn.utils.rnn.pad_sequence(
all_pixel_values_lst, batch_first=True, padding_value=0.0)
B, L, _ = all_pixel_values.shape
all_pixel_values = all_pixel_values.permute(0, 2,
1).reshape(B, 3, -1, L)
patch_attn_mask = torch.zeros((B, 1, max_patches),
dtype=torch.bool,
device=device)
for i in range(B):
patch_attn_mask[i, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True
return self.get_vision_embedding(all_pixel_values.type(dtype),
patch_attn_mask, tgt_sizes)
def is_default_weight_loading(self, name: str) -> bool:
return "resampler" in name
# NOTE: Currently, information about this model is unavailable. We are
# temporarily using `MiniCPMVQwen2` as it's name. The name may need
# to be modified in the future.
class MiniCPMVQwen2(MiniCPMVBaseModel):
def __init__(
self,
config: PretrainedConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__(config, multimodal_config, cache_config, quant_config)
def init_llm(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> nn.Module:
return Qwen2Model(config,
cache_config=cache_config,
quant_config=quant_config)
def init_vision_module(self) -> nn.Module:
# A custom version of SiglipVisionTransformer, won't work with TP
from vllm.model_executor.models.na_vit import SiglipVisionTransformer
if self.config._attn_implementation == "flash_attention_2":
self.config.vision_config._attn_implementation = "flash_attention_2"
else:
# not support sdpa
self.config.vision_config._attn_implementation = "eager"
model = SiglipVisionTransformer(self.config.vision_config)
if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1]
return model
def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
with set_default_torch_dtype(torch.float16):
resampler = Resampler2_5(
num_queries=self.config.query_num,
embed_dim=embed_dim,
num_heads=embed_dim // 128,
kv_dim=vision_dim,
)
return resampler
def get_vision_embedding(
self,
pixel_values: List[torch.Tensor],
patch_attn_mask: Optional[torch.Tensor] = None,
tgt_sizes: Optional[torch.Tensor] = None,
) -> torch.Tensor:
vision_embedding = self.vpm(
pixel_values,
patch_attention_mask=patch_attn_mask,
tgt_sizes=tgt_sizes,
).last_hidden_state
return vision_embedding
def get_vision_hidden_states(self,
data: MiniCPMVImageInputs) -> torch.Tensor:
pixel_values = data["pixel_values"]
tgt_sizes = data["tgt_sizes"]
device = self.vpm.embeddings.position_embedding.weight.device
dtype = self.vpm.embeddings.position_embedding.weight.dtype
all_pixel_values_lst = [
i.flatten(end_dim=1).permute(1, 0) for i in pixel_values
]
max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
assert isinstance(max_patches, int)
all_pixel_values = torch.nn.utils.rnn.pad_sequence(
all_pixel_values_lst, batch_first=True, padding_value=0.0)
B, L, _ = all_pixel_values.shape
all_pixel_values = all_pixel_values.permute(0, 2,
1).reshape(B, 3, -1, L)
patch_attn_mask = torch.zeros((B, 1, max_patches),
dtype=torch.bool,
device=device)
for i in range(B):
patch_attn_mask[i, 0, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True
vision_embedding = self.vpm(
all_pixel_values.type(dtype),
patch_attention_mask=patch_attn_mask,
tgt_sizes=tgt_sizes,
).last_hidden_state
return self.resampler(vision_embedding, tgt_sizes)
def is_default_weight_loading(self, name: str) -> bool:
return "resampler" in name or "vpm" in name
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_minicpmv_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_minicpmv)
@INPUT_REGISTRY.register_input_processor(input_processor_for_minicpmv)
class MiniCPMV(MiniCPMVBaseModel):
"""
Different versions of MiniCPMV use different visual encoders and LLMs,
which is not conducive to the current integration logic of LoRA and
bitsandbytes in vLLM. Therefore, it is necessary to separate them.
"""
def __new__(
cls,
config: PretrainedConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
if not hasattr(config, "version"):
if config.hidden_size == 2304 and config.query_num == 64:
version = (2, 0)
else:
version = (2, 5)
else:
version = str(config.version).split(".")
version = tuple([int(x) for x in version])
# Dispatch class based on version
if version == (2, 0):
instance_class = MiniCPMV2
elif version == (2, 5):
instance_class = MiniCPMV2_5
else:
instance_class = MiniCPMVQwen2
return instance_class(config, multimodal_config, cache_config,
quant_config)
...@@ -100,7 +100,7 @@ def _get_unpad_data(attention_mask): ...@@ -100,7 +100,7 @@ def _get_unpad_data(attention_mask):
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item() max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad( cu_seqlens = F.pad(
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return ( return (
indices, indices,
cu_seqlens, cu_seqlens,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment