Commit 1d9ad5d4 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2728 failed with stages
in 0 seconds
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"layers": 77,
"width": 2304,
"head_width": 144,
"mlp_ratio": 10.9722,
"patch_size": 14,
"eva_model_name": "eva-clip-10b-14-x",
"drop_path_rate": 0,
"xattn": true,
"postnorm": false,
"fusedLN": true
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1280,
"heads": 20,
"layers": 32,
"xattn": false,
"fusedLN": true
}
}
{
"embed_dim": 1024,
"vision_cfg": {
"drop_path_rate": 0,
"eva_model_name": "eva-clip-E-14-plus",
"head_width": 112,
"image_size": 448,
"layers": 64,
"mlp_ratio": 8.571428571428571,
"patch_size": 14,
"postnorm": true,
"qkv_bias": true,
"width": 1792,
"xattn": true
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1280,
"heads": 20,
"layers": 32,
"xattn": false,
"fusedLN": true
}
}
# --------------------------------------------------------
# Adapted from https://github.com/microsoft/unilm/tree/master/beit
# --------------------------------------------------------
import os
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
try:
from timm.models.layers import drop_path, to_2tuple
except:
from timm.layers import drop_path, to_2tuple
try:
import xformers.ops as xops
except ImportError:
xops = None
print("Please 'pip install xformers'")
class PatchDropout(nn.Module):
"""
https://arxiv.org/abs/2212.00794
"""
def __init__(self, prob, exclude_first_token=True):
super().__init__()
assert 0 <= prob < 1.
self.prob = prob
self.exclude_first_token = exclude_first_token # exclude CLS token
print(f"os.getenv('RoPE')={os.getenv('RoPE')}")
def forward(self, x):
if not self.training or self.prob == 0.:
return x
if self.exclude_first_token:
cls_tokens, x = x[:, :1], x[:, 1:]
else:
cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
batch = x.size()[0]
num_tokens = x.size()[1]
batch_indices = torch.arange(batch)
batch_indices = batch_indices[..., None]
keep_prob = 1 - self.prob
num_patches_keep = max(1, int(num_tokens * keep_prob))
rand = torch.randn(batch, num_tokens)
patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
x = x[batch_indices, patch_indices_keep]
if self.exclude_first_token:
x = torch.cat((cls_tokens, x), dim=1)
if self.training and os.getenv('RoPE') == '1':
return x, patch_indices_keep
return x
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
def extra_repr(self) -> str:
return 'p={}'.format(self.drop_prob)
class Mlp(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
drop=0.,
subln=False,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
# x = self.drop(x)
# commit this for the orignal BERT implement
x = self.ffn_ln(x)
x = self.fc2(x)
x = self.drop(x)
return x
class SwiGLU(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.,
norm_layer=nn.LayerNorm, subln=False):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.w1 = nn.Linear(in_features, hidden_features)
self.w2 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
self.w3 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x1 = self.w1(x)
x2 = self.w2(x)
hidden = self.act(x1) * x2
x = self.ffn_ln(hidden)
x = self.w3(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
proj_drop=0., window_size=None, attn_head_dim=None, xattn=False, rope=None, subln=False, norm_layer=nn.LayerNorm):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = qk_scale or head_dim ** -0.5
self.subln = subln
if self.subln:
self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
else:
if qkv_bias:
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=True)
else:
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
# if qkv_bias:
# self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
# self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
# else:
# self.q_bias = None
# self.v_bias = None
self.window_size = None
self.relative_position_bias_table = None
self.relative_position_index = None
self.attn_drop = nn.Dropout(attn_drop)
self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity()
# self.proj = nn.Linear(all_head_dim, all_head_dim)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.xattn = xattn
self.xattn_drop = attn_drop
self.rope = rope
def forward(self, x, rel_pos_bias=None, attn_mask=None):
B, N, C = x.shape
if self.subln:
q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C
k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
else:
# qkv_bias = None
# if self.q_bias is not None:
# qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
# qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = self.qkv(x)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, num_heads, N, C
q, k, v = qkv[0], qkv[1], qkv[2]
if self.rope:
q_t = q[:, :, 1:, :]
ro_q_t = self.rope(q_t)
q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v)
k_t = k[:, :, 1:, :]
ro_k_t = self.rope(k_t)
k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v)
if self.xattn:
q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
x = xops.memory_efficient_attention(
q, k, v,
p=self.xattn_drop,
scale=self.scale,
)
x = x.reshape(B, N, -1)
x = self.inner_attn_ln(x)
x = self.proj(x)
x = self.proj_drop(x)
else:
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
if self.relative_position_bias_table is not None:
relative_position_bias = \
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0).type_as(attn)
if rel_pos_bias is not None:
attn = attn + rel_pos_bias.type_as(attn)
if attn_mask is not None:
attn_mask = attn_mask.bool()
attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.inner_attn_ln(x)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
window_size=None, attn_head_dim=None, xattn=False, rope=None, postnorm=False,
subln=False, naiveswiglu=False):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim,
xattn=xattn, rope=rope, subln=subln, norm_layer=norm_layer)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
if naiveswiglu:
self.mlp = SwiGLU(
in_features=dim,
hidden_features=mlp_hidden_dim,
subln=subln,
norm_layer=norm_layer,
)
else:
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
subln=subln,
drop=drop
)
if init_values is not None and init_values > 0:
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
else:
self.gamma_1, self.gamma_2 = None, None
self.postnorm = postnorm
def forward(self, x, rel_pos_bias=None, attn_mask=None):
if self.gamma_1 is None:
if self.postnorm:
x = x + self.drop_path(self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
x = x + self.drop_path(self.norm2(self.mlp(x)))
else:
x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
x = x + self.drop_path(self.mlp(self.norm2(x)))
else:
if self.postnorm:
x = x + self.drop_path(self.gamma_1 * self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
else:
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x, **kwargs):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2)
return x
class EVAVisionTransformer(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, patch_dropout=0.,
use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, rope=False,
use_mean_pooling=True, init_scale=0.001, grad_checkpointing=False, xattn=False, postnorm=False,
pt_hw_seq_len=16, intp_freq=False, naiveswiglu=False, subln=False,
):
super().__init__()
self.image_size = img_size
# self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
if use_abs_pos_emb:
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
else:
self.pos_embed = None
self.pos_drop = nn.Dropout(p=drop_rate)
self.rel_pos_bias = None
self.rope = None
self.naiveswiglu = naiveswiglu
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.use_rel_pos_bias = use_rel_pos_bias
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None,
xattn=xattn, rope=self.rope, postnorm=postnorm, subln=subln, naiveswiglu=naiveswiglu)
for i in range(depth)])
# setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
self.grad_checkpointing = grad_checkpointing
def get_num_layers(self):
return len(self.blocks)
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
assert unlocked_groups == 0, 'partial locking not currently supported for this model'
for param in self.parameters():
param.requires_grad = False
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def forward_features(self, x):
x = self.patch_embed(x)
batch_size, seq_len, _ = x.size()
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
if self.pos_embed is not None:
x = x + self.pos_embed
x = self.pos_drop(x)
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
if os.getenv('RoPE') == '1':
if self.training and not isinstance(self.patch_dropout, nn.Identity):
x, patch_indices_keep = self.patch_dropout(x)
self.rope.forward = partial(self.rope.forward, patch_indices_keep=patch_indices_keep)
else:
self.rope.forward = partial(self.rope.forward, patch_indices_keep=None)
x = self.patch_dropout(x)
else:
x = self.patch_dropout(x)
rel_pos_bias = None
for blk in self.blocks:
if self.grad_checkpointing:
x = checkpoint(blk, x, (rel_pos_bias,))
else:
x = blk(x, rel_pos_bias=rel_pos_bias)
return x
def forward(self, x):
"""
:return:
forward_features function returns raw features of ViT,
forward with return_all_features returns normalized features of ViT
:param x:
:param return_all_features:
"""
features = self.forward_features(x) # [B, n_patch, C]
return features
\ No newline at end of file
import torch
import torch.nn as nn
from transformers import CLIPImageProcessor
try:
from imagebind.models import imagebind_model
from imagebind.models.imagebind_model import ModalityType
from imagebind.data import load_and_transform_audio_data
except ImportError:
pass
class ImageBindWrapper(nn.Module):
def __init__(self, vision_tower, select_layer, select_feature="patch", delay_load=False):
super().__init__()
self.is_loaded = False
self.vision_tower_name = vision_tower
self.select_layer = select_layer
self.select_feature = select_feature
if not delay_load:
self.load_model()
def load_model(self):
self.image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14")
self.vision_tower = imagebind_model.imagebind_huge(pretrained=True)
for p in self.vision_tower.parameters():
p.requires_grad = False
self.vision_tower.eval()
self.is_loaded = True
def train(self, mode=True):
self.training = mode
if self.is_loaded:
self.vision_tower.eval()
@torch.no_grad()
def forward(self, x):
if type(x) == dict:
if x["audios"] is not None:
inputs = {ModalityType.AUDIO: load_and_transform_audio_data(x["audios"], device=self.device).half()}
embeddings = self.vision_tower(inputs)
audio_embedding = embeddings[ModalityType.AUDIO]
return audio_embedding.unsqueeze(1)
else:
inputs = {ModalityType.VISION: x.to(dtype=self.dtype)}
embeddings = self.vision_tower(inputs)
vision_embedding = embeddings[ModalityType.VISION]
if vision_embedding.ndim == 2:
return vision_embedding.unsqueeze(1)
if vision_embedding.shape[1] == 257:
return vision_embedding[:, 1:]
raise ValueError(f"Unexpected shape: {vision_embedding.shape}")
@property
def dummy_feature(self):
return torch.zeros(1, 1024, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return self.vision_tower.modality_preprocessors.vision.cls_token.dtype
@property
def device(self):
return self.vision_tower.modality_preprocessors.vision.cls_token.device
@property
def hidden_size(self):
return 1024
import torch
import torch.nn as nn
from transformers import CLIPImageProcessor
try:
import open_clip
import torchvision
from open_clip.transformer import _expand_token
except ImportError:
print("OpenCLIP not installed")
open_clip = None
HIDDEN_SIZE_DICT = {
"ViT-H-14-378-quickgelu": 1280,
}
class OpenCLIPVisionTower(nn.Module):
def __init__(self, vision_tower, args, delay_load=False):
super().__init__()
self.is_loaded = False
self.model_name = vision_tower.replace("open_clip_hub:", "")
self.pretrained = args.vision_tower_pretrained
self.select_layer = args.mm_vision_select_layer
self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
if not delay_load:
print(f"Loading vision tower: {vision_tower}")
self.load_model()
elif getattr(args, "unfreeze_mm_vision_tower", False):
# TODO: better detector is needed.
print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
self.load_model()
elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts:
print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.")
self.load_model()
def load_model(self, device_map="auto"):
print(f"Loading OpenCLIP model: {self.model_name}")
print(f"Pretrained: {self.pretrained}")
vision_tower, _, image_processor = open_clip.create_model_and_transforms(model_name=self.model_name, pretrained=self.pretrained, precision="fp32", device="cuda")
resize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Resize)][0]
normalize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Normalize)][0]
self.resize_transform_size = resize_transform.size # 224 or 384
self.patch_size = vision_tower.visual.conv1.kernel_size[0] # 14 or 16
self.image_processor = CLIPImageProcessor.from_pretrained(
"openai/clip-vit-large-patch14",
crop_size=resize_transform.size,
size={"shortest_edge": resize_transform.size},
image_mean=list(normalize_transform.mean),
image_std=list(normalize_transform.std),
)
print(f"Loaded image processor: {self.image_processor}")
self.vision_tower = vision_tower.visual
self.vision_tower.requires_grad_(False)
self.is_loaded = True
def feature_select(self, image_forward_outs):
image_features = image_forward_outs[self.select_layer]
if self.select_feature == "patch":
image_features = image_features[:, 1:]
elif self.select_feature == "cls_patch":
image_features = image_features
elif self.select_feature == "conv_flatten":
image_features = image_features.flatten(2).transpose(1, 2)
else:
raise ValueError(f"Unexpected select feature: {self.select_feature}")
return image_features
def forward_visual(self, x, output_hidden_states=False):
if hasattr(self.vision_tower, "trunk") and hasattr(self.vision_tower.trunk, "_intermediate_layers"):
return self.vision_tower.trunk._intermediate_layers(x, abs(self.select_layer))
else:
def forward_openclip(self, x: torch.Tensor):
features = []
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
# class embeddings and positional embeddings
x = torch.cat(
[_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x],
dim=1,
)
# shape = [*, grid ** 2 + 1, width]
x = x + self.positional_embedding.to(x.dtype)
x = self.patch_dropout(x)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
for r in self.transformer.resblocks:
x = r(x, attn_mask=None)
features.append(x)
return features
return forward_openclip(self.vision_tower, x)
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_forward_out = self.forward_visual(image.to(self.dtype).unsqueeze(0), output_hidden_states=True)
image_feature = self.feature_select(image_forward_out).to(image.dtype)
image_features.append(image_feature)
else:
image_forward_outs = self.forward_visual(images.to(self.dtype), output_hidden_states=True)
image_features = self.feature_select(image_forward_outs).to(images.dtype)
return image_features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
if hasattr(self.vision_tower, "conv1"):
return self.vision_tower.conv1.weight.dtype
if hasattr(self.vision_tower, "trunk"):
return self.vision_tower.trunk.patch_embed.proj.weight.dtype
raise NotImplementedError
@property
def device(self):
if hasattr(self.vision_tower, "conv1"):
return self.vision_tower.conv1.weight.device
if hasattr(self.vision_tower, "trunk"):
return self.vision_tower.trunk.patch_embed.proj.weight.device
raise NotImplementedError
@property
def config(self):
return None
@property
def hidden_size(self):
if self.model_name in HIDDEN_SIZE_DICT:
return HIDDEN_SIZE_DICT[self.model_name]
else:
raise NotImplementedError
@property
def num_patches(self):
image_size = self.resize_transform_size if isinstance(self.resize_transform_size, int) else self.resize_transform_size[0]
_num_patches = (image_size // self.patch_size) ** 2
if "cls_patch" in self.select_feature:
_num_patches += 1
return _num_patches
@property
def image_size(self):
return self.resize_transform_size
@property
def num_patches_per_side(self):
return self.resize_transform_size // self.patch_size
from typing import Optional, Tuple, Union, Dict
from dataclasses import dataclass
from functools import partial, reduce
from PIL import Image
import torch
import torch.utils.checkpoint
from torch import nn
import os
from transformers.image_processing_utils import BatchFeature, get_size_dict
from transformers.image_transforms import (
convert_to_rgb,
normalize,
rescale,
resize,
to_channel_dimension_format,
)
from transformers.image_utils import (
ChannelDimension,
PILImageResampling,
to_numpy_array,
)
from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from transformers.modeling_utils import PreTrainedModel
from transformers import PretrainedConfig
from transformers.utils import ModelOutput
import numpy as np
class SigLipImageProcessor:
def __init__(self, image_mean=(0.5, 0.5, 0.5), image_std=(0.5, 0.5, 0.5), size=(384, 384), crop_size: Dict[str, int] = None, resample=PILImageResampling.BICUBIC, rescale_factor=1 / 255, data_format=ChannelDimension.FIRST):
crop_size = crop_size if crop_size is not None else {"height": 384, "width": 384}
crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
self.image_mean = image_mean
self.image_std = image_std
self.size = size
self.resample = resample
self.rescale_factor = rescale_factor
self.data_format = data_format
self.crop_size = crop_size
def preprocess(self, images, return_tensors):
if isinstance(images, Image.Image):
images = [images]
else:
# to adapt video data
images = [to_numpy_array(image) for image in images]
assert isinstance(images, list)
try:
transforms = [
convert_to_rgb,
to_numpy_array,
partial(resize, size=self.size, resample=self.resample, data_format=self.data_format),
partial(rescale, scale=self.rescale_factor, data_format=self.data_format),
partial(normalize, mean=self.image_mean, std=self.image_std, data_format=self.data_format),
partial(to_channel_dimension_format, channel_dim=self.data_format, input_channel_dim=self.data_format),
]
images = reduce(lambda x, f: [*map(f, x)], transforms, images)
data = {"pixel_values": images}
except ValueError as e:
try:
transforms = [
convert_to_rgb,
to_numpy_array,
partial(resize, size=self.size, resample=self.resample, data_format=self.data_format),
partial(rescale, scale=self.rescale_factor, data_format=self.data_format),
partial(normalize, mean=self.image_mean[0], std=self.image_std[0], data_format=self.data_format),
partial(to_channel_dimension_format, channel_dim=self.data_format, input_channel_dim=self.data_format),
]
images = reduce(lambda x, f: [*map(f, x)], transforms, images)
processed_images = [np.repeat(img, repeats=3, axis=0) for img in images]
data = {"pixel_values": processed_images}
except ValueError as e:
print(f"Grayscale processing failed: {e}")
return BatchFeature(data=data, tensor_type=return_tensors)
class SigLipVisionConfig(PretrainedConfig):
model_type = "siglip_vision_model"
def __init__(
self,
hidden_size=1152,
image_mean=(0.5, 0.5, 0.5),
intermediate_size=4304,
num_hidden_layers=27,
num_attention_heads=16,
num_channels=3,
image_size=384,
patch_size=14,
hidden_act="gelu_pytorch_tanh",
layer_norm_eps=1e-6,
attention_dropout=0.0,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_channels = num_channels
self.patch_size = patch_size
self.image_size = image_size
self.attention_dropout = attention_dropout
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
self.image_mean = image_mean
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the vision config dict if we are loading from SigLipConfig
if config_dict.get("model_type") == "siglip":
config_dict = config_dict["vision_config"]
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
print(f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " f"{cls.model_type}. This is not supported for all configurations of models and can yield errors.")
return cls.from_dict(config_dict, **kwargs)
@dataclass
# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->SigLip
class SigLipVisionModelOutput(ModelOutput):
"""
Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
Args:
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
The image embeddings obtained by applying the projection layer to the pooler_output.
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
image_embeds: Optional[torch.FloatTensor] = None
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
class SigLipVisionEmbeddings(nn.Module):
def __init__(self, config: SigLipVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid",
)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
embeddings = patch_embeds.flatten(2).transpose(1, 2)
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings
class SigLipAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
# Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
def __init__(self, config):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads}).")
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
batch_size, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
k_v_seq_len = key_states.shape[-2]
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
raise ValueError(f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" f" {attn_weights.size()}")
if attention_mask is not None:
if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
raise ValueError(f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}")
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
raise ValueError(f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}")
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->SigLip
class SigLipMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->SigLip
class SigLipEncoderLayer(nn.Module):
def __init__(self, config: SigLipVisionConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = SigLipAttention(config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = SigLipMLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
# Ignore copy
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.FloatTensor]:
"""
Args:
hidden_states (`torch.FloatTensor`):
Input to the layer of shape `(batch, seq_len, embed_dim)`.
attention_mask (`torch.FloatTensor`):
Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class SigLipPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = SigLipVisionConfig
base_model_prefix = "siglip"
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights"""
pass
# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->SigLip
class SigLipEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`SigLipEncoderLayer`].
Args:
config: SigLipVisionConfig
"""
def __init__(self, config: SigLipVisionConfig):
super().__init__()
self.config = config
self.layers = nn.ModuleList([SigLipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
# Ignore copy
def forward(
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutput]:
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
hidden_states = inputs_embeds
for encoder_layer in self.layers:
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions)
class SigLipVisionTransformer(nn.Module):
def __init__(self, config: SigLipVisionConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = SigLipVisionEmbeddings(config)
self.encoder = SigLipEncoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.head = SigLipMultiheadAttentionPoolingHead(config)
def forward(
self,
pixel_values,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
hidden_states = self.embeddings(pixel_values)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.post_layernorm(last_hidden_state)
pooled_output = self.head(last_hidden_state)
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
class SigLipMultiheadAttentionPoolingHead(nn.Module):
"""Multihead Attention Pooling."""
def __init__(self, config: SigLipVisionConfig):
super().__init__()
self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.mlp = SigLipMLP(config)
def forward(self, hidden_state):
batch_size = hidden_state.shape[0]
probe = self.probe.repeat(batch_size, 1, 1)
hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
residual = hidden_state
hidden_state = self.layernorm(hidden_state)
hidden_state = residual + self.mlp(hidden_state)
return hidden_state[:, 0]
class SigLipVisionModel(SigLipPreTrainedModel):
config_class = SigLipVisionConfig
main_input_name = "pixel_values"
_no_split_modules = ["SigLipEncoderLayer"]
def __init__(self, config: SigLipVisionConfig):
super().__init__(config)
self.vision_model = SigLipVisionTransformer(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
def forward(
self,
pixel_values,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, SigLipVisionModel
>>> model = SigLipVisionModel.from_pretrained("google/siglip-base-patch16-224")
>>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_state = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output # pooled features
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
return self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
class SigLipVisionTower(nn.Module):
def __init__(self, vision_tower, vision_tower_cfg, delay_load=False):
super().__init__()
self.is_loaded = False
self.config = SigLipVisionConfig()
self.vision_tower_name = vision_tower
self.image_processor = SigLipImageProcessor()
if not delay_load:
# rank0_print(f"Loading vision tower: {vision_tower}")
self.load_model()
elif getattr(vision_tower_cfg, "unfreeze_mm_vision_tower", False):
# TODO: better detector is needed.
# rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
self.load_model()
elif hasattr(vision_tower_cfg, "mm_tunable_parts") and "mm_vision_tower" in vision_tower_cfg.mm_tunable_parts:
# rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.")
self.load_model()
else:
self.cfg_only = self.config
def load_model(self, device_map=None):
if self.is_loaded:
# rank0_print("{} is already loaded, `load_model` called again, skipping.".format(self.vision_tower_name))
return
self.vision_tower = SigLipVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
# del self.vision_tower.vision_model.encoder.layers[-1:]
self.vision_tower.vision_model.head = nn.Identity()
self.vision_tower.requires_grad_(False)
self.is_loaded = True
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
# image_feature = image_forward_out.hidden_states[-1].to(image.dtype)
image_feature = image_forward_out.last_hidden_state.to(image.dtype)
assert image_features.shape[-2] == 729
image_features.append(image_feature)
else:
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
# image_features = image_forward_outs.hidden_states[-1].to(images.dtype)
image_features = image_forward_outs.last_hidden_state.to(images.dtype)
assert image_features.shape[-2] == 729
return image_features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
for p in self.vision_tower.parameters():
return p.dtype
@property
def device(self):
for p in self.vision_tower.parameters():
return p.device
@property
def hidden_size(self):
return self.config.hidden_size
@property
def num_patches(self):
return (self.config.image_size // self.config.patch_size) ** 2
@property
def num_patches_per_side(self):
return self.config.image_size // self.config.patch_size
# return self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"]
@property
def image_size(self):
return self.config.image_size
\ No newline at end of file
import torch
import torch.nn as nn
import re
class IdentityMap(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, *args, **kwargs):
return x
@property
def config(self):
return {"mm_projector_type": 'identity'}
class SimpleResBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.pre_norm = nn.LayerNorm(channels)
self.proj = nn.Sequential(
nn.Linear(channels, channels),
nn.GELU(),
nn.Linear(channels, channels)
)
def forward(self, x):
x = self.pre_norm(x)
return x + self.proj(x)
def build_vision_projector(config, delay_load=False, **kwargs):
projector_type = getattr(config, 'mm_projector_type', 'linear')
if projector_type == 'linear':
return nn.Linear(config.mm_hidden_size, config.hidden_size)
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
if mlp_gelu_match:
mlp_depth = int(mlp_gelu_match.group(1))
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
return nn.Sequential(*modules)
if projector_type == 'identity':
return IdentityMap()
raise ValueError(f'Unknown projector type: {projector_type}')
def build_gen_vision_projector(config, delay_load=False, **kwargs):
projector_type = getattr(config, 'gen_projector_type', 'linear')
if projector_type == 'linear':
return nn.Linear(config.gen_hidden_size, config.hidden_size)
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
if mlp_gelu_match:
mlp_depth = int(mlp_gelu_match.group(1))
modules = [nn.Linear(config.gen_hidden_size, config.hidden_size)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
return nn.Sequential(*modules)
if projector_type == 'identity':
return IdentityMap()
raise ValueError(f'Unknown projector type: {projector_type}')
def build_down_projector(config, delay_load=False, **kwargs):
return IdentityMap()
from typing import Optional
import torch
from diffusers.models.embeddings import get_2d_rotary_pos_embed_lumina
from transformers import PretrainedConfig, PreTrainedModel
from blip3o.model.lumina_nextdit2d import LuminaNextDiT2DModel
class NextDiTCrossAttnConfig(PretrainedConfig):
model_type = "nextdit-crossattn"
def __init__(
self,
input_size: int = 8,
patch_size: int = 1,
in_channels: int = 1792,
dim: int = 1792,
n_layers: int = 24,
n_heads: int = 28,
n_kv_heads: int = 28,
multiple_of: int = 256,
ffn_dim_multiplier: Optional[float] = None,
norm_eps: float = 1e-5,
latent_embedding_size: int = 3584,
learn_sigma: bool = False,
qk_norm: bool = True,
_gradient_checkpointing: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self.input_size = input_size
self.patch_size = patch_size
self.in_channels = in_channels
self.dim = dim
self.n_layers = n_layers
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.multiple_of = multiple_of
self.ffn_dim_multiplier = ffn_dim_multiplier
self.norm_eps = norm_eps
self.learn_sigma = learn_sigma
self.qk_norm = qk_norm
self.latent_embedding_size = latent_embedding_size
self._gradient_checkpointing = _gradient_checkpointing
class NextDiTCrossAttn(PreTrainedModel):
config_class = NextDiTCrossAttnConfig
def __init__(
self,
config: NextDiTCrossAttnConfig,
) -> None:
super().__init__(config)
assert config.learn_sigma is False, "learn_sigma is not supported in nextdit-crossattn"
self._gradient_checkpointing = config._gradient_checkpointing
self.model = LuminaNextDiT2DModel(
sample_size=config.input_size,
patch_size=config.patch_size,
in_channels=config.in_channels,
hidden_size=config.dim,
num_layers=config.n_layers,
num_attention_heads=config.n_heads,
num_kv_heads=config.n_kv_heads,
multiple_of=config.multiple_of,
ffn_dim_multiplier=config.ffn_dim_multiplier,
norm_eps=config.norm_eps,
learn_sigma=config.learn_sigma,
qk_norm=config.qk_norm,
cross_attention_dim=config.latent_embedding_size,
)
if self._gradient_checkpointing:
self.model.enable_gradient_checkpointing()
# self.model.requires_grad_(False)
self.freqs_cis = get_2d_rotary_pos_embed_lumina(
config.dim // config.n_heads,
384,
384,
)
def forward(self, x, timestep, z_latents, **kwargs):
model_pred = self.model(
hidden_states=x,
timestep=timestep,
encoder_hidden_states=z_latents,
encoder_mask=torch.ones((z_latents.shape[0], z_latents.shape[1]), device=z_latents.device),
image_rotary_emb=self.freqs_cis,
cross_attention_kwargs=dict(),
).sample
return model_pred
from transformers import AutoConfig
def auto_upgrade(config):
cfg = AutoConfig.from_pretrained(config)
if 'blip3o' in config and 'blip3o' not in cfg.model_type:
assert cfg.model_type == 'llama'
print("You are using newer blip3o code base, while the checkpoint of v0 is from older code base.")
print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
if confirm.lower() in ["y", "yes"]:
print("Upgrading checkpoint...")
assert len(cfg.architectures) == 1
setattr(cfg.__class__, "model_type", "blip3o")
cfg.architectures[0] = 'blip3oLlamaForCausalLM'
cfg.save_pretrained(config)
print("Checkpoint upgraded.")
else:
print("Checkpoint upgrade aborted.")
exit(1)
import os
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Sampler
from transformers import Trainer
from transformers.trainer import (
is_sagemaker_mp_enabled,
get_parameter_names,
has_length,
ALL_LAYERNORM_LAYERS,
logger,
)
from typing import List, Optional
from transformers.utils import is_torch_xla_available
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
from torch_xla import __version__ as XLA_VERSION
IS_XLA_FSDPV2_POST_2_2 = version.parse(XLA_VERSION) >= version.parse(XLA_FSDPV2_MIN_VERSION)
if IS_XLA_FSDPV2_POST_2_2:
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
else:
IS_XLA_FSDPV2_POST_2_2 = False
def maybe_zero_3(param, ignore_status=False, name=None):
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
if hasattr(param, "ds_id"):
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
if not ignore_status:
print(name, "no ignore status")
with zero.GatheredParameters([param]):
param = param.data.detach().cpu().clone()
else:
param = param.detach().cpu().clone()
return param
def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}
return to_return
def split_to_even_chunks(indices, lengths, num_chunks):
"""
Split a list of indices into `chunks` chunks of roughly equal lengths.
"""
if len(indices) % num_chunks != 0:
return [indices[i::num_chunks] for i in range(num_chunks)]
num_indices_per_chunk = len(indices) // num_chunks
chunks = [[] for _ in range(num_chunks)]
chunks_lengths = [0 for _ in range(num_chunks)]
for index in indices:
shortest_chunk = chunks_lengths.index(min(chunks_lengths))
chunks[shortest_chunk].append(index)
chunks_lengths[shortest_chunk] += lengths[index]
if len(chunks[shortest_chunk]) == num_indices_per_chunk:
chunks_lengths[shortest_chunk] = float("inf")
return chunks
def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None):
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
assert all(l != 0 for l in lengths), "Should not have zero length."
if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
# all samples are in the same modality
return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator)
mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)]
lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)]
megabatch_size = world_size * batch_size
mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
last_mm = mm_megabatches[-1]
last_lang = lang_megabatches[-1]
additional_batch = last_mm + last_lang
megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
megabatch_indices = torch.randperm(len(megabatches), generator=generator)
megabatches = [megabatches[i] for i in megabatch_indices]
if len(additional_batch) > 0:
megabatches.append(sorted(additional_batch))
return [i for megabatch in megabatches for i in megabatch]
def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
indices = torch.randperm(len(lengths), generator=generator)
megabatch_size = world_size * batch_size
megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
return [i for megabatch in megabatches for batch in megabatch for i in batch]
class LengthGroupedSampler(Sampler):
r"""
Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
keeping a bit of randomness.
"""
def __init__(
self,
batch_size: int,
world_size: int,
lengths: Optional[List[int]] = None,
generator=None,
group_by_modality: bool = False,
):
if lengths is None:
raise ValueError("Lengths must be provided.")
self.batch_size = batch_size
self.world_size = world_size
self.lengths = lengths
self.generator = generator
self.group_by_modality = group_by_modality
def __len__(self):
return len(self.lengths)
def __iter__(self):
if self.group_by_modality:
indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
else:
indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
return iter(indices)
class blip3oTrainer(Trainer):
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.train_dataset is None or not has_length(self.train_dataset):
return None
if self.args.group_by_modality_length:
lengths = self.train_dataset.modality_lengths
return LengthGroupedSampler(
self.args.train_batch_size,
world_size=self.args.world_size * self.args.gradient_accumulation_steps,
lengths=lengths,
group_by_modality=True,
)
else:
return super()._get_train_sampler()
# def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time):
# if not hasattr(self, "largest_loss"):
# self.largest_loss = tr_loss.item()
# self.largest_grad_norm = grad_norm
# self.latest_grad_norm = grad_norm
# else:
# if tr_loss.item() > 10 * self.largest_loss:
# print(f"Loss Spiked: {tr_loss.item()} -> {self.largest_loss}")
# self.control.should_training_stop = True
# if grad_norm > 10 * self.latest_grad_norm and grad_norm > 3:
# print(f"Grad Norm Spiked: {grad_norm} -> {self.latest_grad_norm}")
# self.control.should_training_stop = True
# self.largest_loss = max(tr_loss.item(), self.largest_loss)
# self.largest_grad_norm = max(grad_norm, self.largest_grad_norm)
# self.latest_grad_norm = grad_norm
# if np.isnan(grad_norm) or grad_norm > 1e6:
# print(f"NaN grad norm detected in process {self.args.process_index} on {os.uname().nodename}")
# self.control.should_training_stop = True
# print(f"Shut Down Training")
# if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
# if is_torch_xla_available():
# xm.mark_step()
# logs: Dict[str, float] = {}
# # all_gather + mean() to get average loss over all processes
# tr_loss_scalar = self._nested_gather(tr_loss).mean().item()
# # reset tr_loss to zero
# tr_loss -= tr_loss
# logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
# if grad_norm is not None:
# logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm
# logs["learning_rate"] = self._get_learning_rate()
# self._total_loss_scalar += tr_loss_scalar
# self._globalstep_last_logged = self.state.global_step
# self.store_flos()
# self.log(logs, start_time)
# metrics = None
# if self.control.should_evaluate:
# metrics = self._evaluate(trial, ignore_keys_for_eval)
# is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial)
# if self.args.save_strategy == SaveStrategy.BEST:
# self.control.should_save = is_new_best_metric
# if self.control.should_save:
# self._save_checkpoint(model, trial)
# self.control = self.callback_handler.on_save(self.args, self.state, self.control)
def create_optimizer(self):
"""
Setup the optimizer.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
"""
if is_sagemaker_mp_enabled():
return super().create_optimizer()
opt_model = self.model
if self.optimizer is None:
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
if self.args.mm_projector_lr is not None:
projector_parameters = [name for name, _ in opt_model.named_parameters() if "mm_projector" in name]
optimizer_grouped_parameters = [
{
"params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and p.requires_grad)],
"weight_decay": self.args.weight_decay,
},
{
"params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and p.requires_grad)],
"weight_decay": 0.0,
},
{
"params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in projector_parameters and p.requires_grad)],
"weight_decay": self.args.weight_decay,
"lr": self.args.mm_projector_lr,
},
{
"params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in projector_parameters and p.requires_grad)],
"weight_decay": 0.0,
"lr": self.args.mm_projector_lr,
},
]
else:
optimizer_grouped_parameters = [
{
"params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)],
"weight_decay": self.args.weight_decay,
},
{
"params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)],
"weight_decay": 0.0,
},
]
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
logger.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(module, "weight", {"optim_bits": 32})
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
logger.info(f"skipped: {skipped/2**20}M params")
return self.optimizer
from typing import Optional, Tuple
import warnings
import torch
import transformers
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
try:
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
except ImportError:
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
from flash_attn.bert_padding import unpad_input, pad_input
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
warnings.warn(
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
)
bsz, q_len, _ = hidden_states.size()
query_states = (
self.q_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
key_states = (
self.k_proj(hidden_states)
.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
.transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states)
.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
.transpose(1, 2)
) # shape: (b, num_heads, s, head_dim)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
if past_key_value is not None:
# reuse k, v
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
# Transform the data into the format required by flash attention
qkv = torch.stack([query_states, key_states, value_states], dim=2)
qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim]
key_padding_mask = attention_mask
if key_padding_mask is None:
qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim)
cu_q_lens = torch.arange(
0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
)
max_s = q_len
output = flash_attn_unpadded_qkvpacked_func(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
)
output = output.view(bsz, q_len, -1)
else:
qkv = qkv.reshape(bsz, q_len, -1)
qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask)
qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
output_unpad = flash_attn_unpadded_qkvpacked_func(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
)
output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
output = pad_input(output_unpad, indices, bsz, q_len)
return self.o_proj(output), None, past_key_value
# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
):
# [bsz, seq_len]
return attention_mask
def replace_llama_attn_with_flash_attn():
cuda_major, cuda_minor = torch.cuda.get_device_capability()
if cuda_major < 8:
warnings.warn(
"Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
"ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
)
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
_prepare_decoder_attention_mask
)
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
"""
Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments
"""
import logging
import math
from typing import Optional, Tuple
import torch
import transformers.models.llama.modeling_llama
from torch import nn
try:
import xformers.ops
except ImportError:
logging.error("xformers not found! Please install it before trying to use it.")
def replace_llama_attn_with_xformers_attn():
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
def xformers_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size()
query_states = (
self.q_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
key_states = (
self.k_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
(
query_states,
key_states,
) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# We only apply xformers optimizations if we don't need to output the whole attention matrix
if not output_attentions:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
# This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
# We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
# input and output should be of form (bsz, q_len, num_heads, head_dim)
attn_output = xformers.ops.memory_efficient_attention(
query_states, key_states, value_states, attn_bias=None
)
else:
# input and output should be of form (bsz, q_len, num_heads, head_dim)
attn_output = xformers.ops.memory_efficient_attention(
query_states,
key_states,
value_states,
attn_bias=xformers.ops.LowerTriangularMask(),
)
attn_weights = None
else:
attn_weights = torch.matmul(
query_states, key_states.transpose(2, 3)
) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
)
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights, past_key_value
import os
import io
import copy
from dataclasses import dataclass, field
import json
import logging
import pathlib
from typing import Dict, Optional, Sequence, List
import time
import torch, gc
import glob
import transformers
import tokenizers
import random
from blip3o.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_IDX
from torch.utils.data import Dataset
from blip3o.train.blip3o_trainer import blip3oTrainer
from blip3o import conversation as conversation_lib
from blip3o.model import *
from blip3o.mm_utils import tokenizer_image_token
from PIL import Image, ImageFile
from datasets import load_dataset, concatenate_datasets
from pathlib import Path
from datasets.utils.logging import set_verbosity_info
from transformers import logging as tf_logging
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoProcessor
ImageFile.LOAD_TRUNCATED_IMAGES = True
transform_und_images = T.Compose([T.Resize(448, interpolation=InterpolationMode.BICUBIC, antialias=True), T.CenterCrop(448)])
set_verbosity_info()
tf_logging.set_verbosity_info()
local_rank = None
def rank0_print(*args):
if local_rank == 0:
print(*args)
from packaging import version
IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse("0.14")
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
version: Optional[str] = field(default="v0")
freeze_backbone: bool = field(default=True)
tune_mm_mlp_adapter: bool = field(default=False)
vision_tower: Optional[str] = field(default=None)
gen_vision_tower: Optional[str] = field(default=None)
mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer
pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
pretrain_gen_mlp_adapter: Optional[str] = field(default=None)
vision_tower_pretrained: Optional[str] = field(default=None)
mm_projector_type: Optional[str] = field(default="linear")
gen_projector_type: Optional[str] = field(default="linear")
mm_use_im_start_end: bool = field(default=False)
mm_use_im_patch_token: bool = field(default=True)
mm_patch_merge_type: Optional[str] = field(default="flat")
mm_vision_select_feature: Optional[str] = field(default="patch")
n_query: Optional[int] = field(default=729) # clip 576, siglip 729
n_und_query: Optional[int] = field(default=729) # clip 576, siglip 729
gen_pooling: Optional[str] = field(default="all") # options are: pool2d_3, pool2d_9, seq_3, seq_9, seq_27
@dataclass
class DataArguments:
data_path: str = field(default=None, metadata={"help": "Path to the training data."})
lazy_preprocess: bool = False
is_multimodal: bool = False
image_folder: Optional[str] = field(default=None)
shortcaption_image_folder: Optional[str] = field(default=None)
data_type: Optional[str] = field(default="mix")
image_aspect_ratio: str = "square"
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
remove_unused_columns: bool = field(default=False)
freeze_mm_mlp_adapter: bool = field(default=False)
mpt_attn_impl: Optional[str] = field(default="triton")
model_max_length: int = field(
default=512,
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
)
double_quant: bool = field(
default=True,
metadata={"help": "Compress the quantization statistics through double quantization."},
)
quant_type: str = field(
default="nf4",
metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."},
)
bits: int = field(default=16, metadata={"help": "How many bits to use."})
lora_enable: bool = False
lora_r: int = 64
lora_alpha: int = 16
lora_dropout: float = 0.05
lora_weight_path: str = ""
lora_bias: str = "none"
mm_projector_lr: Optional[float] = None
group_by_modality_length: bool = field(default=False)
def maybe_zero_3(param, ignore_status=False, name=None):
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
if hasattr(param, "ds_id"):
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
if not ignore_status:
logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
with zero.GatheredParameters([param]):
param = param.data.detach().cpu().clone()
else:
param = param.detach().cpu().clone()
return param
# Borrowed from peft.utils.get_peft_model_state_dict
def get_peft_state_maybe_zero_3(named_params, bias):
if bias == "none":
to_return = {k: t for k, t in named_params if "lora_" in k}
elif bias == "all":
to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
elif bias == "lora_only":
to_return = {}
maybe_lora_bias = {}
lora_bias_names = set()
for k, t in named_params:
if "lora_" in k:
to_return[k] = t
bias_name = k.split("lora_")[0] + "bias"
lora_bias_names.add(bias_name)
elif "bias" in k:
maybe_lora_bias[k] = t
for k, t in maybe_lora_bias:
if bias_name in lora_bias_names:
to_return[bias_name] = t
else:
raise NotImplementedError
to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
return to_return
def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
to_return = {k: t for k, t in named_params if "lora_" not in k}
if require_grad_only:
to_return = {k: t for k, t in to_return.items() if t.requires_grad}
to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
return to_return
def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
return to_return
def get_vision_tower_state_maybe_zero_3(named_params, keys_to_match=[""]):
to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
return to_return
def find_all_linear_names(model):
cls = torch.nn.Linear
lora_module_names = set()
multimodal_keywords = ["mm_projector", "vision_tower", "vision_resampler"]
for name, module in model.named_modules():
if any(mm_keyword in name for mm_keyword in multimodal_keywords):
continue
if isinstance(module, cls):
names = name.split(".")
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
if "lm_head" in lora_module_names: # needed for 16-bit
lora_module_names.remove("lm_head")
return list(lora_module_names)
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str, vision_tower: str):
"""Collects the state dict and dump to disk."""
# if getattr(trainer.args, "tune_vision_model", False):
if trainer.deepspeed:
torch.cuda.synchronize()
# Only save Adapter
keys_to_match = ["mm_projector"]
if getattr(trainer.args, "use_im_start_end", False):
keys_to_match.extend(["embed_tokens", "embed_in"])
weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match)
trainer.model.config.save_pretrained(output_dir)
current_folder = output_dir.split("/")[-1]
parent_folder = os.path.dirname(output_dir)
if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
if current_folder.startswith("checkpoint-"):
mm_projector_folder = os.path.join(parent_folder, "mm_projector")
os.makedirs(mm_projector_folder, exist_ok=True)
torch.save(
weight_to_save,
os.path.join(mm_projector_folder, f"{current_folder}.bin"),
)
else:
torch.save(weight_to_save, os.path.join(output_dir, f"mm_projector.bin"))
keys_to_match = ["gen_projector"]
if getattr(trainer.args, "use_im_start_end", False):
keys_to_match.extend(["embed_tokens", "embed_in"])
weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match)
trainer.model.config.save_pretrained(output_dir)
current_folder = output_dir.split("/")[-1]
parent_folder = os.path.dirname(output_dir)
if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
if current_folder.startswith("checkpoint-"):
mm_projector_folder = os.path.join(parent_folder, "gen_projector")
os.makedirs(mm_projector_folder, exist_ok=True)
torch.save(
weight_to_save,
os.path.join(mm_projector_folder, f"{current_folder}.bin"),
)
else:
torch.save(weight_to_save, os.path.join(output_dir, f"gen_projector.bin"))
if trainer.deepspeed:
torch.cuda.synchronize()
trainer.save_model(output_dir)
return
state_dict = trainer.model.state_dict()
if trainer.args.should_save:
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
del state_dict
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
def smart_tokenizer_and_embedding_resize(
special_tokens_dict: Dict,
tokenizer: transformers.PreTrainedTokenizer,
model: transformers.PreTrainedModel,
):
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
model.resize_token_embeddings(len(tokenizer))
if num_new_tokens > 0:
input_embeddings = model.get_input_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
"""Tokenize a list of strings."""
tokenized_list = [
tokenizer(
text,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
)
for text in strings
]
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
input_ids_lens = labels_lens = [tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list]
return dict(
input_ids=input_ids,
labels=labels,
input_ids_lens=input_ids_lens,
labels_lens=labels_lens,
)
def _mask_targets(target, tokenized_lens, speakers):
# cur_idx = 0
cur_idx = tokenized_lens[0]
tokenized_lens = tokenized_lens[1:]
target[:cur_idx] = IGNORE_INDEX
for tokenized_len, speaker in zip(tokenized_lens, speakers):
if speaker == "human":
target[cur_idx + 2 : cur_idx + tokenized_len] = IGNORE_INDEX
cur_idx += tokenized_len
def _add_speaker_and_signal(header, source, get_conversation=True):
"""Add speaker and start/end signal on each round."""
BEGIN_SIGNAL = "### "
END_SIGNAL = "\n"
conversation = header
for sentence in source:
from_str = sentence["from"]
if from_str.lower() == "human":
from_str = conversation_lib.default_conversation.roles[0]
elif from_str.lower() == "gpt":
from_str = conversation_lib.default_conversation.roles[1]
else:
from_str = "unknown"
sentence["value"] = BEGIN_SIGNAL + from_str + ": " + sentence["value"] + END_SIGNAL
if get_conversation:
conversation += sentence["value"]
conversation += BEGIN_SIGNAL
return conversation
def preprocess_multimodal(sources: Sequence[str], data_args: DataArguments) -> Dict:
is_multimodal = data_args.is_multimodal
if not is_multimodal:
return sources
und_placeholder = "<|vision_start|>" + "<|image_pad|>" * data_args.n_und_query + "<|vision_end|>"
gen_placeholder = ""
# "[IMG]" + "<image>" * data_args.n_query + "[/IMG]"
inst_type = None
for source in sources: # [instance]
for sentence in source:
if sentence["from"] == "human" and "<image>" in sentence["value"]:
sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, und_placeholder).strip()
inst_type = "und"
elif sentence["from"] == "gpt" and "<image>" in sentence["value"]:
sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, gen_placeholder).strip()
inst_type = "gen"
return sources, inst_type
def preprocess_qwen(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False, max_len=2048, system_message: str = "You are a helpful assistant.") -> Dict:
roles = {"human": "user", "gpt": "assistant"}
tokenizer = copy.deepcopy(tokenizer)
chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
tokenizer.chat_template = chat_template
# Apply prompt templates
input_ids, targets = [], []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != roles["human"]:
source = source[1:]
input_id, target = [], []
# New version, use apply chat template
# Build system message for each sentence
input_id += tokenizer.apply_chat_template([{"role" : "system", "content" : system_message}])
target += [IGNORE_INDEX] * len(input_id)
for conv in source:
try:
role = conv["role"]
content = conv["content"]
except:
role = conv["from"]
content = conv["value"]
role = roles.get(role, role)
conv = [{"role" : role, "content" : content}]
encode_id = tokenizer.apply_chat_template(conv)
input_id += encode_id
if role in ["user", "system"]:
target += [IGNORE_INDEX] * len(encode_id)
else:
target += encode_id
assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}"
input_ids.append(input_id)
targets.append(target)
input_ids = torch.tensor(input_ids, dtype=torch.long)
targets = torch.tensor(targets, dtype=torch.long)
return dict(
input_ids=input_ids, # tensor(bs x seq_len)
labels=targets, # tensor(bs x seq_len)
)
def preprocess_llama3(
sources,
tokenizer: transformers.PreTrainedTokenizer,
has_image: bool = False,
max_len=2048,
system_message: str = "You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.",
) -> Dict:
# roles = {"human": "<|start_header_id|>user<|end_header_id|>", "gpt": "<|start_header_id|>assistant<|end_header_id|>"}
roles = {"human": "user", "gpt": "assistant"}
# Add image tokens to tokenizer as a special tokens
# Use a deepcopy of tokenizer so that we don't modify on the tokenizer
tokenizer = copy.deepcopy(tokenizer)
# When there is actually an image, we add the image tokens as a special token
if has_image:
tokenizer.add_tokens(["<image>"], special_tokens=True)
image_token_index = tokenizer.convert_tokens_to_ids("<image>")
bos_token_id = tokenizer.convert_tokens_to_ids("<|begin_of_text|>")
start_header_id = tokenizer.convert_tokens_to_ids("<|start_header_id|>")
end_header_id = tokenizer.convert_tokens_to_ids("<|end_header_id|>")
eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>")
unmask_tokens = ["<|begin_of_text|>", "<|start_header_id|>", "<|end_header_id|>", "<|eot_id|>", "\n\n"]
unmask_tokens_idx = [tokenizer.convert_tokens_to_ids(tok) for tok in unmask_tokens]
# After update, calling tokenizer of llama3 will
# auto add bos id for the tokens. ヽ(`⌒´)ノ
def safe_tokenizer_llama3(text):
input_ids = tokenizer(text).input_ids
if input_ids[0] == bos_token_id:
input_ids = input_ids[1:]
return input_ids
nl_tokens = tokenizer.convert_tokens_to_ids("\n\n")
# Apply prompt templates
input_ids, targets = [], []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != roles["human"]:
source = source[1:]
input_id, target = [], []
# New version, use apply chat template
# Build system message for each sentence
input_id += tokenizer.apply_chat_template([{"role" : "system", "content" : system_message}])
target += [IGNORE_INDEX] * len(input_id)
for conv in source:
try:
role = conv["role"]
content = conv["content"]
except:
role = conv["from"]
content = conv["value"]
role = roles.get(role, role)
conv = [{"role" : role, "content" : content}]
# First is bos token we don't need here
encode_id = tokenizer.apply_chat_template(conv)[1:]
input_id += encode_id
if role in ["user", "system"]:
target += [IGNORE_INDEX] * len(encode_id)
else:
target += encode_id
assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}"
for idx, encode_id in enumerate(input_id):
if encode_id in unmask_tokens_idx:
target[idx] = encode_id
if encode_id == image_token_index:
input_id[idx] = IMAGE_TOKEN_INDEX
input_ids.append(input_id)
targets.append(target)
input_ids = torch.tensor(input_ids, dtype=torch.long)
targets = torch.tensor(targets, dtype=torch.long)
return dict(
input_ids=input_ids, # tensor(bs x seq_len)
labels=targets, # tensor(bs x seq_len)
)
def preprocess_plain(
sources: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
# add end signal and concatenate together
conversations = []
for source in sources:
assert len(source) == 2
# assert DEFAULT_IMAGE_TOKEN in source[0]['value'] or DEFAULT_IMAGE_TOKEN in source[1]['value']
conversation = source[0]["value"] + source[1]["value"] + conversation_lib.default_conversation.sep
conversations.append(conversation)
# tokenize conversations
input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations]
targets = copy.deepcopy(input_ids)
for target, source in zip(targets, sources):
tokenized_len = len(tokenizer_image_token(source[0]["value"], tokenizer))
target[:tokenized_len] = IGNORE_INDEX
return dict(input_ids=input_ids, labels=targets)
def preprocess(
sources: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
has_image: bool = False,
) -> Dict:
"""
Given a list of sources, each is a conversation list. This transform:
1. Add signal '### ' at the beginning each sentence, with end signal '\n';
2. Concatenate conversations together;
3. Tokenize the concatenated conversation;
4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
"""
if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:
return preprocess_plain(sources, tokenizer)
if conversation_lib.default_conversation.version == "llama3":
return preprocess_llama3(sources, tokenizer, has_image=has_image)
if conversation_lib.default_conversation.version == "qwen":
return preprocess_qwen(sources, tokenizer, has_image=has_image)
# add end signal and concatenate together
conversations = []
for source in sources:
header = f"{conversation_lib.default_conversation.system}\n\n"
conversation = _add_speaker_and_signal(header, source)
conversations.append(conversation)
# tokenize conversations
def get_tokenize_len(prompts):
return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts]
if has_image:
input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations]
else:
conversations_tokenized = _tokenize_fn(conversations, tokenizer)
input_ids = conversations_tokenized["input_ids"]
targets = copy.deepcopy(input_ids)
for target, source in zip(targets, sources):
if has_image:
tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source])
else:
tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"]
speakers = [sentence["from"] for sentence in source]
_mask_targets(target, tokenized_lens, speakers)
return dict(input_ids=input_ids, labels=targets)
class LazySupervisedMixDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(
self,
data_path: str,
tokenizer: transformers.PreTrainedTokenizer,
data_args: DataArguments,
):
super(LazySupervisedMixDataset, self).__init__()
self.data_args = data_args
list_data_dict = []
###################################### text to image #######################################
data_files = glob.glob(os.path.join(self.data_args.image_folder, "*.tar"))
## text to image
train_dataset = load_dataset("webdataset", data_files=data_files, split="train", num_proc=128)
train_dataset = train_dataset.rename_column("jpg", "image")
train_dataset = train_dataset.add_column('type', len(train_dataset) * ['T2I'])
train_dataset = train_dataset.add_column('image_path', len(train_dataset) * [None])
train_dataset = train_dataset.remove_columns([col for col in train_dataset.column_names if not col in (
["image", "txt", "type", "image_path"])])
print(f"finish loading image {len(train_dataset)}")
list_data_dict.append(train_dataset)
if len(list_data_dict) > 1:
list_data_dict = concatenate_datasets(list_data_dict)
else:
list_data_dict = list_data_dict[0]
list_data_dict = list_data_dict.shuffle(seed=42)
rank0_print(f"Totoal number of training instance: {len(list_data_dict)}")
self.tokenizer = tokenizer
self.list_data_dict = list_data_dict
def __len__(self):
return len(self.list_data_dict)
@property
def lengths(self):
length_list = []
for sample in self.list_data_dict:
img_tokens = 128 if "image" in sample else 0
length_list.append(sum(len(conv["value"].split()) for conv in sample["conversations"]) + img_tokens)
return length_list
@property
def modality_lengths(self):
length_list = []
for sample in self.list_data_dict:
cur_len = sum(len(conv["value"].split()) for conv in sample["conversations"])
cur_len = cur_len if "image" in sample else -cur_len
length_list.append(cur_len)
return length_list
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
while True:
sources = self.list_data_dict[i]
if sources["type"] == "T2I" or sources["type"] == "journeyDB_T2I":
sources["conversations"] = [
{"from": "human", "value": f"Please generate image based on the following caption: {sources['txt']}"},
{"from": "gpt", "value": "<image>"},
]
elif sources["type"] == "I2I" or sources["type"] == "journeyDB_I2I":
sources["conversations"] = [
{
"from": "human",
"value": f"<image>\nPlease reconstruct the given image.",
},
{"from": "gpt", "value": ""},
]
else:
raise ValueError("Unknown source type. Please check the 'type' in 'sources'.")
if "image" in sources:
def img_process(images, processor, image_aspect_ratio):
if image_aspect_ratio == "pad":
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
images = [expand2square(img, tuple(int(x * 255) for x in processor.image_mean)) for img in images]
images = processor.preprocess(images, return_tensors="pt")["pixel_values"]
else:
images = processor.preprocess(images, return_tensors="pt")["pixel_values"]
return images
if sources["type"] == "T2I" or sources["type"] == "I2I":
image_files = self.list_data_dict[i]["image"]
else:
image_files = self.list_data_dict[i]["image_path"]
if not isinstance(image_files, list):
image_files = [image_files]
images = []
def read_bin_as_bytesio(bin_file_path):
with open(bin_file_path, "rb") as f:
return io.BytesIO(f.read())
for img in image_files:
try:
if sources["type"] == "T2I" or sources["type"] == "I2I":
img = img.convert("RGB")
elif sources["type"] == "journeyDB_T2I" or sources["type"] == "journeyDB_I2I":
if sources["type"] == "journeyDB_T2I" or sources["type"] == "journeyDB_I2I":
image_path = os.path.join('/fsx/sfr/data/jiuhai/hub/datasets--JourneyDB--JourneyDB/snapshots/e191aa61ca37e5e4418707ade4df5deb5c6d5d8f/data/train/imgs', img)
else:
raise ValueError("Unknown source type. Please check the 'type' in 'sources'.")
img = Image.open(image_path).convert("RGB")
images.append(img)
except Exception as e:
print(f"Error opening image {img}: {e}")
images = None
break # Skip to the next image if there's an error
if not images is None:
try:
temp = img_process(
images,
self.data_args.gen_image_processor,
self.data_args.image_aspect_ratio,
)
except Exception as e:
print(f"Error wrong number of channels: {e}")
images = None
# If no valid images were found, randomly pick another item
if images is None:
print(sources)
print(f"warning false image!!!!!!")
i = random.randint(0, len(self.list_data_dict) - 1)
continue
sources, inst_type = preprocess_multimodal(copy.deepcopy([sources["conversations"]]), self.data_args)
else:
sources = copy.deepcopy([sources["conversations"]])
data_dict = preprocess(sources, self.tokenizer, has_image=("image" in self.list_data_dict[i]))
if isinstance(i, int):
data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0])
# image exist in the data
if "image" in self.list_data_dict[i]:
if inst_type == "gen":
data_dict["gen_image"] = img_process(
images,
self.data_args.gen_image_processor,
self.data_args.image_aspect_ratio,
)
elif inst_type == "und":
resized_images = [transform_und_images(img) for img in images]
image_inputs = self.data_args.image_processor(resized_images, return_tensors="pt")
data_dict["und_image"] = image_inputs.pixel_values
data_dict["grid_thw"] = image_inputs.image_grid_thw
data_dict["gen_image"] = img_process(
resized_images,
self.data_args.gen_image_processor,
self.data_args.image_aspect_ratio,
)
elif self.data_args.is_multimodal:
crop_size = self.data_args.image_processor.crop_size
data_dict["image"] = torch.zeros(3, crop_size["height"], crop_size["width"])
data_dict["ids"] = self.list_data_dict[i]["id"] if "id" in self.list_data_dict[i] else "unk"
return data_dict
@dataclass
class DataCollatorForSupervisedDataset(object):
"""Collate examples for supervised fine-tuning."""
tokenizer: transformers.PreTrainedTokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels, ids = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels", "ids"))
multi_input_ids = []
multi_labels = []
i_s_pos = []
for input_id, label in zip(input_ids, labels):
input_id = input_id[: self.tokenizer.model_max_length - 65]
label = label[: self.tokenizer.model_max_length - 65]
i_s_pos.append(input_id.shape[0]+1)
img_id = torch.full((65,), IMAGE_TOKEN_IDX, dtype=input_id.dtype, device=input_id.device)
img_id[0] = 151665
input_id = torch.cat([input_id, img_id])
img_label = torch.full((65,), IMAGE_TOKEN_IDX, dtype=label.dtype, device=label.device)
img_label[0] = 151665
label = torch.cat([label, img_label])
multi_input_ids.append(input_id)
multi_labels.append(label)
input_ids = multi_input_ids
labels = multi_labels
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
if input_ids.shape[1] > self.tokenizer.model_max_length:
print(f"Warning input with length {input_ids.shape[1]} is longer than max length {self.tokenizer.model_max_length}")
input_ids = input_ids[:, : self.tokenizer.model_max_length]
labels = labels[:, : self.tokenizer.model_max_length]
batch = dict(
input_ids=input_ids,
labels=labels,
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
)
batch_gen_images = []
batch_und_images = []
batch_grid_thw = []
for instance in instances:
if "gen_image" in instance:
batch_gen_images.append(instance["gen_image"])
if len(batch_gen_images) > 0:
if all(x is not None and y.shape == batch_gen_images[0][0].shape for x in batch_gen_images for y in x):
batch["gen_image"] = torch.cat([images for images in batch_gen_images], dim=0)
else:
batch["gen_image"] = batch_gen_images
else:
batch["gen_image"] = None
for instance in instances:
if "und_image" in instance:
batch_und_images.append(instance["und_image"].unsqueeze(0)) ## 1*1024*1176
batch_grid_thw.append(instance["grid_thw"]) ## 1*3
# print(f"batch_und_images {batch_und_images}")
if len(batch_und_images) > 0:
batch["und_image"] = torch.cat([images for images in batch_und_images], dim=0)
batch["grid_thw"] = torch.cat([images for images in batch_grid_thw], dim=0)
else:
batch["und_image"] = None
batch["grid_thw"] = None
batch["ids"] = ids
batch["i_s_pos"] = i_s_pos
return batch
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
if data_args.data_type == "mix":
train_dataset = LazySupervisedMixDataset(tokenizer=tokenizer, data_path=data_args.data_path, data_args=data_args)
else:
raise ValueError("Unknown data type. Please check the Dataloader type.")
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
def unlock_vit(training_args, model_args, vision_tower):
for n, p in vision_tower.named_parameters():
p.requires_grad = True
def train(attn_implementation=None):
global local_rank
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
print(model_args, data_args, training_args)
local_rank = training_args.local_rank
compute_dtype = torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)
bnb_model_from_pretrained_args = {}
if training_args.bits in [4, 8]:
from transformers import BitsAndBytesConfig
bnb_model_from_pretrained_args.update(
dict(
device_map={"": training_args.device},
load_in_4bit=training_args.bits == 4,
load_in_8bit=training_args.bits == 8,
quantization_config=BitsAndBytesConfig(
load_in_4bit=training_args.bits == 4,
load_in_8bit=training_args.bits == 8,
llm_int8_skip_modules=["mm_projector"],
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=training_args.double_quant,
bnb_4bit_quant_type=training_args.quant_type, # {'fp4', 'nf4'}
),
)
)
if model_args.vision_tower is not None:
model = blip3oLlamaForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
attn_implementation=attn_implementation,
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
**bnb_model_from_pretrained_args,
)
else:
if "Qwen" in model_args.model_name_or_path or "qwen" in model_args.model_name_or_path :
model = blip3oQwenForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
attn_implementation=attn_implementation,
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
**bnb_model_from_pretrained_args,
)
else:
model = transformers.LlamaForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
attn_implementation=attn_implementation,
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
**bnb_model_from_pretrained_args,
)
model.config.use_cache = False
if model_args.freeze_backbone:
for (n, p) in model.get_model().named_parameters():
p.requires_grad = False
for (n, p) in model.visual.named_parameters():
p.requires_grad = False
for (n, p) in model.lm_head.named_parameters():
p.requires_grad = False
if training_args.gradient_checkpointing:
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
if "Qwen" in model_args.model_name_or_path or "qwen" in model_args.model_name_or_path:
tokenizer = AutoProcessor.from_pretrained(model_args.model_name_or_path).tokenizer
tokenizer.model_max_length = training_args.model_max_length
else:
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=False,
)
# tokenizer.pad_token = tokenizer.unk_token
if tokenizer.pad_token is None:
smart_tokenizer_and_embedding_resize(
special_tokens_dict=dict(
pad_token="<pad>",
additional_special_tokens=["[IMG]", "[/IMG]", "<image>"],
),
tokenizer=tokenizer,
model=model,
)
elif not "<image>" in tokenizer.get_added_vocab():
smart_tokenizer_and_embedding_resize(
special_tokens_dict=dict(additional_special_tokens=["[IMG]", "[/IMG]", "<image>"]),
tokenizer=tokenizer,
model=model,
)
if model_args.version in conversation_lib.conv_templates:
conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]
else:
conversation_lib.default_conversation = conversation_lib.conv_templates["llama3"]
rank0_print(f"Using conversation format: {conversation_lib.default_conversation.version}")
# if model_args.vision_tower is not None:
model.get_model().initialize_vision_modules(model_args=model_args, fsdp=training_args.fsdp)
## generation vision tower
gen_vision_tower = model.get_gen_vision_tower()
gen_vision_tower.to(
dtype=torch.bfloat16 if training_args.bf16 else torch.float16,
device=training_args.device,
)
gen_vision_tower.requires_grad_(False)
data_args.gen_image_processor = gen_vision_tower.image_processor
data_args.image_processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct").image_processor
data_args.is_multimodal = True
data_args.n_query = model_args.n_query
data_args.n_und_query = model_args.n_und_query
model.config.image_aspect_ratio = data_args.image_aspect_ratio
model.config.tokenizer_padding_side = tokenizer.padding_side
model.config.tokenizer_model_max_length = tokenizer.model_max_length
model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter
model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter
# Calculate total parameters and trainable parameters
total_params = sum(p.numel() for p in model.get_model().parameters())
trainable_params = sum(p.numel() for p in model.get_model().parameters() if p.requires_grad)
print(f"Total parameters: {total_params}")
print(f"Trainable parameters: {trainable_params}")
model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end
model.config.mm_projector_lr = training_args.mm_projector_lr
training_args.use_im_start_end = model_args.mm_use_im_start_end
model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token
model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer)
model.config.pad_token_id = tokenizer.pad_token_id
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
trainer = blip3oTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
**data_module,
)
from tabulate import tabulate
if trainer.is_world_process_zero():
stat = []
for i, (n, p) in enumerate(trainer.model.named_parameters()):
stat.append([i, n, p.shape, p.requires_grad])
print(tabulate(stat, headers=["idx", "name", "shape", "trainable"]))
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
trainer.save_state()
model.config.use_cache = True
safe_save_model_for_hf_trainer(
trainer=trainer,
output_dir=training_args.output_dir,
vision_tower=model_args.vision_tower,
)
if __name__ == "__main__":
train()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment