Commit 2136e796 authored by mashun1's avatar mashun1
Browse files

codeformer

parents
Pipeline #699 canceled with stages
import importlib
from copy import deepcopy
from os import path as osp
from basicsr.utils import get_root_logger, scandir
from basicsr.utils.registry import ARCH_REGISTRY
__all__ = ['build_network']
# automatically scan and import arch modules for registry
# scan all the files under the 'archs' folder and collect files ending with
# '_arch.py'
arch_folder = osp.dirname(osp.abspath(__file__))
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
# import all the arch modules
_arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames]
def build_network(opt):
opt = deepcopy(opt)
network_type = opt.pop('type')
net = ARCH_REGISTRY.get(network_type)(**opt)
logger = get_root_logger()
logger.info(f'Network [{net.__class__.__name__}] is created.')
return net
import torch.nn as nn
from basicsr.utils.registry import ARCH_REGISTRY
def conv3x3(inplanes, outplanes, stride=1):
"""A simple wrapper for 3x3 convolution with padding.
Args:
inplanes (int): Channel number of inputs.
outplanes (int): Channel number of outputs.
stride (int): Stride in convolution. Default: 1.
"""
return nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False)
class BasicBlock(nn.Module):
"""Basic residual block used in the ResNetArcFace architecture.
Args:
inplanes (int): Channel number of inputs.
planes (int): Channel number of outputs.
stride (int): Stride in convolution. Default: 1.
downsample (nn.Module): The downsample module. Default: None.
"""
expansion = 1 # output channel expansion ratio
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class IRBlock(nn.Module):
"""Improved residual block (IR Block) used in the ResNetArcFace architecture.
Args:
inplanes (int): Channel number of inputs.
planes (int): Channel number of outputs.
stride (int): Stride in convolution. Default: 1.
downsample (nn.Module): The downsample module. Default: None.
use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
"""
expansion = 1 # output channel expansion ratio
def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
super(IRBlock, self).__init__()
self.bn0 = nn.BatchNorm2d(inplanes)
self.conv1 = conv3x3(inplanes, inplanes)
self.bn1 = nn.BatchNorm2d(inplanes)
self.prelu = nn.PReLU()
self.conv2 = conv3x3(inplanes, planes, stride)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
self.use_se = use_se
if self.use_se:
self.se = SEBlock(planes)
def forward(self, x):
residual = x
out = self.bn0(x)
out = self.conv1(out)
out = self.bn1(out)
out = self.prelu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.use_se:
out = self.se(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.prelu(out)
return out
class Bottleneck(nn.Module):
"""Bottleneck block used in the ResNetArcFace architecture.
Args:
inplanes (int): Channel number of inputs.
planes (int): Channel number of outputs.
stride (int): Stride in convolution. Default: 1.
downsample (nn.Module): The downsample module. Default: None.
"""
expansion = 4 # output channel expansion ratio
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class SEBlock(nn.Module):
"""The squeeze-and-excitation block (SEBlock) used in the IRBlock.
Args:
channel (int): Channel number of inputs.
reduction (int): Channel reduction ration. Default: 16.
"""
def __init__(self, channel, reduction=16):
super(SEBlock, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1) # pool to 1x1 without spatial information
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel),
nn.Sigmoid())
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y
@ARCH_REGISTRY.register()
class ResNetArcFace(nn.Module):
"""ArcFace with ResNet architectures.
Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition.
Args:
block (str): Block used in the ArcFace architecture.
layers (tuple(int)): Block numbers in each layer.
use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
"""
def __init__(self, block, layers, use_se=True):
if block == 'IRBlock':
block = IRBlock
self.inplanes = 64
self.use_se = use_se
super(ResNetArcFace, self).__init__()
self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.prelu = nn.PReLU()
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.bn4 = nn.BatchNorm2d(512)
self.dropout = nn.Dropout()
self.fc5 = nn.Linear(512 * 8 * 8, 512)
self.bn5 = nn.BatchNorm1d(512)
# initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.xavier_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
nn.init.constant_(m.bias, 0)
def _make_layer(self, block, planes, num_blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))
self.inplanes = planes
for _ in range(1, num_blocks):
layers.append(block(self.inplanes, planes, use_se=self.use_se))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.prelu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.bn4(x)
x = self.dropout(x)
x = x.view(x.size(0), -1)
x = self.fc5(x)
x = self.bn5(x)
return x
\ No newline at end of file
import collections.abc
import math
import torch
import torchvision
import warnings
from distutils.version import LooseVersion
from itertools import repeat
from torch import nn as nn
from torch.nn import functional as F
from torch.nn import init as init
from torch.nn.modules.batchnorm import _BatchNorm
from basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv
from basicsr.utils import get_root_logger
@torch.no_grad()
def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
"""Initialize network weights.
Args:
module_list (list[nn.Module] | nn.Module): Modules to be initialized.
scale (float): Scale initialized weights, especially for residual
blocks. Default: 1.
bias_fill (float): The value to fill bias. Default: 0
kwargs (dict): Other arguments for initialization function.
"""
if not isinstance(module_list, list):
module_list = [module_list]
for module in module_list:
for m in module.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, **kwargs)
m.weight.data *= scale
if m.bias is not None:
m.bias.data.fill_(bias_fill)
elif isinstance(m, nn.Linear):
init.kaiming_normal_(m.weight, **kwargs)
m.weight.data *= scale
if m.bias is not None:
m.bias.data.fill_(bias_fill)
elif isinstance(m, _BatchNorm):
init.constant_(m.weight, 1)
if m.bias is not None:
m.bias.data.fill_(bias_fill)
def make_layer(basic_block, num_basic_block, **kwarg):
"""Make layers by stacking the same blocks.
Args:
basic_block (nn.module): nn.module class for basic block.
num_basic_block (int): number of blocks.
Returns:
nn.Sequential: Stacked blocks in nn.Sequential.
"""
layers = []
for _ in range(num_basic_block):
layers.append(basic_block(**kwarg))
return nn.Sequential(*layers)
class ResidualBlockNoBN(nn.Module):
"""Residual block without BN.
It has a style of:
---Conv-ReLU-Conv-+-
|________________|
Args:
num_feat (int): Channel number of intermediate features.
Default: 64.
res_scale (float): Residual scale. Default: 1.
pytorch_init (bool): If set to True, use pytorch default init,
otherwise, use default_init_weights. Default: False.
"""
def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
super(ResidualBlockNoBN, self).__init__()
self.res_scale = res_scale
self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
self.relu = nn.ReLU(inplace=True)
if not pytorch_init:
default_init_weights([self.conv1, self.conv2], 0.1)
def forward(self, x):
identity = x
out = self.conv2(self.relu(self.conv1(x)))
return identity + out * self.res_scale
class Upsample(nn.Sequential):
"""Upsample module.
Args:
scale (int): Scale factor. Supported scales: 2^n and 3.
num_feat (int): Channel number of intermediate features.
"""
def __init__(self, scale, num_feat):
m = []
if (scale & (scale - 1)) == 0: # scale = 2^n
for _ in range(int(math.log(scale, 2))):
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(2))
elif scale == 3:
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(3))
else:
raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
super(Upsample, self).__init__(*m)
def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True):
"""Warp an image or feature map with optical flow.
Args:
x (Tensor): Tensor with size (n, c, h, w).
flow (Tensor): Tensor with size (n, h, w, 2), normal value.
interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
padding_mode (str): 'zeros' or 'border' or 'reflection'.
Default: 'zeros'.
align_corners (bool): Before pytorch 1.3, the default value is
align_corners=True. After pytorch 1.3, the default value is
align_corners=False. Here, we use the True as default.
Returns:
Tensor: Warped image or feature map.
"""
assert x.size()[-2:] == flow.size()[1:3]
_, _, h, w = x.size()
# create mesh grid
grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
grid.requires_grad = False
vgrid = grid + flow
# scale grid to [-1,1]
vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)
# TODO, what if align_corners=False
return output
def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False):
"""Resize a flow according to ratio or shape.
Args:
flow (Tensor): Precomputed flow. shape [N, 2, H, W].
size_type (str): 'ratio' or 'shape'.
sizes (list[int | float]): the ratio for resizing or the final output
shape.
1) The order of ratio should be [ratio_h, ratio_w]. For
downsampling, the ratio should be smaller than 1.0 (i.e., ratio
< 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
ratio > 1.0).
2) The order of output_size should be [out_h, out_w].
interp_mode (str): The mode of interpolation for resizing.
Default: 'bilinear'.
align_corners (bool): Whether align corners. Default: False.
Returns:
Tensor: Resized flow.
"""
_, _, flow_h, flow_w = flow.size()
if size_type == 'ratio':
output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
elif size_type == 'shape':
output_h, output_w = sizes[0], sizes[1]
else:
raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.')
input_flow = flow.clone()
ratio_h = output_h / flow_h
ratio_w = output_w / flow_w
input_flow[:, 0, :, :] *= ratio_w
input_flow[:, 1, :, :] *= ratio_h
resized_flow = F.interpolate(
input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)
return resized_flow
# TODO: may write a cpp file
def pixel_unshuffle(x, scale):
""" Pixel unshuffle.
Args:
x (Tensor): Input feature with shape (b, c, hh, hw).
scale (int): Downsample ratio.
Returns:
Tensor: the pixel unshuffled feature.
"""
b, c, hh, hw = x.size()
out_channel = c * (scale**2)
assert hh % scale == 0 and hw % scale == 0
h = hh // scale
w = hw // scale
x_view = x.view(b, c, h, scale, w, scale)
return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
class DCNv2Pack(ModulatedDeformConvPack):
"""Modulated deformable conv for deformable alignment.
Different from the official DCNv2Pack, which generates offsets and masks
from the preceding features, this DCNv2Pack takes another different
features to generate offsets and masks.
Ref:
Delving Deep into Deformable Alignment in Video Super-Resolution.
"""
def forward(self, x, feat):
out = self.conv_offset(feat)
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
offset_absmean = torch.mean(torch.abs(offset))
if offset_absmean > 50:
logger = get_root_logger()
logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.')
if LooseVersion(torchvision.__version__) >= LooseVersion('0.9.0'):
return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
self.dilation, mask)
else:
return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding,
self.dilation, self.groups, self.deformable_groups)
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
'The distribution of values may be incorrect.',
stacklevel=2)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
low = norm_cdf((a - mean) / std)
up = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [low, up], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * low - 1, 2 * up - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
r"""Fills the input Tensor with values drawn from a truncated
normal distribution.
From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
# From PyTorch
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple
\ No newline at end of file
import math
import numpy as np
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from typing import Optional, List
from basicsr.archs.vqgan_arch import *
from basicsr.utils import get_root_logger
from basicsr.utils.registry import ARCH_REGISTRY
def calc_mean_std(feat, eps=1e-5):
"""Calculate mean and std for adaptive_instance_normalization.
Args:
feat (Tensor): 4D tensor.
eps (float): A small value added to the variance to avoid
divide-by-zero. Default: 1e-5.
"""
size = feat.size()
assert len(size) == 4, 'The input feature should be 4D tensor.'
b, c = size[:2]
feat_var = feat.view(b, c, -1).var(dim=2) + eps
feat_std = feat_var.sqrt().view(b, c, 1, 1)
feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
return feat_mean, feat_std
def adaptive_instance_normalization(content_feat, style_feat):
"""Adaptive instance normalization.
Adjust the reference features to have the similar color and illuminations
as those in the degradate features.
Args:
content_feat (Tensor): The reference feature.
style_feat (Tensor): The degradate features.
"""
size = content_feat.size()
style_mean, style_std = calc_mean_std(style_feat)
content_mean, content_std = calc_mean_std(content_feat)
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
class PositionEmbeddingSine(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, x, mask=None):
if mask is None:
mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
not_mask = ~mask
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack(
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos_y = torch.stack(
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
class TransformerSALayer(nn.Module):
def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
super().__init__()
self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
# Implementation of Feedforward model - MLP
self.linear1 = nn.Linear(embed_dim, dim_mlp)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_mlp, embed_dim)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward(self, tgt,
tgt_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
# self attention
tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout1(tgt2)
# ffn
tgt2 = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout2(tgt2)
return tgt
class Fuse_sft_block(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.encode_enc = ResBlock(2*in_ch, out_ch)
self.scale = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
self.shift = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
def forward(self, enc_feat, dec_feat, w=1):
enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
scale = self.scale(enc_feat)
shift = self.shift(enc_feat)
residual = w * (dec_feat * scale + shift)
out = dec_feat + residual
return out
@ARCH_REGISTRY.register()
class CodeFormer(VQAutoEncoder):
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
codebook_size=1024, latent_size=256,
connect_list=['32', '64', '128', '256'],
fix_modules=['quantize','generator'], vqgan_path=None):
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
if vqgan_path is not None:
self.load_state_dict(
torch.load(vqgan_path, map_location='cpu')['params_ema'])
if fix_modules is not None:
for module in fix_modules:
for param in getattr(self, module).parameters():
param.requires_grad = False
self.connect_list = connect_list
self.n_layers = n_layers
self.dim_embd = dim_embd
self.dim_mlp = dim_embd*2
self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
self.feat_emb = nn.Linear(256, self.dim_embd)
# transformer
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
for _ in range(self.n_layers)])
# logits_predict head
self.idx_pred_layer = nn.Sequential(
nn.LayerNorm(dim_embd),
nn.Linear(dim_embd, codebook_size, bias=False))
self.channels = {
'16': 512,
'32': 256,
'64': 256,
'128': 128,
'256': 128,
'512': 64,
}
# after second residual block for > 16, before attn layer for ==16
self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
# after first residual block for > 16, before attn layer for ==16
self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
# fuse_convs_dict
self.fuse_convs_dict = nn.ModuleDict()
for f_size in self.connect_list:
in_ch = self.channels[f_size]
self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
# ################### Encoder #####################
enc_feat_dict = {}
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
for i, block in enumerate(self.encoder.blocks):
x = block(x)
if i in out_list:
enc_feat_dict[str(x.shape[-1])] = x.clone()
lq_feat = x
# ################# Transformer ###################
# quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
# BCHW -> BC(HW) -> (HW)BC
feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
query_emb = feat_emb
# Transformer encoder
for layer in self.ft_layers:
query_emb = layer(query_emb, query_pos=pos_emb)
# output logits
logits = self.idx_pred_layer(query_emb) # (hw)bn
logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
if code_only: # for training stage II
# logits doesn't need softmax before cross_entropy loss
return logits, lq_feat
# ################# Quantization ###################
# if self.training:
# quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
# # b(hw)c -> bc(hw) -> bchw
# quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
# ------------
soft_one_hot = F.softmax(logits, dim=2)
_, top_idx = torch.topk(soft_one_hot, 1, dim=2)
quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
# preserve gradients
# quant_feat = lq_feat + (quant_feat - lq_feat).detach()
if detach_16:
quant_feat = quant_feat.detach() # for training stage III
if adain:
quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
# ################## Generator ####################
x = quant_feat
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
for i, block in enumerate(self.generator.blocks):
x = block(x)
if i in fuse_list: # fuse after i-th block
f_size = str(x.shape[-1])
if w>= 0:
x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
out = x
# logits doesn't need softmax before cross_entropy loss
return out, logits, lq_feat
\ No newline at end of file
import torch
from torch import nn as nn
from torch.nn import functional as F
from basicsr.utils.registry import ARCH_REGISTRY
from .arch_util import default_init_weights, make_layer, pixel_unshuffle
class ResidualDenseBlock(nn.Module):
"""Residual Dense Block.
Used in RRDB block in ESRGAN.
Args:
num_feat (int): Channel number of intermediate features.
num_grow_ch (int): Channels for each growth.
"""
def __init__(self, num_feat=64, num_grow_ch=32):
super(ResidualDenseBlock, self).__init__()
self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
# initialization
default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
def forward(self, x):
x1 = self.lrelu(self.conv1(x))
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
# Emperically, we use 0.2 to scale the residual for better performance
return x5 * 0.2 + x
class RRDB(nn.Module):
"""Residual in Residual Dense Block.
Used in RRDB-Net in ESRGAN.
Args:
num_feat (int): Channel number of intermediate features.
num_grow_ch (int): Channels for each growth.
"""
def __init__(self, num_feat, num_grow_ch=32):
super(RRDB, self).__init__()
self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
def forward(self, x):
out = self.rdb1(x)
out = self.rdb2(out)
out = self.rdb3(out)
# Emperically, we use 0.2 to scale the residual for better performance
return out * 0.2 + x
@ARCH_REGISTRY.register()
class RRDBNet(nn.Module):
"""Networks consisting of Residual in Residual Dense Block, which is used
in ESRGAN.
ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
We extend ESRGAN for scale x2 and scale x1.
Note: This is one option for scale 1, scale 2 in RRDBNet.
We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
Args:
num_in_ch (int): Channel number of inputs.
num_out_ch (int): Channel number of outputs.
num_feat (int): Channel number of intermediate features.
Default: 64
num_block (int): Block number in the trunk network. Defaults: 23
num_grow_ch (int): Channels for each growth. Default: 32.
"""
def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
super(RRDBNet, self).__init__()
self.scale = scale
if scale == 2:
num_in_ch = num_in_ch * 4
elif scale == 1:
num_in_ch = num_in_ch * 16
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
# upsample
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
if self.scale == 2:
feat = pixel_unshuffle(x, scale=2)
elif self.scale == 1:
feat = pixel_unshuffle(x, scale=4)
else:
feat = x
feat = self.conv_first(feat)
body_feat = self.conv_body(self.body(feat))
feat = feat + body_feat
# upsample
feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
return out
\ No newline at end of file
import os
import torch
from collections import OrderedDict
from torch import nn as nn
from torchvision.models import vgg as vgg
from basicsr.utils.registry import ARCH_REGISTRY
VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth'
NAMES = {
'vgg11': [
'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2',
'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2',
'pool5'
],
'vgg13': [
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4',
'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5'
],
'vgg16': [
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2',
'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
'pool5'
],
'vgg19': [
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1',
'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1',
'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'
]
}
def insert_bn(names):
"""Insert bn layer after each conv.
Args:
names (list): The list of layer names.
Returns:
list: The list of layer names with bn layers.
"""
names_bn = []
for name in names:
names_bn.append(name)
if 'conv' in name:
position = name.replace('conv', '')
names_bn.append('bn' + position)
return names_bn
@ARCH_REGISTRY.register()
class VGGFeatureExtractor(nn.Module):
"""VGG network for feature extraction.
In this implementation, we allow users to choose whether use normalization
in the input feature and the type of vgg network. Note that the pretrained
path must fit the vgg type.
Args:
layer_name_list (list[str]): Forward function returns the corresponding
features according to the layer_name_list.
Example: {'relu1_1', 'relu2_1', 'relu3_1'}.
vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
use_input_norm (bool): If True, normalize the input image. Importantly,
the input feature must in the range [0, 1]. Default: True.
range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
Default: False.
requires_grad (bool): If true, the parameters of VGG network will be
optimized. Default: False.
remove_pooling (bool): If true, the max pooling operations in VGG net
will be removed. Default: False.
pooling_stride (int): The stride of max pooling operation. Default: 2.
"""
def __init__(self,
layer_name_list,
vgg_type='vgg19',
use_input_norm=True,
range_norm=False,
requires_grad=False,
remove_pooling=False,
pooling_stride=2):
super(VGGFeatureExtractor, self).__init__()
self.layer_name_list = layer_name_list
self.use_input_norm = use_input_norm
self.range_norm = range_norm
self.names = NAMES[vgg_type.replace('_bn', '')]
if 'bn' in vgg_type:
self.names = insert_bn(self.names)
# only borrow layers that will be used to avoid unused params
max_idx = 0
for v in layer_name_list:
idx = self.names.index(v)
if idx > max_idx:
max_idx = idx
if os.path.exists(VGG_PRETRAIN_PATH):
vgg_net = getattr(vgg, vgg_type)(pretrained=False)
state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage)
vgg_net.load_state_dict(state_dict)
else:
vgg_net = getattr(vgg, vgg_type)(pretrained=True)
features = vgg_net.features[:max_idx + 1]
modified_net = OrderedDict()
for k, v in zip(self.names, features):
if 'pool' in k:
# if remove_pooling is true, pooling operation will be removed
if remove_pooling:
continue
else:
# in some cases, we may want to change the default stride
modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride)
else:
modified_net[k] = v
self.vgg_net = nn.Sequential(modified_net)
if not requires_grad:
self.vgg_net.eval()
for param in self.parameters():
param.requires_grad = False
else:
self.vgg_net.train()
for param in self.parameters():
param.requires_grad = True
if self.use_input_norm:
# the mean is for image with range [0, 1]
self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
# the std is for image with range [0, 1]
self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
if self.range_norm:
x = (x + 1) / 2
if self.use_input_norm:
x = (x - self.mean) / self.std
output = {}
for key, layer in self.vgg_net._modules.items():
x = layer(x)
if key in self.layer_name_list:
output[key] = x.clone()
return output
'''
VQGAN code, adapted from the original created by the Unleashing Transformers authors:
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
'''
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
from basicsr.utils import get_root_logger
from basicsr.utils.registry import ARCH_REGISTRY
def normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
@torch.jit.script
def swish(x):
return x*torch.sigmoid(x)
# Define VQVAE classes
class VectorQuantizer(nn.Module):
def __init__(self, codebook_size, emb_dim, beta):
super(VectorQuantizer, self).__init__()
self.codebook_size = codebook_size # number of embeddings
self.emb_dim = emb_dim # dimension of embedding
self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
def forward(self, z):
# reshape z -> (batch, height, width, channel) and flatten
z = z.permute(0, 2, 3, 1).contiguous()
z_flattened = z.view(-1, self.emb_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \
2 * torch.matmul(z_flattened, self.embedding.weight.t())
mean_distance = torch.mean(d)
# find closest encodings
min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
# min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
# [0-1], higher score, higher confidence
# min_encoding_scores = torch.exp(-min_encoding_scores/10)
min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z)
min_encodings.scatter_(1, min_encoding_indices, 1)
# get quantized latent vectors
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
# compute loss for embedding
loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
# preserve gradients
z_q = z + (z_q - z).detach()
# perplexity
e_mean = torch.mean(min_encodings, dim=0)
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
return z_q, loss, {
"perplexity": perplexity,
"min_encodings": min_encodings,
"min_encoding_indices": min_encoding_indices,
"mean_distance": mean_distance
}
def get_codebook_feat(self, indices, shape):
# input indices: batch*token_num -> (batch*token_num)*1
# shape: batch, height, width, channel
indices = indices.view(-1,1)
min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
min_encodings.scatter_(1, indices, 1)
# get quantized latent vectors
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
if shape is not None: # reshape back to match original input shape
z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
return z_q
class GumbelQuantizer(nn.Module):
def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
super().__init__()
self.codebook_size = codebook_size # number of embeddings
self.emb_dim = emb_dim # dimension of embedding
self.straight_through = straight_through
self.temperature = temp_init
self.kl_weight = kl_weight
self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits
self.embed = nn.Embedding(codebook_size, emb_dim)
def forward(self, z):
hard = self.straight_through if self.training else True
logits = self.proj(z)
soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
# + kl divergence to the prior loss
qy = F.softmax(logits, dim=1)
diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
min_encoding_indices = soft_one_hot.argmax(dim=1)
return z_q, diff, {
"min_encoding_indices": min_encoding_indices
}
class Downsample(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x):
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
return x
class Upsample(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
x = self.conv(x)
return x
class ResBlock(nn.Module):
def __init__(self, in_channels, out_channels=None):
super(ResBlock, self).__init__()
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.norm1 = normalize(in_channels)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.norm2 = normalize(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x_in):
x = x_in
x = self.norm1(x)
x = swish(x)
x = self.conv1(x)
x = self.norm2(x)
x = swish(x)
x = self.conv2(x)
if self.in_channels != self.out_channels:
x_in = self.conv_out(x_in)
return x + x_in
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = normalize(in_channels)
self.q = torch.nn.Conv2d(
in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0
)
self.k = torch.nn.Conv2d(
in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0
)
self.v = torch.nn.Conv2d(
in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0
)
self.proj_out = torch.nn.Conv2d(
in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0
)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h*w)
q = q.permute(0, 2, 1)
k = k.reshape(b, c, h*w)
w_ = torch.bmm(q, k)
w_ = w_ * (int(c)**(-0.5))
w_ = F.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h*w)
w_ = w_.permute(0, 2, 1)
h_ = torch.bmm(v, w_)
h_ = h_.reshape(b, c, h, w)
h_ = self.proj_out(h_)
return x+h_
class Encoder(nn.Module):
def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
super().__init__()
self.nf = nf
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.attn_resolutions = attn_resolutions
curr_res = self.resolution
in_ch_mult = (1,)+tuple(ch_mult)
blocks = []
# initial convultion
blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
# residual and downsampling blocks, with attention on smaller res (16x16)
for i in range(self.num_resolutions):
block_in_ch = nf * in_ch_mult[i]
block_out_ch = nf * ch_mult[i]
for _ in range(self.num_res_blocks):
blocks.append(ResBlock(block_in_ch, block_out_ch))
block_in_ch = block_out_ch
if curr_res in attn_resolutions:
blocks.append(AttnBlock(block_in_ch))
if i != self.num_resolutions - 1:
blocks.append(Downsample(block_in_ch))
curr_res = curr_res // 2
# non-local attention block
blocks.append(ResBlock(block_in_ch, block_in_ch))
blocks.append(AttnBlock(block_in_ch))
blocks.append(ResBlock(block_in_ch, block_in_ch))
# normalise and convert to latent size
blocks.append(normalize(block_in_ch))
blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
self.blocks = nn.ModuleList(blocks)
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
class Generator(nn.Module):
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
super().__init__()
self.nf = nf
self.ch_mult = ch_mult
self.num_resolutions = len(self.ch_mult)
self.num_res_blocks = res_blocks
self.resolution = img_size
self.attn_resolutions = attn_resolutions
self.in_channels = emb_dim
self.out_channels = 3
block_in_ch = self.nf * self.ch_mult[-1]
curr_res = self.resolution // 2 ** (self.num_resolutions-1)
blocks = []
# initial conv
blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
# non-local attention block
blocks.append(ResBlock(block_in_ch, block_in_ch))
blocks.append(AttnBlock(block_in_ch))
blocks.append(ResBlock(block_in_ch, block_in_ch))
for i in reversed(range(self.num_resolutions)):
block_out_ch = self.nf * self.ch_mult[i]
for _ in range(self.num_res_blocks):
blocks.append(ResBlock(block_in_ch, block_out_ch))
block_in_ch = block_out_ch
if curr_res in self.attn_resolutions:
blocks.append(AttnBlock(block_in_ch))
if i != 0:
blocks.append(Upsample(block_in_ch))
curr_res = curr_res * 2
blocks.append(normalize(block_in_ch))
blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
self.blocks = nn.ModuleList(blocks)
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
@ARCH_REGISTRY.register()
class VQAutoEncoder(nn.Module):
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
super().__init__()
logger = get_root_logger()
self.in_channels = 3
self.nf = nf
self.n_blocks = res_blocks
self.codebook_size = codebook_size
self.embed_dim = emb_dim
self.ch_mult = ch_mult
self.resolution = img_size
self.attn_resolutions = attn_resolutions
self.quantizer_type = quantizer
self.encoder = Encoder(
self.in_channels,
self.nf,
self.embed_dim,
self.ch_mult,
self.n_blocks,
self.resolution,
self.attn_resolutions
)
if self.quantizer_type == "nearest":
self.beta = beta #0.25
self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta)
elif self.quantizer_type == "gumbel":
self.gumbel_num_hiddens = emb_dim
self.straight_through = gumbel_straight_through
self.kl_weight = gumbel_kl_weight
self.quantize = GumbelQuantizer(
self.codebook_size,
self.embed_dim,
self.gumbel_num_hiddens,
self.straight_through,
self.kl_weight
)
self.generator = Generator(
self.nf,
self.embed_dim,
self.ch_mult,
self.n_blocks,
self.resolution,
self.attn_resolutions
)
if model_path is not None:
chkpt = torch.load(model_path, map_location='cpu')
if 'params_ema' in chkpt:
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema'])
logger.info(f'vqgan is loaded from: {model_path} [params_ema]')
elif 'params' in chkpt:
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
logger.info(f'vqgan is loaded from: {model_path} [params]')
else:
raise ValueError(f'Wrong params!')
def forward(self, x):
x = self.encoder(x)
quant, codebook_loss, quant_stats = self.quantize(x)
x = self.generator(quant)
return x, codebook_loss, quant_stats
# patch based discriminator
@ARCH_REGISTRY.register()
class VQGANDiscriminator(nn.Module):
def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
super().__init__()
layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
ndf_mult = 1
ndf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
ndf_mult_prev = ndf_mult
ndf_mult = min(2 ** n, 8)
layers += [
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(ndf * ndf_mult),
nn.LeakyReLU(0.2, True)
]
ndf_mult_prev = ndf_mult
ndf_mult = min(2 ** n_layers, 8)
layers += [
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
nn.BatchNorm2d(ndf * ndf_mult),
nn.LeakyReLU(0.2, True)
]
layers += [
nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map
self.main = nn.Sequential(*layers)
if model_path is not None:
chkpt = torch.load(model_path, map_location='cpu')
if 'params_d' in chkpt:
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
elif 'params' in chkpt:
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
else:
raise ValueError(f'Wrong params!')
def forward(self, x):
return self.main(x)
\ No newline at end of file
import importlib
import numpy as np
import random
import torch
import torch.utils.data
from copy import deepcopy
from functools import partial
from os import path as osp
from basicsr.data.prefetch_dataloader import PrefetchDataLoader
from basicsr.utils import get_root_logger, scandir
from basicsr.utils.dist_util import get_dist_info
from basicsr.utils.registry import DATASET_REGISTRY
__all__ = ['build_dataset', 'build_dataloader']
# automatically scan and import dataset modules for registry
# scan all the files under the data folder with '_dataset' in file names
data_folder = osp.dirname(osp.abspath(__file__))
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
# import all the dataset modules
_dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames]
def build_dataset(dataset_opt):
"""Build dataset from options.
Args:
dataset_opt (dict): Configuration for dataset. It must constain:
name (str): Dataset name.
type (str): Dataset type.
"""
dataset_opt = deepcopy(dataset_opt)
dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
logger = get_root_logger()
logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} ' 'is built.')
return dataset
def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
"""Build dataloader.
Args:
dataset (torch.utils.data.Dataset): Dataset.
dataset_opt (dict): Dataset options. It contains the following keys:
phase (str): 'train' or 'val'.
num_worker_per_gpu (int): Number of workers for each GPU.
batch_size_per_gpu (int): Training batch size for each GPU.
num_gpu (int): Number of GPUs. Used only in the train phase.
Default: 1.
dist (bool): Whether in distributed training. Used only in the train
phase. Default: False.
sampler (torch.utils.data.sampler): Data sampler. Default: None.
seed (int | None): Seed. Default: None
"""
phase = dataset_opt['phase']
rank, _ = get_dist_info()
if phase == 'train':
if dist: # distributed training
batch_size = dataset_opt['batch_size_per_gpu']
num_workers = dataset_opt['num_worker_per_gpu']
else: # non-distributed training
multiplier = 1 if num_gpu == 0 else num_gpu
batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
dataloader_args = dict(
dataset=dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
sampler=sampler,
drop_last=True)
if sampler is None:
dataloader_args['shuffle'] = True
dataloader_args['worker_init_fn'] = partial(
worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
elif phase in ['val', 'test']: # validation
dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
else:
raise ValueError(f'Wrong dataset phase: {phase}. ' "Supported ones are 'train', 'val' and 'test'.")
dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
prefetch_mode = dataset_opt.get('prefetch_mode')
if prefetch_mode == 'cpu': # CPUPrefetcher
num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
logger = get_root_logger()
logger.info(f'Use {prefetch_mode} prefetch dataloader: ' f'num_prefetch_queue = {num_prefetch_queue}')
return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
else:
# prefetch_mode=None: Normal dataloader
# prefetch_mode='cuda': dataloader for CUDAPrefetcher
return torch.utils.data.DataLoader(**dataloader_args)
def worker_init_fn(worker_id, num_workers, rank, seed):
# Set the worker seed to num_workers * rank + worker_id + seed
worker_seed = num_workers * rank + worker_id + seed
np.random.seed(worker_seed)
random.seed(worker_seed)
import math
import torch
from torch.utils.data.sampler import Sampler
class EnlargedSampler(Sampler):
"""Sampler that restricts data loading to a subset of the dataset.
Modified from torch.utils.data.distributed.DistributedSampler
Support enlarging the dataset for iteration-based training, for saving
time when restart the dataloader after each epoch
Args:
dataset (torch.utils.data.Dataset): Dataset used for sampling.
num_replicas (int | None): Number of processes participating in
the training. It is usually the world_size.
rank (int | None): Rank of the current process within num_replicas.
ratio (int): Enlarging ratio. Default: 1.
"""
def __init__(self, dataset, num_replicas, rank, ratio=1):
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas)
self.total_size = self.num_samples * self.num_replicas
def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
indices = torch.randperm(self.total_size, generator=g).tolist()
dataset_size = len(self.dataset)
indices = [v % dataset_size for v in indices]
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch
import cv2
import math
import numpy as np
import torch
from os import path as osp
from PIL import Image, ImageDraw
from torch.nn import functional as F
from basicsr.data.transforms import mod_crop
from basicsr.utils import img2tensor, scandir
def read_img_seq(path, require_mod_crop=False, scale=1):
"""Read a sequence of images from a given folder path.
Args:
path (list[str] | str): List of image paths or image folder path.
require_mod_crop (bool): Require mod crop for each image.
Default: False.
scale (int): Scale factor for mod_crop. Default: 1.
Returns:
Tensor: size (t, c, h, w), RGB, [0, 1].
"""
if isinstance(path, list):
img_paths = path
else:
img_paths = sorted(list(scandir(path, full_path=True)))
imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
if require_mod_crop:
imgs = [mod_crop(img, scale) for img in imgs]
imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
imgs = torch.stack(imgs, dim=0)
return imgs
def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'):
"""Generate an index list for reading `num_frames` frames from a sequence
of images.
Args:
crt_idx (int): Current center index.
max_frame_num (int): Max number of the sequence of images (from 1).
num_frames (int): Reading num_frames frames.
padding (str): Padding mode, one of
'replicate' | 'reflection' | 'reflection_circle' | 'circle'
Examples: current_idx = 0, num_frames = 5
The generated frame indices under different padding mode:
replicate: [0, 0, 0, 1, 2]
reflection: [2, 1, 0, 1, 2]
reflection_circle: [4, 3, 0, 1, 2]
circle: [3, 4, 0, 1, 2]
Returns:
list[int]: A list of indices.
"""
assert num_frames % 2 == 1, 'num_frames should be an odd number.'
assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.'
max_frame_num = max_frame_num - 1 # start from 0
num_pad = num_frames // 2
indices = []
for i in range(crt_idx - num_pad, crt_idx + num_pad + 1):
if i < 0:
if padding == 'replicate':
pad_idx = 0
elif padding == 'reflection':
pad_idx = -i
elif padding == 'reflection_circle':
pad_idx = crt_idx + num_pad - i
else:
pad_idx = num_frames + i
elif i > max_frame_num:
if padding == 'replicate':
pad_idx = max_frame_num
elif padding == 'reflection':
pad_idx = max_frame_num * 2 - i
elif padding == 'reflection_circle':
pad_idx = (crt_idx - num_pad) - (i - max_frame_num)
else:
pad_idx = i - num_frames
else:
pad_idx = i
indices.append(pad_idx)
return indices
def paired_paths_from_lmdb(folders, keys):
"""Generate paired paths from lmdb files.
Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is:
lq.lmdb
├── data.mdb
├── lock.mdb
├── meta_info.txt
The data.mdb and lock.mdb are standard lmdb files and you can refer to
https://lmdb.readthedocs.io/en/release/ for more details.
The meta_info.txt is a specified txt file to record the meta information
of our datasets. It will be automatically created when preparing
datasets by our provided dataset tools.
Each line in the txt file records
1)image name (with extension),
2)image shape,
3)compression level, separated by a white space.
Example: `baboon.png (120,125,3) 1`
We use the image name without extension as the lmdb key.
Note that we use the same key for the corresponding lq and gt images.
Args:
folders (list[str]): A list of folder path. The order of list should
be [input_folder, gt_folder].
keys (list[str]): A list of keys identifying folders. The order should
be in consistent with folders, e.g., ['lq', 'gt'].
Note that this key is different from lmdb keys.
Returns:
list[str]: Returned path list.
"""
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
f'But got {len(folders)}')
assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
input_folder, gt_folder = folders
input_key, gt_key = keys
if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')):
raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb '
f'formats. But received {input_key}: {input_folder}; '
f'{gt_key}: {gt_folder}')
# ensure that the two meta_info files are the same
with open(osp.join(input_folder, 'meta_info.txt')) as fin:
input_lmdb_keys = [line.split('.')[0] for line in fin]
with open(osp.join(gt_folder, 'meta_info.txt')) as fin:
gt_lmdb_keys = [line.split('.')[0] for line in fin]
if set(input_lmdb_keys) != set(gt_lmdb_keys):
raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.')
else:
paths = []
for lmdb_key in sorted(input_lmdb_keys):
paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)]))
return paths
def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl):
"""Generate paired paths from an meta information file.
Each line in the meta information file contains the image names and
image shape (usually for gt), separated by a white space.
Example of an meta information file:
```
0001_s001.png (480,480,3)
0001_s002.png (480,480,3)
```
Args:
folders (list[str]): A list of folder path. The order of list should
be [input_folder, gt_folder].
keys (list[str]): A list of keys identifying folders. The order should
be in consistent with folders, e.g., ['lq', 'gt'].
meta_info_file (str): Path to the meta information file.
filename_tmpl (str): Template for each filename. Note that the
template excludes the file extension. Usually the filename_tmpl is
for files in the input folder.
Returns:
list[str]: Returned path list.
"""
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
f'But got {len(folders)}')
assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
input_folder, gt_folder = folders
input_key, gt_key = keys
with open(meta_info_file, 'r') as fin:
gt_names = [line.split(' ')[0] for line in fin]
paths = []
for gt_name in gt_names:
basename, ext = osp.splitext(osp.basename(gt_name))
input_name = f'{filename_tmpl.format(basename)}{ext}'
input_path = osp.join(input_folder, input_name)
gt_path = osp.join(gt_folder, gt_name)
paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
return paths
def paired_paths_from_folder(folders, keys, filename_tmpl):
"""Generate paired paths from folders.
Args:
folders (list[str]): A list of folder path. The order of list should
be [input_folder, gt_folder].
keys (list[str]): A list of keys identifying folders. The order should
be in consistent with folders, e.g., ['lq', 'gt'].
filename_tmpl (str): Template for each filename. Note that the
template excludes the file extension. Usually the filename_tmpl is
for files in the input folder.
Returns:
list[str]: Returned path list.
"""
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
f'But got {len(folders)}')
assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
input_folder, gt_folder = folders
input_key, gt_key = keys
input_paths = list(scandir(input_folder))
gt_paths = list(scandir(gt_folder))
assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: '
f'{len(input_paths)}, {len(gt_paths)}.')
paths = []
for gt_path in gt_paths:
basename, ext = osp.splitext(osp.basename(gt_path))
input_name = f'{filename_tmpl.format(basename)}{ext}'
input_path = osp.join(input_folder, input_name)
assert input_name in input_paths, (f'{input_name} is not in ' f'{input_key}_paths.')
gt_path = osp.join(gt_folder, gt_path)
paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
return paths
def paths_from_folder(folder):
"""Generate paths from folder.
Args:
folder (str): Folder path.
Returns:
list[str]: Returned path list.
"""
paths = list(scandir(folder))
paths = [osp.join(folder, path) for path in paths]
return paths
def paths_from_lmdb(folder):
"""Generate paths from lmdb.
Args:
folder (str): Folder path.
Returns:
list[str]: Returned path list.
"""
if not folder.endswith('.lmdb'):
raise ValueError(f'Folder {folder}folder should in lmdb format.')
with open(osp.join(folder, 'meta_info.txt')) as fin:
paths = [line.split('.')[0] for line in fin]
return paths
def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
"""Generate Gaussian kernel used in `duf_downsample`.
Args:
kernel_size (int): Kernel size. Default: 13.
sigma (float): Sigma of the Gaussian kernel. Default: 1.6.
Returns:
np.array: The Gaussian kernel.
"""
from scipy.ndimage import filters as filters
kernel = np.zeros((kernel_size, kernel_size))
# set element at the middle to one, a dirac delta
kernel[kernel_size // 2, kernel_size // 2] = 1
# gaussian-smooth the dirac, resulting in a gaussian filter
return filters.gaussian_filter(kernel, sigma)
def duf_downsample(x, kernel_size=13, scale=4):
"""Downsamping with Gaussian kernel used in the DUF official code.
Args:
x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w).
kernel_size (int): Kernel size. Default: 13.
scale (int): Downsampling factor. Supported scale: (2, 3, 4).
Default: 4.
Returns:
Tensor: DUF downsampled frames.
"""
assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.'
squeeze_flag = False
if x.ndim == 4:
squeeze_flag = True
x = x.unsqueeze(0)
b, t, c, h, w = x.size()
x = x.view(-1, 1, h, w)
pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2
x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect')
gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale)
gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0)
x = F.conv2d(x, gaussian_filter, stride=scale)
x = x[:, :, 2:-2, 2:-2]
x = x.view(b, t, c, x.size(2), x.size(3))
if squeeze_flag:
x = x.squeeze(0)
return x
def brush_stroke_mask(img, color=(255,255,255)):
min_num_vertex = 8
max_num_vertex = 28
mean_angle = 2*math.pi / 5
angle_range = 2*math.pi / 12
# training large mask ratio (training setting)
min_width = 30
max_width = 70
# very large mask ratio (test setting and refine after 200k)
# min_width = 80
# max_width = 120
def generate_mask(H, W, img=None):
average_radius = math.sqrt(H*H+W*W) / 8
mask = Image.new('RGB', (W, H), 0)
if img is not None: mask = img # Image.fromarray(img)
for _ in range(np.random.randint(1, 4)):
num_vertex = np.random.randint(min_num_vertex, max_num_vertex)
angle_min = mean_angle - np.random.uniform(0, angle_range)
angle_max = mean_angle + np.random.uniform(0, angle_range)
angles = []
vertex = []
for i in range(num_vertex):
if i % 2 == 0:
angles.append(2*math.pi - np.random.uniform(angle_min, angle_max))
else:
angles.append(np.random.uniform(angle_min, angle_max))
h, w = mask.size
vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h))))
for i in range(num_vertex):
r = np.clip(
np.random.normal(loc=average_radius, scale=average_radius//2),
0, 2*average_radius)
new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w)
new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h)
vertex.append((int(new_x), int(new_y)))
draw = ImageDraw.Draw(mask)
width = int(np.random.uniform(min_width, max_width))
draw.line(vertex, fill=color, width=width)
for v in vertex:
draw.ellipse((v[0] - width//2,
v[1] - width//2,
v[0] + width//2,
v[1] + width//2),
fill=color)
return mask
width, height = img.size
mask = generate_mask(height, width, img)
return mask
def random_ff_mask(shape, max_angle = 10, max_len = 100, max_width = 70, times = 10):
"""Generate a random free form mask with configuration.
Args:
config: Config should have configuration including IMG_SHAPES,
VERTICAL_MARGIN, HEIGHT, HORIZONTAL_MARGIN, WIDTH.
Returns:
tuple: (top, left, height, width)
Link:
https://github.com/csqiangwen/DeepFillv2_Pytorch/blob/master/train_dataset.py
"""
height = shape[0]
width = shape[1]
mask = np.zeros((height, width), np.float32)
times = np.random.randint(times-5, times)
for i in range(times):
start_x = np.random.randint(width)
start_y = np.random.randint(height)
for j in range(1 + np.random.randint(5)):
angle = 0.01 + np.random.randint(max_angle)
if i % 2 == 0:
angle = 2 * 3.1415926 - angle
length = 10 + np.random.randint(max_len-20, max_len)
brush_w = 5 + np.random.randint(max_width-30, max_width)
end_x = (start_x + length * np.sin(angle)).astype(np.int32)
end_y = (start_y + length * np.cos(angle)).astype(np.int32)
cv2.line(mask, (start_y, start_x), (end_y, end_x), 1.0, brush_w)
start_x, start_y = end_x, end_y
return mask.astype(np.float32)
\ No newline at end of file
import cv2
import math
import random
import numpy as np
import os.path as osp
from scipy.io import loadmat
from PIL import Image
import torch
import torch.utils.data as data
from torchvision.transforms.functional import (adjust_brightness, adjust_contrast,
adjust_hue, adjust_saturation, normalize)
from basicsr.data import gaussian_kernels as gaussian_kernels
from basicsr.data.transforms import augment
from basicsr.data.data_util import paths_from_folder, brush_stroke_mask, random_ff_mask
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY
@DATASET_REGISTRY.register()
class FFHQBlindDataset(data.Dataset):
def __init__(self, opt):
super(FFHQBlindDataset, self).__init__()
logger = get_root_logger()
self.opt = opt
# file client (io backend)
self.file_client = None
self.io_backend_opt = opt['io_backend']
self.gt_folder = opt['dataroot_gt']
self.gt_size = opt.get('gt_size', 512)
self.in_size = opt.get('in_size', 512)
assert self.gt_size >= self.in_size, 'Wrong setting.'
self.mean = opt.get('mean', [0.5, 0.5, 0.5])
self.std = opt.get('std', [0.5, 0.5, 0.5])
self.component_path = opt.get('component_path', None)
self.latent_gt_path = opt.get('latent_gt_path', None)
if self.component_path is not None:
self.crop_components = True
self.components_dict = torch.load(self.component_path)
self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1.4)
self.nose_enlarge_ratio = opt.get('nose_enlarge_ratio', 1.1)
self.mouth_enlarge_ratio = opt.get('mouth_enlarge_ratio', 1.3)
else:
self.crop_components = False
if self.latent_gt_path is not None:
self.load_latent_gt = True
self.latent_gt_dict = torch.load(self.latent_gt_path)
else:
self.load_latent_gt = False
if self.io_backend_opt['type'] == 'lmdb':
self.io_backend_opt['db_paths'] = self.gt_folder
if not self.gt_folder.endswith('.lmdb'):
raise ValueError("'dataroot_gt' should end with '.lmdb', "f'but received {self.gt_folder}')
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
self.paths = [line.split('.')[0] for line in fin]
else:
self.paths = paths_from_folder(self.gt_folder)
# inpainting mask
self.gen_inpaint_mask = opt.get('gen_inpaint_mask', False)
if self.gen_inpaint_mask:
logger.info(f'generate mask ...')
# self.mask_max_angle = opt.get('mask_max_angle', 10)
# self.mask_max_len = opt.get('mask_max_len', 150)
# self.mask_max_width = opt.get('mask_max_width', 50)
# self.mask_draw_times = opt.get('mask_draw_times', 10)
# # print
# logger.info(f'mask_max_angle: {self.mask_max_angle}')
# logger.info(f'mask_max_len: {self.mask_max_len}')
# logger.info(f'mask_max_width: {self.mask_max_width}')
# logger.info(f'mask_draw_times: {self.mask_draw_times}')
# perform corrupt
self.use_corrupt = opt.get('use_corrupt', True)
self.use_motion_kernel = False
# self.use_motion_kernel = opt.get('use_motion_kernel', True)
if self.use_motion_kernel:
self.motion_kernel_prob = opt.get('motion_kernel_prob', 0.001)
motion_kernel_path = opt.get('motion_kernel_path', 'basicsr/data/motion-blur-kernels-32.pth')
self.motion_kernels = torch.load(motion_kernel_path)
if self.use_corrupt and not self.gen_inpaint_mask:
# degradation configurations
self.blur_kernel_size = opt['blur_kernel_size']
self.blur_sigma = opt['blur_sigma']
self.kernel_list = opt['kernel_list']
self.kernel_prob = opt['kernel_prob']
self.downsample_range = opt['downsample_range']
self.noise_range = opt['noise_range']
self.jpeg_range = opt['jpeg_range']
# print
logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]')
logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')
# color jitter
self.color_jitter_prob = opt.get('color_jitter_prob', None)
self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob', None)
self.color_jitter_shift = opt.get('color_jitter_shift', 20)
if self.color_jitter_prob is not None:
logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}')
# to gray
self.gray_prob = opt.get('gray_prob', 0.0)
if self.gray_prob is not None:
logger.info(f'Use random gray. Prob: {self.gray_prob}')
self.color_jitter_shift /= 255.
@staticmethod
def color_jitter(img, shift):
"""jitter color: randomly jitter the RGB values, in numpy formats"""
jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
img = img + jitter_val
img = np.clip(img, 0, 1)
return img
@staticmethod
def color_jitter_pt(img, brightness, contrast, saturation, hue):
"""jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats"""
fn_idx = torch.randperm(4)
for fn_id in fn_idx:
if fn_id == 0 and brightness is not None:
brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
img = adjust_brightness(img, brightness_factor)
if fn_id == 1 and contrast is not None:
contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
img = adjust_contrast(img, contrast_factor)
if fn_id == 2 and saturation is not None:
saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
img = adjust_saturation(img, saturation_factor)
if fn_id == 3 and hue is not None:
hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
img = adjust_hue(img, hue_factor)
return img
def get_component_locations(self, name, status):
components_bbox = self.components_dict[name]
if status[0]: # hflip
# exchange right and left eye
tmp = components_bbox['left_eye']
components_bbox['left_eye'] = components_bbox['right_eye']
components_bbox['right_eye'] = tmp
# modify the width coordinate
components_bbox['left_eye'][0] = self.gt_size - components_bbox['left_eye'][0]
components_bbox['right_eye'][0] = self.gt_size - components_bbox['right_eye'][0]
components_bbox['nose'][0] = self.gt_size - components_bbox['nose'][0]
components_bbox['mouth'][0] = self.gt_size - components_bbox['mouth'][0]
locations_gt = {}
locations_in = {}
for part in ['left_eye', 'right_eye', 'nose', 'mouth']:
mean = components_bbox[part][0:2]
half_len = components_bbox[part][2]
if 'eye' in part:
half_len *= self.eye_enlarge_ratio
elif part == 'nose':
half_len *= self.nose_enlarge_ratio
elif part == 'mouth':
half_len *= self.mouth_enlarge_ratio
loc = np.hstack((mean - half_len + 1, mean + half_len))
loc = torch.from_numpy(loc).float()
locations_gt[part] = loc
loc_in = loc/(self.gt_size//self.in_size)
locations_in[part] = loc_in
return locations_gt, locations_in
def __getitem__(self, index):
if self.file_client is None:
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
# load gt image
gt_path = self.paths[index]
name = osp.basename(gt_path)[:-4]
img_bytes = self.file_client.get(gt_path)
img_gt = imfrombytes(img_bytes, float32=True)
# random horizontal flip
img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
if self.load_latent_gt:
if status[0]:
latent_gt = self.latent_gt_dict['hflip'][name]
else:
latent_gt = self.latent_gt_dict['orig'][name]
if self.crop_components:
locations_gt, locations_in = self.get_component_locations(name, status)
# generate in image
img_in = img_gt
if self.use_corrupt and not self.gen_inpaint_mask:
# motion blur
if self.use_motion_kernel and random.random() < self.motion_kernel_prob:
m_i = random.randint(0,31)
k = self.motion_kernels[f'{m_i:02d}']
img_in = cv2.filter2D(img_in,-1,k)
# gaussian blur
kernel = gaussian_kernels.random_mixed_kernels(
self.kernel_list,
self.kernel_prob,
self.blur_kernel_size,
self.blur_sigma,
self.blur_sigma,
[-math.pi, math.pi],
noise_range=None)
img_in = cv2.filter2D(img_in, -1, kernel)
# downsample
scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1])
img_in = cv2.resize(img_in, (int(self.gt_size // scale), int(self.gt_size // scale)), interpolation=cv2.INTER_LINEAR)
# noise
if self.noise_range is not None:
noise_sigma = np.random.uniform(self.noise_range[0] / 255., self.noise_range[1] / 255.)
noise = np.float32(np.random.randn(*(img_in.shape))) * noise_sigma
img_in = img_in + noise
img_in = np.clip(img_in, 0, 1)
# jpeg
if self.jpeg_range is not None:
# jpeg_p = np.random.uniform(self.jpeg_range[0], self.jpeg_range[1])
jpeg_p = np.random.randint(self.jpeg_range[0], self.jpeg_range[1])
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_p]
_, encimg = cv2.imencode('.jpg', img_in * 255., encode_param)
img_in = np.float32(cv2.imdecode(encimg, 1)) / 255.
# resize to in_size
img_in = cv2.resize(img_in, (self.in_size, self.in_size), interpolation=cv2.INTER_LINEAR)
# if self.gen_inpaint_mask:
# inpaint_mask = random_ff_mask(shape=(self.gt_size,self.gt_size),
# max_angle = self.mask_max_angle, max_len = self.mask_max_len,
# max_width = self.mask_max_width, times = self.mask_draw_times)
# img_in = img_in * (1 - inpaint_mask.reshape(self.gt_size,self.gt_size,1)) + \
# 1.0 * inpaint_mask.reshape(self.gt_size,self.gt_size,1)
# inpaint_mask = torch.from_numpy(inpaint_mask).view(1,self.gt_size,self.gt_size)
if self.gen_inpaint_mask:
img_in = (img_in*255).astype('uint8')
img_in = brush_stroke_mask(Image.fromarray(img_in))
img_in = np.array(img_in) / 255.
# random color jitter (only for lq)
if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob):
img_in = self.color_jitter(img_in, self.color_jitter_shift)
# random to gray (only for lq)
if self.gray_prob and np.random.uniform() < self.gray_prob:
img_in = cv2.cvtColor(img_in, cv2.COLOR_BGR2GRAY)
img_in = np.tile(img_in[:, :, None], [1, 1, 3])
# BGR to RGB, HWC to CHW, numpy to tensor
img_in, img_gt = img2tensor([img_in, img_gt], bgr2rgb=True, float32=True)
# random color jitter (pytorch version) (only for lq)
if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob):
brightness = self.opt.get('brightness', (0.5, 1.5))
contrast = self.opt.get('contrast', (0.5, 1.5))
saturation = self.opt.get('saturation', (0, 1.5))
hue = self.opt.get('hue', (-0.1, 0.1))
img_in = self.color_jitter_pt(img_in, brightness, contrast, saturation, hue)
# round and clip
img_in = np.clip((img_in * 255.0).round(), 0, 255) / 255.
# Set vgg range_norm=True if use the normalization here
# normalize
normalize(img_in, self.mean, self.std, inplace=True)
normalize(img_gt, self.mean, self.std, inplace=True)
return_dict = {'in': img_in, 'gt': img_gt, 'gt_path': gt_path}
if self.crop_components:
return_dict['locations_in'] = locations_in
return_dict['locations_gt'] = locations_gt
if self.load_latent_gt:
return_dict['latent_gt'] = latent_gt
# if self.gen_inpaint_mask:
# return_dict['inpaint_mask'] = inpaint_mask
return return_dict
def __len__(self):
return len(self.paths)
\ No newline at end of file
import cv2
import math
import random
import numpy as np
import os.path as osp
from scipy.io import loadmat
import torch
import torch.utils.data as data
from torchvision.transforms.functional import (adjust_brightness, adjust_contrast,
adjust_hue, adjust_saturation, normalize)
from basicsr.data import gaussian_kernels as gaussian_kernels
from basicsr.data.transforms import augment
from basicsr.data.data_util import paths_from_folder
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY
@DATASET_REGISTRY.register()
class FFHQBlindJointDataset(data.Dataset):
def __init__(self, opt):
super(FFHQBlindJointDataset, self).__init__()
logger = get_root_logger()
self.opt = opt
# file client (io backend)
self.file_client = None
self.io_backend_opt = opt['io_backend']
self.gt_folder = opt['dataroot_gt']
self.gt_size = opt.get('gt_size', 512)
self.in_size = opt.get('in_size', 512)
assert self.gt_size >= self.in_size, 'Wrong setting.'
self.mean = opt.get('mean', [0.5, 0.5, 0.5])
self.std = opt.get('std', [0.5, 0.5, 0.5])
self.component_path = opt.get('component_path', None)
self.latent_gt_path = opt.get('latent_gt_path', None)
if self.component_path is not None:
self.crop_components = True
self.components_dict = torch.load(self.component_path)
self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1.4)
self.nose_enlarge_ratio = opt.get('nose_enlarge_ratio', 1.1)
self.mouth_enlarge_ratio = opt.get('mouth_enlarge_ratio', 1.3)
else:
self.crop_components = False
if self.latent_gt_path is not None:
self.load_latent_gt = True
self.latent_gt_dict = torch.load(self.latent_gt_path)
else:
self.load_latent_gt = False
if self.io_backend_opt['type'] == 'lmdb':
self.io_backend_opt['db_paths'] = self.gt_folder
if not self.gt_folder.endswith('.lmdb'):
raise ValueError("'dataroot_gt' should end with '.lmdb', "f'but received {self.gt_folder}')
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
self.paths = [line.split('.')[0] for line in fin]
else:
self.paths = paths_from_folder(self.gt_folder)
# perform corrupt
self.use_corrupt = opt.get('use_corrupt', True)
self.use_motion_kernel = False
# self.use_motion_kernel = opt.get('use_motion_kernel', True)
if self.use_motion_kernel:
self.motion_kernel_prob = opt.get('motion_kernel_prob', 0.001)
motion_kernel_path = opt.get('motion_kernel_path', 'basicsr/data/motion-blur-kernels-32.pth')
self.motion_kernels = torch.load(motion_kernel_path)
if self.use_corrupt:
# degradation configurations
self.blur_kernel_size = self.opt['blur_kernel_size']
self.kernel_list = self.opt['kernel_list']
self.kernel_prob = self.opt['kernel_prob']
# Small degradation
self.blur_sigma = self.opt['blur_sigma']
self.downsample_range = self.opt['downsample_range']
self.noise_range = self.opt['noise_range']
self.jpeg_range = self.opt['jpeg_range']
# Large degradation
self.blur_sigma_large = self.opt['blur_sigma_large']
self.downsample_range_large = self.opt['downsample_range_large']
self.noise_range_large = self.opt['noise_range_large']
self.jpeg_range_large = self.opt['jpeg_range_large']
# print
logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]')
logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')
# color jitter
self.color_jitter_prob = opt.get('color_jitter_prob', None)
self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob', None)
self.color_jitter_shift = opt.get('color_jitter_shift', 20)
if self.color_jitter_prob is not None:
logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}')
# to gray
self.gray_prob = opt.get('gray_prob', 0.0)
if self.gray_prob is not None:
logger.info(f'Use random gray. Prob: {self.gray_prob}')
self.color_jitter_shift /= 255.
@staticmethod
def color_jitter(img, shift):
"""jitter color: randomly jitter the RGB values, in numpy formats"""
jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
img = img + jitter_val
img = np.clip(img, 0, 1)
return img
@staticmethod
def color_jitter_pt(img, brightness, contrast, saturation, hue):
"""jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats"""
fn_idx = torch.randperm(4)
for fn_id in fn_idx:
if fn_id == 0 and brightness is not None:
brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
img = adjust_brightness(img, brightness_factor)
if fn_id == 1 and contrast is not None:
contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
img = adjust_contrast(img, contrast_factor)
if fn_id == 2 and saturation is not None:
saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
img = adjust_saturation(img, saturation_factor)
if fn_id == 3 and hue is not None:
hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
img = adjust_hue(img, hue_factor)
return img
def get_component_locations(self, name, status):
components_bbox = self.components_dict[name]
if status[0]: # hflip
# exchange right and left eye
tmp = components_bbox['left_eye']
components_bbox['left_eye'] = components_bbox['right_eye']
components_bbox['right_eye'] = tmp
# modify the width coordinate
components_bbox['left_eye'][0] = self.gt_size - components_bbox['left_eye'][0]
components_bbox['right_eye'][0] = self.gt_size - components_bbox['right_eye'][0]
components_bbox['nose'][0] = self.gt_size - components_bbox['nose'][0]
components_bbox['mouth'][0] = self.gt_size - components_bbox['mouth'][0]
locations_gt = {}
locations_in = {}
for part in ['left_eye', 'right_eye', 'nose', 'mouth']:
mean = components_bbox[part][0:2]
half_len = components_bbox[part][2]
if 'eye' in part:
half_len *= self.eye_enlarge_ratio
elif part == 'nose':
half_len *= self.nose_enlarge_ratio
elif part == 'mouth':
half_len *= self.mouth_enlarge_ratio
loc = np.hstack((mean - half_len + 1, mean + half_len))
loc = torch.from_numpy(loc).float()
locations_gt[part] = loc
loc_in = loc/(self.gt_size//self.in_size)
locations_in[part] = loc_in
return locations_gt, locations_in
def __getitem__(self, index):
if self.file_client is None:
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
# load gt image
gt_path = self.paths[index]
name = osp.basename(gt_path)[:-4]
img_bytes = self.file_client.get(gt_path)
img_gt = imfrombytes(img_bytes, float32=True)
# random horizontal flip
img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
if self.load_latent_gt:
if status[0]:
latent_gt = self.latent_gt_dict['hflip'][name]
else:
latent_gt = self.latent_gt_dict['orig'][name]
if self.crop_components:
locations_gt, locations_in = self.get_component_locations(name, status)
# generate in image
img_in = img_gt
if self.use_corrupt:
# motion blur
if self.use_motion_kernel and random.random() < self.motion_kernel_prob:
m_i = random.randint(0,31)
k = self.motion_kernels[f'{m_i:02d}']
img_in = cv2.filter2D(img_in,-1,k)
# gaussian blur
kernel = gaussian_kernels.random_mixed_kernels(
self.kernel_list,
self.kernel_prob,
self.blur_kernel_size,
self.blur_sigma,
self.blur_sigma,
[-math.pi, math.pi],
noise_range=None)
img_in = cv2.filter2D(img_in, -1, kernel)
# downsample
scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1])
img_in = cv2.resize(img_in, (int(self.gt_size // scale), int(self.gt_size // scale)), interpolation=cv2.INTER_LINEAR)
# noise
if self.noise_range is not None:
noise_sigma = np.random.uniform(self.noise_range[0] / 255., self.noise_range[1] / 255.)
noise = np.float32(np.random.randn(*(img_in.shape))) * noise_sigma
img_in = img_in + noise
img_in = np.clip(img_in, 0, 1)
# jpeg
if self.jpeg_range is not None:
# jpeg_p = np.random.uniform(self.jpeg_range[0], self.jpeg_range[1])
jpeg_p = np.random.randint(self.jpeg_range[0], self.jpeg_range[1])
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_p]
_, encimg = cv2.imencode('.jpg', img_in * 255., encode_param)
img_in = np.float32(cv2.imdecode(encimg, 1)) / 255.
# resize to in_size
img_in = cv2.resize(img_in, (self.in_size, self.in_size), interpolation=cv2.INTER_LINEAR)
# generate in_large with large degradation
img_in_large = img_gt
if self.use_corrupt:
# motion blur
if self.use_motion_kernel and random.random() < self.motion_kernel_prob:
m_i = random.randint(0,31)
k = self.motion_kernels[f'{m_i:02d}']
img_in_large = cv2.filter2D(img_in_large,-1,k)
# gaussian blur
kernel = gaussian_kernels.random_mixed_kernels(
self.kernel_list,
self.kernel_prob,
self.blur_kernel_size,
self.blur_sigma_large,
self.blur_sigma_large,
[-math.pi, math.pi],
noise_range=None)
img_in_large = cv2.filter2D(img_in_large, -1, kernel)
# downsample
scale = np.random.uniform(self.downsample_range_large[0], self.downsample_range_large[1])
img_in_large = cv2.resize(img_in_large, (int(self.gt_size // scale), int(self.gt_size // scale)), interpolation=cv2.INTER_LINEAR)
# noise
if self.noise_range_large is not None:
noise_sigma = np.random.uniform(self.noise_range_large[0] / 255., self.noise_range_large[1] / 255.)
noise = np.float32(np.random.randn(*(img_in_large.shape))) * noise_sigma
img_in_large = img_in_large + noise
img_in_large = np.clip(img_in_large, 0, 1)
# jpeg
if self.jpeg_range_large is not None:
# jpeg_p = np.random.uniform(self.jpeg_range_large[0], self.jpeg_range_large[1])
jpeg_p = np.random.randint(self.jpeg_range_large[0], self.jpeg_range_large[1])
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_p]
_, encimg = cv2.imencode('.jpg', img_in_large * 255., encode_param)
img_in_large = np.float32(cv2.imdecode(encimg, 1)) / 255.
# resize to in_size
img_in_large = cv2.resize(img_in_large, (self.in_size, self.in_size), interpolation=cv2.INTER_LINEAR)
# random color jitter (only for lq)
if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob):
img_in = self.color_jitter(img_in, self.color_jitter_shift)
img_in_large = self.color_jitter(img_in_large, self.color_jitter_shift)
# random to gray (only for lq)
if self.gray_prob and np.random.uniform() < self.gray_prob:
img_in = cv2.cvtColor(img_in, cv2.COLOR_BGR2GRAY)
img_in = np.tile(img_in[:, :, None], [1, 1, 3])
img_in_large = cv2.cvtColor(img_in_large, cv2.COLOR_BGR2GRAY)
img_in_large = np.tile(img_in_large[:, :, None], [1, 1, 3])
# BGR to RGB, HWC to CHW, numpy to tensor
img_in, img_in_large, img_gt = img2tensor([img_in, img_in_large, img_gt], bgr2rgb=True, float32=True)
# random color jitter (pytorch version) (only for lq)
if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob):
brightness = self.opt.get('brightness', (0.5, 1.5))
contrast = self.opt.get('contrast', (0.5, 1.5))
saturation = self.opt.get('saturation', (0, 1.5))
hue = self.opt.get('hue', (-0.1, 0.1))
img_in = self.color_jitter_pt(img_in, brightness, contrast, saturation, hue)
img_in_large = self.color_jitter_pt(img_in_large, brightness, contrast, saturation, hue)
# round and clip
img_in = np.clip((img_in * 255.0).round(), 0, 255) / 255.
img_in_large = np.clip((img_in_large * 255.0).round(), 0, 255) / 255.
# Set vgg range_norm=True if use the normalization here
# normalize
normalize(img_in, self.mean, self.std, inplace=True)
normalize(img_in_large, self.mean, self.std, inplace=True)
normalize(img_gt, self.mean, self.std, inplace=True)
return_dict = {'in': img_in, 'in_large_de': img_in_large, 'gt': img_gt, 'gt_path': gt_path}
if self.crop_components:
return_dict['locations_in'] = locations_in
return_dict['locations_gt'] = locations_gt
if self.load_latent_gt:
return_dict['latent_gt'] = latent_gt
return return_dict
def __len__(self):
return len(self.paths)
This diff is collapsed.
from torch.utils import data as data
from torchvision.transforms.functional import normalize
from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file
from basicsr.data.transforms import augment, paired_random_crop
from basicsr.utils import FileClient, imfrombytes, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY
@DATASET_REGISTRY.register()
class PairedImageDataset(data.Dataset):
"""Paired image dataset for image restoration.
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and
GT image pairs.
There are three modes:
1. 'lmdb': Use lmdb files.
If opt['io_backend'] == lmdb.
2. 'meta_info_file': Use meta information file to generate paths.
If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
3. 'folder': Scan folders to generate paths.
The rest.
Args:
opt (dict): Config for train datasets. It contains the following keys:
dataroot_gt (str): Data root path for gt.
dataroot_lq (str): Data root path for lq.
meta_info_file (str): Path for meta information file.
io_backend (dict): IO backend type and other kwarg.
filename_tmpl (str): Template for each filename. Note that the
template excludes the file extension. Default: '{}'.
gt_size (int): Cropped patched size for gt patches.
use_flip (bool): Use horizontal flips.
use_rot (bool): Use rotation (use vertical flip and transposing h
and w for implementation).
scale (bool): Scale, which will be added automatically.
phase (str): 'train' or 'val'.
"""
def __init__(self, opt):
super(PairedImageDataset, self).__init__()
self.opt = opt
# file client (io backend)
self.file_client = None
self.io_backend_opt = opt['io_backend']
self.mean = opt['mean'] if 'mean' in opt else None
self.std = opt['std'] if 'std' in opt else None
self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
if 'filename_tmpl' in opt:
self.filename_tmpl = opt['filename_tmpl']
else:
self.filename_tmpl = '{}'
if self.io_backend_opt['type'] == 'lmdb':
self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
self.io_backend_opt['client_keys'] = ['lq', 'gt']
self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
elif 'meta_info_file' in self.opt and self.opt['meta_info_file'] is not None:
self.paths = paired_paths_from_meta_info_file([self.lq_folder, self.gt_folder], ['lq', 'gt'],
self.opt['meta_info_file'], self.filename_tmpl)
else:
self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
def __getitem__(self, index):
if self.file_client is None:
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
scale = self.opt['scale']
# Load gt and lq images. Dimension order: HWC; channel order: BGR;
# image range: [0, 1], float32.
gt_path = self.paths[index]['gt_path']
img_bytes = self.file_client.get(gt_path, 'gt')
img_gt = imfrombytes(img_bytes, float32=True)
lq_path = self.paths[index]['lq_path']
img_bytes = self.file_client.get(lq_path, 'lq')
img_lq = imfrombytes(img_bytes, float32=True)
# augmentation for training
if self.opt['phase'] == 'train':
gt_size = self.opt['gt_size']
# random crop
img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
# flip, rotation
img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_flip'], self.opt['use_rot'])
# TODO: color space transform
# BGR to RGB, HWC to CHW, numpy to tensor
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
# normalize
if self.mean is not None or self.std is not None:
normalize(img_lq, self.mean, self.std, inplace=True)
normalize(img_gt, self.mean, self.std, inplace=True)
return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
def __len__(self):
return len(self.paths)
import queue as Queue
import threading
import torch
from torch.utils.data import DataLoader
class PrefetchGenerator(threading.Thread):
"""A general prefetch generator.
Ref:
https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
Args:
generator: Python generator.
num_prefetch_queue (int): Number of prefetch queue.
"""
def __init__(self, generator, num_prefetch_queue):
threading.Thread.__init__(self)
self.queue = Queue.Queue(num_prefetch_queue)
self.generator = generator
self.daemon = True
self.start()
def run(self):
for item in self.generator:
self.queue.put(item)
self.queue.put(None)
def __next__(self):
next_item = self.queue.get()
if next_item is None:
raise StopIteration
return next_item
def __iter__(self):
return self
class PrefetchDataLoader(DataLoader):
"""Prefetch version of dataloader.
Ref:
https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
TODO:
Need to test on single gpu and ddp (multi-gpu). There is a known issue in
ddp.
Args:
num_prefetch_queue (int): Number of prefetch queue.
kwargs (dict): Other arguments for dataloader.
"""
def __init__(self, num_prefetch_queue, **kwargs):
self.num_prefetch_queue = num_prefetch_queue
super(PrefetchDataLoader, self).__init__(**kwargs)
def __iter__(self):
return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
class CPUPrefetcher():
"""CPU prefetcher.
Args:
loader: Dataloader.
"""
def __init__(self, loader):
self.ori_loader = loader
self.loader = iter(loader)
def next(self):
try:
return next(self.loader)
except StopIteration:
return None
def reset(self):
self.loader = iter(self.ori_loader)
class CUDAPrefetcher():
"""CUDA prefetcher.
Ref:
https://github.com/NVIDIA/apex/issues/304#
It may consums more GPU memory.
Args:
loader: Dataloader.
opt (dict): Options.
"""
def __init__(self, loader, opt):
self.ori_loader = loader
self.loader = iter(loader)
self.opt = opt
self.stream = torch.cuda.Stream()
self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
self.preload()
def preload(self):
try:
self.batch = next(self.loader) # self.batch is a dict
except StopIteration:
self.batch = None
return None
# put tensors to gpu
with torch.cuda.stream(self.stream):
for k, v in self.batch.items():
if torch.is_tensor(v):
self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
def next(self):
torch.cuda.current_stream().wait_stream(self.stream)
batch = self.batch
self.preload()
return batch
def reset(self):
self.loader = iter(self.ori_loader)
self.preload()
import cv2
import random
def mod_crop(img, scale):
"""Mod crop images, used during testing.
Args:
img (ndarray): Input image.
scale (int): Scale factor.
Returns:
ndarray: Result image.
"""
img = img.copy()
if img.ndim in (2, 3):
h, w = img.shape[0], img.shape[1]
h_remainder, w_remainder = h % scale, w % scale
img = img[:h - h_remainder, :w - w_remainder, ...]
else:
raise ValueError(f'Wrong img ndim: {img.ndim}.')
return img
def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path):
"""Paired random crop.
It crops lists of lq and gt images with corresponding locations.
Args:
img_gts (list[ndarray] | ndarray): GT images. Note that all images
should have the same shape. If the input is an ndarray, it will
be transformed to a list containing itself.
img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
should have the same shape. If the input is an ndarray, it will
be transformed to a list containing itself.
gt_patch_size (int): GT patch size.
scale (int): Scale factor.
gt_path (str): Path to ground-truth.
Returns:
list[ndarray] | ndarray: GT images and LQ images. If returned results
only have one element, just return ndarray.
"""
if not isinstance(img_gts, list):
img_gts = [img_gts]
if not isinstance(img_lqs, list):
img_lqs = [img_lqs]
h_lq, w_lq, _ = img_lqs[0].shape
h_gt, w_gt, _ = img_gts[0].shape
lq_patch_size = gt_patch_size // scale
if h_gt != h_lq * scale or w_gt != w_lq * scale:
raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
f'multiplication of LQ ({h_lq}, {w_lq}).')
if h_lq < lq_patch_size or w_lq < lq_patch_size:
raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
f'({lq_patch_size}, {lq_patch_size}). '
f'Please remove {gt_path}.')
# randomly choose top and left coordinates for lq patch
top = random.randint(0, h_lq - lq_patch_size)
left = random.randint(0, w_lq - lq_patch_size)
# crop lq patch
img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
# crop corresponding gt patch
top_gt, left_gt = int(top * scale), int(left * scale)
img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
if len(img_gts) == 1:
img_gts = img_gts[0]
if len(img_lqs) == 1:
img_lqs = img_lqs[0]
return img_gts, img_lqs
def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
"""Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
We use vertical flip and transpose for rotation implementation.
All the images in the list use the same augmentation.
Args:
imgs (list[ndarray] | ndarray): Images to be augmented. If the input
is an ndarray, it will be transformed to a list.
hflip (bool): Horizontal flip. Default: True.
rotation (bool): Ratotation. Default: True.
flows (list[ndarray]: Flows to be augmented. If the input is an
ndarray, it will be transformed to a list.
Dimension is (h, w, 2). Default: None.
return_status (bool): Return the status of flip and rotation.
Default: False.
Returns:
list[ndarray] | ndarray: Augmented images and flows. If returned
results only have one element, just return ndarray.
"""
hflip = hflip and random.random() < 0.5
vflip = rotation and random.random() < 0.5
rot90 = rotation and random.random() < 0.5
def _augment(img):
if hflip: # horizontal
cv2.flip(img, 1, img)
if vflip: # vertical
cv2.flip(img, 0, img)
if rot90:
img = img.transpose(1, 0, 2)
return img
def _augment_flow(flow):
if hflip: # horizontal
cv2.flip(flow, 1, flow)
flow[:, :, 0] *= -1
if vflip: # vertical
cv2.flip(flow, 0, flow)
flow[:, :, 1] *= -1
if rot90:
flow = flow.transpose(1, 0, 2)
flow = flow[:, :, [1, 0]]
return flow
if not isinstance(imgs, list):
imgs = [imgs]
imgs = [_augment(img) for img in imgs]
if len(imgs) == 1:
imgs = imgs[0]
if flows is not None:
if not isinstance(flows, list):
flows = [flows]
flows = [_augment_flow(flow) for flow in flows]
if len(flows) == 1:
flows = flows[0]
return imgs, flows
else:
if return_status:
return imgs, (hflip, vflip, rot90)
else:
return imgs
def img_rotate(img, angle, center=None, scale=1.0):
"""Rotate image.
Args:
img (ndarray): Image to be rotated.
angle (float): Rotation angle in degrees. Positive values mean
counter-clockwise rotation.
center (tuple[int]): Rotation center. If the center is None,
initialize it as the center of the image. Default: None.
scale (float): Isotropic scale factor. Default: 1.0.
"""
(h, w) = img.shape[:2]
if center is None:
center = (w // 2, h // 2)
matrix = cv2.getRotationMatrix2D(center, angle, scale)
rotated_img = cv2.warpAffine(img, matrix, (w, h))
return rotated_img
from copy import deepcopy
from basicsr.utils import get_root_logger
from basicsr.utils.registry import LOSS_REGISTRY
from .losses import (CharbonnierLoss, GANLoss, L1Loss, MSELoss, PerceptualLoss, WeightedTVLoss, g_path_regularize,
gradient_penalty_loss, r1_penalty)
__all__ = [
'L1Loss', 'MSELoss', 'CharbonnierLoss', 'WeightedTVLoss', 'PerceptualLoss', 'GANLoss', 'gradient_penalty_loss',
'r1_penalty', 'g_path_regularize'
]
def build_loss(opt):
"""Build loss from options.
Args:
opt (dict): Configuration. It must constain:
type (str): Model type.
"""
opt = deepcopy(opt)
loss_type = opt.pop('type')
loss = LOSS_REGISTRY.get(loss_type)(**opt)
logger = get_root_logger()
logger.info(f'Loss [{loss.__class__.__name__}] is created.')
return loss
import functools
from torch.nn import functional as F
def reduce_loss(loss, reduction):
"""Reduce loss as specified.
Args:
loss (Tensor): Elementwise loss tensor.
reduction (str): Options are 'none', 'mean' and 'sum'.
Returns:
Tensor: Reduced loss tensor.
"""
reduction_enum = F._Reduction.get_enum(reduction)
# none: 0, elementwise_mean:1, sum: 2
if reduction_enum == 0:
return loss
elif reduction_enum == 1:
return loss.mean()
else:
return loss.sum()
def weight_reduce_loss(loss, weight=None, reduction='mean'):
"""Apply element-wise weight and reduce loss.
Args:
loss (Tensor): Element-wise loss.
weight (Tensor): Element-wise weights. Default: None.
reduction (str): Same as built-in losses of PyTorch. Options are
'none', 'mean' and 'sum'. Default: 'mean'.
Returns:
Tensor: Loss values.
"""
# if weight is specified, apply element-wise weight
if weight is not None:
assert weight.dim() == loss.dim()
assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
loss = loss * weight
# if weight is not specified or reduction is sum, just reduce the loss
if weight is None or reduction == 'sum':
loss = reduce_loss(loss, reduction)
# if reduction is mean, then compute mean over weight region
elif reduction == 'mean':
if weight.size(1) > 1:
weight = weight.sum()
else:
weight = weight.sum() * loss.size(1)
loss = loss.sum() / weight
return loss
def weighted_loss(loss_func):
"""Create a weighted version of a given loss function.
To use this decorator, the loss function must have the signature like
`loss_func(pred, target, **kwargs)`. The function only needs to compute
element-wise loss without any reduction. This decorator will add weight
and reduction arguments to the function. The decorated function will have
the signature like `loss_func(pred, target, weight=None, reduction='mean',
**kwargs)`.
:Example:
>>> import torch
>>> @weighted_loss
>>> def l1_loss(pred, target):
>>> return (pred - target).abs()
>>> pred = torch.Tensor([0, 2, 3])
>>> target = torch.Tensor([1, 1, 1])
>>> weight = torch.Tensor([1, 0, 1])
>>> l1_loss(pred, target)
tensor(1.3333)
>>> l1_loss(pred, target, weight)
tensor(1.5000)
>>> l1_loss(pred, target, reduction='none')
tensor([1., 1., 2.])
>>> l1_loss(pred, target, weight, reduction='sum')
tensor(3.)
"""
@functools.wraps(loss_func)
def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
# get element-wise loss
loss = loss_func(pred, target, **kwargs)
loss = weight_reduce_loss(loss, weight, reduction)
return loss
return wrapper
import math
import lpips
import torch
from torch import autograd as autograd
from torch import nn as nn
from torch.nn import functional as F
from basicsr.archs.vgg_arch import VGGFeatureExtractor
from basicsr.utils.registry import LOSS_REGISTRY
from .loss_util import weighted_loss
_reduction_modes = ['none', 'mean', 'sum']
@weighted_loss
def l1_loss(pred, target):
return F.l1_loss(pred, target, reduction='none')
@weighted_loss
def mse_loss(pred, target):
return F.mse_loss(pred, target, reduction='none')
@weighted_loss
def charbonnier_loss(pred, target, eps=1e-12):
return torch.sqrt((pred - target)**2 + eps)
@LOSS_REGISTRY.register()
class L1Loss(nn.Module):
"""L1 (mean absolute error, MAE) loss.
Args:
loss_weight (float): Loss weight for L1 loss. Default: 1.0.
reduction (str): Specifies the reduction to apply to the output.
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
"""
def __init__(self, loss_weight=1.0, reduction='mean'):
super(L1Loss, self).__init__()
if reduction not in ['none', 'mean', 'sum']:
raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
self.loss_weight = loss_weight
self.reduction = reduction
def forward(self, pred, target, weight=None, **kwargs):
"""
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
weight (Tensor, optional): of shape (N, C, H, W). Element-wise
weights. Default: None.
"""
return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction)
@LOSS_REGISTRY.register()
class MSELoss(nn.Module):
"""MSE (L2) loss.
Args:
loss_weight (float): Loss weight for MSE loss. Default: 1.0.
reduction (str): Specifies the reduction to apply to the output.
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
"""
def __init__(self, loss_weight=1.0, reduction='mean'):
super(MSELoss, self).__init__()
if reduction not in ['none', 'mean', 'sum']:
raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
self.loss_weight = loss_weight
self.reduction = reduction
def forward(self, pred, target, weight=None, **kwargs):
"""
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
weight (Tensor, optional): of shape (N, C, H, W). Element-wise
weights. Default: None.
"""
return self.loss_weight * mse_loss(pred, target, weight, reduction=self.reduction)
@LOSS_REGISTRY.register()
class CharbonnierLoss(nn.Module):
"""Charbonnier loss (one variant of Robust L1Loss, a differentiable
variant of L1Loss).
Described in "Deep Laplacian Pyramid Networks for Fast and Accurate
Super-Resolution".
Args:
loss_weight (float): Loss weight for L1 loss. Default: 1.0.
reduction (str): Specifies the reduction to apply to the output.
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
eps (float): A value used to control the curvature near zero.
Default: 1e-12.
"""
def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12):
super(CharbonnierLoss, self).__init__()
if reduction not in ['none', 'mean', 'sum']:
raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
self.loss_weight = loss_weight
self.reduction = reduction
self.eps = eps
def forward(self, pred, target, weight=None, **kwargs):
"""
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
weight (Tensor, optional): of shape (N, C, H, W). Element-wise
weights. Default: None.
"""
return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction)
@LOSS_REGISTRY.register()
class WeightedTVLoss(L1Loss):
"""Weighted TV loss.
Args:
loss_weight (float): Loss weight. Default: 1.0.
"""
def __init__(self, loss_weight=1.0):
super(WeightedTVLoss, self).__init__(loss_weight=loss_weight)
def forward(self, pred, weight=None):
y_diff = super(WeightedTVLoss, self).forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=weight[:, :, :-1, :])
x_diff = super(WeightedTVLoss, self).forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=weight[:, :, :, :-1])
loss = x_diff + y_diff
return loss
@LOSS_REGISTRY.register()
class PerceptualLoss(nn.Module):
"""Perceptual loss with commonly used style loss.
Args:
layer_weights (dict): The weight for each layer of vgg feature.
Here is an example: {'conv5_4': 1.}, which means the conv5_4
feature layer (before relu5_4) will be extracted with weight
1.0 in calculting losses.
vgg_type (str): The type of vgg network used as feature extractor.
Default: 'vgg19'.
use_input_norm (bool): If True, normalize the input image in vgg.
Default: True.
range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
Default: False.
perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
loss will be calculated and the loss will multiplied by the
weight. Default: 1.0.
style_weight (float): If `style_weight > 0`, the style loss will be
calculated and the loss will multiplied by the weight.
Default: 0.
criterion (str): Criterion used for perceptual loss. Default: 'l1'.
"""
def __init__(self,
layer_weights,
vgg_type='vgg19',
use_input_norm=True,
range_norm=False,
perceptual_weight=1.0,
style_weight=0.,
criterion='l1'):
super(PerceptualLoss, self).__init__()
self.perceptual_weight = perceptual_weight
self.style_weight = style_weight
self.layer_weights = layer_weights
self.vgg = VGGFeatureExtractor(
layer_name_list=list(layer_weights.keys()),
vgg_type=vgg_type,
use_input_norm=use_input_norm,
range_norm=range_norm)
self.criterion_type = criterion
if self.criterion_type == 'l1':
self.criterion = torch.nn.L1Loss()
elif self.criterion_type == 'l2':
self.criterion = torch.nn.L2loss()
elif self.criterion_type == 'mse':
self.criterion = torch.nn.MSELoss(reduction='mean')
elif self.criterion_type == 'fro':
self.criterion = None
else:
raise NotImplementedError(f'{criterion} criterion has not been supported.')
def forward(self, x, gt):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
# extract vgg features
x_features = self.vgg(x)
gt_features = self.vgg(gt.detach())
# calculate perceptual loss
if self.perceptual_weight > 0:
percep_loss = 0
for k in x_features.keys():
if self.criterion_type == 'fro':
percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k]
else:
percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
percep_loss *= self.perceptual_weight
else:
percep_loss = None
# calculate style loss
if self.style_weight > 0:
style_loss = 0
for k in x_features.keys():
if self.criterion_type == 'fro':
style_loss += torch.norm(
self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k]
else:
style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(
gt_features[k])) * self.layer_weights[k]
style_loss *= self.style_weight
else:
style_loss = None
return percep_loss, style_loss
def _gram_mat(self, x):
"""Calculate Gram matrix.
Args:
x (torch.Tensor): Tensor with shape of (n, c, h, w).
Returns:
torch.Tensor: Gram matrix.
"""
n, c, h, w = x.size()
features = x.view(n, c, w * h)
features_t = features.transpose(1, 2)
gram = features.bmm(features_t) / (c * h * w)
return gram
@LOSS_REGISTRY.register()
class LPIPSLoss(nn.Module):
def __init__(self,
loss_weight=1.0,
use_input_norm=True,
range_norm=False,):
super(LPIPSLoss, self).__init__()
self.perceptual = lpips.LPIPS(net="vgg", spatial=False).eval()
self.loss_weight = loss_weight
self.use_input_norm = use_input_norm
self.range_norm = range_norm
if self.use_input_norm:
# the mean is for image with range [0, 1]
self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
# the std is for image with range [0, 1]
self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
def forward(self, pred, target):
if self.range_norm:
pred = (pred + 1) / 2
target = (target + 1) / 2
if self.use_input_norm:
pred = (pred - self.mean) / self.std
target = (target - self.mean) / self.std
lpips_loss = self.perceptual(target.contiguous(), pred.contiguous())
return self.loss_weight * lpips_loss.mean()
@LOSS_REGISTRY.register()
class GANLoss(nn.Module):
"""Define GAN loss.
Args:
gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.
real_label_val (float): The value for real label. Default: 1.0.
fake_label_val (float): The value for fake label. Default: 0.0.
loss_weight (float): Loss weight. Default: 1.0.
Note that loss_weight is only for generators; and it is always 1.0
for discriminators.
"""
def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
super(GANLoss, self).__init__()
self.gan_type = gan_type
self.loss_weight = loss_weight
self.real_label_val = real_label_val
self.fake_label_val = fake_label_val
if self.gan_type == 'vanilla':
self.loss = nn.BCEWithLogitsLoss()
elif self.gan_type == 'lsgan':
self.loss = nn.MSELoss()
elif self.gan_type == 'wgan':
self.loss = self._wgan_loss
elif self.gan_type == 'wgan_softplus':
self.loss = self._wgan_softplus_loss
elif self.gan_type == 'hinge':
self.loss = nn.ReLU()
else:
raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.')
def _wgan_loss(self, input, target):
"""wgan loss.
Args:
input (Tensor): Input tensor.
target (bool): Target label.
Returns:
Tensor: wgan loss.
"""
return -input.mean() if target else input.mean()
def _wgan_softplus_loss(self, input, target):
"""wgan loss with soft plus. softplus is a smooth approximation to the
ReLU function.
In StyleGAN2, it is called:
Logistic loss for discriminator;
Non-saturating loss for generator.
Args:
input (Tensor): Input tensor.
target (bool): Target label.
Returns:
Tensor: wgan loss.
"""
return F.softplus(-input).mean() if target else F.softplus(input).mean()
def get_target_label(self, input, target_is_real):
"""Get target label.
Args:
input (Tensor): Input tensor.
target_is_real (bool): Whether the target is real or fake.
Returns:
(bool | Tensor): Target tensor. Return bool for wgan, otherwise,
return Tensor.
"""
if self.gan_type in ['wgan', 'wgan_softplus']:
return target_is_real
target_val = (self.real_label_val if target_is_real else self.fake_label_val)
return input.new_ones(input.size()) * target_val
def forward(self, input, target_is_real, is_disc=False):
"""
Args:
input (Tensor): The input for the loss module, i.e., the network
prediction.
target_is_real (bool): Whether the targe is real or fake.
is_disc (bool): Whether the loss for discriminators or not.
Default: False.
Returns:
Tensor: GAN loss value.
"""
if self.gan_type == 'hinge':
if is_disc: # for discriminators in hinge-gan
input = -input if target_is_real else input
loss = self.loss(1 + input).mean()
else: # for generators in hinge-gan
loss = -input.mean()
else: # other gan types
target_label = self.get_target_label(input, target_is_real)
loss = self.loss(input, target_label)
# loss_weight is always 1.0 for discriminators
return loss if is_disc else loss * self.loss_weight
def r1_penalty(real_pred, real_img):
"""R1 regularization for discriminator. The core idea is to
penalize the gradient on real data alone: when the
generator distribution produces the true data distribution
and the discriminator is equal to 0 on the data manifold, the
gradient penalty ensures that the discriminator cannot create
a non-zero gradient orthogonal to the data manifold without
suffering a loss in the GAN game.
Ref:
Eq. 9 in Which training methods for GANs do actually converge.
"""
grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0]
grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
return grad_penalty
def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3])
grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0]
path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
path_penalty = (path_lengths - path_mean).pow(2).mean()
return path_penalty, path_lengths.detach().mean(), path_mean.detach()
def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None):
"""Calculate gradient penalty for wgan-gp.
Args:
discriminator (nn.Module): Network for the discriminator.
real_data (Tensor): Real input data.
fake_data (Tensor): Fake input data.
weight (Tensor): Weight tensor. Default: None.
Returns:
Tensor: A tensor for gradient penalty.
"""
batch_size = real_data.size(0)
alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1))
# interpolate between real_data and fake_data
interpolates = alpha * real_data + (1. - alpha) * fake_data
interpolates = autograd.Variable(interpolates, requires_grad=True)
disc_interpolates = discriminator(interpolates)
gradients = autograd.grad(
outputs=disc_interpolates,
inputs=interpolates,
grad_outputs=torch.ones_like(disc_interpolates),
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
if weight is not None:
gradients = gradients * weight
gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean()
if weight is not None:
gradients_penalty /= torch.mean(weight)
return gradients_penalty
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment