Commit a8ada82f authored by chenych's avatar chenych
Browse files

First commit

parent 537691da
# 模型唯一标识
modelCode=xxx
# 模型名称
modelName=maskeddenoising_pytorch
# 模型描述
modelDescription=maskeddenoising_pytorch在训练过程中对输入图像的随机像素进行掩蔽,并在训练过程中重建缺失的信息。同时,还在自注意力层中掩蔽特征,以避免训练和测试不一致性的影响。
# 应用场景
appScenario=推理,训练,图像降噪,教育,交通,公安
# 框架类型
frameType=PyTorch
{
"task": "baseline" // real-world image sr. root/task/images-models-options
, "model": "plain" // "plain" | "plain2" if two inputs
, "gpu_ids": [0]
, "dist": false
, "scale": 1 // broadcast to "datasets"
, "n_channels": 3 // broadcast to "datasets", 1 for grayscale, 3 for color
, "path": {
"root": "masked_denoising" // "denoising" | "superresolution" | "dejpeg"
, "pretrained_netG": null // path of pretrained model
, "pretrained_netE": null // path of pretrained model
}
, "datasets": {
"train": {
"name": "train_dataset" // just name
, "dataset_type": "masked_denoising" // "dncnn" | "dnpatch" | "fdncnn" | "ffdnet" | "sr" | "srmd" | "dpsr" | "plain" | "plainpatch" | "jpeg" | "masked_denoising"
, "dataroot_H": "trainsets/trainH" // path of H training dataset. DIV2K (800 training images) + Flickr2K (2650 images) + + OST (10324 images)
, "dataroot_L": null // path of L training dataset
, "H_size": 64 // patch_size 256 | 288 | 320 (256)
, "lq_patchsize": 64 // (64)
, "dataloader_shuffle": true
, "dataloader_num_workers": 2
, "dataloader_batch_size": 16 // batch size 1 | 16 | 32 | 48 | 64 | 128. Total batch size =4x8=32 in SwinIR (32)
, "noise_level": 15
, "if_mask": false
, "mask1": 75
, "mask2": 75
}
, "test": {
"name": "test_dataset" // just name
, "dataset_type": "plain" // "dncnn" | "dnpatch" | "fdncnn" | "ffdnet" | "sr" | "srmd" | "dpsr" | "plain" | "plainpatch" | "jpeg"
, "dataroot_H": "testset/McM/HR" // path of H testing dataset
, "dataroot_L": "testset/McM/McM_poisson_20" // path of L testing dataset
}
}
, "netG": {
"net_type": "swinir"
, "upscale": 1
, "in_chans": 3
, "img_size": 64
, "window_size": 8 // 8 !!!!!!!!!!!!!!!
, "img_range": 1.0
, "depths": [6, 6, 6, 6]
, "embed_dim": 60
, "num_heads": [6, 6, 6, 6]
, "mlp_ratio": 2
, "upsampler": null // "pixelshuffle" | "pixelshuffledirect" | "nearest+conv" | null
, "resi_connection": "3conv" // "1conv" | "3conv"
, "init_type": "default"
, "talking_heads": false
, "attn_fn": "softmax" // null | "softmax" | "entmax15" |
, "head_scale": false
, "on_attn": false
, "use_mask": false // if use attention mask
, "mask_ratio1": 75 // attention mask ratio,
, "mask_ratio2": 75 // randomly sampling from [mask_ratio1, mask_ratio2]
, "mask_is_diff": false
, "type": "stand"
}
, "train": {
"manual_seed": 1
, "G_lossfn_type": "l1" // "l1" preferred | "l2sum" | "l2" | "ssim" | "charbonnier"
, "G_lossfn_weight": 1.0 // default
, "E_decay": 0.999 // Exponential Moving Average for netG: set 0 to disable; default setting 0.999
, "G_optimizer_type": "adam" // fixed, adam is enough
, "G_optimizer_lr": 1e-4 // 2e-4 // learning rate
, "G_optimizer_wd": 0 // weight decay, default 0
, "G_optimizer_clipgrad": null // unused
, "G_optimizer_reuse": true //
, "G_scheduler_type": "MultiStepLR" // "MultiStepLR" is enough
, "G_scheduler_milestones": [] // [250000, 400000, 450000, 475000, 500000]
, "G_scheduler_gamma": 0.5
, "G_regularizer_orthstep": null // unused
, "G_regularizer_clipstep": null // unused
, "G_param_strict": true
, "E_param_strict": true
, "checkpoint_test": 5000 // for testing (5000)
, "checkpoint_save": 5000 // for saving model
, "checkpoint_print": 100 // for print
, "save_image": ["img_043_x1", "img_021_x1", "img_024_x1", "img_031_x1", "img_041_x1", "img_032_x1"] // [250000, 400000, 450000, 475000, 500000]
}
}
{
"task": "80_90" // real-world image sr. root/task/images-models-options
, "model": "plain" // "plain" | "plain2" if two inputs
// , "gpu_ids": [0,1,2,3]
, "gpu_ids": [0]
, "dist": false
, "scale": 1 // broadcast to "datasets"
, "n_channels": 3 // broadcast to "datasets", 1 for grayscale, 3 for color
, "path": {
"root": "masked_denoising" // "denoising" | "superresolution" | "dejpeg"
, "pretrained_netG": null // path of pretrained model
, "pretrained_netE": null // path of pretrained model
},
"datasets": {
"train": {
"name": "train_dataset" // just name
, "dataset_type": "masked_denoising" // "dncnn" | "dnpatch" | "fdncnn" | "ffdnet" | "sr" | "srmd" | "dpsr" | "plain" | "plainpatch" | "jpeg" | "masked_denoising"
, "dataroot_H": "trainsets/trainH" // path of H training dataset. DIV2K (800 training images) + Flickr2K (2650 images) + + OST (10324 images)
, "dataroot_L": null // path of L training dataset
, "H_size": 64 // patch_size 256 | 288 | 320 (256)
, "lq_patchsize": 64 // (64)
, "dataloader_shuffle": true
, "dataloader_num_workers": 16
, "dataloader_batch_size": 64 // batch size 1 | 16 | 32 | 48 | 64 | 128. Total batch size =4x8=32 in SwinIR (32)
, "noise_level": 15
, "if_mask": true
, "mask1": 80
, "mask2": 90
}
, "test": {
"name": "test_dataset" // just name
, "dataset_type": "plain" // "dncnn" | "dnpatch" | "fdncnn" | "ffdnet" | "sr" | "srmd" | "dpsr" | "plain" | "plainpatch" | "jpeg"
, "dataroot_H": "testset/McM/HR" // path of H testing dataset
, "dataroot_L": "testset/McM/McM_poisson_20" // path of L testing dataset
}
},
"netG": {
"net_type": "swinir"
, "upscale": 1
, "in_chans": 3
, "img_size": 64
, "window_size": 8 // 8 !!!!!!!!!!!!!!!
, "img_range": 1.0
, "depths": [6, 6, 6, 6]
, "embed_dim": 60
, "num_heads": [6, 6, 6, 6]
, "mlp_ratio": 2
, "upsampler": null // "pixelshuffle" | "pixelshuffledirect" | "nearest+conv" | null
, "resi_connection": "3conv" // "1conv" | "3conv"
, "init_type": "default"
, "talking_heads": false
, "attn_fn": "softmax" // null | "softmax" | "entmax15" |
, "head_scale": false
, "on_attn": false
, "use_mask": true // if use attention mask
, "mask_ratio1": 75 // attention mask ratio,
, "mask_ratio2": 75 // randomly sampling from [mask_ratio1, mask_ratio2]
, "mask_is_diff": false
, "type": "stand"
},
"train": {
"manual_seed": 1
, "G_lossfn_type": "l1" // "l1" preferred | "l2sum" | "l2" | "ssim" | "charbonnier"
, "G_lossfn_weight": 1.0 // default
, "E_decay": 0.999 // Exponential Moving Average for netG: set 0 to disable; default setting 0.999
, "G_optimizer_type": "adam" // fixed, adam is enough
, "G_optimizer_lr": 1e-4 // 2e-4 // learning rate
, "G_optimizer_wd": 0 // weight decay, default 0
, "G_optimizer_clipgrad": null // unused
, "G_optimizer_reuse": true //
, "G_scheduler_type": "MultiStepLR" // "MultiStepLR" is enough
, "G_scheduler_milestones": [150000, 200000, 300000, 350000, 400000] // [250000, 400000, 450000, 475000, 500000]
, "G_scheduler_gamma": 0.5
, "G_regularizer_orthstep": null // unused
, "G_regularizer_clipstep": null // unused
, "G_param_strict": true
, "E_param_strict": true
, "checkpoint_test": 5000 // for testing (5000)
, "checkpoint_save": 5000 // for saving model
, "checkpoint_print": 100 // for print
, "save_image": ["img_043_x1", "img_021_x1", "img_024_x1", "img_031_x1", "img_041_x1", "img_032_x1"] // [250000, 400000, 450000, 475000, 500000]
}
}
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
'''
# --------------------------------------------
# Advanced nn.Sequential
# https://github.com/xinntao/BasicSR
# --------------------------------------------
'''
def sequential(*args):
"""Advanced nn.Sequential.
Args:
nn.Sequential, nn.Module
Returns:
nn.Sequential
"""
if len(args) == 1:
if isinstance(args[0], OrderedDict):
raise NotImplementedError('sequential does not support OrderedDict input.')
return args[0] # No sequential is needed.
modules = []
for module in args:
if isinstance(module, nn.Sequential):
for submodule in module.children():
modules.append(submodule)
elif isinstance(module, nn.Module):
modules.append(module)
return nn.Sequential(*modules)
'''
# --------------------------------------------
# Useful blocks
# https://github.com/xinntao/BasicSR
# --------------------------------
# conv + normaliation + relu (conv)
# (PixelUnShuffle)
# (ConditionalBatchNorm2d)
# concat (ConcatBlock)
# sum (ShortcutBlock)
# resblock (ResBlock)
# Channel Attention (CA) Layer (CALayer)
# Residual Channel Attention Block (RCABlock)
# Residual Channel Attention Group (RCAGroup)
# Residual Dense Block (ResidualDenseBlock_5C)
# Residual in Residual Dense Block (RRDB)
# --------------------------------------------
'''
# --------------------------------------------
# return nn.Sequantial of (Conv + BN + ReLU)
# --------------------------------------------
def conv(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='CBR', negative_slope=0.2):
L = []
for t in mode:
if t == 'C':
L.append(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias))
elif t == 'T':
L.append(nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias))
elif t == 'B':
L.append(nn.BatchNorm2d(out_channels, momentum=0.9, eps=1e-04, affine=True))
elif t == 'I':
L.append(nn.InstanceNorm2d(out_channels, affine=True))
elif t == 'R':
L.append(nn.ReLU(inplace=True))
elif t == 'r':
L.append(nn.ReLU(inplace=False))
elif t == 'L':
L.append(nn.LeakyReLU(negative_slope=negative_slope, inplace=True))
elif t == 'l':
L.append(nn.LeakyReLU(negative_slope=negative_slope, inplace=False))
elif t == '2':
L.append(nn.PixelShuffle(upscale_factor=2))
elif t == '3':
L.append(nn.PixelShuffle(upscale_factor=3))
elif t == '4':
L.append(nn.PixelShuffle(upscale_factor=4))
elif t == 'U':
L.append(nn.Upsample(scale_factor=2, mode='nearest'))
elif t == 'u':
L.append(nn.Upsample(scale_factor=3, mode='nearest'))
elif t == 'v':
L.append(nn.Upsample(scale_factor=4, mode='nearest'))
elif t == 'M':
L.append(nn.MaxPool2d(kernel_size=kernel_size, stride=stride, padding=0))
elif t == 'A':
L.append(nn.AvgPool2d(kernel_size=kernel_size, stride=stride, padding=0))
else:
raise NotImplementedError('Undefined type: '.format(t))
return sequential(*L)
# --------------------------------------------
# inverse of pixel_shuffle
# --------------------------------------------
def pixel_unshuffle(input, upscale_factor):
r"""Rearranges elements in a Tensor of shape :math:`(C, rH, rW)` to a
tensor of shape :math:`(*, r^2C, H, W)`.
Authors:
Zhaoyi Yan, https://github.com/Zhaoyi-Yan
Kai Zhang, https://github.com/cszn/FFDNet
Date:
01/Jan/2019
"""
batch_size, channels, in_height, in_width = input.size()
out_height = in_height // upscale_factor
out_width = in_width // upscale_factor
input_view = input.contiguous().view(
batch_size, channels, out_height, upscale_factor,
out_width, upscale_factor)
channels *= upscale_factor ** 2
unshuffle_out = input_view.permute(0, 1, 3, 5, 2, 4).contiguous()
return unshuffle_out.view(batch_size, channels, out_height, out_width)
class PixelUnShuffle(nn.Module):
r"""Rearranges elements in a Tensor of shape :math:`(C, rH, rW)` to a
tensor of shape :math:`(*, r^2C, H, W)`.
Authors:
Zhaoyi Yan, https://github.com/Zhaoyi-Yan
Kai Zhang, https://github.com/cszn/FFDNet
Date:
01/Jan/2019
"""
def __init__(self, upscale_factor):
super(PixelUnShuffle, self).__init__()
self.upscale_factor = upscale_factor
def forward(self, input):
return pixel_unshuffle(input, self.upscale_factor)
def extra_repr(self):
return 'upscale_factor={}'.format(self.upscale_factor)
# --------------------------------------------
# conditional batch norm
# https://github.com/pytorch/pytorch/issues/8985#issuecomment-405080775
# --------------------------------------------
class ConditionalBatchNorm2d(nn.Module):
def __init__(self, num_features, num_classes):
super().__init__()
self.num_features = num_features
self.bn = nn.BatchNorm2d(num_features, affine=False)
self.embed = nn.Embedding(num_classes, num_features * 2)
self.embed.weight.data[:, :num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02)
self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
def forward(self, x, y):
out = self.bn(x)
gamma, beta = self.embed(y).chunk(2, 1)
out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1)
return out
# --------------------------------------------
# Concat the output of a submodule to its input
# --------------------------------------------
class ConcatBlock(nn.Module):
def __init__(self, submodule):
super(ConcatBlock, self).__init__()
self.sub = submodule
def forward(self, x):
output = torch.cat((x, self.sub(x)), dim=1)
return output
def __repr__(self):
return self.sub.__repr__() + 'concat'
# --------------------------------------------
# sum the output of a submodule to its input
# --------------------------------------------
class ShortcutBlock(nn.Module):
def __init__(self, submodule):
super(ShortcutBlock, self).__init__()
self.sub = submodule
def forward(self, x):
output = x + self.sub(x)
return output
def __repr__(self):
tmpstr = 'Identity + \n|'
modstr = self.sub.__repr__().replace('\n', '\n|')
tmpstr = tmpstr + modstr
return tmpstr
# --------------------------------------------
# Res Block: x + conv(relu(conv(x)))
# --------------------------------------------
class ResBlock(nn.Module):
def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='CRC', negative_slope=0.2):
super(ResBlock, self).__init__()
assert in_channels == out_channels, 'Only support in_channels==out_channels.'
if mode[0] in ['R', 'L']:
mode = mode[0].lower() + mode[1:]
self.res = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope)
def forward(self, x):
res = self.res(x)
return x + res
# --------------------------------------------
# simplified information multi-distillation block (IMDB)
# x + conv1(concat(split(relu(conv(x)))x3))
# --------------------------------------------
class IMDBlock(nn.Module):
"""
@inproceedings{hui2019lightweight,
title={Lightweight Image Super-Resolution with Information Multi-distillation Network},
author={Hui, Zheng and Gao, Xinbo and Yang, Yunchu and Wang, Xiumei},
booktitle={Proceedings of the 27th ACM International Conference on Multimedia (ACM MM)},
pages={2024--2032},
year={2019}
}
@inproceedings{zhang2019aim,
title={AIM 2019 Challenge on Constrained Super-Resolution: Methods and Results},
author={Kai Zhang and Shuhang Gu and Radu Timofte and others},
booktitle={IEEE International Conference on Computer Vision Workshops},
year={2019}
}
"""
def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='CL', d_rate=0.25, negative_slope=0.05):
super(IMDBlock, self).__init__()
self.d_nc = int(in_channels * d_rate)
self.r_nc = int(in_channels - self.d_nc)
assert mode[0] == 'C', 'convolutional layer first'
self.conv1 = conv(in_channels, in_channels, kernel_size, stride, padding, bias, mode, negative_slope)
self.conv2 = conv(self.r_nc, in_channels, kernel_size, stride, padding, bias, mode, negative_slope)
self.conv3 = conv(self.r_nc, in_channels, kernel_size, stride, padding, bias, mode, negative_slope)
self.conv4 = conv(self.r_nc, self.d_nc, kernel_size, stride, padding, bias, mode[0], negative_slope)
self.conv1x1 = conv(self.d_nc*4, out_channels, kernel_size=1, stride=1, padding=0, bias=bias, mode=mode[0], negative_slope=negative_slope)
def forward(self, x):
d1, r1 = torch.split(self.conv1(x), (self.d_nc, self.r_nc), dim=1)
d2, r2 = torch.split(self.conv2(r1), (self.d_nc, self.r_nc), dim=1)
d3, r3 = torch.split(self.conv3(r2), (self.d_nc, self.r_nc), dim=1)
d4 = self.conv4(r3)
res = self.conv1x1(torch.cat((d1, d2, d3, d4), dim=1))
return x + res
# --------------------------------------------
# Enhanced Spatial Attention (ESA)
# --------------------------------------------
class ESA(nn.Module):
def __init__(self, channel=64, reduction=4, bias=True):
super(ESA, self).__init__()
# -->conv3x3(conv21)-----------------------------------------------------------------------------------------+
# conv1x1(conv1)-->conv3x3-2(conv2)-->maxpool7-3-->conv3x3(conv3)(relu)-->conv3x3(conv4)(relu)-->conv3x3(conv5)-->bilinear--->conv1x1(conv6)-->sigmoid
self.r_nc = channel // reduction
self.conv1 = nn.Conv2d(channel, self.r_nc, kernel_size=1)
self.conv21 = nn.Conv2d(self.r_nc, self.r_nc, kernel_size=1)
self.conv2 = nn.Conv2d(self.r_nc, self.r_nc, kernel_size=3, stride=2, padding=0)
self.conv3 = nn.Conv2d(self.r_nc, self.r_nc, kernel_size=3, padding=1)
self.conv4 = nn.Conv2d(self.r_nc, self.r_nc, kernel_size=3, padding=1)
self.conv5 = nn.Conv2d(self.r_nc, self.r_nc, kernel_size=3, padding=1)
self.conv6 = nn.Conv2d(self.r_nc, channel, kernel_size=1)
self.sigmoid = nn.Sigmoid()
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x1 = self.conv1(x)
x2 = F.max_pool2d(self.conv2(x1), kernel_size=7, stride=3) # 1/6
x2 = self.relu(self.conv3(x2))
x2 = self.relu(self.conv4(x2))
x2 = F.interpolate(self.conv5(x2), (x.size(2), x.size(3)), mode='bilinear', align_corners=False)
x2 = self.conv6(x2 + self.conv21(x1))
return x.mul(self.sigmoid(x2))
# return x.mul_(self.sigmoid(x2))
class CFRB(nn.Module):
def __init__(self, in_channels=50, out_channels=50, kernel_size=3, stride=1, padding=1, bias=True, mode='CL', d_rate=0.5, negative_slope=0.05):
super(CFRB, self).__init__()
self.d_nc = int(in_channels * d_rate)
self.r_nc = in_channels # int(in_channels - self.d_nc)
assert mode[0] == 'C', 'convolutional layer first'
self.conv1_d = conv(in_channels, self.d_nc, kernel_size=1, stride=1, padding=0, bias=bias, mode=mode[0])
self.conv1_r = conv(in_channels, self.r_nc, kernel_size, stride, padding, bias=bias, mode=mode[0])
self.conv2_d = conv(self.r_nc, self.d_nc, kernel_size=1, stride=1, padding=0, bias=bias, mode=mode[0])
self.conv2_r = conv(self.r_nc, self.r_nc, kernel_size, stride, padding, bias=bias, mode=mode[0])
self.conv3_d = conv(self.r_nc, self.d_nc, kernel_size=1, stride=1, padding=0, bias=bias, mode=mode[0])
self.conv3_r = conv(self.r_nc, self.r_nc, kernel_size, stride, padding, bias=bias, mode=mode[0])
self.conv4_d = conv(self.r_nc, self.d_nc, kernel_size, stride, padding, bias=bias, mode=mode[0])
self.conv1x1 = conv(self.d_nc*4, out_channels, kernel_size=1, stride=1, padding=0, bias=bias, mode=mode[0])
self.act = conv(mode=mode[-1], negative_slope=negative_slope)
self.esa = ESA(in_channels, reduction=4, bias=True)
def forward(self, x):
d1 = self.conv1_d(x)
x = self.act(self.conv1_r(x)+x)
d2 = self.conv2_d(x)
x = self.act(self.conv2_r(x)+x)
d3 = self.conv3_d(x)
x = self.act(self.conv3_r(x)+x)
x = self.conv4_d(x)
x = self.act(torch.cat([d1, d2, d3, x], dim=1))
x = self.esa(self.conv1x1(x))
return x
# --------------------------------------------
# Channel Attention (CA) Layer
# --------------------------------------------
class CALayer(nn.Module):
def __init__(self, channel=64, reduction=16):
super(CALayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv_fc = nn.Sequential(
nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
nn.Sigmoid()
)
def forward(self, x):
y = self.avg_pool(x)
y = self.conv_fc(y)
return x * y
# --------------------------------------------
# Residual Channel Attention Block (RCAB)
# --------------------------------------------
class RCABlock(nn.Module):
def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='CRC', reduction=16, negative_slope=0.2):
super(RCABlock, self).__init__()
assert in_channels == out_channels, 'Only support in_channels==out_channels.'
if mode[0] in ['R','L']:
mode = mode[0].lower() + mode[1:]
self.res = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope)
self.ca = CALayer(out_channels, reduction)
def forward(self, x):
res = self.res(x)
res = self.ca(res)
return res + x
# --------------------------------------------
# Residual Channel Attention Group (RG)
# --------------------------------------------
class RCAGroup(nn.Module):
def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='CRC', reduction=16, nb=12, negative_slope=0.2):
super(RCAGroup, self).__init__()
assert in_channels == out_channels, 'Only support in_channels==out_channels.'
if mode[0] in ['R','L']:
mode = mode[0].lower() + mode[1:]
RG = [RCABlock(in_channels, out_channels, kernel_size, stride, padding, bias, mode, reduction, negative_slope) for _ in range(nb)]
RG.append(conv(out_channels, out_channels, mode='C'))
self.rg = nn.Sequential(*RG) # self.rg = ShortcutBlock(nn.Sequential(*RG))
def forward(self, x):
res = self.rg(x)
return res + x
# --------------------------------------------
# Residual Dense Block
# style: 5 convs
# --------------------------------------------
class ResidualDenseBlock_5C(nn.Module):
def __init__(self, nc=64, gc=32, kernel_size=3, stride=1, padding=1, bias=True, mode='CR', negative_slope=0.2):
super(ResidualDenseBlock_5C, self).__init__()
# gc: growth channel
self.conv1 = conv(nc, gc, kernel_size, stride, padding, bias, mode, negative_slope)
self.conv2 = conv(nc+gc, gc, kernel_size, stride, padding, bias, mode, negative_slope)
self.conv3 = conv(nc+2*gc, gc, kernel_size, stride, padding, bias, mode, negative_slope)
self.conv4 = conv(nc+3*gc, gc, kernel_size, stride, padding, bias, mode, negative_slope)
self.conv5 = conv(nc+4*gc, nc, kernel_size, stride, padding, bias, mode[:-1], negative_slope)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(torch.cat((x, x1), 1))
x3 = self.conv3(torch.cat((x, x1, x2), 1))
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
return x5.mul_(0.2) + x
# --------------------------------------------
# Residual in Residual Dense Block
# 3x5c
# --------------------------------------------
class RRDB(nn.Module):
def __init__(self, nc=64, gc=32, kernel_size=3, stride=1, padding=1, bias=True, mode='CR', negative_slope=0.2):
super(RRDB, self).__init__()
self.RDB1 = ResidualDenseBlock_5C(nc, gc, kernel_size, stride, padding, bias, mode, negative_slope)
self.RDB2 = ResidualDenseBlock_5C(nc, gc, kernel_size, stride, padding, bias, mode, negative_slope)
self.RDB3 = ResidualDenseBlock_5C(nc, gc, kernel_size, stride, padding, bias, mode, negative_slope)
def forward(self, x):
out = self.RDB1(x)
out = self.RDB2(out)
out = self.RDB3(out)
return out.mul_(0.2) + x
"""
# --------------------------------------------
# Upsampler
# Kai Zhang, https://github.com/cszn/KAIR
# --------------------------------------------
# upsample_pixelshuffle
# upsample_upconv
# upsample_convtranspose
# --------------------------------------------
"""
# --------------------------------------------
# conv + subp (+ relu)
# --------------------------------------------
def upsample_pixelshuffle(in_channels=64, out_channels=3, kernel_size=3, stride=1, padding=1, bias=True, mode='2R', negative_slope=0.2):
assert len(mode)<4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR.'
up1 = conv(in_channels, out_channels * (int(mode[0]) ** 2), kernel_size, stride, padding, bias, mode='C'+mode, negative_slope=negative_slope)
return up1
# --------------------------------------------
# nearest_upsample + conv (+ R)
# --------------------------------------------
def upsample_upconv(in_channels=64, out_channels=3, kernel_size=3, stride=1, padding=1, bias=True, mode='2R', negative_slope=0.2):
assert len(mode)<4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR'
if mode[0] == '2':
uc = 'UC'
elif mode[0] == '3':
uc = 'uC'
elif mode[0] == '4':
uc = 'vC'
mode = mode.replace(mode[0], uc)
up1 = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode=mode, negative_slope=negative_slope)
return up1
# --------------------------------------------
# convTranspose (+ relu)
# --------------------------------------------
def upsample_convtranspose(in_channels=64, out_channels=3, kernel_size=2, stride=2, padding=0, bias=True, mode='2R', negative_slope=0.2):
assert len(mode)<4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR.'
kernel_size = int(mode[0])
stride = int(mode[0])
mode = mode.replace(mode[0], 'T')
up1 = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope)
return up1
'''
# --------------------------------------------
# Downsampler
# Kai Zhang, https://github.com/cszn/KAIR
# --------------------------------------------
# downsample_strideconv
# downsample_maxpool
# downsample_avgpool
# --------------------------------------------
'''
# --------------------------------------------
# strideconv (+ relu)
# --------------------------------------------
def downsample_strideconv(in_channels=64, out_channels=64, kernel_size=2, stride=2, padding=0, bias=True, mode='2R', negative_slope=0.2):
assert len(mode)<4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR.'
kernel_size = int(mode[0])
stride = int(mode[0])
mode = mode.replace(mode[0], 'C')
down1 = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope)
return down1
# --------------------------------------------
# maxpooling + conv (+ relu)
# --------------------------------------------
def downsample_maxpool(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=0, bias=True, mode='2R', negative_slope=0.2):
assert len(mode)<4 and mode[0] in ['2', '3'], 'mode examples: 2, 2R, 2BR, 3, ..., 3BR.'
kernel_size_pool = int(mode[0])
stride_pool = int(mode[0])
mode = mode.replace(mode[0], 'MC')
pool = conv(kernel_size=kernel_size_pool, stride=stride_pool, mode=mode[0], negative_slope=negative_slope)
pool_tail = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode=mode[1:], negative_slope=negative_slope)
return sequential(pool, pool_tail)
# --------------------------------------------
# averagepooling + conv (+ relu)
# --------------------------------------------
def downsample_avgpool(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='2R', negative_slope=0.2):
assert len(mode)<4 and mode[0] in ['2', '3'], 'mode examples: 2, 2R, 2BR, 3, ..., 3BR.'
kernel_size_pool = int(mode[0])
stride_pool = int(mode[0])
mode = mode.replace(mode[0], 'AC')
pool = conv(kernel_size=kernel_size_pool, stride=stride_pool, mode=mode[0], negative_slope=negative_slope)
pool_tail = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode=mode[1:], negative_slope=negative_slope)
return sequential(pool, pool_tail)
'''
# --------------------------------------------
# NonLocalBlock2D:
# embedded_gaussian
# +W(softmax(thetaXphi)Xg)
# --------------------------------------------
'''
# --------------------------------------------
# non-local block with embedded_gaussian
# https://github.com/AlexHex7/Non-local_pytorch
# --------------------------------------------
class NonLocalBlock2D(nn.Module):
def __init__(self, nc=64, kernel_size=1, stride=1, padding=0, bias=True, act_mode='B', downsample=False, downsample_mode='maxpool', negative_slope=0.2):
super(NonLocalBlock2D, self).__init__()
inter_nc = nc // 2
self.inter_nc = inter_nc
self.W = conv(inter_nc, nc, kernel_size, stride, padding, bias, mode='C'+act_mode)
self.theta = conv(nc, inter_nc, kernel_size, stride, padding, bias, mode='C')
if downsample:
if downsample_mode == 'avgpool':
downsample_block = downsample_avgpool
elif downsample_mode == 'maxpool':
downsample_block = downsample_maxpool
elif downsample_mode == 'strideconv':
downsample_block = downsample_strideconv
else:
raise NotImplementedError('downsample mode [{:s}] is not found'.format(downsample_mode))
self.phi = downsample_block(nc, inter_nc, kernel_size, stride, padding, bias, mode='2')
self.g = downsample_block(nc, inter_nc, kernel_size, stride, padding, bias, mode='2')
else:
self.phi = conv(nc, inter_nc, kernel_size, stride, padding, bias, mode='C')
self.g = conv(nc, inter_nc, kernel_size, stride, padding, bias, mode='C')
def forward(self, x):
'''
:param x: (b, c, t, h, w)
:return:
'''
batch_size = x.size(0)
g_x = self.g(x).view(batch_size, self.inter_nc, -1)
g_x = g_x.permute(0, 2, 1)
theta_x = self.theta(x).view(batch_size, self.inter_nc, -1)
theta_x = theta_x.permute(0, 2, 1)
phi_x = self.phi(x).view(batch_size, self.inter_nc, -1)
f = torch.matmul(theta_x, phi_x)
f_div_C = F.softmax(f, dim=-1)
y = torch.matmul(f_div_C, g_x)
y = y.permute(0, 2, 1).contiguous()
y = y.view(batch_size, self.inter_nc, *x.size()[2:])
W_y = self.W(y)
z = W_y + x
return z
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
def default_conv(in_channels, out_channels, kernel_size, bias=True):
return nn.Conv2d(
in_channels, out_channels, kernel_size,
padding=(kernel_size//2), bias=bias)
class MeanShift(nn.Conv2d):
def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
super(MeanShift, self).__init__(3, 3, kernel_size=1)
std = torch.Tensor(rgb_std)
self.weight.data = torch.eye(3).view(3, 3, 1, 1)
self.weight.data.div_(std.view(3, 1, 1, 1))
self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
self.bias.data.div_(std)
self.requires_grad = False
class BasicBlock(nn.Sequential):
def __init__(
self, in_channels, out_channels, kernel_size, stride=1, bias=False,
bn=True, act=nn.ReLU(True)):
m = [nn.Conv2d(
in_channels, out_channels, kernel_size,
padding=(kernel_size//2), stride=stride, bias=bias)
]
if bn: m.append(nn.BatchNorm2d(out_channels))
if act is not None: m.append(act)
super(BasicBlock, self).__init__(*m)
class ResBlock(nn.Module):
def __init__(
self, conv, n_feat, kernel_size,
bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
super(ResBlock, self).__init__()
m = []
for i in range(2):
m.append(conv(n_feat, n_feat, kernel_size, bias=bias))
if bn: m.append(nn.BatchNorm2d(n_feat))
if i == 0: m.append(act)
self.body = nn.Sequential(*m)
self.res_scale = res_scale
def forward(self, x):
res = self.body(x).mul(self.res_scale)
res += x
return res
class Upsampler(nn.Sequential):
def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True):
m = []
if (scale & (scale - 1)) == 0: # Is scale = 2^n?
for _ in range(int(math.log(scale, 2))):
m.append(conv(n_feat, 4 * n_feat, 3, bias))
m.append(nn.PixelShuffle(2))
if bn: m.append(nn.BatchNorm2d(n_feat))
if act: m.append(act())
elif scale == 3:
m.append(conv(n_feat, 9 * n_feat, 3, bias))
m.append(nn.PixelShuffle(3))
if bn: m.append(nn.BatchNorm2d(n_feat))
if act: m.append(act())
else:
raise NotImplementedError
super(Upsampler, self).__init__(*m)
# add NonLocalBlock2D
# reference: https://github.com/AlexHex7/Non-local_pytorch/blob/master/lib/non_local_simple_version.py
class NonLocalBlock2D(nn.Module):
def __init__(self, in_channels, inter_channels):
super(NonLocalBlock2D, self).__init__()
self.in_channels = in_channels
self.inter_channels = inter_channels
self.g = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0)
self.W = nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0)
nn.init.constant(self.W.weight, 0)
nn.init.constant(self.W.bias, 0)
self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0)
self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
batch_size = x.size(0)
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
g_x = g_x.permute(0,2,1)
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
theta_x = theta_x.permute(0,2,1)
phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
f = torch.matmul(theta_x, phi_x)
f_div_C = F.softmax(f, dim=1)
y = torch.matmul(f_div_C, g_x)
y = y.permute(0,2,1).contiguous()
y = y.view(batch_size, self.inter_channels, *x.size()[2:])
W_y = self.W(y)
z = W_y + x
return z
## define trunk branch
class TrunkBranch(nn.Module):
def __init__(
self, conv, n_feat, kernel_size,
bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
super(TrunkBranch, self).__init__()
modules_body = []
for i in range(2):
modules_body.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1))
self.body = nn.Sequential(*modules_body)
def forward(self, x):
tx = self.body(x)
return tx
## define mask branch
class MaskBranchDownUp(nn.Module):
def __init__(
self, conv, n_feat, kernel_size,
bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
super(MaskBranchDownUp, self).__init__()
MB_RB1 = []
MB_RB1.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1))
MB_Down = []
MB_Down.append(nn.Conv2d(n_feat,n_feat, 3, stride=2, padding=1))
MB_RB2 = []
for i in range(2):
MB_RB2.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1))
MB_Up = []
MB_Up.append(nn.ConvTranspose2d(n_feat,n_feat, 6, stride=2, padding=2))
MB_RB3 = []
MB_RB3.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1))
MB_1x1conv = []
MB_1x1conv.append(nn.Conv2d(n_feat,n_feat, 1, padding=0, bias=True))
MB_sigmoid = []
MB_sigmoid.append(nn.Sigmoid())
self.MB_RB1 = nn.Sequential(*MB_RB1)
self.MB_Down = nn.Sequential(*MB_Down)
self.MB_RB2 = nn.Sequential(*MB_RB2)
self.MB_Up = nn.Sequential(*MB_Up)
self.MB_RB3 = nn.Sequential(*MB_RB3)
self.MB_1x1conv = nn.Sequential(*MB_1x1conv)
self.MB_sigmoid = nn.Sequential(*MB_sigmoid)
def forward(self, x):
x_RB1 = self.MB_RB1(x)
x_Down = self.MB_Down(x_RB1)
x_RB2 = self.MB_RB2(x_Down)
x_Up = self.MB_Up(x_RB2)
x_preRB3 = x_RB1 + x_Up
x_RB3 = self.MB_RB3(x_preRB3)
x_1x1 = self.MB_1x1conv(x_RB3)
mx = self.MB_sigmoid(x_1x1)
return mx
## define nonlocal mask branch
class NLMaskBranchDownUp(nn.Module):
def __init__(
self, conv, n_feat, kernel_size,
bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
super(NLMaskBranchDownUp, self).__init__()
MB_RB1 = []
MB_RB1.append(NonLocalBlock2D(n_feat, n_feat // 2))
MB_RB1.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1))
MB_Down = []
MB_Down.append(nn.Conv2d(n_feat,n_feat, 3, stride=2, padding=1))
MB_RB2 = []
for i in range(2):
MB_RB2.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1))
MB_Up = []
MB_Up.append(nn.ConvTranspose2d(n_feat,n_feat, 6, stride=2, padding=2))
MB_RB3 = []
MB_RB3.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1))
MB_1x1conv = []
MB_1x1conv.append(nn.Conv2d(n_feat,n_feat, 1, padding=0, bias=True))
MB_sigmoid = []
MB_sigmoid.append(nn.Sigmoid())
self.MB_RB1 = nn.Sequential(*MB_RB1)
self.MB_Down = nn.Sequential(*MB_Down)
self.MB_RB2 = nn.Sequential(*MB_RB2)
self.MB_Up = nn.Sequential(*MB_Up)
self.MB_RB3 = nn.Sequential(*MB_RB3)
self.MB_1x1conv = nn.Sequential(*MB_1x1conv)
self.MB_sigmoid = nn.Sequential(*MB_sigmoid)
def forward(self, x):
x_RB1 = self.MB_RB1(x)
x_Down = self.MB_Down(x_RB1)
x_RB2 = self.MB_RB2(x_Down)
x_Up = self.MB_Up(x_RB2)
x_preRB3 = x_RB1 + x_Up
x_RB3 = self.MB_RB3(x_preRB3)
x_1x1 = self.MB_1x1conv(x_RB3)
mx = self.MB_sigmoid(x_1x1)
return mx
## define residual attention module
class ResAttModuleDownUpPlus(nn.Module):
def __init__(
self, conv, n_feat, kernel_size,
bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
super(ResAttModuleDownUpPlus, self).__init__()
RA_RB1 = []
RA_RB1.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1))
RA_TB = []
RA_TB.append(TrunkBranch(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1))
RA_MB = []
RA_MB.append(MaskBranchDownUp(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1))
RA_tail = []
for i in range(2):
RA_tail.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1))
self.RA_RB1 = nn.Sequential(*RA_RB1)
self.RA_TB = nn.Sequential(*RA_TB)
self.RA_MB = nn.Sequential(*RA_MB)
self.RA_tail = nn.Sequential(*RA_tail)
def forward(self, input):
RA_RB1_x = self.RA_RB1(input)
tx = self.RA_TB(RA_RB1_x)
mx = self.RA_MB(RA_RB1_x)
txmx = tx * mx
hx = txmx + RA_RB1_x
hx = self.RA_tail(hx)
return hx
## define nonlocal residual attention module
class NLResAttModuleDownUpPlus(nn.Module):
def __init__(
self, conv, n_feat, kernel_size,
bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
super(NLResAttModuleDownUpPlus, self).__init__()
RA_RB1 = []
RA_RB1.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1))
RA_TB = []
RA_TB.append(TrunkBranch(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1))
RA_MB = []
RA_MB.append(NLMaskBranchDownUp(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1))
RA_tail = []
for i in range(2):
RA_tail.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1))
self.RA_RB1 = nn.Sequential(*RA_RB1)
self.RA_TB = nn.Sequential(*RA_TB)
self.RA_MB = nn.Sequential(*RA_MB)
self.RA_tail = nn.Sequential(*RA_tail)
def forward(self, input):
RA_RB1_x = self.RA_RB1(input)
tx = self.RA_TB(RA_RB1_x)
mx = self.RA_MB(RA_RB1_x)
txmx = tx * mx
hx = txmx + RA_RB1_x
hx = self.RA_tail(hx)
return hx
\ No newline at end of file
import torch
import torch.nn as nn
import torchvision
from torch.nn import functional as F
from torch import autograd as autograd
"""
Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace)
(2*): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace)
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): ReLU(inplace)
(7*): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): ReLU(inplace)
(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace)
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(13): ReLU(inplace)
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): ReLU(inplace)
(16*): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(17): ReLU(inplace)
(18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(20): ReLU(inplace)
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(22): ReLU(inplace)
(23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(24): ReLU(inplace)
(25*): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(26): ReLU(inplace)
(27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(29): ReLU(inplace)
(30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(31): ReLU(inplace)
(32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(33): ReLU(inplace)
(34*): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(35): ReLU(inplace)
(36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
"""
# --------------------------------------------
# Perceptual loss
# --------------------------------------------
class VGGFeatureExtractor(nn.Module):
def __init__(self, feature_layer=[2,7,16,25,34], use_input_norm=True, use_range_norm=False):
super(VGGFeatureExtractor, self).__init__()
'''
use_input_norm: If True, x: [0, 1] --> (x - mean) / std
use_range_norm: If True, x: [0, 1] --> x: [-1, 1]
'''
model = torchvision.models.vgg19(pretrained=True)
self.use_input_norm = use_input_norm
self.use_range_norm = use_range_norm
if self.use_input_norm:
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
self.register_buffer('mean', mean)
self.register_buffer('std', std)
self.list_outputs = isinstance(feature_layer, list)
if self.list_outputs:
self.features = nn.Sequential()
feature_layer = [-1] + feature_layer
for i in range(len(feature_layer)-1):
self.features.add_module('child'+str(i), nn.Sequential(*list(model.features.children())[(feature_layer[i]+1):(feature_layer[i+1]+1)]))
else:
self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)])
print(self.features)
# No need to BP to variable
for k, v in self.features.named_parameters():
v.requires_grad = False
def forward(self, x):
if self.use_range_norm:
x = (x + 1.0) / 2.0
if self.use_input_norm:
x = (x - self.mean) / self.std
if self.list_outputs:
output = []
for child_model in self.features.children():
x = child_model(x)
output.append(x.clone())
return output
else:
return self.features(x)
class PerceptualLoss(nn.Module):
"""VGG Perceptual loss
"""
def __init__(self, feature_layer=[2,7,16,25,34], weights=[0.1,0.1,1.0,1.0,1.0], lossfn_type='l1', use_input_norm=True, use_range_norm=False):
super(PerceptualLoss, self).__init__()
self.vgg = VGGFeatureExtractor(feature_layer=feature_layer, use_input_norm=use_input_norm, use_range_norm=use_range_norm)
self.lossfn_type = lossfn_type
self.weights = weights
if self.lossfn_type == 'l1':
self.lossfn = nn.L1Loss()
else:
self.lossfn = nn.MSELoss()
print(f'feature_layer: {feature_layer} with weights: {weights}')
def forward(self, x, gt):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
x_vgg, gt_vgg = self.vgg(x), self.vgg(gt.detach())
loss = 0.0
if isinstance(x_vgg, list):
n = len(x_vgg)
for i in range(n):
loss += self.weights[i] * self.lossfn(x_vgg[i], gt_vgg[i])
else:
loss += self.lossfn(x_vgg, gt_vgg.detach())
return loss
# --------------------------------------------
# GAN loss: gan, ragan
# --------------------------------------------
class GANLoss(nn.Module):
def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0):
super(GANLoss, self).__init__()
self.gan_type = gan_type.lower()
self.real_label_val = real_label_val
self.fake_label_val = fake_label_val
if self.gan_type == 'gan' or self.gan_type == 'ragan':
self.loss = nn.BCEWithLogitsLoss()
elif self.gan_type == 'lsgan':
self.loss = nn.MSELoss()
elif self.gan_type == 'wgan':
def wgan_loss(input, target):
# target is boolean
return -1 * input.mean() if target else input.mean()
self.loss = wgan_loss
elif self.gan_type == 'softplusgan':
def softplusgan_loss(input, target):
# target is boolean
return F.softplus(-input).mean() if target else F.softplus(input).mean()
self.loss = softplusgan_loss
else:
raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type))
def get_target_label(self, input, target_is_real):
if self.gan_type in ['wgan', 'softplusgan']:
return target_is_real
if target_is_real:
return torch.empty_like(input).fill_(self.real_label_val)
else:
return torch.empty_like(input).fill_(self.fake_label_val)
def forward(self, input, target_is_real):
target_label = self.get_target_label(input, target_is_real)
loss = self.loss(input, target_label)
return loss
# --------------------------------------------
# TV loss
# --------------------------------------------
class TVLoss(nn.Module):
def __init__(self, tv_loss_weight=1):
"""
Total variation loss
https://github.com/jxgu1016/Total_Variation_Loss.pytorch
Args:
tv_loss_weight (int):
"""
super(TVLoss, self).__init__()
self.tv_loss_weight = tv_loss_weight
def forward(self, x):
batch_size = x.size()[0]
h_x = x.size()[2]
w_x = x.size()[3]
count_h = self.tensor_size(x[:, :, 1:, :])
count_w = self.tensor_size(x[:, :, :, 1:])
h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size
@staticmethod
def tensor_size(t):
return t.size()[1] * t.size()[2] * t.size()[3]
# --------------------------------------------
# Charbonnier loss
# --------------------------------------------
class CharbonnierLoss(nn.Module):
"""Charbonnier Loss (L1)"""
def __init__(self, eps=1e-9):
super(CharbonnierLoss, self).__init__()
self.eps = eps
def forward(self, x, y):
diff = x - y
loss = torch.mean(torch.sqrt((diff * diff) + self.eps))
return loss
def r1_penalty(real_pred, real_img):
"""R1 regularization for discriminator. The core idea is to
penalize the gradient on real data alone: when the
generator distribution produces the true data distribution
and the discriminator is equal to 0 on the data manifold, the
gradient penalty ensures that the discriminator cannot create
a non-zero gradient orthogonal to the data manifold without
suffering a loss in the GAN game.
Ref:
Eq. 9 in Which training methods for GANs do actually converge.
"""
grad_real = autograd.grad(
outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0]
grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
return grad_penalty
def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
noise = torch.randn_like(fake_img) / math.sqrt(
fake_img.shape[2] * fake_img.shape[3])
grad = autograd.grad(
outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0]
path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
path_mean = mean_path_length + decay * (
path_lengths.mean() - mean_path_length)
path_penalty = (path_lengths - path_mean).pow(2).mean()
return path_penalty, path_lengths.detach().mean(), path_mean.detach()
def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None):
"""Calculate gradient penalty for wgan-gp.
Args:
discriminator (nn.Module): Network for the discriminator.
real_data (Tensor): Real input data.
fake_data (Tensor): Fake input data.
weight (Tensor): Weight tensor. Default: None.
Returns:
Tensor: A tensor for gradient penalty.
"""
batch_size = real_data.size(0)
alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1))
# interpolate between real_data and fake_data
interpolates = alpha * real_data + (1. - alpha) * fake_data
interpolates = autograd.Variable(interpolates, requires_grad=True)
disc_interpolates = discriminator(interpolates)
gradients = autograd.grad(
outputs=disc_interpolates,
inputs=interpolates,
grad_outputs=torch.ones_like(disc_interpolates),
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
if weight is not None:
gradients = gradients * weight
gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean()
if weight is not None:
gradients_penalty /= torch.mean(weight)
return gradients_penalty
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from math import exp
"""
# ============================================
# SSIM loss
# https://github.com/Po-Hsun-Su/pytorch-ssim
# ============================================
"""
def gaussian(window_size, sigma):
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
return gauss/gauss.sum()
def create_window(window_size, channel):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
return window
def _ssim(img1, img2, window, window_size, channel, size_average=True):
mu1 = F.conv2d(img1, window, padding=window_size//2, groups=channel)
mu2 = F.conv2d(img2, window, padding=window_size//2, groups=channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1*mu2
sigma1_sq = F.conv2d(img1*img1, window, padding=window_size//2, groups=channel) - mu1_sq
sigma2_sq = F.conv2d(img2*img2, window, padding=window_size//2, groups=channel) - mu2_sq
sigma12 = F.conv2d(img1*img2, window, padding=window_size//2, groups=channel) - mu1_mu2
C1 = 0.01**2
C2 = 0.03**2
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
if size_average:
return ssim_map.mean()
else:
return ssim_map.mean(1).mean(1).mean(1)
class SSIMLoss(torch.nn.Module):
def __init__(self, window_size=11, size_average=True):
super(SSIMLoss, self).__init__()
self.window_size = window_size
self.size_average = size_average
self.channel = 1
self.window = create_window(window_size, self.channel)
def forward(self, img1, img2):
(_, channel, _, _) = img1.size()
if channel == self.channel and self.window.data.type() == img1.data.type():
window = self.window
else:
window = create_window(self.window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
self.window = window
self.channel = channel
return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
def ssim(img1, img2, window_size=11, size_average=True):
(_, channel, _, _) = img1.size()
window = create_window(window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
return _ssim(img1, img2, window, window_size, channel, size_average)
if __name__ == '__main__':
import cv2
from torch import optim
from skimage import io
npImg1 = cv2.imread("einstein.png")
img1 = torch.from_numpy(np.rollaxis(npImg1, 2)).float().unsqueeze(0)/255.0
img2 = torch.rand(img1.size())
if torch.cuda.is_available():
img1 = img1.cuda()
img2 = img2.cuda()
img1 = Variable(img1, requires_grad=False)
img2 = Variable(img2, requires_grad=True)
ssim_value = ssim(img1, img2).item()
print("Initial ssim:", ssim_value)
ssim_loss = SSIMLoss()
optimizer = optim.Adam([img2], lr=0.01)
while ssim_value < 0.99:
optimizer.zero_grad()
ssim_out = -ssim_loss(img1, img2)
ssim_value = -ssim_out.item()
print('{:<4.4f}'.format(ssim_value))
ssim_out.backward()
optimizer.step()
img = np.transpose(img2.detach().cpu().squeeze().float().numpy(), (1,2,0))
io.imshow(np.uint8(np.clip(img*255, 0, 255)))
import os
import torch
import torch.nn as nn
from utils.utils_bnorm import merge_bn, tidy_sequential
from torch.nn.parallel import DataParallel, DistributedDataParallel
class ModelBase():
def __init__(self, opt):
self.opt = opt # opt
self.save_dir = opt['path']['models'] # save models
self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu')
self.is_train = opt['is_train'] # training or not
self.schedulers = [] # schedulers
"""
# ----------------------------------------
# Preparation before training with data
# Save model during training
# ----------------------------------------
"""
def init_train(self):
pass
def load(self):
pass
def save(self, label):
pass
def define_loss(self):
pass
def define_optimizer(self):
pass
def define_scheduler(self):
pass
"""
# ----------------------------------------
# Optimization during training with data
# Testing/evaluation
# ----------------------------------------
"""
def feed_data(self, data):
pass
def optimize_parameters(self):
pass
def current_visuals(self):
pass
def current_losses(self):
pass
def update_learning_rate(self, n):
for scheduler in self.schedulers:
scheduler.step(n)
def current_learning_rate(self):
return self.schedulers[0].get_lr()[0]
def requires_grad(self, model, flag=True):
for p in model.parameters():
p.requires_grad = flag
"""
# ----------------------------------------
# Information of net
# ----------------------------------------
"""
def print_network(self):
pass
def info_network(self):
pass
def print_params(self):
pass
def info_params(self):
pass
def get_bare_model(self, network):
"""Get bare model, especially under wrapping with
DistributedDataParallel or DataParallel.
"""
if isinstance(network, (DataParallel, DistributedDataParallel)):
network = network.module
return network
def model_to_device(self, network):
"""Model to device. It also warps models with DistributedDataParallel
or DataParallel.
Args:
network (nn.Module)
"""
network = network.to(self.device)
if self.opt['dist']:
find_unused_parameters = self.opt.get('find_unused_parameters', True)
use_static_graph = self.opt.get('use_static_graph', False)
network = DistributedDataParallel(network, device_ids=[torch.cuda.current_device()], find_unused_parameters=find_unused_parameters)
if use_static_graph:
print('Using static graph. Make sure that "unused parameters" will not change during training loop.')
network._set_static_graph()
else:
network = DataParallel(network)
return network
# ----------------------------------------
# network name and number of parameters
# ----------------------------------------
def describe_network(self, network):
network = self.get_bare_model(network)
msg = '\n'
msg += 'Networks name: {}'.format(network.__class__.__name__) + '\n'
msg += 'Params number: {}'.format(sum(map(lambda x: x.numel(), network.parameters()))) + '\n'
msg += 'Net structure:\n{}'.format(str(network)) + '\n'
return msg
# ----------------------------------------
# parameters description
# ----------------------------------------
def describe_params(self, network):
network = self.get_bare_model(network)
msg = '\n'
msg += ' | {:^6s} | {:^6s} | {:^6s} | {:^6s} || {:<20s}'.format('mean', 'min', 'max', 'std', 'shape', 'param_name') + '\n'
for name, param in network.state_dict().items():
if not 'num_batches_tracked' in name:
v = param.data.clone().float()
msg += ' | {:>6.3f} | {:>6.3f} | {:>6.3f} | {:>6.3f} | {} || {:s}'.format(v.mean(), v.min(), v.max(), v.std(), v.shape, name) + '\n'
return msg
"""
# ----------------------------------------
# Save prameters
# Load prameters
# ----------------------------------------
"""
# ----------------------------------------
# save the state_dict of the network
# ----------------------------------------
def save_network(self, save_dir, network, network_label, iter_label):
save_filename = '{}_{}.pth'.format(iter_label, network_label)
save_path = os.path.join(save_dir, save_filename)
network = self.get_bare_model(network)
state_dict = network.state_dict()
for key, param in state_dict.items():
state_dict[key] = param.cpu()
torch.save(state_dict, save_path)
# ----------------------------------------
# load the state_dict of the network
# ----------------------------------------
def load_network(self, load_path, network, strict=True, param_key='params'):
network = self.get_bare_model(network)
if strict:
state_dict = torch.load(load_path)
if param_key in state_dict.keys():
state_dict = state_dict[param_key]
network.load_state_dict(state_dict, strict=strict)
else:
state_dict_old = torch.load(load_path)
if param_key in state_dict_old.keys():
state_dict_old = state_dict_old[param_key]
state_dict = network.state_dict()
for ((key_old, param_old),(key, param)) in zip(state_dict_old.items(), state_dict.items()):
state_dict[key] = param_old
network.load_state_dict(state_dict, strict=True)
del state_dict_old, state_dict
# ----------------------------------------
# save the state_dict of the optimizer
# ----------------------------------------
def save_optimizer(self, save_dir, optimizer, optimizer_label, iter_label):
save_filename = '{}_{}.pth'.format(iter_label, optimizer_label)
save_path = os.path.join(save_dir, save_filename)
torch.save(optimizer.state_dict(), save_path)
# ----------------------------------------
# load the state_dict of the optimizer
# ----------------------------------------
def load_optimizer(self, load_path, optimizer):
optimizer.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage.cuda(torch.cuda.current_device())))
def update_E(self, decay=0.999):
netG = self.get_bare_model(self.netG)
netG_params = dict(netG.named_parameters())
netE_params = dict(self.netE.named_parameters())
for k in netG_params.keys():
netE_params[k].data.mul_(decay).add_(netG_params[k].data, alpha=1-decay)
"""
# ----------------------------------------
# Merge Batch Normalization for training
# Merge Batch Normalization for testing
# ----------------------------------------
"""
# ----------------------------------------
# merge bn during training
# ----------------------------------------
def merge_bnorm_train(self):
merge_bn(self.netG)
tidy_sequential(self.netG)
self.define_optimizer()
self.define_scheduler()
# ----------------------------------------
# merge bn before testing
# ----------------------------------------
def merge_bnorm_test(self):
merge_bn(self.netG)
tidy_sequential(self.netG)
from collections import OrderedDict
import torch
import torch.nn as nn
from torch.optim import lr_scheduler
from torch.optim import Adam
from models.select_network import define_G, define_D
from models.model_base import ModelBase
from models.loss import GANLoss, PerceptualLoss
from models.loss_ssim import SSIMLoss
class ModelGAN(ModelBase):
"""Train with pixel-VGG-GAN loss"""
def __init__(self, opt):
super(ModelGAN, self).__init__(opt)
# ------------------------------------
# define network
# ------------------------------------
self.opt_train = self.opt['train'] # training option
self.netG = define_G(opt)
self.netG = self.model_to_device(self.netG)
if self.is_train:
self.netD = define_D(opt)
self.netD = self.model_to_device(self.netD)
if self.opt_train['E_decay'] > 0:
self.netE = define_G(opt).to(self.device).eval()
"""
# ----------------------------------------
# Preparation before training with data
# Save model during training
# ----------------------------------------
"""
# ----------------------------------------
# initialize training
# ----------------------------------------
def init_train(self):
self.load() # load model
self.netG.train() # set training mode,for BN
self.netD.train() # set training mode,for BN
self.define_loss() # define loss
self.define_optimizer() # define optimizer
self.load_optimizers() # load optimizer
self.define_scheduler() # define scheduler
self.log_dict = OrderedDict() # log
# ----------------------------------------
# load pre-trained G and D model
# ----------------------------------------
def load(self):
load_path_G = self.opt['path']['pretrained_netG']
if load_path_G is not None:
print('Loading model for G [{:s}] ...'.format(load_path_G))
self.load_network(load_path_G, self.netG, strict=self.opt_train['G_param_strict'])
load_path_E = self.opt['path']['pretrained_netE']
if self.opt_train['E_decay'] > 0:
if load_path_E is not None:
print('Loading model for E [{:s}] ...'.format(load_path_E))
self.load_network(load_path_E, self.netE, strict=self.opt_train['E_param_strict'])
else:
print('Copying model for E')
self.update_E(0)
self.netE.eval()
load_path_D = self.opt['path']['pretrained_netD']
if self.opt['is_train'] and load_path_D is not None:
print('Loading model for D [{:s}] ...'.format(load_path_D))
self.load_network(load_path_D, self.netD, strict=self.opt_train['D_param_strict'])
# ----------------------------------------
# load optimizerG and optimizerD
# ----------------------------------------
def load_optimizers(self):
load_path_optimizerG = self.opt['path']['pretrained_optimizerG']
if load_path_optimizerG is not None and self.opt_train['G_optimizer_reuse']:
print('Loading optimizerG [{:s}] ...'.format(load_path_optimizerG))
self.load_optimizer(load_path_optimizerG, self.G_optimizer)
load_path_optimizerD = self.opt['path']['pretrained_optimizerD']
if load_path_optimizerD is not None and self.opt_train['D_optimizer_reuse']:
print('Loading optimizerD [{:s}] ...'.format(load_path_optimizerD))
self.load_optimizer(load_path_optimizerD, self.D_optimizer)
# ----------------------------------------
# save model / optimizer(optional)
# ----------------------------------------
def save(self, iter_label):
self.save_network(self.save_dir, self.netG, 'G', iter_label)
self.save_network(self.save_dir, self.netD, 'D', iter_label)
if self.opt_train['E_decay'] > 0:
self.save_network(self.save_dir, self.netE, 'E', iter_label)
if self.opt_train['G_optimizer_reuse']:
self.save_optimizer(self.save_dir, self.G_optimizer, 'optimizerG', iter_label)
if self.opt_train['D_optimizer_reuse']:
self.save_optimizer(self.save_dir, self.D_optimizer, 'optimizerD', iter_label)
# ----------------------------------------
# define loss
# ----------------------------------------
def define_loss(self):
# ------------------------------------
# 1) G_loss
# ------------------------------------
if self.opt_train['G_lossfn_weight'] > 0:
G_lossfn_type = self.opt_train['G_lossfn_type']
if G_lossfn_type == 'l1':
self.G_lossfn = nn.L1Loss().to(self.device)
elif G_lossfn_type == 'l2':
self.G_lossfn = nn.MSELoss().to(self.device)
elif G_lossfn_type == 'l2sum':
self.G_lossfn = nn.MSELoss(reduction='sum').to(self.device)
elif G_lossfn_type == 'ssim':
self.G_lossfn = SSIMLoss().to(self.device)
else:
raise NotImplementedError('Loss type [{:s}] is not found.'.format(G_lossfn_type))
self.G_lossfn_weight = self.opt_train['G_lossfn_weight']
else:
print('Do not use pixel loss.')
self.G_lossfn = None
# ------------------------------------
# 2) F_loss
# ------------------------------------
if self.opt_train['F_lossfn_weight'] > 0:
F_feature_layer = self.opt_train['F_feature_layer']
F_weights = self.opt_train['F_weights']
F_lossfn_type = self.opt_train['F_lossfn_type']
F_use_input_norm = self.opt_train['F_use_input_norm']
F_use_range_norm = self.opt_train['F_use_range_norm']
if self.opt['dist']:
self.F_lossfn = PerceptualLoss(feature_layer=F_feature_layer, weights=F_weights, lossfn_type=F_lossfn_type, use_input_norm=F_use_input_norm, use_range_norm=F_use_range_norm).to(self.device)
else:
self.F_lossfn = PerceptualLoss(feature_layer=F_feature_layer, weights=F_weights, lossfn_type=F_lossfn_type, use_input_norm=F_use_input_norm, use_range_norm=F_use_range_norm)
self.F_lossfn.vgg = self.model_to_device(self.F_lossfn.vgg)
self.F_lossfn.lossfn = self.F_lossfn.lossfn.to(self.device)
self.F_lossfn_weight = self.opt_train['F_lossfn_weight']
else:
print('Do not use feature loss.')
self.F_lossfn = None
# ------------------------------------
# 3) D_loss
# ------------------------------------
self.D_lossfn = GANLoss(self.opt_train['gan_type'], 1.0, 0.0).to(self.device)
self.D_lossfn_weight = self.opt_train['D_lossfn_weight']
self.D_update_ratio = self.opt_train['D_update_ratio'] if self.opt_train['D_update_ratio'] else 1
self.D_init_iters = self.opt_train['D_init_iters'] if self.opt_train['D_init_iters'] else 0
# ----------------------------------------
# define optimizer, G and D
# ----------------------------------------
def define_optimizer(self):
G_optim_params = []
for k, v in self.netG.named_parameters():
if v.requires_grad:
G_optim_params.append(v)
else:
print('Params [{:s}] will not optimize.'.format(k))
self.G_optimizer = Adam(G_optim_params, lr=self.opt_train['G_optimizer_lr'], weight_decay=0)
self.D_optimizer = Adam(self.netD.parameters(), lr=self.opt_train['D_optimizer_lr'], weight_decay=0)
# ----------------------------------------
# define scheduler, only "MultiStepLR"
# ----------------------------------------
def define_scheduler(self):
self.schedulers.append(lr_scheduler.MultiStepLR(self.G_optimizer,
self.opt_train['G_scheduler_milestones'],
self.opt_train['G_scheduler_gamma']
))
self.schedulers.append(lr_scheduler.MultiStepLR(self.D_optimizer,
self.opt_train['D_scheduler_milestones'],
self.opt_train['D_scheduler_gamma']
))
"""
# ----------------------------------------
# Optimization during training with data
# Testing/evaluation
# ----------------------------------------
"""
# ----------------------------------------
# feed L/H data
# ----------------------------------------
def feed_data(self, data, need_H=True):
self.L = data['L'].to(self.device)
if need_H:
self.H = data['H'].to(self.device)
# ----------------------------------------
# feed L to netG and get E
# ----------------------------------------
def netG_forward(self):
self.E = self.netG(self.L)
# ----------------------------------------
# update parameters and get loss
# ----------------------------------------
def optimize_parameters(self, current_step):
# ------------------------------------
# optimize G
# ------------------------------------
for p in self.netD.parameters():
p.requires_grad = False
self.G_optimizer.zero_grad()
self.netG_forward()
loss_G_total = 0
if current_step % self.D_update_ratio == 0 and current_step > self.D_init_iters: # updata D first
if self.opt_train['G_lossfn_weight'] > 0:
G_loss = self.G_lossfn_weight * self.G_lossfn(self.E, self.H)
loss_G_total += G_loss # 1) pixel loss
if self.opt_train['F_lossfn_weight'] > 0:
F_loss = self.F_lossfn_weight * self.F_lossfn(self.E.contiguous(), self.H.contiguous())
loss_G_total += F_loss # 2) VGG feature loss
if self.opt['train']['gan_type'] in ['gan', 'lsgan', 'wgan', 'softplusgan']:
pred_g_fake = self.netD(self.E)
D_loss = self.D_lossfn_weight * self.D_lossfn(pred_g_fake, True)
elif self.opt['train']['gan_type'] == 'ragan':
pred_d_real = self.netD(self.H).detach()
pred_g_fake = self.netD(self.E)
D_loss = self.D_lossfn_weight * (
self.D_lossfn(pred_d_real - torch.mean(pred_g_fake, 0, True), False) +
self.D_lossfn(pred_g_fake - torch.mean(pred_d_real, 0, True), True)) / 2
loss_G_total += D_loss # 3) GAN loss
loss_G_total.backward()
self.G_optimizer.step()
# ------------------------------------
# optimize D
# ------------------------------------
for p in self.netD.parameters():
p.requires_grad = True
self.D_optimizer.zero_grad()
# In order to avoid the error in distributed training:
# "Error detected in CudnnBatchNormBackward: RuntimeError: one of
# the variables needed for gradient computation has been modified by
# an inplace operation",
# we separate the backwards for real and fake, and also detach the
# tensor for calculating mean.
if self.opt_train['gan_type'] in ['gan', 'lsgan', 'wgan', 'softplusgan']:
# real
pred_d_real = self.netD(self.H) # 1) real data
l_d_real = self.D_lossfn(pred_d_real, True)
l_d_real.backward()
# fake
pred_d_fake = self.netD(self.E.detach().clone()) # 2) fake data, detach to avoid BP to G
l_d_fake = self.D_lossfn(pred_d_fake, False)
l_d_fake.backward()
elif self.opt_train['gan_type'] == 'ragan':
# real
pred_d_fake = self.netD(self.E).detach() # 1) fake data, detach to avoid BP to G
pred_d_real = self.netD(self.H) # 2) real data
l_d_real = 0.5 * self.D_lossfn(pred_d_real - torch.mean(pred_d_fake, 0, True), True)
l_d_real.backward()
# fake
pred_d_fake = self.netD(self.E.detach())
l_d_fake = 0.5 * self.D_lossfn(pred_d_fake - torch.mean(pred_d_real.detach(), 0, True), False)
l_d_fake.backward()
self.D_optimizer.step()
# ------------------------------------
# record log
# ------------------------------------
if current_step % self.D_update_ratio == 0 and current_step > self.D_init_iters:
if self.opt_train['G_lossfn_weight'] > 0:
self.log_dict['G_loss'] = G_loss.item()
if self.opt_train['F_lossfn_weight'] > 0:
self.log_dict['F_loss'] = F_loss.item()
self.log_dict['D_loss'] = D_loss.item()
#self.log_dict['l_d_real'] = l_d_real.item()
#self.log_dict['l_d_fake'] = l_d_fake.item()
self.log_dict['D_real'] = torch.mean(pred_d_real.detach())
self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())
if self.opt_train['E_decay'] > 0:
self.update_E(self.opt_train['E_decay'])
# ----------------------------------------
# test and inference
# ----------------------------------------
def test(self):
self.netG.eval()
with torch.no_grad():
self.netG_forward()
self.netG.train()
# ----------------------------------------
# get log_dict
# ----------------------------------------
def current_log(self):
return self.log_dict
# ----------------------------------------
# get L, E, H images
# ----------------------------------------
def current_visuals(self, need_H=True):
out_dict = OrderedDict()
out_dict['L'] = self.L.detach()[0].float().cpu()
out_dict['E'] = self.E.detach()[0].float().cpu()
if need_H:
out_dict['H'] = self.H.detach()[0].float().cpu()
return out_dict
"""
# ----------------------------------------
# Information of netG, netD and netF
# ----------------------------------------
"""
# ----------------------------------------
# print network
# ----------------------------------------
def print_network(self):
msg = self.describe_network(self.netG)
print(msg)
if self.is_train:
msg = self.describe_network(self.netD)
print(msg)
# ----------------------------------------
# print params
# ----------------------------------------
def print_params(self):
msg = self.describe_params(self.netG)
print(msg)
# ----------------------------------------
# network information
# ----------------------------------------
def info_network(self):
msg = self.describe_network(self.netG)
if self.is_train:
msg += self.describe_network(self.netD)
return msg
# ----------------------------------------
# params information
# ----------------------------------------
def info_params(self):
msg = self.describe_params(self.netG)
return msg
from collections import OrderedDict
import torch
import torch.nn as nn
from torch.optim import lr_scheduler
from torch.optim import Adam
from models.select_network import define_G
from models.model_base import ModelBase
from models.loss import CharbonnierLoss
from models.loss_ssim import SSIMLoss
from utils.utils_model import test_mode
from utils.utils_regularizers import regularizer_orth, regularizer_clip
class ModelPlain(ModelBase):
"""Train with pixel loss"""
def __init__(self, opt):
super(ModelPlain, self).__init__(opt)
# ------------------------------------
# define network
# ------------------------------------
self.opt_train = self.opt['train'] # training option
self.netG = define_G(opt)
self.netG = self.model_to_device(self.netG)
if self.opt_train['E_decay'] > 0:
self.netE = define_G(opt).to(self.device).eval()
"""
# ----------------------------------------
# Preparation before training with data
# Save model during training
# ----------------------------------------
"""
# ----------------------------------------
# initialize training
# ----------------------------------------
def init_train(self):
self.load() # load model
self.netG.train() # set training mode,for BN
self.define_loss() # define loss
self.define_optimizer() # define optimizer
self.load_optimizers() # load optimizer
self.define_scheduler() # define scheduler
self.log_dict = OrderedDict() # log
# ----------------------------------------
# load pre-trained G model
# ----------------------------------------
def load(self):
load_path_G = self.opt['path']['pretrained_netG']
if load_path_G is not None:
print('Loading model for G [{:s}] ...'.format(load_path_G))
self.load_network(load_path_G, self.netG, strict=self.opt_train['G_param_strict'], param_key='params')
load_path_E = self.opt['path']['pretrained_netE']
if self.opt_train['E_decay'] > 0:
if load_path_E is not None:
print('Loading model for E [{:s}] ...'.format(load_path_E))
self.load_network(load_path_E, self.netE, strict=self.opt_train['E_param_strict'], param_key='params_ema')
else:
print('Copying model for E ...')
self.update_E(0)
self.netE.eval()
# ----------------------------------------
# load optimizer
# ----------------------------------------
def load_optimizers(self):
load_path_optimizerG = self.opt['path']['pretrained_optimizerG']
if load_path_optimizerG is not None and self.opt_train['G_optimizer_reuse']:
print('Loading optimizerG [{:s}] ...'.format(load_path_optimizerG))
self.load_optimizer(load_path_optimizerG, self.G_optimizer)
# ----------------------------------------
# save model / optimizer(optional)
# ----------------------------------------
def save(self, iter_label):
self.save_network(self.save_dir, self.netG, 'G', iter_label)
if self.opt_train['E_decay'] > 0:
self.save_network(self.save_dir, self.netE, 'E', iter_label)
if self.opt_train['G_optimizer_reuse']:
self.save_optimizer(self.save_dir, self.G_optimizer, 'optimizerG', iter_label)
# ----------------------------------------
# define loss
# ----------------------------------------
def define_loss(self):
G_lossfn_type = self.opt_train['G_lossfn_type']
if G_lossfn_type == 'l1':
self.G_lossfn = nn.L1Loss().to(self.device)
elif G_lossfn_type == 'l2':
self.G_lossfn = nn.MSELoss().to(self.device)
elif G_lossfn_type == 'l2sum':
self.G_lossfn = nn.MSELoss(reduction='sum').to(self.device)
elif G_lossfn_type == 'ssim':
self.G_lossfn = SSIMLoss().to(self.device)
elif G_lossfn_type == 'charbonnier':
self.G_lossfn = CharbonnierLoss(self.opt_train['G_charbonnier_eps']).to(self.device)
else:
raise NotImplementedError('Loss type [{:s}] is not found.'.format(G_lossfn_type))
self.G_lossfn_weight = self.opt_train['G_lossfn_weight']
# ----------------------------------------
# define optimizer
# ----------------------------------------
def define_optimizer(self):
G_optim_params = []
for k, v in self.netG.named_parameters():
if v.requires_grad:
G_optim_params.append(v)
else:
print('Params [{:s}] will not optimize.'.format(k))
if self.opt_train['G_optimizer_type'] == 'adam':
self.G_optimizer = Adam(G_optim_params, lr=self.opt_train['G_optimizer_lr'],
betas=self.opt_train['G_optimizer_betas'],
weight_decay=self.opt_train['G_optimizer_wd'])
else:
raise NotImplementedError
# ----------------------------------------
# define scheduler, only "MultiStepLR"
# ----------------------------------------
def define_scheduler(self):
if self.opt_train['G_scheduler_type'] == 'MultiStepLR':
self.schedulers.append(lr_scheduler.MultiStepLR(self.G_optimizer,
self.opt_train['G_scheduler_milestones'],
self.opt_train['G_scheduler_gamma']
))
elif self.opt_train['G_scheduler_type'] == 'CosineAnnealingWarmRestarts':
self.schedulers.append(lr_scheduler.CosineAnnealingWarmRestarts(self.G_optimizer,
self.opt_train['G_scheduler_periods'],
self.opt_train['G_scheduler_restart_weights'],
self.opt_train['G_scheduler_eta_min']
))
else:
raise NotImplementedError
"""
# ----------------------------------------
# Optimization during training with data
# Testing/evaluation
# ----------------------------------------
"""
# ----------------------------------------
# feed L/H data
# ----------------------------------------
def feed_data(self, data, need_H=True):
self.L = data['L'].to(self.device)
if need_H:
self.H = data['H'].to(self.device)
if 'mask' in data:
# print(data['mask'])
# print(type(data['mask']))
# print(data['mask'].shape)
self.mask = data['mask'].to(self.device)
self.is_mask = True
else:
self.is_mask = False
# ----------------------------------------
# feed L to netG
# ----------------------------------------
def netG_forward(self):
self.E = self.netG(self.L)
if self.is_mask:
self.E = torch.mul(self.E, self.mask)
# ----------------------------------------
# update parameters and get loss
# ----------------------------------------
def optimize_parameters(self, current_step):
self.G_optimizer.zero_grad()
self.netG_forward()
G_loss = self.G_lossfn_weight * self.G_lossfn(self.E, self.H)
G_loss.backward()
# ------------------------------------
# clip_grad
# ------------------------------------
# `clip_grad_norm` helps prevent the exploding gradient problem.
G_optimizer_clipgrad = self.opt_train['G_optimizer_clipgrad'] if self.opt_train['G_optimizer_clipgrad'] else 0
if G_optimizer_clipgrad > 0:
torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=self.opt_train['G_optimizer_clipgrad'], norm_type=2)
self.G_optimizer.step()
# ------------------------------------
# regularizer
# ------------------------------------
G_regularizer_orthstep = self.opt_train['G_regularizer_orthstep'] if self.opt_train['G_regularizer_orthstep'] else 0
if G_regularizer_orthstep > 0 and current_step % G_regularizer_orthstep == 0 and current_step % self.opt['train']['checkpoint_save'] != 0:
self.netG.apply(regularizer_orth)
G_regularizer_clipstep = self.opt_train['G_regularizer_clipstep'] if self.opt_train['G_regularizer_clipstep'] else 0
if G_regularizer_clipstep > 0 and current_step % G_regularizer_clipstep == 0 and current_step % self.opt['train']['checkpoint_save'] != 0:
self.netG.apply(regularizer_clip)
# self.log_dict['G_loss'] = G_loss.item()/self.E.size()[0] # if `reduction='sum'`
self.log_dict['G_loss'] = G_loss.item()
if self.opt_train['E_decay'] > 0:
self.update_E(self.opt_train['E_decay'])
# ----------------------------------------
# test / inference
# ----------------------------------------
def test(self):
self.netG.eval()
with torch.no_grad():
self.netG_forward()
self.netG.train()
# ----------------------------------------
# test / inference x8
# ----------------------------------------
def testx8(self):
self.netG.eval()
with torch.no_grad():
self.E = test_mode(self.netG, self.L, mode=3, sf=self.opt['scale'], modulo=1)
self.netG.train()
# ----------------------------------------
# get log_dict
# ----------------------------------------
def current_log(self):
return self.log_dict
# ----------------------------------------
# get L, E, H image
# ----------------------------------------
def current_visuals(self, need_H=True):
out_dict = OrderedDict()
out_dict['L'] = self.L.detach()[0].float().cpu()
out_dict['E'] = self.E.detach()[0].float().cpu()
if need_H:
out_dict['H'] = self.H.detach()[0].float().cpu()
return out_dict
# ----------------------------------------
# get L, E, H batch images
# ----------------------------------------
def current_results(self, need_H=True):
out_dict = OrderedDict()
out_dict['L'] = self.L.detach().float().cpu()
out_dict['E'] = self.E.detach().float().cpu()
if need_H:
out_dict['H'] = self.H.detach().float().cpu()
return out_dict
"""
# ----------------------------------------
# Information of netG
# ----------------------------------------
"""
# ----------------------------------------
# print network
# ----------------------------------------
def print_network(self):
msg = self.describe_network(self.netG)
print(msg)
# ----------------------------------------
# print params
# ----------------------------------------
def print_params(self):
msg = self.describe_params(self.netG)
print(msg)
# ----------------------------------------
# network information
# ----------------------------------------
def info_network(self):
msg = self.describe_network(self.netG)
return msg
# ----------------------------------------
# params information
# ----------------------------------------
def info_params(self):
msg = self.describe_params(self.netG)
return msg
from models.model_plain import ModelPlain
class ModelPlain2(ModelPlain):
"""Train with two inputs (L, C) and with pixel loss"""
# ----------------------------------------
# feed L/H data
# ----------------------------------------
def feed_data(self, data, need_H=True):
self.L = data['L'].to(self.device)
self.C = data['C'].to(self.device)
if need_H:
self.H = data['H'].to(self.device)
# ----------------------------------------
# feed (L, C) to netG and get E
# ----------------------------------------
def netG_forward(self):
self.E = self.netG(self.L, self.C)
from models.model_plain import ModelPlain
import numpy as np
class ModelPlain4(ModelPlain):
"""Train with four inputs (L, k, sf, sigma) and with pixel loss for USRNet"""
# ----------------------------------------
# feed L/H data
# ----------------------------------------
def feed_data(self, data, need_H=True):
self.L = data['L'].to(self.device) # low-quality image
self.k = data['k'].to(self.device) # blur kernel
self.sf = np.int(data['sf'][0,...].squeeze().cpu().numpy()) # scale factor
self.sigma = data['sigma'].to(self.device) # noise level
if need_H:
self.H = data['H'].to(self.device) # H
# ----------------------------------------
# feed (L, C) to netG and get E
# ----------------------------------------
def netG_forward(self):
self.E = self.netG(self.L, self.k, self.sf, self.sigma)
from collections import OrderedDict
import torch
import torch.nn as nn
from torch.optim import lr_scheduler
from torch.optim import Adam
from models.select_network import define_G
from models.model_plain import ModelPlain
from models.loss import CharbonnierLoss
from models.loss_ssim import SSIMLoss
from utils.utils_model import test_mode
from utils.utils_regularizers import regularizer_orth, regularizer_clip
class ModelVRT(ModelPlain):
"""Train video restoration with pixel loss"""
def __init__(self, opt):
super(ModelVRT, self).__init__(opt)
self.fix_iter = self.opt_train.get('fix_iter', 0)
self.fix_keys = self.opt_train.get('fix_keys', [])
self.fix_unflagged = True
# ----------------------------------------
# define optimizer
# ----------------------------------------
def define_optimizer(self):
self.fix_keys = self.opt_train.get('fix_keys', [])
if self.opt_train.get('fix_iter', 0) and len(self.fix_keys) > 0:
fix_lr_mul = self.opt_train['fix_lr_mul']
print(f'Multiple the learning rate for keys: {self.fix_keys} with {fix_lr_mul}.')
if fix_lr_mul == 1:
G_optim_params = self.netG.parameters()
else: # separate flow params and normal params for different lr
normal_params = []
flow_params = []
for name, param in self.netG.named_parameters():
if any([key in name for key in self.fix_keys]):
flow_params.append(param)
else:
normal_params.append(param)
G_optim_params = [
{ # add normal params first
'params': normal_params,
'lr': self.opt_train['G_optimizer_lr']
},
{
'params': flow_params,
'lr': self.opt_train['G_optimizer_lr'] * fix_lr_mul
},
]
if self.opt_train['G_optimizer_type'] == 'adam':
self.G_optimizer = Adam(G_optim_params, lr=self.opt_train['G_optimizer_lr'],
betas=self.opt_train['G_optimizer_betas'],
weight_decay=self.opt_train['G_optimizer_wd'])
else:
raise NotImplementedError
else:
super(ModelVRT, self).define_optimizer()
# ----------------------------------------
# update parameters and get loss
# ----------------------------------------
def optimize_parameters(self, current_step):
if self.fix_iter:
if self.fix_unflagged and current_step < self.fix_iter:
print(f'Fix keys: {self.fix_keys} for the first {self.fix_iter} iters.')
self.fix_unflagged = False
for name, param in self.netG.named_parameters():
if any([key in name for key in self.fix_keys]):
param.requires_grad_(False)
elif current_step == self.fix_iter:
print(f'Train all the parameters from {self.fix_iter} iters.')
self.netG.requires_grad_(True)
super(ModelVRT, self).optimize_parameters(current_step)
# ----------------------------------------
# test / inference
# ----------------------------------------
def test(self):
n = self.L.size(1)
self.netG.eval()
pad_seq = self.opt_train.get('pad_seq', False)
flip_seq = self.opt_train.get('flip_seq', False)
self.center_frame_only = self.opt_train.get('center_frame_only', False)
if pad_seq:
n = n + 1
self.L = torch.cat([self.L, self.L[:, -1:, :, :, :]], dim=1)
if flip_seq:
self.L = torch.cat([self.L, self.L.flip(1)], dim=1)
with torch.no_grad():
self.E = self._test_video(self.L)
if flip_seq:
output_1 = self.E[:, :n, :, :, :]
output_2 = self.E[:, n:, :, :, :].flip(1)
self.E = 0.5 * (output_1 + output_2)
if pad_seq:
n = n - 1
self.E = self.E[:, :n, :, :, :]
if self.center_frame_only:
self.E = self.E[:, n // 2, :, :, :]
self.netG.train()
def _test_video(self, lq):
'''test the video as a whole or as clips (divided temporally). '''
num_frame_testing = self.opt['val'].get('num_frame_testing', 0)
if num_frame_testing:
# test as multiple clips if out-of-memory
sf = self.opt['scale']
num_frame_overlapping = self.opt['val'].get('num_frame_overlapping', 2)
not_overlap_border = False
b, d, c, h, w = lq.size()
c = c - 1 if self.opt['netG'].get('nonblind_denoising', False) else c
stride = num_frame_testing - num_frame_overlapping
d_idx_list = list(range(0, d-num_frame_testing, stride)) + [max(0, d-num_frame_testing)]
E = torch.zeros(b, d, c, h*sf, w*sf)
W = torch.zeros(b, d, 1, 1, 1)
for d_idx in d_idx_list:
lq_clip = lq[:, d_idx:d_idx+num_frame_testing, ...]
out_clip = self._test_clip(lq_clip)
out_clip_mask = torch.ones((b, min(num_frame_testing, d), 1, 1, 1))
if not_overlap_border:
if d_idx < d_idx_list[-1]:
out_clip[:, -num_frame_overlapping//2:, ...] *= 0
out_clip_mask[:, -num_frame_overlapping//2:, ...] *= 0
if d_idx > d_idx_list[0]:
out_clip[:, :num_frame_overlapping//2, ...] *= 0
out_clip_mask[:, :num_frame_overlapping//2, ...] *= 0
E[:, d_idx:d_idx+num_frame_testing, ...].add_(out_clip)
W[:, d_idx:d_idx+num_frame_testing, ...].add_(out_clip_mask)
output = E.div_(W)
else:
# test as one clip (the whole video) if you have enough memory
window_size = self.opt['netG'].get('window_size', [6,8,8])
d_old = lq.size(1)
d_pad = (d_old// window_size[0]+1)*window_size[0] - d_old
lq = torch.cat([lq, torch.flip(lq[:, -d_pad:, ...], [1])], 1)
output = self._test_clip(lq)
output = output[:, :d_old, :, :, :]
return output
def _test_clip(self, lq):
''' test the clip as a whole or as patches. '''
sf = self.opt['scale']
window_size = self.opt['netG'].get('window_size', [6,8,8])
size_patch_testing = self.opt['val'].get('size_patch_testing', 0)
assert size_patch_testing % window_size[-1] == 0, 'testing patch size should be a multiple of window_size.'
if size_patch_testing:
# divide the clip to patches (spatially only, tested patch by patch)
overlap_size = 20
not_overlap_border = True
# test patch by patch
b, d, c, h, w = lq.size()
c = c - 1 if self.opt['netG'].get('nonblind_denoising', False) else c
stride = size_patch_testing - overlap_size
h_idx_list = list(range(0, h-size_patch_testing, stride)) + [max(0, h-size_patch_testing)]
w_idx_list = list(range(0, w-size_patch_testing, stride)) + [max(0, w-size_patch_testing)]
E = torch.zeros(b, d, c, h*sf, w*sf)
W = torch.zeros_like(E)
for h_idx in h_idx_list:
for w_idx in w_idx_list:
in_patch = lq[..., h_idx:h_idx+size_patch_testing, w_idx:w_idx+size_patch_testing]
if hasattr(self, 'netE'):
out_patch = self.netE(in_patch).detach().cpu()
else:
out_patch = self.netG(in_patch).detach().cpu()
out_patch_mask = torch.ones_like(out_patch)
if not_overlap_border:
if h_idx < h_idx_list[-1]:
out_patch[..., -overlap_size//2:, :] *= 0
out_patch_mask[..., -overlap_size//2:, :] *= 0
if w_idx < w_idx_list[-1]:
out_patch[..., :, -overlap_size//2:] *= 0
out_patch_mask[..., :, -overlap_size//2:] *= 0
if h_idx > h_idx_list[0]:
out_patch[..., :overlap_size//2, :] *= 0
out_patch_mask[..., :overlap_size//2, :] *= 0
if w_idx > w_idx_list[0]:
out_patch[..., :, :overlap_size//2] *= 0
out_patch_mask[..., :, :overlap_size//2] *= 0
E[..., h_idx*sf:(h_idx+size_patch_testing)*sf, w_idx*sf:(w_idx+size_patch_testing)*sf].add_(out_patch)
W[..., h_idx*sf:(h_idx+size_patch_testing)*sf, w_idx*sf:(w_idx+size_patch_testing)*sf].add_(out_patch_mask)
output = E.div_(W)
else:
_, _, _, h_old, w_old = lq.size()
h_pad = (h_old// window_size[1]+1)*window_size[1] - h_old
w_pad = (w_old// window_size[2]+1)*window_size[2] - w_old
lq = torch.cat([lq, torch.flip(lq[:, :, :, -h_pad:, :], [3])], 3)
lq = torch.cat([lq, torch.flip(lq[:, :, :, :, -w_pad:], [4])], 4)
if hasattr(self, 'netE'):
output = self.netE(lq).detach().cpu()
else:
output = self.netG(lq).detach().cpu()
output = output[:, :, :, :h_old*sf, :w_old*sf]
return output
# ----------------------------------------
# load the state_dict of the network
# ----------------------------------------
def load_network(self, load_path, network, strict=True, param_key='params'):
network = self.get_bare_model(network)
state_dict = torch.load(load_path)
if param_key in state_dict.keys():
state_dict = state_dict[param_key]
self._print_different_keys_loading(network, state_dict, strict)
network.load_state_dict(state_dict, strict=strict)
def _print_different_keys_loading(self, crt_net, load_net, strict=True):
crt_net = self.get_bare_model(crt_net)
crt_net = crt_net.state_dict()
crt_net_keys = set(crt_net.keys())
load_net_keys = set(load_net.keys())
if crt_net_keys != load_net_keys:
print('Current net - loaded net:')
for v in sorted(list(crt_net_keys - load_net_keys)):
print(f' {v}')
print('Loaded net - current net:')
for v in sorted(list(load_net_keys - crt_net_keys)):
print(f' {v}')
# check the size for the same keys
if not strict:
common_keys = crt_net_keys & load_net_keys
for k in common_keys:
if crt_net[k].size() != load_net[k].size():
print(f'Size different, ignore [{k}]: crt_net: '
f'{crt_net[k].shape}; load_net: {load_net[k].shape}')
load_net[k + '.ignore'] = load_net.pop(k)
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
class MeanShift(nn.Conv2d):
def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
super(MeanShift, self).__init__(3, 3, kernel_size=1)
std = torch.Tensor(rgb_std)
self.weight.data = torch.eye(3).view(3, 3, 1, 1)
self.weight.data.div_(std.view(3, 1, 1, 1))
self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
self.bias.data.div_(std)
self.requires_grad = False
def init_weights(modules):
pass
class BasicBlock(nn.Module):
def __init__(self,
in_channels, out_channels,
ksize=3, stride=1, pad=1):
super(BasicBlock, self).__init__()
self.body = nn.Sequential(
nn.Conv2d(in_channels, out_channels, ksize, stride, pad),
nn.ReLU(inplace=True)
)
init_weights(self.modules)
def forward(self, x):
out = self.body(x)
return out
class ResidualBlock(nn.Module):
def __init__(self,
in_channels, out_channels):
super(ResidualBlock, self).__init__()
self.body = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, 1, 1),
)
init_weights(self.modules)
def forward(self, x):
out = self.body(x)
out = F.relu(out + x)
return out
class CNN5Layer(nn.Module):
def __init__(self):
super(CNN5Layer, self).__init__()
n_feats = 32
kernel_size = 3
rgb_mean = (0.4488, 0.4371, 0.4040)
rgb_std = (1.0, 1.0, 1.0)
self.sub_mean = MeanShift(1, rgb_mean, rgb_std)
self.add_mean = MeanShift(1, rgb_mean, rgb_std, 1)
self.head = BasicBlock(3, n_feats, kernel_size, 1, 1)
self.b1 = BasicBlock(n_feats, n_feats, kernel_size, 1, 1)
self.b2 = BasicBlock(n_feats, n_feats, kernel_size, 1, 1)
self.b3 = BasicBlock(n_feats, n_feats, kernel_size, 1, 1)
self.b4 = BasicBlock(n_feats, n_feats, kernel_size, 1, 1)
self.tail = nn.Conv2d(n_feats, 3, kernel_size, 1, 1, 1)
def forward(self, x):
s = self.sub_mean(x)
h = self.head(s)
b1 = self.b1(h)
b2 = self.b2(b1)
b3 = self.b3(b2)
b_out = self.b4(b3)
res = self.tail(b_out)
out = self.add_mean(res)
f_out = out + x
return f_out
\ No newline at end of file
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn.utils import spectral_norm
import models.basicblock as B
import functools
import numpy as np
"""
# --------------------------------------------
# Discriminator_PatchGAN
# Discriminator_UNet
# --------------------------------------------
"""
# --------------------------------------------
# PatchGAN discriminator
# If n_layers = 3, then the receptive field is 70x70
# --------------------------------------------
class Discriminator_PatchGAN(nn.Module):
def __init__(self, input_nc=3, ndf=64, n_layers=3, norm_type='spectral'):
'''PatchGAN discriminator, receptive field = 70x70 if n_layers = 3
Args:
input_nc: number of input channels
ndf: base channel number
n_layers: number of conv layer with stride 2
norm_type: 'batch', 'instance', 'spectral', 'batchspectral', instancespectral'
Returns:
tensor: score
'''
super(Discriminator_PatchGAN, self).__init__()
self.n_layers = n_layers
norm_layer = self.get_norm_layer(norm_type=norm_type)
kw = 4
padw = int(np.ceil((kw - 1.0) / 2))
sequence = [[self.use_spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), norm_type), nn.LeakyReLU(0.2, True)]]
nf = ndf
for n in range(1, n_layers):
nf_prev = nf
nf = min(nf * 2, 512)
sequence += [[self.use_spectral_norm(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), norm_type),
norm_layer(nf),
nn.LeakyReLU(0.2, True)]]
nf_prev = nf
nf = min(nf * 2, 512)
sequence += [[self.use_spectral_norm(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), norm_type),
norm_layer(nf),
nn.LeakyReLU(0.2, True)]]
sequence += [[self.use_spectral_norm(nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw), norm_type)]]
self.model = nn.Sequential()
for n in range(len(sequence)):
self.model.add_module('child' + str(n), nn.Sequential(*sequence[n]))
self.model.apply(self.weights_init)
def use_spectral_norm(self, module, norm_type='spectral'):
if 'spectral' in norm_type:
return spectral_norm(module)
return module
def get_norm_layer(self, norm_type='instance'):
if 'batch' in norm_type:
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
elif 'instance' in norm_type:
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
else:
norm_layer = functools.partial(nn.Identity)
return norm_layer
def weights_init(self, m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm2d') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
def forward(self, x):
return self.model(x)
class Discriminator_UNet(nn.Module):
"""Defines a U-Net discriminator with spectral normalization (SN)"""
def __init__(self, input_nc=3, ndf=64):
super(Discriminator_UNet, self).__init__()
norm = spectral_norm
self.conv0 = nn.Conv2d(input_nc, ndf, kernel_size=3, stride=1, padding=1)
self.conv1 = norm(nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False))
self.conv2 = norm(nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False))
self.conv3 = norm(nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False))
# upsample
self.conv4 = norm(nn.Conv2d(ndf * 8, ndf * 4, 3, 1, 1, bias=False))
self.conv5 = norm(nn.Conv2d(ndf * 4, ndf * 2, 3, 1, 1, bias=False))
self.conv6 = norm(nn.Conv2d(ndf * 2, ndf, 3, 1, 1, bias=False))
# extra
self.conv7 = norm(nn.Conv2d(ndf, ndf, 3, 1, 1, bias=False))
self.conv8 = norm(nn.Conv2d(ndf, ndf, 3, 1, 1, bias=False))
self.conv9 = nn.Conv2d(ndf, 1, 3, 1, 1)
print('using the UNet discriminator')
def forward(self, x):
x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True)
# upsample
x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False)
x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True)
x4 = x4 + x2
x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True)
x5 = x5 + x1
x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False)
x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True)
x6 = x6 + x0
# extra
out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
out = self.conv9(out)
return out
# --------------------------------------------
# VGG style Discriminator with 96x96 input
# --------------------------------------------
class Discriminator_VGG_96(nn.Module):
def __init__(self, in_nc=3, base_nc=64, ac_type='BL'):
super(Discriminator_VGG_96, self).__init__()
# features
# hxw, c
# 96, 64
conv0 = B.conv(in_nc, base_nc, kernel_size=3, mode='C')
conv1 = B.conv(base_nc, base_nc, kernel_size=4, stride=2, mode='C'+ac_type)
# 48, 64
conv2 = B.conv(base_nc, base_nc*2, kernel_size=3, stride=1, mode='C'+ac_type)
conv3 = B.conv(base_nc*2, base_nc*2, kernel_size=4, stride=2, mode='C'+ac_type)
# 24, 128
conv4 = B.conv(base_nc*2, base_nc*4, kernel_size=3, stride=1, mode='C'+ac_type)
conv5 = B.conv(base_nc*4, base_nc*4, kernel_size=4, stride=2, mode='C'+ac_type)
# 12, 256
conv6 = B.conv(base_nc*4, base_nc*8, kernel_size=3, stride=1, mode='C'+ac_type)
conv7 = B.conv(base_nc*8, base_nc*8, kernel_size=4, stride=2, mode='C'+ac_type)
# 6, 512
conv8 = B.conv(base_nc*8, base_nc*8, kernel_size=3, stride=1, mode='C'+ac_type)
conv9 = B.conv(base_nc*8, base_nc*8, kernel_size=4, stride=2, mode='C'+ac_type)
# 3, 512
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4,
conv5, conv6, conv7, conv8, conv9)
# classifier
self.classifier = nn.Sequential(
nn.Linear(512 * 3 * 3, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
# --------------------------------------------
# VGG style Discriminator with 128x128 input
# --------------------------------------------
class Discriminator_VGG_128(nn.Module):
def __init__(self, in_nc=3, base_nc=64, ac_type='BL'):
super(Discriminator_VGG_128, self).__init__()
# features
# hxw, c
# 128, 64
conv0 = B.conv(in_nc, base_nc, kernel_size=3, mode='C')
conv1 = B.conv(base_nc, base_nc, kernel_size=4, stride=2, mode='C'+ac_type)
# 64, 64
conv2 = B.conv(base_nc, base_nc*2, kernel_size=3, stride=1, mode='C'+ac_type)
conv3 = B.conv(base_nc*2, base_nc*2, kernel_size=4, stride=2, mode='C'+ac_type)
# 32, 128
conv4 = B.conv(base_nc*2, base_nc*4, kernel_size=3, stride=1, mode='C'+ac_type)
conv5 = B.conv(base_nc*4, base_nc*4, kernel_size=4, stride=2, mode='C'+ac_type)
# 16, 256
conv6 = B.conv(base_nc*4, base_nc*8, kernel_size=3, stride=1, mode='C'+ac_type)
conv7 = B.conv(base_nc*8, base_nc*8, kernel_size=4, stride=2, mode='C'+ac_type)
# 8, 512
conv8 = B.conv(base_nc*8, base_nc*8, kernel_size=3, stride=1, mode='C'+ac_type)
conv9 = B.conv(base_nc*8, base_nc*8, kernel_size=4, stride=2, mode='C'+ac_type)
# 4, 512
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4,
conv5, conv6, conv7, conv8, conv9)
# classifier
self.classifier = nn.Sequential(nn.Linear(512 * 4 * 4, 100),
nn.LeakyReLU(0.2, True),
nn.Linear(100, 1))
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
# --------------------------------------------
# VGG style Discriminator with 192x192 input
# --------------------------------------------
class Discriminator_VGG_192(nn.Module):
def __init__(self, in_nc=3, base_nc=64, ac_type='BL'):
super(Discriminator_VGG_192, self).__init__()
# features
# hxw, c
# 192, 64
conv0 = B.conv(in_nc, base_nc, kernel_size=3, mode='C')
conv1 = B.conv(base_nc, base_nc, kernel_size=4, stride=2, mode='C'+ac_type)
# 96, 64
conv2 = B.conv(base_nc, base_nc*2, kernel_size=3, stride=1, mode='C'+ac_type)
conv3 = B.conv(base_nc*2, base_nc*2, kernel_size=4, stride=2, mode='C'+ac_type)
# 48, 128
conv4 = B.conv(base_nc*2, base_nc*4, kernel_size=3, stride=1, mode='C'+ac_type)
conv5 = B.conv(base_nc*4, base_nc*4, kernel_size=4, stride=2, mode='C'+ac_type)
# 24, 256
conv6 = B.conv(base_nc*4, base_nc*8, kernel_size=3, stride=1, mode='C'+ac_type)
conv7 = B.conv(base_nc*8, base_nc*8, kernel_size=4, stride=2, mode='C'+ac_type)
# 12, 512
conv8 = B.conv(base_nc*8, base_nc*8, kernel_size=3, stride=1, mode='C'+ac_type)
conv9 = B.conv(base_nc*8, base_nc*8, kernel_size=4, stride=2, mode='C'+ac_type)
# 6, 512
conv10 = B.conv(base_nc*8, base_nc*8, kernel_size=3, stride=1, mode='C'+ac_type)
conv11 = B.conv(base_nc*8, base_nc*8, kernel_size=4, stride=2, mode='C'+ac_type)
# 3, 512
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5,
conv6, conv7, conv8, conv9, conv10, conv11)
# classifier
self.classifier = nn.Sequential(nn.Linear(512 * 3 * 3, 100),
nn.LeakyReLU(0.2, True),
nn.Linear(100, 1))
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
# --------------------------------------------
# SN-VGG style Discriminator with 128x128 input
# --------------------------------------------
class Discriminator_VGG_128_SN(nn.Module):
def __init__(self):
super(Discriminator_VGG_128_SN, self).__init__()
# features
# hxw, c
# 128, 64
self.lrelu = nn.LeakyReLU(0.2, True)
self.conv0 = spectral_norm(nn.Conv2d(3, 64, 3, 1, 1))
self.conv1 = spectral_norm(nn.Conv2d(64, 64, 4, 2, 1))
# 64, 64
self.conv2 = spectral_norm(nn.Conv2d(64, 128, 3, 1, 1))
self.conv3 = spectral_norm(nn.Conv2d(128, 128, 4, 2, 1))
# 32, 128
self.conv4 = spectral_norm(nn.Conv2d(128, 256, 3, 1, 1))
self.conv5 = spectral_norm(nn.Conv2d(256, 256, 4, 2, 1))
# 16, 256
self.conv6 = spectral_norm(nn.Conv2d(256, 512, 3, 1, 1))
self.conv7 = spectral_norm(nn.Conv2d(512, 512, 4, 2, 1))
# 8, 512
self.conv8 = spectral_norm(nn.Conv2d(512, 512, 3, 1, 1))
self.conv9 = spectral_norm(nn.Conv2d(512, 512, 4, 2, 1))
# 4, 512
# classifier
self.linear0 = spectral_norm(nn.Linear(512 * 4 * 4, 100))
self.linear1 = spectral_norm(nn.Linear(100, 1))
def forward(self, x):
x = self.lrelu(self.conv0(x))
x = self.lrelu(self.conv1(x))
x = self.lrelu(self.conv2(x))
x = self.lrelu(self.conv3(x))
x = self.lrelu(self.conv4(x))
x = self.lrelu(self.conv5(x))
x = self.lrelu(self.conv6(x))
x = self.lrelu(self.conv7(x))
x = self.lrelu(self.conv8(x))
x = self.lrelu(self.conv9(x))
x = x.view(x.size(0), -1)
x = self.lrelu(self.linear0(x))
x = self.linear1(x)
return x
if __name__ == '__main__':
x = torch.rand(1, 3, 96, 96)
net = Discriminator_VGG_96()
net.eval()
with torch.no_grad():
y = net(x)
print(y.size())
x = torch.rand(1, 3, 128, 128)
net = Discriminator_VGG_128()
net.eval()
with torch.no_grad():
y = net(x)
print(y.size())
x = torch.rand(1, 3, 192, 192)
net = Discriminator_VGG_192()
net.eval()
with torch.no_grad():
y = net(x)
print(y.size())
x = torch.rand(1, 3, 128, 128)
net = Discriminator_VGG_128_SN()
net.eval()
with torch.no_grad():
y = net(x)
print(y.size())
# run models/network_discriminator.py
import torch.nn as nn
import models.basicblock as B
"""
# --------------------------------------------
# DnCNN (20 conv layers)
# FDnCNN (20 conv layers)
# IRCNN (7 conv layers)
# --------------------------------------------
# References:
@article{zhang2017beyond,
title={Beyond a gaussian denoiser: Residual learning of deep cnn for image denoising},
author={Zhang, Kai and Zuo, Wangmeng and Chen, Yunjin and Meng, Deyu and Zhang, Lei},
journal={IEEE Transactions on Image Processing},
volume={26},
number={7},
pages={3142--3155},
year={2017},
publisher={IEEE}
}
@article{zhang2018ffdnet,
title={FFDNet: Toward a fast and flexible solution for CNN-based image denoising},
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
journal={IEEE Transactions on Image Processing},
volume={27},
number={9},
pages={4608--4622},
year={2018},
publisher={IEEE}
}
# --------------------------------------------
"""
# --------------------------------------------
# DnCNN
# --------------------------------------------
class DnCNN(nn.Module):
def __init__(self, in_nc=1, out_nc=1, nc=64, nb=17, act_mode='BR'):
"""
# ------------------------------------
in_nc: channel number of input
out_nc: channel number of output
nc: channel number
nb: total number of conv layers
act_mode: batch norm + activation function; 'BR' means BN+ReLU.
# ------------------------------------
Batch normalization and residual learning are
beneficial to Gaussian denoising (especially
for a single noise level).
The residual of a noisy image corrupted by additive white
Gaussian noise (AWGN) follows a constant
Gaussian distribution which stablizes batch
normalization during training.
# ------------------------------------
"""
super(DnCNN, self).__init__()
assert 'R' in act_mode or 'L' in act_mode, 'Examples of activation function: R, L, BR, BL, IR, IL'
bias = True
m_head = B.conv(in_nc, nc, mode='C'+act_mode[-1], bias=bias)
m_body = [B.conv(nc, nc, mode='C'+act_mode, bias=bias) for _ in range(nb-2)]
m_tail = B.conv(nc, out_nc, mode='C', bias=bias)
self.model = B.sequential(m_head, *m_body, m_tail)
def forward(self, x):
n = self.model(x)
return x-n
# --------------------------------------------
# IRCNN denoiser
# --------------------------------------------
class IRCNN(nn.Module):
def __init__(self, in_nc=1, out_nc=1, nc=64):
"""
# ------------------------------------
denoiser of IRCNN
in_nc: channel number of input
out_nc: channel number of output
nc: channel number
nb: total number of conv layers
act_mode: batch norm + activation function; 'BR' means BN+ReLU.
# ------------------------------------
Batch normalization and residual learning are
beneficial to Gaussian denoising (especially
for a single noise level).
The residual of a noisy image corrupted by additive white
Gaussian noise (AWGN) follows a constant
Gaussian distribution which stablizes batch
normalization during training.
# ------------------------------------
"""
super(IRCNN, self).__init__()
L =[]
L.append(nn.Conv2d(in_channels=in_nc, out_channels=nc, kernel_size=3, stride=1, padding=1, dilation=1, bias=True))
L.append(nn.ReLU(inplace=True))
L.append(nn.Conv2d(in_channels=nc, out_channels=nc, kernel_size=3, stride=1, padding=2, dilation=2, bias=True))
L.append(nn.ReLU(inplace=True))
L.append(nn.Conv2d(in_channels=nc, out_channels=nc, kernel_size=3, stride=1, padding=3, dilation=3, bias=True))
L.append(nn.ReLU(inplace=True))
L.append(nn.Conv2d(in_channels=nc, out_channels=nc, kernel_size=3, stride=1, padding=4, dilation=4, bias=True))
L.append(nn.ReLU(inplace=True))
L.append(nn.Conv2d(in_channels=nc, out_channels=nc, kernel_size=3, stride=1, padding=3, dilation=3, bias=True))
L.append(nn.ReLU(inplace=True))
L.append(nn.Conv2d(in_channels=nc, out_channels=nc, kernel_size=3, stride=1, padding=2, dilation=2, bias=True))
L.append(nn.ReLU(inplace=True))
L.append(nn.Conv2d(in_channels=nc, out_channels=out_nc, kernel_size=3, stride=1, padding=1, dilation=1, bias=True))
self.model = B.sequential(*L)
def forward(self, x):
n = self.model(x)
return x-n
# --------------------------------------------
# FDnCNN
# --------------------------------------------
# Compared with DnCNN, FDnCNN has three modifications:
# 1) add noise level map as input
# 2) remove residual learning and BN
# 3) train with L1 loss
# may need more training time, but will not reduce the final PSNR too much.
# --------------------------------------------
class FDnCNN(nn.Module):
def __init__(self, in_nc=2, out_nc=1, nc=64, nb=20, act_mode='R'):
"""
in_nc: channel number of input
out_nc: channel number of output
nc: channel number
nb: total number of conv layers
act_mode: batch norm + activation function; 'BR' means BN+ReLU.
"""
super(FDnCNN, self).__init__()
assert 'R' in act_mode or 'L' in act_mode, 'Examples of activation function: R, L, BR, BL, IR, IL'
bias = True
m_head = B.conv(in_nc, nc, mode='C'+act_mode[-1], bias=bias)
m_body = [B.conv(nc, nc, mode='C'+act_mode, bias=bias) for _ in range(nb-2)]
m_tail = B.conv(nc, out_nc, mode='C', bias=bias)
self.model = B.sequential(m_head, *m_body, m_tail)
def forward(self, x):
x = self.model(x)
return x
if __name__ == '__main__':
from utils import utils_model
import torch
model1 = DnCNN(in_nc=1, out_nc=1, nc=64, nb=20, act_mode='BR')
print(utils_model.describe_model(model1))
model2 = FDnCNN(in_nc=2, out_nc=1, nc=64, nb=20, act_mode='R')
print(utils_model.describe_model(model2))
x = torch.randn((1, 1, 240, 240))
x1 = model1(x)
print(x1.shape)
x = torch.randn((1, 2, 240, 240))
x2 = model2(x)
print(x2.shape)
# run models/network_dncnn.py
import math
import torch.nn as nn
import models.basicblock as B
"""
# --------------------------------------------
# modified SRResNet
# -- MSRResNet_prior (for DPSR)
# --------------------------------------------
References:
@inproceedings{zhang2019deep,
title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
pages={1671--1681},
year={2019}
}
@inproceedings{wang2018esrgan,
title={Esrgan: Enhanced super-resolution generative adversarial networks},
author={Wang, Xintao and Yu, Ke and Wu, Shixiang and Gu, Jinjin and Liu, Yihao and Dong, Chao and Qiao, Yu and Change Loy, Chen},
booktitle={European Conference on Computer Vision (ECCV)},
pages={0--0},
year={2018}
}
@inproceedings{ledig2017photo,
title={Photo-realistic single image super-resolution using a generative adversarial network},
author={Ledig, Christian and Theis, Lucas and Husz{\'a}r, Ferenc and Caballero, Jose and Cunningham, Andrew and Acosta, Alejandro and Aitken, Andrew and Tejani, Alykhan and Totz, Johannes and Wang, Zehan and others},
booktitle={IEEE conference on computer vision and pattern recognition},
pages={4681--4690},
year={2017}
}
# --------------------------------------------
"""
# --------------------------------------------
# MSRResNet super-resolver prior for DPSR
# https://github.com/cszn/DPSR
# https://github.com/cszn/DPSR/blob/master/models/network_srresnet.py
# --------------------------------------------
class MSRResNet_prior(nn.Module):
def __init__(self, in_nc=4, out_nc=3, nc=96, nb=16, upscale=4, act_mode='R', upsample_mode='upconv'):
super(MSRResNet_prior, self).__init__()
n_upscale = int(math.log(upscale, 2))
if upscale == 3:
n_upscale = 1
m_head = B.conv(in_nc, nc, mode='C')
m_body = [B.ResBlock(nc, nc, mode='C'+act_mode+'C') for _ in range(nb)]
m_body.append(B.conv(nc, nc, mode='C'))
if upsample_mode == 'upconv':
upsample_block = B.upsample_upconv
elif upsample_mode == 'pixelshuffle':
upsample_block = B.upsample_pixelshuffle
elif upsample_mode == 'convtranspose':
upsample_block = B.upsample_convtranspose
else:
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
if upscale == 3:
m_uper = upsample_block(nc, nc, mode='3'+act_mode)
else:
m_uper = [upsample_block(nc, nc, mode='2'+act_mode) for _ in range(n_upscale)]
H_conv0 = B.conv(nc, nc, mode='C'+act_mode)
H_conv1 = B.conv(nc, out_nc, bias=False, mode='C')
m_tail = B.sequential(H_conv0, H_conv1)
self.model = B.sequential(m_head, B.ShortcutBlock(B.sequential(*m_body)), *m_uper, m_tail)
def forward(self, x):
x = self.model(x)
return x
class SRResNet(nn.Module):
def __init__(self, in_nc=3, out_nc=3, nc=64, nb=16, upscale=4, act_mode='R', upsample_mode='upconv'):
super(SRResNet, self).__init__()
n_upscale = int(math.log(upscale, 2))
if upscale == 3:
n_upscale = 1
m_head = B.conv(in_nc, nc, mode='C')
m_body = [B.ResBlock(nc, nc, mode='C'+act_mode+'C') for _ in range(nb)]
m_body.append(B.conv(nc, nc, mode='C'))
if upsample_mode == 'upconv':
upsample_block = B.upsample_upconv
elif upsample_mode == 'pixelshuffle':
upsample_block = B.upsample_pixelshuffle
elif upsample_mode == 'convtranspose':
upsample_block = B.upsample_convtranspose
else:
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
if upscale == 3:
m_uper = upsample_block(nc, nc, mode='3'+act_mode)
else:
m_uper = [upsample_block(nc, nc, mode='2'+act_mode) for _ in range(n_upscale)]
H_conv0 = B.conv(nc, nc, mode='C'+act_mode)
H_conv1 = B.conv(nc, out_nc, bias=False, mode='C')
m_tail = B.sequential(H_conv0, H_conv1)
self.model = B.sequential(m_head, B.ShortcutBlock(B.sequential(*m_body)), *m_uper, m_tail)
def forward(self, x):
x = self.model(x)
return x
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment