Commit b7536f78 authored by limm's avatar limm
Browse files

add a to another part of mmgeneration code

parent 57e0e891
Pipeline #2777 canceled with stages
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, kaiming_init, normal_init, xavier_init
from torch.nn import init
def generation_init_weights(module, init_type='normal', init_gain=0.02):
"""Default initialization of network weights for image generation.
By default, we use normal init, but xavier and kaiming might work
better for some applications.
Args:
module (nn.Module): Module to be initialized.
init_type (str): The name of an initialization method:
normal | xavier | kaiming | orthogonal.
init_gain (float): Scaling factor for normal, xavier and
orthogonal.
"""
def init_func(m):
"""Initialization function.
Args:
m (nn.Module): Module to be initialized.
"""
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1
or classname.find('Linear') != -1):
if init_type == 'normal':
normal_init(m, 0.0, init_gain)
elif init_type == 'xavier':
xavier_init(m, gain=init_gain, distribution='normal')
elif init_type == 'kaiming':
kaiming_init(
m,
a=0,
mode='fan_in',
nonlinearity='leaky_relu',
distribution='normal')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight, gain=init_gain)
init.constant_(m.bias.data, 0.0)
else:
raise NotImplementedError(
f"Initialization method '{init_type}' is not implemented")
elif classname.find('BatchNorm2d') != -1:
# BatchNorm Layer's weight is not a matrix;
# only normal distribution applies.
normal_init(m, 1.0, init_gain)
module.apply(init_func)
class UnetSkipConnectionBlock(nn.Module):
"""Construct a Unet submodule with skip connections, with the following.
structure: downsampling - `submodule` - upsampling.
Args:
outer_channels (int): Number of channels at the outer conv layer.
inner_channels (int): Number of channels at the inner conv layer.
in_channels (int): Number of channels in input images/features. If is
None, equals to `outer_channels`. Default: None.
submodule (UnetSkipConnectionBlock): Previously constructed submodule.
Default: None.
is_outermost (bool): Whether this module is the outermost module.
Default: False.
is_innermost (bool): Whether this module is the innermost module.
Default: False.
norm_cfg (dict): Config dict to build norm layer. Default:
`dict(type='BN')`.
use_dropout (bool): Whether to use dropout layers. Default: False.
"""
def __init__(self,
outer_channels,
inner_channels,
in_channels=None,
submodule=None,
is_outermost=False,
is_innermost=False,
norm_cfg=dict(type='BN'),
use_dropout=False):
super().__init__()
# cannot be both outermost and innermost
assert not (is_outermost and is_innermost), (
"'is_outermost' and 'is_innermost' cannot be True"
'at the same time.')
self.is_outermost = is_outermost
assert isinstance(norm_cfg, dict), ("'norm_cfg' should be dict, but"
f'got {type(norm_cfg)}')
assert 'type' in norm_cfg, "'norm_cfg' must have key 'type'"
# We use norm layers in the unet skip connection block.
# Only for IN, use bias since it does not have affine parameters.
use_bias = norm_cfg['type'] == 'IN'
kernel_size = 4
stride = 2
padding = 1
if in_channels is None:
in_channels = outer_channels
down_conv_cfg = dict(type='Conv2d')
down_norm_cfg = norm_cfg
down_act_cfg = dict(type='LeakyReLU', negative_slope=0.2)
up_conv_cfg = dict(type='deconv')
up_norm_cfg = norm_cfg
up_act_cfg = dict(type='ReLU')
up_in_channels = inner_channels * 2
up_bias = use_bias
middle = [submodule]
upper = []
if is_outermost:
down_act_cfg = None
down_norm_cfg = None
up_bias = True
up_norm_cfg = None
upper = [nn.Tanh()]
elif is_innermost:
down_norm_cfg = None
up_in_channels = inner_channels
middle = []
else:
upper = [nn.Dropout(0.5)] if use_dropout else []
down = [
ConvModule(
in_channels=in_channels,
out_channels=inner_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=use_bias,
conv_cfg=down_conv_cfg,
norm_cfg=down_norm_cfg,
act_cfg=down_act_cfg,
order=('act', 'conv', 'norm'))
]
up = [
ConvModule(
in_channels=up_in_channels,
out_channels=outer_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=up_bias,
conv_cfg=up_conv_cfg,
norm_cfg=up_norm_cfg,
act_cfg=up_act_cfg,
order=('act', 'conv', 'norm'))
]
model = down + middle + up + upper
self.model = nn.Sequential(*model)
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
if self.is_outermost:
return self.model(x)
# add skip connections
return torch.cat([x, self.model(x)], 1)
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.nn as nn
from mmgen.models.builder import MODULES
@MODULES.register_module('SPE')
@MODULES.register_module('SPE2d')
class SinusoidalPositionalEmbedding(nn.Module):
"""Sinusoidal Positional Embedding 1D or 2D (SPE/SPE2d).
This module is a modified from:
https://github.com/pytorch/fairseq/blob/master/fairseq/modules/sinusoidal_positional_embedding.py # noqa
Based on the original SPE in single dimension, we implement a 2D sinusoidal
positional encodding (SPE2d), as introduced in Positional Encoding as
Spatial Inductive Bias in GANs, CVPR'2021.
Args:
embedding_dim (int): The number of dimensions for the positional
encoding.
padding_idx (int | list[int]): The index for the padding contents. The
padding positions will obtain an encoding vector filling in zeros.
init_size (int, optional): The initial size of the positional buffer.
Defaults to 1024.
div_half_dim (bool, optional): If true, the embedding will be divided
by :math:`d/2`. Otherwise, it will be divided by
:math:`(d/2 -1)`. Defaults to False.
center_shift (int | None, optional): Shift the center point to some
index. Defaults to None.
"""
def __init__(self,
embedding_dim,
padding_idx,
init_size=1024,
div_half_dim=False,
center_shift=None):
super().__init__()
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.div_half_dim = div_half_dim
self.center_shift = center_shift
self.weights = SinusoidalPositionalEmbedding.get_embedding(
init_size, embedding_dim, padding_idx, self.div_half_dim)
self.register_buffer('_float_tensor', torch.FloatTensor(1))
self.max_positions = int(1e5)
@staticmethod
def get_embedding(num_embeddings,
embedding_dim,
padding_idx=None,
div_half_dim=False):
"""Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
assert embedding_dim % 2 == 0, (
'In this version, we request '
f'embedding_dim divisible by 2 but got {embedding_dim}')
# there is a little difference from the original paper.
half_dim = embedding_dim // 2
if not div_half_dim:
emb = np.log(10000) / (half_dim - 1)
else:
emb = np.log(1e4) / half_dim
# compute exp(-log10000 / d * i)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
emb = torch.arange(
num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat([torch.sin(emb), torch.cos(emb)],
dim=1).view(num_embeddings, -1)
if padding_idx is not None:
emb[padding_idx, :] = 0
return emb
def forward(self, input, **kwargs):
"""Input is expected to be of size [bsz x seqlen].
Returned tensor is expected to be of size [bsz x seq_len x emb_dim]
"""
assert input.dim() == 2 or input.dim(
) == 4, 'Input dimension should be 2 (1D) or 4(2D)'
if input.dim() == 4:
return self.make_grid2d_like(input, **kwargs)
b, seq_len = input.shape
max_pos = self.padding_idx + 1 + seq_len
if self.weights is None or max_pos > self.weights.size(0):
# recompute/expand embedding if needed
self.weights = SinusoidalPositionalEmbedding.get_embedding(
max_pos, self.embedding_dim, self.padding_idx)
self.weights = self.weights.to(self._float_tensor)
positions = self.make_positions(input, self.padding_idx).to(
self._float_tensor.device)
return self.weights.index_select(0, positions.view(-1)).view(
b, seq_len, self.embedding_dim).detach()
def make_positions(self, input, padding_idx):
mask = input.ne(padding_idx).int()
return (torch.cumsum(mask, dim=1).type_as(mask) *
mask).long() + padding_idx
def make_grid2d(self, height, width, num_batches=1, center_shift=None):
h, w = height, width
# if `center_shift` is not given from the outside, use
# `self.center_shift`
if center_shift is None:
center_shift = self.center_shift
h_shift = 0
w_shift = 0
# center shift to the input grid
if center_shift is not None:
# if h/w is even, the left center should be aligned with
# center shift
if h % 2 == 0:
h_left_center = h // 2
h_shift = center_shift - h_left_center
else:
h_center = h // 2 + 1
h_shift = center_shift - h_center
if w % 2 == 0:
w_left_center = w // 2
w_shift = center_shift - w_left_center
else:
w_center = w // 2 + 1
w_shift = center_shift - w_center
# Note that the index is started from 1 since zero will be padding idx.
# axis -- (b, h or w)
x_axis = torch.arange(1, w + 1).unsqueeze(0).repeat(num_batches,
1) + w_shift
y_axis = torch.arange(1, h + 1).unsqueeze(0).repeat(num_batches,
1) + h_shift
# emb -- (b, emb_dim, h or w)
x_emb = self(x_axis).transpose(1, 2)
y_emb = self(y_axis).transpose(1, 2)
# make grid for x/y axis
# Note that repeat will copy data. If use learned emb, expand may be
# better.
x_grid = x_emb.unsqueeze(2).repeat(1, 1, h, 1)
y_grid = y_emb.unsqueeze(3).repeat(1, 1, 1, w)
# cat grid -- (b, 2 x emb_dim, h, w)
grid = torch.cat([x_grid, y_grid], dim=1)
return grid.detach()
def make_grid2d_like(self, x, center_shift=None):
"""Input tensor with shape of (b, ..., h, w) Return tensor with shape
of (b, 2 x emb_dim, h, w)
Note that the positional embedding highly depends on the the function,
``make_positions``.
"""
h, w = x.shape[-2:]
grid = self.make_grid2d(h, w, x.size(0), center_shift)
return grid.to(x)
@MODULES.register_module('CSG2d')
@MODULES.register_module('CSG')
@MODULES.register_module()
class CatersianGrid(nn.Module):
"""Catersian Grid for 2d tensor.
The Catersian Grid is a common-used positional encoding in deep learning.
In this implementation, we follow the convention of ``grid_sample`` in
PyTorch. In other words, ``[-1, -1]`` denotes the left-top corner while
``[1, 1]`` denotes the right-botton corner.
"""
def forward(self, x, **kwargs):
assert x.dim() == 4
return self.make_grid2d_like(x, **kwargs)
def make_grid2d(self, height, width, num_batches=1, requires_grad=False):
h, w = height, width
grid_y, grid_x = torch.meshgrid(torch.arange(0, h), torch.arange(0, w))
grid_x = 2 * grid_x / max(float(w) - 1., 1.) - 1.
grid_y = 2 * grid_y / max(float(h) - 1., 1.) - 1.
grid = torch.stack((grid_x, grid_y), 0)
grid.requires_grad = requires_grad
grid = torch.unsqueeze(grid, 0)
grid = grid.repeat(num_batches, 1, 1, 1)
return grid
def make_grid2d_like(self, x, requires_grad=False):
h, w = x.shape[-2:]
grid = self.make_grid2d(h, w, x.size(0), requires_grad=requires_grad)
return grid.to(x)
# Copyright (c) OpenMMLab. All rights reserved.
from .generator_discriminator import (SinGANMultiScaleDiscriminator,
SinGANMultiScaleGenerator)
from .positional_encoding import SinGANMSGeneratorPE
__all__ = [
'SinGANMultiScaleDiscriminator', 'SinGANMultiScaleGenerator',
'SinGANMSGeneratorPE'
]
# Copyright (c) OpenMMLab. All rights reserved.
from functools import partial
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.runner import load_state_dict
from mmcv.utils import print_log
from mmgen.models.builder import MODULES
from mmgen.utils import get_root_logger
from .modules import DiscriminatorBlock, GeneratorBlock
@MODULES.register_module()
class SinGANMultiScaleGenerator(nn.Module):
"""Multi-Scale Generator used in SinGAN.
More details can be found in: Singan: Learning a Generative Model from a
Single Natural Image, ICCV'19.
Notes:
- In this version, we adopt the interpolation function from the official
PyTorch APIs, which is different from the original implementation by the
authors. However, in our experiments, this influence can be ignored.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
num_scales (int): The number of scales/stages in generator. Note
that this number is counted from zero, which is the same as the
original paper.
kernel_size (int, optional): Kernel size, same as :obj:`nn.Conv2d`.
Defaults to 3.
padding (int, optional): Padding for the convolutional layer, same as
:obj:`nn.Conv2d`. Defaults to 0.
num_layers (int, optional): The number of convolutional layers in each
generator block. Defaults to 5.
base_channels (int, optional): The basic channels for convolutional
layers in the generator block. Defaults to 32.
min_feat_channels (int, optional): Minimum channels for the feature
maps in the generator block. Defaults to 32.
out_act_cfg (dict | None, optional): Configs for output activation
layer. Defaults to dict(type='Tanh').
"""
def __init__(self,
in_channels,
out_channels,
num_scales,
kernel_size=3,
padding=0,
num_layers=5,
base_channels=32,
min_feat_channels=32,
out_act_cfg=dict(type='Tanh'),
**kwargs):
super().__init__()
self.pad_head = int((kernel_size - 1) / 2 * num_layers)
self.blocks = nn.ModuleList()
self.upsample = partial(
F.interpolate, mode='bicubic', align_corners=True)
for scale in range(num_scales + 1):
base_ch = min(base_channels * pow(2, int(np.floor(scale / 4))),
128)
min_feat_ch = min(
min_feat_channels * pow(2, int(np.floor(scale / 4))), 128)
self.blocks.append(
GeneratorBlock(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
padding=padding,
num_layers=num_layers,
base_channels=base_ch,
min_feat_channels=min_feat_ch,
out_act_cfg=out_act_cfg,
**kwargs))
self.noise_padding_layer = nn.ZeroPad2d(self.pad_head)
self.img_padding_layer = nn.ZeroPad2d(self.pad_head)
def forward(self,
input_sample,
fixed_noises,
noise_weights,
rand_mode,
curr_scale,
num_batches=1,
get_prev_res=False,
return_noise=False):
"""Forward function.
Args:
input_sample (Tensor | None): The input for generator. In the
original implementation, a tensor filled with zeros is adopted.
If None is given, we will construct it from the first fixed
noises.
fixed_noises (list[Tensor]): List of the fixed noises in SinGAN.
noise_weights (list[float]): List of the weights for random noises.
rand_mode (str): Choices from ['rand', 'recon']. In ``rand`` mode,
it will sample from random noises. Otherwise, the
reconstruction for the single image will be returned.
curr_scale (int): The scale for the current inference or training.
num_batches (int, optional): The number of batches. Defaults to 1.
get_prev_res (bool, optional): Whether to return results from
previous stages. Defaults to False.
return_noise (bool, optional): Whether to return noises tensor.
Defaults to False.
Returns:
Tensor | dict: Generated image tensor or dictionary containing \
more data.
"""
if get_prev_res or return_noise:
prev_res_list = []
noise_list = []
if input_sample is None:
input_sample = torch.zeros(
(num_batches, 3, fixed_noises[0].shape[-2],
fixed_noises[0].shape[-1])).to(fixed_noises[0])
g_res = input_sample
for stage in range(curr_scale + 1):
if rand_mode == 'recon':
noise_ = fixed_noises[stage]
else:
noise_ = torch.randn(num_batches,
*fixed_noises[stage].shape[1:]).to(g_res)
if return_noise:
noise_list.append(noise_)
# add padding at head
pad_ = (self.pad_head, ) * 4
noise_ = F.pad(noise_, pad_)
g_res_pad = F.pad(g_res, pad_)
noise = noise_ * noise_weights[stage] + g_res_pad
g_res = self.blocks[stage](noise.detach(), g_res)
if get_prev_res and stage != curr_scale:
prev_res_list.append(g_res)
# upsample, here we use interpolation from PyTorch
if stage != curr_scale:
h_next, w_next = fixed_noises[stage + 1].shape[-2:]
g_res = self.upsample(g_res, (h_next, w_next))
if get_prev_res or return_noise:
output_dict = dict(
fake_img=g_res,
prev_res_list=prev_res_list,
noise_batch=noise_list)
return output_dict
return g_res
def check_and_load_prev_weight(self, curr_scale):
if curr_scale == 0:
return
prev_ch = self.blocks[curr_scale - 1].base_channels
curr_ch = self.blocks[curr_scale].base_channels
prev_in_ch = self.blocks[curr_scale - 1].in_channels
curr_in_ch = self.blocks[curr_scale].in_channels
if prev_ch == curr_ch and prev_in_ch == curr_in_ch:
load_state_dict(
self.blocks[curr_scale],
self.blocks[curr_scale - 1].state_dict(),
logger=get_root_logger())
print_log('Successfully load pretrianed model from last scale.')
else:
print_log(
'Cannot load pretrained model from last scale since'
f' prev_ch({prev_ch}) != curr_ch({curr_ch})'
f' or prev_in_ch({prev_in_ch}) != curr_in_ch({curr_in_ch})')
@MODULES.register_module()
class SinGANMultiScaleDiscriminator(nn.Module):
"""Multi-Scale Discriminator used in SinGAN.
More details can be found in: Singan: Learning a Generative Model from a
Single Natural Image, ICCV'19.
Args:
in_channels (int): Input channels.
num_scales (int): The number of scales/stages in generator. Note
that this number is counted from zero, which is the same as the
original paper.
kernel_size (int, optional): Kernel size, same as :obj:`nn.Conv2d`.
Defaults to 3.
padding (int, optional): Padding for the convolutional layer, same as
:obj:`nn.Conv2d`. Defaults to 0.
num_layers (int, optional): The number of convolutional layers in each
generator block. Defaults to 5.
base_channels (int, optional): The basic channels for convolutional
layers in the generator block. Defaults to 32.
min_feat_channels (int, optional): Minimum channels for the feature
maps in the generator block. Defaults to 32.
"""
def __init__(self,
in_channels,
num_scales,
kernel_size=3,
padding=0,
num_layers=5,
base_channels=32,
min_feat_channels=32,
**kwargs):
super().__init__()
self.blocks = nn.ModuleList()
for scale in range(num_scales + 1):
base_ch = min(base_channels * pow(2, int(np.floor(scale / 4))),
128)
min_feat_ch = min(
min_feat_channels * pow(2, int(np.floor(scale / 4))), 128)
self.blocks.append(
DiscriminatorBlock(
in_channels=in_channels,
kernel_size=kernel_size,
padding=padding,
num_layers=num_layers,
base_channels=base_ch,
min_feat_channels=min_feat_ch,
**kwargs))
def forward(self, x, curr_scale):
"""Forward function.
Args:
x (Tensor): Input feature map.
curr_scale (int): Current scale for discriminator. If in testing,
you need to set it to the last scale.
Returns:
Tensor: Discriminative results.
"""
out = self.blocks[curr_scale](x)
return out
def check_and_load_prev_weight(self, curr_scale):
if curr_scale == 0:
return
prev_ch = self.blocks[curr_scale - 1].base_channels
curr_ch = self.blocks[curr_scale].base_channels
if prev_ch == curr_ch:
self.blocks[curr_scale].load_state_dict(
self.blocks[curr_scale - 1].state_dict())
print_log('Successfully load pretrianed model from last scale.')
else:
print_log('Cannot load pretrained model from last scale since'
f' prev_ch({prev_ch}) != curr_ch({curr_ch})')
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import ConvModule, constant_init, normal_init
from mmcv.runner import load_checkpoint
from mmcv.utils.parrots_wrapper import _BatchNorm
from mmgen.utils import get_root_logger
class GeneratorBlock(nn.Module):
"""Generator block used in SinGAN.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
num_scales (int): The number of scales/stages in generator. Note
that this number is counted from zero, which is the same as the
original paper.
kernel_size (int, optional): Kernel size, same as :obj:`nn.Conv2d`.
Defaults to 3.
padding (int, optional): Padding for the convolutional layer, same as
:obj:`nn.Conv2d`. Defaults to 0.
num_layers (int, optional): The number of convolutional layers in each
generator block. Defaults to 5.
base_channels (int, optional): The basic channels for convolutional
layers in the generator block. Defaults to 32.
min_feat_channels (int, optional): Minimum channels for the feature
maps in the generator block. Defaults to 32.
out_act_cfg (dict | None, optional): Configs for output activation
layer. Defaults to dict(type='Tanh').
stride (int, optional): Same as :obj:`nn.Conv2d`. Defaults to 1.
allow_no_residual (bool, optional): Whether to allow no residual link
in this block. Defaults to False.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
padding,
num_layers,
base_channels,
min_feat_channels,
out_act_cfg=dict(type='Tanh'),
stride=1,
allow_no_residual=False,
**kwargs):
super().__init__()
self.in_channels = in_channels
self.base_channels = base_channels
self.kernel_size = kernel_size
self.num_layers = num_layers
self.allow_no_residual = allow_no_residual
self.head = ConvModule(
in_channels=in_channels,
out_channels=base_channels,
kernel_size=kernel_size,
padding=padding,
stride=1,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='LeakyReLU', negative_slope=0.2),
**kwargs)
self.body = nn.Sequential()
for i in range(num_layers - 2):
feat_channels_ = int(base_channels / pow(2, (i + 1)))
block = ConvModule(
max(2 * feat_channels_, min_feat_channels),
max(feat_channels_, min_feat_channels),
kernel_size=kernel_size,
padding=padding,
stride=stride,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='LeakyReLU', negative_slope=0.2),
**kwargs)
self.body.add_module(f'block{i+1}', block)
self.tail = ConvModule(
max(feat_channels_, min_feat_channels),
out_channels,
kernel_size=kernel_size,
padding=padding,
stride=1,
norm_cfg=None,
act_cfg=out_act_cfg,
**kwargs)
self.init_weights()
def forward(self, x, prev):
"""Forward function.
Args:
x (Tensor): Input feature map.
prev (Tensor): Previous feature map.
Returns:
Tensor: Output feature map with the shape of (N, C, H, W).
"""
x = self.head(x)
x = self.body(x)
x = self.tail(x)
# if prev and x are not in the same shape at the channel dimension
if self.allow_no_residual and x.shape[1] != prev.shape[1]:
return x
return x + prev
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
normal_init(m, 0, 0.02)
elif isinstance(m, (_BatchNorm, nn.InstanceNorm2d)):
constant_init(m, 1)
else:
raise TypeError('pretrained must be a str or None but'
f' got {type(pretrained)} instead.')
class DiscriminatorBlock(nn.Module):
"""Discriminator Block used in SinGAN.
Args:
in_channels (int): Input channels.
base_channels (int): Base channels for this block.
min_feat_channels (int): The minimum channels for feature map.
kernel_size (int): Size of convolutional kernel, same as
:obj:`nn.Conv2d`.
padding (int): Padding for convolutional layer, same as
:obj:`nn.Conv2d`.
num_layers (int): The number of convolutional layers in this block.
norm_cfg (dict | None, optional): Config for the normalization layer.
Defaults to dict(type='BN').
act_cfg (dict | None, optional): Config for the activation layer.
Defaults to dict(type='LeakyReLU', negative_slope=0.2).
stride (int, optional): The stride for the convolutional layer, same as
:obj:`nn.Conv2d`. Defaults to 1.
"""
def __init__(self,
in_channels,
base_channels,
min_feat_channels,
kernel_size,
padding,
num_layers,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='LeakyReLU', negative_slope=0.2),
stride=1,
**kwargs):
super().__init__()
self.base_channels = base_channels
self.stride = stride
self.head = ConvModule(
in_channels,
base_channels,
kernel_size=kernel_size,
padding=padding,
stride=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
**kwargs)
self.body = nn.Sequential()
for i in range(num_layers - 2):
feat_channels_ = int(base_channels / pow(2, (i + 1)))
block = ConvModule(
max(2 * feat_channels_, min_feat_channels),
max(feat_channels_, min_feat_channels),
kernel_size=kernel_size,
padding=padding,
stride=stride,
conv_cfg=None,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
**kwargs)
self.body.add_module(f'block{i+1}', block)
self.tail = ConvModule(
max(feat_channels_, min_feat_channels),
1,
kernel_size=kernel_size,
padding=padding,
stride=1,
norm_cfg=None,
act_cfg=None,
**kwargs)
self.init_weights()
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input feature map with shape of (N, C, H, W).
Returns:
Tensor: Output feature map.
"""
x = self.head(x)
x = self.body(x)
x = self.tail(x)
return x
# TODO: study the effects of init functions
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
normal_init(m, 0, 0.02)
elif isinstance(m, (_BatchNorm, nn.InstanceNorm2d)):
constant_init(m, 1)
else:
raise TypeError('pretrained must be a str or None but'
f' got {type(pretrained)} instead.')
# Copyright (c) OpenMMLab. All rights reserved.
"""Implementation for Positional Encoding as Spatial Inductive Bias in GANs.
In this module, we provide necessary components to conduct experiments
mentioned in the paper: Positional Encoding as Spatial Inductive Bias in GANs.
More details can be found in: https://arxiv.org/pdf/2012.05217.pdf
"""
from functools import partial
import mmcv
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmgen.models.builder import MODULES, build_module
from .generator_discriminator import SinGANMultiScaleGenerator
from .modules import GeneratorBlock
@MODULES.register_module()
class SinGANMSGeneratorPE(SinGANMultiScaleGenerator):
"""Multi-Scale Generator used in SinGAN with positional encoding.
More details can be found in: Positional Encoding as Spatial Inductvie Bias
in GANs, CVPR'2021.
Notes:
- In this version, we adopt the interpolation function from the official
PyTorch APIs, which is different from the original implementation by the
authors. However, in our experiments, this influence can be ignored.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
num_scales (int): The number of scales/stages in generator. Note
that this number is counted from zero, which is the same as the
original paper.
kernel_size (int, optional): Kernel size, same as :obj:`nn.Conv2d`.
Defaults to 3.
padding (int, optional): Padding for the convolutional layer, same as
:obj:`nn.Conv2d`. Defaults to 0.
num_layers (int, optional): The number of convolutional layers in each
generator block. Defaults to 5.
base_channels (int, optional): The basic channels for convolutional
layers in the generator block. Defaults to 32.
min_feat_channels (int, optional): Minimum channels for the feature
maps in the generator block. Defaults to 32.
out_act_cfg (dict | None, optional): Configs for output activation
layer. Defaults to dict(type='Tanh').
padding_mode (str, optional): The mode of convolutional padding, same
as :obj:`nn.Conv2d`. Defaults to 'zero'.
pad_at_head (bool, optional): Whether to add padding at head.
Defaults to True.
interp_pad (bool, optional): The padding value of interpolating feature
maps. Defaults to False.
noise_with_pad (bool, optional): Whether the input fixed noises are
with explicit padding. Defaults to False.
positional_encoding (dict | None, optional): Configs for the positional
encoding. Defaults to None.
first_stage_in_channels (int | None, optional): The input channel of
the first generator block. If None, the first stage will adopt the
same input channels as other stages. Defaults to None.
"""
def __init__(self,
in_channels,
out_channels,
num_scales,
kernel_size=3,
padding=0,
num_layers=5,
base_channels=32,
min_feat_channels=32,
out_act_cfg=dict(type='Tanh'),
padding_mode='zero',
pad_at_head=True,
interp_pad=False,
noise_with_pad=False,
positional_encoding=None,
first_stage_in_channels=None,
**kwargs):
super(SinGANMultiScaleGenerator, self).__init__()
self.pad_at_head = pad_at_head
self.interp_pad = interp_pad
self.noise_with_pad = noise_with_pad
self.with_positional_encode = positional_encoding is not None
if self.with_positional_encode:
self.head_position_encode = build_module(positional_encoding)
self.pad_head = int((kernel_size - 1) / 2 * num_layers)
self.blocks = nn.ModuleList()
self.upsample = partial(
F.interpolate, mode='bicubic', align_corners=True)
for scale in range(num_scales + 1):
base_ch = min(base_channels * pow(2, int(np.floor(scale / 4))),
128)
min_feat_ch = min(
min_feat_channels * pow(2, int(np.floor(scale / 4))), 128)
if scale == 0:
in_ch = (
first_stage_in_channels
if first_stage_in_channels else in_channels)
else:
in_ch = in_channels
self.blocks.append(
GeneratorBlock(
in_channels=in_ch,
out_channels=out_channels,
kernel_size=kernel_size,
padding=padding,
num_layers=num_layers,
base_channels=base_ch,
min_feat_channels=min_feat_ch,
out_act_cfg=out_act_cfg,
padding_mode=padding_mode,
**kwargs))
if padding_mode == 'zero':
self.noise_padding_layer = nn.ZeroPad2d(self.pad_head)
self.img_padding_layer = nn.ZeroPad2d(self.pad_head)
self.mask_padding_layer = nn.ReflectionPad2d(self.pad_head)
elif padding_mode == 'reflect':
self.noise_padding_layer = nn.ReflectionPad2d(self.pad_head)
self.img_padding_layer = nn.ReflectionPad2d(self.pad_head)
self.mask_padding_layer = nn.ReflectionPad2d(self.pad_head)
mmcv.print_log('Using Reflection padding', 'mmgen')
else:
raise NotImplementedError(
f'Padding mode {padding_mode} is not supported')
def forward(self,
input_sample,
fixed_noises,
noise_weights,
rand_mode,
curr_scale,
num_batches=1,
get_prev_res=False,
return_noise=False):
"""Forward function.
Args:
input_sample (Tensor | None): The input for generator. In the
original implementation, a tensor filled with zeros is adopted.
If None is given, we will construct it from the first fixed
noises.
fixed_noises (list[Tensor]): List of the fixed noises in SinGAN.
noise_weights (list[float]): List of the weights for random noises.
rand_mode (str): Choices from ['rand', 'recon']. In ``rand`` mode,
it will sample from random noises. Otherwise, the
reconstruction for the single image will be returned.
curr_scale (int): The scale for the current inference or training.
num_batches (int, optional): The number of batches. Defaults to 1.
get_prev_res (bool, optional): Whether to return results from
previous stages. Defaults to False.
return_noise (bool, optional): Whether to return noises tensor.
Defaults to False.
Returns:
Tensor | dict: Generated image tensor or dictionary containing \
more data.
"""
if get_prev_res or return_noise:
prev_res_list = []
noise_list = []
if input_sample is None:
input_sample = torch.zeros(
(num_batches, 3, fixed_noises[0].shape[-2],
fixed_noises[0].shape[-1])).to(fixed_noises[0])
g_res = input_sample
for stage in range(curr_scale + 1):
if rand_mode == 'recon':
noise_ = fixed_noises[stage]
else:
noise_ = torch.randn(num_batches,
*fixed_noises[stage].shape[1:]).to(g_res)
if return_noise:
noise_list.append(noise_)
if self.with_positional_encode and stage == 0:
head_grid = self.head_position_encode(fixed_noises[0])
noise_ = noise_ + head_grid
# add padding at head
if self.pad_at_head:
if self.interp_pad:
if self.noise_with_pad:
size = noise_.shape[-2:]
else:
size = (noise_.size(2) + 2 * self.pad_head,
noise_.size(3) + 2 * self.pad_head)
noise_ = self.upsample(noise_, size)
g_res_pad = self.upsample(g_res, size)
else:
if not self.noise_with_pad:
noise_ = self.noise_padding_layer(noise_)
g_res_pad = self.img_padding_layer(g_res)
else:
g_res_pad = g_res
if stage == 0 and self.with_positional_encode:
noise = noise_ * noise_weights[stage]
else:
noise = noise_ * noise_weights[stage] + g_res_pad
g_res = self.blocks[stage](noise.detach(), g_res)
if get_prev_res and stage != curr_scale:
prev_res_list.append(g_res)
# upsample, here we use interpolation from PyTorch
if stage != curr_scale:
h_next, w_next = fixed_noises[stage + 1].shape[-2:]
if self.noise_with_pad:
# remove the additional padding if noise with pad
h_next -= 2 * self.pad_head
w_next -= 2 * self.pad_head
g_res = self.upsample(g_res, (h_next, w_next))
if get_prev_res or return_noise:
output_dict = dict(
fake_img=g_res,
prev_res_list=prev_res_list,
noise_batch=noise_list)
return output_dict
return g_res
# Copyright (c) OpenMMLab. All rights reserved.
from .generator_discriminator import ProjDiscriminator, SNGANGenerator
from .modules import SNGANDiscHeadResBlock, SNGANDiscResBlock, SNGANGenResBlock
__all__ = [
'ProjDiscriminator', 'SNGANGenerator', 'SNGANGenResBlock',
'SNGANDiscResBlock', 'SNGANDiscHeadResBlock'
]
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import (ConvModule, build_activation_layer, constant_init,
xavier_init)
from mmcv.runner import load_checkpoint
from mmcv.runner.checkpoint import _load_checkpoint_with_prefix
from mmcv.utils import is_list_of
from torch.nn.init import xavier_uniform_
from torch.nn.utils import spectral_norm
from mmgen.models.builder import MODULES, build_module
from mmgen.utils import check_dist_init
from mmgen.utils.logger import get_root_logger
from ..common import get_module_device
@MODULES.register_module('SAGANGenerator')
@MODULES.register_module()
class SNGANGenerator(nn.Module):
r"""Generator for SNGAN / Proj-GAN. The implementation refers to
https://github.com/pfnet-research/sngan_projection/tree/master/gen_models
In our implementation, we have two notable design. Namely,
``channels_cfg`` and ``blocks_cfg``.
``channels_cfg``: In default config of SNGAN / Proj-GAN, the number of
ResBlocks and the channels of those blocks are corresponding to the
resolution of the output image. Therefore, we allow user to define
``channels_cfg`` to try their own models. We also provide a default
config to allow users to build the model only from the output
resolution.
``block_cfg``: In reference code, the generator consists of a group of
ResBlock. However, in our implementation, to make this model more
generalize, we support defining ``blocks_cfg`` by users and loading
the blocks by calling the build_module method.
Args:
output_scale (int): Output scale for the generated image.
num_classes (int, optional): The number classes you would like to
generate. This arguments would influence the structure of the
intermedia blocks and label sampling operation in ``forward``
(e.g. If num_classes=0, ConditionalNormalization layers would
degrade to unconditional ones.). This arguments would be passed
to intermedia blocks by overwrite their config. Defaults to 0.
base_channels (int, optional): The basic channel number of the
generator. The other layers contains channels based on this number.
Default to 64.
out_channels (int, optional): Channels of the output images.
Default to 3.
input_scale (int, optional): Input scale for the features.
Defaults to 4.
noise_size (int, optional): Size of the input noise vector.
Default to 128.
attention_cfg (dict, optional): Config for the self-attention block.
Default to ``dict(type='SelfAttentionBlock')``.
attention_after_nth_block (int | list[int], optional): Self attention
block would be added after which *ConvBlock*. If ``int`` is passed,
only one attention block would be added. If ``list`` is passed,
self-attention blocks would be added after multiple ConvBlocks.
To be noted that if the input is smaller than ``1``,
self-attention corresponding to this index would be ignored.
Default to 0.
channels_cfg (list | dict[list], optional): Config for input channels
of the intermedia blocks. If list is passed, each element of the
list means the input channels of current block is how many times
compared to the ``base_channels``. For block ``i``, the input and
output channels should be ``channels_cfg[i]`` and
``channels_cfg[i+1]`` If dict is provided, the key of the dict
should be the output scale and corresponding value should be a list
to define channels. Default: Please refer to
``_defualt_channels_cfg``.
blocks_cfg (dict, optional): Config for the intermedia blocks.
Defaults to ``dict(type='SNGANGenResBlock')``
act_cfg (dict, optional): Activation config for the final output
layer. Defaults to ``dict(type='ReLU')``.
use_cbn (bool, optional): Whether use conditional normalization. This
argument would pass to norm layers. Defaults to True.
auto_sync_bn (bool, optional): Whether convert Batch Norm to
Synchronized ones when Distributed training is on. Defaults to
True.
with_spectral_norm (bool, optional): Whether use spectral norm for
conv blocks or not. Default to False.
with_embedding_spectral_norm (bool, optional): Whether use spectral
norm for embedding layers in normalization blocks or not. If not
specified (set as ``None``), ``with_embedding_spectral_norm`` would
be set as the same value as ``with_spectral_norm``.
Defaults to None.
sn_style (str, optional): The style of spectral normalization.
If set to `ajbrock`, implementation by
ajbrock(https://github.com/ajbrock/BigGAN-PyTorch/blob/master/layers.py)
will be adopted.
If set to `torch`, implementation by `PyTorch` will be adopted.
Defaults to `torch`.
norm_eps (float, optional): eps for Normalization layers (both
conditional and non-conditional ones). Default to `1e-4`.
sn_eps (float, optional): eps for spectral normalization operation.
Defaults to `1e-12`.
init_cfg (string, optional): Config for weight initialization.
Defaults to ``dict(type='BigGAN')``.
pretrained (str | dict, optional): Path for the pretrained model or
dict containing information for pretained models whose necessary
key is 'ckpt_path'. Besides, you can also provide 'prefix' to load
the generator part from the whole state dict. Defaults to None.
"""
# default channel factors
_default_channels_cfg = {
32: [1, 1, 1],
64: [16, 8, 4, 2],
128: [16, 16, 8, 4, 2]
}
def __init__(self,
output_scale,
num_classes=0,
base_channels=64,
out_channels=3,
input_scale=4,
noise_size=128,
attention_cfg=dict(type='SelfAttentionBlock'),
attention_after_nth_block=0,
channels_cfg=None,
blocks_cfg=dict(type='SNGANGenResBlock'),
act_cfg=dict(type='ReLU'),
use_cbn=True,
auto_sync_bn=True,
with_spectral_norm=False,
with_embedding_spectral_norm=None,
sn_style='torch',
norm_eps=1e-4,
sn_eps=1e-12,
init_cfg=dict(type='BigGAN'),
pretrained=None):
super().__init__()
self.input_scale = input_scale
self.output_scale = output_scale
self.noise_size = noise_size
self.num_classes = num_classes
self.init_type = init_cfg.get('type', None)
self.blocks_cfg = deepcopy(blocks_cfg)
self.blocks_cfg.setdefault('num_classes', num_classes)
self.blocks_cfg.setdefault('act_cfg', act_cfg)
self.blocks_cfg.setdefault('use_cbn', use_cbn)
self.blocks_cfg.setdefault('auto_sync_bn', auto_sync_bn)
self.blocks_cfg.setdefault('with_spectral_norm', with_spectral_norm)
# set `norm_spectral_norm` as `with_spectral_norm` if not defined
with_embedding_spectral_norm = with_embedding_spectral_norm \
if with_embedding_spectral_norm is not None else with_spectral_norm
self.blocks_cfg.setdefault('with_embedding_spectral_norm',
with_embedding_spectral_norm)
self.blocks_cfg.setdefault('init_cfg', init_cfg)
self.blocks_cfg.setdefault('sn_style', sn_style)
self.blocks_cfg.setdefault('norm_eps', norm_eps)
self.blocks_cfg.setdefault('sn_eps', sn_eps)
channels_cfg = deepcopy(self._default_channels_cfg) \
if channels_cfg is None else deepcopy(channels_cfg)
if isinstance(channels_cfg, dict):
if output_scale not in channels_cfg:
raise KeyError(f'`output_scale={output_scale} is not found in '
'`channel_cfg`, only support configs for '
f'{[chn for chn in channels_cfg.keys()]}')
self.channel_factor_list = channels_cfg[output_scale]
elif isinstance(channels_cfg, list):
self.channel_factor_list = channels_cfg
else:
raise ValueError('Only support list or dict for `channel_cfg`, '
f'receive {type(channels_cfg)}')
self.noise2feat = nn.Linear(
noise_size,
input_scale**2 * base_channels * self.channel_factor_list[0])
if with_spectral_norm:
self.noise2feat = spectral_norm(self.noise2feat)
# check `attention_after_nth_block`
if not isinstance(attention_after_nth_block, list):
attention_after_nth_block = [attention_after_nth_block]
if not is_list_of(attention_after_nth_block, int):
raise ValueError('`attention_after_nth_block` only support int or '
'a list of int. Please check your input type.')
self.conv_blocks = nn.ModuleList()
self.attention_block_idx = []
for idx in range(len(self.channel_factor_list)):
factor_input = self.channel_factor_list[idx]
factor_output = self.channel_factor_list[idx+1] \
if idx < len(self.channel_factor_list)-1 else 1
# get block-specific config
block_cfg_ = deepcopy(self.blocks_cfg)
block_cfg_['in_channels'] = factor_input * base_channels
block_cfg_['out_channels'] = factor_output * base_channels
self.conv_blocks.append(build_module(block_cfg_))
# build self-attention block
# `idx` is start from 0, add 1 to get the index
if idx + 1 in attention_after_nth_block:
self.attention_block_idx.append(len(self.conv_blocks))
attn_cfg_ = deepcopy(attention_cfg)
attn_cfg_['in_channels'] = factor_output * base_channels
attn_cfg_['sn_style'] = sn_style
self.conv_blocks.append(build_module(attn_cfg_))
to_rgb_norm_cfg = dict(type='BN', eps=norm_eps)
if check_dist_init() and auto_sync_bn:
to_rgb_norm_cfg['type'] = 'SyncBN'
self.to_rgb = ConvModule(
factor_output * base_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
bias=True,
norm_cfg=to_rgb_norm_cfg,
act_cfg=act_cfg,
order=('norm', 'act', 'conv'),
with_spectral_norm=with_spectral_norm)
self.final_act = build_activation_layer(dict(type='Tanh'))
self.init_weights(pretrained)
def forward(self, noise, num_batches=0, label=None, return_noise=False):
"""Forward function.
Args:
noise (torch.Tensor | callable | None): You can directly give a
batch of noise through a ``torch.Tensor`` or offer a callable
function to sample a batch of noise data. Otherwise, the
``None`` indicates to use the default noise sampler.
num_batches (int, optional): The number of batch size.
Defaults to 0.
label (torch.Tensor | callable | None): You can directly give a
batch of label through a ``torch.Tensor`` or offer a callable
function to sample a batch of label data. Otherwise, the
``None`` indicates to use the default label sampler.
return_noise (bool, optional): If True, ``noise_batch`` will be
returned in a dict with ``fake_img``. Defaults to False.
Returns:
torch.Tensor | dict: If not ``return_noise``, only the output
image will be returned. Otherwise, a dict contains
``fake_image``, ``noise_batch`` and ``label_batch``
would be returned.
"""
if isinstance(noise, torch.Tensor):
assert noise.shape[1] == self.noise_size
assert noise.ndim == 2, ('The noise should be in shape of (n, c), '
f'but got {noise.shape}')
noise_batch = noise
# receive a noise generator and sample noise.
elif callable(noise):
noise_generator = noise
assert num_batches > 0
noise_batch = noise_generator((num_batches, self.noise_size))
# otherwise, we will adopt default noise sampler.
else:
assert num_batches > 0
noise_batch = torch.randn((num_batches, self.noise_size))
if isinstance(label, torch.Tensor):
assert label.ndim == 1, ('The label shoube be in shape of (n, )'
f'but got {label.shape}.')
label_batch = label
elif callable(label):
label_generator = label
assert num_batches > 0
label_batch = label_generator(num_batches)
elif self.num_classes == 0:
label_batch = None
else:
assert num_batches > 0
label_batch = torch.randint(0, self.num_classes, (num_batches, ))
# dirty code for putting data on the right device
noise_batch = noise_batch.to(get_module_device(self))
if label_batch is not None:
label_batch = label_batch.to(get_module_device(self))
x = self.noise2feat(noise_batch)
x = x.reshape(x.size(0), -1, self.input_scale, self.input_scale)
for idx, conv_block in enumerate(self.conv_blocks):
if idx in self.attention_block_idx:
x = conv_block(x)
else:
x = conv_block(x, label_batch)
out_feat = self.to_rgb(x)
out_img = self.final_act(out_feat)
if return_noise:
return dict(
fake_img=out_img, noise_batch=noise_batch, label=label_batch)
return out_img
def init_weights(self, pretrained=None, strict=True):
"""Init weights for SNGAN-Proj and SAGAN. If ``pretrained=None``,
weight initialization would follow the ``INIT_TYPE`` in
``init_cfg=dict(type=INIT_TYPE)``.
For SNGAN-Proj,
(``INIT_TYPE.upper() in ['SNGAN', 'SNGAN-PROJ', 'GAN-PROJ']``),
we follow the initialization method in the official Chainer's
implementation (https://github.com/pfnet-research/sngan_projection).
For SAGAN (``INIT_TYPE.upper() == 'SAGAN'``), we follow the
initialization method in official tensorflow's implementation
(https://github.com/brain-research/self-attention-gan).
Besides the reimplementation of the official code's initialization, we
provide BigGAN's and Pytorch-StudioGAN's style initialization
(``INIT_TYPE.upper() == BIGGAN`` and ``INIT_TYPE.upper() == STUDIO``).
Please refer to https://github.com/ajbrock/BigGAN-PyTorch and
https://github.com/POSTECH-CVLab/PyTorch-StudioGAN.
Args:
pretrained (str | dict, optional): Path for the pretrained model or
dict containing information for pretained models whose
necessary key is 'ckpt_path'. Besides, you can also provide
'prefix' to load the generator part from the whole state dict.
Defaults to None.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=strict, logger=logger)
elif isinstance(pretrained, dict):
ckpt_path = pretrained.get('ckpt_path', None)
assert ckpt_path is not None
prefix = pretrained.get('prefix', '')
map_location = pretrained.get('map_location', 'cpu')
strict = pretrained.get('strict', True)
state_dict = _load_checkpoint_with_prefix(prefix, ckpt_path,
map_location)
self.load_state_dict(state_dict, strict=strict)
elif pretrained is None:
if self.init_type.upper() in 'STUDIO':
# initialization method from Pytorch-StudioGAN
# * weight: orthogonal_init gain=1
# * bias : 0
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.Linear, nn.Embedding)):
nn.init.orthogonal_(m.weight)
if hasattr(m, 'bias') and m.bias is not None:
m.bias.data.fill_(0.)
elif self.init_type.upper() == 'BIGGAN':
# initialization method from BigGAN-pytorch
# * weight: xavier_init gain=1
# * bias : default
for n, m in self.named_modules():
if isinstance(m, (nn.Conv2d, nn.Linear, nn.Embedding)):
xavier_uniform_(m.weight, gain=1)
elif self.init_type.upper() == 'SAGAN':
# initialization method from official tensorflow code
# * weight : xavier_init gain=1
# * bias : 0
# * weight_embedding: 1
# * bias_embedding : 0
for n, m in self.named_modules():
if isinstance(m, (nn.Conv2d, nn.Linear)):
xavier_init(m, gain=1, distribution='uniform')
if isinstance(m, nn.Embedding):
# To be noted that here we initialize the embedding
# layer in cBN with specific prefix. If you implement
# your own cBN and want to use this initialization
# method, please make sure the embedding layers in
# your implementation have the same prefix as ours.
if 'weight' in n:
constant_init(m, 1)
if 'bias' in n:
constant_init(m, 0)
elif self.init_type.upper() in ['SNGAN', 'SNGAN-PROJ', 'GAN-PROJ']:
# initialization method from the official chainer code
# * conv.weight : xavier_init gain=sqrt(2)
# * shortcut.weight : xavier_init gain=1
# * bias : 0
# * weight_embedding: 1
# * bias_embedding : 0
for n, m in self.named_modules():
if isinstance(m, nn.Conv2d):
if 'shortcut' in n or 'to_rgb' in n:
xavier_init(m, gain=1, distribution='uniform')
else:
xavier_init(
m, gain=np.sqrt(2), distribution='uniform')
if isinstance(m, nn.Linear):
xavier_init(m, gain=1, distribution='uniform')
if isinstance(m, nn.Embedding):
# To be noted that here we initialize the embedding
# layer in cBN with specific prefix. If you implement
# your own cBN and want to use this initialization
# method, please make sure the embedding layers in
# your implementation have the same prefix as ours.
if 'weight' in n:
constant_init(m, 1)
if 'bias' in n:
constant_init(m, 0)
else:
raise NotImplementedError('Unknown initialization method: '
f'\'{self.init_type}\'')
else:
raise TypeError("'pretrined' must be a str or None. "
f'But receive {type(pretrained)}.')
@MODULES.register_module('SAGANDiscriminator')
@MODULES.register_module()
class ProjDiscriminator(nn.Module):
r"""Discriminator for SNGAN / Proj-GAN. The implementation is refer to
https://github.com/pfnet-research/sngan_projection/tree/master/dis_models
The overall structure of the projection discriminator can be split into a
``from_rgb`` layer, a group of ResBlocks, a linear decision layer, and a
projection layer. To support defining custom layers, we introduce
``from_rgb_cfg`` and ``blocks_cfg``.
The design of the model structure is highly corresponding to the output
resolution. Therefore, we provide `channels_cfg` and `downsample_cfg` to
control the input channels and the downsample behavior of the intermedia
blocks.
``downsample_cfg``: In default config of SNGAN / Proj-GAN, whether to apply
downsample in each intermedia blocks is quite flexible and
corresponding to the resolution of the output image. Therefore, we
support user to define the ``downsample_cfg`` by themselves, and to
control the structure of the discriminator.
``channels_cfg``: In default config of SNGAN / Proj-GAN, the number of
ResBlocks and the channels of those blocks are corresponding to the
resolution of the output image. Therefore, we allow user to define
`channels_cfg` for try their own models. We also provide a default
config to allow users to build the model only from the output
resolution.
Args:
input_scale (int): The scale of the input image.
num_classes (int, optional): The number classes you would like to
generate. If num_classes=0, no label projection would be used.
Default to 0.
base_channels (int, optional): The basic channel number of the
discriminator. The other layers contains channels based on this
number. Defaults to 128.
input_channels (int, optional): Channels of the input image.
Defaults to 3.
attention_cfg (dict, optional): Config for the self-attention block.
Default to ``dict(type='SelfAttentionBlock')``.
attention_after_nth_block (int | list[int], optional): Self-attention
block would be added after which *ConvBlock* (including the head
block). If ``int`` is passed, only one attention block would be
added. If ``list`` is passed, self-attention blocks would be added
after multiple ConvBlocks. To be noted that if the input is
smaller than ``1``, self-attention corresponding to this index
would be ignored. Default to 0.
channels_cfg (list | dict[list], optional): Config for input channels
of the intermedia blocks. If list is passed, each element of the
list means the input channels of current block is how many times
compared to the ``base_channels``. For block ``i``, the input and
output channels should be ``channels_cfg[i]`` and
``channels_cfg[i+1]`` If dict is provided, the key of the dict
should be the output scale and corresponding value should be a list
to define channels. Default: Please refer to
``_defualt_channels_cfg``.
downsample_cfg (list[bool] | dict[list], optional): Config for
downsample behavior of the intermedia layers. If a list is passed,
``downsample_cfg[idx] == True`` means apply downsample in idx-th
block, and vice versa. If dict is provided, the key dict should
be the input scale of the image and corresponding value should be
a list ti define the downsample behavior. Default: Please refer
to ``_default_downsample_cfg``.
from_rgb_cfg (dict, optional): Config for the first layer to convert
rgb image to feature map. Defaults to
``dict(type='SNGANDiscHeadResBlock')``.
blocks_cfg (dict, optional): Config for the intermedia blocks.
Defaults to ``dict(type='SNGANDiscResBlock')``
act_cfg (dict, optional): Activation config for the final output
layer. Defaults to ``dict(type='ReLU')``.
with_spectral_norm (bool, optional): Whether use spectral norm for
all conv blocks or not. Default to True.
sn_style (str, optional): The style of spectral normalization.
If set to `ajbrock`, implementation by
ajbrock(https://github.com/ajbrock/BigGAN-PyTorch/blob/master/layers.py)
will be adopted.
If set to `torch`, implementation by `PyTorch` will be adopted.
Defaults to `torch`.
sn_eps (float, optional): eps for spectral normalization operation.
Defaults to `1e-12`.
init_cfg (dict, optional): Config for weight initialization.
Default to ``dict(type='BigGAN')``.
pretrained (str | dict , optional): Path for the pretrained model or
dict containing information for pretained models whose necessary
key is 'ckpt_path'. Besides, you can also provide 'prefix' to load
the generator part from the whole state dict. Defaults to None.
"""
# default channel factors
_defualt_channels_cfg = {
32: [1, 1, 1],
64: [2, 4, 8, 16],
128: [2, 4, 8, 16, 16],
}
# default downsample behavior
_defualt_downsample_cfg = {
32: [True, False, False],
64: [True, True, True, True],
128: [True, True, True, True, False]
}
def __init__(self,
input_scale,
num_classes=0,
base_channels=128,
input_channels=3,
attention_cfg=dict(type='SelfAttentionBlock'),
attention_after_nth_block=-1,
channels_cfg=None,
downsample_cfg=None,
from_rgb_cfg=dict(type='SNGANDiscHeadResBlock'),
blocks_cfg=dict(type='SNGANDiscResBlock'),
act_cfg=dict(type='ReLU'),
with_spectral_norm=True,
sn_style='torch',
sn_eps=1e-12,
init_cfg=dict(type='BigGAN'),
pretrained=None):
super().__init__()
self.init_type = init_cfg.get('type', None)
# add SN options and activation function options to cfg
self.from_rgb_cfg = deepcopy(from_rgb_cfg)
self.from_rgb_cfg.setdefault('act_cfg', act_cfg)
self.from_rgb_cfg.setdefault('with_spectral_norm', with_spectral_norm)
self.from_rgb_cfg.setdefault('sn_style', sn_style)
self.from_rgb_cfg.setdefault('init_cfg', init_cfg)
# add SN options and activation function options to cfg
self.blocks_cfg = deepcopy(blocks_cfg)
self.blocks_cfg.setdefault('act_cfg', act_cfg)
self.blocks_cfg.setdefault('with_spectral_norm', with_spectral_norm)
self.blocks_cfg.setdefault('sn_style', sn_style)
self.blocks_cfg.setdefault('sn_eps', sn_eps)
self.blocks_cfg.setdefault('init_cfg', init_cfg)
channels_cfg = deepcopy(self._defualt_channels_cfg) \
if channels_cfg is None else deepcopy(channels_cfg)
if isinstance(channels_cfg, dict):
if input_scale not in channels_cfg:
raise KeyError(f'`input_scale={input_scale} is not found in '
'`channel_cfg`, only support configs for '
f'{[chn for chn in channels_cfg.keys()]}')
self.channel_factor_list = channels_cfg[input_scale]
elif isinstance(channels_cfg, list):
self.channel_factor_list = channels_cfg
else:
raise ValueError('Only support list or dict for `channel_cfg`, '
f'receive {type(channels_cfg)}')
downsample_cfg = deepcopy(self._defualt_downsample_cfg) \
if downsample_cfg is None else deepcopy(downsample_cfg)
if isinstance(downsample_cfg, dict):
if input_scale not in downsample_cfg:
raise KeyError(f'`output_scale={input_scale} is not found in '
'`downsample_cfg`, only support configs for '
f'{[chn for chn in downsample_cfg.keys()]}')
self.downsample_list = downsample_cfg[input_scale]
elif isinstance(downsample_cfg, list):
self.downsample_list = downsample_cfg
else:
raise ValueError('Only support list or dict for `channel_cfg`, '
f'receive {type(downsample_cfg)}')
if len(self.downsample_list) != len(self.channel_factor_list):
raise ValueError('`downsample_cfg` should have same length with '
'`channels_cfg`, but receive '
f'{len(self.downsample_list)} and '
f'{len(self.channel_factor_list)}.')
# check `attention_after_nth_block`
if not isinstance(attention_after_nth_block, list):
attention_after_nth_block = [attention_after_nth_block]
if not all([isinstance(idx, int)
for idx in attention_after_nth_block]):
raise ValueError('`attention_after_nth_block` only support int or '
'a list of int. Please check your input type.')
self.from_rgb = build_module(
self.from_rgb_cfg,
dict(in_channels=input_channels, out_channels=base_channels))
self.conv_blocks = nn.ModuleList()
# add self-attention block after the first block
if 1 in attention_after_nth_block:
attn_cfg_ = deepcopy(attention_cfg)
attn_cfg_['in_channels'] = base_channels
attn_cfg_['sn_style'] = sn_style
self.conv_blocks.append(build_module(attn_cfg_))
for idx in range(len(self.downsample_list)):
factor_input = 1 if idx == 0 else self.channel_factor_list[idx - 1]
factor_output = self.channel_factor_list[idx]
# get block-specific config
block_cfg_ = deepcopy(self.blocks_cfg)
block_cfg_['downsample'] = self.downsample_list[idx]
block_cfg_['in_channels'] = factor_input * base_channels
block_cfg_['out_channels'] = factor_output * base_channels
self.conv_blocks.append(build_module(block_cfg_))
# build self-attention block
# the first ConvBlock is `from_rgb` block,
# add 2 to get the index of the ConvBlocks
if idx + 2 in attention_after_nth_block:
attn_cfg_ = deepcopy(attention_cfg)
attn_cfg_['in_channels'] = factor_output * base_channels
self.conv_blocks.append(build_module(attn_cfg_))
self.decision = nn.Linear(factor_output * base_channels, 1)
if with_spectral_norm:
self.decision = spectral_norm(self.decision)
self.num_classes = num_classes
# In this case, discriminator is designed for conditional synthesis.
if num_classes > 0:
self.proj_y = nn.Embedding(num_classes,
factor_output * base_channels)
if with_spectral_norm:
self.proj_y = spectral_norm(self.proj_y)
self.activate = build_activation_layer(act_cfg)
self.init_weights(pretrained)
def forward(self, x, label=None):
"""Forward function. If `self.num_classes` is larger than 0, label
projection would be used.
Args:
x (torch.Tensor): Fake or real image tensor.
label (torch.Tensor, options): Label correspond to the input image.
Noted that, if `self.num_classed` is larger than 0,
`label` should not be None. Default to None.
Returns:
torch.Tensor: Prediction for the reality of the input image.
"""
h = self.from_rgb(x)
for conv_block in self.conv_blocks:
h = conv_block(h)
h = self.activate(h)
h = torch.sum(h, dim=[2, 3])
out = self.decision(h)
if self.num_classes > 0:
w_y = self.proj_y(label)
out = out + torch.sum(w_y * h, dim=1, keepdim=True)
return out.view(out.size(0), -1)
def init_weights(self, pretrained=None, strict=True):
"""Init weights for SNGAN-Proj and SAGAN. If ``pretrained=None`` and
weight initialization would follow the ``INIT_TYPE`` in
``init_cfg=dict(type=INIT_TYPE)``.
For SNGAN-Proj
(``INIT_TYPE.upper() in ['SNGAN', 'SNGAN-PROJ', 'GAN-PROJ']``),
we follow the initialization method in the official Chainer's
implementation (https://github.com/pfnet-research/sngan_projection).
For SAGAN (``INIT_TYPE.upper() == 'SAGAN'``), we follow the
initialization method in official tensorflow's implementation
(https://github.com/brain-research/self-attention-gan).
Besides the reimplementation of the official code's initialization, we
provide BigGAN's and Pytorch-StudioGAN's style initialization
(``INIT_TYPE.upper() == BIGGAN`` and ``INIT_TYPE.upper() == STUDIO``).
Please refer to https://github.com/ajbrock/BigGAN-PyTorch and
https://github.com/POSTECH-CVLab/PyTorch-StudioGAN.
Args:
pretrained (str | dict, optional): Path for the pretrained model or
dict containing information for pretained models whose
necessary key is 'ckpt_path'. Besides, you can also provide
'prefix' to load the generator part from the whole state dict.
Defaults to None.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=strict, logger=logger)
elif isinstance(pretrained, dict):
ckpt_path = pretrained.get('ckpt_path', None)
assert ckpt_path is not None
prefix = pretrained.get('prefix', '')
map_location = pretrained.get('map_location', 'cpu')
strict = pretrained.get('strict', True)
state_dict = _load_checkpoint_with_prefix(prefix, ckpt_path,
map_location)
self.load_state_dict(state_dict, strict=strict)
elif pretrained is None:
if self.init_type.upper() == 'STUDIO':
# initialization method from Pytorch-StudioGAN
# * weight: orthogonal_init gain=1
# * bias : 0
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.Linear, nn.Embedding)):
nn.init.orthogonal_(m.weight, gain=1)
if hasattr(m, 'bias') and m.bias is not None:
m.bias.data.fill_(0.)
elif self.init_type.upper() == 'BIGGAN':
# initialization method from BigGAN-pytorch
# * weight: xavier_init gain=1
# * bias : default
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.Linear, nn.Embedding)):
xavier_uniform_(m.weight, gain=1)
elif self.init_type.upper() == 'SAGAN':
# initialization method from official tensorflow code
# * weight: xavier_init gain=1
# * bias : 0
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.Linear, nn.Embedding)):
xavier_init(m, gain=1, distribution='uniform')
elif self.init_type.upper() in ['SNGAN', 'SNGAN-PROJ', 'GAN-PROJ']:
# initialization method from the official chainer code
# * embedding.weight: xavier_init gain=1
# * conv.weight : xavier_init gain=sqrt(2)
# * shortcut.weight : xavier_init gain=1
# * bias : 0
for n, m in self.named_modules():
if isinstance(m, nn.Conv2d):
if 'shortcut' in n:
xavier_init(m, gain=1, distribution='uniform')
else:
xavier_init(
m, gain=np.sqrt(2), distribution='uniform')
if isinstance(m, (nn.Linear, nn.Embedding)):
xavier_init(m, gain=1, distribution='uniform')
else:
raise NotImplementedError('Unknown initialization method: '
f'\'{self.init_type}\'')
else:
raise TypeError("'pretrained' must by a str or None. "
f'But receive {type(pretrained)}.')
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
import numpy as np
import torch.nn as nn
from mmcv.cnn import (build_activation_layer, build_norm_layer,
build_upsample_layer, constant_init, xavier_init)
from torch.nn.init import xavier_uniform_
from torch.nn.utils import spectral_norm
from mmgen.models.architectures.biggan.biggan_snmodule import SNEmbedding
from mmgen.models.architectures.biggan.modules import SNConvModule
from mmgen.models.builder import MODULES
from mmgen.utils import check_dist_init
@MODULES.register_module()
class SNGANGenResBlock(nn.Module):
"""ResBlock used in Generator of SNGAN / Proj-GAN.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
hidden_channels (int, optional): Input channels of the second Conv
layer of the block. If ``None`` is given, would be set as
``out_channels``. Default to None.
num_classes (int, optional): Number of classes would like to generate.
This argument would pass to norm layers and influence the structure
and behavior of the normalization process. Default to 0.
use_cbn (bool, optional): Whether use conditional normalization. This
argument would pass to norm layers. Default to True.
use_norm_affine (bool, optional): Whether use learnable affine
parameters in norm operation when cbn is off. Default False.
act_cfg (dict, optional): Config for activate function. Default
to ``dict(type='ReLU')``.
upsample_cfg (dict, optional): Config for the upsample method.
Default to ``dict(type='nearest', scale_factor=2)``.
upsample (bool, optional): Whether apply upsample operation in this
module. Default to True.
auto_sync_bn (bool, optional): Whether convert Batch Norm to
Synchronized ones when Distributed training is on. Default to True.
conv_cfg (dict | None): Config for conv blocks of this module. If pass
``None``, would use ``_default_conv_cfg``. Default to ``None``.
with_spectral_norm (bool, optional): Whether use spectral norm for
conv blocks and norm layers. Default to True.
with_embedding_spectral_norm (bool, optional): Whether use spectral
norm for embedding layers in normalization blocks or not. If not
specified (set as ``None``), ``with_embedding_spectral_norm`` would
be set as the same value as ``with_spectral_norm``.
Default to None.
sn_style (str, optional): The style of spectral normalization.
If set to `ajbrock`, implementation by
ajbrock(https://github.com/ajbrock/BigGAN-PyTorch/blob/master/layers.py)
will be adopted.
If set to `torch`, implementation by `PyTorch` will be adopted.
Defaults to `torch`.
norm_eps (float, optional): eps for Normalization layers (both
conditional and non-conditional ones). Default to `1e-4`.
sn_eps (float, optional): eps for spectral normalization operation.
Default to `1e-12`.
init_cfg (dict, optional): Config for weight initialization.
Default to ``dict(type='BigGAN')``.
"""
_default_conv_cfg = dict(kernel_size=3, stride=1, padding=1, act_cfg=None)
def __init__(self,
in_channels,
out_channels,
hidden_channels=None,
num_classes=0,
use_cbn=True,
use_norm_affine=False,
act_cfg=dict(type='ReLU'),
norm_cfg=dict(type='BN'),
upsample_cfg=dict(type='nearest', scale_factor=2),
upsample=True,
auto_sync_bn=True,
conv_cfg=None,
with_spectral_norm=False,
with_embedding_spectral_norm=None,
sn_style='torch',
norm_eps=1e-4,
sn_eps=1e-12,
init_cfg=dict(type='BigGAN')):
super().__init__()
self.learnable_sc = in_channels != out_channels or upsample
self.with_upsample = upsample
self.init_type = init_cfg.get('type', None)
self.activate = build_activation_layer(act_cfg)
hidden_channels = out_channels if hidden_channels is None \
else hidden_channels
if self.with_upsample:
self.upsample = build_upsample_layer(upsample_cfg)
self.conv_cfg = deepcopy(self._default_conv_cfg)
if conv_cfg is not None:
self.conv_cfg.update(conv_cfg)
# set `norm_spectral_norm` as `with_spectral_norm` if not defined
with_embedding_spectral_norm = with_embedding_spectral_norm \
if with_embedding_spectral_norm is not None else with_spectral_norm
sn_cfg = dict(eps=sn_eps, sn_style=sn_style)
self.conv_1 = SNConvModule(
in_channels,
hidden_channels,
with_spectral_norm=with_spectral_norm,
spectral_norm_cfg=sn_cfg,
**self.conv_cfg)
self.conv_2 = SNConvModule(
hidden_channels,
out_channels,
with_spectral_norm=with_spectral_norm,
spectral_norm_cfg=sn_cfg,
**self.conv_cfg)
self.norm_1 = SNConditionNorm(in_channels, num_classes, use_cbn,
norm_cfg, use_norm_affine, auto_sync_bn,
with_embedding_spectral_norm, sn_style,
norm_eps, sn_eps, init_cfg)
self.norm_2 = SNConditionNorm(hidden_channels, num_classes, use_cbn,
norm_cfg, use_norm_affine, auto_sync_bn,
with_embedding_spectral_norm, sn_style,
norm_eps, sn_eps, init_cfg)
if self.learnable_sc:
# use hyperparameters-fixed shortcut here
self.shortcut = SNConvModule(
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
act_cfg=None,
with_spectral_norm=with_spectral_norm,
spectral_norm_cfg=sn_cfg)
self.init_weights()
def forward(self, x, y=None):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
y (Tensor): Input label with shape (n, ).
Default None.
Returns:
Tensor: Forward results.
"""
out = self.norm_1(x, y)
out = self.activate(out)
if self.with_upsample:
out = self.upsample(out)
out = self.conv_1(out)
out = self.norm_2(out, y)
out = self.activate(out)
out = self.conv_2(out)
shortcut = self.forward_shortcut(x)
return out + shortcut
def forward_shortcut(self, x):
out = x
if self.learnable_sc:
if self.with_upsample:
out = self.upsample(out)
out = self.shortcut(out)
return out
def init_weights(self):
"""Initialize weights for the model."""
if self.init_type.upper() == 'STUDIO':
nn.init.orthogonal_(self.conv_1.conv.weight)
nn.init.orthogonal_(self.conv_2.conv.weight)
self.conv_1.conv.bias.data.fill_(0.)
self.conv_2.conv.bias.data.fill_(0.)
if self.learnable_sc:
nn.init.orthogonal_(self.shortcut.conv.weight)
self.shortcut.conv.bias.data.fill_(0.)
elif self.init_type.upper() == 'BIGGAN':
xavier_uniform_(self.conv_1.conv.weight, gain=1)
xavier_uniform_(self.conv_2.conv.weight, gain=1)
if self.learnable_sc:
xavier_uniform_(self.shortcut.conv.weight, gain=1)
elif self.init_type.upper() == 'SAGAN':
xavier_init(self.conv_1, gain=1, distribution='uniform')
xavier_init(self.conv_2, gain=1, distribution='uniform')
if self.learnable_sc:
xavier_init(self.shortcut, gain=1, distribution='uniform')
elif self.init_type.upper() in ['SNGAN', 'SNGAN-PROJ', 'GAN-PROJ']:
xavier_init(self.conv_1, gain=np.sqrt(2), distribution='uniform')
xavier_init(self.conv_2, gain=np.sqrt(2), distribution='uniform')
if self.learnable_sc:
xavier_init(self.shortcut, gain=1, distribution='uniform')
else:
raise NotImplementedError('Unknown initialization method: '
f'\'{self.init_type}\'')
@MODULES.register_module()
class SNGANDiscResBlock(nn.Module):
"""resblock used in discriminator of sngan / proj-gan.
args:
in_channels (int): input channels.
out_channels (int): output channels.
hidden_channels (int, optional): input channels of the second conv
layer of the block. if ``none`` is given, would be set as
``out_channels``. Defaults to none.
downsample (bool, optional): whether apply downsample operation in this
module. Defaults to false.
act_cfg (dict, optional): config for activate function. default
to ``dict(type='relu')``.
conv_cfg (dict | none): config for conv blocks of this module. if pass
``none``, would use ``_default_conv_cfg``. default to ``none``.
with_spectral_norm (bool, optional): whether use spectral norm for
conv blocks and norm layers. Defaults to true.
sn_eps (float, optional): eps for spectral normalization operation.
Default to `1e-12`.
sn_style (str, optional): The style of spectral normalization.
If set to `ajbrock`, implementation by
ajbrock(https://github.com/ajbrock/BigGAN-PyTorch/blob/master/layers.py)
will be adopted.
If set to `torch`, implementation by `PyTorch` will be adopted.
Defaults to `torch`.
init_cfg (dict, optional): Config for weight initialization.
Defaults to ``dict(type='BigGAN')``.
"""
_default_conv_cfg = dict(kernel_size=3, stride=1, padding=1, act_cfg=None)
def __init__(self,
in_channels,
out_channels,
hidden_channels=None,
downsample=False,
act_cfg=dict(type='ReLU'),
conv_cfg=None,
with_spectral_norm=True,
sn_style='torch',
sn_eps=1e-12,
init_cfg=dict(type='BigGAN')):
super().__init__()
hidden_channels = out_channels if hidden_channels is None \
else hidden_channels
self.with_downsample = downsample
self.init_type = init_cfg.get('type', None)
self.conv_cfg = deepcopy(self._default_conv_cfg)
if conv_cfg is not None:
self.conv_cfg.update(conv_cfg)
self.activate = build_activation_layer(act_cfg)
sn_cfg = dict(eps=sn_eps, sn_style=sn_style)
self.conv_1 = SNConvModule(
in_channels,
hidden_channels,
with_spectral_norm=with_spectral_norm,
spectral_norm_cfg=sn_cfg,
**self.conv_cfg)
self.conv_2 = SNConvModule(
hidden_channels,
out_channels,
with_spectral_norm=with_spectral_norm,
spectral_norm_cfg=sn_cfg,
**self.conv_cfg)
if self.with_downsample:
self.downsample = nn.AvgPool2d(2, 2)
self.learnable_sc = in_channels != out_channels or downsample
if self.learnable_sc:
# use hyperparameters-fixed shortcut here
self.shortcut = SNConvModule(
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
act_cfg=None,
with_spectral_norm=with_spectral_norm,
spectral_norm_cfg=sn_cfg)
self.init_weights()
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
out = self.activate(x)
out = self.conv_1(out)
out = self.activate(out)
out = self.conv_2(out)
if self.with_downsample:
out = self.downsample(out)
shortcut = self.forward_shortcut(x)
return out + shortcut
def forward_shortcut(self, x):
out = x
if self.learnable_sc:
out = self.shortcut(out)
if self.with_downsample:
out = self.downsample(out)
return out
def init_weights(self):
if self.init_type.upper() == 'STUDIO':
nn.init.orthogonal_(self.conv_1.conv.weight)
nn.init.orthogonal_(self.conv_2.conv.weight)
self.conv_1.conv.bias.data.fill_(0.)
self.conv_2.conv.bias.data.fill_(0.)
if self.learnable_sc:
nn.init.orthogonal_(self.shortcut.conv.weight)
self.shortcut.conv.bias.data.fill_(0.)
elif self.init_type.upper() == 'BIGGAN':
xavier_uniform_(self.conv_1.conv.weight, gain=1)
xavier_uniform_(self.conv_2.conv.weight, gain=1)
if self.learnable_sc:
xavier_uniform_(self.shortcut.conv.weight, gain=1)
elif self.init_type.upper() == 'SAGAN':
xavier_init(self.conv_1, gain=1, distribution='uniform')
xavier_init(self.conv_2, gain=1, distribution='uniform')
if self.learnable_sc:
xavier_init(self.shortcut, gain=1, distribution='uniform')
elif self.init_type.upper() in ['SNGAN', 'SNGAN-PROJ', 'GAN-PROJ']:
xavier_init(self.conv_1, gain=np.sqrt(2), distribution='uniform')
xavier_init(self.conv_2, gain=np.sqrt(2), distribution='uniform')
if self.learnable_sc:
xavier_init(self.shortcut, gain=1, distribution='uniform')
else:
raise NotImplementedError('Unknown initialization method: '
f'\'{self.init_type}\'')
@MODULES.register_module()
class SNGANDiscHeadResBlock(nn.Module):
"""The first ResBlock used in discriminator of sngan / proj-gan. Compared
to ``SNGANDisResBlock``, this module has a different forward order.
args:
in_channels (int): Input channels.
out_channels (int): Output channels.
downsample (bool, optional): whether apply downsample operation in this
module. default to false.
conv_cfg (dict | none): config for conv blocks of this module. if pass
``none``, would use ``_default_conv_cfg``. default to ``none``.
act_cfg (dict, optional): config for activate function. default
to ``dict(type='relu')``.
with_spectral_norm (bool, optional): whether use spectral norm for
conv blocks and norm layers. default to true.
sn_style (str, optional): The style of spectral normalization.
If set to `ajbrock`, implementation by
ajbrock(https://github.com/ajbrock/BigGAN-PyTorch/blob/master/layers.py)
will be adopted.
If set to `torch`, implementation by `PyTorch` will be adopted.
Defaults to `torch`.
sn_eps (float, optional): eps for spectral normalization operation.
Default to `1e-12`.
init_cfg (dict, optional): Config for weight initialization.
Default to ``dict(type='BigGAN')``.
"""
_default_conv_cfg = dict(kernel_size=3, stride=1, padding=1, act_cfg=None)
def __init__(self,
in_channels,
out_channels,
conv_cfg=None,
act_cfg=dict(type='ReLU'),
with_spectral_norm=True,
sn_eps=1e-12,
sn_style='torch',
init_cfg=dict(type='BigGAN')):
super().__init__()
self.init_type = init_cfg.get('type', None)
self.conv_cfg = deepcopy(self._default_conv_cfg)
if conv_cfg is not None:
self.conv_cfg.update(conv_cfg)
self.activate = build_activation_layer(act_cfg)
sn_cfg = dict(eps=sn_eps, sn_style=sn_style)
self.conv_1 = SNConvModule(
in_channels,
out_channels,
with_spectral_norm=with_spectral_norm,
spectral_norm_cfg=sn_cfg,
**self.conv_cfg)
self.conv_2 = SNConvModule(
out_channels,
out_channels,
with_spectral_norm=with_spectral_norm,
spectral_norm_cfg=sn_cfg,
**self.conv_cfg)
self.downsample = nn.AvgPool2d(2, 2)
# use hyperparameters-fixed shortcut here
self.shortcut = SNConvModule(
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
act_cfg=None,
with_spectral_norm=with_spectral_norm,
spectral_norm_cfg=sn_cfg)
self.init_weights()
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
out = self.conv_1(x)
out = self.activate(out)
out = self.conv_2(out)
out = self.downsample(out)
shortcut = self.forward_shortcut(x)
return out + shortcut
def forward_shortcut(self, x):
out = self.downsample(x)
out = self.shortcut(out)
return out
def init_weights(self):
if self.init_type.upper() == 'STUDIO':
for m in [self.conv_1, self.conv_2, self.shortcut]:
nn.init.orthogonal_(m.conv.weight)
m.conv.bias.data.fill_(0.)
elif self.init_type.upper() == 'BIGGAN':
xavier_uniform_(self.conv_1.conv.weight, gain=1)
xavier_uniform_(self.conv_2.conv.weight, gain=1)
xavier_uniform_(self.shortcut.conv.weight, gain=1)
elif self.init_type.upper() == 'SAGAN':
xavier_init(self.conv_1, gain=1, distribution='uniform')
xavier_init(self.conv_2, gain=1, distribution='uniform')
xavier_init(self.shortcut, gain=1, distribution='uniform')
elif self.init_type.upper() in ['SNGAN', 'SNGAN-PROJ', 'GAN-PROJ']:
xavier_init(self.conv_1, gain=np.sqrt(2), distribution='uniform')
xavier_init(self.conv_2, gain=np.sqrt(2), distribution='uniform')
xavier_init(self.shortcut, gain=1, distribution='uniform')
else:
raise NotImplementedError('Unknown initialization method: '
f'\'{self.init_type}\'')
@MODULES.register_module()
class SNConditionNorm(nn.Module):
"""Conditional Normalization for SNGAN / Proj-GAN. The implementation
refers to.
https://github.com/pfnet-research/sngan_projection/blob/master/source/links/conditional_batch_normalization.py # noda
and
https://github.com/POSTECH-CVLab/PyTorch-StudioGAN/blob/master/src/utils/model_ops.py # noqa
Args:
in_channels (int): Number of the channels of the input feature map.
num_classes (int): Number of the classes in the dataset. If ``use_cbn``
is True, ``num_classes`` must larger than 0.
use_cbn (bool, optional): Whether use conditional normalization. If
``use_cbn`` is True, two embedding layers would be used to mapping
label to weight and bias used in normalization process.
norm_cfg (dict, optional): Config for normalization method. Defaults
to ``dict(type='BN')``.
cbn_norm_affine (bool): Whether set ``affine=True`` when use conditional batch norm.
This argument only work when ``use_cbn`` is True. Defaults to False.
auto_sync_bn (bool, optional): Whether convert Batch Norm to
Synchronized ones when Distributed training is on. Defaults to True.
with_spectral_norm (bool, optional): whether use spectral norm for
conv blocks and norm layers. Defaults to true.
norm_eps (float, optional): eps for Normalization layers (both
conditional and non-conditional ones). Defaults to `1e-4`.
sn_style (str, optional): The style of spectral normalization.
If set to `ajbrock`, implementation by
ajbrock(https://github.com/ajbrock/BigGAN-PyTorch/blob/master/layers.py)
will be adopted.
If set to `torch`, implementation by `PyTorch` will be adopted.
Defaults to `torch`.
sn_eps (float, optional): eps for spectral normalization operation.
Defaults to `1e-12`.
init_cfg (dict, optional): Config for weight initialization.
Defaults to ``dict(type='BigGAN')``.
"""
def __init__(self,
in_channels,
num_classes,
use_cbn=True,
norm_cfg=dict(type='BN'),
cbn_norm_affine=False,
auto_sync_bn=True,
with_spectral_norm=False,
sn_style='torch',
norm_eps=1e-4,
sn_eps=1e-12,
init_cfg=dict(type='BigGAN')):
super().__init__()
self.use_cbn = use_cbn
self.init_type = init_cfg.get('type', None)
norm_cfg = deepcopy(norm_cfg)
norm_type = norm_cfg['type']
if norm_type not in ['IN', 'BN', 'SyncBN']:
raise ValueError('Only support `IN` (InstanceNorm), '
'`BN` (BatcnNorm) and `SyncBN` for '
'Class-conditional bn. '
f'Receive norm_type: {norm_type}')
if self.use_cbn:
norm_cfg.setdefault('affine', cbn_norm_affine)
norm_cfg.setdefault('eps', norm_eps)
if check_dist_init() and auto_sync_bn and norm_type == 'BN':
norm_cfg['type'] = 'SyncBN'
_, self.norm = build_norm_layer(norm_cfg, in_channels)
if self.use_cbn:
if num_classes <= 0:
raise ValueError('`num_classes` must be larger '
'than 0 with `use_cbn=True`')
self.reweight_embedding = (
self.init_type.upper() == 'BIGGAN'
or self.init_type.upper() == 'STUDIO')
if with_spectral_norm:
if sn_style == 'torch':
self.weight_embedding = spectral_norm(
nn.Embedding(num_classes, in_channels), eps=sn_eps)
self.bias_embedding = spectral_norm(
nn.Embedding(num_classes, in_channels), eps=sn_eps)
elif sn_style == 'ajbrock':
self.weight_embedding = SNEmbedding(
num_classes, in_channels, eps=sn_eps)
self.bias_embedding = SNEmbedding(
num_classes, in_channels, eps=sn_eps)
else:
raise NotImplementedError(
f'{sn_style} style spectral Norm is not '
'supported yet')
else:
self.weight_embedding = nn.Embedding(num_classes, in_channels)
self.bias_embedding = nn.Embedding(num_classes, in_channels)
self.init_weights()
def forward(self, x, y=None):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
y (Tensor, optional): Input label with shape (n, ).
Default None.
Returns:
Tensor: Forward results.
"""
out = self.norm(x)
if self.use_cbn:
weight = self.weight_embedding(y)[:, :, None, None]
bias = self.bias_embedding(y)[:, :, None, None]
if self.reweight_embedding:
# print('reweight_called --> correct')
weight = weight + 1.
out = out * weight + bias
return out
def init_weights(self):
if self.use_cbn:
if self.init_type.upper() == 'STUDIO':
nn.init.orthogonal_(self.weight_embedding.weight)
nn.init.orthogonal_(self.bias_embedding.weight)
elif self.init_type.upper() == 'BIGGAN':
xavier_uniform_(self.weight_embedding.weight, gain=1)
xavier_uniform_(self.bias_embedding.weight, gain=1)
elif self.init_type.upper() in [
'SNGAN', 'SNGAN-PROJ', 'GAN-PROJ', 'SAGAN'
]:
constant_init(self.weight_embedding, 1)
constant_init(self.bias_embedding, 0)
else:
raise NotImplementedError('Unknown initialization method: '
f'\'{self.init_type}\'')
# Copyright (c) OpenMMLab. All rights reserved.
from .generator_discriminator_v1 import (StyleGAN1Discriminator,
StyleGANv1Generator)
from .generator_discriminator_v2 import (StyleGAN2Discriminator,
StyleGANv2Generator)
from .generator_discriminator_v3 import StyleGANv3Generator
from .mspie import MSStyleGAN2Discriminator, MSStyleGANv2Generator
__all__ = [
'StyleGAN2Discriminator', 'StyleGANv2Generator', 'StyleGANv1Generator',
'StyleGAN1Discriminator', 'MSStyleGAN2Discriminator',
'MSStyleGANv2Generator', 'StyleGANv3Generator'
]
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import numpy as np
import scipy.signal
import torch
from mmgen.ops import conv2d_gradfix
from . import grid_sample_gradfix, misc, upfirdn2d
# ----------------------------------------------------------------------------
# Coefficients of various wavelet decomposition low-pass filters.
wavelets = {
'haar': [0.7071067811865476, 0.7071067811865476],
'db1': [0.7071067811865476, 0.7071067811865476],
'db2': [
-0.12940952255092145, 0.22414386804185735, 0.836516303737469,
0.48296291314469025
],
'db3': [
0.035226291882100656, -0.08544127388224149, -0.13501102001039084,
0.4598775021193313, 0.8068915093133388, 0.3326705529509569
],
'db4': [
-0.010597401784997278, 0.032883011666982945, 0.030841381835986965,
-0.18703481171888114, -0.02798376941698385, 0.6308807679295904,
0.7148465705525415, 0.23037781330885523
],
'db5': [
0.003335725285001549, -0.012580751999015526, -0.006241490213011705,
0.07757149384006515, -0.03224486958502952, -0.24229488706619015,
0.13842814590110342, 0.7243085284385744, 0.6038292697974729,
0.160102397974125
],
'db6': [
-0.00107730108499558, 0.004777257511010651, 0.0005538422009938016,
-0.031582039318031156, 0.02752286553001629, 0.09750160558707936,
-0.12976686756709563, -0.22626469396516913, 0.3152503517092432,
0.7511339080215775, 0.4946238903983854, 0.11154074335008017
],
'db7': [
0.0003537138000010399, -0.0018016407039998328, 0.00042957797300470274,
0.012550998556013784, -0.01657454163101562, -0.03802993693503463,
0.0806126091510659, 0.07130921926705004, -0.22403618499416572,
-0.14390600392910627, 0.4697822874053586, 0.7291320908465551,
0.39653931948230575, 0.07785205408506236
],
'db8': [
-0.00011747678400228192, 0.0006754494059985568, -0.0003917403729959771,
-0.00487035299301066, 0.008746094047015655, 0.013981027917015516,
-0.04408825393106472, -0.01736930100202211, 0.128747426620186,
0.00047248457399797254, -0.2840155429624281, -0.015829105256023893,
0.5853546836548691, 0.6756307362980128, 0.3128715909144659,
0.05441584224308161
],
'sym2': [
-0.12940952255092145, 0.22414386804185735, 0.836516303737469,
0.48296291314469025
],
'sym3': [
0.035226291882100656, -0.08544127388224149, -0.13501102001039084,
0.4598775021193313, 0.8068915093133388, 0.3326705529509569
],
'sym4': [
-0.07576571478927333, -0.02963552764599851, 0.49761866763201545,
0.8037387518059161, 0.29785779560527736, -0.09921954357684722,
-0.012603967262037833, 0.0322231006040427
],
'sym5': [
0.027333068345077982, 0.029519490925774643, -0.039134249302383094,
0.1993975339773936, 0.7234076904024206, 0.6339789634582119,
0.01660210576452232, -0.17532808990845047, -0.021101834024758855,
0.019538882735286728
],
'sym6': [
0.015404109327027373, 0.0034907120842174702, -0.11799011114819057,
-0.048311742585633, 0.4910559419267466, 0.787641141030194,
0.3379294217276218, -0.07263752278646252, -0.021060292512300564,
0.04472490177066578, 0.0017677118642428036, -0.007800708325034148
],
'sym7': [
0.002681814568257878, -0.0010473848886829163, -0.01263630340325193,
0.03051551316596357, 0.0678926935013727, -0.049552834937127255,
0.017441255086855827, 0.5361019170917628, 0.767764317003164,
0.2886296317515146, -0.14004724044296152, -0.10780823770381774,
0.004010244871533663, 0.010268176708511255
],
'sym8': [
-0.0033824159510061256, -0.0005421323317911481, 0.03169508781149298,
0.007607487324917605, -0.1432942383508097, -0.061273359067658524,
0.4813596512583722, 0.7771857517005235, 0.3644418948353314,
-0.05194583810770904, -0.027219029917056003, 0.049137179673607506,
0.003808752013890615, -0.01495225833704823, -0.0003029205147213668,
0.0018899503327594609
],
}
# ----------------------------------------------------------------------------
# Helpers for constructing transformation matrices.
def matrix(*rows, device=None):
"""Constructing transformation matrices.
Args:
device (str|torch.device, optional): Matrix device. Defaults to None.
Returns:
ndarry | Tensor : Transformation matrices in np.ndarry or torch.Tensor
format.
"""
assert all(len(row) == len(rows[0]) for row in rows)
elems = [x for row in rows for x in row]
ref = [x for x in elems if isinstance(x, torch.Tensor)]
if len(ref) == 0:
return misc.constant(np.asarray(rows), device=device)
assert device is None or device == ref[0].device
# change `x.float()` to support pt1.5
elems = [
x.float() if isinstance(x, torch.Tensor) else misc.constant(
x, shape=ref[0].shape, device=ref[0].device) for x in elems
]
return torch.stack(elems, dim=-1).reshape(ref[0].shape + (len(rows), -1))
def translate2d(tx, ty, **kwargs):
"""Construct 2d translation matrix.
Args:
tx (float): X-direction translation amount.
ty (float): Y-direction translation amount.
Returns:
ndarry | Tensor : Transformation matrices in np.ndarry or torch.Tensor
format.
"""
return matrix([1, 0, tx], [0, 1, ty], [0, 0, 1], **kwargs)
def translate3d(tx, ty, tz, **kwargs):
"""Construct 3d translation matrix.
Args:
tx (float): X-direction translation amount.
ty (float): Y-direction translation amount.
tz (float): Z-direction translation amount.
Returns:
ndarry | Tensor : Transformation matrices in np.ndarry or torch.Tensor
format.
"""
return matrix([1, 0, 0, tx], [0, 1, 0, ty], [0, 0, 1, tz], [0, 0, 0, 1],
**kwargs)
def scale2d(sx, sy, **kwargs):
"""Construct 2d scaling matrix.
Args:
sx (float): X-direction scaling coefficient.
sy (float): Y-direction scaling coefficient.
Returns:
ndarry | Tensor : Transformation matrices in np.ndarry or torch.Tensor
format.
"""
return matrix([sx, 0, 0], [0, sy, 0], [0, 0, 1], **kwargs)
def scale3d(sx, sy, sz, **kwargs):
"""Construct 3d scaling matrix.
Args:
sx (float): X-direction scaling coefficient.
sy (float): Y-direction scaling coefficient.
sz (float): Z-direction scaling coefficient.
Returns:
ndarry | Tensor : Transformation matrices in np.ndarry or torch.Tensor
format.
"""
return matrix([sx, 0, 0, 0], [0, sy, 0, 0], [0, 0, sz, 0], [0, 0, 0, 1],
**kwargs)
def rotate2d(theta, **kwargs):
"""Construct 2d rotating matrix.
Args:
theta (float): Counter-clock wise rotation angle.
Returns:
ndarry | Tensor : Transformation matrices in np.ndarry or torch.Tensor
format.
"""
return matrix([torch.cos(theta), torch.sin(-theta), 0],
[torch.sin(theta), torch.cos(theta), 0], [0, 0, 1], **kwargs)
def rotate3d(v, theta, **kwargs):
"""Constructing 3d rotating matrix.
Args:
v (torch.Tensor): Luma axis.
theta (float): Rotate theta counter-clock wise with ``v`` as the axis.
Returns:
ndarry | Tensor : Transformation matrices in np.ndarry or torch.Tensor
format.
"""
vx = v[..., 0]
vy = v[..., 1]
vz = v[..., 2]
s = torch.sin(theta)
c = torch.cos(theta)
cc = 1 - c
return matrix(
[vx * vx * cc + c, vx * vy * cc - vz * s, vx * vz * cc + vy * s, 0],
[vy * vx * cc + vz * s, vy * vy * cc + c, vy * vz * cc - vx * s, 0],
[vz * vx * cc - vy * s, vz * vy * cc + vx * s, vz * vz * cc + c, 0],
[0, 0, 0, 1], **kwargs)
def translate2d_inv(tx, ty, **kwargs):
"""Construct inverse matrix of 2d translation matrix.
Args:
tx (float): X-direction translation amount.
ty (float): Y-direction translation amount.
Returns:
ndarry | Tensor : Transformation matrices in np.ndarry or torch.Tensor
format.
"""
return translate2d(-tx, -ty, **kwargs)
def scale2d_inv(sx, sy, **kwargs):
"""Construct inverse matrix of 2d scaling matrix.
Args:
sx (float): X-direction scaling coefficient.
sy (float): Y-direction scaling coefficient.
Returns:
ndarry | Tensor : Transformation matrices in np.ndarry or torch.Tensor
format.
"""
return scale2d(1 / sx, 1 / sy, **kwargs)
def rotate2d_inv(theta, **kwargs):
"""Construct inverse matrix of 2d rotating matrix.
Args:
theta (float): Counter-clock wise rotation angle.
Returns:
ndarry | Tensor : Transformation matrices in np.ndarry or torch.Tensor
format.
"""
return rotate2d(-theta, **kwargs)
# ----------------------------------------------------------------------------
# Versatile image augmentation pipeline from the paper
# "Training Generative Adversarial Networks with Limited Data".
#
# All augmentations are disabled by default; individual augmentations can
# be enabled by setting their probability multipliers to 1.
class AugmentPipe(torch.nn.Module):
"""Augmentation pipeline include multiple geometric and color
transformations.
Note: The meaning of arguments are written in the comments of
``__init__`` function.
"""
def __init__(
self,
xflip=0,
rotate90=0,
xint=0,
xint_max=0.125,
scale=0,
rotate=0,
aniso=0,
xfrac=0,
scale_std=0.2,
rotate_max=1,
aniso_std=0.2,
xfrac_std=0.125,
brightness=0,
contrast=0,
lumaflip=0,
hue=0,
saturation=0,
brightness_std=0.2,
contrast_std=0.5,
hue_max=1,
saturation_std=1,
imgfilter=0,
imgfilter_bands=[1, 1, 1, 1],
imgfilter_std=1,
noise=0,
cutout=0,
noise_std=0.1,
cutout_size=0.5,
):
super().__init__()
self.register_buffer('p', torch.ones(
[])) # Overall multiplier for augmentation probability.
# Pixel blitting.
self.xflip = float(xflip) # Probability multiplier for x-flip.
self.rotate90 = float(
rotate90) # Probability multiplier for 90 degree rotations.
self.xint = float(
xint) # Probability multiplier for integer translation.
self.xint_max = float(
xint_max
) # Range of integer translation, relative to image dimensions.
# General geometric transformations.
self.scale = float(
scale) # Probability multiplier for isotropic scaling.
self.rotate = float(
rotate) # Probability multiplier for arbitrary rotation.
self.aniso = float(
aniso) # Probability multiplier for anisotropic scaling.
self.xfrac = float(
xfrac) # Probability multiplier for fractional translation.
self.scale_std = float(
scale_std) # Log2 standard deviation of isotropic scaling.
self.rotate_max = float(
rotate_max) # Range of arbitrary rotation, 1 = full circle.
self.aniso_std = float(
aniso_std) # Log2 standard deviation of anisotropic scaling.
self.xfrac_std = float(
xfrac_std
) # Standard deviation of frational translation, relative to img dims.
# Color transformations.
self.brightness = float(
brightness) # Probability multiplier for brightness.
self.contrast = float(contrast) # Probability multiplier for contrast.
self.lumaflip = float(
lumaflip) # Probability multiplier for luma flip.
self.hue = float(hue) # Probability multiplier for hue rotation.
self.saturation = float(
saturation) # Probability multiplier for saturation.
self.brightness_std = float(
brightness_std) # Standard deviation of brightness.
self.contrast_std = float(
contrast_std) # Log2 standard deviation of contrast.
self.hue_max = float(
hue_max) # Range of hue rotation, 1 = full circle.
self.saturation_std = float(
saturation_std) # Log2 standard deviation of saturation.
# Image-space filtering.
self.imgfilter = float(
imgfilter) # Probability multiplier for image-space filtering.
self.imgfilter_bands = list(
imgfilter_bands
) # Probability multipliers for individual frequency bands.
self.imgfilter_std = float(
imgfilter_std
) # Log2 standard deviation of image-space filter amplification.
# Image-space corruptions.
self.noise = float(
noise) # Probability multiplier for additive RGB noise.
self.cutout = float(cutout) # Probability multiplier for cutout.
self.noise_std = float(
noise_std) # Standard deviation of additive RGB noise.
self.cutout_size = float(
cutout_size
) # Size of the cutout rectangle, relative to image dimensions.
# Setup orthogonal lowpass filter for geometric augmentations.
self.register_buffer('Hz_geom',
upfirdn2d.setup_filter(wavelets['sym6']))
# Construct filter bank for image-space filtering.
Hz_lo = np.asarray(wavelets['sym2']) # H(z)
Hz_hi = Hz_lo * ((-1)**np.arange(Hz_lo.size)) # H(-z)
Hz_lo2 = np.convolve(Hz_lo, Hz_lo[::-1]) / 2 # H(z) * H(z^-1) / 2
Hz_hi2 = np.convolve(Hz_hi, Hz_hi[::-1]) / 2 # H(-z) * H(-z^-1) / 2
Hz_fbank = np.eye(4, 1) # Bandpass(H(z), b_i)
for i in range(1, Hz_fbank.shape[0]):
Hz_fbank = np.dstack([Hz_fbank, np.zeros_like(Hz_fbank)
]).reshape(Hz_fbank.shape[0], -1)[:, :-1]
Hz_fbank = scipy.signal.convolve(Hz_fbank, [Hz_lo2])
Hz_fbank[i, (Hz_fbank.shape[1] - Hz_hi2.size) //
2:(Hz_fbank.shape[1] + Hz_hi2.size) // 2] += Hz_hi2
self.register_buffer('Hz_fbank',
torch.as_tensor(Hz_fbank, dtype=torch.float32))
def forward(self, images, debug_percentile=None):
assert isinstance(images, torch.Tensor) and images.ndim == 4
batch_size, num_channels, height, width = images.shape
device = images.device
if debug_percentile is not None:
debug_percentile = torch.as_tensor(
debug_percentile, dtype=torch.float32, device=device)
# -------------------------------------
# Select parameters for pixel blitting.
# -------------------------------------
# Initialize inverse homogeneous 2D transform:
# G_inv @ pixel_out ==> pixel_in
I_3 = torch.eye(3, device=device)
G_inv = I_3
# Apply x-flip with probability (xflip * strength).
if self.xflip > 0:
i = torch.floor(torch.rand([batch_size], device=device) * 2)
i = torch.where(
torch.rand([batch_size], device=device) < self.xflip * self.p,
i, torch.zeros_like(i))
if debug_percentile is not None:
i = torch.full_like(i, torch.floor(debug_percentile * 2))
G_inv = G_inv @ scale2d_inv(1 - 2 * i, 1)
# Apply 90 degree rotations with probability (rotate90 * strength).
if self.rotate90 > 0:
i = torch.floor(torch.rand([batch_size], device=device) * 4)
i = torch.where(
torch.rand([batch_size], device=device) <
self.rotate90 * self.p, i, torch.zeros_like(i))
if debug_percentile is not None:
i = torch.full_like(i, torch.floor(debug_percentile * 4))
G_inv = G_inv @ rotate2d_inv(-np.pi / 2 * i)
# Apply integer translation with probability (xint * strength).
if self.xint > 0:
t = (torch.rand([batch_size, 2], device=device) * 2 -
1) * self.xint_max
t = torch.where(
torch.rand([batch_size, 1], device=device) <
self.xint * self.p, t, torch.zeros_like(t))
if debug_percentile is not None:
t = torch.full_like(t,
(debug_percentile * 2 - 1) * self.xint_max)
G_inv = G_inv @ translate2d_inv(
torch.round(t[:, 0] * width), torch.round(t[:, 1] * height))
# --------------------------------------------------------
# Select parameters for general geometric transformations.
# --------------------------------------------------------
# support for pt1.5 (pt1.5 does not contain exp2)
_scalor_log2 = torch.log(
torch.tensor(2., device=images.device, dtype=images.dtype))
# Apply isotropic scaling with probability (scale * strength).
if self.scale > 0:
s = torch.exp(
torch.randn([batch_size], device=device) * self.scale_std *
_scalor_log2)
s = torch.where(
torch.rand([batch_size], device=device) < self.scale * self.p,
s, torch.ones_like(s))
if debug_percentile is not None:
s = torch.full_like(
s,
torch.exp2(
torch.erfinv(debug_percentile * 2 - 1) *
self.scale_std))
G_inv = G_inv @ scale2d_inv(s, s)
# Apply pre-rotation with probability p_rot.
p_rot = 1 - torch.sqrt(
(1 - self.rotate * self.p).clamp(0, 1)) # P(pre OR post) = p
if self.rotate > 0:
theta = (torch.rand([batch_size], device=device) * 2 -
1) * np.pi * self.rotate_max
theta = torch.where(
torch.rand([batch_size], device=device) < p_rot, theta,
torch.zeros_like(theta))
if debug_percentile is not None:
theta = torch.full_like(theta, (debug_percentile * 2 - 1) *
np.pi * self.rotate_max)
G_inv = G_inv @ rotate2d_inv(-theta) # Before anisotropic scaling.
# Apply anisotropic scaling with probability (aniso * strength).
if self.aniso > 0:
s = torch.exp(
torch.randn([batch_size], device=device) * self.aniso_std *
_scalor_log2)
s = torch.where(
torch.rand([batch_size], device=device) < self.aniso * self.p,
s, torch.ones_like(s))
if debug_percentile is not None:
s = torch.full_like(
s,
torch.exp2(
torch.erfinv(debug_percentile * 2 - 1) *
self.aniso_std))
G_inv = G_inv @ scale2d_inv(s, 1 / s)
# Apply post-rotation with probability p_rot.
if self.rotate > 0:
theta = (torch.rand([batch_size], device=device) * 2 -
1) * np.pi * self.rotate_max
theta = torch.where(
torch.rand([batch_size], device=device) < p_rot, theta,
torch.zeros_like(theta))
if debug_percentile is not None:
theta = torch.zeros_like(theta)
G_inv = G_inv @ rotate2d_inv(-theta) # After anisotropic scaling.
# Apply fractional translation with probability (xfrac * strength).
if self.xfrac > 0:
t = torch.randn([batch_size, 2], device=device) * self.xfrac_std
t = torch.where(
torch.rand([batch_size, 1], device=device) <
self.xfrac * self.p, t, torch.zeros_like(t))
if debug_percentile is not None:
t = torch.full_like(
t,
torch.erfinv(debug_percentile * 2 - 1) * self.xfrac_std)
G_inv = G_inv @ translate2d_inv(t[:, 0] * width, t[:, 1] * height)
# ----------------------------------
# Execute geometric transformations.
# ----------------------------------
# Execute if the transform is not identity.
if G_inv is not I_3:
# Calculate padding.
cx = (width - 1) / 2
cy = (height - 1) / 2
cp = matrix([-cx, -cy, 1], [cx, -cy, 1], [cx, cy, 1], [-cx, cy, 1],
device=device) # [idx, xyz]
cp = G_inv @ cp.t() # [batch, xyz, idx]
Hz_pad = self.Hz_geom.shape[0] // 4
margin = cp[:, :2, :].permute(1, 0,
2).flatten(1) # [xy, batch * idx]
margin = torch.cat([-margin,
margin]).max(dim=1).values # [x0, y0, x1, y1]
margin = margin + misc.constant(
[Hz_pad * 2 - cx, Hz_pad * 2 - cy] * 2, device=device)
margin = margin.max(misc.constant([0, 0] * 2, device=device))
margin = margin.min(
misc.constant([width - 1, height - 1] * 2, device=device))
mx0, my0, mx1, my1 = margin.ceil().to(torch.int32)
# Pad image and adjust origin.
images = torch.nn.functional.pad(
input=images, pad=[mx0, mx1, my0, my1], mode='reflect')
G_inv = translate2d(
torch.true_divide(mx0 - mx1, 2), torch.true_divide(
my0 - my1, 2)) @ G_inv
# Upsample.
images = upfirdn2d.upsample2d(x=images, f=self.Hz_geom, up=2)
G_inv = scale2d(
2, 2, device=device) @ G_inv @ scale2d_inv(
2, 2, device=device)
G_inv = translate2d(
-0.5, -0.5, device=device) @ G_inv @ translate2d_inv(
-0.5, -0.5, device=device)
# Execute transformation.
shape = [
batch_size, num_channels, (height + Hz_pad * 2) * 2,
(width + Hz_pad * 2) * 2
]
G_inv = scale2d(
2 / images.shape[3], 2 / images.shape[2],
device=device) @ G_inv @ scale2d_inv(
2 / shape[3], 2 / shape[2], device=device)
grid = torch.nn.functional.affine_grid(
theta=G_inv[:, :2, :], size=shape, align_corners=False)
images = grid_sample_gradfix.grid_sample(images, grid)
# Downsample and crop.
images = upfirdn2d.downsample2d(
x=images,
f=self.Hz_geom,
down=2,
padding=-Hz_pad * 2,
flip_filter=True)
# --------------------------------------------
# Select parameters for color transformations.
# --------------------------------------------
# Initialize homogeneous 3D transformation matrix:
# C @ color_in ==> color_out
I_4 = torch.eye(4, device=device)
C = I_4
# Apply brightness with probability (brightness * strength).
if self.brightness > 0:
b = torch.randn([batch_size], device=device) * self.brightness_std
b = torch.where(
torch.rand([batch_size], device=device) <
self.brightness * self.p, b, torch.zeros_like(b))
if debug_percentile is not None:
b = torch.full_like(
b,
torch.erfinv(debug_percentile * 2 - 1) *
self.brightness_std)
C = translate3d(b, b, b) @ C
# Apply contrast with probability (contrast * strength).
if self.contrast > 0:
c = torch.exp(
torch.randn([batch_size], device=device) * self.contrast_std *
_scalor_log2)
c = torch.where(
torch.rand([batch_size], device=device) <
self.contrast * self.p, c, torch.ones_like(c))
if debug_percentile is not None:
c = torch.full_like(
c,
torch.exp2(
torch.erfinv(debug_percentile * 2 - 1) *
self.contrast_std))
C = scale3d(c, c, c) @ C
# Apply luma flip with probability (lumaflip * strength).
v = misc.constant(
np.asarray([1, 1, 1, 0]) / np.sqrt(3), device=device) # Luma axis.
if self.lumaflip > 0:
i = torch.floor(torch.rand([batch_size, 1, 1], device=device) * 2)
i = torch.where(
torch.rand([batch_size, 1, 1], device=device) <
self.lumaflip * self.p, i, torch.zeros_like(i))
if debug_percentile is not None:
i = torch.full_like(i, torch.floor(debug_percentile * 2))
C = (I_4 - 2 * v.ger(v) * i) @ C # Householder reflection.
# Apply hue rotation with probability (hue * strength).
if self.hue > 0 and num_channels > 1:
theta = (torch.rand([batch_size], device=device) * 2 -
1) * np.pi * self.hue_max
theta = torch.where(
torch.rand([batch_size], device=device) < self.hue * self.p,
theta, torch.zeros_like(theta))
if debug_percentile is not None:
theta = torch.full_like(theta, (debug_percentile * 2 - 1) *
np.pi * self.hue_max)
C = rotate3d(v, theta) @ C # Rotate around v.
# Apply saturation with probability (saturation * strength).
if self.saturation > 0 and num_channels > 1:
s = torch.exp(
torch.randn([batch_size, 1, 1], device=device) *
self.saturation_std * _scalor_log2)
s = torch.where(
torch.rand([batch_size, 1, 1], device=device) <
self.saturation * self.p, s, torch.ones_like(s))
if debug_percentile is not None:
s = torch.full_like(
s,
torch.exp2(
torch.erfinv(debug_percentile * 2 - 1) *
self.saturation_std))
C = (v.ger(v) + (I_4 - v.ger(v)) * s) @ C
# ------------------------------
# Execute color transformations.
# ------------------------------
# Execute if the transform is not identity.
if C is not I_4:
images = images.reshape([batch_size, num_channels, height * width])
if num_channels == 3:
images = C[:, :3, :3] @ images + C[:, :3, 3:]
elif num_channels == 1:
C = C[:, :3, :].mean(dim=1, keepdims=True)
images = images * C[:, :, :3].sum(
dim=2, keepdims=True) + C[:, :, 3:]
else:
raise ValueError(
'Image must be RGB (3 channels) or L (1 channel)')
images = images.reshape([batch_size, num_channels, height, width])
# ----------------------
# Image-space filtering.
# ----------------------
if self.imgfilter > 0:
num_bands = self.Hz_fbank.shape[0]
assert len(self.imgfilter_bands) == num_bands
expected_power = misc.constant(
np.array([10, 1, 1, 1]) / 13,
device=device) # Expected power spectrum (1/f).
# Apply amplification for each band with probability
# (imgfilter * strength * band_strength).
g = torch.ones([batch_size, num_bands],
device=device) # Global gain vector (identity).
for i, band_strength in enumerate(self.imgfilter_bands):
t_i = torch.exp(
torch.randn([batch_size], device=device) *
self.imgfilter_std * _scalor_log2)
t_i = torch.where(
torch.rand([batch_size], device=device) <
self.imgfilter * self.p * band_strength, t_i,
torch.ones_like(t_i))
if debug_percentile is not None:
t_i = torch.full_like(
t_i,
torch.exp2(
torch.erfinv(debug_percentile * 2 - 1) *
self.imgfilter_std)
) if band_strength > 0 else torch.ones_like(t_i)
t = torch.ones([batch_size, num_bands],
device=device) # Temporary gain vector.
t[:, i] = t_i # Replace i'th element.
t = t / (expected_power * t.square()).sum(
dim=-1, keepdims=True).sqrt() # Normalize power.
g = g * t # Accumulate into global gain.
# Construct combined amplification filter.
Hz_prime = g @ self.Hz_fbank # [batch, tap]
Hz_prime = Hz_prime.unsqueeze(1).repeat(
[1, num_channels, 1]) # [batch, channels, tap]
Hz_prime = Hz_prime.reshape([batch_size * num_channels, 1,
-1]) # [batch * channels, 1, tap]
# Apply filter.
p = self.Hz_fbank.shape[1] // 2
images = images.reshape(
[1, batch_size * num_channels, height, width])
images = torch.nn.functional.pad(
input=images, pad=[p, p, p, p], mode='reflect')
images = conv2d_gradfix.conv2d(
input=images,
weight=Hz_prime.unsqueeze(2),
groups=batch_size * num_channels)
images = conv2d_gradfix.conv2d(
input=images,
weight=Hz_prime.unsqueeze(3),
groups=batch_size * num_channels)
images = images.reshape([batch_size, num_channels, height, width])
# ------------------------
# Image-space corruptions.
# ------------------------
# Apply additive RGB noise with probability (noise * strength).
if self.noise > 0:
sigma = torch.randn([batch_size, 1, 1, 1],
device=device).abs() * self.noise_std
sigma = torch.where(
torch.rand([batch_size, 1, 1, 1], device=device) <
self.noise * self.p, sigma, torch.zeros_like(sigma))
if debug_percentile is not None:
sigma = torch.full_like(
sigma,
torch.erfinv(debug_percentile) * self.noise_std)
images = images + torch.randn(
[batch_size, num_channels, height, width],
device=device) * sigma
# Apply cutout with probability (cutout * strength).
if self.cutout > 0:
size = torch.full([batch_size, 2, 1, 1, 1],
self.cutout_size,
device=device)
size = torch.where(
torch.rand([batch_size, 1, 1, 1, 1], device=device) <
self.cutout * self.p, size, torch.zeros_like(size))
center = torch.rand([batch_size, 2, 1, 1, 1], device=device)
if debug_percentile is not None:
size = torch.full_like(size, self.cutout_size)
center = torch.full_like(center, debug_percentile)
coord_x = torch.arange(width, device=device).reshape([1, 1, 1, -1])
coord_y = torch.arange(
height, device=device).reshape([1, 1, -1, 1])
mask_x = (((coord_x + 0.5) / width - center[:, 0]).abs() >=
size[:, 0] / 2)
mask_y = (((coord_y + 0.5) / height - center[:, 1]).abs() >=
size[:, 1] / 2)
mask = torch.logical_or(mask_x, mask_y).to(torch.float32)
images = images * mask
return images
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Custom replacement for `torch.nn.functional.grid_sample` that supports
arbitrarily high order gradients between the input and output.
Only works on 2D images and assumes `mode='bilinear'`, `padding_mode='zeros'`,
`align_corners=False`.
"""
import warnings
import torch
# pylint: disable=redefined-builtin
# pylint: disable=arguments-differ
# pylint: disable=protected-access
# ----------------------------------------------------------------------------
enabled = True # Enable the custom op by setting this to true.
# ----------------------------------------------------------------------------
def grid_sample(input, grid):
if _should_use_custom_op():
return _GridSample2dForward.apply(input, grid)
return torch.nn.functional.grid_sample(
input=input,
grid=grid,
mode='bilinear',
padding_mode='zeros',
align_corners=False)
# ----------------------------------------------------------------------------
def _should_use_custom_op():
if not enabled:
return False
if any(
torch.__version__.startswith(x)
for x in ['1.5.', '1.6.', '1.7.', '1.8.', '1.9.', '1.10.']):
return True
warnings.warn(
f'grid_sample_gradfix not supported on PyTorch {torch.__version__}.'
' Falling back to torch.nn.functional.grid_sample().')
return False
# ----------------------------------------------------------------------------
class _GridSample2dForward(torch.autograd.Function):
@staticmethod
def forward(ctx, input, grid):
assert input.ndim == 4
assert grid.ndim == 4
output = torch.nn.functional.grid_sample(
input=input,
grid=grid,
mode='bilinear',
padding_mode='zeros',
align_corners=False)
ctx.save_for_backward(input, grid)
return output
@staticmethod
def backward(ctx, grad_output):
input, grid = ctx.saved_tensors
grad_input, grad_grid = _GridSample2dBackward.apply(
grad_output, input, grid)
return grad_input, grad_grid
# ----------------------------------------------------------------------------
class _GridSample2dBackward(torch.autograd.Function):
@staticmethod
def forward(ctx, grad_output, input, grid):
op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
ctx.save_for_backward(grid)
return grad_input, grad_grid
@staticmethod
def backward(ctx, grad2_grad_input, grad2_grad_grid):
_ = grad2_grad_grid # unused
grid, = ctx.saved_tensors
grad2_grad_output = None
grad2_input = None
grad2_grid = None
if ctx.needs_input_grad[0]:
grad2_grad_output = _GridSample2dForward.apply(
grad2_grad_input, grid)
assert not ctx.needs_input_grad[2]
return grad2_grad_output, grad2_input, grad2_grid
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
# Cached construction of constant tensors. Avoids CPU=>GPU copy when the
# same constant is used multiple times.
_constant_cache = dict()
def constant(value, shape=None, dtype=None, device=None, memory_format=None):
value = np.asarray(value)
if shape is not None:
shape = tuple(shape)
if dtype is None:
dtype = torch.get_default_dtype()
if device is None:
device = torch.device('cpu')
if memory_format is None:
memory_format = torch.contiguous_format
key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device,
memory_format)
tensor = _constant_cache.get(key, None)
if tensor is None:
tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
if shape is not None:
tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
tensor = tensor.contiguous(memory_format=memory_format)
_constant_cache[key] = tensor
return tensor
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
from mmcv.ops.upfirdn2d import upfirdn2d
def _parse_scaling(scaling):
if isinstance(scaling, int):
scaling = [scaling, scaling]
assert isinstance(scaling, (list, tuple))
assert all(isinstance(x, int) for x in scaling)
sx, sy = scaling
assert sx >= 1 and sy >= 1
return sx, sy
def _parse_padding(padding):
if isinstance(padding, int):
padding = [padding, padding]
assert isinstance(padding, (list, tuple))
assert all(isinstance(x, int) for x in padding)
if len(padding) == 2:
padx, pady = padding
padding = [padx, padx, pady, pady]
padx0, padx1, pady0, pady1 = padding
return padx0, padx1, pady0, pady1
def _get_filter_size(f):
if f is None:
return 1, 1
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
fw = f.shape[-1]
fh = f.shape[0]
fw = int(fw)
fh = int(fh)
assert fw >= 1 and fh >= 1
return fw, fh
def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
r"""Upsample a batch of 2D images using the given 2D FIR filter.
By default, the result is padded so that its shape is a multiple of the
input.
User-specified padding is applied on top of that, with negative values
indicating cropping. Pixels outside the image are assumed to be zero.
Args:
x: Float32/float64/float16 input tensor of the shape
`[batch_size, num_channels, in_height, in_width]`.
f: Float32 FIR filter of the shape
`[filter_height, filter_width]` (non-separable),
`[filter_taps]` (separable), or
`None` (identity).
up: Integer upsampling factor. Can be a single int or a
list/tuple
`[x, y]` (default: 1).
padding: Padding with respect to the output. Can be a single
number or a
list/tuple `[x, y]` or `[x_before, x_after, y_before,
y_after]`
(default: 0).
flip_filter: False = convolution, True = correlation (default: False).
gain: Overall scaling factor for signal magnitude (default: 1).
impl: Implementation to use. Can be `'ref'` or `'cuda'`
(default: `'cuda'`).
Returns:
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`
"""
upx, upy = _parse_scaling(up)
padx0, padx1, pady0, pady1 = _parse_padding(padding)
fw, fh = _get_filter_size(f)
p = [
padx0 + (fw + upx - 1) // 2,
padx1 + (fw - upx) // 2,
pady0 + (fh + upy - 1) // 2,
pady1 + (fh - upy) // 2,
]
gain = gain * upx * upy
f = f * (gain**(f.ndim / 2))
if flip_filter:
f = f.flip(list(range(f.ndim)))
if f.ndim == 1:
x = upfirdn2d(x, f.unsqueeze(0), up=(upx, 1), pad=(p[0], p[1], 0, 0))
x = upfirdn2d(x, f.unsqueeze(1), up=(1, upy), pad=(0, 0, p[2], p[3]))
return x
def setup_filter(f,
device=torch.device('cpu'),
normalize=True,
flip_filter=False,
gain=1,
separable=None):
r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
Args:
f: Torch tensor, numpy array, or python list of the shape
`[filter_height, filter_width]` (non-separable),
`[filter_taps]` (separable),
`[]` (impulse), or
`None` (identity).
device: Result device (default: cpu).
normalize: Normalize the filter so that it retains the magnitude
for constant input signal (DC)? (default: True).
flip_filter: Flip the filter? (default: False).
gain: Overall scaling factor for signal magnitude (default: 1).
separable: Return a separable filter? (default: select automatically)
Returns:
Float32 tensor of the shape
`[filter_height, filter_width]` (non-separable) or
`[filter_taps]` (separable).
"""
# Validate.
if f is None:
f = 1
f = torch.as_tensor(f, dtype=torch.float32)
assert f.ndim in [0, 1, 2]
assert f.numel() > 0
if f.ndim == 0:
f = f[np.newaxis]
# Separable?
if separable is None:
separable = (f.ndim == 1 and f.numel() >= 8)
if f.ndim == 1 and not separable:
f = f.ger(f)
assert f.ndim == (1 if separable else 2)
# Apply normalize, flip, gain, and device.
if normalize:
f /= f.sum()
if flip_filter:
f = f.flip(list(range(f.ndim)))
f = f * (gain**(f.ndim / 2))
f = f.to(device=device)
return f
def downsample2d(x,
f,
down=2,
padding=0,
flip_filter=False,
gain=1,
impl='cuda'):
r"""Downsample a batch of 2D images using the given 2D FIR filter.
By default, the result is padded so that its shape is a fraction of the
input.
User-specified padding is applied on top of that, with negative values
indicating cropping. Pixels outside the image are assumed to be zero.
Args:
x: Float32/float64/float16 input tensor of the shape
`[batch_size, num_channels, in_height, in_width]`.
f: Float32 FIR filter of the shape
`[filter_height, filter_width]` (non-separable),
`[filter_taps]` (separable), or
`None` (identity).
down: Integer downsampling factor. Can be a single int or a
list/tuple
`[x, y]` (default: 1).
padding: Padding with respect to the input. Can be a single number
or a
list/tuple `[x, y]` or `[x_before, x_after, y_before,
y_after]`
(default: 0).
flip_filter: False = convolution, True = correlation (default: False).
gain: Overall scaling factor for signal magnitude (default: 1).
impl: Implementation to use. Can be `'ref'` or `'cuda'`
(default: `'cuda'`).
Returns:
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`
"""
downx, downy = _parse_scaling(down)
padx0, padx1, pady0, pady1 = _parse_padding(padding)
fw, fh = _get_filter_size(f)
p = [
padx0 + (fw - downx + 1) // 2,
padx1 + (fw - downx) // 2,
pady0 + (fh - downy + 1) // 2,
pady1 + (fh - downy) // 2,
]
if flip_filter:
f = f.flip(list(range(f.ndim)))
if f.ndim == 1:
x = upfirdn2d(
x, f.unsqueeze(0), down=(downx, 1), pad=(p[0], p[1], 0, 0))
x = upfirdn2d(
x, f.unsqueeze(1), down=(1, downy), pad=(0, 0, p[2], p[3]))
return x
# Copyright (c) OpenMMLab. All rights reserved.
import random
import mmcv
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmgen.models.architectures import PixelNorm
from mmgen.models.architectures.common import get_module_device
from mmgen.models.architectures.pggan import (EqualizedLRConvDownModule,
EqualizedLRConvModule)
from mmgen.models.architectures.stylegan.modules import Blur
from mmgen.models.builder import MODULES
from .. import MiniBatchStddevLayer
from .modules.styleganv1_modules import StyleConv
from .modules.styleganv2_modules import EqualLinearActModule
from .utils import get_mean_latent, style_mixing
@MODULES.register_module()
class StyleGANv1Generator(nn.Module):
"""StyleGAN1 Generator.
In StyleGAN1, we use a progressive growing architecture composing of a
style mapping module and number of convolutional style blocks. More details
can be found in: A Style-Based Generator Architecture for Generative
Adversarial Networks CVPR2019.
Args:
out_size (int): The output size of the StyleGAN1 generator.
style_channels (int): The number of channels for style code.
num_mlps (int, optional): The number of MLP layers. Defaults to 8.
blur_kernel (list, optional): The blurry kernel. Defaults
to [1, 2, 1].
lr_mlp (float, optional): The learning rate for the style mapping
layer. Defaults to 0.01.
default_style_mode (str, optional): The default mode of style mixing.
In training, we defaultly adopt mixing style mode. However, in the
evaluation, we use 'single' style mode. `['mix', 'single']` are
currently supported. Defaults to 'mix'.
eval_style_mode (str, optional): The evaluation mode of style mixing.
Defaults to 'single'.
mix_prob (float, optional): Mixing probability. The value should be
in range of [0, 1]. Defaults to 0.9.
"""
def __init__(self,
out_size,
style_channels,
num_mlps=8,
blur_kernel=[1, 2, 1],
lr_mlp=0.01,
default_style_mode='mix',
eval_style_mode='single',
mix_prob=0.9):
super().__init__()
self.out_size = out_size
self.style_channels = style_channels
self.num_mlps = num_mlps
self.lr_mlp = lr_mlp
self._default_style_mode = default_style_mode
self.default_style_mode = default_style_mode
self.eval_style_mode = eval_style_mode
self.mix_prob = mix_prob
# define style mapping layers
mapping_layers = [PixelNorm()]
for _ in range(num_mlps):
mapping_layers.append(
EqualLinearActModule(
style_channels,
style_channels,
equalized_lr_cfg=dict(lr_mul=lr_mlp, gain=1.),
act_cfg=dict(type='LeakyReLU', negative_slope=0.2)))
self.style_mapping = nn.Sequential(*mapping_layers)
self.channels = {
4: 512,
8: 512,
16: 512,
32: 512,
64: 256,
128: 128,
256: 64,
512: 32,
1024: 16,
}
# generator backbone (8x8 --> higher resolutions)
self.log_size = int(np.log2(self.out_size))
self.convs = nn.ModuleList()
self.to_rgbs = nn.ModuleList()
in_channels_ = self.channels[4]
for i in range(2, self.log_size + 1):
out_channels_ = self.channels[2**i]
self.convs.append(
StyleConv(
in_channels_,
out_channels_,
3,
style_channels,
initial=(i == 2),
upsample=True,
fused=True))
self.to_rgbs.append(
EqualizedLRConvModule(out_channels_, 3, 1, act_cfg=None))
in_channels_ = out_channels_
self.num_latents = self.log_size * 2 - 2
self.num_injected_noises = self.num_latents
# register buffer for injected noises
for layer_idx in range(self.num_injected_noises):
res = (layer_idx + 4) // 2
shape = [1, 1, 2**res, 2**res]
self.register_buffer(f'injected_noise_{layer_idx}',
torch.randn(*shape))
def train(self, mode=True):
if mode:
if self.default_style_mode != self._default_style_mode:
mmcv.print_log(
f'Switch to train style mode: {self._default_style_mode}',
'mmgen')
self.default_style_mode = self._default_style_mode
else:
if self.default_style_mode != self.eval_style_mode:
mmcv.print_log(
f'Switch to evaluation style mode: {self.eval_style_mode}',
'mmgen')
self.default_style_mode = self.eval_style_mode
return super(StyleGANv1Generator, self).train(mode)
def make_injected_noise(self):
"""make noises that will be injected into feature maps.
Returns:
list[Tensor]: List of layer-wise noise tensor.
"""
device = get_module_device(self)
# noises = [torch.randn(1, 1, 2**2, 2**2, device=device)]
noises = []
for i in range(2, self.log_size + 1):
for _ in range(2):
noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
return noises
def get_mean_latent(self, num_samples=4096, **kwargs):
"""Get mean latent of W space in this generator.
Args:
num_samples (int, optional): Number of sample times. Defaults
to 4096.
Returns:
Tensor: Mean latent of this generator.
"""
return get_mean_latent(self, num_samples, **kwargs)
def style_mixing(self,
n_source,
n_target,
inject_index=1,
truncation_latent=None,
truncation=0.7,
curr_scale=-1,
transition_weight=1):
return style_mixing(
self,
n_source=n_source,
n_target=n_target,
inject_index=inject_index,
truncation=truncation,
truncation_latent=truncation_latent,
style_channels=self.style_channels,
curr_scale=curr_scale,
transition_weight=transition_weight)
def forward(self,
styles,
num_batches=-1,
return_noise=False,
return_latents=False,
inject_index=None,
truncation=1,
truncation_latent=None,
input_is_latent=False,
injected_noise=None,
randomize_noise=True,
transition_weight=1.,
curr_scale=-1):
"""Forward function.
This function has been integrated with the truncation trick. Please
refer to the usage of `truncation` and `truncation_latent`.
Args:
styles (torch.Tensor | list[torch.Tensor] | callable | None): In
StyleGAN1, you can provide noise tensor or latent tensor. Given
a list containing more than one noise or latent tensors, style
mixing trick will be used in training. Of course, You can
directly give a batch of noise through a ``torch.Tensor`` or
offer a callable function to sample a batch of noise data.
Otherwise, the ``None`` indicates to use the default noise
sampler.
num_batches (int, optional): The number of batch size.
Defaults to 0.
return_noise (bool, optional): If True, ``noise_batch`` will be
returned in a dict with ``fake_img``. Defaults to False.
return_latents (bool, optional): If True, ``latent`` will be
returned in a dict with ``fake_img``. Defaults to False.
inject_index (int | None, optional): The index number for mixing
style codes. Defaults to None.
truncation (float, optional): Truncation factor. Give value less
than 1., the truncation trick will be adopted. Defaults to 1.
truncation_latent (torch.Tensor, optional): Mean truncation latent.
Defaults to None.
input_is_latent (bool, optional): If `True`, the input tensor is
the latent tensor. Defaults to False.
injected_noise (torch.Tensor | None, optional): Given a tensor, the
random noise will be fixed as this input injected noise.
Defaults to None.
randomize_noise (bool, optional): If `False`, images are sampled
with the buffered noise tensor injected to the style conv
block. Defaults to True.
transition_weight (float, optional): The weight used in resolution
transition. Defaults to 1..
curr_scale (int, optional): The resolution scale of generated image
tensor. -1 means the max resolution scale of the StyleGAN1.
Defaults to -1.
Returns:
torch.Tensor | dict: Generated image tensor or dictionary \
containing more data.
"""
# receive noise and conduct sanity check.
if isinstance(styles, torch.Tensor):
assert styles.shape[1] == self.style_channels
styles = [styles]
elif mmcv.is_seq_of(styles, torch.Tensor):
for t in styles:
assert t.shape[-1] == self.style_channels
# receive a noise generator and sample noise.
elif callable(styles):
device = get_module_device(self)
noise_generator = styles
assert num_batches > 0
if self.default_style_mode == 'mix' and random.random(
) < self.mix_prob:
styles = [
noise_generator((num_batches, self.style_channels))
for _ in range(2)
]
else:
styles = [noise_generator((num_batches, self.style_channels))]
styles = [s.to(device) for s in styles]
# otherwise, we will adopt default noise sampler.
else:
device = get_module_device(self)
assert num_batches > 0 and not input_is_latent
if self.default_style_mode == 'mix' and random.random(
) < self.mix_prob:
styles = [
torch.randn((num_batches, self.style_channels))
for _ in range(2)
]
else:
styles = [torch.randn((num_batches, self.style_channels))]
styles = [s.to(device) for s in styles]
if not input_is_latent:
noise_batch = styles
styles = [self.style_mapping(s) for s in styles]
else:
noise_batch = None
if injected_noise is None:
if randomize_noise:
injected_noise = [None] * self.num_injected_noises
else:
injected_noise = [
getattr(self, f'injected_noise_{i}')
for i in range(self.num_injected_noises)
]
# use truncation trick
if truncation < 1:
style_t = []
# calculate truncation latent on the fly
if truncation_latent is None and not hasattr(
self, 'truncation_latent'):
self.truncation_latent = self.get_mean_latent()
truncation_latent = self.truncation_latent
elif truncation_latent is None and hasattr(self,
'truncation_latent'):
truncation_latent = self.truncation_latent
for style in styles:
style_t.append(truncation_latent + truncation *
(style - truncation_latent))
styles = style_t
# no style mixing
if len(styles) < 2:
inject_index = self.num_latents
if styles[0].ndim < 3:
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
else:
latent = styles[0]
# style mixing
else:
if inject_index is None:
inject_index = random.randint(1, self.num_latents - 1)
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
latent2 = styles[1].unsqueeze(1).repeat(
1, self.num_latents - inject_index, 1)
latent = torch.cat([latent, latent2], 1)
curr_log_size = self.log_size if curr_scale < 0 else int(
np.log2(curr_scale))
step = curr_log_size - 2
_index = 0
out = latent
# 4x4 ---> higher resolutions
for i, (conv, to_rgb) in enumerate(zip(self.convs, self.to_rgbs)):
if i > 0 and step > 0:
out_prev = out
out = conv(
out,
latent[:, _index],
latent[:, _index + 1],
noise1=injected_noise[2 * i],
noise2=injected_noise[2 * i + 1])
if i == step:
out = to_rgb(out)
if i > 0 and 0 <= transition_weight < 1:
skip_rgb = self.to_rgbs[i - 1](out_prev)
skip_rgb = F.interpolate(
skip_rgb, scale_factor=2, mode='nearest')
out = (1 - transition_weight
) * skip_rgb + transition_weight * out
break
_index += 2
img = out
if return_latents or return_noise:
output_dict = dict(
fake_img=img,
latent=latent,
inject_index=inject_index,
noise_batch=noise_batch)
return output_dict
return img
@MODULES.register_module()
class StyleGAN1Discriminator(nn.Module):
"""StyleGAN1 Discriminator.
The architecture of this discriminator is proposed in StyleGAN1. More
details can be found in: A Style-Based Generator Architecture for
Generative Adversarial Networks CVPR2019.
Args:
in_size (int): The input size of images.
blur_kernel (list, optional): The blurry kernel. Defaults
to [1, 2, 1].
mbstd_cfg (dict, optional): Configs for minibatch-stddev layer.
Defaults to dict(group_size=4).
"""
def __init__(self,
in_size,
blur_kernel=[1, 2, 1],
mbstd_cfg=dict(group_size=4)):
super().__init__()
self.with_mbstd = mbstd_cfg is not None
channels = {
4: 512,
8: 512,
16: 512,
32: 512,
64: 256,
128: 128,
256: 64,
512: 32,
1024: 16,
}
log_size = int(np.log2(in_size))
self.log_size = log_size
in_channels = channels[in_size]
self.convs = nn.ModuleList()
self.from_rgb = nn.ModuleList()
for i in range(log_size, 2, -1):
out_channel = channels[2**(i - 1)]
self.from_rgb.append(
EqualizedLRConvModule(
3,
in_channels,
kernel_size=3,
padding=1,
act_cfg=dict(type='LeakyReLU', negative_slope=0.2)))
self.convs.append(
nn.Sequential(
EqualizedLRConvModule(
in_channels,
out_channel,
kernel_size=3,
padding=1,
bias=True,
norm_cfg=None,
act_cfg=dict(type='LeakyReLU', negative_slope=0.2)),
Blur(blur_kernel, pad=(1, 1)),
EqualizedLRConvDownModule(
out_channel,
out_channel,
kernel_size=3,
stride=2,
padding=1,
act_cfg=None),
nn.LeakyReLU(negative_slope=0.2, inplace=True)))
in_channels = out_channel
self.from_rgb.append(
EqualizedLRConvModule(
3,
in_channels,
kernel_size=3,
padding=0,
act_cfg=dict(type='LeakyReLU', negative_slope=0.2)))
self.convs.append(
nn.Sequential(
EqualizedLRConvModule(
in_channels + 1,
512,
kernel_size=3,
padding=1,
bias=True,
norm_cfg=None,
act_cfg=dict(type='LeakyReLU', negative_slope=0.2)),
EqualizedLRConvModule(
512,
512,
kernel_size=4,
padding=0,
bias=True,
norm_cfg=None,
act_cfg=None),
))
if self.with_mbstd:
self.mbstd_layer = MiniBatchStddevLayer(**mbstd_cfg)
self.final_linear = nn.Sequential(EqualLinearActModule(channels[4], 1))
self.n_layer = len(self.convs)
def forward(self, input, transition_weight=1., curr_scale=-1):
"""Forward function.
Args:
input (torch.Tensor): Input image tensor.
transition_weight (float, optional): The weight used in resolution
transition. Defaults to 1..
curr_scale (int, optional): The resolution scale of image tensor.
-1 means the max resolution scale of the StyleGAN1.
Defaults to -1.
Returns:
torch.Tensor: Predict score for the input image.
"""
curr_log_size = self.log_size if curr_scale < 0 else int(
np.log2(curr_scale))
step = curr_log_size - 2
for i in range(step, -1, -1):
index = self.n_layer - i - 1
if i == step:
out = self.from_rgb[index](input)
# minibatch standard deviation
if i == 0:
out = self.mbstd_layer(out)
out = self.convs[index](out)
if i > 0:
if i == step and 0 <= transition_weight < 1:
skip_rgb = F.avg_pool2d(input, 2)
skip_rgb = self.from_rgb[index + 1](skip_rgb)
out = (1 - transition_weight
) * skip_rgb + transition_weight * out
out = out.view(out.shape[0], -1)
out = self.final_linear(out)
return out
# Copyright (c) OpenMMLab. All rights reserved.
import random
import mmcv
import numpy as np
import torch
import torch.nn as nn
from mmcv.runner.checkpoint import _load_checkpoint_with_prefix
from mmgen.core.runners.fp16_utils import auto_fp16
from mmgen.models.architectures import PixelNorm
from mmgen.models.architectures.common import get_module_device
from mmgen.models.builder import MODULES, build_module
from .ada.augment import AugmentPipe
from .ada.misc import constant
from .modules.styleganv2_modules import (ConstantInput, ConvDownLayer,
EqualLinearActModule,
ModMBStddevLayer, ModulatedStyleConv,
ModulatedToRGB, ResBlock)
from .utils import get_mean_latent, style_mixing
@MODULES.register_module()
class StyleGANv2Generator(nn.Module):
r"""StyleGAN2 Generator.
In StyleGAN2, we use a static architecture composing of a style mapping
module and number of convolutional style blocks. More details can be found
in: Analyzing and Improving the Image Quality of StyleGAN CVPR2020.
You can load pretrained model through passing information into
``pretrained`` argument. We have already offered official weights as
follows:
- stylegan2-ffhq-config-f: https://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-ffhq-config-f-official_20210327_171224-bce9310c.pth # noqa
- stylegan2-horse-config-f: https://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-horse-config-f-official_20210327_173203-ef3e69ca.pth # noqa
- stylegan2-car-config-f: https://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-car-config-f-official_20210327_172340-8cfe053c.pth # noqa
- stylegan2-cat-config-f: https://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-cat-config-f-official_20210327_172444-15bc485b.pth # noqa
- stylegan2-church-config-f: https://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-church-config-f-official_20210327_172657-1d42b7d1.pth # noqa
If you want to load the ema model, you can just use following codes:
.. code-block:: python
# ckpt_http is one of the valid path from http source
generator = StyleGANv2Generator(1024, 512,
pretrained=dict(
ckpt_path=ckpt_http,
prefix='generator_ema'))
Of course, you can also download the checkpoint in advance and set
``ckpt_path`` with local path. If you just want to load the original
generator (not the ema model), please set the prefix with 'generator'.
Note that our implementation allows to generate BGR image, while the
original StyleGAN2 outputs RGB images by default. Thus, we provide
``bgr2rgb`` argument to convert the image space.
Args:
out_size (int): The output size of the StyleGAN2 generator.
style_channels (int): The number of channels for style code.
num_mlps (int, optional): The number of MLP layers. Defaults to 8.
channel_multiplier (int, optional): The multiplier factor for the
channel number. Defaults to 2.
blur_kernel (list, optional): The blurry kernel. Defaults
to [1, 3, 3, 1].
lr_mlp (float, optional): The learning rate for the style mapping
layer. Defaults to 0.01.
default_style_mode (str, optional): The default mode of style mixing.
In training, we defaultly adopt mixing style mode. However, in the
evaluation, we use 'single' style mode. `['mix', 'single']` are
currently supported. Defaults to 'mix'.
eval_style_mode (str, optional): The evaluation mode of style mixing.
Defaults to 'single'.
mix_prob (float, optional): Mixing probability. The value should be
in range of [0, 1]. Defaults to ``0.9``.
num_fp16_scales (int, optional): The number of resolutions to use auto
fp16 training. Different from ``fp16_enabled``, this argument
allows users to adopt FP16 training only in several blocks.
This behaviour is much more similar to the official implementation
by Tero. Defaults to 0.
fp16_enabled (bool, optional): Whether to use fp16 training in this
module. If this flag is `True`, the whole module will be wrapped
with ``auto_fp16``. Defaults to False.
pretrained (dict | None, optional): Information for pretained models.
The necessary key is 'ckpt_path'. Besides, you can also provide
'prefix' to load the generator part from the whole state dict.
Defaults to None.
"""
def __init__(self,
out_size,
style_channels,
num_mlps=8,
channel_multiplier=2,
blur_kernel=[1, 3, 3, 1],
lr_mlp=0.01,
default_style_mode='mix',
eval_style_mode='single',
mix_prob=0.9,
num_fp16_scales=0,
fp16_enabled=False,
pretrained=None):
super().__init__()
self.out_size = out_size
self.style_channels = style_channels
self.num_mlps = num_mlps
self.channel_multiplier = channel_multiplier
self.lr_mlp = lr_mlp
self._default_style_mode = default_style_mode
self.default_style_mode = default_style_mode
self.eval_style_mode = eval_style_mode
self.mix_prob = mix_prob
self.num_fp16_scales = num_fp16_scales
self.fp16_enabled = fp16_enabled
# define style mapping layers
mapping_layers = [PixelNorm()]
for _ in range(num_mlps):
mapping_layers.append(
EqualLinearActModule(
style_channels,
style_channels,
equalized_lr_cfg=dict(lr_mul=lr_mlp, gain=1.),
act_cfg=dict(type='fused_bias')))
self.style_mapping = nn.Sequential(*mapping_layers)
self.channels = {
4: 512,
8: 512,
16: 512,
32: 512,
64: 256 * channel_multiplier,
128: 128 * channel_multiplier,
256: 64 * channel_multiplier,
512: 32 * channel_multiplier,
1024: 16 * channel_multiplier,
}
# constant input layer
self.constant_input = ConstantInput(self.channels[4])
# 4x4 stage
self.conv1 = ModulatedStyleConv(
self.channels[4],
self.channels[4],
kernel_size=3,
style_channels=style_channels,
blur_kernel=blur_kernel)
self.to_rgb1 = ModulatedToRGB(
self.channels[4],
style_channels,
upsample=False,
fp16_enabled=fp16_enabled)
# generator backbone (8x8 --> higher resolutions)
self.log_size = int(np.log2(self.out_size))
self.convs = nn.ModuleList()
self.upsamples = nn.ModuleList()
self.to_rgbs = nn.ModuleList()
in_channels_ = self.channels[4]
for i in range(3, self.log_size + 1):
out_channels_ = self.channels[2**i]
# If `fp16_enabled` is True, all of layers will be run in auto
# FP16. In the case of `num_fp16_sacles` > 0, only partial
# layers will be run in fp16.
_use_fp16 = (self.log_size - i) < num_fp16_scales or fp16_enabled
self.convs.append(
ModulatedStyleConv(
in_channels_,
out_channels_,
3,
style_channels,
upsample=True,
blur_kernel=blur_kernel,
fp16_enabled=_use_fp16))
self.convs.append(
ModulatedStyleConv(
out_channels_,
out_channels_,
3,
style_channels,
upsample=False,
blur_kernel=blur_kernel,
fp16_enabled=_use_fp16))
self.to_rgbs.append(
ModulatedToRGB(
out_channels_,
style_channels,
upsample=True,
fp16_enabled=_use_fp16)) # set to global fp16
in_channels_ = out_channels_
self.num_latents = self.log_size * 2 - 2
self.num_injected_noises = self.num_latents - 1
# register buffer for injected noises
for layer_idx in range(self.num_injected_noises):
res = (layer_idx + 5) // 2
shape = [1, 1, 2**res, 2**res]
self.register_buffer(f'injected_noise_{layer_idx}',
torch.randn(*shape))
if pretrained is not None:
self._load_pretrained_model(**pretrained)
def _load_pretrained_model(self,
ckpt_path,
prefix='',
map_location='cpu',
strict=True):
state_dict = _load_checkpoint_with_prefix(prefix, ckpt_path,
map_location)
self.load_state_dict(state_dict, strict=strict)
mmcv.print_log(f'Load pretrained model from {ckpt_path}', 'mmgen')
def train(self, mode=True):
if mode:
if self.default_style_mode != self._default_style_mode:
mmcv.print_log(
f'Switch to train style mode: {self._default_style_mode}',
'mmgen')
self.default_style_mode = self._default_style_mode
else:
if self.default_style_mode != self.eval_style_mode:
mmcv.print_log(
f'Switch to evaluation style mode: {self.eval_style_mode}',
'mmgen')
self.default_style_mode = self.eval_style_mode
return super(StyleGANv2Generator, self).train(mode)
def make_injected_noise(self):
"""make noises that will be injected into feature maps.
Returns:
list[Tensor]: List of layer-wise noise tensor.
"""
device = get_module_device(self)
noises = [torch.randn(1, 1, 2**2, 2**2, device=device)]
for i in range(3, self.log_size + 1):
for _ in range(2):
noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
return noises
def get_mean_latent(self, num_samples=4096, **kwargs):
"""Get mean latent of W space in this generator.
Args:
num_samples (int, optional): Number of sample times. Defaults
to 4096.
Returns:
Tensor: Mean latent of this generator.
"""
return get_mean_latent(self, num_samples, **kwargs)
def style_mixing(self,
n_source,
n_target,
inject_index=1,
truncation_latent=None,
truncation=0.7):
return style_mixing(
self,
n_source=n_source,
n_target=n_target,
inject_index=inject_index,
truncation=truncation,
truncation_latent=truncation_latent,
style_channels=self.style_channels)
@auto_fp16()
def forward(self,
styles,
num_batches=-1,
return_noise=False,
return_latents=False,
inject_index=None,
truncation=1,
truncation_latent=None,
input_is_latent=False,
injected_noise=None,
randomize_noise=True):
"""Forward function.
This function has been integrated with the truncation trick. Please
refer to the usage of `truncation` and `truncation_latent`.
Args:
styles (torch.Tensor | list[torch.Tensor] | callable | None): In
StyleGAN2, you can provide noise tensor or latent tensor. Given
a list containing more than one noise or latent tensors, style
mixing trick will be used in training. Of course, You can
directly give a batch of noise through a ``torch.Tensor`` or
offer a callable function to sample a batch of noise data.
Otherwise, the ``None`` indicates to use the default noise
sampler.
num_batches (int, optional): The number of batch size.
Defaults to 0.
return_noise (bool, optional): If True, ``noise_batch`` will be
returned in a dict with ``fake_img``. Defaults to False.
return_latents (bool, optional): If True, ``latent`` will be
returned in a dict with ``fake_img``. Defaults to False.
inject_index (int | None, optional): The index number for mixing
style codes. Defaults to None.
truncation (float, optional): Truncation factor. Give value less
than 1., the truncation trick will be adopted. Defaults to 1.
truncation_latent (torch.Tensor, optional): Mean truncation latent.
Defaults to None.
input_is_latent (bool, optional): If `True`, the input tensor is
the latent tensor. Defaults to False.
injected_noise (torch.Tensor | None, optional): Given a tensor, the
random noise will be fixed as this input injected noise.
Defaults to None.
randomize_noise (bool, optional): If `False`, images are sampled
with the buffered noise tensor injected to the style conv
block. Defaults to True.
Returns:
torch.Tensor | dict: Generated image tensor or dictionary \
containing more data.
"""
# receive noise and conduct sanity check.
if isinstance(styles, torch.Tensor):
assert styles.shape[1] == self.style_channels
styles = [styles]
elif mmcv.is_seq_of(styles, torch.Tensor):
for t in styles:
assert t.shape[-1] == self.style_channels
# receive a noise generator and sample noise.
elif callable(styles):
device = get_module_device(self)
noise_generator = styles
assert num_batches > 0
if self.default_style_mode == 'mix' and random.random(
) < self.mix_prob:
styles = [
noise_generator((num_batches, self.style_channels))
for _ in range(2)
]
else:
styles = [noise_generator((num_batches, self.style_channels))]
styles = [s.to(device) for s in styles]
# otherwise, we will adopt default noise sampler.
else:
device = get_module_device(self)
assert num_batches > 0 and not input_is_latent
if self.default_style_mode == 'mix' and random.random(
) < self.mix_prob:
styles = [
torch.randn((num_batches, self.style_channels))
for _ in range(2)
]
else:
styles = [torch.randn((num_batches, self.style_channels))]
styles = [s.to(device) for s in styles]
if not input_is_latent:
noise_batch = styles
styles = [self.style_mapping(s) for s in styles]
else:
noise_batch = None
if injected_noise is None:
if randomize_noise:
injected_noise = [None] * self.num_injected_noises
else:
injected_noise = [
getattr(self, f'injected_noise_{i}')
for i in range(self.num_injected_noises)
]
# use truncation trick
if truncation < 1:
style_t = []
# calculate truncation latent on the fly
if truncation_latent is None and not hasattr(
self, 'truncation_latent'):
self.truncation_latent = self.get_mean_latent()
truncation_latent = self.truncation_latent
elif truncation_latent is None and hasattr(self,
'truncation_latent'):
truncation_latent = self.truncation_latent
for style in styles:
style_t.append(truncation_latent + truncation *
(style - truncation_latent))
styles = style_t
# no style mixing
if len(styles) < 2:
inject_index = self.num_latents
if styles[0].ndim < 3:
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
else:
latent = styles[0]
# style mixing
else:
if inject_index is None:
inject_index = random.randint(1, self.num_latents - 1)
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
latent2 = styles[1].unsqueeze(1).repeat(
1, self.num_latents - inject_index, 1)
latent = torch.cat([latent, latent2], 1)
# 4x4 stage
out = self.constant_input(latent)
out = self.conv1(out, latent[:, 0], noise=injected_noise[0])
skip = self.to_rgb1(out, latent[:, 1])
_index = 1
# 8x8 ---> higher resolutions
for up_conv, conv, noise1, noise2, to_rgb in zip(
self.convs[::2], self.convs[1::2], injected_noise[1::2],
injected_noise[2::2], self.to_rgbs):
out = up_conv(out, latent[:, _index], noise=noise1)
out = conv(out, latent[:, _index + 1], noise=noise2)
skip = to_rgb(out, latent[:, _index + 2], skip)
_index += 2
# make sure the output image is torch.float32 to avoid RunTime Error
# in other modules
img = skip.to(torch.float32)
if return_latents or return_noise:
output_dict = dict(
fake_img=img,
latent=latent,
inject_index=inject_index,
noise_batch=noise_batch)
return output_dict
return img
@MODULES.register_module()
class StyleGAN2Discriminator(nn.Module):
"""StyleGAN2 Discriminator.
The architecture of this discriminator is proposed in StyleGAN2. More
details can be found in: Analyzing and Improving the Image Quality of
StyleGAN CVPR2020.
You can load pretrained model through passing information into
``pretrained`` argument. We have already offered official weights as
follows:
- stylegan2-ffhq-config-f: https://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-ffhq-config-f-official_20210327_171224-bce9310c.pth # noqa
- stylegan2-horse-config-f: https://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-horse-config-f-official_20210327_173203-ef3e69ca.pth # noqa
- stylegan2-car-config-f: https://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-car-config-f-official_20210327_172340-8cfe053c.pth # noqa
- stylegan2-cat-config-f: https://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-cat-config-f-official_20210327_172444-15bc485b.pth # noqa
- stylegan2-church-config-f: https://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-church-config-f-official_20210327_172657-1d42b7d1.pth # noqa
If you want to load the ema model, you can just use following codes:
.. code-block:: python
# ckpt_http is one of the valid path from http source
discriminator = StyleGAN2Discriminator(1024, 512,
pretrained=dict(
ckpt_path=ckpt_http,
prefix='discriminator'))
Of course, you can also download the checkpoint in advance and set
``ckpt_path`` with local path.
Note that our implementation adopts BGR image as input, while the
original StyleGAN2 provides RGB images to the discriminator. Thus, we
provide ``bgr2rgb`` argument to convert the image space. If your images
follow the RGB order, please set it to ``True`` accordingly.
Args:
in_size (int): The input size of images.
channel_multiplier (int, optional): The multiplier factor for the
channel number. Defaults to 2.
blur_kernel (list, optional): The blurry kernel. Defaults
to [1, 3, 3, 1].
mbstd_cfg (dict, optional): Configs for minibatch-stddev layer.
Defaults to dict(group_size=4, channel_groups=1).
num_fp16_scales (int, optional): The number of resolutions to use auto
fp16 training. Defaults to 0.
fp16_enabled (bool, optional): Whether to use fp16 training in this
module. Defaults to False.
out_fp32 (bool, optional): Whether to convert the output feature map to
`torch.float32`. Defaults to `True`.
convert_input_fp32 (bool, optional): Whether to convert input type to
fp32 if not `fp16_enabled`. This argument is designed to deal with
the cases where some modules are run in FP16 and others in FP32.
Defaults to True.
input_bgr2rgb (bool, optional): Whether to reformat the input channels
with order `rgb`. Since we provide several converted weights,
whose input order is `rgb`. You can set this argument to True if
you want to finetune on converted weights. Defaults to False.
pretrained (dict | None, optional): Information for pretained models.
The necessary key is 'ckpt_path'. Besides, you can also provide
'prefix' to load the generator part from the whole state dict.
Defaults to None.
"""
def __init__(self,
in_size,
channel_multiplier=2,
blur_kernel=[1, 3, 3, 1],
mbstd_cfg=dict(group_size=4, channel_groups=1),
num_fp16_scales=0,
fp16_enabled=False,
out_fp32=True,
convert_input_fp32=True,
input_bgr2rgb=False,
pretrained=None):
super().__init__()
self.num_fp16_scale = num_fp16_scales
self.fp16_enabled = fp16_enabled
self.convert_input_fp32 = convert_input_fp32
self.out_fp32 = out_fp32
channels = {
4: 512,
8: 512,
16: 512,
32: 512,
64: 256 * channel_multiplier,
128: 128 * channel_multiplier,
256: 64 * channel_multiplier,
512: 32 * channel_multiplier,
1024: 16 * channel_multiplier,
}
log_size = int(np.log2(in_size))
in_channels = channels[in_size]
_use_fp16 = num_fp16_scales > 0
convs = [
ConvDownLayer(3, channels[in_size], 1, fp16_enabled=_use_fp16)
]
for i in range(log_size, 2, -1):
out_channel = channels[2**(i - 1)]
# add fp16 training for higher resolutions
_use_fp16 = (log_size - i) < num_fp16_scales or fp16_enabled
convs.append(
ResBlock(
in_channels,
out_channel,
blur_kernel,
fp16_enabled=_use_fp16,
convert_input_fp32=convert_input_fp32))
in_channels = out_channel
self.convs = nn.Sequential(*convs)
self.mbstd_layer = ModMBStddevLayer(**mbstd_cfg)
self.final_conv = ConvDownLayer(in_channels + 1, channels[4], 3)
self.final_linear = nn.Sequential(
EqualLinearActModule(
channels[4] * 4 * 4,
channels[4],
act_cfg=dict(type='fused_bias')),
EqualLinearActModule(channels[4], 1),
)
self.input_bgr2rgb = input_bgr2rgb
if pretrained is not None:
self._load_pretrained_model(**pretrained)
def _load_pretrained_model(self,
ckpt_path,
prefix='',
map_location='cpu',
strict=True):
state_dict = _load_checkpoint_with_prefix(prefix, ckpt_path,
map_location)
self.load_state_dict(state_dict, strict=strict)
mmcv.print_log(f'Load pretrained model from {ckpt_path}', 'mmgen')
@auto_fp16()
def forward(self, x):
"""Forward function.
Args:
x (torch.Tensor): Input image tensor.
Returns:
torch.Tensor: Predict score for the input image.
"""
# This setting was used to finetune on converted weights
if self.input_bgr2rgb:
x = x[:, [2, 1, 0], ...]
x = self.convs(x)
x = self.mbstd_layer(x)
if not self.final_conv.fp16_enabled and self.convert_input_fp32:
x = x.to(torch.float32)
x = self.final_conv(x)
x = x.view(x.shape[0], -1)
x = self.final_linear(x)
return x
@MODULES.register_module()
class ADAStyleGAN2Discriminator(StyleGAN2Discriminator):
def __init__(self, in_size, *args, data_aug=None, **kwargs):
"""StyleGANv2 Discriminator with adaptive augmentation.
Args:
in_size (int): The input size of images.
data_aug (dict, optional): Config for data
augmentation. Defaults to None.
"""
super().__init__(in_size, *args, **kwargs)
self.with_ada = data_aug is not None
if self.with_ada:
self.ada_aug = build_module(data_aug)
self.ada_aug.requires_grad = False
self.log_size = int(np.log2(in_size))
def forward(self, x):
"""Forward function."""
if self.with_ada:
x = self.ada_aug.aug_pipeline(x)
return super().forward(x)
@MODULES.register_module()
class ADAAug(nn.Module):
"""Data Augmentation Module for Adaptive Discriminator augmentation.
Args:
aug_pipeline (dict, optional): Config for augmentation pipeline.
Defaults to None.
update_interval (int, optional): Interval for updating
augmentation probability. Defaults to 4.
augment_initial_p (float, optional): Initial augmentation
probability. Defaults to 0..
ada_target (float, optional): ADA target. Defaults to 0.6.
ada_kimg (int, optional): ADA training duration. Defaults to 500.
"""
def __init__(self,
aug_pipeline=None,
update_interval=4,
augment_initial_p=0.,
ada_target=0.6,
ada_kimg=500):
super().__init__()
self.aug_pipeline = AugmentPipe(**aug_pipeline)
self.update_interval = update_interval
self.ada_kimg = ada_kimg
self.ada_target = ada_target
self.aug_pipeline.p.copy_(torch.tensor(augment_initial_p))
# this log buffer stores two numbers: num_scalars, sum_scalars.
self.register_buffer('log_buffer', torch.zeros((2, )))
def update(self, iteration=0, num_batches=0):
"""Update Augment probability.
Args:
iteration (int, optional): Training iteration.
Defaults to 0.
num_batches (int, optional): The number of reals batches.
Defaults to 0.
"""
if (iteration + 1) % self.update_interval == 0:
adjust_step = float(num_batches * self.update_interval) / float(
self.ada_kimg * 1000.)
# get the mean value as the ada heuristic
ada_heuristic = self.log_buffer[1] / self.log_buffer[0]
adjust = np.sign(ada_heuristic.item() -
self.ada_target) * adjust_step
# update the augment p
# Note that p may be bigger than 1.0
self.aug_pipeline.p.copy_((self.aug_pipeline.p + adjust).max(
constant(0, device=self.log_buffer.device)))
self.log_buffer = self.log_buffer * 0.
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
import mmcv
import torch
import torch.nn as nn
from mmcv.runner.checkpoint import _load_checkpoint_with_prefix
from mmgen.models.architectures.common import get_module_device
from mmgen.models.builder import MODULES, build_module
from .utils import get_mean_latent
@MODULES.register_module()
class StyleGANv3Generator(nn.Module):
"""StyleGAN3 Generator.
In StyleGAN3, we make several changes to StyleGANv2's generator which
include transformed fourier features, filtered nonlinearities and
non-critical sampling, etc. More details can be found in: Alias-Free
Generative Adversarial Networks NeurIPS'2021.
Ref: https://github.com/NVlabs/stylegan3
Args:
out_size (int): The output size of the StyleGAN3 generator.
style_channels (int): The number of channels for style code.
img_channels (int): The number of output's channels.
noise_size (int, optional): Size of the input noise vector.
Defaults to 512.
rgb2bgr (bool, optional): Whether to reformat the output channels
with order `bgr`. We provide several pre-trained StyleGAN3
weights whose output channels order is `rgb`. You can set
this argument to True to use the weights.
pretrained (str | dict, optional): Path for the pretrained model or
dict containing information for pretained models whose necessary
key is 'ckpt_path'. Besides, you can also provide 'prefix' to load
the generator part from the whole state dict. Defaults to None.
synthesis_cfg (dict, optional): Config for synthesis network. Defaults
to dict(type='SynthesisNetwork').
mapping_cfg (dict, optional): Config for mapping network. Defaults to
dict(type='MappingNetwork').
"""
def __init__(self,
out_size,
style_channels,
img_channels,
noise_size=512,
rgb2bgr=False,
pretrained=None,
synthesis_cfg=dict(type='SynthesisNetwork'),
mapping_cfg=dict(type='MappingNetwork')):
super().__init__()
self.noise_size = noise_size
self.style_channels = style_channels
self.out_size = out_size
self.img_channels = img_channels
self.rgb2bgr = rgb2bgr
self._synthesis_cfg = deepcopy(synthesis_cfg)
self._synthesis_cfg.setdefault('style_channels', style_channels)
self._synthesis_cfg.setdefault('out_size', out_size)
self._synthesis_cfg.setdefault('img_channels', img_channels)
self.synthesis = build_module(self._synthesis_cfg)
self.num_ws = self.synthesis.num_ws
self._mapping_cfg = deepcopy(mapping_cfg)
self._mapping_cfg.setdefault('noise_size', noise_size)
self._mapping_cfg.setdefault('style_channels', style_channels)
self._mapping_cfg.setdefault('num_ws', self.num_ws)
self.style_mapping = build_module(self._mapping_cfg)
if pretrained is not None:
self._load_pretrained_model(**pretrained)
def _load_pretrained_model(self,
ckpt_path,
prefix='',
map_location='cpu',
strict=True):
state_dict = _load_checkpoint_with_prefix(prefix, ckpt_path,
map_location)
self.load_state_dict(state_dict, strict=strict)
mmcv.print_log(f'Load pretrained model from {ckpt_path}', 'mmgen')
def forward(self,
noise,
num_batches=0,
input_is_latent=False,
truncation=1,
num_truncation_layer=None,
update_emas=False,
force_fp32=True,
return_noise=False,
return_latents=False):
"""Forward Function for stylegan3.
Args:
noise (torch.Tensor | callable | None): You can directly give a
batch of noise through a ``torch.Tensor`` or offer a callable
function to sample a batch of noise data. Otherwise, the
``None`` indicates to use the default noise sampler.
num_batches (int, optional): The number of batch size.
Defaults to 0.
input_is_latent (bool, optional): If `True`, the input tensor is
the latent tensor. Defaults to False.
truncation (float, optional): Truncation factor. Give value less
than 1., the truncation trick will be adopted. Defaults to 1.
num_truncation_layer (int, optional): Number of layers use
truncated latent. Defaults to None.
update_emas (bool, optional): Whether update moving average of
mean latent. Defaults to False.
force_fp32 (bool, optional): Force fp32 ignore the weights.
Defaults to True.
return_noise (bool, optional): If True, ``noise_batch`` will be
returned in a dict with ``fake_img``. Defaults to False.
return_latents (bool, optional): If True, ``latent`` will be
returned in a dict with ``fake_img``. Defaults to False.
Returns:
torch.Tensor | dict: Generated image tensor or dictionary \
containing more data.
"""
# if input is latent, set noise size as the style_channels
noise_size = (
self.style_channels if input_is_latent else self.noise_size)
if isinstance(noise, torch.Tensor):
assert noise.shape[1] == noise_size
assert noise.ndim == 2, ('The noise should be in shape of (n, c), '
f'but got {noise.shape}')
noise_batch = noise
# receive a noise generator and sample noise.
elif callable(noise):
noise_generator = noise
assert num_batches > 0
noise_batch = noise_generator((num_batches, noise_size))
# otherwise, we will adopt default noise sampler.
else:
assert num_batches > 0
noise_batch = torch.randn((num_batches, noise_size))
device = get_module_device(self)
noise_batch = noise_batch.to(device)
if input_is_latent:
ws = noise_batch.unsqueeze(1).repeat([1, self.num_ws, 1])
else:
ws = self.style_mapping(
noise_batch,
truncation=truncation,
num_truncation_layer=num_truncation_layer,
update_emas=update_emas)
out_img = self.synthesis(
ws, update_emas=update_emas, force_fp32=force_fp32)
if self.rgb2bgr:
out_img = out_img[:, [2, 1, 0], ...]
if return_noise or return_latents:
output = dict(fake_img=out_img, noise_batch=noise_batch, latent=ws)
return output
return out_img
def get_mean_latent(self, num_samples=4096, **kwargs):
"""Get mean latent of W space in this generator.
Args:
num_samples (int, optional): Number of sample times. Defaults
to 4096.
Returns:
Tensor: Mean latent of this generator.
"""
if hasattr(self.style_mapping, 'w_avg'):
return self.style_mapping.w_avg
return get_mean_latent(self, num_samples, **kwargs)
def get_training_kwargs(self, phase):
"""Get training kwargs. In StyleGANv3, we enable fp16, and update
mangitude ema during training of discriminator. This function is used
to pass related arguments.
Args:
phase (str): Current training phase.
Returns:
dict: Training kwargs.
"""
if phase == 'disc':
return dict(update_emas=True, force_fp32=False)
if phase == 'gen':
return dict(force_fp32=False)
return {}
# Copyright (c) OpenMMLab. All rights reserved.
from .styleganv2_modules import (Blur, ConstantInput, ModulatedConv2d,
ModulatedStyleConv, ModulatedToRGB,
NoiseInjection)
from .styleganv3_modules import (MappingNetwork, SynthesisInput,
SynthesisLayer, SynthesisNetwork)
__all__ = [
'Blur', 'ModulatedStyleConv', 'ModulatedToRGB', 'NoiseInjection',
'ConstantInput', 'MappingNetwork', 'SynthesisInput', 'SynthesisLayer',
'SynthesisNetwork', 'ModulatedConv2d'
]
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmgen.models.architectures.pggan import (EqualizedLRConvModule,
EqualizedLRConvUpModule,
EqualizedLRLinearModule)
from mmgen.models.architectures.stylegan.modules import (Blur, ConstantInput,
NoiseInjection)
class AdaptiveInstanceNorm(nn.Module):
r"""Adaptive Instance Normalization Module.
Ref: https://github.com/rosinality/style-based-gan-pytorch/blob/master/model.py # noqa
Args:
in_channel (int): The number of input's channel.
style_dim (int): Style latent dimension.
"""
def __init__(self, in_channel, style_dim):
super().__init__()
self.norm = nn.InstanceNorm2d(in_channel)
self.affine = EqualizedLRLinearModule(style_dim, in_channel * 2)
self.affine.bias.data[:in_channel] = 1
self.affine.bias.data[in_channel:] = 0
def forward(self, input, style):
"""Forward function.
Args:
input (Tensor): Input tensor with shape (n, c, h, w).
style (Tensor): Input style tensor with shape (n, c).
Returns:
Tensor: Forward results.
"""
style = self.affine(style).unsqueeze(2).unsqueeze(3)
gamma, beta = style.chunk(2, 1)
out = self.norm(input)
out = gamma * out + beta
return out
class StyleConv(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
style_channels,
padding=1,
initial=False,
blur_kernel=[1, 2, 1],
upsample=False,
fused=False):
"""Convolutional style blocks composing of noise injector, AdaIN module
and convolution layers.
Args:
in_channels (int): The channel number of the input tensor.
out_channels (itn): The channel number of the output tensor.
kernel_size (int): The kernel size of convolution layers.
style_channels (int): The number of channels for style code.
padding (int, optional): Padding of convolution layers.
Defaults to 1.
initial (bool, optional): Whether this is the first StyleConv of
StyleGAN's generator. Defaults to False.
blur_kernel (list, optional): The blurry kernel.
Defaults to [1, 2, 1].
upsample (bool, optional): Whether perform upsampling.
Defaults to False.
fused (bool, optional): Whether use fused upconv.
Defaults to False.
"""
super().__init__()
if initial:
self.conv1 = ConstantInput(in_channels)
else:
if upsample:
if fused:
self.conv1 = nn.Sequential(
EqualizedLRConvUpModule(
in_channels,
out_channels,
kernel_size,
padding=padding,
act_cfg=dict(type='LeakyReLU',
negative_slope=0.2)),
Blur(blur_kernel, pad=(1, 1)),
)
else:
self.conv1 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='nearest'),
EqualizedLRConvModule(
in_channels,
out_channels,
kernel_size,
padding=padding,
act_cfg=None), Blur(blur_kernel, pad=(1, 1)))
else:
self.conv1 = EqualizedLRConvModule(
in_channels,
out_channels,
kernel_size,
padding=padding,
act_cfg=None)
self.noise_injector1 = NoiseInjection()
self.activate1 = nn.LeakyReLU(0.2)
self.adain1 = AdaptiveInstanceNorm(out_channels, style_channels)
self.conv2 = EqualizedLRConvModule(
out_channels,
out_channels,
kernel_size,
padding=padding,
act_cfg=None)
self.noise_injector2 = NoiseInjection()
self.activate2 = nn.LeakyReLU(0.2)
self.adain2 = AdaptiveInstanceNorm(out_channels, style_channels)
def forward(self,
x,
style1,
style2,
noise1=None,
noise2=None,
return_noise=False):
"""Forward function.
Args:
x (Tensor): Input tensor.
style1 (Tensor): Input style tensor with shape (n, c).
style2 (Tensor): Input style tensor with shape (n, c).
noise1 (Tensor, optional): Noise tensor with shape (n, c, h, w).
Defaults to None.
noise2 (Tensor, optional): Noise tensor with shape (n, c, h, w).
Defaults to None.
return_noise (bool, optional): If True, ``noise1`` and ``noise2``
will be returned with ``out``. Defaults to False.
Returns:
Tensor | tuple[Tensor]: Forward results.
"""
out = self.conv1(x)
if return_noise:
out, noise1 = self.noise_injector1(
out, noise=noise1, return_noise=return_noise)
else:
out = self.noise_injector1(
out, noise=noise1, return_noise=return_noise)
out = self.activate1(out)
out = self.adain1(out, style1)
out = self.conv2(out)
if return_noise:
out, noise2 = self.noise_injector2(
out, noise=noise2, return_noise=return_noise)
else:
out = self.noise_injector2(
out, noise=noise2, return_noise=return_noise)
out = self.activate2(out)
out = self.adain2(out, style2)
if return_noise:
return out, noise1, noise2
return out
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from functools import partial
import mmcv
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn.bricks.activation import build_activation_layer
from mmcv.ops.fused_bias_leakyrelu import (FusedBiasLeakyReLU,
fused_bias_leakyrelu)
from mmcv.ops.upfirdn2d import upfirdn2d
from mmcv.runner.dist_utils import get_dist_info
from mmgen.core.runners.fp16_utils import auto_fp16
from mmgen.models.architectures.pggan import (EqualizedLRConvModule,
EqualizedLRLinearModule,
equalized_lr)
from mmgen.models.common import AllGatherLayer
from mmgen.ops import conv2d, conv_transpose2d
class _FusedBiasLeakyReLU(FusedBiasLeakyReLU):
"""Wrap FusedBiasLeakyReLU to support FP16 training."""
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input feature map with shape of (N, C, ...).
Returns:
Tensor: Output feature map.
"""
return fused_bias_leakyrelu(x, self.bias.to(x.dtype),
self.negative_slope, self.scale)
class EqualLinearActModule(nn.Module):
"""Equalized LR Linear Module with Activation Layer.
This module is modified from ``EqualizedLRLinearModule`` defined in PGGAN.
The major features updated in this module is adding support for activation
layers used in StyleGAN2.
Args:
equalized_lr_cfg (dict | None, optional): Config for equalized lr.
Defaults to dict(gain=1., lr_mul=1.).
bias (bool, optional): Whether to use bias item. Defaults to True.
bias_init (float, optional): The value for bias initialization.
Defaults to ``0.``.
act_cfg (dict | None, optional): Config for activation layer.
Defaults to None.
"""
def __init__(self,
*args,
equalized_lr_cfg=dict(gain=1., lr_mul=1.),
bias=True,
bias_init=0.,
act_cfg=None,
**kwargs):
super().__init__()
self.with_activation = act_cfg is not None
# w/o bias in linear layer
self.linear = EqualizedLRLinearModule(
*args, bias=False, equalized_lr_cfg=equalized_lr_cfg, **kwargs)
if equalized_lr_cfg is not None:
self.lr_mul = equalized_lr_cfg.get('lr_mul', 1.)
else:
self.lr_mul = 1.
# define bias outside linear layer
if bias:
self.bias = nn.Parameter(
torch.zeros(self.linear.out_features).fill_(bias_init))
else:
self.bias = None
if self.with_activation:
act_cfg = deepcopy(act_cfg)
if act_cfg['type'] == 'fused_bias':
self.act_type = act_cfg.pop('type')
assert self.bias is not None
self.activate = partial(fused_bias_leakyrelu, **act_cfg)
else:
self.act_type = 'normal'
self.activate = build_activation_layer(act_cfg)
else:
self.act_type = None
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input feature map with shape of (N, C, ...).
Returns:
Tensor: Output feature map.
"""
if x.ndim >= 3:
x = x.reshape(x.size(0), -1)
x = self.linear(x)
if self.with_activation and self.act_type == 'fused_bias':
x = self.activate(x, self.bias * self.lr_mul)
elif self.bias is not None and self.with_activation:
x = self.activate(x + self.bias * self.lr_mul)
elif self.bias is not None:
x = x + self.bias * self.lr_mul
elif self.with_activation:
x = self.activate(x)
return x
def _make_kernel(k):
k = torch.tensor(k, dtype=torch.float32)
if k.ndim == 1:
k = k[None, :] * k[:, None]
k /= k.sum()
return k
class UpsampleUpFIRDn(nn.Module):
"""UpFIRDn for Upsampling.
This module is used in the ``to_rgb`` layers in StyleGAN2 for upsampling
the images.
Args:
kernel (Array): Blur kernel/filter used in UpFIRDn.
factor (int, optional): Upsampling factor. Defaults to 2.
"""
def __init__(self, kernel, factor=2):
super().__init__()
self.factor = factor
kernel = _make_kernel(kernel) * (factor**2)
self.register_buffer('kernel', kernel)
p = kernel.shape[0] - factor
pad0 = (p + 1) // 2 + factor - 1
pad1 = p // 2
self.pad = (pad0, pad1)
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input feature map with shape of (N, C, H, W).
Returns:
Tensor: Output feature map.
"""
out = upfirdn2d(
x, self.kernel.to(x.dtype), up=self.factor, down=1, pad=self.pad)
return out
class DownsampleUpFIRDn(nn.Module):
"""UpFIRDn for Downsampling.
This module is mentioned in StyleGAN2 for dowampling the feature maps.
Args:
kernel (Array): Blur kernel/filter used in UpFIRDn.
factor (int, optional): Downsampling factor. Defaults to 2.
"""
def __init__(self, kernel, factor=2):
super().__init__()
self.factor = factor
kernel = _make_kernel(kernel)
self.register_buffer('kernel', kernel)
p = kernel.shape[0] - factor
pad0 = (p + 1) // 2
pad1 = p // 2
self.pad = (pad0, pad1)
def forward(self, input):
"""Forward function.
Args:
input (Tensor): Input feature map with shape of (N, C, H, W).
Returns:
Tensor: Output feature map.
"""
out = upfirdn2d(
input,
self.kernel.to(input.dtype),
up=1,
down=self.factor,
pad=self.pad)
return out
class Blur(nn.Module):
"""Blur module.
This module is adopted rightly after upsampling operation in StyleGAN2.
Args:
kernel (Array): Blur kernel/filter used in UpFIRDn.
pad (list[int]): Padding for features.
upsample_factor (int, optional): Upsampling factor. Defaults to 1.
"""
def __init__(self, kernel, pad, upsample_factor=1):
super().__init__()
kernel = _make_kernel(kernel)
if upsample_factor > 1:
kernel = kernel * (upsample_factor**2)
self.register_buffer('kernel', kernel)
self.pad = pad
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input feature map with shape of (N, C, H, W).
Returns:
Tensor: Output feature map.
"""
# In Tero's implementation, he uses fp32
return upfirdn2d(x, self.kernel.to(x.dtype), pad=self.pad)
class ModulatedConv2d(nn.Module):
r"""Modulated Conv2d in StyleGANv2.
This module implements the modulated convolution layers proposed in
StyleGAN2. Details can be found in Analyzing and Improving the Image
Quality of StyleGAN, CVPR2020.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
kernel_size (int): Kernel size, same as :obj:`nn.Con2d`.
style_channels (int): Channels for the style codes.
demodulate (bool, optional): Whether to adopt demodulation.
Defaults to True.
upsample (bool, optional): Whether to adopt upsampling in features.
Defaults to False.
downsample (bool, optional): Whether to adopt downsampling in features.
Defaults to False.
blur_kernel (list[int], optional): Blurry kernel.
Defaults to [1, 3, 3, 1].
equalized_lr_cfg (dict | None, optional): Configs for equalized lr.
Defaults to dict(mode='fan_in', lr_mul=1., gain=1.).
style_mod_cfg (dict, optional): Configs for style modulation module.
Defaults to dict(bias_init=1.).
style_bias (float, optional): Bias value for style code.
Defaults to 0..
eps (float, optional): Epsilon value to avoid computation error.
Defaults to 1e-8.
"""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
style_channels,
demodulate=True,
upsample=False,
downsample=False,
blur_kernel=[1, 3, 3, 1],
equalized_lr_cfg=dict(mode='fan_in', lr_mul=1., gain=1.),
style_mod_cfg=dict(bias_init=1.),
style_bias=0.,
padding=None, # self define padding
eps=1e-8):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.style_channels = style_channels
self.demodulate = demodulate
# sanity check for kernel size
assert isinstance(self.kernel_size,
int) and (self.kernel_size >= 1
and self.kernel_size % 2 == 1)
self.upsample = upsample
self.downsample = downsample
self.style_bias = style_bias
self.eps = eps
# build style modulation module
style_mod_cfg = dict() if style_mod_cfg is None else style_mod_cfg
self.style_modulation = EqualLinearActModule(style_channels,
in_channels,
**style_mod_cfg)
# set lr_mul for conv weight
lr_mul_ = 1.
if equalized_lr_cfg is not None:
lr_mul_ = equalized_lr_cfg.get('lr_mul', 1.)
self.weight = nn.Parameter(
torch.randn(1, out_channels, in_channels, kernel_size,
kernel_size).div_(lr_mul_))
# build blurry layer for upsampling
if upsample:
factor = 2
p = (len(blur_kernel) - factor) - (kernel_size - 1)
pad0 = (p + 1) // 2 + factor - 1
pad1 = p // 2 + 1
self.blur = Blur(blur_kernel, (pad0, pad1), upsample_factor=factor)
# build blurry layer for downsampling
if downsample:
factor = 2
p = (len(blur_kernel) - factor) + (kernel_size - 1)
pad0 = (p + 1) // 2
pad1 = p // 2
self.blur = Blur(blur_kernel, pad=(pad0, pad1))
# add equalized_lr hook for conv weight
if equalized_lr_cfg is not None:
equalized_lr(self, **equalized_lr_cfg)
self.padding = padding if padding else (kernel_size // 2)
def forward(self, x, style, input_gain=None):
n, c, h, w = x.shape
weight = self.weight
# Pre-normalize inputs to avoid FP16 overflow.
if x.dtype == torch.float16 and self.demodulate:
weight = weight * (
1 / np.sqrt(
self.in_channels * self.kernel_size * self.kernel_size) /
weight.norm(float('inf'), dim=[1, 2, 3], keepdim=True)
) # max_Ikk
style = style / style.norm(
float('inf'), dim=1, keepdim=True) # max_I
# process style code
style = self.style_modulation(style).view(n, 1, c, 1,
1) + self.style_bias
# combine weight and style
weight = weight * style
if self.demodulate:
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
weight = weight * demod.view(n, self.out_channels, 1, 1, 1)
if input_gain is not None:
# input_gain shape [batch, in_ch]
input_gain = input_gain.expand(n, self.in_channels)
# weight shape [batch, out_ch, in_ch, kernel_size, kernel_size]
weight = weight * input_gain.unsqueeze(1).unsqueeze(3).unsqueeze(4)
weight = weight.view(n * self.out_channels, c, self.kernel_size,
self.kernel_size)
weight = weight.to(x.dtype)
if self.upsample:
x = x.reshape(1, n * c, h, w)
weight = weight.view(n, self.out_channels, c, self.kernel_size,
self.kernel_size)
weight = weight.transpose(1, 2).reshape(n * c, self.out_channels,
self.kernel_size,
self.kernel_size)
x = conv_transpose2d(x, weight, padding=0, stride=2, groups=n)
x = x.reshape(n, self.out_channels, *x.shape[-2:])
x = self.blur(x)
elif self.downsample:
x = self.blur(x)
x = x.view(1, n * self.in_channels, *x.shape[-2:])
x = conv2d(x, weight, stride=2, padding=0, groups=n)
x = x.view(n, self.out_channels, *x.shape[-2:])
else:
x = x.reshape(1, n * c, h, w)
x = conv2d(x, weight, stride=1, padding=self.padding, groups=n)
x = x.view(n, self.out_channels, *x.shape[-2:])
return x
class NoiseInjection(nn.Module):
"""Noise Injection Module.
In StyleGAN2, they adopt this module to inject spatial random noise map in
the generators.
Args:
noise_weight_init (float, optional): Initialization weight for noise
injection. Defaults to ``0.``.
"""
def __init__(self, noise_weight_init=0.):
super().__init__()
self.weight = nn.Parameter(torch.zeros(1).fill_(noise_weight_init))
def forward(self, image, noise=None, return_noise=False):
"""Forward Function.
Args:
image (Tensor): Spatial features with a shape of (N, C, H, W).
noise (Tensor, optional): Noises from the outside.
Defaults to None.
return_noise (bool, optional): Whether to return noise tensor.
Defaults to False.
Returns:
Tensor: Output features.
"""
if noise is None:
batch, _, height, width = image.shape
noise = image.new_empty(batch, 1, height, width).normal_()
noise = noise.to(image.dtype)
if return_noise:
return image + self.weight.to(image.dtype) * noise, noise
return image + self.weight.to(image.dtype) * noise
class ConstantInput(nn.Module):
"""Constant Input.
In StyleGAN2, they substitute the original head noise input with such a
constant input module.
Args:
channel (int): Channels for the constant input tensor.
size (int, optional): Spatial size for the constant input.
Defaults to 4.
"""
def __init__(self, channel, size=4):
super().__init__()
if isinstance(size, int):
size = [size, size]
elif mmcv.is_seq_of(size, int):
assert len(
size
) == 2, f'The length of size should be 2 but got {len(size)}'
else:
raise ValueError(f'Got invalid value in size, {size}')
self.input = nn.Parameter(torch.randn(1, channel, *size))
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input feature map with shape of (N, C, ...).
Returns:
Tensor: Output feature map.
"""
batch = x.shape[0]
out = self.input.repeat(batch, 1, 1, 1)
return out
class ModulatedPEConv2d(nn.Module):
r"""Modulated Conv2d in StyleGANv2 with Positional Encoding (PE).
This module is modified from the ``ModulatedConv2d`` in StyleGAN2 to
support the experiments in: Positional Encoding as Spatial Inductive Bias
in GANs, CVPR'2021.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
kernel_size (int): Kernel size, same as :obj:`nn.Con2d`.
style_channels (int): Channels for the style codes.
demodulate (bool, optional): Whether to adopt demodulation.
Defaults to True.
upsample (bool, optional): Whether to adopt upsampling in features.
Defaults to False.
downsample (bool, optional): Whether to adopt downsampling in features.
Defaults to False.
blur_kernel (list[int], optional): Blurry kernel.
Defaults to [1, 3, 3, 1].
equalized_lr_cfg (dict | None, optional): Configs for equalized lr.
Defaults to dict(mode='fan_in', lr_mul=1., gain=1.).
style_mod_cfg (dict, optional): Configs for style modulation module.
Defaults to dict(bias_init=1.).
style_bias (float, optional): Bias value for style code.
Defaults to 0..
eps (float, optional): Epsilon value to avoid computation error.
Defaults to 1e-8.
no_pad (bool, optional): Whether to removing the padding in
convolution. Defaults to False.
deconv2conv (bool, optional): Whether to substitute the transposed conv
with (conv2d, upsampling). Defaults to False.
interp_pad (int | None, optional): The padding number of interpolation
pad. Defaults to None.
up_config (dict, optional): Upsampling config.
Defaults to dict(scale_factor=2, mode='nearest').
up_after_conv (bool, optional): Whether to adopt upsampling after
convolution. Defaults to False.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
style_channels,
demodulate=True,
upsample=False,
downsample=False,
blur_kernel=[1, 3, 3, 1],
equalized_lr_cfg=dict(mode='fan_in', lr_mul=1., gain=1.),
style_mod_cfg=dict(bias_init=1.),
style_bias=0.,
eps=1e-8,
no_pad=False,
deconv2conv=False,
interp_pad=None,
up_config=dict(scale_factor=2, mode='nearest'),
up_after_conv=False):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.style_channels = style_channels
self.demodulate = demodulate
# sanity check for kernel size
assert isinstance(self.kernel_size,
int) and (self.kernel_size >= 1
and self.kernel_size % 2 == 1)
self.upsample = upsample
self.downsample = downsample
self.style_bias = style_bias
self.eps = eps
self.no_pad = no_pad
self.deconv2conv = deconv2conv
self.interp_pad = interp_pad
self.with_interp_pad = interp_pad is not None
self.up_config = deepcopy(up_config)
self.up_after_conv = up_after_conv
# build style modulation module
style_mod_cfg = dict() if style_mod_cfg is None else style_mod_cfg
self.style_modulation = EqualLinearActModule(style_channels,
in_channels,
**style_mod_cfg)
# set lr_mul for conv weight
lr_mul_ = 1.
if equalized_lr_cfg is not None:
lr_mul_ = equalized_lr_cfg.get('lr_mul', 1.)
self.weight = nn.Parameter(
torch.randn(1, out_channels, in_channels, kernel_size,
kernel_size).div_(lr_mul_))
# build blurry layer for upsampling
if upsample and not self.deconv2conv:
factor = 2
p = (len(blur_kernel) - factor) - (kernel_size - 1)
pad0 = (p + 1) // 2 + factor - 1
pad1 = p // 2 + 1
self.blur = Blur(blur_kernel, (pad0, pad1), upsample_factor=factor)
# build blurry layer for downsampling
if downsample:
factor = 2
p = (len(blur_kernel) - factor) + (kernel_size - 1)
pad0 = (p + 1) // 2
pad1 = p // 2
self.blur = Blur(blur_kernel, pad=(pad0, pad1))
# add equalized_lr hook for conv weight
if equalized_lr_cfg is not None:
equalized_lr(self, **equalized_lr_cfg)
# if `no_pad`, remove all of the padding in conv
self.padding = kernel_size // 2 if not no_pad else 0
def forward(self, x, style):
"""Forward function.
Args:
x ([Tensor): Input features with shape of (N, C, H, W).
style (Tensor): Style latent with shape of (N, C).
Returns:
Tensor: Output feature with shape of (N, C, H, W).
"""
n, c, h, w = x.shape
# process style code
style = self.style_modulation(style).view(n, 1, c, 1,
1) + self.style_bias
# combine weight and style
weight = self.weight * style
if self.demodulate:
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
weight = weight * demod.view(n, self.out_channels, 1, 1, 1)
weight = weight.view(n * self.out_channels, c, self.kernel_size,
self.kernel_size)
if self.upsample and not self.deconv2conv:
x = x.reshape(1, n * c, h, w)
weight = weight.view(n, self.out_channels, c, self.kernel_size,
self.kernel_size)
weight = weight.transpose(1, 2).reshape(n * c, self.out_channels,
self.kernel_size,
self.kernel_size)
x = conv_transpose2d(x, weight, padding=0, stride=2, groups=n)
x = x.reshape(n, self.out_channels, *x.shape[-2:])
x = self.blur(x)
elif self.upsample and self.deconv2conv:
if self.up_after_conv:
x = x.reshape(1, n * c, h, w)
x = conv2d(x, weight, padding=self.padding, groups=n)
x = x.view(n, self.out_channels, *x.shape[2:4])
if self.with_interp_pad:
h_, w_ = x.shape[-2:]
up_cfg_ = deepcopy(self.up_config)
up_scale = up_cfg_.pop('scale_factor')
size_ = (h_ * up_scale + self.interp_pad,
w_ * up_scale + self.interp_pad)
x = F.interpolate(x, size=size_, **up_cfg_)
else:
x = F.interpolate(x, **self.up_config)
if not self.up_after_conv:
h_, w_ = x.shape[-2:]
x = x.view(1, n * c, h_, w_)
x = conv2d(x, weight, padding=self.padding, groups=n)
x = x.view(n, self.out_channels, *x.shape[2:4])
elif self.downsample:
x = self.blur(x)
x = x.view(1, n * self.in_channels, *x.shape[-2:])
x = conv2d(x, weight, stride=2, padding=0, groups=n)
x = x.view(n, self.out_channels, *x.shape[-2:])
else:
x = x.view(1, n * c, h, w)
x = conv2d(x, weight, stride=1, padding=self.padding, groups=n)
x = x.view(n, self.out_channels, *x.shape[-2:])
return x
class ModulatedStyleConv(nn.Module):
"""Modulated Style Convolution.
In this module, we integrate the modulated conv2d, noise injector and
activation layers into together.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
kernel_size (int): Kernel size, same as :obj:`nn.Con2d`.
style_channels (int): Channels for the style codes.
demodulate (bool, optional): Whether to adopt demodulation.
Defaults to True.
upsample (bool, optional): Whether to adopt upsampling in features.
Defaults to False.
downsample (bool, optional): Whether to adopt downsampling in features.
Defaults to False.
blur_kernel (list[int], optional): Blurry kernel.
Defaults to [1, 3, 3, 1].
equalized_lr_cfg (dict | None, optional): Configs for equalized lr.
Defaults to dict(mode='fan_in', lr_mul=1., gain=1.).
style_mod_cfg (dict, optional): Configs for style modulation module.
Defaults to dict(bias_init=1.).
style_bias (float, optional): Bias value for style code.
Defaults to ``0.``.
fp16_enabled (bool, optional): Whether to use fp16 training in this
module. Defaults to False.
conv_clamp (float, optional): Clamp the convolutional layer results to
avoid gradient overflow. Defaults to `256.0`.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
style_channels,
upsample=False,
blur_kernel=[1, 3, 3, 1],
demodulate=True,
style_mod_cfg=dict(bias_init=1.),
style_bias=0.,
fp16_enabled=False,
conv_clamp=256):
super().__init__()
# add support for fp16
self.fp16_enabled = fp16_enabled
self.conv_clamp = float(conv_clamp)
self.conv = ModulatedConv2d(
in_channels,
out_channels,
kernel_size,
style_channels,
demodulate=demodulate,
upsample=upsample,
blur_kernel=blur_kernel,
style_mod_cfg=style_mod_cfg,
style_bias=style_bias)
self.noise_injector = NoiseInjection()
self.activate = _FusedBiasLeakyReLU(out_channels)
# if self.fp16_enabled:
# self.half()
@auto_fp16(apply_to=('x', 'noise'))
def forward(self, x, style, noise=None, return_noise=False):
"""Forward Function.
Args:
x ([Tensor): Input features with shape of (N, C, H, W).
style (Tensor): Style latent with shape of (N, C).
noise (Tensor, optional): Noise for injection. Defaults to None.
return_noise (bool, optional): Whether to return noise tensors.
Defaults to False.
Returns:
Tensor: Output features with shape of (N, C, H, W)
"""
out = self.conv(x, style)
if return_noise:
out, noise = self.noise_injector(
out, noise=noise, return_noise=return_noise)
else:
out = self.noise_injector(
out, noise=noise, return_noise=return_noise)
# TODO: FP16 in activate layers
out = self.activate(out)
if self.fp16_enabled:
out = torch.clamp(out, min=-self.conv_clamp, max=self.conv_clamp)
if return_noise:
return out, noise
return out
class ModulatedPEStyleConv(nn.Module):
"""Modulated Style Convolution with Positional Encoding.
This module is modified from the ``ModulatedStyleConv`` in StyleGAN2 to
support the experiments in: Positional Encoding as Spatial Inductive Bias
in GANs, CVPR'2021.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
kernel_size (int): Kernel size, same as :obj:`nn.Con2d`.
style_channels (int): Channels for the style codes.
demodulate (bool, optional): Whether to adopt demodulation.
Defaults to True.
upsample (bool, optional): Whether to adopt upsampling in features.
Defaults to False.
downsample (bool, optional): Whether to adopt downsampling in features.
Defaults to False.
blur_kernel (list[int], optional): Blurry kernel.
Defaults to [1, 3, 3, 1].
equalized_lr_cfg (dict | None, optional): Configs for equalized lr.
Defaults to dict(mode='fan_in', lr_mul=1., gain=1.).
style_mod_cfg (dict, optional): Configs for style modulation module.
Defaults to dict(bias_init=1.).
style_bias (float, optional): Bias value for style code.
Defaults to 0..
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
style_channels,
upsample=False,
blur_kernel=[1, 3, 3, 1],
demodulate=True,
style_mod_cfg=dict(bias_init=1.),
style_bias=0.,
**kwargs):
super().__init__()
self.conv = ModulatedPEConv2d(
in_channels,
out_channels,
kernel_size,
style_channels,
demodulate=demodulate,
upsample=upsample,
blur_kernel=blur_kernel,
style_mod_cfg=style_mod_cfg,
style_bias=style_bias,
**kwargs)
self.noise_injector = NoiseInjection()
self.activate = _FusedBiasLeakyReLU(out_channels)
def forward(self, x, style, noise=None, return_noise=False):
"""Forward Function.
Args:
x ([Tensor): Input features with shape of (N, C, H, W).
style (Tensor): Style latent with shape of (N, C).
noise (Tensor, optional): Noise for injection. Defaults to None.
return_noise (bool, optional): Whether to return noise tensors.
Defaults to False.
Returns:
Tensor: Output features with shape of (N, C, H, W)
"""
out = self.conv(x, style)
if return_noise:
out, noise = self.noise_injector(
out, noise=noise, return_noise=return_noise)
else:
out = self.noise_injector(
out, noise=noise, return_noise=return_noise)
out = self.activate(out)
if return_noise:
return out, noise
return out
class ModulatedToRGB(nn.Module):
"""To RGB layer.
This module is designed to output image tensor in StyleGAN2.
Args:
in_channels (int): Input channels.
style_channels (int): Channels for the style codes.
out_channels (int, optional): Output channels. Defaults to 3.
upsample (bool, optional): Whether to adopt upsampling in features.
Defaults to False.
blur_kernel (list[int], optional): Blurry kernel.
Defaults to [1, 3, 3, 1].
style_mod_cfg (dict, optional): Configs for style modulation module.
Defaults to dict(bias_init=1.).
style_bias (float, optional): Bias value for style code.
Defaults to 0..
fp16_enabled (bool, optional): Whether to use fp16 training in this
module. Defaults to False.
conv_clamp (float, optional): Clamp the convolutional layer results to
avoid gradient overflow. Defaults to `256.0`.
out_fp32 (bool, optional): Whether to convert the output feature map to
`torch.float32`. Defaults to `True`.
"""
def __init__(self,
in_channels,
style_channels,
out_channels=3,
upsample=True,
blur_kernel=[1, 3, 3, 1],
style_mod_cfg=dict(bias_init=1.),
style_bias=0.,
fp16_enabled=False,
conv_clamp=256,
out_fp32=True):
super().__init__()
if upsample:
self.upsample = UpsampleUpFIRDn(blur_kernel)
# add support for fp16
self.fp16_enabled = fp16_enabled
self.conv_clamp = float(conv_clamp)
self.conv = ModulatedConv2d(
in_channels,
out_channels=out_channels,
kernel_size=1,
style_channels=style_channels,
demodulate=False,
style_mod_cfg=style_mod_cfg,
style_bias=style_bias)
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
# enforece the output to be fp32 (follow Tero's implementation)
self.out_fp32 = out_fp32
@auto_fp16(apply_to=('x', 'style'))
def forward(self, x, style, skip=None):
"""Forward Function.
Args:
x ([Tensor): Input features with shape of (N, C, H, W).
style (Tensor): Style latent with shape of (N, C).
skip (Tensor, optional): Tensor for skip link. Defaults to None.
Returns:
Tensor: Output features with shape of (N, C, H, W)
"""
out = self.conv(x, style)
out = out + self.bias.to(x.dtype)
if self.fp16_enabled:
out = torch.clamp(out, min=-self.conv_clamp, max=self.conv_clamp)
# Here, Tero adopts FP16 at `skip`.
if skip is not None:
skip = self.upsample(skip)
out = out + skip
return out
class ConvDownLayer(nn.Sequential):
"""Convolution and Downsampling layer.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
kernel_size (int): Kernel size, same as :obj:`nn.Con2d`.
downsample (bool, optional): Whether to adopt downsampling in features.
Defaults to False.
blur_kernel (list[int], optional): Blurry kernel.
Defaults to [1, 3, 3, 1].
bias (bool, optional): Whether to use bias parameter. Defaults to True.
act_cfg (dict, optional): Activation configs.
Defaults to dict(type='fused_bias').
fp16_enabled (bool, optional): Whether to use fp16 training in this
module. Defaults to False.
conv_clamp (float, optional): Clamp the convolutional layer results to
avoid gradient overflow. Defaults to `256.0`.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
downsample=False,
blur_kernel=[1, 3, 3, 1],
bias=True,
act_cfg=dict(type='fused_bias'),
fp16_enabled=False,
conv_clamp=256.):
self.fp16_enabled = fp16_enabled
self.conv_clamp = float(conv_clamp)
layers = []
if downsample:
factor = 2
p = (len(blur_kernel) - factor) + (kernel_size - 1)
pad0 = (p + 1) // 2
pad1 = p // 2
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
stride = 2
self.padding = 0
else:
stride = 1
self.padding = kernel_size // 2
self.with_fused_bias = act_cfg is not None and act_cfg.get(
'type') == 'fused_bias'
if self.with_fused_bias:
conv_act_cfg = None
else:
conv_act_cfg = act_cfg
layers.append(
EqualizedLRConvModule(
in_channels,
out_channels,
kernel_size,
padding=self.padding,
stride=stride,
bias=bias and not self.with_fused_bias,
norm_cfg=None,
act_cfg=conv_act_cfg,
equalized_lr_cfg=dict(mode='fan_in', gain=1.)))
if self.with_fused_bias:
layers.append(_FusedBiasLeakyReLU(out_channels))
super(ConvDownLayer, self).__init__(*layers)
@auto_fp16(apply_to=('x', ))
def forward(self, x):
x = super().forward(x)
if self.fp16_enabled:
x = torch.clamp(x, min=-self.conv_clamp, max=self.conv_clamp)
return x
class ResBlock(nn.Module):
"""Residual block used in the discriminator of StyleGAN2.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
kernel_size (int): Kernel size, same as :obj:`nn.Con2d`.
fp16_enabled (bool, optional): Whether to use fp16 training in this
module. Defaults to False.
convert_input_fp32 (bool, optional): Whether to convert input type to
fp32 if not `fp16_enabled`. This argument is designed to deal with
the cases where some modules are run in FP16 and others in FP32.
Defaults to True.
"""
def __init__(self,
in_channels,
out_channels,
blur_kernel=[1, 3, 3, 1],
fp16_enabled=False,
convert_input_fp32=True):
super().__init__()
self.fp16_enabled = fp16_enabled
self.convert_input_fp32 = convert_input_fp32
self.conv1 = ConvDownLayer(
in_channels, in_channels, 3, blur_kernel=blur_kernel)
self.conv2 = ConvDownLayer(
in_channels,
out_channels,
3,
downsample=True,
blur_kernel=blur_kernel)
self.skip = ConvDownLayer(
in_channels,
out_channels,
1,
downsample=True,
act_cfg=None,
bias=False,
blur_kernel=blur_kernel)
@auto_fp16()
def forward(self, input):
"""Forward function.
Args:
input (Tensor): Input feature map with shape of (N, C, H, W).
Returns:
Tensor: Output feature map.
"""
# TODO: study whether this explicit datatype transfer will harm the
# apex training speed
if not self.fp16_enabled and self.convert_input_fp32:
input = input.to(torch.float32)
out = self.conv1(input)
out = self.conv2(out)
skip = self.skip(input)
out = (out + skip) / np.sqrt(2)
return out
class ModMBStddevLayer(nn.Module):
"""Modified MiniBatch Stddev Layer.
This layer is modified from ``MiniBatchStddevLayer`` used in PGGAN. In
StyleGAN2, the authors add a new feature, `channel_groups`, into this
layer.
Note that to accelerate the training procedure, we also add a new feature
of ``sync_std`` to achieve multi-nodes/machine training. This feature is
still in beta version and we have tested it on 256 scales.
Args:
group_size (int, optional): The size of groups in batch dimension.
Defaults to 4.
channel_groups (int, optional): The size of groups in channel
dimension. Defaults to 1.
sync_std (bool, optional): Whether to use synchronized std feature.
Defaults to False.
sync_groups (int | None, optional): The size of groups in node
dimension. Defaults to None.
eps (float, optional): Epsilon value to avoid computation error.
Defaults to 1e-8.
"""
def __init__(self,
group_size=4,
channel_groups=1,
sync_std=False,
sync_groups=None,
eps=1e-8):
super().__init__()
self.group_size = group_size
self.eps = eps
self.channel_groups = channel_groups
self.sync_std = sync_std
self.sync_groups = group_size if sync_groups is None else sync_groups
if self.sync_std:
assert torch.distributed.is_initialized(
), 'Only in distributed training can the sync_std be activated.'
mmcv.print_log('Adopt synced minibatch stddev layer', 'mmgen')
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input feature map with shape of (N, C, H, W).
Returns:
Tensor: Output feature map with shape of (N, C+1, H, W).
"""
if self.sync_std:
# concatenate all features
all_features = torch.cat(AllGatherLayer.apply(x), dim=0)
# get the exact features we need in calculating std-dev
rank, ws = get_dist_info()
local_bs = all_features.shape[0] // ws
start_idx = local_bs * rank
# avoid the case where start idx near the tail of features
if start_idx + self.sync_groups > all_features.shape[0]:
start_idx = all_features.shape[0] - self.sync_groups
end_idx = min(local_bs * rank + self.sync_groups,
all_features.shape[0])
x = all_features[start_idx:end_idx]
# batch size should be smaller than or equal to group size. Otherwise,
# batch size should be divisible by the group size.
assert x.shape[
0] <= self.group_size or x.shape[0] % self.group_size == 0, (
'Batch size be smaller than or equal '
'to group size. Otherwise,'
' batch size should be divisible by the group size.'
f'But got batch size {x.shape[0]},'
f' group size {self.group_size}')
assert x.shape[1] % self.channel_groups == 0, (
'"channel_groups" must be divided by the feature channels. '
f'channel_groups: {self.channel_groups}, '
f'feature channels: {x.shape[1]}')
n, c, h, w = x.shape
group_size = min(n, self.group_size)
# [G, M, Gc, C', H, W]
y = torch.reshape(x, (group_size, -1, self.channel_groups,
c // self.channel_groups, h, w))
y = torch.var(y, dim=0, unbiased=False)
y = torch.sqrt(y + self.eps)
# [M, 1, 1, 1]
y = y.mean(dim=(2, 3, 4), keepdim=True).squeeze(2)
y = y.repeat(group_size, 1, h, w)
return torch.cat([x, y], dim=1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment