Commit 1401de15 authored by dongchy920's avatar dongchy920
Browse files

stylegan2_mmcv

parents
Pipeline #1274 canceled with stages
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from .pretrained_networks import vgg16
def normalize_tensor(in_feat, eps=1e-10):
"""L2 normalization.
Args:
in_feat (Tensor): Tensor with shape [N, C, H, W].
eps (float, optional): Epsilon value to avoid computation error.
Defaults to 1e-10.
Returns:
Tensor: Tensor after L2 normalization per-instance.
"""
norm_factor = torch.sqrt(torch.sum(in_feat**2, dim=1, keepdim=True))
return in_feat / (norm_factor + eps)
def spatial_average(in_tens, keepdim=True):
"""Returns the mean value of each row of the input tensor in the spatial
dimension.
Args:
in_tens (Tensor): Tensor with shape [N, C, H, W].
keepdim (bool, optional): If keepdim is True, the output tensor is of
the shape [N, C, 1, 1]. Otherwise, the output will have shape
[N, C]. Defaults to True.
Returns:
Tensor: Tensor after average pooling to 1x1 with shape [N, C, 1, 1] or
[N, C].
"""
return in_tens.mean([2, 3], keepdim=keepdim)
def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W
"""Upsamples the input to the given size.
Args:
in_tens (Tensor): Tensor with shape [N, C, H, W].
out_H (int, optional): Output spatial size. Defaults to 64.
Returns:
Tensor: Output Tensor.
"""
in_H = in_tens.shape[2]
scale_factor = 1. * out_H / in_H
return nn.Upsample(
scale_factor=scale_factor, mode='bilinear', align_corners=False)(
in_tens)
# Learned perceptual metric
class PNetLin(nn.Module):
r"""
Ref: https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/lpips.py # noqa
"""
def __init__(self,
pnet_rand=False,
pnet_tune=False,
use_dropout=True,
spatial=False,
version='0.1',
lpips=True):
super().__init__()
self.pnet_tune = pnet_tune
self.pnet_rand = pnet_rand
self.spatial = spatial
self.lpips = lpips
self.version = version
self.scaling_layer = ScalingLayer()
self.channels = [64, 128, 256, 512, 512]
self.L = len(self.channels)
self.net = vgg16(
pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
self.lin0 = NetLinLayer(self.channels[0], use_dropout=use_dropout)
self.lin1 = NetLinLayer(self.channels[1], use_dropout=use_dropout)
self.lin2 = NetLinLayer(self.channels[2], use_dropout=use_dropout)
self.lin3 = NetLinLayer(self.channels[3], use_dropout=use_dropout)
self.lin4 = NetLinLayer(self.channels[4], use_dropout=use_dropout)
self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
def forward(self, in0, in1, retPerLayer=False):
# v0.0 - original release had a bug, where input was not scaled
in0_input, in1_input = (
self.scaling_layer(in0),
self.scaling_layer(in1)) if self.version == '0.1' else (in0, in1)
outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
feats0, feats1, diffs = {}, {}, {}
for kk in range(self.L):
feats0[kk], feats1[kk] = normalize_tensor(
outs0[kk]), normalize_tensor(outs1[kk])
diffs[kk] = (feats0[kk] - feats1[kk])**2
if self.lpips:
if self.spatial:
res = [
upsample(
self.lins[kk].model(diffs[kk]), out_H=in0.shape[2])
for kk in range(self.L)
]
else:
res = [
spatial_average(
self.lins[kk].model(diffs[kk]), keepdim=True)
for kk in range(self.L)
]
else:
if self.spatial:
res = [
upsample(
diffs[kk].sum(dim=1, keepdim=True), out_H=in0.shape[2])
for kk in range(self.L)
]
else:
res = [
spatial_average(
diffs[kk].sum(dim=1, keepdim=True), keepdim=True)
for kk in range(self.L)
]
val = sum(res)
if retPerLayer:
return (val, res)
return val
class ScalingLayer(nn.Module):
def __init__(self):
super().__init__()
self.register_buffer(
'shift',
torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
self.register_buffer(
'scale',
torch.Tensor([.458, .448, .450])[None, :, None, None])
def forward(self, inp):
return (inp - self.shift) / self.scale
class NetLinLayer(nn.Module):
"""A single linear layer which does a 1x1 conv."""
def __init__(self, chn_in, chn_out=1, use_dropout=False):
super().__init__()
layers = [
nn.Dropout(),
] if (use_dropout) else []
layers += [
nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
]
self.model = nn.Sequential(*layers)
class Dist2LogitLayer(nn.Module):
"""takes 2 distances, puts through fc layers, spits out value between [0,
1] (if use_sigmoid is True)"""
def __init__(self, chn_mid=32, use_sigmoid=True):
super().__init__()
layers = [
nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),
]
layers += [
nn.LeakyReLU(0.2, True),
]
layers += [
nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),
]
layers += [
nn.LeakyReLU(0.2, True),
]
layers += [
nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),
]
if use_sigmoid:
layers += [
nn.Sigmoid(),
]
self.model = nn.Sequential(*layers)
def forward(self, d0, d1, eps=0.1):
return self.model.forward(
torch.cat((d0, d1, d0 - d1, d0 / (d1 + eps), d1 / (d0 + eps)),
dim=1))
class BCERankingLoss(nn.Module):
def __init__(self, chn_mid=32):
super().__init__()
self.net = Dist2LogitLayer(chn_mid=chn_mid)
# self.parameters = list(self.net.parameters())
self.loss = torch.nn.BCELoss()
def forward(self, d0, d1, judge):
per = (judge + 1.) / 2.
self.logit = self.net.forward(d0, d1)
return self.loss(self.logit, per)
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch.utils.model_zoo import load_url
from .networks_basic import PNetLin
LPIPS_WEIGHTS_URL = 'https://download.openmmlab.com/mmgen/evaluation/lpips/weights/v0.1/vgg.pth' # noqa
class PerceptualLoss(torch.nn.Module):
r"""LPIPS metric with VGG using our perceptually-learned weights.
Ref: https://github.com/rosinality/stylegan2-pytorch/blob/master/lpips/__init__.py # noqa
"""
def __init__(self,
spatial=False,
use_gpu=True,
gpu_ids=[0],
pretrained=True):
super().__init__()
print('Setting up Perceptual loss...')
self.use_gpu = use_gpu
self.spatial = spatial
self.gpu_ids = gpu_ids
print('...[pnet-lin, vgg16] initializing')
self.init_net(pretrained=pretrained)
print('...Done')
def forward(self, pred, target, normalize=False):
if normalize:
target = 2 * target - 1
pred = 2 * pred - 1
return self.net(target, pred)
def init_net(self,
pnet_rand=False,
pnet_tune=False,
pretrained=True,
version='0.1'):
self.net = PNetLin(
pnet_rand=pnet_rand,
pnet_tune=pnet_tune,
use_dropout=True,
spatial=self.spatial,
version=version,
lpips=True)
if pretrained:
print('Loading model from: %s' % LPIPS_WEIGHTS_URL)
self.net.load_state_dict(
load_url(LPIPS_WEIGHTS_URL, map_location='cpu', progress=True),
strict=False)
self.parameters = list(self.net.parameters())
self.net.eval()
if self.use_gpu:
self.net.to(self.gpu_ids[0])
self.net = torch.nn.DataParallel(self.net, device_ids=self.gpu_ids)
# Copyright (c) OpenMMLab. All rights reserved.
from collections import namedtuple
import torch
from torchvision import models as tv
class vgg16(torch.nn.Module):
r"""VGG16 feature extractor for LPIPS metric.
Ref : https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/pretrained_networks.py # noqa
"""
def __init__(self, requires_grad=False, pretrained=True):
super().__init__()
vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.N_slices = 5
for x in range(4):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(4, 9):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(9, 16):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(16, 23):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(23, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h = self.slice1(X)
h_relu1_2 = h
h = self.slice2(h)
h_relu2_2 = h
h = self.slice3(h)
h_relu3_3 = h
h = self.slice4(h)
h_relu4_3 = h
h = self.slice5(h)
h_relu5_3 = h
vgg_outputs = namedtuple(
'VggOutputs',
['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3,
h_relu5_3)
return out
# Copyright (c) OpenMMLab. All rights reserved.
from .generator_discriminator import LSGANDiscriminator, LSGANGenerator
__all__ = ['LSGANDiscriminator', 'LSGANGenerator']
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.cnn.bricks import build_activation_layer
from mmgen.models.builder import MODULES
from ..common import get_module_device
@MODULES.register_module()
class LSGANGenerator(nn.Module):
"""Generator for LSGAN.
Implementation Details for LSGAN architecture:
#. Adopt transposed convolution in the generator;
#. Use batchnorm in the generator except for the final output layer;
#. Use ReLU in the generator in addition to the final output layer;
#. Keep channels of feature maps unchanged in the convolution backbone;
#. Use one more 3x3 conv every upsampling in the convolution backbone.
We follow the implementation details of the origin paper:
Least Squares Generative Adversarial Networks
https://arxiv.org/pdf/1611.04076.pdf
Args:
output_scale (int, optional): Output scale for the generated image.
Defaults to 128.
out_channels (int, optional): The channel number of the output feature.
Defaults to 3.
base_channels (int, optional): The basic channel number of the
generator. The other layers contains channels based on this number.
Defaults to 256.
input_scale (int, optional): The scale of the input 2D feature map.
Defaults to 8.
noise_size (int, optional): Size of the input noise
vector. Defaults to 1024.
conv_cfg (dict, optional): Config for the convolution module used in
this generator. Defaults to dict(type='ConvTranspose2d').
default_norm_cfg (dict, optional): Norm config for all of layers
except for the final output layer. Defaults to dict(type='BN').
default_act_cfg (dict, optional): Activation config for all of layers
except for the final output layer. Defaults to dict(type='ReLU').
out_act_cfg (dict, optional): Activation config for the final output
layer. Defaults to dict(type='Tanh').
"""
def __init__(self,
output_scale=128,
out_channels=3,
base_channels=256,
input_scale=8,
noise_size=1024,
conv_cfg=dict(type='ConvTranspose2d'),
default_norm_cfg=dict(type='BN'),
default_act_cfg=dict(type='ReLU'),
out_act_cfg=dict(type='Tanh')):
super().__init__()
assert output_scale % input_scale == 0
assert output_scale // input_scale >= 4
self.output_scale = output_scale
self.base_channels = base_channels
self.input_scale = input_scale
self.noise_size = noise_size
self.noise2feat_head = nn.Sequential(
nn.Linear(noise_size, input_scale * input_scale * base_channels))
self.noise2feat_tail = nn.Sequential(nn.BatchNorm2d(base_channels))
if default_act_cfg is not None:
self.noise2feat_tail.add_module(
'act', build_activation_layer(default_act_cfg))
# the number of times for upsampling
self.num_upsamples = int(np.log2(output_scale // input_scale)) - 2
# build up convolution backbone (excluding the output layer)
self.conv_blocks = nn.ModuleList()
for _ in range(self.num_upsamples):
self.conv_blocks.append(
ConvModule(
base_channels,
base_channels,
kernel_size=3,
stride=2,
padding=1,
conv_cfg=dict(conv_cfg, output_padding=1),
norm_cfg=default_norm_cfg,
act_cfg=default_act_cfg))
self.conv_blocks.append(
ConvModule(
base_channels,
base_channels,
kernel_size=3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=default_norm_cfg,
act_cfg=default_act_cfg))
# output blocks
self.conv_blocks.append(
ConvModule(
base_channels,
int(base_channels // 2),
kernel_size=3,
stride=2,
padding=1,
conv_cfg=dict(conv_cfg, output_padding=1),
norm_cfg=default_norm_cfg,
act_cfg=default_act_cfg))
self.conv_blocks.append(
ConvModule(
int(base_channels // 2),
int(base_channels // 4),
kernel_size=3,
stride=2,
padding=1,
conv_cfg=dict(conv_cfg, output_padding=1),
norm_cfg=default_norm_cfg,
act_cfg=default_act_cfg))
self.conv_blocks.append(
ConvModule(
int(base_channels // 4),
out_channels,
kernel_size=3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=None,
act_cfg=out_act_cfg))
def forward(self, noise, num_batches=0, return_noise=False):
"""Forward function.
Args:
noise (torch.Tensor | callable | None): You can directly give a
batch of noise through a ``torch.Tensor`` or offer a callable
function to sample a batch of noise data. Otherwise, the
``None`` indicates to use the default noise sampler.
num_batches (int, optional): The number of batch size.
Defaults to 0.
return_noise (bool, optional): If True, ``noise_batch`` will be
returned in a dict with ``fake_img``. Defaults to False.
Returns:
torch.Tensor | dict: If not ``return_noise``, only the output image
will be returned. Otherwise, a dict contains ``fake_img`` and
``noise_batch`` will be returned.
"""
# receive noise and conduct sanity check.
if isinstance(noise, torch.Tensor):
assert noise.shape[1] == self.noise_size
if noise.ndim == 2:
noise_batch = noise
else:
raise ValueError('The noise should be in shape of (n, c)'
f'but got {noise.shape}')
# receive a noise generator and sample noise.
elif callable(noise):
noise_generator = noise
assert num_batches > 0
noise_batch = noise_generator((num_batches, self.noise_size))
# otherwise, we will adopt default noise sampler.
else:
assert num_batches > 0
noise_batch = torch.randn((num_batches, self.noise_size))
# dirty code for putting data on the right device
noise_batch = noise_batch.to(get_module_device(self))
# noise2feat
x = self.noise2feat_head(noise_batch)
x = x.reshape(
(-1, self.base_channels, self.input_scale, self.input_scale))
x = self.noise2feat_tail(x)
# conv module
for conv in self.conv_blocks:
x = conv(x)
if return_noise:
return dict(fake_img=x, noise_batch=noise_batch)
return x
@MODULES.register_module()
class LSGANDiscriminator(nn.Module):
"""Discriminator for LSGAN.
Implementation Details for LSGAN architecture:
#. Adopt convolution in the discriminator;
#. Use batchnorm in the discriminator except for the input and final \
output layer;
#. Use LeakyReLU in the discriminator in addition to the output layer;
#. Use fully connected layer in the output layer;
#. Use 5x5 conv rather than 4x4 conv in DCGAN.
Args:
input_scale (int, optional): The scale of the input image. Defaults to
128.
output_scale (int, optional): The final scale of the convolutional
feature. Defaults to 8.
out_channels (int, optional): The channel number of the final output
layer. Defaults to 1.
in_channels (int, optional): The channel number of the input image.
Defaults to 3.
base_channels (int, optional): The basic channel number of the
generator. The other layers contains channels based on this number.
Defaults to 128.
conv_cfg (dict, optional): Config for the convolution module used in
this discriminator. Defaults to dict(type='Conv2d').
default_norm_cfg (dict, optional): Norm config for all of layers
except for the final output layer. Defaults to ``dict(type='BN')``.
default_act_cfg (dict, optional): Activation config for all of layers
except for the final output layer. Defaults to
``dict(type='LeakyReLU', negative_slope=0.2)``.
out_act_cfg (dict, optional): Activation config for the final output
layer. Defaults to ``dict(type='Tanh')``.
"""
def __init__(self,
input_scale=128,
output_scale=8,
out_channels=1,
in_channels=3,
base_channels=64,
conv_cfg=dict(type='Conv2d'),
default_norm_cfg=dict(type='BN'),
default_act_cfg=dict(type='LeakyReLU', negative_slope=0.2),
out_act_cfg=None):
super().__init__()
assert input_scale % output_scale == 0
assert input_scale // output_scale >= 2
self.input_scale = input_scale
self.output_scale = output_scale
self.out_channels = out_channels
self.base_channels = base_channels
self.with_out_activation = out_act_cfg is not None
self.conv_blocks = nn.ModuleList()
self.conv_blocks.append(
ConvModule(
in_channels,
base_channels,
kernel_size=5,
stride=2,
padding=2,
conv_cfg=conv_cfg,
norm_cfg=None,
act_cfg=default_act_cfg))
# the number of times for downsampling
self.num_downsamples = int(np.log2(input_scale // output_scale)) - 1
# build up downsampling backbone (excluding the output layer)
curr_channels = base_channels
for _ in range(self.num_downsamples):
self.conv_blocks.append(
ConvModule(
curr_channels,
curr_channels * 2,
kernel_size=5,
stride=2,
padding=2,
conv_cfg=conv_cfg,
norm_cfg=default_norm_cfg,
act_cfg=default_act_cfg))
curr_channels = curr_channels * 2
# output layer
self.decision = nn.Sequential(
nn.Linear(output_scale * output_scale * curr_channels,
out_channels))
if self.with_out_activation:
self.out_activation = build_activation_layer(out_act_cfg)
def forward(self, x):
"""Forward function.
Args:
x (torch.Tensor): Fake or real image tensor.
Returns:
torch.Tensor: Prediction for the reality of the input image.
"""
n = x.shape[0]
for conv in self.conv_blocks:
x = conv(x)
x = x.reshape(n, -1)
x = self.decision(x)
if self.with_out_activation:
x = self.out_activation(x)
return x
# Copyright (c) OpenMMLab. All rights reserved.
from .generator_discriminator import PGGANDiscriminator, PGGANGenerator
from .modules import (EqualizedLR, EqualizedLRConvDownModule,
EqualizedLRConvModule, EqualizedLRConvUpModule,
EqualizedLRLinearModule, MiniBatchStddevLayer,
PGGANNoiseTo2DFeat, PixelNorm, equalized_lr)
__all__ = [
'EqualizedLR', 'equalized_lr', 'EqualizedLRConvModule',
'EqualizedLRLinearModule', 'EqualizedLRConvUpModule',
'EqualizedLRConvDownModule', 'PixelNorm', 'MiniBatchStddevLayer',
'PGGANNoiseTo2DFeat', 'PGGANGenerator', 'PGGANDiscriminator'
]
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from functools import partial
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn.bricks.upsample import build_upsample_layer
from mmgen.models.builder import MODULES
from ..common import get_module_device
from .modules import (EqualizedLRConvDownModule, EqualizedLRConvModule,
EqualizedLRConvUpModule, MiniBatchStddevLayer,
PGGANDecisionHead, PGGANNoiseTo2DFeat)
@MODULES.register_module()
class PGGANGenerator(nn.Module):
"""Generator for PGGAN.
Args:
noise_size (int): Size of the input noise vector.
out_scale (int): Output scale for the generated image.
label_size (int, optional): Size of the label vector.
Defaults to 0.
base_channels (int, optional): The basic channel number of the
generator. The other layers contains channels based on this
number. Defaults to 8192.
channel_decay (float, optional): Decay for channels of feature maps.
Defaults to 1.0.
max_channels (int, optional): Maximum channels for the feature
maps in the generator block. Defaults to 512.
fused_upconv (bool, optional): Whether use fused upconv.
Defaults to True.
conv_module_cfg (dict, optional): Config for the convolution
module used in this generator. Defaults to None.
fused_upconv_cfg (dict, optional): Config for the fused upconv
module used in this generator. Defaults to None.
upsample_cfg (dict, optional): Config for the upsampling operation.
Defaults to None.
"""
_default_fused_upconv_cfg = dict(
conv_cfg=dict(type='deconv'),
kernel_size=3,
stride=2,
padding=1,
bias=True,
act_cfg=dict(type='LeakyReLU', negative_slope=0.2),
norm_cfg=dict(type='PixelNorm'),
order=('conv', 'act', 'norm'))
_default_conv_module_cfg = dict(
conv_cfg=None,
kernel_size=3,
stride=1,
padding=1,
bias=True,
act_cfg=dict(type='LeakyReLU', negative_slope=0.2),
norm_cfg=dict(type='PixelNorm'),
order=('conv', 'act', 'norm'))
_default_upsample_cfg = dict(type='nearest', scale_factor=2)
def __init__(self,
noise_size,
out_scale,
label_size=0,
base_channels=8192,
channel_decay=1.,
max_channels=512,
fused_upconv=True,
conv_module_cfg=None,
fused_upconv_cfg=None,
upsample_cfg=None):
super().__init__()
self.noise_size = noise_size if noise_size else min(
base_channels, max_channels)
self.out_scale = out_scale
self.out_log2_scale = int(np.log2(out_scale))
# sanity check for the output scale
assert out_scale == 2**self.out_log2_scale and out_scale >= 4
self.label_size = label_size
self.base_channels = base_channels
self.channel_decay = channel_decay
self.max_channels = max_channels
self.fused_upconv = fused_upconv
# set conv cfg
self.conv_module_cfg = deepcopy(self._default_conv_module_cfg)
# update with customized config
if conv_module_cfg:
self.conv_module_cfg.update(conv_module_cfg)
if self.fused_upconv:
self.fused_upconv_cfg = deepcopy(self._default_fused_upconv_cfg)
# update with customized config
if fused_upconv_cfg:
self.fused_upconv_cfg.update(fused_upconv_cfg)
self.upsample_cfg = deepcopy(self._default_upsample_cfg)
if upsample_cfg is not None:
self.upsample_cfg.update(upsample_cfg)
self.noise2feat = PGGANNoiseTo2DFeat(noise_size + label_size,
self._num_out_channels(1))
self.torgb_layers = nn.ModuleList()
self.conv_blocks = nn.ModuleList()
for s in range(2, self.out_log2_scale + 1):
in_ch = self._num_out_channels(
s - 1) if s == 2 else self._num_out_channels(s - 2)
# setup torgb layers
self.torgb_layers.append(
self._get_torgb_layer(self._num_out_channels(s - 1)))
# setup upconv or conv blocks
self.conv_blocks.extend(self._get_upconv_block(in_ch, s))
# build upsample layer for residual path
self.upsample_layer = build_upsample_layer(self.upsample_cfg)
def _get_torgb_layer(self, in_channels):
return EqualizedLRConvModule(
in_channels,
3,
kernel_size=1,
stride=1,
equalized_lr_cfg=dict(gain=1),
bias=True,
norm_cfg=None,
act_cfg=None)
def _num_out_channels(self, log_scale):
return min(
int(self.base_channels / (2.0**(log_scale * self.channel_decay))),
self.max_channels)
def _get_upconv_block(self, in_channels, log_scale):
modules = []
# start 4x4 scale
if log_scale == 2:
modules.append(
EqualizedLRConvModule(in_channels,
self._num_out_channels(log_scale - 1),
**self.conv_module_cfg))
# 8x8 --> 1024x1024 scales
else:
if self.fused_upconv:
cfg_ = dict(upsample=dict(type='fused_nn'))
cfg_.update(self.fused_upconv_cfg)
else:
cfg_ = dict(upsample=self.upsample_cfg)
cfg_.update(self.conv_module_cfg)
# up + conv
modules.append(
EqualizedLRConvUpModule(in_channels,
self._num_out_channels(log_scale - 1),
**cfg_))
# refine conv
modules.append(
EqualizedLRConvModule(
self._num_out_channels(log_scale - 1),
self._num_out_channels(log_scale - 1),
**self.conv_module_cfg))
return modules
def forward(self,
noise,
label=None,
num_batches=0,
return_noise=False,
transition_weight=1.,
curr_scale=-1):
"""Forward function.
Args:
noise (torch.Tensor | callable | None): You can directly give a
batch of noise through a ``torch.Tensor`` or offer a callable
function to sample a batch of noise data. Otherwise, the
``None`` indicates to use the default noise sampler.
label (Tensor, optional): Label vector with shape [N, C]. Defaults
to None.
num_batches (int, optional): The number of batch size. Defaults to
0.
return_noise (bool, optional): If True, ``noise_batch`` will be
returned in a dict with ``fake_img``. Defaults to False.
transition_weight (float, optional): The weight used in resolution
transition. Defaults to 1.0.
curr_scale (int, optional): The scale for the current inference or
training. Defaults to -1.
Returns:
torch.Tensor | dict: If not ``return_noise``, only the output image
will be returned. Otherwise, a dict contains ``fake_img`` and
``noise_batch`` will be returned.
"""
# receive noise and conduct sanity check.
if isinstance(noise, torch.Tensor):
assert noise.shape[1] == self.noise_size
assert noise.ndim == 2, ('The noise should be in shape of (n, c), '
f'but got {noise.shape}')
noise_batch = noise
# receive a noise generator and sample noise.
elif callable(noise):
noise_generator = noise
assert num_batches > 0
noise_batch = noise_generator((num_batches, self.noise_size))
# otherwise, we will adopt default noise sampler.
else:
assert num_batches > 0
# TODO: check pggan default noise type
noise_batch = torch.randn((num_batches, self.noise_size))
# dirty code for putting data on the right device
noise_batch = noise_batch.to(get_module_device(self))
if label is not None:
noise_batch = torch.cat(
[noise_batch, label.to(noise_batch)], dim=1)
# noise vector to 2D feature
x = self.noise2feat(noise_batch)
# build current computational graph
curr_log2_scale = self.out_log2_scale if curr_scale < 0 else int(
np.log2(curr_scale))
# 4x4 scale
x = self.conv_blocks[0](x)
if curr_log2_scale <= 3:
out_img = last_img = self.torgb_layers[0](x)
# 8x8 and larger scales
for s in range(3, curr_log2_scale + 1):
x = self.conv_blocks[2 * s - 5](x)
x = self.conv_blocks[2 * s - 4](x)
if s + 1 == curr_log2_scale:
last_img = self.torgb_layers[s - 2](x)
elif s == curr_log2_scale:
out_img = self.torgb_layers[s - 2](x)
residual_img = self.upsample_layer(last_img)
out_img = residual_img + transition_weight * (
out_img - residual_img)
if return_noise:
output = dict(
fake_img=out_img, noise_batch=noise_batch, label=label)
return output
return out_img
@MODULES.register_module()
class PGGANDiscriminator(nn.Module):
"""Discriminator for PGGAN.
Args:
in_scale (int): The scale of the input image.
label_size (int, optional): Size of the label vector. Defaults to
0.
base_channels (int, optional): The basic channel number of the
generator. The other layers contains channels based on this
number. Defaults to 8192.
max_channels (int, optional): Maximum channels for the feature
maps in the discriminator block. Defaults to 512.
in_channels (int, optional): Number of channels in input images.
Defaults to 3.
channel_decay (float, optional): Decay for channels of feature
maps. Defaults to 1.0.
mbstd_cfg (dict, optional): Configs for minibatch-stddev layer.
Defaults to dict(group_size=4).
fused_convdown (bool, optional): Whether use fused downconv.
Defaults to True.
conv_module_cfg (dict, optional): Config for the convolution
module used in this generator. Defaults to None.
fused_convdown_cfg (dict, optional): Config for the fused downconv
module used in this discriminator. Defaults to None.
fromrgb_layer_cfg (dict, optional): Config for the fromrgb layer.
Defaults to None.
downsample_cfg (dict, optional): Config for the downsampling
operation. Defaults to None.
"""
_default_fromrgb_cfg = dict(
conv_cfg=None,
kernel_size=1,
stride=1,
padding=0,
bias=True,
act_cfg=dict(type='LeakyReLU', negative_slope=0.2),
norm_cfg=None,
order=('conv', 'act', 'norm'))
_default_conv_module_cfg = dict(
kernel_size=3,
padding=1,
stride=1,
norm_cfg=None,
act_cfg=dict(type='LeakyReLU', negative_slope=0.2))
_default_convdown_cfg = dict(
kernel_size=3,
padding=1,
stride=2,
norm_cfg=None,
act_cfg=dict(type='LeakyReLU', negative_slope=0.2))
def __init__(self,
in_scale,
label_size=0,
base_channels=8192,
max_channels=512,
in_channels=3,
channel_decay=1.0,
mbstd_cfg=dict(group_size=4),
fused_convdown=True,
conv_module_cfg=None,
fused_convdown_cfg=None,
fromrgb_layer_cfg=None,
downsample_cfg=None):
super().__init__()
self.in_scale = in_scale
self.in_log2_scale = int(np.log2(self.in_scale))
self.label_size = label_size
self.base_channels = base_channels
self.max_channels = max_channels
self.in_channels = in_channels
self.channel_decay = channel_decay
self.with_mbstd = mbstd_cfg is not None
self.fused_convdown = fused_convdown
self.conv_module_cfg = deepcopy(self._default_conv_module_cfg)
if conv_module_cfg is not None:
self.conv_module_cfg.update(conv_module_cfg)
if self.fused_convdown:
self.fused_convdown_cfg = deepcopy(self._default_convdown_cfg)
if fused_convdown_cfg is not None:
self.fused_convdown_cfg.update(fused_convdown_cfg)
self.fromrgb_layer_cfg = deepcopy(self._default_fromrgb_cfg)
if fromrgb_layer_cfg:
self.fromrgb_layer_cfg.update(fromrgb_layer_cfg)
# setup conv blocks
self.conv_blocks = nn.ModuleList()
self.fromrgb_layers = nn.ModuleList()
for s in range(2, self.in_log2_scale + 1):
self.fromrgb_layers.append(
self._get_fromrgb_layer(self.in_channels, s))
self.conv_blocks.extend(
self._get_convdown_block(self._num_out_channels(s - 1), s))
# setup downsample layer
self.downsample_cfg = deepcopy(downsample_cfg)
if self.downsample_cfg is None or self.downsample_cfg.get(
'type', None) == 'avgpool':
self.downsample = nn.AvgPool2d(kernel_size=2, stride=2)
elif self.downsample_cfg.get('type', None) in ['nearest', 'bilinear']:
self.downsample = partial(
F.interpolate,
mode=self.downsample_cfg.pop('type'),
**self.downsample_cfg)
else:
raise NotImplementedError(
'We have not supported the downsampling with type'
f' {downsample_cfg}.')
# setup minibatch stddev layer
if self.with_mbstd:
self.mbstd_layer = MiniBatchStddevLayer(**mbstd_cfg)
# minibatch stddev layer will concatenate an additional feature map
# in channel dimension.
decision_in_channels = self._num_out_channels(1) * 16 + 16
else:
decision_in_channels = self._num_out_channels(1) * 16
# setup decision layer
self.decision = PGGANDecisionHead(decision_in_channels,
self._num_out_channels(0),
1 + self.label_size)
def _num_out_channels(self, log_scale):
return min(
int(self.base_channels / (2.0**(log_scale * self.channel_decay))),
self.max_channels)
def _get_fromrgb_layer(self, in_channels, log2_scale):
return EqualizedLRConvModule(in_channels,
self._num_out_channels(log2_scale - 1),
**self.fromrgb_layer_cfg)
def _get_convdown_block(self, in_channels, log2_scale):
modules = []
if log2_scale == 2:
modules.append(
EqualizedLRConvModule(in_channels,
self._num_out_channels(log2_scale - 1),
**self.conv_module_cfg))
else:
modules.append(
EqualizedLRConvModule(in_channels,
self._num_out_channels(log2_scale - 1),
**self.conv_module_cfg))
if self.fused_convdown:
cfg_ = dict(downsample=dict(type='fused_pool'))
cfg_.update(self.fused_convdown_cfg)
else:
cfg_ = dict(downsample=self.downsample)
cfg_.update(self.conv_module_cfg)
modules.append(
EqualizedLRConvDownModule(
self._num_out_channels(log2_scale - 1),
self._num_out_channels(log2_scale - 2), **cfg_))
return modules
def forward(self, x, transition_weight=1., curr_scale=-1):
"""Forward function.
Args:
x (torch.Tensor): Input image tensor.
transition_weight (float, optional): The weight used in resolution
transition. Defaults to 1.0.
curr_scale (int, optional): The scale for the current inference or
training. Defaults to -1.
Returns:
Tensor: Predict score for the input image.
"""
curr_log2_scale = self.in_log2_scale if curr_scale < 4 else int(
np.log2(curr_scale))
original_img = x
x = self.fromrgb_layers[curr_log2_scale - 2](x)
for s in range(curr_log2_scale, 2, -1):
x = self.conv_blocks[2 * s - 5](x)
x = self.conv_blocks[2 * s - 4](x)
if s == curr_log2_scale:
img_down = self.downsample(original_img)
y = self.fromrgb_layers[curr_log2_scale - 3](img_down)
x = y + transition_weight * (x - y)
if self.with_mbstd:
x = self.mbstd_layer(x)
x = self.decision(x)
if self.label_size > 0:
return x[:, :1], x[:, 1:]
return x
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn.bricks import (NORM_LAYERS, PLUGIN_LAYERS, ConvModule,
build_activation_layer, build_norm_layer,
build_upsample_layer)
from mmcv.cnn.utils import normal_init
from torch.nn.init import _calculate_correct_fan
from mmgen.models.builder import MODULES
from mmgen.models.common import AllGatherLayer
class EqualizedLR:
r"""Equalized Learning Rate.
This trick is proposed in:
Progressive Growing of GANs for Improved Quality, Stability, and Variation
The general idea is to dynamically rescale the weight in training instead
of in initializing so that the variance of the responses in each layer is
guaranteed with some statistical properties.
Note that this function is always combined with a convolution module which
is initialized with :math:`\mathcal{N}(0, 1)`.
Args:
name (str | optional): The name of weights. Defaults to 'weight'.
mode (str, optional): The mode of computing ``fan`` which is the
same as ``kaiming_init`` in pytorch. You can choose one from
['fan_in', 'fan_out']. Defaults to 'fan_in'.
"""
def __init__(self, name='weight', gain=2**0.5, mode='fan_in', lr_mul=1.0):
self.name = name
self.mode = mode
self.gain = gain
self.lr_mul = lr_mul
def compute_weight(self, module):
"""Compute weight with equalized learning rate.
Args:
module (nn.Module): A module that is wrapped with equalized lr.
Returns:
torch.Tensor: Updated weight.
"""
weight = getattr(module, self.name + '_orig')
if weight.ndim == 5:
# weight in shape of [b, out, in, k, k]
fan = _calculate_correct_fan(weight[0], self.mode)
else:
assert weight.ndim <= 4
fan = _calculate_correct_fan(weight, self.mode)
weight = weight * torch.tensor(
self.gain, device=weight.device) * torch.sqrt(
torch.tensor(1. / fan, device=weight.device)) * self.lr_mul
return weight
def __call__(self, module, inputs):
"""Standard interface for forward pre hooks."""
setattr(module, self.name, self.compute_weight(module))
@staticmethod
def apply(module, name, gain=2**0.5, mode='fan_in', lr_mul=1.):
"""Apply function.
This function is to register an equalized learning rate hook in an
``nn.Module``.
Args:
module (nn.Module): Module to be wrapped.
name (str | optional): The name of weights. Defaults to 'weight'.
mode (str, optional): The mode of computing ``fan`` which is the
same as ``kaiming_init`` in pytorch. You can choose one from
['fan_in', 'fan_out']. Defaults to 'fan_in'.
Returns:
nn.Module: Module that is registered with equalized lr hook.
"""
# sanity check for duplicated hooks.
for _, hook in module._forward_pre_hooks.items():
if isinstance(hook, EqualizedLR):
raise RuntimeError(
'Cannot register two equalized_lr hooks on the same '
f'parameter {name} in {module} module.')
fn = EqualizedLR(name, gain=gain, mode=mode, lr_mul=lr_mul)
weight = module._parameters[name]
delattr(module, name)
module.register_parameter(name + '_orig', weight)
# We still need to assign weight back as fn.name because all sorts of
# things may assume that it exists, e.g., when initializing weights.
# However, we can't directly assign as it could be an nn.Parameter and
# gets added as a parameter. Instead, we register weight.data as a
# plain attribute.
setattr(module, name, weight.data)
module.register_forward_pre_hook(fn)
# TODO: register load state dict hook
return fn
def equalized_lr(module, name='weight', gain=2**0.5, mode='fan_in', lr_mul=1.):
r"""Equalized Learning Rate.
This trick is proposed in:
Progressive Growing of GANs for Improved Quality, Stability, and Variation
The general idea is to dynamically rescale the weight in training instead
of in initializing so that the variance of the responses in each layer is
guaranteed with some statistical properties.
Note that this function is always combined with a convolution module which
is initialized with :math:`\mathcal{N}(0, 1)`.
Args:
module (nn.Module): Module to be wrapped.
name (str | optional): The name of weights. Defaults to 'weight'.
mode (str, optional): The mode of computing ``fan`` which is the
same as ``kaiming_init`` in pytorch. You can choose one from
['fan_in', 'fan_out']. Defaults to 'fan_in'.
Returns:
nn.Module: Module that is registered with equalized lr hook.
"""
EqualizedLR.apply(module, name, gain=gain, mode=mode, lr_mul=lr_mul)
return module
def pixel_norm(x, eps=1e-6):
"""Pixel Normalization.
This normalization is proposed in:
Progressive Growing of GANs for Improved Quality, Stability, and Variation
Args:
x (torch.Tensor): Tensor to be normalized.
eps (float, optional): Epsilon to avoid dividing zero.
Defaults to 1e-6.
Returns:
torch.Tensor: Normalized tensor.
"""
if torch.__version__ >= '1.7.0':
norm = torch.linalg.norm(x, ord=2, dim=1, keepdim=True)
# support older pytorch version
else:
norm = torch.norm(x, p=2, dim=1, keepdim=True)
norm = norm / torch.sqrt(torch.tensor(x.shape[1]).to(x))
return x / (norm + eps)
@MODULES.register_module()
@NORM_LAYERS.register_module()
class PixelNorm(nn.Module):
"""Pixel Normalization.
This module is proposed in:
Progressive Growing of GANs for Improved Quality, Stability, and Variation
Args:
eps (float, optional): Epsilon value. Defaults to 1e-6.
"""
_abbr_ = 'pn'
def __init__(self, in_channels=None, eps=1e-6):
super().__init__()
self.eps = eps
def forward(self, x):
"""Forward function.
Args:
x (torch.Tensor): Tensor to be normalized.
Returns:
torch.Tensor: Normalized tensor.
"""
return pixel_norm(x, self.eps)
@PLUGIN_LAYERS.register_module()
class EqualizedLRConvModule(ConvModule):
r"""Equalized LR ConvModule.
In this module, we inherit default ``mmcv.cnn.ConvModule`` and adopt
equalized lr in convolution. The equalized learning rate is proposed in:
Progressive Growing of GANs for Improved Quality, Stability, and Variation
Note that, the initialization of ``self.conv`` will be overwritten as
:math:`\mathcal{N}(0, 1)`.
Args:
equalized_lr_cfg (dict | None, optional): Config for ``EqualizedLR``.
If ``None``, equalized learning rate is ignored. Defaults to
dict(mode='fan_in').
"""
def __init__(self, *args, equalized_lr_cfg=dict(mode='fan_in'), **kwargs):
super().__init__(*args, **kwargs)
self.with_equalized_lr = equalized_lr_cfg is not None
if self.with_equalized_lr:
self.conv = equalized_lr(self.conv, **equalized_lr_cfg)
# initialize the conv weight with standard Gaussian noise.
self._init_conv_weights()
def _init_conv_weights(self):
"""Initialize conv weights as described in PGGAN."""
normal_init(self.conv)
@PLUGIN_LAYERS.register_module()
class EqualizedLRConvUpModule(EqualizedLRConvModule):
r"""Equalized LR (Upsample + Conv) Module.
In this module, we inherit ``EqualizedLRConvModule`` and adopt
upsampling before convolution. As for upsampling, in addition to the
sampling layer in MMCV, we also offer the "fused_nn" type. "fused_nn"
denotes fusing upsampling and convolution. The fusion is modified from
the official Tensorflow implementation in:
https://github.com/tkarras/progressive_growing_of_gans/blob/master/networks.py#L86
Args:
upsample (dict | None, optional): Config for upsampling operation. If
``None``, upsampling is ignored. If you need a faster fused version as
the official PGGAN in Tensorflow, you should set it as
``dict(type='fused_nn')``. Defaults to
``dict(type='nearest', scale_factor=2)``.
"""
def __init__(self,
*args,
upsample=dict(type='nearest', scale_factor=2),
**kwargs):
super().__init__(*args, **kwargs)
self.with_upsample = upsample is not None
if self.with_upsample:
if upsample.get('type') == 'fused_nn':
assert isinstance(self.conv, nn.ConvTranspose2d)
self.conv.register_forward_pre_hook(
EqualizedLRConvUpModule.fused_nn_hook)
else:
self.upsample_layer = build_upsample_layer(upsample)
def forward(self, x, **kwargs):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
if hasattr(self, 'upsample_layer'):
x = self.upsample_layer(x)
return super().forward(x, **kwargs)
@staticmethod
def fused_nn_hook(module, inputs):
"""Standard interface for forward pre hooks."""
weight = module.weight
# pad the last two dimensions
weight = F.pad(weight, (1, 1, 1, 1))
weight = weight[..., 1:, 1:] + weight[..., 1:, :-1] + weight[
..., :-1, 1:] + weight[..., :-1, :-1]
module.weight = weight
@PLUGIN_LAYERS.register_module()
class EqualizedLRConvDownModule(EqualizedLRConvModule):
r"""Equalized LR (Conv + Downsample) Module.
In this module, we inherit ``EqualizedLRConvModule`` and adopt
downsampling after convolution. As for downsampling, we provide two modes
of "avgpool" and "fused_pool". "avgpool" denotes the commonly used average
pooling operation, while "fused_pool" represents fusing downsampling and
convolution. The fusion is modified from the official Tensorflow
implementation in:
https://github.com/tkarras/progressive_growing_of_gans/blob/master/networks.py#L109
Args:
downsample (dict | None, optional): Config for downsampling operation.
If ``None``, downsampling is ignored. Currently, we support the
types of ["avgpool", "fused_pool"]. Defaults to
dict(type='fused_pool').
"""
def __init__(self, *args, downsample=dict(type='fused_pool'), **kwargs):
super().__init__(*args, **kwargs)
downsample_cfg = deepcopy(downsample)
self.with_downsample = downsample is not None
if self.with_downsample:
type_ = downsample_cfg.pop('type')
if type_ == 'avgpool':
self.downsample = nn.AvgPool2d(2, 2)
elif type_ == 'fused_pool':
self.conv.register_forward_pre_hook(
EqualizedLRConvDownModule.fused_avgpool_hook)
elif callable(downsample):
self.downsample = downsample
else:
raise NotImplementedError(
'Currently, we only support ["avgpool", "fused_pool"] as '
f'the type of downsample, but got {type_} instead.')
def forward(self, x, **kwargs):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
torch.Tensor: Normalized tensor.
"""
x = super().forward(x, **kwargs)
if hasattr(self, 'downsample'):
x = self.downsample(x)
return x
@staticmethod
def fused_avgpool_hook(module, inputs):
"""Standard interface for forward pre hooks."""
weight = module.weight
# pad the last two dimensions
weight = F.pad(weight, (1, 1, 1, 1))
weight = (weight[..., 1:, 1:] + weight[..., 1:, :-1] +
weight[..., :-1, 1:] + weight[..., :-1, :-1]) * 0.25
module.weight = weight
@PLUGIN_LAYERS.register_module()
class EqualizedLRLinearModule(nn.Linear):
r"""Equalized LR LinearModule.
In this module, we adopt equalized lr in ``nn.Linear``. The equalized
learning rate is proposed in:
Progressive Growing of GANs for Improved Quality, Stability, and Variation
Note that, the initialization of ``self.weight`` will be overwritten as
:math:`\mathcal{N}(0, 1)`.
Args:
equalized_lr_cfg (dict | None, optional): Config for ``EqualizedLR``.
If ``None``, equalized learning rate is ignored. Defaults to
dict(mode='fan_in').
"""
def __init__(self, *args, equalized_lr_cfg=dict(mode='fan_in'), **kwargs):
super().__init__(*args, **kwargs)
self.with_equalized_lr = equalized_lr_cfg is not None
if self.with_equalized_lr:
self.lr_mul = equalized_lr_cfg.get('lr_mul', 1.)
else:
# In fact, lr_mul will only be used in EqualizedLR for
# initialization
self.lr_mul = 1.
if self.with_equalized_lr:
equalized_lr(self, **equalized_lr_cfg)
self._init_linear_weights()
def _init_linear_weights(self):
"""Initialize linear weights as described in PGGAN."""
nn.init.normal_(self.weight, 0, 1. / self.lr_mul)
if self.bias is not None:
nn.init.constant_(self.bias, 0.)
@MODULES.register_module()
class PGGANNoiseTo2DFeat(nn.Module):
def __init__(self,
noise_size,
out_channels,
act_cfg=dict(type='LeakyReLU', negative_slope=0.2),
norm_cfg=dict(type='PixelNorm'),
normalize_latent=True,
order=('linear', 'act', 'norm')):
super().__init__()
self.noise_size = noise_size
self.out_channels = out_channels
self.normalize_latent = normalize_latent
self.with_activation = act_cfg is not None
self.with_norm = norm_cfg is not None
self.order = order
assert len(order) == 3 and set(order) == set(['linear', 'act', 'norm'])
# w/o bias, because the bias is added after reshaping the tensor to
# 2D feature
self.linear = EqualizedLRLinearModule(
noise_size,
out_channels * 16,
equalized_lr_cfg=dict(gain=np.sqrt(2) / 4),
bias=False)
if self.with_activation:
self.activation = build_activation_layer(act_cfg)
# add bias for reshaped 2D feature.
self.register_parameter(
'bias', nn.Parameter(torch.zeros(1, out_channels, 1, 1)))
if self.with_norm:
_, self.norm = build_norm_layer(norm_cfg, out_channels)
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input noise tensor with shape (n, c).
Returns:
Tensor: Forward results with shape (n, c, 4, 4).
"""
assert x.ndim == 2
if self.normalize_latent:
x = pixel_norm(x)
for order in self.order:
if order == 'linear':
x = self.linear(x)
# [n, c, 4, 4]
x = torch.reshape(x, (-1, self.out_channels, 4, 4))
x = x + self.bias
elif order == 'act' and self.with_activation:
x = self.activation(x)
elif order == 'norm' and self.with_norm:
x = self.norm(x)
return x
class PGGANDecisionHead(nn.Module):
def __init__(self,
in_channels,
mid_channels,
out_channels,
bias=True,
equalized_lr_cfg=dict(gain=1),
act_cfg=dict(type='LeakyReLU', negative_slope=0.2),
out_act=None):
super().__init__()
self.in_channels = in_channels
self.mid_channels = mid_channels
self.out_channels = out_channels
self.with_activation = act_cfg is not None
self.with_out_activation = out_act is not None
# setup linear layers
# dirty code for supporting default mode in PGGAN
if equalized_lr_cfg:
equalized_lr_cfg_ = dict(gain=2**0.5)
else:
equalized_lr_cfg_ = None
self.linear0 = EqualizedLRLinearModule(
self.in_channels,
self.mid_channels,
bias=bias,
equalized_lr_cfg=equalized_lr_cfg_)
self.linear1 = EqualizedLRLinearModule(
self.mid_channels,
self.out_channels,
bias=bias,
equalized_lr_cfg=equalized_lr_cfg)
# setup activation layers
if self.with_activation:
self.activation = build_activation_layer(act_cfg)
if self.with_out_activation:
self.out_activation = build_activation_layer(out_act)
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
if x.ndim > 2:
x = torch.reshape(x, (x.shape[0], -1))
x = self.linear0(x)
if self.with_activation:
x = self.activation(x)
x = self.linear1(x)
if self.with_out_activation:
x = self.out_activation(x)
return x
@MODULES.register_module()
@PLUGIN_LAYERS.register_module()
class MiniBatchStddevLayer(nn.Module):
"""Minibatch standard deviation.
Args:
group_size (int, optional): The size of groups in batch dimension.
Defaults to 4.
eps (float, optional): Epsilon value to avoid computation error.
Defaults to 1e-8.
gather_all_batch (bool, optional): Whether gather batch from all GPUs.
Defaults to False.
"""
def __init__(self, group_size=4, eps=1e-8, gather_all_batch=False):
super().__init__()
self.group_size = group_size
self.eps = eps
self.gather_all_batch = gather_all_batch
if self.gather_all_batch:
assert torch.distributed.is_initialized(
), 'Only in distributed training can the tensors be all gathered.'
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
if self.gather_all_batch:
x = torch.cat(AllGatherLayer.apply(x), dim=0)
# batch size should be smaller than or equal to group size. Otherwise,
# batch size should be divisible by the group size.
assert x.shape[
0] <= self.group_size or x.shape[0] % self.group_size == 0, (
'Batch size be smaller than or equal '
'to group size. Otherwise,'
' batch size should be divisible by the group size.'
f'But got batch size {x.shape[0]},'
f' group size {self.group_size}')
n, c, h, w = x.shape
group_size = min(n, self.group_size)
# [G, M, C, H, W]
y = torch.reshape(x, (group_size, -1, c, h, w))
# [G, M, C, H, W]
y = y - y.mean(dim=0, keepdim=True)
# In pt>=1.7, you can just use `.square()` function.
# [M, C, H, W]
y = y.pow(2).mean(dim=0, keepdim=False)
y = torch.sqrt(y + self.eps)
# [M, 1, 1, 1]
y = y.mean(dim=(1, 2, 3), keepdim=True)
y = y.repeat(group_size, 1, h, w)
return torch.cat([x, y], dim=1)
# Copyright (c) OpenMMLab. All rights reserved.
from .generator_discriminator import PatchDiscriminator, UnetGenerator
from .modules import UnetSkipConnectionBlock, generation_init_weights
__all__ = [
'PatchDiscriminator', 'UnetGenerator', 'UnetSkipConnectionBlock',
'generation_init_weights'
]
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import ConvModule, build_conv_layer
from mmcv.runner import load_checkpoint
from mmgen.models.builder import MODULES
from mmgen.utils import get_root_logger
from .modules import UnetSkipConnectionBlock, generation_init_weights
@MODULES.register_module()
class UnetGenerator(nn.Module):
"""Construct the Unet-based generator from the innermost layer to the
outermost layer, which is a recursive process.
Args:
in_channels (int): Number of channels in input images.
out_channels (int): Number of channels in output images.
num_down (int): Number of downsamplings in Unet. If `num_down` is 8,
the image with size 256x256 will become 1x1 at the bottleneck.
Default: 8.
base_channels (int): Number of channels at the last conv layer.
Default: 64.
norm_cfg (dict): Config dict to build norm layer. Default:
`dict(type='BN')`.
use_dropout (bool): Whether to use dropout layers. Default: False.
init_cfg (dict): Config dict for initialization.
`type`: The name of our initialization method. Default: 'normal'.
`gain`: Scaling factor for normal, xavier and orthogonal.
Default: 0.02.
"""
def __init__(self,
in_channels,
out_channels,
num_down=8,
base_channels=64,
norm_cfg=dict(type='BN'),
use_dropout=False,
init_cfg=dict(type='normal', gain=0.02)):
super().__init__()
# We use norm layers in the unet generator.
assert isinstance(norm_cfg, dict), ("'norm_cfg' should be dict, but"
f'got {type(norm_cfg)}')
assert 'type' in norm_cfg, "'norm_cfg' must have key 'type'"
# add the innermost layer
unet_block = UnetSkipConnectionBlock(
base_channels * 8,
base_channels * 8,
in_channels=None,
submodule=None,
norm_cfg=norm_cfg,
is_innermost=True)
# add intermediate layers with base_channels * 8 filters
for _ in range(num_down - 5):
unet_block = UnetSkipConnectionBlock(
base_channels * 8,
base_channels * 8,
in_channels=None,
submodule=unet_block,
norm_cfg=norm_cfg,
use_dropout=use_dropout)
# gradually reduce the number of filters
# from base_channels * 8 to base_channels
unet_block = UnetSkipConnectionBlock(
base_channels * 4,
base_channels * 8,
in_channels=None,
submodule=unet_block,
norm_cfg=norm_cfg)
unet_block = UnetSkipConnectionBlock(
base_channels * 2,
base_channels * 4,
in_channels=None,
submodule=unet_block,
norm_cfg=norm_cfg)
unet_block = UnetSkipConnectionBlock(
base_channels,
base_channels * 2,
in_channels=None,
submodule=unet_block,
norm_cfg=norm_cfg)
# add the outermost layer
self.model = UnetSkipConnectionBlock(
out_channels,
base_channels,
in_channels=in_channels,
submodule=unet_block,
is_outermost=True,
norm_cfg=norm_cfg)
self.init_type = 'normal' if init_cfg is None else init_cfg.get(
'type', 'normal')
self.init_gain = 0.02 if init_cfg is None else init_cfg.get(
'gain', 0.02)
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
return self.model(x)
def init_weights(self, pretrained=None, strict=True):
"""Initialize weights for the model.
Args:
pretrained (str, optional): Path for pretrained weights. If given
None, pretrained weights will not be loaded. Default: None.
strict (bool, optional): Whether to allow different params for the
model and checkpoint. Default: True.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=strict, logger=logger)
elif pretrained is None:
generation_init_weights(
self, init_type=self.init_type, init_gain=self.init_gain)
else:
raise TypeError("'pretrained' must be a str or None. "
f'But received {type(pretrained)}.')
@MODULES.register_module()
class PatchDiscriminator(nn.Module):
"""A PatchGAN discriminator.
Args:
in_channels (int): Number of channels in input images.
base_channels (int): Number of channels at the first conv layer.
Default: 64.
num_conv (int): Number of stacked intermediate convs (excluding input
and output conv). Default: 3.
norm_cfg (dict): Config dict to build norm layer. Default:
`dict(type='BN')`.
init_cfg (dict): Config dict for initialization.
`type`: The name of our initialization method. Default: 'normal'.
`gain`: Scaling factor for normal, xavier and orthogonal.
Default: 0.02.
"""
def __init__(self,
in_channels,
base_channels=64,
num_conv=3,
norm_cfg=dict(type='BN'),
init_cfg=dict(type='normal', gain=0.02)):
super().__init__()
assert isinstance(norm_cfg, dict), ("'norm_cfg' should be dict, but"
f'got {type(norm_cfg)}')
assert 'type' in norm_cfg, "'norm_cfg' must have key 'type'"
# We use norm layers in the patch discriminator.
# Only for IN, use bias since it does not have affine parameters.
use_bias = norm_cfg['type'] == 'IN'
kernel_size = 4
padding = 1
# input layer
sequence = [
ConvModule(
in_channels=in_channels,
out_channels=base_channels,
kernel_size=kernel_size,
stride=2,
padding=padding,
bias=True,
norm_cfg=None,
act_cfg=dict(type='LeakyReLU', negative_slope=0.2))
]
# stacked intermediate layers,
# gradually increasing the number of filters
multiple_now = 1
multiple_prev = 1
for n in range(1, num_conv):
multiple_prev = multiple_now
multiple_now = min(2**n, 8)
sequence += [
ConvModule(
in_channels=base_channels * multiple_prev,
out_channels=base_channels * multiple_now,
kernel_size=kernel_size,
stride=2,
padding=padding,
bias=use_bias,
norm_cfg=norm_cfg,
act_cfg=dict(type='LeakyReLU', negative_slope=0.2))
]
multiple_prev = multiple_now
multiple_now = min(2**num_conv, 8)
sequence += [
ConvModule(
in_channels=base_channels * multiple_prev,
out_channels=base_channels * multiple_now,
kernel_size=kernel_size,
stride=1,
padding=padding,
bias=use_bias,
norm_cfg=norm_cfg,
act_cfg=dict(type='LeakyReLU', negative_slope=0.2))
]
# output one-channel prediction map
sequence += [
build_conv_layer(
dict(type='Conv2d'),
base_channels * multiple_now,
1,
kernel_size=kernel_size,
stride=1,
padding=padding)
]
self.model = nn.Sequential(*sequence)
self.init_type = 'normal' if init_cfg is None else init_cfg.get(
'type', 'normal')
self.init_gain = 0.02 if init_cfg is None else init_cfg.get(
'gain', 0.02)
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
return self.model(x)
def init_weights(self, pretrained=None):
"""Initialize weights for the model.
Args:
pretrained (str, optional): Path for pretrained weights. If given
None, pretrained weights will not be loaded. Default: None.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
generation_init_weights(
self, init_type=self.init_type, init_gain=self.init_gain)
else:
raise TypeError("'pretrained' must be a str or None. "
f'But received {type(pretrained)}.')
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, kaiming_init, normal_init, xavier_init
from torch.nn import init
def generation_init_weights(module, init_type='normal', init_gain=0.02):
"""Default initialization of network weights for image generation.
By default, we use normal init, but xavier and kaiming might work
better for some applications.
Args:
module (nn.Module): Module to be initialized.
init_type (str): The name of an initialization method:
normal | xavier | kaiming | orthogonal.
init_gain (float): Scaling factor for normal, xavier and
orthogonal.
"""
def init_func(m):
"""Initialization function.
Args:
m (nn.Module): Module to be initialized.
"""
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1
or classname.find('Linear') != -1):
if init_type == 'normal':
normal_init(m, 0.0, init_gain)
elif init_type == 'xavier':
xavier_init(m, gain=init_gain, distribution='normal')
elif init_type == 'kaiming':
kaiming_init(
m,
a=0,
mode='fan_in',
nonlinearity='leaky_relu',
distribution='normal')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight, gain=init_gain)
init.constant_(m.bias.data, 0.0)
else:
raise NotImplementedError(
f"Initialization method '{init_type}' is not implemented")
elif classname.find('BatchNorm2d') != -1:
# BatchNorm Layer's weight is not a matrix;
# only normal distribution applies.
normal_init(m, 1.0, init_gain)
module.apply(init_func)
class UnetSkipConnectionBlock(nn.Module):
"""Construct a Unet submodule with skip connections, with the following.
structure: downsampling - `submodule` - upsampling.
Args:
outer_channels (int): Number of channels at the outer conv layer.
inner_channels (int): Number of channels at the inner conv layer.
in_channels (int): Number of channels in input images/features. If is
None, equals to `outer_channels`. Default: None.
submodule (UnetSkipConnectionBlock): Previously constructed submodule.
Default: None.
is_outermost (bool): Whether this module is the outermost module.
Default: False.
is_innermost (bool): Whether this module is the innermost module.
Default: False.
norm_cfg (dict): Config dict to build norm layer. Default:
`dict(type='BN')`.
use_dropout (bool): Whether to use dropout layers. Default: False.
"""
def __init__(self,
outer_channels,
inner_channels,
in_channels=None,
submodule=None,
is_outermost=False,
is_innermost=False,
norm_cfg=dict(type='BN'),
use_dropout=False):
super().__init__()
# cannot be both outermost and innermost
assert not (is_outermost and is_innermost), (
"'is_outermost' and 'is_innermost' cannot be True"
'at the same time.')
self.is_outermost = is_outermost
assert isinstance(norm_cfg, dict), ("'norm_cfg' should be dict, but"
f'got {type(norm_cfg)}')
assert 'type' in norm_cfg, "'norm_cfg' must have key 'type'"
# We use norm layers in the unet skip connection block.
# Only for IN, use bias since it does not have affine parameters.
use_bias = norm_cfg['type'] == 'IN'
kernel_size = 4
stride = 2
padding = 1
if in_channels is None:
in_channels = outer_channels
down_conv_cfg = dict(type='Conv2d')
down_norm_cfg = norm_cfg
down_act_cfg = dict(type='LeakyReLU', negative_slope=0.2)
up_conv_cfg = dict(type='deconv')
up_norm_cfg = norm_cfg
up_act_cfg = dict(type='ReLU')
up_in_channels = inner_channels * 2
up_bias = use_bias
middle = [submodule]
upper = []
if is_outermost:
down_act_cfg = None
down_norm_cfg = None
up_bias = True
up_norm_cfg = None
upper = [nn.Tanh()]
elif is_innermost:
down_norm_cfg = None
up_in_channels = inner_channels
middle = []
else:
upper = [nn.Dropout(0.5)] if use_dropout else []
down = [
ConvModule(
in_channels=in_channels,
out_channels=inner_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=use_bias,
conv_cfg=down_conv_cfg,
norm_cfg=down_norm_cfg,
act_cfg=down_act_cfg,
order=('act', 'conv', 'norm'))
]
up = [
ConvModule(
in_channels=up_in_channels,
out_channels=outer_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=up_bias,
conv_cfg=up_conv_cfg,
norm_cfg=up_norm_cfg,
act_cfg=up_act_cfg,
order=('act', 'conv', 'norm'))
]
model = down + middle + up + upper
self.model = nn.Sequential(*model)
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
if self.is_outermost:
return self.model(x)
# add skip connections
return torch.cat([x, self.model(x)], 1)
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.nn as nn
from mmgen.models.builder import MODULES
@MODULES.register_module('SPE')
@MODULES.register_module('SPE2d')
class SinusoidalPositionalEmbedding(nn.Module):
"""Sinusoidal Positional Embedding 1D or 2D (SPE/SPE2d).
This module is a modified from:
https://github.com/pytorch/fairseq/blob/master/fairseq/modules/sinusoidal_positional_embedding.py # noqa
Based on the original SPE in single dimension, we implement a 2D sinusoidal
positional encodding (SPE2d), as introduced in Positional Encoding as
Spatial Inductive Bias in GANs, CVPR'2021.
Args:
embedding_dim (int): The number of dimensions for the positional
encoding.
padding_idx (int | list[int]): The index for the padding contents. The
padding positions will obtain an encoding vector filling in zeros.
init_size (int, optional): The initial size of the positional buffer.
Defaults to 1024.
div_half_dim (bool, optional): If true, the embedding will be divided
by :math:`d/2`. Otherwise, it will be divided by
:math:`(d/2 -1)`. Defaults to False.
center_shift (int | None, optional): Shift the center point to some
index. Defaults to None.
"""
def __init__(self,
embedding_dim,
padding_idx,
init_size=1024,
div_half_dim=False,
center_shift=None):
super().__init__()
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.div_half_dim = div_half_dim
self.center_shift = center_shift
self.weights = SinusoidalPositionalEmbedding.get_embedding(
init_size, embedding_dim, padding_idx, self.div_half_dim)
self.register_buffer('_float_tensor', torch.FloatTensor(1))
self.max_positions = int(1e5)
@staticmethod
def get_embedding(num_embeddings,
embedding_dim,
padding_idx=None,
div_half_dim=False):
"""Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
assert embedding_dim % 2 == 0, (
'In this version, we request '
f'embedding_dim divisible by 2 but got {embedding_dim}')
# there is a little difference from the original paper.
half_dim = embedding_dim // 2
if not div_half_dim:
emb = np.log(10000) / (half_dim - 1)
else:
emb = np.log(1e4) / half_dim
# compute exp(-log10000 / d * i)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
emb = torch.arange(
num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat([torch.sin(emb), torch.cos(emb)],
dim=1).view(num_embeddings, -1)
if padding_idx is not None:
emb[padding_idx, :] = 0
return emb
def forward(self, input, **kwargs):
"""Input is expected to be of size [bsz x seqlen].
Returned tensor is expected to be of size [bsz x seq_len x emb_dim]
"""
assert input.dim() == 2 or input.dim(
) == 4, 'Input dimension should be 2 (1D) or 4(2D)'
if input.dim() == 4:
return self.make_grid2d_like(input, **kwargs)
b, seq_len = input.shape
max_pos = self.padding_idx + 1 + seq_len
if self.weights is None or max_pos > self.weights.size(0):
# recompute/expand embedding if needed
self.weights = SinusoidalPositionalEmbedding.get_embedding(
max_pos, self.embedding_dim, self.padding_idx)
self.weights = self.weights.to(self._float_tensor)
positions = self.make_positions(input, self.padding_idx).to(
self._float_tensor.device)
return self.weights.index_select(0, positions.view(-1)).view(
b, seq_len, self.embedding_dim).detach()
def make_positions(self, input, padding_idx):
mask = input.ne(padding_idx).int()
return (torch.cumsum(mask, dim=1).type_as(mask) *
mask).long() + padding_idx
def make_grid2d(self, height, width, num_batches=1, center_shift=None):
h, w = height, width
# if `center_shift` is not given from the outside, use
# `self.center_shift`
if center_shift is None:
center_shift = self.center_shift
h_shift = 0
w_shift = 0
# center shift to the input grid
if center_shift is not None:
# if h/w is even, the left center should be aligned with
# center shift
if h % 2 == 0:
h_left_center = h // 2
h_shift = center_shift - h_left_center
else:
h_center = h // 2 + 1
h_shift = center_shift - h_center
if w % 2 == 0:
w_left_center = w // 2
w_shift = center_shift - w_left_center
else:
w_center = w // 2 + 1
w_shift = center_shift - w_center
# Note that the index is started from 1 since zero will be padding idx.
# axis -- (b, h or w)
x_axis = torch.arange(1, w + 1).unsqueeze(0).repeat(num_batches,
1) + w_shift
y_axis = torch.arange(1, h + 1).unsqueeze(0).repeat(num_batches,
1) + h_shift
# emb -- (b, emb_dim, h or w)
x_emb = self(x_axis).transpose(1, 2)
y_emb = self(y_axis).transpose(1, 2)
# make grid for x/y axis
# Note that repeat will copy data. If use learned emb, expand may be
# better.
x_grid = x_emb.unsqueeze(2).repeat(1, 1, h, 1)
y_grid = y_emb.unsqueeze(3).repeat(1, 1, 1, w)
# cat grid -- (b, 2 x emb_dim, h, w)
grid = torch.cat([x_grid, y_grid], dim=1)
return grid.detach()
def make_grid2d_like(self, x, center_shift=None):
"""Input tensor with shape of (b, ..., h, w) Return tensor with shape
of (b, 2 x emb_dim, h, w)
Note that the positional embedding highly depends on the the function,
``make_positions``.
"""
h, w = x.shape[-2:]
grid = self.make_grid2d(h, w, x.size(0), center_shift)
return grid.to(x)
@MODULES.register_module('CSG2d')
@MODULES.register_module('CSG')
@MODULES.register_module()
class CatersianGrid(nn.Module):
"""Catersian Grid for 2d tensor.
The Catersian Grid is a common-used positional encoding in deep learning.
In this implementation, we follow the convention of ``grid_sample`` in
PyTorch. In other words, ``[-1, -1]`` denotes the left-top corner while
``[1, 1]`` denotes the right-botton corner.
"""
def forward(self, x, **kwargs):
assert x.dim() == 4
return self.make_grid2d_like(x, **kwargs)
def make_grid2d(self, height, width, num_batches=1, requires_grad=False):
h, w = height, width
grid_y, grid_x = torch.meshgrid(torch.arange(0, h), torch.arange(0, w))
grid_x = 2 * grid_x / max(float(w) - 1., 1.) - 1.
grid_y = 2 * grid_y / max(float(h) - 1., 1.) - 1.
grid = torch.stack((grid_x, grid_y), 0)
grid.requires_grad = requires_grad
grid = torch.unsqueeze(grid, 0)
grid = grid.repeat(num_batches, 1, 1, 1)
return grid
def make_grid2d_like(self, x, requires_grad=False):
h, w = x.shape[-2:]
grid = self.make_grid2d(h, w, x.size(0), requires_grad=requires_grad)
return grid.to(x)
# Copyright (c) OpenMMLab. All rights reserved.
from .generator_discriminator import (SinGANMultiScaleDiscriminator,
SinGANMultiScaleGenerator)
from .positional_encoding import SinGANMSGeneratorPE
__all__ = [
'SinGANMultiScaleDiscriminator', 'SinGANMultiScaleGenerator',
'SinGANMSGeneratorPE'
]
# Copyright (c) OpenMMLab. All rights reserved.
from functools import partial
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.runner import load_state_dict
from mmcv.utils import print_log
from mmgen.models.builder import MODULES
from mmgen.utils import get_root_logger
from .modules import DiscriminatorBlock, GeneratorBlock
@MODULES.register_module()
class SinGANMultiScaleGenerator(nn.Module):
"""Multi-Scale Generator used in SinGAN.
More details can be found in: Singan: Learning a Generative Model from a
Single Natural Image, ICCV'19.
Notes:
- In this version, we adopt the interpolation function from the official
PyTorch APIs, which is different from the original implementation by the
authors. However, in our experiments, this influence can be ignored.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
num_scales (int): The number of scales/stages in generator. Note
that this number is counted from zero, which is the same as the
original paper.
kernel_size (int, optional): Kernel size, same as :obj:`nn.Conv2d`.
Defaults to 3.
padding (int, optional): Padding for the convolutional layer, same as
:obj:`nn.Conv2d`. Defaults to 0.
num_layers (int, optional): The number of convolutional layers in each
generator block. Defaults to 5.
base_channels (int, optional): The basic channels for convolutional
layers in the generator block. Defaults to 32.
min_feat_channels (int, optional): Minimum channels for the feature
maps in the generator block. Defaults to 32.
out_act_cfg (dict | None, optional): Configs for output activation
layer. Defaults to dict(type='Tanh').
"""
def __init__(self,
in_channels,
out_channels,
num_scales,
kernel_size=3,
padding=0,
num_layers=5,
base_channels=32,
min_feat_channels=32,
out_act_cfg=dict(type='Tanh'),
**kwargs):
super().__init__()
self.pad_head = int((kernel_size - 1) / 2 * num_layers)
self.blocks = nn.ModuleList()
self.upsample = partial(
F.interpolate, mode='bicubic', align_corners=True)
for scale in range(num_scales + 1):
base_ch = min(base_channels * pow(2, int(np.floor(scale / 4))),
128)
min_feat_ch = min(
min_feat_channels * pow(2, int(np.floor(scale / 4))), 128)
self.blocks.append(
GeneratorBlock(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
padding=padding,
num_layers=num_layers,
base_channels=base_ch,
min_feat_channels=min_feat_ch,
out_act_cfg=out_act_cfg,
**kwargs))
self.noise_padding_layer = nn.ZeroPad2d(self.pad_head)
self.img_padding_layer = nn.ZeroPad2d(self.pad_head)
def forward(self,
input_sample,
fixed_noises,
noise_weights,
rand_mode,
curr_scale,
num_batches=1,
get_prev_res=False,
return_noise=False):
"""Forward function.
Args:
input_sample (Tensor | None): The input for generator. In the
original implementation, a tensor filled with zeros is adopted.
If None is given, we will construct it from the first fixed
noises.
fixed_noises (list[Tensor]): List of the fixed noises in SinGAN.
noise_weights (list[float]): List of the weights for random noises.
rand_mode (str): Choices from ['rand', 'recon']. In ``rand`` mode,
it will sample from random noises. Otherwise, the
reconstruction for the single image will be returned.
curr_scale (int): The scale for the current inference or training.
num_batches (int, optional): The number of batches. Defaults to 1.
get_prev_res (bool, optional): Whether to return results from
previous stages. Defaults to False.
return_noise (bool, optional): Whether to return noises tensor.
Defaults to False.
Returns:
Tensor | dict: Generated image tensor or dictionary containing \
more data.
"""
if get_prev_res or return_noise:
prev_res_list = []
noise_list = []
if input_sample is None:
input_sample = torch.zeros(
(num_batches, 3, fixed_noises[0].shape[-2],
fixed_noises[0].shape[-1])).to(fixed_noises[0])
g_res = input_sample
for stage in range(curr_scale + 1):
if rand_mode == 'recon':
noise_ = fixed_noises[stage]
else:
noise_ = torch.randn(num_batches,
*fixed_noises[stage].shape[1:]).to(g_res)
if return_noise:
noise_list.append(noise_)
# add padding at head
pad_ = (self.pad_head, ) * 4
noise_ = F.pad(noise_, pad_)
g_res_pad = F.pad(g_res, pad_)
noise = noise_ * noise_weights[stage] + g_res_pad
g_res = self.blocks[stage](noise.detach(), g_res)
if get_prev_res and stage != curr_scale:
prev_res_list.append(g_res)
# upsample, here we use interpolation from PyTorch
if stage != curr_scale:
h_next, w_next = fixed_noises[stage + 1].shape[-2:]
g_res = self.upsample(g_res, (h_next, w_next))
if get_prev_res or return_noise:
output_dict = dict(
fake_img=g_res,
prev_res_list=prev_res_list,
noise_batch=noise_list)
return output_dict
return g_res
def check_and_load_prev_weight(self, curr_scale):
if curr_scale == 0:
return
prev_ch = self.blocks[curr_scale - 1].base_channels
curr_ch = self.blocks[curr_scale].base_channels
prev_in_ch = self.blocks[curr_scale - 1].in_channels
curr_in_ch = self.blocks[curr_scale].in_channels
if prev_ch == curr_ch and prev_in_ch == curr_in_ch:
load_state_dict(
self.blocks[curr_scale],
self.blocks[curr_scale - 1].state_dict(),
logger=get_root_logger())
print_log('Successfully load pretrianed model from last scale.')
else:
print_log(
'Cannot load pretrained model from last scale since'
f' prev_ch({prev_ch}) != curr_ch({curr_ch})'
f' or prev_in_ch({prev_in_ch}) != curr_in_ch({curr_in_ch})')
@MODULES.register_module()
class SinGANMultiScaleDiscriminator(nn.Module):
"""Multi-Scale Discriminator used in SinGAN.
More details can be found in: Singan: Learning a Generative Model from a
Single Natural Image, ICCV'19.
Args:
in_channels (int): Input channels.
num_scales (int): The number of scales/stages in generator. Note
that this number is counted from zero, which is the same as the
original paper.
kernel_size (int, optional): Kernel size, same as :obj:`nn.Conv2d`.
Defaults to 3.
padding (int, optional): Padding for the convolutional layer, same as
:obj:`nn.Conv2d`. Defaults to 0.
num_layers (int, optional): The number of convolutional layers in each
generator block. Defaults to 5.
base_channels (int, optional): The basic channels for convolutional
layers in the generator block. Defaults to 32.
min_feat_channels (int, optional): Minimum channels for the feature
maps in the generator block. Defaults to 32.
"""
def __init__(self,
in_channels,
num_scales,
kernel_size=3,
padding=0,
num_layers=5,
base_channels=32,
min_feat_channels=32,
**kwargs):
super().__init__()
self.blocks = nn.ModuleList()
for scale in range(num_scales + 1):
base_ch = min(base_channels * pow(2, int(np.floor(scale / 4))),
128)
min_feat_ch = min(
min_feat_channels * pow(2, int(np.floor(scale / 4))), 128)
self.blocks.append(
DiscriminatorBlock(
in_channels=in_channels,
kernel_size=kernel_size,
padding=padding,
num_layers=num_layers,
base_channels=base_ch,
min_feat_channels=min_feat_ch,
**kwargs))
def forward(self, x, curr_scale):
"""Forward function.
Args:
x (Tensor): Input feature map.
curr_scale (int): Current scale for discriminator. If in testing,
you need to set it to the last scale.
Returns:
Tensor: Discriminative results.
"""
out = self.blocks[curr_scale](x)
return out
def check_and_load_prev_weight(self, curr_scale):
if curr_scale == 0:
return
prev_ch = self.blocks[curr_scale - 1].base_channels
curr_ch = self.blocks[curr_scale].base_channels
if prev_ch == curr_ch:
self.blocks[curr_scale].load_state_dict(
self.blocks[curr_scale - 1].state_dict())
print_log('Successfully load pretrianed model from last scale.')
else:
print_log('Cannot load pretrained model from last scale since'
f' prev_ch({prev_ch}) != curr_ch({curr_ch})')
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import ConvModule, constant_init, normal_init
from mmcv.runner import load_checkpoint
from mmcv.utils.parrots_wrapper import _BatchNorm
from mmgen.utils import get_root_logger
class GeneratorBlock(nn.Module):
"""Generator block used in SinGAN.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
num_scales (int): The number of scales/stages in generator. Note
that this number is counted from zero, which is the same as the
original paper.
kernel_size (int, optional): Kernel size, same as :obj:`nn.Conv2d`.
Defaults to 3.
padding (int, optional): Padding for the convolutional layer, same as
:obj:`nn.Conv2d`. Defaults to 0.
num_layers (int, optional): The number of convolutional layers in each
generator block. Defaults to 5.
base_channels (int, optional): The basic channels for convolutional
layers in the generator block. Defaults to 32.
min_feat_channels (int, optional): Minimum channels for the feature
maps in the generator block. Defaults to 32.
out_act_cfg (dict | None, optional): Configs for output activation
layer. Defaults to dict(type='Tanh').
stride (int, optional): Same as :obj:`nn.Conv2d`. Defaults to 1.
allow_no_residual (bool, optional): Whether to allow no residual link
in this block. Defaults to False.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
padding,
num_layers,
base_channels,
min_feat_channels,
out_act_cfg=dict(type='Tanh'),
stride=1,
allow_no_residual=False,
**kwargs):
super().__init__()
self.in_channels = in_channels
self.base_channels = base_channels
self.kernel_size = kernel_size
self.num_layers = num_layers
self.allow_no_residual = allow_no_residual
self.head = ConvModule(
in_channels=in_channels,
out_channels=base_channels,
kernel_size=kernel_size,
padding=padding,
stride=1,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='LeakyReLU', negative_slope=0.2),
**kwargs)
self.body = nn.Sequential()
for i in range(num_layers - 2):
feat_channels_ = int(base_channels / pow(2, (i + 1)))
block = ConvModule(
max(2 * feat_channels_, min_feat_channels),
max(feat_channels_, min_feat_channels),
kernel_size=kernel_size,
padding=padding,
stride=stride,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='LeakyReLU', negative_slope=0.2),
**kwargs)
self.body.add_module(f'block{i+1}', block)
self.tail = ConvModule(
max(feat_channels_, min_feat_channels),
out_channels,
kernel_size=kernel_size,
padding=padding,
stride=1,
norm_cfg=None,
act_cfg=out_act_cfg,
**kwargs)
self.init_weights()
def forward(self, x, prev):
"""Forward function.
Args:
x (Tensor): Input feature map.
prev (Tensor): Previous feature map.
Returns:
Tensor: Output feature map with the shape of (N, C, H, W).
"""
x = self.head(x)
x = self.body(x)
x = self.tail(x)
# if prev and x are not in the same shape at the channel dimension
if self.allow_no_residual and x.shape[1] != prev.shape[1]:
return x
return x + prev
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
normal_init(m, 0, 0.02)
elif isinstance(m, (_BatchNorm, nn.InstanceNorm2d)):
constant_init(m, 1)
else:
raise TypeError('pretrained must be a str or None but'
f' got {type(pretrained)} instead.')
class DiscriminatorBlock(nn.Module):
"""Discriminator Block used in SinGAN.
Args:
in_channels (int): Input channels.
base_channels (int): Base channels for this block.
min_feat_channels (int): The minimum channels for feature map.
kernel_size (int): Size of convolutional kernel, same as
:obj:`nn.Conv2d`.
padding (int): Padding for convolutional layer, same as
:obj:`nn.Conv2d`.
num_layers (int): The number of convolutional layers in this block.
norm_cfg (dict | None, optional): Config for the normalization layer.
Defaults to dict(type='BN').
act_cfg (dict | None, optional): Config for the activation layer.
Defaults to dict(type='LeakyReLU', negative_slope=0.2).
stride (int, optional): The stride for the convolutional layer, same as
:obj:`nn.Conv2d`. Defaults to 1.
"""
def __init__(self,
in_channels,
base_channels,
min_feat_channels,
kernel_size,
padding,
num_layers,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='LeakyReLU', negative_slope=0.2),
stride=1,
**kwargs):
super().__init__()
self.base_channels = base_channels
self.stride = stride
self.head = ConvModule(
in_channels,
base_channels,
kernel_size=kernel_size,
padding=padding,
stride=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
**kwargs)
self.body = nn.Sequential()
for i in range(num_layers - 2):
feat_channels_ = int(base_channels / pow(2, (i + 1)))
block = ConvModule(
max(2 * feat_channels_, min_feat_channels),
max(feat_channels_, min_feat_channels),
kernel_size=kernel_size,
padding=padding,
stride=stride,
conv_cfg=None,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
**kwargs)
self.body.add_module(f'block{i+1}', block)
self.tail = ConvModule(
max(feat_channels_, min_feat_channels),
1,
kernel_size=kernel_size,
padding=padding,
stride=1,
norm_cfg=None,
act_cfg=None,
**kwargs)
self.init_weights()
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input feature map with shape of (N, C, H, W).
Returns:
Tensor: Output feature map.
"""
x = self.head(x)
x = self.body(x)
x = self.tail(x)
return x
# TODO: study the effects of init functions
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
normal_init(m, 0, 0.02)
elif isinstance(m, (_BatchNorm, nn.InstanceNorm2d)):
constant_init(m, 1)
else:
raise TypeError('pretrained must be a str or None but'
f' got {type(pretrained)} instead.')
# Copyright (c) OpenMMLab. All rights reserved.
"""Implementation for Positional Encoding as Spatial Inductive Bias in GANs.
In this module, we provide necessary components to conduct experiments
mentioned in the paper: Positional Encoding as Spatial Inductive Bias in GANs.
More details can be found in: https://arxiv.org/pdf/2012.05217.pdf
"""
from functools import partial
import mmcv
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmgen.models.builder import MODULES, build_module
from .generator_discriminator import SinGANMultiScaleGenerator
from .modules import GeneratorBlock
@MODULES.register_module()
class SinGANMSGeneratorPE(SinGANMultiScaleGenerator):
"""Multi-Scale Generator used in SinGAN with positional encoding.
More details can be found in: Positional Encoding as Spatial Inductvie Bias
in GANs, CVPR'2021.
Notes:
- In this version, we adopt the interpolation function from the official
PyTorch APIs, which is different from the original implementation by the
authors. However, in our experiments, this influence can be ignored.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
num_scales (int): The number of scales/stages in generator. Note
that this number is counted from zero, which is the same as the
original paper.
kernel_size (int, optional): Kernel size, same as :obj:`nn.Conv2d`.
Defaults to 3.
padding (int, optional): Padding for the convolutional layer, same as
:obj:`nn.Conv2d`. Defaults to 0.
num_layers (int, optional): The number of convolutional layers in each
generator block. Defaults to 5.
base_channels (int, optional): The basic channels for convolutional
layers in the generator block. Defaults to 32.
min_feat_channels (int, optional): Minimum channels for the feature
maps in the generator block. Defaults to 32.
out_act_cfg (dict | None, optional): Configs for output activation
layer. Defaults to dict(type='Tanh').
padding_mode (str, optional): The mode of convolutional padding, same
as :obj:`nn.Conv2d`. Defaults to 'zero'.
pad_at_head (bool, optional): Whether to add padding at head.
Defaults to True.
interp_pad (bool, optional): The padding value of interpolating feature
maps. Defaults to False.
noise_with_pad (bool, optional): Whether the input fixed noises are
with explicit padding. Defaults to False.
positional_encoding (dict | None, optional): Configs for the positional
encoding. Defaults to None.
first_stage_in_channels (int | None, optional): The input channel of
the first generator block. If None, the first stage will adopt the
same input channels as other stages. Defaults to None.
"""
def __init__(self,
in_channels,
out_channels,
num_scales,
kernel_size=3,
padding=0,
num_layers=5,
base_channels=32,
min_feat_channels=32,
out_act_cfg=dict(type='Tanh'),
padding_mode='zero',
pad_at_head=True,
interp_pad=False,
noise_with_pad=False,
positional_encoding=None,
first_stage_in_channels=None,
**kwargs):
super(SinGANMultiScaleGenerator, self).__init__()
self.pad_at_head = pad_at_head
self.interp_pad = interp_pad
self.noise_with_pad = noise_with_pad
self.with_positional_encode = positional_encoding is not None
if self.with_positional_encode:
self.head_position_encode = build_module(positional_encoding)
self.pad_head = int((kernel_size - 1) / 2 * num_layers)
self.blocks = nn.ModuleList()
self.upsample = partial(
F.interpolate, mode='bicubic', align_corners=True)
for scale in range(num_scales + 1):
base_ch = min(base_channels * pow(2, int(np.floor(scale / 4))),
128)
min_feat_ch = min(
min_feat_channels * pow(2, int(np.floor(scale / 4))), 128)
if scale == 0:
in_ch = (
first_stage_in_channels
if first_stage_in_channels else in_channels)
else:
in_ch = in_channels
self.blocks.append(
GeneratorBlock(
in_channels=in_ch,
out_channels=out_channels,
kernel_size=kernel_size,
padding=padding,
num_layers=num_layers,
base_channels=base_ch,
min_feat_channels=min_feat_ch,
out_act_cfg=out_act_cfg,
padding_mode=padding_mode,
**kwargs))
if padding_mode == 'zero':
self.noise_padding_layer = nn.ZeroPad2d(self.pad_head)
self.img_padding_layer = nn.ZeroPad2d(self.pad_head)
self.mask_padding_layer = nn.ReflectionPad2d(self.pad_head)
elif padding_mode == 'reflect':
self.noise_padding_layer = nn.ReflectionPad2d(self.pad_head)
self.img_padding_layer = nn.ReflectionPad2d(self.pad_head)
self.mask_padding_layer = nn.ReflectionPad2d(self.pad_head)
mmcv.print_log('Using Reflection padding', 'mmgen')
else:
raise NotImplementedError(
f'Padding mode {padding_mode} is not supported')
def forward(self,
input_sample,
fixed_noises,
noise_weights,
rand_mode,
curr_scale,
num_batches=1,
get_prev_res=False,
return_noise=False):
"""Forward function.
Args:
input_sample (Tensor | None): The input for generator. In the
original implementation, a tensor filled with zeros is adopted.
If None is given, we will construct it from the first fixed
noises.
fixed_noises (list[Tensor]): List of the fixed noises in SinGAN.
noise_weights (list[float]): List of the weights for random noises.
rand_mode (str): Choices from ['rand', 'recon']. In ``rand`` mode,
it will sample from random noises. Otherwise, the
reconstruction for the single image will be returned.
curr_scale (int): The scale for the current inference or training.
num_batches (int, optional): The number of batches. Defaults to 1.
get_prev_res (bool, optional): Whether to return results from
previous stages. Defaults to False.
return_noise (bool, optional): Whether to return noises tensor.
Defaults to False.
Returns:
Tensor | dict: Generated image tensor or dictionary containing \
more data.
"""
if get_prev_res or return_noise:
prev_res_list = []
noise_list = []
if input_sample is None:
input_sample = torch.zeros(
(num_batches, 3, fixed_noises[0].shape[-2],
fixed_noises[0].shape[-1])).to(fixed_noises[0])
g_res = input_sample
for stage in range(curr_scale + 1):
if rand_mode == 'recon':
noise_ = fixed_noises[stage]
else:
noise_ = torch.randn(num_batches,
*fixed_noises[stage].shape[1:]).to(g_res)
if return_noise:
noise_list.append(noise_)
if self.with_positional_encode and stage == 0:
head_grid = self.head_position_encode(fixed_noises[0])
noise_ = noise_ + head_grid
# add padding at head
if self.pad_at_head:
if self.interp_pad:
if self.noise_with_pad:
size = noise_.shape[-2:]
else:
size = (noise_.size(2) + 2 * self.pad_head,
noise_.size(3) + 2 * self.pad_head)
noise_ = self.upsample(noise_, size)
g_res_pad = self.upsample(g_res, size)
else:
if not self.noise_with_pad:
noise_ = self.noise_padding_layer(noise_)
g_res_pad = self.img_padding_layer(g_res)
else:
g_res_pad = g_res
if stage == 0 and self.with_positional_encode:
noise = noise_ * noise_weights[stage]
else:
noise = noise_ * noise_weights[stage] + g_res_pad
g_res = self.blocks[stage](noise.detach(), g_res)
if get_prev_res and stage != curr_scale:
prev_res_list.append(g_res)
# upsample, here we use interpolation from PyTorch
if stage != curr_scale:
h_next, w_next = fixed_noises[stage + 1].shape[-2:]
if self.noise_with_pad:
# remove the additional padding if noise with pad
h_next -= 2 * self.pad_head
w_next -= 2 * self.pad_head
g_res = self.upsample(g_res, (h_next, w_next))
if get_prev_res or return_noise:
output_dict = dict(
fake_img=g_res,
prev_res_list=prev_res_list,
noise_batch=noise_list)
return output_dict
return g_res
# Copyright (c) OpenMMLab. All rights reserved.
from .generator_discriminator import ProjDiscriminator, SNGANGenerator
from .modules import SNGANDiscHeadResBlock, SNGANDiscResBlock, SNGANGenResBlock
__all__ = [
'ProjDiscriminator', 'SNGANGenerator', 'SNGANGenResBlock',
'SNGANDiscResBlock', 'SNGANDiscHeadResBlock'
]
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import (ConvModule, build_activation_layer, constant_init,
xavier_init)
from mmcv.runner import load_checkpoint
from mmcv.runner.checkpoint import _load_checkpoint_with_prefix
from mmcv.utils import is_list_of
from torch.nn.init import xavier_uniform_
from torch.nn.utils import spectral_norm
from mmgen.models.builder import MODULES, build_module
from mmgen.utils import check_dist_init
from mmgen.utils.logger import get_root_logger
from ..common import get_module_device
@MODULES.register_module('SAGANGenerator')
@MODULES.register_module()
class SNGANGenerator(nn.Module):
r"""Generator for SNGAN / Proj-GAN. The implementation refers to
https://github.com/pfnet-research/sngan_projection/tree/master/gen_models
In our implementation, we have two notable design. Namely,
``channels_cfg`` and ``blocks_cfg``.
``channels_cfg``: In default config of SNGAN / Proj-GAN, the number of
ResBlocks and the channels of those blocks are corresponding to the
resolution of the output image. Therefore, we allow user to define
``channels_cfg`` to try their own models. We also provide a default
config to allow users to build the model only from the output
resolution.
``block_cfg``: In reference code, the generator consists of a group of
ResBlock. However, in our implementation, to make this model more
generalize, we support defining ``blocks_cfg`` by users and loading
the blocks by calling the build_module method.
Args:
output_scale (int): Output scale for the generated image.
num_classes (int, optional): The number classes you would like to
generate. This arguments would influence the structure of the
intermedia blocks and label sampling operation in ``forward``
(e.g. If num_classes=0, ConditionalNormalization layers would
degrade to unconditional ones.). This arguments would be passed
to intermedia blocks by overwrite their config. Defaults to 0.
base_channels (int, optional): The basic channel number of the
generator. The other layers contains channels based on this number.
Default to 64.
out_channels (int, optional): Channels of the output images.
Default to 3.
input_scale (int, optional): Input scale for the features.
Defaults to 4.
noise_size (int, optional): Size of the input noise vector.
Default to 128.
attention_cfg (dict, optional): Config for the self-attention block.
Default to ``dict(type='SelfAttentionBlock')``.
attention_after_nth_block (int | list[int], optional): Self attention
block would be added after which *ConvBlock*. If ``int`` is passed,
only one attention block would be added. If ``list`` is passed,
self-attention blocks would be added after multiple ConvBlocks.
To be noted that if the input is smaller than ``1``,
self-attention corresponding to this index would be ignored.
Default to 0.
channels_cfg (list | dict[list], optional): Config for input channels
of the intermedia blocks. If list is passed, each element of the
list means the input channels of current block is how many times
compared to the ``base_channels``. For block ``i``, the input and
output channels should be ``channels_cfg[i]`` and
``channels_cfg[i+1]`` If dict is provided, the key of the dict
should be the output scale and corresponding value should be a list
to define channels. Default: Please refer to
``_defualt_channels_cfg``.
blocks_cfg (dict, optional): Config for the intermedia blocks.
Defaults to ``dict(type='SNGANGenResBlock')``
act_cfg (dict, optional): Activation config for the final output
layer. Defaults to ``dict(type='ReLU')``.
use_cbn (bool, optional): Whether use conditional normalization. This
argument would pass to norm layers. Defaults to True.
auto_sync_bn (bool, optional): Whether convert Batch Norm to
Synchronized ones when Distributed training is on. Defaults to
True.
with_spectral_norm (bool, optional): Whether use spectral norm for
conv blocks or not. Default to False.
with_embedding_spectral_norm (bool, optional): Whether use spectral
norm for embedding layers in normalization blocks or not. If not
specified (set as ``None``), ``with_embedding_spectral_norm`` would
be set as the same value as ``with_spectral_norm``.
Defaults to None.
sn_style (str, optional): The style of spectral normalization.
If set to `ajbrock`, implementation by
ajbrock(https://github.com/ajbrock/BigGAN-PyTorch/blob/master/layers.py)
will be adopted.
If set to `torch`, implementation by `PyTorch` will be adopted.
Defaults to `torch`.
norm_eps (float, optional): eps for Normalization layers (both
conditional and non-conditional ones). Default to `1e-4`.
sn_eps (float, optional): eps for spectral normalization operation.
Defaults to `1e-12`.
init_cfg (string, optional): Config for weight initialization.
Defaults to ``dict(type='BigGAN')``.
pretrained (str | dict, optional): Path for the pretrained model or
dict containing information for pretained models whose necessary
key is 'ckpt_path'. Besides, you can also provide 'prefix' to load
the generator part from the whole state dict. Defaults to None.
"""
# default channel factors
_default_channels_cfg = {
32: [1, 1, 1],
64: [16, 8, 4, 2],
128: [16, 16, 8, 4, 2]
}
def __init__(self,
output_scale,
num_classes=0,
base_channels=64,
out_channels=3,
input_scale=4,
noise_size=128,
attention_cfg=dict(type='SelfAttentionBlock'),
attention_after_nth_block=0,
channels_cfg=None,
blocks_cfg=dict(type='SNGANGenResBlock'),
act_cfg=dict(type='ReLU'),
use_cbn=True,
auto_sync_bn=True,
with_spectral_norm=False,
with_embedding_spectral_norm=None,
sn_style='torch',
norm_eps=1e-4,
sn_eps=1e-12,
init_cfg=dict(type='BigGAN'),
pretrained=None):
super().__init__()
self.input_scale = input_scale
self.output_scale = output_scale
self.noise_size = noise_size
self.num_classes = num_classes
self.init_type = init_cfg.get('type', None)
self.blocks_cfg = deepcopy(blocks_cfg)
self.blocks_cfg.setdefault('num_classes', num_classes)
self.blocks_cfg.setdefault('act_cfg', act_cfg)
self.blocks_cfg.setdefault('use_cbn', use_cbn)
self.blocks_cfg.setdefault('auto_sync_bn', auto_sync_bn)
self.blocks_cfg.setdefault('with_spectral_norm', with_spectral_norm)
# set `norm_spectral_norm` as `with_spectral_norm` if not defined
with_embedding_spectral_norm = with_embedding_spectral_norm \
if with_embedding_spectral_norm is not None else with_spectral_norm
self.blocks_cfg.setdefault('with_embedding_spectral_norm',
with_embedding_spectral_norm)
self.blocks_cfg.setdefault('init_cfg', init_cfg)
self.blocks_cfg.setdefault('sn_style', sn_style)
self.blocks_cfg.setdefault('norm_eps', norm_eps)
self.blocks_cfg.setdefault('sn_eps', sn_eps)
channels_cfg = deepcopy(self._default_channels_cfg) \
if channels_cfg is None else deepcopy(channels_cfg)
if isinstance(channels_cfg, dict):
if output_scale not in channels_cfg:
raise KeyError(f'`output_scale={output_scale} is not found in '
'`channel_cfg`, only support configs for '
f'{[chn for chn in channels_cfg.keys()]}')
self.channel_factor_list = channels_cfg[output_scale]
elif isinstance(channels_cfg, list):
self.channel_factor_list = channels_cfg
else:
raise ValueError('Only support list or dict for `channel_cfg`, '
f'receive {type(channels_cfg)}')
self.noise2feat = nn.Linear(
noise_size,
input_scale**2 * base_channels * self.channel_factor_list[0])
if with_spectral_norm:
self.noise2feat = spectral_norm(self.noise2feat)
# check `attention_after_nth_block`
if not isinstance(attention_after_nth_block, list):
attention_after_nth_block = [attention_after_nth_block]
if not is_list_of(attention_after_nth_block, int):
raise ValueError('`attention_after_nth_block` only support int or '
'a list of int. Please check your input type.')
self.conv_blocks = nn.ModuleList()
self.attention_block_idx = []
for idx in range(len(self.channel_factor_list)):
factor_input = self.channel_factor_list[idx]
factor_output = self.channel_factor_list[idx+1] \
if idx < len(self.channel_factor_list)-1 else 1
# get block-specific config
block_cfg_ = deepcopy(self.blocks_cfg)
block_cfg_['in_channels'] = factor_input * base_channels
block_cfg_['out_channels'] = factor_output * base_channels
self.conv_blocks.append(build_module(block_cfg_))
# build self-attention block
# `idx` is start from 0, add 1 to get the index
if idx + 1 in attention_after_nth_block:
self.attention_block_idx.append(len(self.conv_blocks))
attn_cfg_ = deepcopy(attention_cfg)
attn_cfg_['in_channels'] = factor_output * base_channels
attn_cfg_['sn_style'] = sn_style
self.conv_blocks.append(build_module(attn_cfg_))
to_rgb_norm_cfg = dict(type='BN', eps=norm_eps)
if check_dist_init() and auto_sync_bn:
to_rgb_norm_cfg['type'] = 'SyncBN'
self.to_rgb = ConvModule(
factor_output * base_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
bias=True,
norm_cfg=to_rgb_norm_cfg,
act_cfg=act_cfg,
order=('norm', 'act', 'conv'),
with_spectral_norm=with_spectral_norm)
self.final_act = build_activation_layer(dict(type='Tanh'))
self.init_weights(pretrained)
def forward(self, noise, num_batches=0, label=None, return_noise=False):
"""Forward function.
Args:
noise (torch.Tensor | callable | None): You can directly give a
batch of noise through a ``torch.Tensor`` or offer a callable
function to sample a batch of noise data. Otherwise, the
``None`` indicates to use the default noise sampler.
num_batches (int, optional): The number of batch size.
Defaults to 0.
label (torch.Tensor | callable | None): You can directly give a
batch of label through a ``torch.Tensor`` or offer a callable
function to sample a batch of label data. Otherwise, the
``None`` indicates to use the default label sampler.
return_noise (bool, optional): If True, ``noise_batch`` will be
returned in a dict with ``fake_img``. Defaults to False.
Returns:
torch.Tensor | dict: If not ``return_noise``, only the output
image will be returned. Otherwise, a dict contains
``fake_image``, ``noise_batch`` and ``label_batch``
would be returned.
"""
if isinstance(noise, torch.Tensor):
assert noise.shape[1] == self.noise_size
assert noise.ndim == 2, ('The noise should be in shape of (n, c), '
f'but got {noise.shape}')
noise_batch = noise
# receive a noise generator and sample noise.
elif callable(noise):
noise_generator = noise
assert num_batches > 0
noise_batch = noise_generator((num_batches, self.noise_size))
# otherwise, we will adopt default noise sampler.
else:
assert num_batches > 0
noise_batch = torch.randn((num_batches, self.noise_size))
if isinstance(label, torch.Tensor):
assert label.ndim == 1, ('The label shoube be in shape of (n, )'
f'but got {label.shape}.')
label_batch = label
elif callable(label):
label_generator = label
assert num_batches > 0
label_batch = label_generator(num_batches)
elif self.num_classes == 0:
label_batch = None
else:
assert num_batches > 0
label_batch = torch.randint(0, self.num_classes, (num_batches, ))
# dirty code for putting data on the right device
noise_batch = noise_batch.to(get_module_device(self))
if label_batch is not None:
label_batch = label_batch.to(get_module_device(self))
x = self.noise2feat(noise_batch)
x = x.reshape(x.size(0), -1, self.input_scale, self.input_scale)
for idx, conv_block in enumerate(self.conv_blocks):
if idx in self.attention_block_idx:
x = conv_block(x)
else:
x = conv_block(x, label_batch)
out_feat = self.to_rgb(x)
out_img = self.final_act(out_feat)
if return_noise:
return dict(
fake_img=out_img, noise_batch=noise_batch, label=label_batch)
return out_img
def init_weights(self, pretrained=None, strict=True):
"""Init weights for SNGAN-Proj and SAGAN. If ``pretrained=None``,
weight initialization would follow the ``INIT_TYPE`` in
``init_cfg=dict(type=INIT_TYPE)``.
For SNGAN-Proj,
(``INIT_TYPE.upper() in ['SNGAN', 'SNGAN-PROJ', 'GAN-PROJ']``),
we follow the initialization method in the official Chainer's
implementation (https://github.com/pfnet-research/sngan_projection).
For SAGAN (``INIT_TYPE.upper() == 'SAGAN'``), we follow the
initialization method in official tensorflow's implementation
(https://github.com/brain-research/self-attention-gan).
Besides the reimplementation of the official code's initialization, we
provide BigGAN's and Pytorch-StudioGAN's style initialization
(``INIT_TYPE.upper() == BIGGAN`` and ``INIT_TYPE.upper() == STUDIO``).
Please refer to https://github.com/ajbrock/BigGAN-PyTorch and
https://github.com/POSTECH-CVLab/PyTorch-StudioGAN.
Args:
pretrained (str | dict, optional): Path for the pretrained model or
dict containing information for pretained models whose
necessary key is 'ckpt_path'. Besides, you can also provide
'prefix' to load the generator part from the whole state dict.
Defaults to None.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=strict, logger=logger)
elif isinstance(pretrained, dict):
ckpt_path = pretrained.get('ckpt_path', None)
assert ckpt_path is not None
prefix = pretrained.get('prefix', '')
map_location = pretrained.get('map_location', 'cpu')
strict = pretrained.get('strict', True)
state_dict = _load_checkpoint_with_prefix(prefix, ckpt_path,
map_location)
self.load_state_dict(state_dict, strict=strict)
elif pretrained is None:
if self.init_type.upper() in 'STUDIO':
# initialization method from Pytorch-StudioGAN
# * weight: orthogonal_init gain=1
# * bias : 0
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.Linear, nn.Embedding)):
nn.init.orthogonal_(m.weight)
if hasattr(m, 'bias') and m.bias is not None:
m.bias.data.fill_(0.)
elif self.init_type.upper() == 'BIGGAN':
# initialization method from BigGAN-pytorch
# * weight: xavier_init gain=1
# * bias : default
for n, m in self.named_modules():
if isinstance(m, (nn.Conv2d, nn.Linear, nn.Embedding)):
xavier_uniform_(m.weight, gain=1)
elif self.init_type.upper() == 'SAGAN':
# initialization method from official tensorflow code
# * weight : xavier_init gain=1
# * bias : 0
# * weight_embedding: 1
# * bias_embedding : 0
for n, m in self.named_modules():
if isinstance(m, (nn.Conv2d, nn.Linear)):
xavier_init(m, gain=1, distribution='uniform')
if isinstance(m, nn.Embedding):
# To be noted that here we initialize the embedding
# layer in cBN with specific prefix. If you implement
# your own cBN and want to use this initialization
# method, please make sure the embedding layers in
# your implementation have the same prefix as ours.
if 'weight' in n:
constant_init(m, 1)
if 'bias' in n:
constant_init(m, 0)
elif self.init_type.upper() in ['SNGAN', 'SNGAN-PROJ', 'GAN-PROJ']:
# initialization method from the official chainer code
# * conv.weight : xavier_init gain=sqrt(2)
# * shortcut.weight : xavier_init gain=1
# * bias : 0
# * weight_embedding: 1
# * bias_embedding : 0
for n, m in self.named_modules():
if isinstance(m, nn.Conv2d):
if 'shortcut' in n or 'to_rgb' in n:
xavier_init(m, gain=1, distribution='uniform')
else:
xavier_init(
m, gain=np.sqrt(2), distribution='uniform')
if isinstance(m, nn.Linear):
xavier_init(m, gain=1, distribution='uniform')
if isinstance(m, nn.Embedding):
# To be noted that here we initialize the embedding
# layer in cBN with specific prefix. If you implement
# your own cBN and want to use this initialization
# method, please make sure the embedding layers in
# your implementation have the same prefix as ours.
if 'weight' in n:
constant_init(m, 1)
if 'bias' in n:
constant_init(m, 0)
else:
raise NotImplementedError('Unknown initialization method: '
f'\'{self.init_type}\'')
else:
raise TypeError("'pretrined' must be a str or None. "
f'But receive {type(pretrained)}.')
@MODULES.register_module('SAGANDiscriminator')
@MODULES.register_module()
class ProjDiscriminator(nn.Module):
r"""Discriminator for SNGAN / Proj-GAN. The implementation is refer to
https://github.com/pfnet-research/sngan_projection/tree/master/dis_models
The overall structure of the projection discriminator can be split into a
``from_rgb`` layer, a group of ResBlocks, a linear decision layer, and a
projection layer. To support defining custom layers, we introduce
``from_rgb_cfg`` and ``blocks_cfg``.
The design of the model structure is highly corresponding to the output
resolution. Therefore, we provide `channels_cfg` and `downsample_cfg` to
control the input channels and the downsample behavior of the intermedia
blocks.
``downsample_cfg``: In default config of SNGAN / Proj-GAN, whether to apply
downsample in each intermedia blocks is quite flexible and
corresponding to the resolution of the output image. Therefore, we
support user to define the ``downsample_cfg`` by themselves, and to
control the structure of the discriminator.
``channels_cfg``: In default config of SNGAN / Proj-GAN, the number of
ResBlocks and the channels of those blocks are corresponding to the
resolution of the output image. Therefore, we allow user to define
`channels_cfg` for try their own models. We also provide a default
config to allow users to build the model only from the output
resolution.
Args:
input_scale (int): The scale of the input image.
num_classes (int, optional): The number classes you would like to
generate. If num_classes=0, no label projection would be used.
Default to 0.
base_channels (int, optional): The basic channel number of the
discriminator. The other layers contains channels based on this
number. Defaults to 128.
input_channels (int, optional): Channels of the input image.
Defaults to 3.
attention_cfg (dict, optional): Config for the self-attention block.
Default to ``dict(type='SelfAttentionBlock')``.
attention_after_nth_block (int | list[int], optional): Self-attention
block would be added after which *ConvBlock* (including the head
block). If ``int`` is passed, only one attention block would be
added. If ``list`` is passed, self-attention blocks would be added
after multiple ConvBlocks. To be noted that if the input is
smaller than ``1``, self-attention corresponding to this index
would be ignored. Default to 0.
channels_cfg (list | dict[list], optional): Config for input channels
of the intermedia blocks. If list is passed, each element of the
list means the input channels of current block is how many times
compared to the ``base_channels``. For block ``i``, the input and
output channels should be ``channels_cfg[i]`` and
``channels_cfg[i+1]`` If dict is provided, the key of the dict
should be the output scale and corresponding value should be a list
to define channels. Default: Please refer to
``_defualt_channels_cfg``.
downsample_cfg (list[bool] | dict[list], optional): Config for
downsample behavior of the intermedia layers. If a list is passed,
``downsample_cfg[idx] == True`` means apply downsample in idx-th
block, and vice versa. If dict is provided, the key dict should
be the input scale of the image and corresponding value should be
a list ti define the downsample behavior. Default: Please refer
to ``_default_downsample_cfg``.
from_rgb_cfg (dict, optional): Config for the first layer to convert
rgb image to feature map. Defaults to
``dict(type='SNGANDiscHeadResBlock')``.
blocks_cfg (dict, optional): Config for the intermedia blocks.
Defaults to ``dict(type='SNGANDiscResBlock')``
act_cfg (dict, optional): Activation config for the final output
layer. Defaults to ``dict(type='ReLU')``.
with_spectral_norm (bool, optional): Whether use spectral norm for
all conv blocks or not. Default to True.
sn_style (str, optional): The style of spectral normalization.
If set to `ajbrock`, implementation by
ajbrock(https://github.com/ajbrock/BigGAN-PyTorch/blob/master/layers.py)
will be adopted.
If set to `torch`, implementation by `PyTorch` will be adopted.
Defaults to `torch`.
sn_eps (float, optional): eps for spectral normalization operation.
Defaults to `1e-12`.
init_cfg (dict, optional): Config for weight initialization.
Default to ``dict(type='BigGAN')``.
pretrained (str | dict , optional): Path for the pretrained model or
dict containing information for pretained models whose necessary
key is 'ckpt_path'. Besides, you can also provide 'prefix' to load
the generator part from the whole state dict. Defaults to None.
"""
# default channel factors
_defualt_channels_cfg = {
32: [1, 1, 1],
64: [2, 4, 8, 16],
128: [2, 4, 8, 16, 16],
}
# default downsample behavior
_defualt_downsample_cfg = {
32: [True, False, False],
64: [True, True, True, True],
128: [True, True, True, True, False]
}
def __init__(self,
input_scale,
num_classes=0,
base_channels=128,
input_channels=3,
attention_cfg=dict(type='SelfAttentionBlock'),
attention_after_nth_block=-1,
channels_cfg=None,
downsample_cfg=None,
from_rgb_cfg=dict(type='SNGANDiscHeadResBlock'),
blocks_cfg=dict(type='SNGANDiscResBlock'),
act_cfg=dict(type='ReLU'),
with_spectral_norm=True,
sn_style='torch',
sn_eps=1e-12,
init_cfg=dict(type='BigGAN'),
pretrained=None):
super().__init__()
self.init_type = init_cfg.get('type', None)
# add SN options and activation function options to cfg
self.from_rgb_cfg = deepcopy(from_rgb_cfg)
self.from_rgb_cfg.setdefault('act_cfg', act_cfg)
self.from_rgb_cfg.setdefault('with_spectral_norm', with_spectral_norm)
self.from_rgb_cfg.setdefault('sn_style', sn_style)
self.from_rgb_cfg.setdefault('init_cfg', init_cfg)
# add SN options and activation function options to cfg
self.blocks_cfg = deepcopy(blocks_cfg)
self.blocks_cfg.setdefault('act_cfg', act_cfg)
self.blocks_cfg.setdefault('with_spectral_norm', with_spectral_norm)
self.blocks_cfg.setdefault('sn_style', sn_style)
self.blocks_cfg.setdefault('sn_eps', sn_eps)
self.blocks_cfg.setdefault('init_cfg', init_cfg)
channels_cfg = deepcopy(self._defualt_channels_cfg) \
if channels_cfg is None else deepcopy(channels_cfg)
if isinstance(channels_cfg, dict):
if input_scale not in channels_cfg:
raise KeyError(f'`input_scale={input_scale} is not found in '
'`channel_cfg`, only support configs for '
f'{[chn for chn in channels_cfg.keys()]}')
self.channel_factor_list = channels_cfg[input_scale]
elif isinstance(channels_cfg, list):
self.channel_factor_list = channels_cfg
else:
raise ValueError('Only support list or dict for `channel_cfg`, '
f'receive {type(channels_cfg)}')
downsample_cfg = deepcopy(self._defualt_downsample_cfg) \
if downsample_cfg is None else deepcopy(downsample_cfg)
if isinstance(downsample_cfg, dict):
if input_scale not in downsample_cfg:
raise KeyError(f'`output_scale={input_scale} is not found in '
'`downsample_cfg`, only support configs for '
f'{[chn for chn in downsample_cfg.keys()]}')
self.downsample_list = downsample_cfg[input_scale]
elif isinstance(downsample_cfg, list):
self.downsample_list = downsample_cfg
else:
raise ValueError('Only support list or dict for `channel_cfg`, '
f'receive {type(downsample_cfg)}')
if len(self.downsample_list) != len(self.channel_factor_list):
raise ValueError('`downsample_cfg` should have same length with '
'`channels_cfg`, but receive '
f'{len(self.downsample_list)} and '
f'{len(self.channel_factor_list)}.')
# check `attention_after_nth_block`
if not isinstance(attention_after_nth_block, list):
attention_after_nth_block = [attention_after_nth_block]
if not all([isinstance(idx, int)
for idx in attention_after_nth_block]):
raise ValueError('`attention_after_nth_block` only support int or '
'a list of int. Please check your input type.')
self.from_rgb = build_module(
self.from_rgb_cfg,
dict(in_channels=input_channels, out_channels=base_channels))
self.conv_blocks = nn.ModuleList()
# add self-attention block after the first block
if 1 in attention_after_nth_block:
attn_cfg_ = deepcopy(attention_cfg)
attn_cfg_['in_channels'] = base_channels
attn_cfg_['sn_style'] = sn_style
self.conv_blocks.append(build_module(attn_cfg_))
for idx in range(len(self.downsample_list)):
factor_input = 1 if idx == 0 else self.channel_factor_list[idx - 1]
factor_output = self.channel_factor_list[idx]
# get block-specific config
block_cfg_ = deepcopy(self.blocks_cfg)
block_cfg_['downsample'] = self.downsample_list[idx]
block_cfg_['in_channels'] = factor_input * base_channels
block_cfg_['out_channels'] = factor_output * base_channels
self.conv_blocks.append(build_module(block_cfg_))
# build self-attention block
# the first ConvBlock is `from_rgb` block,
# add 2 to get the index of the ConvBlocks
if idx + 2 in attention_after_nth_block:
attn_cfg_ = deepcopy(attention_cfg)
attn_cfg_['in_channels'] = factor_output * base_channels
self.conv_blocks.append(build_module(attn_cfg_))
self.decision = nn.Linear(factor_output * base_channels, 1)
if with_spectral_norm:
self.decision = spectral_norm(self.decision)
self.num_classes = num_classes
# In this case, discriminator is designed for conditional synthesis.
if num_classes > 0:
self.proj_y = nn.Embedding(num_classes,
factor_output * base_channels)
if with_spectral_norm:
self.proj_y = spectral_norm(self.proj_y)
self.activate = build_activation_layer(act_cfg)
self.init_weights(pretrained)
def forward(self, x, label=None):
"""Forward function. If `self.num_classes` is larger than 0, label
projection would be used.
Args:
x (torch.Tensor): Fake or real image tensor.
label (torch.Tensor, options): Label correspond to the input image.
Noted that, if `self.num_classed` is larger than 0,
`label` should not be None. Default to None.
Returns:
torch.Tensor: Prediction for the reality of the input image.
"""
h = self.from_rgb(x)
for conv_block in self.conv_blocks:
h = conv_block(h)
h = self.activate(h)
h = torch.sum(h, dim=[2, 3])
out = self.decision(h)
if self.num_classes > 0:
w_y = self.proj_y(label)
out = out + torch.sum(w_y * h, dim=1, keepdim=True)
return out.view(out.size(0), -1)
def init_weights(self, pretrained=None, strict=True):
"""Init weights for SNGAN-Proj and SAGAN. If ``pretrained=None`` and
weight initialization would follow the ``INIT_TYPE`` in
``init_cfg=dict(type=INIT_TYPE)``.
For SNGAN-Proj
(``INIT_TYPE.upper() in ['SNGAN', 'SNGAN-PROJ', 'GAN-PROJ']``),
we follow the initialization method in the official Chainer's
implementation (https://github.com/pfnet-research/sngan_projection).
For SAGAN (``INIT_TYPE.upper() == 'SAGAN'``), we follow the
initialization method in official tensorflow's implementation
(https://github.com/brain-research/self-attention-gan).
Besides the reimplementation of the official code's initialization, we
provide BigGAN's and Pytorch-StudioGAN's style initialization
(``INIT_TYPE.upper() == BIGGAN`` and ``INIT_TYPE.upper() == STUDIO``).
Please refer to https://github.com/ajbrock/BigGAN-PyTorch and
https://github.com/POSTECH-CVLab/PyTorch-StudioGAN.
Args:
pretrained (str | dict, optional): Path for the pretrained model or
dict containing information for pretained models whose
necessary key is 'ckpt_path'. Besides, you can also provide
'prefix' to load the generator part from the whole state dict.
Defaults to None.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=strict, logger=logger)
elif isinstance(pretrained, dict):
ckpt_path = pretrained.get('ckpt_path', None)
assert ckpt_path is not None
prefix = pretrained.get('prefix', '')
map_location = pretrained.get('map_location', 'cpu')
strict = pretrained.get('strict', True)
state_dict = _load_checkpoint_with_prefix(prefix, ckpt_path,
map_location)
self.load_state_dict(state_dict, strict=strict)
elif pretrained is None:
if self.init_type.upper() == 'STUDIO':
# initialization method from Pytorch-StudioGAN
# * weight: orthogonal_init gain=1
# * bias : 0
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.Linear, nn.Embedding)):
nn.init.orthogonal_(m.weight, gain=1)
if hasattr(m, 'bias') and m.bias is not None:
m.bias.data.fill_(0.)
elif self.init_type.upper() == 'BIGGAN':
# initialization method from BigGAN-pytorch
# * weight: xavier_init gain=1
# * bias : default
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.Linear, nn.Embedding)):
xavier_uniform_(m.weight, gain=1)
elif self.init_type.upper() == 'SAGAN':
# initialization method from official tensorflow code
# * weight: xavier_init gain=1
# * bias : 0
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.Linear, nn.Embedding)):
xavier_init(m, gain=1, distribution='uniform')
elif self.init_type.upper() in ['SNGAN', 'SNGAN-PROJ', 'GAN-PROJ']:
# initialization method from the official chainer code
# * embedding.weight: xavier_init gain=1
# * conv.weight : xavier_init gain=sqrt(2)
# * shortcut.weight : xavier_init gain=1
# * bias : 0
for n, m in self.named_modules():
if isinstance(m, nn.Conv2d):
if 'shortcut' in n:
xavier_init(m, gain=1, distribution='uniform')
else:
xavier_init(
m, gain=np.sqrt(2), distribution='uniform')
if isinstance(m, (nn.Linear, nn.Embedding)):
xavier_init(m, gain=1, distribution='uniform')
else:
raise NotImplementedError('Unknown initialization method: '
f'\'{self.init_type}\'')
else:
raise TypeError("'pretrained' must by a str or None. "
f'But receive {type(pretrained)}.')
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
import numpy as np
import torch.nn as nn
from mmcv.cnn import (build_activation_layer, build_norm_layer,
build_upsample_layer, constant_init, xavier_init)
from torch.nn.init import xavier_uniform_
from torch.nn.utils import spectral_norm
from mmgen.models.architectures.biggan.biggan_snmodule import SNEmbedding
from mmgen.models.architectures.biggan.modules import SNConvModule
from mmgen.models.builder import MODULES
from mmgen.utils import check_dist_init
@MODULES.register_module()
class SNGANGenResBlock(nn.Module):
"""ResBlock used in Generator of SNGAN / Proj-GAN.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
hidden_channels (int, optional): Input channels of the second Conv
layer of the block. If ``None`` is given, would be set as
``out_channels``. Default to None.
num_classes (int, optional): Number of classes would like to generate.
This argument would pass to norm layers and influence the structure
and behavior of the normalization process. Default to 0.
use_cbn (bool, optional): Whether use conditional normalization. This
argument would pass to norm layers. Default to True.
use_norm_affine (bool, optional): Whether use learnable affine
parameters in norm operation when cbn is off. Default False.
act_cfg (dict, optional): Config for activate function. Default
to ``dict(type='ReLU')``.
upsample_cfg (dict, optional): Config for the upsample method.
Default to ``dict(type='nearest', scale_factor=2)``.
upsample (bool, optional): Whether apply upsample operation in this
module. Default to True.
auto_sync_bn (bool, optional): Whether convert Batch Norm to
Synchronized ones when Distributed training is on. Default to True.
conv_cfg (dict | None): Config for conv blocks of this module. If pass
``None``, would use ``_default_conv_cfg``. Default to ``None``.
with_spectral_norm (bool, optional): Whether use spectral norm for
conv blocks and norm layers. Default to True.
with_embedding_spectral_norm (bool, optional): Whether use spectral
norm for embedding layers in normalization blocks or not. If not
specified (set as ``None``), ``with_embedding_spectral_norm`` would
be set as the same value as ``with_spectral_norm``.
Default to None.
sn_style (str, optional): The style of spectral normalization.
If set to `ajbrock`, implementation by
ajbrock(https://github.com/ajbrock/BigGAN-PyTorch/blob/master/layers.py)
will be adopted.
If set to `torch`, implementation by `PyTorch` will be adopted.
Defaults to `torch`.
norm_eps (float, optional): eps for Normalization layers (both
conditional and non-conditional ones). Default to `1e-4`.
sn_eps (float, optional): eps for spectral normalization operation.
Default to `1e-12`.
init_cfg (dict, optional): Config for weight initialization.
Default to ``dict(type='BigGAN')``.
"""
_default_conv_cfg = dict(kernel_size=3, stride=1, padding=1, act_cfg=None)
def __init__(self,
in_channels,
out_channels,
hidden_channels=None,
num_classes=0,
use_cbn=True,
use_norm_affine=False,
act_cfg=dict(type='ReLU'),
norm_cfg=dict(type='BN'),
upsample_cfg=dict(type='nearest', scale_factor=2),
upsample=True,
auto_sync_bn=True,
conv_cfg=None,
with_spectral_norm=False,
with_embedding_spectral_norm=None,
sn_style='torch',
norm_eps=1e-4,
sn_eps=1e-12,
init_cfg=dict(type='BigGAN')):
super().__init__()
self.learnable_sc = in_channels != out_channels or upsample
self.with_upsample = upsample
self.init_type = init_cfg.get('type', None)
self.activate = build_activation_layer(act_cfg)
hidden_channels = out_channels if hidden_channels is None \
else hidden_channels
if self.with_upsample:
self.upsample = build_upsample_layer(upsample_cfg)
self.conv_cfg = deepcopy(self._default_conv_cfg)
if conv_cfg is not None:
self.conv_cfg.update(conv_cfg)
# set `norm_spectral_norm` as `with_spectral_norm` if not defined
with_embedding_spectral_norm = with_embedding_spectral_norm \
if with_embedding_spectral_norm is not None else with_spectral_norm
sn_cfg = dict(eps=sn_eps, sn_style=sn_style)
self.conv_1 = SNConvModule(
in_channels,
hidden_channels,
with_spectral_norm=with_spectral_norm,
spectral_norm_cfg=sn_cfg,
**self.conv_cfg)
self.conv_2 = SNConvModule(
hidden_channels,
out_channels,
with_spectral_norm=with_spectral_norm,
spectral_norm_cfg=sn_cfg,
**self.conv_cfg)
self.norm_1 = SNConditionNorm(in_channels, num_classes, use_cbn,
norm_cfg, use_norm_affine, auto_sync_bn,
with_embedding_spectral_norm, sn_style,
norm_eps, sn_eps, init_cfg)
self.norm_2 = SNConditionNorm(hidden_channels, num_classes, use_cbn,
norm_cfg, use_norm_affine, auto_sync_bn,
with_embedding_spectral_norm, sn_style,
norm_eps, sn_eps, init_cfg)
if self.learnable_sc:
# use hyperparameters-fixed shortcut here
self.shortcut = SNConvModule(
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
act_cfg=None,
with_spectral_norm=with_spectral_norm,
spectral_norm_cfg=sn_cfg)
self.init_weights()
def forward(self, x, y=None):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
y (Tensor): Input label with shape (n, ).
Default None.
Returns:
Tensor: Forward results.
"""
out = self.norm_1(x, y)
out = self.activate(out)
if self.with_upsample:
out = self.upsample(out)
out = self.conv_1(out)
out = self.norm_2(out, y)
out = self.activate(out)
out = self.conv_2(out)
shortcut = self.forward_shortcut(x)
return out + shortcut
def forward_shortcut(self, x):
out = x
if self.learnable_sc:
if self.with_upsample:
out = self.upsample(out)
out = self.shortcut(out)
return out
def init_weights(self):
"""Initialize weights for the model."""
if self.init_type.upper() == 'STUDIO':
nn.init.orthogonal_(self.conv_1.conv.weight)
nn.init.orthogonal_(self.conv_2.conv.weight)
self.conv_1.conv.bias.data.fill_(0.)
self.conv_2.conv.bias.data.fill_(0.)
if self.learnable_sc:
nn.init.orthogonal_(self.shortcut.conv.weight)
self.shortcut.conv.bias.data.fill_(0.)
elif self.init_type.upper() == 'BIGGAN':
xavier_uniform_(self.conv_1.conv.weight, gain=1)
xavier_uniform_(self.conv_2.conv.weight, gain=1)
if self.learnable_sc:
xavier_uniform_(self.shortcut.conv.weight, gain=1)
elif self.init_type.upper() == 'SAGAN':
xavier_init(self.conv_1, gain=1, distribution='uniform')
xavier_init(self.conv_2, gain=1, distribution='uniform')
if self.learnable_sc:
xavier_init(self.shortcut, gain=1, distribution='uniform')
elif self.init_type.upper() in ['SNGAN', 'SNGAN-PROJ', 'GAN-PROJ']:
xavier_init(self.conv_1, gain=np.sqrt(2), distribution='uniform')
xavier_init(self.conv_2, gain=np.sqrt(2), distribution='uniform')
if self.learnable_sc:
xavier_init(self.shortcut, gain=1, distribution='uniform')
else:
raise NotImplementedError('Unknown initialization method: '
f'\'{self.init_type}\'')
@MODULES.register_module()
class SNGANDiscResBlock(nn.Module):
"""resblock used in discriminator of sngan / proj-gan.
args:
in_channels (int): input channels.
out_channels (int): output channels.
hidden_channels (int, optional): input channels of the second conv
layer of the block. if ``none`` is given, would be set as
``out_channels``. Defaults to none.
downsample (bool, optional): whether apply downsample operation in this
module. Defaults to false.
act_cfg (dict, optional): config for activate function. default
to ``dict(type='relu')``.
conv_cfg (dict | none): config for conv blocks of this module. if pass
``none``, would use ``_default_conv_cfg``. default to ``none``.
with_spectral_norm (bool, optional): whether use spectral norm for
conv blocks and norm layers. Defaults to true.
sn_eps (float, optional): eps for spectral normalization operation.
Default to `1e-12`.
sn_style (str, optional): The style of spectral normalization.
If set to `ajbrock`, implementation by
ajbrock(https://github.com/ajbrock/BigGAN-PyTorch/blob/master/layers.py)
will be adopted.
If set to `torch`, implementation by `PyTorch` will be adopted.
Defaults to `torch`.
init_cfg (dict, optional): Config for weight initialization.
Defaults to ``dict(type='BigGAN')``.
"""
_default_conv_cfg = dict(kernel_size=3, stride=1, padding=1, act_cfg=None)
def __init__(self,
in_channels,
out_channels,
hidden_channels=None,
downsample=False,
act_cfg=dict(type='ReLU'),
conv_cfg=None,
with_spectral_norm=True,
sn_style='torch',
sn_eps=1e-12,
init_cfg=dict(type='BigGAN')):
super().__init__()
hidden_channels = out_channels if hidden_channels is None \
else hidden_channels
self.with_downsample = downsample
self.init_type = init_cfg.get('type', None)
self.conv_cfg = deepcopy(self._default_conv_cfg)
if conv_cfg is not None:
self.conv_cfg.update(conv_cfg)
self.activate = build_activation_layer(act_cfg)
sn_cfg = dict(eps=sn_eps, sn_style=sn_style)
self.conv_1 = SNConvModule(
in_channels,
hidden_channels,
with_spectral_norm=with_spectral_norm,
spectral_norm_cfg=sn_cfg,
**self.conv_cfg)
self.conv_2 = SNConvModule(
hidden_channels,
out_channels,
with_spectral_norm=with_spectral_norm,
spectral_norm_cfg=sn_cfg,
**self.conv_cfg)
if self.with_downsample:
self.downsample = nn.AvgPool2d(2, 2)
self.learnable_sc = in_channels != out_channels or downsample
if self.learnable_sc:
# use hyperparameters-fixed shortcut here
self.shortcut = SNConvModule(
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
act_cfg=None,
with_spectral_norm=with_spectral_norm,
spectral_norm_cfg=sn_cfg)
self.init_weights()
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
out = self.activate(x)
out = self.conv_1(out)
out = self.activate(out)
out = self.conv_2(out)
if self.with_downsample:
out = self.downsample(out)
shortcut = self.forward_shortcut(x)
return out + shortcut
def forward_shortcut(self, x):
out = x
if self.learnable_sc:
out = self.shortcut(out)
if self.with_downsample:
out = self.downsample(out)
return out
def init_weights(self):
if self.init_type.upper() == 'STUDIO':
nn.init.orthogonal_(self.conv_1.conv.weight)
nn.init.orthogonal_(self.conv_2.conv.weight)
self.conv_1.conv.bias.data.fill_(0.)
self.conv_2.conv.bias.data.fill_(0.)
if self.learnable_sc:
nn.init.orthogonal_(self.shortcut.conv.weight)
self.shortcut.conv.bias.data.fill_(0.)
elif self.init_type.upper() == 'BIGGAN':
xavier_uniform_(self.conv_1.conv.weight, gain=1)
xavier_uniform_(self.conv_2.conv.weight, gain=1)
if self.learnable_sc:
xavier_uniform_(self.shortcut.conv.weight, gain=1)
elif self.init_type.upper() == 'SAGAN':
xavier_init(self.conv_1, gain=1, distribution='uniform')
xavier_init(self.conv_2, gain=1, distribution='uniform')
if self.learnable_sc:
xavier_init(self.shortcut, gain=1, distribution='uniform')
elif self.init_type.upper() in ['SNGAN', 'SNGAN-PROJ', 'GAN-PROJ']:
xavier_init(self.conv_1, gain=np.sqrt(2), distribution='uniform')
xavier_init(self.conv_2, gain=np.sqrt(2), distribution='uniform')
if self.learnable_sc:
xavier_init(self.shortcut, gain=1, distribution='uniform')
else:
raise NotImplementedError('Unknown initialization method: '
f'\'{self.init_type}\'')
@MODULES.register_module()
class SNGANDiscHeadResBlock(nn.Module):
"""The first ResBlock used in discriminator of sngan / proj-gan. Compared
to ``SNGANDisResBlock``, this module has a different forward order.
args:
in_channels (int): Input channels.
out_channels (int): Output channels.
downsample (bool, optional): whether apply downsample operation in this
module. default to false.
conv_cfg (dict | none): config for conv blocks of this module. if pass
``none``, would use ``_default_conv_cfg``. default to ``none``.
act_cfg (dict, optional): config for activate function. default
to ``dict(type='relu')``.
with_spectral_norm (bool, optional): whether use spectral norm for
conv blocks and norm layers. default to true.
sn_style (str, optional): The style of spectral normalization.
If set to `ajbrock`, implementation by
ajbrock(https://github.com/ajbrock/BigGAN-PyTorch/blob/master/layers.py)
will be adopted.
If set to `torch`, implementation by `PyTorch` will be adopted.
Defaults to `torch`.
sn_eps (float, optional): eps for spectral normalization operation.
Default to `1e-12`.
init_cfg (dict, optional): Config for weight initialization.
Default to ``dict(type='BigGAN')``.
"""
_default_conv_cfg = dict(kernel_size=3, stride=1, padding=1, act_cfg=None)
def __init__(self,
in_channels,
out_channels,
conv_cfg=None,
act_cfg=dict(type='ReLU'),
with_spectral_norm=True,
sn_eps=1e-12,
sn_style='torch',
init_cfg=dict(type='BigGAN')):
super().__init__()
self.init_type = init_cfg.get('type', None)
self.conv_cfg = deepcopy(self._default_conv_cfg)
if conv_cfg is not None:
self.conv_cfg.update(conv_cfg)
self.activate = build_activation_layer(act_cfg)
sn_cfg = dict(eps=sn_eps, sn_style=sn_style)
self.conv_1 = SNConvModule(
in_channels,
out_channels,
with_spectral_norm=with_spectral_norm,
spectral_norm_cfg=sn_cfg,
**self.conv_cfg)
self.conv_2 = SNConvModule(
out_channels,
out_channels,
with_spectral_norm=with_spectral_norm,
spectral_norm_cfg=sn_cfg,
**self.conv_cfg)
self.downsample = nn.AvgPool2d(2, 2)
# use hyperparameters-fixed shortcut here
self.shortcut = SNConvModule(
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
act_cfg=None,
with_spectral_norm=with_spectral_norm,
spectral_norm_cfg=sn_cfg)
self.init_weights()
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
out = self.conv_1(x)
out = self.activate(out)
out = self.conv_2(out)
out = self.downsample(out)
shortcut = self.forward_shortcut(x)
return out + shortcut
def forward_shortcut(self, x):
out = self.downsample(x)
out = self.shortcut(out)
return out
def init_weights(self):
if self.init_type.upper() == 'STUDIO':
for m in [self.conv_1, self.conv_2, self.shortcut]:
nn.init.orthogonal_(m.conv.weight)
m.conv.bias.data.fill_(0.)
elif self.init_type.upper() == 'BIGGAN':
xavier_uniform_(self.conv_1.conv.weight, gain=1)
xavier_uniform_(self.conv_2.conv.weight, gain=1)
xavier_uniform_(self.shortcut.conv.weight, gain=1)
elif self.init_type.upper() == 'SAGAN':
xavier_init(self.conv_1, gain=1, distribution='uniform')
xavier_init(self.conv_2, gain=1, distribution='uniform')
xavier_init(self.shortcut, gain=1, distribution='uniform')
elif self.init_type.upper() in ['SNGAN', 'SNGAN-PROJ', 'GAN-PROJ']:
xavier_init(self.conv_1, gain=np.sqrt(2), distribution='uniform')
xavier_init(self.conv_2, gain=np.sqrt(2), distribution='uniform')
xavier_init(self.shortcut, gain=1, distribution='uniform')
else:
raise NotImplementedError('Unknown initialization method: '
f'\'{self.init_type}\'')
@MODULES.register_module()
class SNConditionNorm(nn.Module):
"""Conditional Normalization for SNGAN / Proj-GAN. The implementation
refers to.
https://github.com/pfnet-research/sngan_projection/blob/master/source/links/conditional_batch_normalization.py # noda
and
https://github.com/POSTECH-CVLab/PyTorch-StudioGAN/blob/master/src/utils/model_ops.py # noqa
Args:
in_channels (int): Number of the channels of the input feature map.
num_classes (int): Number of the classes in the dataset. If ``use_cbn``
is True, ``num_classes`` must larger than 0.
use_cbn (bool, optional): Whether use conditional normalization. If
``use_cbn`` is True, two embedding layers would be used to mapping
label to weight and bias used in normalization process.
norm_cfg (dict, optional): Config for normalization method. Defaults
to ``dict(type='BN')``.
cbn_norm_affine (bool): Whether set ``affine=True`` when use conditional batch norm.
This argument only work when ``use_cbn`` is True. Defaults to False.
auto_sync_bn (bool, optional): Whether convert Batch Norm to
Synchronized ones when Distributed training is on. Defaults to True.
with_spectral_norm (bool, optional): whether use spectral norm for
conv blocks and norm layers. Defaults to true.
norm_eps (float, optional): eps for Normalization layers (both
conditional and non-conditional ones). Defaults to `1e-4`.
sn_style (str, optional): The style of spectral normalization.
If set to `ajbrock`, implementation by
ajbrock(https://github.com/ajbrock/BigGAN-PyTorch/blob/master/layers.py)
will be adopted.
If set to `torch`, implementation by `PyTorch` will be adopted.
Defaults to `torch`.
sn_eps (float, optional): eps for spectral normalization operation.
Defaults to `1e-12`.
init_cfg (dict, optional): Config for weight initialization.
Defaults to ``dict(type='BigGAN')``.
"""
def __init__(self,
in_channels,
num_classes,
use_cbn=True,
norm_cfg=dict(type='BN'),
cbn_norm_affine=False,
auto_sync_bn=True,
with_spectral_norm=False,
sn_style='torch',
norm_eps=1e-4,
sn_eps=1e-12,
init_cfg=dict(type='BigGAN')):
super().__init__()
self.use_cbn = use_cbn
self.init_type = init_cfg.get('type', None)
norm_cfg = deepcopy(norm_cfg)
norm_type = norm_cfg['type']
if norm_type not in ['IN', 'BN', 'SyncBN']:
raise ValueError('Only support `IN` (InstanceNorm), '
'`BN` (BatcnNorm) and `SyncBN` for '
'Class-conditional bn. '
f'Receive norm_type: {norm_type}')
if self.use_cbn:
norm_cfg.setdefault('affine', cbn_norm_affine)
norm_cfg.setdefault('eps', norm_eps)
if check_dist_init() and auto_sync_bn and norm_type == 'BN':
norm_cfg['type'] = 'SyncBN'
_, self.norm = build_norm_layer(norm_cfg, in_channels)
if self.use_cbn:
if num_classes <= 0:
raise ValueError('`num_classes` must be larger '
'than 0 with `use_cbn=True`')
self.reweight_embedding = (
self.init_type.upper() == 'BIGGAN'
or self.init_type.upper() == 'STUDIO')
if with_spectral_norm:
if sn_style == 'torch':
self.weight_embedding = spectral_norm(
nn.Embedding(num_classes, in_channels), eps=sn_eps)
self.bias_embedding = spectral_norm(
nn.Embedding(num_classes, in_channels), eps=sn_eps)
elif sn_style == 'ajbrock':
self.weight_embedding = SNEmbedding(
num_classes, in_channels, eps=sn_eps)
self.bias_embedding = SNEmbedding(
num_classes, in_channels, eps=sn_eps)
else:
raise NotImplementedError(
f'{sn_style} style spectral Norm is not '
'supported yet')
else:
self.weight_embedding = nn.Embedding(num_classes, in_channels)
self.bias_embedding = nn.Embedding(num_classes, in_channels)
self.init_weights()
def forward(self, x, y=None):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
y (Tensor, optional): Input label with shape (n, ).
Default None.
Returns:
Tensor: Forward results.
"""
out = self.norm(x)
if self.use_cbn:
weight = self.weight_embedding(y)[:, :, None, None]
bias = self.bias_embedding(y)[:, :, None, None]
if self.reweight_embedding:
# print('reweight_called --> correct')
weight = weight + 1.
out = out * weight + bias
return out
def init_weights(self):
if self.use_cbn:
if self.init_type.upper() == 'STUDIO':
nn.init.orthogonal_(self.weight_embedding.weight)
nn.init.orthogonal_(self.bias_embedding.weight)
elif self.init_type.upper() == 'BIGGAN':
xavier_uniform_(self.weight_embedding.weight, gain=1)
xavier_uniform_(self.bias_embedding.weight, gain=1)
elif self.init_type.upper() in [
'SNGAN', 'SNGAN-PROJ', 'GAN-PROJ', 'SAGAN'
]:
constant_init(self.weight_embedding, 1)
constant_init(self.bias_embedding, 0)
else:
raise NotImplementedError('Unknown initialization method: '
f'\'{self.init_type}\'')
# Copyright (c) OpenMMLab. All rights reserved.
from .generator_discriminator_v1 import (StyleGAN1Discriminator,
StyleGANv1Generator)
from .generator_discriminator_v2 import (StyleGAN2Discriminator,
StyleGANv2Generator)
from .generator_discriminator_v3 import StyleGANv3Generator
from .mspie import MSStyleGAN2Discriminator, MSStyleGANv2Generator
__all__ = [
'StyleGAN2Discriminator', 'StyleGANv2Generator', 'StyleGANv1Generator',
'StyleGAN1Discriminator', 'MSStyleGAN2Discriminator',
'MSStyleGANv2Generator', 'StyleGANv3Generator'
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment