Commit c501623c authored by chenych's avatar chenych
Browse files

add vlmo

parent 4538607b
This diff is collapsed.
import torch
import random
import json
from transformers.optimization import AdamW
from transformers import (
get_polynomial_decay_schedule_with_warmup,
get_cosine_schedule_with_warmup,
)
from vlmo.modules.dist_utils import all_gather
from vlmo.modules.objectives import compute_irtr_recall, compute_irtr_recall_with_rerank
from vlmo.gadgets.my_metrics import Accuracy, VQAScore, Scalar
from pytorch_lightning.utilities.distributed import rank_zero_info
def set_metrics(pl_module):
for split in ["train", "val"]:
for k, v in pl_module.hparams.config["loss_names"].items():
if v < 1:
continue
if k == "vqa":
setattr(pl_module, f"{split}_vqa_score", VQAScore())
setattr(pl_module, f"{split}_{k}_loss", Scalar())
elif k == "nlvr2":
if split == "train":
setattr(pl_module, f"train_{k}_accuracy", Accuracy())
setattr(pl_module, f"train_{k}_loss", Scalar())
else:
setattr(pl_module, f"dev_{k}_accuracy", Accuracy())
setattr(pl_module, f"dev_{k}_loss", Scalar())
setattr(pl_module, f"test_{k}_accuracy", Accuracy())
setattr(pl_module, f"test_{k}_loss", Scalar())
elif k == "irtr":
setattr(pl_module, f"{split}_{k}_i2t_accuracy", Accuracy())
setattr(pl_module, f"{split}_{k}_t2i_accuracy", Accuracy())
setattr(pl_module, f"{split}_{k}_loss", Scalar())
setattr(pl_module, f"{split}_{k}_logit_scale", Scalar())
elif k == "itm":
setattr(pl_module, f"{split}_{k}_accuracy", Accuracy())
setattr(pl_module, f"{split}_{k}_loss", Scalar())
elif k == "itc":
setattr(pl_module, f"{split}_{k}_i2t_accuracy", Accuracy())
setattr(pl_module, f"{split}_{k}_t2i_accuracy", Accuracy())
setattr(pl_module, f"{split}_{k}_loss", Scalar())
setattr(pl_module, f"{split}_{k}_logit_scale", Scalar())
setattr(pl_module, f"{split}_{k}_vl_i2t_accuracy", Accuracy())
setattr(pl_module, f"{split}_{k}_vl_t2i_accuracy", Accuracy())
setattr(pl_module, f"{split}_{k}_vl_logit_scale", Scalar())
else:
setattr(pl_module, f"{split}_{k}_accuracy", Accuracy())
setattr(pl_module, f"{split}_{k}_loss", Scalar())
def epoch_wrapup(pl_module):
phase = "train" if pl_module.training else "val"
the_metric = 0
if pl_module.hparams.config["get_recall_metric"] and not pl_module.training:
(val_ir_r1, val_ir_r5, val_ir_r10, val_tr_r1, val_tr_r5, val_tr_r10) = compute_irtr_recall(pl_module, split="val")
val_avg = (val_ir_r1.item() + val_ir_r5.item() + val_ir_r10.item() + val_tr_r1.item() + val_tr_r5.item() + val_tr_r10.item()) / 6.0
pl_module.logger.experiment.add_scalar(
"recalls/val_avg", val_avg, pl_module.global_step
)
(ir_r1, ir_r5, ir_r10, tr_r1, tr_r5, tr_r10) = compute_irtr_recall(pl_module, split="test")
test_avg = (ir_r1.item() + ir_r5.item() + ir_r10.item() + tr_r1.item() + tr_r5.item() + tr_r10.item()) / 6.0
pl_module.logger.experiment.add_scalar(
"recalls/test_avg", test_avg, pl_module.global_step
)
print("val_avg:{}, test_avg:{}".format(val_avg, test_avg))
print("test ir_r1:{}, ir_r5:{}, ir_r10:{}, tr_r1:{}, tr_r5:{}, tr_r10:{}".format(ir_r1, ir_r5, ir_r10, tr_r1, tr_r5, tr_r10))
pl_module.logger.experiment.add_scalar(
"recalls/ir_r1", ir_r1, pl_module.global_step
)
pl_module.logger.experiment.add_scalar(
"recalls/ir_r5", ir_r5, pl_module.global_step
)
pl_module.logger.experiment.add_scalar(
"recalls/ir_r10", ir_r10, pl_module.global_step
)
pl_module.logger.experiment.add_scalar(
"recalls/tr_r1", tr_r1, pl_module.global_step
)
pl_module.logger.experiment.add_scalar(
"recalls/tr_r5", tr_r5, pl_module.global_step
)
pl_module.logger.experiment.add_scalar(
"recalls/tr_r10", tr_r10, pl_module.global_step
)
the_metric += val_avg
for loss_name, v in pl_module.hparams.config["loss_names"].items():
if v < 1:
continue
value = 0
if loss_name == "vqa":
value = getattr(pl_module, f"{phase}_{loss_name}_score").compute()
pl_module.log(f"{loss_name}/{phase}/score_epoch", value)
getattr(pl_module, f"{phase}_{loss_name}_score").reset()
pl_module.log(
f"{loss_name}/{phase}/loss_epoch",
getattr(pl_module, f"{phase}_{loss_name}_loss").compute(),
)
getattr(pl_module, f"{phase}_{loss_name}_loss").reset()
elif loss_name == "nlvr2":
if phase == "train":
value = getattr(pl_module, f"train_{loss_name}_accuracy").compute()
pl_module.log(f"{loss_name}/train/accuracy_epoch", value)
getattr(pl_module, f"train_{loss_name}_accuracy").reset()
pl_module.log(
f"{loss_name}/train/loss_epoch",
getattr(pl_module, f"train_{loss_name}_loss").compute(),
)
getattr(pl_module, f"train_{loss_name}_loss").reset()
else:
value_dev = getattr(pl_module, f"dev_{loss_name}_accuracy").compute()
pl_module.log(f"{loss_name}/dev/accuracy_epoch", value_dev)
getattr(pl_module, f"dev_{loss_name}_accuracy").reset()
pl_module.log(
f"{loss_name}/dev/loss_epoch",
getattr(pl_module, f"dev_{loss_name}_loss").compute(),
)
getattr(pl_module, f"dev_{loss_name}_loss").reset()
value_test = getattr(pl_module, f"test_{loss_name}_accuracy").compute()
pl_module.log(f"{loss_name}/test/accuracy_epoch", value_test)
getattr(pl_module, f"test_{loss_name}_accuracy").reset()
pl_module.log(
f"{loss_name}/test/loss_epoch",
getattr(pl_module, f"test_{loss_name}_loss").compute(),
)
getattr(pl_module, f"test_{loss_name}_loss").reset()
value = value_dev
elif loss_name == "irtr":
value_i2t = getattr(pl_module, f"{phase}_{loss_name}_i2t_accuracy").compute()
pl_module.log(f"{loss_name}/{phase}/i2t_accuracy_epoch", value_i2t)
getattr(pl_module, f"{phase}_{loss_name}_i2t_accuracy").reset()
value_t2i = getattr(pl_module, f"{phase}_{loss_name}_t2i_accuracy").compute()
pl_module.log(f"{loss_name}/{phase}/t2i_accuracy_epoch", value_t2i)
getattr(pl_module, f"{phase}_{loss_name}_t2i_accuracy").reset()
value = value_i2t + value_t2i
pl_module.log(
f"{loss_name}/{phase}/loss_epoch",
getattr(pl_module, f"{phase}_{loss_name}_loss").compute(),
)
getattr(pl_module, f"{phase}_{loss_name}_loss").reset()
elif loss_name == "itm":
value = getattr(pl_module, f"{phase}_{loss_name}_accuracy").compute()
pl_module.log(f"{loss_name}/{phase}/accuracy_epoch", value)
getattr(pl_module, f"{phase}_{loss_name}_accuracy").reset()
pl_module.log(
f"{loss_name}/{phase}/loss_epoch",
getattr(pl_module, f"{phase}_{loss_name}_loss").compute(),
)
getattr(pl_module, f"{phase}_{loss_name}_loss").reset()
elif loss_name == "itc":
value_i2t = getattr(pl_module, f"{phase}_{loss_name}_i2t_accuracy").compute()
pl_module.log(f"{loss_name}/{phase}/i2t_accuracy_epoch", value_i2t)
getattr(pl_module, f"{phase}_{loss_name}_i2t_accuracy").reset()
value_t2i = getattr(pl_module, f"{phase}_{loss_name}_t2i_accuracy").compute()
pl_module.log(f"{loss_name}/{phase}/t2i_accuracy_epoch", value_t2i)
getattr(pl_module, f"{phase}_{loss_name}_t2i_accuracy").reset()
pl_module.log(
f"{loss_name}/{phase}/loss_epoch",
getattr(pl_module, f"{phase}_{loss_name}_loss").compute(),
)
getattr(pl_module, f"{phase}_{loss_name}_loss").reset()
value_vl_i2t = getattr(pl_module, f"{phase}_{loss_name}_vl_i2t_accuracy").compute()
pl_module.log(f"{loss_name}/{phase}/vl_i2t_accuracy_epoch", value_vl_i2t)
getattr(pl_module, f"{phase}_{loss_name}_vl_i2t_accuracy").reset()
value_vl_t2i = getattr(pl_module, f"{phase}_{loss_name}_vl_t2i_accuracy").compute()
pl_module.log(f"{loss_name}/{phase}/vl_t2i_accuracy_epoch", value_vl_t2i)
getattr(pl_module, f"{phase}_{loss_name}_vl_t2i_accuracy").reset()
value = value_i2t + value_t2i
else:
value = getattr(pl_module, f"{phase}_{loss_name}_accuracy").compute()
pl_module.log(f"{loss_name}/{phase}/accuracy_epoch", value)
getattr(pl_module, f"{phase}_{loss_name}_accuracy").reset()
pl_module.log(
f"{loss_name}/{phase}/loss_epoch",
getattr(pl_module, f"{phase}_{loss_name}_loss").compute(),
)
getattr(pl_module, f"{phase}_{loss_name}_loss").reset()
the_metric += value
pl_module.log(f"{phase}/the_metric", the_metric)
def check_non_acc_grad(pl_module):
if pl_module.token_type_embeddings.weight.grad is None:
return True
else:
grad = pl_module.token_type_embeddings.weight.grad
return (grad.sum() == 0).item()
def set_task(pl_module):
pl_module.current_tasks = [
k for k, v in pl_module.hparams.config["loss_names"].items() if v >= 1
]
return
def set_schedule(pl_module):
lr = pl_module.hparams.config["learning_rate"]
wd = pl_module.hparams.config["weight_decay"]
no_decay = [
"bias",
"LayerNorm.bias",
"LayerNorm.weight",
"norm.bias",
"norm.weight",
"norm1.bias",
"norm1.weight",
"norm2.bias",
"norm2.weight",
]
head_names = ["vqa_classifier", "nlvr2_classifier"]
lr_mult = pl_module.hparams.config["lr_mult"]
end_lr = pl_module.hparams.config["end_lr"]
decay_power = pl_module.hparams.config["decay_power"]
optim_type = pl_module.hparams.config["optim_type"]
names = [n for n, p in pl_module.named_parameters()]
optimizer_grouped_parameters = [
{
"params": [
p
for n, p in pl_module.named_parameters()
if not any(nd in n for nd in no_decay)
and not any(bb in n for bb in head_names)
],
"weight_decay": wd,
"lr": lr,
},
{
"params": [
p
for n, p in pl_module.named_parameters()
if any(nd in n for nd in no_decay)
and not any(bb in n for bb in head_names)
],
"weight_decay": 0.0,
"lr": lr,
},
{
"params": [
p
for n, p in pl_module.named_parameters()
if not any(nd in n for nd in no_decay)
and any(bb in n for bb in head_names)
],
"weight_decay": wd,
"lr": lr * lr_mult,
},
{
"params": [
p
for n, p in pl_module.named_parameters()
if any(nd in n for nd in no_decay) and any(bb in n for bb in head_names)
],
"weight_decay": 0.0,
"lr": lr * lr_mult,
},
]
if optim_type == "adamw":
optimizer = AdamW(
optimizer_grouped_parameters, lr=lr, eps=1e-8, betas=(0.9, 0.98)
)
elif optim_type == "adam":
optimizer = torch.optim.Adam(optimizer_grouped_parameters, lr=lr)
elif optim_type == "sgd":
optimizer = torch.optim.SGD(optimizer_grouped_parameters, lr=lr, momentum=0.9)
if pl_module.trainer.max_steps is None or pl_module.trainer.max_steps==-1:
max_steps = (
len(pl_module.trainer.datamodule.train_dataloader())
* pl_module.trainer.max_epochs
// pl_module.trainer.accumulate_grad_batches
)
else:
max_steps = pl_module.trainer.max_steps
warmup_steps = pl_module.hparams.config["warmup_steps"]
if isinstance(pl_module.hparams.config["warmup_steps"], float):
warmup_steps = int(max_steps * warmup_steps)
rank_zero_info("Warmup_steps:{} \t Max_steps:{}".format(warmup_steps, max_steps))
if decay_power == "cosine":
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=max_steps,
)
else:
scheduler = get_polynomial_decay_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=max_steps,
lr_end=end_lr,
power=decay_power,
)
sched = {"scheduler": scheduler, "interval": "step"}
return (
[optimizer],
[sched],
)
from .pixelbert import (
pixelbert_transform,
pixelbert_transform_randaug,
)
from .square_transform import (
square_transform,
square_transform_randaug,
)
_transforms = {
"pixelbert": pixelbert_transform,
"pixelbert_randaug": pixelbert_transform_randaug,
"square_transform": square_transform,
"square_transform_randaug": square_transform_randaug,
}
def keys_to_transforms(keys: list, size=224):
return [_transforms[key](size=size) for key in keys]
from .utils import (
inception_normalize,
MinMaxResize,
)
from torchvision import transforms
from .randaug import RandAugment
def pixelbert_transform(size=800):
longer = int((1333 / 800) * size)
return transforms.Compose(
[
MinMaxResize(shorter=size, longer=longer),
transforms.ToTensor(),
inception_normalize,
]
)
def pixelbert_transform_randaug(size=800):
longer = int((1333 / 800) * size)
trs = transforms.Compose(
[
MinMaxResize(shorter=size, longer=longer),
transforms.ToTensor(),
inception_normalize,
]
)
trs.transforms.insert(0, RandAugment(2, 9))
return trs
This diff is collapsed.
This diff is collapsed.
# code in this file is adpated from the ALBEF repo (https://github.com/salesforce/ALBEF)
from .utils import (
inception_normalize,
)
from torchvision import transforms
from .randaugment import RandomAugment
from PIL import Image
def square_transform(size=224):
return transforms.Compose(
[
transforms.Resize((size, size), interpolation=Image.BICUBIC),
transforms.ToTensor(),
inception_normalize,
]
)
def square_transform_randaug(size=224):
return transforms.Compose(
[
transforms.RandomResizedCrop(size, scale=(0.5, 1.0), interpolation=Image.BICUBIC),
transforms.RandomHorizontalFlip(),
RandomAugment(2,7,isPIL=True,augs=['Identity','AutoContrast','Equalize','Brightness','Sharpness',
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
transforms.ToTensor(),
inception_normalize,
]
)
\ No newline at end of file
from torchvision import transforms
from PIL import Image
class MinMaxResize:
def __init__(self, shorter=800, longer=1333):
self.min = shorter
self.max = longer
def __call__(self, x):
w, h = x.size
scale = self.min / min(w, h)
if h < w:
newh, neww = self.min, scale * w
else:
newh, neww = scale * h, self.min
if max(newh, neww) > self.max:
scale = self.max / max(newh, neww)
newh = newh * scale
neww = neww * scale
newh, neww = int(newh + 0.5), int(neww + 0.5)
newh, neww = newh // 32 * 32, neww // 32 * 32
return x.resize((neww, newh), resample=Image.BICUBIC)
class UnNormalize(object):
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, tensor):
"""
Args:
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
Returns:
Tensor: Normalized image.
"""
for t, m, s in zip(tensor, self.mean, self.std):
t.mul_(s).add_(m)
# The normalize code -> t.sub_(m).div_(s)
return tensor
# This is simple maximum entropy normalization performed in Inception paper
inception_normalize = transforms.Compose(
[transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])]
)
# ViT uses simple non-biased inception normalization
# https://github.com/google-research/vision_transformer/blob/master/vit_jax/input_pipeline.py#L132
inception_unnormalize = transforms.Compose(
[UnNormalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])]
)
This diff is collapsed.
import json
import os
import pandas as pd
import pyarrow as pa
import random
from tqdm import tqdm
from glob import glob
from collections import defaultdict
def path2rest(path, iid2captions, iid2split):
name = path.split("/")[-1]
with open(path, "rb") as fp:
binary = fp.read()
captions = iid2captions[name]
split = iid2split[name]
return [binary, captions, name, split]
def make_arrow(root, dataset_root):
with open(f"{root}/karpathy/dataset_coco.json", "r") as fp:
captions = json.load(fp)
captions = captions["images"]
iid2captions = defaultdict(list)
iid2split = dict()
for cap in tqdm(captions):
filename = cap["filename"]
iid2split[filename] = cap["split"]
for c in cap["sentences"]:
iid2captions[filename].append(c["raw"])
paths = list(glob(f"{root}/train2014/*.jpg")) + list(glob(f"{root}/val2014/*.jpg"))
random.shuffle(paths)
caption_paths = [path for path in paths if path.split("/")[-1] in iid2captions]
if len(paths) == len(caption_paths):
print("all images have caption annotations")
else:
print("not all images have caption annotations")
print(
len(paths), len(caption_paths), len(iid2captions),
)
bs = [path2rest(path, iid2captions, iid2split) for path in tqdm(caption_paths)]
for split in ["train", "val", "restval", "test"]:
batches = [b for b in bs if b[-1] == split]
dataframe = pd.DataFrame(
batches, columns=["image", "caption", "image_id", "split"],
)
table = pa.Table.from_pandas(dataframe)
os.makedirs(dataset_root, exist_ok=True)
with pa.OSFile(
f"{dataset_root}/coco_caption_karpathy_{split}.arrow", "wb"
) as sink:
with pa.RecordBatchFileWriter(sink, table.schema) as writer:
writer.write_table(table)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
import json
import pandas as pd
import pyarrow as pa
import gc
import random
import os
from tqdm import tqdm
from glob import glob
def path2rest(line):
return [
"None",
[line],
"wikibk",
"train",
]
def make_arrow(root, dataset_root):
for index in range(0, 50):
file_path = f"{root}/wikibk.{index}.txt"
all_sents = []
with open(file_path, "r", encoding="utf-8") as fp:
for line in fp:
all_sents.append(line.strip())
print(file_path)
print("Number of sentences: {}".format(len(all_sents)))
bs = [path2rest(line) for line in tqdm(all_sents)]
dataframe = pd.DataFrame(bs, columns=["image", "caption", "source", "split"],)
table = pa.Table.from_pandas(dataframe)
os.makedirs(dataset_root, exist_ok=True)
with pa.OSFile(f"{dataset_root}/wikibk_train_{index}.arrow", "wb") as sink:
with pa.RecordBatchFileWriter(sink, table.schema) as writer:
writer.write_table(table)
del dataframe
del table
del bs
gc.collect()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment