Commit 112bf76b authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #1826 canceled with stages
import math
import pdb
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class DTCBlock(nn.Module):
def __init__(
self, input_dim, output_dim, kernel_size, stride, causal_conv, dilation, dropout_rate
):
super(DTCBlock, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
if causal_conv:
self.padding = 0
self.lorder = (kernel_size - 1) * self.dilation
self.left_padding = nn.ConstantPad1d((self.lorder, 0), 0.0)
else:
assert (kernel_size - 1) % 2 == 0
self.padding = ((kernel_size - 1) // 2) * self.dilation
self.lorder = 0
self.causal_conv = causal_conv
self.depthwise_conv = nn.Conv1d(
self.input_dim,
self.input_dim,
self.kernel_size,
self.stride,
self.padding,
self.dilation,
groups=self.input_dim,
)
self.point_conv_1 = nn.Conv1d(self.input_dim, self.input_dim, 1, 1, self.padding)
self.point_conv_2 = nn.Conv1d(self.input_dim, self.input_dim, 1, 1, self.padding)
self.bn_1 = nn.BatchNorm1d(self.input_dim)
self.bn_2 = nn.BatchNorm1d(self.input_dim)
self.bn_3 = nn.BatchNorm1d(self.input_dim)
self.dropout = nn.Dropout(p=dropout_rate)
# buffer = 1, self.input_dim, self.lorder
self.lorder = (kernel_size - 1) * self.dilation - (self.stride - 1)
self.buffer_size = 1 * self.input_dim * self.lorder
@torch.jit.unused
def forward(self, x):
x_in = x
x_data = x_in.transpose(1, 2)
if self.causal_conv:
x_data_pad = self.left_padding(x_data)
else:
x_data_pad = x_data
x_depth = self.depthwise_conv(x_data_pad)
x_bn_1 = self.bn_1(x_depth)
x_point_1 = self.point_conv_1(x_bn_1)
x_bn_2 = self.bn_2(x_point_1)
x_relu_2 = torch.relu(x_bn_2)
x_point_2 = self.point_conv_2(x_relu_2)
x_bn_3 = self.bn_3(x_point_2)
x_bn_3 = x_bn_3.transpose(1, 2)
if self.stride == 1:
x_relu_3 = torch.relu(x_bn_3 + x_in)
else:
x_relu_3 = torch.relu(x_bn_3)
x_drop = self.dropout(x_relu_3)
return x_drop
@torch.jit.export
def infer(self, x, buffer, buffer_index, buffer_out):
# type: (Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor]
x_in = x
x = x_in.transpose(1, 2)
cnn_buffer = buffer[buffer_index : buffer_index + self.buffer_size].reshape(
[1, self.input_dim, self.lorder]
)
x = torch.cat([cnn_buffer, x], dim=2)
buffer_out.append(x[:, :, -self.lorder :].reshape(-1))
buffer_index = buffer_index + self.buffer_size
x = self.depthwise_conv(x)
x = self.bn_1(x)
x = self.point_conv_1(x)
x = self.bn_2(x)
x = torch.relu(x)
x = self.point_conv_2(x)
x = self.bn_3(x)
x = x.transpose(1, 2)
if self.stride == 1:
x = torch.relu(x + x_in)
else:
x = torch.relu(x)
return x, buffer, buffer_index, buffer_out
import torch
import torch.nn as nn
import torch.nn.functional as F
class FsmnLayer(nn.Module):
def __init__(
self,
input_dim,
out_dim,
hidden_dim,
left_frame=1,
right_frame=1,
left_dilation=1,
right_dilation=1,
):
super(FsmnLayer, self).__init__()
self.input_dim = input_dim
self.out_dim = out_dim
self.hidden_dim = hidden_dim
self.left_frame = left_frame
self.right_frame = right_frame
self.left_dilation = left_dilation
self.right_dilation = right_dilation
self.conv_in = nn.Conv1d(input_dim, hidden_dim, kernel_size=1)
if left_frame > 0:
self.pad_left = nn.ConstantPad1d([left_dilation * left_frame, 0], 0.0)
self.conv_left = nn.Conv1d(
hidden_dim,
hidden_dim,
kernel_size=left_frame + 1,
dilation=left_dilation,
bias=False,
groups=hidden_dim,
)
if right_frame > 0:
self.pad_right = nn.ConstantPad1d([-right_dilation, right_dilation * right_frame], 0.0)
self.conv_right = nn.Conv1d(
hidden_dim,
hidden_dim,
kernel_size=right_frame,
dilation=right_dilation,
bias=False,
groups=hidden_dim,
)
self.conv_out = nn.Conv1d(hidden_dim, out_dim, kernel_size=1)
# cache = 1, self.hidden_dim, left_frame * left_dilation + right_frame * right_dilation
self.cache_size = left_frame * left_dilation + right_frame * right_dilation
self.buffer_size = self.hidden_dim * self.cache_size
self.p_in_raw_chache_size = self.right_frame * self.right_dilation
self.p_in_raw_buffer_size = self.hidden_dim * self.p_in_raw_chache_size
self.hidden_chache_size = self.right_frame * self.right_dilation
self.hidden_buffer_size = self.hidden_dim * self.hidden_chache_size
@torch.jit.unused
def forward(self, x, hidden=None):
x_data = x.transpose(1, 2)
p_in = self.conv_in(x_data)
if self.left_frame > 0:
p_left = self.pad_left(p_in)
p_left = self.conv_left(p_left)
else:
p_left = 0
if self.right_frame > 0:
p_right = self.pad_right(p_in)
p_right = self.conv_right(p_right)
else:
p_right = 0
p_out = p_in + p_right + p_left
if hidden is not None:
p_out = hidden + p_out
out = F.relu(self.conv_out(p_out))
out = out.transpose(1, 2)
return out, p_out
@torch.jit.export
def infer(self, x, buffer, buffer_index, buffer_out, hidden=None):
# type: (Tensor, Tensor, Tensor, Optional[Tensor]) -> Tuple[Tensor, Tensor, Tensor, Tensor]
p_in_raw = self.conv_in(x)
cnn_buffer = buffer[buffer_index : buffer_index + self.buffer_size].reshape(
[1, self.hidden_dim, self.cache_size]
)
p_in = torch.cat([cnn_buffer, p_in_raw], dim=2)
# buffer[buffer_index: buffer_index + self.buffer_size] = p_in[:, :, -self.cache_size:].reshape(-1)
buffer_out.append(p_in[:, :, -self.cache_size :].reshape(-1))
buffer_index = buffer_index + self.buffer_size
if self.left_frame > 0:
if self.right_frame > 0:
p_left = p_in[:, :, : -self.right_frame * self.right_dilation]
else:
p_left = p_in[:, :]
p_left_out = self.conv_left(p_left)
else:
p_left_out = torch.tensor([0])
if self.right_frame > 0:
p_right = p_in[:, :, self.left_frame * self.left_dilation + 1 :]
p_right_out = self.conv_right(p_right)
else:
p_right_out = torch.tensor([0])
if self.right_frame > 0:
p_in_raw_cnn_buffer = buffer[
buffer_index : buffer_index + self.p_in_raw_buffer_size
].reshape([1, self.hidden_dim, self.p_in_raw_chache_size])
p_in_raw = torch.cat([p_in_raw_cnn_buffer, p_in_raw], dim=2)
# buffer[buffer_index: buffer_index + self.p_in_raw_buffer_size] = p_in_raw[:, :, -self.p_in_raw_chache_size:].reshape(-1)
buffer_out.append(p_in_raw[:, :, -self.p_in_raw_chache_size :].reshape(-1))
buffer_index = buffer_index + self.p_in_raw_buffer_size
p_in_raw = p_in_raw[:, :, : -self.p_in_raw_chache_size]
p_out = p_in_raw + p_left_out + p_right_out
if hidden is not None:
if self.right_frame > 0:
hidden_cnn_buffer = buffer[
buffer_index : buffer_index + self.hidden_buffer_size
].reshape([1, self.hidden_dim, self.hidden_chache_size])
hidden = torch.cat([hidden_cnn_buffer, hidden], dim=2)
# buffer[buffer_index: buffer_index + self.hidden_buffer_size] = hidden[:, :, -self.hidden_chache_size:].reshape(-1)
buffer_out.append(hidden[:, :, -self.hidden_chache_size :].reshape(-1))
buffer_index = buffer_index + self.hidden_buffer_size
hidden = hidden[:, :, : -self.hidden_chache_size]
p_out = hidden + p_out
out = F.relu(self.conv_out(p_out))
return out, buffer, buffer_index, buffer_out, p_out
import argparse
import importlib
import json
import os
from distutils.util import strtobool as dist_strtobool
import torch
import yaml
IGNORE_ID = -1
def assign_args_from_yaml(args, yaml_path, prefix_key=None):
with open(yaml_path) as f:
ydict = yaml.load(f, Loader=yaml.FullLoader)
if prefix_key is not None:
ydict = ydict[prefix_key]
for k, v in ydict.items():
k_args = k.replace("-", "_")
if hasattr(args, k_args):
setattr(args, k_args, ydict[k])
return args
def get_model_conf(model_path):
model_conf = os.path.dirname(model_path) + "/model.json"
with open(model_conf, "rb") as f:
print("reading a config file from " + model_conf)
confs = json.load(f)
# for asr, tts, mt
idim, odim, args = confs
return argparse.Namespace(**args)
def strtobool(x):
return bool(dist_strtobool(x))
def dynamic_import(import_path, alias=dict()):
"""dynamic import module and class
:param str import_path: syntax 'module_name:class_name'
e.g., 'espnet.transform.add_deltas:AddDeltas'
:param dict alias: shortcut for registered class
:return: imported class
"""
if import_path not in alias and ":" not in import_path:
raise ValueError(
"import_path should be one of {} or "
'include ":", e.g. "espnet.transform.add_deltas:AddDeltas" : '
"{}".format(set(alias), import_path)
)
if ":" not in import_path:
import_path = alias[import_path]
module_name, objname = import_path.split(":")
m = importlib.import_module(module_name)
return getattr(m, objname)
def set_deterministic_pytorch(args):
# seed setting
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = False
def pad_list(xs, pad_value):
n_batch = len(xs)
max_len = max(x.size(0) for x in xs)
pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
for i in range(n_batch):
pad[i, : xs[i].size(0)] = xs[i]
return pad
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
batch_size = lengths.size(0)
max_len = max_len if max_len > 0 else lengths.max().item()
seq_range = torch.arange(0, max_len, dtype=torch.int64, device=lengths.device)
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
seq_length_expand = lengths.unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand
return mask
def subsequent_chunk_mask(
size: int,
ck_size: int,
num_l_cks: int = -1,
device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
ret = torch.zeros(size, size, device=device, dtype=torch.bool)
for i in range(size):
if num_l_cks < 0:
start = 0
else:
start = max((i // ck_size - num_l_cks) * ck_size, 0)
ending = min((i // ck_size + 1) * ck_size, size)
ret[i, start:ending] = True
return ret
def add_optional_chunk_mask(
xs: torch.Tensor,
masks: torch.Tensor,
use_dynamic_chunk: bool,
use_dynamic_left_chunk: bool,
decoding_chunk_size: int,
static_chunk_size: int,
num_decoding_left_chunks: int,
):
if use_dynamic_chunk:
max_len = xs.size(1)
if decoding_chunk_size < 0:
chunk_size = max_len
num_l_cks = -1
elif decoding_chunk_size > 0:
chunk_size = decoding_chunk_size
num_l_cks = num_decoding_left_chunks
else:
chunk_size = torch.randint(1, max_len, (1,)).item()
num_l_cks = -1
if chunk_size > max_len // 2:
chunk_size = max_len
else:
chunk_size = chunk_size % 25 + 1
if use_dynamic_left_chunk:
max_left_chunks = (max_len - 1) // chunk_size
num_l_cks = torch.randint(0, max_left_chunks, (1,)).item()
ck_masks = subsequent_chunk_mask(
xs.size(1), chunk_size, num_l_cks, xs.device
) # (L, L)
ck_masks = ck_masks.unsqueeze(0) # (1, L, L)
ck_masks = masks & ck_masks # (B, L, L)
elif static_chunk_size > 0:
num_l_cks = num_decoding_left_chunks
ck_masks = subsequent_chunk_mask(
xs.size(1), static_chunk_size, num_l_cks, xs.device
) # (L, L)
ck_masks = ck_masks.unsqueeze(0) # (1, L, L)
ck_masks = masks & ck_masks # (B, L, L)
else:
ck_masks = masks
return ck_masks
import math
import re
from functools import partial
from torch import nn
from timm.layers.norm_act import LayerNormAct2d
from torchvision.models.mobilenetv3 import InvertedResidual, InvertedResidualConfig
from torchvision.ops.misc import SqueezeExcitation as SELayer
class IdentityMap(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, *args, **kwargs):
return x
@property
def config(self):
return {"mm_projector_type": "identity"}
class Minigpt(nn.Module):
def __init__(self, config=None):
super(Minigpt, self).__init__()
# c*4 is the input size, and c is the output size for the linear layer
inc, ouc = config.mm_hidden_size, config.hidden_size
self.linear = nn.Linear(inc * 4, ouc)
def forward(self, x):
# x is the input tensor with shape [b, num_tokens, c]
b, num_tokens, c = x.shape
# Check if num_tokens is divisible by 4
if num_tokens % 4 != 0:
raise ValueError("num_tokens must be divisible by 4")
# Reshape x to [b, num_tokens/4, c*4]
x = x.view(b, num_tokens // 4, c * 4)
# Apply the linear transformation
x = self.linear(x)
return x
class Vanilla(nn.Module):
def __init__(self, config=None):
super(Vanilla, self).__init__()
# c*4 is the input size, and c is the output size for the linear layer
inc, ouc = config.mm_hidden_size, config.hidden_size
self.linear = nn.Linear(inc * 4, ouc)
def forward(self, x):
b, num_tokens, c = x.shape
# Check if num_tokens is divisible by 4
if num_tokens % 4 != 0:
raise ValueError("num_tokens must be divisible by 4")
# First, reshape to [b, num_tokens//4, 4, c]
x = x.view(b, num_tokens // 4, 4, c)
# Then, permute to interleave the tokens
x = x.permute(0, 1, 3, 2).contiguous()
# Finally, reshape to [b, num_tokens//4, c*4] to interleave features of 4 tokens
x = x.view(b, num_tokens // 4, c * 4)
# Apply the linear transformation
x = self.linear(x)
return x
class LDPBlock(nn.Module):
# Lightweight Downsample Projector Block
def __init__(self, config=None):
super().__init__()
inc, ouc = config.mm_hidden_size, config.hidden_size
layer_norm = partial(LayerNormAct2d, act_layer=None)
se_layer = partial(SELayer, scale_activation=nn.Hardsigmoid)
self.mlp = nn.Sequential(nn.Identity(), nn.Linear(inc, ouc), nn.GELU(), nn.Linear(ouc, ouc))
self.mb_block = nn.Sequential(
nn.Identity(),
InvertedResidual(
InvertedResidualConfig(ouc, 3, ouc, ouc, True, "HS", 1, 1, 1), layer_norm, se_layer
),
InvertedResidual(
InvertedResidualConfig(ouc, 3, ouc, ouc, True, "HS", 2, 1, 1), layer_norm, se_layer
),
)
def forward(self, x):
b, num_tokens, c = x.shape
h = int(math.sqrt(num_tokens))
x = self.mlp(x)
x = x.permute(0, 2, 1).reshape(b, -1, h, h)
x = self.mb_block(x)
x = x.flatten(2).permute(0, 2, 1)
return x
class LDPNetProjector(nn.Module):
def __init__(self, config=None):
super().__init__()
self.model = LDPBlock(config)
def forward(self, x):
return self.model(x)
class SPP(nn.Module):
def __init__(self, config=None, projector_type="v1"):
super().__init__()
self.projector_type = projector_type
inc, ouc = config.mm_hidden_size, config.hidden_size
self.linear_0 = nn.Linear(inc, inc)
self.linear_1 = nn.Linear(inc, ouc)
self.pooling = nn.AvgPool2d(kernel_size=2)
self.linear_2 = nn.Linear(ouc, ouc)
def forward(self, x):
b, num_tokens, c = x.shape
h = int(math.sqrt(num_tokens))
if "v1" in self.projector_type:
x = self.linear_1(x)
x = x.permute(0, 2, 1).reshape(b, -1, h, h)
x = self.pooling(x)
x = x.flatten(2).permute(0, 2, 1)
x = self.linear_2(x)
elif "v2" in self.projector_type:
x = self.linear_1(x)
x = self.linear_2(x)
x = x.permute(0, 2, 1).reshape(b, -1, h, h)
x = self.pooling(x)
x = x.flatten(2).permute(0, 2, 1)
elif "v3" in self.projector_type:
x = self.linear_0(x)
x = x.permute(0, 2, 1).reshape(b, -1, h, h)
x = self.pooling(x)
x = x.flatten(2).permute(0, 2, 1)
x = self.linear_1(x)
x = self.linear_2(x)
return x
def build_vision_projector(config, delay_load=False, **kwargs):
projector_type = getattr(config, "mm_projector_type", "mlp2x_gelu")
if projector_type == "linear":
return nn.Linear(config.mm_hidden_size, config.hidden_size)
elif projector_type.startswith("mlp"):
mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
if mlp_gelu_match:
mlp_depth = int(mlp_gelu_match.group(1))
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
return nn.Sequential(*modules)
elif projector_type.startswith("spp"):
return SPP(config, projector_type)
elif projector_type == "ldp":
return LDPNetProjector(config)
elif projector_type == "vanilla":
return Vanilla(config)
elif projector_type == "minigpt":
return Minigpt(config)
elif projector_type == "identity":
return IdentityMap()
raise ValueError(f"Unknown projector type: {projector_type}")
import math
from abc import ABC, abstractmethod
import torch
import torch.nn.functional as F
from vita.constants import AUDIO_TOKEN_INDEX, IGNORE_INDEX, IMAGE_TOKEN_INDEX
from .multimodal_encoder.builder import build_audio_encoder, build_vision_tower
from .multimodal_projector.builder import build_vision_projector
class VITAMetaModel:
def __init__(self, config):
super(VITAMetaModel, self).__init__(config)
if hasattr(config, "mm_vision_tower"):
self.vision_tower = build_vision_tower(
config, delay_load=not getattr(config, "continuous_training", False)
)
if getattr(config, "continuous_training", False):
config.continuous_training = False
self.mm_projector = build_vision_projector(config)
if hasattr(config, "mm_audio_encoder"):
self.audio_encoder = build_audio_encoder(config)
def get_vision_tower(self):
vision_tower = getattr(self, "vision_tower", None)
if type(vision_tower) is list:
vision_tower = vision_tower[0]
return vision_tower
def get_audio_encoder(self):
audio_encoder = getattr(self, "audio_encoder", None)
return audio_encoder
def initialize_vision_modules(self, model_args):
vision_tower = model_args.vision_tower
pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
self.config.mm_vision_tower = vision_tower
if self.get_vision_tower() is None:
vision_tower = build_vision_tower(model_args)
self.vision_tower = vision_tower
else:
vision_tower = self.vision_tower
vision_tower.load_model()
self.config.use_mm_proj = True
self.config.mm_projector_type = getattr(model_args, "mm_projector_type")
self.config.mm_hidden_size = vision_tower.hidden_size
if getattr(self, "mm_projector", None) is None:
self.mm_projector = build_vision_projector(self.config)
else:
# In case it is frozen by LoRA
for p in self.mm_projector.parameters():
p.requires_grad = True
if pretrain_mm_mlp_adapter is not None:
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location="cpu")
def get_w(weights, keyword):
return {k.split(keyword + ".")[1]: v for k, v in weights.items() if keyword in k}
self.mm_projector.load_state_dict(get_w(mm_projector_weights, "mm_projector"))
def initialize_audio_modules(self, model_args):
audio_encoder = model_args.audio_encoder
setattr(self.config, "mm_audio_encoder", audio_encoder)
if self.get_audio_encoder() is None:
# audio_encoder = build_audio_encoder(model_args)
audio_encoder = build_audio_encoder(self.config)
self.audio_encoder = audio_encoder
# from safetensors.torch import load_file
# import os
# audio_weights = {}
# import pdb; pdb.set_trace()
# for file_name in os.listdir(model_args.model_name_or_path):
# if file_name.endswith('safetensors'):
# audio_weights.update(
# {k[20:]: v for k, v in load_file(os.path.join(model_args.model_name_or_path, file_name)).items() if
# k.startswith('model.audio_encoder.')})
# import pdb; pdb.set_trace()
# self.audio_encoder.load_state_dict(audio_weights, strict=True)
checkpoint = torch.load(model_args.audio_encoder + "/final.pt", map_location="cpu")
model_dict = self.audio_encoder.state_dict()
for key in model_dict.keys():
if key in checkpoint.keys():
if model_dict[key].shape == checkpoint[key].shape:
model_dict[key] = checkpoint[key]
else:
print(
"Key {} has different shape, {} VS {}".format(
key, model_dict[key].shape, checkpoint[key].shape
)
)
else:
print("Key {} has not in resume model".format(key))
# import pdb; pdb.set_trace()
self.audio_encoder.load_state_dict(model_dict)
class VITAMetaForCausalLM(ABC):
@abstractmethod
def get_model(self):
pass
def get_vision_tower(self):
return self.get_model().get_vision_tower()
def get_audio_encoder(self):
return self.get_model().get_audio_encoder()
def pool_feats(self, x):
b, num_tokens, c = x.shape
h = int(math.sqrt(num_tokens))
x = x.permute(0, 2, 1).reshape(b, -1, h, h)
x = F.adaptive_avg_pool2d(x, (12, 12))
num_tokens = x.shape[2] * x.shape[3] # Recalculate the number of tokens after pooling
x = x.reshape(b, c, num_tokens).permute(0, 2, 1)
return x
def encode_images(self, images):
image_features = self.get_model().get_vision_tower()(images)
image_features = self.get_model().mm_projector(image_features)
return image_features
def encode_images_frameCat(self, images):
image_features = self.get_model().get_vision_tower()(images)
assert len(image_features) % 5 == 0
concatenated_features = []
for i in range(0, len(image_features), 5):
tensors_to_concat = [image_features[j] for j in range(i, i + 5)]
concatenated_tensor = torch.cat(tensors_to_concat, dim=-1)
concatenated_features.append(concatenated_tensor)
concatenated_features = torch.stack(concatenated_features)
image_features = concatenated_features
image_features = self.get_model().mm_projector(image_features)
return image_features
def prepare_inputs_labels_for_multimodal(
self, input_ids, position_ids, attention_mask, past_key_values, labels, images, audios
):
vision_tower = self.get_vision_tower()
if vision_tower is None or images is None or input_ids.shape[1] == 1:
if (
past_key_values is not None
and vision_tower is not None
and images is not None
and input_ids.shape[1] == 1
):
target_shape = past_key_values[-1][-1].shape[-2] + 1
attention_mask = torch.cat(
(
attention_mask,
torch.ones(
(attention_mask.shape[0], target_shape - attention_mask.shape[1]),
dtype=attention_mask.dtype,
device=attention_mask.device,
),
),
dim=1,
)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
return input_ids, position_ids, attention_mask, past_key_values, None, labels
if type(images) is list or images.ndim == 5:
concat_images = torch.cat([image for image in images], dim=0)
image_features = self.encode_images(concat_images)
split_sizes = [image.shape[0] for image in images]
image_features = torch.split(image_features, split_sizes, dim=0)
image_features = [x.flatten(0, 1).to(self.device) for x in image_features]
else:
image_features = self.encode_images(images).to(self.device)
audio_encoder = self.get_audio_encoder()
# audio_features = audio_encoder(audios['audios'], audios['lengths'])
if audios is not None:
audio_features = audio_encoder(audios["audios"], audios["lengths"])
else:
audio_features = None
# Let's just add dummy tensors if they do not exist,
# it is a headache to deal with None all the time.
# But it is not ideal, and if you have a better idea,
# please open an issue / submit a PR, thanks.
_labels = labels
_position_ids = position_ids
_attention_mask = attention_mask
if attention_mask is None:
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
else:
attention_mask = attention_mask.bool()
if position_ids is None:
position_ids = torch.arange(
0, input_ids.shape[1], dtype=torch.long, device=input_ids.device
)
if labels is None:
labels = torch.full_like(input_ids, IGNORE_INDEX)
# remove the padding using attention_mask -- TODO: double check
input_ids = [
cur_input_ids[cur_attention_mask]
for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)
]
labels = [
cur_labels[cur_attention_mask]
for cur_labels, cur_attention_mask in zip(labels, attention_mask)
]
new_input_embeds = []
new_labels = []
cur_image_idx = 0
cur_audio_idx = 0
# assert sum([(cur == IMAGE_TOKEN_INDEX).sum() for cur in input_ids]) <= image_features.shape[0]
# assert sum([(cur == AUDIO_TOKEN_INDEX).sum() for cur in input_ids]) <= audio_features['inputs_embeds'].shape[0]
assert (
sum([(cur == IMAGE_TOKEN_INDEX).sum() for cur in input_ids])
+ sum([(IMAGE_TOKEN_INDEX not in cur) for cur in input_ids])
== image_features.shape[0]
)
assert (
sum([(cur == AUDIO_TOKEN_INDEX).sum() for cur in input_ids])
+ sum([(AUDIO_TOKEN_INDEX not in cur) for cur in input_ids])
== audio_features["inputs_embeds"].shape[0]
)
for batch_idx, cur_input_ids in enumerate(input_ids):
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
num_audio_frames = (cur_input_ids == AUDIO_TOKEN_INDEX).sum()
if num_images == 0 and num_audio_frames == 0:
cur_image_features = image_features[cur_image_idx]
cur_audio_features = audio_features["inputs_embeds"][cur_audio_idx]
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
cur_input_embeds = torch.cat(
[cur_input_embeds_1, cur_image_features[0:0], cur_audio_features[0:0]], dim=0
)
new_input_embeds.append(cur_input_embeds)
new_labels.append(labels[batch_idx])
cur_image_idx += 1
cur_audio_idx += 1
continue
image_audio_token_indices = (
[-1]
+ torch.where(
(cur_input_ids == IMAGE_TOKEN_INDEX) | (cur_input_ids == AUDIO_TOKEN_INDEX)
)[0].tolist()
+ [cur_input_ids.shape[0]]
)
cur_input_ids_noim_noau = []
cur_labels = labels[batch_idx]
cur_labels_noim_noau = []
for i in range(len(image_audio_token_indices) - 1):
cur_input_ids_noim_noau.append(
cur_input_ids[
image_audio_token_indices[i] + 1 : image_audio_token_indices[i + 1]
]
)
cur_labels_noim_noau.append(
cur_labels[image_audio_token_indices[i] + 1 : image_audio_token_indices[i + 1]]
)
split_sizes = [x.shape[0] for x in cur_labels_noim_noau]
cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim_noau))
cur_input_embeds_no_im_no_au = torch.split(cur_input_embeds, split_sizes, dim=0)
cur_new_input_embeds = []
cur_new_labels = []
for i in range(num_images + num_audio_frames + 1):
cur_new_input_embeds.append(cur_input_embeds_no_im_no_au[i])
cur_new_labels.append(cur_labels_noim_noau[i])
if i < num_images + num_audio_frames:
if cur_input_ids[image_audio_token_indices[i + 1]] == IMAGE_TOKEN_INDEX:
cur_image_features = image_features[cur_image_idx]
cur_image_idx += 1
cur_new_input_embeds.append(cur_image_features)
cur_new_labels.append(
torch.full(
(cur_image_features.shape[0],),
IGNORE_INDEX,
device=cur_labels.device,
dtype=cur_labels.dtype,
)
)
elif cur_input_ids[image_audio_token_indices[i + 1]] == AUDIO_TOKEN_INDEX:
cur_audio_features = audio_features["inputs_embeds"][cur_audio_idx]
cur_audio_idx += 1
cur_new_input_embeds.append(cur_audio_features)
cur_new_labels.append(
torch.full(
(cur_audio_features.shape[0],),
IGNORE_INDEX,
device=cur_labels.device,
dtype=cur_labels.dtype,
)
)
else:
raise ValueError
if num_images != 0 and num_audio_frames == 0:
cur_audio_features = audio_features["inputs_embeds"][cur_audio_idx]
cur_audio_idx += 1
cur_new_input_embeds.append(cur_audio_features[0:0])
elif num_images == 0 and num_audio_frames != 0:
cur_image_features = image_features[cur_image_idx]
cur_image_idx += 1
cur_new_input_embeds.append(cur_image_features[0:0])
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
cur_new_labels = torch.cat(cur_new_labels)
new_input_embeds.append(cur_new_input_embeds)
new_labels.append(cur_new_labels)
assert cur_image_idx == image_features.shape[0]
assert cur_audio_idx == audio_features["inputs_embeds"].shape[0]
# Truncate sequences to max length as image embeddings can make the sequence longer
tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None)
if tokenizer_model_max_length is not None:
new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
# Combine them
max_len = max(x.shape[0] for x in new_input_embeds)
batch_size = len(new_input_embeds)
new_input_embeds_padded = []
new_labels_padded = torch.full(
(batch_size, max_len),
IGNORE_INDEX,
dtype=new_labels[0].dtype,
device=new_labels[0].device,
)
attention_mask = torch.zeros(
(batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device
)
position_ids = torch.zeros(
(batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device
)
for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
cur_len = cur_new_embed.shape[0]
if getattr(self.config, "tokenizer_padding_side", "right") == "left":
new_input_embeds_padded.append(
torch.cat(
(
torch.zeros(
(max_len - cur_len, cur_new_embed.shape[1]),
dtype=cur_new_embed.dtype,
device=cur_new_embed.device,
),
cur_new_embed,
),
dim=0,
)
)
if cur_len > 0:
new_labels_padded[i, -cur_len:] = cur_new_labels
attention_mask[i, -cur_len:] = True
position_ids[i, -cur_len:] = torch.arange(
0, cur_len, dtype=position_ids.dtype, device=position_ids.device
)
else:
new_input_embeds_padded.append(
torch.cat(
(
cur_new_embed,
torch.zeros(
(max_len - cur_len, cur_new_embed.shape[1]),
dtype=cur_new_embed.dtype,
device=cur_new_embed.device,
),
),
dim=0,
)
)
if cur_len > 0:
new_labels_padded[i, :cur_len] = cur_new_labels
attention_mask[i, :cur_len] = True
position_ids[i, :cur_len] = torch.arange(
0, cur_len, dtype=position_ids.dtype, device=position_ids.device
)
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
if _labels is None:
new_labels = None
else:
new_labels = new_labels_padded
if _attention_mask is None:
attention_mask = None
else:
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
if _position_ids is None:
position_ids = None
return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
import logging
import os
import pathlib
import random
from dataclasses import dataclass, field
from typing import Optional
import numpy as np
import torch
import transformers
from transformers import set_seed
from vita import conversation as conversation_lib
from vita.model import *
from vita.train.vita_trainer import VITATrainer
# from vita.util.data_utils_video_audio import make_supervised_data_module, DataArguments
# from vita.util.data_utils_video_audio_neg_patch import make_supervised_data_module, DataArguments
from vita.util.data_utils_video_audio_neg_frameCat import DataArguments, make_supervised_data_module
def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
set_seed(seed)
set_random_seed(42)
local_rank = None
def rank0_print(*args):
if local_rank == 0:
print(*args)
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default=None)
model_type: Optional[str] = field(default=None)
version: Optional[str] = field(default=None)
freeze_backbone: bool = field(default=False)
tune_mm_mlp_adapter: bool = field(default=False)
vision_tower: Optional[str] = field(default=None)
audio_encoder: Optional[str] = field(default=None)
freeze_audio_encoder: bool = field(default=True)
freeze_audio_encoder_adapter: bool = field(default=True)
unfreeze_vision_tower: bool = field(default=False)
use_s2: bool = field(default=False)
pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
mm_projector_type: Optional[str] = field(default="mlp2x_gelu")
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
remove_unused_columns: bool = field(default=False)
freeze_mm_mlp_adapter: bool = field(default=False)
mpt_attn_impl: Optional[str] = field(default="triton")
model_max_length: int = field(
default=512,
metadata={
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
},
)
double_quant: bool = field(
default=True,
metadata={"help": "Compress the quantization statistics through double quantization."},
)
quant_type: str = field(
default="nf4",
metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."},
)
bits: int = field(default=16, metadata={"help": "How many bits to use."})
lora_enable: bool = False
lora_r: int = 64
lora_alpha: int = 16
lora_dropout: float = 0.05
lora_weight_path: str = ""
lora_bias: str = "none"
mm_projector_lr: Optional[float] = None
group_by_modality_length: bool = field(default=False)
def maybe_zero_3(param, ignore_status=False, name=None):
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
if hasattr(param, "ds_id"):
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
if not ignore_status:
logging.warning(
f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}"
)
with zero.GatheredParameters([param]):
param = param.data.detach().cpu().clone()
else:
param = param.detach().cpu().clone()
return param
# Borrowed from peft.util.get_peft_model_state_dict
def get_peft_state_maybe_zero_3(named_params, bias):
if bias == "none":
to_return = {k: t for k, t in named_params if "lora_" in k}
elif bias == "all":
to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
elif bias == "lora_only":
to_return = {}
maybe_lora_bias = {}
lora_bias_names = set()
for k, t in named_params:
if "lora_" in k:
to_return[k] = t
bias_name = k.split("lora_")[0] + "bias"
lora_bias_names.add(bias_name)
elif "bias" in k:
maybe_lora_bias[k] = t
for k, t in maybe_lora_bias:
if bias_name in lora_bias_names:
to_return[bias_name] = t
else:
raise NotImplementedError
to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
return to_return
def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
to_return = {k: t for k, t in named_params if "lora_" not in k}
if require_grad_only:
to_return = {k: t for k, t in to_return.items() if t.requires_grad}
to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
return to_return
def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
to_return = {
k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)
}
to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
return to_return
def find_all_linear_names(model):
cls = torch.nn.Linear
lora_module_names = set()
multimodal_keywords = ["mm_projector", "vision_tower", "vision_resampler"]
for name, module in model.named_modules():
if any(mm_keyword in name for mm_keyword in multimodal_keywords):
continue
if isinstance(module, cls):
names = name.split(".")
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
if "lm_head" in lora_module_names: # needed for 16-bit
lora_module_names.remove("lm_head")
return list(lora_module_names)
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
"""Collects the state dict and dump to disk."""
if getattr(trainer.args, "tune_mm_mlp_adapter", False):
# Only save Adapter
keys_to_match = ["mm_projector"]
if getattr(trainer.args, "use_im_start_end", False):
keys_to_match.extend(["embed_tokens", "embed_in"])
weight_to_save = get_mm_adapter_state_maybe_zero_3(
trainer.model.named_parameters(), keys_to_match
)
trainer.model.config.save_pretrained(output_dir)
current_folder = output_dir.split("/")[-1]
parent_folder = os.path.dirname(output_dir)
if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
if current_folder.startswith("checkpoint-"):
mm_projector_folder = os.path.join(parent_folder, "mm_projector")
os.makedirs(mm_projector_folder, exist_ok=True)
torch.save(
weight_to_save, os.path.join(mm_projector_folder, f"{current_folder}.bin")
)
else:
torch.save(weight_to_save, os.path.join(output_dir, f"mm_projector.bin"))
return
if trainer.deepspeed:
torch.cuda.synchronize()
trainer.save_model(output_dir)
return
state_dict = trainer.model.state_dict()
if trainer.args.should_save:
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
del state_dict
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
def train():
global local_rank
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
local_rank = training_args.local_rank
compute_dtype = (
torch.float16
if training_args.fp16
else (torch.bfloat16 if training_args.bf16 else torch.float32)
)
bnb_model_from_pretrained_args = {}
if training_args.bits in [4, 8]:
from transformers import BitsAndBytesConfig
bnb_model_from_pretrained_args.update(
dict(
device_map={"": training_args.device},
load_in_4bit=training_args.bits == 4,
load_in_8bit=training_args.bits == 8,
quantization_config=BitsAndBytesConfig(
load_in_4bit=training_args.bits == 4,
load_in_8bit=training_args.bits == 8,
llm_int8_skip_modules=["mm_projector"],
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=training_args.double_quant,
bnb_4bit_quant_type=training_args.quant_type, # {'fp4', 'nf4'}
),
)
)
assert model_args.vision_tower is not None
if model_args.model_type in {"mixtral-8x7b"}:
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=True,
)
if tokenizer.unk_token is not None and tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.unk_token
if model_args.model_type == "llama3-8b":
tokenizer.pad_token = tokenizer.eos_token
if model_args.model_type == "mixtral-8x7b":
torch_dtype = torch.float16 if training_args.fp16 else torch.bfloat16
model = VITAMixtralForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
torch_dtype=torch_dtype,
attn_implementation="flash_attention_2",
**bnb_model_from_pretrained_args,
)
else:
raise ValueError(f"Unknown Model Type {model_args.model_type}")
model.config.use_cache = False
if model_args.freeze_backbone:
model.model.requires_grad_(False)
if training_args.bits in [4, 8]:
from peft import prepare_model_for_kbit_training
model.config.torch_dtype = (
torch.float32
if training_args.fp16
else (torch.bfloat16 if training_args.bf16 else torch.float32)
)
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=training_args.gradient_checkpointing
)
if training_args.gradient_checkpointing:
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
if training_args.lora_enable:
from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
r=training_args.lora_r,
lora_alpha=training_args.lora_alpha,
target_modules=find_all_linear_names(model),
lora_dropout=training_args.lora_dropout,
bias=training_args.lora_bias,
task_type="CAUSAL_LM",
)
if training_args.bits == 16:
if training_args.bf16:
model.to(torch.bfloat16)
if training_args.fp16:
model.to(torch.float16)
rank0_print("Adding LoRA adapters...")
model = get_peft_model(model, lora_config)
if model_args.version in conversation_lib.conv_templates:
conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]
else:
conversation_lib.default_conversation = conversation_lib.conv_templates["default"]
model.config.freeze_audio_encoder = model_args.freeze_audio_encoder
model.config.freeze_audio_encoder_adapter = model_args.freeze_audio_encoder_adapter
model.get_model().initialize_vision_modules(model_args=model_args)
model.get_model().initialize_audio_modules(model_args=model_args)
vision_tower = model.get_vision_tower()
vision_tower.to(
dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device
)
audio_encoder = model.get_audio_encoder()
audio_encoder.to(
dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device
)
data_args.image_processor = vision_tower.image_processor
data_args.audio_processor = audio_encoder.audio_processor
model.config.image_aspect_ratio = data_args.image_aspect_ratio
model.config.tokenizer_padding_side = tokenizer.padding_side
model.config.tokenizer_model_max_length = tokenizer.model_max_length
model.config.tune_mm_mlp_adapter = (
training_args.tune_mm_mlp_adapter
) = model_args.tune_mm_mlp_adapter
if model_args.tune_mm_mlp_adapter:
model.requires_grad_(False)
for p in model.get_model().mm_projector.parameters():
p.requires_grad = True
model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter
if training_args.freeze_mm_mlp_adapter:
for p in model.get_model().mm_projector.parameters():
p.requires_grad = False
if training_args.bits in [4, 8]:
model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device)
model.config.mm_projector_lr = training_args.mm_projector_lr
model.config.use_s2 = model_args.use_s2
model.config.unfreeze_vision_tower = (
training_args.unfreeze_vision_tower
) = model_args.unfreeze_vision_tower
if training_args.unfreeze_vision_tower:
for p in model.get_model().vision_tower.parameters():
p.requires_grad = True
if training_args.bits in [4, 8]:
from peft.tuners.lora import LoraLayer
for name, module in model.named_modules():
if isinstance(module, LoraLayer):
if training_args.bf16:
module = module.to(torch.bfloat16)
if "norm" in name:
module = module.to(torch.float32)
if "lm_head" in name or "embed_tokens" in name:
if hasattr(module, "weight"):
if training_args.bf16 and module.weight.dtype == torch.float32:
module = module.to(torch.bfloat16)
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
trainer = VITATrainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
trainer.save_state()
model.config.use_cache = True
if training_args.lora_enable:
state_dict = get_peft_state_maybe_zero_3(model.named_parameters(), training_args.lora_bias)
non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(model.named_parameters())
if training_args.local_rank == 0 or training_args.local_rank == -1:
model.config.save_pretrained(training_args.output_dir)
model.save_pretrained(training_args.output_dir, state_dict=state_dict)
torch.save(
non_lora_state_dict,
os.path.join(training_args.output_dir, "non_lora_trainables.bin"),
)
else:
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
if __name__ == "__main__":
train()
import os
from typing import Any, Dict, List, Optional, Union
import torch
from torch import nn
from torch.utils.data import Sampler
from transformers import Trainer
from transformers.trainer import (
ALL_LAYERNORM_LAYERS,
get_parameter_names,
has_length,
is_sagemaker_mp_enabled,
logger,
)
def maybe_zero_3(param, ignore_status=False, name=None):
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
if hasattr(param, "ds_id"):
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
if not ignore_status:
print(name, "no ignore status")
with zero.GatheredParameters([param]):
param = param.data.detach().cpu().clone()
else:
param = param.detach().cpu().clone()
return param
def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
to_return = {
k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)
}
to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}
return to_return
def split_to_even_chunks(indices, lengths, num_chunks):
"""
Split a list of indices into `chunks` chunks of roughly equal lengths.
"""
if len(indices) % num_chunks != 0:
return [indices[i::num_chunks] for i in range(num_chunks)]
num_indices_per_chunk = len(indices) // num_chunks
chunks = [[] for _ in range(num_chunks)]
chunks_lengths = [0 for _ in range(num_chunks)]
for index in indices:
shortest_chunk = chunks_lengths.index(min(chunks_lengths))
chunks[shortest_chunk].append(index)
chunks_lengths[shortest_chunk] += lengths[index]
if len(chunks[shortest_chunk]) == num_indices_per_chunk:
chunks_lengths[shortest_chunk] = float("inf")
return chunks
def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None):
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
assert all(l != 0 for l in lengths), "Should not have zero length."
if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
# all samples are in the same modality
return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator)
mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
mm_shuffle = [
mm_indices[i]
for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)
]
lang_shuffle = [
lang_indices[i]
for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)
]
megabatch_size = world_size * batch_size
mm_megabatches = [
mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)
]
lang_megabatches = [
lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)
]
last_mm = mm_megabatches[-1]
last_lang = lang_megabatches[-1]
additional_batch = last_mm + last_lang
megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
megabatch_indices = torch.randperm(len(megabatches), generator=generator)
megabatches = [megabatches[i] for i in megabatch_indices]
if len(additional_batch) > 0:
megabatches.append(sorted(additional_batch))
return [i for megabatch in megabatches for i in megabatch]
def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
indices = torch.randperm(len(lengths), generator=generator)
megabatch_size = world_size * batch_size
megabatches = [
indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)
]
megabatches = [
sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches
]
megabatches = [
split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches
]
return [i for megabatch in megabatches for batch in megabatch for i in batch]
class LengthGroupedSampler(Sampler):
r"""
Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
keeping a bit of randomness.
"""
def __init__(
self,
batch_size: int,
world_size: int,
lengths: Optional[List[int]] = None,
generator=None,
group_by_modality: bool = False,
):
if lengths is None:
raise ValueError("Lengths must be provided.")
self.batch_size = batch_size
self.world_size = world_size
self.lengths = lengths
self.generator = generator
self.group_by_modality = group_by_modality
def __len__(self):
return len(self.lengths)
def __iter__(self):
if self.group_by_modality:
indices = get_modality_length_grouped_indices(
self.lengths, self.batch_size, self.world_size, generator=self.generator
)
else:
indices = get_length_grouped_indices(
self.lengths, self.batch_size, self.world_size, generator=self.generator
)
return iter(indices)
class VITATrainer(Trainer):
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.train_dataset is None or not has_length(self.train_dataset):
return None
if self.args.group_by_modality_length:
lengths = self.train_dataset.modality_lengths
return LengthGroupedSampler(
self.args.train_batch_size,
world_size=self.args.world_size * self.args.gradient_accumulation_steps,
lengths=lengths,
group_by_modality=True,
)
else:
return super()._get_train_sampler()
def create_optimizer(self):
"""
Setup the optimizer.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
"""
if is_sagemaker_mp_enabled():
return super().create_optimizer()
opt_model = self.model
if self.optimizer is None:
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
if self.args.mm_projector_lr is not None:
projector_parameters = [
name
for name, _ in opt_model.named_parameters()
if "mm_projector" in name or "vision_tower" in name
]
optimizer_grouped_parameters = [
{
"params": [
p
for n, p in opt_model.named_parameters()
if (
n in decay_parameters
and n not in projector_parameters
and p.requires_grad
)
],
"weight_decay": self.args.weight_decay,
},
{
"params": [
p
for n, p in opt_model.named_parameters()
if (
n not in decay_parameters
and n not in projector_parameters
and p.requires_grad
)
],
"weight_decay": 0.0,
},
{
"params": [
p
for n, p in opt_model.named_parameters()
if (
n in decay_parameters
and n in projector_parameters
and p.requires_grad
)
],
"weight_decay": self.args.weight_decay,
"lr": self.args.mm_projector_lr,
},
{
"params": [
p
for n, p in opt_model.named_parameters()
if (
n not in decay_parameters
and n in projector_parameters
and p.requires_grad
)
],
"weight_decay": 0.0,
"lr": self.args.mm_projector_lr,
},
]
else:
optimizer_grouped_parameters = [
{
"params": [
p
for n, p in opt_model.named_parameters()
if (n in decay_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
},
{
"params": [
p
for n, p in opt_model.named_parameters()
if (n not in decay_parameters and p.requires_grad)
],
"weight_decay": 0.0,
},
]
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum(
{p.data_ptr(): p.numel() for p in module.parameters()}.values()
)
logger.info(f"skipped {module}: {skipped / 2 ** 20}M params")
manager.register_module_override(module, "weight", {"optim_bits": 32})
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
logger.info(f"skipped: {skipped / 2 ** 20}M params")
return self.optimizer
def _save_checkpoint(self, model, trial, metrics=None):
# print('model.model.audio_encoder.adpter.project.weight')
# print(model.model.audio_encoder.adpter.project.weight)
# print('model.model.audio_encoder.adpter.project.weight.requires_grad')
# print(model.model.audio_encoder.adpter.project.weight.requires_grad)
if getattr(self.args, "tune_mm_mlp_adapter", False):
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
# Only save Adapter
keys_to_match = ["mm_projector", "vision_resampler"]
if getattr(self.args, "use_im_start_end", False):
keys_to_match.extend(["embed_tokens", "embed_in"])
weight_to_save = get_mm_adapter_state_maybe_zero_3(
self.model.named_parameters(), keys_to_match
)
if self.args.local_rank == 0 or self.args.local_rank == -1:
self.model.config.save_pretrained(output_dir)
torch.save(weight_to_save, os.path.join(output_dir, f"mm_projector.bin"))
else:
super(VITATrainer, self)._save_checkpoint(model, trial, metrics)
def _save(self, output_dir: Optional[str] = None, state_dict=None):
if getattr(self.args, "tune_mm_mlp_adapter", False):
pass
else:
super(VITATrainer, self)._save(output_dir, state_dict)
def training_step(
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]
) -> torch.Tensor:
"""
Perform a training step on a batch of inputs.
Print
Subclass and override to inject custom behavior.
Args:
model (`nn.Module`):
The model to train.
inputs (`Dict[str, Union[torch.Tensor, Any]]`):
The inputs and targets of the model.
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument `labels`. Check your model's documentation for all accepted arguments.
Return:
`torch.Tensor`: The tensor with training loss on this batch.
"""
tr_loss_step = super().training_step(model, inputs)
return tr_loss_step
# try:
# #import pdb; pdb.set_trace()
# tr_loss_step = super().training_step(model, inputs)
# return tr_loss_step
# except Exception as e:
# print('------------------------------------------------len of images------------------------------------------------')
# print(len(inputs['images']))
# print('------------------------------------------------input_ids------------------------------------------------')
# print(inputs['input_ids'].tolist())
# print(e)
import copy
import json
import math
import os
import random
from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence
import numpy as np
import torch
import transformers
from PIL import Image
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
from decord import VideoReader, cpu
from vita import conversation as conversation_lib
from vita.config import AudioFolder, DataConfig, FolderDict
from vita.constants import (
DEFAULT_AUDIO_TOKEN,
DEFAULT_IMAGE_TOKEN,
DEFAULT_VIDEO_TOKEN,
IGNORE_INDEX,
MAX_IMAGE_LENGTH,
MIN_IMAGE_LENGTH,
)
from vita.util.mm_utils import tokenizer_image_audio_token, tokenizer_image_token
@dataclass
class DataArguments:
lazy_preprocess: bool = False
is_multimodal: bool = True
image_folder: Optional[str] = field(default=None)
image_aspect_ratio: str = field(default=None)
dataset_use: str = field(default="temp")
def preprocess_multimodal(
sources: Sequence[str], data_args: DataArguments, image_token_num=1, audio_lens: int = 0
) -> Dict:
is_multimodal = data_args.is_multimodal
if not is_multimodal:
return sources
for source in sources:
for sentence in source:
if DEFAULT_IMAGE_TOKEN in sentence["value"] or DEFAULT_VIDEO_TOKEN in sentence["value"]:
sentence["value"] = (
sentence["value"]
.replace(DEFAULT_IMAGE_TOKEN + "\n", DEFAULT_IMAGE_TOKEN)
.strip()
)
sentence["value"] = (
sentence["value"]
.replace("\n" + DEFAULT_IMAGE_TOKEN, DEFAULT_IMAGE_TOKEN)
.strip()
)
if sentence["value"].endswith(DEFAULT_IMAGE_TOKEN):
IMAGE_TOKEN_NUM = sentence["value"].count(DEFAULT_IMAGE_TOKEN)
sentence["value"] = (
sentence["value"].replace(DEFAULT_IMAGE_TOKEN * IMAGE_TOKEN_NUM, "").strip()
)
sentence["value"] = DEFAULT_IMAGE_TOKEN * IMAGE_TOKEN_NUM + sentence["value"]
sentence["value"] = sentence["value"].strip()
if sentence["value"].endswith(DEFAULT_VIDEO_TOKEN):
VIDEO_TOKEN_NUM = sentence["value"].count(DEFAULT_VIDEO_TOKEN)
sentence["value"] = (
sentence["value"].replace(DEFAULT_VIDEO_TOKEN * VIDEO_TOKEN_NUM, "").strip()
)
sentence["value"] = DEFAULT_VIDEO_TOKEN * VIDEO_TOKEN_NUM + sentence["value"]
sentence["value"] = sentence["value"].strip()
if "mmtag" in conversation_lib.default_conversation.version:
sentence["value"] = sentence["value"].replace(
DEFAULT_IMAGE_TOKEN, "<Image>" + DEFAULT_IMAGE_TOKEN + "</Image>"
)
IMAGE_TOKEN_NUM = sentence["value"].count(DEFAULT_IMAGE_TOKEN)
if IMAGE_TOKEN_NUM > MAX_IMAGE_LENGTH:
sentence["value"] = (
sentence["value"]
.replace(
DEFAULT_IMAGE_TOKEN * IMAGE_TOKEN_NUM,
DEFAULT_IMAGE_TOKEN * MAX_IMAGE_LENGTH,
)
.strip()
)
replace_token, vid_replace_token, audio_replace_token = (
DEFAULT_IMAGE_TOKEN,
DEFAULT_IMAGE_TOKEN * image_token_num,
DEFAULT_AUDIO_TOKEN,
) # * audio_lens
sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token + "\n")
sentence["value"] = sentence["value"].replace(
DEFAULT_VIDEO_TOKEN, vid_replace_token + "\n"
)
sentence["value"] = sentence["value"].replace(
DEFAULT_AUDIO_TOKEN + "\n", audio_replace_token
)
sentence["value"] = sentence["value"].replace("\n\n", "\n")
return sources
def preprocess_mixtral_zh(
sources,
tokenizer: transformers.PreTrainedTokenizer,
has_image: bool = False,
has_audio: bool = False,
) -> Dict:
conv = conversation_lib.default_conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt())
# Tokenize conversations
if has_image and not has_audio:
input_ids = torch.stack(
[
tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
for prompt in conversations
],
dim=0,
)
elif has_image and has_audio:
input_ids = torch.stack(
[
tokenizer_image_audio_token(prompt, tokenizer, return_tensors="pt")
for prompt in conversations
],
dim=0,
)
elif not has_image and has_audio:
input_ids = torch.stack(
[
tokenizer_image_audio_token(prompt, tokenizer, return_tensors="pt")
for prompt in conversations
],
dim=0,
)
else:
input_ids = tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
assert conv.sep_style == conversation_lib.SeparatorStyle.MixtralZh
# Mask targets
sep = conv.sep + "\n" + conv.roles[1] + ":"
sep2_2 = "\n" + conv.roles[0] + ":"
sep2 = conv.sep2 + sep2_2
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
rounds = conversation.split(sep2)
rounds = [rounds[0] + sep2 + rounds[1]] + rounds[2:]
cur_len = 1
end_token_cnt = 0
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(rounds):
if rou == "":
break
if i > 0:
rou = sep2_2 + rou
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
if has_image and not has_audio:
round_len = len(tokenizer_image_token(rou, tokenizer))
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1
elif has_image and has_audio:
round_len = len(tokenizer_image_audio_token(rou, tokenizer))
instruction_len = len(tokenizer_image_audio_token(parts[0], tokenizer)) - 1
elif not has_image and has_audio:
round_len = len(tokenizer_image_audio_token(rou, tokenizer))
instruction_len = len(tokenizer_image_audio_token(parts[0], tokenizer)) - 1
else:
round_len = len(tokenizer(rou).input_ids)
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
end_token_cnt += 1
cur_len += round_len
cur_len = cur_len - 1
target[cur_len:] = IGNORE_INDEX
if tokenizer.pad_token_id == tokenizer.eos_token_id:
cur_len -= end_token_cnt
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_INDEX
print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)")
# print(f"YOU NEED GO TO DEBUG THIS DATA ITEM: {conversations}")
return dict(
input_ids=input_ids,
labels=targets,
)
def preprocess_plain(
sources: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
# add end signal and concatenate together
conversations = []
for source in sources:
assert len(source) == 2
assert DEFAULT_IMAGE_TOKEN in source[0]["value"]
source[0]["value"] = DEFAULT_IMAGE_TOKEN
conversation = (
source[0]["value"] + source[1]["value"] + conversation_lib.default_conversation.sep
)
conversations.append(conversation)
# tokenize conversations
input_ids = [
tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations
]
targets = copy.deepcopy(input_ids)
for target, source in zip(targets, sources):
tokenized_len = len(tokenizer_image_token(source[0]["value"], tokenizer))
target[:tokenized_len] = IGNORE_INDEX
return dict(input_ids=input_ids, labels=targets)
def preprocess(
sources: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
has_image: bool = False,
has_audio: bool = False,
) -> Dict:
if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:
return preprocess_plain(sources, tokenizer)
if conversation_lib.default_conversation.version == "mixtral_zh":
return preprocess_mixtral_zh(sources, tokenizer, has_image=has_image, has_audio=has_audio)
def _get_rawvideo_dec(
video_path,
image_processor,
max_frames=32,
min_frames=4,
image_resolution=384,
video_framerate=1,
s=None,
e=None,
image_aspect_ratio="pad",
):
# speed up video decode via decord.
video_mask = np.zeros(max_frames, dtype=np.int64)
max_video_length = 0
# T x 3 x H x W
video = np.zeros((max_frames, 3, image_resolution, image_resolution), dtype=np.float64)
if s is None:
start_time, end_time = None, None
else:
start_time = int(s)
end_time = int(e)
start_time = start_time if start_time >= 0.0 else 0.0
end_time = end_time if end_time >= 0.0 else 0.0
if start_time > end_time:
start_time, end_time = end_time, start_time
elif start_time == end_time:
end_time = start_time + 1
if os.path.exists(video_path):
vreader = VideoReader(video_path, ctx=cpu(0))
else:
print(video_path)
raise FileNotFoundError
fps = vreader.get_avg_fps()
f_start = 0 if start_time is None else int(start_time * fps)
f_end = int(min(1000000000 if end_time is None else end_time * fps, len(vreader) - 1))
num_frames = f_end - f_start + 1
if num_frames > 0:
# T x 3 x H x W
sample_fps = int(video_framerate)
t_stride = int(round(float(fps) / sample_fps))
all_pos = list(range(f_start, f_end + 1, t_stride))
if len(all_pos) > max_frames:
sample_pos = [
all_pos[_] for _ in np.linspace(0, len(all_pos) - 1, num=max_frames, dtype=int)
]
elif len(all_pos) < min_frames:
sample_pos = [
all_pos[_] for _ in np.linspace(0, len(all_pos) - 1, num=min_frames, dtype=int)
]
else:
sample_pos = all_pos
patch_images = [Image.fromarray(f) for f in vreader.get_batch(sample_pos).asnumpy()]
if image_aspect_ratio == "pad":
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
patch_images = [
expand2square(i, tuple(int(x * 255) for x in image_processor.image_mean))
for i in patch_images
]
patch_images = [
image_processor.preprocess(i, return_tensors="pt")["pixel_values"][0]
for i in patch_images
]
else:
patch_images = [
image_processor.preprocess(i, return_tensors="pt")["pixel_values"][0]
for i in patch_images
]
# patch_images = [image_processor.preprocess(img, return_tensors='pt')['pixel_values'][0] for img in patch_images]
slice_len = len(patch_images)
return patch_images, slice_len
max_video_length = max_video_length if max_video_length > slice_len else slice_len
if slice_len < 1:
pass
else:
while len(patch_images) < max_frames:
patch_images.append(torch.zeros((3, image_resolution, image_resolution)))
# video[:slice_len, ...] = patch_images
else:
print("video path: {} error.".format(video_path))
video_mask[:max_video_length] = [1] * max_video_length
return patch_images, video_mask
class LazySupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(self, tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments):
super(LazySupervisedDataset, self).__init__()
dataset_list = DataConfig[str(data_args.dataset_use)]
print(dataset_list)
self.max_length = MAX_IMAGE_LENGTH
list_data_dict = []
self.folder_dict = {}
for i in dataset_list:
list_data_dict += json.load(open(i["chat_path"], "r"))
image_folder = [folder for folder in i if folder is not "chat_path"]
for folder in image_folder:
if folder not in self.folder_dict:
self.folder_dict[folder] = i[folder]
for key in FolderDict.keys():
if key not in self.folder_dict:
self.folder_dict[key] = FolderDict[key]
random.shuffle(list_data_dict)
self.tokenizer = tokenizer
self.list_data_dict = list_data_dict
self.data_args = data_args
def __len__(self):
return len(self.list_data_dict)
# @property
# def lengths(self):
# length_list = []
# for sample in self.list_data_dict:
# img_tokens = 128 if 'image' in sample else 0
# length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens)
# return length_list
@property
def modality_lengths(self):
length_list = []
for sample in self.list_data_dict:
cur_len = sum(len(conv["value"].split()) for conv in sample["conversations"])
cur_len = cur_len if ("image" in sample or "video" in sample) else -cur_len
length_list.append(cur_len)
return length_list
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
sources = self.list_data_dict[i]
if isinstance(i, int):
sources = [sources]
assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
if "image" in sources[0] and "audio" not in sources[0]:
image_file = self.list_data_dict[i]["image"]
set_id = self.list_data_dict[i].get("set", None)
file = image_file[0] if type(image_file) is list else image_file
processor = self.data_args.image_processor
if type(image_file) is list:
assert type(set_id) is list
if len(image_file) != len(set_id):
assert len(set(set_id)) == 1
image = [
Image.open(
os.path.join(self.folder_dict[set_id[k]], file.replace("\\", "/"))
).convert("RGB")
for k, file in enumerate(image_file)
]
if self.data_args.image_aspect_ratio == "pad":
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = [
expand2square(i, tuple(int(x * 255) for x in processor.image_mean))
for i in image
]
image = [
processor.preprocess(i, return_tensors="pt")["pixel_values"][0]
for i in image
]
else:
image = [
processor.preprocess(i, return_tensors="pt")["pixel_values"][0]
for i in image
]
else:
image_folder = self.folder_dict[set_id]
image = Image.open(
os.path.join(image_folder, image_file.replace("\\", "/"))
).convert("RGB")
if self.data_args.image_aspect_ratio == "pad":
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
else:
image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
sources = preprocess_multimodal(
copy.deepcopy([e["conversations"] for e in sources]), self.data_args
)
data_dict = preprocess(sources, self.tokenizer, has_image=True)
elif "image" in sources[0] and "audio" in sources[0]:
image_file = self.list_data_dict[i]["image"]
set_id = self.list_data_dict[i].get("set", None)
file = image_file[0] if type(image_file) is list else image_file
audio_file = self.list_data_dict[i]["audio"]
processor = self.data_args.image_processor
if type(image_file) is list:
assert type(set_id) is list
if len(image_file) != len(set_id): # 多图数据
assert len(set(set_id)) == 1
image = [
Image.open(
os.path.join(self.folder_dict[set_id[k]], file.replace("\\", "/"))
).convert("RGB")
for k, file in enumerate(image_file)
]
if self.data_args.image_aspect_ratio == "pad":
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = [
expand2square(i, tuple(int(x * 255) for x in processor.image_mean))
for i in image
]
image = [
processor.preprocess(i, return_tensors="pt")["pixel_values"][0]
for i in image
]
else:
image = [
processor.preprocess(i, return_tensors="pt")["pixel_values"][0]
for i in image
]
else:
image_folder = self.folder_dict[set_id]
image = Image.open(
os.path.join(image_folder, image_file.replace("\\", "/"))
).convert("RGB")
if self.data_args.image_aspect_ratio == "pad":
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
else:
image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
if type(audio_file) is list:
# if type(set_id) is list:
# audio_folder = self.folder_dict[set_id[0]+'_audio']
# else:
# audio_folder = self.folder_dict[set_id+'_audio']
audio_folder = AudioFolder
assert len(audio_file) > 0, "audio_file为列表时不能为空"
audio = []
audio_for_llm_lens = []
audio_length = []
for file in audio_file:
try:
a, a_llm = self.data_args.audio_processor.process(
os.path.join(audio_folder, "audio", file)
)
except:
print(f"File {os.path.join(audio_folder, 'audio', file)} not OK!!!!!")
audio.append(a)
audio_for_llm_lens.append(a_llm)
audio_length.append(a.shape[0])
else:
# audio_folder = self.folder_dict[set_id+'_audio']
audio_folder = AudioFolder
assert audio_file, "audio_file不能为空"
audio, audio_for_llm_lens = self.data_args.audio_processor.process(
os.path.join(audio_folder, "audio", audio_file)
)
audio_length = audio.shape[0]
sources = preprocess_multimodal(
copy.deepcopy([e["conversations"] for e in sources]),
self.data_args,
audio_lens=audio_for_llm_lens,
)
data_dict = preprocess(sources, self.tokenizer, has_image=True, has_audio=True)
data_dict["audio_lengths"] = audio_length
data_dict["audio_lengths_for_llm"] = audio_for_llm_lens
elif "video" in sources[0] and "audio" not in sources[0]:
video_file = self.list_data_dict[i]["video"]
video_id = self.list_data_dict[i]["id"]
set_id = self.list_data_dict[i].get("set", None)
processor = self.data_args.image_processor
if "height" in processor.size.keys():
image_size = processor.size["height"]
elif "shortest_edge" in processor.size.keys():
image_size = processor.size["shortest_edge"]
else:
raise NotImplementedError(f"Please use correct key to use processor size!")
video_folder = self.folder_dict[set_id]
image, image_token_num = _get_rawvideo_dec(
os.path.join(video_folder, video_file),
processor,
max_frames=MAX_IMAGE_LENGTH,
min_frames=MIN_IMAGE_LENGTH,
image_resolution=image_size,
)
sources = preprocess_multimodal(
copy.deepcopy([e["conversations"] for e in sources]),
self.data_args,
image_token_num=image_token_num,
)
data_dict = preprocess(sources, self.tokenizer, has_image=True, has_audio=False)
elif "video" in sources[0] and "audio" in sources[0]:
video_file = self.list_data_dict[i]["video"]
video_id = self.list_data_dict[i]["id"]
set_id = self.list_data_dict[i].get("set", None)
audio_file = self.list_data_dict[i]["audio"]
processor = self.data_args.image_processor
if "height" in processor.size.keys():
image_size = processor.size["height"]
elif "shortest_edge" in processor.size.keys():
image_size = processor.size["shortest_edge"]
else:
raise NotImplementedError(f"Please use correct key to use processor size!")
video_folder = self.folder_dict[set_id]
# audio_folder = self.folder_dict[set_id+'_audio']
audio_folder = AudioFolder
image, image_token_num = _get_rawvideo_dec(
os.path.join(video_folder, video_file),
processor,
max_frames=MAX_IMAGE_LENGTH,
min_frames=MIN_IMAGE_LENGTH,
image_resolution=image_size,
)
if type(audio_file) is list:
assert len(audio_file) > 0, "audio_file为列表时不能为空"
audio = []
audio_for_llm_lens = []
audio_length = []
for file in audio_file:
a, a_llm = self.data_args.audio_processor.process(
os.path.join(audio_folder, "audio", file)
)
audio.append(a)
audio_for_llm_lens.append(a_llm)
audio_length.append(a.shape[0])
else:
assert audio_file, "audio_file不能为空"
audio, audio_for_llm_lens = self.data_args.audio_processor.process(
os.path.join(audio_folder, "audio", audio_file)
)
audio_length = audio.shape[0]
sources = preprocess_multimodal(
copy.deepcopy([e["conversations"] for e in sources]),
self.data_args,
image_token_num=image_token_num,
audio_lens=audio_for_llm_lens,
)
data_dict = preprocess(sources, self.tokenizer, has_image=True, has_audio=True)
data_dict["audio_lengths"] = audio_length
data_dict["audio_lengths_for_llm"] = audio_for_llm_lens
elif "audio" in sources[0]:
audio_file = self.list_data_dict[i]["audio"]
audio_folder = AudioFolder
if type(audio_file) is list:
assert len(audio_file) > 0, "audio_file为列表时不能为空"
audio = []
audio_for_llm_lens = []
audio_length = []
for file in audio_file:
a, a_llm = self.data_args.audio_processor.process(
os.path.join(audio_folder, "audio", file)
)
audio.append(a)
audio_for_llm_lens.append(a_llm)
audio_length.append(a.shape[0])
else:
assert audio_file, "audio_file不能为空"
audio, audio_for_llm_lens = self.data_args.audio_processor.process(
os.path.join(audio_folder, "audio", audio_file)
)
audio_length = audio.shape[0]
sources = preprocess_multimodal(
copy.deepcopy([e["conversations"] for e in sources]),
self.data_args,
image_token_num=0,
audio_lens=audio_for_llm_lens,
)
data_dict = preprocess(sources, self.tokenizer, has_image=False, has_audio=True)
data_dict["audio_lengths"] = audio_length
data_dict["audio_lengths_for_llm"] = audio_for_llm_lens
else:
sources = copy.deepcopy([e["conversations"] for e in sources])
data_dict = preprocess(sources, self.tokenizer, has_image=False)
if isinstance(i, int):
if "audio" in self.list_data_dict[i]:
data_dict = dict(
input_ids=data_dict["input_ids"][0],
labels=data_dict["labels"][0],
audio_lengths=data_dict["audio_lengths"],
audio_lengths_for_llm=data_dict["audio_lengths_for_llm"],
)
else:
data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0])
# image exist in the data
if "image" in self.list_data_dict[i] or "video" in self.list_data_dict[i]:
data_dict["image"] = image
elif self.data_args.is_multimodal:
# image does not exist in the data, but the model is multimodal
crop_size = self.data_args.image_processor.crop_size
data_dict["image"] = torch.zeros(3, crop_size["height"], crop_size["width"])
if "audio" in self.list_data_dict[i]:
data_dict["audio"] = audio
elif self.data_args.is_multimodal:
data_dict["audio"] = torch.zeros(400, 80)
data_dict["audio_lengths"] = 400
data_dict["audio_lengths_for_llm"] = 60
return data_dict
@dataclass
class DataCollatorForSupervisedDataset(object):
"""Collate examples for supervised fine-tuning."""
tokenizer: transformers.PreTrainedTokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple(
[instance[key] for instance in instances] for key in ("input_ids", "labels")
)
if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
for input_id in input_ids:
input_id[input_id == self.tokenizer.eos_token_id] = -300
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
)
labels = torch.nn.utils.rnn.pad_sequence(
labels, batch_first=True, padding_value=IGNORE_INDEX
)
input_ids = input_ids[:, : self.tokenizer.model_max_length]
attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
labels = labels[:, : self.tokenizer.model_max_length]
if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
for input_id in input_ids:
input_id[input_id == -300] = self.tokenizer.eos_token_id
batch = dict(
input_ids=input_ids,
labels=labels,
attention_mask=attention_mask,
)
if "image" in instances[0]:
images = [instance["image"] for instance in instances]
new_images = []
for image in images:
if type(image) is list:
for i in image:
new_images.append(i)
else:
new_images.append(image)
images = new_images
if all(x is not None and x.shape == images[0].shape for x in images):
batch["images"] = torch.stack(images)
else:
batch["images"] = images
batch["audios"] = {}
if "audio" in instances[0]:
audios = [instance["audio"] for instance in instances]
audio_lengths = [instance["audio_lengths"] for instance in instances]
audio_lengths_for_llm = [instance["audio_lengths_for_llm"] for instance in instances]
new_audios = []
new_audio_lengths = []
new_audio_lengths_for_llm = []
for i, audio in enumerate(audios):
length = audio_lengths[i]
length_for_llm = audio_lengths_for_llm[i]
if type(audio) is list:
for j, a in enumerate(audio):
new_audios.append(a)
new_audio_lengths.append(length[j])
new_audio_lengths_for_llm.append(length_for_llm[j])
else:
new_audios.append(audio)
new_audio_lengths.append(length)
new_audio_lengths_for_llm.append(length_for_llm)
audios = new_audios
audios = pad_sequence(audios, batch_first=True, padding_value=0)
batch["audios"]["audios"] = audios
batch["audios"]["lengths"] = torch.tensor(new_audio_lengths)
batch["audios"]["lengths_for_llm"] = torch.tensor(new_audio_lengths_for_llm)
return batch
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
train_dataset = LazySupervisedDataset(tokenizer=tokenizer, data_args=data_args)
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
import copy
import json
import math
import os
import random
from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence
import matplotlib.pyplot as plt
import numpy as np
import torch
import transformers
from PIL import Image
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
from decord import VideoReader, cpu
from vita import conversation as conversation_lib
from vita.config import AudioFolder, DataConfig, FolderDict, NoPatchSets
from vita.constants import (
DEFAULT_AUDIO_TOKEN,
DEFAULT_DATA_RATIO,
DEFAULT_IMAGE_TOKEN,
DEFAULT_VIDEO_TOKEN,
IGNORE_INDEX,
MAX_IMAGE_LENGTH,
MIN_IMAGE_LENGTH,
)
from vita.util.mm_utils import tokenizer_image_audio_token, tokenizer_image_token
@dataclass
class DataArguments:
lazy_preprocess: bool = False
is_multimodal: bool = True
image_folder: Optional[str] = field(default=None)
image_aspect_ratio: str = field(default=None)
dataset_use: str = field(default="temp")
min_dynamic_patch: int = 2
max_dynamic_patch: int = 12
use_thumbnail: bool = True
def preprocess_multimodal(
sources: Sequence[str],
data_args: DataArguments,
image_token_num=1,
patch_num=[1],
audio_lens: int = 0,
inserted_id=None,
) -> Dict:
is_multimodal = data_args.is_multimodal
if not is_multimodal:
return sources
k_img_ph = 0
for source in sources:
if inserted_id is not None:
assert source[inserted_id]["from"] == "gpt"
for i, sentence in enumerate(source):
if DEFAULT_IMAGE_TOKEN in sentence["value"] or DEFAULT_VIDEO_TOKEN in sentence["value"]:
sentence["value"] = (
sentence["value"]
.replace(DEFAULT_IMAGE_TOKEN + "\n", DEFAULT_IMAGE_TOKEN)
.strip()
)
sentence["value"] = (
sentence["value"]
.replace("\n" + DEFAULT_IMAGE_TOKEN, DEFAULT_IMAGE_TOKEN)
.strip()
)
if sentence["value"].endswith(DEFAULT_IMAGE_TOKEN):
IMAGE_TOKEN_NUM = sentence["value"].count(DEFAULT_IMAGE_TOKEN)
sentence["value"] = (
sentence["value"].replace(DEFAULT_IMAGE_TOKEN * IMAGE_TOKEN_NUM, "").strip()
)
sentence["value"] = DEFAULT_IMAGE_TOKEN * IMAGE_TOKEN_NUM + sentence["value"]
sentence["value"] = sentence["value"].strip()
if sentence["value"].endswith(DEFAULT_VIDEO_TOKEN):
VIDEO_TOKEN_NUM = sentence["value"].count(DEFAULT_VIDEO_TOKEN)
sentence["value"] = (
sentence["value"].replace(DEFAULT_VIDEO_TOKEN * VIDEO_TOKEN_NUM, "").strip()
)
sentence["value"] = DEFAULT_VIDEO_TOKEN * VIDEO_TOKEN_NUM + sentence["value"]
sentence["value"] = sentence["value"].strip()
if "mmtag" in conversation_lib.default_conversation.version:
sentence["value"] = sentence["value"].replace(
DEFAULT_IMAGE_TOKEN, "<Image>" + DEFAULT_IMAGE_TOKEN + "</Image>"
)
IMAGE_TOKEN_NUM = sentence["value"].count(DEFAULT_IMAGE_TOKEN)
if IMAGE_TOKEN_NUM > MAX_IMAGE_LENGTH:
sentence["value"] = (
sentence["value"]
.replace(
DEFAULT_IMAGE_TOKEN * IMAGE_TOKEN_NUM,
DEFAULT_IMAGE_TOKEN * MAX_IMAGE_LENGTH,
)
.strip()
)
replace_token, vid_replace_token, audio_replace_token = (
DEFAULT_IMAGE_TOKEN,
DEFAULT_IMAGE_TOKEN * image_token_num,
DEFAULT_AUDIO_TOKEN,
) # * audio_lens
if DEFAULT_IMAGE_TOKEN in sentence["value"]:
replace_token = DEFAULT_IMAGE_TOKEN * patch_num[k_img_ph]
k_img_ph += 1
sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token + "\n")
sentence["value"] = sentence["value"].replace(
DEFAULT_VIDEO_TOKEN, vid_replace_token + "\n"
)
sentence["value"] = sentence["value"].replace(
DEFAULT_AUDIO_TOKEN + "\n", audio_replace_token
)
sentence["value"] = sentence["value"].replace("\n\n", "\n")
if i == inserted_id:
assert sentence["from"] == "gpt"
sentence["value"] = "<2>" + sentence["value"]
elif sentence["from"] == "gpt":
if "<audio>" in source[i - 1]["value"]:
sentence["value"] = "<1>" + sentence["value"]
else:
sentence["value"] = "<3>" + sentence["value"]
# print(patch_num)
# print(sum(patch_num))
# print(sources)
# import pdb; pdb.set_trace()
return sources
def preprocess_mixtral_zh(
sources,
tokenizer: transformers.PreTrainedTokenizer,
has_image: bool = False,
has_audio: bool = False,
end_tag: bool = True,
) -> Dict:
conv = conversation_lib.default_conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt())
# Tokenize conversations
if not end_tag:
conversations[0] = conversations[0][:-4]
if has_image and not has_audio:
input_ids = torch.stack(
[
tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
for prompt in conversations
],
dim=0,
)
elif has_image and has_audio:
input_ids = torch.stack(
[
tokenizer_image_audio_token(prompt, tokenizer, return_tensors="pt")
for prompt in conversations
],
dim=0,
)
elif not has_image and has_audio:
input_ids = torch.stack(
[
tokenizer_image_audio_token(prompt, tokenizer, return_tensors="pt")
for prompt in conversations
],
dim=0,
)
else:
input_ids = tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
# print(f'end_tag: {end_tag}')
# print(conversations)
# print(input_ids)
# import pdb; pdb.set_trace()
targets = input_ids.clone()
assert conv.sep_style == conversation_lib.SeparatorStyle.MixtralZh
# Mask targets
sep = conv.sep + "\n" + conv.roles[1] + ":"
sep2_2 = "\n" + conv.roles[0] + ":"
sep2 = conv.sep2 + sep2_2
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
rounds = conversation.split(sep2)
rounds = [rounds[0] + sep2 + rounds[1]] + rounds[2:]
cur_len = 1
end_token_cnt = 0
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(rounds):
if rou == "":
break
if i > 0:
rou = sep2_2 + rou
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
if has_image and not has_audio:
round_len = len(tokenizer_image_token(rou, tokenizer))
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1
elif has_image and has_audio:
round_len = len(tokenizer_image_audio_token(rou, tokenizer))
instruction_len = len(tokenizer_image_audio_token(parts[0], tokenizer)) - 1
elif not has_image and has_audio:
round_len = len(tokenizer_image_audio_token(rou, tokenizer))
instruction_len = len(tokenizer_image_audio_token(parts[0], tokenizer)) - 1
else:
round_len = len(tokenizer(rou).input_ids)
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
end_token_cnt += 1
cur_len += round_len
cur_len = cur_len - 1
target[cur_len:] = IGNORE_INDEX
if tokenizer.pad_token_id == tokenizer.eos_token_id:
cur_len -= end_token_cnt
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_INDEX
print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)")
# print(f"YOU NEED GO TO DEBUG THIS DATA ITEM: {conversations}")
return dict(
input_ids=input_ids,
labels=targets,
)
def preprocess_mixtral_two(
sources,
tokenizer: transformers.PreTrainedTokenizer,
has_image: bool = False,
has_audio: bool = False,
end_tag: bool = True,
modality: str = "lang",
) -> Dict:
conv = conversation_lib.default_conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt(modality))
# print(conversations)
# import pdb; pdb.set_trace()
# Tokenize conversations
if not end_tag:
conversations[0] = conversations[0][:-4]
if has_image and not has_audio:
input_ids = torch.stack(
[
tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
for prompt in conversations
],
dim=0,
)
elif has_image and has_audio:
input_ids = torch.stack(
[
tokenizer_image_audio_token(prompt, tokenizer, return_tensors="pt")
for prompt in conversations
],
dim=0,
)
elif not has_image and has_audio:
input_ids = torch.stack(
[
tokenizer_image_audio_token(prompt, tokenizer, return_tensors="pt")
for prompt in conversations
],
dim=0,
)
else:
input_ids = tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
# print(f'end_tag: {end_tag}')
# print(conversations)
# print(input_ids)
# import pdb; pdb.set_trace()
targets = input_ids.clone()
assert conv.sep_style == conversation_lib.SeparatorStyle.MixtralTwo
# Mask targets
sep = conv.sep + "\n" + conv.roles[1] + ":"
sep2_2 = "\n" + conv.roles[0] + ":"
sep2 = conv.sep2 + sep2_2
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
rounds = conversation.split(sep2)
rounds = [rounds[0] + sep2 + rounds[1]] + rounds[2:]
cur_len = 1
end_token_cnt = 0
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(rounds):
if rou == "":
break
if i > 0:
rou = sep2_2 + rou
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
if has_image and not has_audio:
round_len = len(tokenizer_image_token(rou, tokenizer))
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1
elif has_image and has_audio:
round_len = len(tokenizer_image_audio_token(rou, tokenizer))
instruction_len = len(tokenizer_image_audio_token(parts[0], tokenizer)) - 1
elif not has_image and has_audio:
round_len = len(tokenizer_image_audio_token(rou, tokenizer))
instruction_len = len(tokenizer_image_audio_token(parts[0], tokenizer)) - 1
else:
round_len = len(tokenizer(rou).input_ids)
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
end_token_cnt += 1
cur_len += round_len
cur_len = cur_len - 1
target[cur_len:] = IGNORE_INDEX
if tokenizer.pad_token_id == tokenizer.eos_token_id:
cur_len -= end_token_cnt
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_INDEX
print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)")
# print(f"YOU NEED GO TO DEBUG THIS DATA ITEM: {conversations}")
return dict(
input_ids=input_ids,
labels=targets,
)
def preprocess_plain(
sources: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
# add end signal and concatenate together
conversations = []
for source in sources:
assert len(source) == 2
assert DEFAULT_IMAGE_TOKEN in source[0]["value"]
source[0]["value"] = DEFAULT_IMAGE_TOKEN
conversation = (
source[0]["value"] + source[1]["value"] + conversation_lib.default_conversation.sep
)
conversations.append(conversation)
# tokenize conversations
input_ids = [
tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations
]
targets = copy.deepcopy(input_ids)
for target, source in zip(targets, sources):
tokenized_len = len(tokenizer_image_token(source[0]["value"], tokenizer))
target[:tokenized_len] = IGNORE_INDEX
return dict(input_ids=input_ids, labels=targets)
def preprocess(
sources: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
has_image: bool = False,
has_audio: bool = False,
end_tag: bool = True,
modality: str = "lang",
) -> Dict:
if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:
return preprocess_plain(sources, tokenizer)
if conversation_lib.default_conversation.version == "mixtral_zh":
return preprocess_mixtral_zh(
sources, tokenizer, has_image=has_image, has_audio=has_audio, end_tag=end_tag
)
elif conversation_lib.default_conversation.version == "mixtral_two":
return preprocess_mixtral_two(
sources,
tokenizer,
has_image=has_image,
has_audio=has_audio,
end_tag=end_tag,
modality=modality,
)
def _get_rawvideo_dec(
video_path,
image_processor,
max_frames=MAX_IMAGE_LENGTH,
min_frames=MIN_IMAGE_LENGTH,
image_resolution=384,
video_framerate=1,
s=None,
e=None,
image_aspect_ratio="pad",
):
# speed up video decode via decord.
video_mask = np.zeros(max_frames, dtype=np.int64)
max_video_length = 0
# T x 3 x H x W
video = np.zeros((max_frames, 3, image_resolution, image_resolution), dtype=np.float64)
if s is None:
start_time, end_time = None, None
else:
start_time = int(s)
end_time = int(e)
start_time = start_time if start_time >= 0.0 else 0.0
end_time = end_time if end_time >= 0.0 else 0.0
if start_time > end_time:
start_time, end_time = end_time, start_time
elif start_time == end_time:
end_time = start_time + 1
if os.path.exists(video_path):
vreader = VideoReader(video_path, ctx=cpu(0))
else:
print(video_path)
raise FileNotFoundError
fps = vreader.get_avg_fps()
f_start = 0 if start_time is None else int(start_time * fps)
f_end = int(min(1000000000 if end_time is None else end_time * fps, len(vreader) - 1))
num_frames = f_end - f_start + 1
if num_frames > 0:
# T x 3 x H x W
sample_fps = int(video_framerate)
t_stride = int(round(float(fps) / sample_fps))
all_pos = list(range(f_start, f_end + 1, t_stride))
num_frame = math.ceil(len(all_pos) / 4) * 4 # rounded up to the nearest multiple of 4
if num_frame > max_frames:
num_frame = math.floor(max_frames / 4) * 4
assert num_frame <= MAX_IMAGE_LENGTH and num_frame >= MIN_IMAGE_LENGTH
sample_fps = 3
t_stride = int(round(float(fps) / sample_fps))
all_pos = list(range(f_start, f_end + 1, t_stride))
sample_pos = [
all_pos[_] for _ in np.linspace(0, len(all_pos) - 1, num=num_frame, dtype=int)
]
patch_images = [Image.fromarray(f) for f in vreader.get_batch(sample_pos).asnumpy()]
assert len(patch_images) % 4 == 0
new_patch_images = []
for i in range(0, len(patch_images), 4):
img1, img2, img3, img4 = patch_images[i : i + 4]
width, height = img1.size
new_image = Image.new(
patch_images[0].mode,
(2 * width, 2 * height),
tuple(int(x * 255) for x in image_processor.image_mean),
)
new_image.paste(img1, (0, 0))
new_image.paste(img2, (width, 0))
new_image.paste(img3, (0, height))
new_image.paste(img4, (width, height))
new_patch_images.append(new_image)
new_patch_images.extend([img1, img2, img3, img4])
patch_images = new_patch_images
# import pdb; pdb.set_trace()
# visualize_images(patch_images[0], patch_images)
if image_aspect_ratio == "pad":
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
patch_images = [
expand2square(i, tuple(int(x * 255) for x in image_processor.image_mean))
for i in patch_images
]
patch_images = [
image_processor.preprocess(i, return_tensors="pt")["pixel_values"][0]
for i in patch_images
]
else:
patch_images = [
image_processor.preprocess(i, return_tensors="pt")["pixel_values"][0]
for i in patch_images
]
assert len(patch_images) % 5 == 0
slice_len = len(patch_images) // 5
return patch_images, slice_len
else:
print("video path: {} error.".format(video_path))
class LazySupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(self, tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments):
super(LazySupervisedDataset, self).__init__()
dataset_list = DataConfig[str(data_args.dataset_use)]
print(dataset_list)
self.max_length = MAX_IMAGE_LENGTH
list_data_dict = []
self.folder_dict = {}
for i in dataset_list:
# list_data_dict += json.load(open(i["chat_path"], "r"))
data_ratio = i.get("data_ratio", DEFAULT_DATA_RATIO)
data_i = json.load(open(i["chat_path"], "r"))
len_data_i = len(data_i)
data_i = random.sample(data_i, int(len_data_i * data_ratio))
list_data_dict += data_i
image_folder = [folder for folder in i if folder is not "chat_path"]
for folder in image_folder:
if folder not in self.folder_dict:
self.folder_dict[folder] = i[folder]
for key in FolderDict.keys():
if key not in self.folder_dict:
self.folder_dict[key] = FolderDict[key]
random.shuffle(list_data_dict)
self.tokenizer = tokenizer
self.list_data_dict = list_data_dict
self.data_args = data_args
def __len__(self):
return len(self.list_data_dict)
# @property
# def lengths(self):
# length_list = []
# for sample in self.list_data_dict:
# img_tokens = 128 if 'image' in sample else 0
# length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens)
# return length_list
@property
def modality_lengths(self):
length_list = []
for sample in self.list_data_dict:
cur_len = sum(len(conv["value"].split()) for conv in sample["conversations"])
cur_len = cur_len if ("image" in sample or "video" in sample) else -cur_len
length_list.append(cur_len)
return length_list
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
sources = self.list_data_dict[i]
if isinstance(i, int):
sources = [sources]
assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
if "image" in sources[0] and "audio" not in sources[0]:
image_file = self.list_data_dict[i]["image"]
set_id = self.list_data_dict[i].get("set", None)
file = image_file[0] if type(image_file) is list else image_file
processor = self.data_args.image_processor
if "height" in processor.size.keys():
image_size = processor.size["height"]
elif "shortest_edge" in processor.size.keys():
image_size = processor.size["shortest_edge"]
else:
raise NotImplementedError(f"Please use correct key to use processor size!")
if type(image_file) is list:
assert type(set_id) is list
if len(image_file) != len(set_id):
assert len(set(set_id)) == 1
image = [
Image.open(
os.path.join(self.folder_dict[set_id[k]], file.replace("\\", "/"))
).convert("RGB")
for k, file in enumerate(image_file)
]
if self.data_args.image_aspect_ratio == "pad":
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = [
expand2square(i, tuple(int(x * 255) for x in processor.image_mean))
for i in image
]
image_patches, patch_num = [], []
for k, img in enumerate(image):
if set_id[k] not in NoPatchSets:
img, p_num = dynamic_preprocess(
img,
min_num=self.data_args.min_dynamic_patch,
max_num=self.data_args.max_dynamic_patch,
image_size=image_size,
use_thumbnail=self.data_args.use_thumbnail,
img_mean=processor.image_mean,
)
else:
img, p_num = [img], [1]
image_patches += img
patch_num += p_num
image = image_patches
image = [
processor.preprocess(i, return_tensors="pt")["pixel_values"][0]
for i in image
]
else:
image_patches, patch_num = [], []
for k, img in enumerate(image):
if set_id[k] not in NoPatchSets:
img, p_num = dynamic_preprocess(
img,
min_num=self.data_args.min_dynamic_patch,
max_num=self.data_args.max_dynamic_patch,
image_size=image_size,
use_thumbnail=self.data_args.use_thumbnail,
img_mean=processor.image_mean,
)
else:
img, p_num = [img], [1]
image_patches += img
patch_num += p_num
image = image_patches
image = [
processor.preprocess(i, return_tensors="pt")["pixel_values"][0]
for i in image
]
else:
image_folder = self.folder_dict[set_id]
image = Image.open(
os.path.join(image_folder, image_file.replace("\\", "/"))
).convert("RGB")
if self.data_args.image_aspect_ratio == "pad":
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
image, patch_num = dynamic_preprocess(
image,
min_num=self.data_args.min_dynamic_patch,
max_num=self.data_args.max_dynamic_patch,
image_size=image_size,
use_thumbnail=self.data_args.use_thumbnail,
img_mean=processor.image_mean,
)
image = [
processor.preprocess(i, return_tensors="pt")["pixel_values"][0]
for i in image
]
else:
image, patch_num = dynamic_preprocess(
image,
min_num=self.data_args.min_dynamic_patch,
max_num=self.data_args.max_dynamic_patch,
image_size=image_size,
use_thumbnail=self.data_args.use_thumbnail,
img_mean=processor.image_mean,
)
image = [
processor.preprocess(i, return_tensors="pt")["pixel_values"][0]
for i in image
]
inserted_id = self.list_data_dict[i].get("inserted_id", None)
end_tag = self.list_data_dict[i].get("end_tag", True)
assert inserted_id is None
assert end_tag is True
sources = preprocess_multimodal(
copy.deepcopy([e["conversations"] for e in sources]),
self.data_args,
patch_num=patch_num,
inserted_id=inserted_id,
)
data_dict = preprocess(
sources, self.tokenizer, has_image=True, end_tag=end_tag, modality="image"
)
elif "image" in sources[0] and "audio" in sources[0]:
image_file = self.list_data_dict[i]["image"]
set_id = self.list_data_dict[i].get("set", None)
file = image_file[0] if type(image_file) is list else image_file
audio_file = self.list_data_dict[i]["audio"]
processor = self.data_args.image_processor
if "height" in processor.size.keys():
image_size = processor.size["height"]
elif "shortest_edge" in processor.size.keys():
image_size = processor.size["shortest_edge"]
else:
raise NotImplementedError(f"Please use correct key to use processor size!")
if type(image_file) is list:
assert type(set_id) is list
if len(image_file) != len(set_id): # 多图数据
assert len(set(set_id)) == 1
image = [
Image.open(
os.path.join(self.folder_dict[set_id[k]], file.replace("\\", "/"))
).convert("RGB")
for k, file in enumerate(image_file)
]
if self.data_args.image_aspect_ratio == "pad":
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = [
expand2square(i, tuple(int(x * 255) for x in processor.image_mean))
for i in image
]
image_patches, patch_num = [], []
for k, img in enumerate(image):
if set_id[k] not in NoPatchSets:
img, p_num = dynamic_preprocess(
img,
min_num=self.data_args.min_dynamic_patch,
max_num=self.data_args.max_dynamic_patch,
image_size=image_size,
use_thumbnail=self.data_args.use_thumbnail,
img_mean=processor.image_mean,
)
else:
img, p_num = [img], [1]
image_patches += img
patch_num += p_num
image = image_patches
image = [
processor.preprocess(i, return_tensors="pt")["pixel_values"][0]
for i in image
]
else:
image_patches, patch_num = [], []
for k, img in enumerate(image):
if set_id[k] not in NoPatchSets:
img, p_num = dynamic_preprocess(
img,
min_num=self.data_args.min_dynamic_patch,
max_num=self.data_args.max_dynamic_patch,
image_size=image_size,
use_thumbnail=self.data_args.use_thumbnail,
img_mean=processor.image_mean,
)
else:
img, p_num = [img], [1]
image_patches += img
patch_num += p_num
image = image_patches
image = [
processor.preprocess(i, return_tensors="pt")["pixel_values"][0]
for i in image
]
else:
image_folder = self.folder_dict[set_id]
image = Image.open(
os.path.join(image_folder, image_file.replace("\\", "/"))
).convert("RGB")
if self.data_args.image_aspect_ratio == "pad":
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
image, patch_num = dynamic_preprocess(
image,
min_num=self.data_args.min_dynamic_patch,
max_num=self.data_args.max_dynamic_patch,
image_size=image_size,
use_thumbnail=self.data_args.use_thumbnail,
img_mean=processor.image_mean,
)
image = [
processor.preprocess(i, return_tensors="pt")["pixel_values"][0]
for i in image
]
else:
image, patch_num = dynamic_preprocess(
image,
min_num=self.data_args.min_dynamic_patch,
max_num=self.data_args.max_dynamic_patch,
image_size=image_size,
use_thumbnail=self.data_args.use_thumbnail,
img_mean=processor.image_mean,
)
image = [
processor.preprocess(i, return_tensors="pt")["pixel_values"][0]
for i in image
]
if type(audio_file) is list:
# if type(set_id) is list:
# audio_folder = self.folder_dict[set_id[0]+'_audio']
# else:
# audio_folder = self.folder_dict[set_id+'_audio']
audio_folder = AudioFolder
assert len(audio_file) > 0, "audio_file为列表时不能为空"
audio = []
audio_for_llm_lens = []
audio_length = []
for file in audio_file:
try:
a, a_llm = self.data_args.audio_processor.process(
os.path.join(audio_folder, "audio", file)
)
except:
print(f"File {os.path.join(audio_folder, 'audio', file)} not OK!!!!!")
audio.append(a)
audio_for_llm_lens.append(a_llm)
audio_length.append(a.shape[0])
else:
# audio_folder = self.folder_dict[set_id+'_audio']
audio_folder = AudioFolder
assert audio_file, "audio_file不能为空"
audio, audio_for_llm_lens = self.data_args.audio_processor.process(
os.path.join(audio_folder, "audio", audio_file)
)
audio_length = audio.shape[0]
inserted_id = self.list_data_dict[i].get("inserted_id", None)
end_tag = self.list_data_dict[i].get("end_tag", True)
sources = preprocess_multimodal(
copy.deepcopy([e["conversations"] for e in sources]),
self.data_args,
patch_num=patch_num,
audio_lens=audio_for_llm_lens,
inserted_id=inserted_id,
)
data_dict = preprocess(
sources,
self.tokenizer,
has_image=True,
has_audio=True,
end_tag=end_tag,
modality="image",
)
data_dict["audio_lengths"] = audio_length
data_dict["audio_lengths_for_llm"] = audio_for_llm_lens
elif "video" in sources[0] and "audio" not in sources[0]:
video_file = self.list_data_dict[i]["video"]
video_id = self.list_data_dict[i]["id"]
set_id = self.list_data_dict[i].get("set", None)
processor = self.data_args.image_processor
if "height" in processor.size.keys():
image_size = processor.size["height"]
elif "shortest_edge" in processor.size.keys():
image_size = processor.size["shortest_edge"]
else:
raise NotImplementedError(f"Please use correct key to use processor size!")
video_folder = self.folder_dict[set_id]
image, image_token_num = _get_rawvideo_dec(
os.path.join(video_folder, video_file),
processor,
max_frames=MAX_IMAGE_LENGTH,
min_frames=MIN_IMAGE_LENGTH,
image_resolution=image_size,
image_aspect_ratio=self.data_args.image_aspect_ratio,
)
inserted_id = self.list_data_dict[i].get("inserted_id", None)
end_tag = self.list_data_dict[i].get("end_tag", True)
assert inserted_id is None
assert end_tag is True
sources = preprocess_multimodal(
copy.deepcopy([e["conversations"] for e in sources]),
self.data_args,
image_token_num=image_token_num,
inserted_id=inserted_id,
)
data_dict = preprocess(
sources,
self.tokenizer,
has_image=True,
has_audio=False,
end_tag=end_tag,
modality="video",
)
elif "video" in sources[0] and "audio" in sources[0]:
video_file = self.list_data_dict[i]["video"]
video_id = self.list_data_dict[i]["id"]
set_id = self.list_data_dict[i].get("set", None)
audio_file = self.list_data_dict[i]["audio"]
processor = self.data_args.image_processor
if "height" in processor.size.keys():
image_size = processor.size["height"]
elif "shortest_edge" in processor.size.keys():
image_size = processor.size["shortest_edge"]
else:
raise NotImplementedError(f"Please use correct key to use processor size!")
video_folder = self.folder_dict[set_id]
# audio_folder = self.folder_dict[set_id+'_audio']
audio_folder = AudioFolder
image, image_token_num = _get_rawvideo_dec(
os.path.join(video_folder, video_file),
processor,
max_frames=MAX_IMAGE_LENGTH,
min_frames=MIN_IMAGE_LENGTH,
image_resolution=image_size,
image_aspect_ratio=self.data_args.image_aspect_ratio,
)
if type(audio_file) is list:
assert len(audio_file) > 0, "audio_file为列表时不能为空"
audio = []
audio_for_llm_lens = []
audio_length = []
for file in audio_file:
a, a_llm = self.data_args.audio_processor.process(
os.path.join(audio_folder, "audio", file)
)
audio.append(a)
audio_for_llm_lens.append(a_llm)
audio_length.append(a.shape[0])
else:
assert audio_file, "audio_file不能为空"
audio, audio_for_llm_lens = self.data_args.audio_processor.process(
os.path.join(audio_folder, "audio", audio_file)
)
audio_length = audio.shape[0]
inserted_id = self.list_data_dict[i].get("inserted_id", None)
end_tag = self.list_data_dict[i].get("end_tag", True)
sources = preprocess_multimodal(
copy.deepcopy([e["conversations"] for e in sources]),
self.data_args,
image_token_num=image_token_num,
audio_lens=audio_for_llm_lens,
inserted_id=inserted_id,
)
data_dict = preprocess(
sources,
self.tokenizer,
has_image=True,
has_audio=True,
end_tag=end_tag,
modality="video",
)
data_dict["audio_lengths"] = audio_length
data_dict["audio_lengths_for_llm"] = audio_for_llm_lens
elif "audio" in sources[0]:
audio_file = self.list_data_dict[i]["audio"]
audio_folder = AudioFolder
if type(audio_file) is list:
assert len(audio_file) > 0, "audio_file为列表时不能为空"
audio = []
audio_for_llm_lens = []
audio_length = []
for file in audio_file:
a, a_llm = self.data_args.audio_processor.process(
os.path.join(audio_folder, "audio", file)
)
audio.append(a)
audio_for_llm_lens.append(a_llm)
audio_length.append(a.shape[0])
else:
assert audio_file, "audio_file不能为空"
audio, audio_for_llm_lens = self.data_args.audio_processor.process(
os.path.join(audio_folder, "audio", audio_file)
)
audio_length = audio.shape[0]
inserted_id = self.list_data_dict[i].get("inserted_id", None)
end_tag = self.list_data_dict[i].get("end_tag", True)
sources = preprocess_multimodal(
copy.deepcopy([e["conversations"] for e in sources]),
self.data_args,
image_token_num=0,
audio_lens=audio_for_llm_lens,
inserted_id=inserted_id,
)
data_dict = preprocess(
sources,
self.tokenizer,
has_image=False,
has_audio=True,
end_tag=end_tag,
modality="lang",
)
data_dict["audio_lengths"] = audio_length
data_dict["audio_lengths_for_llm"] = audio_for_llm_lens
else:
sources = copy.deepcopy([e["conversations"] for e in sources])
data_dict = preprocess(sources, self.tokenizer, has_image=False, modality="lang")
if isinstance(i, int):
if "audio" in self.list_data_dict[i]:
data_dict = dict(
input_ids=data_dict["input_ids"][0],
labels=data_dict["labels"][0],
audio_lengths=data_dict["audio_lengths"],
audio_lengths_for_llm=data_dict["audio_lengths_for_llm"],
)
else:
data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0])
# image exist in the data
if "image" in self.list_data_dict[i] or "video" in self.list_data_dict[i]:
data_dict["image"] = image
elif self.data_args.is_multimodal:
# image does not exist in the data, but the model is multimodal
crop_size = self.data_args.image_processor.crop_size
# data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
data_dict["image"] = [torch.zeros(3, crop_size["height"], crop_size["width"])] * 5
if "audio" in self.list_data_dict[i]:
data_dict["audio"] = audio
elif self.data_args.is_multimodal:
data_dict["audio"] = torch.zeros(400, 80)
data_dict["audio_lengths"] = 400
data_dict["audio_lengths_for_llm"] = 60
return data_dict
@dataclass
class DataCollatorForSupervisedDataset(object):
"""Collate examples for supervised fine-tuning."""
tokenizer: transformers.PreTrainedTokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple(
[instance[key] for instance in instances] for key in ("input_ids", "labels")
)
if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
for input_id in input_ids:
input_id[input_id == self.tokenizer.eos_token_id] = -300
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
)
labels = torch.nn.utils.rnn.pad_sequence(
labels, batch_first=True, padding_value=IGNORE_INDEX
)
input_ids = input_ids[:, : self.tokenizer.model_max_length]
attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
labels = labels[:, : self.tokenizer.model_max_length]
if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
for input_id in input_ids:
input_id[input_id == -300] = self.tokenizer.eos_token_id
batch = dict(
input_ids=input_ids,
labels=labels,
attention_mask=attention_mask,
)
if "image" in instances[0]:
images = [instance["image"] for instance in instances]
new_images = []
for image in images:
if type(image) is list:
for i in image:
new_images.append(i)
else:
new_images.append(image)
images = new_images
if all(x is not None and x.shape == images[0].shape for x in images):
batch["images"] = torch.stack(images)
else:
batch["images"] = images
batch["audios"] = {}
if "audio" in instances[0]:
audios = [instance["audio"] for instance in instances]
audio_lengths = [instance["audio_lengths"] for instance in instances]
audio_lengths_for_llm = [instance["audio_lengths_for_llm"] for instance in instances]
new_audios = []
new_audio_lengths = []
new_audio_lengths_for_llm = []
for i, audio in enumerate(audios):
length = audio_lengths[i]
length_for_llm = audio_lengths_for_llm[i]
if type(audio) is list:
for j, a in enumerate(audio):
new_audios.append(a)
new_audio_lengths.append(length[j])
new_audio_lengths_for_llm.append(length_for_llm[j])
else:
new_audios.append(audio)
new_audio_lengths.append(length)
new_audio_lengths_for_llm.append(length_for_llm)
audios = new_audios
audios = pad_sequence(audios, batch_first=True, padding_value=0)
batch["audios"]["audios"] = audios
batch["audios"]["lengths"] = torch.tensor(new_audio_lengths)
batch["audios"]["lengths_for_llm"] = torch.tensor(new_audio_lengths_for_llm)
return batch
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
train_dataset = LazySupervisedDataset(tokenizer=tokenizer, data_args=data_args)
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float("inf")
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
# print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')
return best_ratio
def dynamic_preprocess(
image, min_num=1, max_num=6, image_size=448, use_thumbnail=False, img_mean=0
):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num
)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size
)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
# expand target_aspect_ratio to even for each size
new_target_aspect_ratio = [e if e % 2 == 0 else e + 1 for e in target_aspect_ratio]
blocks_big = int(0.5 * new_target_aspect_ratio[0] * 0.5 * new_target_aspect_ratio[1])
# padding to even patch for each size
new_target_width = new_target_aspect_ratio[0] * image_size
new_target_height = new_target_aspect_ratio[1] * image_size
resized_img = expand2even(
resized_img, new_target_width, new_target_height, tuple(int(x * 255) for x in img_mean)
)
assert resized_img.size[0] == new_target_aspect_ratio[0] * image_size
assert resized_img.size[1] == new_target_aspect_ratio[1] * image_size
processed_images = []
image_size_big = image_size * 2
for i in range(blocks_big):
# TODO append big patch per 4 patch, order: big then small
box = (
(i % (new_target_width // image_size_big)) * image_size_big,
(i // (new_target_width // image_size_big)) * image_size_big,
((i % (new_target_width // image_size_big)) + 1) * image_size_big,
((i // (new_target_width // image_size_big)) + 1) * image_size_big,
)
# split the image
split_img_big = resized_img.crop(box)
split_img = split_img_big.resize((image_size, image_size))
processed_images.append(split_img)
blocks_small = 2 * 2
for i in range(blocks_small):
# TODO append big patch per 4 patch, order: big then small
box = (
(i % (image_size_big // image_size)) * image_size,
(i // (image_size_big // image_size)) * image_size,
((i % (image_size_big // image_size)) + 1) * image_size,
((i // (image_size_big // image_size)) + 1) * image_size,
)
# split the image
split_img = split_img_big.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks_big * 5
assert len(processed_images) % 5 == 0
# import pdb; pdb.set_trace()
# visualize_images(resized_img, processed_images)
return processed_images, [len(processed_images) // 5]
def expand2even(pil_img, new_target_width, new_target_height, background_color):
result = Image.new(pil_img.mode, (new_target_width, new_target_height), background_color)
result.paste(pil_img, (0, 0))
return result
def visualize_images(resized_img, processed_images, output_path="output.png"):
# Create a figure to hold the subplots
fig, axes = plt.subplots(
nrows=(len(processed_images) // 5) + 1,
ncols=5,
figsize=(15, (len(processed_images) // 5) + 1),
)
# Plot the resized_img in the first row
axes[0, 0].imshow(resized_img)
axes[0, 0].set_title("Resized Image")
axes[0, 0].axis("off")
# Hide the remaining subplots in the first row
for j in range(1, 5):
axes[0, j].axis("off")
# Plot the processed_images
for i, img in enumerate(processed_images):
row = (i // 5) + 1
col = i % 5
axes[row, col].imshow(img)
axes[row, col].axis("off")
# Save the figure
plt.tight_layout()
plt.savefig(output_path)
plt.close()
import copy
import json
import math
import os
import random
from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence
import numpy as np
import torch
import transformers
from PIL import Image
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
from decord import VideoReader, cpu
from vita import conversation as conversation_lib
from vita.config import AudioFolder, DataConfig, FolderDict, NoPatchSets
from vita.constants import (
DEFAULT_AUDIO_TOKEN,
DEFAULT_DATA_RATIO,
DEFAULT_IMAGE_TOKEN,
DEFAULT_VIDEO_TOKEN,
IGNORE_INDEX,
MAX_IMAGE_LENGTH,
MIN_IMAGE_LENGTH,
)
from vita.util.mm_utils import tokenizer_image_audio_token, tokenizer_image_token
@dataclass
class DataArguments:
lazy_preprocess: bool = False
is_multimodal: bool = True
image_folder: Optional[str] = field(default=None)
image_aspect_ratio: str = field(default=None)
dataset_use: str = field(default="temp")
min_dynamic_patch: int = 1
max_dynamic_patch: int = 12
use_thumbnail: bool = True
def preprocess_multimodal(
sources: Sequence[str],
data_args: DataArguments,
image_token_num=1,
patch_num=[1],
audio_lens: int = 0,
inserted_id=None,
) -> Dict:
is_multimodal = data_args.is_multimodal
if not is_multimodal:
return sources
k_img_ph = 0
for source in sources:
if inserted_id is not None:
assert source[inserted_id]["from"] == "gpt"
for i, sentence in enumerate(source):
if DEFAULT_IMAGE_TOKEN in sentence["value"] or DEFAULT_VIDEO_TOKEN in sentence["value"]:
sentence["value"] = (
sentence["value"]
.replace(DEFAULT_IMAGE_TOKEN + "\n", DEFAULT_IMAGE_TOKEN)
.strip()
)
sentence["value"] = (
sentence["value"]
.replace("\n" + DEFAULT_IMAGE_TOKEN, DEFAULT_IMAGE_TOKEN)
.strip()
)
if sentence["value"].endswith(DEFAULT_IMAGE_TOKEN):
IMAGE_TOKEN_NUM = sentence["value"].count(DEFAULT_IMAGE_TOKEN)
sentence["value"] = (
sentence["value"].replace(DEFAULT_IMAGE_TOKEN * IMAGE_TOKEN_NUM, "").strip()
)
sentence["value"] = DEFAULT_IMAGE_TOKEN * IMAGE_TOKEN_NUM + sentence["value"]
sentence["value"] = sentence["value"].strip()
if sentence["value"].endswith(DEFAULT_VIDEO_TOKEN):
VIDEO_TOKEN_NUM = sentence["value"].count(DEFAULT_VIDEO_TOKEN)
sentence["value"] = (
sentence["value"].replace(DEFAULT_VIDEO_TOKEN * VIDEO_TOKEN_NUM, "").strip()
)
sentence["value"] = DEFAULT_VIDEO_TOKEN * VIDEO_TOKEN_NUM + sentence["value"]
sentence["value"] = sentence["value"].strip()
if "mmtag" in conversation_lib.default_conversation.version:
sentence["value"] = sentence["value"].replace(
DEFAULT_IMAGE_TOKEN, "<Image>" + DEFAULT_IMAGE_TOKEN + "</Image>"
)
IMAGE_TOKEN_NUM = sentence["value"].count(DEFAULT_IMAGE_TOKEN)
if IMAGE_TOKEN_NUM > MAX_IMAGE_LENGTH:
sentence["value"] = (
sentence["value"]
.replace(
DEFAULT_IMAGE_TOKEN * IMAGE_TOKEN_NUM,
DEFAULT_IMAGE_TOKEN * MAX_IMAGE_LENGTH,
)
.strip()
)
replace_token, vid_replace_token, audio_replace_token = (
DEFAULT_IMAGE_TOKEN,
DEFAULT_IMAGE_TOKEN * image_token_num,
DEFAULT_AUDIO_TOKEN,
) # * audio_lens
if DEFAULT_IMAGE_TOKEN in sentence["value"]:
replace_token = DEFAULT_IMAGE_TOKEN * patch_num[k_img_ph]
k_img_ph += 1
sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token + "\n")
sentence["value"] = sentence["value"].replace(
DEFAULT_VIDEO_TOKEN, vid_replace_token + "\n"
)
sentence["value"] = sentence["value"].replace(
DEFAULT_AUDIO_TOKEN + "\n", audio_replace_token
)
sentence["value"] = sentence["value"].replace("\n\n", "\n")
if i == inserted_id:
assert sentence["from"] == "gpt"
sentence["value"] = "<2>" + sentence["value"]
elif sentence["from"] == "gpt":
if "<audio>" in source[i - 1]["value"]:
sentence["value"] = "<1>" + sentence["value"]
else:
sentence["value"] = "<3>" + sentence["value"]
# print(patch_num)
# print(sum(patch_num))
# print(sources)
# import pdb; pdb.set_trace()
return sources
def preprocess_mixtral_zh(
sources,
tokenizer: transformers.PreTrainedTokenizer,
has_image: bool = False,
has_audio: bool = False,
end_tag: bool = True,
) -> Dict:
conv = conversation_lib.default_conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt())
# Tokenize conversations
if not end_tag:
conversations[0] = conversations[0][:-4]
if has_image and not has_audio:
input_ids = torch.stack(
[
tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
for prompt in conversations
],
dim=0,
)
elif has_image and has_audio:
input_ids = torch.stack(
[
tokenizer_image_audio_token(prompt, tokenizer, return_tensors="pt")
for prompt in conversations
],
dim=0,
)
elif not has_image and has_audio:
input_ids = torch.stack(
[
tokenizer_image_audio_token(prompt, tokenizer, return_tensors="pt")
for prompt in conversations
],
dim=0,
)
else:
input_ids = tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
# print(f'end_tag: {end_tag}')
# print(conversations)
# print(input_ids)
# import pdb; pdb.set_trace()
targets = input_ids.clone()
assert conv.sep_style == conversation_lib.SeparatorStyle.MixtralZh
# Mask targets
sep = conv.sep + "\n" + conv.roles[1] + ":"
sep2_2 = "\n" + conv.roles[0] + ":"
sep2 = conv.sep2 + sep2_2
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
rounds = conversation.split(sep2)
rounds = [rounds[0] + sep2 + rounds[1]] + rounds[2:]
cur_len = 1
end_token_cnt = 0
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(rounds):
if rou == "":
break
if i > 0:
rou = sep2_2 + rou
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
if has_image and not has_audio:
round_len = len(tokenizer_image_token(rou, tokenizer))
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1
elif has_image and has_audio:
round_len = len(tokenizer_image_audio_token(rou, tokenizer))
instruction_len = len(tokenizer_image_audio_token(parts[0], tokenizer)) - 1
elif not has_image and has_audio:
round_len = len(tokenizer_image_audio_token(rou, tokenizer))
instruction_len = len(tokenizer_image_audio_token(parts[0], tokenizer)) - 1
else:
round_len = len(tokenizer(rou).input_ids)
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
end_token_cnt += 1
cur_len += round_len
cur_len = cur_len - 1
target[cur_len:] = IGNORE_INDEX
if tokenizer.pad_token_id == tokenizer.eos_token_id:
cur_len -= end_token_cnt
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_INDEX
print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)")
# print(f"YOU NEED GO TO DEBUG THIS DATA ITEM: {conversations}")
return dict(
input_ids=input_ids,
labels=targets,
)
def preprocess_mixtral_two(
sources,
tokenizer: transformers.PreTrainedTokenizer,
has_image: bool = False,
has_audio: bool = False,
end_tag: bool = True,
modality: str = "lang",
) -> Dict:
conv = conversation_lib.default_conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt(modality))
# print(conversations)
# import pdb; pdb.set_trace()
# Tokenize conversations
if not end_tag:
conversations[0] = conversations[0][:-4]
if has_image and not has_audio:
input_ids = torch.stack(
[
tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
for prompt in conversations
],
dim=0,
)
elif has_image and has_audio:
input_ids = torch.stack(
[
tokenizer_image_audio_token(prompt, tokenizer, return_tensors="pt")
for prompt in conversations
],
dim=0,
)
elif not has_image and has_audio:
input_ids = torch.stack(
[
tokenizer_image_audio_token(prompt, tokenizer, return_tensors="pt")
for prompt in conversations
],
dim=0,
)
else:
input_ids = tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
# print(f'end_tag: {end_tag}')
# print(conversations)
# print(input_ids)
# import pdb; pdb.set_trace()
targets = input_ids.clone()
assert conv.sep_style == conversation_lib.SeparatorStyle.MixtralTwo
# Mask targets
sep = conv.sep + "\n" + conv.roles[1] + ":"
sep2_2 = "\n" + conv.roles[0] + ":"
sep2 = conv.sep2 + sep2_2
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
rounds = conversation.split(sep2)
rounds = [rounds[0] + sep2 + rounds[1]] + rounds[2:]
cur_len = 1
end_token_cnt = 0
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(rounds):
if rou == "":
break
if i > 0:
rou = sep2_2 + rou
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
if has_image and not has_audio:
round_len = len(tokenizer_image_token(rou, tokenizer))
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1
elif has_image and has_audio:
round_len = len(tokenizer_image_audio_token(rou, tokenizer))
instruction_len = len(tokenizer_image_audio_token(parts[0], tokenizer)) - 1
elif not has_image and has_audio:
round_len = len(tokenizer_image_audio_token(rou, tokenizer))
instruction_len = len(tokenizer_image_audio_token(parts[0], tokenizer)) - 1
else:
round_len = len(tokenizer(rou).input_ids)
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
end_token_cnt += 1
cur_len += round_len
cur_len = cur_len - 1
target[cur_len:] = IGNORE_INDEX
if tokenizer.pad_token_id == tokenizer.eos_token_id:
cur_len -= end_token_cnt
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_INDEX
print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)")
# print(f"YOU NEED GO TO DEBUG THIS DATA ITEM: {conversations}")
return dict(
input_ids=input_ids,
labels=targets,
)
def preprocess_plain(
sources: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
# add end signal and concatenate together
conversations = []
for source in sources:
assert len(source) == 2
assert DEFAULT_IMAGE_TOKEN in source[0]["value"]
source[0]["value"] = DEFAULT_IMAGE_TOKEN
conversation = (
source[0]["value"] + source[1]["value"] + conversation_lib.default_conversation.sep
)
conversations.append(conversation)
# tokenize conversations
input_ids = [
tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations
]
targets = copy.deepcopy(input_ids)
for target, source in zip(targets, sources):
tokenized_len = len(tokenizer_image_token(source[0]["value"], tokenizer))
target[:tokenized_len] = IGNORE_INDEX
return dict(input_ids=input_ids, labels=targets)
def preprocess(
sources: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
has_image: bool = False,
has_audio: bool = False,
end_tag: bool = True,
modality: str = "lang",
) -> Dict:
if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:
return preprocess_plain(sources, tokenizer)
if conversation_lib.default_conversation.version == "mixtral_zh":
return preprocess_mixtral_zh(
sources, tokenizer, has_image=has_image, has_audio=has_audio, end_tag=end_tag
)
elif conversation_lib.default_conversation.version == "mixtral_two":
return preprocess_mixtral_two(
sources,
tokenizer,
has_image=has_image,
has_audio=has_audio,
end_tag=end_tag,
modality=modality,
)
def _get_rawvideo_dec(
video_path,
image_processor,
max_frames=32,
min_frames=4,
image_resolution=384,
video_framerate=1,
s=None,
e=None,
image_aspect_ratio="pad",
):
# speed up video decode via decord.
video_mask = np.zeros(max_frames, dtype=np.int64)
max_video_length = 0
# T x 3 x H x W
video = np.zeros((max_frames, 3, image_resolution, image_resolution), dtype=np.float64)
if s is None:
start_time, end_time = None, None
else:
start_time = int(s)
end_time = int(e)
start_time = start_time if start_time >= 0.0 else 0.0
end_time = end_time if end_time >= 0.0 else 0.0
if start_time > end_time:
start_time, end_time = end_time, start_time
elif start_time == end_time:
end_time = start_time + 1
if os.path.exists(video_path):
vreader = VideoReader(video_path, ctx=cpu(0))
else:
print(video_path)
raise FileNotFoundError
fps = vreader.get_avg_fps()
f_start = 0 if start_time is None else int(start_time * fps)
f_end = int(min(1000000000 if end_time is None else end_time * fps, len(vreader) - 1))
num_frames = f_end - f_start + 1
if num_frames > 0:
# T x 3 x H x W
sample_fps = int(video_framerate)
t_stride = int(round(float(fps) / sample_fps))
all_pos = list(range(f_start, f_end + 1, t_stride))
if len(all_pos) > max_frames:
sample_pos = [
all_pos[_] for _ in np.linspace(0, len(all_pos) - 1, num=max_frames, dtype=int)
]
elif len(all_pos) < min_frames:
sample_pos = [
all_pos[_] for _ in np.linspace(0, len(all_pos) - 1, num=min_frames, dtype=int)
]
else:
sample_pos = all_pos
patch_images = [Image.fromarray(f) for f in vreader.get_batch(sample_pos).asnumpy()]
if image_aspect_ratio == "pad":
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
patch_images = [
expand2square(i, tuple(int(x * 255) for x in image_processor.image_mean))
for i in patch_images
]
patch_images = [
image_processor.preprocess(i, return_tensors="pt")["pixel_values"][0]
for i in patch_images
]
else:
patch_images = [
image_processor.preprocess(i, return_tensors="pt")["pixel_values"][0]
for i in patch_images
]
# patch_images = [image_processor.preprocess(img, return_tensors='pt')['pixel_values'][0] for img in patch_images]
slice_len = len(patch_images)
return patch_images, slice_len
max_video_length = max_video_length if max_video_length > slice_len else slice_len
if slice_len < 1:
pass
else:
while len(patch_images) < max_frames:
patch_images.append(torch.zeros((3, image_resolution, image_resolution)))
# video[:slice_len, ...] = patch_images
else:
print("video path: {} error.".format(video_path))
video_mask[:max_video_length] = [1] * max_video_length
return patch_images, video_mask
class LazySupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(self, tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments):
super(LazySupervisedDataset, self).__init__()
dataset_list = DataConfig[str(data_args.dataset_use)]
print(dataset_list)
self.max_length = MAX_IMAGE_LENGTH
list_data_dict = []
self.folder_dict = {}
for i in dataset_list:
# list_data_dict += json.load(open(i["chat_path"], "r"))
data_ratio = i.get("data_ratio", DEFAULT_DATA_RATIO)
data_i = json.load(open(i["chat_path"], "r"))
len_data_i = len(data_i)
data_i = random.sample(data_i, int(len_data_i * data_ratio))
list_data_dict += data_i
image_folder = [folder for folder in i if folder is not "chat_path"]
for folder in image_folder:
if folder not in self.folder_dict:
self.folder_dict[folder] = i[folder]
for key in FolderDict.keys():
if key not in self.folder_dict:
self.folder_dict[key] = FolderDict[key]
random.shuffle(list_data_dict)
self.tokenizer = tokenizer
self.list_data_dict = list_data_dict
self.data_args = data_args
def __len__(self):
return len(self.list_data_dict)
# @property
# def lengths(self):
# length_list = []
# for sample in self.list_data_dict:
# img_tokens = 128 if 'image' in sample else 0
# length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens)
# return length_list
@property
def modality_lengths(self):
length_list = []
for sample in self.list_data_dict:
cur_len = sum(len(conv["value"].split()) for conv in sample["conversations"])
cur_len = cur_len if ("image" in sample or "video" in sample) else -cur_len
length_list.append(cur_len)
return length_list
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
sources = self.list_data_dict[i]
if isinstance(i, int):
sources = [sources]
assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
if "image" in sources[0] and "audio" not in sources[0]:
image_file = self.list_data_dict[i]["image"]
set_id = self.list_data_dict[i].get("set", None)
file = image_file[0] if type(image_file) is list else image_file
processor = self.data_args.image_processor
if "height" in processor.size.keys():
image_size = processor.size["height"]
elif "shortest_edge" in processor.size.keys():
image_size = processor.size["shortest_edge"]
else:
raise NotImplementedError(f"Please use correct key to use processor size!")
if type(image_file) is list:
assert type(set_id) is list
if len(image_file) != len(set_id):
assert len(set(set_id)) == 1
image = [
Image.open(
os.path.join(self.folder_dict[set_id[k]], file.replace("\\", "/"))
).convert("RGB")
for k, file in enumerate(image_file)
]
if self.data_args.image_aspect_ratio == "pad":
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = [
expand2square(i, tuple(int(x * 255) for x in processor.image_mean))
for i in image
]
image_patches, patch_num = [], []
for k, img in enumerate(image):
if set_id[k] not in NoPatchSets:
img, p_num = dynamic_preprocess(
img,
min_num=self.data_args.min_dynamic_patch,
max_num=self.data_args.max_dynamic_patch,
image_size=image_size,
use_thumbnail=self.data_args.use_thumbnail,
)
else:
img, p_num = [img], [1]
image_patches += img
patch_num += p_num
assert len(image_patches) == sum(patch_num)
image = image_patches
image = [
processor.preprocess(i, return_tensors="pt")["pixel_values"][0]
for i in image
]
else:
image_patches, patch_num = [], []
for k, img in enumerate(image):
if set_id[k] not in NoPatchSets:
img, p_num = dynamic_preprocess(
img,
min_num=self.data_args.min_dynamic_patch,
max_num=self.data_args.max_dynamic_patch,
image_size=image_size,
use_thumbnail=self.data_args.use_thumbnail,
)
else:
img, p_num = [img], [1]
image_patches += img
patch_num += p_num
assert len(image_patches) == sum(patch_num)
image = image_patches
image = [
processor.preprocess(i, return_tensors="pt")["pixel_values"][0]
for i in image
]
else:
image_folder = self.folder_dict[set_id]
image = Image.open(
os.path.join(image_folder, image_file.replace("\\", "/"))
).convert("RGB")
if self.data_args.image_aspect_ratio == "pad":
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
image, patch_num = dynamic_preprocess(
image,
min_num=self.data_args.min_dynamic_patch,
max_num=self.data_args.max_dynamic_patch,
image_size=image_size,
use_thumbnail=self.data_args.use_thumbnail,
)
image = [
processor.preprocess(i, return_tensors="pt")["pixel_values"][0]
for i in image
]
else:
image, patch_num = dynamic_preprocess(
image,
min_num=self.data_args.min_dynamic_patch,
max_num=self.data_args.max_dynamic_patch,
image_size=image_size,
use_thumbnail=self.data_args.use_thumbnail,
)
image = [
processor.preprocess(i, return_tensors="pt")["pixel_values"][0]
for i in image
]
inserted_id = self.list_data_dict[i].get("inserted_id", None)
end_tag = self.list_data_dict[i].get("end_tag", True)
assert inserted_id is None
assert end_tag is True
sources = preprocess_multimodal(
copy.deepcopy([e["conversations"] for e in sources]),
self.data_args,
patch_num=patch_num,
inserted_id=inserted_id,
)
data_dict = preprocess(
sources, self.tokenizer, has_image=True, end_tag=end_tag, modality="image"
)
elif "image" in sources[0] and "audio" in sources[0]:
image_file = self.list_data_dict[i]["image"]
set_id = self.list_data_dict[i].get("set", None)
file = image_file[0] if type(image_file) is list else image_file
audio_file = self.list_data_dict[i]["audio"]
processor = self.data_args.image_processor
if "height" in processor.size.keys():
image_size = processor.size["height"]
elif "shortest_edge" in processor.size.keys():
image_size = processor.size["shortest_edge"]
else:
raise NotImplementedError(f"Please use correct key to use processor size!")
if type(image_file) is list:
assert type(set_id) is list
if len(image_file) != len(set_id): # 多图数据
assert len(set(set_id)) == 1
image = [
Image.open(
os.path.join(self.folder_dict[set_id[k]], file.replace("\\", "/"))
).convert("RGB")
for k, file in enumerate(image_file)
]
if self.data_args.image_aspect_ratio == "pad":
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = [
expand2square(i, tuple(int(x * 255) for x in processor.image_mean))
for i in image
]
image_patches, patch_num = [], []
for k, img in enumerate(image):
if set_id[k] not in NoPatchSets:
img, p_num = dynamic_preprocess(
img,
min_num=self.data_args.min_dynamic_patch,
max_num=self.data_args.max_dynamic_patch,
image_size=image_size,
use_thumbnail=self.data_args.use_thumbnail,
)
else:
img, p_num = [img], [1]
image_patches += img
patch_num += p_num
assert len(image_patches) == sum(patch_num)
image = image_patches
image = [
processor.preprocess(i, return_tensors="pt")["pixel_values"][0]
for i in image
]
else:
image_patches, patch_num = [], []
for k, img in enumerate(image):
if set_id[k] not in NoPatchSets:
img, p_num = dynamic_preprocess(
img,
min_num=self.data_args.min_dynamic_patch,
max_num=self.data_args.max_dynamic_patch,
image_size=image_size,
use_thumbnail=self.data_args.use_thumbnail,
)
else:
img, p_num = [img], [1]
image_patches += img
patch_num += p_num
assert len(image_patches) == sum(patch_num)
image = image_patches
image = [
processor.preprocess(i, return_tensors="pt")["pixel_values"][0]
for i in image
]
else:
image_folder = self.folder_dict[set_id]
image = Image.open(
os.path.join(image_folder, image_file.replace("\\", "/"))
).convert("RGB")
if self.data_args.image_aspect_ratio == "pad":
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
image, patch_num = dynamic_preprocess(
image,
min_num=self.data_args.min_dynamic_patch,
max_num=self.data_args.max_dynamic_patch,
image_size=image_size,
use_thumbnail=self.data_args.use_thumbnail,
)
image = [
processor.preprocess(i, return_tensors="pt")["pixel_values"][0]
for i in image
]
else:
image, patch_num = dynamic_preprocess(
image,
min_num=self.data_args.min_dynamic_patch,
max_num=self.data_args.max_dynamic_patch,
image_size=image_size,
use_thumbnail=self.data_args.use_thumbnail,
)
image = [
processor.preprocess(i, return_tensors="pt")["pixel_values"][0]
for i in image
]
if type(audio_file) is list:
# if type(set_id) is list:
# audio_folder = self.folder_dict[set_id[0]+'_audio']
# else:
# audio_folder = self.folder_dict[set_id+'_audio']
audio_folder = AudioFolder
assert len(audio_file) > 0, "audio_file为列表时不能为空"
audio = []
audio_for_llm_lens = []
audio_length = []
for file in audio_file:
try:
a, a_llm = self.data_args.audio_processor.process(
os.path.join(audio_folder, "audio", file)
)
except:
print(f"File {os.path.join(audio_folder, 'audio', file)} not OK!!!!!")
audio.append(a)
audio_for_llm_lens.append(a_llm)
audio_length.append(a.shape[0])
else:
# audio_folder = self.folder_dict[set_id+'_audio']
audio_folder = AudioFolder
assert audio_file, "audio_file不能为空"
audio, audio_for_llm_lens = self.data_args.audio_processor.process(
os.path.join(audio_folder, "audio", audio_file)
)
audio_length = audio.shape[0]
inserted_id = self.list_data_dict[i].get("inserted_id", None)
end_tag = self.list_data_dict[i].get("end_tag", True)
sources = preprocess_multimodal(
copy.deepcopy([e["conversations"] for e in sources]),
self.data_args,
patch_num=patch_num,
audio_lens=audio_for_llm_lens,
inserted_id=inserted_id,
)
data_dict = preprocess(
sources,
self.tokenizer,
has_image=True,
has_audio=True,
end_tag=end_tag,
modality="image",
)
data_dict["audio_lengths"] = audio_length
data_dict["audio_lengths_for_llm"] = audio_for_llm_lens
elif "video" in sources[0] and "audio" not in sources[0]:
video_file = self.list_data_dict[i]["video"]
video_id = self.list_data_dict[i]["id"]
set_id = self.list_data_dict[i].get("set", None)
processor = self.data_args.image_processor
if "height" in processor.size.keys():
image_size = processor.size["height"]
elif "shortest_edge" in processor.size.keys():
image_size = processor.size["shortest_edge"]
else:
raise NotImplementedError(f"Please use correct key to use processor size!")
video_folder = self.folder_dict[set_id]
image, image_token_num = _get_rawvideo_dec(
os.path.join(video_folder, video_file),
processor,
max_frames=MAX_IMAGE_LENGTH,
min_frames=MIN_IMAGE_LENGTH,
image_resolution=image_size,
image_aspect_ratio=self.data_args.image_aspect_ratio,
)
inserted_id = self.list_data_dict[i].get("inserted_id", None)
end_tag = self.list_data_dict[i].get("end_tag", True)
assert inserted_id is None
assert end_tag is True
sources = preprocess_multimodal(
copy.deepcopy([e["conversations"] for e in sources]),
self.data_args,
image_token_num=image_token_num,
inserted_id=inserted_id,
)
data_dict = preprocess(
sources,
self.tokenizer,
has_image=True,
has_audio=False,
end_tag=end_tag,
modality="video",
)
elif "video" in sources[0] and "audio" in sources[0]:
video_file = self.list_data_dict[i]["video"]
video_id = self.list_data_dict[i]["id"]
set_id = self.list_data_dict[i].get("set", None)
audio_file = self.list_data_dict[i]["audio"]
processor = self.data_args.image_processor
if "height" in processor.size.keys():
image_size = processor.size["height"]
elif "shortest_edge" in processor.size.keys():
image_size = processor.size["shortest_edge"]
else:
raise NotImplementedError(f"Please use correct key to use processor size!")
video_folder = self.folder_dict[set_id]
# audio_folder = self.folder_dict[set_id+'_audio']
audio_folder = AudioFolder
image, image_token_num = _get_rawvideo_dec(
os.path.join(video_folder, video_file),
processor,
max_frames=MAX_IMAGE_LENGTH,
min_frames=MIN_IMAGE_LENGTH,
image_resolution=image_size,
image_aspect_ratio=self.data_args.image_aspect_ratio,
)
if type(audio_file) is list:
assert len(audio_file) > 0, "audio_file为列表时不能为空"
audio = []
audio_for_llm_lens = []
audio_length = []
for file in audio_file:
a, a_llm = self.data_args.audio_processor.process(
os.path.join(audio_folder, "audio", file)
)
audio.append(a)
audio_for_llm_lens.append(a_llm)
audio_length.append(a.shape[0])
else:
assert audio_file, "audio_file不能为空"
audio, audio_for_llm_lens = self.data_args.audio_processor.process(
os.path.join(audio_folder, "audio", audio_file)
)
audio_length = audio.shape[0]
inserted_id = self.list_data_dict[i].get("inserted_id", None)
end_tag = self.list_data_dict[i].get("end_tag", True)
sources = preprocess_multimodal(
copy.deepcopy([e["conversations"] for e in sources]),
self.data_args,
image_token_num=image_token_num,
audio_lens=audio_for_llm_lens,
inserted_id=inserted_id,
)
data_dict = preprocess(
sources,
self.tokenizer,
has_image=True,
has_audio=True,
end_tag=end_tag,
modality="video",
)
data_dict["audio_lengths"] = audio_length
data_dict["audio_lengths_for_llm"] = audio_for_llm_lens
elif "audio" in sources[0]:
audio_file = self.list_data_dict[i]["audio"]
audio_folder = AudioFolder
if type(audio_file) is list:
assert len(audio_file) > 0, "audio_file为列表时不能为空"
audio = []
audio_for_llm_lens = []
audio_length = []
for file in audio_file:
a, a_llm = self.data_args.audio_processor.process(
os.path.join(audio_folder, "audio", file)
)
audio.append(a)
audio_for_llm_lens.append(a_llm)
audio_length.append(a.shape[0])
else:
assert audio_file, "audio_file不能为空"
audio, audio_for_llm_lens = self.data_args.audio_processor.process(
os.path.join(audio_folder, "audio", audio_file)
)
audio_length = audio.shape[0]
inserted_id = self.list_data_dict[i].get("inserted_id", None)
end_tag = self.list_data_dict[i].get("end_tag", True)
sources = preprocess_multimodal(
copy.deepcopy([e["conversations"] for e in sources]),
self.data_args,
image_token_num=0,
audio_lens=audio_for_llm_lens,
inserted_id=inserted_id,
)
data_dict = preprocess(
sources,
self.tokenizer,
has_image=False,
has_audio=True,
end_tag=end_tag,
modality="lang",
)
data_dict["audio_lengths"] = audio_length
data_dict["audio_lengths_for_llm"] = audio_for_llm_lens
else:
sources = copy.deepcopy([e["conversations"] for e in sources])
data_dict = preprocess(sources, self.tokenizer, has_image=False, modality="lang")
if isinstance(i, int):
if "audio" in self.list_data_dict[i]:
data_dict = dict(
input_ids=data_dict["input_ids"][0],
labels=data_dict["labels"][0],
audio_lengths=data_dict["audio_lengths"],
audio_lengths_for_llm=data_dict["audio_lengths_for_llm"],
)
else:
data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0])
# image exist in the data
if "image" in self.list_data_dict[i] or "video" in self.list_data_dict[i]:
data_dict["image"] = image
elif self.data_args.is_multimodal:
# image does not exist in the data, but the model is multimodal
crop_size = self.data_args.image_processor.crop_size
data_dict["image"] = torch.zeros(3, crop_size["height"], crop_size["width"])
if "audio" in self.list_data_dict[i]:
data_dict["audio"] = audio
elif self.data_args.is_multimodal:
data_dict["audio"] = torch.zeros(400, 80)
data_dict["audio_lengths"] = 400
data_dict["audio_lengths_for_llm"] = 60
return data_dict
@dataclass
class DataCollatorForSupervisedDataset(object):
"""Collate examples for supervised fine-tuning."""
tokenizer: transformers.PreTrainedTokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple(
[instance[key] for instance in instances] for key in ("input_ids", "labels")
)
if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
for input_id in input_ids:
input_id[input_id == self.tokenizer.eos_token_id] = -300
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
)
labels = torch.nn.utils.rnn.pad_sequence(
labels, batch_first=True, padding_value=IGNORE_INDEX
)
input_ids = input_ids[:, : self.tokenizer.model_max_length]
attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
labels = labels[:, : self.tokenizer.model_max_length]
if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
for input_id in input_ids:
input_id[input_id == -300] = self.tokenizer.eos_token_id
batch = dict(
input_ids=input_ids,
labels=labels,
attention_mask=attention_mask,
)
if "image" in instances[0]:
images = [instance["image"] for instance in instances]
new_images = []
for image in images:
if type(image) is list:
for i in image:
new_images.append(i)
else:
new_images.append(image)
images = new_images
if all(x is not None and x.shape == images[0].shape for x in images):
batch["images"] = torch.stack(images)
else:
batch["images"] = images
batch["audios"] = {}
if "audio" in instances[0]:
audios = [instance["audio"] for instance in instances]
audio_lengths = [instance["audio_lengths"] for instance in instances]
audio_lengths_for_llm = [instance["audio_lengths_for_llm"] for instance in instances]
new_audios = []
new_audio_lengths = []
new_audio_lengths_for_llm = []
for i, audio in enumerate(audios):
length = audio_lengths[i]
length_for_llm = audio_lengths_for_llm[i]
if type(audio) is list:
for j, a in enumerate(audio):
new_audios.append(a)
new_audio_lengths.append(length[j])
new_audio_lengths_for_llm.append(length_for_llm[j])
else:
new_audios.append(audio)
new_audio_lengths.append(length)
new_audio_lengths_for_llm.append(length_for_llm)
audios = new_audios
audios = pad_sequence(audios, batch_first=True, padding_value=0)
batch["audios"]["audios"] = audios
batch["audios"]["lengths"] = torch.tensor(new_audio_lengths)
batch["audios"]["lengths_for_llm"] = torch.tensor(new_audio_lengths_for_llm)
return batch
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
train_dataset = LazySupervisedDataset(tokenizer=tokenizer, data_args=data_args)
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float("inf")
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
# print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')
return best_ratio
def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num
)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size
)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size,
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images, [len(processed_images)]
import base64
import re
from io import BytesIO
import torch
from PIL import Image
from transformers import StoppingCriteria
from vita.constants import AUDIO_TOKEN_INDEX, IMAGE_TOKEN_INDEX
def load_image_from_base64(image):
return Image.open(BytesIO(base64.b64decode(image)))
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
def process_images(images, image_processor, model_cfg):
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
new_images = []
if image_aspect_ratio == "pad":
for image in images:
image = expand2square(image, tuple(int(x * 255) for x in image_processor.image_mean))
image = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
new_images.append(image)
else:
return image_processor(images, return_tensors="pt")["pixel_values"]
if all(x.shape == new_images[0].shape for x in new_images):
new_images = torch.stack(new_images, dim=0)
return new_images
def tokenizer_image_token(
prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None
):
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
def insert_separator(X, sep):
return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
input_ids = []
offset = 0
if (
len(prompt_chunks) > 0
and len(prompt_chunks[0]) > 0
and prompt_chunks[0][0] == tokenizer.bos_token_id
):
offset = 1
input_ids.append(prompt_chunks[0][0])
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
input_ids.extend(x[offset:])
if return_tensors is not None:
if return_tensors == "pt":
return torch.tensor(input_ids, dtype=torch.long)
raise ValueError(f"Unsupported tensor type: {return_tensors}")
return input_ids
def tokenizer_image_audio_token(
prompt,
tokenizer,
image_token_index=IMAGE_TOKEN_INDEX,
audio_token_index=AUDIO_TOKEN_INDEX,
return_tensors=None,
):
prompt_chunks = []
for chunk in re.split(r"(<audio>|<image>)", prompt):
if chunk == "<audio>":
prompt_chunks.append([audio_token_index])
elif chunk == "<image>":
prompt_chunks.append([image_token_index])
else:
prompt_chunks.append(tokenizer(chunk).input_ids)
input_ids = []
offset = 0
if (
len(prompt_chunks) > 0
and len(prompt_chunks[0]) > 0
and prompt_chunks[0][0] == tokenizer.bos_token_id
):
offset = 1
input_ids.append(prompt_chunks[0][0])
for x in prompt_chunks:
if x != [image_token_index] and x != [audio_token_index]:
input_ids.extend(x[offset:])
else:
input_ids.extend(x[:])
if return_tensors is not None:
if return_tensors == "pt":
return torch.tensor(input_ids, dtype=torch.long)
raise ValueError(f"Unsupported tensor type: {return_tensors}")
return input_ids
def get_model_name_from_path(model_path):
model_path = model_path.strip("/")
model_paths = model_path.split("/")
if model_paths[-1].startswith("checkpoint-"):
return model_paths[-2] + "_" + model_paths[-1]
else:
return model_paths[-1]
class KeywordsStoppingCriteria(StoppingCriteria):
def __init__(self, keywords, tokenizer, input_ids):
self.keywords = keywords
self.keyword_ids = []
self.max_keyword_len = 0
for keyword in keywords:
cur_keyword_ids = tokenizer(keyword).input_ids
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
cur_keyword_ids = cur_keyword_ids[1:]
if len(cur_keyword_ids) > self.max_keyword_len:
self.max_keyword_len = len(cur_keyword_ids)
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
self.tokenizer = tokenizer
self.start_len = input_ids.shape[1]
def call_for_batch(
self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
for keyword_id in self.keyword_ids:
truncated_output_ids = output_ids[0, -keyword_id.shape[0] :]
if torch.equal(truncated_output_ids, keyword_id):
return True
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
for keyword in self.keywords:
if keyword in outputs:
return True
return False
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
outputs = []
for i in range(output_ids.shape[0]):
outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
return all(outputs)
from .core import *
from .utils import *
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment