"src/vscode:/vscode.git/clone" did not exist on "9f10306b3fd8168a100e749716e99b75b769e3ef"
Commit c2f37e29 authored by yongshk's avatar yongshk
Browse files

Initial commit

parents
import os
import sys
import time
import torch
import numpy as np
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import config
import myutils
from loss import Loss
from torch.utils.data import DataLoader
def load_checkpoint(args, model, optimizer , path):
print("loading checkpoint %s" % path)
checkpoint = torch.load(path)
args.start_epoch = checkpoint['epoch'] + 1
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr = checkpoint.get("lr" , args.lr)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
##### Parse CmdLine Arguments #####
args, unparsed = config.get_args()
cwd = os.getcwd()
print(args)
save_loc = os.path.join(args.checkpoint_dir , "saved_models_final" , args.dataset , args.exp_name)
if not os.path.exists(save_loc):
os.makedirs(save_loc)
opts_file = os.path.join(save_loc , "opts.txt")
with open(opts_file , "w") as fh:
fh.write(str(args))
##### TensorBoard & Misc Setup #####
writer_loc = os.path.join(args.checkpoint_dir , 'tensorboard_logs_%s_final/%s' % (args.dataset , args.exp_name))
writer = SummaryWriter(writer_loc)
device = torch.device('cuda' if args.cuda else 'cpu')
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
torch.manual_seed(args.random_seed)
if args.cuda:
torch.cuda.manual_seed(args.random_seed)
if args.dataset == "vimeo90K_septuplet":
from dataset.vimeo90k_septuplet import get_loader
train_loader = get_loader('train', args.data_root, args.batch_size, shuffle=True, num_workers=args.num_workers)
test_loader = get_loader('test', args.data_root, args.test_batch_size, shuffle=False, num_workers=args.num_workers)
elif args.dataset == "gopro":
from dataset.GoPro import get_loader
train_loader = get_loader(args.data_root, args.batch_size, shuffle=True, num_workers=args.num_workers, test_mode=False, interFrames=args.n_outputs, n_inputs=args.nbr_frame)
test_loader = get_loader(args.data_root, args.batch_size, shuffle=False, num_workers=args.num_workers, test_mode=True, interFrames=args.n_outputs, n_inputs=args.nbr_frame)
else:
raise NotImplementedError
from model.FLAVR_arch import UNet_3D_3D
print("Building model: %s"%args.model.lower())
model = UNet_3D_3D(args.model.lower() , n_inputs=args.nbr_frame, n_outputs=args.n_outputs, joinType=args.joinType, upmode=args.upmode)
model = torch.nn.DataParallel(model).to(device)
##### Define Loss & Optimizer #####
criterion = Loss(args)
## ToDo: Different learning rate schemes for different parameters
from torch.optim import Adam
optimizer = Adam(model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)
def train(args, epoch):
losses, psnrs, ssims = myutils.init_meters(args.loss)
model.train()
criterion.train()
t = time.time()
for i, (images, gt_image) in enumerate(train_loader):
# Build input batch
images = [img_.cuda() for img_ in images]
gt = [gt_.cuda() for gt_ in gt_image]
# Forward
optimizer.zero_grad()
out = model(images)
out = torch.cat(out)
gt = torch.cat(gt)
loss, loss_specific = criterion(out, gt)
# Save loss values
for k, v in losses.items():
if k != 'total':
v.update(loss_specific[k].item())
losses['total'].update(loss.item())
loss.backward()
optimizer.step()
# Calc metrics & print logs
if i % args.log_iter == 0:
myutils.eval_metrics(out, gt, psnrs, ssims)
print('Train Epoch: {} [{}/{}]\tLoss: {:.6f}\tPSNR: {:.4f}'.format(
epoch, i, len(train_loader), losses['total'].avg, psnrs.avg , flush=True))
# Log to TensorBoard
timestep = epoch * len(train_loader) + i
writer.add_scalar('Loss/train', loss.data.item(), timestep)
writer.add_scalar('PSNR/train', psnrs.avg, timestep)
writer.add_scalar('SSIM/train', ssims.avg, timestep)
writer.add_scalar('lr', optimizer.param_groups[-1]['lr'], timestep)
# Reset metrics
losses, psnrs, ssims = myutils.init_meters(args.loss)
t = time.time()
def test(args, epoch):
print('Evaluating for epoch = %d' % epoch)
losses, psnrs, ssims = myutils.init_meters(args.loss)
model.eval()
criterion.eval()
t = time.time()
with torch.no_grad():
for i, (images, gt_image) in enumerate(tqdm(test_loader)):
images = [img_.cuda() for img_ in images]
gt = [gt_.cuda() for gt_ in gt_image]
out = model(images) ## images is a list of neighboring frames
out = torch.cat(out)
gt = torch.cat(gt)
# Save loss values
loss, loss_specific = criterion(out, gt)
for k, v in losses.items():
if k != 'total':
v.update(loss_specific[k].item())
losses['total'].update(loss.item())
# Evaluate metrics
myutils.eval_metrics(out, gt, psnrs, ssims)
# Print progress
print("Loss: %f, PSNR: %f, SSIM: %f\n" %
(losses['total'].avg, psnrs.avg, ssims.avg))
# Save psnr & ssim
save_fn = os.path.join(save_loc, 'results.txt')
with open(save_fn, 'a') as f:
f.write('For epoch=%d\t' % epoch)
f.write("PSNR: %f, SSIM: %f\n" %
(psnrs.avg, ssims.avg))
# Log to TensorBoard
timestep = epoch +1
writer.add_scalar('Loss/test', loss.data.item(), timestep)
writer.add_scalar('PSNR/test', psnrs.avg, timestep)
writer.add_scalar('SSIM/test', ssims.avg, timestep)
return losses['total'].avg, psnrs.avg, ssims.avg
""" Entry Point """
def main(args):
if args.pretrained:
## For low data, it is better to load from a supervised pretrained model
loadStateDict = torch.load(args.pretrained)['state_dict']
modelStateDict = model.state_dict()
for k,v in loadStateDict.items():
if v.shape == modelStateDict[k].shape:
print("Loading " , k)
modelStateDict[k] = v
else:
print("Not loading" , k)
model.load_state_dict(modelStateDict)
best_psnr = 0
for epoch in range(args.start_epoch, args.max_epoch):
train(args, epoch)
test_loss, psnr, _ = test(args, epoch)
# save checkpoint
is_best = psnr > best_psnr
best_psnr = max(psnr, best_psnr)
myutils.save_checkpoint({
'epoch': epoch,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'best_psnr': best_psnr,
'lr' : optimizer.param_groups[-1]['lr']
}, save_loc, is_best, args.exp_name)
# update optimizer policy
scheduler.step(test_loss)
if __name__ == "__main__":
main(args)
import math
import numpy as np
import importlib
import torch
import torch.nn as nn
import torch.nn.functional as F
from .resnet_3D import SEGating
def joinTensors(X1 , X2 , type="concat"):
if type == "concat":
return torch.cat([X1 , X2] , dim=1)
elif type == "add":
return X1 + X2
else:
return X1
class Conv_2d(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size, stride=1, padding=0, bias=False, batchnorm=False):
super().__init__()
self.conv = [nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)]
if batchnorm:
self.conv += [nn.BatchNorm2d(out_ch)]
self.conv = nn.Sequential(*self.conv)
def forward(self, x):
return self.conv(x)
class upConv3D(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size, stride, padding, upmode="transpose" , batchnorm=False):
super().__init__()
self.upmode = upmode
if self.upmode=="transpose":
self.upconv = nn.ModuleList(
[nn.ConvTranspose3d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding),
SEGating(out_ch)
]
)
else:
self.upconv = nn.ModuleList(
[nn.Upsample(mode='trilinear', scale_factor=(1,2,2), align_corners=False),
nn.Conv3d(in_ch, out_ch , kernel_size=1 , stride=1),
SEGating(out_ch)
]
)
if batchnorm:
self.upconv += [nn.BatchNorm3d(out_ch)]
self.upconv = nn.Sequential(*self.upconv)
def forward(self, x):
return self.upconv(x)
class Conv_3d(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size, stride=1, padding=0, bias=True, batchnorm=False):
super().__init__()
self.conv = [nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
SEGating(out_ch)
]
if batchnorm:
self.conv += [nn.BatchNorm3d(out_ch)]
self.conv = nn.Sequential(*self.conv)
def forward(self, x):
return self.conv(x)
class upConv2D(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size, stride, padding, upmode="transpose" , batchnorm=False):
super().__init__()
self.upmode = upmode
if self.upmode=="transpose":
self.upconv = [nn.ConvTranspose2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding)]
else:
self.upconv = [
nn.Upsample(mode='bilinear', scale_factor=2, align_corners=False),
nn.Conv2d(in_ch, out_ch , kernel_size=1 , stride=1)
]
if batchnorm:
self.upconv += [nn.BatchNorm2d(out_ch)]
self.upconv = nn.Sequential(*self.upconv)
def forward(self, x):
return self.upconv(x)
class UNet_3D_3D(nn.Module):
def __init__(self, block , n_inputs, n_outputs, batchnorm=False , joinType="concat" , upmode="transpose"):
super().__init__()
nf = [512 , 256 , 128 , 64]
out_channels = 3*n_outputs
self.joinType = joinType
self.n_outputs = n_outputs
growth = 2 if joinType == "concat" else 1
self.lrelu = nn.LeakyReLU(0.2, True)
unet_3D = importlib.import_module(".resnet_3D" , "model")
if n_outputs > 1:
unet_3D.useBias = True
self.encoder = getattr(unet_3D , block)(pretrained=False , bn=batchnorm)
self.decoder = nn.Sequential(
Conv_3d(nf[0], nf[1] , kernel_size=3, padding=1, bias=True, batchnorm=batchnorm),
upConv3D(nf[1]*growth, nf[2], kernel_size=(3,4,4), stride=(1,2,2), padding=(1,1,1) , upmode=upmode, batchnorm=batchnorm),
upConv3D(nf[2]*growth, nf[3], kernel_size=(3,4,4), stride=(1,2,2), padding=(1,1,1) , upmode=upmode, batchnorm=batchnorm),
Conv_3d(nf[3]*growth, nf[3] , kernel_size=3, padding=1, bias=True, batchnorm=batchnorm),
upConv3D(nf[3]*growth , nf[3], kernel_size=(3,4,4), stride=(1,2,2), padding=(1,1,1) , upmode=upmode, batchnorm=batchnorm)
)
self.feature_fuse = Conv_2d(nf[3]*n_inputs , nf[3] , kernel_size=1 , stride=1, batchnorm=batchnorm)
self.outconv = nn.Sequential(
nn.ReflectionPad2d(3),
nn.Conv2d(nf[3], out_channels , kernel_size=7 , stride=1, padding=0)
)
def forward(self, images):
images = torch.stack(images , dim=2)
## Batch mean normalization works slightly better than global mean normalization, thanks to https://github.com/myungsub/CAIN
mean_ = images.mean(2, keepdim=True).mean(3, keepdim=True).mean(4,keepdim=True)
images = images-mean_
x_0 , x_1 , x_2 , x_3 , x_4 = self.encoder(images)
dx_3 = self.lrelu(self.decoder[0](x_4))
dx_3 = joinTensors(dx_3 , x_3 , type=self.joinType)
dx_2 = self.lrelu(self.decoder[1](dx_3))
dx_2 = joinTensors(dx_2 , x_2 , type=self.joinType)
dx_1 = self.lrelu(self.decoder[2](dx_2))
dx_1 = joinTensors(dx_1 , x_1 , type=self.joinType)
dx_0 = self.lrelu(self.decoder[3](dx_1))
dx_0 = joinTensors(dx_0 , x_0 , type=self.joinType)
dx_out = self.lrelu(self.decoder[4](dx_0))
dx_out = torch.cat(torch.unbind(dx_out , 2) , 1)
out = self.lrelu(self.feature_fuse(dx_out))
out = self.outconv(out)
out = torch.split(out, dim=1, split_size_or_sections=3)
mean_ = mean_.squeeze(2)
out = [o+mean_ for o in out]
return out
# Modified from https://github.com/pytorch/vision/tree/master/torchvision/models/video
import torch
import torch.nn as nn
__all__ = ['unet_18', 'unet_34']
useBias = False
class identity(nn.Module):
def __init__(self , *args , **kwargs):
super().__init__()
def forward(self , x):
return x
class Conv3DSimple(nn.Conv3d):
def __init__(self,
in_planes,
out_planes,
midplanes=None,
stride=1,
padding=1):
super(Conv3DSimple, self).__init__(
in_channels=in_planes,
out_channels=out_planes,
kernel_size=(3, 3, 3),
stride=stride,
padding=padding,
bias=useBias)
@staticmethod
def get_downsample_stride(stride , temporal_stride):
if temporal_stride:
return (temporal_stride, stride, stride)
else:
return (stride , stride , stride)
class BasicStem(nn.Sequential):
"""The default conv-batchnorm-relu stem
"""
def __init__(self):
super().__init__(
nn.Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2),
padding=(1, 3, 3), bias=useBias),
batchnorm(64),
nn.ReLU(inplace=False))
class Conv2Plus1D(nn.Sequential):
def __init__(self,
in_planes,
out_planes,
midplanes,
stride=1,
padding=1):
if not isinstance(stride , int):
temporal_stride , stride , stride = stride
else:
temporal_stride = stride
super(Conv2Plus1D, self).__init__(
nn.Conv3d(in_planes, midplanes, kernel_size=(1, 3, 3),
stride=(1, stride, stride), padding=(0, padding, padding),
bias=False),
# batchnorm(midplanes),
nn.ReLU(inplace=True),
nn.Conv3d(midplanes, out_planes, kernel_size=(3, 1, 1),
stride=(temporal_stride, 1, 1), padding=(padding, 0, 0),
bias=False))
@staticmethod
def get_downsample_stride(stride , temporal_stride):
if temporal_stride:
return (temporal_stride, stride, stride)
else:
return (stride , stride , stride)
class R2Plus1dStem(nn.Sequential):
"""R(2+1)D stem is different than the default one as it uses separated 3D convolution
"""
def __init__(self):
super().__init__(
nn.Conv3d(3, 45, kernel_size=(1, 7, 7),
stride=(1, 2, 2), padding=(0, 3, 3),
bias=False),
batchnorm(45),
nn.ReLU(inplace=True),
nn.Conv3d(45, 64, kernel_size=(3, 1, 1),
stride=(1, 1, 1), padding=(1, 0, 0),
bias=False),
batchnorm(64),
nn.ReLU(inplace=True))
class SEGating(nn.Module):
def __init__(self , inplanes , reduction=16):
super().__init__()
self.pool = nn.AdaptiveAvgPool3d(1)
self.attn_layer = nn.Sequential(
nn.Conv3d(inplanes , inplanes , kernel_size=1 , stride=1 , bias=True),
nn.Sigmoid()
)
def forward(self , x):
out = self.pool(x)
y = self.attn_layer(out)
return x * y
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None):
midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes)
super(BasicBlock, self).__init__()
self.conv1 = nn.Sequential(
conv_builder(inplanes, planes, midplanes, stride),
batchnorm(planes),
nn.ReLU(inplace=True)
)
self.conv2 = nn.Sequential(
conv_builder(planes, planes, midplanes),
batchnorm(planes)
)
self.fg = SEGating(planes) ## Feature Gating
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.conv2(out)
out = self.fg(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class VideoResNet(nn.Module):
def __init__(self, block, conv_makers, layers,
stem, zero_init_residual=False):
"""Generic resnet video generator.
Args:
block (nn.Module): resnet building block
conv_makers (list(functions)): generator function for each layer
layers (List[int]): number of blocks per layer
stem (nn.Module, optional): Resnet stem, if None, defaults to conv-bn-relu. Defaults to None.
"""
super(VideoResNet, self).__init__()
self.inplanes = 64
self.stem = stem()
self.layer1 = self._make_layer(block, conv_makers[0], 64, layers[0], stride=1 )
self.layer2 = self._make_layer(block, conv_makers[1], 128, layers[1], stride=2 , temporal_stride=1)
self.layer3 = self._make_layer(block, conv_makers[2], 256, layers[2], stride=2 , temporal_stride=1)
self.layer4 = self._make_layer(block, conv_makers[3], 512, layers[3], stride=1, temporal_stride=1)
# init weights
self._initialize_weights()
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
def forward(self, x):
x_0 = self.stem(x)
x_1 = self.layer1(x_0)
x_2 = self.layer2(x_1)
x_3 = self.layer3(x_2)
x_4 = self.layer4(x_3)
return x_0 , x_1 , x_2 , x_3 , x_4
def _make_layer(self, block, conv_builder, planes, blocks, stride=1, temporal_stride=None):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
ds_stride = conv_builder.get_downsample_stride(stride , temporal_stride)
downsample = nn.Sequential(
nn.Conv3d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=ds_stride, bias=False),
batchnorm(planes * block.expansion)
)
stride = ds_stride
layers = []
layers.append(block(self.inplanes, planes, conv_builder, stride, downsample ))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, conv_builder ))
return nn.Sequential(*layers)
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv3d):
nn.init.kaiming_normal_(m.weight, mode='fan_out',
nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm3d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def _video_resnet(arch, pretrained=False, progress=True, **kwargs):
model = VideoResNet(**kwargs)
## TODO: Other 3D resnet models, like S3D, r(2+1)D.
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
model.load_state_dict(state_dict)
return model
def unet_18(pretrained=False, bn=False, progress=True, **kwargs):
"""
Construct 18 layer Unet3D model as in
https://arxiv.org/abs/1711.11248
Args:
pretrained (bool): If True, returns a model pre-trained on Kinetics-400
progress (bool): If True, displays a progress bar of the download to stderr
Returns:
nn.Module: R3D-18 encoder
"""
global batchnorm
if bn:
batchnorm = nn.BatchNorm3d
else:
batchnorm = identity
return _video_resnet('r3d_18',
pretrained, progress,
block=BasicBlock,
conv_makers=[Conv3DSimple] * 4,
layers=[2, 2, 2, 2],
stem=BasicStem, **kwargs)
def unet_34(pretrained=False, bn=False, progress=True, **kwargs):
"""
Construct 34 layer Unet3D model as in
https://arxiv.org/abs/1711.11248
Args:
pretrained (bool): If True, returns a model pre-trained on Kinetics-400
progress (bool): If True, displays a progress bar of the download to stderr
Returns:
nn.Module: R3D-18 encoder
"""
global batchnorm
# bn = False
if bn:
batchnorm = nn.BatchNorm3d
else:
batchnorm = identity
return _video_resnet('r3d_34',
pretrained, progress,
block=BasicBlock,
conv_makers=[Conv3DSimple] * 4,
layers=[3, 4, 6, 3],
stem=BasicStem, **kwargs)
# from https://github.com/myungsub/CAIN/blob/master/utils.py,
# but removed the errenous normalization and quantization steps from computing the PSNR.
from pytorch_msssim import ssim_matlab as calc_ssim
import math
import os
import torch
import shutil
def init_meters(loss_str):
losses = init_losses(loss_str)
psnrs = AverageMeter()
ssims = AverageMeter()
return losses, psnrs, ssims
def eval_metrics(output, gt, psnrs, ssims):
# PSNR should be calculated for each image, since sum(log) =/= log(sum).
for b in range(gt.size(0)):
psnr = calc_psnr(output[b], gt[b])
psnrs.update(psnr)
ssim = calc_ssim(output[b].unsqueeze(0).clamp(0,1), gt[b].unsqueeze(0).clamp(0,1) , val_range=1.)
ssims.update(ssim)
def init_losses(loss_str):
loss_specifics = {}
loss_list = loss_str.split('+')
for l in loss_list:
_, loss_type = l.split('*')
loss_specifics[loss_type] = AverageMeter()
loss_specifics['total'] = AverageMeter()
return loss_specifics
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def calc_psnr(pred, gt):
diff = (pred - gt).pow(2).mean() + 1e-8
return -10 * math.log10(diff)
def save_checkpoint(state, directory, is_best, exp_name, filename='checkpoint.pth'):
"""Saves checkpoint to disk"""
if not os.path.exists(directory):
os.makedirs(directory)
filename = os.path.join(directory , filename)
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, os.path.join(directory , 'model_best.pth'))
def log_tensorboard(writer, loss, psnr, ssim, lpips, lr, timestep, mode='train'):
writer.add_scalar('Loss/%s/%s' % mode, loss, timestep)
writer.add_scalar('PSNR/%s' % mode, psnr, timestep)
writer.add_scalar('SSIM/%s' % mode, ssim, timestep)
if mode == 'train':
writer.add_scalar('lr', lr, timestep)
\ No newline at end of file
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Fast Frame Interpolation with FLAVR.ipynb",
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"source": [
"# Fast Frame Interpolation with FLAVR\n",
"FLAVR is a fast, flow-free frame interpolation method capable of single shot multi-frame prediction. It uses a customized encoder decoder architecture with spatio-temporal convolutions and channel gating to capture and interpolate complex motion trajectories between frames to generate realistic high frame rate videos. This notebook is to apply slow-motion filtering on your own videos. \n",
"A GPU runtime is suggested to execute the code in this notebook. \n",
" \n",
"Credits for the original FLAVR work:\n",
"\n",
"\n",
"```\n",
"@article{kalluri2021flavr,\n",
" title={FLAVR: Flow-Agnostic Video Representations for Fast Frame Interpolation},\n",
" author={Kalluri, Tarun and Pathak, Deepak and Chandraker, Manmohan and Tran, Du},\n",
" booktitle={arxiv},\n",
" year={2021}\n",
"}\n",
"```\n",
"\n"
],
"metadata": {
"id": "GtNm2bt5m__t"
}
},
{
"cell_type": "markdown",
"source": [
"### Settings"
],
"metadata": {
"id": "Cer3xI_vC8AX"
}
},
{
"cell_type": "markdown",
"source": [
"Clone the official GitHub repository."
],
"metadata": {
"id": "L25AZqD1aqYy"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5geYqIv5ah8_"
},
"outputs": [],
"source": [
"!git clone https://github.com/tarun005/FLAVR.git\n",
"%cd FLAVR"
]
},
{
"cell_type": "markdown",
"source": [
"Install the missing requirements. Almost all the required Python packages for the code in this notebook are available by default in a Colab runtime. Only *PyAV*, a Pythonic binding for the FFmpeg libraries, to be installed really."
],
"metadata": {
"id": "VZ69AA375uby"
}
},
{
"cell_type": "code",
"source": [
"!pip install av"
],
"metadata": {
"id": "L1Bd6U5H5x8X"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Download a pretrained model. The Colab GPU runtime specs allow full completion only for 2X interpolation."
],
"metadata": {
"id": "3idcRJmwa0ss"
}
},
{
"cell_type": "code",
"source": [
"!gdown --id 1XFk9YZP9llTeporF-t_Au1dI-VhDQppG"
],
"metadata": {
"id": "eAjOsOhCbCXB"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"If the code cell above doesn't work, please copy the pre-trained model manually to your Google Drive space and then follow the instructions for the next 3 code cells."
],
"metadata": {
"id": "nspERdHKiilc"
}
},
{
"cell_type": "markdown",
"source": [
"Mount your Google Drive. After executing the code in the cell below, a URL will be shown in the cell output. Click on it and follow the instructions that would appear online."
],
"metadata": {
"id": "t1W1cafV0RRB"
}
},
{
"cell_type": "code",
"source": [
"from google.colab import drive\n",
"\n",
"drive.mount('/content/gdrive')"
],
"metadata": {
"id": "DPiPftbD0SWC"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Copy the pre-trained model to this runtime filesystem."
],
"metadata": {
"id": "-PibNYIlpu4K"
}
},
{
"cell_type": "code",
"source": [
"!cp -av '/content/gdrive/My Drive/FLAVR_2x.pth' './FLAVR_2x.pth'"
],
"metadata": {
"id": "cEcpbDyW0axe"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Unmount your Google Drive when done with the pre-trained model copy."
],
"metadata": {
"id": "ZLD3eO790bLP"
}
},
{
"cell_type": "code",
"source": [
"drive.flush_and_unmount()"
],
"metadata": {
"id": "8Meb4kd90eFF"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Define a function to upload videos."
],
"metadata": {
"id": "_-ll5UukbWE_"
}
},
{
"cell_type": "code",
"source": [
"import os\n",
"import shutil\n",
"from google.colab import files\n",
"\n",
"def upload_files(upload_path):\n",
" uploaded = files.upload()\n",
" for filename, content in uploaded.items():\n",
" dst_path = os.path.join(upload_path, filename)\n",
" shutil.move(filename, dst_path)\n",
" return list(uploaded.keys())"
],
"metadata": {
"id": "1R442tTXbcT9"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Create a directory for uploaded videos."
],
"metadata": {
"id": "4edydmfoceHc"
}
},
{
"cell_type": "code",
"source": [
"!mkdir ./test_videos\n",
"image_input_dir = '/content/FLAVR/test_videos/'"
],
"metadata": {
"id": "Zl3_EauGcjqE"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### Slow-Motion Filtering"
],
"metadata": {
"id": "zJieST7OoVEV"
}
},
{
"cell_type": "markdown",
"source": [
"Upload your own video."
],
"metadata": {
"id": "5tko37fpczcH"
}
},
{
"cell_type": "code",
"source": [
"uploaded_videos = upload_files(image_input_dir)"
],
"metadata": {
"id": "btKmYZb3ciJt"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"os.environ['UPLOADED_VIDEO_FILENAME'] = os.path.join(image_input_dir, uploaded_videos[0])"
],
"metadata": {
"id": "KtbfV4g24LI7"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Execute interpolation on the uploaded video."
],
"metadata": {
"id": "B8oX93NIc60B"
}
},
{
"cell_type": "code",
"source": [
"!python ./interpolate.py --input_video $UPLOADED_VIDEO_FILENAME --factor 2 --load_model ./FLAVR_2x.pth"
],
"metadata": {
"id": "I94Xy1e8dEW_"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Display the result."
],
"metadata": {
"id": "LsmXsAVI-c3d"
}
},
{
"cell_type": "code",
"source": [
"from moviepy.editor import VideoFileClip\n",
"\n",
"uploaded_video_filename_tokens = uploaded_videos[0].split('.')\n",
"result_video_path = uploaded_video_filename_tokens[0] + '_2x.' + uploaded_video_filename_tokens[1]\n",
"\n",
"clip = VideoFileClip(result_video_path)\n",
"clip.ipython_display(width=280)"
],
"metadata": {
"id": "Z9CeJL-Dd-Ul"
},
"execution_count": null,
"outputs": []
}
]
}
\ No newline at end of file
import torch
import torch.nn.functional as F
from math import exp
import numpy as np
def gaussian(window_size, sigma):
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
return gauss/gauss.sum()
def create_window(window_size, channel=1):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0).cuda()
window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
return window
def create_window_3d(window_size, channel=1):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t())
_3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t())
window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().cuda()
return window
def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
# Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
if val_range is None:
if torch.max(img1) > 128:
max_val = 255
else:
max_val = 1
if torch.min(img1) < -0.5:
min_val = -1
else:
min_val = 0
L = max_val - min_val
else:
L = val_range
padd = 0
(_, channel, height, width) = img1.size()
if window is None:
real_size = min(window_size, height, width)
window = create_window(real_size, channel=channel).to(img1.device)
# mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
# mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
mu1 = F.conv2d(F.pad(img1, (5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=channel)
mu2 = F.conv2d(F.pad(img2, (5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
# sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq
# sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq
# sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2
sigma1_sq = F.conv2d(F.pad(img1 * img1, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu1_sq
sigma2_sq = F.conv2d(F.pad(img2 * img2, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu2_sq
sigma12 = F.conv2d(F.pad(img1 * img2, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu1_mu2
C1 = (0.01 * L) ** 2
C2 = (0.03 * L) ** 2
v1 = 2.0 * sigma12 + C2
v2 = sigma1_sq + sigma2_sq + C2
cs = torch.mean(v1 / v2) # contrast sensitivity
ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
if size_average:
ret = ssim_map.mean()
else:
ret = ssim_map.mean(1).mean(1).mean(1)
if full:
return ret, cs
return ret
def ssim_matlab(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=1):
# Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
if val_range is None:
if torch.max(img1) > 128:
max_val = 255
else:
max_val = 1
if torch.min(img1) < -0.5:
min_val = -1
else:
min_val = 0
L = max_val - min_val
else:
L = val_range
padd = 0
(_, _, height, width) = img1.size()
if window is None:
real_size = min(window_size, height, width)
window = create_window_3d(real_size, channel=1).to(img1.device)
# Channel is set to 1 since we consider color images as volumetric images
img1 = img1.unsqueeze(1)
img2 = img2.unsqueeze(1)
mu1 = F.conv3d(F.pad(img1, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1)
mu2 = F.conv3d(F.pad(img2, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv3d(F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_sq
sigma2_sq = F.conv3d(F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu2_sq
sigma12 = F.conv3d(F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_mu2
C1 = (0.01 * L) ** 2
C2 = (0.03 * L) ** 2
v1 = 2.0 * sigma12 + C2
v2 = sigma1_sq + sigma2_sq + C2
cs = torch.mean(v1 / v2) # contrast sensitivity
ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
if size_average:
ret = ssim_map.mean()
else:
ret = ssim_map.mean(1).mean(1).mean(1)
if full:
return ret, cs
return ret
def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False):
device = img1.device
weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device)
levels = weights.size()[0]
mssim = []
mcs = []
for _ in range(levels):
sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range)
mssim.append(sim)
mcs.append(cs)
img1 = F.avg_pool2d(img1, (2, 2))
img2 = F.avg_pool2d(img2, (2, 2))
mssim = torch.stack(mssim)
mcs = torch.stack(mcs)
# Normalize (to avoid NaNs during training unstable models, not compliant with original definition)
if normalize:
mssim = (mssim + 1) / 2
mcs = (mcs + 1) / 2
pow1 = mcs ** weights
pow2 = mssim ** weights
# From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/
output = torch.prod(pow1[:-1] * pow2[-1])
return output
# Classes to re-use window
class SSIM(torch.nn.Module):
def __init__(self, window_size=11, size_average=True, val_range=None):
super(SSIM, self).__init__()
self.window_size = window_size
self.size_average = size_average
self.val_range = val_range
# Assume 3 channel for SSIM
self.channel = 3
self.window = create_window(window_size, channel=self.channel)
def forward(self, img1, img2):
(_, channel, _, _) = img1.size()
if channel == self.channel and self.window.dtype == img1.dtype:
window = self.window
else:
window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)
self.window = window
self.channel = channel
_ssim = ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)
dssim = (1 - _ssim) / 2
return dssim
class MSSSIM(torch.nn.Module):
def __init__(self, window_size=11, size_average=True, channel=3):
super(MSSSIM, self).__init__()
self.window_size = window_size
self.size_average = size_average
self.channel = channel
def forward(self, img1, img2):
# TODO: store window between calls if possible
return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average)
import os
import sys
import time
import copy
import shutil
import random
import pdb
import torch
import numpy as np
from tqdm import tqdm
import config
import myutils
from torch.utils.data import DataLoader
##### Parse CmdLine Arguments #####
os.environ["CUDA_VISIBLE_DEVICES"]='7'
args, unparsed = config.get_args()
cwd = os.getcwd()
device = torch.device('cuda' if args.cuda else 'cpu')
torch.manual_seed(args.random_seed)
if args.cuda:
torch.cuda.manual_seed(args.random_seed)
if args.dataset == "vimeo90K_septuplet":
from dataset.vimeo90k_septuplet import get_loader
test_loader = get_loader('test', args.data_root, args.test_batch_size, shuffle=False, num_workers=args.num_workers)
elif args.dataset == "ucf101":
from dataset.ucf101_test import get_loader
test_loader = get_loader(args.data_root, args.test_batch_size, shuffle=False, num_workers=args.num_workers)
elif args.dataset == "gopro":
from dataset.GoPro import get_loader
test_loader = get_loader(args.data_root, args.test_batch_size, shuffle=False, num_workers=args.num_workers, test_mode=True, interFrames=args.n_outputs)
else:
raise NotImplementedError
from model.FLAVR_arch import UNet_3D_3D
print("Building model: %s"%args.model.lower())
model = UNet_3D_3D(args.model.lower() , n_inputs=args.nbr_frame, n_outputs=args.n_outputs, joinType=args.joinType)
model = torch.nn.DataParallel(model).to(device)
print("#params" , sum([p.numel() for p in model.parameters()]))
def test(args):
time_taken = []
losses, psnrs, ssims = myutils.init_meters(args.loss)
model.eval()
psnr_list = []
with torch.no_grad():
for i, (images, gt_image ) in enumerate(tqdm(test_loader)):
images = [img_.cuda() for img_ in images]
gt = [g_.cuda() for g_ in gt_image]
torch.cuda.synchronize()
start_time = time.time()
out = model(images)
out = torch.cat(out)
gt = torch.cat(gt)
torch.cuda.synchronize()
time_taken.append(time.time() - start_time)
myutils.eval_metrics(out, gt, psnrs, ssims)
print("PSNR: %f, SSIM: %fn" %
(psnrs.avg, ssims.avg))
print("Time , " , sum(time_taken)/len(time_taken))
return psnrs.avg
""" Entry Point """
def main(args):
assert args.load_from is not None
model_dict = model.state_dict()
model.load_state_dict(torch.load(args.load_from)["state_dict"] , strict=True)
test(args)
if __name__ == "__main__":
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment