Commit a8ada82f authored by chenych's avatar chenych
Browse files

First commit

parent 537691da
# -*- coding: utf-8 -*-
import numpy as np
import torch
from utils import utils_image as util
import re
import glob
import os
'''
# --------------------------------------------
# Model
# --------------------------------------------
# Kai Zhang (github: https://github.com/cszn)
# 03/Mar/2019
# --------------------------------------------
'''
def find_last_checkpoint(save_dir, net_type='G', pretrained_path=None):
"""
# ---------------------------------------
# Kai Zhang (github: https://github.com/cszn)
# 03/Mar/2019
# ---------------------------------------
Args:
save_dir: model folder
net_type: 'G' or 'D' or 'optimizerG' or 'optimizerD'
pretrained_path: pretrained model path. If save_dir does not have any model, load from pretrained_path
Return:
init_iter: iteration number
init_path: model path
# ---------------------------------------
"""
file_list = glob.glob(os.path.join(save_dir, '*_{}.pth'.format(net_type)))
if file_list:
iter_exist = []
for file_ in file_list:
iter_current = re.findall(r"(\d+)_{}.pth".format(net_type), file_)
iter_exist.append(int(iter_current[0]))
init_iter = max(iter_exist)
init_path = os.path.join(save_dir, '{}_{}.pth'.format(init_iter, net_type))
else:
init_iter = 0
init_path = pretrained_path
return init_iter, init_path
def test_mode(model, L, mode=0, refield=32, min_size=256, sf=1, modulo=1):
'''
# ---------------------------------------
# Kai Zhang (github: https://github.com/cszn)
# 03/Mar/2019
# ---------------------------------------
Args:
model: trained model
L: input Low-quality image
mode:
(0) normal: test(model, L)
(1) pad: test_pad(model, L, modulo=16)
(2) split: test_split(model, L, refield=32, min_size=256, sf=1, modulo=1)
(3) x8: test_x8(model, L, modulo=1) ^_^
(4) split and x8: test_split_x8(model, L, refield=32, min_size=256, sf=1, modulo=1)
refield: effective receptive filed of the network, 32 is enough
useful when split, i.e., mode=2, 4
min_size: min_sizeXmin_size image, e.g., 256X256 image
useful when split, i.e., mode=2, 4
sf: scale factor for super-resolution, otherwise 1
modulo: 1 if split
useful when pad, i.e., mode=1
Returns:
E: estimated image
# ---------------------------------------
'''
if mode == 0:
E = test(model, L)
elif mode == 1:
E = test_pad(model, L, modulo, sf)
elif mode == 2:
E = test_split(model, L, refield, min_size, sf, modulo)
elif mode == 3:
E = test_x8(model, L, modulo, sf)
elif mode == 4:
E = test_split_x8(model, L, refield, min_size, sf, modulo)
return E
'''
# --------------------------------------------
# normal (0)
# --------------------------------------------
'''
def test(model, L):
E = model(L)
return E
'''
# --------------------------------------------
# pad (1)
# --------------------------------------------
'''
def test_pad(model, L, modulo=16, sf=1):
h, w = L.size()[-2:]
paddingBottom = int(np.ceil(h/modulo)*modulo-h)
paddingRight = int(np.ceil(w/modulo)*modulo-w)
L = torch.nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(L)
E = model(L)
E = E[..., :h*sf, :w*sf]
return E
'''
# --------------------------------------------
# split (function)
# --------------------------------------------
'''
def test_split_fn(model, L, refield=32, min_size=256, sf=1, modulo=1):
"""
Args:
model: trained model
L: input Low-quality image
refield: effective receptive filed of the network, 32 is enough
min_size: min_sizeXmin_size image, e.g., 256X256 image
sf: scale factor for super-resolution, otherwise 1
modulo: 1 if split
Returns:
E: estimated result
"""
h, w = L.size()[-2:]
if h*w <= min_size**2:
L = torch.nn.ReplicationPad2d((0, int(np.ceil(w/modulo)*modulo-w), 0, int(np.ceil(h/modulo)*modulo-h)))(L)
E = model(L)
E = E[..., :h*sf, :w*sf]
else:
top = slice(0, (h//2//refield+1)*refield)
bottom = slice(h - (h//2//refield+1)*refield, h)
left = slice(0, (w//2//refield+1)*refield)
right = slice(w - (w//2//refield+1)*refield, w)
Ls = [L[..., top, left], L[..., top, right], L[..., bottom, left], L[..., bottom, right]]
if h * w <= 4*(min_size**2):
Es = [model(Ls[i]) for i in range(4)]
else:
Es = [test_split_fn(model, Ls[i], refield=refield, min_size=min_size, sf=sf, modulo=modulo) for i in range(4)]
b, c = Es[0].size()[:2]
E = torch.zeros(b, c, sf * h, sf * w).type_as(L)
E[..., :h//2*sf, :w//2*sf] = Es[0][..., :h//2*sf, :w//2*sf]
E[..., :h//2*sf, w//2*sf:w*sf] = Es[1][..., :h//2*sf, (-w + w//2)*sf:]
E[..., h//2*sf:h*sf, :w//2*sf] = Es[2][..., (-h + h//2)*sf:, :w//2*sf]
E[..., h//2*sf:h*sf, w//2*sf:w*sf] = Es[3][..., (-h + h//2)*sf:, (-w + w//2)*sf:]
return E
'''
# --------------------------------------------
# split (2)
# --------------------------------------------
'''
def test_split(model, L, refield=32, min_size=256, sf=1, modulo=1):
E = test_split_fn(model, L, refield=refield, min_size=min_size, sf=sf, modulo=modulo)
return E
'''
# --------------------------------------------
# x8 (3)
# --------------------------------------------
'''
def test_x8(model, L, modulo=1, sf=1):
E_list = [test_pad(model, util.augment_img_tensor4(L, mode=i), modulo=modulo, sf=sf) for i in range(8)]
for i in range(len(E_list)):
if i == 3 or i == 5:
E_list[i] = util.augment_img_tensor4(E_list[i], mode=8 - i)
else:
E_list[i] = util.augment_img_tensor4(E_list[i], mode=i)
output_cat = torch.stack(E_list, dim=0)
E = output_cat.mean(dim=0, keepdim=False)
return E
'''
# --------------------------------------------
# split and x8 (4)
# --------------------------------------------
'''
def test_split_x8(model, L, refield=32, min_size=256, sf=1, modulo=1):
E_list = [test_split_fn(model, util.augment_img_tensor4(L, mode=i), refield=refield, min_size=min_size, sf=sf, modulo=modulo) for i in range(8)]
for k, i in enumerate(range(len(E_list))):
if i==3 or i==5:
E_list[k] = util.augment_img_tensor4(E_list[k], mode=8-i)
else:
E_list[k] = util.augment_img_tensor4(E_list[k], mode=i)
output_cat = torch.stack(E_list, dim=0)
E = output_cat.mean(dim=0, keepdim=False)
return E
'''
# ^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-
# _^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^
# ^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-
'''
'''
# --------------------------------------------
# print
# --------------------------------------------
'''
# --------------------------------------------
# print model
# --------------------------------------------
def print_model(model):
msg = describe_model(model)
print(msg)
# --------------------------------------------
# print params
# --------------------------------------------
def print_params(model):
msg = describe_params(model)
print(msg)
'''
# --------------------------------------------
# information
# --------------------------------------------
'''
# --------------------------------------------
# model inforation
# --------------------------------------------
def info_model(model):
msg = describe_model(model)
return msg
# --------------------------------------------
# params inforation
# --------------------------------------------
def info_params(model):
msg = describe_params(model)
return msg
'''
# --------------------------------------------
# description
# --------------------------------------------
'''
# --------------------------------------------
# model name and total number of parameters
# --------------------------------------------
def describe_model(model):
if isinstance(model, torch.nn.DataParallel):
model = model.module
msg = '\n'
msg += 'models name: {}'.format(model.__class__.__name__) + '\n'
msg += 'Params number: {}'.format(sum(map(lambda x: x.numel(), model.parameters()))) + '\n'
msg += 'Net structure:\n{}'.format(str(model)) + '\n'
return msg
# --------------------------------------------
# parameters description
# --------------------------------------------
def describe_params(model):
if isinstance(model, torch.nn.DataParallel):
model = model.module
msg = '\n'
msg += ' | {:^6s} | {:^6s} | {:^6s} | {:^6s} || {:<20s}'.format('mean', 'min', 'max', 'std', 'shape', 'param_name') + '\n'
for name, param in model.state_dict().items():
if not 'num_batches_tracked' in name:
v = param.data.clone().float()
msg += ' | {:>6.3f} | {:>6.3f} | {:>6.3f} | {:>6.3f} | {} || {:s}'.format(v.mean(), v.min(), v.max(), v.std(), v.shape, name) + '\n'
return msg
if __name__ == '__main__':
class Net(torch.nn.Module):
def __init__(self, in_channels=3, out_channels=3):
super(Net, self).__init__()
self.conv = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1)
def forward(self, x):
x = self.conv(x)
return x
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
model = Net()
model = model.eval()
print_model(model)
print_params(model)
x = torch.randn((2,3,401,401))
torch.cuda.empty_cache()
with torch.no_grad():
for mode in range(5):
y = test_mode(model, x, mode, refield=32, min_size=256, sf=1, modulo=1)
print(y.shape)
# run utils/utils_model.py
import torch.nn as nn
import torch
import numpy as np
'''
---- 1) FLOPs: floating point operations
---- 2) #Activations: the number of elements of all ‘Conv2d’ outputs
---- 3) #Conv2d: the number of ‘Conv2d’ layers
# --------------------------------------------
# Kai Zhang (github: https://github.com/cszn)
# 21/July/2020
# --------------------------------------------
# Reference
https://github.com/sovrasov/flops-counter.pytorch.git
# If you use this code, please consider the following citation:
@inproceedings{zhang2020aim, %
title={AIM 2020 Challenge on Efficient Super-Resolution: Methods and Results},
author={Kai Zhang and Martin Danelljan and Yawei Li and Radu Timofte and others},
booktitle={European Conference on Computer Vision Workshops},
year={2020}
}
# --------------------------------------------
'''
def get_model_flops(model, input_res, print_per_layer_stat=True,
input_constructor=None):
assert type(input_res) is tuple, 'Please provide the size of the input image.'
assert len(input_res) >= 3, 'Input image should have 3 dimensions.'
flops_model = add_flops_counting_methods(model)
flops_model.eval().start_flops_count()
if input_constructor:
input = input_constructor(input_res)
_ = flops_model(**input)
else:
device = list(flops_model.parameters())[-1].device
batch = torch.FloatTensor(1, *input_res).to(device)
_ = flops_model(batch)
if print_per_layer_stat:
print_model_with_flops(flops_model)
flops_count = flops_model.compute_average_flops_cost()
flops_model.stop_flops_count()
return flops_count
def get_model_activation(model, input_res, input_constructor=None):
assert type(input_res) is tuple, 'Please provide the size of the input image.'
assert len(input_res) >= 3, 'Input image should have 3 dimensions.'
activation_model = add_activation_counting_methods(model)
activation_model.eval().start_activation_count()
if input_constructor:
input = input_constructor(input_res)
_ = activation_model(**input)
else:
device = list(activation_model.parameters())[-1].device
batch = torch.FloatTensor(1, *input_res).to(device)
_ = activation_model(batch)
activation_count, num_conv = activation_model.compute_average_activation_cost()
activation_model.stop_activation_count()
return activation_count, num_conv
def get_model_complexity_info(model, input_res, print_per_layer_stat=True, as_strings=True,
input_constructor=None):
assert type(input_res) is tuple
assert len(input_res) >= 3
flops_model = add_flops_counting_methods(model)
flops_model.eval().start_flops_count()
if input_constructor:
input = input_constructor(input_res)
_ = flops_model(**input)
else:
batch = torch.FloatTensor(1, *input_res)
_ = flops_model(batch)
if print_per_layer_stat:
print_model_with_flops(flops_model)
flops_count = flops_model.compute_average_flops_cost()
params_count = get_model_parameters_number(flops_model)
flops_model.stop_flops_count()
if as_strings:
return flops_to_string(flops_count), params_to_string(params_count)
return flops_count, params_count
def flops_to_string(flops, units='GMac', precision=2):
if units is None:
if flops // 10**9 > 0:
return str(round(flops / 10.**9, precision)) + ' GMac'
elif flops // 10**6 > 0:
return str(round(flops / 10.**6, precision)) + ' MMac'
elif flops // 10**3 > 0:
return str(round(flops / 10.**3, precision)) + ' KMac'
else:
return str(flops) + ' Mac'
else:
if units == 'GMac':
return str(round(flops / 10.**9, precision)) + ' ' + units
elif units == 'MMac':
return str(round(flops / 10.**6, precision)) + ' ' + units
elif units == 'KMac':
return str(round(flops / 10.**3, precision)) + ' ' + units
else:
return str(flops) + ' Mac'
def params_to_string(params_num):
if params_num // 10 ** 6 > 0:
return str(round(params_num / 10 ** 6, 2)) + ' M'
elif params_num // 10 ** 3:
return str(round(params_num / 10 ** 3, 2)) + ' k'
else:
return str(params_num)
def print_model_with_flops(model, units='GMac', precision=3):
total_flops = model.compute_average_flops_cost()
def accumulate_flops(self):
if is_supported_instance(self):
return self.__flops__ / model.__batch_counter__
else:
sum = 0
for m in self.children():
sum += m.accumulate_flops()
return sum
def flops_repr(self):
accumulated_flops_cost = self.accumulate_flops()
return ', '.join([flops_to_string(accumulated_flops_cost, units=units, precision=precision),
'{:.3%} MACs'.format(accumulated_flops_cost / total_flops),
self.original_extra_repr()])
def add_extra_repr(m):
m.accumulate_flops = accumulate_flops.__get__(m)
flops_extra_repr = flops_repr.__get__(m)
if m.extra_repr != flops_extra_repr:
m.original_extra_repr = m.extra_repr
m.extra_repr = flops_extra_repr
assert m.extra_repr != m.original_extra_repr
def del_extra_repr(m):
if hasattr(m, 'original_extra_repr'):
m.extra_repr = m.original_extra_repr
del m.original_extra_repr
if hasattr(m, 'accumulate_flops'):
del m.accumulate_flops
model.apply(add_extra_repr)
print(model)
model.apply(del_extra_repr)
def get_model_parameters_number(model):
params_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
return params_num
def add_flops_counting_methods(net_main_module):
# adding additional methods to the existing module object,
# this is done this way so that each function has access to self object
# embed()
net_main_module.start_flops_count = start_flops_count.__get__(net_main_module)
net_main_module.stop_flops_count = stop_flops_count.__get__(net_main_module)
net_main_module.reset_flops_count = reset_flops_count.__get__(net_main_module)
net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__(net_main_module)
net_main_module.reset_flops_count()
return net_main_module
def compute_average_flops_cost(self):
"""
A method that will be available after add_flops_counting_methods() is called
on a desired net object.
Returns current mean flops consumption per image.
"""
flops_sum = 0
for module in self.modules():
if is_supported_instance(module):
flops_sum += module.__flops__
return flops_sum
def start_flops_count(self):
"""
A method that will be available after add_flops_counting_methods() is called
on a desired net object.
Activates the computation of mean flops consumption per image.
Call it before you run the network.
"""
self.apply(add_flops_counter_hook_function)
def stop_flops_count(self):
"""
A method that will be available after add_flops_counting_methods() is called
on a desired net object.
Stops computing the mean flops consumption per image.
Call whenever you want to pause the computation.
"""
self.apply(remove_flops_counter_hook_function)
def reset_flops_count(self):
"""
A method that will be available after add_flops_counting_methods() is called
on a desired net object.
Resets statistics computed so far.
"""
self.apply(add_flops_counter_variable_or_reset)
def add_flops_counter_hook_function(module):
if is_supported_instance(module):
if hasattr(module, '__flops_handle__'):
return
if isinstance(module, (nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d)):
handle = module.register_forward_hook(conv_flops_counter_hook)
elif isinstance(module, (nn.ReLU, nn.PReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6)):
handle = module.register_forward_hook(relu_flops_counter_hook)
elif isinstance(module, nn.Linear):
handle = module.register_forward_hook(linear_flops_counter_hook)
elif isinstance(module, (nn.BatchNorm2d)):
handle = module.register_forward_hook(bn_flops_counter_hook)
else:
handle = module.register_forward_hook(empty_flops_counter_hook)
module.__flops_handle__ = handle
def remove_flops_counter_hook_function(module):
if is_supported_instance(module):
if hasattr(module, '__flops_handle__'):
module.__flops_handle__.remove()
del module.__flops_handle__
def add_flops_counter_variable_or_reset(module):
if is_supported_instance(module):
module.__flops__ = 0
# ---- Internal functions
def is_supported_instance(module):
if isinstance(module,
(
nn.Conv2d, nn.ConvTranspose2d,
nn.BatchNorm2d,
nn.Linear,
nn.ReLU, nn.PReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6,
)):
return True
return False
def conv_flops_counter_hook(conv_module, input, output):
# Can have multiple inputs, getting the first one
# input = input[0]
batch_size = output.shape[0]
output_dims = list(output.shape[2:])
kernel_dims = list(conv_module.kernel_size)
in_channels = conv_module.in_channels
out_channels = conv_module.out_channels
groups = conv_module.groups
filters_per_channel = out_channels // groups
conv_per_position_flops = np.prod(kernel_dims) * in_channels * filters_per_channel
active_elements_count = batch_size * np.prod(output_dims)
overall_conv_flops = int(conv_per_position_flops) * int(active_elements_count)
# overall_flops = overall_conv_flops
conv_module.__flops__ += int(overall_conv_flops)
# conv_module.__output_dims__ = output_dims
def relu_flops_counter_hook(module, input, output):
active_elements_count = output.numel()
module.__flops__ += int(active_elements_count)
# print(module.__flops__, id(module))
# print(module)
def linear_flops_counter_hook(module, input, output):
input = input[0]
if len(input.shape) == 1:
batch_size = 1
module.__flops__ += int(batch_size * input.shape[0] * output.shape[0])
else:
batch_size = input.shape[0]
module.__flops__ += int(batch_size * input.shape[1] * output.shape[1])
def bn_flops_counter_hook(module, input, output):
# input = input[0]
# TODO: need to check here
# batch_flops = np.prod(input.shape)
# if module.affine:
# batch_flops *= 2
# module.__flops__ += int(batch_flops)
batch = output.shape[0]
output_dims = output.shape[2:]
channels = module.num_features
batch_flops = batch * channels * np.prod(output_dims)
if module.affine:
batch_flops *= 2
module.__flops__ += int(batch_flops)
# ---- Count the number of convolutional layers and the activation
def add_activation_counting_methods(net_main_module):
# adding additional methods to the existing module object,
# this is done this way so that each function has access to self object
# embed()
net_main_module.start_activation_count = start_activation_count.__get__(net_main_module)
net_main_module.stop_activation_count = stop_activation_count.__get__(net_main_module)
net_main_module.reset_activation_count = reset_activation_count.__get__(net_main_module)
net_main_module.compute_average_activation_cost = compute_average_activation_cost.__get__(net_main_module)
net_main_module.reset_activation_count()
return net_main_module
def compute_average_activation_cost(self):
"""
A method that will be available after add_activation_counting_methods() is called
on a desired net object.
Returns current mean activation consumption per image.
"""
activation_sum = 0
num_conv = 0
for module in self.modules():
if is_supported_instance_for_activation(module):
activation_sum += module.__activation__
num_conv += module.__num_conv__
return activation_sum, num_conv
def start_activation_count(self):
"""
A method that will be available after add_activation_counting_methods() is called
on a desired net object.
Activates the computation of mean activation consumption per image.
Call it before you run the network.
"""
self.apply(add_activation_counter_hook_function)
def stop_activation_count(self):
"""
A method that will be available after add_activation_counting_methods() is called
on a desired net object.
Stops computing the mean activation consumption per image.
Call whenever you want to pause the computation.
"""
self.apply(remove_activation_counter_hook_function)
def reset_activation_count(self):
"""
A method that will be available after add_activation_counting_methods() is called
on a desired net object.
Resets statistics computed so far.
"""
self.apply(add_activation_counter_variable_or_reset)
def add_activation_counter_hook_function(module):
if is_supported_instance_for_activation(module):
if hasattr(module, '__activation_handle__'):
return
if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)):
handle = module.register_forward_hook(conv_activation_counter_hook)
module.__activation_handle__ = handle
def remove_activation_counter_hook_function(module):
if is_supported_instance_for_activation(module):
if hasattr(module, '__activation_handle__'):
module.__activation_handle__.remove()
del module.__activation_handle__
def add_activation_counter_variable_or_reset(module):
if is_supported_instance_for_activation(module):
module.__activation__ = 0
module.__num_conv__ = 0
def is_supported_instance_for_activation(module):
if isinstance(module,
(
nn.Conv2d, nn.ConvTranspose2d,
)):
return True
return False
def conv_activation_counter_hook(module, input, output):
"""
Calculate the activations in the convolutional operation.
Reference: Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr Dollár, Designing Network Design Spaces.
:param module:
:param input:
:param output:
:return:
"""
module.__activation__ += output.numel()
module.__num_conv__ += 1
def empty_flops_counter_hook(module, input, output):
module.__flops__ += 0
def upsample_flops_counter_hook(module, input, output):
output_size = output[0]
batch_size = output_size.shape[0]
output_elements_count = batch_size
for val in output_size.shape[1:]:
output_elements_count *= val
module.__flops__ += int(output_elements_count)
def pool_flops_counter_hook(module, input, output):
input = input[0]
module.__flops__ += int(np.prod(input.shape))
def dconv_flops_counter_hook(dconv_module, input, output):
input = input[0]
batch_size = input.shape[0]
output_dims = list(output.shape[2:])
m_channels, in_channels, kernel_dim1, _, = dconv_module.weight.shape
out_channels, _, kernel_dim2, _, = dconv_module.projection.shape
# groups = dconv_module.groups
# filters_per_channel = out_channels // groups
conv_per_position_flops1 = kernel_dim1 ** 2 * in_channels * m_channels
conv_per_position_flops2 = kernel_dim2 ** 2 * out_channels * m_channels
active_elements_count = batch_size * np.prod(output_dims)
overall_conv_flops = (conv_per_position_flops1 + conv_per_position_flops2) * active_elements_count
overall_flops = overall_conv_flops
dconv_module.__flops__ += int(overall_flops)
# dconv_module.__output_dims__ = output_dims
import os
from collections import OrderedDict
from datetime import datetime
import json
import re
import glob
'''
# --------------------------------------------
# Kai Zhang (github: https://github.com/cszn)
# 03/Mar/2019
# --------------------------------------------
# https://github.com/xinntao/BasicSR
# --------------------------------------------
'''
def get_timestamp():
return datetime.now().strftime('_%y%m%d_%H%M%S')
def parse(opt_path, is_train=True):
# ----------------------------------------
# remove comments starting with '//'
# ----------------------------------------
json_str = ''
with open(opt_path, 'r') as f:
for line in f:
line = line.split('//')[0] + '\n'
json_str += line
# ----------------------------------------
# initialize opt
# ----------------------------------------
opt = json.loads(json_str, object_pairs_hook=OrderedDict)
opt['opt_path'] = opt_path
opt['is_train'] = is_train
# ----------------------------------------
# set default
# ----------------------------------------
if 'merge_bn' not in opt:
opt['merge_bn'] = False
opt['merge_bn_startpoint'] = -1
if 'scale' not in opt:
opt['scale'] = 1
# ----------------------------------------
# datasets
# ----------------------------------------
for phase, dataset in opt['datasets'].items():
phase = phase.split('_')[0]
dataset['phase'] = phase
dataset['scale'] = opt['scale'] # broadcast
dataset['n_channels'] = opt['n_channels'] # broadcast
if 'dataroot_H' in dataset and dataset['dataroot_H'] is not None:
dataset['dataroot_H'] = os.path.expanduser(dataset['dataroot_H'])
if 'dataroot_L' in dataset and dataset['dataroot_L'] is not None:
dataset['dataroot_L'] = os.path.expanduser(dataset['dataroot_L'])
# ----------------------------------------
# path
# ----------------------------------------
for key, path in opt['path'].items():
if path and key in opt['path']:
opt['path'][key] = os.path.expanduser(path)
path_task = os.path.join(opt['path']['root'], opt['task'])
opt['path']['task'] = path_task
opt['path']['log'] = path_task
opt['path']['options'] = os.path.join(path_task, 'options')
if is_train:
opt['path']['models'] = os.path.join(path_task, 'models')
opt['path']['images'] = os.path.join(path_task, 'images')
else: # test
opt['path']['images'] = os.path.join(path_task, 'test_images')
# ----------------------------------------
# network
# ----------------------------------------
opt['netG']['scale'] = opt['scale'] if 'scale' in opt else 1
# ----------------------------------------
# GPU devices
# ----------------------------------------
gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
# ----------------------------------------
# default setting for distributeddataparallel
# ----------------------------------------
if 'find_unused_parameters' not in opt:
opt['find_unused_parameters'] = True
if 'use_static_graph' not in opt:
opt['use_static_graph'] = False
if 'dist' not in opt:
opt['dist'] = False
opt['num_gpu'] = len(opt['gpu_ids'])
print('number of GPUs is: ' + str(opt['num_gpu']))
# ----------------------------------------
# default setting for perceptual loss
# ----------------------------------------
if 'F_feature_layer' not in opt['train']:
opt['train']['F_feature_layer'] = 34 # 25; [2,7,16,25,34]
if 'F_weights' not in opt['train']:
opt['train']['F_weights'] = 1.0 # 1.0; [0.1,0.1,1.0,1.0,1.0]
if 'F_lossfn_type' not in opt['train']:
opt['train']['F_lossfn_type'] = 'l1'
if 'F_use_input_norm' not in opt['train']:
opt['train']['F_use_input_norm'] = True
if 'F_use_range_norm' not in opt['train']:
opt['train']['F_use_range_norm'] = False
# ----------------------------------------
# default setting for optimizer
# ----------------------------------------
if 'G_optimizer_type' not in opt['train']:
opt['train']['G_optimizer_type'] = "adam"
if 'G_optimizer_betas' not in opt['train']:
opt['train']['G_optimizer_betas'] = [0.9,0.999]
if 'G_scheduler_restart_weights' not in opt['train']:
opt['train']['G_scheduler_restart_weights'] = 1
if 'G_optimizer_wd' not in opt['train']:
opt['train']['G_optimizer_wd'] = 0
if 'G_optimizer_reuse' not in opt['train']:
opt['train']['G_optimizer_reuse'] = False
if 'netD' in opt and 'D_optimizer_reuse' not in opt['train']:
opt['train']['D_optimizer_reuse'] = False
# ----------------------------------------
# default setting of strict for model loading
# ----------------------------------------
if 'G_param_strict' not in opt['train']:
opt['train']['G_param_strict'] = True
if 'netD' in opt and 'D_param_strict' not in opt['path']:
opt['train']['D_param_strict'] = True
if 'E_param_strict' not in opt['path']:
opt['train']['E_param_strict'] = True
# ----------------------------------------
# Exponential Moving Average
# ----------------------------------------
if 'E_decay' not in opt['train']:
opt['train']['E_decay'] = 0
# ----------------------------------------
# default setting for discriminator
# ----------------------------------------
if 'netD' in opt:
if 'net_type' not in opt['netD']:
opt['netD']['net_type'] = 'discriminator_patchgan' # discriminator_unet
if 'in_nc' not in opt['netD']:
opt['netD']['in_nc'] = 3
if 'base_nc' not in opt['netD']:
opt['netD']['base_nc'] = 64
if 'n_layers' not in opt['netD']:
opt['netD']['n_layers'] = 3
if 'norm_type' not in opt['netD']:
opt['netD']['norm_type'] = 'spectral'
return opt
def find_last_checkpoint(save_dir, net_type='G', pretrained_path=None):
"""
Args:
save_dir: model folder
net_type: 'G' or 'D' or 'optimizerG' or 'optimizerD'
pretrained_path: pretrained model path. If save_dir does not have any model, load from pretrained_path
Return:
init_iter: iteration number
init_path: model path
"""
file_list = glob.glob(os.path.join(save_dir, '*_{}.pth'.format(net_type)))
if file_list:
iter_exist = []
for file_ in file_list:
iter_current = re.findall(r"(\d+)_{}.pth".format(net_type), file_)
iter_exist.append(int(iter_current[0]))
init_iter = max(iter_exist)
init_path = os.path.join(save_dir, '{}_{}.pth'.format(init_iter, net_type))
else:
init_iter = 0
init_path = pretrained_path
return init_iter, init_path
'''
# --------------------------------------------
# convert the opt into json file
# --------------------------------------------
'''
def save(opt):
opt_path = opt['opt_path']
opt_path_copy = opt['path']['options']
dirname, filename_ext = os.path.split(opt_path)
filename, ext = os.path.splitext(filename_ext)
dump_path = os.path.join(opt_path_copy, filename+get_timestamp()+ext)
with open(dump_path, 'w') as dump_file:
json.dump(opt, dump_file, indent=2)
'''
# --------------------------------------------
# dict to string for logger
# --------------------------------------------
'''
def dict2str(opt, indent_l=1):
msg = ''
for k, v in opt.items():
if isinstance(v, dict):
msg += ' ' * (indent_l * 2) + k + ':[\n'
msg += dict2str(v, indent_l + 1)
msg += ' ' * (indent_l * 2) + ']\n'
else:
msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n'
return msg
'''
# --------------------------------------------
# convert OrderedDict to NoneDict,
# return None for missing key
# --------------------------------------------
'''
def dict_to_nonedict(opt):
if isinstance(opt, dict):
new_opt = dict()
for key, sub_opt in opt.items():
new_opt[key] = dict_to_nonedict(sub_opt)
return NoneDict(**new_opt)
elif isinstance(opt, list):
return [dict_to_nonedict(sub_opt) for sub_opt in opt]
else:
return opt
class NoneDict(dict):
def __missing__(self, key):
return None
import torch
import torchvision
from models import basicblock as B
def show_kv(net):
for k, v in net.items():
print(k)
# should run train debug mode first to get an initial model
#crt_net = torch.load('../../experiments/debug_SRResNet_bicx4_in3nf64nb16/models/8_G.pth')
#
#for k, v in crt_net.items():
# print(k)
#for k, v in crt_net.items():
# if k in pretrained_net:
# crt_net[k] = pretrained_net[k]
# print('replace ... ', k)
# x2 -> x4
#crt_net['model.5.weight'] = pretrained_net['model.2.weight']
#crt_net['model.5.bias'] = pretrained_net['model.2.bias']
#crt_net['model.8.weight'] = pretrained_net['model.5.weight']
#crt_net['model.8.bias'] = pretrained_net['model.5.bias']
#crt_net['model.10.weight'] = pretrained_net['model.7.weight']
#crt_net['model.10.bias'] = pretrained_net['model.7.bias']
#torch.save(crt_net, '../pretrained_tmp.pth')
# x2 -> x3
'''
in_filter = pretrained_net['model.2.weight'] # 256, 64, 3, 3
new_filter = torch.Tensor(576, 64, 3, 3)
new_filter[0:256, :, :, :] = in_filter
new_filter[256:512, :, :, :] = in_filter
new_filter[512:, :, :, :] = in_filter[0:576-512, :, :, :]
crt_net['model.2.weight'] = new_filter
in_bias = pretrained_net['model.2.bias'] # 256, 64, 3, 3
new_bias = torch.Tensor(576)
new_bias[0:256] = in_bias
new_bias[256:512] = in_bias
new_bias[512:] = in_bias[0:576 - 512]
crt_net['model.2.bias'] = new_bias
torch.save(crt_net, '../pretrained_tmp.pth')
'''
# x2 -> x8
'''
crt_net['model.5.weight'] = pretrained_net['model.2.weight']
crt_net['model.5.bias'] = pretrained_net['model.2.bias']
crt_net['model.8.weight'] = pretrained_net['model.2.weight']
crt_net['model.8.bias'] = pretrained_net['model.2.bias']
crt_net['model.11.weight'] = pretrained_net['model.5.weight']
crt_net['model.11.bias'] = pretrained_net['model.5.bias']
crt_net['model.13.weight'] = pretrained_net['model.7.weight']
crt_net['model.13.bias'] = pretrained_net['model.7.bias']
torch.save(crt_net, '../pretrained_tmp.pth')
'''
# x3/4/8 RGB -> Y
def rgb2gray_net(net, only_input=True):
if only_input:
in_filter = net['0.weight']
in_new_filter = in_filter[:,0,:,:]*0.2989 + in_filter[:,1,:,:]*0.587 + in_filter[:,2,:,:]*0.114
in_new_filter.unsqueeze_(1)
net['0.weight'] = in_new_filter
# out_filter = pretrained_net['model.13.weight']
# out_new_filter = out_filter[0, :, :, :] * 0.2989 + out_filter[1, :, :, :] * 0.587 + \
# out_filter[2, :, :, :] * 0.114
# out_new_filter.unsqueeze_(0)
# crt_net['model.13.weight'] = out_new_filter
# out_bias = pretrained_net['model.13.bias']
# out_new_bias = out_bias[0] * 0.2989 + out_bias[1] * 0.587 + out_bias[2] * 0.114
# out_new_bias = torch.Tensor(1).fill_(out_new_bias)
# crt_net['model.13.bias'] = out_new_bias
# torch.save(crt_net, '../pretrained_tmp.pth')
return net
if __name__ == '__main__':
net = torchvision.models.vgg19(pretrained=True)
for k,v in net.features.named_parameters():
if k=='0.weight':
in_new_filter = v[:,0,:,:]*0.2989 + v[:,1,:,:]*0.587 + v[:,2,:,:]*0.114
in_new_filter.unsqueeze_(1)
v = in_new_filter
print(v.shape)
print(v[0,0,0,0])
if k=='0.bias':
in_new_bias = v
print(v[0])
print(net.features[0])
net.features[0] = B.conv(1, 64, mode='C')
print(net.features[0])
net.features[0].weight.data=in_new_filter
net.features[0].bias.data=in_new_bias
for k,v in net.features.named_parameters():
if k=='0.weight':
print(v[0,0,0,0])
if k=='0.bias':
print(v[0])
# transfer parameters of old model to new one
model_old = torch.load(model_path)
state_dict = model.state_dict()
for ((key, param),(key2, param2)) in zip(model_old.items(), state_dict.items()):
state_dict[key2] = param
print([key, key2])
# print([param.size(), param2.size()])
torch.save(state_dict, 'model_new.pth')
# rgb2gray_net(net)
# -*- coding: utf-8 -*-
# online calculation: https://fomoro.com/research/article/receptive-field-calculator#
# [filter size, stride, padding]
#Assume the two dimensions are the same
#Each kernel requires the following parameters:
# - k_i: kernel size
# - s_i: stride
# - p_i: padding (if padding is uneven, right padding will higher than left padding; "SAME" option in tensorflow)
#
#Each layer i requires the following parameters to be fully represented:
# - n_i: number of feature (data layer has n_1 = imagesize )
# - j_i: distance (projected to image pixel distance) between center of two adjacent features
# - r_i: receptive field of a feature in layer i
# - start_i: position of the first feature's receptive field in layer i (idx start from 0, negative means the center fall into padding)
import math
def outFromIn(conv, layerIn):
n_in = layerIn[0]
j_in = layerIn[1]
r_in = layerIn[2]
start_in = layerIn[3]
k = conv[0]
s = conv[1]
p = conv[2]
n_out = math.floor((n_in - k + 2*p)/s) + 1
actualP = (n_out-1)*s - n_in + k
pR = math.ceil(actualP/2)
pL = math.floor(actualP/2)
j_out = j_in * s
r_out = r_in + (k - 1)*j_in
start_out = start_in + ((k-1)/2 - pL)*j_in
return n_out, j_out, r_out, start_out
def printLayer(layer, layer_name):
print(layer_name + ":")
print(" n features: %s jump: %s receptive size: %s start: %s " % (layer[0], layer[1], layer[2], layer[3]))
layerInfos = []
if __name__ == '__main__':
convnet = [[3,1,1],[3,1,1],[3,1,1],[4,2,1],[2,2,0],[3,1,1]]
layer_names = ['conv1','conv2','conv3','conv4','conv5','conv6','conv7','conv8','conv9','conv10','conv11','conv12']
imsize = 128
print ("-------Net summary------")
currentLayer = [imsize, 1, 1, 0.5]
printLayer(currentLayer, "input image")
for i in range(len(convnet)):
currentLayer = outFromIn(convnet[i], currentLayer)
layerInfos.append(currentLayer)
printLayer(currentLayer, layer_names[i])
# run utils/utils_receptivefield.py
\ No newline at end of file
import torch
import torch.nn as nn
'''
# --------------------------------------------
# Kai Zhang (github: https://github.com/cszn)
# 03/Mar/2019
# --------------------------------------------
'''
# --------------------------------------------
# SVD Orthogonal Regularization
# --------------------------------------------
def regularizer_orth(m):
"""
# ----------------------------------------
# SVD Orthogonal Regularization
# ----------------------------------------
# Applies regularization to the training by performing the
# orthogonalization technique described in the paper
# This function is to be called by the torch.nn.Module.apply() method,
# which applies svd_orthogonalization() to every layer of the model.
# usage: net.apply(regularizer_orth)
# ----------------------------------------
"""
classname = m.__class__.__name__
if classname.find('Conv') != -1:
w = m.weight.data.clone()
c_out, c_in, f1, f2 = w.size()
# dtype = m.weight.data.type()
w = w.permute(2, 3, 1, 0).contiguous().view(f1*f2*c_in, c_out)
# self.netG.apply(svd_orthogonalization)
u, s, v = torch.svd(w)
s[s > 1.5] = s[s > 1.5] - 1e-4
s[s < 0.5] = s[s < 0.5] + 1e-4
w = torch.mm(torch.mm(u, torch.diag(s)), v.t())
m.weight.data = w.view(f1, f2, c_in, c_out).permute(3, 2, 0, 1) # .type(dtype)
else:
pass
# --------------------------------------------
# SVD Orthogonal Regularization
# --------------------------------------------
def regularizer_orth2(m):
"""
# ----------------------------------------
# Applies regularization to the training by performing the
# orthogonalization technique described in the paper
# This function is to be called by the torch.nn.Module.apply() method,
# which applies svd_orthogonalization() to every layer of the model.
# usage: net.apply(regularizer_orth2)
# ----------------------------------------
"""
classname = m.__class__.__name__
if classname.find('Conv') != -1:
w = m.weight.data.clone()
c_out, c_in, f1, f2 = w.size()
# dtype = m.weight.data.type()
w = w.permute(2, 3, 1, 0).contiguous().view(f1*f2*c_in, c_out)
u, s, v = torch.svd(w)
s_mean = s.mean()
s[s > 1.5*s_mean] = s[s > 1.5*s_mean] - 1e-4
s[s < 0.5*s_mean] = s[s < 0.5*s_mean] + 1e-4
w = torch.mm(torch.mm(u, torch.diag(s)), v.t())
m.weight.data = w.view(f1, f2, c_in, c_out).permute(3, 2, 0, 1) # .type(dtype)
else:
pass
def regularizer_clip(m):
"""
# ----------------------------------------
# usage: net.apply(regularizer_clip)
# ----------------------------------------
"""
eps = 1e-4
c_min = -1.5
c_max = 1.5
classname = m.__class__.__name__
if classname.find('Conv') != -1 or classname.find('Linear') != -1:
w = m.weight.data.clone()
w[w > c_max] -= eps
w[w < c_min] += eps
m.weight.data = w
if m.bias is not None:
b = m.bias.data.clone()
b[b > c_max] -= eps
b[b < c_min] += eps
m.bias.data = b
# elif classname.find('BatchNorm2d') != -1:
#
# rv = m.running_var.data.clone()
# rm = m.running_mean.data.clone()
#
# if m.affine:
# m.weight.data
# m.bias.data
# -*- coding: utf-8 -*-
from utils import utils_image as util
import random
import scipy
import scipy.stats as ss
import scipy.io as io
from scipy import ndimage
from scipy.interpolate import interp2d
import numpy as np
import torch
"""
# --------------------------------------------
# Super-Resolution
# --------------------------------------------
#
# Kai Zhang (cskaizhang@gmail.com)
# https://github.com/cszn
# modified by Kai Zhang (github: https://github.com/cszn)
# 03/03/2020
# --------------------------------------------
"""
"""
# --------------------------------------------
# anisotropic Gaussian kernels
# --------------------------------------------
"""
def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
""" generate an anisotropic Gaussian kernel
Args:
ksize : e.g., 15, kernel size
theta : [0, pi], rotation angle range
l1 : [0.1,50], scaling of eigenvalues
l2 : [0.1,l1], scaling of eigenvalues
If l1 = l2, will get an isotropic Gaussian kernel.
Returns:
k : kernel
"""
v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
V = np.array([[v[0], v[1]], [v[1], -v[0]]])
D = np.array([[l1, 0], [0, l2]])
Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
return k
def gm_blur_kernel(mean, cov, size=15):
center = size / 2.0 + 0.5
k = np.zeros([size, size])
for y in range(size):
for x in range(size):
cy = y - center + 1
cx = x - center + 1
k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
k = k / np.sum(k)
return k
"""
# --------------------------------------------
# calculate PCA projection matrix
# --------------------------------------------
"""
def get_pca_matrix(x, dim_pca=15):
"""
Args:
x: 225x10000 matrix
dim_pca: 15
Returns:
pca_matrix: 15x225
"""
C = np.dot(x, x.T)
w, v = scipy.linalg.eigh(C)
pca_matrix = v[:, -dim_pca:].T
return pca_matrix
def show_pca(x):
"""
x: PCA projection matrix, e.g., 15x225
"""
for i in range(x.shape[0]):
xc = np.reshape(x[i, :], (int(np.sqrt(x.shape[1])), -1), order="F")
util.surf(xc)
def cal_pca_matrix(path='PCA_matrix.mat', ksize=15, l_max=12.0, dim_pca=15, num_samples=500):
kernels = np.zeros([ksize*ksize, num_samples], dtype=np.float32)
for i in range(num_samples):
theta = np.pi*np.random.rand(1)
l1 = 0.1+l_max*np.random.rand(1)
l2 = 0.1+(l1-0.1)*np.random.rand(1)
k = anisotropic_Gaussian(ksize=ksize, theta=theta[0], l1=l1[0], l2=l2[0])
# util.imshow(k)
kernels[:, i] = np.reshape(k, (-1), order="F") # k.flatten(order='F')
# io.savemat('k.mat', {'k': kernels})
pca_matrix = get_pca_matrix(kernels, dim_pca=dim_pca)
io.savemat(path, {'p': pca_matrix})
return pca_matrix
"""
# --------------------------------------------
# shifted anisotropic Gaussian kernels
# --------------------------------------------
"""
def shifted_anisotropic_Gaussian(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
""""
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
# Kai Zhang
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
# max_var = 2.5 * sf
"""
# Set random eigen-vals (lambdas) and angle (theta) for COV matrix
lambda_1 = min_var + np.random.rand() * (max_var - min_var)
lambda_2 = min_var + np.random.rand() * (max_var - min_var)
theta = np.random.rand() * np.pi # random theta
noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
# Set COV matrix using Lambdas and Theta
LAMBDA = np.diag([lambda_1, lambda_2])
Q = np.array([[np.cos(theta), -np.sin(theta)],
[np.sin(theta), np.cos(theta)]])
SIGMA = Q @ LAMBDA @ Q.T
INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
# Set expectation position (shifting kernel for aligned image)
MU = k_size // 2 - 0.5*(scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
MU = MU[None, None, :, None]
# Create meshgrid for Gaussian
[X,Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
Z = np.stack([X, Y], 2)[:, :, :, None]
# Calcualte Gaussian for every pixel of the kernel
ZZ = Z-MU
ZZ_t = ZZ.transpose(0,1,3,2)
raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
# shift the kernel so it will be centered
#raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
# Normalize the kernel and return
#kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
kernel = raw_kernel / np.sum(raw_kernel)
return kernel
def gen_kernel(k_size=np.array([25, 25]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=12., noise_level=0):
""""
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
# Kai Zhang
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
# max_var = 2.5 * sf
"""
sf = random.choice([1, 2, 3, 4])
scale_factor = np.array([sf, sf])
# Set random eigen-vals (lambdas) and angle (theta) for COV matrix
lambda_1 = min_var + np.random.rand() * (max_var - min_var)
lambda_2 = min_var + np.random.rand() * (max_var - min_var)
theta = np.random.rand() * np.pi # random theta
noise = 0#-noise_level + np.random.rand(*k_size) * noise_level * 2
# Set COV matrix using Lambdas and Theta
LAMBDA = np.diag([lambda_1, lambda_2])
Q = np.array([[np.cos(theta), -np.sin(theta)],
[np.sin(theta), np.cos(theta)]])
SIGMA = Q @ LAMBDA @ Q.T
INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
# Set expectation position (shifting kernel for aligned image)
MU = k_size // 2 - 0.5*(scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
MU = MU[None, None, :, None]
# Create meshgrid for Gaussian
[X,Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
Z = np.stack([X, Y], 2)[:, :, :, None]
# Calcualte Gaussian for every pixel of the kernel
ZZ = Z-MU
ZZ_t = ZZ.transpose(0,1,3,2)
raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
# shift the kernel so it will be centered
#raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
# Normalize the kernel and return
#kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
kernel = raw_kernel / np.sum(raw_kernel)
return kernel
"""
# --------------------------------------------
# degradation models
# --------------------------------------------
"""
def bicubic_degradation(x, sf=3):
'''
Args:
x: HxWxC image, [0, 1]
sf: down-scale factor
Return:
bicubicly downsampled LR image
'''
x = util.imresize_np(x, scale=1/sf)
return x
def srmd_degradation(x, k, sf=3):
''' blur + bicubic downsampling
Args:
x: HxWxC image, [0, 1]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
Reference:
@inproceedings{zhang2018learning,
title={Learning a single convolutional super-resolution network for multiple degradations},
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
pages={3262--3271},
year={2018}
}
'''
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
x = bicubic_degradation(x, sf=sf)
return x
def dpsr_degradation(x, k, sf=3):
''' bicubic downsampling + blur
Args:
x: HxWxC image, [0, 1]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
Reference:
@inproceedings{zhang2019deep,
title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
pages={1671--1681},
year={2019}
}
'''
x = bicubic_degradation(x, sf=sf)
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
return x
def classical_degradation(x, k, sf=3):
''' blur + downsampling
Args:
x: HxWxC image, [0, 1]/[0, 255]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
'''
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
#x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
st = 0
return x[st::sf, st::sf, ...]
def modcrop_np(img, sf):
'''
Args:
img: numpy image, WxH or WxHxC
sf: scale factor
Return:
cropped image
'''
w, h = img.shape[:2]
im = np.copy(img)
return im[:w - w % sf, :h - h % sf, ...]
'''
# =================
# Numpy
# =================
'''
def shift_pixel(x, sf, upper_left=True):
"""shift pixel for super-resolution with different scale factors
Args:
x: WxHxC or WxH, image or kernel
sf: scale factor
upper_left: shift direction
"""
h, w = x.shape[:2]
shift = (sf-1)*0.5
xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
if upper_left:
x1 = xv + shift
y1 = yv + shift
else:
x1 = xv - shift
y1 = yv - shift
x1 = np.clip(x1, 0, w-1)
y1 = np.clip(y1, 0, h-1)
if x.ndim == 2:
x = interp2d(xv, yv, x)(x1, y1)
if x.ndim == 3:
for i in range(x.shape[-1]):
x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
return x
'''
# =================
# pytorch
# =================
'''
def splits(a, sf):
'''
a: tensor NxCxWxHx2
sf: scale factor
out: tensor NxCx(W/sf)x(H/sf)x2x(sf^2)
'''
b = torch.stack(torch.chunk(a, sf, dim=2), dim=5)
b = torch.cat(torch.chunk(b, sf, dim=3), dim=5)
return b
def c2c(x):
return torch.from_numpy(np.stack([np.float32(x.real), np.float32(x.imag)], axis=-1))
def r2c(x):
return torch.stack([x, torch.zeros_like(x)], -1)
def cdiv(x, y):
a, b = x[..., 0], x[..., 1]
c, d = y[..., 0], y[..., 1]
cd2 = c**2 + d**2
return torch.stack([(a*c+b*d)/cd2, (b*c-a*d)/cd2], -1)
def csum(x, y):
return torch.stack([x[..., 0] + y, x[..., 1]], -1)
def cabs(x):
return torch.pow(x[..., 0]**2+x[..., 1]**2, 0.5)
def cmul(t1, t2):
'''
complex multiplication
t1: NxCxHxWx2
output: NxCxHxWx2
'''
real1, imag1 = t1[..., 0], t1[..., 1]
real2, imag2 = t2[..., 0], t2[..., 1]
return torch.stack([real1 * real2 - imag1 * imag2, real1 * imag2 + imag1 * real2], dim=-1)
def cconj(t, inplace=False):
'''
# complex's conjugation
t: NxCxHxWx2
output: NxCxHxWx2
'''
c = t.clone() if not inplace else t
c[..., 1] *= -1
return c
def rfft(t):
return torch.rfft(t, 2, onesided=False)
def irfft(t):
return torch.irfft(t, 2, onesided=False)
def fft(t):
return torch.fft(t, 2)
def ifft(t):
return torch.ifft(t, 2)
def p2o(psf, shape):
'''
Args:
psf: NxCxhxw
shape: [H,W]
Returns:
otf: NxCxHxWx2
'''
otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf)
otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf)
for axis, axis_size in enumerate(psf.shape[2:]):
otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2)
otf = torch.rfft(otf, 2, onesided=False)
n_ops = torch.sum(torch.tensor(psf.shape).type_as(psf) * torch.log2(torch.tensor(psf.shape).type_as(psf)))
otf[...,1][torch.abs(otf[...,1])<n_ops*2.22e-16] = torch.tensor(0).type_as(psf)
return otf
'''
# =================
PyTorch
# =================
'''
def INVLS_pytorch(FB, FBC, F2B, FR, tau, sf=2):
'''
FB: NxCxWxHx2
F2B: NxCxWxHx2
x1 = FB.*FR;
FBR = BlockMM(nr,nc,Nb,m,x1);
invW = BlockMM(nr,nc,Nb,m,F2B);
invWBR = FBR./(invW + tau*Nb);
fun = @(block_struct) block_struct.data.*invWBR;
FCBinvWBR = blockproc(FBC,[nr,nc],fun);
FX = (FR-FCBinvWBR)/tau;
Xest = real(ifft2(FX));
'''
x1 = cmul(FB, FR)
FBR = torch.mean(splits(x1, sf), dim=-1, keepdim=False)
invW = torch.mean(splits(F2B, sf), dim=-1, keepdim=False)
invWBR = cdiv(FBR, csum(invW, tau))
FCBinvWBR = cmul(FBC, invWBR.repeat(1,1,sf,sf,1))
FX = (FR-FCBinvWBR)/tau
Xest = torch.irfft(FX, 2, onesided=False)
return Xest
def real2complex(x):
return torch.stack([x, torch.zeros_like(x)], -1)
def modcrop(img, sf):
'''
img: tensor image, NxCxWxH or CxWxH or WxH
sf: scale factor
'''
w, h = img.shape[-2:]
im = img.clone()
return im[..., :w - w % sf, :h - h % sf]
def upsample(x, sf=3, center=False):
'''
x: tensor image, NxCxWxH
'''
st = (sf-1)//2 if center else 0
z = torch.zeros((x.shape[0], x.shape[1], x.shape[2]*sf, x.shape[3]*sf)).type_as(x)
z[..., st::sf, st::sf].copy_(x)
return z
def downsample(x, sf=3, center=False):
st = (sf-1)//2 if center else 0
return x[..., st::sf, st::sf]
def circular_pad(x, pad):
'''
# x[N, 1, W, H] -> x[N, 1, W + 2 pad, H + 2 pad] (pariodic padding)
'''
x = torch.cat([x, x[:, :, 0:pad, :]], dim=2)
x = torch.cat([x, x[:, :, :, 0:pad]], dim=3)
x = torch.cat([x[:, :, -2 * pad:-pad, :], x], dim=2)
x = torch.cat([x[:, :, :, -2 * pad:-pad], x], dim=3)
return x
def pad_circular(input, padding):
# type: (Tensor, List[int]) -> Tensor
"""
Arguments
:param input: tensor of shape :math:`(N, C_{\text{in}}, H, [W, D]))`
:param padding: (tuple): m-elem tuple where m is the degree of convolution
Returns
:return: tensor of shape :math:`(N, C_{\text{in}}, [D + 2 * padding[0],
H + 2 * padding[1]], W + 2 * padding[2]))`
"""
offset = 3
for dimension in range(input.dim() - offset + 1):
input = dim_pad_circular(input, padding[dimension], dimension + offset)
return input
def dim_pad_circular(input, padding, dimension):
# type: (Tensor, int, int) -> Tensor
input = torch.cat([input, input[[slice(None)] * (dimension - 1) +
[slice(0, padding)]]], dim=dimension - 1)
input = torch.cat([input[[slice(None)] * (dimension - 1) +
[slice(-2 * padding, -padding)]], input], dim=dimension - 1)
return input
def imfilter(x, k):
'''
x: image, NxcxHxW
k: kernel, cx1xhxw
'''
x = pad_circular(x, padding=((k.shape[-2]-1)//2, (k.shape[-1]-1)//2))
x = torch.nn.functional.conv2d(x, k, groups=x.shape[1])
return x
def G(x, k, sf=3, center=False):
'''
x: image, NxcxHxW
k: kernel, cx1xhxw
sf: scale factor
center: the first one or the moddle one
Matlab function:
tmp = imfilter(x,h,'circular');
y = downsample2(tmp,K);
'''
x = downsample(imfilter(x, k), sf=sf, center=center)
return x
def Gt(x, k, sf=3, center=False):
'''
x: image, NxcxHxW
k: kernel, cx1xhxw
sf: scale factor
center: the first one or the moddle one
Matlab function:
tmp = upsample2(x,K);
y = imfilter(tmp,h,'circular');
'''
x = imfilter(upsample(x, sf=sf, center=center), k)
return x
def interpolation_down(x, sf, center=False):
mask = torch.zeros_like(x)
if center:
start = torch.tensor((sf-1)//2)
mask[..., start::sf, start::sf] = torch.tensor(1).type_as(x)
LR = x[..., start::sf, start::sf]
else:
mask[..., ::sf, ::sf] = torch.tensor(1).type_as(x)
LR = x[..., ::sf, ::sf]
y = x.mul(mask)
return LR, y, mask
'''
# =================
Numpy
# =================
'''
def blockproc(im, blocksize, fun):
xblocks = np.split(im, range(blocksize[0], im.shape[0], blocksize[0]), axis=0)
xblocks_proc = []
for xb in xblocks:
yblocks = np.split(xb, range(blocksize[1], im.shape[1], blocksize[1]), axis=1)
yblocks_proc = []
for yb in yblocks:
yb_proc = fun(yb)
yblocks_proc.append(yb_proc)
xblocks_proc.append(np.concatenate(yblocks_proc, axis=1))
proc = np.concatenate(xblocks_proc, axis=0)
return proc
def fun_reshape(a):
return np.reshape(a, (-1,1,a.shape[-1]), order='F')
def fun_mul(a, b):
return a*b
def BlockMM(nr, nc, Nb, m, x1):
'''
myfun = @(block_struct) reshape(block_struct.data,m,1);
x1 = blockproc(x1,[nr nc],myfun);
x1 = reshape(x1,m,Nb);
x1 = sum(x1,2);
x = reshape(x1,nr,nc);
'''
fun = fun_reshape
x1 = blockproc(x1, blocksize=(nr, nc), fun=fun)
x1 = np.reshape(x1, (m, Nb, x1.shape[-1]), order='F')
x1 = np.sum(x1, 1)
x = np.reshape(x1, (nr, nc, x1.shape[-1]), order='F')
return x
def INVLS(FB, FBC, F2B, FR, tau, Nb, nr, nc, m):
'''
x1 = FB.*FR;
FBR = BlockMM(nr,nc,Nb,m,x1);
invW = BlockMM(nr,nc,Nb,m,F2B);
invWBR = FBR./(invW + tau*Nb);
fun = @(block_struct) block_struct.data.*invWBR;
FCBinvWBR = blockproc(FBC,[nr,nc],fun);
FX = (FR-FCBinvWBR)/tau;
Xest = real(ifft2(FX));
'''
x1 = FB*FR
FBR = BlockMM(nr, nc, Nb, m, x1)
invW = BlockMM(nr, nc, Nb, m, F2B)
invWBR = FBR/(invW + tau*Nb)
FCBinvWBR = blockproc(FBC, [nr, nc], lambda im: fun_mul(im, invWBR))
FX = (FR-FCBinvWBR)/tau
Xest = np.real(np.fft.ifft2(FX, axes=(0, 1)))
return Xest
def psf2otf(psf, shape=None):
"""
Convert point-spread function to optical transfer function.
Compute the Fast Fourier Transform (FFT) of the point-spread
function (PSF) array and creates the optical transfer function (OTF)
array that is not influenced by the PSF off-centering.
By default, the OTF array is the same size as the PSF array.
To ensure that the OTF is not altered due to PSF off-centering, PSF2OTF
post-pads the PSF array (down or to the right) with zeros to match
dimensions specified in OUTSIZE, then circularly shifts the values of
the PSF array up (or to the left) until the central pixel reaches (1,1)
position.
Parameters
----------
psf : `numpy.ndarray`
PSF array
shape : int
Output shape of the OTF array
Returns
-------
otf : `numpy.ndarray`
OTF array
Notes
-----
Adapted from MATLAB psf2otf function
"""
if type(shape) == type(None):
shape = psf.shape
shape = np.array(shape)
if np.all(psf == 0):
# return np.zeros_like(psf)
return np.zeros(shape)
if len(psf.shape) == 1:
psf = psf.reshape((1, psf.shape[0]))
inshape = psf.shape
psf = zero_pad(psf, shape, position='corner')
for axis, axis_size in enumerate(inshape):
psf = np.roll(psf, -int(axis_size / 2), axis=axis)
# Compute the OTF
otf = np.fft.fft2(psf, axes=(0, 1))
# Estimate the rough number of operations involved in the FFT
# and discard the PSF imaginary part if within roundoff error
# roundoff error = machine epsilon = sys.float_info.epsilon
# or np.finfo().eps
n_ops = np.sum(psf.size * np.log2(psf.shape))
otf = np.real_if_close(otf, tol=n_ops)
return otf
def zero_pad(image, shape, position='corner'):
"""
Extends image to a certain size with zeros
Parameters
----------
image: real 2d `numpy.ndarray`
Input image
shape: tuple of int
Desired output shape of the image
position : str, optional
The position of the input image in the output one:
* 'corner'
top-left corner (default)
* 'center'
centered
Returns
-------
padded_img: real `numpy.ndarray`
The zero-padded image
"""
shape = np.asarray(shape, dtype=int)
imshape = np.asarray(image.shape, dtype=int)
if np.alltrue(imshape == shape):
return image
if np.any(shape <= 0):
raise ValueError("ZERO_PAD: null or negative shape given")
dshape = shape - imshape
if np.any(dshape < 0):
raise ValueError("ZERO_PAD: target size smaller than source one")
pad_img = np.zeros(shape, dtype=image.dtype)
idx, idy = np.indices(imshape)
if position == 'center':
if np.any(dshape % 2 != 0):
raise ValueError("ZERO_PAD: source and target shapes "
"have different parity.")
offx, offy = dshape // 2
else:
offx, offy = (0, 0)
pad_img[idx + offx, idy + offy] = image
return pad_img
def upsample_np(x, sf=3, center=False):
st = (sf-1)//2 if center else 0
z = np.zeros((x.shape[0]*sf, x.shape[1]*sf, x.shape[2]))
z[st::sf, st::sf, ...] = x
return z
def downsample_np(x, sf=3, center=False):
st = (sf-1)//2 if center else 0
return x[st::sf, st::sf, ...]
def imfilter_np(x, k):
'''
x: image, NxcxHxW
k: kernel, cx1xhxw
'''
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
return x
def G_np(x, k, sf=3, center=False):
'''
x: image, NxcxHxW
k: kernel, cx1xhxw
Matlab function:
tmp = imfilter(x,h,'circular');
y = downsample2(tmp,K);
'''
x = downsample_np(imfilter_np(x, k), sf=sf, center=center)
return x
def Gt_np(x, k, sf=3, center=False):
'''
x: image, NxcxHxW
k: kernel, cx1xhxw
Matlab function:
tmp = upsample2(x,K);
y = imfilter(tmp,h,'circular');
'''
x = imfilter_np(upsample_np(x, sf=sf, center=center), k)
return x
if __name__ == '__main__':
img = util.imread_uint('test.bmp', 3)
img = util.uint2single(img)
k = anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6)
util.imshow(k*10)
for sf in [2, 3, 4]:
# modcrop
img = modcrop_np(img, sf=sf)
# 1) bicubic degradation
img_b = bicubic_degradation(img, sf=sf)
print(img_b.shape)
# 2) srmd degradation
img_s = srmd_degradation(img, k, sf=sf)
print(img_s.shape)
# 3) dpsr degradation
img_d = dpsr_degradation(img, k, sf=sf)
print(img_d.shape)
# 4) classical degradation
img_d = classical_degradation(img, k, sf=sf)
print(img_d.shape)
k = anisotropic_Gaussian(ksize=7, theta=0.25*np.pi, l1=0.01, l2=0.01)
#print(k)
# util.imshow(k*10)
k = shifted_anisotropic_Gaussian(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.8, max_var=10.8, noise_level=0.0)
# util.imshow(k*10)
# PCA
# pca_matrix = cal_pca_matrix(ksize=15, l_max=10.0, dim_pca=15, num_samples=12500)
# print(pca_matrix.shape)
# show_pca(pca_matrix)
# run utils/utils_sisr.py
# run utils_sisr.py
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment