"tests/vscode:/vscode.git/clone" did not exist on "b6c16cf8ff8d558ec943f1f17342c2c081f3f5af"
Unverified Commit 8e60afa1 authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[Model][LoRA]LoRA support added for MiniCPMV2.6 (#8943)


Co-authored-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent b6d73925
...@@ -65,11 +65,10 @@ class Idefics2VisionEmbeddings(nn.Module): ...@@ -65,11 +65,10 @@ class Idefics2VisionEmbeddings(nn.Module):
self.position_embedding = nn.Embedding(self.num_positions, self.position_embedding = nn.Embedding(self.num_positions,
self.embed_dim) self.embed_dim)
def forward( def forward(self,
self,
pixel_values: torch.FloatTensor, pixel_values: torch.FloatTensor,
patch_attention_mask: torch.BoolTensor, patch_attention_mask: torch.BoolTensor,
) -> torch.Tensor: tgt_sizes: Optional[torch.IntTensor] = None) -> torch.Tensor:
batch_size, _, max_im_h, max_im_w = pixel_values.shape batch_size, _, max_im_h, max_im_w = pixel_values.shape
patch_embeds = self.patch_embedding(pixel_values) patch_embeds = self.patch_embedding(pixel_values)
embeddings = patch_embeds.flatten(2).transpose(1, 2) embeddings = patch_embeds.flatten(2).transpose(1, 2)
...@@ -84,6 +83,11 @@ class Idefics2VisionEmbeddings(nn.Module): ...@@ -84,6 +83,11 @@ class Idefics2VisionEmbeddings(nn.Module):
fill_value=0) fill_value=0)
for batch_idx, p_attn_mask in enumerate(patch_attention_mask): for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
if tgt_sizes is not None:
nb_patches_h = tgt_sizes[batch_idx][0]
nb_patches_w = tgt_sizes[batch_idx][1]
else:
nb_patches_h = p_attn_mask[:, 0].sum() nb_patches_h = p_attn_mask[:, 0].sum()
nb_patches_w = p_attn_mask[0].sum() nb_patches_w = p_attn_mask[0].sum()
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
...@@ -287,10 +291,12 @@ class Idefics2VisionTransformer(nn.Module): ...@@ -287,10 +291,12 @@ class Idefics2VisionTransformer(nn.Module):
self, self,
pixel_values, pixel_values,
patch_attention_mask: Optional[torch.BoolTensor] = None, patch_attention_mask: Optional[torch.BoolTensor] = None,
) -> torch.tensor: tgt_sizes: Optional[torch.IntTensor] = None,
) -> torch.Tensor:
hidden_states = self.embeddings( hidden_states = self.embeddings(
pixel_values=pixel_values, pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask) patch_attention_mask=patch_attention_mask,
tgt_sizes=tgt_sizes)
encoder_outputs = self.encoder(hidden_states) encoder_outputs = self.encoder(hidden_states)
last_hidden_state = self.post_layernorm(encoder_outputs) last_hidden_state = self.post_layernorm(encoder_outputs)
return last_hidden_state return last_hidden_state
...@@ -31,17 +31,15 @@ import torch ...@@ -31,17 +31,15 @@ import torch
import torch.types import torch.types
from PIL import Image from PIL import Image
from torch import nn from torch import nn
from torch.nn.init import trunc_normal_
from transformers import PretrainedConfig from transformers import PretrainedConfig
from typing_extensions import NotRequired from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.resampler import (Resampler2, from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2,
get_2d_sincos_pos_embed) get_2d_sincos_pos_embed)
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
...@@ -106,58 +104,6 @@ class MiniCPMVImagePixelInputs(TypedDict): ...@@ -106,58 +104,6 @@ class MiniCPMVImagePixelInputs(TypedDict):
DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6) DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
class BaseResampler(nn.Module):
"""
A 2D perceiver-resampler network with one cross attention layers by
(grid_size**2) learnable queries and 2d sincos pos_emb
Outputs:
A tensor with the shape of (grid_size**2, embed_dim)
"""
def __init__(
self,
num_queries: int,
embed_dim: int,
num_heads: int,
kv_dim: Optional[int] = None,
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
) -> None:
super().__init__()
self.num_queries = num_queries
self.embed_dim = embed_dim
self.num_heads = num_heads
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
trunc_normal_(self.query, std=0.02)
if kv_dim is not None and kv_dim != embed_dim:
self.kv_proj = ReplicatedLinear(kv_dim, embed_dim, bias=False)
else:
# Maintain the same return value with ReplicatedLinear.forward
self.kv_proj = lambda *args, **kwargs: (
nn.Identity()(*args, **kwargs),
None,
)
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
self.ln_q = norm_layer(embed_dim)
self.ln_kv = norm_layer(embed_dim)
self.ln_post = norm_layer(embed_dim)
self.proj = nn.Parameter(
(embed_dim**-0.5) * torch.randn(embed_dim, embed_dim))
def _init_weights(self, m: nn.Module) -> None:
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def _repeat(self, query, N: int):
return query.unsqueeze(1).repeat(1, N, 1)
class Resampler2_5(BaseResampler): class Resampler2_5(BaseResampler):
def __init__( def __init__(
...@@ -869,7 +815,35 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA): ...@@ -869,7 +815,35 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
return "resampler" in name return "resampler" in name
class MiniCPMV2_6(MiniCPMVBaseModel): class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
# vision encoder
"fc1",
"fc2",
"out_proj",
# language model
"qkv_proj", # same name with vision encoder
"o_proj",
"gate_up_proj",
"down_proj",
# resampler
"kv_proj",
]
embedding_modules = {}
embedding_padding_modules = []
def __init__( def __init__(
self, self,
...@@ -894,15 +868,8 @@ class MiniCPMV2_6(MiniCPMVBaseModel): ...@@ -894,15 +868,8 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
name="model") name="model")
def init_vision_module(self) -> nn.Module: def init_vision_module(self) -> nn.Module:
# A custom version of SiglipVisionTransformer, won't work with TP
from vllm.model_executor.models.na_vit import SiglipVisionTransformer
if self.config._attn_implementation == "flash_attention_2": model = Idefics2VisionTransformer(self.config.vision_config)
self.config.vision_config._attn_implementation = "flash_attention_2"
else:
# not support sdpa
self.config.vision_config._attn_implementation = "eager"
model = SiglipVisionTransformer(self.config.vision_config)
if self.config.drop_vision_last_layer: if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1] model.encoder.layers = model.encoder.layers[:-1]
return model return model
...@@ -928,7 +895,7 @@ class MiniCPMV2_6(MiniCPMVBaseModel): ...@@ -928,7 +895,7 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
pixel_values, pixel_values,
patch_attention_mask=patch_attn_mask, patch_attention_mask=patch_attn_mask,
tgt_sizes=tgt_sizes, tgt_sizes=tgt_sizes,
).last_hidden_state )
return vision_embedding return vision_embedding
def get_vision_hidden_states( def get_vision_hidden_states(
...@@ -960,12 +927,12 @@ class MiniCPMV2_6(MiniCPMVBaseModel): ...@@ -960,12 +927,12 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
all_pixel_values.type(dtype), all_pixel_values.type(dtype),
patch_attention_mask=patch_attn_mask, patch_attention_mask=patch_attn_mask,
tgt_sizes=tgt_sizes, tgt_sizes=tgt_sizes,
).last_hidden_state )
return self.resampler(vision_embedding, tgt_sizes) return self.resampler(vision_embedding, tgt_sizes)
def is_default_weight_loading(self, name: str) -> bool: def is_default_weight_loading(self, name: str) -> bool:
return "resampler" in name or "vpm" in name return "resampler" in name
_SUPPORT_VERSION = { _SUPPORT_VERSION = {
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment