Commit d5096d86 authored by mashun1's avatar mashun1
Browse files

idmvton

parents
Pipeline #1220 canceled with stages
import torch
def convert_edvr():
ori_net = torch.load('experiments/pretrained_models/EDVR_REDS_SR_M.pth')
crt_net = torch.load('xxx/net_g_8.pth')
save_path = './edvr_medium_x4_reds_sr_official.pth'
# for k, v in ori_net.items():
# print(k)
# print('*****')
# for k, v in crt_net.items():
# print(k)
for crt_k, _ in crt_net.items():
# deblur hr in
if 'predeblur.stride_conv_hr1' in crt_k:
ori_k = crt_k.replace('predeblur.stride_conv_hr1', 'pre_deblur.conv_first_2')
elif 'predeblur.stride_conv_hr2' in crt_k:
ori_k = crt_k.replace('predeblur.stride_conv_hr2', 'pre_deblur.conv_first_3')
elif 'predeblur.conv_first' in crt_k:
ori_k = crt_k.replace('predeblur.conv_first', 'pre_deblur.conv_first_1')
# predeblur module
# elif 'predeblur.conv_first' in crt_k:
# ori_k = crt_k.replace('predeblur.conv_first',
# 'pre_deblur.conv_first')
elif 'predeblur.stride_conv_l2' in crt_k:
ori_k = crt_k.replace('predeblur.stride_conv_l2', 'pre_deblur.deblur_L2_conv')
elif 'predeblur.stride_conv_l3' in crt_k:
ori_k = crt_k.replace('predeblur.stride_conv_l3', 'pre_deblur.deblur_L3_conv')
elif 'predeblur.resblock_l3' in crt_k:
ori_k = crt_k.replace('predeblur.resblock_l3', 'pre_deblur.RB_L3_1')
elif 'predeblur.resblock_l2' in crt_k:
ori_k = crt_k.replace('predeblur.resblock_l', 'pre_deblur.RB_L')
elif 'predeblur.resblock_l1' in crt_k:
a, b, c, d, e = crt_k.split('.')
ori_k = f'pre_deblur.RB_L1_{int(c)+1}.{d}.{e}'
elif 'conv_l2' in crt_k:
ori_k = crt_k.replace('conv_l2_', 'fea_L2_conv')
elif 'conv_l3' in crt_k:
ori_k = crt_k.replace('conv_l3_', 'fea_L3_conv')
elif 'pcd_align.dcn_pack' in crt_k:
idx = crt_k.split('.l')[1].split('.')[0]
name = crt_k.split('.l')[1].split('.')[1]
ori_k = f'pcd_align.L{idx}_dcnpack.{name}'
if 'conv_offset' in crt_k:
name = name.replace('conv_offset', 'conv_offset_mask')
weight_bias = crt_k.split('.l')[1].split('.')[2]
ori_k = f'pcd_align.L{idx}_dcnpack.{name}.{weight_bias}'
elif 'pcd_align.offset_conv' in crt_k:
_, b, c, d = crt_k.split('.')
idx = b.split('conv')[1]
level = c.split('l')[1]
ori_k = f'pcd_align.L{level}_offset_conv{idx}.{d}'
elif 'pcd_align.feat_conv' in crt_k:
a, b, c, d = crt_k.split('.')
level = c.split('l')[1]
ori_k = f'pcd_align.L{level}_fea_conv.{d}'
elif 'pcd_align.cas_dcnpack' in crt_k:
ori_k = crt_k.replace('conv_offset', 'conv_offset_mask')
elif ('conv_first' in crt_k or 'feature_extraction' in crt_k or 'pcd_align.cas_offset' in crt_k
or 'upconv' in crt_k or 'conv_last' in crt_k or 'conv_1x1' in crt_k):
ori_k = crt_k
elif 'temporal_attn1' in crt_k:
ori_k = crt_k.replace('fusion.temporal_attn1', 'tsa_fusion.tAtt_2')
elif 'temporal_attn2' in crt_k:
ori_k = crt_k.replace('fusion.temporal_attn2', 'tsa_fusion.tAtt_1')
elif 'fusion.feat_fusion' in crt_k:
ori_k = crt_k.replace('fusion.feat_fusion', 'tsa_fusion.fea_fusion')
elif 'fusion.spatial_attn_add' in crt_k:
ori_k = crt_k.replace('fusion.spatial_attn_add', 'tsa_fusion.sAtt_add_')
elif 'fusion.spatial_attn_l' in crt_k:
ori_k = crt_k.replace('fusion.spatial_attn_l', 'tsa_fusion.sAtt_L')
elif 'fusion.spatial_attn' in crt_k:
ori_k = crt_k.replace('fusion.spatial_attn', 'tsa_fusion.sAtt_')
elif 'reconstruction' in crt_k:
ori_k = crt_k.replace('reconstruction', 'recon_trunk')
elif 'conv_hr' in crt_k:
ori_k = crt_k.replace('conv_hr', 'HRconv')
# for model woTSA
elif 'fusion' in crt_k:
ori_k = crt_k.replace('fusion', 'tsa_fusion')
else:
print('unprocess key', crt_k)
# print(ori_k)
crt_net[crt_k] = ori_net[ori_k]
ori_k = None
torch.save(crt_net, save_path)
def convert_edsr(ori_net_path, crt_net_path, save_path, num_block=32):
"""Convert EDSR models in https://github.com/thstkdgus35/EDSR-PyTorch.
It supports converting x2, x3 and x4 models.
Args:
ori_net_path (str): Original network path.
crt_net_path (str): Current network path.
save_path (str): The path to save the converted model.
num_block (int): Number of blocks. Default: 16.
"""
ori_net = torch.load(ori_net_path)
crt_net = torch.load(crt_net_path)
for crt_k, _ in crt_net.items():
if 'conv_first' in crt_k:
ori_k = crt_k.replace('conv_first', 'head.0')
crt_net[crt_k] = ori_net[ori_k]
elif 'conv_after_body' in crt_k:
ori_k = crt_k.replace('conv_after_body', f'body.{num_block}')
elif 'body' in crt_k:
ori_k = crt_k.replace('conv1', 'body.0').replace('conv2', 'body.2')
elif 'upsample.0' in crt_k:
ori_k = crt_k.replace('upsample.0', 'tail.0.0')
elif 'upsample.2' in crt_k:
ori_k = crt_k.replace('upsample.2', 'tail.0.2')
elif 'conv_last' in crt_k:
ori_k = crt_k.replace('conv_last', 'tail.1')
else:
print('unprocess key', crt_k)
crt_net[crt_k] = ori_net[ori_k]
torch.save(crt_net, save_path)
def convert_rcan_model():
ori_net = torch.load('RCAN_model_best.pt')
crt_net = torch.load('experiments/201_RCANx4_scratch_DIV2K_rand0/models/net_g_5000.pth')
# for ori_k, ori_v in ori_net.items():
# print(ori_k)
for crt_k, _ in crt_net.items():
# print(crt_k)
if 'conv_first' in crt_k:
ori_k = crt_k.replace('conv_first', 'head.0')
crt_net[crt_k] = ori_net[ori_k]
elif 'conv_after_body' in crt_k:
ori_k = crt_k.replace('conv_after_body', 'body.10')
elif 'upsample.0' in crt_k:
ori_k = crt_k.replace('upsample.0', 'tail.0.0')
elif 'upsample.2' in crt_k:
ori_k = crt_k.replace('upsample.2', 'tail.0.2')
elif 'conv_last' in crt_k:
ori_k = crt_k.replace('conv_last', 'tail.1')
elif 'attention' in crt_k:
_, ai, _, bi, _, ci, d, di, e = crt_k.split('.')
ori_k = f'body.{ai}.body.{bi}.body.{ci}.conv_du.{int(di)-1}.{e}'
elif 'rcab' in crt_k:
a, ai, b, bi, c, ci, d = crt_k.split('.')
ori_k = f'body.{ai}.body.{bi}.body.{ci}.{d}'
elif 'body' in crt_k:
ori_k = crt_k.replace('conv.', 'body.20.')
else:
print('unprocess key', crt_k)
crt_net[crt_k] = ori_net[ori_k]
torch.save(crt_net, 'RCAN_model_best.pth')
def convert_esrgan_model():
from basicsr.archs.rrdbnet_arch import RRDBNet
rrdb = RRDBNet(3, 3, num_feat=64, num_block=23, num_grow_ch=32)
crt_net = rrdb.state_dict()
# for k, v in crt_net.items():
# print(k)
ori_net = torch.load('experiments/pretrained_models/RRDB_ESRGAN_x4.pth')
# for k, v in ori_net.items():
# print(k)
for crt_k, _ in crt_net.items():
if 'rdb' in crt_k:
ori_k = crt_k.replace('rdb', 'RDB').replace('body', 'RRDB_trunk')
elif 'conv_body' in crt_k:
ori_k = crt_k.replace('conv_body', 'trunk_conv')
elif 'conv_up' in crt_k:
ori_k = crt_k.replace('conv_up', 'upconv')
elif 'conv_hr' in crt_k:
ori_k = crt_k.replace('conv_hr', 'HRconv')
else:
ori_k = crt_k
print(crt_k)
crt_net[crt_k] = ori_net[ori_k]
torch.save(crt_net, 'experiments/pretrained_models/ESRGAN_x4_SR_DF2KOST_official.pth')
def convert_duf_model():
from basicsr.archs.duf_arch import DUF
scale = 2
duf = DUF(scale=scale, num_layer=16, adapt_official_weights=True)
crt_net = duf.state_dict()
# for k, v in crt_net.items():
# print(k)
ori_net = torch.load('experiments/pretrained_models/old_DUF_x2_16L_official.pth')
# print('******')
# for k, v in ori_net.items():
# print(k)
'''
for crt_k, crt_v in crt_net.items():
if 'conv3d1' in crt_k:
ori_k = crt_k.replace('conv3d1', 'conv3d_1')
elif 'conv3d2' in crt_k:
ori_k = crt_k.replace('conv3d2', 'conv3d_2')
elif 'dense_block1.dense_blocks' in crt_k:
# dense_block1.dense_blocks.0.0.weight
a, b, c, d, e = crt_k.split('.')
# dense_block_1.dense_blocks.0.weight
ori_k = f'dense_block_1.dense_blocks.{int(c) * 6 + int(d)}.{e}'
elif 'dense_block2.temporal_reduce1.0' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce1.0',
'dense_block_2.bn3d_1')
elif 'dense_block2.temporal_reduce1.2' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce1.2',
'dense_block_2.conv3d_1')
elif 'dense_block2.temporal_reduce1.3' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce1.3',
'dense_block_2.bn3d_2')
elif 'dense_block2.temporal_reduce1.5' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce1.5',
'dense_block_2.conv3d_2')
elif 'dense_block2.temporal_reduce2.0' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce2.0',
'dense_block_2.bn3d_3')
elif 'dense_block2.temporal_reduce2.2' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce2.2',
'dense_block_2.conv3d_3')
elif 'dense_block2.temporal_reduce2.3' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce2.3',
'dense_block_2.bn3d_4')
elif 'dense_block2.temporal_reduce2.5' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce2.5',
'dense_block_2.conv3d_4')
elif 'dense_block2.temporal_reduce3.0' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce3.0',
'dense_block_2.bn3d_5')
elif 'dense_block2.temporal_reduce3.2' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce3.2',
'dense_block_2.conv3d_5')
elif 'dense_block2.temporal_reduce3.3' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce3.3',
'dense_block_2.bn3d_6')
elif 'dense_block2.temporal_reduce3.5' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce3.5',
'dense_block_2.conv3d_6')
elif 'bn3d2' in crt_k:
ori_k = crt_k.replace('bn3d2', 'bn3d_2')
else:
ori_k = crt_k
print(crt_k)
crt_net[crt_k] = ori_net[ori_k]
'''
# for 16 layers
for crt_k, _ in crt_net.items():
if 'conv3d1' in crt_k:
ori_k = crt_k.replace('conv3d1', 'conv3d_1')
elif 'conv3d2' in crt_k:
ori_k = crt_k.replace('conv3d2', 'conv3d_2')
elif 'dense_block1.dense_blocks.0.0' in crt_k:
ori_k = crt_k.replace('dense_block1.dense_blocks.0.0', 'dense_block_1.bn3d_1')
elif 'dense_block1.dense_blocks.0.2' in crt_k:
ori_k = crt_k.replace('dense_block1.dense_blocks.0.2', 'dense_block_1.conv3d_1')
elif 'dense_block1.dense_blocks.0.3' in crt_k:
ori_k = crt_k.replace('dense_block1.dense_blocks.0.3', 'dense_block_1.bn3d_2')
elif 'dense_block1.dense_blocks.0.5' in crt_k:
ori_k = crt_k.replace('dense_block1.dense_blocks.0.5', 'dense_block_1.conv3d_2')
elif 'dense_block1.dense_blocks.1.0' in crt_k:
ori_k = crt_k.replace('dense_block1.dense_blocks.1.0', 'dense_block_1.bn3d_3')
elif 'dense_block1.dense_blocks.1.2' in crt_k:
ori_k = crt_k.replace('dense_block1.dense_blocks.1.2', 'dense_block_1.conv3d_3')
elif 'dense_block1.dense_blocks.1.3' in crt_k:
ori_k = crt_k.replace('dense_block1.dense_blocks.1.3', 'dense_block_1.bn3d_4')
elif 'dense_block1.dense_blocks.1.5' in crt_k:
ori_k = crt_k.replace('dense_block1.dense_blocks.1.5', 'dense_block_1.conv3d_4')
elif 'dense_block1.dense_blocks.2.0' in crt_k:
ori_k = crt_k.replace('dense_block1.dense_blocks.2.0', 'dense_block_1.bn3d_5')
elif 'dense_block1.dense_blocks.2.2' in crt_k:
ori_k = crt_k.replace('dense_block1.dense_blocks.2.2', 'dense_block_1.conv3d_5')
elif 'dense_block1.dense_blocks.2.3' in crt_k:
ori_k = crt_k.replace('dense_block1.dense_blocks.2.3', 'dense_block_1.bn3d_6')
elif 'dense_block1.dense_blocks.2.5' in crt_k:
ori_k = crt_k.replace('dense_block1.dense_blocks.2.5', 'dense_block_1.conv3d_6')
elif 'dense_block2.temporal_reduce1.0' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce1.0', 'dense_block_2.bn3d_1')
elif 'dense_block2.temporal_reduce1.2' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce1.2', 'dense_block_2.conv3d_1')
elif 'dense_block2.temporal_reduce1.3' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce1.3', 'dense_block_2.bn3d_2')
elif 'dense_block2.temporal_reduce1.5' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce1.5', 'dense_block_2.conv3d_2')
elif 'dense_block2.temporal_reduce2.0' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce2.0', 'dense_block_2.bn3d_3')
elif 'dense_block2.temporal_reduce2.2' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce2.2', 'dense_block_2.conv3d_3')
elif 'dense_block2.temporal_reduce2.3' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce2.3', 'dense_block_2.bn3d_4')
elif 'dense_block2.temporal_reduce2.5' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce2.5', 'dense_block_2.conv3d_4')
elif 'dense_block2.temporal_reduce3.0' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce3.0', 'dense_block_2.bn3d_5')
elif 'dense_block2.temporal_reduce3.2' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce3.2', 'dense_block_2.conv3d_5')
elif 'dense_block2.temporal_reduce3.3' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce3.3', 'dense_block_2.bn3d_6')
elif 'dense_block2.temporal_reduce3.5' in crt_k:
ori_k = crt_k.replace('dense_block2.temporal_reduce3.5', 'dense_block_2.conv3d_6')
elif 'bn3d2' in crt_k:
ori_k = crt_k.replace('bn3d2', 'bn3d_2')
else:
ori_k = crt_k
print(crt_k)
crt_net[crt_k] = ori_net[ori_k]
x = crt_net['conv3d_r2.weight'].clone()
x1 = x[::3, ...]
x2 = x[1::3, ...]
x3 = x[2::3, ...]
crt_net['conv3d_r2.weight'][:scale**2, ...] = x1
crt_net['conv3d_r2.weight'][scale**2:2 * (scale**2), ...] = x2
crt_net['conv3d_r2.weight'][2 * (scale**2):, ...] = x3
x = crt_net['conv3d_r2.bias'].clone()
x1 = x[::3, ...]
x2 = x[1::3, ...]
x3 = x[2::3, ...]
crt_net['conv3d_r2.bias'][:scale**2, ...] = x1
crt_net['conv3d_r2.bias'][scale**2:2 * (scale**2), ...] = x2
crt_net['conv3d_r2.bias'][2 * (scale**2):, ...] = x3
torch.save(crt_net, 'experiments/pretrained_models/DUF_x2_16L_official.pth')
if __name__ == '__main__':
# convert EDSR models
# ori_net_path = 'path to original model'
# crt_net_path = 'path to current model'
# save_path = 'save path'
# convert_edsr(ori_net_path, crt_net_path, save_path, num_block=32)
convert_duf_model()
import torch
from collections import OrderedDict
from basicsr.archs.ridnet_arch import RIDNet
if __name__ == '__main__':
ori_net_checkpoint = torch.load(
'experiments/pretrained_models/RIDNet/RIDNet_official_original.pt', map_location=lambda storage, loc: storage)
rid_net = RIDNet(3, 64, 3)
new_ridnet_dict = OrderedDict()
rid_net_namelist = []
for name, param in rid_net.named_parameters():
rid_net_namelist.append(name)
count = 0
for name, param in ori_net_checkpoint.items():
new_ridnet_dict[rid_net_namelist[count]] = param
count += 1
rid_net.load_state_dict(new_ridnet_dict)
torch.save(rid_net.state_dict(), 'experiments/pretrained_models/RIDNet/RIDNet.pth')
import torch
from basicsr.archs.stylegan2_arch import StyleGAN2Discriminator, StyleGAN2Generator
def convert_net_g(ori_net, crt_net):
"""Convert network generator."""
for crt_k, crt_v in crt_net.items():
if 'style_mlp' in crt_k:
ori_k = crt_k.replace('style_mlp', 'style')
elif 'constant_input.weight' in crt_k:
ori_k = crt_k.replace('constant_input.weight', 'input.input')
# style conv1
elif 'style_conv1.modulated_conv' in crt_k:
ori_k = crt_k.replace('style_conv1.modulated_conv', 'conv1.conv')
elif 'style_conv1' in crt_k:
if crt_v.shape == torch.Size([1]):
ori_k = crt_k.replace('style_conv1', 'conv1.noise')
else:
ori_k = crt_k.replace('style_conv1', 'conv1')
# style conv
elif 'style_convs' in crt_k:
ori_k = crt_k.replace('style_convs', 'convs').replace('modulated_conv', 'conv')
if crt_v.shape == torch.Size([1]):
ori_k = ori_k.replace('.weight', '.noise.weight')
# to_rgb1
elif 'to_rgb1.modulated_conv' in crt_k:
ori_k = crt_k.replace('to_rgb1.modulated_conv', 'to_rgb1.conv')
# to_rgbs
elif 'to_rgbs' in crt_k:
ori_k = crt_k.replace('modulated_conv', 'conv')
elif 'noises' in crt_k:
ori_k = crt_k.replace('.noise', '.noise_')
else:
ori_k = crt_k
# replace
if crt_net[crt_k].size() != ori_net[ori_k].size():
raise ValueError('Wrong tensor size: \n'
f'crt_net: {crt_net[crt_k].size()}\n'
f'ori_net: {ori_net[ori_k].size()}')
else:
crt_net[crt_k] = ori_net[ori_k]
return crt_net
def convert_net_d(ori_net, crt_net):
"""Convert network discriminator."""
for crt_k, _ in crt_net.items():
if 'conv_body' in crt_k:
ori_k = crt_k.replace('conv_body', 'convs')
else:
ori_k = crt_k
# replace
if crt_net[crt_k].size() != ori_net[ori_k].size():
raise ValueError('Wrong tensor size: \n'
f'crt_net: {crt_net[crt_k].size()}\n'
f'ori_net: {ori_net[ori_k].size()}')
else:
crt_net[crt_k] = ori_net[ori_k]
return crt_net
if __name__ == '__main__':
"""Convert official stylegan2 weights from stylegan2-pytorch."""
# configuration
ori_net = torch.load('experiments/pretrained_models/stylegan2-ffhq.pth')
save_path_g = 'experiments/pretrained_models/stylegan2_ffhq_config_f_1024_official.pth' # noqa: E501
save_path_d = 'experiments/pretrained_models/stylegan2_ffhq_config_f_1024_discriminator_official.pth' # noqa: E501
out_size = 1024
channel_multiplier = 1
# convert generator
crt_net = StyleGAN2Generator(out_size, num_style_feat=512, num_mlp=8, channel_multiplier=channel_multiplier)
crt_net = crt_net.state_dict()
crt_net_params_ema = convert_net_g(ori_net['g_ema'], crt_net)
torch.save(dict(params_ema=crt_net_params_ema, latent_avg=ori_net['latent_avg']), save_path_g)
# convert discriminator
crt_net = StyleGAN2Discriminator(out_size, channel_multiplier=channel_multiplier)
crt_net = crt_net.state_dict()
crt_net_params = convert_net_d(ori_net['d'], crt_net)
torch.save(dict(params=crt_net_params), save_path_d)
# Plot Figures
We provide source codes for some representative figures.
You can easily modify those figures to fit your needs.
The commonly used functions are defined in [basicsr/utils/plot_util.py](https://github.com/XPixelGroup/BasicSR/blob/plot/basicsr/utils/plot_util.py), such as `read_data_from_tensorboard`, `smooth_data`, `read_data_from_txt_2v`, *etc*.
- [model_complexity_cmp_bsrn.py](model_complexity_cmp_bsrn.py) by Haoming Cai. [[Paper](https://openaccess.thecvf.com/content/CVPR2022W/NTIRE/papers/Li_Blueprint_Separable_Residual_Network_for_Efficient_Image_Super-Resolution_CVPRW_2022_paper.pdf)]
<p align="center">
<img src="../../assets/plot/model_complexity_cmp_bsrn.png" height=250>
</p>
def main():
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(15, 10))
radius = 9.5
notation_size = 27
'''0 - 10'''
# BSRN-S, FSRCNN
x = [156, 13]
y = [32.16, 30.71]
area = (30) * radius**2
ax.scatter(x, y, s=area, alpha=0.8, marker='.', c='#4D96FF', edgecolors='white', linewidths=2.0)
plt.annotate('FSRCNN', (13 + 10, 30.71 + 0.1), fontsize=notation_size)
plt.annotate('BSRN-S(Ours)', (156 - 70, 32.16 + 0.15), fontsize=notation_size)
'''10 - 25'''
# BSRN, RFDN
x = [357, 550]
y = [32.30, 32.24]
area = (75) * radius**2
ax.scatter(x, y, s=area, alpha=1.0, marker='.', c='#FFD93D', edgecolors='white', linewidths=2.0)
plt.annotate('BSRN(Ours)', (357 - 70, 32.35 + 0.10), fontsize=notation_size)
plt.annotate('RFDN', (550 - 70, 32.24 + 0.15), fontsize=notation_size)
'''25 - 50'''
# IDN, IMDN, PAN
x = [553, 715, 272]
y = [31.82, 32.21, 32.13]
area = (140) * radius**2
ax.scatter(x, y, s=area, alpha=0.6, marker='.', c='#95CD41', edgecolors='white', linewidths=2.0)
plt.annotate('IDN', (553 - 60, 31.82 + 0.15), fontsize=notation_size)
plt.annotate('IMDN', (715 + 10, 32.21 + 0.15), fontsize=notation_size)
plt.annotate('PAN', (272 - 70, 32.13 - 0.25), fontsize=notation_size)
'''50 - 100'''
# SRCNN, CARN, LAPAR-A
x = [57, 1592, 659]
y = [30.48, 32.13, 32.15]
area = 175 * radius**2
ax.scatter(x, y, s=area, alpha=0.8, marker='.', c='#EAE7C6', edgecolors='white', linewidths=2.0)
plt.annotate('SRCNN', (57 + 30, 30.48 + 0.1), fontsize=notation_size)
plt.annotate('LAPAR-A', (659 - 75, 32.15 + 0.20), fontsize=notation_size)
'''1M+'''
# LapSRCN, VDSR, DRRN, MemNet
x = [502, 666, 298, 678]
y = [31.54, 31.35, 31.68, 31.74]
area = (250) * radius**2
ax.scatter(x, y, s=area, alpha=0.3, marker='.', c='#264653', edgecolors='white', linewidths=2.0)
plt.annotate('LapSRCN', (502 - 90, 31.54 - 0.35), fontsize=notation_size)
plt.annotate('VDSR', (666 - 70, 31.35 - 0.35), fontsize=notation_size)
plt.annotate('DRRN', (298 - 65, 31.68 - 0.35), fontsize=notation_size)
plt.annotate('MemNet', (678 + 15, 31.74 + 0.18), fontsize=notation_size)
'''Ours marker'''
x = [156]
y = [32.16]
ax.scatter(x, y, alpha=1.0, marker='*', c='r', s=300)
x = [357]
y = [32.30]
ax.scatter(x, y, alpha=1.0, marker='*', c='r', s=700)
plt.xlim(0, 800)
plt.ylim(29.75, 32.75)
plt.xlabel('Parameters (K)', fontsize=35)
plt.ylabel('PSNR (dB)', fontsize=35)
plt.title('PSNR vs. Parameters vs. Multi-Adds', fontsize=35)
h = [
plt.plot([], [], color=c, marker='.', ms=i, alpha=a, ls='')[0] for i, c, a in zip(
[40, 60, 80, 95, 110], ['#4D96FF', '#FFD93D', '#95CD41', '#EAE7C6', '#264653'], [0.8, 1.0, 0.6, 0.8, 0.3])
]
ax.legend(
labelspacing=0.1,
handles=h,
handletextpad=1.0,
markerscale=1.0,
fontsize=17,
title='Multi-Adds',
title_fontsize=25,
labels=['<10k', '10k-25k', '25k-50k', '50k-100k', '1M+'],
scatteryoffsets=[0.0],
loc='lower right',
ncol=5,
shadow=True,
handleheight=6)
for size in ax.get_xticklabels(): # Set fontsize for x-axis
size.set_fontsize('30')
for size in ax.get_yticklabels(): # Set fontsize for y-axis
size.set_fontsize('30')
ax.grid(b=True, linestyle='-.', linewidth=0.5)
plt.show()
fig.savefig('model_complexity_cmp_bsrn.png')
if __name__ == '__main__':
main()
import glob
import subprocess
import torch
from os import path as osp
from torch.serialization import _is_zipfile, _open_file_like
def update_sha(paths):
print('# Update sha ...')
for idx, path in enumerate(paths):
print(f'{idx+1:03d}: Processing {path}')
net = torch.load(path, map_location=torch.device('cpu'))
basename = osp.basename(path)
if 'params' not in net and 'params_ema' not in net:
user_response = input(f'WARN: Model {basename} does not have "params"/"params_ema" key. '
'Do you still want to continue? Y/N\n')
if user_response.lower() == 'y':
pass
elif user_response.lower() == 'n':
raise ValueError('Please modify..')
else:
raise ValueError('Wrong input. Only accepts Y/N.')
if '-' in basename:
# check whether the sha is the latest
old_sha = basename.split('-')[1].split('.')[0]
new_sha = subprocess.check_output(['sha256sum', path]).decode()[:8]
if old_sha != new_sha:
final_file = path.split('-')[0] + f'-{new_sha}.pth'
print(f'\tSave from {path} to {final_file}')
subprocess.Popen(['mv', path, final_file])
else:
sha = subprocess.check_output(['sha256sum', path]).decode()[:8]
final_file = path.split('.pth')[0] + f'-{sha}.pth'
print(f'\tSave from {path} to {final_file}')
subprocess.Popen(['mv', path, final_file])
def convert_to_backward_compatible_models(paths):
"""Convert to backward compatible pth files.
PyTorch 1.6 uses a updated version of torch.save. In order to be compatible
with previous PyTorch version, save it with
_use_new_zipfile_serialization=False.
"""
print('# Convert to backward compatible pth files ...')
for idx, path in enumerate(paths):
print(f'{idx+1:03d}: Processing {path}')
flag_need_conversion = False
with _open_file_like(path, 'rb') as opened_file:
if _is_zipfile(opened_file):
flag_need_conversion = True
if flag_need_conversion:
net = torch.load(path, map_location=torch.device('cpu'))
print('\tConverting to compatible pth file...')
torch.save(net, path, _use_new_zipfile_serialization=False)
if __name__ == '__main__':
paths = glob.glob('experiments/pretrained_models/*.pth') + glob.glob('experiments/pretrained_models/**/*.pth')
convert_to_backward_compatible_models(paths)
update_sha(paths)
[flake8]
ignore =
# line break before binary operator (W503)
W503,
# line break after binary operator (W504)
W504,
max-line-length=120
[yapf]
based_on_style = pep8
column_limit = 120
blank_line_before_nested_class_or_def = true
split_before_expression_after_opening_paren = true
[isort]
line_length = 120
multi_line_output = 0
known_standard_library = pkg_resources,setuptools
known_first_party = basicsr
known_third_party = PIL,cv2,lmdb,numpy,pytest,requests,scipy,skimage,torch,torchvision,tqdm,yaml
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY
[codespell]
skip = .git,./docs/build,*.cfg
count =
quiet-level = 3
ignore-words-list = gool
[aliases]
test=pytest
[tool:pytest]
addopts=tests/
#!/usr/bin/env python
from setuptools import find_packages, setup
import os
import subprocess
import time
version_file = 'basicsr/version.py'
def readme():
with open('README.md', encoding='utf-8') as f:
content = f.read()
return content
def get_git_hash():
def _minimal_ext_cmd(cmd):
# construct minimal environment
env = {}
for k in ['SYSTEMROOT', 'PATH', 'HOME']:
v = os.environ.get(k)
if v is not None:
env[k] = v
# LANGUAGE is used on win32
env['LANGUAGE'] = 'C'
env['LANG'] = 'C'
env['LC_ALL'] = 'C'
out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
return out
try:
out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
sha = out.strip().decode('ascii')
except OSError:
sha = 'unknown'
return sha
def get_hash():
if os.path.exists('.git'):
sha = get_git_hash()[:7]
# currently ignore this
# elif os.path.exists(version_file):
# try:
# from basicsr.version import __version__
# sha = __version__.split('+')[-1]
# except ImportError:
# raise ImportError('Unable to get git version')
else:
sha = 'unknown'
return sha
def write_version_py():
content = """# GENERATED VERSION FILE
# TIME: {}
__version__ = '{}'
__gitsha__ = '{}'
version_info = ({})
"""
sha = get_hash()
with open('VERSION', 'r') as f:
SHORT_VERSION = f.read().strip()
VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')])
version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO)
with open(version_file, 'w') as f:
f.write(version_file_str)
def get_version():
with open(version_file, 'r') as f:
exec(compile(f.read(), version_file, 'exec'))
return locals()['__version__']
def make_cuda_ext(name, module, sources, sources_cuda=None):
if sources_cuda is None:
sources_cuda = []
define_macros = []
extra_compile_args = {'cxx': []}
if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
define_macros += [('WITH_CUDA', None)]
extension = CUDAExtension
extra_compile_args['nvcc'] = [
'-D__CUDA_NO_HALF_OPERATORS__',
'-D__CUDA_NO_HALF_CONVERSIONS__',
'-D__CUDA_NO_HALF2_OPERATORS__',
]
sources += sources_cuda
else:
print(f'Compiling {name} without CUDA')
extension = CppExtension
return extension(
name=f'{module}.{name}',
sources=[os.path.join(*module.split('.'), p) for p in sources],
define_macros=define_macros,
extra_compile_args=extra_compile_args)
def get_requirements(filename='requirements.txt'):
here = os.path.dirname(os.path.realpath(__file__))
with open(os.path.join(here, filename), 'r') as f:
requires = [line.replace('\n', '') for line in f.readlines()]
return requires
if __name__ == '__main__':
cuda_ext = os.getenv('BASICSR_EXT') # whether compile cuda ext
if cuda_ext == 'True':
try:
import torch
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
except ImportError:
raise ImportError('Unable to import torch - torch is needed to build cuda extensions')
ext_modules = [
make_cuda_ext(
name='deform_conv_ext',
module='basicsr.ops.dcn',
sources=['src/deform_conv_ext.cpp'],
sources_cuda=['src/deform_conv_cuda.cpp', 'src/deform_conv_cuda_kernel.cu']),
make_cuda_ext(
name='fused_act_ext',
module='basicsr.ops.fused_act',
sources=['src/fused_bias_act.cpp'],
sources_cuda=['src/fused_bias_act_kernel.cu']),
make_cuda_ext(
name='upfirdn2d_ext',
module='basicsr.ops.upfirdn2d',
sources=['src/upfirdn2d.cpp'],
sources_cuda=['src/upfirdn2d_kernel.cu']),
]
setup_kwargs = dict(cmdclass={'build_ext': BuildExtension})
else:
ext_modules = []
setup_kwargs = dict()
write_version_py()
setup(
name='basicsr',
version=get_version(),
description='Open Source Image and Video Super-Resolution Toolbox',
long_description=readme(),
long_description_content_type='text/markdown',
author='Xintao Wang',
author_email='xintao.wang@outlook.com',
keywords='computer vision, restoration, super resolution',
url='https://github.com/xinntao/BasicSR',
include_package_data=True,
packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')),
classifiers=[
'Development Status :: 4 - Beta',
'License :: OSI Approved :: Apache Software License',
'Operating System :: OS Independent',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
],
license='Apache License 2.0',
setup_requires=['cython', 'numpy', 'torch'],
install_requires=get_requirements(),
ext_modules=ext_modules,
zip_safe=False,
**setup_kwargs)
import copy
import random
import torch
from torch import nn as nn
class ToyDiscriminator(nn.Module):
def __init__(self):
super(ToyDiscriminator, self).__init__()
self.conv0 = nn.Conv2d(3, 4, 3, 1, 1, bias=True)
self.bn0 = nn.BatchNorm2d(4, affine=True)
self.conv1 = nn.Conv2d(4, 4, 3, 1, 1, bias=True)
self.bn1 = nn.BatchNorm2d(4, affine=True)
self.linear = nn.Linear(4 * 6 * 6, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
feat = self.lrelu(self.bn0(self.conv0(x)))
feat = self.lrelu(self.bn1(self.conv1(feat)))
feat = feat.view(feat.size(0), -1)
out = torch.sigmoid(self.linear(feat))
return out
def main():
# use fixed random seed
manual_seed = 999
random.seed(manual_seed)
torch.manual_seed(manual_seed)
img_real = torch.rand((1, 3, 6, 6))
img_fake = torch.rand((1, 3, 6, 6))
net_d_1 = ToyDiscriminator()
net_d_2 = copy.deepcopy(net_d_1)
net_d_1.train()
net_d_2.train()
criterion = nn.BCELoss()
real_label = 1
fake_label = 0
for k, v in net_d_1.named_parameters():
print(k, v.size())
###########################
# (1) Backward D network twice as the official tutorial does:
# https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
###########################
net_d_1.zero_grad()
# real
output = net_d_1(img_real).view(-1)
label = output.new_ones(output.size()) * real_label
loss_real = criterion(output, label)
loss_real.backward()
# fake
output = net_d_1(img_fake).view(-1)
label = output.new_ones(output.size()) * fake_label
loss_fake = criterion(output, label)
loss_fake.backward()
###########################
# (2) Backward D network once
###########################
net_d_2.zero_grad()
# real
output = net_d_2(img_real).view(-1)
label = output.new_ones(output.size()) * real_label
loss_real = criterion(output, label)
# fake
output = net_d_2(img_fake).view(-1)
label = output.new_ones(output.size()) * fake_label
loss_fake = criterion(output, label)
loss = loss_real + loss_fake
loss.backward()
###########################
# Compare differences
###########################
for k1, k2 in zip(net_d_1.parameters(), net_d_2.parameters()):
print(torch.sum(torch.abs(k1.grad - k2.grad)))
if __name__ == '__main__':
main()
r"""Output:
conv0.weight torch.Size([4, 3, 3, 3])
conv0.bias torch.Size([4])
bn0.weight torch.Size([4])
bn0.bias torch.Size([4])
conv1.weight torch.Size([4, 4, 3, 3])
conv1.bias torch.Size([4])
bn1.weight torch.Size([4])
bn1.bias torch.Size([4])
linear.weight torch.Size([1, 144])
linear.bias torch.Size([1])
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
"""
import math
import os
import torch
import torchvision.utils
from basicsr.data import build_dataloader, build_dataset
def main():
"""Test FFHQ dataset."""
opt = {}
opt['dist'] = False
opt['gpu_ids'] = [0]
opt['phase'] = 'train'
opt['name'] = 'FFHQ'
opt['type'] = 'FFHQDataset'
opt['dataroot_gt'] = 'datasets/ffhq/ffhq_256.lmdb'
opt['io_backend'] = dict(type='lmdb')
opt['use_hflip'] = True
opt['mean'] = [0.5, 0.5, 0.5]
opt['std'] = [0.5, 0.5, 0.5]
opt['num_worker_per_gpu'] = 1
opt['batch_size_per_gpu'] = 4
opt['dataset_enlarge_ratio'] = 1
os.makedirs('tmp', exist_ok=True)
dataset = build_dataset(opt)
data_loader = build_dataloader(dataset, opt, num_gpu=0, dist=opt['dist'], sampler=None)
nrow = int(math.sqrt(opt['batch_size_per_gpu']))
padding = 2 if opt['phase'] == 'train' else 0
print('start...')
for i, data in enumerate(data_loader):
if i > 5:
break
print(i)
gt = data['gt']
print(torch.min(gt), torch.max(gt))
gt_path = data['gt_path']
print(gt_path)
torchvision.utils.save_image(
gt, f'tmp/gt_{i:03d}.png', nrow=nrow, padding=padding, normalize=True, range=(-1, 1))
if __name__ == '__main__':
main()
import torch
from basicsr.models.lr_scheduler import CosineAnnealingRestartLR
try:
import matplotlib as mpl
from matplotlib import pyplot as plt
from matplotlib import ticker as mtick
except ImportError:
print('Please install matplotlib.')
mpl.use('Agg')
def main():
optim_params = [
{
'params': [torch.zeros(3, 64, 3, 3)],
'lr': 4e-4
},
{
'params': [torch.zeros(3, 64, 3, 3)],
'lr': 2e-4
},
]
optimizer = torch.optim.Adam(optim_params, lr=2e-4, weight_decay=0, betas=(0.9, 0.99))
period = [50000, 100000, 150000, 150000, 150000]
restart_weights = [1, 1, 0.5, 1, 0.5]
scheduler = CosineAnnealingRestartLR(
optimizer,
period,
restart_weights=restart_weights,
eta_min=1e-7,
)
# draw figure
total_iter = 600000
lr_l = list(range(total_iter))
lr_l2 = list(range(total_iter))
for i in range(total_iter):
optimizer.step()
scheduler.step()
lr_l[i] = optimizer.param_groups[0]['lr']
lr_l2[i] = optimizer.param_groups[1]['lr']
mpl.style.use('default')
plt.figure(1)
plt.subplot(111)
plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0))
plt.title('Cosine Annealing Restart Learning Rate Scheme', fontsize=16, color='k')
plt.plot(list(range(total_iter)), lr_l, linewidth=1.5, label='learning rate 1')
plt.plot(list(range(total_iter)), lr_l2, linewidth=1.5, label='learning rate 2')
plt.legend(loc='upper right', shadow=False)
ax = plt.gca()
ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e'))
ax.set_ylabel('Learning Rate')
ax.set_xlabel('Iteration')
fig = plt.gcf()
fig.savefig('test_lr_scheduler.png')
if __name__ == '__main__':
main()
import cv2
import warnings
from basicsr.metrics import calculate_niqe
def main():
img_path = 'tests/data/baboon.png'
img = cv2.imread(img_path)
with warnings.catch_warnings():
warnings.simplefilter('ignore', category=RuntimeWarning)
niqe_result = calculate_niqe(img, 0, input_order='HWC', convert_to='y')
print(niqe_result)
if __name__ == '__main__':
main()
import math
import os
import torchvision.utils
from basicsr.data import build_dataloader, build_dataset
def main(mode='folder'):
"""Test paired image dataset.
Args:
mode: There are three modes: 'lmdb', 'folder', 'meta_info_file'.
"""
opt = {}
opt['dist'] = False
opt['phase'] = 'train'
opt['name'] = 'DIV2K'
opt['type'] = 'PairedImageDataset'
if mode == 'folder':
opt['dataroot_gt'] = 'datasets/DIV2K/DIV2K_train_HR_sub'
opt['dataroot_lq'] = 'datasets/DIV2K/DIV2K_train_LR_bicubic/X4_sub'
opt['filename_tmpl'] = '{}'
opt['io_backend'] = dict(type='disk')
elif mode == 'meta_info_file':
opt['dataroot_gt'] = 'datasets/DIV2K/DIV2K_train_HR_sub'
opt['dataroot_lq'] = 'datasets/DIV2K/DIV2K_train_LR_bicubic/X4_sub'
opt['meta_info_file'] = 'basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt' # noqa:E501
opt['filename_tmpl'] = '{}'
opt['io_backend'] = dict(type='disk')
elif mode == 'lmdb':
opt['dataroot_gt'] = 'datasets/DIV2K/DIV2K_train_HR_sub.lmdb'
opt['dataroot_lq'] = 'datasets/DIV2K/DIV2K_train_LR_bicubic_X4_sub.lmdb' # noqa:E501
opt['io_backend'] = dict(type='lmdb')
opt['gt_size'] = 128
opt['use_hflip'] = True
opt['use_rot'] = True
opt['num_worker_per_gpu'] = 2
opt['batch_size_per_gpu'] = 16
opt['scale'] = 4
opt['dataset_enlarge_ratio'] = 1
os.makedirs('tmp', exist_ok=True)
dataset = build_dataset(opt)
data_loader = build_dataloader(dataset, opt, num_gpu=0, dist=opt['dist'], sampler=None)
nrow = int(math.sqrt(opt['batch_size_per_gpu']))
padding = 2 if opt['phase'] == 'train' else 0
print('start...')
for i, data in enumerate(data_loader):
if i > 5:
break
print(i)
lq = data['lq']
gt = data['gt']
lq_path = data['lq_path']
gt_path = data['gt_path']
print(lq_path, gt_path)
torchvision.utils.save_image(lq, f'tmp/lq_{i:03d}.png', nrow=nrow, padding=padding, normalize=False)
torchvision.utils.save_image(gt, f'tmp/gt_{i:03d}.png', nrow=nrow, padding=padding, normalize=False)
if __name__ == '__main__':
main()
import math
import os
import torchvision.utils
from basicsr.data import build_dataloader, build_dataset
def main(mode='folder'):
"""Test reds dataset.
Args:
mode: There are two modes: 'lmdb', 'folder'.
"""
opt = {}
opt['dist'] = False
opt['phase'] = 'train'
opt['name'] = 'REDS'
opt['type'] = 'REDSDataset'
if mode == 'folder':
opt['dataroot_gt'] = 'datasets/REDS/train_sharp'
opt['dataroot_lq'] = 'datasets/REDS/train_sharp_bicubic'
opt['dataroot_flow'] = None
opt['meta_info_file'] = 'basicsr/data/meta_info/meta_info_REDS_GT.txt'
opt['io_backend'] = dict(type='disk')
elif mode == 'lmdb':
opt['dataroot_gt'] = 'datasets/REDS/train_sharp_with_val.lmdb'
opt['dataroot_lq'] = 'datasets/REDS/train_sharp_bicubic_with_val.lmdb'
opt['dataroot_flow'] = None
opt['meta_info_file'] = 'basicsr/data/meta_info/meta_info_REDS_GT.txt'
opt['io_backend'] = dict(type='lmdb')
opt['val_partition'] = 'REDS4'
opt['num_frame'] = 5
opt['gt_size'] = 256
opt['interval_list'] = [1]
opt['random_reverse'] = True
opt['use_hflip'] = True
opt['use_rot'] = True
opt['num_worker_per_gpu'] = 1
opt['batch_size_per_gpu'] = 16
opt['scale'] = 4
opt['dataset_enlarge_ratio'] = 1
os.makedirs('tmp', exist_ok=True)
dataset = build_dataset(opt)
data_loader = build_dataloader(dataset, opt, num_gpu=0, dist=opt['dist'], sampler=None)
nrow = int(math.sqrt(opt['batch_size_per_gpu']))
padding = 2 if opt['phase'] == 'train' else 0
print('start...')
for i, data in enumerate(data_loader):
if i > 5:
break
print(i)
lq = data['lq']
gt = data['gt']
key = data['key']
print(key)
for j in range(opt['num_frame']):
torchvision.utils.save_image(
lq[:, j, :, :, :], f'tmp/lq_{i:03d}_frame{j}.png', nrow=nrow, padding=padding, normalize=False)
torchvision.utils.save_image(gt, f'tmp/gt_{i:03d}.png', nrow=nrow, padding=padding, normalize=False)
if __name__ == '__main__':
main()
import math
import os
import torchvision.utils
from basicsr.data import build_dataloader, build_dataset
def main(mode='folder'):
"""Test vimeo90k dataset.
Args:
mode: There are two modes: 'lmdb', 'folder'.
"""
opt = {}
opt['dist'] = False
opt['phase'] = 'train'
opt['name'] = 'Vimeo90K'
opt['type'] = 'Vimeo90KDataset'
if mode == 'folder':
opt['dataroot_gt'] = 'datasets/vimeo90k/vimeo_septuplet/sequences'
opt['dataroot_lq'] = 'datasets/vimeo90k/vimeo_septuplet_matlabLRx4/sequences' # noqa E501
opt['meta_info_file'] = 'basicsr/data/meta_info/meta_info_Vimeo90K_train_GT.txt' # noqa E501
opt['io_backend'] = dict(type='disk')
elif mode == 'lmdb':
opt['dataroot_gt'] = 'datasets/vimeo90k/vimeo90k_train_GT_only4th.lmdb'
opt['dataroot_lq'] = 'datasets/vimeo90k/vimeo90k_train_LR7frames.lmdb'
opt['meta_info_file'] = 'basicsr/data/meta_info/meta_info_Vimeo90K_train_GT.txt' # noqa E501
opt['io_backend'] = dict(type='lmdb')
opt['num_frame'] = 7
opt['gt_size'] = 256
opt['random_reverse'] = True
opt['use_hflip'] = True
opt['use_rot'] = True
opt['num_worker_per_gpu'] = 1
opt['batch_size_per_gpu'] = 16
opt['scale'] = 4
opt['dataset_enlarge_ratio'] = 1
os.makedirs('tmp', exist_ok=True)
dataset = build_dataset(opt)
data_loader = build_dataloader(dataset, opt, num_gpu=0, dist=opt['dist'], sampler=None)
nrow = int(math.sqrt(opt['batch_size_per_gpu']))
padding = 2 if opt['phase'] == 'train' else 0
print('start...')
for i, data in enumerate(data_loader):
if i > 5:
break
print(i)
lq = data['lq']
gt = data['gt']
key = data['key']
print(key)
for j in range(opt['num_frame']):
torchvision.utils.save_image(
lq[:, j, :, :, :], f'tmp/lq_{i:03d}_frame{j}.png', nrow=nrow, padding=padding, normalize=False)
torchvision.utils.save_image(gt, f'tmp/gt_{i:03d}.png', nrow=nrow, padding=padding, normalize=False)
if __name__ == '__main__':
main()
# UnitTest
- It requires GPU CUDA environment
baboon.png (480,492,3) 1
comic.png (360,240,3) 1
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment