Commit c501623c authored by chenych's avatar chenych
Browse files

add vlmo

parent 4538607b
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import numpy as np
import vlmo.modules.multiway_transformer
from transformers.models.bert.modeling_bert import BertConfig, BertEmbeddings
from vlmo.modules import heads, objectives, vlmo_utils
from pytorch_lightning.utilities.distributed import rank_zero_info
from scipy import interpolate
from timm.models import create_model
def convert_to_textpt_ckpt(state_dict, module):
new_state_dict = {}
# Merge relative_position_bias_table from all layer into one tensor,
# so we can use one op for gather the relative position bias for speed up
relative_position_bias_tables = {}
for key in state_dict:
value = state_dict[key]
if "relative_position_bias_table" in key:
# transformer.blocks.0.attn.relative_position_bias_table
layer_idx = int(key.split(".attn.")[0].split('.')[-1])
relative_position_bias_tables[layer_idx] = value
continue
if "mlp" in key:
key_imag = "transformer." + key.replace("mlp", "mlp_imag")
new_state_dict[key_imag] = value
elif "norm2" in key:
key_imag = "transformer." + key.replace("norm2", "norm2_imag")
new_state_dict[key_imag] = value
else:
new_key = "transformer." + key
new_state_dict[new_key] = value
if len(relative_position_bias_tables) > 0:
tensor_list = []
for layer_idx in sorted(relative_position_bias_tables.keys()):
tensor_list.append(relative_position_bias_tables[layer_idx])
relative_position_bias_table = torch.cat(tensor_list, dim=1)
num_distence, _ = relative_position_bias_table.shape
all_relative_position_bias_table = module.relative_position_bias_table.data.clone()
all_relative_position_bias_table[:num_distence, :] = relative_position_bias_table
new_state_dict["relative_position_bias_table"] = all_relative_position_bias_table
return new_state_dict
def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
# interpolate position embedding
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = visual_encoder.patch_embed.num_patches
num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches ** 0.5)
if orig_size!=new_size:
# class_token and dist_token are kept unchanged
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
return new_pos_embed
else:
return pos_embed_checkpoint
def convert_deepspeed_ckpt(state_dict):
new_state_dict = {}
for key in state_dict:
if key.startswith("module."):
new_key = key[len("module."):]
value = state_dict[key]
new_state_dict[new_key] = value
else:
new_state_dict[key] = state_dict[key]
return new_state_dict
class VLMo(pl.LightningModule):
def __init__(self, config):
super().__init__()
self.save_hyperparameters()
# backbone & patch projection
self.img_size = config["image_size"]
self.transformer = create_model(
config["model_arch"],
img_size=self.img_size,
pretrained=False,
drop_rate=0,
drop_path_rate=config["drop_path_rate"],
attn_drop_rate=0,
drop_block_rate=None,
config=self.hparams.config,
)
self.patch_size = self.transformer.patch_size
self.vlffn_start_layer_index = self.transformer.vlffn_start_layer_index
self.num_layers = len(self.transformer.blocks)
self.num_features = self.transformer.num_features
self.build_relative_position_embed(config)
# language embedding
bert_config = BertConfig(
vocab_size=config["vocab_size"],
hidden_size=self.num_features,
max_position_embeddings=config["max_text_len"],
hidden_dropout_prob=config["drop_path_rate"],
position_embedding_type="rel_pos" if self.transformer.need_relative_position_embed else "absolute",
)
self.text_embeddings = BertEmbeddings(bert_config)
self.text_embeddings.apply(objectives.init_weights)
self.token_type_embeddings = nn.Embedding(2, self.num_features)
self.token_type_embeddings.apply(objectives.init_weights)
# task layers
self.pooler = heads.Pooler(self.num_features)
self.pooler.apply(objectives.init_weights)
## language modeling
if config["loss_names"]["mlm"] > 0 or config["loss_names"]["textmlm"] > 0:
self.mlm_score = heads.MLMHead(bert_config)
self.mlm_score.apply(objectives.init_weights)
## image-text matching (global hard negative)
if config["loss_names"]["itm"] > 0:
self.itm_score = heads.ITMHead(self.num_features)
self.itm_score.apply(objectives.init_weights)
## contrastive loss (or sampling for global hard negative)
if config["loss_names"]["itc"] > 0:
self.itc_text_proj = heads.ITCHead(self.num_features)
self.itc_image_proj = heads.ITCHead(self.num_features)
self.itc_text_proj.apply(objectives.init_weights)
self.itc_image_proj.apply(objectives.init_weights)
self.itc_vl_text_proj = heads.ITCHead(self.num_features)
self.itc_vl_image_proj = heads.ITCHead(self.num_features)
self.itc_vl_text_proj.apply(objectives.init_weights)
self.itc_vl_image_proj.apply(objectives.init_weights)
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.logit_vl_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
## retrieval task ft
if config["loss_names"]["irtr"] > 0:
self.itc_text_proj = heads.ITCHead(self.num_features)
self.itc_image_proj = heads.ITCHead(self.num_features)
self.itc_text_proj.apply(objectives.init_weights)
self.itc_image_proj.apply(objectives.init_weights)
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.load_pretrained_weight()
# ===================== Downstream ===================== #
## VQAv2
if self.hparams.config["loss_names"]["vqa"] > 0:
vs = self.hparams.config["vqav2_label_size"]
self.vqa_classifier = nn.Sequential(
nn.Linear(self.num_features, self.num_features * 2),
nn.LayerNorm(self.num_features * 2),
nn.GELU(),
nn.Linear(self.num_features * 2, vs),
)
self.vqa_classifier.apply(objectives.init_weights)
## NLVR2 (Visual reasoning)
if self.hparams.config["loss_names"]["nlvr2"] > 0:
self.nlvr2_classifier = nn.Sequential(
nn.Linear(self.num_features * 2, self.num_features * 2),
nn.LayerNorm(self.num_features * 2),
nn.GELU(),
nn.Linear(self.num_features * 2, 2),
)
self.nlvr2_classifier.apply(objectives.init_weights)
emb_data = self.token_type_embeddings.weight.data
self.token_type_embeddings = nn.Embedding(3, self.num_features)
self.token_type_embeddings.apply(objectives.init_weights)
self.token_type_embeddings.weight.data[0, :] = emb_data[0, :]
self.token_type_embeddings.weight.data[1, :] = emb_data[1, :]
self.token_type_embeddings.weight.data[2, :] = emb_data[1, :]
vlmo_utils.set_metrics(self)
self.current_tasks = list()
# ===================== load downstream (test_only) ======================
if self.hparams.config["load_path"] != "" and self.hparams.config["test_only"]:
rank_zero_info("Load ckpt from: {}".format(self.hparams.config["load_path"]))
ckpt = torch.load(self.hparams.config["load_path"], map_location="cpu")
state_dict = None
for state_dict_key in ("state_dict", "module", "model"):
if state_dict_key in ckpt:
rank_zero_info("Read state dict from ckpt[%s]. " % state_dict_key)
state_dict = ckpt[state_dict_key]
break
if state_dict_key == "module":
state_dict = convert_deepspeed_ckpt(state_dict)
if state_dict is None:
rank_zero_info("Read state dict from ckpt. ")
state_dict = ckpt
missing_keys, unexpected_keys = self.load_state_dict(state_dict, strict=False)
rank_zero_info("missing_keys: {}".format(missing_keys))
rank_zero_info("unexpected_keys: {}".format(unexpected_keys))
def load_pretrained_weight(self):
if self.hparams.config["load_path"] != "" and not self.hparams.config["test_only"]:
config = self.hparams.config
ckpt = torch.load(self.hparams.config["load_path"], map_location="cpu")
rank_zero_info("Load ckpt from: {}".format(self.hparams.config["load_path"]))
state_dict = None
for state_dict_key in ("state_dict", "module", "model"):
if state_dict_key in ckpt:
rank_zero_info("Read state dict from ckpt[%s]. " % state_dict_key)
state_dict = ckpt[state_dict_key]
break
if state_dict_key == "module":
state_dict = convert_deepspeed_ckpt(state_dict)
if state_dict is None:
rank_zero_info("Read state dict from ckpt. ")
state_dict = ckpt
for key in state_dict:
var = state_dict[key]
rank_zero_info("%s = %s" % (key, str(var.size())))
rank_zero_info(config["loss_names"])
if config["loss_names"]["textmlm"] > 0:
rank_zero_info("convert to textpt")
state_dict = convert_to_textpt_ckpt(state_dict, self)
max_text_len = config["max_text_len"]
if "text_embeddings.position_embeddings.weight" in state_dict and state_dict["text_embeddings.position_embeddings.weight"].size(0) != max_text_len:
state_dict["text_embeddings.position_embeddings.weight"].data = state_dict["text_embeddings.position_embeddings.weight"].data[:max_text_len, :]
state_dict["text_embeddings.position_ids"].data = state_dict["text_embeddings.position_ids"].data[:, :max_text_len]
rank_zero_info("text position_embeddings size: {}".format(state_dict["text_embeddings.position_embeddings.weight"].size()))
for check_key in ("relative_position_index", "text_relative_position_index", "text_imag_relative_position_index"):
if check_key in state_dict:
state_dict.pop(check_key)
if "transformer.pos_embed" in state_dict:
pos_embed_reshaped = interpolate_pos_embed(state_dict['transformer.pos_embed'], self.transformer)
state_dict['transformer.pos_embed'] = pos_embed_reshaped
if "relative_position_bias_table" in state_dict:
rel_pos_bias = state_dict["relative_position_bias_table"]
src_num_pos, num_attn_heads = rel_pos_bias.size()
dst_num_pos, _ = self.relative_position_bias_table.size()
dst_patch_shape = self.transformer.patch_embed.patch_shape
if dst_patch_shape[0] != dst_patch_shape[1]:
raise NotImplementedError()
num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1)
src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)
if src_size != dst_size:
state_dict.pop("relative_position_index")
state_dict.pop("text_relative_position_index")
state_dict.pop("text_imag_relative_position_index")
rank_zero_info("Position interpolate from %dx%d to %dx%d" % (
src_size, src_size, dst_size, dst_size))
extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
def geometric_progression(a, r, n):
return a * (1.0 - r ** n) / (1.0 - r)
left, right = 1.01, 1.5
while right - left > 1e-6:
q = (left + right) / 2.0
gp = geometric_progression(1, q, src_size // 2)
if gp > dst_size // 2:
right = q
else:
left = q
# if q > 1.090307:
# q = 1.090307
dis = []
cur = 1
for i in range(src_size // 2):
dis.append(cur)
cur += q ** (i + 1)
r_ids = [-_ for _ in reversed(dis)]
x = r_ids + [0] + dis
y = r_ids + [0] + dis
t = dst_size // 2.0
dx = np.arange(-t, t + 0.1, 1.0)
dy = np.arange(-t, t + 0.1, 1.0)
rank_zero_info("Original positions = %s" % str(x))
rank_zero_info("Target positions = %s" % str(dx))
all_rel_pos_bias = []
for i in range(num_attn_heads):
z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
f = interpolate.interp2d(x, y, z, kind='cubic')
all_rel_pos_bias.append(
torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))
rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
state_dict["relative_position_bias_table"] = new_rel_pos_bias
missing_keys, unexpected_keys = self.load_state_dict(state_dict, strict=False)
rank_zero_info("missing_keys: {}".format(missing_keys))
rank_zero_info("unexpected_keys: {}".format(unexpected_keys))
def get_rel_pos_bias(self, relative_position_index):
if self.relative_position_embed:
relative_position_bias = F.embedding(relative_position_index.long().to(self.relative_position_bias_table.device),
self.relative_position_bias_table)
all_relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, x, y
relative_position_bias_list = torch.chunk(all_relative_position_bias, self.num_layers, dim=0)
return relative_position_bias_list
else:
return [None] * self.num_layers
def build_relative_position_embed(self, config):
if not self.transformer.need_relative_position_embed:
self.relative_position_embed = False
self.text_imag_relative_position_index = None
self.text_relative_position_index = None
self.relative_position_index = None
return
self.relative_position_embed = True
window_size = (int(self.img_size / self.patch_size), int(self.img_size / self.patch_size)) #(14, 14)
rank_zero_info("window_size: {}".format(window_size))
num_heads = self.transformer.num_heads
max_text_len_of_initckpt = config["max_text_len_of_initckpt"] #196
max_text_len = config["max_text_len"] #40
max_imag_len = window_size[0] * window_size[1] + 1 #197
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
self.text_num_relative_distance = 2 * max_text_len_of_initckpt
self.all_num_relative_distance = self.num_relative_distance + self.text_num_relative_distance + 2
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.all_num_relative_distance, num_heads * self.num_layers))
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = \
torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.relative_position_index = relative_position_index
text_position_ids = torch.arange(max_text_len-1)
text_rel_pos_mat = text_position_ids.unsqueeze(-2) - text_position_ids.unsqueeze(-1)
min_distance = int(2-max_text_len_of_initckpt) #-194
# rank_zero_info("min_distance: {}".format(min_distance))
text_rel_pos_mat = text_rel_pos_mat - min_distance
text_rel_pos_mat += (self.num_relative_distance + 2)
text_relative_position_index = \
torch.zeros(size=(max_text_len, ) * 2, dtype=relative_coords.dtype)
text_relative_position_index[1:, 1:] = text_rel_pos_mat
text_relative_position_index[0, 0:] = self.all_num_relative_distance - 3
text_relative_position_index[0:, 0] = self.all_num_relative_distance - 2
text_relative_position_index[0, 0] = self.all_num_relative_distance - 1
self.text_relative_position_index = text_relative_position_index
text2imag_relative_position_index = torch.ones(max_text_len, max_imag_len) * (self.num_relative_distance)
imag2text_relative_position_index = torch.ones(max_imag_len, max_text_len) * (self.num_relative_distance + 1)
text_row_relative_position_index = torch.cat((text_relative_position_index, text2imag_relative_position_index), 1)
imag_row_relative_position_index = torch.cat((imag2text_relative_position_index, relative_position_index), 1)
text_imag_relative_position_index = torch.cat((text_row_relative_position_index, imag_row_relative_position_index), 0)
self.text_imag_relative_position_index = text_imag_relative_position_index
def infer(
self,
batch,
mask_text=False,
mask_image=False,
image_token_type_idx=1,
image_embeds=None,
image_masks=None,
):
if f"image_{image_token_type_idx - 1}" in batch:
imgkey = f"image_{image_token_type_idx - 1}"
else:
imgkey = "image"
do_mlm = "_mlm" if mask_text else ""
text_ids = batch[f"text_ids{do_mlm}"]
text_labels = batch[f"text_labels{do_mlm}"]
text_masks = batch[f"text_masks"]
text_embeds = self.text_embeddings(text_ids)
img = batch[imgkey][0]
image_embeds, image_masks = self.transformer.visual_embed(img)
image_masks = image_masks.long().to(device=img.get_device())
text_embeds, image_embeds = (
text_embeds + self.token_type_embeddings(torch.zeros_like(text_masks)),
image_embeds
+ self.token_type_embeddings(
torch.full_like(image_masks, image_token_type_idx)
),
)
co_embeds = torch.cat([text_embeds, image_embeds], dim=1)
co_masks = torch.cat([text_masks, image_masks], dim=1)
x = co_embeds
relative_position_bias_list = self.get_rel_pos_bias(self.text_imag_relative_position_index)
for i, blk in enumerate(self.transformer.blocks):
x = blk(x, mask=co_masks, modality_type="vl", relative_position_bias=relative_position_bias_list[i])
x = self.transformer.norm(x)
text_feats, image_feats = (
x[:, : text_embeds.shape[1]],
x[:, text_embeds.shape[1] :],
)
cls_feats = self.pooler(x)
ret = {
"text_feats": text_feats,
"image_feats": image_feats,
"cls_feats": cls_feats,
"raw_cls_feats": x[:, 0],
"image": img,
"text_labels": text_labels,
"text_ids": text_ids,
"text_masks": text_masks,
}
return ret
def infer_text(
self,
batch,
mask_text=False,
):
do_mlm = "_mlm" if mask_text else ""
text_ids = batch[f"text_ids{do_mlm}"]
text_labels = batch[f"text_labels{do_mlm}"]
text_masks = batch[f"text_masks"]
text_embeds = self.text_embeddings(text_ids)
text_embeds = text_embeds + self.token_type_embeddings(torch.zeros_like(text_masks))
co_embeds = text_embeds
co_masks = text_masks
x = co_embeds
all_hidden_states = []
relative_position_bias_list = self.get_rel_pos_bias(self.text_relative_position_index)
for i, blk in enumerate(self.transformer.blocks):
x = blk(x, mask=co_masks, modality_type="text", relative_position_bias=relative_position_bias_list[i])
all_hidden_states.append(x)
vlffn_hiddens = all_hidden_states[self.vlffn_start_layer_index-1]
for vlffn_index in range(self.vlffn_start_layer_index, self.num_layers):
vlffn_hiddens = self.transformer.blocks[vlffn_index](vlffn_hiddens, mask=co_masks, modality_type="vl", relative_position_bias=relative_position_bias_list[vlffn_index])
lffn_hiddens = all_hidden_states[-1]
lffn_hiddens = self.transformer.norm(lffn_hiddens)
text_feats, image_feats = (
lffn_hiddens,
None,
)
cls_feats = self.itc_text_proj(lffn_hiddens[:, 0])
cls_feats = cls_feats / cls_feats.norm(dim=-1, keepdim=True)
vlffn_hiddens = self.transformer.norm(vlffn_hiddens)
cls_vlffn_feats = self.itc_vl_text_proj(vlffn_hiddens[:, 0])
cls_vlffn_feats = cls_vlffn_feats / cls_vlffn_feats.norm(dim=-1, keepdim=True)
ret = {
"text_feats": text_feats,
"image_feats": image_feats,
"cls_feats": cls_feats,
"cls_vlffn_feats": cls_vlffn_feats,
"raw_cls_feats": x[:, 0],
"image_masks": None,
"text_labels": text_labels,
"text_ids": text_ids,
"text_masks": text_masks,
}
return ret
def infer_text_ft(
self,
batch,
mask_text=False,
):
do_mlm = "_mlm" if mask_text else ""
text_ids = batch[f"text_ids{do_mlm}"]
text_labels = batch[f"text_labels{do_mlm}"]
text_masks = batch[f"text_masks"]
text_embeds = self.text_embeddings(text_ids)
text_embeds = text_embeds + self.token_type_embeddings(torch.zeros_like(text_masks))
co_embeds = text_embeds
co_masks = text_masks
x = co_embeds
all_hidden_states = []
relative_position_bias_list = self.get_rel_pos_bias(self.text_relative_position_index)
for i, blk in enumerate(self.transformer.blocks):
x = blk(x, mask=co_masks, modality_type="text", relative_position_bias=relative_position_bias_list[i])
all_hidden_states.append(x)
lffn_hiddens = all_hidden_states[-1]
lffn_hiddens = self.transformer.norm(lffn_hiddens)
text_feats, image_feats = (
lffn_hiddens,
None,
)
cls_feats = self.itc_text_proj(lffn_hiddens[:, 0])
cls_feats = cls_feats / cls_feats.norm(dim=-1, keepdim=True)
ret = {
"text_feats": text_feats,
"image_feats": image_feats,
"cls_feats": cls_feats,
"cls_vlffn_feats": None,
"raw_cls_feats": x[:, 0],
"image_masks": None,
"text_labels": text_labels,
"text_ids": text_ids,
"text_masks": text_masks,
}
return ret
def infer_text_mlm(
self,
batch,
mask_text=False,
):
do_mlm = "_mlm" if mask_text else ""
text_ids = batch[f"text_ids{do_mlm}"]
text_labels = batch[f"text_labels{do_mlm}"]
text_masks = batch[f"text_masks"]
text_embeds = self.text_embeddings(text_ids)
text_embeds = text_embeds + self.token_type_embeddings(torch.zeros_like(text_masks))
co_embeds = text_embeds
co_masks = text_masks
x = co_embeds
all_hidden_states = []
relative_position_bias_list = self.get_rel_pos_bias(self.text_relative_position_index)
for i, blk in enumerate(self.transformer.blocks):
x = blk(x, mask=co_masks, modality_type="text", relative_position_bias=relative_position_bias_list[i])
all_hidden_states.append(x)
lffn_hiddens = all_hidden_states[-1]
lffn_hiddens = self.transformer.norm(lffn_hiddens)
text_feats, image_feats = (
lffn_hiddens,
None,
)
ret = {
"text_feats": text_feats,
"image_feats": image_feats,
"cls_feats": None,
"cls_vlffn_feats": None,
"raw_cls_feats": x[:, 0],
"image_masks": None,
"text_labels": text_labels,
"text_ids": text_ids,
"text_masks": text_masks,
}
return ret
def infer_image(
self,
batch,
mask_image=False,
image_token_type_idx=1,
image_embeds=None,
image_masks=None,
):
if f"image_{image_token_type_idx - 1}" in batch:
imgkey = f"image_{image_token_type_idx - 1}"
else:
imgkey = "image"
img = batch[imgkey][0]
image_embeds, image_masks = self.transformer.visual_embed(img)
image_masks = image_masks.long().to(device=img.get_device())
image_embeds = image_embeds + self.token_type_embeddings(
torch.full_like(image_masks, image_token_type_idx)
)
co_embeds = image_embeds
co_masks = image_masks
x = co_embeds
all_hidden_states = []
relative_position_bias_list = self.get_rel_pos_bias(self.relative_position_index)
for i, blk in enumerate(self.transformer.blocks):
x = blk(x, mask=co_masks, modality_type="image", relative_position_bias=relative_position_bias_list[i])
all_hidden_states.append(x)
vlffn_hiddens = all_hidden_states[self.vlffn_start_layer_index-1]
for vlffn_index in range(self.vlffn_start_layer_index, self.num_layers):
vlffn_hiddens = self.transformer.blocks[vlffn_index](vlffn_hiddens, mask=co_masks, modality_type="vl", relative_position_bias=relative_position_bias_list[vlffn_index])
vffn_hiddens = all_hidden_states[-1]
vffn_hiddens = self.transformer.norm(vffn_hiddens)
text_feats, image_feats = (
None,
vffn_hiddens,
)
cls_feats = self.itc_image_proj(vffn_hiddens[:, 0])
cls_feats = cls_feats / cls_feats.norm(dim=-1, keepdim=True)
vlffn_hiddens = self.transformer.norm(vlffn_hiddens)
cls_vlffn_feats = self.itc_vl_image_proj(vlffn_hiddens[:, 0])
cls_vlffn_feats = cls_vlffn_feats / cls_vlffn_feats.norm(dim=-1, keepdim=True)
ret = {
"text_feats": text_feats,
"image_feats": image_feats,
"cls_feats": cls_feats,
"cls_vlffn_feats": cls_vlffn_feats,
"raw_cls_feats": x[:, 0],
"image_masks": image_masks,
"text_labels": None,
"text_ids": None,
"text_masks": None,
}
return ret
def infer_image_ft(
self,
batch,
mask_image=False,
image_token_type_idx=1,
image_embeds=None,
image_masks=None,
):
if f"image_{image_token_type_idx - 1}" in batch:
imgkey = f"image_{image_token_type_idx - 1}"
else:
imgkey = "image"
img = batch[imgkey][0]
image_embeds, image_masks = self.transformer.visual_embed(img)
image_masks = image_masks.long().to(device=img.get_device())
image_embeds = image_embeds + self.token_type_embeddings(
torch.full_like(image_masks, image_token_type_idx)
)
co_embeds = image_embeds
co_masks = image_masks
x = co_embeds
all_hidden_states = []
relative_position_bias_list = self.get_rel_pos_bias(self.relative_position_index)
for i, blk in enumerate(self.transformer.blocks):
x = blk(x, mask=co_masks, modality_type="image", relative_position_bias=relative_position_bias_list[i])
all_hidden_states.append(x)
vffn_hiddens = all_hidden_states[-1]
vffn_hiddens = self.transformer.norm(vffn_hiddens)
text_feats, image_feats = (
None,
vffn_hiddens,
)
cls_feats = self.itc_image_proj(vffn_hiddens[:, 0])
cls_feats = cls_feats / cls_feats.norm(dim=-1, keepdim=True)
ret = {
"text_feats": text_feats,
"image_feats": image_feats,
"cls_feats": cls_feats,
"cls_vlffn_feats": None,
"raw_cls_feats": x[:, 0],
"image_masks": image_masks,
"text_labels": None,
"text_ids": None,
"text_masks": None,
}
return ret
def forward(self, batch):
ret = dict()
if len(self.current_tasks) == 0:
ret.update(self.infer(batch))
return ret
# Masked Language Modeling
if "mlm" in self.current_tasks:
ret.update(objectives.compute_mlm(self, batch))
# Textonly Masked Language Modeling
if "textmlm" in self.current_tasks:
ret.update(objectives.compute_textonly_mlm(self, batch))
# Contrastive loss for pretraining
if "itc" in self.current_tasks:
ret.update(objectives.compute_itc(self, batch))
# Contrastive loss for finetuning
if "irtr" in self.current_tasks:
ret.update(objectives.compute_irtr(self, batch))
# Image Text Matching with global hard negative, must use with itc
if "itm" in self.current_tasks:
ret.update(objectives.compute_itm_hardneg(self, batch, ret["itc_i2t_logits"], ret["itc_t2i_logits"]))
# Visual Question Answering
if "vqa" in self.current_tasks:
ret.update(objectives.compute_vqa(self, batch))
# Natural Language for Visual Reasoning 2
if "nlvr2" in self.current_tasks:
ret.update(objectives.compute_nlvr2(self, batch))
return ret
def training_step(self, batch, batch_idx):
vlmo_utils.set_task(self)
output = self(batch)
total_loss = sum([v for k, v in output.items() if "loss" in k])
return total_loss
def training_epoch_end(self, outs):
vlmo_utils.epoch_wrapup(self)
def validation_step(self, batch, batch_idx):
vlmo_utils.set_task(self)
output = self(batch)
def validation_epoch_end(self, outs):
vlmo_utils.epoch_wrapup(self)
def test_step(self, batch, batch_idx):
vlmo_utils.set_task(self)
output = self(batch)
ret = dict()
if self.hparams.config["loss_names"]["vqa"] > 0:
ret.update(objectives.vqa_test_step(self, batch, output))
return ret
def test_epoch_end(self, outs):
model_name = self.hparams.config["load_path"].split("/")[-1][:-5]
if self.hparams.config["loss_names"]["vqa"] > 0:
objectives.vqa_test_wrapup(outs, model_name, self.hparams.config["log_dir"])
vlmo_utils.epoch_wrapup(self)
def configure_optimizers(self):
return vlmo_utils.set_schedule(self)
import torch
import random
import json
from transformers.optimization import AdamW
from transformers import (
get_polynomial_decay_schedule_with_warmup,
get_cosine_schedule_with_warmup,
)
from vlmo.modules.dist_utils import all_gather
from vlmo.modules.objectives import compute_irtr_recall, compute_irtr_recall_with_rerank
from vlmo.gadgets.my_metrics import Accuracy, VQAScore, Scalar
from pytorch_lightning.utilities.distributed import rank_zero_info
def set_metrics(pl_module):
for split in ["train", "val"]:
for k, v in pl_module.hparams.config["loss_names"].items():
if v < 1:
continue
if k == "vqa":
setattr(pl_module, f"{split}_vqa_score", VQAScore())
setattr(pl_module, f"{split}_{k}_loss", Scalar())
elif k == "nlvr2":
if split == "train":
setattr(pl_module, f"train_{k}_accuracy", Accuracy())
setattr(pl_module, f"train_{k}_loss", Scalar())
else:
setattr(pl_module, f"dev_{k}_accuracy", Accuracy())
setattr(pl_module, f"dev_{k}_loss", Scalar())
setattr(pl_module, f"test_{k}_accuracy", Accuracy())
setattr(pl_module, f"test_{k}_loss", Scalar())
elif k == "irtr":
setattr(pl_module, f"{split}_{k}_i2t_accuracy", Accuracy())
setattr(pl_module, f"{split}_{k}_t2i_accuracy", Accuracy())
setattr(pl_module, f"{split}_{k}_loss", Scalar())
setattr(pl_module, f"{split}_{k}_logit_scale", Scalar())
elif k == "itm":
setattr(pl_module, f"{split}_{k}_accuracy", Accuracy())
setattr(pl_module, f"{split}_{k}_loss", Scalar())
elif k == "itc":
setattr(pl_module, f"{split}_{k}_i2t_accuracy", Accuracy())
setattr(pl_module, f"{split}_{k}_t2i_accuracy", Accuracy())
setattr(pl_module, f"{split}_{k}_loss", Scalar())
setattr(pl_module, f"{split}_{k}_logit_scale", Scalar())
setattr(pl_module, f"{split}_{k}_vl_i2t_accuracy", Accuracy())
setattr(pl_module, f"{split}_{k}_vl_t2i_accuracy", Accuracy())
setattr(pl_module, f"{split}_{k}_vl_logit_scale", Scalar())
else:
setattr(pl_module, f"{split}_{k}_accuracy", Accuracy())
setattr(pl_module, f"{split}_{k}_loss", Scalar())
def epoch_wrapup(pl_module):
phase = "train" if pl_module.training else "val"
the_metric = 0
if pl_module.hparams.config["get_recall_metric"] and not pl_module.training:
(val_ir_r1, val_ir_r5, val_ir_r10, val_tr_r1, val_tr_r5, val_tr_r10) = compute_irtr_recall(pl_module, split="val")
val_avg = (val_ir_r1.item() + val_ir_r5.item() + val_ir_r10.item() + val_tr_r1.item() + val_tr_r5.item() + val_tr_r10.item()) / 6.0
pl_module.logger.experiment.add_scalar(
"recalls/val_avg", val_avg, pl_module.global_step
)
(ir_r1, ir_r5, ir_r10, tr_r1, tr_r5, tr_r10) = compute_irtr_recall(pl_module, split="test")
test_avg = (ir_r1.item() + ir_r5.item() + ir_r10.item() + tr_r1.item() + tr_r5.item() + tr_r10.item()) / 6.0
pl_module.logger.experiment.add_scalar(
"recalls/test_avg", test_avg, pl_module.global_step
)
print("val_avg:{}, test_avg:{}".format(val_avg, test_avg))
print("test ir_r1:{}, ir_r5:{}, ir_r10:{}, tr_r1:{}, tr_r5:{}, tr_r10:{}".format(ir_r1, ir_r5, ir_r10, tr_r1, tr_r5, tr_r10))
pl_module.logger.experiment.add_scalar(
"recalls/ir_r1", ir_r1, pl_module.global_step
)
pl_module.logger.experiment.add_scalar(
"recalls/ir_r5", ir_r5, pl_module.global_step
)
pl_module.logger.experiment.add_scalar(
"recalls/ir_r10", ir_r10, pl_module.global_step
)
pl_module.logger.experiment.add_scalar(
"recalls/tr_r1", tr_r1, pl_module.global_step
)
pl_module.logger.experiment.add_scalar(
"recalls/tr_r5", tr_r5, pl_module.global_step
)
pl_module.logger.experiment.add_scalar(
"recalls/tr_r10", tr_r10, pl_module.global_step
)
the_metric += val_avg
for loss_name, v in pl_module.hparams.config["loss_names"].items():
if v < 1:
continue
value = 0
if loss_name == "vqa":
value = getattr(pl_module, f"{phase}_{loss_name}_score").compute()
pl_module.log(f"{loss_name}/{phase}/score_epoch", value)
getattr(pl_module, f"{phase}_{loss_name}_score").reset()
pl_module.log(
f"{loss_name}/{phase}/loss_epoch",
getattr(pl_module, f"{phase}_{loss_name}_loss").compute(),
)
getattr(pl_module, f"{phase}_{loss_name}_loss").reset()
elif loss_name == "nlvr2":
if phase == "train":
value = getattr(pl_module, f"train_{loss_name}_accuracy").compute()
pl_module.log(f"{loss_name}/train/accuracy_epoch", value)
getattr(pl_module, f"train_{loss_name}_accuracy").reset()
pl_module.log(
f"{loss_name}/train/loss_epoch",
getattr(pl_module, f"train_{loss_name}_loss").compute(),
)
getattr(pl_module, f"train_{loss_name}_loss").reset()
else:
value_dev = getattr(pl_module, f"dev_{loss_name}_accuracy").compute()
pl_module.log(f"{loss_name}/dev/accuracy_epoch", value_dev)
getattr(pl_module, f"dev_{loss_name}_accuracy").reset()
pl_module.log(
f"{loss_name}/dev/loss_epoch",
getattr(pl_module, f"dev_{loss_name}_loss").compute(),
)
getattr(pl_module, f"dev_{loss_name}_loss").reset()
value_test = getattr(pl_module, f"test_{loss_name}_accuracy").compute()
pl_module.log(f"{loss_name}/test/accuracy_epoch", value_test)
getattr(pl_module, f"test_{loss_name}_accuracy").reset()
pl_module.log(
f"{loss_name}/test/loss_epoch",
getattr(pl_module, f"test_{loss_name}_loss").compute(),
)
getattr(pl_module, f"test_{loss_name}_loss").reset()
value = value_dev
elif loss_name == "irtr":
value_i2t = getattr(pl_module, f"{phase}_{loss_name}_i2t_accuracy").compute()
pl_module.log(f"{loss_name}/{phase}/i2t_accuracy_epoch", value_i2t)
getattr(pl_module, f"{phase}_{loss_name}_i2t_accuracy").reset()
value_t2i = getattr(pl_module, f"{phase}_{loss_name}_t2i_accuracy").compute()
pl_module.log(f"{loss_name}/{phase}/t2i_accuracy_epoch", value_t2i)
getattr(pl_module, f"{phase}_{loss_name}_t2i_accuracy").reset()
value = value_i2t + value_t2i
pl_module.log(
f"{loss_name}/{phase}/loss_epoch",
getattr(pl_module, f"{phase}_{loss_name}_loss").compute(),
)
getattr(pl_module, f"{phase}_{loss_name}_loss").reset()
elif loss_name == "itm":
value = getattr(pl_module, f"{phase}_{loss_name}_accuracy").compute()
pl_module.log(f"{loss_name}/{phase}/accuracy_epoch", value)
getattr(pl_module, f"{phase}_{loss_name}_accuracy").reset()
pl_module.log(
f"{loss_name}/{phase}/loss_epoch",
getattr(pl_module, f"{phase}_{loss_name}_loss").compute(),
)
getattr(pl_module, f"{phase}_{loss_name}_loss").reset()
elif loss_name == "itc":
value_i2t = getattr(pl_module, f"{phase}_{loss_name}_i2t_accuracy").compute()
pl_module.log(f"{loss_name}/{phase}/i2t_accuracy_epoch", value_i2t)
getattr(pl_module, f"{phase}_{loss_name}_i2t_accuracy").reset()
value_t2i = getattr(pl_module, f"{phase}_{loss_name}_t2i_accuracy").compute()
pl_module.log(f"{loss_name}/{phase}/t2i_accuracy_epoch", value_t2i)
getattr(pl_module, f"{phase}_{loss_name}_t2i_accuracy").reset()
pl_module.log(
f"{loss_name}/{phase}/loss_epoch",
getattr(pl_module, f"{phase}_{loss_name}_loss").compute(),
)
getattr(pl_module, f"{phase}_{loss_name}_loss").reset()
value_vl_i2t = getattr(pl_module, f"{phase}_{loss_name}_vl_i2t_accuracy").compute()
pl_module.log(f"{loss_name}/{phase}/vl_i2t_accuracy_epoch", value_vl_i2t)
getattr(pl_module, f"{phase}_{loss_name}_vl_i2t_accuracy").reset()
value_vl_t2i = getattr(pl_module, f"{phase}_{loss_name}_vl_t2i_accuracy").compute()
pl_module.log(f"{loss_name}/{phase}/vl_t2i_accuracy_epoch", value_vl_t2i)
getattr(pl_module, f"{phase}_{loss_name}_vl_t2i_accuracy").reset()
value = value_i2t + value_t2i
else:
value = getattr(pl_module, f"{phase}_{loss_name}_accuracy").compute()
pl_module.log(f"{loss_name}/{phase}/accuracy_epoch", value)
getattr(pl_module, f"{phase}_{loss_name}_accuracy").reset()
pl_module.log(
f"{loss_name}/{phase}/loss_epoch",
getattr(pl_module, f"{phase}_{loss_name}_loss").compute(),
)
getattr(pl_module, f"{phase}_{loss_name}_loss").reset()
the_metric += value
pl_module.log(f"{phase}/the_metric", the_metric)
def check_non_acc_grad(pl_module):
if pl_module.token_type_embeddings.weight.grad is None:
return True
else:
grad = pl_module.token_type_embeddings.weight.grad
return (grad.sum() == 0).item()
def set_task(pl_module):
pl_module.current_tasks = [
k for k, v in pl_module.hparams.config["loss_names"].items() if v >= 1
]
return
def set_schedule(pl_module):
lr = pl_module.hparams.config["learning_rate"]
wd = pl_module.hparams.config["weight_decay"]
no_decay = [
"bias",
"LayerNorm.bias",
"LayerNorm.weight",
"norm.bias",
"norm.weight",
"norm1.bias",
"norm1.weight",
"norm2.bias",
"norm2.weight",
]
head_names = ["vqa_classifier", "nlvr2_classifier"]
lr_mult = pl_module.hparams.config["lr_mult"]
end_lr = pl_module.hparams.config["end_lr"]
decay_power = pl_module.hparams.config["decay_power"]
optim_type = pl_module.hparams.config["optim_type"]
names = [n for n, p in pl_module.named_parameters()]
optimizer_grouped_parameters = [
{
"params": [
p
for n, p in pl_module.named_parameters()
if not any(nd in n for nd in no_decay)
and not any(bb in n for bb in head_names)
],
"weight_decay": wd,
"lr": lr,
},
{
"params": [
p
for n, p in pl_module.named_parameters()
if any(nd in n for nd in no_decay)
and not any(bb in n for bb in head_names)
],
"weight_decay": 0.0,
"lr": lr,
},
{
"params": [
p
for n, p in pl_module.named_parameters()
if not any(nd in n for nd in no_decay)
and any(bb in n for bb in head_names)
],
"weight_decay": wd,
"lr": lr * lr_mult,
},
{
"params": [
p
for n, p in pl_module.named_parameters()
if any(nd in n for nd in no_decay) and any(bb in n for bb in head_names)
],
"weight_decay": 0.0,
"lr": lr * lr_mult,
},
]
if optim_type == "adamw":
optimizer = AdamW(
optimizer_grouped_parameters, lr=lr, eps=1e-8, betas=(0.9, 0.98)
)
elif optim_type == "adam":
optimizer = torch.optim.Adam(optimizer_grouped_parameters, lr=lr)
elif optim_type == "sgd":
optimizer = torch.optim.SGD(optimizer_grouped_parameters, lr=lr, momentum=0.9)
if pl_module.trainer.max_steps is None or pl_module.trainer.max_steps==-1:
max_steps = (
len(pl_module.trainer.datamodule.train_dataloader())
* pl_module.trainer.max_epochs
// pl_module.trainer.accumulate_grad_batches
)
else:
max_steps = pl_module.trainer.max_steps
warmup_steps = pl_module.hparams.config["warmup_steps"]
if isinstance(pl_module.hparams.config["warmup_steps"], float):
warmup_steps = int(max_steps * warmup_steps)
rank_zero_info("Warmup_steps:{} \t Max_steps:{}".format(warmup_steps, max_steps))
if decay_power == "cosine":
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=max_steps,
)
else:
scheduler = get_polynomial_decay_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=max_steps,
lr_end=end_lr,
power=decay_power,
)
sched = {"scheduler": scheduler, "interval": "step"}
return (
[optimizer],
[sched],
)
from .pixelbert import (
pixelbert_transform,
pixelbert_transform_randaug,
)
from .square_transform import (
square_transform,
square_transform_randaug,
)
_transforms = {
"pixelbert": pixelbert_transform,
"pixelbert_randaug": pixelbert_transform_randaug,
"square_transform": square_transform,
"square_transform_randaug": square_transform_randaug,
}
def keys_to_transforms(keys: list, size=224):
return [_transforms[key](size=size) for key in keys]
from .utils import (
inception_normalize,
MinMaxResize,
)
from torchvision import transforms
from .randaug import RandAugment
def pixelbert_transform(size=800):
longer = int((1333 / 800) * size)
return transforms.Compose(
[
MinMaxResize(shorter=size, longer=longer),
transforms.ToTensor(),
inception_normalize,
]
)
def pixelbert_transform_randaug(size=800):
longer = int((1333 / 800) * size)
trs = transforms.Compose(
[
MinMaxResize(shorter=size, longer=longer),
transforms.ToTensor(),
inception_normalize,
]
)
trs.transforms.insert(0, RandAugment(2, 9))
return trs
# code in this file is adpated from rpmcruz/autoaugment
# https://github.com/rpmcruz/autoaugment/blob/master/transformations.py
import random
import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw
import numpy as np
import torch
from PIL import Image
def ShearX(img, v): # [-0.3, 0.3]
assert -0.3 <= v <= 0.3
if random.random() > 0.5:
v = -v
return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))
def ShearY(img, v): # [-0.3, 0.3]
assert -0.3 <= v <= 0.3
if random.random() > 0.5:
v = -v
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))
def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
assert -0.45 <= v <= 0.45
if random.random() > 0.5:
v = -v
v = v * img.size[0]
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
assert 0 <= v
if random.random() > 0.5:
v = -v
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
assert -0.45 <= v <= 0.45
if random.random() > 0.5:
v = -v
v = v * img.size[1]
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
assert 0 <= v
if random.random() > 0.5:
v = -v
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
def Rotate(img, v): # [-30, 30]
assert -30 <= v <= 30
if random.random() > 0.5:
v = -v
return img.rotate(v)
def AutoContrast(img, _):
return PIL.ImageOps.autocontrast(img)
def Invert(img, _):
return PIL.ImageOps.invert(img)
def Equalize(img, _):
return PIL.ImageOps.equalize(img)
def Flip(img, _): # not from the paper
return PIL.ImageOps.mirror(img)
def Solarize(img, v): # [0, 256]
assert 0 <= v <= 256
return PIL.ImageOps.solarize(img, v)
def SolarizeAdd(img, addition=0, threshold=128):
img_np = np.array(img).astype(np.int)
img_np = img_np + addition
img_np = np.clip(img_np, 0, 255)
img_np = img_np.astype(np.uint8)
img = Image.fromarray(img_np)
return PIL.ImageOps.solarize(img, threshold)
def Posterize(img, v): # [4, 8]
v = int(v)
v = max(1, v)
return PIL.ImageOps.posterize(img, v)
def Contrast(img, v): # [0.1,1.9]
assert 0.1 <= v <= 1.9
return PIL.ImageEnhance.Contrast(img).enhance(v)
def Color(img, v): # [0.1,1.9]
assert 0.1 <= v <= 1.9
return PIL.ImageEnhance.Color(img).enhance(v)
def Brightness(img, v): # [0.1,1.9]
assert 0.1 <= v <= 1.9
return PIL.ImageEnhance.Brightness(img).enhance(v)
def Sharpness(img, v): # [0.1,1.9]
assert 0.1 <= v <= 1.9
return PIL.ImageEnhance.Sharpness(img).enhance(v)
def Cutout(img, v): # [0, 60] => percentage: [0, 0.2]
assert 0.0 <= v <= 0.2
if v <= 0.0:
return img
v = v * img.size[0]
return CutoutAbs(img, v)
def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2]
# assert 0 <= v <= 20
if v < 0:
return img
w, h = img.size
x0 = np.random.uniform(w)
y0 = np.random.uniform(h)
x0 = int(max(0, x0 - v / 2.0))
y0 = int(max(0, y0 - v / 2.0))
x1 = min(w, x0 + v)
y1 = min(h, y0 + v)
xy = (x0, y0, x1, y1)
color = (125, 123, 114)
# color = (0, 0, 0)
img = img.copy()
PIL.ImageDraw.Draw(img).rectangle(xy, color)
return img
def SamplePairing(imgs): # [0, 0.4]
def f(img1, v):
i = np.random.choice(len(imgs))
img2 = PIL.Image.fromarray(imgs[i])
return PIL.Image.blend(img1, img2, v)
return f
def Identity(img, v):
return img
def augment_list(): # 16 oeprations and their ranges
# https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57
# l = [
# (Identity, 0., 1.0),
# (ShearX, 0., 0.3), # 0
# (ShearY, 0., 0.3), # 1
# (TranslateX, 0., 0.33), # 2
# (TranslateY, 0., 0.33), # 3
# (Rotate, 0, 30), # 4
# (AutoContrast, 0, 1), # 5
# (Invert, 0, 1), # 6
# (Equalize, 0, 1), # 7
# (Solarize, 0, 110), # 8
# (Posterize, 4, 8), # 9
# # (Contrast, 0.1, 1.9), # 10
# (Color, 0.1, 1.9), # 11
# (Brightness, 0.1, 1.9), # 12
# (Sharpness, 0.1, 1.9), # 13
# # (Cutout, 0, 0.2), # 14
# # (SamplePairing(imgs), 0, 0.4), # 15
# ]
# https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505
l = [
(AutoContrast, 0, 1),
(Equalize, 0, 1),
# (Invert, 0, 1),
(Rotate, 0, 30),
(Posterize, 0, 4),
(Solarize, 0, 256),
(SolarizeAdd, 0, 110),
(Color, 0.1, 1.9),
(Contrast, 0.1, 1.9),
(Brightness, 0.1, 1.9),
(Sharpness, 0.1, 1.9),
(ShearX, 0.0, 0.3),
(ShearY, 0.0, 0.3),
# (CutoutAbs, 0, 40),
(TranslateXabs, 0.0, 100),
(TranslateYabs, 0.0, 100),
]
return l
class Lighting(object):
"""Lighting noise(AlexNet - style PCA - based noise)"""
def __init__(self, alphastd, eigval, eigvec):
self.alphastd = alphastd
self.eigval = torch.Tensor(eigval)
self.eigvec = torch.Tensor(eigvec)
def __call__(self, img):
if self.alphastd == 0:
return img
alpha = img.new().resize_(3).normal_(0, self.alphastd)
rgb = (
self.eigvec.type_as(img)
.clone()
.mul(alpha.view(1, 3).expand(3, 3))
.mul(self.eigval.view(1, 3).expand(3, 3))
.sum(1)
.squeeze()
)
return img.add(rgb.view(3, 1, 1).expand_as(img))
class CutoutDefault(object):
"""
Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py
"""
def __init__(self, length):
self.length = length
def __call__(self, img):
h, w = img.size(1), img.size(2)
mask = np.ones((h, w), np.float32)
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)
mask[y1:y2, x1:x2] = 0.0
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
img *= mask
return img
class RandAugment:
def __init__(self, n, m):
self.n = n
self.m = m # [0, 30]
self.augment_list = augment_list()
def __call__(self, img):
ops = random.choices(self.augment_list, k=self.n)
for op, minval, maxval in ops:
val = (float(self.m) / 30) * float(maxval - minval) + minval
img = op(img, val)
return img
import cv2
import numpy as np
## aug functions
def identity_func(img):
return img
def autocontrast_func(img, cutoff=0):
'''
same output as PIL.ImageOps.autocontrast
'''
n_bins = 256
def tune_channel(ch):
n = ch.size
cut = cutoff * n // 100
if cut == 0:
high, low = ch.max(), ch.min()
else:
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
low = np.argwhere(np.cumsum(hist) > cut)
low = 0 if low.shape[0] == 0 else low[0]
high = np.argwhere(np.cumsum(hist[::-1]) > cut)
high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
if high <= low:
table = np.arange(n_bins)
else:
scale = (n_bins - 1) / (high - low)
offset = -low * scale
table = np.arange(n_bins) * scale + offset
table[table < 0] = 0
table[table > n_bins - 1] = n_bins - 1
table = table.clip(0, 255).astype(np.uint8)
return table[ch]
channels = [tune_channel(ch) for ch in cv2.split(img)]
out = cv2.merge(channels)
return out
def equalize_func(img):
'''
same output as PIL.ImageOps.equalize
PIL's implementation is different from cv2.equalize
'''
n_bins = 256
def tune_channel(ch):
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
non_zero_hist = hist[hist != 0].reshape(-1)
step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
if step == 0: return ch
n = np.empty_like(hist)
n[0] = step // 2
n[1:] = hist[:-1]
table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
return table[ch]
channels = [tune_channel(ch) for ch in cv2.split(img)]
out = cv2.merge(channels)
return out
def rotate_func(img, degree, fill=(0, 0, 0)):
'''
like PIL, rotate by degree, not radians
'''
H, W = img.shape[0], img.shape[1]
center = W / 2, H / 2
M = cv2.getRotationMatrix2D(center, degree, 1)
out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
return out
def solarize_func(img, thresh=128):
'''
same output as PIL.ImageOps.posterize
'''
table = np.array([el if el < thresh else 255 - el for el in range(256)])
table = table.clip(0, 255).astype(np.uint8)
out = table[img]
return out
def color_func(img, factor):
'''
same output as PIL.ImageEnhance.Color
'''
## implementation according to PIL definition, quite slow
# degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
# out = blend(degenerate, img, factor)
# M = (
# np.eye(3) * factor
# + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
# )[np.newaxis, np.newaxis, :]
M = (
np.float32([
[0.886, -0.114, -0.114],
[-0.587, 0.413, -0.587],
[-0.299, -0.299, 0.701]]) * factor
+ np.float32([[0.114], [0.587], [0.299]])
)
out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
return out
def contrast_func(img, factor):
"""
same output as PIL.ImageEnhance.Contrast
"""
mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
table = np.array([(
el - mean) * factor + mean
for el in range(256)
]).clip(0, 255).astype(np.uint8)
out = table[img]
return out
def brightness_func(img, factor):
'''
same output as PIL.ImageEnhance.Contrast
'''
table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
out = table[img]
return out
def sharpness_func(img, factor):
'''
The differences the this result and PIL are all on the 4 boundaries, the center
areas are same
'''
kernel = np.ones((3, 3), dtype=np.float32)
kernel[1][1] = 5
kernel /= 13
degenerate = cv2.filter2D(img, -1, kernel)
if factor == 0.0:
out = degenerate
elif factor == 1.0:
out = img
else:
out = img.astype(np.float32)
degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
out = out.astype(np.uint8)
return out
def shear_x_func(img, factor, fill=(0, 0, 0)):
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, factor, 0], [0, 1, 0]])
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
return out
def translate_x_func(img, offset, fill=(0, 0, 0)):
'''
same output as PIL.Image.transform
'''
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, 0, -offset], [0, 1, 0]])
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
return out
def translate_y_func(img, offset, fill=(0, 0, 0)):
'''
same output as PIL.Image.transform
'''
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, 0, 0], [0, 1, -offset]])
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
return out
def posterize_func(img, bits):
'''
same output as PIL.ImageOps.posterize
'''
out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
return out
def shear_y_func(img, factor, fill=(0, 0, 0)):
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, 0, 0], [factor, 1, 0]])
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
return out
def cutout_func(img, pad_size, replace=(0, 0, 0)):
replace = np.array(replace, dtype=np.uint8)
H, W = img.shape[0], img.shape[1]
rh, rw = np.random.random(2)
pad_size = pad_size // 2
ch, cw = int(rh * H), int(rw * W)
x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
out = img.copy()
out[x1:x2, y1:y2, :] = replace
return out
### level to args
def enhance_level_to_args(MAX_LEVEL):
def level_to_args(level):
return ((level / MAX_LEVEL) * 1.8 + 0.1,)
return level_to_args
def shear_level_to_args(MAX_LEVEL, replace_value):
def level_to_args(level):
level = (level / MAX_LEVEL) * 0.3
if np.random.random() > 0.5: level = -level
return (level, replace_value)
return level_to_args
def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
def level_to_args(level):
level = (level / MAX_LEVEL) * float(translate_const)
if np.random.random() > 0.5: level = -level
return (level, replace_value)
return level_to_args
def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
def level_to_args(level):
level = int((level / MAX_LEVEL) * cutout_const)
return (level, replace_value)
return level_to_args
def solarize_level_to_args(MAX_LEVEL):
def level_to_args(level):
level = int((level / MAX_LEVEL) * 256)
return (level, )
return level_to_args
def none_level_to_args(level):
return ()
def posterize_level_to_args(MAX_LEVEL):
def level_to_args(level):
level = int((level / MAX_LEVEL) * 4)
return (level, )
return level_to_args
def rotate_level_to_args(MAX_LEVEL, replace_value):
def level_to_args(level):
level = (level / MAX_LEVEL) * 30
if np.random.random() < 0.5:
level = -level
return (level, replace_value)
return level_to_args
func_dict = {
'Identity': identity_func,
'AutoContrast': autocontrast_func,
'Equalize': equalize_func,
'Rotate': rotate_func,
'Solarize': solarize_func,
'Color': color_func,
'Contrast': contrast_func,
'Brightness': brightness_func,
'Sharpness': sharpness_func,
'ShearX': shear_x_func,
'TranslateX': translate_x_func,
'TranslateY': translate_y_func,
'Posterize': posterize_func,
'ShearY': shear_y_func,
}
translate_const = 10
MAX_LEVEL = 10
replace_value = (128, 128, 128)
arg_dict = {
'Identity': none_level_to_args,
'AutoContrast': none_level_to_args,
'Equalize': none_level_to_args,
'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value),
'Solarize': solarize_level_to_args(MAX_LEVEL),
'Color': enhance_level_to_args(MAX_LEVEL),
'Contrast': enhance_level_to_args(MAX_LEVEL),
'Brightness': enhance_level_to_args(MAX_LEVEL),
'Sharpness': enhance_level_to_args(MAX_LEVEL),
'ShearX': shear_level_to_args(MAX_LEVEL, replace_value),
'TranslateX': translate_level_to_args(
translate_const, MAX_LEVEL, replace_value
),
'TranslateY': translate_level_to_args(
translate_const, MAX_LEVEL, replace_value
),
'Posterize': posterize_level_to_args(MAX_LEVEL),
'ShearY': shear_level_to_args(MAX_LEVEL, replace_value),
}
class RandomAugment(object):
def __init__(self, N=2, M=10, isPIL=False, augs=[]):
self.N = N
self.M = M
self.isPIL = isPIL
if augs:
self.augs = augs
else:
self.augs = list(arg_dict.keys())
def get_random_ops(self):
sampled_ops = np.random.choice(self.augs, self.N)
return [(op, 0.5, self.M) for op in sampled_ops]
def __call__(self, img):
if self.isPIL:
img = np.array(img)
ops = self.get_random_ops()
for name, prob, level in ops:
if np.random.random() > prob:
continue
args = arg_dict[name](level)
img = func_dict[name](img, *args)
return img
if __name__ == '__main__':
a = RandomAugment()
img = np.random.randn(32, 32, 3)
a(img)
\ No newline at end of file
# code in this file is adpated from the ALBEF repo (https://github.com/salesforce/ALBEF)
from .utils import (
inception_normalize,
)
from torchvision import transforms
from .randaugment import RandomAugment
from PIL import Image
def square_transform(size=224):
return transforms.Compose(
[
transforms.Resize((size, size), interpolation=Image.BICUBIC),
transforms.ToTensor(),
inception_normalize,
]
)
def square_transform_randaug(size=224):
return transforms.Compose(
[
transforms.RandomResizedCrop(size, scale=(0.5, 1.0), interpolation=Image.BICUBIC),
transforms.RandomHorizontalFlip(),
RandomAugment(2,7,isPIL=True,augs=['Identity','AutoContrast','Equalize','Brightness','Sharpness',
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
transforms.ToTensor(),
inception_normalize,
]
)
\ No newline at end of file
from torchvision import transforms
from PIL import Image
class MinMaxResize:
def __init__(self, shorter=800, longer=1333):
self.min = shorter
self.max = longer
def __call__(self, x):
w, h = x.size
scale = self.min / min(w, h)
if h < w:
newh, neww = self.min, scale * w
else:
newh, neww = scale * h, self.min
if max(newh, neww) > self.max:
scale = self.max / max(newh, neww)
newh = newh * scale
neww = neww * scale
newh, neww = int(newh + 0.5), int(neww + 0.5)
newh, neww = newh // 32 * 32, neww // 32 * 32
return x.resize((neww, newh), resample=Image.BICUBIC)
class UnNormalize(object):
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, tensor):
"""
Args:
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
Returns:
Tensor: Normalized image.
"""
for t, m, s in zip(tensor, self.mean, self.std):
t.mul_(s).add_(m)
# The normalize code -> t.sub_(m).div_(s)
return tensor
# This is simple maximum entropy normalization performed in Inception paper
inception_normalize = transforms.Compose(
[transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])]
)
# ViT uses simple non-biased inception normalization
# https://github.com/google-research/vision_transformer/blob/master/vit_jax/input_pipeline.py#L132
inception_unnormalize = transforms.Compose(
[UnNormalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])]
)
import re
contractions = {
"aint": "ain't",
"arent": "aren't",
"cant": "can't",
"couldve": "could've",
"couldnt": "couldn't",
"couldn'tve": "couldn't've",
"couldnt've": "couldn't've",
"didnt": "didn't",
"doesnt": "doesn't",
"dont": "don't",
"hadnt": "hadn't",
"hadnt've": "hadn't've",
"hadn'tve": "hadn't've",
"hasnt": "hasn't",
"havent": "haven't",
"hed": "he'd",
"hed've": "he'd've",
"he'dve": "he'd've",
"hes": "he's",
"howd": "how'd",
"howll": "how'll",
"hows": "how's",
"Id've": "I'd've",
"I'dve": "I'd've",
"Im": "I'm",
"Ive": "I've",
"isnt": "isn't",
"itd": "it'd",
"itd've": "it'd've",
"it'dve": "it'd've",
"itll": "it'll",
"let's": "let's",
"maam": "ma'am",
"mightnt": "mightn't",
"mightnt've": "mightn't've",
"mightn'tve": "mightn't've",
"mightve": "might've",
"mustnt": "mustn't",
"mustve": "must've",
"neednt": "needn't",
"notve": "not've",
"oclock": "o'clock",
"oughtnt": "oughtn't",
"ow's'at": "'ow's'at",
"'ows'at": "'ow's'at",
"'ow'sat": "'ow's'at",
"shant": "shan't",
"shed've": "she'd've",
"she'dve": "she'd've",
"she's": "she's",
"shouldve": "should've",
"shouldnt": "shouldn't",
"shouldnt've": "shouldn't've",
"shouldn'tve": "shouldn't've",
"somebody'd": "somebodyd",
"somebodyd've": "somebody'd've",
"somebody'dve": "somebody'd've",
"somebodyll": "somebody'll",
"somebodys": "somebody's",
"someoned": "someone'd",
"someoned've": "someone'd've",
"someone'dve": "someone'd've",
"someonell": "someone'll",
"someones": "someone's",
"somethingd": "something'd",
"somethingd've": "something'd've",
"something'dve": "something'd've",
"somethingll": "something'll",
"thats": "that's",
"thered": "there'd",
"thered've": "there'd've",
"there'dve": "there'd've",
"therere": "there're",
"theres": "there's",
"theyd": "they'd",
"theyd've": "they'd've",
"they'dve": "they'd've",
"theyll": "they'll",
"theyre": "they're",
"theyve": "they've",
"twas": "'twas",
"wasnt": "wasn't",
"wed've": "we'd've",
"we'dve": "we'd've",
"weve": "we've",
"werent": "weren't",
"whatll": "what'll",
"whatre": "what're",
"whats": "what's",
"whatve": "what've",
"whens": "when's",
"whered": "where'd",
"wheres": "where's",
"whereve": "where've",
"whod": "who'd",
"whod've": "who'd've",
"who'dve": "who'd've",
"wholl": "who'll",
"whos": "who's",
"whove": "who've",
"whyll": "why'll",
"whyre": "why're",
"whys": "why's",
"wont": "won't",
"wouldve": "would've",
"wouldnt": "wouldn't",
"wouldnt've": "wouldn't've",
"wouldn'tve": "wouldn't've",
"yall": "y'all",
"yall'll": "y'all'll",
"y'allll": "y'all'll",
"yall'd've": "y'all'd've",
"y'alld've": "y'all'd've",
"y'all'dve": "y'all'd've",
"youd": "you'd",
"youd've": "you'd've",
"you'dve": "you'd've",
"youll": "you'll",
"youre": "you're",
"youve": "you've",
}
manual_map = {
"none": "0",
"zero": "0",
"one": "1",
"two": "2",
"three": "3",
"four": "4",
"five": "5",
"six": "6",
"seven": "7",
"eight": "8",
"nine": "9",
"ten": "10",
}
articles = ["a", "an", "the"]
period_strip = re.compile("(?!<=\d)(\.)(?!\d)")
comma_strip = re.compile("(\d)(\,)(\d)")
punct = [
";",
r"/",
"[",
"]",
'"',
"{",
"}",
"(",
")",
"=",
"+",
"\\",
"_",
"-",
">",
"<",
"@",
"`",
",",
"?",
"!",
]
def normalize_word(token):
_token = token
for p in punct:
if (p + " " in token or " " + p in token) or (
re.search(comma_strip, token) != None
):
_token = _token.replace(p, "")
else:
_token = _token.replace(p, " ")
token = period_strip.sub("", _token, re.UNICODE)
_token = []
temp = token.lower().split()
for word in temp:
word = manual_map.setdefault(word, word)
if word not in articles:
_token.append(word)
for i, word in enumerate(_token):
if word in contractions:
_token[i] = contractions[word]
token = " ".join(_token)
token = token.replace(",", "")
return token
import json
import os
import pandas as pd
import pyarrow as pa
import random
from tqdm import tqdm
from glob import glob
from collections import defaultdict
def path2rest(path, iid2captions, iid2split):
name = path.split("/")[-1]
with open(path, "rb") as fp:
binary = fp.read()
captions = iid2captions[name]
split = iid2split[name]
return [binary, captions, name, split]
def make_arrow(root, dataset_root):
with open(f"{root}/karpathy/dataset_coco.json", "r") as fp:
captions = json.load(fp)
captions = captions["images"]
iid2captions = defaultdict(list)
iid2split = dict()
for cap in tqdm(captions):
filename = cap["filename"]
iid2split[filename] = cap["split"]
for c in cap["sentences"]:
iid2captions[filename].append(c["raw"])
paths = list(glob(f"{root}/train2014/*.jpg")) + list(glob(f"{root}/val2014/*.jpg"))
random.shuffle(paths)
caption_paths = [path for path in paths if path.split("/")[-1] in iid2captions]
if len(paths) == len(caption_paths):
print("all images have caption annotations")
else:
print("not all images have caption annotations")
print(
len(paths), len(caption_paths), len(iid2captions),
)
bs = [path2rest(path, iid2captions, iid2split) for path in tqdm(caption_paths)]
for split in ["train", "val", "restval", "test"]:
batches = [b for b in bs if b[-1] == split]
dataframe = pd.DataFrame(
batches, columns=["image", "caption", "image_id", "split"],
)
table = pa.Table.from_pandas(dataframe)
os.makedirs(dataset_root, exist_ok=True)
with pa.OSFile(
f"{dataset_root}/coco_caption_karpathy_{split}.arrow", "wb"
) as sink:
with pa.RecordBatchFileWriter(sink, table.schema) as writer:
writer.write_table(table)
import json
import pandas as pd
import pyarrow as pa
import gc
import random
import os
from tqdm import tqdm
from glob import glob
def path2rest(path, iid2captions):
split, _, name = path.split("/")[-3:]
split = split.split("_")[-1]
iid = name
with open(path, "rb") as fp:
binary = fp.read()
captions = iid2captions[iid]
return [
binary,
captions,
iid,
split,
]
def make_arrow(root, dataset_root):
for split in ["val", "train"]:
with open(f"{root}/{split}_annot.json", "r") as fp:
captions = json.load(fp)
iid2captions = dict()
for cap in tqdm(captions):
iid = cap[0].split("/")[-1]
iid2captions[iid] = [cap[1]]
paths = list(glob(f"{root}/images_{split}/*/*"))
random.shuffle(paths)
caption_paths = [path for path in paths if path.split("/")[-1] in iid2captions]
if len(paths) == len(caption_paths):
print("all images have caption annotations")
else:
print("not all images have caption annotations")
print(
len(paths), len(caption_paths), len(iid2captions),
)
sub_len = int(len(caption_paths) // 100000)
subs = list(range(sub_len + 1))
for sub in subs:
sub_paths = caption_paths[sub * 100000 : (sub + 1) * 100000]
bs = [path2rest(path, iid2captions) for path in tqdm(sub_paths)]
dataframe = pd.DataFrame(
bs, columns=["image", "caption", "image_id", "split"],
)
table = pa.Table.from_pandas(dataframe)
os.makedirs(dataset_root, exist_ok=True)
with pa.OSFile(
f"{dataset_root}/conceptual_caption_{split}_{sub}.arrow", "wb"
) as sink:
with pa.RecordBatchFileWriter(sink, table.schema) as writer:
writer.write_table(table)
del dataframe
del table
del bs
gc.collect()
import json
import pandas as pd
import pyarrow as pa
import random
import os
from tqdm import tqdm
from glob import glob
from collections import defaultdict
def path2rest(path, iid2captions, iid2split):
name = path.split("/")[-1]
with open(path, "rb") as fp:
binary = fp.read()
captions = iid2captions[name]
split = iid2split[name]
return [binary, captions, name, split]
def make_arrow(root, dataset_root):
with open(f"{root}/karpathy/dataset_flickr30k.json", "r") as fp:
captions = json.load(fp)
captions = captions["images"]
iid2captions = defaultdict(list)
iid2split = dict()
for cap in tqdm(captions):
filename = cap["filename"]
iid2split[filename] = cap["split"]
for c in cap["sentences"]:
iid2captions[filename].append(c["raw"])
paths = list(glob(f"{root}/flickr30k-images/*.jpg"))
random.shuffle(paths)
caption_paths = [path for path in paths if path.split("/")[-1] in iid2captions]
if len(paths) == len(caption_paths):
print("all images have caption annotations")
else:
print("not all images have caption annotations")
print(
len(paths), len(caption_paths), len(iid2captions),
)
bs = [path2rest(path, iid2captions, iid2split) for path in tqdm(caption_paths)]
for split in ["train", "val", "test"]:
batches = [b for b in bs if b[-1] == split]
dataframe = pd.DataFrame(
batches, columns=["image", "caption", "image_id", "split"],
)
table = pa.Table.from_pandas(dataframe)
os.makedirs(dataset_root, exist_ok=True)
with pa.OSFile(
f"{dataset_root}/f30k_caption_karpathy_{split}.arrow", "wb"
) as sink:
with pa.RecordBatchFileWriter(sink, table.schema) as writer:
writer.write_table(table)
import json
import pandas as pd
import pyarrow as pa
import os
from tqdm import tqdm
from collections import defaultdict
def process(root, iden, row):
texts = [r["sentence"] for r in row]
labels = [r["label"] for r in row]
split = iden.split("-")[0]
if iden.startswith("train"):
directory = row[0]["directory"]
path = f"{root}/images/train/{directory}/{iden}"
else:
path = f"{root}/{split}/{iden}"
with open(f"{path}-img0.png", "rb") as fp:
img0 = fp.read()
with open(f"{path}-img1.png", "rb") as fp:
img1 = fp.read()
return [img0, img1, texts, labels, iden]
def make_arrow(root, dataset_root):
train_data = list(
map(json.loads, open(f"{root}/nlvr2/data/train.json").readlines())
)
test1_data = list(
map(json.loads, open(f"{root}/nlvr2/data/test1.json").readlines())
)
dev_data = list(map(json.loads, open(f"{root}/nlvr2/data/dev.json").readlines()))
balanced_test1_data = list(
map(
json.loads,
open(f"{root}/nlvr2/data/balanced/balanced_test1.json").readlines(),
)
)
balanced_dev_data = list(
map(
json.loads,
open(f"{root}/nlvr2/data/balanced/balanced_dev.json").readlines(),
)
)
unbalanced_test1_data = list(
map(
json.loads,
open(f"{root}/nlvr2/data/unbalanced/unbalanced_test1.json").readlines(),
)
)
unbalanced_dev_data = list(
map(
json.loads,
open(f"{root}/nlvr2/data/unbalanced/unbalanced_dev.json").readlines(),
)
)
splits = [
"train",
"dev",
"test1",
"balanced_dev",
"balanced_test1",
"unbalanced_dev",
"unbalanced_test1",
]
datas = [
train_data,
dev_data,
test1_data,
balanced_dev_data,
balanced_test1_data,
unbalanced_dev_data,
unbalanced_test1_data,
]
annotations = dict()
for split, data in zip(splits, datas):
_annot = defaultdict(list)
for row in tqdm(data):
_annot["-".join(row["identifier"].split("-")[:-1])].append(row)
annotations[split] = _annot
for split in splits:
bs = [
process(root, iden, row) for iden, row in tqdm(annotations[split].items())
]
dataframe = pd.DataFrame(
bs, columns=["image_0", "image_1", "questions", "answers", "identifier"],
)
table = pa.Table.from_pandas(dataframe)
os.makedirs(dataset_root, exist_ok=True)
with pa.OSFile(f"{dataset_root}/nlvr2_{split}.arrow", "wb") as sink:
with pa.RecordBatchFileWriter(sink, table.schema) as writer:
writer.write_table(table)
import json
import pandas as pd
import pyarrow as pa
import gc
import random
import os
from tqdm import tqdm
from glob import glob
def path2rest(path, iid2captions):
split, _, name = path.split("/")[-3:]
split = split.split("_")[-1]
iid = name
with open(path, "rb") as fp:
binary = fp.read()
captions = iid2captions[iid]
return [
binary,
captions,
iid,
split,
]
def make_arrow(root, dataset_root):
with open(f"{root}/annot.json", "r") as fp:
captions = json.load(fp)
iid2captions = dict()
for cap in tqdm(captions):
iid = cap[0].split("/")[-1]
iid2captions[iid] = [cap[1]]
paths = list(glob(f"{root}/images_train/*/*"))
random.shuffle(paths)
caption_paths = [path for path in paths if path.split("/")[-1] in iid2captions]
if len(paths) == len(caption_paths):
print("all images have caption annotations")
else:
print("not all images have caption annotations")
print(
len(paths), len(caption_paths), len(iid2captions),
)
sub_len = int(len(caption_paths) // 100000)
subs = list(range(sub_len + 1))
for sub in subs:
sub_paths = caption_paths[sub * 100000 : (sub + 1) * 100000]
bs = [path2rest(path, iid2captions) for path in tqdm(sub_paths)]
dataframe = pd.DataFrame(bs, columns=["image", "caption", "image_id", "split"],)
table = pa.Table.from_pandas(dataframe)
os.makedirs(dataset_root, exist_ok=True)
with pa.OSFile(f"{dataset_root}/sbu_{sub}.arrow", "wb") as sink:
with pa.RecordBatchFileWriter(sink, table.schema) as writer:
writer.write_table(table)
del dataframe
del table
del bs
gc.collect()
import json
import pandas as pd
import pyarrow as pa
import random
import os
from tqdm import tqdm
from glob import glob
from collections import defaultdict
def path2rest(path, iid2captions):
name = path.split("/")[-1]
iid = int(name[:-4])
with open(path, "rb") as fp:
binary = fp.read()
cdicts = iid2captions[iid]
captions = [c["phrase"] for c in cdicts]
widths = [c["width"] for c in cdicts]
heights = [c["height"] for c in cdicts]
xs = [c["x"] for c in cdicts]
ys = [c["y"] for c in cdicts]
return [
binary,
captions,
widths,
heights,
xs,
ys,
str(iid),
]
def make_arrow(root, dataset_root):
with open(f"{root}/annotations/region_descriptions.json", "r") as fp:
captions = json.load(fp)
iid2captions = defaultdict(list)
for cap in tqdm(captions):
cap = cap["regions"]
for c in cap:
iid2captions[c["image_id"]].append(c)
paths = list(glob(f"{root}/images/VG_100K/*.jpg")) + list(
glob(f"{root}/images/VG_100K_2/*.jpg")
)
random.shuffle(paths)
caption_paths = [
path for path in paths if int(path.split("/")[-1][:-4]) in iid2captions
]
if len(paths) == len(caption_paths):
print("all images have caption annotations")
else:
print("not all images have caption annotations")
print(
len(paths), len(caption_paths), len(iid2captions),
)
bs = [path2rest(path, iid2captions) for path in tqdm(caption_paths)]
dataframe = pd.DataFrame(
bs, columns=["image", "caption", "width", "height", "x", "y", "image_id"],
)
table = pa.Table.from_pandas(dataframe)
os.makedirs(dataset_root, exist_ok=True)
with pa.OSFile(f"{dataset_root}/vg.arrow", "wb") as sink:
with pa.RecordBatchFileWriter(sink, table.schema) as writer:
writer.write_table(table)
import json
import pandas as pd
import pyarrow as pa
import random
import os
from tqdm import tqdm
from glob import glob
from collections import defaultdict, Counter
from .glossary import normalize_word
def get_score(occurences):
if occurences == 0:
return 0.0
elif occurences == 1:
return 0.3
elif occurences == 2:
return 0.6
elif occurences == 3:
return 0.9
else:
return 1.0
def path2rest(path, split, annotations, label2ans):
iid = int(path.split("/")[-1].split("_")[-1][:-4])
with open(path, "rb") as fp:
binary = fp.read()
_annot = annotations[split][iid]
_annot = list(_annot.items())
qids, qas = [a[0] for a in _annot], [a[1] for a in _annot]
questions = [qa[0] for qa in qas]
answers = [qa[1] for qa in qas] if "test" not in split else list(list())
answer_labels = (
[a["labels"] for a in answers] if "test" not in split else list(list())
)
answer_scores = (
[a["scores"] for a in answers] if "test" not in split else list(list())
)
answers = (
[[label2ans[l] for l in al] for al in answer_labels]
if "test" not in split
else list(list())
)
return [binary, questions, answers, answer_labels, answer_scores, iid, qids, split]
def make_arrow(root, dataset_root):
with open(f"{root}/v2_OpenEnded_mscoco_train2014_questions.json", "r") as fp:
questions_train2014 = json.load(fp)["questions"]
with open(f"{root}/v2_OpenEnded_mscoco_val2014_questions.json", "r") as fp:
questions_val2014 = json.load(fp)["questions"]
with open(f"{root}/v2_OpenEnded_mscoco_test2015_questions.json", "r") as fp:
questions_test2015 = json.load(fp)["questions"]
with open(f"{root}/v2_OpenEnded_mscoco_test-dev2015_questions.json", "r") as fp:
questions_test_dev2015 = json.load(fp)["questions"]
with open(f"{root}/v2_mscoco_train2014_annotations.json", "r") as fp:
annotations_train2014 = json.load(fp)["annotations"]
with open(f"{root}/v2_mscoco_val2014_annotations.json", "r") as fp:
annotations_val2014 = json.load(fp)["annotations"]
annotations = dict()
for split, questions in zip(
["train", "val", "test", "test-dev"],
[
questions_train2014,
questions_val2014,
questions_test2015,
questions_test_dev2015,
],
):
_annot = defaultdict(dict)
for q in tqdm(questions):
_annot[q["image_id"]][q["question_id"]] = [q["question"]]
annotations[split] = _annot
all_major_answers = list()
for split, annots in zip(
["train", "val"], [annotations_train2014, annotations_val2014],
):
_annot = annotations[split]
for q in tqdm(annots):
all_major_answers.append(q["multiple_choice_answer"])
all_major_answers = [normalize_word(word) for word in tqdm(all_major_answers)]
counter = {k: v for k, v in Counter(all_major_answers).items() if v >= 9}
ans2label = {k: i for i, k in enumerate(counter.keys())}
label2ans = list(counter.keys())
for split, annots in zip(
["train", "val"], [annotations_train2014, annotations_val2014],
):
_annot = annotations[split]
for q in tqdm(annots):
answers = q["answers"]
answer_count = {}
for answer in answers:
answer_ = answer["answer"]
answer_count[answer_] = answer_count.get(answer_, 0) + 1
labels = []
scores = []
for answer in answer_count:
if answer not in ans2label:
continue
labels.append(ans2label[answer])
score = get_score(answer_count[answer])
scores.append(score)
_annot[q["image_id"]][q["question_id"]].append(
{"labels": labels, "scores": scores,}
)
for split in ["train", "val"]:
filtered_annot = dict()
for ik, iv in annotations[split].items():
new_q = dict()
for qk, qv in iv.items():
if len(qv[1]["labels"]) != 0:
new_q[qk] = qv
if len(new_q) != 0:
filtered_annot[ik] = new_q
annotations[split] = filtered_annot
for split in [
"train",
"val",
"test",
"test-dev",
]:
annot = annotations[split]
split_name = {
"train": "train2014",
"val": "val2014",
"test": "test2015",
"test-dev": "test2015",
}[split]
paths = list(glob(f"{root}/{split_name}/*.jpg"))
random.shuffle(paths)
annot_paths = [
path
for path in paths
if int(path.split("/")[-1].split("_")[-1][:-4]) in annot
]
if len(paths) == len(annot_paths):
print("all images have caption annotations")
else:
print("not all images have caption annotations")
print(
len(paths), len(annot_paths), len(annot),
)
bs = [
path2rest(path, split, annotations, label2ans) for path in tqdm(annot_paths)
]
dataframe = pd.DataFrame(
bs,
columns=[
"image",
"questions",
"answers",
"answer_labels",
"answer_scores",
"image_id",
"question_id",
"split",
],
)
table = pa.Table.from_pandas(dataframe)
os.makedirs(dataset_root, exist_ok=True)
with pa.OSFile(f"{dataset_root}/vqav2_{split}.arrow", "wb") as sink:
with pa.RecordBatchFileWriter(sink, table.schema) as writer:
writer.write_table(table)
table = pa.ipc.RecordBatchFileReader(
pa.memory_map(f"{dataset_root}/vqav2_val.arrow", "r")
).read_all()
pdtable = table.to_pandas()
df1 = pdtable[:-1000]
df2 = pdtable[-1000:]
df1 = pa.Table.from_pandas(df1)
df2 = pa.Table.from_pandas(df2)
with pa.OSFile(f"{dataset_root}/vqav2_trainable_val.arrow", "wb") as sink:
with pa.RecordBatchFileWriter(sink, df1.schema) as writer:
writer.write_table(df1)
with pa.OSFile(f"{dataset_root}/vqav2_rest_val.arrow", "wb") as sink:
with pa.RecordBatchFileWriter(sink, df2.schema) as writer:
writer.write_table(df2)
import json
import pandas as pd
import pyarrow as pa
import gc
import random
import os
from tqdm import tqdm
from glob import glob
def path2rest(line):
return [
"None",
[line],
"wikibk",
"train",
]
def make_arrow(root, dataset_root):
for index in range(0, 50):
file_path = f"{root}/wikibk.{index}.txt"
all_sents = []
with open(file_path, "r", encoding="utf-8") as fp:
for line in fp:
all_sents.append(line.strip())
print(file_path)
print("Number of sentences: {}".format(len(all_sents)))
bs = [path2rest(line) for line in tqdm(all_sents)]
dataframe = pd.DataFrame(bs, columns=["image", "caption", "source", "split"],)
table = pa.Table.from_pandas(dataframe)
os.makedirs(dataset_root, exist_ok=True)
with pa.OSFile(f"{dataset_root}/wikibk_train_{index}.arrow", "wb") as sink:
with pa.RecordBatchFileWriter(sink, table.schema) as writer:
writer.write_table(table)
del dataframe
del table
del bs
gc.collect()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment