Commit c2f37e29 authored by yongshk's avatar yongshk
Browse files

Initial commit

parents
# 简介
FLAVR是一种用于视频插值的深度学习模型,可以通过插值技术将低帧率视频转换为高帧率视频。它通过对低帧率视频进行逐帧处理,并使用深度学习网络来推断丢失的帧,以生成更平滑的高帧率视频。相比传统的插值算法,FLAVR能够生成更加自然和逼真的高帧率视频,并且可以处理复杂的场景和动作。
# 测试流程
## 安装工具包
pytorch1.10版本[1.10.0a0+git2040069-dtk2210]
## 加载环境变量
```
export PATH={PYTHON3_install_dir}/bin:$PATH
export LD_LIBRARY_PATH={PYTHON3_install_dir}/lib:$LD_LIBRARY_PATH
```
## 下载数据集
Vimeo-90K septuplet 数据集
[vimeo_triplet](http://data.csail.mit.edu/tofu/dataset/vimeo_triplet.zip)
# 运行指令
## 在 Vimeo-90K septuplets 上训练模型
要在 Vimeo-90K 数据集上训练您自己的模型,请使用以下命令。[您可以从此链接](http://toflow.csail.mit.edu/)下载数据集(至少需要2个DCU)。
```
python main.py --batch_size 32 --test_batch_size 32 --dataset vimeo90K_septuplet --loss 1*L1 --max_epoch 200 --lr 0.0002 --data_root <dataset_path> --n_outputs 1
```
GoPro 数据集上的训练类似,更改`n_outputs`为 7 以进行 8 倍插值。
## 直接进行模型下载
您可以从以下链接下载预训练的 FLAVR 模型。
| 方法 | 训练好的模型 |
| ------- | ------------------------------------------------------------ |
| **2倍** | [关联](https://drive.google.com/drive/folders/1M6ec7t59exOSlx_Wp6K9_njBlLH2IPBC?usp=sharing) |
| **4倍** | [关联](https://drive.google.com/file/d/1btmNm4LkHVO9gjAaKKN9CXf5vP7h4hCy/view?usp=sharing) |
| **8倍** | [关联](https://drive.google.com/drive/folders/1Gd2l69j7UC1Zua7StbUNcomAAhmE-xFb?usp=sharing) |
### Middleburry 评价
要在 Middleburry 的公共基准上进行评估,请运行以下命令。
```
python Middleburry_Test.py --data_root <data_path> --load_from <model_path>
```
# 参考
[https://github.com/tarun005/FLAVR](https://github.com/tarun005/FLAVR)
\ No newline at end of file
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
import os
import sys
import time
import copy
import shutil
import random
import pdb
import torch
import numpy as np
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from dataset.transforms import Resize
import config
import myutils
from torch.utils.data import DataLoader
args, unparsed = config.get_args()
cwd = os.getcwd()
device = torch.device('cuda' if args.cuda else 'cpu')
torch.manual_seed(args.random_seed)
if args.cuda:
torch.cuda.manual_seed(args.random_seed)
from dataset.Middleburry import get_loader
test_loader = get_loader(args.data_root, 1, shuffle=False, num_workers=args.num_workers)
from model.FLAVR_arch import UNet_3D_3D
print("Building model: %s"%args.model.lower())
model = UNet_3D_3D(args.model.lower() , n_inputs=args.nbr_frame, n_outputs=args.n_outputs, joinType=args.joinType)
# Just make every model to DataParallel
model = torch.nn.DataParallel(model).to(device)
print("#params" , sum([p.numel() for p in model.parameters()]))
def make_image(img):
# img = F.interpolate(img.unsqueeze(0) , (720,1280) , mode="bilinear").squeeze(0)
q_im = img.data.mul(255.).clamp(0,255).round()
im = q_im.permute(1, 2, 0).cpu().numpy().astype(np.uint8)
return im
folderList = ['Backyard', 'Basketball', 'Dumptruck', 'Evergreen', 'Mequon', 'Schefflera', 'Teddy', 'Urban']
def test(args):
time_taken = []
img_save_id = 0
losses, psnrs, ssims = myutils.init_meters(args.loss)
model.eval()
psnr_list = []
with torch.no_grad():
for i, (images, name ) in enumerate((test_loader)):
if name[0] not in folderList:
continue;
images = torch.stack(images , dim=1).squeeze(0)
# images = [img_.cuda() for img_ in images]
H,W = images[0].shape[-2:]
resizes = 8*(H//8) , 8*(W//8)
import torchvision
transform = Resize(resizes)
rev_transforms = Resize((H,W))
images = transform(images).unsqueeze(0).cuda()# [transform(img_.squeeze(0)).unsqueeze(0).cuda() for img_ in images]
images = torch.unbind(images, dim=1)
start_time = time.time()
out = model(images)
print("Time Taken" , time.time() - start_time)
out = torch.cat(out)
out = rev_transforms(out)
output_image = make_image(out.squeeze(0))
import imageio
os.makedirs("Middleburry/%s/"%name[0])
imageio.imwrite("Middleburry/%s/frame10i11.png"%name[0], output_image)
return
def main(args):
assert args.load_from is not None
model_dict = model.state_dict()
model.load_state_dict(torch.load(args.load_from)["state_dict"] , strict=True)
test(args)
if __name__ == "__main__":
main(args)
# FLAVR: Flow-Agnostic Video Representations for Fast Frame Interpolation
## WACV 2023 (Best Paper Finalist)
![Eg1](./figures/baloon.gif)
![Eg2](./figures/sprite.gif)
[[project page](https://tarun005.github.io/FLAVR/)] [[paper](https://arxiv.org/pdf/2012.08512.pdf)] [[Project Video](youtu.be/HFOY7CGpJRM)]
FLAVR is a fast, flow-free frame interpolation method capable of single shot multi-frame prediction. It uses a customized encoder decoder architecture with spatio-temporal convolutions and channel gating to capture and interpolate complex motion trajectories between frames to generate realistic high frame rate videos. This repository contains original source code.
## Inference Times
FLAVR delivers a better trade-off between speed and accuracy compared to prior frame interpolation methods.
Method | FPS on 512x512 Image (sec)
| ------------- |:-------------:|
| FLAVR | 3.10 |
| SuperSloMo | 3.33 |
| QVI | 1.02 |
| DAIN | 0.77 |
## Dependencies
We used the following to train and test the model.
- Ubuntu 18.04
- Python==3.7.4
- numpy==1.19.2
- [PyTorch](http://pytorch.org/)==1.5.0, torchvision==0.6.0, cudatoolkit==10.1
## Model
<center><img src="./figures/arch_dia.png" width="90%"></center>
## Training model on Vimeo-90K septuplets
For training your own model on the Vimeo-90K dataset, use the following command. You can download the dataset from [this link](http://toflow.csail.mit.edu/). The results reported in the paper are trained using 8GPUs.
``` bash
python main.py --batch_size 32 --test_batch_size 32 --dataset vimeo90K_septuplet --loss 1*L1 --max_epoch 200 --lr 0.0002 --data_root <dataset_path> --n_outputs 1
```
Training on GoPro dataset is similar, change `n_outputs` to 7 for 8x interpolation.
## Testing using trained model.
### Trained Models.
You can download the pretrained FLAVR models from the following links.
Method | Trained Model |
| ------------- |:-----|
| **2x** | [Link](https://drive.google.com/drive/folders/1M6ec7t59exOSlx_Wp6K9_njBlLH2IPBC?usp=sharing) |
| **4x** | [Link](https://drive.google.com/file/d/1btmNm4LkHVO9gjAaKKN9CXf5vP7h4hCy/view?usp=sharing) |
| **8x** | [Link](https://drive.google.com/drive/folders/1Gd2l69j7UC1Zua7StbUNcomAAhmE-xFb?usp=sharing) |
### 2x Interpolation
For testing a pretrained model on Vimeo-90K septuplet validation set, you can run the following command:
```bash
python test.py --dataset vimeo90K_septuplet --data_root <data_path> --load_from <saved_model> --n_outputs 1
```
### 8x Interpolation
For testing a multiframe interpolation model, use the same command as above with multiframe FLAVR model, with `n_outputs` changed accordingly.
### Time Benchmarking
The testing script, in addition to computing PSNR and SSIM values, will also output the inference time and speed for interpolation.
### Evaluation on Middleburry
To evaluate on the public benchmark of Middleburry, run the following.
```bash
python Middleburry_Test.py --data_root <data_path> --load_from <model_path>
```
The interpolated images will be saved to the folder `Middleburry` in a format that can be readily uploaded to the [leaderboard](https://vision.middlebury.edu/flow/eval/results/results-i2.php).
## SloMo-Filter on custom video
You can use our trained models and apply the slomo filter on your own video (requires OpenCV 4.2.0). Use the following command. If you want to convert a 30FPS video to 240FPS video, simply use the command
```bash
python interpolate.py --input_video <input_video> --factor 8 --load_model <model_path>
```
by using our [pretrained model](https://drive.google.com/drive/folders/1Gd2l69j7UC1Zua7StbUNcomAAhmE-xFb?usp=sharing) for 8x interpolation. For converting a 30FPS video to 60FPS video, use a 2x model with `factor` 2.
## Baseline Models
We also train models for many other previous works on our setting, and provide models for all these methods. Complete benchmarking scripts will also be released soon.
Method | PSNR on Vimeo | Trained Model |
| ------------- |:-------------:| -----:|
| FLAVR | 36.3 | [Model](https://drive.google.com/drive/folders/1M6ec7t59exOSlx_Wp6K9_njBlLH2IPBC?usp=sharing)
| AdaCoF | 35.3 | [Model](https://drive.google.com/file/d/19Y2TDZkSbRgNu-OItvqk3qn5cBWGg1RT/view?usp=sharing) |
| QVI* | 35.15 | [Model](https://drive.google.com/file/d/1v2u5diGcvdTLhck8Xwu0baI4zm0JBJhI/view?usp=sharing) |
| DAIN | 34.19 | [Model](https://drive.google.com/file/d/1RfrwrHoSX_3RIdsoQgPg9IfGAJRhOoEp/view?usp=sharing) |
| SuperSloMo* | 32.90 | [Model](https://drive.google.com/file/d/1dR2at5DQO7w5s2tA5stC95Nmu_ezsPth/view?usp=sharing)
* SuperSloMo is implemented using code repository from [here](https://github.com/avinashpaliwal/Super-SloMo). Other baselines are implemented using the official codebases.
* The numbers presented here for the baselines are slightly better than those reported in the paper.
## Google Colab
A Colab notebook to try 2x slow-motion filtering on custom videos is available in the *notebooks* directory of this repo.
## Model for Motion-Magnification
Unfortunately, we cannot provide the trained models for motion-magnification at this time. We are working towards making a model available soon.
## Acknowledgement
The code is heavily borrowed from Facebook's official [PyTorch video repository](https://github.com/facebookresearch/VMZ) and [CAIN](https://github.com/myungsub/CAIN).
## Cite
If this code helps in your work, please consider citing us.
``` text
@article{kalluri2021flavr,
title={FLAVR: Flow-Agnostic Video Representations for Fast Frame Interpolation},
author={Kalluri, Tarun and Pathak, Deepak and Chandraker, Manmohan and Tran, Du},
booktitle={arxiv},
year={2021}
}
```
import argparse
arg_lists = []
parser = argparse.ArgumentParser()
def add_argument_group(name):
arg = parser.add_argument_group(name)
arg_lists.append(arg)
return arg
# Dataset
data_arg = add_argument_group('Dataset')
data_arg.add_argument('--dataset', type=str, default='vimeo90k')
data_arg.add_argument('--data_root', type=str)
# Model
model_choices = ["unet_18", "unet_34"]
model_arg = add_argument_group('Model')
model_arg.add_argument('--model', choices=model_choices, type=str, default="unet_18")
model_arg.add_argument('--nbr_frame' , type=int , default=4)
model_arg.add_argument('--nbr_width' , type=int , default=1)
model_arg.add_argument('--joinType' , choices=["concat" , "add" , "none"], default="concat")
model_arg.add_argument('--upmode' , choices=["transpose","upsample"], type=str, default="transpose")
model_arg.add_argument('n_outputs' , type=int, default=1,
help="For Kx FLAVR, use n_outputs k-1")
# Training / test parameters
learn_arg = add_argument_group('Learning')
learn_arg.add_argument('--loss', type=str, default='1*L1')
learn_arg.add_argument('--lr', type=float, default=2e-4)
learn_arg.add_argument('--beta1', type=float, default=0.9)
learn_arg.add_argument('--beta2', type=float, default=0.99)
learn_arg.add_argument('--batch_size', type=int, default=16)
learn_arg.add_argument('--test_batch_size', type=int, default=1)
learn_arg.add_argument('--start_epoch', type=int, default=0)
learn_arg.add_argument('--max_epoch', type=int, default=200)
learn_arg.add_argument('--resume', action='store_true')
learn_arg.add_argument('--resume_exp', type=str, default=None)
learn_arg.add_argument('--checkpoint_dir', type=str ,default=".")
learn_arg.add_argument("--load_from" ,type=str , default=None)
learn_arg.add_argument("--pretrained" , type=str,
help="Load from a pretrained model.")
# Misc
misc_arg = add_argument_group('Misc')
misc_arg.add_argument('--exp_name', type=str, default='exp')
misc_arg.add_argument('--log_iter', type=int, default=60)
misc_arg.add_argument('--num_gpu', type=int, default=1)
misc_arg.add_argument('--random_seed', type=int, default=12345)
misc_arg.add_argument('--num_workers', type=int, default=16)
misc_arg.add_argument('--use_tensorboard', action='store_true')
misc_arg.add_argument('--val_freq', type=int, default=1)
def get_args():
"""Parses all of the arguments above
"""
args, unparsed = parser.parse_known_args()
if args.num_gpu > 0:
setattr(args, 'cuda', True)
else:
setattr(args, 'cuda', False)
if len(unparsed) > 1:
print("Unparsed args: {}".format(unparsed))
return args, unparsed
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import random
import glob
class Davis(Dataset):
def __init__(self, data_root , ext="png"):
super().__init__()
self.data_root = data_root
self.images_sets = []
for label_id in os.listdir(self.data_root):
ctg_imgs_ = sorted(os.listdir(os.path.join(self.data_root , label_id)))
ctg_imgs_ = [os.path.join(self.data_root , label_id , img_id) for img_id in ctg_imgs_]
for start_idx in range(0,len(ctg_imgs_)-6,2):
add_files = ctg_imgs_[start_idx : start_idx+7 : 2]
add_files = add_files[:2] + [ctg_imgs_[start_idx+3]] + add_files[2:]
self.images_sets.append(add_files)
self.transforms = transforms.Compose([
transforms.CenterCrop((480,840)),
transforms.ToTensor()
])
print(len(self.images_sets))
def __getitem__(self, idx):
imgpaths = self.images_sets[idx]
images = [Image.open(img) for img in imgpaths]
images = [self.transforms(img) for img in images]
return images[:2] + images[3:] , [images[2]]
def __len__(self):
return len(self.images_sets)
def get_loader(data_root, batch_size, shuffle, num_workers, test_mode=True):
dataset = Davis(data_root)
return DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
if __name__ == "__main__":
dataset = Davis("./Davis_test/")
print(len(dataset))
dataloader = DataLoader(dataset , batch_size=1, shuffle=True, num_workers=0)
\ No newline at end of file
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import random
import glob
import pdb
class GoPro(Dataset):
def __init__(self, data_root , mode="train", interFrames=3, n_inputs=4, ext="png"):
super().__init__()
self.interFrames = interFrames
self.n_inputs = n_inputs
self.setLength = (n_inputs-1)*(interFrames+1)+1 ## We require these many frames in total for interpolating `interFrames` number of
## intermediate frames with `n_input` input frames.
self.data_root = os.path.join(data_root , mode)
video_list = os.listdir(self.data_root)
self.frames_list = []
self.file_list = []
for video in video_list:
frames = sorted(os.listdir(os.path.join(self.data_root , video)))
n_sets = (len(frames) - self.setLength)//(interFrames+1) + 1
videoInputs = [frames[(interFrames+1)*i:(interFrames+1)*i+self.setLength ] for i in range(n_sets)]
videoInputs = [[os.path.join(video , f) for f in group] for group in videoInputs]
self.file_list.extend(videoInputs)
self.transforms = transforms.Compose([
transforms.CenterCrop(512),
transforms.ToTensor()
])
def __getitem__(self, idx):
imgpaths = [os.path.join(self.data_root , fp) for fp in self.file_list[idx]]
pick_idxs = list(range(0,self.setLength,self.interFrames+1))
rem = self.interFrames%2
gt_idx = list(range(self.setLength//2-self.interFrames//2 , self.setLength//2+self.interFrames//2+rem))
input_paths = [imgpaths[idx] for idx in pick_idxs]
gt_paths = [imgpaths[idx] for idx in gt_idx]
images = [Image.open(pth_) for pth_ in input_paths]
images = [self.transforms(img_) for img_ in images]
gt_images = [Image.open(pth_) for pth_ in gt_paths]
gt_images = [self.transforms(img_) for img_ in gt_images]
return images , gt_images
def __len__(self):
return len(self.file_list)
def get_loader(data_root, batch_size, shuffle, num_workers, test_mode=True, interFrames=3, n_inputs=4):
if test_mode:
mode = "test"
else:
mode = "train"
dataset = GoPro(data_root , mode, interFrames=interFrames, n_inputs=n_inputs)
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True)
if __name__ == "__main__":
dataset = GoPro("./GoPro" , mode="train", interFrames=3, n_inputs=4)
print(len(dataset))
dataloader = DataLoader(dataset , batch_size=1, shuffle=True, num_workers=0)
alldirs/GOPR0384_11_00
alldirs/GOPR0385_11_01
alldirs/GOPR0410_11_00
alldirs/GOPR0862_11_00
alldirs/GOPR0869_11_00
alldirs/GOPR0881_11_01
alldirs/GOPR0384_11_05
alldirs/GOPR0396_11_00
alldirs/GOPR0854_11_00
alldirs/GOPR0868_11_00
alldirs/GOPR0871_11_00
alldirs/GOPR0372_07_00
alldirs/GOPR0374_11_01
alldirs/GOPR0378_13_00
alldirs/GOPR0384_11_01
alldirs/GOPR0384_11_04
alldirs/GOPR0477_11_00
alldirs/GOPR0868_11_02
alldirs/GOPR0884_11_00
alldirs/GOPR0372_07_01
alldirs/GOPR0374_11_02
alldirs/GOPR0379_11_00
alldirs/GOPR0384_11_02
alldirs/GOPR0385_11_00
alldirs/GOPR0857_11_00
alldirs/GOPR0871_11_01
alldirs/GOPR0374_11_00
alldirs/GOPR0374_11_03
alldirs/GOPR0380_11_00
alldirs/GOPR0384_11_03
alldirs/GOPR0386_11_00
alldirs/GOPR0868_11_01
alldirs/GOPR0881_11_00
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import random
import glob
class Middelburry(Dataset):
def __init__(self, data_root , ext="png"):
super().__init__()
self.data_root = data_root
self.file_list = os.listdir(self.data_root)
self.transforms = transforms.Compose([
transforms.ToTensor()
])
def __getitem__(self, idx):
imgpath = os.path.join(self.data_root , self.file_list[idx])
name = self.file_list[idx]
if name == "Teddy": ## Handle inputs with just two inout frames. FLAVR takes atleast 4.
imgpaths = [os.path.join(imgpath , "frame10.png") , os.path.join(imgpath , "frame10.png") ,os.path.join(imgpath , "frame11.png") ,os.path.join(imgpath , "frame11.png") ]
else:
imgpaths = [os.path.join(imgpath , "frame09.png") , os.path.join(imgpath , "frame10.png") ,os.path.join(imgpath , "frame11.png") ,os.path.join(imgpath , "frame12.png") ]
images = [Image.open(img).convert('RGB') for img in imgpaths]
images = [self.transforms(img) for img in images]
sizes = images[0].shape
return images , name
def __len__(self):
return len(self.file_list)
def get_loader(data_root, batch_size, shuffle, num_workers, test_mode=True):
dataset = Middelburry(data_root)
return DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
# from https://github.com/facebookresearch/VMZ
import torch
import numbers
import random
from torchvision.transforms import RandomCrop, RandomResizedCrop
__all__ = [
"RandomCropVideo",
"RandomResizedCropVideo",
"CenterCropVideo",
"NormalizeVideo",
"ToTensorVideo",
"RandomHorizontalFlipVideo",
"Resize",
"TemporalCenterCrop",
"RandomTemporalFlipVideo",
"RandomVerticalFlipVideo"
]
def _is_tensor_video_clip(clip):
if not torch.is_tensor(clip):
raise TypeError("clip should be Tesnor. Got %s" % type(clip))
if not clip.ndimension() == 4:
raise ValueError("clip should be 4D. Got %dD" % clip.dim())
return True
def crop(clip, i, j, h, w):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
"""
assert len(clip.size()) == 4, "clip should be a 4D tensor"
return clip[..., i : i + h, j : j + w]
def temporal_center_crop(clip, clip_len):
"""
Args:
clip (torch.tensor): Video clip to be
cropped along the temporal axis. Size is (C, T, H, W)
"""
assert len(clip.size()) == 4, "clip should be a 4D tensor"
assert clip.size(1) >= clip_len, "clip is shorter than the proposed lenght"
middle = int(clip.size(1) // 2)
start = middle - clip_len // 2
return clip[:, start : start + clip_len, ...]
def resize(clip, target_size, interpolation_mode):
assert len(target_size) == 2, "target size should be tuple (height, width)"
return torch.nn.functional.interpolate(
clip, size=target_size, mode=interpolation_mode, align_corners=False
)
def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
"""
Do spatial cropping and resizing to the video clip
Args:
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
i (int): i in (i,j) i.e coordinates of the upper left corner.
j (int): j in (i,j) i.e coordinates of the upper left corner.
h (int): Height of the cropped region.
w (int): Width of the cropped region.
size (tuple(int, int)): height and width of resized clip
Returns:
clip (torch.tensor): Resized and cropped clip. Size is (C, T, H, W)
"""
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor"
clip = crop(clip, i, j, h, w)
clip = resize(clip, size, interpolation_mode)
return clip
def center_crop(clip, crop_size):
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor"
h, w = clip.size(-2), clip.size(-1)
th, tw = crop_size
assert h >= th and w >= tw, "height and width must be >= than crop_size"
i = int(round((h - th) / 2.0))
j = int(round((w - tw) / 2.0))
return crop(clip, i, j, th, tw)
def to_tensor(clip):
"""
Convert tensor data type from uint8 to float, divide value by 255.0 and
permute the dimenions of clip tensor
Args:
clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C)
Return:
clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W)
"""
_is_tensor_video_clip(clip)
if not clip.dtype == torch.uint8:
raise TypeError(
"clip tensor should have data type uint8. Got %s" % str(clip.dtype)
)
return clip.float().permute(3, 0, 1, 2) / 255.0
def normalize(clip, mean, std, inplace=False):
"""
Args:
clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W)
mean (tuple): pixel RGB mean. Size is (3)
std (tuple): pixel standard deviation. Size is (3)
Returns:
normalized clip (torch.tensor): Size is (C, T, H, W)
"""
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor"
if not inplace:
clip = clip.clone()
mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
return clip
def hflip(clip):
"""
Args:
clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W)
Returns:
flipped clip (torch.tensor): Size is (C, T, H, W)
"""
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor"
return clip.flip((-1))
def vflip(clip):
"""
Args:
clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W)
Returns:
flipped clip (torch.tensor): Size is (C, T, H, W)
"""
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor"
return clip.flip((-2))
def tflip(clip):
"""
Args:
clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W)
Returns:
flipped clip (torch.tensor): Size is (C, T, H, W)
"""
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor"
return clip.flip((-3))
class RandomCropVideo(RandomCrop):
def __init__(self, size):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
Returns:
torch.tensor: randomly cropped/resized video clip.
size is (C, T, OH, OW)
"""
i, j, h, w = self.get_params(clip, self.size)
return crop(clip, i, j, h, w)
def __repr__(self):
return self.__class__.__name__ + "(size={0})".format(self.size)
class RandomResizedCropVideo(RandomResizedCrop):
def __init__(
self,
size,
scale=(0.08, 1.0),
ratio=(3.0 / 4.0, 4.0 / 3.0),
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
assert len(size) == 2, "size should be tuple (height, width)"
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
self.scale = scale
self.ratio = ratio
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
Returns:
torch.tensor: randomly cropped/resized video clip.
size is (C, T, H, W)
"""
i, j, h, w = self.get_params(clip, self.scale, self.ratio)
return resized_crop(clip, i, j, h, w, self.size, self.interpolation_mode)
def __repr__(self):
return (
self.__class__.__name__
+ "(size={0}, interpolation_mode={1}, scale={2}, ratio={3})".format(
self.size, self.interpolation_mode, self.scale, self.ratio
)
)
class CenterCropVideo(object):
def __init__(self, crop_size):
if isinstance(crop_size, numbers.Number):
self.crop_size = (int(crop_size), int(crop_size))
else:
self.crop_size = crop_size
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
Returns:
torch.tensor: central cropping of video clip. Size is
(C, T, crop_size, crop_size)
"""
return center_crop(clip, self.crop_size)
def __repr__(self):
r = self.__class__.__name__ + "(crop_size={0})".format(self.crop_size)
return r
class TemporalCenterCrop(object):
def __init__(self, clip_len):
self.clip_len = clip_len
def __call__(self, clip):
return temporal_center_crop(clip, self.clip_len)
class UnfoldClips(object):
def __init__(self, clip_len, overlap):
self.clip_len = clip_len
assert overlap > 0 and overlap <= 1
self.step = round(clip_len * overlap)
def __call__(self, clip):
if clip.size(1) < self.clip_len:
return clip.unfold(1, clip.size(1), clip.size(1)).permute(1, 0, 4, 2, 3)
results = clip.unfold(1, self.clip_len, self.clip_len).permute(1, 0, 4, 2, 3)
return results
class TempPadClip(object):
def __init__(self, clip_len):
self.num_frames = clip_len
def __call__(self, clip):
if clip.size(1) == 0:
return clip
if clip.size(1) < self.num_frames:
# do something and return
step = clip.size(1) / self.num_frames
idxs = torch.arange(self.num_frames, dtype=torch.float32) * step
idxs = idxs.floor().to(torch.int64)
return clip[:, idxs, ...]
step = clip.size(1) / self.num_frames
if step.is_integer():
# optimization: if step is integer, don't need to perform
# advanced indexing
step = int(step)
return clip[:, slice(None, None, step), ...]
idxs = torch.arange(self.num_frames, dtype=torch.float32) * step
idxs = idxs.floor().to(torch.int64)
return clip[:, idxs, ...]
class NormalizeVideo(object):
"""
Normalize the video clip by mean subtraction
and division by standard deviation
Args:
mean (3-tuple): pixel RGB mean
std (3-tuple): pixel RGB standard deviation
inplace (boolean): whether do in-place normalization
"""
def __init__(self, mean, std, inplace=False):
self.mean = mean
self.std = std
self.inplace = inplace
def __call__(self, clip):
"""
Args:
clip (torch.tensor): video clip to be
normalized. Size is (C, T, H, W)
"""
return normalize(clip, self.mean, self.std, self.inplace)
def __repr__(self):
return self.__class__.__name__ + "(mean={0}, std={1}, inplace={2})".format(
self.mean, self.std, self.inplace
)
class ToTensorVideo(object):
"""
Convert tensor data type from uint8 to float, divide value by 255.0 and
permute the dimenions of clip tensor
"""
def __init__(self):
pass
def __call__(self, clip):
"""
Args:
clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C)
Return:
clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W)
"""
return to_tensor(clip)
def __repr__(self):
return self.__class__.__name__
class RandomHorizontalFlipVideo(object):
"""
Flip the video clip along the horizonal direction with a given probability
Args:
p (float): probability of the clip being flipped. Default value is 0.5
"""
def __init__(self, p=0.5):
self.p = p
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Size is (C, T, H, W)
Return:
clip (torch.tensor): Size is (C, T, H, W)
"""
if random.random() < self.p:
clip = hflip(clip)
return clip
def __repr__(self):
return self.__class__.__name__ + "(p={0})".format(self.p)
class RandomVerticalFlipVideo(object):
"""
Flip the video clip along the horizonal direction with a given probability
Args:
p (float): probability of the clip being flipped. Default value is 0.5
"""
def __init__(self, p=0.5):
self.p = p
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Size is (C, T, H, W)
Return:
clip (torch.tensor): Size is (C, T, H, W)
"""
if random.random() < self.p:
clip = vflip(clip)
return clip
def __repr__(self):
return self.__class__.__name__ + "(p={0})".format(self.p)
class RandomTemporalFlipVideo(object):
"""
Flip the video clip along the horizonal direction with a given probability
Args:
p (float): probability of the clip being flipped. Default value is 0.5
"""
def __init__(self, p=0.5):
self.p = p
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Size is (C, T, H, W)
Return:
clip (torch.tensor): Size is (C, T, H, W)
"""
if random.random() < self.p:
clip = tflip(clip)
return clip
def __repr__(self):
return self.__class__.__name__ + "(p={0})".format(self.p)
class Resize(object):
def __init__(self, size):
self.size = size
def __call__(self, vid):
return resize(vid, self.size, interpolation_mode="bilinear")
\ No newline at end of file
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import random
import glob
class UCF(Dataset):
def __init__(self, data_root , ext="png"):
super().__init__()
self.data_root = data_root
self.file_list = sorted(os.listdir(self.data_root))
self.transforms = transforms.Compose([
transforms.CenterCrop((224,224)),
transforms.ToTensor(),
])
def __getitem__(self, idx):
imgpath = os.path.join(self.data_root , self.file_list[idx])
imgpaths = [os.path.join(imgpath , "frame0.png") , os.path.join(imgpath , "frame1.png") ,os.path.join(imgpath , "frame2.png") ,os.path.join(imgpath , "frame3.png") ,os.path.join(imgpath , "framet.png")]
images = [Image.open(img) for img in imgpaths]
images = [self.transforms(img) for img in images]
return images[:-1] , [images[-1]]
def __len__(self):
return len(self.file_list)
def get_loader(data_root, batch_size, shuffle, num_workers, test_mode=True):
dataset = UCF(data_root)
return DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
if __name__ == "__main__":
dataset = UCF_triplet("./ucf_test/")
print(len(dataset))
dataloader = DataLoader(dataset , batch_size=1, shuffle=True, num_workers=0)
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import random
class VimeoSepTuplet(Dataset):
def __init__(self, data_root, is_training , input_frames="1357"):
"""
Creates a Vimeo Septuplet object.
Inputs.
data_root: Root path for the Vimeo dataset containing the sep tuples.
is_training: Train/Test.
input_frames: Which frames to input for frame interpolation network.
"""
self.data_root = data_root
self.image_root = os.path.join(self.data_root, 'sequences')
self.training = is_training
self.inputs = input_frames
train_fn = os.path.join(self.data_root, 'sep_trainlist.txt')
test_fn = os.path.join(self.data_root, 'sep_testlist.txt')
with open(train_fn, 'r') as f:
self.trainlist = f.read().splitlines()
with open(test_fn, 'r') as f:
self.testlist = f.read().splitlines()
if self.training:
self.transforms = transforms.Compose([
transforms.RandomCrop(256),
transforms.RandomHorizontalFlip(0.5),
transforms.RandomVerticalFlip(0.5),
transforms.ColorJitter(0.05, 0.05, 0.05, 0.05),
transforms.ToTensor()
])
else:
self.transforms = transforms.Compose([
transforms.ToTensor()
])
def __getitem__(self, index):
if self.training:
imgpath = os.path.join(self.image_root, self.trainlist[index])
else:
imgpath = os.path.join(self.image_root, self.testlist[index])
imgpaths = [imgpath + f'/im{i}.png' for i in range(1,8)]
pth_ = imgpaths
# Load images
images = [Image.open(pth) for pth in imgpaths]
## Select only relevant inputs
inputs = [int(e)-1 for e in list(self.inputs)]
inputs = inputs[:len(inputs)//2] + [3] + inputs[len(inputs)//2:]
images = [images[i] for i in inputs]
imgpaths = [imgpaths[i] for i in inputs]
# Data augmentation
if self.training:
seed = random.randint(0, 2**32)
images_ = []
for img_ in images:
random.seed(seed)
images_.append(self.transforms(img_))
images = images_
# Random Temporal Flip
if random.random() >= 0.5:
images = images[::-1]
imgpaths = imgpaths[::-1]
else:
T = self.transforms
images = [T(img_) for img_ in images]
gt = images[len(images)//2]
images = images[:len(images)//2] + images[len(images)//2+1:]
return images, [gt]
def __len__(self):
if self.training:
return len(self.trainlist)
else:
return len(self.testlist)
def get_loader(mode, data_root, batch_size, shuffle, num_workers, test_mode=None):
if mode == 'train':
is_training = True
else:
is_training = False
dataset = VimeoSepTuplet(data_root, is_training=is_training)
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True)
if __name__ == "__main__":
dataset = VimeoSepTuplet("./vimeo_septuplet/", is_training=True)
dataloader = DataLoader(dataset, batch_size=100, shuffle=False, num_workers=32, pin_memory=True)
\ No newline at end of file
import os
import torch
import cv2
import pdb
import time
import sys
import torchvision
from PIL import Image
import numpy as np
import tqdm
from torchvision.io import read_video , write_video
from dataset.transforms import ToTensorVideo , Resize
import torch.nn.functional as F
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--input_video" , type=str , required=True , help="Path/WebURL to input video")
parser.add_argument("--youtube-dl" , type=str , help="Path to youtube_dl" , default=".local/bin/youtube-dl")
parser.add_argument("--factor" , type=int , required=True , choices=[2,4,8] , help="How much interpolation needed. 2x/4x/8x.")
parser.add_argument("--codec" , type=str , help="video codec" , default="mpeg4")
parser.add_argument("--load_model" , required=True , type=str , help="path for stored model")
parser.add_argument("--up_mode" , type=str , help="Upsample Mode" , default="transpose")
parser.add_argument("--output_ext" , type=str , help="Output video format" , default=".avi")
parser.add_argument("--input_ext" , type=str, help="Input video format", default=".mp4")
parser.add_argument("--downscale" , type=float , help="Downscale input res. for memory" , default=1)
parser.add_argument("--output_fps" , type=int , help="Target FPS" , default=30)
parser.add_argument("--is_folder" , action="store_true" )
args = parser.parse_args()
input_video = args.input_video
input_ext = args.input_ext
from os import path
if not args.is_folder and not path.exists(input_video):
print("Invalid input file path!")
exit()
if args.is_folder and not path.exists(input_video):
print("Invalid input directory path!")
exit()
if args.output_ext != ".avi":
print("Currently supporting only writing to avi. Try using ffmpeg for conversion to mp4 etc.")
if input_video.endswith("/"):
video_name = input_video.split("/")[-2].split(input_ext)[0]
else:
video_name = input_video.split("/")[-1].split(input_ext)[0]
output_video = os.path.join(video_name + f"_{args.factor}x" + str(args.output_ext))
n_outputs = args.factor - 1
model_name = "unet_18"
nbr_frame = 4
joinType = "concat"
if input_video.startswith("http"):
assert args.youtube_dl is not None
youtube_dl_path = args.youtube_dl
cmd = f"{youtube_dl_path} -i -o video.mp4 {input_video}"
os.system(cmd)
input_video = "video.mp4"
output_video = "video" + str(args.output_ext)
def loadModel(model, checkpoint):
saved_state_dict = torch.load(checkpoint)['state_dict']
saved_state_dict = {k.partition("module.")[-1]:v for k,v in saved_state_dict.items()}
model.load_state_dict(saved_state_dict)
checkpoint = args.load_model
from model.FLAVR_arch import UNet_3D_3D
model = UNet_3D_3D(model_name.lower() , n_inputs=4, n_outputs=n_outputs, joinType=joinType , upmode=args.up_mode)
loadModel(model , checkpoint)
model = model.cuda()
def write_video_cv2(frames , video_name , fps , sizes):
out = cv2.VideoWriter(video_name,cv2.CAP_OPENCV_MJPEG,cv2.VideoWriter_fourcc('M','J','P','G'), fps, sizes)
for frame in frames:
out.write(frame)
def make_image(img):
q_im = img.data.mul(255.).clamp(0,255).round()
im = q_im.permute(1, 2, 0).cpu().numpy().astype(np.uint8)
im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
return im
def files_to_videoTensor(path , downscale=1.):
from PIL import Image
files = sorted(os.listdir(path))
print(len(files))
images = [torch.Tensor(np.asarray(Image.open(os.path.join(input_video , f)))).type(torch.uint8) for f in files]
print(images[0].shape)
videoTensor = torch.stack(images)
return videoTensor
def video_to_tensor(video):
videoTensor , _ , md = read_video(video)
fps = md["video_fps"]
print(fps)
return videoTensor
def video_transform(videoTensor , downscale=1):
T , H , W = videoTensor.size(0), videoTensor.size(1) , videoTensor.size(2)
downscale = int(downscale * 8)
resizes = 8*(H//downscale) , 8*(W//downscale)
transforms = torchvision.transforms.Compose([ToTensorVideo() , Resize(resizes)])
videoTensor = transforms(videoTensor)
# resizes = 720,1280
print("Resizing to %dx%d"%(resizes[0] , resizes[1]) )
return videoTensor , resizes
if args.is_folder:
videoTensor = files_to_videoTensor(input_video , args.downscale)
else:
videoTensor = video_to_tensor(input_video)
idxs = torch.Tensor(range(len(videoTensor))).type(torch.long).view(1,-1).unfold(1,size=nbr_frame,step=1).squeeze(0)
videoTensor , resizes = video_transform(videoTensor , args.downscale)
print("Video tensor shape is , " , videoTensor.shape)
frames = torch.unbind(videoTensor , 1)
n_inputs = len(frames)
width = n_outputs + 1
outputs = [] ## store the input and interpolated frames
outputs.append(frames[idxs[0][1]])
model = model.eval()
for i in tqdm.tqdm(range(len(idxs))):
idxSet = idxs[i]
inputs = [frames[idx_].cuda().unsqueeze(0) for idx_ in idxSet]
with torch.no_grad():
outputFrame = model(inputs)
outputFrame = [of.squeeze(0).cpu().data for of in outputFrame]
outputs.extend(outputFrame)
outputs.append(inputs[2].squeeze(0).cpu().data)
new_video = [make_image(im_) for im_ in outputs]
write_video_cv2(new_video , output_video , args.output_fps , (resizes[1] , resizes[0]))
print("Writing to " , output_video.split(".")[0] + ".mp4")
os.system('ffmpeg -hide_banner -loglevel warning -i %s %s'%(output_video , output_video.split(".")[0] + ".mp4"))
os.remove(output_video)
# Sourced from https://github.com/myungsub/CAIN/blob/master/loss.py, who sourced from https://github.com/thstkdgus35/EDSR-PyTorch/tree/master/src/loss
# Added Huber loss in addition.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import pytorch_msssim
class MeanShift(nn.Conv2d):
def __init__(self, rgb_mean, rgb_std, sign=-1):
super(MeanShift, self).__init__(3, 3, kernel_size=1)
std = torch.Tensor(rgb_std)
self.weight.data = torch.eye(3).view(3, 3, 1, 1)
self.weight.data.div_(std.view(3, 1, 1, 1))
self.bias.data = sign * torch.Tensor(rgb_mean)
self.bias.data.div_(std)
self.requires_grad = False
class HuberLoss(nn.Module):
def __init__(self , delta=1):
super().__init__()
self.delta = delta
def forward(self , sr , hr):
l1 = torch.abs(sr - hr)
mask = l1<self.delta
sq_loss = .5*(l1**2)
abs_loss = self.delta*(l1 - .5*self.delta)
return torch.mean(mask*sq_loss + (~mask)*(abs_loss))
class VGG(nn.Module):
def __init__(self, loss_type):
super(VGG, self).__init__()
vgg_features = models.vgg19(pretrained=True).features
modules = [m for m in vgg_features]
conv_index = loss_type[-2:]
if conv_index == '22':
self.vgg = nn.Sequential(*modules[:8])
elif conv_index == '33':
self.vgg = nn.Sequential(*modules[:16])
elif conv_index == '44':
self.vgg = nn.Sequential(*modules[:26])
elif conv_index == '54':
self.vgg = nn.Sequential(*modules[:35])
elif conv_index == 'P':
self.vgg = nn.ModuleList([
nn.Sequential(*modules[:8]),
nn.Sequential(*modules[8:16]),
nn.Sequential(*modules[16:26]),
nn.Sequential(*modules[26:35])
])
self.vgg = nn.DataParallel(self.vgg).cuda()
vgg_mean = (0.485, 0.456, 0.406)
vgg_std = (0.229, 0.224, 0.225)
self.sub_mean = MeanShift(vgg_mean, vgg_std)
self.vgg.requires_grad = False
# self.criterion = nn.L1Loss()
self.conv_index = conv_index
def forward(self, sr, hr):
def _forward(x):
x = self.sub_mean(x)
x = self.vgg(x)
return x
def _forward_all(x):
feats = []
x = self.sub_mean(x)
for module in self.vgg.module:
x = module(x)
feats.append(x)
return feats
if self.conv_index == 'P':
vgg_sr_feats = _forward_all(sr)
with torch.no_grad():
vgg_hr_feats = _forward_all(hr.detach())
loss = 0
for i in range(len(vgg_sr_feats)):
loss_f = F.mse_loss(vgg_sr_feats[i], vgg_hr_feats[i])
#print(loss_f)
loss += loss_f
#print()
else:
vgg_sr = _forward(sr)
with torch.no_grad():
vgg_hr = _forward(hr.detach())
loss = F.mse_loss(vgg_sr, vgg_hr)
return loss
# For Adversarial loss
class BasicBlock(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=False, bn=True, act=nn.ReLU(True)):
m = [nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size//2), stride=stride, bias=bias)]
if bn: m.append(nn.BatchNorm2d(out_channels))
if act is not None: m.append(act)
super(BasicBlock, self).__init__(*m)
class Discriminator(nn.Module):
def __init__(self, args, gan_type='GAN'):
super(Discriminator, self).__init__()
in_channels = 3
out_channels = 64
depth = 7
#bn = not gan_type == 'WGAN_GP'
bn = True
act = nn.LeakyReLU(negative_slope=0.2, inplace=True)
m_features = [
BasicBlock(in_channels, out_channels, 3, bn=bn, act=act)
]
for i in range(depth):
in_channels = out_channels
if i % 2 == 1:
stride = 1
out_channels *= 2
else:
stride = 2
m_features.append(BasicBlock(
in_channels, out_channels, 3, stride=stride, bn=bn, act=act
))
self.features = nn.Sequential(*m_features)
self.patch_size = args.patch_size
feature_patch_size = self.patch_size // (2**((depth + 1) // 2))
#patch_size = 256 // (2**((depth + 1) // 2))
m_classifier = [
nn.Linear(out_channels * feature_patch_size**2, 1024),
act,
nn.Linear(1024, 1)
]
self.classifier = nn.Sequential(*m_classifier)
def forward(self, x):
if x.size(2) != self.patch_size or x.size(3) != self.patch_size:
midH, midW = x.size(2) // 2, x.size(3) // 2
p = self.patch_size // 2
x = x[:, :, (midH - p):(midH - p + self.patch_size), (midW - p):(midW - p + self.patch_size)]
features = self.features(x)
output = self.classifier(features.view(features.size(0), -1))
return output
import torch.optim as optim
class Adversarial(nn.Module):
def __init__(self, args, gan_type):
super(Adversarial, self).__init__()
self.gan_type = gan_type
self.gan_k = 1 #args.gan_k
self.discriminator = torch.nn.DataParallel(Discriminator(args, gan_type))
if gan_type != 'WGAN_GP':
self.optimizer = optim.Adam(
self.discriminator.parameters(),
betas=(0.9, 0.99), eps=1e-8, lr=1e-4
)
else:
self.optimizer = optim.Adam(
self.discriminator.parameters(),
betas=(0, 0.9), eps=1e-8, lr=1e-5
)
# self.scheduler = utility.make_scheduler(args, self.optimizer)
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer, mode='min', factor=0.5, patience=3, verbose=True)
def forward(self, fake, real, fake_input0=None, fake_input1=None, fake_input_mean=None):
# def forward(self, fake, real):
fake_detach = fake.detach()
if fake_input0 is not None:
fake0, fake1 = fake_input0.detach(), fake_input1.detach()
if fake_input_mean is not None:
fake_m = fake_input_mean.detach()
# print(fake.size(), fake_input0.size(), fake_input1.size(), fake_input_mean.size())
self.loss = 0
for _ in range(self.gan_k):
self.optimizer.zero_grad()
d_fake = self.discriminator(fake_detach)
if fake_input0 is not None and fake_input1 is not None:
d_fake0 = self.discriminator(fake0)
d_fake1 = self.discriminator(fake1)
if fake_input_mean is not None:
d_fake_m = self.discriminator(fake_m)
# print(d_fake.size(), d_fake0.size(), d_fake1.size(), d_fake_m.size())
d_real = self.discriminator(real)
if self.gan_type == 'GAN':
label_fake = torch.zeros_like(d_fake)
label_real = torch.ones_like(d_real)
loss_d \
= F.binary_cross_entropy_with_logits(d_fake, label_fake) \
+ F.binary_cross_entropy_with_logits(d_real, label_real)
if fake_input0 is not None and fake_input1 is not None:
loss_d += F.binary_cross_entropy_with_logits(d_fake0, label_fake) \
+ F.binary_cross_entropy_with_logits(d_fake1, label_fake)
if fake_input_mean is not None:
loss_d += F.binary_cross_entropy_with_logits(d_fake_m, label_fake)
elif self.gan_type.find('WGAN') >= 0:
loss_d = (d_fake - d_real).mean()
if self.gan_type.find('GP') >= 0:
epsilon = torch.rand_like(fake).view(-1, 1, 1, 1)
hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon)
hat.requires_grad = True
d_hat = self.discriminator(hat)
gradients = torch.autograd.grad(
outputs=d_hat.sum(), inputs=hat,
retain_graph=True, create_graph=True, only_inputs=True
)[0]
gradients = gradients.view(gradients.size(0), -1)
gradient_norm = gradients.norm(2, dim=1)
gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean()
loss_d += gradient_penalty
# Discriminator update
self.loss += loss_d.item()
if self.training:
loss_d.backward()
self.optimizer.step()
if self.gan_type == 'WGAN':
for p in self.discriminator.parameters():
p.data.clamp_(-1, 1)
self.loss /= self.gan_k
d_fake_for_g = self.discriminator(fake)
if self.gan_type == 'GAN':
loss_g = F.binary_cross_entropy_with_logits(
d_fake_for_g, label_real
)
elif self.gan_type.find('WGAN') >= 0:
loss_g = -d_fake_for_g.mean()
# Generator loss
return loss_g
def state_dict(self, *args, **kwargs):
state_discriminator = self.discriminator.state_dict(*args, **kwargs)
state_optimizer = self.optimizer.state_dict()
return dict(**state_discriminator, **state_optimizer)
# Some references
# https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py
# OR
# https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py
# Wrapper of loss functions
class Loss(nn.modules.loss._Loss):
def __init__(self, args):
super(Loss, self).__init__()
print('Preparing loss function:')
self.loss = []
self.loss_module = nn.ModuleList()
for loss in args.loss.split('+'):
weight, loss_type = loss.split('*')
if loss_type == 'MSE':
loss_function = nn.MSELoss()
elif loss_type == 'Huber':
loss_function = HuberLoss(delta=.5)
elif loss_type == 'L1':
loss_function = nn.L1Loss()
elif loss_type.find('VGG') >= 0:
loss_function = VGG(loss_type[3:])
elif loss_type == 'SSIM':
loss_function = pytorch_msssim.SSIM(val_range=1.)
elif loss_type.find('GAN') >= 0:
loss_function = Adversarial(args, loss_type)
self.loss.append({
'type': loss_type,
'weight': float(weight),
'function': loss_function}
)
if loss_type.find('GAN') >= 0 >= 0:
self.loss.append({'type': 'DIS', 'weight': 1, 'function': None})
if len(self.loss) > 1:
self.loss.append({'type': 'Total', 'weight': 0, 'function': None})
for l in self.loss:
if l['function'] is not None:
print('{:.3f} * {}'.format(l['weight'], l['type']))
self.loss_module.append(l['function'])
device = torch.device('cuda' if args.cuda else 'cpu')
self.loss_module.to(device)
#if args.precision == 'half': self.loss_module.half()
if args.cuda:# and args.n_GPUs > 1:
self.loss_module = nn.DataParallel(self.loss_module)
def forward(self, sr, hr, fake_imgs=None):
loss = 0
losses = {}
for i, l in enumerate(self.loss):
if l['function'] is not None:
if l['type'] == 'GAN':
if fake_imgs is None:
fake_imgs = [None, None, None]
_loss = l['function'](sr, hr, fake_imgs[0], fake_imgs[1], fake_imgs[2])
else:
_loss = l['function'](sr, hr)
effective_loss = l['weight'] * _loss
losses[l['type']] = effective_loss
loss += effective_loss
elif l['type'] == 'DIS':
losses[l['type']] = self.loss[i - 1]['function'].loss
return loss, losses
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment