Commit 491d0cec authored by chenych's avatar chenych
Browse files

First commit

parents
Pipeline #705 failed with stages
in 0 seconds
close all;clear all;
denoised = load('/MATLAB Drive/painter/sidd/Idenoised.mat');
gt = load('/MATLAB Drive/painter/sidd/ValidationGtBlocksSrgb.mat');
denoised = denoised.Idenoised;
gt = gt.ValidationGtBlocksSrgb;
gt = im2single(gt);
total_psnr = 0;
total_ssim = 0;
for i = 1:40
for k = 1:32
denoised_patch = squeeze(denoised(i,k,:,:,:));
gt_patch = squeeze(gt(i,k,:,:,:));
ssim_val = ssim(denoised_patch, gt_patch);
psnr_val = psnr(denoised_patch, gt_patch);
total_ssim = total_ssim + ssim_val;
total_psnr = total_psnr + psnr_val;
end
end
qm_psnr = total_psnr / (40*32);
qm_ssim = total_ssim / (40*32);
fprintf('PSNR: %f SSIM: %f\n', qm_psnr, qm_ssim);
\ No newline at end of file
# --------------------------------------------------------
# Images Speak in Images: A Generalist Painter for In-Context Visual Learning (https://arxiv.org/abs/2212.02499)
# Github source: https://github.com/baaivision/Painter
# Copyright (c) 2022 Beijing Academy of Artificial Intelligence (BAAI)
# Licensed under The MIT License [see LICENSE for details]
# By Xinlong Wang, Wen Wang
# Based on MAE, BEiT, detectron2, Mask2Former, bts, mmcv, mmdetetection, mmpose, MIRNet, MPRNet, and Uformer codebases
# --------------------------------------------------------'
import sys
import os
import warnings
import requests
import argparse
import torch
import torch.nn.functional as F
import numpy as np
import glob
import tqdm
import matplotlib.pyplot as plt
import cv2
from PIL import Image
import scipy.io as sio
sys.path.append('.')
import models_painter
from skimage.metrics import peak_signal_noise_ratio as psnr_loss
from skimage.metrics import structural_similarity as ssim_loss
imagenet_mean = np.array([0.485, 0.456, 0.406])
imagenet_std = np.array([0.229, 0.224, 0.225])
def get_args_parser():
parser = argparse.ArgumentParser('SIDD denoising', add_help=False)
parser.add_argument('--ckpt_path', type=str, help='path to ckpt',
default='')
parser.add_argument('--model', type=str, help='dir to ckpt',
default='painter_vit_large_patch16_input896x448_win_dec64_8glb_sl1')
parser.add_argument('--prompt', type=str, help='prompt image in train set',
default='9_9')
parser.add_argument('--input_size', type=int, default=448)
parser.add_argument('--save', action='store_true', help='save predictions',
default=False)
return parser.parse_args()
def prepare_model(chkpt_dir, arch='painter_vit_large_patch16_input896x448_win_dec64_8glb_sl1'):
# build model
model = getattr(models_painter, arch)()
# load model
checkpoint = torch.load(chkpt_dir, map_location='cuda:0')
msg = model.load_state_dict(checkpoint['model'], strict=False)
print(msg)
return model
def run_one_image(img, tgt, size, model, out_path, device):
x = torch.tensor(img)
x = x.unsqueeze(dim=0)
x = torch.einsum('nhwc->nchw', x)
tgt = torch.tensor(tgt)
tgt = tgt.unsqueeze(dim=0)
tgt = torch.einsum('nhwc->nchw', tgt)
bool_masked_pos = torch.zeros(model.patch_embed.num_patches)
bool_masked_pos[model.patch_embed.num_patches//2:] = 1
bool_masked_pos = bool_masked_pos.unsqueeze(dim=0)
valid = torch.ones_like(tgt)
loss, y, mask = model(x.float().to(device), tgt.float().to(device), bool_masked_pos.to(device), valid.float().to(device))
y = model.unpatchify(y)
y = torch.einsum('nchw->nhwc', y).detach().cpu()
output = y[0, y.shape[1]//2:, :, :]
output = output * imagenet_std + imagenet_mean
output = F.interpolate(
output[None, ...].permute(0, 3, 1, 2), size=[size[1], size[0]], mode='bicubic').permute(0, 2, 3, 1)[0]
return output.numpy()
if __name__ == '__main__':
args = get_args_parser()
ckpt_path = args.ckpt_path
model = args.model
prompt = args.prompt
input_size = args.input_size
path_splits = ckpt_path.split('/')
ckpt_dir, ckpt_file = path_splits[-2], path_splits[-1]
dst_dir = os.path.join('models_inference', ckpt_dir.split('/')[-1],
"sidd_inference_{}_{}".format(ckpt_file, os.path.basename(prompt).split(".")[0]))
if not os.path.exists(dst_dir):
os.makedirs(dst_dir)
print("output_dir: {}".format(dst_dir))
model_painter = prepare_model(ckpt_path, model)
print('Model loaded.')
device = torch.device("cuda")
model_painter.to(device)
img_src_dir = "datasets/denoise/val/"
filepath = os.path.join(img_src_dir, 'ValidationNoisyBlocksSrgb.mat')
img = sio.loadmat(filepath)
Inoisy = np.float32(np.array(img['ValidationNoisyBlocksSrgb'])) # (40, 32, 256, 256, 3)
Inoisy /= 255.
img2_path = "datasets/denoise/train/input/{}.png".format(prompt)
tgt2_path = "datasets/denoise/train/groundtruth/{}.png".format(prompt)
# load the shared prompt image pair
img2 = Image.open(img2_path).convert("RGB")
img2 = img2.resize((input_size, input_size))
img2 = np.array(img2) / 255.
tgt2 = Image.open(tgt2_path)
tgt2 = tgt2.resize((input_size, input_size))
tgt2 = np.array(tgt2) / 255.
model_painter.eval()
restored = np.zeros_like(Inoisy)
for img_idx in tqdm.tqdm(range(40)):
for patch_idx in range(32):
""" Load an image """
img_org = Inoisy[img_idx, patch_idx, :, :, :]
img = cv2.resize(img_org, (input_size, input_size))
# img = img_org.resize((input_size, input_size))
img = np.concatenate((img2, img), axis=0)
assert img.shape == (input_size * 2, input_size, 3)
# normalize by ImageNet mean and std
img = img - imagenet_mean
img = img / imagenet_std
tgt = tgt2 # tgt is not available
tgt = np.concatenate((tgt2, tgt), axis=0)
assert tgt.shape == (input_size * 2, input_size, 3)
# normalize by ImageNet mean and std
tgt = tgt - imagenet_mean
tgt = tgt / imagenet_std
# make random mask reproducible (comment out to make it change)
torch.manual_seed(2)
output = run_one_image(img, tgt, size=(256, 256), model=model_painter, out_path=None, device=device)
rgb_restored = output
rgb_restored = np.clip(rgb_restored, 0, 1)
restored[img_idx, patch_idx, :, :, :] = rgb_restored
# optionally save images
if args.save:
out_path = os.path.join(dst_dir, '%04d_%02d.png' % (img_idx + 1, patch_idx + 1))
output = rgb_restored * 255
output = Image.fromarray(output.astype(np.uint8))
output.save(out_path)
# save denoised data
sio.savemat(os.path.join(dst_dir, 'Idenoised.mat'), {"Idenoised": restored, })
print(os.path.join(dst_dir, 'Idenoised.mat'))
# --------------------------------------------------------
# Images Speak in Images: A Generalist Painter for In-Context Visual Learning (https://arxiv.org/abs/2212.02499)
# Github source: https://github.com/baaivision/Painter
# Copyright (c) 2022 Beijing Academy of Artificial Intelligence (BAAI)
# Licensed under The MIT License [see LICENSE for details]
# By Xinlong Wang, Wen Wang
# Based on MAE, BEiT, detectron2, Mask2Former, bts, mmcv, mmdetetection, mmpose, MIRNet, MPRNet, and Uformer codebases
# --------------------------------------------------------'
import argparse
import datetime
import json
import numpy as np
import os
import time
from pathlib import Path
import torch
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
import timm
assert timm.__version__ == "0.3.2" # version check
import util.lr_decay as lrd
import util.misc as misc
from util.misc import get_parameter_groups
from util.misc import NativeScalerWithGradNormCount as NativeScaler
from util.pos_embed import interpolate_pos_embed
import models_painter
from engine_train import train_one_epoch, evaluate_pt
from data.pairdataset import PairDataset
import data.pair_transforms as pair_transforms
from util.masking_generator import MaskingGenerator
from data.sampler import DistributedSamplerWrapper
try:
import wandb
has_wandb = True
except ImportError:
has_wandb = False
def get_args_parser():
parser = argparse.ArgumentParser('Painter pre-training', add_help=False)
parser.add_argument('--batch_size', default=2, type=int,
help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
parser.add_argument('--epochs', default=15, type=int)
parser.add_argument('--accum_iter', default=16, type=int,
help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
# Model parameters
parser.add_argument('--model', default='painter_vit_large_patch16_input896x448_win_dec64_8glb_sl1', type=str, metavar='MODEL',
help='Name of model to train')
parser.add_argument('--input_size', default=224, type=int, nargs='+',
help='images input size')
parser.add_argument('--mask_ratio', default=0.5, type=float,
help='Masking ratio (percentage of removed patches).')
parser.add_argument('--norm_pix_loss', action='store_true',
help='Use (per-patch) normalized pixels as targets for computing loss')
parser.set_defaults(norm_pix_loss=False)
parser.add_argument('--num_mask_patches', default=784, type=int,
help='number of the visual tokens/patches need be masked')
parser.add_argument('--max_mask_patches_per_block', type=int, default=None)
parser.add_argument('--min_mask_patches_per_block', type=int, default=16)
parser.add_argument('--stop_grad_patch_embed', action='store_true',
help='stop-grad after first conv, or patch embedding')
parser.set_defaults(stop_grad_patch_embed=False)
parser.add_argument('--finetune', default='',
help='finetune from checkpoint')
parser.add_argument('--drop_path', default=0., type=float,
help='Drop path rate (default: 0.)')
parser.add_argument('--min_random_scale', default=0.3, type=float,
help='Minimal random scale for randomresizecrop (default: 0.3)')
parser.add_argument('--last_norm_instance', action='store_true', default=False,
help='use instance norm to normalize each channel map before the decoder layer')
parser.add_argument('--half_mask_ratio', default=0.1, type=float,
help='ratio of using half mask during training (default: 0.1)')
parser.add_argument('--use_checkpoint', action='store_true', default=False,
help='use checkpoint to save GPU memory')
# Optimizer parameters
parser.add_argument('--weight_decay', type=float, default=0.1,
help='weight decay (default: 0.1)')
parser.add_argument('--lr', type=float, default=None, metavar='LR',
help='learning rate (absolute lr)')
parser.add_argument('--blr', type=float, default=1e-3, metavar='LR',
help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0')
parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N',
help='epochs to warmup LR')
parser.add_argument('--save_freq', type=int, default=100,
help='save checkkpoints frequency')
parser.add_argument('--clip_grad', type=float, default=3.0, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: 1e-8)')
parser.add_argument('--opt_betas', default=[0.9, 0.999], type=float, nargs='+', metavar='BETA',
help='Optimizer Betas (default: None, use opt default)')
parser.add_argument('--layer_decay', type=float, default=1.0, metavar='LRD',
help='Learning rate layer decay')
# Dataset parameters
parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str,
help='dataset path')
#parser.add_argument('--json_path', default='./', type=str,
parser.add_argument('--json_path', default='./', nargs='+', type=str,
help='json path')
parser.add_argument('--val_json_path', default='./', nargs='+',type=str,
help='json path')
parser.add_argument('--output_dir', default='./output_dir',
help='path where to save, empty for no saving')
parser.add_argument('--log_dir', default='./output_dir',
help='path where to tensorboard log')
parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--resume', default='',
help='resume from checkpoint')
parser.add_argument('--auto_resume', action='store_true')
parser.set_defaults(auto_resume=False)
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('--num_workers', default=10, type=int)
parser.add_argument('--pin_mem', action='store_true',
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
parser.set_defaults(pin_mem=True)
parser.add_argument('--use_two_pairs', action='store_true',
help='concatenate two pairs of images')
parser.set_defaults(use_two_pairs=True)
# distributed training parameters
parser.add_argument('--world_size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--dist_on_itp', action='store_true')
parser.add_argument('--dist_url', default='env://',
help='url used to set up distributed training')
parser.add_argument('--enable_deepspeed',
action='store_true', default=False)
parser.add_argument('--zero_stage', default=0, type=int,
help='ZeRO optimizer stage (default: 0)')
# misc
parser.add_argument('--log_wandb', action='store_true', default=False,
help='log training and validation metrics to wandb')
known_args, _ = parser.parse_known_args()
if known_args.enable_deepspeed:
try:
import deepspeed
from deepspeed import DeepSpeedConfig
parser = deepspeed.add_config_arguments(parser)
ds_init = deepspeed.initialize
except:
print("Please 'pip install deepspeed==0.4.0'")
exit(0)
else:
ds_init = None
return parser.parse_args(), ds_init
def main(args, ds_init):
misc.init_distributed_mode(args)
if ds_init is not None:
misc.create_ds_config(args)
print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
print("{}".format(args).replace(', ', ',\n'))
device = torch.device(args.device)
# fix the seed for reproducibility
seed = args.seed + misc.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
cudnn.benchmark = True
# define the model
model = models_painter.__dict__[args.model]()
if args.finetune:
checkpoint = torch.load(args.finetune, map_location='cpu')
print("Load pre-trained checkpoint from: %s" % args.finetune)
checkpoint_model = checkpoint['model']
state_dict = model.state_dict()
rm_key_list = ['decoder_embed.weight', 'decoder_embed.bias', 'mask_token']
if args.last_norm_instance:
rm_key_list.extend(['norm.weight', 'norm.bias'])
for k in rm_key_list:
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
print(f"Removing key {k} from pretrained checkpoint")
del checkpoint_model[k]
# interpolate patch embedding
if "patch32" in args.model:
patch_weight = checkpoint['model']['patch_embed.proj.weight']
new_patch_weight = torch.nn.functional.interpolate(patch_weight, size=(32, 32), mode='bicubic', align_corners=False)
checkpoint['model']['patch_embed.proj.weight'] = new_patch_weight
# interpolate position embedding
if "painter" not in args.model:
interpolate_pos_embed(model, checkpoint_model)
# load pre-trained model
msg = model.load_state_dict(checkpoint_model, strict=False)
print(msg)
patch_size = model.patch_size
print("Patch size = %s" % str(patch_size))
args.window_size = (args.input_size[0] // patch_size, args.input_size[1] // patch_size)
args.patch_size = patch_size
# simple augmentation
transform_train = pair_transforms.Compose([
pair_transforms.RandomResizedCrop(args.input_size[1], scale=(args.min_random_scale, 1.0), interpolation=3), # 3 is bicubic
pair_transforms.RandomApply([
pair_transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)
], p=0.8),
pair_transforms.RandomHorizontalFlip(),
pair_transforms.ToTensor(),
pair_transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
transform_train2 = pair_transforms.Compose([
pair_transforms.RandomResizedCrop(args.input_size[1], scale=(0.9999, 1.0), interpolation=3), # 3 is bicubic
pair_transforms.ToTensor(),
pair_transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
transform_train3 = pair_transforms.Compose([
pair_transforms.RandomResizedCrop(args.input_size[1], scale=(0.9999, 1.0), interpolation=3), # 3 is bicubic
pair_transforms.ToTensor(),
pair_transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
transform_train_seccrop = pair_transforms.Compose([
pair_transforms.RandomResizedCrop(args.input_size, scale=(args.min_random_scale, 1.0), ratio=(0.3, 0.7), interpolation=3), # 3 is bicubic
])
transform_val = pair_transforms.Compose([
pair_transforms.RandomResizedCrop(args.input_size[1], scale=(0.9999, 1.0), interpolation=3), # 3 is bicubic
pair_transforms.ToTensor(),
pair_transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
masked_position_generator = MaskingGenerator(
args.window_size, num_masking_patches=args.num_mask_patches,
max_num_patches=args.max_mask_patches_per_block,
min_num_patches=args.min_mask_patches_per_block,
)
dataset_train = PairDataset(args.data_path, args.json_path, transform=transform_train, transform2=transform_train2, transform3=transform_train3, transform_seccrop=transform_train_seccrop, masked_position_generator=masked_position_generator, use_two_pairs=args.use_two_pairs, half_mask_ratio=args.half_mask_ratio)
dataset_val = PairDataset(args.data_path, args.val_json_path, transform=transform_val, transform2=None, transform3=None, masked_position_generator=masked_position_generator, use_two_pairs=args.use_two_pairs, half_mask_ratio=1.0)
print(dataset_train)
print(dataset_val)
if True: # args.distributed:
num_tasks = misc.get_world_size()
global_rank = misc.get_rank()
num_samples_train = len(dataset_train)
weights_train = dataset_train.weights
sampler_train = torch.utils.data.WeightedRandomSampler(weights_train, num_samples_train, replacement=True)
sampler_train = DistributedSamplerWrapper(sampler_train, num_replicas=num_tasks, rank=global_rank, shuffle=True)
print("Sampler_train = %s" % str(sampler_train))
sampler_val = torch.utils.data.DistributedSampler(
dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
else:
sampler_train = torch.utils.data.RandomSampler(dataset_train)
if global_rank == 0 and args.log_dir is not None:
os.makedirs(args.log_dir, exist_ok=True)
log_writer = SummaryWriter(log_dir=args.log_dir)
else:
log_writer = None
if global_rank == 0 and args.log_wandb:
experiment = args.log_dir.split('/')[-2]
if args.resume == '':
wandb.init(project="Painter", name=experiment, config=args)
else:
wandb.init(project="Painter", name=experiment, config=args, resume=True)
data_loader_train = torch.utils.data.DataLoader(
dataset_train, sampler=sampler_train,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=True,
)
data_loader_val = torch.utils.data.DataLoader(
dataset_val, sampler=sampler_val,
batch_size=int(1.5 * args.batch_size),
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=False,
)
model.to(device)
model_without_ddp = model
print("Model = %s" % str(model_without_ddp))
eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
if args.lr is None: # only base_lr is specified
args.lr = args.blr * eff_batch_size / 256
print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
print("actual lr: %.2e" % args.lr)
print("accumulate grad iterations: %d" % args.accum_iter)
print("effective batch size: %d" % eff_batch_size)
if args.enable_deepspeed:
loss_scaler = None
optimizer_params = get_parameter_groups(
model, args.weight_decay, model.no_weight_decay()
)
model, optimizer, _, _ = ds_init(
args=args, model=model, model_parameters=optimizer_params,
dist_init_required=not args.distributed,
)
print("model.gradient_accumulation_steps() = %d" %
model.gradient_accumulation_steps())
assert model.gradient_accumulation_steps() == args.accum_iter
else:
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
# following timm: set wd as 0 for bias and norm layers
param_groups = lrd.param_groups_lrd(model_without_ddp, args.weight_decay,
no_weight_decay_list=model_without_ddp.no_weight_decay(),
layer_decay=args.layer_decay
)
optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=args.opt_betas)
print(optimizer)
loss_scaler = NativeScaler()
misc.auto_load_model(
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
print(f"Start training for {args.epochs} epochs")
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
data_loader_train.sampler.set_epoch(epoch)
train_stats = train_one_epoch(
model, data_loader_train,
optimizer, device, epoch, loss_scaler,
log_writer=log_writer,
global_rank=global_rank,
args=args
)
if args.output_dir and (epoch % args.save_freq == 0 or epoch + 1 == args.epochs):
misc.save_model(
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
loss_scaler=loss_scaler, epoch=epoch)
test_stats = evaluate_pt(data_loader_val, model, device, epoch=epoch, global_rank=global_rank, args=args)
print(f"Val loss of the network on the {len(dataset_val)} test images: {test_stats['loss']:.3f}")
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
**{f'test_{k}': v for k, v in test_stats.items()},
'epoch': epoch,}
if args.output_dir and misc.is_main_process():
if log_writer is not None:
log_writer.flush()
with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
f.write(json.dumps(log_stats) + "\n")
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
if global_rank == 0 and args.log_wandb:
wandb.finish()
if __name__ == '__main__':
args, ds_init = get_args_parser()
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args, ds_init)
# 模型唯一标识
modelCode=515
# 模型名称
modelName=painter_pytorch
# 模型描述
modelDescription=将视觉任务的连续输出空间离散化,并使用语言或专门设计的离散标记作为任务提示,将视觉问题转化为NLP问题。
# 应用场景
appScenario=推理,训练,图像超分,图像分割,交通,医疗,政府,制造
# 框架类型
frameType=PyTorch
# --------------------------------------------------------
# Images Speak in Images: A Generalist Painter for In-Context Visual Learning (https://arxiv.org/abs/2212.02499)
# Github source: https://github.com/baaivision/Painter
# Copyright (c) 2022 Beijing Academy of Artificial Intelligence (BAAI)
# Licensed under The MIT License [see LICENSE for details]
# By Xinlong Wang, Wen Wang
# Based on MAE, BEiT, detectron2, Mask2Former, bts, mmcv, mmdetetection, mmpose, MIRNet, MPRNet, and Uformer codebases
# --------------------------------------------------------'
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
##########################
import fvcore.nn.weight_init as weight_init
from detectron2.layers import CNNBlockBase, Conv2d, get_norm
from fairscale.nn.checkpoint import checkpoint_wrapper
from timm.models.layers import DropPath, trunc_normal_
from timm.models.vision_transformer import Mlp
from util.vitdet_utils import (
PatchEmbed,
add_decomposed_rel_pos,
get_abs_pos,
window_partition,
window_unpartition,
LayerNorm2D,
)
class Attention(nn.Module):
"""Multi-head Attention block with relative position embeddings."""
def __init__(
self,
dim,
num_heads=8,
qkv_bias=True,
use_rel_pos=False,
rel_pos_zero_init=True,
input_size=None,
):
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
qkv_bias (bool: If True, add a learnable bias to query, key, value.
rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
input_size (int or None): Input resolution for calculating the relative positional
parameter size.
"""
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.use_rel_pos = use_rel_pos
if self.use_rel_pos:
# initialize relative positional embeddings
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
if not rel_pos_zero_init:
trunc_normal_(self.rel_pos_h, std=0.02)
trunc_normal_(self.rel_pos_w, std=0.02)
def forward(self, x):
B, H, W, _ = x.shape
# qkv with shape (3, B, nHead, H * W, C)
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
# q, k, v with shape (B * nHead, H * W, C)
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
attn = (q * self.scale) @ k.transpose(-2, -1)
if self.use_rel_pos:
attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
attn = attn.softmax(dim=-1)
x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
x = self.proj(x)
return x
class ResBottleneckBlock(CNNBlockBase):
"""
The standard bottleneck residual block without the last activation layer.
It contains 3 conv layers with kernels 1x1, 3x3, 1x1.
"""
def __init__(
self,
in_channels,
out_channels,
bottleneck_channels,
norm="LN",
act_layer=nn.GELU,
):
"""
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
bottleneck_channels (int): number of output channels for the 3x3
"bottleneck" conv layers.
norm (str or callable): normalization for all conv layers.
See :func:`layers.get_norm` for supported format.
act_layer (callable): activation for all conv layers.
"""
super().__init__(in_channels, out_channels, 1)
self.conv1 = Conv2d(in_channels, bottleneck_channels, 1, bias=False)
self.norm1 = get_norm(norm, bottleneck_channels)
self.act1 = act_layer()
self.conv2 = Conv2d(
bottleneck_channels,
bottleneck_channels,
3,
padding=1,
bias=False,
)
self.norm2 = get_norm(norm, bottleneck_channels)
self.act2 = act_layer()
self.conv3 = Conv2d(bottleneck_channels, out_channels, 1, bias=False)
self.norm3 = get_norm(norm, out_channels)
for layer in [self.conv1, self.conv2, self.conv3]:
weight_init.c2_msra_fill(layer)
for layer in [self.norm1, self.norm2]:
layer.weight.data.fill_(1.0)
layer.bias.data.zero_()
# zero init last norm layer.
self.norm3.weight.data.zero_()
self.norm3.bias.data.zero_()
def forward(self, x):
out = x
for layer in self.children():
out = layer(out)
out = x + out
return out
class Block(nn.Module):
"""Transformer blocks with support of window attention and residual propagation blocks"""
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=True,
drop_path=0.0,
norm_layer=nn.LayerNorm,
act_layer=nn.GELU,
use_rel_pos=False,
rel_pos_zero_init=True,
window_size=0,
use_residual_block=False,
input_size=None,
):
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
drop_path (float): Stochastic depth rate.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks. If it equals 0, then not
use window attention.
use_residual_block (bool): If True, use a residual block after the MLP block.
input_size (int or None): Input resolution for calculating the relative positional
parameter size.
"""
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
input_size=input_size if window_size == 0 else (window_size, window_size),
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer)
self.window_size = window_size
self.use_residual_block = use_residual_block
if use_residual_block:
# Use a residual block with bottleneck channel as dim // 2
self.residual = ResBottleneckBlock(
in_channels=dim,
out_channels=dim,
bottleneck_channels=dim // 2,
norm="LN",
act_layer=act_layer,
)
def forward(self, x):
shortcut = x
x = self.norm1(x)
# Window partition
if self.window_size > 0:
H, W = x.shape[1], x.shape[2]
x, pad_hw = window_partition(x, self.window_size)
x = self.attn(x)
# Reverse window partition
if self.window_size > 0:
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
if self.use_residual_block:
x = self.residual(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
return x
class Painter(nn.Module):
""" Masked Autoencoder with VisionTransformer backbone
"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4.,
qkv_bias=True,
drop_path_rate=0.,
norm_layer=nn.LayerNorm,
act_layer=nn.GELU,
use_abs_pos=True,
use_rel_pos=False,
rel_pos_zero_init=True,
window_size=0,
window_block_indexes=(),
residual_block_indexes=(),
use_act_checkpoint=False,
pretrain_img_size=224,
pretrain_use_cls_token=True,
out_feature="last_feat",
decoder_embed_dim=128,
loss_func="smoothl1",
):
super().__init__()
# --------------------------------------------------------------------------
self.pretrain_use_cls_token = pretrain_use_cls_token
self.patch_size = patch_size
self.patch_embed = PatchEmbed(
kernel_size=(patch_size, patch_size),
stride=(patch_size, patch_size),
in_chans=in_chans,
embed_dim=embed_dim,
)
self.patch_embed.num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size)
self.mask_token = nn.Parameter(torch.zeros(1, 1, 1, embed_dim))
self.segment_token_x = nn.Parameter(torch.zeros(1, 1, 1, embed_dim))
self.segment_token_y = nn.Parameter(torch.zeros(1, 1, 1, embed_dim))
if use_abs_pos:
# Initialize absolute positional embedding with pretrain image size.
num_patches = (pretrain_img_size // patch_size) * (pretrain_img_size // patch_size)
num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches
self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim), requires_grad=True)
else:
self.pos_embed = None
# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
self.blocks = nn.ModuleList()
for i in range(depth):
block = Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
window_size=window_size if i in window_block_indexes else 0,
use_residual_block=i in residual_block_indexes,
input_size=(img_size[0] // patch_size, img_size[1] // patch_size),
)
if use_act_checkpoint:
block = checkpoint_wrapper(block)
self.blocks.append(block)
self._out_feature_channels = {out_feature: embed_dim}
self._out_feature_strides = {out_feature: patch_size}
self._out_features = [out_feature]
if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=0.02)
self.norm = norm_layer(embed_dim)
# --------------------------------------------------------------------------
# --------------------------------------------------------------------------
self.decoder_embed_dim = decoder_embed_dim
self.decoder_embed = nn.Linear(embed_dim*4, patch_size ** 2 * self.decoder_embed_dim, bias=True)
self.decoder_pred = nn.Sequential(
nn.Conv2d(self.decoder_embed_dim, self.decoder_embed_dim, kernel_size=3, padding=1, ),
LayerNorm2D(self.decoder_embed_dim),
nn.GELU(),
nn.Conv2d(self.decoder_embed_dim, 3, kernel_size=1, bias=True),
)
# --------------------------------------------------------------------------
self.loss_func = loss_func
torch.nn.init.normal_(self.mask_token, std=.02)
torch.nn.init.normal_(self.segment_token_x, std=.02)
torch.nn.init.normal_(self.segment_token_y, std=.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def patchify(self, imgs):
"""
imgs: (N, 3, H, W)
x: (N, L, patch_size**2 *3)
"""
p = self.patch_size
assert imgs.shape[2] == 2 * imgs.shape[3] and imgs.shape[2] % p == 0
w = imgs.shape[3] // p
h = w * 2
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
x = torch.einsum('nchpwq->nhwpqc', x)
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
return x
def unpatchify(self, x):
"""
x: (N, L, patch_size**2 *3)
imgs: (N, 3, H, W)
"""
p = self.patch_size
w = int((x.shape[1]*0.5)**.5)
h = w * 2
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], 3, h * p, w * p))
return imgs
def forward_encoder(self, imgs, tgts, bool_masked_pos):
# embed patches
x = self.patch_embed(imgs)
y = self.patch_embed(tgts)
batch_size, Hp, Wp, _ = x.size()
seq_len = Hp * Wp
mask_token = self.mask_token.expand(batch_size, Hp, Wp, -1)
# replace the masked visual tokens by mask_token
w = bool_masked_pos.unsqueeze(-1).type_as(mask_token).reshape(-1, Hp, Wp, 1)
y = y * (1 - w) + mask_token * w
# add pos embed w/o cls token
x = x + self.segment_token_x
y = y + self.segment_token_y
if self.pos_embed is not None:
x = x + get_abs_pos(
self.pos_embed, self.pretrain_use_cls_token, (x.shape[1], x.shape[2])
)
y = y + get_abs_pos(
self.pos_embed, self.pretrain_use_cls_token, (y.shape[1], y.shape[2])
)
merge_idx = 2
x = torch.cat((x, y), dim=0)
# apply Transformer blocks
out = []
for idx, blk in enumerate(self.blocks):
x = blk(x)
if idx == merge_idx:
x = (x[:x.shape[0]//2] + x[x.shape[0]//2:]) * 0.5
if idx in [5, 11, 17, 23]:
out.append(self.norm(x))
return out
def forward_decoder(self, x):
# predictor projection
x = torch.cat(x, dim=-1)
x = self.decoder_embed(x)
p = self.patch_size
h, w = x.shape[1], x.shape[2]
x = x.reshape(shape=(x.shape[0], h, w, p, p, self.decoder_embed_dim))
x = torch.einsum('nhwpqc->nchpwq', x)
x = x.reshape(shape=(x.shape[0], -1, h * p, w * p))
x = self.decoder_pred(x) # Bx3xHxW
return x
def forward_loss(self, pred, tgts, mask, valid):
"""
tgts: [N, 3, H, W]
pred: [N, 3, H, W]
mask: [N, L], 0 is keep, 1 is remove,
valid: [N, 3, H, W]
"""
mask = mask[:, :, None].repeat(1, 1, self.patch_size**2 * 3)
mask = self.unpatchify(mask)
# ignore if the unmasked pixels are all zeros
imagenet_mean=torch.tensor([0.485, 0.456, 0.406]).to(tgts.device)[None, :, None, None]
imagenet_std=torch.tensor([0.229, 0.224, 0.225]).to(tgts.device)[None, :, None, None]
inds_ign = ((tgts * imagenet_std + imagenet_mean) * (1 - 1.*mask)).sum((1, 2, 3)) < 100*3
if inds_ign.sum() > 0:
valid[inds_ign] = 0.
mask = mask * valid
target = tgts
if self.loss_func == "l1l2":
loss = ((pred - target).abs() + (pred - target) ** 2.) * 0.5
elif self.loss_func == "l1":
loss = (pred - target).abs()
elif self.loss_func == "l2":
loss = (pred - target) ** 2.
elif self.loss_func == "smoothl1":
loss = F.smooth_l1_loss(pred, target, reduction="none", beta=0.01)
loss = (loss * mask).sum() / (mask.sum() + 1e-2) # mean loss on removed patches
return loss
def forward(self, imgs, tgts, bool_masked_pos=None, valid=None):
if bool_masked_pos is None:
bool_masked_pos = torch.zeros((imgs.shape[0], self.patch_embed.num_patches), dtype=torch.bool).to(imgs.device)
else:
bool_masked_pos = bool_masked_pos.flatten(1).to(torch.bool)
latent = self.forward_encoder(imgs, tgts, bool_masked_pos)
pred = self.forward_decoder(latent) # [N, L, p*p*3]
loss = self.forward_loss(pred, tgts, bool_masked_pos, valid)
return loss, self.patchify(pred), bool_masked_pos
def painter_vit_large_patch16_input896x448_win_dec64_8glb_sl1(**kwargs):
model = Painter(
img_size=(896, 448), patch_size=16, embed_dim=1024, depth=24, num_heads=16,
drop_path_rate=0.1, window_size=14, qkv_bias=True,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6),
window_block_indexes=(list(range(0, 2)) + list(range(3, 5)) + list(range(6, 8)) + list(range(9, 11)) + \
list(range(12, 14)), list(range(15, 17)), list(range(18, 20)), list(range(21, 23))),
residual_block_indexes=[], use_rel_pos=True, out_feature="last_feat",
decoder_embed_dim=64,
loss_func="smoothl1",
**kwargs)
return model
def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12):
"""
Calculate lr decay rate for different ViT blocks.
Args:
name (string): parameter name.
lr_decay_rate (float): base lr decay rate.
num_layers (int): number of ViT blocks.
Returns:
lr decay rate for the given parameter.
"""
layer_id = num_layers + 1
if name.startswith("backbone"):
if ".pos_embed" in name or ".patch_embed" in name:
layer_id = 0
elif ".blocks." in name and ".residual." not in name:
layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
return lr_decay_rate ** (num_layers + 1 - layer_id)
timm==0.3.2
git+https://github.com/cocodataset/panopticapi.git
h5py # for depth
xtcocotools # for pose
natsort # for denoising
wandb
scikit-image==0.18.0
git+https://github.com/svenkreiss/poseval.git
tensorbord
fvcore==0.1.5
yapf==0.40.1
fairscale==0.4.13
\ No newline at end of file
[{"image_path": "ade20k/images/training/ADE_train_00014165.jpg", "target_path": "ade20k/annotations_with_color/training/ADE_train_00014165.png", "type": "ade20k_image2semantic"}, {"image_path": "ade20k/images/training/ADE_train_00017885.jpg", "target_path": "ade20k/annotations_with_color/training/ADE_train_00017885.png", "type": "ade20k_image2semantic"}, {"image_path": "ade20k/images/training/ADE_train_00010565.jpg", "target_path": "ade20k/annotations_with_color/training/ADE_train_00010565.png", "type": "ade20k_image2semantic"}, {"image_path": "ade20k/images/training/ADE_train_00010441.jpg", "target_path": "ade20k/annotations_with_color/training/ADE_train_00010441.png", "type": "ade20k_image2semantic"}, {"image_path": "ade20k/images/training/ADE_train_00007380.jpg", "target_path": "ade20k/annotations_with_color/training/ADE_train_00007380.png", "type": "ade20k_image2semantic"}, {"image_path": "ade20k/images/training/ADE_train_00013301.jpg", "target_path": "ade20k/annotations_with_color/training/ADE_train_00013301.png", "type": "ade20k_image2semantic"}, {"image_path": "ade20k/images/training/ADE_train_00009464.jpg", "target_path": "ade20k/annotations_with_color/training/ADE_train_00009464.png", "type": "ade20k_image2semantic"}, {"image_path": "ade20k/images/training/ADE_train_00016499.jpg", "target_path": "ade20k/annotations_with_color/training/ADE_train_00016499.png", "type": "ade20k_image2semantic"}, {"image_path": "ade20k/images/training/ADE_train_00001098.jpg", "target_path": "ade20k/annotations_with_color/training/ADE_train_00001098.png", "type": "ade20k_image2semantic"}, {"image_path": "ade20k/images/training/ADE_train_00003821.jpg", "target_path": "ade20k/annotations_with_color/training/ADE_train_00003821.png", "type": "ade20k_image2semantic"}]
\ No newline at end of file
[{"image_path": "ade20k/images/validation/ADE_val_00001256.jpg", "target_path": "ade20k/annotations_with_color/validation/ADE_val_00001256.png", "type": "ade20k_image2semantic"}, {"image_path": "ade20k/images/validation/ADE_val_00000077.jpg", "target_path": "ade20k/annotations_with_color/validation/ADE_val_00000077.png", "type": "ade20k_image2semantic"}, {"image_path": "ade20k/images/validation/ADE_val_00001364.jpg", "target_path": "ade20k/annotations_with_color/validation/ADE_val_00001364.png", "type": "ade20k_image2semantic"}, {"image_path": "ade20k/images/validation/ADE_val_00001556.jpg", "target_path": "ade20k/annotations_with_color/validation/ADE_val_00001556.png", "type": "ade20k_image2semantic"}, {"image_path": "ade20k/images/validation/ADE_val_00001014.jpg", "target_path": "ade20k/annotations_with_color/validation/ADE_val_00001014.png", "type": "ade20k_image2semantic"}, {"image_path": "ade20k/images/validation/ADE_val_00000011.jpg", "target_path": "ade20k/annotations_with_color/validation/ADE_val_00000011.png", "type": "ade20k_image2semantic"}, {"image_path": "ade20k/images/validation/ADE_val_00001680.jpg", "target_path": "ade20k/annotations_with_color/validation/ADE_val_00001680.png", "type": "ade20k_image2semantic"}, {"image_path": "ade20k/images/validation/ADE_val_00001952.jpg", "target_path": "ade20k/annotations_with_color/validation/ADE_val_00001952.png", "type": "ade20k_image2semantic"}, {"image_path": "ade20k/images/validation/ADE_val_00001152.jpg", "target_path": "ade20k/annotations_with_color/validation/ADE_val_00001152.png", "type": "ade20k_image2semantic"}, {"image_path": "ade20k/images/validation/ADE_val_00000998.jpg", "target_path": "ade20k/annotations_with_color/validation/ADE_val_00000998.png", "type": "ade20k_image2semantic"}]
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment