"vscode:/vscode.git/clone" did not exist on "08275ec0a2ba48fbc1054bdbdda2f1e0dfcb20b3"
Unverified Commit 8e60afa1 authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[Model][LoRA]LoRA support added for MiniCPMV2.6 (#8943)


Co-authored-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent b6d73925
......@@ -65,11 +65,10 @@ class Idefics2VisionEmbeddings(nn.Module):
self.position_embedding = nn.Embedding(self.num_positions,
self.embed_dim)
def forward(
self,
def forward(self,
pixel_values: torch.FloatTensor,
patch_attention_mask: torch.BoolTensor,
) -> torch.Tensor:
tgt_sizes: Optional[torch.IntTensor] = None) -> torch.Tensor:
batch_size, _, max_im_h, max_im_w = pixel_values.shape
patch_embeds = self.patch_embedding(pixel_values)
embeddings = patch_embeds.flatten(2).transpose(1, 2)
......@@ -84,6 +83,11 @@ class Idefics2VisionEmbeddings(nn.Module):
fill_value=0)
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
if tgt_sizes is not None:
nb_patches_h = tgt_sizes[batch_idx][0]
nb_patches_w = tgt_sizes[batch_idx][1]
else:
nb_patches_h = p_attn_mask[:, 0].sum()
nb_patches_w = p_attn_mask[0].sum()
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
......@@ -287,10 +291,12 @@ class Idefics2VisionTransformer(nn.Module):
self,
pixel_values,
patch_attention_mask: Optional[torch.BoolTensor] = None,
) -> torch.tensor:
tgt_sizes: Optional[torch.IntTensor] = None,
) -> torch.Tensor:
hidden_states = self.embeddings(
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask)
patch_attention_mask=patch_attention_mask,
tgt_sizes=tgt_sizes)
encoder_outputs = self.encoder(hidden_states)
last_hidden_state = self.post_layernorm(encoder_outputs)
return last_hidden_state
......@@ -31,17 +31,15 @@ import torch
import torch.types
from PIL import Image
from torch import nn
from torch.nn.init import trunc_normal_
from transformers import PretrainedConfig
from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.resampler import (Resampler2,
from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2,
get_2d_sincos_pos_embed)
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
......@@ -106,58 +104,6 @@ class MiniCPMVImagePixelInputs(TypedDict):
DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
class BaseResampler(nn.Module):
"""
A 2D perceiver-resampler network with one cross attention layers by
(grid_size**2) learnable queries and 2d sincos pos_emb
Outputs:
A tensor with the shape of (grid_size**2, embed_dim)
"""
def __init__(
self,
num_queries: int,
embed_dim: int,
num_heads: int,
kv_dim: Optional[int] = None,
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
) -> None:
super().__init__()
self.num_queries = num_queries
self.embed_dim = embed_dim
self.num_heads = num_heads
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
trunc_normal_(self.query, std=0.02)
if kv_dim is not None and kv_dim != embed_dim:
self.kv_proj = ReplicatedLinear(kv_dim, embed_dim, bias=False)
else:
# Maintain the same return value with ReplicatedLinear.forward
self.kv_proj = lambda *args, **kwargs: (
nn.Identity()(*args, **kwargs),
None,
)
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
self.ln_q = norm_layer(embed_dim)
self.ln_kv = norm_layer(embed_dim)
self.ln_post = norm_layer(embed_dim)
self.proj = nn.Parameter(
(embed_dim**-0.5) * torch.randn(embed_dim, embed_dim))
def _init_weights(self, m: nn.Module) -> None:
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def _repeat(self, query, N: int):
return query.unsqueeze(1).repeat(1, N, 1)
class Resampler2_5(BaseResampler):
def __init__(
......@@ -869,7 +815,35 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
return "resampler" in name
class MiniCPMV2_6(MiniCPMVBaseModel):
class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
# vision encoder
"fc1",
"fc2",
"out_proj",
# language model
"qkv_proj", # same name with vision encoder
"o_proj",
"gate_up_proj",
"down_proj",
# resampler
"kv_proj",
]
embedding_modules = {}
embedding_padding_modules = []
def __init__(
self,
......@@ -894,15 +868,8 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
name="model")
def init_vision_module(self) -> nn.Module:
# A custom version of SiglipVisionTransformer, won't work with TP
from vllm.model_executor.models.na_vit import SiglipVisionTransformer
if self.config._attn_implementation == "flash_attention_2":
self.config.vision_config._attn_implementation = "flash_attention_2"
else:
# not support sdpa
self.config.vision_config._attn_implementation = "eager"
model = SiglipVisionTransformer(self.config.vision_config)
model = Idefics2VisionTransformer(self.config.vision_config)
if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1]
return model
......@@ -928,7 +895,7 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
pixel_values,
patch_attention_mask=patch_attn_mask,
tgt_sizes=tgt_sizes,
).last_hidden_state
)
return vision_embedding
def get_vision_hidden_states(
......@@ -960,12 +927,12 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
all_pixel_values.type(dtype),
patch_attention_mask=patch_attn_mask,
tgt_sizes=tgt_sizes,
).last_hidden_state
)
return self.resampler(vision_embedding, tgt_sizes)
def is_default_weight_loading(self, name: str) -> bool:
return "resampler" in name or "vpm" in name
return "resampler" in name
_SUPPORT_VERSION = {
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment