Commit 76b9024b authored by yangzhong's avatar yangzhong
Browse files

git init

parents
Pipeline #3145 failed with stages
in 0 seconds
import collections.abc
import math
import torch
import torchvision
import warnings
from distutils.version import LooseVersion
from itertools import repeat
from torch import nn as nn
from torch.nn import functional as F
from torch.nn import init as init
from torch.nn.modules.batchnorm import _BatchNorm
from basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv
from basicsr.utils import get_root_logger
@torch.no_grad()
def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
"""Initialize network weights.
Args:
module_list (list[nn.Module] | nn.Module): Modules to be initialized.
scale (float): Scale initialized weights, especially for residual
blocks. Default: 1.
bias_fill (float): The value to fill bias. Default: 0
kwargs (dict): Other arguments for initialization function.
"""
if not isinstance(module_list, list):
module_list = [module_list]
for module in module_list:
for m in module.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, **kwargs)
m.weight.data *= scale
if m.bias is not None:
m.bias.data.fill_(bias_fill)
elif isinstance(m, nn.Linear):
init.kaiming_normal_(m.weight, **kwargs)
m.weight.data *= scale
if m.bias is not None:
m.bias.data.fill_(bias_fill)
elif isinstance(m, _BatchNorm):
init.constant_(m.weight, 1)
if m.bias is not None:
m.bias.data.fill_(bias_fill)
def make_layer(basic_block, num_basic_block, **kwarg):
"""Make layers by stacking the same blocks.
Args:
basic_block (nn.module): nn.module class for basic block.
num_basic_block (int): number of blocks.
Returns:
nn.Sequential: Stacked blocks in nn.Sequential.
"""
layers = []
for _ in range(num_basic_block):
layers.append(basic_block(**kwarg))
return nn.Sequential(*layers)
class PixelShufflePack(nn.Module):
"""Pixel Shuffle upsample layer.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
scale_factor (int): Upsample ratio.
upsample_kernel (int): Kernel size of Conv layer to expand channels.
Returns:
Upsampled feature map.
"""
def __init__(self, in_channels, out_channels, scale_factor,
upsample_kernel):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.scale_factor = scale_factor
self.upsample_kernel = upsample_kernel
self.upsample_conv = nn.Conv2d(
self.in_channels,
self.out_channels * scale_factor * scale_factor,
self.upsample_kernel,
padding=(self.upsample_kernel - 1) // 2)
self.init_weights()
def init_weights(self):
"""Initialize weights for PixelShufflePack."""
default_init_weights(self, 1)
def forward(self, x):
"""Forward function for PixelShufflePack.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
x = self.upsample_conv(x)
x = F.pixel_shuffle(x, self.scale_factor)
return x
class ResidualBlockNoBN(nn.Module):
"""Residual block without BN.
Args:
num_feat (int): Channel number of intermediate features.
Default: 64.
res_scale (float): Residual scale. Default: 1.
pytorch_init (bool): If set to True, use pytorch default init,
otherwise, use default_init_weights. Default: False.
"""
def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
super(ResidualBlockNoBN, self).__init__()
self.res_scale = res_scale
self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
self.relu = nn.ReLU()
if not pytorch_init:
default_init_weights([self.conv1, self.conv2], 0.1)
def forward(self, x):
identity = x
x = self.conv1(x)
x = self.relu(x)
out = self.conv2(x)
# out = self.conv2(self.relu(self.conv1(x)))
return identity + out * self.res_scale
class Upsample(nn.Sequential):
"""Upsample module.
Args:
scale (int): Scale factor. Supported scales: 2^n and 3.
num_feat (int): Channel number of intermediate features.
"""
def __init__(self, scale, num_feat):
m = []
if (scale & (scale - 1)) == 0: # scale = 2^n
for _ in range(int(math.log(scale, 2))):
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(2))
elif scale == 3:
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(3))
else:
raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
super(Upsample, self).__init__(*m)
def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True):
"""Warp an image or feature map with optical flow.
Args:
x (Tensor): Tensor with size (n, c, h, w).
flow (Tensor): Tensor with size (n, h, w, 2), normal value.
interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
padding_mode (str): 'zeros' or 'border' or 'reflection'.
Default: 'zeros'.
align_corners (bool): Before pytorch 1.3, the default value is
align_corners=True. After pytorch 1.3, the default value is
align_corners=False. Here, we use the True as default.
Returns:
Tensor: Warped image or feature map.
"""
assert x.size()[-2:] == flow.size()[1:3]
_, _, h, w = x.size()
# create mesh grid
grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
grid.requires_grad = False
vgrid = grid + flow
# scale grid to [-1,1]
vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)
# TODO, what if align_corners=False
return output
def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False):
"""Resize a flow according to ratio or shape.
Args:
flow (Tensor): Precomputed flow. shape [N, 2, H, W].
size_type (str): 'ratio' or 'shape'.
sizes (list[int | float]): the ratio for resizing or the final output
shape.
1) The order of ratio should be [ratio_h, ratio_w]. For
downsampling, the ratio should be smaller than 1.0 (i.e., ratio
< 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
ratio > 1.0).
2) The order of output_size should be [out_h, out_w].
interp_mode (str): The mode of interpolation for resizing.
Default: 'bilinear'.
align_corners (bool): Whether align corners. Default: False.
Returns:
Tensor: Resized flow.
"""
_, _, flow_h, flow_w = flow.size()
if size_type == 'ratio':
output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
elif size_type == 'shape':
output_h, output_w = sizes[0], sizes[1]
else:
raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.')
input_flow = flow.clone()
ratio_h = output_h / flow_h
ratio_w = output_w / flow_w
input_flow[:, 0, :, :] *= ratio_w
input_flow[:, 1, :, :] *= ratio_h
resized_flow = F.interpolate(
input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)
return resized_flow
# TODO: may write a cpp file
def pixel_unshuffle(x, scale):
""" Pixel unshuffle.
Args:
x (Tensor): Input feature with shape (b, c, hh, hw).
scale (int): Downsample ratio.
Returns:
Tensor: the pixel unshuffled feature.
"""
b, c, hh, hw = x.size()
out_channel = c * (scale**2)
assert hh % scale == 0 and hw % scale == 0
h = hh // scale
w = hw // scale
x_view = x.view(b, c, h, scale, w, scale)
return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
class DCNv2Pack(ModulatedDeformConvPack):
"""Modulated deformable conv for deformable alignment.
Different from the official DCNv2Pack, which generates offsets and masks
from the preceding features, this DCNv2Pack takes another different
features to generate offsets and masks.
``Paper: Delving Deep into Deformable Alignment in Video Super-Resolution``
"""
def forward(self, x, feat):
out = self.conv_offset(feat)
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
offset_absmean = torch.mean(torch.abs(offset))
if offset_absmean > 50:
logger = get_root_logger()
logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.')
if LooseVersion(torchvision.__version__) >= LooseVersion('0.9.0'):
return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
self.dilation, mask)
else:
return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding,
self.dilation, self.groups, self.deformable_groups)
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
'The distribution of values may be incorrect.',
stacklevel=2)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
low = norm_cdf((a - mean) / std)
up = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [low, up], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * low - 1, 2 * up - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
r"""Fills the input Tensor with values drawn from a truncated
normal distribution.
From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
# From PyTorch
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple
import torch
from torch import nn as nn
from torch.nn import functional as F
from basicsr.utils.registry import ARCH_REGISTRY
from .arch_util import ResidualBlockNoBN, flow_warp, make_layer
from .edvr_arch import PCDAlignment, TSAFusion
from .spynet_arch import SpyNet
@ARCH_REGISTRY.register()
class BasicVSR(nn.Module):
"""A recurrent network for video SR. Now only x4 is supported.
Args:
num_feat (int): Number of channels. Default: 64.
num_block (int): Number of residual blocks for each branch. Default: 15
spynet_path (str): Path to the pretrained weights of SPyNet. Default: None.
"""
def __init__(self, num_feat=64, num_block=15, spynet_path=None):
super().__init__()
self.num_feat = num_feat
# alignment
self.spynet = SpyNet(spynet_path)
# propagation
self.backward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block)
self.forward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block)
# reconstruction
self.fusion = nn.Conv2d(num_feat * 2, num_feat, 1, 1, 0, bias=True)
self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1, bias=True)
self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1, bias=True)
self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
self.pixel_shuffle = nn.PixelShuffle(2)
# activation functions
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
def get_flow(self, x):
b, n, c, h, w = x.size()
x_1 = x[:, :-1, :, :, :].reshape(-1, c, h, w)
x_2 = x[:, 1:, :, :, :].reshape(-1, c, h, w)
flows_backward = self.spynet(x_1, x_2).view(b, n - 1, 2, h, w)
flows_forward = self.spynet(x_2, x_1).view(b, n - 1, 2, h, w)
return flows_forward, flows_backward
def forward(self, x):
"""Forward function of BasicVSR.
Args:
x: Input frames with shape (b, n, c, h, w). n is the temporal dimension / number of frames.
"""
flows_forward, flows_backward = self.get_flow(x)
b, n, _, h, w = x.size()
# backward branch
out_l = []
feat_prop = x.new_zeros(b, self.num_feat, h, w)
for i in range(n - 1, -1, -1):
x_i = x[:, i, :, :, :]
if i < n - 1:
flow = flows_backward[:, i, :, :, :]
feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
feat_prop = torch.cat([x_i, feat_prop], dim=1)
feat_prop = self.backward_trunk(feat_prop)
out_l.insert(0, feat_prop)
# forward branch
feat_prop = torch.zeros_like(feat_prop)
for i in range(0, n):
x_i = x[:, i, :, :, :]
if i > 0:
flow = flows_forward[:, i - 1, :, :, :]
feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
feat_prop = torch.cat([x_i, feat_prop], dim=1)
feat_prop = self.forward_trunk(feat_prop)
# upsample
out = torch.cat([out_l[i], feat_prop], dim=1)
out = self.lrelu(self.fusion(out))
out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
out = self.lrelu(self.conv_hr(out))
out = self.conv_last(out)
base = F.interpolate(x_i, scale_factor=4, mode='bilinear', align_corners=False)
out += base
out_l[i] = out
return torch.stack(out_l, dim=1)
class ConvResidualBlocks(nn.Module):
"""Conv and residual block used in BasicVSR.
Args:
num_in_ch (int): Number of input channels. Default: 3.
num_out_ch (int): Number of output channels. Default: 64.
num_block (int): Number of residual blocks. Default: 15.
"""
def __init__(self, num_in_ch=3, num_out_ch=64, num_block=15):
super().__init__()
self.main = nn.Sequential(
nn.Conv2d(num_in_ch, num_out_ch, 3, 1, 1, bias=True), nn.LeakyReLU(negative_slope=0.1, inplace=True),
make_layer(ResidualBlockNoBN, num_block, num_feat=num_out_ch))
def forward(self, fea):
return self.main(fea)
@ARCH_REGISTRY.register()
class IconVSR(nn.Module):
"""IconVSR, proposed also in the BasicVSR paper.
Args:
num_feat (int): Number of channels. Default: 64.
num_block (int): Number of residual blocks for each branch. Default: 15.
keyframe_stride (int): Keyframe stride. Default: 5.
temporal_padding (int): Temporal padding. Default: 2.
spynet_path (str): Path to the pretrained weights of SPyNet. Default: None.
edvr_path (str): Path to the pretrained EDVR model. Default: None.
"""
def __init__(self,
num_feat=64,
num_block=15,
keyframe_stride=5,
temporal_padding=2,
spynet_path=None,
edvr_path=None):
super().__init__()
self.num_feat = num_feat
self.temporal_padding = temporal_padding
self.keyframe_stride = keyframe_stride
# keyframe_branch
self.edvr = EDVRFeatureExtractor(temporal_padding * 2 + 1, num_feat, edvr_path)
# alignment
self.spynet = SpyNet(spynet_path)
# propagation
self.backward_fusion = nn.Conv2d(2 * num_feat, num_feat, 3, 1, 1, bias=True)
self.backward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block)
self.forward_fusion = nn.Conv2d(2 * num_feat, num_feat, 3, 1, 1, bias=True)
self.forward_trunk = ConvResidualBlocks(2 * num_feat + 3, num_feat, num_block)
# reconstruction
self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1, bias=True)
self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1, bias=True)
self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
self.pixel_shuffle = nn.PixelShuffle(2)
# activation functions
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
def pad_spatial(self, x):
"""Apply padding spatially.
Since the PCD module in EDVR requires that the resolution is a multiple
of 4, we apply padding to the input LR images if their resolution is
not divisible by 4.
Args:
x (Tensor): Input LR sequence with shape (n, t, c, h, w).
Returns:
Tensor: Padded LR sequence with shape (n, t, c, h_pad, w_pad).
"""
n, t, c, h, w = x.size()
pad_h = (4 - h % 4) % 4
pad_w = (4 - w % 4) % 4
# padding
x = x.view(-1, c, h, w)
x = F.pad(x, [0, pad_w, 0, pad_h], mode='reflect')
return x.view(n, t, c, h + pad_h, w + pad_w)
def get_flow(self, x):
b, n, c, h, w = x.size()
x_1 = x[:, :-1, :, :, :].reshape(-1, c, h, w)
x_2 = x[:, 1:, :, :, :].reshape(-1, c, h, w)
flows_backward = self.spynet(x_1, x_2).view(b, n - 1, 2, h, w)
flows_forward = self.spynet(x_2, x_1).view(b, n - 1, 2, h, w)
return flows_forward, flows_backward
def get_keyframe_feature(self, x, keyframe_idx):
if self.temporal_padding == 2:
x = [x[:, [4, 3]], x, x[:, [-4, -5]]]
elif self.temporal_padding == 3:
x = [x[:, [6, 5, 4]], x, x[:, [-5, -6, -7]]]
x = torch.cat(x, dim=1)
num_frames = 2 * self.temporal_padding + 1
feats_keyframe = {}
for i in keyframe_idx:
feats_keyframe[i] = self.edvr(x[:, i:i + num_frames].contiguous())
return feats_keyframe
def forward(self, x):
b, n, _, h_input, w_input = x.size()
x = self.pad_spatial(x)
h, w = x.shape[3:]
keyframe_idx = list(range(0, n, self.keyframe_stride))
if keyframe_idx[-1] != n - 1:
keyframe_idx.append(n - 1) # last frame is a keyframe
# compute flow and keyframe features
flows_forward, flows_backward = self.get_flow(x)
feats_keyframe = self.get_keyframe_feature(x, keyframe_idx)
# backward branch
out_l = []
feat_prop = x.new_zeros(b, self.num_feat, h, w)
for i in range(n - 1, -1, -1):
x_i = x[:, i, :, :, :]
if i < n - 1:
flow = flows_backward[:, i, :, :, :]
feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
if i in keyframe_idx:
feat_prop = torch.cat([feat_prop, feats_keyframe[i]], dim=1)
feat_prop = self.backward_fusion(feat_prop)
feat_prop = torch.cat([x_i, feat_prop], dim=1)
feat_prop = self.backward_trunk(feat_prop)
out_l.insert(0, feat_prop)
# forward branch
feat_prop = torch.zeros_like(feat_prop)
for i in range(0, n):
x_i = x[:, i, :, :, :]
if i > 0:
flow = flows_forward[:, i - 1, :, :, :]
feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
if i in keyframe_idx:
feat_prop = torch.cat([feat_prop, feats_keyframe[i]], dim=1)
feat_prop = self.forward_fusion(feat_prop)
feat_prop = torch.cat([x_i, out_l[i], feat_prop], dim=1)
feat_prop = self.forward_trunk(feat_prop)
# upsample
out = self.lrelu(self.pixel_shuffle(self.upconv1(feat_prop)))
out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
out = self.lrelu(self.conv_hr(out))
out = self.conv_last(out)
base = F.interpolate(x_i, scale_factor=4, mode='bilinear', align_corners=False)
out += base
out_l[i] = out
return torch.stack(out_l, dim=1)[..., :4 * h_input, :4 * w_input]
class EDVRFeatureExtractor(nn.Module):
"""EDVR feature extractor used in IconVSR.
Args:
num_input_frame (int): Number of input frames.
num_feat (int): Number of feature channels
load_path (str): Path to the pretrained weights of EDVR. Default: None.
"""
def __init__(self, num_input_frame, num_feat, load_path):
super(EDVRFeatureExtractor, self).__init__()
self.center_frame_idx = num_input_frame // 2
# extract pyramid features
self.conv_first = nn.Conv2d(3, num_feat, 3, 1, 1)
self.feature_extraction = make_layer(ResidualBlockNoBN, 5, num_feat=num_feat)
self.conv_l2_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
self.conv_l2_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_l3_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
self.conv_l3_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
# pcd and tsa module
self.pcd_align = PCDAlignment(num_feat=num_feat, deformable_groups=8)
self.fusion = TSAFusion(num_feat=num_feat, num_frame=num_input_frame, center_frame_idx=self.center_frame_idx)
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
if load_path:
self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params'])
def forward(self, x):
b, n, c, h, w = x.size()
# extract features for each frame
# L1
feat_l1 = self.lrelu(self.conv_first(x.view(-1, c, h, w)))
feat_l1 = self.feature_extraction(feat_l1)
# L2
feat_l2 = self.lrelu(self.conv_l2_1(feat_l1))
feat_l2 = self.lrelu(self.conv_l2_2(feat_l2))
# L3
feat_l3 = self.lrelu(self.conv_l3_1(feat_l2))
feat_l3 = self.lrelu(self.conv_l3_2(feat_l3))
feat_l1 = feat_l1.view(b, n, -1, h, w)
feat_l2 = feat_l2.view(b, n, -1, h // 2, w // 2)
feat_l3 = feat_l3.view(b, n, -1, h // 4, w // 4)
# PCD alignment
ref_feat_l = [ # reference feature list
feat_l1[:, self.center_frame_idx, :, :, :].clone(), feat_l2[:, self.center_frame_idx, :, :, :].clone(),
feat_l3[:, self.center_frame_idx, :, :, :].clone()
]
aligned_feat = []
for i in range(n):
nbr_feat_l = [ # neighboring feature list
feat_l1[:, i, :, :, :].clone(), feat_l2[:, i, :, :, :].clone(), feat_l3[:, i, :, :, :].clone()
]
aligned_feat.append(self.pcd_align(nbr_feat_l, ref_feat_l))
aligned_feat = torch.stack(aligned_feat, dim=1) # (b, t, c, h, w)
# TSA fusion
return self.fusion(aligned_feat)
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import warnings
from basicsr.archs.arch_util import flow_warp
from basicsr.archs.basicvsr_arch import ConvResidualBlocks
from basicsr.archs.spynet_arch import SpyNet
from basicsr.ops.dcn import ModulatedDeformConvPack
from basicsr.utils.registry import ARCH_REGISTRY
@ARCH_REGISTRY.register()
class BasicVSRPlusPlus(nn.Module):
"""BasicVSR++ network structure.
Support either x4 upsampling or same size output. Since DCN is used in this
model, it can only be used with CUDA enabled. If CUDA is not enabled,
feature alignment will be skipped. Besides, we adopt the official DCN
implementation and the version of torch need to be higher than 1.9.
``Paper: BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment``
Args:
mid_channels (int, optional): Channel number of the intermediate
features. Default: 64.
num_blocks (int, optional): The number of residual blocks in each
propagation branch. Default: 7.
max_residue_magnitude (int): The maximum magnitude of the offset
residue (Eq. 6 in paper). Default: 10.
is_low_res_input (bool, optional): Whether the input is low-resolution
or not. If False, the output resolution is equal to the input
resolution. Default: True.
spynet_path (str): Path to the pretrained weights of SPyNet. Default: None.
cpu_cache_length (int, optional): When the length of sequence is larger
than this value, the intermediate features are sent to CPU. This
saves GPU memory, but slows down the inference speed. You can
increase this number if you have a GPU with large memory.
Default: 100.
"""
def __init__(self,
mid_channels=64,
num_blocks=7,
max_residue_magnitude=10,
is_low_res_input=True,
spynet_path=None,
cpu_cache_length=100):
super().__init__()
self.mid_channels = mid_channels
self.is_low_res_input = is_low_res_input
self.cpu_cache_length = cpu_cache_length
# optical flow
self.spynet = SpyNet(spynet_path)
# feature extraction module
if is_low_res_input:
self.feat_extract = ConvResidualBlocks(3, mid_channels, 5)
else:
self.feat_extract = nn.Sequential(
nn.Conv2d(3, mid_channels, 3, 2, 1), nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(mid_channels, mid_channels, 3, 2, 1), nn.LeakyReLU(negative_slope=0.1, inplace=True),
ConvResidualBlocks(mid_channels, mid_channels, 5))
# propagation branches
self.deform_align = nn.ModuleDict()
self.backbone = nn.ModuleDict()
modules = ['backward_1', 'forward_1', 'backward_2', 'forward_2']
for i, module in enumerate(modules):
if torch.cuda.is_available():
self.deform_align[module] = SecondOrderDeformableAlignment(
2 * mid_channels,
mid_channels,
3,
padding=1,
deformable_groups=16,
max_residue_magnitude=max_residue_magnitude)
self.backbone[module] = ConvResidualBlocks((2 + i) * mid_channels, mid_channels, num_blocks)
# upsampling module
self.reconstruction = ConvResidualBlocks(5 * mid_channels, mid_channels, 5)
self.upconv1 = nn.Conv2d(mid_channels, mid_channels * 4, 3, 1, 1, bias=True)
self.upconv2 = nn.Conv2d(mid_channels, 64 * 4, 3, 1, 1, bias=True)
self.pixel_shuffle = nn.PixelShuffle(2)
self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
self.img_upsample = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False)
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
# check if the sequence is augmented by flipping
self.is_mirror_extended = False
if len(self.deform_align) > 0:
self.is_with_alignment = True
else:
self.is_with_alignment = False
warnings.warn('Deformable alignment module is not added. '
'Probably your CUDA is not configured correctly. DCN can only '
'be used with CUDA enabled. Alignment is skipped now.')
def check_if_mirror_extended(self, lqs):
"""Check whether the input is a mirror-extended sequence.
If mirror-extended, the i-th (i=0, ..., t-1) frame is equal to the (t-1-i)-th frame.
Args:
lqs (tensor): Input low quality (LQ) sequence with shape (n, t, c, h, w).
"""
if lqs.size(1) % 2 == 0:
lqs_1, lqs_2 = torch.chunk(lqs, 2, dim=1)
if torch.norm(lqs_1 - lqs_2.flip(1)) == 0:
self.is_mirror_extended = True
def compute_flow(self, lqs):
"""Compute optical flow using SPyNet for feature alignment.
Note that if the input is an mirror-extended sequence, 'flows_forward'
is not needed, since it is equal to 'flows_backward.flip(1)'.
Args:
lqs (tensor): Input low quality (LQ) sequence with
shape (n, t, c, h, w).
Return:
tuple(Tensor): Optical flow. 'flows_forward' corresponds to the flows used for forward-time propagation \
(current to previous). 'flows_backward' corresponds to the flows used for backward-time \
propagation (current to next).
"""
n, t, c, h, w = lqs.size()
lqs_1 = lqs[:, :-1, :, :, :].reshape(-1, c, h, w)
lqs_2 = lqs[:, 1:, :, :, :].reshape(-1, c, h, w)
flows_backward = self.spynet(lqs_1, lqs_2).view(n, t - 1, 2, h, w)
if self.is_mirror_extended: # flows_forward = flows_backward.flip(1)
flows_forward = flows_backward.flip(1)
else:
flows_forward = self.spynet(lqs_2, lqs_1).view(n, t - 1, 2, h, w)
if self.cpu_cache:
flows_backward = flows_backward.cpu()
flows_forward = flows_forward.cpu()
return flows_forward, flows_backward
def propagate(self, feats, flows, module_name):
"""Propagate the latent features throughout the sequence.
Args:
feats dict(list[tensor]): Features from previous branches. Each
component is a list of tensors with shape (n, c, h, w).
flows (tensor): Optical flows with shape (n, t - 1, 2, h, w).
module_name (str): The name of the propgation branches. Can either
be 'backward_1', 'forward_1', 'backward_2', 'forward_2'.
Return:
dict(list[tensor]): A dictionary containing all the propagated \
features. Each key in the dictionary corresponds to a \
propagation branch, which is represented by a list of tensors.
"""
n, t, _, h, w = flows.size()
frame_idx = range(0, t + 1)
flow_idx = range(-1, t)
mapping_idx = list(range(0, len(feats['spatial'])))
mapping_idx += mapping_idx[::-1]
if 'backward' in module_name:
frame_idx = frame_idx[::-1]
flow_idx = frame_idx
feat_prop = flows.new_zeros(n, self.mid_channels, h, w)
for i, idx in enumerate(frame_idx):
feat_current = feats['spatial'][mapping_idx[idx]]
if self.cpu_cache:
feat_current = feat_current.cuda()
feat_prop = feat_prop.cuda()
# second-order deformable alignment
if i > 0 and self.is_with_alignment:
flow_n1 = flows[:, flow_idx[i], :, :, :]
if self.cpu_cache:
flow_n1 = flow_n1.cuda()
cond_n1 = flow_warp(feat_prop, flow_n1.permute(0, 2, 3, 1))
# initialize second-order features
feat_n2 = torch.zeros_like(feat_prop)
flow_n2 = torch.zeros_like(flow_n1)
cond_n2 = torch.zeros_like(cond_n1)
if i > 1: # second-order features
feat_n2 = feats[module_name][-2]
if self.cpu_cache:
feat_n2 = feat_n2.cuda()
flow_n2 = flows[:, flow_idx[i - 1], :, :, :]
if self.cpu_cache:
flow_n2 = flow_n2.cuda()
flow_n2 = flow_n1 + flow_warp(flow_n2, flow_n1.permute(0, 2, 3, 1))
cond_n2 = flow_warp(feat_n2, flow_n2.permute(0, 2, 3, 1))
# flow-guided deformable convolution
cond = torch.cat([cond_n1, feat_current, cond_n2], dim=1)
feat_prop = torch.cat([feat_prop, feat_n2], dim=1)
feat_prop = self.deform_align[module_name](feat_prop, cond, flow_n1, flow_n2)
# concatenate and residual blocks
feat = [feat_current] + [feats[k][idx] for k in feats if k not in ['spatial', module_name]] + [feat_prop]
if self.cpu_cache:
feat = [f.cuda() for f in feat]
feat = torch.cat(feat, dim=1)
feat_prop = feat_prop + self.backbone[module_name](feat)
feats[module_name].append(feat_prop)
if self.cpu_cache:
feats[module_name][-1] = feats[module_name][-1].cpu()
torch.cuda.empty_cache()
if 'backward' in module_name:
feats[module_name] = feats[module_name][::-1]
return feats
def upsample(self, lqs, feats):
"""Compute the output image given the features.
Args:
lqs (tensor): Input low quality (LQ) sequence with
shape (n, t, c, h, w).
feats (dict): The features from the propagation branches.
Returns:
Tensor: Output HR sequence with shape (n, t, c, 4h, 4w).
"""
outputs = []
num_outputs = len(feats['spatial'])
mapping_idx = list(range(0, num_outputs))
mapping_idx += mapping_idx[::-1]
for i in range(0, lqs.size(1)):
hr = [feats[k].pop(0) for k in feats if k != 'spatial']
hr.insert(0, feats['spatial'][mapping_idx[i]])
hr = torch.cat(hr, dim=1)
if self.cpu_cache:
hr = hr.cuda()
hr = self.reconstruction(hr)
hr = self.lrelu(self.pixel_shuffle(self.upconv1(hr)))
hr = self.lrelu(self.pixel_shuffle(self.upconv2(hr)))
hr = self.lrelu(self.conv_hr(hr))
hr = self.conv_last(hr)
if self.is_low_res_input:
hr += self.img_upsample(lqs[:, i, :, :, :])
else:
hr += lqs[:, i, :, :, :]
if self.cpu_cache:
hr = hr.cpu()
torch.cuda.empty_cache()
outputs.append(hr)
return torch.stack(outputs, dim=1)
def forward(self, lqs):
"""Forward function for BasicVSR++.
Args:
lqs (tensor): Input low quality (LQ) sequence with
shape (n, t, c, h, w).
Returns:
Tensor: Output HR sequence with shape (n, t, c, 4h, 4w).
"""
n, t, c, h, w = lqs.size()
# whether to cache the features in CPU
self.cpu_cache = True if t > self.cpu_cache_length else False
if self.is_low_res_input:
lqs_downsample = lqs.clone()
else:
lqs_downsample = F.interpolate(
lqs.view(-1, c, h, w), scale_factor=0.25, mode='bicubic').view(n, t, c, h // 4, w // 4)
# check whether the input is an extended sequence
self.check_if_mirror_extended(lqs)
feats = {}
# compute spatial features
if self.cpu_cache:
feats['spatial'] = []
for i in range(0, t):
feat = self.feat_extract(lqs[:, i, :, :, :]).cpu()
feats['spatial'].append(feat)
torch.cuda.empty_cache()
else:
feats_ = self.feat_extract(lqs.view(-1, c, h, w))
h, w = feats_.shape[2:]
feats_ = feats_.view(n, t, -1, h, w)
feats['spatial'] = [feats_[:, i, :, :, :] for i in range(0, t)]
# compute optical flow using the low-res inputs
assert lqs_downsample.size(3) >= 64 and lqs_downsample.size(4) >= 64, (
'The height and width of low-res inputs must be at least 64, '
f'but got {h} and {w}.')
flows_forward, flows_backward = self.compute_flow(lqs_downsample)
# feature propgation
for iter_ in [1, 2]:
for direction in ['backward', 'forward']:
module = f'{direction}_{iter_}'
feats[module] = []
if direction == 'backward':
flows = flows_backward
elif flows_forward is not None:
flows = flows_forward
else:
flows = flows_backward.flip(1)
feats = self.propagate(feats, flows, module)
if self.cpu_cache:
del flows
torch.cuda.empty_cache()
return self.upsample(lqs, feats)
class SecondOrderDeformableAlignment(ModulatedDeformConvPack):
"""Second-order deformable alignment module.
Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
kernel_size (int or tuple[int]): Same as nn.Conv2d.
stride (int or tuple[int]): Same as nn.Conv2d.
padding (int or tuple[int]): Same as nn.Conv2d.
dilation (int or tuple[int]): Same as nn.Conv2d.
groups (int): Same as nn.Conv2d.
bias (bool or str): If specified as `auto`, it will be decided by the
norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
False.
max_residue_magnitude (int): The maximum magnitude of the offset
residue (Eq. 6 in paper). Default: 10.
"""
def __init__(self, *args, **kwargs):
self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 10)
super(SecondOrderDeformableAlignment, self).__init__(*args, **kwargs)
self.conv_offset = nn.Sequential(
nn.Conv2d(3 * self.out_channels + 4, self.out_channels, 3, 1, 1),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(self.out_channels, 27 * self.deformable_groups, 3, 1, 1),
)
self.init_offset()
def init_offset(self):
def _constant_init(module, val, bias=0):
if hasattr(module, 'weight') and module.weight is not None:
nn.init.constant_(module.weight, val)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
_constant_init(self.conv_offset[-1], val=0, bias=0)
def forward(self, x, extra_feat, flow_1, flow_2):
extra_feat = torch.cat([extra_feat, flow_1, flow_2], dim=1)
out = self.conv_offset(extra_feat)
o1, o2, mask = torch.chunk(out, 3, dim=1)
# offset
offset = self.max_residue_magnitude * torch.tanh(torch.cat((o1, o2), dim=1))
offset_1, offset_2 = torch.chunk(offset, 2, dim=1)
offset_1 = offset_1 + flow_1.flip(1).repeat(1, offset_1.size(1) // 2, 1, 1)
offset_2 = offset_2 + flow_2.flip(1).repeat(1, offset_2.size(1) // 2, 1, 1)
offset = torch.cat([offset_1, offset_2], dim=1)
# mask
mask = torch.sigmoid(mask)
return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
self.dilation, mask)
# if __name__ == '__main__':
# spynet_path = 'experiments/pretrained_models/flownet/spynet_sintel_final-3d2a1287.pth'
# model = BasicVSRPlusPlus(spynet_path=spynet_path).cuda()
# input = torch.rand(1, 2, 3, 64, 64).cuda()
# output = model(input)
# print('===================')
# print(output.shape)
from torch import nn as nn
from basicsr.archs.arch_util import ResidualBlockNoBN, default_init_weights
from basicsr.utils.registry import ARCH_REGISTRY
@ARCH_REGISTRY.register()
class DEResNet(nn.Module):
"""Degradation Estimator with ResNetNoBN arch. v2.1, no vector anymore
As shown in paper 'Towards Flexible Blind JPEG Artifacts Removal',
resnet arch works for image quality estimation.
Args:
num_in_ch (int): channel number of inputs. Default: 3.
num_degradation (int): num of degradation the DE should estimate. Default: 2(blur+noise).
degradation_embed_size (int): embedding size of each degradation vector.
degradation_degree_actv (int): activation function for degradation degree scalar. Default: sigmoid.
num_feats (list): channel number of each stage.
num_blocks (list): residual block of each stage.
downscales (list): downscales of each stage.
"""
def __init__(self,
num_in_ch=3,
num_degradation=2,
degradation_degree_actv='sigmoid',
num_feats=(64, 128, 256, 512),
num_blocks=(2, 2, 2, 2),
downscales=(2, 2, 2, 1)):
super(DEResNet, self).__init__()
assert isinstance(num_feats, list)
assert isinstance(num_blocks, list)
assert isinstance(downscales, list)
assert len(num_feats) == len(num_blocks) and len(num_feats) == len(downscales)
num_stage = len(num_feats)
self.conv_first = nn.ModuleList()
for _ in range(num_degradation):
self.conv_first.append(nn.Conv2d(num_in_ch, num_feats[0], 3, 1, 1))
self.body = nn.ModuleList()
for _ in range(num_degradation):
body = list()
for stage in range(num_stage):
for _ in range(num_blocks[stage]):
body.append(ResidualBlockNoBN(num_feats[stage]))
if downscales[stage] == 1:
if stage < num_stage - 1 and num_feats[stage] != num_feats[stage + 1]:
body.append(nn.Conv2d(num_feats[stage], num_feats[stage + 1], 3, 1, 1))
continue
elif downscales[stage] == 2:
body.append(nn.Conv2d(num_feats[stage], num_feats[min(stage + 1, num_stage - 1)], 3, 2, 1))
else:
raise NotImplementedError
self.body.append(nn.Sequential(*body))
# self.body = nn.Sequential(*body)
self.num_degradation = num_degradation
self.fc_degree = nn.ModuleList()
if degradation_degree_actv == 'sigmoid':
actv = nn.Sigmoid
elif degradation_degree_actv == 'tanh':
actv = nn.Tanh
else:
raise NotImplementedError(f'only sigmoid and tanh are supported for degradation_degree_actv, '
f'{degradation_degree_actv} is not supported yet.')
for _ in range(num_degradation):
self.fc_degree.append(
nn.Sequential(
nn.Linear(num_feats[-1], 512),
nn.ReLU(inplace=True),
nn.Linear(512, 1),
actv(),
))
self.avg_pool = nn.AdaptiveAvgPool2d(1)
default_init_weights([self.conv_first, self.body, self.fc_degree], 0.1)
def forward(self, x):
degrees = []
for i in range(self.num_degradation):
x_out = self.conv_first[i](x)
feat = self.body[i](x_out)
feat = self.avg_pool(feat)
feat = feat.squeeze(-1).squeeze(-1)
# for i in range(self.num_degradation):
degrees.append(self.fc_degree[i](feat).squeeze(-1))
return degrees
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.spectral_norm import spectral_norm
from basicsr.utils.registry import ARCH_REGISTRY
from .dfdnet_util import AttentionBlock, Blur, MSDilationBlock, UpResBlock, adaptive_instance_normalization
from .vgg_arch import VGGFeatureExtractor
class SFTUpBlock(nn.Module):
"""Spatial feature transform (SFT) with upsampling block.
Args:
in_channel (int): Number of input channels.
out_channel (int): Number of output channels.
kernel_size (int): Kernel size in convolutions. Default: 3.
padding (int): Padding in convolutions. Default: 1.
"""
def __init__(self, in_channel, out_channel, kernel_size=3, padding=1):
super(SFTUpBlock, self).__init__()
self.conv1 = nn.Sequential(
Blur(in_channel),
spectral_norm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)),
nn.LeakyReLU(0.04, True),
# The official codes use two LeakyReLU here, so 0.04 for equivalent
)
self.convup = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
spectral_norm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)),
nn.LeakyReLU(0.2, True),
)
# for SFT scale and shift
self.scale_block = nn.Sequential(
spectral_norm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True),
spectral_norm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)))
self.shift_block = nn.Sequential(
spectral_norm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True),
spectral_norm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)), nn.Sigmoid())
# The official codes use sigmoid for shift block, do not know why
def forward(self, x, updated_feat):
out = self.conv1(x)
# SFT
scale = self.scale_block(updated_feat)
shift = self.shift_block(updated_feat)
out = out * scale + shift
# upsample
out = self.convup(out)
return out
@ARCH_REGISTRY.register()
class DFDNet(nn.Module):
"""DFDNet: Deep Face Dictionary Network.
It only processes faces with 512x512 size.
Args:
num_feat (int): Number of feature channels.
dict_path (str): Path to the facial component dictionary.
"""
def __init__(self, num_feat, dict_path):
super().__init__()
self.parts = ['left_eye', 'right_eye', 'nose', 'mouth']
# part_sizes: [80, 80, 50, 110]
channel_sizes = [128, 256, 512, 512]
self.feature_sizes = np.array([256, 128, 64, 32])
self.vgg_layers = ['relu2_2', 'relu3_4', 'relu4_4', 'conv5_4']
self.flag_dict_device = False
# dict
self.dict = torch.load(dict_path)
# vgg face extractor
self.vgg_extractor = VGGFeatureExtractor(
layer_name_list=self.vgg_layers,
vgg_type='vgg19',
use_input_norm=True,
range_norm=True,
requires_grad=False)
# attention block for fusing dictionary features and input features
self.attn_blocks = nn.ModuleDict()
for idx, feat_size in enumerate(self.feature_sizes):
for name in self.parts:
self.attn_blocks[f'{name}_{feat_size}'] = AttentionBlock(channel_sizes[idx])
# multi scale dilation block
self.multi_scale_dilation = MSDilationBlock(num_feat * 8, dilation=[4, 3, 2, 1])
# upsampling and reconstruction
self.upsample0 = SFTUpBlock(num_feat * 8, num_feat * 8)
self.upsample1 = SFTUpBlock(num_feat * 8, num_feat * 4)
self.upsample2 = SFTUpBlock(num_feat * 4, num_feat * 2)
self.upsample3 = SFTUpBlock(num_feat * 2, num_feat)
self.upsample4 = nn.Sequential(
spectral_norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1)), nn.LeakyReLU(0.2, True), UpResBlock(num_feat),
UpResBlock(num_feat), nn.Conv2d(num_feat, 3, kernel_size=3, stride=1, padding=1), nn.Tanh())
def swap_feat(self, vgg_feat, updated_feat, dict_feat, location, part_name, f_size):
"""swap the features from the dictionary."""
# get the original vgg features
part_feat = vgg_feat[:, :, location[1]:location[3], location[0]:location[2]].clone()
# resize original vgg features
part_resize_feat = F.interpolate(part_feat, dict_feat.size()[2:4], mode='bilinear', align_corners=False)
# use adaptive instance normalization to adjust color and illuminations
dict_feat = adaptive_instance_normalization(dict_feat, part_resize_feat)
# get similarity scores
similarity_score = F.conv2d(part_resize_feat, dict_feat)
similarity_score = F.softmax(similarity_score.view(-1), dim=0)
# select the most similar features in the dict (after norm)
select_idx = torch.argmax(similarity_score)
swap_feat = F.interpolate(dict_feat[select_idx:select_idx + 1], part_feat.size()[2:4])
# attention
attn = self.attn_blocks[f'{part_name}_' + str(f_size)](swap_feat - part_feat)
attn_feat = attn * swap_feat
# update features
updated_feat[:, :, location[1]:location[3], location[0]:location[2]] = attn_feat + part_feat
return updated_feat
def put_dict_to_device(self, x):
if self.flag_dict_device is False:
for k, v in self.dict.items():
for kk, vv in v.items():
self.dict[k][kk] = vv.to(x)
self.flag_dict_device = True
def forward(self, x, part_locations):
"""
Now only support testing with batch size = 0.
Args:
x (Tensor): Input faces with shape (b, c, 512, 512).
part_locations (list[Tensor]): Part locations.
"""
self.put_dict_to_device(x)
# extract vggface features
vgg_features = self.vgg_extractor(x)
# update vggface features using the dictionary for each part
updated_vgg_features = []
batch = 0 # only supports testing with batch size = 0
for vgg_layer, f_size in zip(self.vgg_layers, self.feature_sizes):
dict_features = self.dict[f'{f_size}']
vgg_feat = vgg_features[vgg_layer]
updated_feat = vgg_feat.clone()
# swap features from dictionary
for part_idx, part_name in enumerate(self.parts):
location = (part_locations[part_idx][batch] // (512 / f_size)).int()
updated_feat = self.swap_feat(vgg_feat, updated_feat, dict_features[part_name], location, part_name,
f_size)
updated_vgg_features.append(updated_feat)
vgg_feat_dilation = self.multi_scale_dilation(vgg_features['conv5_4'])
# use updated vgg features to modulate the upsampled features with
# SFT (Spatial Feature Transform) scaling and shifting manner.
upsampled_feat = self.upsample0(vgg_feat_dilation, updated_vgg_features[3])
upsampled_feat = self.upsample1(upsampled_feat, updated_vgg_features[2])
upsampled_feat = self.upsample2(upsampled_feat, updated_vgg_features[1])
upsampled_feat = self.upsample3(upsampled_feat, updated_vgg_features[0])
out = self.upsample4(upsampled_feat)
return out
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.nn.utils.spectral_norm import spectral_norm
class BlurFunctionBackward(Function):
@staticmethod
def forward(ctx, grad_output, kernel, kernel_flip):
ctx.save_for_backward(kernel, kernel_flip)
grad_input = F.conv2d(grad_output, kernel_flip, padding=1, groups=grad_output.shape[1])
return grad_input
@staticmethod
def backward(ctx, gradgrad_output):
kernel, _ = ctx.saved_tensors
grad_input = F.conv2d(gradgrad_output, kernel, padding=1, groups=gradgrad_output.shape[1])
return grad_input, None, None
class BlurFunction(Function):
@staticmethod
def forward(ctx, x, kernel, kernel_flip):
ctx.save_for_backward(kernel, kernel_flip)
output = F.conv2d(x, kernel, padding=1, groups=x.shape[1])
return output
@staticmethod
def backward(ctx, grad_output):
kernel, kernel_flip = ctx.saved_tensors
grad_input = BlurFunctionBackward.apply(grad_output, kernel, kernel_flip)
return grad_input, None, None
blur = BlurFunction.apply
class Blur(nn.Module):
def __init__(self, channel):
super().__init__()
kernel = torch.tensor([[1, 2, 1], [2, 4, 2], [1, 2, 1]], dtype=torch.float32)
kernel = kernel.view(1, 1, 3, 3)
kernel = kernel / kernel.sum()
kernel_flip = torch.flip(kernel, [2, 3])
self.kernel = kernel.repeat(channel, 1, 1, 1)
self.kernel_flip = kernel_flip.repeat(channel, 1, 1, 1)
def forward(self, x):
return blur(x, self.kernel.type_as(x), self.kernel_flip.type_as(x))
def calc_mean_std(feat, eps=1e-5):
"""Calculate mean and std for adaptive_instance_normalization.
Args:
feat (Tensor): 4D tensor.
eps (float): A small value added to the variance to avoid
divide-by-zero. Default: 1e-5.
"""
size = feat.size()
assert len(size) == 4, 'The input feature should be 4D tensor.'
n, c = size[:2]
feat_var = feat.view(n, c, -1).var(dim=2) + eps
feat_std = feat_var.sqrt().view(n, c, 1, 1)
feat_mean = feat.view(n, c, -1).mean(dim=2).view(n, c, 1, 1)
return feat_mean, feat_std
def adaptive_instance_normalization(content_feat, style_feat):
"""Adaptive instance normalization.
Adjust the reference features to have the similar color and illuminations
as those in the degradate features.
Args:
content_feat (Tensor): The reference feature.
style_feat (Tensor): The degradate features.
"""
size = content_feat.size()
style_mean, style_std = calc_mean_std(style_feat)
content_mean, content_std = calc_mean_std(content_feat)
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
def AttentionBlock(in_channel):
return nn.Sequential(
spectral_norm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True),
spectral_norm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)))
def conv_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, bias=True):
"""Conv block used in MSDilationBlock."""
return nn.Sequential(
spectral_norm(
nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=((kernel_size - 1) // 2) * dilation,
bias=bias)),
nn.LeakyReLU(0.2),
spectral_norm(
nn.Conv2d(
out_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=((kernel_size - 1) // 2) * dilation,
bias=bias)),
)
class MSDilationBlock(nn.Module):
"""Multi-scale dilation block."""
def __init__(self, in_channels, kernel_size=3, dilation=(1, 1, 1, 1), bias=True):
super(MSDilationBlock, self).__init__()
self.conv_blocks = nn.ModuleList()
for i in range(4):
self.conv_blocks.append(conv_block(in_channels, in_channels, kernel_size, dilation=dilation[i], bias=bias))
self.conv_fusion = spectral_norm(
nn.Conv2d(
in_channels * 4,
in_channels,
kernel_size=kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
bias=bias))
def forward(self, x):
out = []
for i in range(4):
out.append(self.conv_blocks[i](x))
out = torch.cat(out, 1)
out = self.conv_fusion(out) + x
return out
class UpResBlock(nn.Module):
def __init__(self, in_channel):
super(UpResBlock, self).__init__()
self.body = nn.Sequential(
nn.Conv2d(in_channel, in_channel, 3, 1, 1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(in_channel, in_channel, 3, 1, 1),
)
def forward(self, x):
out = x + self.body(x)
return out
from torch import nn as nn
from torch.nn import functional as F
from torch.nn.utils import spectral_norm
from basicsr.utils.registry import ARCH_REGISTRY
@ARCH_REGISTRY.register()
class VGGStyleDiscriminator(nn.Module):
"""VGG style discriminator with input size 128 x 128 or 256 x 256.
It is used to train SRGAN, ESRGAN, and VideoGAN.
Args:
num_in_ch (int): Channel number of inputs. Default: 3.
num_feat (int): Channel number of base intermediate features.Default: 64.
"""
def __init__(self, num_in_ch, num_feat, input_size=128):
super(VGGStyleDiscriminator, self).__init__()
self.input_size = input_size
assert self.input_size == 128 or self.input_size == 256, (
f'input size must be 128 or 256, but received {input_size}')
self.conv0_0 = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1, bias=True)
self.conv0_1 = nn.Conv2d(num_feat, num_feat, 4, 2, 1, bias=False)
self.bn0_1 = nn.BatchNorm2d(num_feat, affine=True)
self.conv1_0 = nn.Conv2d(num_feat, num_feat * 2, 3, 1, 1, bias=False)
self.bn1_0 = nn.BatchNorm2d(num_feat * 2, affine=True)
self.conv1_1 = nn.Conv2d(num_feat * 2, num_feat * 2, 4, 2, 1, bias=False)
self.bn1_1 = nn.BatchNorm2d(num_feat * 2, affine=True)
self.conv2_0 = nn.Conv2d(num_feat * 2, num_feat * 4, 3, 1, 1, bias=False)
self.bn2_0 = nn.BatchNorm2d(num_feat * 4, affine=True)
self.conv2_1 = nn.Conv2d(num_feat * 4, num_feat * 4, 4, 2, 1, bias=False)
self.bn2_1 = nn.BatchNorm2d(num_feat * 4, affine=True)
self.conv3_0 = nn.Conv2d(num_feat * 4, num_feat * 8, 3, 1, 1, bias=False)
self.bn3_0 = nn.BatchNorm2d(num_feat * 8, affine=True)
self.conv3_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False)
self.bn3_1 = nn.BatchNorm2d(num_feat * 8, affine=True)
self.conv4_0 = nn.Conv2d(num_feat * 8, num_feat * 8, 3, 1, 1, bias=False)
self.bn4_0 = nn.BatchNorm2d(num_feat * 8, affine=True)
self.conv4_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False)
self.bn4_1 = nn.BatchNorm2d(num_feat * 8, affine=True)
if self.input_size == 256:
self.conv5_0 = nn.Conv2d(num_feat * 8, num_feat * 8, 3, 1, 1, bias=False)
self.bn5_0 = nn.BatchNorm2d(num_feat * 8, affine=True)
self.conv5_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False)
self.bn5_1 = nn.BatchNorm2d(num_feat * 8, affine=True)
self.linear1 = nn.Linear(num_feat * 8 * 4 * 4, 100)
self.linear2 = nn.Linear(100, 1)
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
assert x.size(2) == self.input_size, (f'Input size must be identical to input_size, but received {x.size()}.')
feat = self.lrelu(self.conv0_0(x))
feat = self.lrelu(self.bn0_1(self.conv0_1(feat))) # output spatial size: /2
feat = self.lrelu(self.bn1_0(self.conv1_0(feat)))
feat = self.lrelu(self.bn1_1(self.conv1_1(feat))) # output spatial size: /4
feat = self.lrelu(self.bn2_0(self.conv2_0(feat)))
feat = self.lrelu(self.bn2_1(self.conv2_1(feat))) # output spatial size: /8
feat = self.lrelu(self.bn3_0(self.conv3_0(feat)))
feat = self.lrelu(self.bn3_1(self.conv3_1(feat))) # output spatial size: /16
feat = self.lrelu(self.bn4_0(self.conv4_0(feat)))
feat = self.lrelu(self.bn4_1(self.conv4_1(feat))) # output spatial size: /32
if self.input_size == 256:
feat = self.lrelu(self.bn5_0(self.conv5_0(feat)))
feat = self.lrelu(self.bn5_1(self.conv5_1(feat))) # output spatial size: / 64
# spatial size: (4, 4)
feat = feat.view(feat.size(0), -1)
feat = self.lrelu(self.linear1(feat))
out = self.linear2(feat)
return out
@ARCH_REGISTRY.register(suffix='basicsr')
class UNetDiscriminatorSN(nn.Module):
"""Defines a U-Net discriminator with spectral normalization (SN)
It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
Arg:
num_in_ch (int): Channel number of inputs. Default: 3.
num_feat (int): Channel number of base intermediate features. Default: 64.
skip_connection (bool): Whether to use skip connections between U-Net. Default: True.
"""
def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
super(UNetDiscriminatorSN, self).__init__()
self.skip_connection = skip_connection
norm = spectral_norm
# the first convolution
self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)
# downsample
self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False))
self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False))
self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False))
# upsample
self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False))
self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))
self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))
# extra convolutions
self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1)
def forward(self, x):
# downsample
x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True)
# upsample
x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False)
x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True)
if self.skip_connection:
x4 = x4 + x2
x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True)
if self.skip_connection:
x5 = x5 + x1
x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False)
x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True)
if self.skip_connection:
x6 = x6 + x0
# extra convolutions
out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
out = self.conv9(out)
return out
import numpy as np
import torch
from torch import nn as nn
from torch.nn import functional as F
from basicsr.utils.registry import ARCH_REGISTRY
class DenseBlocksTemporalReduce(nn.Module):
"""A concatenation of 3 dense blocks with reduction in temporal dimension.
Note that the output temporal dimension is 6 fewer the input temporal dimension, since there are 3 blocks.
Args:
num_feat (int): Number of channels in the blocks. Default: 64.
num_grow_ch (int): Growing factor of the dense blocks. Default: 32
adapt_official_weights (bool): Whether to adapt the weights translated from the official implementation.
Set to false if you want to train from scratch. Default: False.
"""
def __init__(self, num_feat=64, num_grow_ch=32, adapt_official_weights=False):
super(DenseBlocksTemporalReduce, self).__init__()
if adapt_official_weights:
eps = 1e-3
momentum = 1e-3
else: # pytorch default values
eps = 1e-05
momentum = 0.1
self.temporal_reduce1 = nn.Sequential(
nn.BatchNorm3d(num_feat, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
nn.Conv3d(num_feat, num_feat, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True),
nn.BatchNorm3d(num_feat, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
nn.Conv3d(num_feat, num_grow_ch, (3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True))
self.temporal_reduce2 = nn.Sequential(
nn.BatchNorm3d(num_feat + num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
nn.Conv3d(
num_feat + num_grow_ch,
num_feat + num_grow_ch, (1, 1, 1),
stride=(1, 1, 1),
padding=(0, 0, 0),
bias=True), nn.BatchNorm3d(num_feat + num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
nn.Conv3d(num_feat + num_grow_ch, num_grow_ch, (3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True))
self.temporal_reduce3 = nn.Sequential(
nn.BatchNorm3d(num_feat + 2 * num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
nn.Conv3d(
num_feat + 2 * num_grow_ch,
num_feat + 2 * num_grow_ch, (1, 1, 1),
stride=(1, 1, 1),
padding=(0, 0, 0),
bias=True), nn.BatchNorm3d(num_feat + 2 * num_grow_ch, eps=eps, momentum=momentum),
nn.ReLU(inplace=True),
nn.Conv3d(
num_feat + 2 * num_grow_ch, num_grow_ch, (3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True))
def forward(self, x):
"""
Args:
x (Tensor): Input tensor with shape (b, num_feat, t, h, w).
Returns:
Tensor: Output with shape (b, num_feat + num_grow_ch * 3, 1, h, w).
"""
x1 = self.temporal_reduce1(x)
x1 = torch.cat((x[:, :, 1:-1, :, :], x1), 1)
x2 = self.temporal_reduce2(x1)
x2 = torch.cat((x1[:, :, 1:-1, :, :], x2), 1)
x3 = self.temporal_reduce3(x2)
x3 = torch.cat((x2[:, :, 1:-1, :, :], x3), 1)
return x3
class DenseBlocks(nn.Module):
""" A concatenation of N dense blocks.
Args:
num_feat (int): Number of channels in the blocks. Default: 64.
num_grow_ch (int): Growing factor of the dense blocks. Default: 32.
num_block (int): Number of dense blocks. The values are:
DUF-S (16 layers): 3
DUF-M (18 layers): 9
DUF-L (52 layers): 21
adapt_official_weights (bool): Whether to adapt the weights translated from the official implementation.
Set to false if you want to train from scratch. Default: False.
"""
def __init__(self, num_block, num_feat=64, num_grow_ch=16, adapt_official_weights=False):
super(DenseBlocks, self).__init__()
if adapt_official_weights:
eps = 1e-3
momentum = 1e-3
else: # pytorch default values
eps = 1e-05
momentum = 0.1
self.dense_blocks = nn.ModuleList()
for i in range(0, num_block):
self.dense_blocks.append(
nn.Sequential(
nn.BatchNorm3d(num_feat + i * num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
nn.Conv3d(
num_feat + i * num_grow_ch,
num_feat + i * num_grow_ch, (1, 1, 1),
stride=(1, 1, 1),
padding=(0, 0, 0),
bias=True), nn.BatchNorm3d(num_feat + i * num_grow_ch, eps=eps, momentum=momentum),
nn.ReLU(inplace=True),
nn.Conv3d(
num_feat + i * num_grow_ch,
num_grow_ch, (3, 3, 3),
stride=(1, 1, 1),
padding=(1, 1, 1),
bias=True)))
def forward(self, x):
"""
Args:
x (Tensor): Input tensor with shape (b, num_feat, t, h, w).
Returns:
Tensor: Output with shape (b, num_feat + num_block * num_grow_ch, t, h, w).
"""
for i in range(0, len(self.dense_blocks)):
y = self.dense_blocks[i](x)
x = torch.cat((x, y), 1)
return x
class DynamicUpsamplingFilter(nn.Module):
"""Dynamic upsampling filter used in DUF.
Reference: https://github.com/yhjo09/VSR-DUF
It only supports input with 3 channels. And it applies the same filters to 3 channels.
Args:
filter_size (tuple): Filter size of generated filters. The shape is (kh, kw). Default: (5, 5).
"""
def __init__(self, filter_size=(5, 5)):
super(DynamicUpsamplingFilter, self).__init__()
if not isinstance(filter_size, tuple):
raise TypeError(f'The type of filter_size must be tuple, but got type{filter_size}')
if len(filter_size) != 2:
raise ValueError(f'The length of filter size must be 2, but got {len(filter_size)}.')
# generate a local expansion filter, similar to im2col
self.filter_size = filter_size
filter_prod = np.prod(filter_size)
expansion_filter = torch.eye(int(filter_prod)).view(filter_prod, 1, *filter_size) # (kh*kw, 1, kh, kw)
self.expansion_filter = expansion_filter.repeat(3, 1, 1, 1) # repeat for all the 3 channels
def forward(self, x, filters):
"""Forward function for DynamicUpsamplingFilter.
Args:
x (Tensor): Input image with 3 channels. The shape is (n, 3, h, w).
filters (Tensor): Generated dynamic filters. The shape is (n, filter_prod, upsampling_square, h, w).
filter_prod: prod of filter kernel size, e.g., 1*5*5=25.
upsampling_square: similar to pixel shuffle, upsampling_square = upsampling * upsampling.
e.g., for x 4 upsampling, upsampling_square= 4*4 = 16
Returns:
Tensor: Filtered image with shape (n, 3*upsampling_square, h, w)
"""
n, filter_prod, upsampling_square, h, w = filters.size()
kh, kw = self.filter_size
expanded_input = F.conv2d(
x, self.expansion_filter.to(x), padding=(kh // 2, kw // 2), groups=3) # (n, 3*filter_prod, h, w)
expanded_input = expanded_input.view(n, 3, filter_prod, h, w).permute(0, 3, 4, 1,
2) # (n, h, w, 3, filter_prod)
filters = filters.permute(0, 3, 4, 1, 2) # (n, h, w, filter_prod, upsampling_square]
out = torch.matmul(expanded_input, filters) # (n, h, w, 3, upsampling_square)
return out.permute(0, 3, 4, 1, 2).view(n, 3 * upsampling_square, h, w)
@ARCH_REGISTRY.register()
class DUF(nn.Module):
"""Network architecture for DUF
``Paper: Deep Video Super-Resolution Network Using Dynamic Upsampling Filters Without Explicit Motion Compensation``
Reference: https://github.com/yhjo09/VSR-DUF
For all the models below, 'adapt_official_weights' is only necessary when
loading the weights converted from the official TensorFlow weights.
Please set it to False if you are training the model from scratch.
There are three models with different model size: DUF16Layers, DUF28Layers,
and DUF52Layers. This class is the base class for these models.
Args:
scale (int): The upsampling factor. Default: 4.
num_layer (int): The number of layers. Default: 52.
adapt_official_weights_weights (bool): Whether to adapt the weights
translated from the official implementation. Set to false if you
want to train from scratch. Default: False.
"""
def __init__(self, scale=4, num_layer=52, adapt_official_weights=False):
super(DUF, self).__init__()
self.scale = scale
if adapt_official_weights:
eps = 1e-3
momentum = 1e-3
else: # pytorch default values
eps = 1e-05
momentum = 0.1
self.conv3d1 = nn.Conv3d(3, 64, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True)
self.dynamic_filter = DynamicUpsamplingFilter((5, 5))
if num_layer == 16:
num_block = 3
num_grow_ch = 32
elif num_layer == 28:
num_block = 9
num_grow_ch = 16
elif num_layer == 52:
num_block = 21
num_grow_ch = 16
else:
raise ValueError(f'Only supported (16, 28, 52) layers, but got {num_layer}.')
self.dense_block1 = DenseBlocks(
num_block=num_block, num_feat=64, num_grow_ch=num_grow_ch,
adapt_official_weights=adapt_official_weights) # T = 7
self.dense_block2 = DenseBlocksTemporalReduce(
64 + num_grow_ch * num_block, num_grow_ch, adapt_official_weights=adapt_official_weights) # T = 1
channels = 64 + num_grow_ch * num_block + num_grow_ch * 3
self.bn3d2 = nn.BatchNorm3d(channels, eps=eps, momentum=momentum)
self.conv3d2 = nn.Conv3d(channels, 256, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True)
self.conv3d_r1 = nn.Conv3d(256, 256, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
self.conv3d_r2 = nn.Conv3d(256, 3 * (scale**2), (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
self.conv3d_f1 = nn.Conv3d(256, 512, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
self.conv3d_f2 = nn.Conv3d(
512, 1 * 5 * 5 * (scale**2), (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
def forward(self, x):
"""
Args:
x (Tensor): Input with shape (b, 7, c, h, w)
Returns:
Tensor: Output with shape (b, c, h * scale, w * scale)
"""
num_batches, num_imgs, _, h, w = x.size()
x = x.permute(0, 2, 1, 3, 4) # (b, c, 7, h, w) for Conv3D
x_center = x[:, :, num_imgs // 2, :, :]
x = self.conv3d1(x)
x = self.dense_block1(x)
x = self.dense_block2(x)
x = F.relu(self.bn3d2(x), inplace=True)
x = F.relu(self.conv3d2(x), inplace=True)
# residual image
res = self.conv3d_r2(F.relu(self.conv3d_r1(x), inplace=True))
# filter
filter_ = self.conv3d_f2(F.relu(self.conv3d_f1(x), inplace=True))
filter_ = F.softmax(filter_.view(num_batches, 25, self.scale**2, h, w), dim=1)
# dynamic filter
out = self.dynamic_filter(x_center, filter_)
out += res.squeeze_(2)
out = F.pixel_shuffle(out, self.scale)
return out
import torch
import torch.nn as nn
import torch.nn.functional as F
from basicsr.utils.registry import ARCH_REGISTRY
class SeqConv3x3(nn.Module):
"""The re-parameterizable block used in the ECBSR architecture.
``Paper: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices``
Reference: https://github.com/xindongzhang/ECBSR
Args:
seq_type (str): Sequence type, option: conv1x1-conv3x3 | conv1x1-sobelx | conv1x1-sobely | conv1x1-laplacian.
in_channels (int): Channel number of input.
out_channels (int): Channel number of output.
depth_multiplier (int): Width multiplier in the expand-and-squeeze conv. Default: 1.
"""
def __init__(self, seq_type, in_channels, out_channels, depth_multiplier=1):
super(SeqConv3x3, self).__init__()
self.seq_type = seq_type
self.in_channels = in_channels
self.out_channels = out_channels
if self.seq_type == 'conv1x1-conv3x3':
self.mid_planes = int(out_channels * depth_multiplier)
conv0 = torch.nn.Conv2d(self.in_channels, self.mid_planes, kernel_size=1, padding=0)
self.k0 = conv0.weight
self.b0 = conv0.bias
conv1 = torch.nn.Conv2d(self.mid_planes, self.out_channels, kernel_size=3)
self.k1 = conv1.weight
self.b1 = conv1.bias
elif self.seq_type == 'conv1x1-sobelx':
conv0 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1, padding=0)
self.k0 = conv0.weight
self.b0 = conv0.bias
# init scale and bias
scale = torch.randn(size=(self.out_channels, 1, 1, 1)) * 1e-3
self.scale = nn.Parameter(scale)
bias = torch.randn(self.out_channels) * 1e-3
bias = torch.reshape(bias, (self.out_channels, ))
self.bias = nn.Parameter(bias)
# init mask
self.mask = torch.zeros((self.out_channels, 1, 3, 3), dtype=torch.float32)
for i in range(self.out_channels):
self.mask[i, 0, 0, 0] = 1.0
self.mask[i, 0, 1, 0] = 2.0
self.mask[i, 0, 2, 0] = 1.0
self.mask[i, 0, 0, 2] = -1.0
self.mask[i, 0, 1, 2] = -2.0
self.mask[i, 0, 2, 2] = -1.0
self.mask = nn.Parameter(data=self.mask, requires_grad=False)
elif self.seq_type == 'conv1x1-sobely':
conv0 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1, padding=0)
self.k0 = conv0.weight
self.b0 = conv0.bias
# init scale and bias
scale = torch.randn(size=(self.out_channels, 1, 1, 1)) * 1e-3
self.scale = nn.Parameter(torch.FloatTensor(scale))
bias = torch.randn(self.out_channels) * 1e-3
bias = torch.reshape(bias, (self.out_channels, ))
self.bias = nn.Parameter(torch.FloatTensor(bias))
# init mask
self.mask = torch.zeros((self.out_channels, 1, 3, 3), dtype=torch.float32)
for i in range(self.out_channels):
self.mask[i, 0, 0, 0] = 1.0
self.mask[i, 0, 0, 1] = 2.0
self.mask[i, 0, 0, 2] = 1.0
self.mask[i, 0, 2, 0] = -1.0
self.mask[i, 0, 2, 1] = -2.0
self.mask[i, 0, 2, 2] = -1.0
self.mask = nn.Parameter(data=self.mask, requires_grad=False)
elif self.seq_type == 'conv1x1-laplacian':
conv0 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1, padding=0)
self.k0 = conv0.weight
self.b0 = conv0.bias
# init scale and bias
scale = torch.randn(size=(self.out_channels, 1, 1, 1)) * 1e-3
self.scale = nn.Parameter(torch.FloatTensor(scale))
bias = torch.randn(self.out_channels) * 1e-3
bias = torch.reshape(bias, (self.out_channels, ))
self.bias = nn.Parameter(torch.FloatTensor(bias))
# init mask
self.mask = torch.zeros((self.out_channels, 1, 3, 3), dtype=torch.float32)
for i in range(self.out_channels):
self.mask[i, 0, 0, 1] = 1.0
self.mask[i, 0, 1, 0] = 1.0
self.mask[i, 0, 1, 2] = 1.0
self.mask[i, 0, 2, 1] = 1.0
self.mask[i, 0, 1, 1] = -4.0
self.mask = nn.Parameter(data=self.mask, requires_grad=False)
else:
raise ValueError('The type of seqconv is not supported!')
def forward(self, x):
if self.seq_type == 'conv1x1-conv3x3':
# conv-1x1
y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1)
# explicitly padding with bias
y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0)
b0_pad = self.b0.view(1, -1, 1, 1)
y0[:, :, 0:1, :] = b0_pad
y0[:, :, -1:, :] = b0_pad
y0[:, :, :, 0:1] = b0_pad
y0[:, :, :, -1:] = b0_pad
# conv-3x3
y1 = F.conv2d(input=y0, weight=self.k1, bias=self.b1, stride=1)
else:
y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1)
# explicitly padding with bias
y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0)
b0_pad = self.b0.view(1, -1, 1, 1)
y0[:, :, 0:1, :] = b0_pad
y0[:, :, -1:, :] = b0_pad
y0[:, :, :, 0:1] = b0_pad
y0[:, :, :, -1:] = b0_pad
# conv-3x3
y1 = F.conv2d(input=y0, weight=self.scale * self.mask, bias=self.bias, stride=1, groups=self.out_channels)
return y1
def rep_params(self):
device = self.k0.get_device()
if device < 0:
device = None
if self.seq_type == 'conv1x1-conv3x3':
# re-param conv kernel
rep_weight = F.conv2d(input=self.k1, weight=self.k0.permute(1, 0, 2, 3))
# re-param conv bias
rep_bias = torch.ones(1, self.mid_planes, 3, 3, device=device) * self.b0.view(1, -1, 1, 1)
rep_bias = F.conv2d(input=rep_bias, weight=self.k1).view(-1, ) + self.b1
else:
tmp = self.scale * self.mask
k1 = torch.zeros((self.out_channels, self.out_channels, 3, 3), device=device)
for i in range(self.out_channels):
k1[i, i, :, :] = tmp[i, 0, :, :]
b1 = self.bias
# re-param conv kernel
rep_weight = F.conv2d(input=k1, weight=self.k0.permute(1, 0, 2, 3))
# re-param conv bias
rep_bias = torch.ones(1, self.out_channels, 3, 3, device=device) * self.b0.view(1, -1, 1, 1)
rep_bias = F.conv2d(input=rep_bias, weight=k1).view(-1, ) + b1
return rep_weight, rep_bias
class ECB(nn.Module):
"""The ECB block used in the ECBSR architecture.
Paper: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices
Ref git repo: https://github.com/xindongzhang/ECBSR
Args:
in_channels (int): Channel number of input.
out_channels (int): Channel number of output.
depth_multiplier (int): Width multiplier in the expand-and-squeeze conv. Default: 1.
act_type (str): Activation type. Option: prelu | relu | rrelu | softplus | linear. Default: prelu.
with_idt (bool): Whether to use identity connection. Default: False.
"""
def __init__(self, in_channels, out_channels, depth_multiplier, act_type='prelu', with_idt=False):
super(ECB, self).__init__()
self.depth_multiplier = depth_multiplier
self.in_channels = in_channels
self.out_channels = out_channels
self.act_type = act_type
if with_idt and (self.in_channels == self.out_channels):
self.with_idt = True
else:
self.with_idt = False
self.conv3x3 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=3, padding=1)
self.conv1x1_3x3 = SeqConv3x3('conv1x1-conv3x3', self.in_channels, self.out_channels, self.depth_multiplier)
self.conv1x1_sbx = SeqConv3x3('conv1x1-sobelx', self.in_channels, self.out_channels)
self.conv1x1_sby = SeqConv3x3('conv1x1-sobely', self.in_channels, self.out_channels)
self.conv1x1_lpl = SeqConv3x3('conv1x1-laplacian', self.in_channels, self.out_channels)
if self.act_type == 'prelu':
self.act = nn.PReLU(num_parameters=self.out_channels)
elif self.act_type == 'relu':
self.act = nn.ReLU(inplace=True)
elif self.act_type == 'rrelu':
self.act = nn.RReLU(lower=-0.05, upper=0.05)
elif self.act_type == 'softplus':
self.act = nn.Softplus()
elif self.act_type == 'linear':
pass
else:
raise ValueError('The type of activation if not support!')
def forward(self, x):
if self.training:
y = self.conv3x3(x) + self.conv1x1_3x3(x) + self.conv1x1_sbx(x) + self.conv1x1_sby(x) + self.conv1x1_lpl(x)
if self.with_idt:
y += x
else:
rep_weight, rep_bias = self.rep_params()
y = F.conv2d(input=x, weight=rep_weight, bias=rep_bias, stride=1, padding=1)
if self.act_type != 'linear':
y = self.act(y)
return y
def rep_params(self):
weight0, bias0 = self.conv3x3.weight, self.conv3x3.bias
weight1, bias1 = self.conv1x1_3x3.rep_params()
weight2, bias2 = self.conv1x1_sbx.rep_params()
weight3, bias3 = self.conv1x1_sby.rep_params()
weight4, bias4 = self.conv1x1_lpl.rep_params()
rep_weight, rep_bias = (weight0 + weight1 + weight2 + weight3 + weight4), (
bias0 + bias1 + bias2 + bias3 + bias4)
if self.with_idt:
device = rep_weight.get_device()
if device < 0:
device = None
weight_idt = torch.zeros(self.out_channels, self.out_channels, 3, 3, device=device)
for i in range(self.out_channels):
weight_idt[i, i, 1, 1] = 1.0
bias_idt = 0.0
rep_weight, rep_bias = rep_weight + weight_idt, rep_bias + bias_idt
return rep_weight, rep_bias
@ARCH_REGISTRY.register()
class ECBSR(nn.Module):
"""ECBSR architecture.
Paper: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices
Ref git repo: https://github.com/xindongzhang/ECBSR
Args:
num_in_ch (int): Channel number of inputs.
num_out_ch (int): Channel number of outputs.
num_block (int): Block number in the trunk network.
num_channel (int): Channel number.
with_idt (bool): Whether use identity in convolution layers.
act_type (str): Activation type.
scale (int): Upsampling factor.
"""
def __init__(self, num_in_ch, num_out_ch, num_block, num_channel, with_idt, act_type, scale):
super(ECBSR, self).__init__()
self.num_in_ch = num_in_ch
self.scale = scale
backbone = []
backbone += [ECB(num_in_ch, num_channel, depth_multiplier=2.0, act_type=act_type, with_idt=with_idt)]
for _ in range(num_block):
backbone += [ECB(num_channel, num_channel, depth_multiplier=2.0, act_type=act_type, with_idt=with_idt)]
backbone += [
ECB(num_channel, num_out_ch * scale * scale, depth_multiplier=2.0, act_type='linear', with_idt=with_idt)
]
self.backbone = nn.Sequential(*backbone)
self.upsampler = nn.PixelShuffle(scale)
def forward(self, x):
if self.num_in_ch > 1:
shortcut = torch.repeat_interleave(x, self.scale * self.scale, dim=1)
else:
shortcut = x # will repeat the input in the channel dimension (repeat scale * scale times)
y = self.backbone(x) + shortcut
y = self.upsampler(y)
return y
import torch
from torch import nn as nn
from basicsr.archs.arch_util import ResidualBlockNoBN, Upsample, make_layer
from basicsr.utils.registry import ARCH_REGISTRY
@ARCH_REGISTRY.register()
class EDSR(nn.Module):
"""EDSR network structure.
Paper: Enhanced Deep Residual Networks for Single Image Super-Resolution.
Ref git repo: https://github.com/thstkdgus35/EDSR-PyTorch
Args:
num_in_ch (int): Channel number of inputs.
num_out_ch (int): Channel number of outputs.
num_feat (int): Channel number of intermediate features.
Default: 64.
num_block (int): Block number in the trunk network. Default: 16.
upscale (int): Upsampling factor. Support 2^n and 3.
Default: 4.
res_scale (float): Used to scale the residual in residual block.
Default: 1.
img_range (float): Image range. Default: 255.
rgb_mean (tuple[float]): Image mean in RGB orders.
Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset.
"""
def __init__(self,
num_in_ch,
num_out_ch,
num_feat=64,
num_block=16,
upscale=4,
res_scale=1,
img_range=255.,
rgb_mean=(0.4488, 0.4371, 0.4040)):
super(EDSR, self).__init__()
self.img_range = img_range
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
self.body = make_layer(ResidualBlockNoBN, num_block, num_feat=num_feat, res_scale=res_scale, pytorch_init=True)
self.conv_after_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.upsample = Upsample(upscale, num_feat)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
def forward(self, x):
self.mean = self.mean.type_as(x)
x = (x - self.mean) * self.img_range
x = self.conv_first(x)
res = self.conv_after_body(self.body(x))
res += x
x = self.conv_last(self.upsample(res))
x = x / self.img_range + self.mean
return x
import torch
from torch import nn as nn
from torch.nn import functional as F
from basicsr.utils.registry import ARCH_REGISTRY
from .arch_util import DCNv2Pack, ResidualBlockNoBN, make_layer
class PCDAlignment(nn.Module):
"""Alignment module using Pyramid, Cascading and Deformable convolution
(PCD). It is used in EDVR.
``Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks``
Args:
num_feat (int): Channel number of middle features. Default: 64.
deformable_groups (int): Deformable groups. Defaults: 8.
"""
def __init__(self, num_feat=64, deformable_groups=8):
super(PCDAlignment, self).__init__()
# Pyramid has three levels:
# L3: level 3, 1/4 spatial size
# L2: level 2, 1/2 spatial size
# L1: level 1, original spatial size
self.offset_conv1 = nn.ModuleDict()
self.offset_conv2 = nn.ModuleDict()
self.offset_conv3 = nn.ModuleDict()
self.dcn_pack = nn.ModuleDict()
self.feat_conv = nn.ModuleDict()
# Pyramids
for i in range(3, 0, -1):
level = f'l{i}'
self.offset_conv1[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
if i == 3:
self.offset_conv2[level] = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
else:
self.offset_conv2[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
self.offset_conv3[level] = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.dcn_pack[level] = DCNv2Pack(num_feat, num_feat, 3, padding=1, deformable_groups=deformable_groups)
if i < 3:
self.feat_conv[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
# Cascading dcn
self.cas_offset_conv1 = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
self.cas_offset_conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.cas_dcnpack = DCNv2Pack(num_feat, num_feat, 3, padding=1, deformable_groups=deformable_groups)
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
def forward(self, nbr_feat_l, ref_feat_l):
"""Align neighboring frame features to the reference frame features.
Args:
nbr_feat_l (list[Tensor]): Neighboring feature list. It
contains three pyramid levels (L1, L2, L3),
each with shape (b, c, h, w).
ref_feat_l (list[Tensor]): Reference feature list. It
contains three pyramid levels (L1, L2, L3),
each with shape (b, c, h, w).
Returns:
Tensor: Aligned features.
"""
# Pyramids
upsampled_offset, upsampled_feat = None, None
for i in range(3, 0, -1):
level = f'l{i}'
offset = torch.cat([nbr_feat_l[i - 1], ref_feat_l[i - 1]], dim=1)
offset = self.lrelu(self.offset_conv1[level](offset))
if i == 3:
offset = self.lrelu(self.offset_conv2[level](offset))
else:
offset = self.lrelu(self.offset_conv2[level](torch.cat([offset, upsampled_offset], dim=1)))
offset = self.lrelu(self.offset_conv3[level](offset))
feat = self.dcn_pack[level](nbr_feat_l[i - 1], offset)
if i < 3:
feat = self.feat_conv[level](torch.cat([feat, upsampled_feat], dim=1))
if i > 1:
feat = self.lrelu(feat)
if i > 1: # upsample offset and features
# x2: when we upsample the offset, we should also enlarge
# the magnitude.
upsampled_offset = self.upsample(offset) * 2
upsampled_feat = self.upsample(feat)
# Cascading
offset = torch.cat([feat, ref_feat_l[0]], dim=1)
offset = self.lrelu(self.cas_offset_conv2(self.lrelu(self.cas_offset_conv1(offset))))
feat = self.lrelu(self.cas_dcnpack(feat, offset))
return feat
class TSAFusion(nn.Module):
"""Temporal Spatial Attention (TSA) fusion module.
Temporal: Calculate the correlation between center frame and
neighboring frames;
Spatial: It has 3 pyramid levels, the attention is similar to SFT.
(SFT: Recovering realistic texture in image super-resolution by deep
spatial feature transform.)
Args:
num_feat (int): Channel number of middle features. Default: 64.
num_frame (int): Number of frames. Default: 5.
center_frame_idx (int): The index of center frame. Default: 2.
"""
def __init__(self, num_feat=64, num_frame=5, center_frame_idx=2):
super(TSAFusion, self).__init__()
self.center_frame_idx = center_frame_idx
# temporal attention (before fusion conv)
self.temporal_attn1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.temporal_attn2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.feat_fusion = nn.Conv2d(num_frame * num_feat, num_feat, 1, 1)
# spatial attention (after fusion conv)
self.max_pool = nn.MaxPool2d(3, stride=2, padding=1)
self.avg_pool = nn.AvgPool2d(3, stride=2, padding=1)
self.spatial_attn1 = nn.Conv2d(num_frame * num_feat, num_feat, 1)
self.spatial_attn2 = nn.Conv2d(num_feat * 2, num_feat, 1)
self.spatial_attn3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.spatial_attn4 = nn.Conv2d(num_feat, num_feat, 1)
self.spatial_attn5 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.spatial_attn_l1 = nn.Conv2d(num_feat, num_feat, 1)
self.spatial_attn_l2 = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
self.spatial_attn_l3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.spatial_attn_add1 = nn.Conv2d(num_feat, num_feat, 1)
self.spatial_attn_add2 = nn.Conv2d(num_feat, num_feat, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
def forward(self, aligned_feat):
"""
Args:
aligned_feat (Tensor): Aligned features with shape (b, t, c, h, w).
Returns:
Tensor: Features after TSA with the shape (b, c, h, w).
"""
b, t, c, h, w = aligned_feat.size()
# temporal attention
embedding_ref = self.temporal_attn1(aligned_feat[:, self.center_frame_idx, :, :, :].clone())
embedding = self.temporal_attn2(aligned_feat.view(-1, c, h, w))
embedding = embedding.view(b, t, -1, h, w) # (b, t, c, h, w)
corr_l = [] # correlation list
for i in range(t):
emb_neighbor = embedding[:, i, :, :, :]
corr = torch.sum(emb_neighbor * embedding_ref, 1) # (b, h, w)
corr_l.append(corr.unsqueeze(1)) # (b, 1, h, w)
corr_prob = torch.sigmoid(torch.cat(corr_l, dim=1)) # (b, t, h, w)
corr_prob = corr_prob.unsqueeze(2).expand(b, t, c, h, w)
corr_prob = corr_prob.contiguous().view(b, -1, h, w) # (b, t*c, h, w)
aligned_feat = aligned_feat.view(b, -1, h, w) * corr_prob
# fusion
feat = self.lrelu(self.feat_fusion(aligned_feat))
# spatial attention
attn = self.lrelu(self.spatial_attn1(aligned_feat))
attn_max = self.max_pool(attn)
attn_avg = self.avg_pool(attn)
attn = self.lrelu(self.spatial_attn2(torch.cat([attn_max, attn_avg], dim=1)))
# pyramid levels
attn_level = self.lrelu(self.spatial_attn_l1(attn))
attn_max = self.max_pool(attn_level)
attn_avg = self.avg_pool(attn_level)
attn_level = self.lrelu(self.spatial_attn_l2(torch.cat([attn_max, attn_avg], dim=1)))
attn_level = self.lrelu(self.spatial_attn_l3(attn_level))
attn_level = self.upsample(attn_level)
attn = self.lrelu(self.spatial_attn3(attn)) + attn_level
attn = self.lrelu(self.spatial_attn4(attn))
attn = self.upsample(attn)
attn = self.spatial_attn5(attn)
attn_add = self.spatial_attn_add2(self.lrelu(self.spatial_attn_add1(attn)))
attn = torch.sigmoid(attn)
# after initialization, * 2 makes (attn * 2) to be close to 1.
feat = feat * attn * 2 + attn_add
return feat
class PredeblurModule(nn.Module):
"""Pre-dublur module.
Args:
num_in_ch (int): Channel number of input image. Default: 3.
num_feat (int): Channel number of intermediate features. Default: 64.
hr_in (bool): Whether the input has high resolution. Default: False.
"""
def __init__(self, num_in_ch=3, num_feat=64, hr_in=False):
super(PredeblurModule, self).__init__()
self.hr_in = hr_in
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
if self.hr_in:
# downsample x4 by stride conv
self.stride_conv_hr1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
self.stride_conv_hr2 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
# generate feature pyramid
self.stride_conv_l2 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
self.stride_conv_l3 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
self.resblock_l3 = ResidualBlockNoBN(num_feat=num_feat)
self.resblock_l2_1 = ResidualBlockNoBN(num_feat=num_feat)
self.resblock_l2_2 = ResidualBlockNoBN(num_feat=num_feat)
self.resblock_l1 = nn.ModuleList([ResidualBlockNoBN(num_feat=num_feat) for i in range(5)])
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
def forward(self, x):
feat_l1 = self.lrelu(self.conv_first(x))
if self.hr_in:
feat_l1 = self.lrelu(self.stride_conv_hr1(feat_l1))
feat_l1 = self.lrelu(self.stride_conv_hr2(feat_l1))
# generate feature pyramid
feat_l2 = self.lrelu(self.stride_conv_l2(feat_l1))
feat_l3 = self.lrelu(self.stride_conv_l3(feat_l2))
feat_l3 = self.upsample(self.resblock_l3(feat_l3))
feat_l2 = self.resblock_l2_1(feat_l2) + feat_l3
feat_l2 = self.upsample(self.resblock_l2_2(feat_l2))
for i in range(2):
feat_l1 = self.resblock_l1[i](feat_l1)
feat_l1 = feat_l1 + feat_l2
for i in range(2, 5):
feat_l1 = self.resblock_l1[i](feat_l1)
return feat_l1
@ARCH_REGISTRY.register()
class EDVR(nn.Module):
"""EDVR network structure for video super-resolution.
Now only support X4 upsampling factor.
``Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks``
Args:
num_in_ch (int): Channel number of input image. Default: 3.
num_out_ch (int): Channel number of output image. Default: 3.
num_feat (int): Channel number of intermediate features. Default: 64.
num_frame (int): Number of input frames. Default: 5.
deformable_groups (int): Deformable groups. Defaults: 8.
num_extract_block (int): Number of blocks for feature extraction.
Default: 5.
num_reconstruct_block (int): Number of blocks for reconstruction.
Default: 10.
center_frame_idx (int): The index of center frame. Frame counting from
0. Default: Middle of input frames.
hr_in (bool): Whether the input has high resolution. Default: False.
with_predeblur (bool): Whether has predeblur module.
Default: False.
with_tsa (bool): Whether has TSA module. Default: True.
"""
def __init__(self,
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_frame=5,
deformable_groups=8,
num_extract_block=5,
num_reconstruct_block=10,
center_frame_idx=None,
hr_in=False,
with_predeblur=False,
with_tsa=True):
super(EDVR, self).__init__()
if center_frame_idx is None:
self.center_frame_idx = num_frame // 2
else:
self.center_frame_idx = center_frame_idx
self.hr_in = hr_in
self.with_predeblur = with_predeblur
self.with_tsa = with_tsa
# extract features for each frame
if self.with_predeblur:
self.predeblur = PredeblurModule(num_feat=num_feat, hr_in=self.hr_in)
self.conv_1x1 = nn.Conv2d(num_feat, num_feat, 1, 1)
else:
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
# extract pyramid features
self.feature_extraction = make_layer(ResidualBlockNoBN, num_extract_block, num_feat=num_feat)
self.conv_l2_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
self.conv_l2_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_l3_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
self.conv_l3_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
# pcd and tsa module
self.pcd_align = PCDAlignment(num_feat=num_feat, deformable_groups=deformable_groups)
if self.with_tsa:
self.fusion = TSAFusion(num_feat=num_feat, num_frame=num_frame, center_frame_idx=self.center_frame_idx)
else:
self.fusion = nn.Conv2d(num_frame * num_feat, num_feat, 1, 1)
# reconstruction
self.reconstruction = make_layer(ResidualBlockNoBN, num_reconstruct_block, num_feat=num_feat)
# upsample
self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1)
self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1)
self.pixel_shuffle = nn.PixelShuffle(2)
self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
def forward(self, x):
b, t, c, h, w = x.size()
if self.hr_in:
assert h % 16 == 0 and w % 16 == 0, ('The height and width must be multiple of 16.')
else:
assert h % 4 == 0 and w % 4 == 0, ('The height and width must be multiple of 4.')
x_center = x[:, self.center_frame_idx, :, :, :].contiguous()
# extract features for each frame
# L1
if self.with_predeblur:
feat_l1 = self.conv_1x1(self.predeblur(x.view(-1, c, h, w)))
if self.hr_in:
h, w = h // 4, w // 4
else:
feat_l1 = self.lrelu(self.conv_first(x.view(-1, c, h, w)))
feat_l1 = self.feature_extraction(feat_l1)
# L2
feat_l2 = self.lrelu(self.conv_l2_1(feat_l1))
feat_l2 = self.lrelu(self.conv_l2_2(feat_l2))
# L3
feat_l3 = self.lrelu(self.conv_l3_1(feat_l2))
feat_l3 = self.lrelu(self.conv_l3_2(feat_l3))
feat_l1 = feat_l1.view(b, t, -1, h, w)
feat_l2 = feat_l2.view(b, t, -1, h // 2, w // 2)
feat_l3 = feat_l3.view(b, t, -1, h // 4, w // 4)
# PCD alignment
ref_feat_l = [ # reference feature list
feat_l1[:, self.center_frame_idx, :, :, :].clone(), feat_l2[:, self.center_frame_idx, :, :, :].clone(),
feat_l3[:, self.center_frame_idx, :, :, :].clone()
]
aligned_feat = []
for i in range(t):
nbr_feat_l = [ # neighboring feature list
feat_l1[:, i, :, :, :].clone(), feat_l2[:, i, :, :, :].clone(), feat_l3[:, i, :, :, :].clone()
]
aligned_feat.append(self.pcd_align(nbr_feat_l, ref_feat_l))
aligned_feat = torch.stack(aligned_feat, dim=1) # (b, t, c, h, w)
if not self.with_tsa:
aligned_feat = aligned_feat.view(b, -1, h, w)
feat = self.fusion(aligned_feat)
out = self.reconstruction(feat)
out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
out = self.lrelu(self.conv_hr(out))
out = self.conv_last(out)
if self.hr_in:
base = x_center
else:
base = F.interpolate(x_center, scale_factor=4, mode='bilinear', align_corners=False)
out += base
return out
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from basicsr.utils.registry import ARCH_REGISTRY
from .hifacegan_util import BaseNetwork, LIPEncoder, SPADEResnetBlock, get_nonspade_norm_layer
class SPADEGenerator(BaseNetwork):
"""Generator with SPADEResBlock"""
def __init__(self,
num_in_ch=3,
num_feat=64,
use_vae=False,
z_dim=256,
crop_size=512,
norm_g='spectralspadesyncbatch3x3',
is_train=True,
init_train_phase=3): # progressive training disabled
super().__init__()
self.nf = num_feat
self.input_nc = num_in_ch
self.is_train = is_train
self.train_phase = init_train_phase
self.scale_ratio = 5 # hardcoded now
self.sw = crop_size // (2**self.scale_ratio)
self.sh = self.sw # 20210519: By default use square image, aspect_ratio = 1.0
if use_vae:
# In case of VAE, we will sample from random z vector
self.fc = nn.Linear(z_dim, 16 * self.nf * self.sw * self.sh)
else:
# Otherwise, we make the network deterministic by starting with
# downsampled segmentation map instead of random z
self.fc = nn.Conv2d(num_in_ch, 16 * self.nf, 3, padding=1)
self.head_0 = SPADEResnetBlock(16 * self.nf, 16 * self.nf, norm_g)
self.g_middle_0 = SPADEResnetBlock(16 * self.nf, 16 * self.nf, norm_g)
self.g_middle_1 = SPADEResnetBlock(16 * self.nf, 16 * self.nf, norm_g)
self.ups = nn.ModuleList([
SPADEResnetBlock(16 * self.nf, 8 * self.nf, norm_g),
SPADEResnetBlock(8 * self.nf, 4 * self.nf, norm_g),
SPADEResnetBlock(4 * self.nf, 2 * self.nf, norm_g),
SPADEResnetBlock(2 * self.nf, 1 * self.nf, norm_g)
])
self.to_rgbs = nn.ModuleList([
nn.Conv2d(8 * self.nf, 3, 3, padding=1),
nn.Conv2d(4 * self.nf, 3, 3, padding=1),
nn.Conv2d(2 * self.nf, 3, 3, padding=1),
nn.Conv2d(1 * self.nf, 3, 3, padding=1)
])
self.up = nn.Upsample(scale_factor=2)
def encode(self, input_tensor):
"""
Encode input_tensor into feature maps, can be overridden in derived classes
Default: nearest downsampling of 2**5 = 32 times
"""
h, w = input_tensor.size()[-2:]
sh, sw = h // 2**self.scale_ratio, w // 2**self.scale_ratio
x = F.interpolate(input_tensor, size=(sh, sw))
return self.fc(x)
def forward(self, x):
# In oroginal SPADE, seg means a segmentation map, but here we use x instead.
seg = x
x = self.encode(x)
x = self.head_0(x, seg)
x = self.up(x)
x = self.g_middle_0(x, seg)
x = self.g_middle_1(x, seg)
if self.is_train:
phase = self.train_phase + 1
else:
phase = len(self.to_rgbs)
for i in range(phase):
x = self.up(x)
x = self.ups[i](x, seg)
x = self.to_rgbs[phase - 1](F.leaky_relu(x, 2e-1))
x = torch.tanh(x)
return x
def mixed_guidance_forward(self, input_x, seg=None, n=0, mode='progressive'):
"""
A helper class for subspace visualization. Input and seg are different images.
For the first n levels (including encoder) we use input, for the rest we use seg.
If mode = 'progressive', the output's like: AAABBB
If mode = 'one_plug', the output's like: AAABAA
If mode = 'one_ablate', the output's like: BBBABB
"""
if seg is None:
return self.forward(input_x)
if self.is_train:
phase = self.train_phase + 1
else:
phase = len(self.to_rgbs)
if mode == 'progressive':
n = max(min(n, 4 + phase), 0)
guide_list = [input_x] * n + [seg] * (4 + phase - n)
elif mode == 'one_plug':
n = max(min(n, 4 + phase - 1), 0)
guide_list = [seg] * (4 + phase)
guide_list[n] = input_x
elif mode == 'one_ablate':
if n > 3 + phase:
return self.forward(input_x)
guide_list = [input_x] * (4 + phase)
guide_list[n] = seg
x = self.encode(guide_list[0])
x = self.head_0(x, guide_list[1])
x = self.up(x)
x = self.g_middle_0(x, guide_list[2])
x = self.g_middle_1(x, guide_list[3])
for i in range(phase):
x = self.up(x)
x = self.ups[i](x, guide_list[4 + i])
x = self.to_rgbs[phase - 1](F.leaky_relu(x, 2e-1))
x = torch.tanh(x)
return x
@ARCH_REGISTRY.register()
class HiFaceGAN(SPADEGenerator):
"""
HiFaceGAN: SPADEGenerator with a learnable feature encoder
Current encoder design: LIPEncoder
"""
def __init__(self,
num_in_ch=3,
num_feat=64,
use_vae=False,
z_dim=256,
crop_size=512,
norm_g='spectralspadesyncbatch3x3',
is_train=True,
init_train_phase=3):
super().__init__(num_in_ch, num_feat, use_vae, z_dim, crop_size, norm_g, is_train, init_train_phase)
self.lip_encoder = LIPEncoder(num_in_ch, num_feat, self.sw, self.sh, self.scale_ratio)
def encode(self, input_tensor):
return self.lip_encoder(input_tensor)
@ARCH_REGISTRY.register()
class HiFaceGANDiscriminator(BaseNetwork):
"""
Inspired by pix2pixHD multiscale discriminator.
Args:
num_in_ch (int): Channel number of inputs. Default: 3.
num_out_ch (int): Channel number of outputs. Default: 3.
conditional_d (bool): Whether use conditional discriminator.
Default: True.
num_d (int): Number of Multiscale discriminators. Default: 3.
n_layers_d (int): Number of downsample layers in each D. Default: 4.
num_feat (int): Channel number of base intermediate features.
Default: 64.
norm_d (str): String to determine normalization layers in D.
Choices: [spectral][instance/batch/syncbatch]
Default: 'spectralinstance'.
keep_features (bool): Keep intermediate features for matching loss, etc.
Default: True.
"""
def __init__(self,
num_in_ch=3,
num_out_ch=3,
conditional_d=True,
num_d=2,
n_layers_d=4,
num_feat=64,
norm_d='spectralinstance',
keep_features=True):
super().__init__()
self.num_d = num_d
input_nc = num_in_ch
if conditional_d:
input_nc += num_out_ch
for i in range(num_d):
subnet_d = NLayerDiscriminator(input_nc, n_layers_d, num_feat, norm_d, keep_features)
self.add_module(f'discriminator_{i}', subnet_d)
def downsample(self, x):
return F.avg_pool2d(x, kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False)
# Returns list of lists of discriminator outputs.
# The final result is of size opt.num_d x opt.n_layers_D
def forward(self, x):
result = []
for _, _net_d in self.named_children():
out = _net_d(x)
result.append(out)
x = self.downsample(x)
return result
class NLayerDiscriminator(BaseNetwork):
"""Defines the PatchGAN discriminator with the specified arguments."""
def __init__(self, input_nc, n_layers_d, num_feat, norm_d, keep_features):
super().__init__()
kw = 4
padw = int(np.ceil((kw - 1.0) / 2))
nf = num_feat
self.keep_features = keep_features
norm_layer = get_nonspade_norm_layer(norm_d)
sequence = [[nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, False)]]
for n in range(1, n_layers_d):
nf_prev = nf
nf = min(nf * 2, 512)
stride = 1 if n == n_layers_d - 1 else 2
sequence += [[
norm_layer(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=stride, padding=padw)),
nn.LeakyReLU(0.2, False)
]]
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
# We divide the layers into groups to extract intermediate layer outputs
for n in range(len(sequence)):
self.add_module('model' + str(n), nn.Sequential(*sequence[n]))
def forward(self, x):
results = [x]
for submodel in self.children():
intermediate_output = submodel(results[-1])
results.append(intermediate_output)
if self.keep_features:
return results[1:]
else:
return results[-1]
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
# Warning: spectral norm could be buggy
# under eval mode and multi-GPU inference
# A workaround is sticking to single-GPU inference and train mode
from torch.nn.utils import spectral_norm
class SPADE(nn.Module):
def __init__(self, config_text, norm_nc, label_nc):
super().__init__()
assert config_text.startswith('spade')
parsed = re.search('spade(\\D+)(\\d)x\\d', config_text)
param_free_norm_type = str(parsed.group(1))
ks = int(parsed.group(2))
if param_free_norm_type == 'instance':
self.param_free_norm = nn.InstanceNorm2d(norm_nc)
elif param_free_norm_type == 'syncbatch':
print('SyncBatchNorm is currently not supported under single-GPU mode, switch to "instance" instead')
self.param_free_norm = nn.InstanceNorm2d(norm_nc)
elif param_free_norm_type == 'batch':
self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
else:
raise ValueError(f'{param_free_norm_type} is not a recognized param-free norm type in SPADE')
# The dimension of the intermediate embedding space. Yes, hardcoded.
nhidden = 128 if norm_nc > 128 else norm_nc
pw = ks // 2
self.mlp_shared = nn.Sequential(nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw), nn.ReLU())
self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw, bias=False)
self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw, bias=False)
def forward(self, x, segmap):
# Part 1. generate parameter-free normalized activations
normalized = self.param_free_norm(x)
# Part 2. produce scaling and bias conditioned on semantic map
segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
actv = self.mlp_shared(segmap)
gamma = self.mlp_gamma(actv)
beta = self.mlp_beta(actv)
# apply scale and bias
out = normalized * gamma + beta
return out
class SPADEResnetBlock(nn.Module):
"""
ResNet block that uses SPADE. It differs from the ResNet block of pix2pixHD in that
it takes in the segmentation map as input, learns the skip connection if necessary,
and applies normalization first and then convolution.
This architecture seemed like a standard architecture for unconditional or
class-conditional GAN architecture using residual block.
The code was inspired from https://github.com/LMescheder/GAN_stability.
"""
def __init__(self, fin, fout, norm_g='spectralspadesyncbatch3x3', semantic_nc=3):
super().__init__()
# Attributes
self.learned_shortcut = (fin != fout)
fmiddle = min(fin, fout)
# create conv layers
self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
if self.learned_shortcut:
self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
# apply spectral norm if specified
if 'spectral' in norm_g:
self.conv_0 = spectral_norm(self.conv_0)
self.conv_1 = spectral_norm(self.conv_1)
if self.learned_shortcut:
self.conv_s = spectral_norm(self.conv_s)
# define normalization layers
spade_config_str = norm_g.replace('spectral', '')
self.norm_0 = SPADE(spade_config_str, fin, semantic_nc)
self.norm_1 = SPADE(spade_config_str, fmiddle, semantic_nc)
if self.learned_shortcut:
self.norm_s = SPADE(spade_config_str, fin, semantic_nc)
# note the resnet block with SPADE also takes in |seg|,
# the semantic segmentation map as input
def forward(self, x, seg):
x_s = self.shortcut(x, seg)
dx = self.conv_0(self.act(self.norm_0(x, seg)))
dx = self.conv_1(self.act(self.norm_1(dx, seg)))
out = x_s + dx
return out
def shortcut(self, x, seg):
if self.learned_shortcut:
x_s = self.conv_s(self.norm_s(x, seg))
else:
x_s = x
return x_s
def act(self, x):
return F.leaky_relu(x, 2e-1)
class BaseNetwork(nn.Module):
""" A basis for hifacegan archs with custom initialization """
def init_weights(self, init_type='normal', gain=0.02):
def init_func(m):
classname = m.__class__.__name__
if classname.find('BatchNorm2d') != -1:
if hasattr(m, 'weight') and m.weight is not None:
init.normal_(m.weight.data, 1.0, gain)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == 'xavier_uniform':
init.xavier_uniform_(m.weight.data, gain=1.0)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=gain)
elif init_type == 'none': # uses pytorch's default init method
m.reset_parameters()
else:
raise NotImplementedError(f'initialization method [{init_type}] is not implemented')
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
self.apply(init_func)
# propagate to children
for m in self.children():
if hasattr(m, 'init_weights'):
m.init_weights(init_type, gain)
def forward(self, x):
pass
def lip2d(x, logit, kernel=3, stride=2, padding=1):
weight = logit.exp()
return F.avg_pool2d(x * weight, kernel, stride, padding) / F.avg_pool2d(weight, kernel, stride, padding)
class SoftGate(nn.Module):
COEFF = 12.0
def forward(self, x):
return torch.sigmoid(x).mul(self.COEFF)
class SimplifiedLIP(nn.Module):
def __init__(self, channels):
super(SimplifiedLIP, self).__init__()
self.logit = nn.Sequential(
nn.Conv2d(channels, channels, 3, padding=1, bias=False), nn.InstanceNorm2d(channels, affine=True),
SoftGate())
def init_layer(self):
self.logit[0].weight.data.fill_(0.0)
def forward(self, x):
frac = lip2d(x, self.logit(x))
return frac
class LIPEncoder(BaseNetwork):
"""Local Importance-based Pooling (Ziteng Gao et.al.,ICCV 2019)"""
def __init__(self, input_nc, ngf, sw, sh, n_2xdown, norm_layer=nn.InstanceNorm2d):
super().__init__()
self.sw = sw
self.sh = sh
self.max_ratio = 16
# 20200310: Several Convolution (stride 1) + LIP blocks, 4 fold
kw = 3
pw = (kw - 1) // 2
model = [
nn.Conv2d(input_nc, ngf, kw, stride=1, padding=pw, bias=False),
norm_layer(ngf),
nn.ReLU(),
]
cur_ratio = 1
for i in range(n_2xdown):
next_ratio = min(cur_ratio * 2, self.max_ratio)
model += [
SimplifiedLIP(ngf * cur_ratio),
nn.Conv2d(ngf * cur_ratio, ngf * next_ratio, kw, stride=1, padding=pw),
norm_layer(ngf * next_ratio),
]
cur_ratio = next_ratio
if i < n_2xdown - 1:
model += [nn.ReLU(inplace=True)]
self.model = nn.Sequential(*model)
def forward(self, x):
return self.model(x)
def get_nonspade_norm_layer(norm_type='instance'):
# helper function to get # output channels of the previous layer
def get_out_channel(layer):
if hasattr(layer, 'out_channels'):
return getattr(layer, 'out_channels')
return layer.weight.size(0)
# this function will be returned
def add_norm_layer(layer):
nonlocal norm_type
if norm_type.startswith('spectral'):
layer = spectral_norm(layer)
subnorm_type = norm_type[len('spectral'):]
if subnorm_type == 'none' or len(subnorm_type) == 0:
return layer
# remove bias in the previous layer, which is meaningless
# since it has no effect after normalization
if getattr(layer, 'bias', None) is not None:
delattr(layer, 'bias')
layer.register_parameter('bias', None)
if subnorm_type == 'batch':
norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
elif subnorm_type == 'sync_batch':
print('SyncBatchNorm is currently not supported under single-GPU mode, switch to "instance" instead')
# norm_layer = SynchronizedBatchNorm2d(
# get_out_channel(layer), affine=True)
norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
elif subnorm_type == 'instance':
norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
else:
raise ValueError(f'normalization layer {subnorm_type} is not recognized')
return nn.Sequential(layer, norm_layer)
print('This is a legacy from nvlabs/SPADE, and will be removed in future versions.')
return add_norm_layer
# Modified from https://github.com/mseitzer/pytorch-fid/blob/master/pytorch_fid/inception.py # noqa: E501
# For FID metric
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.model_zoo import load_url
from torchvision import models
# Inception weights ported to Pytorch from
# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501
LOCAL_FID_WEIGHTS = 'experiments/pretrained_models/pt_inception-2015-12-05-6726825d.pth' # noqa: E501
class InceptionV3(nn.Module):
"""Pretrained InceptionV3 network returning feature maps"""
# Index of default block of inception to return,
# corresponds to output of final average pooling
DEFAULT_BLOCK_INDEX = 3
# Maps feature dimensionality to their output blocks indices
BLOCK_INDEX_BY_DIM = {
64: 0, # First max pooling features
192: 1, # Second max pooling features
768: 2, # Pre-aux classifier features
2048: 3 # Final average pooling features
}
def __init__(self,
output_blocks=(DEFAULT_BLOCK_INDEX),
resize_input=True,
normalize_input=True,
requires_grad=False,
use_fid_inception=True):
"""Build pretrained InceptionV3.
Args:
output_blocks (list[int]): Indices of blocks to return features of.
Possible values are:
- 0: corresponds to output of first max pooling
- 1: corresponds to output of second max pooling
- 2: corresponds to output which is fed to aux classifier
- 3: corresponds to output of final average pooling
resize_input (bool): If true, bilinearly resizes input to width and
height 299 before feeding input to model. As the network
without fully connected layers is fully convolutional, it
should be able to handle inputs of arbitrary size, so resizing
might not be strictly needed. Default: True.
normalize_input (bool): If true, scales the input from range (0, 1)
to the range the pretrained Inception network expects,
namely (-1, 1). Default: True.
requires_grad (bool): If true, parameters of the model require
gradients. Possibly useful for finetuning the network.
Default: False.
use_fid_inception (bool): If true, uses the pretrained Inception
model used in Tensorflow's FID implementation.
If false, uses the pretrained Inception model available in
torchvision. The FID Inception model has different weights
and a slightly different structure from torchvision's
Inception model. If you want to compute FID scores, you are
strongly advised to set this parameter to true to get
comparable results. Default: True.
"""
super(InceptionV3, self).__init__()
self.resize_input = resize_input
self.normalize_input = normalize_input
self.output_blocks = sorted(output_blocks)
self.last_needed_block = max(output_blocks)
assert self.last_needed_block <= 3, ('Last possible output block index is 3')
self.blocks = nn.ModuleList()
if use_fid_inception:
inception = fid_inception_v3()
else:
try:
inception = models.inception_v3(pretrained=True, init_weights=False)
except TypeError:
# pytorch < 1.5 does not have init_weights for inception_v3
inception = models.inception_v3(pretrained=True)
# Block 0: input to maxpool1
block0 = [
inception.Conv2d_1a_3x3, inception.Conv2d_2a_3x3, inception.Conv2d_2b_3x3,
nn.MaxPool2d(kernel_size=3, stride=2)
]
self.blocks.append(nn.Sequential(*block0))
# Block 1: maxpool1 to maxpool2
if self.last_needed_block >= 1:
block1 = [inception.Conv2d_3b_1x1, inception.Conv2d_4a_3x3, nn.MaxPool2d(kernel_size=3, stride=2)]
self.blocks.append(nn.Sequential(*block1))
# Block 2: maxpool2 to aux classifier
if self.last_needed_block >= 2:
block2 = [
inception.Mixed_5b,
inception.Mixed_5c,
inception.Mixed_5d,
inception.Mixed_6a,
inception.Mixed_6b,
inception.Mixed_6c,
inception.Mixed_6d,
inception.Mixed_6e,
]
self.blocks.append(nn.Sequential(*block2))
# Block 3: aux classifier to final avgpool
if self.last_needed_block >= 3:
block3 = [
inception.Mixed_7a, inception.Mixed_7b, inception.Mixed_7c,
nn.AdaptiveAvgPool2d(output_size=(1, 1))
]
self.blocks.append(nn.Sequential(*block3))
for param in self.parameters():
param.requires_grad = requires_grad
def forward(self, x):
"""Get Inception feature maps.
Args:
x (Tensor): Input tensor of shape (b, 3, h, w).
Values are expected to be in range (-1, 1). You can also input
(0, 1) with setting normalize_input = True.
Returns:
list[Tensor]: Corresponding to the selected output block, sorted
ascending by index.
"""
output = []
if self.resize_input:
x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False)
if self.normalize_input:
x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
for idx, block in enumerate(self.blocks):
x = block(x)
if idx in self.output_blocks:
output.append(x)
if idx == self.last_needed_block:
break
return output
def fid_inception_v3():
"""Build pretrained Inception model for FID computation.
The Inception model for FID computation uses a different set of weights
and has a slightly different structure than torchvision's Inception.
This method first constructs torchvision's Inception and then patches the
necessary parts that are different in the FID Inception model.
"""
try:
inception = models.inception_v3(num_classes=1008, aux_logits=False, pretrained=False, init_weights=False)
except TypeError:
# pytorch < 1.5 does not have init_weights for inception_v3
inception = models.inception_v3(num_classes=1008, aux_logits=False, pretrained=False)
inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
inception.Mixed_7b = FIDInceptionE_1(1280)
inception.Mixed_7c = FIDInceptionE_2(2048)
if os.path.exists(LOCAL_FID_WEIGHTS):
state_dict = torch.load(LOCAL_FID_WEIGHTS, map_location=lambda storage, loc: storage)
else:
state_dict = load_url(FID_WEIGHTS_URL, progress=True)
inception.load_state_dict(state_dict)
return inception
class FIDInceptionA(models.inception.InceptionA):
"""InceptionA block patched for FID computation"""
def __init__(self, in_channels, pool_features):
super(FIDInceptionA, self).__init__(in_channels, pool_features)
def forward(self, x):
branch1x1 = self.branch1x1(x)
branch5x5 = self.branch5x5_1(x)
branch5x5 = self.branch5x5_2(branch5x5)
branch3x3dbl = self.branch3x3dbl_1(x)
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
# Patch: Tensorflow's average pool does not use the padded zero's in
# its average calculation
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)
branch_pool = self.branch_pool(branch_pool)
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
return torch.cat(outputs, 1)
class FIDInceptionC(models.inception.InceptionC):
"""InceptionC block patched for FID computation"""
def __init__(self, in_channels, channels_7x7):
super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
def forward(self, x):
branch1x1 = self.branch1x1(x)
branch7x7 = self.branch7x7_1(x)
branch7x7 = self.branch7x7_2(branch7x7)
branch7x7 = self.branch7x7_3(branch7x7)
branch7x7dbl = self.branch7x7dbl_1(x)
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
# Patch: Tensorflow's average pool does not use the padded zero's in
# its average calculation
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)
branch_pool = self.branch_pool(branch_pool)
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
return torch.cat(outputs, 1)
class FIDInceptionE_1(models.inception.InceptionE):
"""First InceptionE block patched for FID computation"""
def __init__(self, in_channels):
super(FIDInceptionE_1, self).__init__(in_channels)
def forward(self, x):
branch1x1 = self.branch1x1(x)
branch3x3 = self.branch3x3_1(x)
branch3x3 = [
self.branch3x3_2a(branch3x3),
self.branch3x3_2b(branch3x3),
]
branch3x3 = torch.cat(branch3x3, 1)
branch3x3dbl = self.branch3x3dbl_1(x)
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
branch3x3dbl = [
self.branch3x3dbl_3a(branch3x3dbl),
self.branch3x3dbl_3b(branch3x3dbl),
]
branch3x3dbl = torch.cat(branch3x3dbl, 1)
# Patch: Tensorflow's average pool does not use the padded zero's in
# its average calculation
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)
branch_pool = self.branch_pool(branch_pool)
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
return torch.cat(outputs, 1)
class FIDInceptionE_2(models.inception.InceptionE):
"""Second InceptionE block patched for FID computation"""
def __init__(self, in_channels):
super(FIDInceptionE_2, self).__init__(in_channels)
def forward(self, x):
branch1x1 = self.branch1x1(x)
branch3x3 = self.branch3x3_1(x)
branch3x3 = [
self.branch3x3_2a(branch3x3),
self.branch3x3_2b(branch3x3),
]
branch3x3 = torch.cat(branch3x3, 1)
branch3x3dbl = self.branch3x3dbl_1(x)
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
branch3x3dbl = [
self.branch3x3dbl_3a(branch3x3dbl),
self.branch3x3dbl_3b(branch3x3dbl),
]
branch3x3dbl = torch.cat(branch3x3dbl, 1)
# Patch: The FID Inception model uses max pooling instead of average
# pooling. This is likely an error in this specific Inception
# implementation, as other Inception models use average pooling here
# (which matches the description in the paper).
branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
branch_pool = self.branch_pool(branch_pool)
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
return torch.cat(outputs, 1)
import torch
from torch import nn as nn
from basicsr.utils.registry import ARCH_REGISTRY
from .arch_util import Upsample, make_layer
class ChannelAttention(nn.Module):
"""Channel attention used in RCAN.
Args:
num_feat (int): Channel number of intermediate features.
squeeze_factor (int): Channel squeeze factor. Default: 16.
"""
def __init__(self, num_feat, squeeze_factor=16):
super(ChannelAttention, self).__init__()
self.attention = nn.Sequential(
nn.AdaptiveAvgPool2d(1), nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0),
nn.ReLU(inplace=True), nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0), nn.Sigmoid())
def forward(self, x):
y = self.attention(x)
return x * y
class RCAB(nn.Module):
"""Residual Channel Attention Block (RCAB) used in RCAN.
Args:
num_feat (int): Channel number of intermediate features.
squeeze_factor (int): Channel squeeze factor. Default: 16.
res_scale (float): Scale the residual. Default: 1.
"""
def __init__(self, num_feat, squeeze_factor=16, res_scale=1):
super(RCAB, self).__init__()
self.res_scale = res_scale
self.rcab = nn.Sequential(
nn.Conv2d(num_feat, num_feat, 3, 1, 1), nn.ReLU(True), nn.Conv2d(num_feat, num_feat, 3, 1, 1),
ChannelAttention(num_feat, squeeze_factor))
def forward(self, x):
res = self.rcab(x) * self.res_scale
return res + x
class ResidualGroup(nn.Module):
"""Residual Group of RCAB.
Args:
num_feat (int): Channel number of intermediate features.
num_block (int): Block number in the body network.
squeeze_factor (int): Channel squeeze factor. Default: 16.
res_scale (float): Scale the residual. Default: 1.
"""
def __init__(self, num_feat, num_block, squeeze_factor=16, res_scale=1):
super(ResidualGroup, self).__init__()
self.residual_group = make_layer(
RCAB, num_block, num_feat=num_feat, squeeze_factor=squeeze_factor, res_scale=res_scale)
self.conv = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
def forward(self, x):
res = self.conv(self.residual_group(x))
return res + x
@ARCH_REGISTRY.register()
class RCAN(nn.Module):
"""Residual Channel Attention Networks.
``Paper: Image Super-Resolution Using Very Deep Residual Channel Attention Networks``
Reference: https://github.com/yulunzhang/RCAN
Args:
num_in_ch (int): Channel number of inputs.
num_out_ch (int): Channel number of outputs.
num_feat (int): Channel number of intermediate features.
Default: 64.
num_group (int): Number of ResidualGroup. Default: 10.
num_block (int): Number of RCAB in ResidualGroup. Default: 16.
squeeze_factor (int): Channel squeeze factor. Default: 16.
upscale (int): Upsampling factor. Support 2^n and 3.
Default: 4.
res_scale (float): Used to scale the residual in residual block.
Default: 1.
img_range (float): Image range. Default: 255.
rgb_mean (tuple[float]): Image mean in RGB orders.
Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset.
"""
def __init__(self,
num_in_ch,
num_out_ch,
num_feat=64,
num_group=10,
num_block=16,
squeeze_factor=16,
upscale=4,
res_scale=1,
img_range=255.,
rgb_mean=(0.4488, 0.4371, 0.4040)):
super(RCAN, self).__init__()
self.img_range = img_range
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
self.body = make_layer(
ResidualGroup,
num_group,
num_feat=num_feat,
num_block=num_block,
squeeze_factor=squeeze_factor,
res_scale=res_scale)
self.conv_after_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.upsample = Upsample(upscale, num_feat)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
def forward(self, x):
self.mean = self.mean.type_as(x)
x = (x - self.mean) * self.img_range
x = self.conv_first(x)
res = self.conv_after_body(self.body(x))
res += x
x = self.conv_last(self.upsample(res))
x = x / self.img_range + self.mean
return x
import torch
import torch.nn as nn
from basicsr.utils.registry import ARCH_REGISTRY
from .arch_util import ResidualBlockNoBN, make_layer
class MeanShift(nn.Conv2d):
""" Data normalization with mean and std.
Args:
rgb_range (int): Maximum value of RGB.
rgb_mean (list[float]): Mean for RGB channels.
rgb_std (list[float]): Std for RGB channels.
sign (int): For subtraction, sign is -1, for addition, sign is 1.
Default: -1.
requires_grad (bool): Whether to update the self.weight and self.bias.
Default: True.
"""
def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1, requires_grad=True):
super(MeanShift, self).__init__(3, 3, kernel_size=1)
std = torch.Tensor(rgb_std)
self.weight.data = torch.eye(3).view(3, 3, 1, 1)
self.weight.data.div_(std.view(3, 1, 1, 1))
self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
self.bias.data.div_(std)
self.requires_grad = requires_grad
class EResidualBlockNoBN(nn.Module):
"""Enhanced Residual block without BN.
There are three convolution layers in residual branch.
"""
def __init__(self, in_channels, out_channels):
super(EResidualBlockNoBN, self).__init__()
self.body = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 1, 1, 0),
)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
out = self.body(x)
out = self.relu(out + x)
return out
class MergeRun(nn.Module):
""" Merge-and-run unit.
This unit contains two branches with different dilated convolutions,
followed by a convolution to process the concatenated features.
Paper: Real Image Denoising with Feature Attention
Ref git repo: https://github.com/saeed-anwar/RIDNet
"""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
super(MergeRun, self).__init__()
self.dilation1 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding), nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size, stride, 2, 2), nn.ReLU(inplace=True))
self.dilation2 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, 3, 3), nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size, stride, 4, 4), nn.ReLU(inplace=True))
self.aggregation = nn.Sequential(
nn.Conv2d(out_channels * 2, out_channels, kernel_size, stride, padding), nn.ReLU(inplace=True))
def forward(self, x):
dilation1 = self.dilation1(x)
dilation2 = self.dilation2(x)
out = torch.cat([dilation1, dilation2], dim=1)
out = self.aggregation(out)
out = out + x
return out
class ChannelAttention(nn.Module):
"""Channel attention.
Args:
num_feat (int): Channel number of intermediate features.
squeeze_factor (int): Channel squeeze factor. Default:
"""
def __init__(self, mid_channels, squeeze_factor=16):
super(ChannelAttention, self).__init__()
self.attention = nn.Sequential(
nn.AdaptiveAvgPool2d(1), nn.Conv2d(mid_channels, mid_channels // squeeze_factor, 1, padding=0),
nn.ReLU(inplace=True), nn.Conv2d(mid_channels // squeeze_factor, mid_channels, 1, padding=0), nn.Sigmoid())
def forward(self, x):
y = self.attention(x)
return x * y
class EAM(nn.Module):
"""Enhancement attention modules (EAM) in RIDNet.
This module contains a merge-and-run unit, a residual block,
an enhanced residual block and a feature attention unit.
Attributes:
merge: The merge-and-run unit.
block1: The residual block.
block2: The enhanced residual block.
ca: The feature/channel attention unit.
"""
def __init__(self, in_channels, mid_channels, out_channels):
super(EAM, self).__init__()
self.merge = MergeRun(in_channels, mid_channels)
self.block1 = ResidualBlockNoBN(mid_channels)
self.block2 = EResidualBlockNoBN(mid_channels, out_channels)
self.ca = ChannelAttention(out_channels)
# The residual block in the paper contains a relu after addition.
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
out = self.merge(x)
out = self.relu(self.block1(out))
out = self.block2(out)
out = self.ca(out)
return out
@ARCH_REGISTRY.register()
class RIDNet(nn.Module):
"""RIDNet: Real Image Denoising with Feature Attention.
Ref git repo: https://github.com/saeed-anwar/RIDNet
Args:
in_channels (int): Channel number of inputs.
mid_channels (int): Channel number of EAM modules.
Default: 64.
out_channels (int): Channel number of outputs.
num_block (int): Number of EAM. Default: 4.
img_range (float): Image range. Default: 255.
rgb_mean (tuple[float]): Image mean in RGB orders.
Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset.
"""
def __init__(self,
in_channels,
mid_channels,
out_channels,
num_block=4,
img_range=255.,
rgb_mean=(0.4488, 0.4371, 0.4040),
rgb_std=(1.0, 1.0, 1.0)):
super(RIDNet, self).__init__()
self.sub_mean = MeanShift(img_range, rgb_mean, rgb_std)
self.add_mean = MeanShift(img_range, rgb_mean, rgb_std, 1)
self.head = nn.Conv2d(in_channels, mid_channels, 3, 1, 1)
self.body = make_layer(
EAM, num_block, in_channels=mid_channels, mid_channels=mid_channels, out_channels=mid_channels)
self.tail = nn.Conv2d(mid_channels, out_channels, 3, 1, 1)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
res = self.sub_mean(x)
res = self.tail(self.body(self.relu(self.head(res))))
res = self.add_mean(res)
out = x + res
return out
import torch
from torch import nn as nn
from torch.nn import functional as F
from basicsr.utils.registry import ARCH_REGISTRY
from .arch_util import default_init_weights, make_layer, pixel_unshuffle
class ResidualDenseBlock(nn.Module):
"""Residual Dense Block.
Used in RRDB block in ESRGAN.
Args:
num_feat (int): Channel number of intermediate features.
num_grow_ch (int): Channels for each growth.
"""
def __init__(self, num_feat=64, num_grow_ch=32):
super(ResidualDenseBlock, self).__init__()
self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
# initialization
default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
def forward(self, x):
x1 = self.lrelu(self.conv1(x))
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
# Empirically, we use 0.2 to scale the residual for better performance
return x5 * 0.2 + x
class RRDB(nn.Module):
"""Residual in Residual Dense Block.
Used in RRDB-Net in ESRGAN.
Args:
num_feat (int): Channel number of intermediate features.
num_grow_ch (int): Channels for each growth.
"""
def __init__(self, num_feat, num_grow_ch=32):
super(RRDB, self).__init__()
self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
def forward(self, x):
out = self.rdb1(x)
out = self.rdb2(out)
out = self.rdb3(out)
# Empirically, we use 0.2 to scale the residual for better performance
return out * 0.2 + x
@ARCH_REGISTRY.register()
class RRDBNet(nn.Module):
"""Networks consisting of Residual in Residual Dense Block, which is used
in ESRGAN.
ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
We extend ESRGAN for scale x2 and scale x1.
Note: This is one option for scale 1, scale 2 in RRDBNet.
We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
Args:
num_in_ch (int): Channel number of inputs.
num_out_ch (int): Channel number of outputs.
num_feat (int): Channel number of intermediate features.
Default: 64
num_block (int): Block number in the trunk network. Defaults: 23
num_grow_ch (int): Channels for each growth. Default: 32.
"""
def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
super(RRDBNet, self).__init__()
self.scale = scale
if scale == 2:
num_in_ch = num_in_ch * 4
elif scale == 1:
num_in_ch = num_in_ch * 16
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
# upsample
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
if self.scale == 2:
feat = pixel_unshuffle(x, scale=2)
elif self.scale == 1:
feat = pixel_unshuffle(x, scale=4)
else:
feat = x
feat = self.conv_first(feat)
body_feat = self.conv_body(self.body(feat))
feat = feat + body_feat
# upsample
feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment