Commit 76b9024b authored by yangzhong's avatar yangzhong
Browse files

git init

parents
Pipeline #3145 failed with stages
in 0 seconds
import math
import torch
from torch import nn as nn
from torch.nn import functional as F
from basicsr.utils.registry import ARCH_REGISTRY
from .arch_util import flow_warp
class BasicModule(nn.Module):
"""Basic Module for SpyNet.
"""
def __init__(self):
super(BasicModule, self).__init__()
self.basic_module = nn.Sequential(
nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3))
def forward(self, tensor_input):
return self.basic_module(tensor_input)
@ARCH_REGISTRY.register()
class SpyNet(nn.Module):
"""SpyNet architecture.
Args:
load_path (str): path for pretrained SpyNet. Default: None.
"""
def __init__(self, load_path=None):
super(SpyNet, self).__init__()
self.basic_module = nn.ModuleList([BasicModule() for _ in range(6)])
if load_path:
self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params'])
self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
def preprocess(self, tensor_input):
tensor_output = (tensor_input - self.mean) / self.std
return tensor_output
def process(self, ref, supp):
flow = []
ref = [self.preprocess(ref)]
supp = [self.preprocess(supp)]
for level in range(5):
ref.insert(0, F.avg_pool2d(input=ref[0], kernel_size=2, stride=2, count_include_pad=False))
supp.insert(0, F.avg_pool2d(input=supp[0], kernel_size=2, stride=2, count_include_pad=False))
flow = ref[0].new_zeros(
[ref[0].size(0), 2,
int(math.floor(ref[0].size(2) / 2.0)),
int(math.floor(ref[0].size(3) / 2.0))])
for level in range(len(ref)):
upsampled_flow = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0
if upsampled_flow.size(2) != ref[level].size(2):
upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 0, 0, 1], mode='replicate')
if upsampled_flow.size(3) != ref[level].size(3):
upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 1, 0, 0], mode='replicate')
flow = self.basic_module[level](torch.cat([
ref[level],
flow_warp(
supp[level], upsampled_flow.permute(0, 2, 3, 1), interp_mode='bilinear', padding_mode='border'),
upsampled_flow
], 1)) + upsampled_flow
return flow
def forward(self, ref, supp):
assert ref.size() == supp.size()
h, w = ref.size(2), ref.size(3)
w_floor = math.floor(math.ceil(w / 32.0) * 32.0)
h_floor = math.floor(math.ceil(h / 32.0) * 32.0)
ref = F.interpolate(input=ref, size=(h_floor, w_floor), mode='bilinear', align_corners=False)
supp = F.interpolate(input=supp, size=(h_floor, w_floor), mode='bilinear', align_corners=False)
flow = F.interpolate(input=self.process(ref, supp), size=(h, w), mode='bilinear', align_corners=False)
flow[:, 0, :, :] *= float(w) / float(w_floor)
flow[:, 1, :, :] *= float(h) / float(h_floor)
return flow
from torch import nn as nn
from torch.nn import functional as F
from basicsr.utils.registry import ARCH_REGISTRY
from .arch_util import ResidualBlockNoBN, default_init_weights, make_layer
@ARCH_REGISTRY.register()
class MSRResNet(nn.Module):
"""Modified SRResNet.
A compacted version modified from SRResNet in
"Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network"
It uses residual blocks without BN, similar to EDSR.
Currently, it supports x2, x3 and x4 upsampling scale factor.
Args:
num_in_ch (int): Channel number of inputs. Default: 3.
num_out_ch (int): Channel number of outputs. Default: 3.
num_feat (int): Channel number of intermediate features. Default: 64.
num_block (int): Block number in the body network. Default: 16.
upscale (int): Upsampling factor. Support x2, x3 and x4. Default: 4.
"""
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=16, upscale=4):
super(MSRResNet, self).__init__()
self.upscale = upscale
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
self.body = make_layer(ResidualBlockNoBN, num_block, num_feat=num_feat)
# upsampling
if self.upscale in [2, 3]:
self.upconv1 = nn.Conv2d(num_feat, num_feat * self.upscale * self.upscale, 3, 1, 1)
self.pixel_shuffle = nn.PixelShuffle(self.upscale)
elif self.upscale == 4:
self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1)
self.upconv2 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1)
self.pixel_shuffle = nn.PixelShuffle(2)
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
# initialization
default_init_weights([self.conv_first, self.upconv1, self.conv_hr, self.conv_last], 0.1)
if self.upscale == 4:
default_init_weights(self.upconv2, 0.1)
def forward(self, x):
feat = self.lrelu(self.conv_first(x))
out = self.body(feat)
if self.upscale == 4:
out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
elif self.upscale in [2, 3]:
out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
out = self.conv_last(self.lrelu(self.conv_hr(out)))
base = F.interpolate(x, scale_factor=self.upscale, mode='bilinear', align_corners=False)
out += base
return out
from torch import nn as nn
from torch.nn import functional as F
from basicsr.utils.registry import ARCH_REGISTRY
@ARCH_REGISTRY.register(suffix='basicsr')
class SRVGGNetCompact(nn.Module):
"""A compact VGG-style network structure for super-resolution.
It is a compact network structure, which performs upsampling in the last layer and no convolution is
conducted on the HR feature space.
Args:
num_in_ch (int): Channel number of inputs. Default: 3.
num_out_ch (int): Channel number of outputs. Default: 3.
num_feat (int): Channel number of intermediate features. Default: 64.
num_conv (int): Number of convolution layers in the body network. Default: 16.
upscale (int): Upsampling factor. Default: 4.
act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu.
"""
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
super(SRVGGNetCompact, self).__init__()
self.num_in_ch = num_in_ch
self.num_out_ch = num_out_ch
self.num_feat = num_feat
self.num_conv = num_conv
self.upscale = upscale
self.act_type = act_type
self.body = nn.ModuleList()
# the first conv
self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
# the first activation
if act_type == 'relu':
activation = nn.ReLU(inplace=True)
elif act_type == 'prelu':
activation = nn.PReLU(num_parameters=num_feat)
elif act_type == 'leakyrelu':
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
self.body.append(activation)
# the body structure
for _ in range(num_conv):
self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
# activation
if act_type == 'relu':
activation = nn.ReLU(inplace=True)
elif act_type == 'prelu':
activation = nn.PReLU(num_parameters=num_feat)
elif act_type == 'leakyrelu':
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
self.body.append(activation)
# the last conv
self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
# upsample
self.upsampler = nn.PixelShuffle(upscale)
def forward(self, x):
out = x
for i in range(0, len(self.body)):
out = self.body[i](out)
out = self.upsampler(out)
# add the nearest upsampled image, so that the network learns the residual
base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
out += base
return out
import math
import random
import torch
from torch import nn
from torch.nn import functional as F
from basicsr.ops.fused_act import FusedLeakyReLU, fused_leaky_relu
from basicsr.ops.upfirdn2d import upfirdn2d
from basicsr.utils.registry import ARCH_REGISTRY
class NormStyleCode(nn.Module):
def forward(self, x):
"""Normalize the style codes.
Args:
x (Tensor): Style codes with shape (b, c).
Returns:
Tensor: Normalized tensor.
"""
return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
def make_resample_kernel(k):
"""Make resampling kernel for UpFirDn.
Args:
k (list[int]): A list indicating the 1D resample kernel magnitude.
Returns:
Tensor: 2D resampled kernel.
"""
k = torch.tensor(k, dtype=torch.float32)
if k.ndim == 1:
k = k[None, :] * k[:, None] # to 2D kernel, outer product
# normalize
k /= k.sum()
return k
class UpFirDnUpsample(nn.Module):
"""Upsample, FIR filter, and downsample (upsampole version).
References:
1. https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.upfirdn.html # noqa: E501
2. http://www.ece.northwestern.edu/local-apps/matlabhelp/toolbox/signal/upfirdn.html # noqa: E501
Args:
resample_kernel (list[int]): A list indicating the 1D resample kernel
magnitude.
factor (int): Upsampling scale factor. Default: 2.
"""
def __init__(self, resample_kernel, factor=2):
super(UpFirDnUpsample, self).__init__()
self.kernel = make_resample_kernel(resample_kernel) * (factor**2)
self.factor = factor
pad = self.kernel.shape[0] - factor
self.pad = ((pad + 1) // 2 + factor - 1, pad // 2)
def forward(self, x):
out = upfirdn2d(x, self.kernel.type_as(x), up=self.factor, down=1, pad=self.pad)
return out
def __repr__(self):
return (f'{self.__class__.__name__}(factor={self.factor})')
class UpFirDnDownsample(nn.Module):
"""Upsample, FIR filter, and downsample (downsampole version).
Args:
resample_kernel (list[int]): A list indicating the 1D resample kernel
magnitude.
factor (int): Downsampling scale factor. Default: 2.
"""
def __init__(self, resample_kernel, factor=2):
super(UpFirDnDownsample, self).__init__()
self.kernel = make_resample_kernel(resample_kernel)
self.factor = factor
pad = self.kernel.shape[0] - factor
self.pad = ((pad + 1) // 2, pad // 2)
def forward(self, x):
out = upfirdn2d(x, self.kernel.type_as(x), up=1, down=self.factor, pad=self.pad)
return out
def __repr__(self):
return (f'{self.__class__.__name__}(factor={self.factor})')
class UpFirDnSmooth(nn.Module):
"""Upsample, FIR filter, and downsample (smooth version).
Args:
resample_kernel (list[int]): A list indicating the 1D resample kernel
magnitude.
upsample_factor (int): Upsampling scale factor. Default: 1.
downsample_factor (int): Downsampling scale factor. Default: 1.
kernel_size (int): Kernel size: Default: 1.
"""
def __init__(self, resample_kernel, upsample_factor=1, downsample_factor=1, kernel_size=1):
super(UpFirDnSmooth, self).__init__()
self.upsample_factor = upsample_factor
self.downsample_factor = downsample_factor
self.kernel = make_resample_kernel(resample_kernel)
if upsample_factor > 1:
self.kernel = self.kernel * (upsample_factor**2)
if upsample_factor > 1:
pad = (self.kernel.shape[0] - upsample_factor) - (kernel_size - 1)
self.pad = ((pad + 1) // 2 + upsample_factor - 1, pad // 2 + 1)
elif downsample_factor > 1:
pad = (self.kernel.shape[0] - downsample_factor) + (kernel_size - 1)
self.pad = ((pad + 1) // 2, pad // 2)
else:
raise NotImplementedError
def forward(self, x):
out = upfirdn2d(x, self.kernel.type_as(x), up=1, down=1, pad=self.pad)
return out
def __repr__(self):
return (f'{self.__class__.__name__}(upsample_factor={self.upsample_factor}'
f', downsample_factor={self.downsample_factor})')
class EqualLinear(nn.Module):
"""Equalized Linear as StyleGAN2.
Args:
in_channels (int): Size of each sample.
out_channels (int): Size of each output sample.
bias (bool): If set to ``False``, the layer will not learn an additive
bias. Default: ``True``.
bias_init_val (float): Bias initialized value. Default: 0.
lr_mul (float): Learning rate multiplier. Default: 1.
activation (None | str): The activation after ``linear`` operation.
Supported: 'fused_lrelu', None. Default: None.
"""
def __init__(self, in_channels, out_channels, bias=True, bias_init_val=0, lr_mul=1, activation=None):
super(EqualLinear, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.lr_mul = lr_mul
self.activation = activation
if self.activation not in ['fused_lrelu', None]:
raise ValueError(f'Wrong activation value in EqualLinear: {activation}'
"Supported ones are: ['fused_lrelu', None].")
self.scale = (1 / math.sqrt(in_channels)) * lr_mul
self.weight = nn.Parameter(torch.randn(out_channels, in_channels).div_(lr_mul))
if bias:
self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
else:
self.register_parameter('bias', None)
def forward(self, x):
if self.bias is None:
bias = None
else:
bias = self.bias * self.lr_mul
if self.activation == 'fused_lrelu':
out = F.linear(x, self.weight * self.scale)
out = fused_leaky_relu(out, bias)
else:
out = F.linear(x, self.weight * self.scale, bias=bias)
return out
def __repr__(self):
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
f'out_channels={self.out_channels}, bias={self.bias is not None})')
class ModulatedConv2d(nn.Module):
"""Modulated Conv2d used in StyleGAN2.
There is no bias in ModulatedConv2d.
Args:
in_channels (int): Channel number of the input.
out_channels (int): Channel number of the output.
kernel_size (int): Size of the convolving kernel.
num_style_feat (int): Channel number of style features.
demodulate (bool): Whether to demodulate in the conv layer.
Default: True.
sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
Default: None.
resample_kernel (list[int]): A list indicating the 1D resample kernel
magnitude. Default: (1, 3, 3, 1).
eps (float): A value added to the denominator for numerical stability.
Default: 1e-8.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
num_style_feat,
demodulate=True,
sample_mode=None,
resample_kernel=(1, 3, 3, 1),
eps=1e-8):
super(ModulatedConv2d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.demodulate = demodulate
self.sample_mode = sample_mode
self.eps = eps
if self.sample_mode == 'upsample':
self.smooth = UpFirDnSmooth(
resample_kernel, upsample_factor=2, downsample_factor=1, kernel_size=kernel_size)
elif self.sample_mode == 'downsample':
self.smooth = UpFirDnSmooth(
resample_kernel, upsample_factor=1, downsample_factor=2, kernel_size=kernel_size)
elif self.sample_mode is None:
pass
else:
raise ValueError(f'Wrong sample mode {self.sample_mode}, '
"supported ones are ['upsample', 'downsample', None].")
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
# modulation inside each modulated conv
self.modulation = EqualLinear(
num_style_feat, in_channels, bias=True, bias_init_val=1, lr_mul=1, activation=None)
self.weight = nn.Parameter(torch.randn(1, out_channels, in_channels, kernel_size, kernel_size))
self.padding = kernel_size // 2
def forward(self, x, style):
"""Forward function.
Args:
x (Tensor): Tensor with shape (b, c, h, w).
style (Tensor): Tensor with shape (b, num_style_feat).
Returns:
Tensor: Modulated tensor after convolution.
"""
b, c, h, w = x.shape # c = c_in
# weight modulation
style = self.modulation(style).view(b, 1, c, 1, 1)
# self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
weight = self.scale * self.weight * style # (b, c_out, c_in, k, k)
if self.demodulate:
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
if self.sample_mode == 'upsample':
x = x.view(1, b * c, h, w)
weight = weight.view(b, self.out_channels, c, self.kernel_size, self.kernel_size)
weight = weight.transpose(1, 2).reshape(b * c, self.out_channels, self.kernel_size, self.kernel_size)
out = F.conv_transpose2d(x, weight, padding=0, stride=2, groups=b)
out = out.view(b, self.out_channels, *out.shape[2:4])
out = self.smooth(out)
elif self.sample_mode == 'downsample':
x = self.smooth(x)
x = x.view(1, b * c, *x.shape[2:4])
out = F.conv2d(x, weight, padding=0, stride=2, groups=b)
out = out.view(b, self.out_channels, *out.shape[2:4])
else:
x = x.view(1, b * c, h, w)
# weight: (b*c_out, c_in, k, k), groups=b
out = F.conv2d(x, weight, padding=self.padding, groups=b)
out = out.view(b, self.out_channels, *out.shape[2:4])
return out
def __repr__(self):
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
f'out_channels={self.out_channels}, '
f'kernel_size={self.kernel_size}, '
f'demodulate={self.demodulate}, sample_mode={self.sample_mode})')
class StyleConv(nn.Module):
"""Style conv.
Args:
in_channels (int): Channel number of the input.
out_channels (int): Channel number of the output.
kernel_size (int): Size of the convolving kernel.
num_style_feat (int): Channel number of style features.
demodulate (bool): Whether demodulate in the conv layer. Default: True.
sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
Default: None.
resample_kernel (list[int]): A list indicating the 1D resample kernel
magnitude. Default: (1, 3, 3, 1).
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
num_style_feat,
demodulate=True,
sample_mode=None,
resample_kernel=(1, 3, 3, 1)):
super(StyleConv, self).__init__()
self.modulated_conv = ModulatedConv2d(
in_channels,
out_channels,
kernel_size,
num_style_feat,
demodulate=demodulate,
sample_mode=sample_mode,
resample_kernel=resample_kernel)
self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
self.activate = FusedLeakyReLU(out_channels)
def forward(self, x, style, noise=None):
# modulate
out = self.modulated_conv(x, style)
# noise injection
if noise is None:
b, _, h, w = out.shape
noise = out.new_empty(b, 1, h, w).normal_()
out = out + self.weight * noise
# activation (with bias)
out = self.activate(out)
return out
class ToRGB(nn.Module):
"""To RGB from features.
Args:
in_channels (int): Channel number of input.
num_style_feat (int): Channel number of style features.
upsample (bool): Whether to upsample. Default: True.
resample_kernel (list[int]): A list indicating the 1D resample kernel
magnitude. Default: (1, 3, 3, 1).
"""
def __init__(self, in_channels, num_style_feat, upsample=True, resample_kernel=(1, 3, 3, 1)):
super(ToRGB, self).__init__()
if upsample:
self.upsample = UpFirDnUpsample(resample_kernel, factor=2)
else:
self.upsample = None
self.modulated_conv = ModulatedConv2d(
in_channels, 3, kernel_size=1, num_style_feat=num_style_feat, demodulate=False, sample_mode=None)
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
def forward(self, x, style, skip=None):
"""Forward function.
Args:
x (Tensor): Feature tensor with shape (b, c, h, w).
style (Tensor): Tensor with shape (b, num_style_feat).
skip (Tensor): Base/skip tensor. Default: None.
Returns:
Tensor: RGB images.
"""
out = self.modulated_conv(x, style)
out = out + self.bias
if skip is not None:
if self.upsample:
skip = self.upsample(skip)
out = out + skip
return out
class ConstantInput(nn.Module):
"""Constant input.
Args:
num_channel (int): Channel number of constant input.
size (int): Spatial size of constant input.
"""
def __init__(self, num_channel, size):
super(ConstantInput, self).__init__()
self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
def forward(self, batch):
out = self.weight.repeat(batch, 1, 1, 1)
return out
@ARCH_REGISTRY.register()
class StyleGAN2Generator(nn.Module):
"""StyleGAN2 Generator.
Args:
out_size (int): The spatial size of outputs.
num_style_feat (int): Channel number of style features. Default: 512.
num_mlp (int): Layer number of MLP style layers. Default: 8.
channel_multiplier (int): Channel multiplier for large networks of
StyleGAN2. Default: 2.
resample_kernel (list[int]): A list indicating the 1D resample kernel
magnitude. A cross production will be applied to extent 1D resample
kernel to 2D resample kernel. Default: (1, 3, 3, 1).
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
narrow (float): Narrow ratio for channels. Default: 1.0.
"""
def __init__(self,
out_size,
num_style_feat=512,
num_mlp=8,
channel_multiplier=2,
resample_kernel=(1, 3, 3, 1),
lr_mlp=0.01,
narrow=1):
super(StyleGAN2Generator, self).__init__()
# Style MLP layers
self.num_style_feat = num_style_feat
style_mlp_layers = [NormStyleCode()]
for i in range(num_mlp):
style_mlp_layers.append(
EqualLinear(
num_style_feat, num_style_feat, bias=True, bias_init_val=0, lr_mul=lr_mlp,
activation='fused_lrelu'))
self.style_mlp = nn.Sequential(*style_mlp_layers)
channels = {
'4': int(512 * narrow),
'8': int(512 * narrow),
'16': int(512 * narrow),
'32': int(512 * narrow),
'64': int(256 * channel_multiplier * narrow),
'128': int(128 * channel_multiplier * narrow),
'256': int(64 * channel_multiplier * narrow),
'512': int(32 * channel_multiplier * narrow),
'1024': int(16 * channel_multiplier * narrow)
}
self.channels = channels
self.constant_input = ConstantInput(channels['4'], size=4)
self.style_conv1 = StyleConv(
channels['4'],
channels['4'],
kernel_size=3,
num_style_feat=num_style_feat,
demodulate=True,
sample_mode=None,
resample_kernel=resample_kernel)
self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False, resample_kernel=resample_kernel)
self.log_size = int(math.log(out_size, 2))
self.num_layers = (self.log_size - 2) * 2 + 1
self.num_latent = self.log_size * 2 - 2
self.style_convs = nn.ModuleList()
self.to_rgbs = nn.ModuleList()
self.noises = nn.Module()
in_channels = channels['4']
# noise
for layer_idx in range(self.num_layers):
resolution = 2**((layer_idx + 5) // 2)
shape = [1, 1, resolution, resolution]
self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape))
# style convs and to_rgbs
for i in range(3, self.log_size + 1):
out_channels = channels[f'{2**i}']
self.style_convs.append(
StyleConv(
in_channels,
out_channels,
kernel_size=3,
num_style_feat=num_style_feat,
demodulate=True,
sample_mode='upsample',
resample_kernel=resample_kernel,
))
self.style_convs.append(
StyleConv(
out_channels,
out_channels,
kernel_size=3,
num_style_feat=num_style_feat,
demodulate=True,
sample_mode=None,
resample_kernel=resample_kernel))
self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True, resample_kernel=resample_kernel))
in_channels = out_channels
def make_noise(self):
"""Make noise for noise injection."""
device = self.constant_input.weight.device
noises = [torch.randn(1, 1, 4, 4, device=device)]
for i in range(3, self.log_size + 1):
for _ in range(2):
noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
return noises
def get_latent(self, x):
return self.style_mlp(x)
def mean_latent(self, num_latent):
latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device)
latent = self.style_mlp(latent_in).mean(0, keepdim=True)
return latent
def forward(self,
styles,
input_is_latent=False,
noise=None,
randomize_noise=True,
truncation=1,
truncation_latent=None,
inject_index=None,
return_latents=False):
"""Forward function for StyleGAN2Generator.
Args:
styles (list[Tensor]): Sample codes of styles.
input_is_latent (bool): Whether input is latent style.
Default: False.
noise (Tensor | None): Input noise or None. Default: None.
randomize_noise (bool): Randomize noise, used when 'noise' is
False. Default: True.
truncation (float): TODO. Default: 1.
truncation_latent (Tensor | None): TODO. Default: None.
inject_index (int | None): The injection index for mixing noise.
Default: None.
return_latents (bool): Whether to return style latents.
Default: False.
"""
# style codes -> latents with Style MLP layer
if not input_is_latent:
styles = [self.style_mlp(s) for s in styles]
# noises
if noise is None:
if randomize_noise:
noise = [None] * self.num_layers # for each style conv layer
else: # use the stored noise
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
# style truncation
if truncation < 1:
style_truncation = []
for style in styles:
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
styles = style_truncation
# get style latent with injection
if len(styles) == 1:
inject_index = self.num_latent
if styles[0].ndim < 3:
# repeat latent code for all the layers
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
else: # used for encoder with different latent code for each layer
latent = styles[0]
elif len(styles) == 2: # mixing noises
if inject_index is None:
inject_index = random.randint(1, self.num_latent - 1)
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
latent = torch.cat([latent1, latent2], 1)
# main generation
out = self.constant_input(latent.shape[0])
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
skip = self.to_rgb1(out, latent[:, 1])
i = 1
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
noise[2::2], self.to_rgbs):
out = conv1(out, latent[:, i], noise=noise1)
out = conv2(out, latent[:, i + 1], noise=noise2)
skip = to_rgb(out, latent[:, i + 2], skip)
i += 2
image = skip
if return_latents:
return image, latent
else:
return image, None
class ScaledLeakyReLU(nn.Module):
"""Scaled LeakyReLU.
Args:
negative_slope (float): Negative slope. Default: 0.2.
"""
def __init__(self, negative_slope=0.2):
super(ScaledLeakyReLU, self).__init__()
self.negative_slope = negative_slope
def forward(self, x):
out = F.leaky_relu(x, negative_slope=self.negative_slope)
return out * math.sqrt(2)
class EqualConv2d(nn.Module):
"""Equalized Linear as StyleGAN2.
Args:
in_channels (int): Channel number of the input.
out_channels (int): Channel number of the output.
kernel_size (int): Size of the convolving kernel.
stride (int): Stride of the convolution. Default: 1
padding (int): Zero-padding added to both sides of the input.
Default: 0.
bias (bool): If ``True``, adds a learnable bias to the output.
Default: ``True``.
bias_init_val (float): Bias initialized value. Default: 0.
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, bias_init_val=0):
super(EqualConv2d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
if bias:
self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
else:
self.register_parameter('bias', None)
def forward(self, x):
out = F.conv2d(
x,
self.weight * self.scale,
bias=self.bias,
stride=self.stride,
padding=self.padding,
)
return out
def __repr__(self):
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
f'out_channels={self.out_channels}, '
f'kernel_size={self.kernel_size},'
f' stride={self.stride}, padding={self.padding}, '
f'bias={self.bias is not None})')
class ConvLayer(nn.Sequential):
"""Conv Layer used in StyleGAN2 Discriminator.
Args:
in_channels (int): Channel number of the input.
out_channels (int): Channel number of the output.
kernel_size (int): Kernel size.
downsample (bool): Whether downsample by a factor of 2.
Default: False.
resample_kernel (list[int]): A list indicating the 1D resample
kernel magnitude. A cross production will be applied to
extent 1D resample kernel to 2D resample kernel.
Default: (1, 3, 3, 1).
bias (bool): Whether with bias. Default: True.
activate (bool): Whether use activateion. Default: True.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
downsample=False,
resample_kernel=(1, 3, 3, 1),
bias=True,
activate=True):
layers = []
# downsample
if downsample:
layers.append(
UpFirDnSmooth(resample_kernel, upsample_factor=1, downsample_factor=2, kernel_size=kernel_size))
stride = 2
self.padding = 0
else:
stride = 1
self.padding = kernel_size // 2
# conv
layers.append(
EqualConv2d(
in_channels, out_channels, kernel_size, stride=stride, padding=self.padding, bias=bias
and not activate))
# activation
if activate:
if bias:
layers.append(FusedLeakyReLU(out_channels))
else:
layers.append(ScaledLeakyReLU(0.2))
super(ConvLayer, self).__init__(*layers)
class ResBlock(nn.Module):
"""Residual block used in StyleGAN2 Discriminator.
Args:
in_channels (int): Channel number of the input.
out_channels (int): Channel number of the output.
resample_kernel (list[int]): A list indicating the 1D resample
kernel magnitude. A cross production will be applied to
extent 1D resample kernel to 2D resample kernel.
Default: (1, 3, 3, 1).
"""
def __init__(self, in_channels, out_channels, resample_kernel=(1, 3, 3, 1)):
super(ResBlock, self).__init__()
self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
self.conv2 = ConvLayer(
in_channels, out_channels, 3, downsample=True, resample_kernel=resample_kernel, bias=True, activate=True)
self.skip = ConvLayer(
in_channels, out_channels, 1, downsample=True, resample_kernel=resample_kernel, bias=False, activate=False)
def forward(self, x):
out = self.conv1(x)
out = self.conv2(out)
skip = self.skip(x)
out = (out + skip) / math.sqrt(2)
return out
@ARCH_REGISTRY.register()
class StyleGAN2Discriminator(nn.Module):
"""StyleGAN2 Discriminator.
Args:
out_size (int): The spatial size of outputs.
channel_multiplier (int): Channel multiplier for large networks of
StyleGAN2. Default: 2.
resample_kernel (list[int]): A list indicating the 1D resample kernel
magnitude. A cross production will be applied to extent 1D resample
kernel to 2D resample kernel. Default: (1, 3, 3, 1).
stddev_group (int): For group stddev statistics. Default: 4.
narrow (float): Narrow ratio for channels. Default: 1.0.
"""
def __init__(self, out_size, channel_multiplier=2, resample_kernel=(1, 3, 3, 1), stddev_group=4, narrow=1):
super(StyleGAN2Discriminator, self).__init__()
channels = {
'4': int(512 * narrow),
'8': int(512 * narrow),
'16': int(512 * narrow),
'32': int(512 * narrow),
'64': int(256 * channel_multiplier * narrow),
'128': int(128 * channel_multiplier * narrow),
'256': int(64 * channel_multiplier * narrow),
'512': int(32 * channel_multiplier * narrow),
'1024': int(16 * channel_multiplier * narrow)
}
log_size = int(math.log(out_size, 2))
conv_body = [ConvLayer(3, channels[f'{out_size}'], 1, bias=True, activate=True)]
in_channels = channels[f'{out_size}']
for i in range(log_size, 2, -1):
out_channels = channels[f'{2**(i - 1)}']
conv_body.append(ResBlock(in_channels, out_channels, resample_kernel))
in_channels = out_channels
self.conv_body = nn.Sequential(*conv_body)
self.final_conv = ConvLayer(in_channels + 1, channels['4'], 3, bias=True, activate=True)
self.final_linear = nn.Sequential(
EqualLinear(
channels['4'] * 4 * 4, channels['4'], bias=True, bias_init_val=0, lr_mul=1, activation='fused_lrelu'),
EqualLinear(channels['4'], 1, bias=True, bias_init_val=0, lr_mul=1, activation=None),
)
self.stddev_group = stddev_group
self.stddev_feat = 1
def forward(self, x):
out = self.conv_body(x)
b, c, h, w = out.shape
# concatenate a group stddev statistics to out
group = min(b, self.stddev_group) # Minibatch must be divisible by (or smaller than) group_size
stddev = out.view(group, -1, self.stddev_feat, c // self.stddev_feat, h, w)
stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
stddev = stddev.repeat(group, 1, h, w)
out = torch.cat([out, stddev], 1)
out = self.final_conv(out)
out = out.view(b, -1)
out = self.final_linear(out)
return out
import math
import random
import torch
from torch import nn
from torch.nn import functional as F
from basicsr.ops.fused_act import FusedLeakyReLU, fused_leaky_relu
from basicsr.utils.registry import ARCH_REGISTRY
class NormStyleCode(nn.Module):
def forward(self, x):
"""Normalize the style codes.
Args:
x (Tensor): Style codes with shape (b, c).
Returns:
Tensor: Normalized tensor.
"""
return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
class EqualLinear(nn.Module):
"""Equalized Linear as StyleGAN2.
Args:
in_channels (int): Size of each sample.
out_channels (int): Size of each output sample.
bias (bool): If set to ``False``, the layer will not learn an additive
bias. Default: ``True``.
bias_init_val (float): Bias initialized value. Default: 0.
lr_mul (float): Learning rate multiplier. Default: 1.
activation (None | str): The activation after ``linear`` operation.
Supported: 'fused_lrelu', None. Default: None.
"""
def __init__(self, in_channels, out_channels, bias=True, bias_init_val=0, lr_mul=1, activation=None):
super(EqualLinear, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.lr_mul = lr_mul
self.activation = activation
if self.activation not in ['fused_lrelu', None]:
raise ValueError(f'Wrong activation value in EqualLinear: {activation}'
"Supported ones are: ['fused_lrelu', None].")
self.scale = (1 / math.sqrt(in_channels)) * lr_mul
self.weight = nn.Parameter(torch.randn(out_channels, in_channels).div_(lr_mul))
if bias:
self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
else:
self.register_parameter('bias', None)
def forward(self, x):
if self.bias is None:
bias = None
else:
bias = self.bias * self.lr_mul
if self.activation == 'fused_lrelu':
out = F.linear(x, self.weight * self.scale)
out = fused_leaky_relu(out, bias)
else:
out = F.linear(x, self.weight * self.scale, bias=bias)
return out
def __repr__(self):
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
f'out_channels={self.out_channels}, bias={self.bias is not None})')
class ModulatedConv2d(nn.Module):
"""Modulated Conv2d used in StyleGAN2.
There is no bias in ModulatedConv2d.
Args:
in_channels (int): Channel number of the input.
out_channels (int): Channel number of the output.
kernel_size (int): Size of the convolving kernel.
num_style_feat (int): Channel number of style features.
demodulate (bool): Whether to demodulate in the conv layer.
Default: True.
sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
Default: None.
eps (float): A value added to the denominator for numerical stability.
Default: 1e-8.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
num_style_feat,
demodulate=True,
sample_mode=None,
eps=1e-8,
interpolation_mode='bilinear'):
super(ModulatedConv2d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.demodulate = demodulate
self.sample_mode = sample_mode
self.eps = eps
self.interpolation_mode = interpolation_mode
if self.interpolation_mode == 'nearest':
self.align_corners = None
else:
self.align_corners = False
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
# modulation inside each modulated conv
self.modulation = EqualLinear(
num_style_feat, in_channels, bias=True, bias_init_val=1, lr_mul=1, activation=None)
self.weight = nn.Parameter(torch.randn(1, out_channels, in_channels, kernel_size, kernel_size))
self.padding = kernel_size // 2
def forward(self, x, style):
"""Forward function.
Args:
x (Tensor): Tensor with shape (b, c, h, w).
style (Tensor): Tensor with shape (b, num_style_feat).
Returns:
Tensor: Modulated tensor after convolution.
"""
b, c, h, w = x.shape # c = c_in
# weight modulation
style = self.modulation(style).view(b, 1, c, 1, 1)
# self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
weight = self.scale * self.weight * style # (b, c_out, c_in, k, k)
if self.demodulate:
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
if self.sample_mode == 'upsample':
x = F.interpolate(x, scale_factor=2, mode=self.interpolation_mode, align_corners=self.align_corners)
elif self.sample_mode == 'downsample':
x = F.interpolate(x, scale_factor=0.5, mode=self.interpolation_mode, align_corners=self.align_corners)
b, c, h, w = x.shape
x = x.view(1, b * c, h, w)
# weight: (b*c_out, c_in, k, k), groups=b
out = F.conv2d(x, weight, padding=self.padding, groups=b)
out = out.view(b, self.out_channels, *out.shape[2:4])
return out
def __repr__(self):
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
f'out_channels={self.out_channels}, '
f'kernel_size={self.kernel_size}, '
f'demodulate={self.demodulate}, sample_mode={self.sample_mode})')
class StyleConv(nn.Module):
"""Style conv.
Args:
in_channels (int): Channel number of the input.
out_channels (int): Channel number of the output.
kernel_size (int): Size of the convolving kernel.
num_style_feat (int): Channel number of style features.
demodulate (bool): Whether demodulate in the conv layer. Default: True.
sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
Default: None.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
num_style_feat,
demodulate=True,
sample_mode=None,
interpolation_mode='bilinear'):
super(StyleConv, self).__init__()
self.modulated_conv = ModulatedConv2d(
in_channels,
out_channels,
kernel_size,
num_style_feat,
demodulate=demodulate,
sample_mode=sample_mode,
interpolation_mode=interpolation_mode)
self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
self.activate = FusedLeakyReLU(out_channels)
def forward(self, x, style, noise=None):
# modulate
out = self.modulated_conv(x, style)
# noise injection
if noise is None:
b, _, h, w = out.shape
noise = out.new_empty(b, 1, h, w).normal_()
out = out + self.weight * noise
# activation (with bias)
out = self.activate(out)
return out
class ToRGB(nn.Module):
"""To RGB from features.
Args:
in_channels (int): Channel number of input.
num_style_feat (int): Channel number of style features.
upsample (bool): Whether to upsample. Default: True.
"""
def __init__(self, in_channels, num_style_feat, upsample=True, interpolation_mode='bilinear'):
super(ToRGB, self).__init__()
self.upsample = upsample
self.interpolation_mode = interpolation_mode
if self.interpolation_mode == 'nearest':
self.align_corners = None
else:
self.align_corners = False
self.modulated_conv = ModulatedConv2d(
in_channels,
3,
kernel_size=1,
num_style_feat=num_style_feat,
demodulate=False,
sample_mode=None,
interpolation_mode=interpolation_mode)
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
def forward(self, x, style, skip=None):
"""Forward function.
Args:
x (Tensor): Feature tensor with shape (b, c, h, w).
style (Tensor): Tensor with shape (b, num_style_feat).
skip (Tensor): Base/skip tensor. Default: None.
Returns:
Tensor: RGB images.
"""
out = self.modulated_conv(x, style)
out = out + self.bias
if skip is not None:
if self.upsample:
skip = F.interpolate(
skip, scale_factor=2, mode=self.interpolation_mode, align_corners=self.align_corners)
out = out + skip
return out
class ConstantInput(nn.Module):
"""Constant input.
Args:
num_channel (int): Channel number of constant input.
size (int): Spatial size of constant input.
"""
def __init__(self, num_channel, size):
super(ConstantInput, self).__init__()
self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
def forward(self, batch):
out = self.weight.repeat(batch, 1, 1, 1)
return out
@ARCH_REGISTRY.register(suffix='basicsr')
class StyleGAN2GeneratorBilinear(nn.Module):
"""StyleGAN2 Generator.
Args:
out_size (int): The spatial size of outputs.
num_style_feat (int): Channel number of style features. Default: 512.
num_mlp (int): Layer number of MLP style layers. Default: 8.
channel_multiplier (int): Channel multiplier for large networks of
StyleGAN2. Default: 2.
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
narrow (float): Narrow ratio for channels. Default: 1.0.
"""
def __init__(self,
out_size,
num_style_feat=512,
num_mlp=8,
channel_multiplier=2,
lr_mlp=0.01,
narrow=1,
interpolation_mode='bilinear'):
super(StyleGAN2GeneratorBilinear, self).__init__()
# Style MLP layers
self.num_style_feat = num_style_feat
style_mlp_layers = [NormStyleCode()]
for i in range(num_mlp):
style_mlp_layers.append(
EqualLinear(
num_style_feat, num_style_feat, bias=True, bias_init_val=0, lr_mul=lr_mlp,
activation='fused_lrelu'))
self.style_mlp = nn.Sequential(*style_mlp_layers)
channels = {
'4': int(512 * narrow),
'8': int(512 * narrow),
'16': int(512 * narrow),
'32': int(512 * narrow),
'64': int(256 * channel_multiplier * narrow),
'128': int(128 * channel_multiplier * narrow),
'256': int(64 * channel_multiplier * narrow),
'512': int(32 * channel_multiplier * narrow),
'1024': int(16 * channel_multiplier * narrow)
}
self.channels = channels
self.constant_input = ConstantInput(channels['4'], size=4)
self.style_conv1 = StyleConv(
channels['4'],
channels['4'],
kernel_size=3,
num_style_feat=num_style_feat,
demodulate=True,
sample_mode=None,
interpolation_mode=interpolation_mode)
self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False, interpolation_mode=interpolation_mode)
self.log_size = int(math.log(out_size, 2))
self.num_layers = (self.log_size - 2) * 2 + 1
self.num_latent = self.log_size * 2 - 2
self.style_convs = nn.ModuleList()
self.to_rgbs = nn.ModuleList()
self.noises = nn.Module()
in_channels = channels['4']
# noise
for layer_idx in range(self.num_layers):
resolution = 2**((layer_idx + 5) // 2)
shape = [1, 1, resolution, resolution]
self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape))
# style convs and to_rgbs
for i in range(3, self.log_size + 1):
out_channels = channels[f'{2**i}']
self.style_convs.append(
StyleConv(
in_channels,
out_channels,
kernel_size=3,
num_style_feat=num_style_feat,
demodulate=True,
sample_mode='upsample',
interpolation_mode=interpolation_mode))
self.style_convs.append(
StyleConv(
out_channels,
out_channels,
kernel_size=3,
num_style_feat=num_style_feat,
demodulate=True,
sample_mode=None,
interpolation_mode=interpolation_mode))
self.to_rgbs.append(
ToRGB(out_channels, num_style_feat, upsample=True, interpolation_mode=interpolation_mode))
in_channels = out_channels
def make_noise(self):
"""Make noise for noise injection."""
device = self.constant_input.weight.device
noises = [torch.randn(1, 1, 4, 4, device=device)]
for i in range(3, self.log_size + 1):
for _ in range(2):
noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
return noises
def get_latent(self, x):
return self.style_mlp(x)
def mean_latent(self, num_latent):
latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device)
latent = self.style_mlp(latent_in).mean(0, keepdim=True)
return latent
def forward(self,
styles,
input_is_latent=False,
noise=None,
randomize_noise=True,
truncation=1,
truncation_latent=None,
inject_index=None,
return_latents=False):
"""Forward function for StyleGAN2Generator.
Args:
styles (list[Tensor]): Sample codes of styles.
input_is_latent (bool): Whether input is latent style.
Default: False.
noise (Tensor | None): Input noise or None. Default: None.
randomize_noise (bool): Randomize noise, used when 'noise' is
False. Default: True.
truncation (float): TODO. Default: 1.
truncation_latent (Tensor | None): TODO. Default: None.
inject_index (int | None): The injection index for mixing noise.
Default: None.
return_latents (bool): Whether to return style latents.
Default: False.
"""
# style codes -> latents with Style MLP layer
if not input_is_latent:
styles = [self.style_mlp(s) for s in styles]
# noises
if noise is None:
if randomize_noise:
noise = [None] * self.num_layers # for each style conv layer
else: # use the stored noise
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
# style truncation
if truncation < 1:
style_truncation = []
for style in styles:
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
styles = style_truncation
# get style latent with injection
if len(styles) == 1:
inject_index = self.num_latent
if styles[0].ndim < 3:
# repeat latent code for all the layers
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
else: # used for encoder with different latent code for each layer
latent = styles[0]
elif len(styles) == 2: # mixing noises
if inject_index is None:
inject_index = random.randint(1, self.num_latent - 1)
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
latent = torch.cat([latent1, latent2], 1)
# main generation
out = self.constant_input(latent.shape[0])
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
skip = self.to_rgb1(out, latent[:, 1])
i = 1
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
noise[2::2], self.to_rgbs):
out = conv1(out, latent[:, i], noise=noise1)
out = conv2(out, latent[:, i + 1], noise=noise2)
skip = to_rgb(out, latent[:, i + 2], skip)
i += 2
image = skip
if return_latents:
return image, latent
else:
return image, None
class ScaledLeakyReLU(nn.Module):
"""Scaled LeakyReLU.
Args:
negative_slope (float): Negative slope. Default: 0.2.
"""
def __init__(self, negative_slope=0.2):
super(ScaledLeakyReLU, self).__init__()
self.negative_slope = negative_slope
def forward(self, x):
out = F.leaky_relu(x, negative_slope=self.negative_slope)
return out * math.sqrt(2)
class EqualConv2d(nn.Module):
"""Equalized Linear as StyleGAN2.
Args:
in_channels (int): Channel number of the input.
out_channels (int): Channel number of the output.
kernel_size (int): Size of the convolving kernel.
stride (int): Stride of the convolution. Default: 1
padding (int): Zero-padding added to both sides of the input.
Default: 0.
bias (bool): If ``True``, adds a learnable bias to the output.
Default: ``True``.
bias_init_val (float): Bias initialized value. Default: 0.
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, bias_init_val=0):
super(EqualConv2d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
if bias:
self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
else:
self.register_parameter('bias', None)
def forward(self, x):
out = F.conv2d(
x,
self.weight * self.scale,
bias=self.bias,
stride=self.stride,
padding=self.padding,
)
return out
def __repr__(self):
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
f'out_channels={self.out_channels}, '
f'kernel_size={self.kernel_size},'
f' stride={self.stride}, padding={self.padding}, '
f'bias={self.bias is not None})')
class ConvLayer(nn.Sequential):
"""Conv Layer used in StyleGAN2 Discriminator.
Args:
in_channels (int): Channel number of the input.
out_channels (int): Channel number of the output.
kernel_size (int): Kernel size.
downsample (bool): Whether downsample by a factor of 2.
Default: False.
bias (bool): Whether with bias. Default: True.
activate (bool): Whether use activateion. Default: True.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
downsample=False,
bias=True,
activate=True,
interpolation_mode='bilinear'):
layers = []
self.interpolation_mode = interpolation_mode
# downsample
if downsample:
if self.interpolation_mode == 'nearest':
self.align_corners = None
else:
self.align_corners = False
layers.append(
torch.nn.Upsample(scale_factor=0.5, mode=interpolation_mode, align_corners=self.align_corners))
stride = 1
self.padding = kernel_size // 2
# conv
layers.append(
EqualConv2d(
in_channels, out_channels, kernel_size, stride=stride, padding=self.padding, bias=bias
and not activate))
# activation
if activate:
if bias:
layers.append(FusedLeakyReLU(out_channels))
else:
layers.append(ScaledLeakyReLU(0.2))
super(ConvLayer, self).__init__(*layers)
class ResBlock(nn.Module):
"""Residual block used in StyleGAN2 Discriminator.
Args:
in_channels (int): Channel number of the input.
out_channels (int): Channel number of the output.
"""
def __init__(self, in_channels, out_channels, interpolation_mode='bilinear'):
super(ResBlock, self).__init__()
self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
self.conv2 = ConvLayer(
in_channels,
out_channels,
3,
downsample=True,
interpolation_mode=interpolation_mode,
bias=True,
activate=True)
self.skip = ConvLayer(
in_channels,
out_channels,
1,
downsample=True,
interpolation_mode=interpolation_mode,
bias=False,
activate=False)
def forward(self, x):
out = self.conv1(x)
out = self.conv2(out)
skip = self.skip(x)
out = (out + skip) / math.sqrt(2)
return out
# Modified from https://github.com/JingyunLiang/SwinIR
# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
# Originally Written by Ze Liu, Modified by Jingyun Liang.
import math
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from basicsr.utils.registry import ARCH_REGISTRY
from .arch_util import to_2tuple, trunc_normal_
def drop_path(x, drop_prob: float = 0., training: bool = False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0], ) + (1, ) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def window_partition(x, window_size):
"""
Args:
x: (b, h, w, c)
window_size (int): window size
Returns:
windows: (num_windows*b, window_size, window_size, c)
"""
b, h, w, c = x.shape
x = x.view(b, h // window_size, window_size, w // window_size, window_size, c)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, c)
return windows
def window_reverse(windows, window_size, h, w):
"""
Args:
windows: (num_windows*b, window_size, window_size, c)
window_size (int): Window size
h (int): Height of image
w (int): Width of image
Returns:
x: (b, h, w, c)
"""
b = int(windows.shape[0] / (h * w / window_size / window_size))
x = windows.view(b, h // window_size, w // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1)
return x
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer('relative_position_index', relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*b, n, c)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
b_, n, c = x.shape
qkv = self.qkv(x).reshape(b_, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nw = mask.shape[0]
attn = attn.view(b_ // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, n, n)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(b_, n, c)
x = self.proj(x)
x = self.proj_drop(x)
return x
def extra_repr(self) -> str:
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
def flops(self, n):
# calculate flops for 1 window with token length of n
flops = 0
# qkv = self.qkv(x)
flops += n * self.dim * 3 * self.dim
# attn = (q @ k.transpose(-2, -1))
flops += self.num_heads * n * (self.dim // self.num_heads) * n
# x = (attn @ v)
flops += self.num_heads * n * n * (self.dim // self.num_heads)
# x = self.proj(x)
flops += n * self.dim * self.dim
return flops
class SwinTransformerBlock(nn.Module):
r""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self,
dim,
input_resolution,
num_heads,
window_size=7,
shift_size=0,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size'
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim,
window_size=to_2tuple(self.window_size),
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if self.shift_size > 0:
attn_mask = self.calculate_mask(self.input_resolution)
else:
attn_mask = None
self.register_buffer('attn_mask', attn_mask)
def calculate_mask(self, x_size):
# calculate attention mask for SW-MSA
h, w = x_size
img_mask = torch.zeros((1, h, w, 1)) # 1 h w 1
h_slices = (slice(0, -self.window_size), slice(-self.window_size,
-self.shift_size), slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size), slice(-self.window_size,
-self.shift_size), slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nw, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
return attn_mask
def forward(self, x, x_size):
h, w = x_size
b, _, c = x.shape
# assert seq_len == h * w, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(b, h, w, c)
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nw*b, window_size, window_size, c
x_windows = x_windows.view(-1, self.window_size * self.window_size, c) # nw*b, window_size*window_size, c
# W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
if self.input_resolution == x_size:
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nw*b, window_size*window_size, c
else:
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, c)
shifted_x = window_reverse(attn_windows, self.window_size, h, w) # b h' w' c
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(b, h * w, c)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
def extra_repr(self) -> str:
return (f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, '
f'window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}')
def flops(self):
flops = 0
h, w = self.input_resolution
# norm1
flops += self.dim * h * w
# W-MSA/SW-MSA
nw = h * w / self.window_size / self.window_size
flops += nw * self.attn.flops(self.window_size * self.window_size)
# mlp
flops += 2 * h * w * self.dim * self.dim * self.mlp_ratio
# norm2
flops += self.dim * h * w
return flops
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
"""
x: b, h*w, c
"""
h, w = self.input_resolution
b, seq_len, c = x.shape
assert seq_len == h * w, 'input feature has wrong size'
assert h % 2 == 0 and w % 2 == 0, f'x size ({h}*{w}) are not even.'
x = x.view(b, h, w, c)
x0 = x[:, 0::2, 0::2, :] # b h/2 w/2 c
x1 = x[:, 1::2, 0::2, :] # b h/2 w/2 c
x2 = x[:, 0::2, 1::2, :] # b h/2 w/2 c
x3 = x[:, 1::2, 1::2, :] # b h/2 w/2 c
x = torch.cat([x0, x1, x2, x3], -1) # b h/2 w/2 4*c
x = x.view(b, -1, 4 * c) # b h/2*w/2 4*c
x = self.norm(x)
x = self.reduction(x)
return x
def extra_repr(self) -> str:
return f'input_resolution={self.input_resolution}, dim={self.dim}'
def flops(self):
h, w = self.input_resolution
flops = h * w * self.dim
flops += (h // 2) * (w // 2) * 4 * self.dim * 2 * self.dim
return flops
class BasicLayer(nn.Module):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(self,
dim,
input_resolution,
depth,
num_heads,
window_size,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
norm_layer=nn.LayerNorm,
downsample=None,
use_checkpoint=False):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer) for i in range(depth)
])
# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x, x_size):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x, x_size)
if self.downsample is not None:
x = self.downsample(x)
return x
def extra_repr(self) -> str:
return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}'
def flops(self):
flops = 0
for blk in self.blocks:
flops += blk.flops()
if self.downsample is not None:
flops += self.downsample.flops()
return flops
class RSTB(nn.Module):
"""Residual Swin Transformer Block (RSTB).
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
img_size: Input image size.
patch_size: Patch size.
resi_connection: The convolutional block before residual connection.
"""
def __init__(self,
dim,
input_resolution,
depth,
num_heads,
window_size,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
norm_layer=nn.LayerNorm,
downsample=None,
use_checkpoint=False,
img_size=224,
patch_size=4,
resi_connection='1conv'):
super(RSTB, self).__init__()
self.dim = dim
self.input_resolution = input_resolution
self.residual_group = BasicLayer(
dim=dim,
input_resolution=input_resolution,
depth=depth,
num_heads=num_heads,
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path,
norm_layer=norm_layer,
downsample=downsample,
use_checkpoint=use_checkpoint)
if resi_connection == '1conv':
self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
elif resi_connection == '3conv':
# to save parameters and memory
self.conv = nn.Sequential(
nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(dim // 4, dim, 3, 1, 1))
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
self.patch_unembed = PatchUnEmbed(
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
def forward(self, x, x_size):
return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
def flops(self):
flops = 0
flops += self.residual_group.flops()
h, w = self.input_resolution
flops += h * w * self.dim * self.dim * 9
flops += self.patch_embed.flops()
flops += self.patch_unembed.flops()
return flops
class PatchEmbed(nn.Module):
r""" Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
x = x.flatten(2).transpose(1, 2) # b Ph*Pw c
if self.norm is not None:
x = self.norm(x)
return x
def flops(self):
flops = 0
h, w = self.img_size
if self.norm is not None:
flops += h * w * self.embed_dim
return flops
class PatchUnEmbed(nn.Module):
r""" Image to Patch Unembedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
def forward(self, x, x_size):
x = x.transpose(1, 2).view(x.shape[0], self.embed_dim, x_size[0], x_size[1]) # b Ph*Pw c
return x
def flops(self):
flops = 0
return flops
class Upsample(nn.Sequential):
"""Upsample module.
Args:
scale (int): Scale factor. Supported scales: 2^n and 3.
num_feat (int): Channel number of intermediate features.
"""
def __init__(self, scale, num_feat):
m = []
if (scale & (scale - 1)) == 0: # scale = 2^n
for _ in range(int(math.log(scale, 2))):
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(2))
elif scale == 3:
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(3))
else:
raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
super(Upsample, self).__init__(*m)
class UpsampleOneStep(nn.Sequential):
"""UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
Used in lightweight SR to save parameters.
Args:
scale (int): Scale factor. Supported scales: 2^n and 3.
num_feat (int): Channel number of intermediate features.
"""
def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
self.num_feat = num_feat
self.input_resolution = input_resolution
m = []
m.append(nn.Conv2d(num_feat, (scale**2) * num_out_ch, 3, 1, 1))
m.append(nn.PixelShuffle(scale))
super(UpsampleOneStep, self).__init__(*m)
def flops(self):
h, w = self.input_resolution
flops = h * w * self.num_feat * 3 * 9
return flops
@ARCH_REGISTRY.register()
class SwinIR(nn.Module):
r""" SwinIR
A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
Args:
img_size (int | tuple(int)): Input image size. Default 64
patch_size (int | tuple(int)): Patch size. Default: 1
in_chans (int): Number of input image channels. Default: 3
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
img_range: Image range. 1. or 255.
upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
"""
def __init__(self,
img_size=64,
patch_size=1,
in_chans=3,
embed_dim=96,
depths=(6, 6, 6, 6),
num_heads=(6, 6, 6, 6),
window_size=7,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.1,
norm_layer=nn.LayerNorm,
ape=False,
patch_norm=True,
use_checkpoint=False,
upscale=2,
img_range=1.,
upsampler='',
resi_connection='1conv',
**kwargs):
super(SwinIR, self).__init__()
num_in_ch = in_chans
num_out_ch = in_chans
num_feat = 64
self.img_range = img_range
if in_chans == 3:
rgb_mean = (0.4488, 0.4371, 0.4040)
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
else:
self.mean = torch.zeros(1, 1, 1, 1)
self.upscale = upscale
self.upsampler = upsampler
# ------------------------- 1, shallow feature extraction ------------------------- #
self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
# ------------------------- 2, deep feature extraction ------------------------- #
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.num_features = embed_dim
self.mlp_ratio = mlp_ratio
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=embed_dim,
embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
# merge non-overlapping patches into image
self.patch_unembed = PatchUnEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=embed_dim,
embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
trunc_normal_(self.absolute_pos_embed, std=.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
# build Residual Swin Transformer blocks (RSTB)
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = RSTB(
dim=embed_dim,
input_resolution=(patches_resolution[0], patches_resolution[1]),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
norm_layer=norm_layer,
downsample=None,
use_checkpoint=use_checkpoint,
img_size=img_size,
patch_size=patch_size,
resi_connection=resi_connection)
self.layers.append(layer)
self.norm = norm_layer(self.num_features)
# build the last conv layer in deep feature extraction
if resi_connection == '1conv':
self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
elif resi_connection == '3conv':
# to save parameters and memory
self.conv_after_body = nn.Sequential(
nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
# ------------------------- 3, high quality image reconstruction ------------------------- #
if self.upsampler == 'pixelshuffle':
# for classical SR
self.conv_before_upsample = nn.Sequential(
nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
self.upsample = Upsample(upscale, num_feat)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
elif self.upsampler == 'pixelshuffledirect':
# for lightweight SR (to save parameters)
self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
(patches_resolution[0], patches_resolution[1]))
elif self.upsampler == 'nearest+conv':
# for real-world SR (less artifacts)
assert self.upscale == 4, 'only support x4 now.'
self.conv_before_upsample = nn.Sequential(
nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
else:
# for image denoising and JPEG compression artifact reduction
self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'absolute_pos_embed'}
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {'relative_position_bias_table'}
def forward_features(self, x):
x_size = (x.shape[2], x.shape[3])
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x, x_size)
x = self.norm(x) # b seq_len c
x = self.patch_unembed(x, x_size)
return x
def forward(self, x):
self.mean = self.mean.type_as(x)
x = (x - self.mean) * self.img_range
if self.upsampler == 'pixelshuffle':
# for classical SR
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.conv_before_upsample(x)
x = self.conv_last(self.upsample(x))
elif self.upsampler == 'pixelshuffledirect':
# for lightweight SR
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.upsample(x)
elif self.upsampler == 'nearest+conv':
# for real-world SR
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.conv_before_upsample(x)
x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
x = self.conv_last(self.lrelu(self.conv_hr(x)))
else:
# for image denoising and JPEG compression artifact reduction
x_first = self.conv_first(x)
res = self.conv_after_body(self.forward_features(x_first)) + x_first
x = x + self.conv_last(res)
x = x / self.img_range + self.mean
return x
def flops(self):
flops = 0
h, w = self.patches_resolution
flops += h * w * 3 * self.embed_dim * 9
flops += self.patch_embed.flops()
for layer in self.layers:
flops += layer.flops()
flops += h * w * 3 * self.embed_dim * self.embed_dim
flops += self.upsample.flops()
return flops
if __name__ == '__main__':
upscale = 4
window_size = 8
height = (1024 // upscale // window_size + 1) * window_size
width = (720 // upscale // window_size + 1) * window_size
model = SwinIR(
upscale=2,
img_size=(height, width),
window_size=window_size,
img_range=1.,
depths=[6, 6, 6, 6],
embed_dim=60,
num_heads=[6, 6, 6, 6],
mlp_ratio=2,
upsampler='pixelshuffledirect')
print(model)
print(height, width, model.flops() / 1e9)
x = torch.randn((1, 3, height, width))
x = model(x)
print(x.shape)
import torch
from torch import nn as nn
from torch.nn import functional as F
from basicsr.utils.registry import ARCH_REGISTRY
from .arch_util import flow_warp
class BasicModule(nn.Module):
"""Basic module of SPyNet.
Note that unlike the architecture in spynet_arch.py, the basic module
here contains batch normalization.
"""
def __init__(self):
super(BasicModule, self).__init__()
self.basic_module = nn.Sequential(
nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3, bias=False),
nn.BatchNorm2d(32), nn.ReLU(inplace=True),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3, bias=False),
nn.BatchNorm2d(64), nn.ReLU(inplace=True),
nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3, bias=False),
nn.BatchNorm2d(32), nn.ReLU(inplace=True),
nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3, bias=False),
nn.BatchNorm2d(16), nn.ReLU(inplace=True),
nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3))
def forward(self, tensor_input):
"""
Args:
tensor_input (Tensor): Input tensor with shape (b, 8, h, w).
8 channels contain:
[reference image (3), neighbor image (3), initial flow (2)].
Returns:
Tensor: Estimated flow with shape (b, 2, h, w)
"""
return self.basic_module(tensor_input)
class SPyNetTOF(nn.Module):
"""SPyNet architecture for TOF.
Note that this implementation is specifically for TOFlow. Please use :file:`spynet_arch.py` for general use.
They differ in the following aspects:
1. The basic modules here contain BatchNorm.
2. Normalization and denormalization are not done here, as they are done in TOFlow.
``Paper: Optical Flow Estimation using a Spatial Pyramid Network``
Reference: https://github.com/Coldog2333/pytoflow
Args:
load_path (str): Path for pretrained SPyNet. Default: None.
"""
def __init__(self, load_path=None):
super(SPyNetTOF, self).__init__()
self.basic_module = nn.ModuleList([BasicModule() for _ in range(4)])
if load_path:
self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params'])
def forward(self, ref, supp):
"""
Args:
ref (Tensor): Reference image with shape of (b, 3, h, w).
supp: The supporting image to be warped: (b, 3, h, w).
Returns:
Tensor: Estimated optical flow: (b, 2, h, w).
"""
num_batches, _, h, w = ref.size()
ref = [ref]
supp = [supp]
# generate downsampled frames
for _ in range(3):
ref.insert(0, F.avg_pool2d(input=ref[0], kernel_size=2, stride=2, count_include_pad=False))
supp.insert(0, F.avg_pool2d(input=supp[0], kernel_size=2, stride=2, count_include_pad=False))
# flow computation
flow = ref[0].new_zeros(num_batches, 2, h // 16, w // 16)
for i in range(4):
flow_up = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0
flow = flow_up + self.basic_module[i](
torch.cat([ref[i], flow_warp(supp[i], flow_up.permute(0, 2, 3, 1)), flow_up], 1))
return flow
@ARCH_REGISTRY.register()
class TOFlow(nn.Module):
"""PyTorch implementation of TOFlow.
In TOFlow, the LR frames are pre-upsampled and have the same size with the GT frames.
``Paper: Video Enhancement with Task-Oriented Flow``
Reference: https://github.com/anchen1011/toflow
Reference: https://github.com/Coldog2333/pytoflow
Args:
adapt_official_weights (bool): Whether to adapt the weights translated
from the official implementation. Set to false if you want to
train from scratch. Default: False
"""
def __init__(self, adapt_official_weights=False):
super(TOFlow, self).__init__()
self.adapt_official_weights = adapt_official_weights
self.ref_idx = 0 if adapt_official_weights else 3
self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
# flow estimation module
self.spynet = SPyNetTOF()
# reconstruction module
self.conv_1 = nn.Conv2d(3 * 7, 64, 9, 1, 4)
self.conv_2 = nn.Conv2d(64, 64, 9, 1, 4)
self.conv_3 = nn.Conv2d(64, 64, 1)
self.conv_4 = nn.Conv2d(64, 3, 1)
# activation function
self.relu = nn.ReLU(inplace=True)
def normalize(self, img):
return (img - self.mean) / self.std
def denormalize(self, img):
return img * self.std + self.mean
def forward(self, lrs):
"""
Args:
lrs: Input lr frames: (b, 7, 3, h, w).
Returns:
Tensor: SR frame: (b, 3, h, w).
"""
# In the official implementation, the 0-th frame is the reference frame
if self.adapt_official_weights:
lrs = lrs[:, [3, 0, 1, 2, 4, 5, 6], :, :, :]
num_batches, num_lrs, _, h, w = lrs.size()
lrs = self.normalize(lrs.view(-1, 3, h, w))
lrs = lrs.view(num_batches, num_lrs, 3, h, w)
lr_ref = lrs[:, self.ref_idx, :, :, :]
lr_aligned = []
for i in range(7): # 7 frames
if i == self.ref_idx:
lr_aligned.append(lr_ref)
else:
lr_supp = lrs[:, i, :, :, :]
flow = self.spynet(lr_ref, lr_supp)
lr_aligned.append(flow_warp(lr_supp, flow.permute(0, 2, 3, 1)))
# reconstruction
hr = torch.stack(lr_aligned, dim=1)
hr = hr.view(num_batches, -1, h, w)
hr = self.relu(self.conv_1(hr))
hr = self.relu(self.conv_2(hr))
hr = self.relu(self.conv_3(hr))
hr = self.conv_4(hr) + lr_ref
return self.denormalize(hr)
import os
import torch
from collections import OrderedDict
from torch import nn as nn
from torchvision.models import vgg as vgg
from basicsr.utils.registry import ARCH_REGISTRY
VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth'
NAMES = {
'vgg11': [
'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2',
'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2',
'pool5'
],
'vgg13': [
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4',
'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5'
],
'vgg16': [
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2',
'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
'pool5'
],
'vgg19': [
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1',
'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1',
'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'
]
}
def insert_bn(names):
"""Insert bn layer after each conv.
Args:
names (list): The list of layer names.
Returns:
list: The list of layer names with bn layers.
"""
names_bn = []
for name in names:
names_bn.append(name)
if 'conv' in name:
position = name.replace('conv', '')
names_bn.append('bn' + position)
return names_bn
@ARCH_REGISTRY.register()
class VGGFeatureExtractor(nn.Module):
"""VGG network for feature extraction.
In this implementation, we allow users to choose whether use normalization
in the input feature and the type of vgg network. Note that the pretrained
path must fit the vgg type.
Args:
layer_name_list (list[str]): Forward function returns the corresponding
features according to the layer_name_list.
Example: {'relu1_1', 'relu2_1', 'relu3_1'}.
vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
use_input_norm (bool): If True, normalize the input image. Importantly,
the input feature must in the range [0, 1]. Default: True.
range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
Default: False.
requires_grad (bool): If true, the parameters of VGG network will be
optimized. Default: False.
remove_pooling (bool): If true, the max pooling operations in VGG net
will be removed. Default: False.
pooling_stride (int): The stride of max pooling operation. Default: 2.
"""
def __init__(self,
layer_name_list,
vgg_type='vgg19',
use_input_norm=True,
range_norm=False,
requires_grad=False,
remove_pooling=False,
pooling_stride=2):
super(VGGFeatureExtractor, self).__init__()
self.layer_name_list = layer_name_list
self.use_input_norm = use_input_norm
self.range_norm = range_norm
self.names = NAMES[vgg_type.replace('_bn', '')]
if 'bn' in vgg_type:
self.names = insert_bn(self.names)
# only borrow layers that will be used to avoid unused params
max_idx = 0
for v in layer_name_list:
idx = self.names.index(v)
if idx > max_idx:
max_idx = idx
if os.path.exists(VGG_PRETRAIN_PATH):
vgg_net = getattr(vgg, vgg_type)(pretrained=False)
state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage)
vgg_net.load_state_dict(state_dict)
else:
vgg_net = getattr(vgg, vgg_type)(pretrained=True)
features = vgg_net.features[:max_idx + 1]
modified_net = OrderedDict()
for k, v in zip(self.names, features):
if 'pool' in k:
# if remove_pooling is true, pooling operation will be removed
if remove_pooling:
continue
else:
# in some cases, we may want to change the default stride
modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride)
else:
modified_net[k] = v
self.vgg_net = nn.Sequential(modified_net)
if not requires_grad:
self.vgg_net.eval()
for param in self.parameters():
param.requires_grad = False
else:
self.vgg_net.train()
for param in self.parameters():
param.requires_grad = True
if self.use_input_norm:
# the mean is for image with range [0, 1]
self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
# the std is for image with range [0, 1]
self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
if self.range_norm:
x = (x + 1) / 2
if self.use_input_norm:
x = (x - self.mean) / self.std
output = {}
for key, layer in self.vgg_net._modules.items():
x = layer(x)
if key in self.layer_name_list:
output[key] = x.clone()
return output
import importlib
import numpy as np
import random
import torch
import torch.utils.data
from copy import deepcopy
from functools import partial
from os import path as osp
from basicsr.data.prefetch_dataloader import PrefetchDataLoader
from basicsr.utils import get_root_logger, scandir
from basicsr.utils.dist_util import get_dist_info
from basicsr.utils.registry import DATASET_REGISTRY
__all__ = ['build_dataset', 'build_dataloader']
# automatically scan and import dataset modules for registry
# scan all the files under the data folder with '_dataset' in file names
data_folder = osp.dirname(osp.abspath(__file__))
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
# import all the dataset modules
_dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames]
def build_dataset(dataset_opt):
"""Build dataset from options.
Args:
dataset_opt (dict): Configuration for dataset. It must contain:
name (str): Dataset name.
type (str): Dataset type.
"""
dataset_opt = deepcopy(dataset_opt)
dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
logger = get_root_logger()
logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} is built.')
return dataset
def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
"""Build dataloader.
Args:
dataset (torch.utils.data.Dataset): Dataset.
dataset_opt (dict): Dataset options. It contains the following keys:
phase (str): 'train' or 'val'.
num_worker_per_gpu (int): Number of workers for each GPU.
batch_size_per_gpu (int): Training batch size for each GPU.
num_gpu (int): Number of GPUs. Used only in the train phase.
Default: 1.
dist (bool): Whether in distributed training. Used only in the train
phase. Default: False.
sampler (torch.utils.data.sampler): Data sampler. Default: None.
seed (int | None): Seed. Default: None
"""
phase = dataset_opt['phase']
rank, _ = get_dist_info()
if phase == 'train':
if dist: # distributed training
batch_size = dataset_opt['batch_size_per_gpu']
num_workers = dataset_opt['num_worker_per_gpu']
else: # non-distributed training
multiplier = 1 if num_gpu == 0 else num_gpu
batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
dataloader_args = dict(
dataset=dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
sampler=sampler,
drop_last=True)
if sampler is None:
dataloader_args['shuffle'] = True
dataloader_args['worker_init_fn'] = partial(
worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
elif phase in ['val', 'test']: # validation
dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
else:
raise ValueError(f"Wrong dataset phase: {phase}. Supported ones are 'train', 'val' and 'test'.")
dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
dataloader_args['persistent_workers'] = dataset_opt.get('persistent_workers', False)
prefetch_mode = dataset_opt.get('prefetch_mode')
if prefetch_mode == 'cpu': # CPUPrefetcher
num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
logger = get_root_logger()
logger.info(f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}')
return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
else:
# prefetch_mode=None: Normal dataloader
# prefetch_mode='cuda': dataloader for CUDAPrefetcher
return torch.utils.data.DataLoader(**dataloader_args)
def worker_init_fn(worker_id, num_workers, rank, seed):
# Set the worker seed to num_workers * rank + worker_id + seed
worker_seed = num_workers * rank + worker_id + seed
np.random.seed(worker_seed)
random.seed(worker_seed)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment