Commit e2696ece authored by mashun1's avatar mashun1
Browse files

controlnet

parents
Pipeline #643 canceled with stages
import argparse
from os import path as osp
from basicsr.utils import scandir
from basicsr.utils.lmdb_util import make_lmdb_from_imgs
def create_lmdb_for_div2k():
"""Create lmdb files for DIV2K dataset.
Usage:
Before run this script, please run `extract_subimages.py`.
Typically, there are four folders to be processed for DIV2K dataset.
* DIV2K_train_HR_sub
* DIV2K_train_LR_bicubic/X2_sub
* DIV2K_train_LR_bicubic/X3_sub
* DIV2K_train_LR_bicubic/X4_sub
Remember to modify opt configurations according to your settings.
"""
# HR images
folder_path = 'datasets/DIV2K/DIV2K_train_HR_sub'
lmdb_path = 'datasets/DIV2K/DIV2K_train_HR_sub.lmdb'
img_path_list, keys = prepare_keys_div2k(folder_path)
make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
# LRx2 images
folder_path = 'datasets/DIV2K/DIV2K_train_LR_bicubic/X2_sub'
lmdb_path = 'datasets/DIV2K/DIV2K_train_LR_bicubic_X2_sub.lmdb'
img_path_list, keys = prepare_keys_div2k(folder_path)
make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
# LRx3 images
folder_path = 'datasets/DIV2K/DIV2K_train_LR_bicubic/X3_sub'
lmdb_path = 'datasets/DIV2K/DIV2K_train_LR_bicubic_X3_sub.lmdb'
img_path_list, keys = prepare_keys_div2k(folder_path)
make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
# LRx4 images
folder_path = 'datasets/DIV2K/DIV2K_train_LR_bicubic/X4_sub'
lmdb_path = 'datasets/DIV2K/DIV2K_train_LR_bicubic_X4_sub.lmdb'
img_path_list, keys = prepare_keys_div2k(folder_path)
make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
def prepare_keys_div2k(folder_path):
"""Prepare image path list and keys for DIV2K dataset.
Args:
folder_path (str): Folder path.
Returns:
list[str]: Image path list.
list[str]: Key list.
"""
print('Reading image path list ...')
img_path_list = sorted(list(scandir(folder_path, suffix='png', recursive=False)))
keys = [img_path.split('.png')[0] for img_path in sorted(img_path_list)]
return img_path_list, keys
def create_lmdb_for_reds():
"""Create lmdb files for REDS dataset.
Usage:
Before run this script, please run :file:`merge_reds_train_val.py`.
We take two folders for example:
* train_sharp
* train_sharp_bicubic
Remember to modify opt configurations according to your settings.
"""
# train_sharp
folder_path = 'datasets/REDS/train_sharp'
lmdb_path = 'datasets/REDS/train_sharp_with_val.lmdb'
img_path_list, keys = prepare_keys_reds(folder_path)
make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys, multiprocessing_read=True)
# train_sharp_bicubic
folder_path = 'datasets/REDS/train_sharp_bicubic'
lmdb_path = 'datasets/REDS/train_sharp_bicubic_with_val.lmdb'
img_path_list, keys = prepare_keys_reds(folder_path)
make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys, multiprocessing_read=True)
def prepare_keys_reds(folder_path):
"""Prepare image path list and keys for REDS dataset.
Args:
folder_path (str): Folder path.
Returns:
list[str]: Image path list.
list[str]: Key list.
"""
print('Reading image path list ...')
img_path_list = sorted(list(scandir(folder_path, suffix='png', recursive=True)))
keys = [v.split('.png')[0] for v in img_path_list] # example: 000/00000000
return img_path_list, keys
def create_lmdb_for_vimeo90k():
"""Create lmdb files for Vimeo90K dataset.
Usage:
Remember to modify opt configurations according to your settings.
"""
# GT
folder_path = 'datasets/vimeo90k/vimeo_septuplet/sequences'
lmdb_path = 'datasets/vimeo90k/vimeo90k_train_GT_only4th.lmdb'
train_list_path = 'datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt'
img_path_list, keys = prepare_keys_vimeo90k(folder_path, train_list_path, 'gt')
make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys, multiprocessing_read=True)
# LQ
folder_path = 'datasets/vimeo90k/vimeo_septuplet_matlabLRx4/sequences'
lmdb_path = 'datasets/vimeo90k/vimeo90k_train_LR7frames.lmdb'
train_list_path = 'datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt'
img_path_list, keys = prepare_keys_vimeo90k(folder_path, train_list_path, 'lq')
make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys, multiprocessing_read=True)
def prepare_keys_vimeo90k(folder_path, train_list_path, mode):
"""Prepare image path list and keys for Vimeo90K dataset.
Args:
folder_path (str): Folder path.
train_list_path (str): Path to the official train list.
mode (str): One of 'gt' or 'lq'.
Returns:
list[str]: Image path list.
list[str]: Key list.
"""
print('Reading image path list ...')
with open(train_list_path, 'r') as fin:
train_list = [line.strip() for line in fin]
img_path_list = []
keys = []
for line in train_list:
folder, sub_folder = line.split('/')
img_path_list.extend([osp.join(folder, sub_folder, f'im{j + 1}.png') for j in range(7)])
keys.extend([f'{folder}/{sub_folder}/im{j + 1}' for j in range(7)])
if mode == 'gt':
print('Only keep the 4th frame for the gt mode.')
img_path_list = [v for v in img_path_list if v.endswith('im4.png')]
keys = [v for v in keys if v.endswith('/im4')]
return img_path_list, keys
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--dataset',
type=str,
help=("Options: 'DIV2K', 'REDS', 'Vimeo90K' You may need to modify the corresponding configurations in codes."))
args = parser.parse_args()
dataset = args.dataset.lower()
if dataset == 'div2k':
create_lmdb_for_div2k()
elif dataset == 'reds':
create_lmdb_for_reds()
elif dataset == 'vimeo90k':
create_lmdb_for_vimeo90k()
else:
raise ValueError('Wrong dataset.')
import argparse
import glob
import os
from os import path as osp
from basicsr.utils.download_util import download_file_from_google_drive
def download_dataset(dataset, file_ids):
save_path_root = './datasets/'
os.makedirs(save_path_root, exist_ok=True)
for file_name, file_id in file_ids.items():
save_path = osp.abspath(osp.join(save_path_root, file_name))
if osp.exists(save_path):
user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n')
if user_response.lower() == 'y':
print(f'Covering {file_name} to {save_path}')
download_file_from_google_drive(file_id, save_path)
elif user_response.lower() == 'n':
print(f'Skipping {file_name}')
else:
raise ValueError('Wrong input. Only accepts Y/N.')
else:
print(f'Downloading {file_name} to {save_path}')
download_file_from_google_drive(file_id, save_path)
# unzip
if save_path.endswith('.zip'):
extracted_path = save_path.replace('.zip', '')
print(f'Extract {save_path} to {extracted_path}')
import zipfile
with zipfile.ZipFile(save_path, 'r') as zip_ref:
zip_ref.extractall(extracted_path)
file_name = file_name.replace('.zip', '')
subfolder = osp.join(extracted_path, file_name)
if osp.isdir(subfolder):
print(f'Move {subfolder} to {extracted_path}')
import shutil
for path in glob.glob(osp.join(subfolder, '*')):
shutil.move(path, extracted_path)
shutil.rmtree(subfolder)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'dataset',
type=str,
help=("Options: 'Set5', 'Set14'. "
"Set to 'all' if you want to download all the dataset."))
args = parser.parse_args()
file_ids = {
'Set5': {
'Set5.zip': # file name
'1RtyIeUFTyW8u7oa4z7a0lSzT3T1FwZE9', # file id
},
'Set14': {
'Set14.zip': '1vsw07sV8wGrRQ8UARe2fO5jjgy9QJy_E',
}
}
if args.dataset == 'all':
for dataset in file_ids.keys():
download_dataset(dataset, file_ids[dataset])
else:
download_dataset(args.dataset, file_ids[args.dataset])
import argparse
import cv2
import glob
import numpy as np
import os
from basicsr.utils.lmdb_util import LmdbMaker
def convert_celeba_tfrecords(tf_file, log_resolution, save_root, save_type='img', compress_level=1):
"""Convert CelebA tfrecords to images or lmdb files.
Args:
tf_file (str): Input tfrecords file in glob pattern.
Example: 'datasets/celeba/celeba_tfrecords/validation/validation-r08-s-*-of-*.tfrecords' # noqa:E501
log_resolution (int): Log scale of resolution.
save_root (str): Path root to save.
save_type (str): Save type. Options: img | lmdb. Default: img.
compress_level (int): Compress level when encoding images. Default: 1.
"""
if 'validation' in tf_file:
phase = 'validation'
else:
phase = 'train'
if save_type == 'lmdb':
save_path = os.path.join(save_root, f'celeba_{2**log_resolution}_{phase}.lmdb')
lmdb_maker = LmdbMaker(save_path)
elif save_type == 'img':
save_path = os.path.join(save_root, f'celeba_{2**log_resolution}_{phase}')
else:
raise ValueError('Wrong save type.')
os.makedirs(save_path, exist_ok=True)
idx = 0
for record in sorted(glob.glob(tf_file)):
print('Processing record: ', record)
record_iterator = tf.python_io.tf_record_iterator(record)
for string_record in record_iterator:
example = tf.train.Example()
example.ParseFromString(string_record)
# label = example.features.feature['label'].int64_list.value[0]
# attr = example.features.feature['attr'].int64_list.value
# male = attr[20]
# young = attr[39]
shape = example.features.feature['shape'].int64_list.value
h, w, c = shape
img_str = example.features.feature['data'].bytes_list.value[0]
img = np.fromstring(img_str, dtype=np.uint8).reshape((h, w, c))
img = img[:, :, [2, 1, 0]]
if save_type == 'img':
cv2.imwrite(os.path.join(save_path, f'{idx:08d}.png'), img)
elif save_type == 'lmdb':
_, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
key = f'{idx:08d}/r{log_resolution:02d}'
lmdb_maker.put(img_byte, key, (h, w, c))
idx += 1
print(idx)
if save_type == 'lmdb':
lmdb_maker.close()
def convert_ffhq_tfrecords(tf_file, log_resolution, save_root, save_type='img', compress_level=1):
"""Convert FFHQ tfrecords to images or lmdb files.
Args:
tf_file (str): Input tfrecords file.
log_resolution (int): Log scale of resolution.
save_root (str): Path root to save.
save_type (str): Save type. Options: img | lmdb. Default: img.
compress_level (int): Compress level when encoding images. Default: 1.
"""
if save_type == 'lmdb':
save_path = os.path.join(save_root, f'ffhq_{2**log_resolution}.lmdb')
lmdb_maker = LmdbMaker(save_path)
elif save_type == 'img':
save_path = os.path.join(save_root, f'ffhq_{2**log_resolution}')
else:
raise ValueError('Wrong save type.')
os.makedirs(save_path, exist_ok=True)
idx = 0
for record in sorted(glob.glob(tf_file)):
print('Processing record: ', record)
record_iterator = tf.python_io.tf_record_iterator(record)
for string_record in record_iterator:
example = tf.train.Example()
example.ParseFromString(string_record)
shape = example.features.feature['shape'].int64_list.value
c, h, w = shape
img_str = example.features.feature['data'].bytes_list.value[0]
img = np.fromstring(img_str, dtype=np.uint8).reshape((c, h, w))
img = img.transpose(1, 2, 0)
img = img[:, :, [2, 1, 0]]
if save_type == 'img':
cv2.imwrite(os.path.join(save_path, f'{idx:08d}.png'), img)
elif save_type == 'lmdb':
_, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
key = f'{idx:08d}/r{log_resolution:02d}'
lmdb_maker.put(img_byte, key, (h, w, c))
idx += 1
print(idx)
if save_type == 'lmdb':
lmdb_maker.close()
def make_ffhq_lmdb_from_imgs(folder_path, log_resolution, save_root, save_type='lmdb', compress_level=1):
"""Make FFHQ lmdb from images.
Args:
folder_path (str): Folder path.
log_resolution (int): Log scale of resolution.
save_root (str): Path root to save.
save_type (str): Save type. Options: img | lmdb. Default: img.
compress_level (int): Compress level when encoding images. Default: 1.
"""
if save_type == 'lmdb':
save_path = os.path.join(save_root, f'ffhq_{2**log_resolution}_crop1.2.lmdb')
lmdb_maker = LmdbMaker(save_path)
else:
raise ValueError('Wrong save type.')
os.makedirs(save_path, exist_ok=True)
img_list = sorted(glob.glob(os.path.join(folder_path, '*')))
for idx, img_path in enumerate(img_list):
print(f'Processing {idx}: ', img_path)
img = cv2.imread(img_path)
h, w, c = img.shape
if save_type == 'lmdb':
_, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
key = f'{idx:08d}/r{log_resolution:02d}'
lmdb_maker.put(img_byte, key, (h, w, c))
if save_type == 'lmdb':
lmdb_maker.close()
if __name__ == '__main__':
"""Read tfrecords w/o define a graph.
We have tested it on TensorFlow 1.15
References: http://warmspringwinds.github.io/tensorflow/tf-slim/2016/12/21/tfrecords-guide/
"""
parser = argparse.ArgumentParser()
parser.add_argument(
'--dataset', type=str, default='ffhq', help="Dataset name. Options: 'ffhq' | 'celeba'. Default: 'ffhq'.")
parser.add_argument(
'--tf_file',
type=str,
default='datasets/ffhq/ffhq-r10.tfrecords',
help=(
'Input tfrecords file. For celeba, it should be glob pattern. '
'Put quotes around the wildcard argument to prevent the shell '
'from expanding it.'
"Example: 'datasets/celeba/celeba_tfrecords/validation/validation-r08-s-*-of-*.tfrecords'" # noqa:E501
))
parser.add_argument('--log_resolution', type=int, default=10, help='Log scale of resolution.')
parser.add_argument('--save_root', type=str, default='datasets/ffhq/', help='Save root path.')
parser.add_argument(
'--save_type', type=str, default='img', help="Save type. Options: 'img' | 'lmdb'. Default: 'img'.")
parser.add_argument(
'--compress_level', type=int, default=1, help='Compress level when encoding images. Default: 1.')
args = parser.parse_args()
try:
import tensorflow as tf
except Exception:
raise ImportError('You need to install tensorflow to read tfrecords.')
if args.dataset == 'ffhq':
convert_ffhq_tfrecords(
args.tf_file,
args.log_resolution,
args.save_root,
save_type=args.save_type,
compress_level=args.compress_level)
else:
convert_celeba_tfrecords(
args.tf_file,
args.log_resolution,
args.save_root,
save_type=args.save_type,
compress_level=args.compress_level)
import cv2
import numpy as np
import os
import sys
from multiprocessing import Pool
from os import path as osp
from tqdm import tqdm
from basicsr.utils import scandir
def main():
"""A multi-thread tool to crop large images to sub-images for faster IO.
It is used for DIV2K dataset.
Args:
opt (dict): Configuration dict. It contains:
n_thread (int): Thread number.
compression_level (int): CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and
longer compression time. Use 0 for faster CPU decompression. Default: 3, same in cv2.
input_folder (str): Path to the input folder.
save_folder (str): Path to save folder.
crop_size (int): Crop size.
step (int): Step for overlapped sliding window.
thresh_size (int): Threshold size. Patches whose size is lower than thresh_size will be dropped.
Usage:
For each folder, run this script.
Typically, there are four folders to be processed for DIV2K dataset.
* DIV2K_train_HR
* DIV2K_train_LR_bicubic/X2
* DIV2K_train_LR_bicubic/X3
* DIV2K_train_LR_bicubic/X4
After process, each sub_folder should have the same number of subimages.
Remember to modify opt configurations according to your settings.
"""
opt = {}
opt['n_thread'] = 20
opt['compression_level'] = 3
# HR images
opt['input_folder'] = 'datasets/DIV2K/DIV2K_train_HR'
opt['save_folder'] = 'datasets/DIV2K/DIV2K_train_HR_sub'
opt['crop_size'] = 480
opt['step'] = 240
opt['thresh_size'] = 0
extract_subimages(opt)
# LRx2 images
opt['input_folder'] = 'datasets/DIV2K/DIV2K_train_LR_bicubic/X2'
opt['save_folder'] = 'datasets/DIV2K/DIV2K_train_LR_bicubic/X2_sub'
opt['crop_size'] = 240
opt['step'] = 120
opt['thresh_size'] = 0
extract_subimages(opt)
# LRx3 images
opt['input_folder'] = 'datasets/DIV2K/DIV2K_train_LR_bicubic/X3'
opt['save_folder'] = 'datasets/DIV2K/DIV2K_train_LR_bicubic/X3_sub'
opt['crop_size'] = 160
opt['step'] = 80
opt['thresh_size'] = 0
extract_subimages(opt)
# LRx4 images
opt['input_folder'] = 'datasets/DIV2K/DIV2K_train_LR_bicubic/X4'
opt['save_folder'] = 'datasets/DIV2K/DIV2K_train_LR_bicubic/X4_sub'
opt['crop_size'] = 120
opt['step'] = 60
opt['thresh_size'] = 0
extract_subimages(opt)
def extract_subimages(opt):
"""Crop images to subimages.
Args:
opt (dict): Configuration dict. It contains:
input_folder (str): Path to the input folder.
save_folder (str): Path to save folder.
n_thread (int): Thread number.
"""
input_folder = opt['input_folder']
save_folder = opt['save_folder']
if not osp.exists(save_folder):
os.makedirs(save_folder)
print(f'mkdir {save_folder} ...')
else:
print(f'Folder {save_folder} already exists. Exit.')
sys.exit(1)
img_list = list(scandir(input_folder, full_path=True))
pbar = tqdm(total=len(img_list), unit='image', desc='Extract')
pool = Pool(opt['n_thread'])
for path in img_list:
pool.apply_async(worker, args=(path, opt), callback=lambda arg: pbar.update(1))
pool.close()
pool.join()
pbar.close()
print('All processes done.')
def worker(path, opt):
"""Worker for each process.
Args:
path (str): Image path.
opt (dict): Configuration dict. It contains:
crop_size (int): Crop size.
step (int): Step for overlapped sliding window.
thresh_size (int): Threshold size. Patches whose size is lower than thresh_size will be dropped.
save_folder (str): Path to save folder.
compression_level (int): for cv2.IMWRITE_PNG_COMPRESSION.
Returns:
process_info (str): Process information displayed in progress bar.
"""
crop_size = opt['crop_size']
step = opt['step']
thresh_size = opt['thresh_size']
img_name, extension = osp.splitext(osp.basename(path))
# remove the x2, x3, x4 and x8 in the filename for DIV2K
img_name = img_name.replace('x2', '').replace('x3', '').replace('x4', '').replace('x8', '')
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
h, w = img.shape[0:2]
h_space = np.arange(0, h - crop_size + 1, step)
if h - (h_space[-1] + crop_size) > thresh_size:
h_space = np.append(h_space, h - crop_size)
w_space = np.arange(0, w - crop_size + 1, step)
if w - (w_space[-1] + crop_size) > thresh_size:
w_space = np.append(w_space, w - crop_size)
index = 0
for x in h_space:
for y in w_space:
index += 1
cropped_img = img[x:x + crop_size, y:y + crop_size, ...]
cropped_img = np.ascontiguousarray(cropped_img)
cv2.imwrite(
osp.join(opt['save_folder'], f'{img_name}_s{index:03d}{extension}'), cropped_img,
[cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']])
process_info = f'Processing {img_name} ...'
return process_info
if __name__ == '__main__':
main()
from os import path as osp
from PIL import Image
from basicsr.utils import scandir
def generate_meta_info_div2k():
"""Generate meta info for DIV2K dataset.
"""
gt_folder = 'datasets/DIV2K/DIV2K_train_HR_sub/'
meta_info_txt = 'basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt'
img_list = sorted(list(scandir(gt_folder)))
with open(meta_info_txt, 'w') as f:
for idx, img_path in enumerate(img_list):
img = Image.open(osp.join(gt_folder, img_path)) # lazy load
width, height = img.size
mode = img.mode
if mode == 'RGB':
n_channel = 3
elif mode == 'L':
n_channel = 1
else:
raise ValueError(f'Unsupported mode {mode}.')
info = f'{img_path} ({height},{width},{n_channel})'
print(idx + 1, info)
f.write(f'{info}\n')
if __name__ == '__main__':
generate_meta_info_div2k()
import cv2
import os
from tqdm import tqdm
class Mosaic16x:
"""
Mosaic16x: A customized image augmentor for 16-pixel mosaic
By default it replaces each pixel value with the mean value
of its 16x16 neighborhood
"""
def augment_image(self, x):
h, w = x.shape[:2]
x = x.astype('float') # avoid overflow for uint8
irange, jrange = (h + 15) // 16, (w + 15) // 16
for i in range(irange):
for j in range(jrange):
mean = x[i * 16:(i + 1) * 16, j * 16:(j + 1) * 16].mean(axis=(0, 1))
x[i * 16:(i + 1) * 16, j * 16:(j + 1) * 16] = mean
return x.astype('uint8')
class DegradationSimulator:
"""
Generating training/testing data pairs on the fly.
The degradation script is aligned with HiFaceGAN paper settings.
Args:
opt(str | op): Config for degradation script, with degradation type and parameters
Custom degradation is possible by passing an inherited class from ia.augmentors
"""
def __init__(self, ):
import imgaug.augmenters as ia
self.default_deg_templates = {
'sr4x':
ia.Sequential([
# It's almost like a 4x bicubic downsampling
ia.Resize((0.25000, 0.25001), cv2.INTER_AREA),
ia.Resize({
'height': 512,
'width': 512
}, cv2.INTER_CUBIC),
]),
'sr4x8x':
ia.Sequential([
ia.Resize((0.125, 0.25), cv2.INTER_AREA),
ia.Resize({
'height': 512,
'width': 512
}, cv2.INTER_CUBIC),
]),
'denoise':
ia.OneOf([
ia.AdditiveGaussianNoise(scale=(20, 40), per_channel=True),
ia.AdditiveLaplaceNoise(scale=(20, 40), per_channel=True),
ia.AdditivePoissonNoise(lam=(15, 30), per_channel=True),
]),
'deblur':
ia.OneOf([
ia.MotionBlur(k=(10, 20)),
ia.GaussianBlur((3.0, 8.0)),
]),
'jpeg':
ia.JpegCompression(compression=(50, 85)),
'16x':
Mosaic16x(),
}
rand_deg_list = [
self.default_deg_templates['deblur'],
self.default_deg_templates['denoise'],
self.default_deg_templates['jpeg'],
self.default_deg_templates['sr4x8x'],
]
self.default_deg_templates['face_renov'] = ia.Sequential(rand_deg_list, random_order=True)
def create_training_dataset(self, deg, gt_folder, lq_folder=None):
from imgaug.augmenters.meta import Augmenter # baseclass
"""
Create a degradation simulator and apply it to GT images on the fly
Save the degraded result in the lq_folder (if None, name it as GT_deg)
"""
if not lq_folder:
suffix = deg if isinstance(deg, str) else 'custom'
lq_folder = '_'.join([gt_folder.replace('gt', 'lq'), suffix])
print(lq_folder)
os.makedirs(lq_folder, exist_ok=True)
if isinstance(deg, str):
assert deg in self.default_deg_templates, (
f'Degration type {deg} not recognized: {"|".join(list(self.default_deg_templates.keys()))}')
deg = self.default_deg_templates[deg]
else:
assert isinstance(deg, Augmenter), f'Deg must be either str|Augmenter, got {deg}'
names = os.listdir(gt_folder)
for name in tqdm(names):
gt = cv2.imread(os.path.join(gt_folder, name))
lq = deg.augment_image(gt)
# pack = np.concatenate([lq, gt], axis=0)
cv2.imwrite(os.path.join(lq_folder, name), lq)
print('Dataset prepared.')
if __name__ == '__main__':
simuator = DegradationSimulator()
gt_folder = 'datasets/FFHQ_512_gt'
deg = 'sr4x'
simuator.create_training_dataset(deg, gt_folder)
import glob
import os
def regroup_reds_dataset(train_path, val_path):
"""Regroup original REDS datasets.
We merge train and validation data into one folder, and separate the
validation clips in reds_dataset.py.
There are 240 training clips (starting from 0 to 239),
so we name the validation clip index starting from 240 to 269 (total 30
validation clips).
Args:
train_path (str): Path to the train folder.
val_path (str): Path to the validation folder.
"""
# move the validation data to the train folder
val_folders = glob.glob(os.path.join(val_path, '*'))
for folder in val_folders:
new_folder_idx = int(folder.split('/')[-1]) + 240
os.system(f'cp -r {folder} {os.path.join(train_path, str(new_folder_idx))}')
if __name__ == '__main__':
# train_sharp
train_path = 'datasets/REDS/train_sharp'
val_path = 'datasets/REDS/val_sharp'
regroup_reds_dataset(train_path, val_path)
# train_sharp_bicubic
train_path = 'datasets/REDS/train_sharp_bicubic/X4'
val_path = 'datasets/REDS/val_sharp_bicubic/X4'
regroup_reds_dataset(train_path, val_path)
#!/usr/bin/env bash
GPUS=$1
CONFIG=$2
PORT=${PORT:-4321}
# usage
if [ $# -ne 2 ] ;then
echo "usage:"
echo "./scripts/dist_test.sh [number of gpu] [path to option file]"
exit
fi
PYTHONPATH="$(dirname $0)/..:${PYTHONPATH}" \
python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
basicsr/test.py -opt $CONFIG --launcher pytorch
#!/usr/bin/env bash
GPUS=$1
CONFIG=$2
PORT=${PORT:-4321}
# usage
if [ $# -lt 2 ] ;then
echo "usage:"
echo "./scripts/dist_train.sh [number of gpu] [path to option file]"
exit
fi
PYTHONPATH="$(dirname $0)/..:${PYTHONPATH}" \
python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
basicsr/train.py -opt $CONFIG --launcher pytorch ${@:3}
import argparse
from basicsr.utils.download_util import download_file_from_google_drive
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--id', type=str, help='File id')
parser.add_argument('--output', type=str, help='Save path')
args = parser.parse_args()
download_file_from_google_drive(args.id, args.save_path)
import argparse
import os
from os import path as osp
from basicsr.utils.download_util import download_file_from_google_drive
def download_pretrained_models(method, file_ids):
save_path_root = f'./experiments/pretrained_models/{method}'
os.makedirs(save_path_root, exist_ok=True)
for file_name, file_id in file_ids.items():
save_path = osp.abspath(osp.join(save_path_root, file_name))
if osp.exists(save_path):
user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n')
if user_response.lower() == 'y':
print(f'Covering {file_name} to {save_path}')
download_file_from_google_drive(file_id, save_path)
elif user_response.lower() == 'n':
print(f'Skipping {file_name}')
else:
raise ValueError('Wrong input. Only accepts Y/N.')
else:
print(f'Downloading {file_name} to {save_path}')
download_file_from_google_drive(file_id, save_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'method',
type=str,
help=("Options: 'ESRGAN', 'EDVR', 'StyleGAN', 'EDSR', 'DUF', 'DFDNet', 'dlib', 'TOF', 'flownet', 'BasicVSR'. "
"Set to 'all' to download all the models."))
args = parser.parse_args()
file_ids = {
'ESRGAN': {
'ESRGAN_SRx4_DF2KOST_official-ff704c30.pth': # file name
'1b3_bWZTjNO3iL2js1yWkJfjZykcQgvzT', # file id
'ESRGAN_PSNR_SRx4_DF2K_official-150ff491.pth': '1swaV5iBMFfg-DL6ZyiARztbhutDCWXMM'
},
'EDVR': {
'EDVR_L_x4_SR_REDS_official-9f5f5039.pth': '127KXEjlCwfoPC1aXyDkluNwr9elwyHNb',
'EDVR_L_x4_SR_Vimeo90K_official-162b54e4.pth': '1aVR3lkX6ItCphNLcT7F5bbbC484h4Qqy',
'EDVR_M_woTSA_x4_SR_REDS_official-1edf645c.pth': '1C_WdN-NyNj-P7SOB5xIVuHl4EBOwd-Ny',
'EDVR_M_x4_SR_REDS_official-32075921.pth': '1dd6aFj-5w2v08VJTq5mS9OFsD-wALYD6',
'EDVR_L_x4_SRblur_REDS_official-983d7b8e.pth': '1GZz_87ybR8eAAY3X2HWwI3L6ny7-5Yvl',
'EDVR_L_deblur_REDS_official-ca46bd8c.pth': '1_ma2tgHscZtkIY2tEJkVdU-UP8bnqBRE',
'EDVR_L_deblurcomp_REDS_official-0e988e5c.pth': '1fEoSeLFnHSBbIs95Au2W197p8e4ws4DW'
},
'StyleGAN': {
'stylegan2_ffhq_config_f_1024_official-3ab41b38.pth': '1qtdsT1FrvKQsFiW3OqOcIb-VS55TVy1g',
'stylegan2_ffhq_config_f_1024_discriminator_official-a386354a.pth': '1nPqCxm8TkDU3IvXdHCzPUxlBwR5Pd78G',
'stylegan2_cat_config_f_256_official-0a9173ad.pth': '1gfJkX6XO5pJ2J8LyMdvUgGldz7xwWpBJ',
'stylegan2_cat_config_f_256_discriminator_official-2c97fd08.pth': '1hy5FEQQl28XvfqpiWvSBd8YnIzsyDRb7',
'stylegan2_church_config_f_256_official-44ba63bf.pth': '1FCQMZXeOKZyl-xYKbl1Y_x2--rFl-1N_',
'stylegan2_church_config_f_256_discriminator_official-20cd675b.pth': # noqa: E501
'1BS9ODHkUkhfTGFVfR6alCMGtr9nGm9ox',
'stylegan2_car_config_f_512_official-e8fcab4f.pth': '14jS-nWNTguDSd1kTIX-tBHp2WdvK7hva',
'stylegan2_car_config_f_512_discriminator_official-5008e3d1.pth': '1UxkAzZ0zvw4KzBVOUpShCivsdXBS8Zi2',
'stylegan2_horse_config_f_256_official-26d57fee.pth': '12QsZ-mrO8_4gC0UrO15Jb3ykcQ88HxFx',
'stylegan2_horse_config_f_256_discriminator_official-be6c4c33.pth': '1me4ybSib72xA9ZxmzKsHDtP-eNCKw_X4'
},
'EDSR': {
'EDSR_Mx2_f64b16_DIV2K_official-3ba7b086.pth': '1mREMGVDymId3NzIc2u90sl_X4-pb4ZcV',
'EDSR_Mx3_f64b16_DIV2K_official-6908f88a.pth': '1EriqQqlIiRyPbrYGBbwr_FZzvb3iwqz5',
'EDSR_Mx4_f64b16_DIV2K_official-0c287733.pth': '1bCK6cFYU01uJudLgUUe-jgx-tZ3ikOWn',
'EDSR_Lx2_f256b32_DIV2K_official-be38e77d.pth': '15257lZCRZ0V6F9LzTyZFYbbPrqNjKyMU',
'EDSR_Lx3_f256b32_DIV2K_official-3660f70d.pth': '18q_D434sLG_rAZeHGonAX8dkqjoyZ2su',
'EDSR_Lx4_f256b32_DIV2K_official-76ee1c8f.pth': '1GCi30YYCzgMCcgheGWGusP9aWKOAy5vl'
},
'DUF': {
'DUF_x2_16L_official-39537cb9.pth': '1e91cEZOlUUk35keK9EnuK0F54QegnUKo',
'DUF_x3_16L_official-34ce53ec.pth': '1XN6aQj20esM7i0hxTbfiZr_SL8i4PZ76',
'DUF_x4_16L_official-bf8f0cfa.pth': '1V_h9U1CZgLSHTv1ky2M3lvuH-hK5hw_J',
'DUF_x4_28L_official-cbada450.pth': '1M8w0AMBJW65MYYD-_8_be0cSH_SHhDQ4',
'DUF_x4_52L_official-483d2c78.pth': '1GcmEWNr7mjTygi-QCOVgQWOo5OCNbh_T'
},
'TOF': {
'tof_x4_vimeo90k_official-32c9e01f.pth': '1TgQiXXsvkTBFrQ1D0eKPgL10tQGu0gKb'
},
'DFDNet': {
'DFDNet_dict_512-f79685f0.pth': '1iH00oMsoN_1OJaEQw3zP7_wqiAYMnY79',
'DFDNet_official-d1fa5650.pth': '1u6Sgcp8gVoy4uVTrOJKD3y9RuqH2JBAe'
},
'dlib': {
'mmod_human_face_detector-4cb19393.dat': '1FUM-hcoxNzFCOpCWbAUStBBMiU4uIGIL',
'shape_predictor_5_face_landmarks-c4b1e980.dat': '1PNPSmFjmbuuUDd5Mg5LDxyk7tu7TQv2F',
'shape_predictor_68_face_landmarks-fbdc2cb8.dat': '1IneH-O-gNkG0SQpNCplwxtOAtRCkG2ni'
},
'flownet': {
'spynet_sintel_final-3d2a1287.pth': '1VZz1cikwTRVX7zXoD247DB7n5Tj_LQpF'
},
'BasicVSR': {
'BasicVSR_REDS4-543c8261.pth': '1wLWdz18lWf9Z7lomHPkdySZ-_GV2920p',
'BasicVSR_Vimeo90K_BDx4-e9bf46eb.pth': '1baaf4RSpzs_zcDAF_s2CyArrGvLgmXxW',
'BasicVSR_Vimeo90K_BIx4-2a29695a.pth': '1ykIu1jv5wo95Kca2TjlieJFxeV4VVfHP',
'EDVR_REDS_pretrained_for_IconVSR-f62a2f1e.pth': '1ShfwddugTmT3_kB8VL6KpCMrIpEO5sBi',
'EDVR_Vimeo90K_pretrained_for_IconVSR-ee48ee92.pth': '16vR262NDVyVv5Q49xp2Sb-Llu05f63tt',
'IconVSR_REDS-aaa5367f.pth': '1b8ir754uIAFUSJ8YW_cmPzqer19AR7Hz',
'IconVSR_Vimeo90K_BDx4-cfcb7e00.pth': '13lp55s-YTd-fApx8tTy24bbHsNIGXdAH',
'IconVSR_Vimeo90K_BIx4-35fec07c.pth': '1lWUB36ERjFbAspr-8UsopJ6xwOuWjh2g'
}
}
if args.method == 'all':
for method in file_ids.keys():
download_pretrained_models(method, file_ids[method])
else:
download_pretrained_models(args.method, file_ids[args.method])
function [im_h] = backprojection(im_h, im_l, maxIter)
[row_l, col_l,~] = size(im_l);
[row_h, col_h,~] = size(im_h);
p = fspecial('gaussian', 5, 1);
p = p.^2;
p = p./sum(p(:));
im_l = double(im_l);
im_h = double(im_h);
for ii = 1:maxIter
im_l_s = imresize(im_h, [row_l, col_l], 'bicubic');
im_diff = im_l - im_l_s;
im_diff = imresize(im_diff, [row_h, col_h], 'bicubic');
im_h(:,:,1) = im_h(:,:,1) + conv2(im_diff(:,:,1), p, 'same');
im_h(:,:,2) = im_h(:,:,2) + conv2(im_diff(:,:,2), p, 'same');
im_h(:,:,3) = im_h(:,:,3) + conv2(im_diff(:,:,3), p, 'same');
end
clear; close all; clc;
LR_folder = './LR'; % LR
preout_folder = './results'; % pre output
save_folder = './results_20bp';
filepaths = dir(fullfile(preout_folder, '*.png'));
max_iter = 20;
if ~ exist(save_folder, 'dir')
mkdir(save_folder);
end
for idx_im = 1:length(filepaths)
fprintf([num2str(idx_im) '\n']);
im_name = filepaths(idx_im).name;
im_LR = im2double(imread(fullfile(LR_folder, im_name)));
im_out = im2double(imread(fullfile(preout_folder, im_name)));
%tic
im_out = backprojection(im_out, im_LR, max_iter);
%toc
imwrite(im_out, fullfile(save_folder, im_name));
end
clear; close all; clc;
LR_folder = './LR'; % LR
preout_folder = './results'; % pre output
save_folder = './results_20if';
filepaths = dir(fullfile(preout_folder, '*.png'));
max_iter = 20;
if ~ exist(save_folder, 'dir')
mkdir(save_folder);
end
for idx_im = 1:length(filepaths)
fprintf([num2str(idx_im) '\n']);
im_name = filepaths(idx_im).name;
im_LR = im2double(imread(fullfile(LR_folder, im_name)));
im_out = im2double(imread(fullfile(preout_folder, im_name)));
J = imresize(im_LR,4,'bicubic');
%tic
for m = 1:max_iter
im_out = im_out + (J - imresize(imresize(im_out,1/4,'bicubic'),4,'bicubic'));
end
%toc
imwrite(im_out, fullfile(save_folder, im_name));
end
function generate_LR_Vimeo90K()
%% matlab code to genetate bicubic-downsampled for Vimeo90K dataset
up_scale = 4;
mod_scale = 4;
idx = 0;
filepaths = dir('/home/xtwang/datasets/vimeo90k/vimeo_septuplet/sequences/*/*/*.png');
for i = 1 : length(filepaths)
[~,imname,ext] = fileparts(filepaths(i).name);
folder_path = filepaths(i).folder;
save_LR_folder = strrep(folder_path,'vimeo_septuplet','vimeo_septuplet_matlabLRx4');
if ~exist(save_LR_folder, 'dir')
mkdir(save_LR_folder);
end
if isempty(imname)
disp('Ignore . folder.');
elseif strcmp(imname, '.')
disp('Ignore .. folder.');
else
idx = idx + 1;
str_result = sprintf('%d\t%s.\n', idx, imname);
fprintf(str_result);
% read image
img = imread(fullfile(folder_path, [imname, ext]));
img = im2double(img);
% modcrop
img = modcrop(img, mod_scale);
% LR
im_LR = imresize(img, 1/up_scale, 'bicubic');
if exist('save_LR_folder', 'var')
imwrite(im_LR, fullfile(save_LR_folder, [imname, '.png']));
end
end
end
end
%% modcrop
function img = modcrop(img, modulo)
if size(img,3) == 1
sz = size(img);
sz = sz - mod(sz, modulo);
img = img(1:sz(1), 1:sz(2));
else
tmpsz = size(img);
sz = tmpsz(1:2);
sz = sz - mod(sz, modulo);
img = img(1:sz(1), 1:sz(2),:);
end
end
function generate_bicubic_img()
%% matlab code to genetate mod images, bicubic-downsampled images and
%% bicubic_upsampled images
%% set configurations
% comment the unnecessary lines
input_folder = '../../datasets/Set5/original';
save_mod_folder = '../../datasets/Set5/GTmod12';
save_lr_folder = '../../datasets/Set5/LRbicx2';
% save_bic_folder = '';
mod_scale = 12;
up_scale = 2;
if exist('save_mod_folder', 'var')
if exist(save_mod_folder, 'dir')
disp(['It will cover ', save_mod_folder]);
else
mkdir(save_mod_folder);
end
end
if exist('save_lr_folder', 'var')
if exist(save_lr_folder, 'dir')
disp(['It will cover ', save_lr_folder]);
else
mkdir(save_lr_folder);
end
end
if exist('save_bic_folder', 'var')
if exist(save_bic_folder, 'dir')
disp(['It will cover ', save_bic_folder]);
else
mkdir(save_bic_folder);
end
end
idx = 0;
filepaths = dir(fullfile(input_folder,'*.*'));
for i = 1 : length(filepaths)
[paths, img_name, ext] = fileparts(filepaths(i).name);
if isempty(img_name)
disp('Ignore . folder.');
elseif strcmp(img_name, '.')
disp('Ignore .. folder.');
else
idx = idx + 1;
str_result = sprintf('%d\t%s.\n', idx, img_name);
fprintf(str_result);
% read image
img = imread(fullfile(input_folder, [img_name, ext]));
img = im2double(img);
% modcrop
img = modcrop(img, mod_scale);
if exist('save_mod_folder', 'var')
imwrite(img, fullfile(save_mod_folder, [img_name, '.png']));
end
% LR
im_lr = imresize(img, 1/up_scale, 'bicubic');
if exist('save_lr_folder', 'var')
imwrite(im_lr, fullfile(save_lr_folder, [img_name, '.png']));
end
% Bicubic
if exist('save_bic_folder', 'var')
im_bicubic = imresize(im_lr, up_scale, 'bicubic');
imwrite(im_bicubic, fullfile(save_bic_folder, [img_name, '.png']));
end
end
end
end
%% modcrop
function img = modcrop(img, modulo)
if size(img,3) == 1
sz = size(img);
sz = sz - mod(sz, modulo);
img = img(1:sz(1), 1:sz(2));
else
tmpsz = size(img);
sz = tmpsz(1:2);
sz = sz - mod(sz, modulo);
img = img(1:sz(1), 1:sz(2),:);
end
end
import argparse
import math
import numpy as np
import torch
from torch.utils.data import DataLoader
from basicsr.data import build_dataset
from basicsr.metrics.fid import calculate_fid, extract_inception_features, load_patched_inception_v3
def calculate_fid_folder():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
parser = argparse.ArgumentParser()
parser.add_argument('folder', type=str, help='Path to the folder.')
parser.add_argument('--fid_stats', type=str, help='Path to the dataset fid statistics.')
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--num_sample', type=int, default=50000)
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--backend', type=str, default='disk', help='io backend for dataset. Option: disk, lmdb')
args = parser.parse_args()
# inception model
inception = load_patched_inception_v3(device)
# create dataset
opt = {}
opt['name'] = 'SingleImageDataset'
opt['type'] = 'SingleImageDataset'
opt['dataroot_lq'] = args.folder
opt['io_backend'] = dict(type=args.backend)
opt['mean'] = [0.5, 0.5, 0.5]
opt['std'] = [0.5, 0.5, 0.5]
dataset = build_dataset(opt)
# create dataloader
data_loader = DataLoader(
dataset=dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
sampler=None,
drop_last=False)
args.num_sample = min(args.num_sample, len(dataset))
total_batch = math.ceil(args.num_sample / args.batch_size)
def data_generator(data_loader, total_batch):
for idx, data in enumerate(data_loader):
if idx >= total_batch:
break
else:
yield data['lq']
features = extract_inception_features(data_generator(data_loader, total_batch), inception, total_batch, device)
features = features.numpy()
total_len = features.shape[0]
features = features[:args.num_sample]
print(f'Extracted {total_len} features, use the first {features.shape[0]} features to calculate stats.')
sample_mean = np.mean(features, 0)
sample_cov = np.cov(features, rowvar=False)
# load the dataset stats
stats = torch.load(args.fid_stats)
real_mean = stats['mean']
real_cov = stats['cov']
# calculate FID metric
fid = calculate_fid(sample_mean, sample_cov, real_mean, real_cov)
print('fid:', fid)
if __name__ == '__main__':
calculate_fid_folder()
import argparse
import math
import numpy as np
import torch
from torch.utils.data import DataLoader
from basicsr.data import build_dataset
from basicsr.metrics.fid import extract_inception_features, load_patched_inception_v3
def calculate_stats_from_dataset():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
parser = argparse.ArgumentParser()
parser.add_argument('--num_sample', type=int, default=50000)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--size', type=int, default=512)
parser.add_argument('--dataroot', type=str, default='datasets/ffhq')
args = parser.parse_args()
# inception model
inception = load_patched_inception_v3(device)
# create dataset
opt = {}
opt['name'] = 'FFHQ'
opt['type'] = 'FFHQDataset'
opt['dataroot_gt'] = f'datasets/ffhq/ffhq_{args.size}.lmdb'
opt['io_backend'] = dict(type='lmdb')
opt['use_hflip'] = False
opt['mean'] = [0.5, 0.5, 0.5]
opt['std'] = [0.5, 0.5, 0.5]
dataset = build_dataset(opt)
# create dataloader
data_loader = DataLoader(
dataset=dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, sampler=None, drop_last=False)
total_batch = math.ceil(args.num_sample / args.batch_size)
def data_generator(data_loader, total_batch):
for idx, data in enumerate(data_loader):
if idx >= total_batch:
break
else:
yield data['gt']
features = extract_inception_features(data_generator(data_loader, total_batch), inception, total_batch, device)
features = features.numpy()
total_len = features.shape[0]
features = features[:args.num_sample]
print(f'Extracted {total_len} features, use the first {features.shape[0]} features to calculate stats.')
mean = np.mean(features, 0)
cov = np.cov(features, rowvar=False)
save_path = f'inception_{opt["name"]}_{args.size}.pth'
torch.save(
dict(name=opt['name'], size=args.size, mean=mean, cov=cov), save_path, _use_new_zipfile_serialization=False)
if __name__ == '__main__':
calculate_stats_from_dataset()
import cv2
import glob
import numpy as np
import os.path as osp
from torchvision.transforms.functional import normalize
from basicsr.utils import img2tensor
try:
import lpips
except ImportError:
print('Please install lpips: pip install lpips')
def main():
# Configurations
# -------------------------------------------------------------------------
folder_gt = 'datasets/celeba/celeba_512_validation'
folder_restored = 'datasets/celeba/celeba_512_validation_lq'
# crop_border = 4
suffix = ''
# -------------------------------------------------------------------------
loss_fn_vgg = lpips.LPIPS(net='vgg').cuda() # RGB, normalized to [-1,1]
lpips_all = []
img_list = sorted(glob.glob(osp.join(folder_gt, '*')))
mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]
for i, img_path in enumerate(img_list):
basename, ext = osp.splitext(osp.basename(img_path))
img_gt = cv2.imread(img_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
img_restored = cv2.imread(osp.join(folder_restored, basename + suffix + ext), cv2.IMREAD_UNCHANGED).astype(
np.float32) / 255.
img_gt, img_restored = img2tensor([img_gt, img_restored], bgr2rgb=True, float32=True)
# norm to [-1, 1]
normalize(img_gt, mean, std, inplace=True)
normalize(img_restored, mean, std, inplace=True)
# calculate lpips
lpips_val = loss_fn_vgg(img_restored.unsqueeze(0).cuda(), img_gt.unsqueeze(0).cuda())
print(f'{i+1:3d}: {basename:25}. \tLPIPS: {lpips_val:.6f}.')
lpips_all.append(lpips_val)
print(f'Average: LPIPS: {sum(lpips_all) / len(lpips_all):.6f}')
if __name__ == '__main__':
main()
import argparse
import cv2
import os
import warnings
from basicsr.metrics import calculate_niqe
from basicsr.utils import scandir
def main(args):
niqe_all = []
img_list = sorted(scandir(args.input, recursive=True, full_path=True))
for i, img_path in enumerate(img_list):
basename, _ = os.path.splitext(os.path.basename(img_path))
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
with warnings.catch_warnings():
warnings.simplefilter('ignore', category=RuntimeWarning)
niqe_score = calculate_niqe(img, args.crop_border, input_order='HWC', convert_to='y')
print(f'{i+1:3d}: {basename:25}. \tNIQE: {niqe_score:.6f}')
niqe_all.append(niqe_score)
print(args.input)
print(f'Average: NIQE: {sum(niqe_all) / len(niqe_all):.6f}')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--input', type=str, default='datasets/val_set14/Set14', help='Input path')
parser.add_argument('--crop_border', type=int, default=0, help='Crop border for each side')
args = parser.parse_args()
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment