Commit d5096d86 authored by mashun1's avatar mashun1
Browse files

idmvton

parents
Pipeline #1220 canceled with stages
import cv2
import numpy as np
import os
import sys
from multiprocessing import Pool
from os import path as osp
from tqdm import tqdm
from basicsr.utils import scandir
def main():
"""A multi-thread tool to crop large images to sub-images for faster IO.
It is used for DIV2K dataset.
Args:
opt (dict): Configuration dict. It contains:
n_thread (int): Thread number.
compression_level (int): CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and
longer compression time. Use 0 for faster CPU decompression. Default: 3, same in cv2.
input_folder (str): Path to the input folder.
save_folder (str): Path to save folder.
crop_size (int): Crop size.
step (int): Step for overlapped sliding window.
thresh_size (int): Threshold size. Patches whose size is lower than thresh_size will be dropped.
Usage:
For each folder, run this script.
Typically, there are four folders to be processed for DIV2K dataset.
* DIV2K_train_HR
* DIV2K_train_LR_bicubic/X2
* DIV2K_train_LR_bicubic/X3
* DIV2K_train_LR_bicubic/X4
After process, each sub_folder should have the same number of subimages.
Remember to modify opt configurations according to your settings.
"""
opt = {}
opt['n_thread'] = 20
opt['compression_level'] = 3
# HR images
opt['input_folder'] = 'datasets/DIV2K/DIV2K_train_HR'
opt['save_folder'] = 'datasets/DIV2K/DIV2K_train_HR_sub'
opt['crop_size'] = 480
opt['step'] = 240
opt['thresh_size'] = 0
extract_subimages(opt)
# LRx2 images
opt['input_folder'] = 'datasets/DIV2K/DIV2K_train_LR_bicubic/X2'
opt['save_folder'] = 'datasets/DIV2K/DIV2K_train_LR_bicubic/X2_sub'
opt['crop_size'] = 240
opt['step'] = 120
opt['thresh_size'] = 0
extract_subimages(opt)
# LRx3 images
opt['input_folder'] = 'datasets/DIV2K/DIV2K_train_LR_bicubic/X3'
opt['save_folder'] = 'datasets/DIV2K/DIV2K_train_LR_bicubic/X3_sub'
opt['crop_size'] = 160
opt['step'] = 80
opt['thresh_size'] = 0
extract_subimages(opt)
# LRx4 images
opt['input_folder'] = 'datasets/DIV2K/DIV2K_train_LR_bicubic/X4'
opt['save_folder'] = 'datasets/DIV2K/DIV2K_train_LR_bicubic/X4_sub'
opt['crop_size'] = 120
opt['step'] = 60
opt['thresh_size'] = 0
extract_subimages(opt)
def extract_subimages(opt):
"""Crop images to subimages.
Args:
opt (dict): Configuration dict. It contains:
input_folder (str): Path to the input folder.
save_folder (str): Path to save folder.
n_thread (int): Thread number.
"""
input_folder = opt['input_folder']
save_folder = opt['save_folder']
if not osp.exists(save_folder):
os.makedirs(save_folder)
print(f'mkdir {save_folder} ...')
else:
print(f'Folder {save_folder} already exists. Exit.')
sys.exit(1)
img_list = list(scandir(input_folder, full_path=True))
pbar = tqdm(total=len(img_list), unit='image', desc='Extract')
pool = Pool(opt['n_thread'])
for path in img_list:
pool.apply_async(worker, args=(path, opt), callback=lambda arg: pbar.update(1))
pool.close()
pool.join()
pbar.close()
print('All processes done.')
def worker(path, opt):
"""Worker for each process.
Args:
path (str): Image path.
opt (dict): Configuration dict. It contains:
crop_size (int): Crop size.
step (int): Step for overlapped sliding window.
thresh_size (int): Threshold size. Patches whose size is lower than thresh_size will be dropped.
save_folder (str): Path to save folder.
compression_level (int): for cv2.IMWRITE_PNG_COMPRESSION.
Returns:
process_info (str): Process information displayed in progress bar.
"""
crop_size = opt['crop_size']
step = opt['step']
thresh_size = opt['thresh_size']
img_name, extension = osp.splitext(osp.basename(path))
# remove the x2, x3, x4 and x8 in the filename for DIV2K
img_name = img_name.replace('x2', '').replace('x3', '').replace('x4', '').replace('x8', '')
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
h, w = img.shape[0:2]
h_space = np.arange(0, h - crop_size + 1, step)
if h - (h_space[-1] + crop_size) > thresh_size:
h_space = np.append(h_space, h - crop_size)
w_space = np.arange(0, w - crop_size + 1, step)
if w - (w_space[-1] + crop_size) > thresh_size:
w_space = np.append(w_space, w - crop_size)
index = 0
for x in h_space:
for y in w_space:
index += 1
cropped_img = img[x:x + crop_size, y:y + crop_size, ...]
cropped_img = np.ascontiguousarray(cropped_img)
cv2.imwrite(
osp.join(opt['save_folder'], f'{img_name}_s{index:03d}{extension}'), cropped_img,
[cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']])
process_info = f'Processing {img_name} ...'
return process_info
if __name__ == '__main__':
main()
from os import path as osp
from PIL import Image
from basicsr.utils import scandir
def generate_meta_info_div2k():
"""Generate meta info for DIV2K dataset.
"""
gt_folder = 'datasets/DIV2K/DIV2K_train_HR_sub/'
meta_info_txt = 'basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt'
img_list = sorted(list(scandir(gt_folder)))
with open(meta_info_txt, 'w') as f:
for idx, img_path in enumerate(img_list):
img = Image.open(osp.join(gt_folder, img_path)) # lazy load
width, height = img.size
mode = img.mode
if mode == 'RGB':
n_channel = 3
elif mode == 'L':
n_channel = 1
else:
raise ValueError(f'Unsupported mode {mode}.')
info = f'{img_path} ({height},{width},{n_channel})'
print(idx + 1, info)
f.write(f'{info}\n')
if __name__ == '__main__':
generate_meta_info_div2k()
import cv2
import os
from tqdm import tqdm
class Mosaic16x:
"""
Mosaic16x: A customized image augmentor for 16-pixel mosaic
By default it replaces each pixel value with the mean value
of its 16x16 neighborhood
"""
def augment_image(self, x):
h, w = x.shape[:2]
x = x.astype('float') # avoid overflow for uint8
irange, jrange = (h + 15) // 16, (w + 15) // 16
for i in range(irange):
for j in range(jrange):
mean = x[i * 16:(i + 1) * 16, j * 16:(j + 1) * 16].mean(axis=(0, 1))
x[i * 16:(i + 1) * 16, j * 16:(j + 1) * 16] = mean
return x.astype('uint8')
class DegradationSimulator:
"""
Generating training/testing data pairs on the fly.
The degradation script is aligned with HiFaceGAN paper settings.
Args:
opt(str | op): Config for degradation script, with degradation type and parameters
Custom degradation is possible by passing an inherited class from ia.augmentors
"""
def __init__(self, ):
import imgaug.augmenters as ia
self.default_deg_templates = {
'sr4x':
ia.Sequential([
# It's almost like a 4x bicubic downsampling
ia.Resize((0.25000, 0.25001), cv2.INTER_AREA),
ia.Resize({
'height': 512,
'width': 512
}, cv2.INTER_CUBIC),
]),
'sr4x8x':
ia.Sequential([
ia.Resize((0.125, 0.25), cv2.INTER_AREA),
ia.Resize({
'height': 512,
'width': 512
}, cv2.INTER_CUBIC),
]),
'denoise':
ia.OneOf([
ia.AdditiveGaussianNoise(scale=(20, 40), per_channel=True),
ia.AdditiveLaplaceNoise(scale=(20, 40), per_channel=True),
ia.AdditivePoissonNoise(lam=(15, 30), per_channel=True),
]),
'deblur':
ia.OneOf([
ia.MotionBlur(k=(10, 20)),
ia.GaussianBlur((3.0, 8.0)),
]),
'jpeg':
ia.JpegCompression(compression=(50, 85)),
'16x':
Mosaic16x(),
}
rand_deg_list = [
self.default_deg_templates['deblur'],
self.default_deg_templates['denoise'],
self.default_deg_templates['jpeg'],
self.default_deg_templates['sr4x8x'],
]
self.default_deg_templates['face_renov'] = ia.Sequential(rand_deg_list, random_order=True)
def create_training_dataset(self, deg, gt_folder, lq_folder=None):
from imgaug.augmenters.meta import Augmenter # baseclass
"""
Create a degradation simulator and apply it to GT images on the fly
Save the degraded result in the lq_folder (if None, name it as GT_deg)
"""
if not lq_folder:
suffix = deg if isinstance(deg, str) else 'custom'
lq_folder = '_'.join([gt_folder.replace('gt', 'lq'), suffix])
print(lq_folder)
os.makedirs(lq_folder, exist_ok=True)
if isinstance(deg, str):
assert deg in self.default_deg_templates, (
f'Degration type {deg} not recognized: {"|".join(list(self.default_deg_templates.keys()))}')
deg = self.default_deg_templates[deg]
else:
assert isinstance(deg, Augmenter), f'Deg must be either str|Augmenter, got {deg}'
names = os.listdir(gt_folder)
for name in tqdm(names):
gt = cv2.imread(os.path.join(gt_folder, name))
lq = deg.augment_image(gt)
# pack = np.concatenate([lq, gt], axis=0)
cv2.imwrite(os.path.join(lq_folder, name), lq)
print('Dataset prepared.')
if __name__ == '__main__':
simuator = DegradationSimulator()
gt_folder = 'datasets/FFHQ_512_gt'
deg = 'sr4x'
simuator.create_training_dataset(deg, gt_folder)
import glob
import os
def regroup_reds_dataset(train_path, val_path):
"""Regroup original REDS datasets.
We merge train and validation data into one folder, and separate the
validation clips in reds_dataset.py.
There are 240 training clips (starting from 0 to 239),
so we name the validation clip index starting from 240 to 269 (total 30
validation clips).
Args:
train_path (str): Path to the train folder.
val_path (str): Path to the validation folder.
"""
# move the validation data to the train folder
val_folders = glob.glob(os.path.join(val_path, '*'))
for folder in val_folders:
new_folder_idx = int(folder.split('/')[-1]) + 240
os.system(f'cp -r {folder} {os.path.join(train_path, str(new_folder_idx))}')
if __name__ == '__main__':
# train_sharp
train_path = 'datasets/REDS/train_sharp'
val_path = 'datasets/REDS/val_sharp'
regroup_reds_dataset(train_path, val_path)
# train_sharp_bicubic
train_path = 'datasets/REDS/train_sharp_bicubic/X4'
val_path = 'datasets/REDS/val_sharp_bicubic/X4'
regroup_reds_dataset(train_path, val_path)
#!/usr/bin/env bash
GPUS=$1
CONFIG=$2
PORT=${PORT:-4321}
# usage
if [ $# -ne 2 ] ;then
echo "usage:"
echo "./scripts/dist_test.sh [number of gpu] [path to option file]"
exit
fi
PYTHONPATH="$(dirname $0)/..:${PYTHONPATH}" \
python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
basicsr/test.py -opt $CONFIG --launcher pytorch
#!/usr/bin/env bash
GPUS=$1
CONFIG=$2
PORT=${PORT:-4321}
# usage
if [ $# -lt 2 ] ;then
echo "usage:"
echo "./scripts/dist_train.sh [number of gpu] [path to option file]"
exit
fi
PYTHONPATH="$(dirname $0)/..:${PYTHONPATH}" \
python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
basicsr/train.py -opt $CONFIG --launcher pytorch ${@:3}
import argparse
from basicsr.utils.download_util import download_file_from_google_drive
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--id', type=str, help='File id')
parser.add_argument('--output', type=str, help='Save path')
args = parser.parse_args()
download_file_from_google_drive(args.id, args.save_path)
import argparse
import os
from os import path as osp
from basicsr.utils.download_util import download_file_from_google_drive
def download_pretrained_models(method, file_ids):
save_path_root = f'./experiments/pretrained_models/{method}'
os.makedirs(save_path_root, exist_ok=True)
for file_name, file_id in file_ids.items():
save_path = osp.abspath(osp.join(save_path_root, file_name))
if osp.exists(save_path):
user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n')
if user_response.lower() == 'y':
print(f'Covering {file_name} to {save_path}')
download_file_from_google_drive(file_id, save_path)
elif user_response.lower() == 'n':
print(f'Skipping {file_name}')
else:
raise ValueError('Wrong input. Only accepts Y/N.')
else:
print(f'Downloading {file_name} to {save_path}')
download_file_from_google_drive(file_id, save_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'method',
type=str,
help=("Options: 'ESRGAN', 'EDVR', 'StyleGAN', 'EDSR', 'DUF', 'DFDNet', 'dlib', 'TOF', 'flownet', 'BasicVSR'. "
"Set to 'all' to download all the models."))
args = parser.parse_args()
file_ids = {
'ESRGAN': {
'ESRGAN_SRx4_DF2KOST_official-ff704c30.pth': # file name
'1b3_bWZTjNO3iL2js1yWkJfjZykcQgvzT', # file id
'ESRGAN_PSNR_SRx4_DF2K_official-150ff491.pth': '1swaV5iBMFfg-DL6ZyiARztbhutDCWXMM'
},
'EDVR': {
'EDVR_L_x4_SR_REDS_official-9f5f5039.pth': '127KXEjlCwfoPC1aXyDkluNwr9elwyHNb',
'EDVR_L_x4_SR_Vimeo90K_official-162b54e4.pth': '1aVR3lkX6ItCphNLcT7F5bbbC484h4Qqy',
'EDVR_M_woTSA_x4_SR_REDS_official-1edf645c.pth': '1C_WdN-NyNj-P7SOB5xIVuHl4EBOwd-Ny',
'EDVR_M_x4_SR_REDS_official-32075921.pth': '1dd6aFj-5w2v08VJTq5mS9OFsD-wALYD6',
'EDVR_L_x4_SRblur_REDS_official-983d7b8e.pth': '1GZz_87ybR8eAAY3X2HWwI3L6ny7-5Yvl',
'EDVR_L_deblur_REDS_official-ca46bd8c.pth': '1_ma2tgHscZtkIY2tEJkVdU-UP8bnqBRE',
'EDVR_L_deblurcomp_REDS_official-0e988e5c.pth': '1fEoSeLFnHSBbIs95Au2W197p8e4ws4DW'
},
'StyleGAN': {
'stylegan2_ffhq_config_f_1024_official-3ab41b38.pth': '1qtdsT1FrvKQsFiW3OqOcIb-VS55TVy1g',
'stylegan2_ffhq_config_f_1024_discriminator_official-a386354a.pth': '1nPqCxm8TkDU3IvXdHCzPUxlBwR5Pd78G',
'stylegan2_cat_config_f_256_official-0a9173ad.pth': '1gfJkX6XO5pJ2J8LyMdvUgGldz7xwWpBJ',
'stylegan2_cat_config_f_256_discriminator_official-2c97fd08.pth': '1hy5FEQQl28XvfqpiWvSBd8YnIzsyDRb7',
'stylegan2_church_config_f_256_official-44ba63bf.pth': '1FCQMZXeOKZyl-xYKbl1Y_x2--rFl-1N_',
'stylegan2_church_config_f_256_discriminator_official-20cd675b.pth': # noqa: E501
'1BS9ODHkUkhfTGFVfR6alCMGtr9nGm9ox',
'stylegan2_car_config_f_512_official-e8fcab4f.pth': '14jS-nWNTguDSd1kTIX-tBHp2WdvK7hva',
'stylegan2_car_config_f_512_discriminator_official-5008e3d1.pth': '1UxkAzZ0zvw4KzBVOUpShCivsdXBS8Zi2',
'stylegan2_horse_config_f_256_official-26d57fee.pth': '12QsZ-mrO8_4gC0UrO15Jb3ykcQ88HxFx',
'stylegan2_horse_config_f_256_discriminator_official-be6c4c33.pth': '1me4ybSib72xA9ZxmzKsHDtP-eNCKw_X4'
},
'EDSR': {
'EDSR_Mx2_f64b16_DIV2K_official-3ba7b086.pth': '1mREMGVDymId3NzIc2u90sl_X4-pb4ZcV',
'EDSR_Mx3_f64b16_DIV2K_official-6908f88a.pth': '1EriqQqlIiRyPbrYGBbwr_FZzvb3iwqz5',
'EDSR_Mx4_f64b16_DIV2K_official-0c287733.pth': '1bCK6cFYU01uJudLgUUe-jgx-tZ3ikOWn',
'EDSR_Lx2_f256b32_DIV2K_official-be38e77d.pth': '15257lZCRZ0V6F9LzTyZFYbbPrqNjKyMU',
'EDSR_Lx3_f256b32_DIV2K_official-3660f70d.pth': '18q_D434sLG_rAZeHGonAX8dkqjoyZ2su',
'EDSR_Lx4_f256b32_DIV2K_official-76ee1c8f.pth': '1GCi30YYCzgMCcgheGWGusP9aWKOAy5vl'
},
'DUF': {
'DUF_x2_16L_official-39537cb9.pth': '1e91cEZOlUUk35keK9EnuK0F54QegnUKo',
'DUF_x3_16L_official-34ce53ec.pth': '1XN6aQj20esM7i0hxTbfiZr_SL8i4PZ76',
'DUF_x4_16L_official-bf8f0cfa.pth': '1V_h9U1CZgLSHTv1ky2M3lvuH-hK5hw_J',
'DUF_x4_28L_official-cbada450.pth': '1M8w0AMBJW65MYYD-_8_be0cSH_SHhDQ4',
'DUF_x4_52L_official-483d2c78.pth': '1GcmEWNr7mjTygi-QCOVgQWOo5OCNbh_T'
},
'TOF': {
'tof_x4_vimeo90k_official-32c9e01f.pth': '1TgQiXXsvkTBFrQ1D0eKPgL10tQGu0gKb'
},
'DFDNet': {
'DFDNet_dict_512-f79685f0.pth': '1iH00oMsoN_1OJaEQw3zP7_wqiAYMnY79',
'DFDNet_official-d1fa5650.pth': '1u6Sgcp8gVoy4uVTrOJKD3y9RuqH2JBAe'
},
'dlib': {
'mmod_human_face_detector-4cb19393.dat': '1FUM-hcoxNzFCOpCWbAUStBBMiU4uIGIL',
'shape_predictor_5_face_landmarks-c4b1e980.dat': '1PNPSmFjmbuuUDd5Mg5LDxyk7tu7TQv2F',
'shape_predictor_68_face_landmarks-fbdc2cb8.dat': '1IneH-O-gNkG0SQpNCplwxtOAtRCkG2ni'
},
'flownet': {
'spynet_sintel_final-3d2a1287.pth': '1VZz1cikwTRVX7zXoD247DB7n5Tj_LQpF'
},
'BasicVSR': {
'BasicVSR_REDS4-543c8261.pth': '1wLWdz18lWf9Z7lomHPkdySZ-_GV2920p',
'BasicVSR_Vimeo90K_BDx4-e9bf46eb.pth': '1baaf4RSpzs_zcDAF_s2CyArrGvLgmXxW',
'BasicVSR_Vimeo90K_BIx4-2a29695a.pth': '1ykIu1jv5wo95Kca2TjlieJFxeV4VVfHP',
'EDVR_REDS_pretrained_for_IconVSR-f62a2f1e.pth': '1ShfwddugTmT3_kB8VL6KpCMrIpEO5sBi',
'EDVR_Vimeo90K_pretrained_for_IconVSR-ee48ee92.pth': '16vR262NDVyVv5Q49xp2Sb-Llu05f63tt',
'IconVSR_REDS-aaa5367f.pth': '1b8ir754uIAFUSJ8YW_cmPzqer19AR7Hz',
'IconVSR_Vimeo90K_BDx4-cfcb7e00.pth': '13lp55s-YTd-fApx8tTy24bbHsNIGXdAH',
'IconVSR_Vimeo90K_BIx4-35fec07c.pth': '1lWUB36ERjFbAspr-8UsopJ6xwOuWjh2g'
}
}
if args.method == 'all':
for method in file_ids.keys():
download_pretrained_models(method, file_ids[method])
else:
download_pretrained_models(args.method, file_ids[args.method])
function [im_h] = backprojection(im_h, im_l, maxIter)
[row_l, col_l,~] = size(im_l);
[row_h, col_h,~] = size(im_h);
p = fspecial('gaussian', 5, 1);
p = p.^2;
p = p./sum(p(:));
im_l = double(im_l);
im_h = double(im_h);
for ii = 1:maxIter
im_l_s = imresize(im_h, [row_l, col_l], 'bicubic');
im_diff = im_l - im_l_s;
im_diff = imresize(im_diff, [row_h, col_h], 'bicubic');
im_h(:,:,1) = im_h(:,:,1) + conv2(im_diff(:,:,1), p, 'same');
im_h(:,:,2) = im_h(:,:,2) + conv2(im_diff(:,:,2), p, 'same');
im_h(:,:,3) = im_h(:,:,3) + conv2(im_diff(:,:,3), p, 'same');
end
clear; close all; clc;
LR_folder = './LR'; % LR
preout_folder = './results'; % pre output
save_folder = './results_20bp';
filepaths = dir(fullfile(preout_folder, '*.png'));
max_iter = 20;
if ~ exist(save_folder, 'dir')
mkdir(save_folder);
end
for idx_im = 1:length(filepaths)
fprintf([num2str(idx_im) '\n']);
im_name = filepaths(idx_im).name;
im_LR = im2double(imread(fullfile(LR_folder, im_name)));
im_out = im2double(imread(fullfile(preout_folder, im_name)));
%tic
im_out = backprojection(im_out, im_LR, max_iter);
%toc
imwrite(im_out, fullfile(save_folder, im_name));
end
clear; close all; clc;
LR_folder = './LR'; % LR
preout_folder = './results'; % pre output
save_folder = './results_20if';
filepaths = dir(fullfile(preout_folder, '*.png'));
max_iter = 20;
if ~ exist(save_folder, 'dir')
mkdir(save_folder);
end
for idx_im = 1:length(filepaths)
fprintf([num2str(idx_im) '\n']);
im_name = filepaths(idx_im).name;
im_LR = im2double(imread(fullfile(LR_folder, im_name)));
im_out = im2double(imread(fullfile(preout_folder, im_name)));
J = imresize(im_LR,4,'bicubic');
%tic
for m = 1:max_iter
im_out = im_out + (J - imresize(imresize(im_out,1/4,'bicubic'),4,'bicubic'));
end
%toc
imwrite(im_out, fullfile(save_folder, im_name));
end
function generate_LR_Vimeo90K()
%% matlab code to genetate bicubic-downsampled for Vimeo90K dataset
up_scale = 4;
mod_scale = 4;
idx = 0;
filepaths = dir('/home/xtwang/datasets/vimeo90k/vimeo_septuplet/sequences/*/*/*.png');
for i = 1 : length(filepaths)
[~,imname,ext] = fileparts(filepaths(i).name);
folder_path = filepaths(i).folder;
save_LR_folder = strrep(folder_path,'vimeo_septuplet','vimeo_septuplet_matlabLRx4');
if ~exist(save_LR_folder, 'dir')
mkdir(save_LR_folder);
end
if isempty(imname)
disp('Ignore . folder.');
elseif strcmp(imname, '.')
disp('Ignore .. folder.');
else
idx = idx + 1;
str_result = sprintf('%d\t%s.\n', idx, imname);
fprintf(str_result);
% read image
img = imread(fullfile(folder_path, [imname, ext]));
img = im2double(img);
% modcrop
img = modcrop(img, mod_scale);
% LR
im_LR = imresize(img, 1/up_scale, 'bicubic');
if exist('save_LR_folder', 'var')
imwrite(im_LR, fullfile(save_LR_folder, [imname, '.png']));
end
end
end
end
%% modcrop
function img = modcrop(img, modulo)
if size(img,3) == 1
sz = size(img);
sz = sz - mod(sz, modulo);
img = img(1:sz(1), 1:sz(2));
else
tmpsz = size(img);
sz = tmpsz(1:2);
sz = sz - mod(sz, modulo);
img = img(1:sz(1), 1:sz(2),:);
end
end
function generate_bicubic_img()
%% matlab code to genetate mod images, bicubic-downsampled images and
%% bicubic_upsampled images
%% set configurations
% comment the unnecessary lines
input_folder = '../../datasets/Set5/original';
save_mod_folder = '../../datasets/Set5/GTmod12';
save_lr_folder = '../../datasets/Set5/LRbicx2';
% save_bic_folder = '';
mod_scale = 12;
up_scale = 2;
if exist('save_mod_folder', 'var')
if exist(save_mod_folder, 'dir')
disp(['It will cover ', save_mod_folder]);
else
mkdir(save_mod_folder);
end
end
if exist('save_lr_folder', 'var')
if exist(save_lr_folder, 'dir')
disp(['It will cover ', save_lr_folder]);
else
mkdir(save_lr_folder);
end
end
if exist('save_bic_folder', 'var')
if exist(save_bic_folder, 'dir')
disp(['It will cover ', save_bic_folder]);
else
mkdir(save_bic_folder);
end
end
idx = 0;
filepaths = dir(fullfile(input_folder,'*.*'));
for i = 1 : length(filepaths)
[paths, img_name, ext] = fileparts(filepaths(i).name);
if isempty(img_name)
disp('Ignore . folder.');
elseif strcmp(img_name, '.')
disp('Ignore .. folder.');
else
idx = idx + 1;
str_result = sprintf('%d\t%s.\n', idx, img_name);
fprintf(str_result);
% read image
img = imread(fullfile(input_folder, [img_name, ext]));
img = im2double(img);
% modcrop
img = modcrop(img, mod_scale);
if exist('save_mod_folder', 'var')
imwrite(img, fullfile(save_mod_folder, [img_name, '.png']));
end
% LR
im_lr = imresize(img, 1/up_scale, 'bicubic');
if exist('save_lr_folder', 'var')
imwrite(im_lr, fullfile(save_lr_folder, [img_name, '.png']));
end
% Bicubic
if exist('save_bic_folder', 'var')
im_bicubic = imresize(im_lr, up_scale, 'bicubic');
imwrite(im_bicubic, fullfile(save_bic_folder, [img_name, '.png']));
end
end
end
end
%% modcrop
function img = modcrop(img, modulo)
if size(img,3) == 1
sz = size(img);
sz = sz - mod(sz, modulo);
img = img(1:sz(1), 1:sz(2));
else
tmpsz = size(img);
sz = tmpsz(1:2);
sz = sz - mod(sz, modulo);
img = img(1:sz(1), 1:sz(2),:);
end
end
import argparse
import math
import numpy as np
import torch
from torch.utils.data import DataLoader
from basicsr.data import build_dataset
from basicsr.metrics.fid import calculate_fid, extract_inception_features, load_patched_inception_v3
def calculate_fid_folder():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
parser = argparse.ArgumentParser()
parser.add_argument('folder', type=str, help='Path to the folder.')
parser.add_argument('--fid_stats', type=str, help='Path to the dataset fid statistics.')
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--num_sample', type=int, default=50000)
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--backend', type=str, default='disk', help='io backend for dataset. Option: disk, lmdb')
args = parser.parse_args()
# inception model
inception = load_patched_inception_v3(device)
# create dataset
opt = {}
opt['name'] = 'SingleImageDataset'
opt['type'] = 'SingleImageDataset'
opt['dataroot_lq'] = args.folder
opt['io_backend'] = dict(type=args.backend)
opt['mean'] = [0.5, 0.5, 0.5]
opt['std'] = [0.5, 0.5, 0.5]
dataset = build_dataset(opt)
# create dataloader
data_loader = DataLoader(
dataset=dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
sampler=None,
drop_last=False)
args.num_sample = min(args.num_sample, len(dataset))
total_batch = math.ceil(args.num_sample / args.batch_size)
def data_generator(data_loader, total_batch):
for idx, data in enumerate(data_loader):
if idx >= total_batch:
break
else:
yield data['lq']
features = extract_inception_features(data_generator(data_loader, total_batch), inception, total_batch, device)
features = features.numpy()
total_len = features.shape[0]
features = features[:args.num_sample]
print(f'Extracted {total_len} features, use the first {features.shape[0]} features to calculate stats.')
sample_mean = np.mean(features, 0)
sample_cov = np.cov(features, rowvar=False)
# load the dataset stats
stats = torch.load(args.fid_stats)
real_mean = stats['mean']
real_cov = stats['cov']
# calculate FID metric
fid = calculate_fid(sample_mean, sample_cov, real_mean, real_cov)
print('fid:', fid)
if __name__ == '__main__':
calculate_fid_folder()
import argparse
import math
import numpy as np
import torch
from torch.utils.data import DataLoader
from basicsr.data import build_dataset
from basicsr.metrics.fid import extract_inception_features, load_patched_inception_v3
def calculate_stats_from_dataset():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
parser = argparse.ArgumentParser()
parser.add_argument('--num_sample', type=int, default=50000)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--size', type=int, default=512)
parser.add_argument('--dataroot', type=str, default='datasets/ffhq')
args = parser.parse_args()
# inception model
inception = load_patched_inception_v3(device)
# create dataset
opt = {}
opt['name'] = 'FFHQ'
opt['type'] = 'FFHQDataset'
opt['dataroot_gt'] = f'datasets/ffhq/ffhq_{args.size}.lmdb'
opt['io_backend'] = dict(type='lmdb')
opt['use_hflip'] = False
opt['mean'] = [0.5, 0.5, 0.5]
opt['std'] = [0.5, 0.5, 0.5]
dataset = build_dataset(opt)
# create dataloader
data_loader = DataLoader(
dataset=dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, sampler=None, drop_last=False)
total_batch = math.ceil(args.num_sample / args.batch_size)
def data_generator(data_loader, total_batch):
for idx, data in enumerate(data_loader):
if idx >= total_batch:
break
else:
yield data['gt']
features = extract_inception_features(data_generator(data_loader, total_batch), inception, total_batch, device)
features = features.numpy()
total_len = features.shape[0]
features = features[:args.num_sample]
print(f'Extracted {total_len} features, use the first {features.shape[0]} features to calculate stats.')
mean = np.mean(features, 0)
cov = np.cov(features, rowvar=False)
save_path = f'inception_{opt["name"]}_{args.size}.pth'
torch.save(
dict(name=opt['name'], size=args.size, mean=mean, cov=cov), save_path, _use_new_zipfile_serialization=False)
if __name__ == '__main__':
calculate_stats_from_dataset()
import cv2
import glob
import numpy as np
import os.path as osp
from torchvision.transforms.functional import normalize
from basicsr.utils import img2tensor
try:
import lpips
except ImportError:
print('Please install lpips: pip install lpips')
def main():
# Configurations
# -------------------------------------------------------------------------
folder_gt = 'datasets/celeba/celeba_512_validation'
folder_restored = 'datasets/celeba/celeba_512_validation_lq'
# crop_border = 4
suffix = ''
# -------------------------------------------------------------------------
loss_fn_vgg = lpips.LPIPS(net='vgg').cuda() # RGB, normalized to [-1,1]
lpips_all = []
img_list = sorted(glob.glob(osp.join(folder_gt, '*')))
mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]
for i, img_path in enumerate(img_list):
basename, ext = osp.splitext(osp.basename(img_path))
img_gt = cv2.imread(img_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
img_restored = cv2.imread(osp.join(folder_restored, basename + suffix + ext), cv2.IMREAD_UNCHANGED).astype(
np.float32) / 255.
img_gt, img_restored = img2tensor([img_gt, img_restored], bgr2rgb=True, float32=True)
# norm to [-1, 1]
normalize(img_gt, mean, std, inplace=True)
normalize(img_restored, mean, std, inplace=True)
# calculate lpips
lpips_val = loss_fn_vgg(img_restored.unsqueeze(0).cuda(), img_gt.unsqueeze(0).cuda())
print(f'{i+1:3d}: {basename:25}. \tLPIPS: {lpips_val:.6f}.')
lpips_all.append(lpips_val)
print(f'Average: LPIPS: {sum(lpips_all) / len(lpips_all):.6f}')
if __name__ == '__main__':
main()
import argparse
import cv2
import os
import warnings
from basicsr.metrics import calculate_niqe
from basicsr.utils import scandir
def main(args):
niqe_all = []
img_list = sorted(scandir(args.input, recursive=True, full_path=True))
for i, img_path in enumerate(img_list):
basename, _ = os.path.splitext(os.path.basename(img_path))
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
with warnings.catch_warnings():
warnings.simplefilter('ignore', category=RuntimeWarning)
niqe_score = calculate_niqe(img, args.crop_border, input_order='HWC', convert_to='y')
print(f'{i+1:3d}: {basename:25}. \tNIQE: {niqe_score:.6f}')
niqe_all.append(niqe_score)
print(args.input)
print(f'Average: NIQE: {sum(niqe_all) / len(niqe_all):.6f}')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--input', type=str, default='datasets/val_set14/Set14', help='Input path')
parser.add_argument('--crop_border', type=int, default=0, help='Crop border for each side')
args = parser.parse_args()
main(args)
import argparse
import cv2
import numpy as np
from os import path as osp
from basicsr.metrics import calculate_psnr, calculate_ssim
from basicsr.utils import bgr2ycbcr, scandir
def main(args):
"""Calculate PSNR and SSIM for images.
"""
psnr_all = []
ssim_all = []
img_list_gt = sorted(list(scandir(args.gt, recursive=True, full_path=True)))
img_list_restored = sorted(list(scandir(args.restored, recursive=True, full_path=True)))
if args.test_y_channel:
print('Testing Y channel.')
else:
print('Testing RGB channels.')
for i, img_path in enumerate(img_list_gt):
basename, ext = osp.splitext(osp.basename(img_path))
img_gt = cv2.imread(img_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
if args.suffix == '':
img_path_restored = img_list_restored[i]
else:
img_path_restored = osp.join(args.restored, basename + args.suffix + ext)
img_restored = cv2.imread(img_path_restored, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
if args.correct_mean_var:
mean_l = []
std_l = []
for j in range(3):
mean_l.append(np.mean(img_gt[:, :, j]))
std_l.append(np.std(img_gt[:, :, j]))
for j in range(3):
# correct twice
mean = np.mean(img_restored[:, :, j])
img_restored[:, :, j] = img_restored[:, :, j] - mean + mean_l[j]
std = np.std(img_restored[:, :, j])
img_restored[:, :, j] = img_restored[:, :, j] / std * std_l[j]
mean = np.mean(img_restored[:, :, j])
img_restored[:, :, j] = img_restored[:, :, j] - mean + mean_l[j]
std = np.std(img_restored[:, :, j])
img_restored[:, :, j] = img_restored[:, :, j] / std * std_l[j]
if args.test_y_channel and img_gt.ndim == 3 and img_gt.shape[2] == 3:
img_gt = bgr2ycbcr(img_gt, y_only=True)
img_restored = bgr2ycbcr(img_restored, y_only=True)
# calculate PSNR and SSIM
psnr = calculate_psnr(img_gt * 255, img_restored * 255, crop_border=args.crop_border, input_order='HWC')
ssim = calculate_ssim(img_gt * 255, img_restored * 255, crop_border=args.crop_border, input_order='HWC')
print(f'{i+1:3d}: {basename:25}. \tPSNR: {psnr:.6f} dB, \tSSIM: {ssim:.6f}')
psnr_all.append(psnr)
ssim_all.append(ssim)
print(args.gt)
print(args.restored)
print(f'Average: PSNR: {sum(psnr_all) / len(psnr_all):.6f} dB, SSIM: {sum(ssim_all) / len(ssim_all):.6f}')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gt', type=str, default='datasets/val_set14/Set14', help='Path to gt (Ground-Truth)')
parser.add_argument('--restored', type=str, default='results/Set14', help='Path to restored images')
parser.add_argument('--crop_border', type=int, default=0, help='Crop border for each side')
parser.add_argument('--suffix', type=str, default='', help='Suffix for restored images')
parser.add_argument(
'--test_y_channel',
action='store_true',
help='If True, test Y channel (In MatLab YCbCr format). If False, test RGB channels.')
parser.add_argument('--correct_mean_var', action='store_true', help='Correct the mean and var of restored images.')
args = parser.parse_args()
main(args)
import argparse
import math
import numpy as np
import torch
from torch import nn
from basicsr.archs.stylegan2_arch import StyleGAN2Generator
from basicsr.metrics.fid import calculate_fid, extract_inception_features, load_patched_inception_v3
def calculate_stylegan2_fid():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
parser = argparse.ArgumentParser()
parser.add_argument('ckpt', type=str, help='Path to the stylegan2 checkpoint.')
parser.add_argument('fid_stats', type=str, help='Path to the dataset fid statistics.')
parser.add_argument('--size', type=int, default=256)
parser.add_argument('--channel_multiplier', type=int, default=2)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--num_sample', type=int, default=50000)
parser.add_argument('--truncation', type=float, default=1)
parser.add_argument('--truncation_mean', type=int, default=4096)
args = parser.parse_args()
# create stylegan2 model
generator = StyleGAN2Generator(
out_size=args.size,
num_style_feat=512,
num_mlp=8,
channel_multiplier=args.channel_multiplier,
resample_kernel=(1, 3, 3, 1))
generator.load_state_dict(torch.load(args.ckpt)['params_ema'])
generator = nn.DataParallel(generator).eval().to(device)
if args.truncation < 1:
with torch.no_grad():
truncation_latent = generator.mean_latent(args.truncation_mean)
else:
truncation_latent = None
# inception model
inception = load_patched_inception_v3(device)
total_batch = math.ceil(args.num_sample / args.batch_size)
def sample_generator(total_batch):
for _ in range(total_batch):
with torch.no_grad():
latent = torch.randn(args.batch_size, 512, device=device)
samples, _ = generator([latent], truncation=args.truncation, truncation_latent=truncation_latent)
yield samples
features = extract_inception_features(sample_generator(total_batch), inception, total_batch, device)
features = features.numpy()
total_len = features.shape[0]
features = features[:args.num_sample]
print(f'Extracted {total_len} features, use the first {features.shape[0]} features to calculate stats.')
sample_mean = np.mean(features, 0)
sample_cov = np.cov(features, rowvar=False)
# load the dataset stats
stats = torch.load(args.fid_stats)
real_mean = stats['mean']
real_cov = stats['cov']
# calculate FID metric
fid = calculate_fid(sample_mean, sample_cov, real_mean, real_cov)
print('fid:', fid)
if __name__ == '__main__':
calculate_stylegan2_fid()
import torch
from basicsr.archs.dfdnet_arch import DFDNet
from basicsr.archs.vgg_arch import NAMES
def convert_net(ori_net, crt_net):
for crt_k, _ in crt_net.items():
# vgg feature extractor
if 'vgg_extractor' in crt_k:
ori_k = crt_k.replace('vgg_extractor', 'VggExtract').replace('vgg_net', 'model')
if 'mean' in crt_k:
ori_k = ori_k.replace('mean', 'RGB_mean')
elif 'std' in crt_k:
ori_k = ori_k.replace('std', 'RGB_std')
else:
idx = NAMES['vgg19'].index(crt_k.split('.')[2])
if 'weight' in crt_k:
ori_k = f'VggExtract.model.features.{idx}.weight'
else:
ori_k = f'VggExtract.model.features.{idx}.bias'
elif 'attn_blocks' in crt_k:
if 'left_eye' in crt_k:
ori_k = crt_k.replace('attn_blocks.left_eye', 'le')
elif 'right_eye' in crt_k:
ori_k = crt_k.replace('attn_blocks.right_eye', 're')
elif 'mouth' in crt_k:
ori_k = crt_k.replace('attn_blocks.mouth', 'mo')
elif 'nose' in crt_k:
ori_k = crt_k.replace('attn_blocks.nose', 'no')
else:
raise ValueError('Wrong!')
elif 'multi_scale_dilation' in crt_k:
if 'conv_blocks' in crt_k:
_, _, c, d, e = crt_k.split('.')
ori_k = f'MSDilate.conv{int(c)+1}.{d}.{e}'
else:
ori_k = crt_k.replace('multi_scale_dilation.conv_fusion', 'MSDilate.convi')
elif crt_k.startswith('upsample'):
ori_k = crt_k.replace('upsample', 'up')
if 'scale_block' in crt_k:
ori_k = ori_k.replace('scale_block', 'ScaleModel1')
elif 'shift_block' in crt_k:
ori_k = ori_k.replace('shift_block', 'ShiftModel1')
elif 'upsample4' in crt_k and 'body' in crt_k:
ori_k = ori_k.replace('body', 'Model')
else:
print('unprocess key: ', crt_k)
# replace
if crt_net[crt_k].size() != ori_net[ori_k].size():
raise ValueError('Wrong tensor size: \n'
f'crt_net: {crt_net[crt_k].size()}\n'
f'ori_net: {ori_net[ori_k].size()}')
else:
crt_net[crt_k] = ori_net[ori_k]
return crt_net
if __name__ == '__main__':
ori_net = torch.load('experiments/pretrained_models/DFDNet/DFDNet_official_original.pth')
dfd_net = DFDNet(64, dict_path='experiments/pretrained_models/DFDNet/DFDNet_dict_512.pth')
crt_net = dfd_net.state_dict()
crt_net_params = convert_net(ori_net, crt_net)
torch.save(
dict(params=crt_net_params),
'experiments/pretrained_models/DFDNet/DFDNet_official.pth',
_use_new_zipfile_serialization=False)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment