Commit bc5ebf0f authored by luopl's avatar luopl
Browse files

Initial commit

parents
Pipeline #2167 canceled with stages
from .lavit import LavitTokenCompressor
from .evo import EVOTokenCompressor
from .avgpool import AvgPoolTokenCompressor
from .roipool import ROIPoolTokenCompressor
from .minicpm_resampler import MiniCPMResampler
from torch import nn
class TokenCompressorStream(nn.Module):
def __init__(self, compressor_list, compressor_type_list) -> None:
super(TokenCompressorStream, self).__init__()
self.compressor_list = nn.ModuleList(compressor_list)
self.compressor_type_list = compressor_type_list
def has_type(self, target):
return target in self.compressor_type_list
def forward(self, x):
# x can be tensor(B, N, C) or [tensor(N1, C), tensor(N2, C), ...]
for type, compressor in zip(self.compressor_type_list, self.compressor_list):
x = compressor(x)
return x
def build_token_compressor(config) -> nn.Sequential:
token_compressor_config = config.token_compressor_config
compressor_list = []
compressor_type_list = []
for item in token_compressor_config:
print(item)
compressor_type = item["type"]
compressor_params = item["params"]
# build token compressor by compressor type
if compressor_type == "lavit":
compressor = LavitTokenCompressor(embed_dim=config.hidden_size, **compressor_params)
elif compressor_type == "evo":
compressor = EVOTokenCompressor(embed_dim=config.hidden_size, **compressor_params)
elif compressor_type == "avgpool":
compressor = AvgPoolTokenCompressor(**compressor_params)
elif compressor_type == "roipool":
compressor = ROIPoolTokenCompressor(**compressor_params)
elif compressor_type == "minicpm_resampler":
assert config.mm_projector_type == "identity_patch"
compressor = MiniCPMResampler(embed_dim=config.hidden_size,
num_heads=config.hidden_size // 128,
kv_dim=config.mm_hidden_size,
**compressor_params)
else:
raise ValueError("Unspported Compressor type!")
compressor_list.append(compressor)
compressor_type_list.append(compressor_type)
print(f"building token compressor done. using: {compressor_type_list}")
return TokenCompressorStream(compressor_list, compressor_type_list)
import torch
from torch import nn
from timm.models.layers import trunc_normal_
import torch.nn.functional as F
class EVOTokenCompressor(nn.Module):
"""
A PyTorch module for compressing tokens using EVO.
Reference: https://github.com/YifanXu74/Evo-ViT/blob/main/deit/evo_deit.py
This module compresses input tokens by reducing their spatial dimensions according to a specified prune ratio.
It includes normalization, a 2-layer MLP, and a pruning mechanism.
Attributes:
embed_dim (int): The input tensor's embedding dimension. Default is 2048.
inner_dim (int): The inner dimension for the 2-layer MLP. Default is 64.
prune_ratio (float): The ratio of tokens to prune. Default is 0.25.
Example:
>>> compressor = EVOTokenCompressor(prune_ratio=0.25)
>>> input_tensor =torch.randn(1, 256, 4096) # Shape: [B, N, dim]
>>> output_tensor = compressor(input_tensor)
>>> print(output_tensor.shape) # Expected shape: [1, 64, 4096]
"""
def __init__(self, embed_dim=2048, inner_dim=64, prune_ratio=0.25, **kwargs):
super(EVOTokenCompressor, self).__init__()
self.embed_dim = embed_dim
self.inner_dim = inner_dim
if type(prune_ratio) is str:
prune_ratio = eval(prune_ratio)
self.prune_ratio = prune_ratio
self.norm = nn.LayerNorm(embed_dim, eps=1e-5)
self.out_conv = nn.Sequential(
nn.Linear(embed_dim, inner_dim),
nn.GELU(),
nn.Linear(inner_dim, 1),
)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward_features(self, x):
x = self.norm(x)
x = self.out_conv(x)
return F.softmax(x.squeeze(-1), dim=-1)
def easy_gather(self, x, indices):
# x: B,N,C; indices: B,N
B, N, C = x.shape
N_new = indices.shape[1]
offset = torch.arange(B, dtype=torch.long, device=x.device).view(B, 1) * N
indices = indices + offset
out = x.reshape(B * N, C)[indices.view(-1)].reshape(B, N_new, C)
return out
def _inner_forward(self, x):
B, N, C = x.shape
N_prune = int(N * self.prune_ratio)
pred_score = self.forward_features(x)
_, indices = torch.sort(pred_score, dim=1, descending=True) # torch.sort is derivable
x = self.easy_gather(x, indices)
image_embeds = x[:, :N_prune]
return image_embeds
def forward(self, x):
if type(x) is list:
x = [self._inner_forward(item.unsqueeze(0)).squeeze(0) for item in x]
else:
x = self._inner_forward(x)
return x
import torch
from torch import nn
from timm.models.layers import trunc_normal_
import torch.nn.functional as F
class LavitTokenCompressor(nn.Module):
"""
A PyTorch module for compressing tokens using LaVIT.
Reference: https://github.com/jy0205/LaVIT/blob/main/LaVIT/models/modeling_visual_tokenzier.py
This module compresses input tokens by reducing their spatial dimensions.
It uses Gumbel-Softmax sampling to select the tokens to keep.
The number of tokens to keep in each image is UNCERTAIN.
Attributes:
embed_dim (int): The input tensor's embedding dimension. Default is 2048.
inner_dim (int): The inner dimension for the 2-layer MLP. Default is 64.
Example:
>>> compressor = LavitTokenCompressor(embed_dim=4096, inner_dim=64)
>>> input_tensor = torch.randn(2, 256, 4096) # Shape: [B, N, dim]
>>> output_tokens = compressor(input_tensor)
>>> print([token.shape for token in output_tokens]) # Example output: [(114, 4096), (98, 4096))]
"""
def __init__(self, embed_dim=2048, inner_dim=64, **kwargs):
super(LavitTokenCompressor, self).__init__()
self.embed_dim = embed_dim
self.inner_dim = inner_dim
self.norm = nn.LayerNorm(embed_dim, eps=1e-5)
self.out_conv = nn.Sequential(
nn.Linear(embed_dim, inner_dim),
nn.GELU(),
nn.Linear(inner_dim, 2),
nn.LogSoftmax(dim=-1)
)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward_features(self, x, policy):
x = self.norm(x)
B, N, C = x.size()
local_x = x[:,:, :C // 2]
global_x = (x[:,:, C // 2:] * policy).sum(dim=1, keepdim=True) / torch.sum(policy, dim=1, keepdim=True)
x = torch.cat([local_x, global_x.expand(B, N, C // 2)], dim=-1)
return self.out_conv(x)
def _inner_forward(self, x):
B, N, C = x.shape
mask = torch.ones((B, N, 1), dtype=x.dtype, device=x.device)
pred_score = self.forward_features(x, mask).reshape(B, -1, 2)
# Sample from the score distribution
hard_keep_decision = F.gumbel_softmax(pred_score, hard=True)[:, :, 0] # [N, num_patches]
token_num = hard_keep_decision.long().sum(dim=-1)
index_select = hard_keep_decision.bool()
# get remained token list
remained_token = torch.masked_select(x, index_select[:,:,None])
remained_token = remained_token.reshape(-1, C) # (sum_n, dim)
remained_token_list = torch.split(remained_token, token_num.tolist()) # [(n1, dim), (n2, dim), ...]
remained_token_list = list(remained_token_list)
return remained_token_list
def forward(self, x):
if type(x) is list:
x = [self._inner_forward(item.unsqueeze(0)).squeeze(0) for item in x]
else:
x = self._inner_forward(x)
return x
from functools import partial
from typing import Optional, Tuple
import numpy as np
import warnings
import torch
from torch import nn
from torch import Tensor
import torch.nn.functional as F
from torch.nn.functional import *
from torch.nn.modules.activation import *
from torch.nn.init import trunc_normal_, constant_, xavier_normal_, xavier_uniform_
from transformers.integrations import is_deepspeed_zero3_enabled
class MiniCPMResampler(nn.Module):
"""
A PyTorch module for resampling tokens using MiniCPM Resampler.
Reference: https://huggingface.co/openbmb/MiniCPM-V-2_6/blob/main/resampler.py
This module uses cross-attention mechanism to condense the information of input tokens into the query tokens.
The number of query tokens is determined by num_queries.
Attributes:
num_queries (int): The number of query tokens.
embed_dim (int): The input tensor's embedding dimension.
num_heads (int): The number of attention heads.
kv_dim (int): The dimension of the key and value vectors. Default is None, which means using embed_dim.
norm_layer (nn.Module): The normalization layer. Default is nn.LayerNorm.
adaptive (bool): Whether to use adaptive resampling. Default is False.
max_size (Tuple[int, int]): The maximum size of the input image. Default is (70, 70).
ckpt_path (str): The path to the checkpoint file. Default is None.
Example:
>>> resampler = MiniCPMResampler(num_queries=64, embed_dim=4096, num_heads=32)
>>> input_tensor = torch.randn(1, 256, 4096) # Shape: [B, N, dim]
>>> output_tensor = resampler(input_tensor)
>>> print(output_tensor.shape) # Expected shape: [1, 64, 4096]
"""
def __init__(
self,
num_queries,
embed_dim,
num_heads,
kv_dim=None,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
adaptive=False,
max_size=(70, 70),
ckpt_path=None
):
super(MiniCPMResampler, self).__init__()
self.resampler = Resampler(num_queries, embed_dim, num_heads, kv_dim, norm_layer, adaptive, max_size)
if ckpt_path is not None:
try:
resampler_weights = torch.load(ckpt_path, map_location='cpu')
self.resampler.load_state_dict({k.split("resampler.")[1]: v for k, v in resampler_weights.items()})
except:
print("load resampler weights error, Init now.")
self.resampler.apply(self.resampler._init_weights)
else:
self.resampler.apply(self.resampler._init_weights)
def forward(self, x, tgt_sizes=None):
if tgt_sizes is None: # default mode, input square image
assert type(x) is torch.Tensor, "only support tensor input"
H = W = int(x.shape[1] ** 0.5)
tgt_sizes = torch.tensor((H, W)).unsqueeze(0).expand(x.shape[0], -1).to(dtype=torch.long, device=x.device)
# print(x.shape)
return self.resampler(x, tgt_sizes)
else: # use the whole minicpm model
return self.resampler(x, tgt_sizes)
def get_2d_sincos_pos_embed(embed_dim, image_size):
"""
image_size: image_size or (image_height, image_width)
return:
pos_embed: [image_height, image_width, embed_dim]
"""
if isinstance(image_size, int):
grid_h_size, grid_w_size = image_size, image_size
else:
grid_h_size, grid_w_size = image_size[0], image_size[1]
grid_h = np.arange(grid_h_size, dtype=np.float32)
grid_w = np.arange(grid_w_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid_new(embed_dim // 2, grid[0]) # (H, W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid_new(embed_dim // 2, grid[1]) # (H, W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D)
return emb
def get_1d_sincos_pos_embed_from_grid_new(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (H, W)
out: (H, W, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float32)
omega /= embed_dim / 2.
omega = 1. / 10000 ** omega # (D/2,)
out = np.einsum('hw,d->hwd', pos, omega) # (H, W, D/2), outer product
emb_sin = np.sin(out) # (H, W, D/2)
emb_cos = np.cos(out) # (H, W, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D)
return emb
class Resampler(nn.Module):
"""
A 2D perceiver-resampler network with one cross attention layers by
given learnable queries and 2d sincos pos_emb
Outputs:
A tensor with the shape of (batch_size, num_queries, embed_dim)
"""
def __init__(
self,
num_queries,
embed_dim,
num_heads,
kv_dim=None,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
adaptive=False,
max_size=(70, 70),
):
super().__init__()
self.num_queries = num_queries
self.embed_dim = embed_dim
self.num_heads = num_heads
self.adaptive = adaptive
self.max_size = max_size
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
if kv_dim is not None and kv_dim != embed_dim:
self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False)
else:
self.kv_proj = nn.Identity()
self.attn = MultiheadAttention(embed_dim, num_heads)
self.ln_q = norm_layer(embed_dim)
self.ln_kv = norm_layer(embed_dim)
self.ln_post = norm_layer(embed_dim)
self.proj = nn.Parameter((embed_dim ** -0.5) * torch.randn(embed_dim, embed_dim))
self._set_2d_pos_cache(self.max_size)
def _set_2d_pos_cache(self, max_size, device='cpu'):
if is_deepspeed_zero3_enabled():
device = 'cuda'
pos_embed = torch.from_numpy(get_2d_sincos_pos_embed(self.embed_dim, max_size)).float().to(device)
self.register_buffer("pos_embed", pos_embed, persistent=False)
def _adjust_pos_cache(self, tgt_sizes, device):
max_h = torch.max(tgt_sizes[:, 0])
max_w = torch.max(tgt_sizes[:, 1])
if max_h > self.max_size[0] or max_w > self.max_size[1]:
self.max_size = [max(max_h, self.max_size[0]), max(max_w, self.max_size[1])]
self._set_2d_pos_cache(self.max_size, device)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x, tgt_sizes=None):
assert x.shape[0] == tgt_sizes.shape[0]
bs = x.shape[0]
device = x.device
dtype = x.dtype
patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1]
self._adjust_pos_cache(tgt_sizes, device=device)
max_patch_len = torch.max(patch_len)
key_padding_mask = torch.zeros((bs, max_patch_len), dtype=torch.bool, device=device)
pos_embed = []
for i in range(bs):
tgt_h, tgt_w = tgt_sizes[i]
pos_embed.append(self.pos_embed[:tgt_h, :tgt_w, :].reshape((tgt_h * tgt_w, -1)).to(dtype)) # patches * D
key_padding_mask[i, patch_len[i]:] = True
pos_embed = torch.nn.utils.rnn.pad_sequence(
pos_embed, batch_first=True, padding_value=0.0).permute(1, 0, 2) # BLD => L * B * D
x = self.kv_proj(x) # B * L * D
x = self.ln_kv(x).permute(1, 0, 2) # L * B * D
q = self.ln_q(self.query) # Q * D
out = self.attn(
self._repeat(q, bs), # Q * B * D
x + pos_embed, # L * B * D + L * B * D
x,
key_padding_mask=key_padding_mask)[0]
# out: Q * B * D
x = out.permute(1, 0, 2) # B * Q * D
x = self.ln_post(x)
x = x @ self.proj
return x
def _repeat(self, query, N: int):
return query.unsqueeze(1).repeat(1, N, 1)
class MultiheadAttention(nn.MultiheadAttention):
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False,
add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None):
super().__init__(embed_dim, num_heads, dropout, bias, add_bias_kv, add_zero_attn,
kdim, vdim, batch_first, device, dtype)
# rewrite out_proj layer,with nn.Linear
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
def forward(self,
query: Tensor,
key: Tensor,
value: Tensor,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
average_attn_weights: bool = True,
is_causal: bool = False) -> Tuple[Tensor, Optional[Tensor]]:
why_not_fast_path = ''
if ((attn_mask is not None and torch.is_floating_point(attn_mask))
or (key_padding_mask is not None) and torch.is_floating_point(key_padding_mask)):
why_not_fast_path = "floating-point masks are not supported for fast path."
is_batched = query.dim() == 3
key_padding_mask = _canonical_mask(
mask=key_padding_mask,
mask_name="key_padding_mask",
other_type=F._none_or_dtype(attn_mask),
other_name="attn_mask",
target_type=query.dtype
)
attn_mask = _canonical_mask(
mask=attn_mask,
mask_name="attn_mask",
other_type=None,
other_name="",
target_type=query.dtype,
check_other=False,
)
if not is_batched:
why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
elif query is not key or key is not value:
# When lifting this restriction, don't forget to either
# enforce that the dtypes all match or test cases where
# they don't!
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias" \
+ f" ({self.in_proj_bias.dtype}) don't match"
elif self.in_proj_weight is None:
why_not_fast_path = "in_proj_weight was None"
elif query.dtype != self.in_proj_weight.dtype:
# this case will fail anyway, but at least they'll get a useful error message.
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight" \
+ f" ({self.in_proj_weight.dtype}) don't match"
elif self.training:
why_not_fast_path = "training is enabled"
elif (self.num_heads % 2) != 0:
why_not_fast_path = "self.num_heads is not even"
elif not self.batch_first:
why_not_fast_path = "batch_first was not True"
elif self.bias_k is not None:
why_not_fast_path = "self.bias_k was not None"
elif self.bias_v is not None:
why_not_fast_path = "self.bias_v was not None"
elif self.add_zero_attn:
why_not_fast_path = "add_zero_attn was enabled"
elif not self._qkv_same_embed_dim:
why_not_fast_path = "_qkv_same_embed_dim was not True"
elif query.is_nested and (key_padding_mask is not None or attn_mask is not None):
why_not_fast_path = "supplying both src_key_padding_mask and src_mask at the same time \
is not supported with NestedTensor input"
elif torch.is_autocast_enabled():
why_not_fast_path = "autocast is enabled"
if not why_not_fast_path:
tensor_args = (
query,
key,
value,
self.in_proj_weight,
self.in_proj_bias,
self.out_proj.weight,
self.out_proj.bias,
)
# We have to use list comprehensions below because TorchScript does not support
# generator expressions.
if torch.overrides.has_torch_function(tensor_args):
why_not_fast_path = "some Tensor argument has_torch_function"
elif _is_make_fx_tracing():
why_not_fast_path = "we are running make_fx tracing"
elif not all(_check_arg_device(x) for x in tensor_args):
why_not_fast_path = ("some Tensor argument's device is neither one of "
f"cpu, cuda or {torch.utils.backend_registration._privateuse1_backend_name}")
elif torch.is_grad_enabled() and any(_arg_requires_grad(x) for x in tensor_args):
why_not_fast_path = ("grad is enabled and at least one of query or the "
"input/output projection weights or biases requires_grad")
if not why_not_fast_path:
merged_mask, mask_type = self.merge_masks(attn_mask, key_padding_mask, query)
if self.in_proj_bias is not None and self.in_proj_weight is not None:
return torch._native_multi_head_attention(
query,
key,
value,
self.embed_dim,
self.num_heads,
self.in_proj_weight,
self.in_proj_bias,
self.out_proj.weight,
self.out_proj.bias,
merged_mask,
need_weights,
average_attn_weights,
mask_type)
any_nested = query.is_nested or key.is_nested or value.is_nested
assert not any_nested, "MultiheadAttention does not support NestedTensor outside of its fast path. " + \
f"The fast path was not hit because {why_not_fast_path}"
if self.batch_first and is_batched:
# make sure that the transpose op does not affect the "is" property
if key is value:
if query is key:
query = key = value = query.transpose(1, 0)
else:
query, key = (x.transpose(1, 0) for x in (query, key))
value = key
else:
query, key, value = (x.transpose(1, 0) for x in (query, key, value))
if not self._qkv_same_embed_dim:
attn_output, attn_output_weights = self.multi_head_attention_forward(
query, key, value, self.embed_dim, self.num_heads,
self.in_proj_weight, self.in_proj_bias,
self.bias_k, self.bias_v, self.add_zero_attn,
self.dropout, self.out_proj.weight, self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask, need_weights=need_weights,
attn_mask=attn_mask,
use_separate_proj_weight=True,
q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight,
average_attn_weights=average_attn_weights,
is_causal=is_causal)
else:
attn_output, attn_output_weights = self.multi_head_attention_forward(
query, key, value, self.embed_dim, self.num_heads,
self.in_proj_weight, self.in_proj_bias,
self.bias_k, self.bias_v, self.add_zero_attn,
self.dropout, self.out_proj.weight, self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
average_attn_weights=average_attn_weights,
is_causal=is_causal)
if self.batch_first and is_batched:
return attn_output.transpose(1, 0), attn_output_weights
else:
return attn_output, attn_output_weights
def multi_head_attention_forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
embed_dim_to_check: int,
num_heads: int,
in_proj_weight: Optional[Tensor],
in_proj_bias: Optional[Tensor],
bias_k: Optional[Tensor],
bias_v: Optional[Tensor],
add_zero_attn: bool,
dropout_p: float,
out_proj_weight: Tensor,
out_proj_bias: Optional[Tensor],
training: bool = True,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
use_separate_proj_weight: bool = False,
q_proj_weight: Optional[Tensor] = None,
k_proj_weight: Optional[Tensor] = None,
v_proj_weight: Optional[Tensor] = None,
static_k: Optional[Tensor] = None,
static_v: Optional[Tensor] = None,
average_attn_weights: bool = True,
is_causal: bool = False,
) -> Tuple[Tensor, Optional[Tensor]]:
# tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias)
is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads)
# For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
# is batched, run the computation and before returning squeeze the
# batch dimension so that the output doesn't carry this temporary batch dimension.
if not is_batched:
# unsqueeze if the input is unbatched
query = query.unsqueeze(1)
key = key.unsqueeze(1)
value = value.unsqueeze(1)
if key_padding_mask is not None:
key_padding_mask = key_padding_mask.unsqueeze(0)
# set up shape vars
tgt_len, bsz, embed_dim = query.shape
src_len, _, _ = key.shape
key_padding_mask = _canonical_mask(
mask=key_padding_mask,
mask_name="key_padding_mask",
other_type=_none_or_dtype(attn_mask),
other_name="attn_mask",
target_type=query.dtype
)
if is_causal and attn_mask is None:
raise RuntimeError(
"Need attn_mask if specifying the is_causal hint. "
"You may use the Transformer module method "
"`generate_square_subsequent_mask` to create this mask."
)
if is_causal and key_padding_mask is None and not need_weights:
# when we have a kpm or need weights, we need attn_mask
# Otherwise, we use the is_causal hint go as is_causal
# indicator to SDPA.
attn_mask = None
else:
attn_mask = _canonical_mask(
mask=attn_mask,
mask_name="attn_mask",
other_type=None,
other_name="",
target_type=query.dtype,
check_other=False,
)
if key_padding_mask is not None:
# We have the attn_mask, and use that to merge kpm into it.
# Turn off use of is_causal hint, as the merged mask is no
# longer causal.
is_causal = False
assert embed_dim == embed_dim_to_check, \
f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
if isinstance(embed_dim, torch.Tensor):
# embed_dim can be a tensor when JIT tracing
head_dim = embed_dim.div(num_heads, rounding_mode='trunc')
else:
head_dim = embed_dim // num_heads
assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
if use_separate_proj_weight:
# allow MHA to have different embedding dimensions when separate projection weights are used
assert key.shape[:2] == value.shape[:2], \
f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
else:
assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
#
# compute in-projection
#
if not use_separate_proj_weight:
assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None"
q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
else:
assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
if in_proj_bias is None:
b_q = b_k = b_v = None
else:
b_q, b_k, b_v = in_proj_bias.chunk(3)
q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v)
# prep attention mask
if attn_mask is not None:
# ensure attn_mask's dim is 3
if attn_mask.dim() == 2:
correct_2d_size = (tgt_len, src_len)
if attn_mask.shape != correct_2d_size:
raise RuntimeError(
f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}."
)
attn_mask = attn_mask.unsqueeze(0)
elif attn_mask.dim() == 3:
correct_3d_size = (bsz * num_heads, tgt_len, src_len)
if attn_mask.shape != correct_3d_size:
raise RuntimeError(
f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."
)
else:
raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
# add bias along batch dimension (currently second)
if bias_k is not None and bias_v is not None:
assert static_k is None, "bias cannot be added to static key."
assert static_v is None, "bias cannot be added to static value."
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
attn_mask = pad(attn_mask, (0, 1))
if key_padding_mask is not None:
key_padding_mask = pad(key_padding_mask, (0, 1))
else:
assert bias_k is None
assert bias_v is None
#
# reshape q, k, v for multihead attention and make em batch first
#
q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
if static_k is None:
k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
assert static_k.size(0) == bsz * num_heads, \
f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
assert static_k.size(2) == head_dim, \
f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
k = static_k
if static_v is None:
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
assert static_v.size(0) == bsz * num_heads, \
f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
assert static_v.size(2) == head_dim, \
f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
v = static_v
# add zero attention along batch dimension (now first)
if add_zero_attn:
zero_attn_shape = (bsz * num_heads, 1, head_dim)
k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1)
v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1)
if attn_mask is not None:
attn_mask = pad(attn_mask, (0, 1))
if key_padding_mask is not None:
key_padding_mask = pad(key_padding_mask, (0, 1))
# update source sequence length after adjustments
src_len = k.size(1)
# merge key padding and attention masks
if key_padding_mask is not None:
assert key_padding_mask.shape == (bsz, src_len), \
f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \
expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
if attn_mask is None:
attn_mask = key_padding_mask
else:
attn_mask = attn_mask + key_padding_mask
# adjust dropout probability
if not training:
dropout_p = 0.0
#
# (deep breath) calculate attention and out projection
#
if need_weights:
B, Nt, E = q.shape
q_scaled = q / math.sqrt(E)
assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights"
if attn_mask is not None:
attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
else:
attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
attn_output_weights = softmax(attn_output_weights, dim=-1)
if dropout_p > 0.0:
attn_output_weights = dropout(attn_output_weights, p=dropout_p)
attn_output = torch.bmm(attn_output_weights, v)
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
attn_output = self.out_proj(attn_output)
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
# optionally average attention weights over heads
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
if average_attn_weights:
attn_output_weights = attn_output_weights.mean(dim=1)
if not is_batched:
# squeeze the output if input was unbatched
attn_output = attn_output.squeeze(1)
attn_output_weights = attn_output_weights.squeeze(0)
return attn_output, attn_output_weights
else:
# attn_mask can be either (L,S) or (N*num_heads, L, S)
# if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
# in order to match the input for SDPA of (N, num_heads, L, S)
if attn_mask is not None:
if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
attn_mask = attn_mask.unsqueeze(0)
else:
attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)
q = q.view(bsz, num_heads, tgt_len, head_dim)
k = k.view(bsz, num_heads, src_len, head_dim)
v = v.view(bsz, num_heads, src_len, head_dim)
attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
attn_output = self.out_proj(attn_output)
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
if not is_batched:
# squeeze the output if input was unbatched
attn_output = attn_output.squeeze(1)
return attn_output, None
def _mha_shape_check(query: Tensor, key: Tensor, value: Tensor,
key_padding_mask: Optional[Tensor], attn_mask: Optional[Tensor], num_heads: int):
# Verifies the expected shape for `query, `key`, `value`, `key_padding_mask` and `attn_mask`
# and returns if the input is batched or not.
# Raises an error if `query` is not 2-D (unbatched) or 3-D (batched) tensor.
# Shape check.
if query.dim() == 3:
# Batched Inputs
is_batched = True
assert key.dim() == 3 and value.dim() == 3, \
("For batched (3-D) `query`, expected `key` and `value` to be 3-D"
f" but found {key.dim()}-D and {value.dim()}-D tensors respectively")
if key_padding_mask is not None:
assert key_padding_mask.dim() == 2, \
("For batched (3-D) `query`, expected `key_padding_mask` to be `None` or 2-D"
f" but found {key_padding_mask.dim()}-D tensor instead")
if attn_mask is not None:
assert attn_mask.dim() in (2, 3), \
("For batched (3-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
f" but found {attn_mask.dim()}-D tensor instead")
elif query.dim() == 2:
# Unbatched Inputs
is_batched = False
assert key.dim() == 2 and value.dim() == 2, \
("For unbatched (2-D) `query`, expected `key` and `value` to be 2-D"
f" but found {key.dim()}-D and {value.dim()}-D tensors respectively")
if key_padding_mask is not None:
assert key_padding_mask.dim() == 1, \
("For unbatched (2-D) `query`, expected `key_padding_mask` to be `None` or 1-D"
f" but found {key_padding_mask.dim()}-D tensor instead")
if attn_mask is not None:
assert attn_mask.dim() in (2, 3), \
("For unbatched (2-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
f" but found {attn_mask.dim()}-D tensor instead")
if attn_mask.dim() == 3:
expected_shape = (num_heads, query.shape[0], key.shape[0])
assert attn_mask.shape == expected_shape, \
(f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}")
else:
raise AssertionError(
f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor")
return is_batched
def _canonical_mask(
mask: Optional[Tensor],
mask_name: str,
other_type: Optional[DType],
other_name: str,
target_type: DType,
check_other: bool = True,
) -> Optional[Tensor]:
if mask is not None:
_mask_dtype = mask.dtype
_mask_is_float = torch.is_floating_point(mask)
if _mask_dtype != torch.bool and not _mask_is_float:
raise AssertionError(
f"only bool and floating types of {mask_name} are supported")
if check_other and other_type is not None:
if _mask_dtype != other_type:
warnings.warn(
f"Support for mismatched {mask_name} and {other_name} "
"is deprecated. Use same type for both instead."
)
if not _mask_is_float:
mask = (
torch.zeros_like(mask, dtype=target_type)
.masked_fill_(mask, float("-inf"))
)
return mask
def _none_or_dtype(input: Optional[Tensor]) -> Optional[DType]:
if input is None:
return None
elif isinstance(input, torch.Tensor):
return input.dtype
raise RuntimeError("input to _none_or_dtype() must be None or torch.Tensor")
def _in_projection_packed(
q: Tensor,
k: Tensor,
v: Tensor,
w: Tensor,
b: Optional[Tensor] = None,
) -> List[Tensor]:
r"""
Performs the in-projection step of the attention operation, using packed weights.
Output is a triple containing projection tensors for query, key and value.
Args:
q, k, v: query, key and value tensors to be projected. For self-attention,
these are typically the same tensor; for encoder-decoder attention,
k and v are typically the same tensor. (We take advantage of these
identities for performance if they are present.) Regardless, q, k and v
must share a common embedding dimension; otherwise their shapes may vary.
w: projection weights for q, k and v, packed into a single tensor. Weights
are packed along dimension 0, in q, k, v order.
b: optional projection biases for q, k and v, packed into a single tensor
in q, k, v order.
Shape:
Inputs:
- q: :math:`(..., E)` where E is the embedding dimension
- k: :math:`(..., E)` where E is the embedding dimension
- v: :math:`(..., E)` where E is the embedding dimension
- w: :math:`(E * 3, E)` where E is the embedding dimension
- b: :math:`E * 3` where E is the embedding dimension
Output:
- in output list :math:`[q', k', v']`, each output tensor will have the
same shape as the corresponding input tensor.
"""
E = q.size(-1)
if k is v:
if q is k:
# self-attention
proj = linear(q, w, b)
# reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
proj = proj.unflatten(-1, (3, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
return proj[0], proj[1], proj[2]
else:
# encoder-decoder attention
w_q, w_kv = w.split([E, E * 2])
if b is None:
b_q = b_kv = None
else:
b_q, b_kv = b.split([E, E * 2])
q_proj = linear(q, w_q, b_q)
kv_proj = linear(k, w_kv, b_kv)
# reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk()
kv_proj = kv_proj.unflatten(-1, (2, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
return (q_proj, kv_proj[0], kv_proj[1])
else:
w_q, w_k, w_v = w.chunk(3)
if b is None:
b_q = b_k = b_v = None
else:
b_q, b_k, b_v = b.chunk(3)
return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
def _in_projection(
q: Tensor,
k: Tensor,
v: Tensor,
w_q: Tensor,
w_k: Tensor,
w_v: Tensor,
b_q: Optional[Tensor] = None,
b_k: Optional[Tensor] = None,
b_v: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Tensor]:
r"""
Performs the in-projection step of the attention operation. This is simply
a triple of linear projections, with shape constraints on the weights which
ensure embedding dimension uniformity in the projected outputs.
Output is a triple containing projection tensors for query, key and value.
Args:
q, k, v: query, key and value tensors to be projected.
w_q, w_k, w_v: weights for q, k and v, respectively.
b_q, b_k, b_v: optional biases for q, k and v, respectively.
Shape:
Inputs:
- q: :math:`(Qdims..., Eq)` where Eq is the query embedding dimension and Qdims are any
number of leading dimensions.
- k: :math:`(Kdims..., Ek)` where Ek is the key embedding dimension and Kdims are any
number of leading dimensions.
- v: :math:`(Vdims..., Ev)` where Ev is the value embedding dimension and Vdims are any
number of leading dimensions.
- w_q: :math:`(Eq, Eq)`
- w_k: :math:`(Eq, Ek)`
- w_v: :math:`(Eq, Ev)`
- b_q: :math:`(Eq)`
- b_k: :math:`(Eq)`
- b_v: :math:`(Eq)`
Output: in output triple :math:`(q', k', v')`,
- q': :math:`[Qdims..., Eq]`
- k': :math:`[Kdims..., Eq]`
- v': :math:`[Vdims..., Eq]`
"""
Eq, Ek, Ev = q.size(-1), k.size(-1), v.size(-1)
assert w_q.shape == (Eq, Eq), f"expecting query weights shape of {(Eq, Eq)}, but got {w_q.shape}"
assert w_k.shape == (Eq, Ek), f"expecting key weights shape of {(Eq, Ek)}, but got {w_k.shape}"
assert w_v.shape == (Eq, Ev), f"expecting value weights shape of {(Eq, Ev)}, but got {w_v.shape}"
assert b_q is None or b_q.shape == (Eq,), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}"
assert b_k is None or b_k.shape == (Eq,), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}"
assert b_v is None or b_v.shape == (Eq,), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}"
return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
from torch import nn
class ROIPoolTokenCompressor(nn.Module):
"""
A Pytorch module for compressing tokens using RoI Pooling.
This module performs RoI Pooling on the input tensor to reduce its spatial dimensions
by specified max_vision_token and mode.
Attributes:
max_vision_token (int): The max vision token number.
mode (str): The mode for RoI Pooling. Default is "single". Options: "single" or "multiple".
Note:
When mode is "single", max_vision_token means the max vision token number of
one image (no cropping) or one tile (cropping).
When mode is "multiple", max_vision_token means the max vision token number of
all tiles (only for cropping).
Example:
>>> compressor = ROIPoolTokenCompressor(max_vision_token=64, mode="single")
>>> input_tensor = torch.randn(1, 256, 4096) # Shape: [B, N, dim], B means the number of images
>>> output_tensor = compressor(input_tensor)
>>> print(output_tensor.shape) # Expected shape: [1, 64, 4096]
>>> compressor = ROIPoolTokenCompressor(max_vision_token="4*64", mode="multiple")
>>> input_tensor = torch.randn(4, 256, 4096) # Shape: [B, N, dim], B means the number of tiles of one image
>>> output_tensor = compressor(input_tensor)
>>> print(output_tensor.shape) # Expected shape: [4, 64, 4096]
"""
def __init__(self, max_vision_token, mode="single") -> None:
super(ROIPoolTokenCompressor, self).__init__()
assert mode in ["single", "multiple"], "Unspported mode for ROIPoolTokenCompressor"
if type(max_vision_token) is str:
max_vision_token = eval(max_vision_token)
self.max_vision_token = max_vision_token
self.mode = mode
def _inner_forward(self, x):
B, N, dim = x.shape
H = W = int(N ** 0.5)
if self.mode == "single" and N > self.max_vision_token:
H_new = W_new = int(self.max_vision_token ** 0.5)
x = x.view(B, H, W, dim).permute(0, 3, 1, 2)
# different from roi pooling, but in square image, it seems the same
x = nn.AdaptiveAvgPool2d((H_new, W_new))(x)
x = x.permute(0, 2, 3, 1).view(B, -1, dim)
elif self.mode == "multiple" and (B * N) > self.max_vision_token:
H_new = W_new = int((self.max_vision_token / B) ** 0.5)
x = x.view(B, H, W, dim).permute(0, 3, 1, 2)
# different from roi pooling, but in square image, it seems the same
x = nn.AdaptiveAvgPool2d((H_new, W_new))(x)
x = x.permute(0, 2, 3, 1).view(B, -1, dim)
return x
def forward(self, x):
if type(x) is list:
x = [self._inner_forward(item.unsqueeze(0)).squeeze(0) for item in x]
else:
x = self._inner_forward(x)
return x
# Copyright 2023 Haotian Liu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
import torch
from torch import nn
import numpy as np
from ..util.config import (
DEFAULT_IM_END_TOKEN,
DEFAULT_IM_START_TOKEN,
DEFAULT_IMAGE_PATCH_TOKEN,
DEFAULT_VI_END_TOKEN,
DEFAULT_VI_START_TOKEN,
GANDALF_TOKEN_INDEX,
IGNORE_INDEX,
IMAGE_TOKEN_INDEX,
COR_START_TOKEN,
COR_END_TOKEN
)
from .multimodal_encoder.builder import build_vision_tower
from .multimodal_projector.builder import build_vision_projector
from .token_compressor.builder import build_token_compressor
from ..util.mm_utils import get_anyres_image_grid_shape, unpad_image
class ValleyMetaModel:
def __init__(self, config):
super(ValleyMetaModel, self).__init__(config)
if hasattr(config, "mm_vision_tower"):
self.vision_tower, self.qwen2vl_vision_tower = build_vision_tower(config, delay_load=False)
self.mm_projector = build_vision_projector(config)
if hasattr(config, "token_compressor_config") and config.token_compressor_config is not None:
self.token_compressor = build_token_compressor(config)
def get_vision_tower(self):
vision_tower = getattr(self, "vision_tower", None)
qwen2vl_vision_tower = getattr(self, "qwen2vl_vision_tower", None)
if type(vision_tower) is list:
vision_tower = vision_tower[0]
return vision_tower, qwen2vl_vision_tower
def get_token_compressor(self):
token_compressor = getattr(self, "token_compressor", None)
return token_compressor
def initialize_token_compressor(self, model_args, logger):
self.config.token_compressor_config = model_args.token_compressor_config
if getattr(self, "token_compressor", None) is None and model_args.token_compressor_config is not None:
logger.warning("initializing token compressor weights...")
self.token_compressor = build_token_compressor(self.config)
def initialize_vision_modules(self, model_args, logger):
""" Initialize thevision modules and save the model config args
when first train multimodal model. The function should after model init
in train script.
Args:
model_args (_type_): model arguments from train config.
"""
self.config.mm_vision_tower = model_args.vision_tower # model_args.vision_tower is string
self.config.eagle_vision_tower = model_args.eagle_vision_tower
self.vision_tower, self.qwen2vl_vision_tower = build_vision_tower(model_args)
self.vision_tower.load_model() # vision_tower is an instance
self.config.mm_projector_type = getattr(model_args, "mm_projector_type", "linear")
self.config.pool_out_size = model_args.pool_out_size
self.config.mm_hidden_size = self.vision_tower.hidden_size
self.config.mm_vision_select_layer = model_args.mm_vision_select_layer
self.config.mm_vision_select_feature = model_args.mm_vision_select_feature
self.config.pixelshuffle_downsample_ratio = model_args.pixelshuffle_downsample_ratio
self.config.mlp_hidden_dim = model_args.mlp_hidden_dim
self.config.tokenize_function = model_args.tokenize_function
# valley-video projector has no mm_projector_type attribute
if (getattr(self, "mm_projector", None) is None) or (
getattr(self.mm_projector, "mm_projector_type", None) != self.config.mm_projector_type
):
logger.warning("initializing projector weights...")
self.mm_projector = build_vision_projector(self.config)
if model_args.pretrain_mm_mlp_adapter is not None:
mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location="cpu")
weight_keys = list(mm_projector_weights.keys())
def get_w(weights, keyword):
return {k.split(keyword + ".")[1]: v for k, v in weights.items() if keyword in k}
try:
logger.warning('Loading projector weight, and projector weight keys have prefix "mm_projector". ')
self.mm_projector.load_state_dict(get_w(mm_projector_weights, "mm_projector"))
except:
assert "mm_projector" not in weight_keys[0]
self.mm_projector.load_state_dict(mm_projector_weights)
class ValleyMetaForCausalLM(ABC):
@abstractmethod
def get_model(self):
pass
def get_vision_tower(self):
return self.get_model().get_vision_tower()
def get_token_compressor(self):
return self.get_model().get_token_compressor()
def split_by_instance(self, original_list, split_sizes):
start = 0
sub_lists = []
for size in split_sizes:
end = start + size
sub_list = original_list[start:end]
# sub_lists.append(torch.stack(sub_list, dim=0))
sub_lists.append([x.to(self.device) for x in sub_list])
start = end
return sub_lists
def encode_images(self, images=None, split_sizes=None, pixel_values=None, grid_thw=None):
"""
images: (if not anyres) images.shape = [n,3,336,336] , n = number of images + (number of video) * 8
images: (if anyres) images.shape = [n,3,336,336] , n = number of tiles * number of images
"""
siglip_vision_tower, qwen2vl_vision_tower = self.get_model().get_vision_tower()
if images is not None:
image_features = siglip_vision_tower(images)
image_features = self.get_model().mm_projector(image_features)
qwen2vl_image_features = None
if pixel_values is not None:
qwen2vl_image_features = qwen2vl_vision_tower(pixel_values, grid_thw)
qwen2vl_image_split_sizes = torch.prod(grid_thw[:, 1:3] // 2, dim=1)
qwen2vl_image_features = torch.split(qwen2vl_image_features, qwen2vl_image_split_sizes.tolist(), dim=0)
qwen2vl_image_features = self.split_by_instance(qwen2vl_image_features, split_sizes)
if images is None:
return qwen2vl_image_features
if getattr(self.config,'anyres', False) and getattr(self.config, 'max_vision_token', None) is not None:
assert split_sizes is not None
image_features = list(torch.split(image_features, split_sizes, dim=0))
for i,image_feature in enumerate(image_features):
hidden_dim = image_feature.shape[-1]
image_tokens = image_feature.shape[0] * image_feature.shape[1]
# the max_vision_token will be processed in the unpad image token part
if False:
if image_tokens > self.config.max_vision_token:
intput_shape = int((image_feature.shape[1])**0.5)
output_shape = int((self.config.max_vision_token / image_feature.shape[0])**0.5)
image_feature = image_feature.view(image_feature.shape[0],intput_shape, intput_shape, -1) \
.permute(0,3,1,2)
# different from roi pooling, but in square image, it seems the same
m = nn.AdaptiveAvgPool2d(output_shape)
pooling_feature = m(image_feature).permute(0,2,3,1)
image_features[i] = pooling_feature.view(image_feature.shape[0], -1, hidden_dim)
split_sizes = None # have already split, set the flag
if getattr(self.config, 'model_class', None) in ['valley-video','valley_video']:
# since we mix video data and image data in a batch, and in valley video structure,
# both have same dimention, we need to split them to process
if split_sizes is not None:
image_features = torch.split(image_features, split_sizes, dim=0)
if getattr(self.config, 'mm_use_im_start_end', False):
video_start_end_image_features = []
for feature in image_features:
temporal_features = feature[:,0,:]
video_features = torch.mean(feature[:,1:,:],dim=0)
special_token_ids = torch.tensor(
self.tokenizer.convert_tokens_to_ids(
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_VI_START_TOKEN, DEFAULT_VI_END_TOKEN]
)
).to(video_features.device)
special_token_feature = self.get_model().embed_tokens(special_token_ids)
# add special sep feature as [<im_start><video_feature><im_end><vi_start><temporal_feature><vi_end>]
new_image_feature = torch.cat([
special_token_feature[0].unsqueeze(0),
video_features,
special_token_feature[1].unsqueeze(0),
special_token_feature[2].unsqueeze(0),
temporal_features,
special_token_feature[2].unsqueeze(0)
])
video_start_end_image_features.append(new_image_feature.unsqueeze(0))
return video_start_end_image_features, qwen2vl_image_features
else:
image_features_new = []
for feature in image_features:
temporal_features = feature[:,0,:]
video_features = torch.mean(feature[:,1:,:],dim=0)
new_image_feature = torch.cat([video_features, temporal_features])
image_features_new.append(new_image_feature.unsqueeze(0)) # increase batch dim
return image_features_new, qwen2vl_image_features
elif getattr(self.config, 'model_class', None) in ['valley-product','valley_product', 'tinyvalley']:
if getattr(self.config, 'mm_use_im_start_end', False):
raise ValueError('mm_use_im_start is not support in valley_product')
if split_sizes is not None:
image_features = torch.split(image_features, split_sizes, dim=0)
return image_features, qwen2vl_image_features
elif getattr(self.config, 'model_class', None) == 'valley-product-gandalf':
raise ValueError('valley-product-gandalf is not support in this version.')
else:
raise ValueError('No model class specified')
def prepare_inputs_labels_for_multimodal(
self, input_ids, position_ids, attention_mask, past_key_values, labels, images,
image_sizes, pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw, pack_ids):
vision_tower = self.get_vision_tower()
if vision_tower is None or images is None or input_ids.shape[1] == 1:
if past_key_values is not None and vision_tower is not None and \
images is not None and input_ids.shape[1] == 1:
target_shape = past_key_values[-1][-1].shape[-2] + 1
attention_mask = torch.cat((attention_mask, torch.ones(
(attention_mask.shape[0], target_shape - attention_mask.shape[1]),
dtype=attention_mask.dtype,
device=attention_mask.device
)), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
return input_ids, position_ids, attention_mask, past_key_values, None, labels
if type(images) is list or images.ndim == 5:
if not getattr(self.config,'anyres', False):
concat_images = torch.cat([image for image in images], dim=0) # to do batch compute
split_sizes = [image.shape[0] for image in images]
if pixel_values is not None:
image_features, qwen2vl_image_features = self.encode_images(
concat_images,
split_sizes,
pixel_values,
image_grid_thw
)
image_features = [x.to(self.device) for x in image_features]
elif pixel_values_videos is not None:
image_features, qwen2vl_image_features = self.encode_images(
concat_images,
split_sizes,
pixel_values_videos,
video_grid_thw
)
image_features = [x.to(self.device) for x in image_features]
else:
image_features, _ = self.encode_images(concat_images, split_sizes)
# image_features = [x.flatten(0, 1).to(self.device) for x in image_features]
# token compress
if self.get_token_compressor() is not None:
image_features = [self.get_token_compressor()(x) for x in image_features]
else:
# if do anyres, each image become some sub_images, so need to add a
# images = [
# [image1_tiles(n1,3,336,336), image2_tiles(n2,3,336,336), ...],
# [image1_tiles(n1,3,336,336), image2_tiles(n2,3,336,336), ...], ...
# ]
split_sizes = [len(image) for image in images]
# get qwen2vl features
qwen2vl_image_features = self.encode_images(None, split_sizes, pixel_values, image_grid_thw)
image_features = []
for batch_images in images:
concat_images = torch.cat([image for image in batch_images], dim=0) # to do batch compute
split_sizes = [image.shape[0] for image in batch_images]
batch_image_features, _ = self.encode_images(
concat_images,
split_sizes,
pixel_values,
image_grid_thw
)
# token compress
if self.get_token_compressor() is not None:
# x is tensor(n_tiles, T, d) or [tensor(T1, d), tensor(T2, d), ...]
batch_image_features = [self.get_token_compressor()(x) for x in batch_image_features]
if type(batch_image_features[0]) is list:
batch_image_features = [torch.cat(x).to(self.device) for x in batch_image_features]
else:
# tiles feature need to flatten in token dimention, [n_tiles, T, d] -> [n_tiles * T, d]
batch_image_features = [x.view(-1,x.shape[-1]).to(self.device) for x in batch_image_features]
image_features.append(batch_image_features)
# unpad image tokens
height = width = self.config.num_patches_per_side
new_image_features = []
for batch_image_features, batch_image_sizes in zip(image_features, image_sizes):
batch_image_features_list = []
for cur_image_feature, cur_image_size in zip(batch_image_features, batch_image_sizes):
base_image_feature = cur_image_feature[:width * height, :]
image_feature = cur_image_feature[width * height:, :]
if image_feature.shape[0] != 0:
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
cur_image_size,
self.config.grid_pinpoints,
self.config.vit_crop_size
)
# (num_patch_H, num_patch_W, H, W, C)
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
# (C, num_patch_H, H, num_patch_W, W)
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
# (C, num_token_H, num_token_W)
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
# (C, num_token_H_unpad, num_token_W_unpad)
image_feature = unpad_image(image_feature, cur_image_size)
input_shape = (image_feature.shape[-2], image_feature.shape[-1])
subimage_tokens = np.prod(input_shape)
# adaptive avg 2d pool for reducing token num
max_subimage_tokens = self.config.max_vision_token - width * height
if subimage_tokens > max_subimage_tokens:
aspect_ratio = input_shape[0] / input_shape[1]
output_shape = (
int((max_subimage_tokens / aspect_ratio) ** 0.5 * aspect_ratio),
int((max_subimage_tokens / aspect_ratio) ** 0.5)
)
m = nn.AdaptiveAvgPool2d(output_shape)
image_feature = m(image_feature)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
else:
image_feature = cur_image_feature
batch_image_features_list.append(image_feature)
new_image_features.append(batch_image_features_list)
image_features = new_image_features
else:
image_features = self.encode_images(images).to(self.device)
# token compress
if self.get_token_compressor() is not None:
image_features = self.get_token_compressor()(image_features)
# TODO: image start / end is not implemented here to support pretraining.
# if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
# raise NotImplementedError
# Let's just add dummy tensors if they do not exist,
# it is a headache to deal with None all the time.
# But it is not ideal, and if you have a better idea,
# please open an issue / submit a PR, thanks.
_labels = labels
_position_ids = position_ids
_attention_mask = attention_mask
if attention_mask is None:
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
elif getattr(self, "use_pack", False) is False:
attention_mask = attention_mask.bool()
if position_ids is None:
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
if labels is None:
labels = torch.full_like(input_ids, IGNORE_INDEX)
# remove the padding using attention_mask -- TODO: double check
input_ids = [
cur_input_ids[cur_attention_mask]
for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask.bool())
]
labels = [
cur_labels[cur_attention_mask]
for cur_labels, cur_attention_mask in zip(labels, attention_mask.bool())
]
attention_mask = [cur_attention_mask[cur_attention_mask.bool()] for cur_attention_mask in attention_mask]
new_input_embeds = []
new_labels = []
new_attention_mask = []
for batch_idx, cur_input_ids in enumerate(input_ids):
cur_batch_image_idx = 0
# for iamge
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
if getattr(self.config, 'model_class', None) in ['valley-video','valley_video']:
assert num_images <= 1, 'valley video is not support for multi image input'
if num_images == 0:
# if this piece of data is pure text,
# then concat a dummy image to ensure the whole compute graph is same on all device
# cur_image_features = image_features[batch_idx][cur_batch_image_idx]
siglip_feat = image_features[batch_idx][cur_batch_image_idx]
try:
qwen2vl_feat = qwen2vl_image_features[batch_idx][cur_batch_image_idx]
cur_image_features = torch.cat((siglip_feat, qwen2vl_feat), dim=0)
except:
print("only siglip feature:", siglip_feat.shape)
cur_image_features = siglip_feat
# print("num_images = 0: ", siglip_feat.shape, qwen2vl_feat.shape, cur_image_features.shape)
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
if getattr(self.config, "use_special_start_end_token", False) \
and getattr(self.config, "training_stage", None) == 'stage1':
cur_input_embeds_1 = cur_input_embeds_1.detach()
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features.squeeze(0)[0:0]], dim=0)
new_input_embeds.append(cur_input_embeds)
new_labels.append(labels[batch_idx])
new_attention_mask.append(attention_mask[batch_idx])
cur_batch_image_idx += 1
continue
image_token_indices = \
[-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
cur_input_ids_noim = [] # this list is to keep text input_ids
cur_labels = labels[batch_idx]
cur_labels_noim = []
cur_attention_mask = attention_mask[batch_idx]
cur_img_attention_mask = [
attention_mask[batch_idx][i].item()
for i in torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist()
]
cur_attention_mask_noim = []
for i in range(len(image_token_indices) - 1):
cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1: image_token_indices[i + 1]])
cur_labels_noim.append(cur_labels[image_token_indices[i] + 1: image_token_indices[i + 1]])
cur_attention_mask_noim.append(
cur_attention_mask[image_token_indices[i] + 1: image_token_indices[i + 1]]
)
split_sizes = [x.shape[0] for x in cur_labels_noim]
cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
cur_input_embeds_no_im = list(torch.split(cur_input_embeds, split_sizes, dim=0)) # get text features
if getattr(self.config, "use_special_start_end_token", False) and \
getattr(self.config, "training_stage", None) == 'stage1':
# for all sequence without image token,
# the first sequence's last token(<im_start> or <vi_start>) need to update embeds weight,
# the last sequences's first token(<im_end> or <vi_end>) need to update embeds weight,
# other sequence's first and last token need to update weight.
cur_input_embeds_no_im[0] = torch.cat(
[cur_input_embeds_no_im[0][:-1,:].detach(),cur_input_embeds_no_im[0][-1,:].unsqueeze(0)],
dim=0
)
cur_input_embeds_no_im[-1] = torch.cat(
[cur_input_embeds_no_im[-1][0,:].unsqueeze(0), cur_input_embeds_no_im[-1][1:,:].detach()],
dim=0
)
for i in range(1,len(cur_input_embeds_no_im) - 1):
# in this branch <image> token should not be placed in succession
cur_input_embeds_no_im[i] = torch.cat(
[
cur_input_embeds_no_im[i][0,:].unsqueeze(0), # for im_end token
cur_input_embeds_no_im[i][1:-1,:].detach(), # for text token
cur_input_embeds_no_im[i][-1,:].unsqueeze(0) # for im_start token
], dim=0
)
elif getattr(self.config, "training_stage", None) == 'special-token-sft':
for i in range(len(cur_input_embeds_no_im)):
special_token_idx = torch.where(cur_input_ids_noim[i] > self.config.eos_token_id)[0].tolist()
cur_input_embeds_no_im[i] = torch.cat([
cur_input_embeds_no_im[i][j,:].unsqueeze(0) if j in special_token_idx
else cur_input_embeds_no_im[i][j,:].detach().unsqueeze(0)
for j in range(len(cur_input_embeds_no_im[i]))
], dim=0)
cur_new_input_embeds = []
cur_new_labels = []
cur_new_attention_mask = []
for i in range(num_images + 1): # to add multimodal feature internal the text feature
cur_new_input_embeds.append(cur_input_embeds_no_im[i])
cur_new_labels.append(cur_labels_noim[i])
cur_new_attention_mask.append(cur_attention_mask_noim[i])
if i < num_images:
# print(num_images, f"({len(image_features)}, {len(image_features[batch_idx])})", \
# f"({len(qwen2vl_image_features)}, {len(qwen2vl_image_features[batch_idx])})", \
# f"({batch_idx}, {cur_batch_image_idx})")
siglip_feat = image_features[batch_idx][cur_batch_image_idx]
try:
qwen2vl_feat = qwen2vl_image_features[batch_idx][cur_batch_image_idx]
cur_image_features = torch.cat((siglip_feat, qwen2vl_feat), dim=0)
# print(siglip_feat.shape, qwen2vl_feat.shape, cur_image_features.shape)
except:
print("only siglip feature:", siglip_feat.shape)
cur_image_features = siglip_feat
# cur_image_features = torch.cat((siglip_feat, qwen2vl_feat), dim=0)
# print(siglip_feat.shape, qwen2vl_feat.shape, cur_image_features.shape)
cur_batch_image_idx += 1
cur_new_input_embeds.append(cur_image_features)
cur_new_labels.append(
torch.full(
(cur_image_features.shape[0],),
IGNORE_INDEX,
device=cur_labels.device,
dtype=cur_labels.dtype
)
)
# build attention_mask for pack
if getattr(self, "use_pack", False) is False:
cur_new_attention_mask.append(
torch.full(
(cur_image_features.shape[0],),
True,
device=cur_attention_mask.device,
dtype=cur_attention_mask.dtype
)
)
else:
cur_new_attention_mask.append(
torch.full(
(cur_image_features.shape[0],),
cur_img_attention_mask[i],
device=cur_attention_mask.device,
dtype=cur_attention_mask.dtype
)
)
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
cur_new_labels = torch.cat(cur_new_labels)
cur_new_attention_mask = torch.cat(cur_new_attention_mask)
new_input_embeds.append(cur_new_input_embeds)
new_labels.append(cur_new_labels)
new_attention_mask.append(cur_new_attention_mask)
# Truncate sequences to max length as image embeddings can make the sequence longer
tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
if tokenizer_model_max_length is not None:
new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
new_attention_mask = [x[:tokenizer_model_max_length] for x in new_attention_mask]
# Combine them
max_len = max(x.shape[0] for x in new_input_embeds)
batch_size = len(new_input_embeds)
new_input_embeds_padded = []
new_labels_padded = torch.full(
(batch_size, max_len),
IGNORE_INDEX,
dtype=new_labels[0].dtype,
device=new_labels[0].device
)
new_attention_mask_padded = torch.zeros(
(batch_size, max_len),
dtype=new_attention_mask[0].dtype,
device=new_attention_mask[0].device
)
# attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
for i, (cur_new_embed, cur_new_labels, cur_attention_mask) \
in enumerate(zip(new_input_embeds, new_labels, new_attention_mask)):
cur_len = cur_new_embed.shape[0]
if not self.training: # for inference
new_input_embeds_padded.append(torch.cat((
torch.zeros(
(max_len - cur_len, cur_new_embed.shape[1]),
dtype=cur_new_embed.dtype,
device=cur_new_embed.device
),
cur_new_embed
), dim=0))
if cur_len > 0:
new_labels_padded[i, -cur_len:] = cur_new_labels
new_attention_mask_padded[i, -cur_len:] = cur_attention_mask
# attention_mask[i, -cur_len:] = True
position_ids[i, -cur_len:] = torch.arange(
0,
cur_len,
dtype=position_ids.dtype,
device=position_ids.device
)
else:
new_input_embeds_padded.append(torch.cat((
cur_new_embed,
torch.zeros(
(max_len - cur_len, cur_new_embed.shape[1]),
dtype=cur_new_embed.dtype,
device=cur_new_embed.device
)
), dim=0))
if cur_len > 0:
new_labels_padded[i, :cur_len] = cur_new_labels
new_attention_mask_padded[i, :cur_len] = cur_attention_mask
# attention_mask[i, :cur_len] = True
position_ids[i, :cur_len] = torch.arange(
0,
cur_len,
dtype=position_ids.dtype,
device=position_ids.device
)
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
if _labels is None:
new_labels = None
else:
new_labels = new_labels_padded
if _attention_mask is None:
new_attention_mask = None
else:
new_attention_mask = new_attention_mask_padded
if _position_ids is None:
position_ids = None
if getattr(self, "use_pack", False) is True:
# new_attention_mask = new_attention_mask.bool()
new_attention_mask = self._prepare_4d_causal_attention_mask_for_pack(
new_attention_mask,
dtype=new_input_embeds.dtype
) # only for pack
return None, position_ids, new_attention_mask, past_key_values, new_input_embeds, new_labels
def _prepare_4d_causal_attention_mask_for_pack(self, attention_mask, dtype):
"""
Prepares a 4D causal attention mask for packed sequences.
This function generates a 4D causal attention mask for sequences that are packed together.
The mask ensures that each token can only attend to previous tokens within the same sequence
and not across different sequences.
Args:
attention_mask (torch.Tensor): A 1D tensor where each element,
indicating whether the corresponding token is valid (non-zero) or not (zero).
Tokens with the same non-zero value belong to the same sequence.
e.g. [1, 1, 1, 2, 2, 2, 3, 3, 3, 0, 0], 0 is the padding token.
dtype (torch.dtype): The data type to use for the resulting mask.
Returns:
torch.Tensor: A 4D tensor of shape (bs, 1, max_len, max_len) representing the causal attention mask.
The mask is filled with `torch.finfo(dtype).min` where tokens cannot attend and 0 where they can.
"""
batch_size, max_len = attention_mask.shape
tril_mask = torch.tril(
torch.ones(
(batch_size, 1, max_len, max_len),
dtype=torch.bool,
device=attention_mask.device
)
)
tril_mask = tril_mask \
& (attention_mask[:, None, None, :] == attention_mask[:, None, :, None]) \
& (attention_mask[:, None, None, :] != 0)
tril_mask = tril_mask.to(dtype=dtype)
tril_mask[tril_mask == 0] = torch.finfo(dtype).min
tril_mask[tril_mask == 1] = 0
return tril_mask
def initialize_vision_tokenizer(self, model_args, tokenizer, logger):
if model_args.mm_use_im_patch_token:
logger.info('Model is using image patch token placeholder. Adding <im_patch> to tokenizer...')
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
self.resize_token_embeddings(len(tokenizer))
if model_args.mm_use_im_start_end:
logger.info(
'Model is using im_start and im_end token placeholder. Adding <im_start> and <im_end> to tokenizer...'
)
num_new_tokens = tokenizer.add_tokens(
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_VI_START_TOKEN, DEFAULT_VI_END_TOKEN],
special_tokens=True)
self.resize_token_embeddings(len(tokenizer))
if num_new_tokens > 0:
input_embeddings = self.get_input_embeddings().weight.data
output_embeddings = self.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
if model_args.tune_mm_mlp_adapter:
for p in self.get_input_embeddings().parameters():
p.requires_grad = True
for p in self.get_output_embeddings().parameters():
p.requires_grad = False
if model_args.pretrain_mm_mlp_adapter:
mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
assert num_new_tokens == 4
if input_embeddings.shape == embed_tokens_weight.shape:
input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
elif embed_tokens_weight.shape[0] == num_new_tokens:
input_embeddings[-num_new_tokens:] = embed_tokens_weight
else:
raise ValueError(
f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. "
f"Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}."
)
elif getattr(model_args, "use_special_start_end_token", False):
logger.info(
'Model is using special token for video frame, image and grounding box.'
'Adding <im_start>/<im_end>/<vi_start>/<vi_end>/<cor>/</cor> to tokenizer...'
)
num_new_tokens = tokenizer.add_tokens(
[
DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_VI_START_TOKEN,
DEFAULT_VI_END_TOKEN, COR_START_TOKEN, COR_END_TOKEN
],
special_tokens=True)
self.resize_token_embeddings(len(tokenizer))
if num_new_tokens > 0:
input_embeddings = self.get_input_embeddings().weight.data
output_embeddings = self.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
if model_args.tune_mm_mlp_adapter and self.config.training_stage == 'stage1':
for p in self.get_input_embeddings().parameters():
p.requires_grad = True
# if model's word embedding is tied with lm head, then do not freeze lm head(word embed)
if not getattr(self.config, "tie_word_embeddings", True):
for p in self.get_output_embeddings().parameters():
p.requires_grad = False
if model_args.pretrain_mm_mlp_adapter:
mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
assert num_new_tokens == 6
if input_embeddings.shape == embed_tokens_weight.shape:
input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
elif embed_tokens_weight.shape[0] == num_new_tokens:
input_embeddings[-num_new_tokens:] = embed_tokens_weight
else:
raise ValueError(
f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. "
f"Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}."
)
elif model_args.mm_use_im_patch_token:
if model_args.tune_mm_mlp_adapter:
for p in self.get_input_embeddings().parameters():
p.requires_grad = False
for p in self.get_output_embeddings().parameters():
p.requires_grad = False
IGNORE_INDEX = -100
IMAGE_TOKEN_INDEX = -200
GANDALF_TOKEN_INDEX = -300
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "</s>"
DEFAULT_UNK_TOKEN = "<unk>"
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"
DEFAULT_VIDEO_TOKEN = "<video>"
DEFAULT_VIDEO_FRAME_TOKEN = "<vi_frame>"
DEFAULT_VI_START_TOKEN = "<vi_start>"
DEFAULT_VI_END_TOKEN = "<vi_end>"
DEFAULT_GANDALF_TOKEN = "<gandalf>"
DEFAULT_EOC_TOKEN = "<eoc>"
COR_START_TOKEN = "<cor>"
COR_END_TOKEN = "<\cor>"
SEQ_MAX_LEN = 50000
import copy
import re
from typing import Dict, Sequence
import torch
import transformers
from PIL import Image
from transformers import CLIPImageProcessor, StoppingCriteria
from .. import conversation as conversation_lib
import ast
import math
# from valley.constants import *
# from valley.util.config import *
from .config import (
DEFAULT_GANDALF_TOKEN,
DEFAULT_IM_END_TOKEN,
DEFAULT_IM_START_TOKEN,
DEFAULT_IMAGE_TOKEN,
DEFAULT_VI_END_TOKEN,
DEFAULT_VI_START_TOKEN,
DEFAULT_VIDEO_TOKEN,
GANDALF_TOKEN_INDEX,
IGNORE_INDEX,
IMAGE_TOKEN_INDEX,
SEQ_MAX_LEN
)
SPLIT_TOKEN = "<SPLIT_TOKEN>"
def collate_wrapper(batch):
try:
image_list = [b[0] for b in batch]
prompt_list = [b[2] for b in batch]
# input_ids = pad_sequence(prompt_list, padding_value = 0, batch_first = True)
conv_list = [b[3] for b in batch]
save_id_list = [b[4] for b in batch]
label_list = [b[1] for b in batch]
except Exception as e:
prompt_list, image_list, conv_list, label_list, save_id_list = None, None, None, None, None
print(f"error in collate_wrapper: {e} ||| all set to None")
return prompt_list, image_list, conv_list, label_list, save_id_list
def collate_process_image_text(batch, tokenizer, image_processor):
batch_input_ids, batch_image, conv_list, label_list, save_id_list = batch
input_ids = torch.stack(batch_input_ids, dim=0)
videos = []
for this_batch_images in batch_image:
if (
".mp4" not in save_id_list[0] and ".avi" not in save_id_list[0]
): # if not a video file, do image list process func
video = image_processor.preprocess(this_batch_images, return_tensors="pt")["pixel_values"]
videos.append(video)
else:
videos.append(this_batch_images)
return input_ids, videos, conv_list, label_list, save_id_list
class KeywordsStoppingCriteria(StoppingCriteria):
def __init__(self, keywords, tokenizer, input_ids):
self.keywords = keywords
self.tokenizer = tokenizer
self.start_len = None
self.input_ids = input_ids
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
if self.start_len is None:
self.start_len = self.input_ids.shape[1]
else:
outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0]
for keyword in self.keywords:
if keyword in outputs:
return True
return False
# for finetune
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
"""Collects the state dict and dump to disk."""
if trainer.args.should_save:
if getattr(trainer.args, "lora", None):
trainer.model.save_pretrained(output_dir)
if trainer.args.tune_mm_mlp_adapter:
trainer.model.base_model.model.save_pretrained(output_dir)
else:
state_dict = trainer.model.state_dict()
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
del state_dict
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
def tokenizer_image_token(
prompt,
tokenizer,
image_token_index=IMAGE_TOKEN_INDEX,
gandalf_token_index=GANDALF_TOKEN_INDEX,
return_tensors=None,
):
def split_with_token(string, token):
result = string.split(token)
for i in range(len(result) - 1):
result.insert(i * 2 + 1, token)
return result
if len(prompt) > SEQ_MAX_LEN:
# This error will be caught by the __getitem__ method in LazySupervisedDataset within valley/data/dataset.py,
# and it will then randomly select another valid data item to return.
raise ValueError("sequence is too long !!!")
prompt_chunks = split_with_token(prompt, DEFAULT_IMAGE_TOKEN)
prompt_chunks = sum([split_with_token(chunk, DEFAULT_GANDALF_TOKEN) for chunk in prompt_chunks], [])
input_ids, offset = ([tokenizer.bos_token_id], 1) if getattr(tokenizer,'bos_token',None) else ([], 0)
token2index = {DEFAULT_IMAGE_TOKEN: image_token_index, DEFAULT_GANDALF_TOKEN: gandalf_token_index}
for chunk in prompt_chunks:
if chunk in token2index:
input_ids.append(token2index[chunk])
else:
chunk_ids = tokenizer(chunk).input_ids
# For Qwen2-7B, bos token exists but does not appear in the beginning
if chunk_ids[0] != getattr(tokenizer,'bos_token_id', None):
offset = 0
input_ids.extend(chunk_ids[offset:])
# prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
# def insert_separator(X, sep):
# return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
# input_ids = []
# offset = 0
# if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
# offset = 1
# input_ids.append(prompt_chunks[0][0])
# for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
# input_ids.extend(x[offset:])
if return_tensors is not None:
if return_tensors == "pt":
return torch.tensor(input_ids, dtype=torch.long)
raise ValueError(f"Unsupported tensor type: {return_tensors}")
return input_ids
def smart_tokenizer_and_embedding_resize(
special_tokens_dict: Dict,
tokenizer: transformers.PreTrainedTokenizer,
model: transformers.PreTrainedModel,
):
"""Resize tokenizer and embedding.
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
"""
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
model.resize_token_embeddings(len(tokenizer))
if num_new_tokens > 0:
input_embeddings = model.get_input_embeddings().weight.data
output_embeddings = model.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
"""Tokenize a list of strings."""
tokenized_list = [
tokenizer(
text,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
)
for text in strings
]
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
input_ids_lens = labels_lens = [
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
]
return dict(
input_ids=input_ids,
labels=labels,
input_ids_lens=input_ids_lens,
labels_lens=labels_lens,
)
def _mask_targets(target, tokenized_lens, speakers, only_mask_system=False):
# cur_idx = 0
cur_idx = tokenized_lens[0]
tokenized_lens = tokenized_lens[1:]
target[:cur_idx] = IGNORE_INDEX
if not only_mask_system:
for tokenized_len, speaker in zip(tokenized_lens, speakers):
if speaker == "human":
target[cur_idx + 2: cur_idx + tokenized_len] = IGNORE_INDEX
cur_idx += tokenized_len
def _add_speaker_and_signal(header, source, get_conversation=True):
"""Add speaker and start/end signal on each round."""
BEGIN_SIGNAL = "### "
END_SIGNAL = "\n"
conversation = header
for sentence in source:
from_str = sentence["from"].strip()
if from_str.lower() == "human":
from_str = conversation_lib.default_conversation.roles[0]
elif from_str.lower() == "gpt":
from_str = conversation_lib.default_conversation.roles[1]
else:
from_str = "unknown"
sentence["value"] = BEGIN_SIGNAL + from_str + ": " + sentence["value"] + END_SIGNAL
if get_conversation:
conversation += sentence["value"]
conversation += BEGIN_SIGNAL
return conversation
def preprocess_multimodal(
conversations: Sequence[dict],
img_num,
data_args,
) -> Dict:
is_multimodal = data_args.is_multimodal
if not is_multimodal:
return conversations
for sentence in conversations:
if data_args.model_class in ["valley-product", "valley-gandalf", "tinyvalley", "valley-product-mistral"]:
if DEFAULT_VIDEO_TOKEN in sentence["value"]:
if data_args.use_special_start_end_token:
video_replace_token = (
DEFAULT_VI_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_VI_END_TOKEN
) * img_num
else:
video_replace_token = DEFAULT_IMAGE_TOKEN * img_num
# video_replace_token = ' '.join(f'Frame {i}: {DEFAULT_IMAGE_TOKEN}' for i in range(img_num))
sentence["value"] = sentence['value'].replace(DEFAULT_VIDEO_TOKEN, '').strip()
sentence["value"] = video_replace_token + '\n' + sentence["value"]
else:
segs = re.split(DEFAULT_IMAGE_TOKEN, sentence["value"])
if data_args.use_special_start_end_token:
sentence["value"] = (
DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
).join(segs[: img_num + 1]) + "".join(segs[img_num + 1:])
else:
sentence["value"] = DEFAULT_IMAGE_TOKEN.join(segs[: img_num + 1]) + "".join(
segs[img_num + 1:]
)
elif data_args.model_class in ["valley-video", "valley-video-mistral"]:
if DEFAULT_IMAGE_TOKEN in sentence["value"] or DEFAULT_VIDEO_TOKEN in sentence["value"]:
sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, "").strip()
sentence["value"] = sentence["value"].replace(DEFAULT_VIDEO_TOKEN, "").strip()
sentence["value"] = DEFAULT_IMAGE_TOKEN + "\n" + sentence["value"]
sentence["value"] = sentence["value"].strip()
if "mmtag" in conversation_lib.default_conversation.version:
sentence["value"] = sentence["value"].replace(
DEFAULT_IMAGE_TOKEN, "<Image>" + DEFAULT_IMAGE_TOKEN + "</Image>"
)
else:
raise Exception("unknown model class")
return conversations
def preprocess_llama_2(
sources,
tokenizer: transformers.PreTrainedTokenizer,
has_image: bool = False,
inference: bool = False,
only_mask_system: bool = False,
) -> Dict:
'''
FIXME: support only_mask_system=True; check tokenizer; unwrap sources
'''
conv = conversation_lib.default_conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
sources = [sources]
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
if inference:
conv.append_message(conv.roles[1], None)
conversations.append(conv.get_prompt())
# Tokenize conversations
if has_image:
input_ids = torch.stack(
[tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
else:
input_ids = tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2
# Mask targets
sep = "[/INST] "
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
rounds = conversation.split(conv.sep2)
cur_len = 1
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(rounds):
if rou == "":
break
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
if has_image:
round_len = len(tokenizer_image_token(rou, tokenizer))
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
else:
round_len = len(tokenizer(rou).input_ids)
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
target[cur_len: cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
target[cur_len:] = IGNORE_INDEX
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_INDEX
print(
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
f" (ignored)"
)
return dict(
input_ids=input_ids.squeeze(0),
labels=targets.squeeze(0),
)
def preprocess_mistral(
sources,
tokenizer: transformers.PreTrainedTokenizer,
has_image: bool = False,
inference: bool = False,
only_mask_system: bool = False,
) -> Dict:
"""
FIXME: support only_mask_system=True; check tokenizer; unwrap sources
"""
conv = conversation_lib.default_conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
sources = [sources]
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
if inference:
source.pop(-1)
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
if inference:
conv.append_message(conv.roles[1], None)
conversations.append(conv.get_prompt())
# Tokenize conversations
if has_image:
input_ids = torch.stack(
[tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations], dim=0
)
else:
input_ids = tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
# assert (input_ids == 1).sum() == 2 and input_ids.shape[0] ==1
# input_ids = input_ids[:,1:]
targets = input_ids.clone()
assert conv.sep_style == conversation_lib.SeparatorStyle.MISTRAL
# Mask targets
sep = "[/INST]"
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
rounds = conversation.split(conv.sep2)
cur_len = 1
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(rounds):
if rou == "":
break
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
if has_image:
round_len = len(tokenizer_image_token(rou, tokenizer))
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1
else:
round_len = len(tokenizer(rou).input_ids)
instruction_len = len(tokenizer(parts[0]).input_ids) - 1
target[cur_len: cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
if not only_mask_system:
target[cur_len:] = IGNORE_INDEX
if cur_len < tokenizer.model_max_length and not only_mask_system and not inference:
if cur_len != total_len:
target[:] = IGNORE_INDEX
print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)")
return dict(
input_ids=input_ids,
labels=targets,
)
def preprocess_v0(
sources,
tokenizer: transformers.PreTrainedTokenizer,
has_image: bool = False,
inference: bool = False,
only_mask_system: bool = False,
) -> Dict:
"""
FIXME: check tokenizer; unwrap sources
"""
sources = [sources]
conversations = []
for source in sources:
header = f"{conversation_lib.default_conversation.system}\n\n"
conversation = _add_speaker_and_signal(header, source)
conversations.append(conversation)
# tokenize conversations
def get_tokenize_len(prompts):
return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts]
if has_image:
input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations]
else:
conversations_tokenized = _tokenize_fn(conversations, tokenizer)
input_ids = conversations_tokenized["input_ids"]
targets = copy.deepcopy(input_ids)
for target, source in zip(targets, sources):
if has_image:
tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source])
else:
tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"]
speakers = [sentence["from"] for sentence in source]
_mask_targets(target, tokenized_lens, speakers, only_mask_system=only_mask_system)
return dict(input_ids=input_ids, labels=targets)
def preprocess_v1(
source,
tokenizer: transformers.PreTrainedTokenizer,
has_image: bool = False,
inference: bool = False,
only_mask_system: bool = False,
) -> Dict:
"""
FIXME: support only_mask_system=True
"""
conv = conversation_lib.default_conversation.copy()
assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# Apply prompt templates
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{j}"
conv.append_message(role, sentence["value"])
if inference:
conv.append_message(conv.roles[1], None)
conversation = conv.get_prompt()
# Mask targets
rounds = conversation.split(conv.sep2)
input_ids_ = torch.tensor([1], dtype=torch.int64)
targets_ = torch.tensor([-100], dtype=torch.int64)
for i, rou in enumerate(rounds):
if rou == "":
continue
if (not inference) or (i < (len(rounds) - 1)):
rou += conv.sep2
if has_image:
cur_input_ids_ = tokenizer_image_token(rou, tokenizer, return_tensors="pt")[1:]
input_ids_ = torch.cat([input_ids_, cur_input_ids_], dim=0)
if only_mask_system:
mask_len = len(
tokenizer_image_token(re.sub(rf"{conv.roles[0]}:[\s\S]*", f"{conv.roles[0]}:", rou), tokenizer)[1:]
)
else:
mask_len = len(
tokenizer_image_token(re.sub(rf"{conv.roles[1]}:[\s\S]*", f"{conv.roles[1]}:", rou), tokenizer)[1:]
)
# targets_ = torch.cat([targets_, torch.tensor([-100] * mask_len), cur_input_ids_[mask_len:]], dim=0)
targets_ = torch.cat([targets_, torch.tensor([-100] * mask_len), cur_input_ids_[mask_len:]], dim=0)
else:
cur_input_ids_ = tokenizer(rou, return_tensors="pt")["input_ids"][0, 1:]
input_ids_ = torch.cat([input_ids_, cur_input_ids_], dim=0)
mask_len = len(tokenizer(re.sub(rf"{conv.roles[1]}:[\s\S]*", f"{conv.roles[1]}:", rou))["input_ids"][1:])
# targets_ = torch.cat([targets_, torch.tensor([-100] * mask_len), cur_input_ids_[mask_len:]], dim=0)
targets_ = torch.cat([targets_, torch.tensor([-100] * mask_len), cur_input_ids_[mask_len:]], dim=0)
return {"input_ids": input_ids_, "labels": targets_}
def preprocess_plain(
sources: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
# add end signal and concatenate together
conversations = []
for source in sources:
assert len(source) == 2
assert DEFAULT_IMAGE_TOKEN in source[0]["value"]
source[0]["value"] = DEFAULT_IMAGE_TOKEN
conversation = source[0]["value"] + source[1]["value"] + conversation_lib.default_conversation.sep
conversations.append(conversation)
# tokenize conversations
input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations]
targets = copy.deepcopy(input_ids)
for target, source in zip(targets, sources):
tokenized_len = len(tokenizer_image_token(source[0]["value"], tokenizer))
target[:tokenized_len] = IGNORE_INDEX
return dict(input_ids=input_ids, labels=targets)
def preprocess_uninstruct_text_image(
sources,
tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
content = sources["content"]
input_ids_ = torch.tensor([1], dtype=torch.int64) if tokenizer.bos_token else torch.tensor([], dtype=torch.int64)
targets_ = torch.tensor([-100], dtype=torch.int64) if tokenizer.bos_token else torch.tensor([], dtype=torch.int64)
cur_input_ids_ = tokenizer_image_token(content, tokenizer, return_tensors="pt")[1:]
input_ids_ = torch.cat([input_ids_, cur_input_ids_], dim=0)
targets_ = torch.cat([targets_, cur_input_ids_[:]], dim=0)
return {"input_ids": input_ids_, "labels": targets_}
def preprocess_text(
sources,
tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
content = sources["content"]
if len(content) > SEQ_MAX_LEN:
# This error will be caught by the __getitem__ method in LazySupervisedDataset within valley/data/dataset.py,
# and it will then randomly select another valid data item to return.
raise ValueError("sequence is too long !!!")
input_tokens = []
bos_token = [tokenizer.bos_token] if tokenizer.bos_token else [] # suppor qwen2
for sub_text in content.split(SPLIT_TOKEN):
input_tokens.extend(bos_token + tokenizer.tokenize(sub_text) + [tokenizer.eos_token])
input_ids = torch.tensor(tokenizer.convert_tokens_to_ids(input_tokens))
targets = input_ids.clone()
return {"input_ids": input_ids, "labels": targets}
def preprocess_qwen2(
source,
tokenizer: transformers.PreTrainedTokenizer,
has_image: bool = False,
inference: bool = False,
only_mask_system: bool = False,
):
'''
"chat_template":
"{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}
{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' +
message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}
{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
'''
conv = conversation_lib.default_conversation.copy()
assert conv.sep_style == conversation_lib.SeparatorStyle.QWEN2
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# Apply prompt templatess
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{j}"
messages.append({"role":role, "content":sentence["value"]})
conversation = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=inference)
# Mask targets
rounds = conversation.split(conv.sep2)
input_ids_ = torch.tensor([], dtype=torch.int64)
targets_ = torch.tensor([], dtype=torch.int64)
for i, rou in enumerate(rounds):
if rou == "":
continue
if (not inference) or (i < (len(rounds) - 1)):
rou += conv.sep2
if has_image:
cur_input_ids_ = tokenizer_image_token(rou, tokenizer, return_tensors='pt')
input_ids_ = torch.cat([input_ids_, cur_input_ids_], dim=0)
if only_mask_system:
mask_len = len(tokenizer_image_token(re.sub(rf'{conv.roles[0]}\n[\s\S]*', f'{conv.roles[0]}:', rou),
tokenizer))
else:
mask_len = len(tokenizer_image_token(re.sub(rf'{conv.roles[1]}\n[\s\S]*', f'{conv.roles[1]}:', rou),
tokenizer))
targets_ = torch.cat([targets_, torch.tensor([-100] * mask_len), cur_input_ids_[mask_len:]], dim=0)
else:
cur_input_ids_ = tokenizer(rou, return_tensors='pt')["input_ids"][0, :]
input_ids_ = torch.cat([input_ids_, cur_input_ids_], dim=0)
mask_len = len(tokenizer(re.sub(rf'{conv.roles[1]}\n[\s\S]*', rf'{conv.roles[1]}:', rou))["input_ids"][:])
# targets_ = torch.cat([targets_, torch.tensor([-100] * mask_len), cur_input_ids_[mask_len:]], dim=0)
targets_ = torch.cat([targets_, torch.tensor([-100] * mask_len), cur_input_ids_[mask_len:]], dim=0)
return {"input_ids": input_ids_, "labels": targets_}
def preprocess(
sources: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
has_image: bool = False,
only_mask_system: bool = False,
inference: bool = False,
) -> Dict:
"""
Given a list of sources, each is a conversation list. This transform:
1. Add signal '### ' at the beginning each sentence, with end signal '\n';
2. Concatenate conversations together;
3. Tokenize the concatenated conversation;
4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
"""
assert conversation_lib.default_conversation.version in [
"v0", "v1", "mistral", "llama_2", "plain", 'qwen2','gemma2'
]
# v0 is for vicuna-v0, sep is '###'
# v1 is for vicuna-v1.x, sep is ' ', sep2 is '</s>'
# mistral is for mistral, sep is [INST]
# llama_2 is for llama2, sep is [INST]
# plain is for pretraining, no chat tamplete
# please refer to file examples/valleyproduct/valley/conversation.py for details
if isinstance(sources, dict):
if sources["preprocess_mode"] == "uninstruct_text_image":
return preprocess_uninstruct_text_image(sources, tokenizer)
elif sources["preprocess_mode"] == "puretext":
return preprocess_text(sources, tokenizer)
if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:
return preprocess_plain(sources, tokenizer)
if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2:
return preprocess_llama_2(
sources, tokenizer, has_image=has_image, inference=inference, only_mask_system=only_mask_system
)
if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.QWEN2:
return preprocess_qwen2(
sources, tokenizer, has_image=has_image, inference=inference, only_mask_system=only_mask_system
)
if conversation_lib.default_conversation.version == "v0":
return preprocess_v0(
sources, tokenizer, has_image=has_image, inference=inference, only_mask_system=only_mask_system
)
if conversation_lib.default_conversation.version == "v1":
return preprocess_v1(
sources, tokenizer, has_image=has_image, inference=inference, only_mask_system=only_mask_system
)
if conversation_lib.default_conversation.version == "mistral":
return preprocess_mistral(
sources, tokenizer, has_image=has_image, inference=inference, only_mask_system=only_mask_system
)
if conversation_lib.default_conversation.version.startswith("v1"):
print(
f"you'd better change your conversation version, current version is "
f"{conversation_lib.default_conversation.version}"
)
return preprocess_v1(
sources, tokenizer, has_image=has_image, inference=inference, only_mask_system=only_mask_system
)
def find_closest_aspect_ratio(aspect_ratio, min_tile_num, max_tile_num, width, height, tiled_image_size):
"""
Find the closest aspect ratio from a min tiles' number and a max tiles' number to the current image's aspect ratio.
An example usage:
find_closest_aspect_ratio(1.5, 1, 6, 1200, 800, 1024)
This will return the aspect ratio that is closest to 1.5, considering the image dimensions and preferring a larger
relative area to the 'image_size'.
In case of a tie, the ratio that results in a larger relative area compared to the original image size is chosen.
Args:
aspect_ratio (float): The current image's aspect ratio (width divided by height), e.g., 1.5.
max_tile_num (int): crop min tiles of the image
max_num (int): crop min tiles of the image
width (int): The width of the current image, e.g., 1200.
height (int): The height of the current image, e.g., 800.
tiled_image_size (int): the tile size , e.g, 336.
Returns:
Tuple[int, int]: The aspect ratio closest to the current image's aspect ratio
based on the criteria, e.g., (16, 9).
"""
# calculate the existing image aspect ratio
target_ratios = set(
(i, j) for n in range(min_tile_num, max_tile_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
i * j <= max_tile_num and i * j >= min_tile_num)
# sort by aera
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the best ratio
best_ratio_diff = float('inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
# choose the larger area, if the aspect ratio is the same like 2:3 and 4:6.
# And in this case(2:3 and 4:6), if the image aera is larger than 1/2 sum of all tiles aera,
# then choose 4:6,
# because the target_ratios is sorted, 4:6 is behind 2:3, the final ratio will be 4:6.
all_tile_aera_sum = tiled_image_size * tiled_image_size * ratio[0] * ratio[1]
if area > 0.5 * all_tile_aera_sum:
best_ratio = ratio
return best_ratio
def dynamic_preprocess(image, min_num=1, max_num=6, tiled_image_size=448, use_thumbnail=False):
"""
Processes an image dynamically based on its aspect ratio and specified parameters,
splitting it into sub-images or creating a thumbnail as needed.
Example:
>>> from PIL import Image
>>> img = Image.open('example.jpg')
>>> processed_imgs = dynamic_preprocess(img, min_num=1, max_num=6, image_size=448, use_thumbnail=True)
Args:
image (PIL.Image.Image): Input image to be processed.
min_num (int): Minimum product of width and height for aspect ratio consideration.
max_num (int): Maximum product of width and height for aspect ratio consideration.
image_size (int): Target size for resizing images.
use_thumbnail (bool): Whether to append a thumbnail of the original image if multiple sub-images are generated.
Returns:
List[PIL.Image.Image]: A list of processed images after resizing and/or splitting, with an optional thumbnail.
"""
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, min_num, max_num, orig_width, orig_height, tiled_image_size)
# calculate the target width and height
target_width = tiled_image_size * target_aspect_ratio[0]
target_height = tiled_image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
# save_resize_ratio
width_resize_ratio = target_width / orig_width
height_resize_ratio = target_height / orig_height
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // tiled_image_size)) * tiled_image_size,
(i // (target_width // tiled_image_size)) * tiled_image_size,
((i % (target_width // tiled_image_size)) + 1) * tiled_image_size,
((i // (target_width // tiled_image_size)) + 1) * tiled_image_size
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((tiled_image_size, tiled_image_size))
processed_images.append(thumbnail_img)
return processed_images, target_aspect_ratio, width_resize_ratio, height_resize_ratio
from PIL import Image
from io import BytesIO
import base64
import math
import ast
import re
import torch
from transformers import StoppingCriteria
from .config import IMAGE_TOKEN_INDEX
def resize_and_center_crop(image, shortest_edge_length):
# Calculate new dimensions and resize
aspect_ratio = float(image.width) / float(image.height)
if aspect_ratio > 1:
new_width = int(shortest_edge_length * aspect_ratio)
new_height = shortest_edge_length
else:
new_width = shortest_edge_length
new_height = int(shortest_edge_length / aspect_ratio)
resized_image = image.resize((new_width, new_height), Image.ANTIALIAS)
# Calculate the position and perform the center crop
left = (new_width - shortest_edge_length) / 2
top = (new_height - shortest_edge_length) / 2
right = (new_width + shortest_edge_length) / 2
bottom = (new_height + shortest_edge_length) / 2
cropped_image = resized_image.crop((left, top, right, bottom))
return cropped_image
def auto_pad_images(image, grid_params):
assert isinstance(image, Image.Image), "Input should be a Pillow Image"
assert len(grid_params) > 0, "Grid parameters should not be empty"
# Step 1: Calculate and find the closest aspect ratio
input_width, input_height = image.size
input_aspect_ratio = input_width / input_height
candidate_resolutions = [(w / h, w, h) for w in grid_params for h in grid_params]
closest_aspect_ratio = min(candidate_resolutions, key=lambda x: abs(input_aspect_ratio - x[0]))
candidate_resolutions = [(x[1], x[2]) for x in candidate_resolutions if abs(x[0] - closest_aspect_ratio[0]) < 1e-3]
target_resolution = min(candidate_resolutions, key=lambda res: abs(max(input_width, input_height) / max(res) - 1))
resize_width, resize_height = target_resolution
if input_width > input_height:
resize_height = int(resize_width / input_aspect_ratio)
else:
resize_width = int(resize_height * input_aspect_ratio)
resized_image = image.resize((resize_width, resize_height), Image.ANTIALIAS)
# Step 5: Pad the resized image if necessary to match the target resolution
pad_width = target_resolution[0] - resize_width
pad_height = target_resolution[1] - resize_height
padded_image = Image.new("RGB", target_resolution, color=(0, 0, 0))
padded_image.paste(resized_image, (pad_width // 2, pad_height // 2))
return padded_image
def extract_patches(image, patch_size, overlap_ratio):
assert isinstance(image, Image.Image), "Input should be a Pillow Image"
assert patch_size > 0, "Patch size should be greater than 0"
assert 0 <= overlap_ratio < 1, "Overlap ratio should be between 0 and 1"
W, H = image.size
patches = []
stride = int(patch_size * (1 - overlap_ratio))
num_patches_y = (H - patch_size) // stride + 1
num_patches_x = (W - patch_size) // stride + 1
y_start = (H - (num_patches_y - 1) * stride - patch_size) // 2
x_start = (W - (num_patches_x - 1) * stride - patch_size) // 2
for y in range(y_start, y_start + num_patches_y * stride, stride):
for x in range(x_start, x_start + num_patches_x * stride, stride):
patch = image.crop((x, y, x + patch_size, y + patch_size))
patches.append(patch)
return patches
def process_highres_image_crop_split(image, data_args, processor=None):
crop_resolution = data_args.image_crop_resolution
split_resolution = data_args.image_split_resolution
if processor is None:
processor = data_args.image_processor
image_crop = resize_and_center_crop(image, crop_resolution)
image_patches = extract_patches(image_crop, patch_size=split_resolution, overlap_ratio=0)
image_patches = [
processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0]
for image_patch in image_patches
]
return torch.stack(image_patches, dim=0)
def process_highres_image(image, processor, grid_pinpoints):
grid_params = [int(x) for x in grid_pinpoints.split(",")]
width_height = max(image.size)
fit_grid_params = [x for x in grid_params if x >= width_height]
if len(fit_grid_params) == 0:
select_size = max(grid_params)
else:
select_size = min(fit_grid_params)
# FIXME: always select the 448
select_size = max(grid_params)
image_padded = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
# FIXME: this seems to be a bug that it always resizes instead of padding
image_original_resize = image.resize((processor.size["shortest_edge"], processor.size["shortest_edge"]))
image_padded = image_padded.resize((select_size, select_size))
image_patches = extract_patches(image_padded, patch_size=processor.size["shortest_edge"], overlap_ratio=0)
image_patches = [image_original_resize] + image_patches
image_patches = [
processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0]
for image_patch in image_patches
]
return torch.stack(image_patches, dim=0)
def select_best_resolution(original_size, possible_resolutions):
"""
Selects the best resolution from a list of possible resolutions based on the original size.
Args:
original_size (tuple): The original size of the image in the format (width, height).
possible_resolutions (list): A list of possible resolutions in the format
[(width1, height1), (width2, height2), ...].
Returns:
tuple: The best fit resolution in the format (width, height).
"""
original_width, original_height = original_size
best_fit = None
max_effective_resolution = 0
min_wasted_resolution = float("inf")
for width, height in possible_resolutions:
# Calculate the downscaled size to keep the aspect ratio
scale = min(width / original_width, height / original_height)
downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
# Calculate effective and wasted resolutions
effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
wasted_resolution = (width * height) - effective_resolution
if effective_resolution > max_effective_resolution or \
(effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
max_effective_resolution = effective_resolution
min_wasted_resolution = wasted_resolution
best_fit = (width, height)
return best_fit
def resize_and_pad_image(image, target_resolution):
"""
Resize and pad an image to a target resolution while maintaining aspect ratio.
Args:
image (PIL.Image.Image): The input image.
target_resolution (tuple): The target resolution (width, height) of the image.
Returns:
PIL.Image.Image: The resized and padded image.
"""
original_width, original_height = image.size
target_width, target_height = target_resolution
# Determine which dimension (width or height) to fill
scale_w = target_width / original_width
scale_h = target_height / original_height
if scale_w < scale_h:
# Width will be filled completely
new_width = target_width
new_height = min(math.ceil(original_height * scale_w), target_height)
else:
# Height will be filled completely
new_height = target_height
new_width = min(math.ceil(original_width * scale_h), target_width)
# Resize the image
resized_image = image.resize((new_width, new_height))
# Create a new image with the target size and paste the resized image onto it
new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0))
paste_x = (target_width - new_width) // 2
paste_y = (target_height - new_height) // 2
new_image.paste(resized_image, (paste_x, paste_y))
return new_image
def divide_to_patches(image, patch_size):
"""
Divides an image into patches of a specified size.
Args:
image (PIL.Image.Image): The input image.
patch_size (int): The size of each patch.
Returns:
list: A list of PIL.Image.Image objects representing the patches.
"""
patches = []
width, height = image.size
for i in range(0, height, patch_size):
for j in range(0, width, patch_size):
box = (j, i, j + patch_size, i + patch_size)
patch = image.crop(box)
patches.append(patch)
return patches
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
"""
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
Args:
image_size (tuple): The size of the input image in the format (width, height).
grid_pinpoints (str): A string representation of a list of possible resolutions.
patch_size (int): The size of each image patch.
Returns:
tuple: The shape of the image patch grid in the format (width, height).
"""
if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
# Use regex to extract the range from the input string
matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
range_start = tuple(map(int, matches[0]))
range_end = tuple(map(int, matches[-1]))
# Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
grid_pinpoints = [
(i, j)
for i in range(range_start[0], range_end[0] + 1)
for j in range(range_start[1], range_end[1] + 1)
]
# Multiply all elements by patch_size
grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
if type(grid_pinpoints) is list:
possible_resolutions = grid_pinpoints
else:
possible_resolutions = ast.literal_eval(grid_pinpoints)
width, height = select_best_resolution(image_size, possible_resolutions)
return width // patch_size, height // patch_size
def process_anyres_image(image, processor, grid_pinpoints):
"""
Process an image with variable resolutions.
Args:
image (PIL.Image.Image): The input image to be processed.
processor: The image processor object.
grid_pinpoints (str): A string representation of a list of possible resolutions.
Returns:
torch.Tensor: A tensor containing the processed image patches.
"""
# Convert grid_pinpoints from string to list
if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
try:
patch_size = processor.size["height"]
except Exception:
patch_size = processor.size["shortest_edge"]
assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
# Use regex to extract the range from the input string
matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
range_start = tuple(map(int, matches[0]))
range_end = tuple(map(int, matches[-1]))
# Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
grid_pinpoints = [
(i, j)
for i in range(range_start[0], range_end[0] + 1)
for j in range(range_start[1], range_end[1] + 1)
]
# Multiply all elements by patch_size
grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
if type(grid_pinpoints) is list:
possible_resolutions = grid_pinpoints
else:
possible_resolutions = ast.literal_eval(grid_pinpoints)
best_resolution = select_best_resolution(image.size, possible_resolutions)
image_padded = resize_and_pad_image(image, best_resolution)
patches = divide_to_patches(image_padded, processor.size["height"])
# FIXME: this seems to be a bug that it resizes instead of pad.
# but to keep it consistent with previous, i will keep it as it is
# TODO: uncomment below to ablate with the padding
if isinstance(processor.size, dict):
shortest_edge = processor.size["height"]
else:
shortest_edge = min(processor.size)
image_original_resize = image.resize((shortest_edge, shortest_edge))
# image_padded_square = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
image_patches = [image_original_resize] + patches
image_patches = [
processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0]
for image_patch in image_patches
]
# return torch.stack(image_patches, dim=0)
return image_patches
def load_image_from_base64(image):
return Image.open(BytesIO(base64.b64decode(image)))
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
def process_images(images, image_processor, model_cfg):
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
new_images = []
if image_aspect_ratio == "highres":
for image in images:
image = process_highres_image(image, image_processor, model_cfg.image_grid_pinpoints)
new_images.append(image)
elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
for image in images:
image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
new_images.append(image)
elif image_aspect_ratio == "crop_split":
for image in images:
image = process_highres_image_crop_split(image, model_cfg, image_processor)
new_images.append(image)
elif image_aspect_ratio == "pad":
for image in images:
image = expand2square(image, tuple(int(x * 255) for x in image_processor.image_mean))
image = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
new_images.append(image)
else:
return image_processor.preprocess(images, return_tensors="pt")["pixel_values"]
if all(x.shape == new_images[0].shape for x in new_images):
new_images = torch.stack(new_images, dim=0)
return new_images
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
def insert_separator(X, sep):
return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
input_ids = []
offset = 0
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
offset = 1
input_ids.append(prompt_chunks[0][0])
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
input_ids.extend(x[offset:])
if return_tensors is not None:
if return_tensors == "pt":
return torch.tensor(input_ids, dtype=torch.long)
raise ValueError(f"Unsupported tensor type: {return_tensors}")
return input_ids
def get_model_name_from_path(model_path):
model_path = model_path.strip("/")
model_paths = model_path.split("/")
if model_paths[-1].startswith("checkpoint-"):
return model_paths[-2] + "_" + model_paths[-1]
else:
return model_paths[-1]
class KeywordsStoppingCriteria(StoppingCriteria):
def __init__(self, keywords, tokenizer, input_ids):
self.keywords = keywords
self.keyword_ids = []
for keyword in keywords:
cur_keyword_ids = tokenizer(keyword).input_ids
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
cur_keyword_ids = cur_keyword_ids[1:]
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
self.tokenizer = tokenizer
self.start_len = input_ids.shape[1]
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
offset = min(output_ids.shape[1] - self.start_len, 3)
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
for keyword_id in self.keyword_ids:
if output_ids[0, -keyword_id.shape[0]:] == keyword_id:
return True
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
for keyword in self.keywords:
if keyword in outputs:
return True
return False
def unpad_image(tensor, original_size):
"""
Unpads a PyTorch tensor of a padded and resized image.
Args:
tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
original_size (tuple): The original size of the image (height, width).
Returns:
torch.Tensor: The unpadded image tensor.
"""
original_width, original_height = original_size
current_height, current_width = tensor.shape[1:]
# Compute aspect ratios
original_aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
# Determine padding size and direction
if original_aspect_ratio > current_aspect_ratio:
# Padding was added to the height
scale_factor = current_width / original_width
new_height = int(original_height * scale_factor)
padding = (current_height - new_height) // 2
unpadded_tensor = tensor[:, padding: current_height - padding, :]
else:
# Padding was added to the width
scale_factor = current_height / original_height
new_width = int(original_width * scale_factor)
padding = (current_width - new_width) // 2
unpadded_tensor = tensor[:, :, padding: current_width - padding]
return unpadded_tensor
from transformers import PretrainedConfig
siglip_config = PretrainedConfig.from_dict(
{
"attention_dropout": 0.0,
"hidden_act": "gelu_pytorch_tanh",
"hidden_size": 1152,
"image_size": 384,
"intermediate_size": 4304,
"layer_norm_eps": 1e-06,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_channels": 3,
"num_hidden_layers": 27,
"patch_size": 14,
}
)
qwen2vl_vit_config = PretrainedConfig.from_dict(
{
"depth": 32,
"embed_dim": 1280,
"hidden_act": "quick_gelu",
"hidden_size": 3584,
"in_channels": 3,
"in_chans": 3,
"mlp_ratio": 4,
"model_type": "qwen2_vl",
"num_heads": 16,
"patch_size": 14,
"spatial_merge_size": 2,
"spatial_patch_size": 14,
"temporal_patch_size": 2,
"_attn_implementation": "flash_attention_2",
"_attn_implementation_internal": "flash_attention_2"
}
)
import datetime
import logging
import logging.handlers
import os
import sys
import requests
from .constants import LOGDIR
import re
server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
handler = None
import logging
import warnings
import torch.distributed as dist
from prettytable import PrettyTable
def check_model_config(config):
model_class = getattr(config, "model_class", None)
llm_name = getattr(config, "llm_name", None)
if model_class not in ["valley-video", "valley-product", "valley-gandalf"]:
if model_class == 'tinyvalley':
config.model_class = 'valley-product'
warnings.warn(
'"tinyvalley" belongs to "valley-product" model class, force set to "valley-product" here.',
category=None,
stacklevel=1,
source=None
)
elif model_class is None:
raise ValueError("Please specify 'model_class' in 'config.json' in model path")
else:
raise ValueError(
"Invalid model class. Only [ 'valley-video', 'valley-product', 'valley-gandalf'] is now supported."
)
if llm_name not in ['llama','llama_2', 'mistral','qwen2']:
if llm_name is None:
raise ValueError("Please specify 'model_class' in 'config.json' in model path")
else:
raise ValueError("Unknown LLM Name. Only ['llama', 'llama_2', 'mistral'] is now supported.")
return config
def print_trainable_params(model):
logger = get_logger('train') # get the logger while train
if dist.get_rank() == 0:
trainable_params = [k for k,v in model.named_parameters() if v.requires_grad]
trainable_params_group = {}
for para in trainable_params:
layer_num = re.findall(r'layers.(\d+)\.',para)
block_num = re.findall(r'blocks.(\d+)\.',para)
if layer_num:
cur_layer = int(layer_num[0])
if para.replace('layers.' + layer_num[0],'layers.*') not in trainable_params_group:
trainable_params_group[para.replace('layers.' + layer_num[0],'layers.*')] = layer_num[0]
elif cur_layer > int(trainable_params_group[para.replace('layers.' + layer_num[0],'layers.*')]):
trainable_params_group[para.replace('layers.' + layer_num[0],'layers.*')] = layer_num[0]
elif block_num:
cur_layer = int(block_num[0])
if para.replace('blocks.' + block_num[0],'blocks.*') not in trainable_params_group:
trainable_params_group[para.replace('blocks.' + block_num[0],'blocks.*')] = block_num[0]
elif cur_layer > int(trainable_params_group[para.replace('blocks.' + block_num[0],'blocks.*')]):
trainable_params_group[para.replace('blocks.' + block_num[0],'blocks.*')] = block_num[0]
else:
trainable_params_group[para] = '0'
table = PrettyTable(['Parameter Name','Max Layer Number'])
for key in trainable_params_group.keys():
table.add_row([key, str(int(trainable_params_group[key]) + 1)])
print(table)
total_num = sum([v.numel() for k,v in model.named_parameters()])
trainable_num = sum([v.numel() for k,v in model.named_parameters() if v.requires_grad])
logger.info('Total: {:.2f}M'.format(total_num / 1e6))
logger.info(' Trainable: {:.2f}M'.format(trainable_num / 1e6))
def rank_zero_info(content: str, logger, print_type: str = "info"):
output_method = getattr(logger, print_type)
if dist.get_rank() == 0:
output_method(content)
def get_logger(name: str):
# logger initialize
logger = logging.getLogger(name)
logger.setLevel(logging.INFO)
# handler
handler = logging.StreamHandler(sys.stdout)
handler.setLevel(logging.INFO)
# formatter
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
handler.setFormatter(formatter)
# add handler
logger.addHandler(handler)
return logger
def build_logger(logger_name, logger_filename, logdir=LOGDIR):
global handler
formatter = logging.Formatter(
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
# Set the format of root handlers
if not logging.getLogger().handlers:
logging.basicConfig(level=logging.INFO)
logging.getLogger().handlers[0].setFormatter(formatter)
# Redirect stdout and stderr to loggers
stdout_logger = logging.getLogger("stdout")
stdout_logger.setLevel(logging.INFO)
sl = StreamToLogger(stdout_logger, logging.INFO)
sys.stdout = sl
stderr_logger = logging.getLogger("stderr")
stderr_logger.setLevel(logging.ERROR)
sl = StreamToLogger(stderr_logger, logging.ERROR)
sys.stderr = sl
# Get logger
logger = logging.getLogger(logger_name)
logger.setLevel(logging.INFO)
# Add a file handler for all loggers
if handler is None:
os.makedirs(logdir, exist_ok=True)
filename = os.path.join(logdir, logger_filename)
handler = logging.handlers.TimedRotatingFileHandler(
filename, when='D', utc=True)
handler.setFormatter(formatter)
for name, item in logging.root.manager.loggerDict.items():
if isinstance(item, logging.Logger):
item.addHandler(handler)
return logger
class StreamToLogger(object):
"""
Fake file-like stream object that redirects writes to a logger instance.
"""
def __init__(self, logger, log_level=logging.INFO):
self.terminal = sys.stdout
self.logger = logger
self.log_level = log_level
self.linebuf = ''
def __getattr__(self, attr):
return getattr(self.terminal, attr)
def write(self, buf):
temp_linebuf = self.linebuf + buf
self.linebuf = ''
for line in temp_linebuf.splitlines(True):
# From the io.TextIOWrapper docs:
# On output, if newline is None, any '\n' characters written
# are translated to the system default line separator.
# By default sys.stdout.write() expects '\n' newlines and then
# translates them so this is still cross platform.
if line[-1] == '\n':
self.logger.log(self.log_level, line.rstrip())
else:
self.linebuf += line
def flush(self):
if self.linebuf != '':
self.logger.log(self.log_level, self.linebuf.rstrip())
self.linebuf = ''
def disable_torch_init():
"""
Disable the redundant torch default initialization to accelerate model creation.
"""
import torch
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
def violates_moderation(text):
"""
Check whether the text violates OpenAI moderation API.
"""
url = "https://api.openai.com/v1/moderations"
headers = {"Content-Type": "application/json",
"Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
text = text.replace("\n", "")
data = "{" + '"input": ' + f'"{text}"' + "}"
data = data.encode("utf-8")
try:
ret = requests.post(url, headers=headers, data=data, timeout=5)
flagged = ret.json()["results"][0]["flagged"]
except requests.exceptions.RequestException:
flagged = False
except KeyError:
flagged = False
return flagged
def pretty_print_semaphore(semaphore):
if semaphore is None:
return "None"
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
import torch
from PIL import Image
from ..base import BaseModel
from ...smp import *
from typing import Dict
import logging
from transformers import set_seed
from transformers import AutoTokenizer, AutoProcessor
from qwen_vl_utils import fetch_image, fetch_video
import re
from .valley_eagle.model.language_model.valley_qwen2 import ValleyQwen2ForCausalLM
from .valley_eagle.util.mm_utils import process_anyres_image
from .valley_eagle import conversation as conversation_lib
from .valley_eagle.util.data_util import dynamic_preprocess, preprocess
IGNORE_INDEX = -100
IMAGE_TOKEN_INDEX = -200
GANDALF_TOKEN_INDEX = -300
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "</s>"
DEFAULT_UNK_TOKEN = "<unk>"
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"
DEFAULT_VIDEO_TOKEN = "<video>"
DEFAULT_VIDEO_FRAME_TOKEN = "<vi_frame>"
DEFAULT_VI_START_TOKEN = "<vi_start>"
DEFAULT_VI_END_TOKEN = "<vi_end>"
DEFAULT_GANDALF_TOKEN = "<gandalf>"
DEFAULT_EOC_TOKEN = "<eoc>"
COT_PROMPT = "\nPlease think step by step."
def disable_torch_init():
"""
Disable the redundant torch default initialization to accelerate model creation.
"""
import torch
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
def preprocess_multimodal(
conversations,
img_num,
data_args,
) -> Dict:
for sentence in conversations:
if data_args.model_class in ["valley-product", "valley-gandalf", "tinyvalley", "valley-product-mistral"]:
if DEFAULT_VIDEO_TOKEN in sentence["value"]:
if data_args.use_special_start_end_token:
video_replace_token = \
(DEFAULT_VI_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_VI_END_TOKEN) * img_num
else:
video_replace_token = DEFAULT_IMAGE_TOKEN * img_num
sentence["value"] = sentence['value'].replace(DEFAULT_VIDEO_TOKEN, '').strip()
sentence["value"] = video_replace_token + '\n' + sentence["value"]
else:
segs = re.split(DEFAULT_IMAGE_TOKEN, sentence["value"])
if data_args.use_special_start_end_token:
sentence["value"] = \
(DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN).join(segs[: img_num + 1])\
+ "".join(segs[img_num + 1:])
else:
sentence["value"] = DEFAULT_IMAGE_TOKEN.join(segs[: img_num + 1]) + "".join(
segs[img_num + 1:]
)
elif data_args.model_class in ["valley-video", "valley-video-mistral"]:
if DEFAULT_IMAGE_TOKEN in sentence["value"] or DEFAULT_VIDEO_TOKEN in sentence["value"]:
sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, "").strip()
sentence["value"] = sentence["value"].replace(DEFAULT_VIDEO_TOKEN, "").strip()
sentence["value"] = DEFAULT_IMAGE_TOKEN + "\n" + sentence["value"]
sentence["value"] = sentence["value"].strip()
else:
raise Exception(f"unknown model class : {data_args.model_class}")
return conversations
def tokenizer_image_token(
prompt,
tokenizer,
image_token_index=IMAGE_TOKEN_INDEX,
gandalf_token_index=GANDALF_TOKEN_INDEX,
return_tensors=None,
):
def split_with_token(string, token):
result = string.split(token)
for i in range(len(result) - 1):
result.insert(i * 2 + 1, token)
return result
prompt_chunks = split_with_token(prompt, DEFAULT_IMAGE_TOKEN)
prompt_chunks = sum([split_with_token(chunk, DEFAULT_GANDALF_TOKEN) for chunk in prompt_chunks], [])
input_ids, offset = ([tokenizer.bos_token_id], 1) if getattr(tokenizer,'bos_token',None) else ([], 0)
token2index = {DEFAULT_IMAGE_TOKEN: image_token_index, DEFAULT_GANDALF_TOKEN: gandalf_token_index}
for chunk in prompt_chunks:
if chunk in token2index:
input_ids.append(token2index[chunk])
else:
chunk_ids = tokenizer(chunk).input_ids
# For Qwen2-7B, bos token exists but does not appear in the beginning
if chunk_ids[0] != getattr(tokenizer, 'bos_token_id', None):
offset = 0
input_ids.extend(chunk_ids[offset:])
if return_tensors is not None:
if return_tensors == "pt":
return torch.tensor(input_ids, dtype=torch.long)
raise ValueError(f"Unsupported tensor type: {return_tensors}")
return input_ids
BLACK_IMG_ENV = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x03\x00' + \
b'\x00\x00\x03\x08\x02\x00\x00\x00\xd9J"\xe8\x00\x00\x00' + \
b'\x12IDAT\x08\x1dcd\x80\x01F\x06\x18`d\x80\x01\x00\x00Z\x00' + \
b'\x04we\x03N\x00\x00\x00\x00IEND\xaeB`\x82'
class ValleyEagleChat(BaseModel):
def __init__(self,
model_path='liuhaotian/llava_v1.5_7b',
**kwargs):
torch_dtype = torch.float16
padding_side = 'left'
use_fast = True
trust_remote_code = True
output_logits = False
conversation_tag = 'qwen2'
max_new_tokens: int = 384
seed = 42
black_img = BLACK_IMG_ENV
disable_torch_init()
set_seed(seed)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.output_logits = output_logits
self.conversation_tag = conversation_tag
conversation_lib.default_conversation = conversation_lib.conv_templates[self.conversation_tag]
# Load model
logging.info(f"Start loading valley model from {model_path}")
self.model_path = model_path
self.model = ValleyQwen2ForCausalLM.from_pretrained(model_path, torch_dtype=torch_dtype)
self.model = self.model.to(self.device).half()
# should check this code
self.model.config.min_tile_num = 1
self.model.config.max_tile_num = 9
self.model.eval()
self.tokenizer = AutoTokenizer.from_pretrained(
model_path,
use_fast=use_fast,
trust_remote_code=trust_remote_code
)
self.tokenizer.padding_side = padding_side
logging.info("Load model success!")
self.black_img = black_img
self.max_new_tokens = max_new_tokens
# Load image preprocessor
from transformers import SiglipImageProcessor
self.qwen2vl_processor = None
self.image_processor = SiglipImageProcessor.from_pretrained(self.model.config.mm_vision_tower)
self.image_processor.crop_size = self.image_processor.size["height"]
# self.vision_tower.load_model() # vision_tower is an instance
kwargs_default = dict(do_sample=False, temperature=0, max_new_tokens=512, top_p=None, num_beams=1, use_cache=True) # noqa E501
kwargs_default.update(kwargs)
self.kwargs = kwargs_default
warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ')
def expand2square(self,pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
def preprocess_images(
self, image_binary_list
):
images = [Image.open(binary).convert("RGB") for binary in image_binary_list]
image_sizes_list = [img.size for img in images]
video_pad = []
for img in images:
if self.model.config.anyres:
image = process_anyres_image(img, self.image_processor, self.model.config.grid_pinpoints)
else:
image = self.image_processor(img, return_tensors="pt")["pixel_values"][0]
video_pad.append(image)
video_pad = (
[self.black_img] if len(video_pad) == 0 else video_pad
) # black image is not processed by danyres
# import pdb; pdb.set_trace()
if not self.model.config.anyres:
video = torch.stack(video_pad, dim=0)
else:
video = [torch.stack(img, dim=0) for img in video_pad]
return video, image_sizes_list
def generate_inner(self, message, dataset=None):
if self.qwen2vl_processor is None:
if dataset == 'OCRBench':
self.qwen2vl_processor = AutoProcessor.from_pretrained(
self.model.config.eagle_vision_tower,
max_pixels=1280 * 28 * 28,
min_pixels=10 * 10 * 28 * 28
)
else:
self.qwen2vl_processor = AutoProcessor.from_pretrained(
self.model.config.eagle_vision_tower,
max_pixels=1280 * 28 * 28
)
messages = []
text, images = '', []
for item in message:
if item['type'] == 'text':
text += item['value']
elif item['type'] == 'image':
text += ' <image> '
images.append(item['value'])
if dataset in ["MMMU_DEV_VAL", "MMStar", "OCRBench", "MMVet"]:
messages.append({"from": 'human', "value": text + COT_PROMPT})
else:
messages.append({"from": 'human', "value": text})
messages_qwen = []
image_list = []
for image_file in images:
image = fetch_image({"image": image_file})
image_list.append(image)
messages_qwen.append({'role': 'user', "content": [{"type": "text", "text": text}]})
messages_qwen.append({"role": "assistant", "content": [{"type": "text", "text": ""}]})
text = self.qwen2vl_processor.apply_chat_template(
messages_qwen[:-1],
tokenize=False,
add_generation_prompt=True
)
text_segs = re.split("<image>", text)
text = "<|vision_start|><|image_pad|><|vision_end|>".join(text_segs[: len(image_list) + 1]) + \
"".join(text_segs[len(image_list) + 1:])
sources = self.qwen2vl_processor(text=[text], images=image_list, padding=True, return_tensors="pt")
mask_len = len(self.qwen2vl_processor(
text=[re.sub(r"assistant\\\n[\s\S]*", "assistant\n", text)],
images=image_list,
padding=True,
return_tensors="pt"
)["input_ids"][0])
sources["input_ids"] = sources["input_ids"][0]
sources["labels"] = torch.cat([torch.tensor([-100] * mask_len), sources["input_ids"][mask_len:]], dim=0)
data_dict_qwen2vl = sources
video_images_tensor, image_sizes_list = self.preprocess_images(images)
img_length = len(video_images_tensor)
source = preprocess_multimodal(messages, img_length, self.model.config)
data_dict = preprocess(
source,
self.tokenizer,
has_image=True,
only_mask_system=False,
inference=True,
)
input_ids = data_dict['input_ids']
input_ids = input_ids.unsqueeze(0).to(self.device)
if img_length:
images = [item.to(self.device).half() for item in video_images_tensor]
with torch.inference_mode():
output_ids = self.model.generate(
input_ids=input_ids,
images=[images],
image_sizes=[image_sizes_list],
pixel_values=data_dict_qwen2vl['pixel_values'].to(self.device),
image_grid_thw=data_dict_qwen2vl['image_grid_thw'].to(self.device),
pixel_values_videos=None,
video_grid_thw=None,
do_sample=False,
max_new_tokens=2048,
repetition_penalty=1.0,
pad_token_id=self.tokenizer.pad_token_id,
return_dict_in_generate=True, output_scores=True)
input_token_len = input_ids.shape[1]
generation_text = self.tokenizer.batch_decode(output_ids.sequences[:, input_token_len:])[0]
generation_text = generation_text.replace("<|im_end|>", "")
return generation_text
from .video_llava import VideoLLaVA, VideoLLaVA_HF
from .videochat2 import VideoChat2_HD
from .chat_uni_vi import Chatunivi
from .video_chatgpt import VideoChatGPT
from .llama_vid import LLaMAVID
from .pllava import PLLaVA
__all__ = ['VideoLLaVA', 'VideoLLaVA_HF', 'Chatunivi', 'VideoChatGPT', 'LLaMAVID', 'VideoChat2_HD', 'PLLaVA']
import torch
import warnings
import copy as cp
import numpy as np
import sys
import os
import logging
from ..base import BaseModel
from ...smp import isimg, listinstr
from ...dataset import DATASET_TYPE
from decord import VideoReader, cpu
from PIL import Image
def _get_rawvideo_dec(
video_path,
image_processor,
max_frames=64,
image_resolution=224,
video_framerate=1,
s=None,
e=None,
):
# speed up video decode via decord.
video_mask = np.zeros(max_frames, dtype=np.int64)
max_video_length = 0
# T x 3 x H x W
video = np.zeros((max_frames, 3, image_resolution, image_resolution), dtype=np.float64)
if s is None:
start_time, end_time = None, None
else:
start_time = int(s)
end_time = int(e)
start_time = start_time if start_time >= 0.0 else 0.0
end_time = end_time if end_time >= 0.0 else 0.0
if start_time > end_time:
start_time, end_time = end_time, start_time
elif start_time == end_time:
end_time = start_time + 1
if os.path.exists(video_path):
vreader = VideoReader(video_path, ctx=cpu(0))
else:
print(video_path)
raise FileNotFoundError
fps = vreader.get_avg_fps()
f_start = 0 if start_time is None else int(start_time * fps)
f_end = int(min(1000000000 if end_time is None else end_time * fps, len(vreader) - 1))
num_frames = f_end - f_start + 1
if num_frames > 0:
# T x 3 x H x W
sample_fps = int(video_framerate)
t_stride = int(round(float(fps) / sample_fps))
all_pos = list(range(f_start, f_end + 1, t_stride))
if len(all_pos) > max_frames:
sample_pos = [all_pos[_] for _ in np.linspace(0, len(all_pos) - 1, num=max_frames, dtype=int)]
else:
sample_pos = all_pos
patch_images = [Image.fromarray(f) for f in vreader.get_batch(sample_pos).asnumpy()]
patch_images = torch.stack(
[image_processor.preprocess(img, return_tensors='pt')['pixel_values'][0] for img in patch_images]
)
slice_len = patch_images.shape[0]
max_video_length = max_video_length if max_video_length > slice_len else slice_len
if slice_len < 1:
pass
else:
video[:slice_len, ...] = patch_images
return patch_images, slice_len
else:
print('video path: {} error.'.format(video_path))
video_mask[:max_video_length] = [1] * max_video_length
return torch.from_numpy(video), video_mask
class Chatunivi(BaseModel):
INSTALL_REQ = True
INTERLEAVE = False
VIDEO_LLM = True
# sample 1 fps (maximum 64 frames) from the video
def __init__(self, model_path='Chat-UniVi/Chat-UniVi', **kwargs):
assert model_path is not None
try:
from ChatUniVi.model.builder import load_pretrained_model
except Exception as err:
logging.critical('Please install Chat-UniVi from https://github.com/PKU-YuanGroup/Chat-UniVi.git.')
raise err
model_name = 'ChatUniVi'
tokenizer, model, processor, context_len = load_pretrained_model(model_path, None, model_name)
self.tokenizer = tokenizer
self.model = model
vision_tower = model.get_vision_tower()
if not vision_tower.is_loaded:
vision_tower.load_model()
image_processor = vision_tower.image_processor
self.processor = image_processor
self.context_len = context_len
self.kwargs = kwargs
self.fps = 1
self.resolution = 224
if 'v1.5' in model_path:
self.resolution = 336
def get_model_output(self, model, video_processor, tokenizer, video, qs):
from ChatUniVi.conversation import conv_templates, SeparatorStyle
from ChatUniVi.constants import (
DEFAULT_IMAGE_PATCH_TOKEN,
DEFAULT_IMAGE_TOKEN,
IMAGE_TOKEN_INDEX,
DEFAULT_IM_START_TOKEN,
DEFAULT_IM_END_TOKEN,
MAX_IMAGE_LENGTH,
)
from ChatUniVi.mm_utils import (
tokenizer_image_token,
KeywordsStoppingCriteria,
)
mm_use_im_start_end = getattr(model.config, 'mm_use_im_start_end', False)
mm_use_im_patch_token = getattr(model.config, 'mm_use_im_patch_token', True)
if mm_use_im_patch_token:
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
if mm_use_im_start_end:
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
model.resize_token_embeddings(len(tokenizer))
if model.config.config['use_cluster']:
for n, m in model.named_modules():
m = m.to(dtype=torch.bfloat16)
video_frames, slice_len = _get_rawvideo_dec(
video, video_processor, max_frames=MAX_IMAGE_LENGTH,
image_resolution=self.resolution, video_framerate=self.fps
)
if model.config.mm_use_im_start_end:
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN * slice_len + DEFAULT_IM_END_TOKEN + '\n' + qs
if type(qs) is dict and 'user' in qs:
qs['user'] = DEFAULT_IMAGE_TOKEN * slice_len + '\n' + qs['user']
else:
qs = DEFAULT_IMAGE_TOKEN * slice_len + '\n' + qs
conv = conv_templates['v1'].copy()
if type(qs) is dict and 'system' in qs:
conv.system = qs['system']
if type(qs) is dict and 'user' in qs:
conv.append_message(conv.roles[0], qs['user'])
else:
conv.append_message(conv.roles[0], qs)
if type(qs) is dict and 'assistant' in qs:
conv.append_message(conv.roles[1], qs['assistant'])
else:
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt().strip('</s>')
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(
0).cuda()
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=video_frames.half().cuda(),
do_sample=True,
temperature=0.2,
top_p=None,
num_beams=1,
output_scores=True,
return_dict_in_generate=True,
max_new_tokens=1024,
use_cache=True,
stopping_criteria=[stopping_criteria])
output_ids = output_ids.sequences
input_token_len = input_ids.shape[1]
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
outputs = outputs.strip()
return outputs
def generate_inner(self, message, dataset=None):
if listinstr(['MLVU', 'MVBench'], dataset):
question, video = self.message_to_promptvideo_withrole(message, dataset)
else:
question, video = self.message_to_promptvideo(message)
response = self.get_model_output(self.model, self.processor, self.tokenizer, video, question)
return response
import torch
import warnings
import copy as cp
import numpy as np
import sys
import os
import logging
from ..base import BaseModel
from ...smp import isimg, listinstr, load, dump, download_file
from ...dataset import DATASET_TYPE
from decord import VideoReader, cpu
from huggingface_hub import snapshot_download
def load_video(video_path, setting_fps):
vr = VideoReader(video_path, ctx=cpu(0))
total_frame_num = len(vr)
fps = round(vr.get_avg_fps())
frame_idx = [i for i in range(0, total_frame_num, int(fps / setting_fps))]
spare_frames = vr.get_batch(frame_idx).asnumpy()
return spare_frames
def change_file(file_path, mm_vision_tower):
org_data = load(file_path)
org_data['image_processor'] = './vlmeval/vlm/video_llm/configs/llama_vid/processor/clip-patch14-224'
org_data['mm_vision_tower'] = mm_vision_tower
dump(org_data, file_path)
class LLaMAVID(BaseModel):
INSTALL_REQ = True
INTERLEAVE = False
VIDEO_LLM = True
# sample 1 fps from the video
def __init__(self, model_path='YanweiLi/llama-vid-7b-full-224-video-fps-1', **kwargs):
assert model_path is not None
try:
from llamavid.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path
except Exception as err:
logging.critical('Please install LLaMA-VID from https://github.com/dvlab-research/LLaMA-VID.')
raise err
model_base = None
model_name = get_model_name_from_path(model_path)
eva_vit_g_url = 'https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth'
true_model_path = snapshot_download(model_path)
eva_vit_path = os.path.join(true_model_path, 'eva_vit_g.pth')
if not os.path.exists(eva_vit_path):
download_file(eva_vit_g_url, eva_vit_path)
config_path = os.path.join(true_model_path, 'config.json')
change_file(config_path, eva_vit_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(
true_model_path, model_base, model_name, None, device_map='cpu', device='cpu'
)
model.cuda()
self.tokenizer = tokenizer
self.model = model
self.processor = image_processor
self.context_len = context_len
self.kwargs = kwargs
self.fps = 1
def get_model_output(self, model, video_processor, tokenizer, video, qs):
from llamavid.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from llamavid.constants import DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llamavid.conversation import conv_templates, SeparatorStyle
from llava.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria
if type(qs) is dict:
original_qs = cp.deepcopy(qs['user'])
else:
original_qs = cp.deepcopy(qs)
if model.config.mm_use_im_start_end:
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
elif type(qs) is dict and 'user' in qs:
qs['user'] = DEFAULT_IMAGE_TOKEN + '\n' + qs['user']
else:
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
conv_mode = 'vicuna_v1'
conv = conv_templates[conv_mode].copy()
if type(qs) is dict and 'system' in qs:
conv.system = qs['system']
if type(qs) is dict and 'user' in qs:
conv.append_message(conv.roles[0], qs['user'])
else:
conv.append_message(conv.roles[0], qs)
if type(qs) is dict and 'assistant' in qs:
conv.append_message(conv.roles[1], qs['assistant'])
else:
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt().strip('</s>')
# Check if the video exists
if os.path.exists(video):
video = load_video(video, self.fps)
video = video_processor.preprocess(video, return_tensors='pt')['pixel_values'].half().cuda()
video = [video]
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
cur_prompt = original_qs
with torch.inference_mode():
model.update_prompt([[cur_prompt]])
output_ids = model.generate(
input_ids,
images=video,
do_sample=True,
temperature=0.2,
max_new_tokens=1024,
use_cache=True,
stopping_criteria=[stopping_criteria],
)
input_token_len = input_ids.shape[1]
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[: -len(stop_str)]
outputs = outputs.strip()
return outputs
def generate_inner(self, message, dataset=None):
if listinstr(['MLVU', 'MVBench'], dataset):
question, video = self.message_to_promptvideo_withrole(message, dataset)
else:
question, video = self.message_to_promptvideo(message)
response = self.get_model_output(self.model, self.processor, self.tokenizer, video, question)
return response
import torch
import warnings
import copy as cp
import numpy as np
import sys
from PIL import Image
import torchvision
import logging
from ..base import BaseModel
from ...smp import isimg, listinstr, get_rank_and_world_size
from ...dataset import DATASET_TYPE
from huggingface_hub import snapshot_download
class PLLaVA(BaseModel):
INSTALL_REQ = True
INTERLEAVE = False
VIDEO_LLM = True
def __init__(self, model_path='ermu2001/pllava-13b', dir_root=None, **kwargs):
sys.path.append(dir_root)
try:
from tasks.eval.model_utils import load_pllava
except Exception as err:
logging.critical(
'Please first install requirements and set the root path to use PLLaVA. \
Follow the instructions at https://github.com/magic-research/PLLaVA.'
)
raise err
rank, world_size = get_rank_and_world_size()
self.nframe = 16
self.use_lora = True
self.lora_alpha = 4
self.pooling_shape = (16, 12, 12)
self.RESOLUTION = 672
self.model_path = model_path
# remind that, once the model goes larger (30B+) may cause the memory to be heavily used up. Even Tearing Nodes.
weight_dir = snapshot_download(model_path)
self.model, self.processor = load_pllava(
model_path, num_frames=self.nframe, use_lora=self.use_lora,
weight_dir=weight_dir, lora_alpha=self.lora_alpha, pooling_shape=self.pooling_shape
)
# position embedding
self.model = self.model.to(torch.device(rank))
self.model = self.model.eval()
def load_video(self, video_path, num_segments=8, resolution=336):
from decord import VideoReader, cpu
transforms = torchvision.transforms.Resize(size=resolution)
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
num_frames = len(vr)
frame_indices = self.get_index(num_frames, num_segments)
images_group = list()
for frame_index in frame_indices:
img = Image.fromarray(vr[frame_index].asnumpy())
images_group.append(transforms(img))
return images_group
def get_index(self, num_frames, num_segments):
seg_size = float(num_frames - 1) / num_segments
start = int(seg_size / 2)
offsets = np.array([
start + int(np.round(seg_size * idx)) for idx in range(num_segments)
])
return offsets
def generate_inner(self, message, dataset=None):
from tasks.eval.model_utils import pllava_answer
from tasks.eval.eval_utils import conv_templates
question, video = self.message_to_promptvideo(message)
img_list = self.load_video(video, num_segments=self.nframe, resolution=self.RESOLUTION)
if self.model_path == 'ermu2001/pllava-34b': # using slightly different conversation mode for 34b model
if listinstr(['Video-MCQ'], DATASET_TYPE(dataset)): # MCQ dataset
conv_mode = 'eval_mvbench_llavanext'
else: # VQA dataset
conv_mode = 'eval_videoqa_llavanext'
else:
if listinstr(['Video-MCQ'], DATASET_TYPE(dataset)): # MCQ dataset
conv_mode = 'eval_mvbench'
else: # VQA dataset
conv_mode = 'eval_videoqabench'
conv = conv_templates[conv_mode].copy()
if dataset in ['MVBench', 'MVBench_MP4']:
conv.user_query(message[1]['value'], message[0]['value'], message[-2]['value'], is_mm=True)
conv.assistant_response(message[-1]['value'])
else:
conv.user_query(question, is_mm=True)
llm_response, conv = pllava_answer(
conv=conv, model=self.model, processor=self.processor,
do_sample=False, img_list=img_list, max_new_tokens=512, print_res=False
)
if dataset in ['MVBench', 'MVBench_MP4']:
llm_response = '(' + ''.join(llm_response.split(message[-1]['value'])[1:])
return llm_response
import torch
import os
import warnings
import copy as cp
import numpy as np
import sys
import logging
from ..base import BaseModel
from ...smp import isimg, listinstr
from ...dataset import DATASET_TYPE
from huggingface_hub import snapshot_download
class VideoChatGPT(BaseModel):
INSTALL_REQ = True
INTERLEAVE = False
VIDEO_LLM = True
# sample a video in 100 frames
def __init__(self, model_path='MBZUAI/Video-ChatGPT-7B', dir_root=None, **kwargs):
assert model_path is not None
sys.path.append(dir_root)
try:
from video_chatgpt.eval.model_utils import initialize_model
except Exception as err:
logging.critical(
'Please first install requirements and set the root path to use Video-ChatGPT. \
Follow the instructions at https://github.com/mbzuai-oryx/Video-ChatGPT.'
)
raise err
base_model_path = snapshot_download('mmaaz60/LLaVA-7B-Lightening-v1-1')
projection_path = snapshot_download(model_path)
projection_name = 'video_chatgpt-7B.bin'
projection_path = os.path.join(projection_path, projection_name)
model, vision_tower, tokenizer, image_processor, video_token_len = initialize_model(
base_model_path, projection_path
)
self.tokenizer = tokenizer
self.model = model
self.processor = image_processor
self.context_len = video_token_len
self.kwargs = kwargs
self.vision_tower = vision_tower
def get_model_output(self, model, video_processor, tokenizer, video, qs):
from video_chatgpt.eval.model_utils import load_video
from video_chatgpt.inference import video_chatgpt_infer
conv_mode = 'video-chatgpt_v1'
video_frames = load_video(video)
# Run inference on the video and questions
output = video_chatgpt_infer(
video_frames,
qs,
conv_mode,
model,
self.vision_tower,
tokenizer,
video_processor,
self.context_len,
)
return output
def generate_inner(self, message, dataset=None):
question, video = self.message_to_promptvideo(message)
response = self.get_model_output(self.model, self.processor, self.tokenizer, video, question)
return response
import torch
import warnings
import copy as cp
import numpy as np
import sys
import logging
from ..base import BaseModel
from ...smp import isimg, listinstr
from ...dataset import DATASET_TYPE
def read_video_pyav(container, indices):
frames = []
container.seek(0)
start_index = indices[0]
end_index = indices[-1]
for i, frame in enumerate(container.decode(video=0)):
if i > end_index:
break
if i >= start_index and i in indices:
frames.append(frame)
return np.stack([x.to_ndarray(format='rgb24') for x in frames])
class VideoLLaVA_HF(BaseModel):
INSTALL_REQ = False
INTERLEAVE = False
VIDEO_LLM = True
# sample a video in 8 frames
def __init__(self, model_path='LanguageBind/Video-LLaVA-7B-hf', **kwargs):
try:
from transformers import VideoLlavaProcessor, VideoLlavaForConditionalGeneration
except Exception as err:
logging.critical('Please install the latest version transformers. \
You can install by `pip install transformers==4.42.0` \
or `pip install --upgrade git+https://github.com/huggingface/transformers.git`.')
raise err
assert model_path is not None
self.model_path = model_path
self.model = VideoLlavaForConditionalGeneration.from_pretrained(model_path)
self.model.eval().cuda()
self.processor = VideoLlavaProcessor.from_pretrained(model_path)
self.kwargs = kwargs
torch.cuda.empty_cache()
def generate_inner(self, message, dataset=None):
import av
question, video = self.message_to_promptvideo(message)
container = av.open(video)
# sample uniformly 8 frames from the video
total_frames = container.streams.video[0].frames
indices = np.arange(0, total_frames, total_frames / self.nframe).astype(int)
clip = read_video_pyav(container, indices)
prompt = f'USER: <video>\n{question} ASSISTANT:'
inputs = self.processor(text=prompt, videos=clip, return_tensors='pt').to(self.model.device)
# Generate args -- deperecated
generation_args = {
'max_new_tokens': 1024,
'temperature': 0.0,
'do_sample': False,
}
generation_args.update(self.kwargs)
generate_ids = self.model.generate(**inputs, **generation_args)
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
response = self.processor.batch_decode(
generate_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)[0]
return response
class VideoLLaVA(BaseModel):
INSTALL_REQ = True
INTERLEAVE = False
VIDEO_LLM = True
# sample a video in 8 frames
def __init__(self, model_path='LanguageBind/Video-LLaVA-7B', **kwargs):
assert model_path is not None
try:
from videollava.conversation import conv_templates, SeparatorStyle
from videollava.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
from videollava.constants import DEFAULT_VID_START_TOKEN, DEFAULT_VID_END_TOKEN
from videollava.mm_utils import get_model_name_from_path, tokenizer_image_token, KeywordsStoppingCriteria
from videollava.model.builder import load_pretrained_model
from videollava.model.language_model.llava_llama import LlavaLlamaForCausalLM
from videollava.train.train import smart_tokenizer_and_embedding_resize
except Exception as err:
logging.critical('Please install Video-LLaVA from https://github.com/FangXinyu-0913/Video-LLaVA.')
raise err
model_base = None
model_name = model_path.split('/')[-1]
tokenizer, model, processor, context_len = load_pretrained_model(model_path, model_base, model_name)
self.tokenizer = tokenizer
self.model = model
self.processor = processor
self.context_len = context_len
self.kwargs = kwargs
self.nframes = 8
def get_model_output(self, model, video_processor, tokenizer, video, qs):
from videollava.conversation import conv_templates, SeparatorStyle
from videollava.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
from videollava.constants import DEFAULT_VID_START_TOKEN, DEFAULT_VID_END_TOKEN
from videollava.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria
if type(qs) is dict and 'user' in qs:
qs['user'] = ''.join([DEFAULT_IMAGE_TOKEN] * self.nframes) + '\n' + qs['user']
else:
qs = ''.join([DEFAULT_IMAGE_TOKEN] * self.nframes) + '\n' + qs
conv_mode = 'llava_v1'
device = torch.device('cuda')
conv = conv_templates[conv_mode].copy()
if type(qs) is dict and 'system' in qs:
conv.system = qs['system']
if type(qs) is dict and 'user' in qs:
conv.append_message(conv.roles[0], qs['user'])
else:
conv.append_message(conv.roles[0], qs)
if type(qs) is dict and 'assistant' in qs:
conv.append_message(conv.roles[1], qs['assistant'])
else:
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt().strip('</s>')
video_tensor = video_processor.preprocess(video, return_tensors='pt')['pixel_values'][0].half().to(device)
input_ids = tokenizer_image_token(prompt, tokenizer,
IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(device)
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=[video_tensor],
do_sample=False,
temperature=0.0,
max_new_tokens=1024,
use_cache=True,
stopping_criteria=[stopping_criteria])
input_token_len = input_ids.shape[1]
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
outputs = outputs.strip()
return outputs
def generate_inner(self, message, dataset=None):
if self.nframes != 8:
raise Exception(f'Video-LLaVA only supported 8 frames to generate, you now set frame numbers to {self.nframes}') # noqa
if listinstr(['MLVU', 'MVBench'], dataset):
question, video = self.message_to_promptvideo_withrole(message, dataset)
else:
question, video = self.message_to_promptvideo(message)
response = self.get_model_output(self.model, self.processor['video'], self.tokenizer, video, question)
return response
import torch
import warnings
import copy as cp
import numpy as np
import sys
import os.path as osp
import os
import requests
import shutil
import huggingface_hub
import logging
from transformers import StoppingCriteria, StoppingCriteriaList
from huggingface_hub import snapshot_download
from PIL import Image
from torchvision.transforms import PILToTensor
from torchvision import transforms
from ..base import BaseModel
from ...smp import *
from ...dataset import DATASET_TYPE
def get_prompt(conv):
ret = conv.system + conv.sep
for role, message in conv.messages:
if message:
ret += role + ' ' + message + ' ' + conv.sep
else:
ret += role
return ret
def get_prompt2(conv):
ret = conv.system + conv.sep
count = 0
for role, message in conv.messages:
count += 1
if count == len(conv.messages):
ret += role + ' ' + message
else:
if message:
ret += role + ' ' + message + ' ' + conv.sep
else:
ret += role
return ret
class StoppingCriteriaSub(StoppingCriteria):
def __init__(self, stops=[], encounters=1):
super().__init__()
self.stops = stops
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
for stop in self.stops:
if torch.all((stop == input_ids[0][-len(stop):])).item():
return True
return False
class VideoChat2_HD(BaseModel):
INSTALL_REQ = True
INTERLEAVE = False
VIDEO_LLM = True
def __init__(self, model_path='OpenGVLab/VideoChat2_HD_stage4_Mistral_7B',
root='./Ask-Anything', config_file='./configs/videochat2_hd.json',
**kwargs):
from peft import get_peft_model, LoraConfig, TaskType
self.config_file = config_file
self.root = root
self.model_path = model_path
if root is None:
raise ValueError('Please set `root` to Ask-Anything directory, \
which is cloned from here: https://github.com/OpenGVLab/Ask-Anything')
sys.path.append(osp.join(root, 'video_chat2'))
try:
from utils.config import Config
from utils.easydict import EasyDict
from models import VideoChat2_it_hd_mistral
from dataset.hd_utils import HD_transform_padding, HD_transform_no_padding
except Exception as err:
logging.critical(
'Please first install VideoChat2 and set the root path to use VideoChat2, '
'which is cloned from here: https://github.com/OpenGVLab/Ask-Anything '
)
raise err
cfg = Config.from_file(self.config_file)
def download_file(url, pth):
destination_folder = pth
# 确保目标文件夹存在
if not os.path.exists(destination_folder):
os.makedirs(destination_folder)
# 获取文件名
filename = os.path.basename(url)
destination_path = os.path.join(destination_folder, filename)
if os.path.exists(destination_path):
print(f'File downloaded! No repeat download needed. Saved in {destination_path}')
return
# 下载文件
response = requests.get(url, stream=True)
if response.status_code == 200:
with open(destination_path, 'wb') as file:
response.raw.decode_content = True
shutil.copyfileobj(response.raw, file)
print(f'File downloaded and saved to {destination_path}')
else:
print(f'Download failed, status code: {response.status_code}')
hf_token = os.environ.get('HUGGINGFACE_TOKEN')
huggingface_hub.login(hf_token)
videochat2_model_path = snapshot_download(repo_id=cfg.model.videochat2_model_path, repo_type='model')
cfg.model.videochat2_model_path = osp.join(videochat2_model_path, 'videochat2_mistral_7b_stage2.pth')
mistral_model_path = snapshot_download(repo_id=cfg.model.mistral_model_path, repo_type='model')
cfg.model.mistral_model_path = mistral_model_path
vit_blip_model_path = snapshot_download(repo_id=cfg.model.vit_blip_model_path, repo_type='model')
cfg.model.vit_blip_model_path = osp.join(vit_blip_model_path, 'umt_l16_qformer.pth')
model = VideoChat2_it_hd_mistral(config=cfg.model)
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM, inference_mode=False,
r=16, lora_alpha=32, lora_dropout=0.,
target_modules=[
'q_proj', 'k_proj', 'v_proj', 'o_proj',
'gate_proj', 'up_proj', 'down_proj', 'lm_head'
]
)
model.mistral_model = get_peft_model(model.mistral_model, peft_config)
stage4_model_path = snapshot_download(repo_id=model_path, repo_type='model')
state_dict = torch.load(osp.join(stage4_model_path, 'videochat2_hd_mistral_7b_stage4.pth'), 'cuda')
if 'model' in state_dict.keys():
model.load_state_dict(state_dict['model'], strict=False)
else:
model.load_state_dict(state_dict, strict=False)
model = model.to(torch.device('cuda'))
model = model.eval()
self.model = model
# position embedding
self.nframe = 16
self.resolution = 224
self.hd_num = 6
new_pos_emb = self.get_sinusoid_encoding_table(
n_position=(self.resolution // 16) ** 2 * self.nframe,
cur_frame=self.nframe
)
self.model.vision_encoder.encoder.pos_embed = new_pos_emb
self.hd_transform = HD_transform_no_padding
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
self.transform = transforms.Compose([
transforms.Lambda(lambda x: x.float().div(255.0)),
transforms.Normalize(mean, std)
])
def get_sinusoid_encoding_table(self, n_position=784, d_hid=1024,
cur_frame=8, ckpt_num_frame=4,
pre_n_position=784):
''' Sinusoid position encoding table '''
# TODO: make it with torch instead of numpy
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
# generate checkpoint position embedding
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(pre_n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
sinusoid_table = torch.tensor(sinusoid_table, dtype=torch.float, requires_grad=False).unsqueeze(0)
print(f'n_position: {n_position}')
print(f'pre_n_position: {pre_n_position}')
if n_position != pre_n_position:
T = ckpt_num_frame # checkpoint frame
P = 14 # checkpoint size
C = d_hid
new_P = int((n_position // cur_frame) ** 0.5) # testing size
if new_P != 14:
print(f'Pretraining uses 14x14, but current version is {new_P}x{new_P}')
print('Interpolate the position embedding')
sinusoid_table = sinusoid_table.reshape(-1, T, P, P, C)
sinusoid_table = sinusoid_table.reshape(-1, P, P, C).permute(0, 3, 1, 2)
sinusoid_table = torch.nn.functional.interpolate(
sinusoid_table, size=(new_P, new_P), mode='bicubic', align_corners=False)
# BT, C, H, W -> BT, H, W, C -> B, T, H, W, C
sinusoid_table = sinusoid_table.permute(0, 2, 3, 1).reshape(-1, T, new_P, new_P, C)
sinusoid_table = sinusoid_table.flatten(1, 3) # B, THW, C
if cur_frame != ckpt_num_frame:
print(f'Pretraining uses 4 frames, but current frame is {cur_frame}')
print('Interpolate the position embedding')
T = ckpt_num_frame # checkpoint frame
new_T = cur_frame # testing frame
# interpolate
P = int((n_position // cur_frame) ** 0.5) # testing size
C = d_hid
sinusoid_table = sinusoid_table.reshape(-1, T, P, P, C)
sinusoid_table = sinusoid_table.permute(0, 2, 3, 4, 1).reshape(-1, C, T) # BHW, C, T
sinusoid_table = torch.nn.functional.interpolate(sinusoid_table, size=new_T, mode='linear')
sinusoid_table = sinusoid_table.reshape(1, P, P, C, new_T).permute(0, 4, 1, 2, 3) # B, T, H, W, C
sinusoid_table = sinusoid_table.flatten(1, 3) # B, THW, C
return sinusoid_table
def get_index(self, bound, fps, max_frame, first_idx=0):
if bound:
start, end = bound[0], bound[1]
else:
start, end = -100000, 100000
start_idx = max(first_idx, round(start * fps))
end_idx = min(round(end * fps), max_frame)
seg_size = float(end_idx - start_idx) / self.nframe
frame_indices = np.array([
int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
for idx in range(self.nframe)
])
return frame_indices
def read_video(self, video_path, bound=None):
from decord import VideoReader, cpu
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
max_frame = len(vr) - 1
fps = float(vr.get_avg_fps())
frame_indices = self.get_index(bound, fps, max_frame, first_idx=0)
frames = vr.get_batch(frame_indices)
frames = frames.permute(0, 3, 1, 2)
frames = self.hd_transform(frames.float(), image_size=self.resolution, hd_num=self.hd_num)
torch_imgs = self.transform(frames)
return torch_imgs
def ask(self, text, conv):
conv.messages.append([conv.roles[0], text])
def get_context_emb(self, conv, model, img_list, answer_prompt=None, print_res=False):
if answer_prompt:
prompt = get_prompt2(conv)
else:
prompt = get_prompt(conv)
if print_res:
print(prompt)
if '<VideoHere>' in prompt:
prompt_segs = prompt.split('<VideoHere>')
else:
prompt_segs = prompt.split('<ImageHere>')
assert len(prompt_segs) == len(img_list) + 1, 'Unmatched numbers of image placeholders and images.'
with torch.no_grad():
seg_tokens = [
model.mistral_tokenizer(
seg, return_tensors='pt', add_special_tokens=i == 0).to('cuda').input_ids
# only add bos to the first seg
for i, seg in enumerate(prompt_segs)
]
seg_embs = [model.mistral_model.base_model.model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
# seg_embs = [model.mistral_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
mixed_embs = torch.cat(mixed_embs, dim=1)
return mixed_embs
def answer(self, conv, model, img_list, do_sample=True, max_new_tokens=500, num_beams=1, min_length=1, top_p=0.9,
repetition_penalty=1.0, length_penalty=1, temperature=1.0, answer_prompt=None, print_res=False):
stop_words_ids = [
torch.tensor([2]).to('cuda'),
torch.tensor([29871, 2]).to('cuda')] # '</s>' can be encoded in two different ways.
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
conv.messages.append([conv.roles[1], answer_prompt])
embs = self.get_context_emb(conv, model, img_list, answer_prompt=answer_prompt, print_res=print_res)
with torch.no_grad():
outputs = model.mistral_model.generate(
inputs_embeds=embs,
max_new_tokens=max_new_tokens,
stopping_criteria=stopping_criteria,
num_beams=num_beams,
do_sample=do_sample,
min_length=min_length,
top_p=top_p,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
temperature=temperature,
)
output_token = outputs[0]
if output_token[0] == 0: # the model might output a unknow token <unk> at the beginning. remove it
output_token = output_token[1:]
if output_token[0] == 1: # some users find that there is a start token <s> at the beginning. remove it
output_token = output_token[1:]
output_text = model.mistral_tokenizer.decode(output_token, add_special_tokens=False)
output_text = output_text.split('</s>')[0] # remove the stop sign </s>
# output_text = output_text.split('[/INST]')[-1].strip()
conv.messages[-1][1] = output_text + '</s>'
return output_text, output_token.cpu().numpy()
def infer_data(
self, data_sample, system=' ',
question_prompt='', # add in the end of question
answer_prompt=None, # add in the begining of answer
system_q=False, # whether add question in the system prompt for QFormer
print_res=True,
system_llm=False
):
assert system_q is False, 'do not support system_q now'
video = data_sample['video']
T_, C, H, W = video.shape
video = video.reshape(1, T_, C, H, W).to('cuda')
video_list = []
with torch.no_grad():
if system_q:
raise NotImplementedError
else:
video_emb, _, _ = self.model.encode_img(video, system)
video_list.append(video_emb[0])
question = data_sample['question']
from utils.easydict import EasyDict
chat = EasyDict({
'system': system,
'roles': ('[INST]', '[/INST]'),
'messages': [],
'sep': ''
})
if data_sample['subtitle'] != '':
subtitle = f"This video's subtitles are listed below: {data_sample['subtitle']}"
chat.messages.append([chat.roles[0], f'{subtitle}\n<Video><VideoHere></Video> [/INST]'])
else:
chat.messages.append([chat.roles[0], '<Video><VideoHere></Video> [/INST]'])
if system_llm:
prompt = system + question + question_prompt
else:
prompt = question + question_prompt
self.ask(prompt, chat)
llm_message = self.answer(
conv=chat, model=self.model, do_sample=False,
img_list=video_list, max_new_tokens=100,
answer_prompt=answer_prompt, print_res=print_res
)[0]
return llm_message.strip()
def qa_template(self, data):
question = data.split('Answer:')[0].split('\n')[0] + '\n'
question += 'Options:\n'
choices = data.split('Answer:')[0].split('\n')[1:]
choices = [item for item in choices if item != ''] # remove blank space
for idx, c in enumerate(choices):
cur_choice, cur_text = c[0], c[3:]
question += f'({cur_choice}) {cur_text}\n'
question = question.rstrip()
return question
def split_subtitle(self, data):
if 'This video\'s subtitles are listed below' in data:
# 找到subtitle的起始和结束位置
start_marker = 'This video\'s subtitles are listed below:'
end_marker = 'Select the best answer to the following multiple-choice question based on the video.'
start_index = data.find(start_marker) + len(start_marker)
end_index = data.find(end_marker)
# 提取subtitle部分
subtitle = data[start_index:end_index].strip()
return subtitle
else:
return ''
def generate_inner(self, message, dataset=None):
if dataset == 'Video-MME':
_, video = self.message_to_promptvideo(message)
torch_imgs = self.read_video(video)
subtitle = self.split_subtitle(message[-2]['value'])
question = self.qa_template(message[-1]['value'])
example = {
'subtitle': subtitle,
'video': torch_imgs,
'question': question
}
pred_option = self.infer_data(
example,
' ',
question_prompt='\nOnly give the best option.',
answer_prompt='Best option:(',
system_q=False,
print_res=False,
system_llm=True
)
return_message = '(' + pred_option.split('\n')[0]
return return_message
elif listinstr(['MLVU', 'MVBench', 'TempCompass'], dataset):
question, video = self.message_to_promptvideo_withrole(message, dataset)
torch_imgs = self.read_video(video)
example = {
'subtitle': '',
'video': torch_imgs,
'question': question['user']
}
if 'assistant' not in question:
question['assistant'] = None
if question['system'] == '':
question['system'] = ' '
pred_option = self.infer_data(
example,
question['system'],
answer_prompt=question['assistant'],
system_q=False,
print_res=False,
system_llm=True
)
return_message = '(' + pred_option.split('\n')[0]
return return_message
else:
question, video = self.message_to_promptvideo(message)
torch_imgs = self.read_video(video)
example = {
'subtitle': '',
'video': torch_imgs,
'question': f'Question:{question}\nAnswer:'
}
pred_result = self.infer_data(
example,
' ',
system_q=False,
print_res=False,
system_llm=False
)
return pred_result
import torch
from PIL import Image
from abc import abstractproperty
import sys
import os.path as osp
from .base import BaseModel
from ..smp import *
from ..dataset import DATASET_TYPE
import copy
class VILA(BaseModel):
INSTALL_REQ = True
INTERLEAVE = True
def __init__(self,
model_path='Efficient-Large-Model/Llama-3-VILA1.5-8b',
**kwargs):
try:
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path
from llava.mm_utils import process_images, tokenizer_image_token, KeywordsStoppingCriteria
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN # noqa E501
from llava.conversation import conv_templates, SeparatorStyle
except Exception as err:
logging.critical('Please install VILA before using VILA')
logging.critical('Please install VILA from https://github.com/NVlabs/VILA')
logging.critical('Please install VLMEvalKit after installing VILA')
logging.critical('VILA is supported only with transformers==4.36.2')
raise err
warnings.warn('Please install the latest version of VILA from GitHub before you evaluate the VILA model.')
assert osp.exists(model_path) or len(model_path.split('/')) == 2
model_name = get_model_name_from_path(model_path)
try:
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
model_path=model_path,
model_base=None,
model_name=model_name,
device='cpu',
device_map='cpu'
)
except Exception as err:
logging.critical('Error loading VILA model: ')
raise err
self.model = self.model.cuda()
if '3b' in model_path:
self.conv_mode = 'vicuna_v1'
if '8b' in model_path:
self.conv_mode = 'llama_3'
elif '13b' in model_path:
self.conv_mode = 'vicuna_v1'
elif '40b' in model_path:
self.conv_mode = 'hermes-2'
kwargs_default = dict(do_sample=False, temperature=0, max_new_tokens=512, top_p=None, num_beams=1, use_cache=True) # noqa E501
kwargs_default.update(kwargs)
self.kwargs = kwargs_default
warnings.warn(f'Using the following kwargs for generation config: {self.kwargs}')
self.conv_templates = conv_templates
self.process_images = process_images
self.tokenizer_image_token = tokenizer_image_token
self. DEFAULT_IMAGE_TOKEN = DEFAULT_IMAGE_TOKEN
self.SeparatorStyle = SeparatorStyle
self.IMAGE_TOKEN_INDEX = IMAGE_TOKEN_INDEX
self.KeywordsStoppingCriteria = KeywordsStoppingCriteria
def use_custom_prompt(self, dataset):
assert dataset is not None
# TODO see if custom prompt needed
return False
def generate_inner(self, message, dataset=None):
content, images = '', []
for msg in message:
if msg['type'] == 'text':
content += msg['value']
elif msg['type'] == 'image':
image = Image.open(msg['value']).convert('RGB')
images.append(image)
content += (self.DEFAULT_IMAGE_TOKEN + '\n')
image_tensor = self.process_images(
images, self.image_processor,
self.model.config).to(self.model.device, dtype=torch.float16)
# Support interleave text and image
conv = self.conv_templates[self.conv_mode].copy()
conv.append_message(conv.roles[0], content)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = self.tokenizer_image_token(prompt, self.tokenizer, self.IMAGE_TOKEN_INDEX,
return_tensors='pt').unsqueeze(0).cuda()
stop_str = conv.sep if conv.sep_style != self.SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = self.KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)
with torch.inference_mode():
output_ids = self.model.generate(
input_ids, images=image_tensor, stopping_criteria=[stopping_criteria], **self.kwargs)
output = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
return output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment