Commit 1401de15 authored by dongchy920's avatar dongchy920
Browse files

stylegan2_mmcv

parents
Pipeline #1274 canceled with stages
# Copyright (c) OpenMMLab. All rights reserved.
import random
import mmcv
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmgen.models.architectures import PixelNorm
from mmgen.models.architectures.common import get_module_device
from mmgen.models.architectures.pggan import (EqualizedLRConvDownModule,
EqualizedLRConvModule)
from mmgen.models.architectures.stylegan.modules import Blur
from mmgen.models.builder import MODULES
from .. import MiniBatchStddevLayer
from .modules.styleganv1_modules import StyleConv
from .modules.styleganv2_modules import EqualLinearActModule
from .utils import get_mean_latent, style_mixing
@MODULES.register_module()
class StyleGANv1Generator(nn.Module):
"""StyleGAN1 Generator.
In StyleGAN1, we use a progressive growing architecture composing of a
style mapping module and number of convolutional style blocks. More details
can be found in: A Style-Based Generator Architecture for Generative
Adversarial Networks CVPR2019.
Args:
out_size (int): The output size of the StyleGAN1 generator.
style_channels (int): The number of channels for style code.
num_mlps (int, optional): The number of MLP layers. Defaults to 8.
blur_kernel (list, optional): The blurry kernel. Defaults
to [1, 2, 1].
lr_mlp (float, optional): The learning rate for the style mapping
layer. Defaults to 0.01.
default_style_mode (str, optional): The default mode of style mixing.
In training, we defaultly adopt mixing style mode. However, in the
evaluation, we use 'single' style mode. `['mix', 'single']` are
currently supported. Defaults to 'mix'.
eval_style_mode (str, optional): The evaluation mode of style mixing.
Defaults to 'single'.
mix_prob (float, optional): Mixing probability. The value should be
in range of [0, 1]. Defaults to 0.9.
"""
def __init__(self,
out_size,
style_channels,
num_mlps=8,
blur_kernel=[1, 2, 1],
lr_mlp=0.01,
default_style_mode='mix',
eval_style_mode='single',
mix_prob=0.9):
super().__init__()
self.out_size = out_size
self.style_channels = style_channels
self.num_mlps = num_mlps
self.lr_mlp = lr_mlp
self._default_style_mode = default_style_mode
self.default_style_mode = default_style_mode
self.eval_style_mode = eval_style_mode
self.mix_prob = mix_prob
# define style mapping layers
mapping_layers = [PixelNorm()]
for _ in range(num_mlps):
mapping_layers.append(
EqualLinearActModule(
style_channels,
style_channels,
equalized_lr_cfg=dict(lr_mul=lr_mlp, gain=1.),
act_cfg=dict(type='LeakyReLU', negative_slope=0.2)))
self.style_mapping = nn.Sequential(*mapping_layers)
self.channels = {
4: 512,
8: 512,
16: 512,
32: 512,
64: 256,
128: 128,
256: 64,
512: 32,
1024: 16,
}
# generator backbone (8x8 --> higher resolutions)
self.log_size = int(np.log2(self.out_size))
self.convs = nn.ModuleList()
self.to_rgbs = nn.ModuleList()
in_channels_ = self.channels[4]
for i in range(2, self.log_size + 1):
out_channels_ = self.channels[2**i]
self.convs.append(
StyleConv(
in_channels_,
out_channels_,
3,
style_channels,
initial=(i == 2),
upsample=True,
fused=True))
self.to_rgbs.append(
EqualizedLRConvModule(out_channels_, 3, 1, act_cfg=None))
in_channels_ = out_channels_
self.num_latents = self.log_size * 2 - 2
self.num_injected_noises = self.num_latents
# register buffer for injected noises
for layer_idx in range(self.num_injected_noises):
res = (layer_idx + 4) // 2
shape = [1, 1, 2**res, 2**res]
self.register_buffer(f'injected_noise_{layer_idx}',
torch.randn(*shape))
def train(self, mode=True):
if mode:
if self.default_style_mode != self._default_style_mode:
mmcv.print_log(
f'Switch to train style mode: {self._default_style_mode}',
'mmgen')
self.default_style_mode = self._default_style_mode
else:
if self.default_style_mode != self.eval_style_mode:
mmcv.print_log(
f'Switch to evaluation style mode: {self.eval_style_mode}',
'mmgen')
self.default_style_mode = self.eval_style_mode
return super(StyleGANv1Generator, self).train(mode)
def make_injected_noise(self):
"""make noises that will be injected into feature maps.
Returns:
list[Tensor]: List of layer-wise noise tensor.
"""
device = get_module_device(self)
# noises = [torch.randn(1, 1, 2**2, 2**2, device=device)]
noises = []
for i in range(2, self.log_size + 1):
for _ in range(2):
noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
return noises
def get_mean_latent(self, num_samples=4096, **kwargs):
"""Get mean latent of W space in this generator.
Args:
num_samples (int, optional): Number of sample times. Defaults
to 4096.
Returns:
Tensor: Mean latent of this generator.
"""
return get_mean_latent(self, num_samples, **kwargs)
def style_mixing(self,
n_source,
n_target,
inject_index=1,
truncation_latent=None,
truncation=0.7,
curr_scale=-1,
transition_weight=1):
return style_mixing(
self,
n_source=n_source,
n_target=n_target,
inject_index=inject_index,
truncation=truncation,
truncation_latent=truncation_latent,
style_channels=self.style_channels,
curr_scale=curr_scale,
transition_weight=transition_weight)
def forward(self,
styles,
num_batches=-1,
return_noise=False,
return_latents=False,
inject_index=None,
truncation=1,
truncation_latent=None,
input_is_latent=False,
injected_noise=None,
randomize_noise=True,
transition_weight=1.,
curr_scale=-1):
"""Forward function.
This function has been integrated with the truncation trick. Please
refer to the usage of `truncation` and `truncation_latent`.
Args:
styles (torch.Tensor | list[torch.Tensor] | callable | None): In
StyleGAN1, you can provide noise tensor or latent tensor. Given
a list containing more than one noise or latent tensors, style
mixing trick will be used in training. Of course, You can
directly give a batch of noise through a ``torch.Tensor`` or
offer a callable function to sample a batch of noise data.
Otherwise, the ``None`` indicates to use the default noise
sampler.
num_batches (int, optional): The number of batch size.
Defaults to 0.
return_noise (bool, optional): If True, ``noise_batch`` will be
returned in a dict with ``fake_img``. Defaults to False.
return_latents (bool, optional): If True, ``latent`` will be
returned in a dict with ``fake_img``. Defaults to False.
inject_index (int | None, optional): The index number for mixing
style codes. Defaults to None.
truncation (float, optional): Truncation factor. Give value less
than 1., the truncation trick will be adopted. Defaults to 1.
truncation_latent (torch.Tensor, optional): Mean truncation latent.
Defaults to None.
input_is_latent (bool, optional): If `True`, the input tensor is
the latent tensor. Defaults to False.
injected_noise (torch.Tensor | None, optional): Given a tensor, the
random noise will be fixed as this input injected noise.
Defaults to None.
randomize_noise (bool, optional): If `False`, images are sampled
with the buffered noise tensor injected to the style conv
block. Defaults to True.
transition_weight (float, optional): The weight used in resolution
transition. Defaults to 1..
curr_scale (int, optional): The resolution scale of generated image
tensor. -1 means the max resolution scale of the StyleGAN1.
Defaults to -1.
Returns:
torch.Tensor | dict: Generated image tensor or dictionary \
containing more data.
"""
# receive noise and conduct sanity check.
if isinstance(styles, torch.Tensor):
assert styles.shape[1] == self.style_channels
styles = [styles]
elif mmcv.is_seq_of(styles, torch.Tensor):
for t in styles:
assert t.shape[-1] == self.style_channels
# receive a noise generator and sample noise.
elif callable(styles):
device = get_module_device(self)
noise_generator = styles
assert num_batches > 0
if self.default_style_mode == 'mix' and random.random(
) < self.mix_prob:
styles = [
noise_generator((num_batches, self.style_channels))
for _ in range(2)
]
else:
styles = [noise_generator((num_batches, self.style_channels))]
styles = [s.to(device) for s in styles]
# otherwise, we will adopt default noise sampler.
else:
device = get_module_device(self)
assert num_batches > 0 and not input_is_latent
if self.default_style_mode == 'mix' and random.random(
) < self.mix_prob:
styles = [
torch.randn((num_batches, self.style_channels))
for _ in range(2)
]
else:
styles = [torch.randn((num_batches, self.style_channels))]
styles = [s.to(device) for s in styles]
if not input_is_latent:
noise_batch = styles
styles = [self.style_mapping(s) for s in styles]
else:
noise_batch = None
if injected_noise is None:
if randomize_noise:
injected_noise = [None] * self.num_injected_noises
else:
injected_noise = [
getattr(self, f'injected_noise_{i}')
for i in range(self.num_injected_noises)
]
# use truncation trick
if truncation < 1:
style_t = []
# calculate truncation latent on the fly
if truncation_latent is None and not hasattr(
self, 'truncation_latent'):
self.truncation_latent = self.get_mean_latent()
truncation_latent = self.truncation_latent
elif truncation_latent is None and hasattr(self,
'truncation_latent'):
truncation_latent = self.truncation_latent
for style in styles:
style_t.append(truncation_latent + truncation *
(style - truncation_latent))
styles = style_t
# no style mixing
if len(styles) < 2:
inject_index = self.num_latents
if styles[0].ndim < 3:
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
else:
latent = styles[0]
# style mixing
else:
if inject_index is None:
inject_index = random.randint(1, self.num_latents - 1)
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
latent2 = styles[1].unsqueeze(1).repeat(
1, self.num_latents - inject_index, 1)
latent = torch.cat([latent, latent2], 1)
curr_log_size = self.log_size if curr_scale < 0 else int(
np.log2(curr_scale))
step = curr_log_size - 2
_index = 0
out = latent
# 4x4 ---> higher resolutions
for i, (conv, to_rgb) in enumerate(zip(self.convs, self.to_rgbs)):
if i > 0 and step > 0:
out_prev = out
out = conv(
out,
latent[:, _index],
latent[:, _index + 1],
noise1=injected_noise[2 * i],
noise2=injected_noise[2 * i + 1])
if i == step:
out = to_rgb(out)
if i > 0 and 0 <= transition_weight < 1:
skip_rgb = self.to_rgbs[i - 1](out_prev)
skip_rgb = F.interpolate(
skip_rgb, scale_factor=2, mode='nearest')
out = (1 - transition_weight
) * skip_rgb + transition_weight * out
break
_index += 2
img = out
if return_latents or return_noise:
output_dict = dict(
fake_img=img,
latent=latent,
inject_index=inject_index,
noise_batch=noise_batch)
return output_dict
return img
@MODULES.register_module()
class StyleGAN1Discriminator(nn.Module):
"""StyleGAN1 Discriminator.
The architecture of this discriminator is proposed in StyleGAN1. More
details can be found in: A Style-Based Generator Architecture for
Generative Adversarial Networks CVPR2019.
Args:
in_size (int): The input size of images.
blur_kernel (list, optional): The blurry kernel. Defaults
to [1, 2, 1].
mbstd_cfg (dict, optional): Configs for minibatch-stddev layer.
Defaults to dict(group_size=4).
"""
def __init__(self,
in_size,
blur_kernel=[1, 2, 1],
mbstd_cfg=dict(group_size=4)):
super().__init__()
self.with_mbstd = mbstd_cfg is not None
channels = {
4: 512,
8: 512,
16: 512,
32: 512,
64: 256,
128: 128,
256: 64,
512: 32,
1024: 16,
}
log_size = int(np.log2(in_size))
self.log_size = log_size
in_channels = channels[in_size]
self.convs = nn.ModuleList()
self.from_rgb = nn.ModuleList()
for i in range(log_size, 2, -1):
out_channel = channels[2**(i - 1)]
self.from_rgb.append(
EqualizedLRConvModule(
3,
in_channels,
kernel_size=3,
padding=1,
act_cfg=dict(type='LeakyReLU', negative_slope=0.2)))
self.convs.append(
nn.Sequential(
EqualizedLRConvModule(
in_channels,
out_channel,
kernel_size=3,
padding=1,
bias=True,
norm_cfg=None,
act_cfg=dict(type='LeakyReLU', negative_slope=0.2)),
Blur(blur_kernel, pad=(1, 1)),
EqualizedLRConvDownModule(
out_channel,
out_channel,
kernel_size=3,
stride=2,
padding=1,
act_cfg=None),
nn.LeakyReLU(negative_slope=0.2, inplace=True)))
in_channels = out_channel
self.from_rgb.append(
EqualizedLRConvModule(
3,
in_channels,
kernel_size=3,
padding=0,
act_cfg=dict(type='LeakyReLU', negative_slope=0.2)))
self.convs.append(
nn.Sequential(
EqualizedLRConvModule(
in_channels + 1,
512,
kernel_size=3,
padding=1,
bias=True,
norm_cfg=None,
act_cfg=dict(type='LeakyReLU', negative_slope=0.2)),
EqualizedLRConvModule(
512,
512,
kernel_size=4,
padding=0,
bias=True,
norm_cfg=None,
act_cfg=None),
))
if self.with_mbstd:
self.mbstd_layer = MiniBatchStddevLayer(**mbstd_cfg)
self.final_linear = nn.Sequential(EqualLinearActModule(channels[4], 1))
self.n_layer = len(self.convs)
def forward(self, input, transition_weight=1., curr_scale=-1):
"""Forward function.
Args:
input (torch.Tensor): Input image tensor.
transition_weight (float, optional): The weight used in resolution
transition. Defaults to 1..
curr_scale (int, optional): The resolution scale of image tensor.
-1 means the max resolution scale of the StyleGAN1.
Defaults to -1.
Returns:
torch.Tensor: Predict score for the input image.
"""
curr_log_size = self.log_size if curr_scale < 0 else int(
np.log2(curr_scale))
step = curr_log_size - 2
for i in range(step, -1, -1):
index = self.n_layer - i - 1
if i == step:
out = self.from_rgb[index](input)
# minibatch standard deviation
if i == 0:
out = self.mbstd_layer(out)
out = self.convs[index](out)
if i > 0:
if i == step and 0 <= transition_weight < 1:
skip_rgb = F.avg_pool2d(input, 2)
skip_rgb = self.from_rgb[index + 1](skip_rgb)
out = (1 - transition_weight
) * skip_rgb + transition_weight * out
out = out.view(out.shape[0], -1)
out = self.final_linear(out)
return out
# Copyright (c) OpenMMLab. All rights reserved.
import random
import mmcv
import numpy as np
import torch
import torch.nn as nn
from mmcv.runner.checkpoint import _load_checkpoint_with_prefix
from mmgen.core.runners.fp16_utils import auto_fp16
from mmgen.models.architectures import PixelNorm
from mmgen.models.architectures.common import get_module_device
from mmgen.models.builder import MODULES, build_module
from .ada.augment import AugmentPipe
from .ada.misc import constant
from .modules.styleganv2_modules import (ConstantInput, ConvDownLayer,
EqualLinearActModule,
ModMBStddevLayer, ModulatedStyleConv,
ModulatedToRGB, ResBlock)
from .utils import get_mean_latent, style_mixing
@MODULES.register_module()
class StyleGANv2Generator(nn.Module):
r"""StyleGAN2 Generator.
In StyleGAN2, we use a static architecture composing of a style mapping
module and number of convolutional style blocks. More details can be found
in: Analyzing and Improving the Image Quality of StyleGAN CVPR2020.
You can load pretrained model through passing information into
``pretrained`` argument. We have already offered official weights as
follows:
- stylegan2-ffhq-config-f: https://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-ffhq-config-f-official_20210327_171224-bce9310c.pth # noqa
- stylegan2-horse-config-f: https://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-horse-config-f-official_20210327_173203-ef3e69ca.pth # noqa
- stylegan2-car-config-f: https://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-car-config-f-official_20210327_172340-8cfe053c.pth # noqa
- stylegan2-cat-config-f: https://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-cat-config-f-official_20210327_172444-15bc485b.pth # noqa
- stylegan2-church-config-f: https://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-church-config-f-official_20210327_172657-1d42b7d1.pth # noqa
If you want to load the ema model, you can just use following codes:
.. code-block:: python
# ckpt_http is one of the valid path from http source
generator = StyleGANv2Generator(1024, 512,
pretrained=dict(
ckpt_path=ckpt_http,
prefix='generator_ema'))
Of course, you can also download the checkpoint in advance and set
``ckpt_path`` with local path. If you just want to load the original
generator (not the ema model), please set the prefix with 'generator'.
Note that our implementation allows to generate BGR image, while the
original StyleGAN2 outputs RGB images by default. Thus, we provide
``bgr2rgb`` argument to convert the image space.
Args:
out_size (int): The output size of the StyleGAN2 generator.
style_channels (int): The number of channels for style code.
num_mlps (int, optional): The number of MLP layers. Defaults to 8.
channel_multiplier (int, optional): The multiplier factor for the
channel number. Defaults to 2.
blur_kernel (list, optional): The blurry kernel. Defaults
to [1, 3, 3, 1].
lr_mlp (float, optional): The learning rate for the style mapping
layer. Defaults to 0.01.
default_style_mode (str, optional): The default mode of style mixing.
In training, we defaultly adopt mixing style mode. However, in the
evaluation, we use 'single' style mode. `['mix', 'single']` are
currently supported. Defaults to 'mix'.
eval_style_mode (str, optional): The evaluation mode of style mixing.
Defaults to 'single'.
mix_prob (float, optional): Mixing probability. The value should be
in range of [0, 1]. Defaults to ``0.9``.
num_fp16_scales (int, optional): The number of resolutions to use auto
fp16 training. Different from ``fp16_enabled``, this argument
allows users to adopt FP16 training only in several blocks.
This behaviour is much more similar to the official implementation
by Tero. Defaults to 0.
fp16_enabled (bool, optional): Whether to use fp16 training in this
module. If this flag is `True`, the whole module will be wrapped
with ``auto_fp16``. Defaults to False.
pretrained (dict | None, optional): Information for pretained models.
The necessary key is 'ckpt_path'. Besides, you can also provide
'prefix' to load the generator part from the whole state dict.
Defaults to None.
"""
def __init__(self,
out_size,
style_channels,
num_mlps=8,
channel_multiplier=2,
blur_kernel=[1, 3, 3, 1],
lr_mlp=0.01,
default_style_mode='mix',
eval_style_mode='single',
mix_prob=0.9,
num_fp16_scales=0,
fp16_enabled=False,
pretrained=None):
super().__init__()
self.out_size = out_size
self.style_channels = style_channels
self.num_mlps = num_mlps
self.channel_multiplier = channel_multiplier
self.lr_mlp = lr_mlp
self._default_style_mode = default_style_mode
self.default_style_mode = default_style_mode
self.eval_style_mode = eval_style_mode
self.mix_prob = mix_prob
self.num_fp16_scales = num_fp16_scales
self.fp16_enabled = fp16_enabled
# define style mapping layers
mapping_layers = [PixelNorm()]
for _ in range(num_mlps):
mapping_layers.append(
EqualLinearActModule(
style_channels,
style_channels,
equalized_lr_cfg=dict(lr_mul=lr_mlp, gain=1.),
act_cfg=dict(type='fused_bias')))
self.style_mapping = nn.Sequential(*mapping_layers)
self.channels = {
4: 512,
8: 512,
16: 512,
32: 512,
64: 256 * channel_multiplier,
128: 128 * channel_multiplier,
256: 64 * channel_multiplier,
512: 32 * channel_multiplier,
1024: 16 * channel_multiplier,
}
# constant input layer
self.constant_input = ConstantInput(self.channels[4])
# 4x4 stage
self.conv1 = ModulatedStyleConv(
self.channels[4],
self.channels[4],
kernel_size=3,
style_channels=style_channels,
blur_kernel=blur_kernel)
self.to_rgb1 = ModulatedToRGB(
self.channels[4],
style_channels,
upsample=False,
fp16_enabled=fp16_enabled)
# generator backbone (8x8 --> higher resolutions)
self.log_size = int(np.log2(self.out_size))
self.convs = nn.ModuleList()
self.upsamples = nn.ModuleList()
self.to_rgbs = nn.ModuleList()
in_channels_ = self.channels[4]
for i in range(3, self.log_size + 1):
out_channels_ = self.channels[2**i]
# If `fp16_enabled` is True, all of layers will be run in auto
# FP16. In the case of `num_fp16_sacles` > 0, only partial
# layers will be run in fp16.
_use_fp16 = (self.log_size - i) < num_fp16_scales or fp16_enabled
self.convs.append(
ModulatedStyleConv(
in_channels_,
out_channels_,
3,
style_channels,
upsample=True,
blur_kernel=blur_kernel,
fp16_enabled=_use_fp16))
self.convs.append(
ModulatedStyleConv(
out_channels_,
out_channels_,
3,
style_channels,
upsample=False,
blur_kernel=blur_kernel,
fp16_enabled=_use_fp16))
self.to_rgbs.append(
ModulatedToRGB(
out_channels_,
style_channels,
upsample=True,
fp16_enabled=_use_fp16)) # set to global fp16
in_channels_ = out_channels_
self.num_latents = self.log_size * 2 - 2
self.num_injected_noises = self.num_latents - 1
# register buffer for injected noises
for layer_idx in range(self.num_injected_noises):
res = (layer_idx + 5) // 2
shape = [1, 1, 2**res, 2**res]
self.register_buffer(f'injected_noise_{layer_idx}',
torch.randn(*shape))
if pretrained is not None:
self._load_pretrained_model(**pretrained)
def _load_pretrained_model(self,
ckpt_path,
prefix='',
map_location='cpu',
strict=True):
state_dict = _load_checkpoint_with_prefix(prefix, ckpt_path,
map_location)
self.load_state_dict(state_dict, strict=strict)
mmcv.print_log(f'Load pretrained model from {ckpt_path}', 'mmgen')
def train(self, mode=True):
if mode:
if self.default_style_mode != self._default_style_mode:
mmcv.print_log(
f'Switch to train style mode: {self._default_style_mode}',
'mmgen')
self.default_style_mode = self._default_style_mode
else:
if self.default_style_mode != self.eval_style_mode:
mmcv.print_log(
f'Switch to evaluation style mode: {self.eval_style_mode}',
'mmgen')
self.default_style_mode = self.eval_style_mode
return super(StyleGANv2Generator, self).train(mode)
def make_injected_noise(self):
"""make noises that will be injected into feature maps.
Returns:
list[Tensor]: List of layer-wise noise tensor.
"""
device = get_module_device(self)
noises = [torch.randn(1, 1, 2**2, 2**2, device=device)]
for i in range(3, self.log_size + 1):
for _ in range(2):
noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
return noises
def get_mean_latent(self, num_samples=4096, **kwargs):
"""Get mean latent of W space in this generator.
Args:
num_samples (int, optional): Number of sample times. Defaults
to 4096.
Returns:
Tensor: Mean latent of this generator.
"""
return get_mean_latent(self, num_samples, **kwargs)
def style_mixing(self,
n_source,
n_target,
inject_index=1,
truncation_latent=None,
truncation=0.7):
return style_mixing(
self,
n_source=n_source,
n_target=n_target,
inject_index=inject_index,
truncation=truncation,
truncation_latent=truncation_latent,
style_channels=self.style_channels)
@auto_fp16()
def forward(self,
styles,
num_batches=-1,
return_noise=False,
return_latents=False,
inject_index=None,
truncation=1,
truncation_latent=None,
input_is_latent=False,
injected_noise=None,
randomize_noise=True):
"""Forward function.
This function has been integrated with the truncation trick. Please
refer to the usage of `truncation` and `truncation_latent`.
Args:
styles (torch.Tensor | list[torch.Tensor] | callable | None): In
StyleGAN2, you can provide noise tensor or latent tensor. Given
a list containing more than one noise or latent tensors, style
mixing trick will be used in training. Of course, You can
directly give a batch of noise through a ``torch.Tensor`` or
offer a callable function to sample a batch of noise data.
Otherwise, the ``None`` indicates to use the default noise
sampler.
num_batches (int, optional): The number of batch size.
Defaults to 0.
return_noise (bool, optional): If True, ``noise_batch`` will be
returned in a dict with ``fake_img``. Defaults to False.
return_latents (bool, optional): If True, ``latent`` will be
returned in a dict with ``fake_img``. Defaults to False.
inject_index (int | None, optional): The index number for mixing
style codes. Defaults to None.
truncation (float, optional): Truncation factor. Give value less
than 1., the truncation trick will be adopted. Defaults to 1.
truncation_latent (torch.Tensor, optional): Mean truncation latent.
Defaults to None.
input_is_latent (bool, optional): If `True`, the input tensor is
the latent tensor. Defaults to False.
injected_noise (torch.Tensor | None, optional): Given a tensor, the
random noise will be fixed as this input injected noise.
Defaults to None.
randomize_noise (bool, optional): If `False`, images are sampled
with the buffered noise tensor injected to the style conv
block. Defaults to True.
Returns:
torch.Tensor | dict: Generated image tensor or dictionary \
containing more data.
"""
# receive noise and conduct sanity check.
if isinstance(styles, torch.Tensor):
assert styles.shape[1] == self.style_channels
styles = [styles]
elif mmcv.is_seq_of(styles, torch.Tensor):
for t in styles:
assert t.shape[-1] == self.style_channels
# receive a noise generator and sample noise.
elif callable(styles):
device = get_module_device(self)
noise_generator = styles
assert num_batches > 0
if self.default_style_mode == 'mix' and random.random(
) < self.mix_prob:
styles = [
noise_generator((num_batches, self.style_channels))
for _ in range(2)
]
else:
styles = [noise_generator((num_batches, self.style_channels))]
styles = [s.to(device) for s in styles]
# otherwise, we will adopt default noise sampler.
else:
device = get_module_device(self)
assert num_batches > 0 and not input_is_latent
if self.default_style_mode == 'mix' and random.random(
) < self.mix_prob:
styles = [
torch.randn((num_batches, self.style_channels))
for _ in range(2)
]
else:
styles = [torch.randn((num_batches, self.style_channels))]
styles = [s.to(device) for s in styles]
if not input_is_latent:
noise_batch = styles
styles = [self.style_mapping(s) for s in styles]
else:
noise_batch = None
if injected_noise is None:
if randomize_noise:
injected_noise = [None] * self.num_injected_noises
else:
injected_noise = [
getattr(self, f'injected_noise_{i}')
for i in range(self.num_injected_noises)
]
# use truncation trick
if truncation < 1:
style_t = []
# calculate truncation latent on the fly
if truncation_latent is None and not hasattr(
self, 'truncation_latent'):
self.truncation_latent = self.get_mean_latent()
truncation_latent = self.truncation_latent
elif truncation_latent is None and hasattr(self,
'truncation_latent'):
truncation_latent = self.truncation_latent
for style in styles:
style_t.append(truncation_latent + truncation *
(style - truncation_latent))
styles = style_t
# no style mixing
if len(styles) < 2:
inject_index = self.num_latents
if styles[0].ndim < 3:
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
else:
latent = styles[0]
# style mixing
else:
if inject_index is None:
inject_index = random.randint(1, self.num_latents - 1)
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
latent2 = styles[1].unsqueeze(1).repeat(
1, self.num_latents - inject_index, 1)
latent = torch.cat([latent, latent2], 1)
# 4x4 stage
out = self.constant_input(latent)
out = self.conv1(out, latent[:, 0], noise=injected_noise[0])
skip = self.to_rgb1(out, latent[:, 1])
_index = 1
# 8x8 ---> higher resolutions
for up_conv, conv, noise1, noise2, to_rgb in zip(
self.convs[::2], self.convs[1::2], injected_noise[1::2],
injected_noise[2::2], self.to_rgbs):
out = up_conv(out, latent[:, _index], noise=noise1)
out = conv(out, latent[:, _index + 1], noise=noise2)
skip = to_rgb(out, latent[:, _index + 2], skip)
_index += 2
# make sure the output image is torch.float32 to avoid RunTime Error
# in other modules
img = skip.to(torch.float32)
if return_latents or return_noise:
output_dict = dict(
fake_img=img,
latent=latent,
inject_index=inject_index,
noise_batch=noise_batch)
return output_dict
return img
@MODULES.register_module()
class StyleGAN2Discriminator(nn.Module):
"""StyleGAN2 Discriminator.
The architecture of this discriminator is proposed in StyleGAN2. More
details can be found in: Analyzing and Improving the Image Quality of
StyleGAN CVPR2020.
You can load pretrained model through passing information into
``pretrained`` argument. We have already offered official weights as
follows:
- stylegan2-ffhq-config-f: https://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-ffhq-config-f-official_20210327_171224-bce9310c.pth # noqa
- stylegan2-horse-config-f: https://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-horse-config-f-official_20210327_173203-ef3e69ca.pth # noqa
- stylegan2-car-config-f: https://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-car-config-f-official_20210327_172340-8cfe053c.pth # noqa
- stylegan2-cat-config-f: https://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-cat-config-f-official_20210327_172444-15bc485b.pth # noqa
- stylegan2-church-config-f: https://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-church-config-f-official_20210327_172657-1d42b7d1.pth # noqa
If you want to load the ema model, you can just use following codes:
.. code-block:: python
# ckpt_http is one of the valid path from http source
discriminator = StyleGAN2Discriminator(1024, 512,
pretrained=dict(
ckpt_path=ckpt_http,
prefix='discriminator'))
Of course, you can also download the checkpoint in advance and set
``ckpt_path`` with local path.
Note that our implementation adopts BGR image as input, while the
original StyleGAN2 provides RGB images to the discriminator. Thus, we
provide ``bgr2rgb`` argument to convert the image space. If your images
follow the RGB order, please set it to ``True`` accordingly.
Args:
in_size (int): The input size of images.
channel_multiplier (int, optional): The multiplier factor for the
channel number. Defaults to 2.
blur_kernel (list, optional): The blurry kernel. Defaults
to [1, 3, 3, 1].
mbstd_cfg (dict, optional): Configs for minibatch-stddev layer.
Defaults to dict(group_size=4, channel_groups=1).
num_fp16_scales (int, optional): The number of resolutions to use auto
fp16 training. Defaults to 0.
fp16_enabled (bool, optional): Whether to use fp16 training in this
module. Defaults to False.
out_fp32 (bool, optional): Whether to convert the output feature map to
`torch.float32`. Defaults to `True`.
convert_input_fp32 (bool, optional): Whether to convert input type to
fp32 if not `fp16_enabled`. This argument is designed to deal with
the cases where some modules are run in FP16 and others in FP32.
Defaults to True.
input_bgr2rgb (bool, optional): Whether to reformat the input channels
with order `rgb`. Since we provide several converted weights,
whose input order is `rgb`. You can set this argument to True if
you want to finetune on converted weights. Defaults to False.
pretrained (dict | None, optional): Information for pretained models.
The necessary key is 'ckpt_path'. Besides, you can also provide
'prefix' to load the generator part from the whole state dict.
Defaults to None.
"""
def __init__(self,
in_size,
channel_multiplier=2,
blur_kernel=[1, 3, 3, 1],
mbstd_cfg=dict(group_size=4, channel_groups=1),
num_fp16_scales=0,
fp16_enabled=False,
out_fp32=True,
convert_input_fp32=True,
input_bgr2rgb=False,
pretrained=None):
super().__init__()
self.num_fp16_scale = num_fp16_scales
self.fp16_enabled = fp16_enabled
self.convert_input_fp32 = convert_input_fp32
self.out_fp32 = out_fp32
channels = {
4: 512,
8: 512,
16: 512,
32: 512,
64: 256 * channel_multiplier,
128: 128 * channel_multiplier,
256: 64 * channel_multiplier,
512: 32 * channel_multiplier,
1024: 16 * channel_multiplier,
}
log_size = int(np.log2(in_size))
in_channels = channels[in_size]
_use_fp16 = num_fp16_scales > 0
convs = [
ConvDownLayer(3, channels[in_size], 1, fp16_enabled=_use_fp16)
]
for i in range(log_size, 2, -1):
out_channel = channels[2**(i - 1)]
# add fp16 training for higher resolutions
_use_fp16 = (log_size - i) < num_fp16_scales or fp16_enabled
convs.append(
ResBlock(
in_channels,
out_channel,
blur_kernel,
fp16_enabled=_use_fp16,
convert_input_fp32=convert_input_fp32))
in_channels = out_channel
self.convs = nn.Sequential(*convs)
self.mbstd_layer = ModMBStddevLayer(**mbstd_cfg)
self.final_conv = ConvDownLayer(in_channels + 1, channels[4], 3)
self.final_linear = nn.Sequential(
EqualLinearActModule(
channels[4] * 4 * 4,
channels[4],
act_cfg=dict(type='fused_bias')),
EqualLinearActModule(channels[4], 1),
)
self.input_bgr2rgb = input_bgr2rgb
if pretrained is not None:
self._load_pretrained_model(**pretrained)
def _load_pretrained_model(self,
ckpt_path,
prefix='',
map_location='cpu',
strict=True):
state_dict = _load_checkpoint_with_prefix(prefix, ckpt_path,
map_location)
self.load_state_dict(state_dict, strict=strict)
mmcv.print_log(f'Load pretrained model from {ckpt_path}', 'mmgen')
@auto_fp16()
def forward(self, x):
"""Forward function.
Args:
x (torch.Tensor): Input image tensor.
Returns:
torch.Tensor: Predict score for the input image.
"""
# This setting was used to finetune on converted weights
if self.input_bgr2rgb:
x = x[:, [2, 1, 0], ...]
x = self.convs(x)
x = self.mbstd_layer(x)
if not self.final_conv.fp16_enabled and self.convert_input_fp32:
x = x.to(torch.float32)
x = self.final_conv(x)
x = x.view(x.shape[0], -1)
x = self.final_linear(x)
return x
@MODULES.register_module()
class ADAStyleGAN2Discriminator(StyleGAN2Discriminator):
def __init__(self, in_size, *args, data_aug=None, **kwargs):
"""StyleGANv2 Discriminator with adaptive augmentation.
Args:
in_size (int): The input size of images.
data_aug (dict, optional): Config for data
augmentation. Defaults to None.
"""
super().__init__(in_size, *args, **kwargs)
self.with_ada = data_aug is not None
if self.with_ada:
self.ada_aug = build_module(data_aug)
self.ada_aug.requires_grad = False
self.log_size = int(np.log2(in_size))
def forward(self, x):
"""Forward function."""
if self.with_ada:
x = self.ada_aug.aug_pipeline(x)
return super().forward(x)
@MODULES.register_module()
class ADAAug(nn.Module):
"""Data Augmentation Module for Adaptive Discriminator augmentation.
Args:
aug_pipeline (dict, optional): Config for augmentation pipeline.
Defaults to None.
update_interval (int, optional): Interval for updating
augmentation probability. Defaults to 4.
augment_initial_p (float, optional): Initial augmentation
probability. Defaults to 0..
ada_target (float, optional): ADA target. Defaults to 0.6.
ada_kimg (int, optional): ADA training duration. Defaults to 500.
"""
def __init__(self,
aug_pipeline=None,
update_interval=4,
augment_initial_p=0.,
ada_target=0.6,
ada_kimg=500):
super().__init__()
self.aug_pipeline = AugmentPipe(**aug_pipeline)
self.update_interval = update_interval
self.ada_kimg = ada_kimg
self.ada_target = ada_target
self.aug_pipeline.p.copy_(torch.tensor(augment_initial_p))
# this log buffer stores two numbers: num_scalars, sum_scalars.
self.register_buffer('log_buffer', torch.zeros((2, )))
def update(self, iteration=0, num_batches=0):
"""Update Augment probability.
Args:
iteration (int, optional): Training iteration.
Defaults to 0.
num_batches (int, optional): The number of reals batches.
Defaults to 0.
"""
if (iteration + 1) % self.update_interval == 0:
adjust_step = float(num_batches * self.update_interval) / float(
self.ada_kimg * 1000.)
# get the mean value as the ada heuristic
ada_heuristic = self.log_buffer[1] / self.log_buffer[0]
adjust = np.sign(ada_heuristic.item() -
self.ada_target) * adjust_step
# update the augment p
# Note that p may be bigger than 1.0
self.aug_pipeline.p.copy_((self.aug_pipeline.p + adjust).max(
constant(0, device=self.log_buffer.device)))
self.log_buffer = self.log_buffer * 0.
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
import mmcv
import torch
import torch.nn as nn
from mmcv.runner.checkpoint import _load_checkpoint_with_prefix
from mmgen.models.architectures.common import get_module_device
from mmgen.models.builder import MODULES, build_module
from .utils import get_mean_latent
@MODULES.register_module()
class StyleGANv3Generator(nn.Module):
"""StyleGAN3 Generator.
In StyleGAN3, we make several changes to StyleGANv2's generator which
include transformed fourier features, filtered nonlinearities and
non-critical sampling, etc. More details can be found in: Alias-Free
Generative Adversarial Networks NeurIPS'2021.
Ref: https://github.com/NVlabs/stylegan3
Args:
out_size (int): The output size of the StyleGAN3 generator.
style_channels (int): The number of channels for style code.
img_channels (int): The number of output's channels.
noise_size (int, optional): Size of the input noise vector.
Defaults to 512.
rgb2bgr (bool, optional): Whether to reformat the output channels
with order `bgr`. We provide several pre-trained StyleGAN3
weights whose output channels order is `rgb`. You can set
this argument to True to use the weights.
pretrained (str | dict, optional): Path for the pretrained model or
dict containing information for pretained models whose necessary
key is 'ckpt_path'. Besides, you can also provide 'prefix' to load
the generator part from the whole state dict. Defaults to None.
synthesis_cfg (dict, optional): Config for synthesis network. Defaults
to dict(type='SynthesisNetwork').
mapping_cfg (dict, optional): Config for mapping network. Defaults to
dict(type='MappingNetwork').
"""
def __init__(self,
out_size,
style_channels,
img_channels,
noise_size=512,
rgb2bgr=False,
pretrained=None,
synthesis_cfg=dict(type='SynthesisNetwork'),
mapping_cfg=dict(type='MappingNetwork')):
super().__init__()
self.noise_size = noise_size
self.style_channels = style_channels
self.out_size = out_size
self.img_channels = img_channels
self.rgb2bgr = rgb2bgr
self._synthesis_cfg = deepcopy(synthesis_cfg)
self._synthesis_cfg.setdefault('style_channels', style_channels)
self._synthesis_cfg.setdefault('out_size', out_size)
self._synthesis_cfg.setdefault('img_channels', img_channels)
self.synthesis = build_module(self._synthesis_cfg)
self.num_ws = self.synthesis.num_ws
self._mapping_cfg = deepcopy(mapping_cfg)
self._mapping_cfg.setdefault('noise_size', noise_size)
self._mapping_cfg.setdefault('style_channels', style_channels)
self._mapping_cfg.setdefault('num_ws', self.num_ws)
self.style_mapping = build_module(self._mapping_cfg)
if pretrained is not None:
self._load_pretrained_model(**pretrained)
def _load_pretrained_model(self,
ckpt_path,
prefix='',
map_location='cpu',
strict=True):
state_dict = _load_checkpoint_with_prefix(prefix, ckpt_path,
map_location)
self.load_state_dict(state_dict, strict=strict)
mmcv.print_log(f'Load pretrained model from {ckpt_path}', 'mmgen')
def forward(self,
noise,
num_batches=0,
input_is_latent=False,
truncation=1,
num_truncation_layer=None,
update_emas=False,
force_fp32=True,
return_noise=False,
return_latents=False):
"""Forward Function for stylegan3.
Args:
noise (torch.Tensor | callable | None): You can directly give a
batch of noise through a ``torch.Tensor`` or offer a callable
function to sample a batch of noise data. Otherwise, the
``None`` indicates to use the default noise sampler.
num_batches (int, optional): The number of batch size.
Defaults to 0.
input_is_latent (bool, optional): If `True`, the input tensor is
the latent tensor. Defaults to False.
truncation (float, optional): Truncation factor. Give value less
than 1., the truncation trick will be adopted. Defaults to 1.
num_truncation_layer (int, optional): Number of layers use
truncated latent. Defaults to None.
update_emas (bool, optional): Whether update moving average of
mean latent. Defaults to False.
force_fp32 (bool, optional): Force fp32 ignore the weights.
Defaults to True.
return_noise (bool, optional): If True, ``noise_batch`` will be
returned in a dict with ``fake_img``. Defaults to False.
return_latents (bool, optional): If True, ``latent`` will be
returned in a dict with ``fake_img``. Defaults to False.
Returns:
torch.Tensor | dict: Generated image tensor or dictionary \
containing more data.
"""
# if input is latent, set noise size as the style_channels
noise_size = (
self.style_channels if input_is_latent else self.noise_size)
if isinstance(noise, torch.Tensor):
assert noise.shape[1] == noise_size
assert noise.ndim == 2, ('The noise should be in shape of (n, c), '
f'but got {noise.shape}')
noise_batch = noise
# receive a noise generator and sample noise.
elif callable(noise):
noise_generator = noise
assert num_batches > 0
noise_batch = noise_generator((num_batches, noise_size))
# otherwise, we will adopt default noise sampler.
else:
assert num_batches > 0
noise_batch = torch.randn((num_batches, noise_size))
device = get_module_device(self)
noise_batch = noise_batch.to(device)
if input_is_latent:
ws = noise_batch.unsqueeze(1).repeat([1, self.num_ws, 1])
else:
ws = self.style_mapping(
noise_batch,
truncation=truncation,
num_truncation_layer=num_truncation_layer,
update_emas=update_emas)
out_img = self.synthesis(
ws, update_emas=update_emas, force_fp32=force_fp32)
if self.rgb2bgr:
out_img = out_img[:, [2, 1, 0], ...]
if return_noise or return_latents:
output = dict(fake_img=out_img, noise_batch=noise_batch, latent=ws)
return output
return out_img
def get_mean_latent(self, num_samples=4096, **kwargs):
"""Get mean latent of W space in this generator.
Args:
num_samples (int, optional): Number of sample times. Defaults
to 4096.
Returns:
Tensor: Mean latent of this generator.
"""
if hasattr(self.style_mapping, 'w_avg'):
return self.style_mapping.w_avg
return get_mean_latent(self, num_samples, **kwargs)
def get_training_kwargs(self, phase):
"""Get training kwargs. In StyleGANv3, we enable fp16, and update
mangitude ema during training of discriminator. This function is used
to pass related arguments.
Args:
phase (str): Current training phase.
Returns:
dict: Training kwargs.
"""
if phase == 'disc':
return dict(update_emas=True, force_fp32=False)
if phase == 'gen':
return dict(force_fp32=False)
return {}
# Copyright (c) OpenMMLab. All rights reserved.
from .styleganv2_modules import (Blur, ConstantInput, ModulatedConv2d,
ModulatedStyleConv, ModulatedToRGB,
NoiseInjection)
from .styleganv3_modules import (MappingNetwork, SynthesisInput,
SynthesisLayer, SynthesisNetwork)
__all__ = [
'Blur', 'ModulatedStyleConv', 'ModulatedToRGB', 'NoiseInjection',
'ConstantInput', 'MappingNetwork', 'SynthesisInput', 'SynthesisLayer',
'SynthesisNetwork', 'ModulatedConv2d'
]
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmgen.models.architectures.pggan import (EqualizedLRConvModule,
EqualizedLRConvUpModule,
EqualizedLRLinearModule)
from mmgen.models.architectures.stylegan.modules import (Blur, ConstantInput,
NoiseInjection)
class AdaptiveInstanceNorm(nn.Module):
r"""Adaptive Instance Normalization Module.
Ref: https://github.com/rosinality/style-based-gan-pytorch/blob/master/model.py # noqa
Args:
in_channel (int): The number of input's channel.
style_dim (int): Style latent dimension.
"""
def __init__(self, in_channel, style_dim):
super().__init__()
self.norm = nn.InstanceNorm2d(in_channel)
self.affine = EqualizedLRLinearModule(style_dim, in_channel * 2)
self.affine.bias.data[:in_channel] = 1
self.affine.bias.data[in_channel:] = 0
def forward(self, input, style):
"""Forward function.
Args:
input (Tensor): Input tensor with shape (n, c, h, w).
style (Tensor): Input style tensor with shape (n, c).
Returns:
Tensor: Forward results.
"""
style = self.affine(style).unsqueeze(2).unsqueeze(3)
gamma, beta = style.chunk(2, 1)
out = self.norm(input)
out = gamma * out + beta
return out
class StyleConv(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
style_channels,
padding=1,
initial=False,
blur_kernel=[1, 2, 1],
upsample=False,
fused=False):
"""Convolutional style blocks composing of noise injector, AdaIN module
and convolution layers.
Args:
in_channels (int): The channel number of the input tensor.
out_channels (itn): The channel number of the output tensor.
kernel_size (int): The kernel size of convolution layers.
style_channels (int): The number of channels for style code.
padding (int, optional): Padding of convolution layers.
Defaults to 1.
initial (bool, optional): Whether this is the first StyleConv of
StyleGAN's generator. Defaults to False.
blur_kernel (list, optional): The blurry kernel.
Defaults to [1, 2, 1].
upsample (bool, optional): Whether perform upsampling.
Defaults to False.
fused (bool, optional): Whether use fused upconv.
Defaults to False.
"""
super().__init__()
if initial:
self.conv1 = ConstantInput(in_channels)
else:
if upsample:
if fused:
self.conv1 = nn.Sequential(
EqualizedLRConvUpModule(
in_channels,
out_channels,
kernel_size,
padding=padding,
act_cfg=dict(type='LeakyReLU',
negative_slope=0.2)),
Blur(blur_kernel, pad=(1, 1)),
)
else:
self.conv1 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='nearest'),
EqualizedLRConvModule(
in_channels,
out_channels,
kernel_size,
padding=padding,
act_cfg=None), Blur(blur_kernel, pad=(1, 1)))
else:
self.conv1 = EqualizedLRConvModule(
in_channels,
out_channels,
kernel_size,
padding=padding,
act_cfg=None)
self.noise_injector1 = NoiseInjection()
self.activate1 = nn.LeakyReLU(0.2)
self.adain1 = AdaptiveInstanceNorm(out_channels, style_channels)
self.conv2 = EqualizedLRConvModule(
out_channels,
out_channels,
kernel_size,
padding=padding,
act_cfg=None)
self.noise_injector2 = NoiseInjection()
self.activate2 = nn.LeakyReLU(0.2)
self.adain2 = AdaptiveInstanceNorm(out_channels, style_channels)
def forward(self,
x,
style1,
style2,
noise1=None,
noise2=None,
return_noise=False):
"""Forward function.
Args:
x (Tensor): Input tensor.
style1 (Tensor): Input style tensor with shape (n, c).
style2 (Tensor): Input style tensor with shape (n, c).
noise1 (Tensor, optional): Noise tensor with shape (n, c, h, w).
Defaults to None.
noise2 (Tensor, optional): Noise tensor with shape (n, c, h, w).
Defaults to None.
return_noise (bool, optional): If True, ``noise1`` and ``noise2``
will be returned with ``out``. Defaults to False.
Returns:
Tensor | tuple[Tensor]: Forward results.
"""
out = self.conv1(x)
if return_noise:
out, noise1 = self.noise_injector1(
out, noise=noise1, return_noise=return_noise)
else:
out = self.noise_injector1(
out, noise=noise1, return_noise=return_noise)
out = self.activate1(out)
out = self.adain1(out, style1)
out = self.conv2(out)
if return_noise:
out, noise2 = self.noise_injector2(
out, noise=noise2, return_noise=return_noise)
else:
out = self.noise_injector2(
out, noise=noise2, return_noise=return_noise)
out = self.activate2(out)
out = self.adain2(out, style2)
if return_noise:
return out, noise1, noise2
return out
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from functools import partial
import mmcv
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn.bricks.activation import build_activation_layer
from mmcv.ops.fused_bias_leakyrelu import (FusedBiasLeakyReLU,
fused_bias_leakyrelu)
from mmcv.ops.upfirdn2d import upfirdn2d
from mmcv.runner.dist_utils import get_dist_info
from mmgen.core.runners.fp16_utils import auto_fp16
from mmgen.models.architectures.pggan import (EqualizedLRConvModule,
EqualizedLRLinearModule,
equalized_lr)
from mmgen.models.common import AllGatherLayer
from mmgen.ops import conv2d, conv_transpose2d
class _FusedBiasLeakyReLU(FusedBiasLeakyReLU):
"""Wrap FusedBiasLeakyReLU to support FP16 training."""
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input feature map with shape of (N, C, ...).
Returns:
Tensor: Output feature map.
"""
return fused_bias_leakyrelu(x, self.bias.to(x.dtype),
self.negative_slope, self.scale)
class EqualLinearActModule(nn.Module):
"""Equalized LR Linear Module with Activation Layer.
This module is modified from ``EqualizedLRLinearModule`` defined in PGGAN.
The major features updated in this module is adding support for activation
layers used in StyleGAN2.
Args:
equalized_lr_cfg (dict | None, optional): Config for equalized lr.
Defaults to dict(gain=1., lr_mul=1.).
bias (bool, optional): Whether to use bias item. Defaults to True.
bias_init (float, optional): The value for bias initialization.
Defaults to ``0.``.
act_cfg (dict | None, optional): Config for activation layer.
Defaults to None.
"""
def __init__(self,
*args,
equalized_lr_cfg=dict(gain=1., lr_mul=1.),
bias=True,
bias_init=0.,
act_cfg=None,
**kwargs):
super().__init__()
self.with_activation = act_cfg is not None
# w/o bias in linear layer
self.linear = EqualizedLRLinearModule(
*args, bias=False, equalized_lr_cfg=equalized_lr_cfg, **kwargs)
if equalized_lr_cfg is not None:
self.lr_mul = equalized_lr_cfg.get('lr_mul', 1.)
else:
self.lr_mul = 1.
# define bias outside linear layer
if bias:
self.bias = nn.Parameter(
torch.zeros(self.linear.out_features).fill_(bias_init))
else:
self.bias = None
if self.with_activation:
act_cfg = deepcopy(act_cfg)
if act_cfg['type'] == 'fused_bias':
self.act_type = act_cfg.pop('type')
assert self.bias is not None
self.activate = partial(fused_bias_leakyrelu, **act_cfg)
else:
self.act_type = 'normal'
self.activate = build_activation_layer(act_cfg)
else:
self.act_type = None
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input feature map with shape of (N, C, ...).
Returns:
Tensor: Output feature map.
"""
if x.ndim >= 3:
x = x.reshape(x.size(0), -1)
x = self.linear(x)
if self.with_activation and self.act_type == 'fused_bias':
x = self.activate(x, self.bias * self.lr_mul)
elif self.bias is not None and self.with_activation:
x = self.activate(x + self.bias * self.lr_mul)
elif self.bias is not None:
x = x + self.bias * self.lr_mul
elif self.with_activation:
x = self.activate(x)
return x
def _make_kernel(k):
k = torch.tensor(k, dtype=torch.float32)
if k.ndim == 1:
k = k[None, :] * k[:, None]
k /= k.sum()
return k
class UpsampleUpFIRDn(nn.Module):
"""UpFIRDn for Upsampling.
This module is used in the ``to_rgb`` layers in StyleGAN2 for upsampling
the images.
Args:
kernel (Array): Blur kernel/filter used in UpFIRDn.
factor (int, optional): Upsampling factor. Defaults to 2.
"""
def __init__(self, kernel, factor=2):
super().__init__()
self.factor = factor
kernel = _make_kernel(kernel) * (factor**2)
self.register_buffer('kernel', kernel)
p = kernel.shape[0] - factor
pad0 = (p + 1) // 2 + factor - 1
pad1 = p // 2
self.pad = (pad0, pad1)
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input feature map with shape of (N, C, H, W).
Returns:
Tensor: Output feature map.
"""
out = upfirdn2d(
x, self.kernel.to(x.dtype), up=self.factor, down=1, pad=self.pad)
return out
class DownsampleUpFIRDn(nn.Module):
"""UpFIRDn for Downsampling.
This module is mentioned in StyleGAN2 for dowampling the feature maps.
Args:
kernel (Array): Blur kernel/filter used in UpFIRDn.
factor (int, optional): Downsampling factor. Defaults to 2.
"""
def __init__(self, kernel, factor=2):
super().__init__()
self.factor = factor
kernel = _make_kernel(kernel)
self.register_buffer('kernel', kernel)
p = kernel.shape[0] - factor
pad0 = (p + 1) // 2
pad1 = p // 2
self.pad = (pad0, pad1)
def forward(self, input):
"""Forward function.
Args:
input (Tensor): Input feature map with shape of (N, C, H, W).
Returns:
Tensor: Output feature map.
"""
out = upfirdn2d(
input,
self.kernel.to(input.dtype),
up=1,
down=self.factor,
pad=self.pad)
return out
class Blur(nn.Module):
"""Blur module.
This module is adopted rightly after upsampling operation in StyleGAN2.
Args:
kernel (Array): Blur kernel/filter used in UpFIRDn.
pad (list[int]): Padding for features.
upsample_factor (int, optional): Upsampling factor. Defaults to 1.
"""
def __init__(self, kernel, pad, upsample_factor=1):
super().__init__()
kernel = _make_kernel(kernel)
if upsample_factor > 1:
kernel = kernel * (upsample_factor**2)
self.register_buffer('kernel', kernel)
self.pad = pad
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input feature map with shape of (N, C, H, W).
Returns:
Tensor: Output feature map.
"""
# In Tero's implementation, he uses fp32
return upfirdn2d(x, self.kernel.to(x.dtype), pad=self.pad)
class ModulatedConv2d(nn.Module):
r"""Modulated Conv2d in StyleGANv2.
This module implements the modulated convolution layers proposed in
StyleGAN2. Details can be found in Analyzing and Improving the Image
Quality of StyleGAN, CVPR2020.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
kernel_size (int): Kernel size, same as :obj:`nn.Con2d`.
style_channels (int): Channels for the style codes.
demodulate (bool, optional): Whether to adopt demodulation.
Defaults to True.
upsample (bool, optional): Whether to adopt upsampling in features.
Defaults to False.
downsample (bool, optional): Whether to adopt downsampling in features.
Defaults to False.
blur_kernel (list[int], optional): Blurry kernel.
Defaults to [1, 3, 3, 1].
equalized_lr_cfg (dict | None, optional): Configs for equalized lr.
Defaults to dict(mode='fan_in', lr_mul=1., gain=1.).
style_mod_cfg (dict, optional): Configs for style modulation module.
Defaults to dict(bias_init=1.).
style_bias (float, optional): Bias value for style code.
Defaults to 0..
eps (float, optional): Epsilon value to avoid computation error.
Defaults to 1e-8.
"""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
style_channels,
demodulate=True,
upsample=False,
downsample=False,
blur_kernel=[1, 3, 3, 1],
equalized_lr_cfg=dict(mode='fan_in', lr_mul=1., gain=1.),
style_mod_cfg=dict(bias_init=1.),
style_bias=0.,
padding=None, # self define padding
eps=1e-8):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.style_channels = style_channels
self.demodulate = demodulate
# sanity check for kernel size
assert isinstance(self.kernel_size,
int) and (self.kernel_size >= 1
and self.kernel_size % 2 == 1)
self.upsample = upsample
self.downsample = downsample
self.style_bias = style_bias
self.eps = eps
# build style modulation module
style_mod_cfg = dict() if style_mod_cfg is None else style_mod_cfg
self.style_modulation = EqualLinearActModule(style_channels,
in_channels,
**style_mod_cfg)
# set lr_mul for conv weight
lr_mul_ = 1.
if equalized_lr_cfg is not None:
lr_mul_ = equalized_lr_cfg.get('lr_mul', 1.)
self.weight = nn.Parameter(
torch.randn(1, out_channels, in_channels, kernel_size,
kernel_size).div_(lr_mul_))
# build blurry layer for upsampling
if upsample:
factor = 2
p = (len(blur_kernel) - factor) - (kernel_size - 1)
pad0 = (p + 1) // 2 + factor - 1
pad1 = p // 2 + 1
self.blur = Blur(blur_kernel, (pad0, pad1), upsample_factor=factor)
# build blurry layer for downsampling
if downsample:
factor = 2
p = (len(blur_kernel) - factor) + (kernel_size - 1)
pad0 = (p + 1) // 2
pad1 = p // 2
self.blur = Blur(blur_kernel, pad=(pad0, pad1))
# add equalized_lr hook for conv weight
if equalized_lr_cfg is not None:
equalized_lr(self, **equalized_lr_cfg)
self.padding = padding if padding else (kernel_size // 2)
def forward(self, x, style, input_gain=None):
n, c, h, w = x.shape
weight = self.weight
# Pre-normalize inputs to avoid FP16 overflow.
if x.dtype == torch.float16 and self.demodulate:
weight = weight * (
1 / np.sqrt(
self.in_channels * self.kernel_size * self.kernel_size) /
weight.norm(float('inf'), dim=[1, 2, 3], keepdim=True)
) # max_Ikk
style = style / style.norm(
float('inf'), dim=1, keepdim=True) # max_I
# process style code
style = self.style_modulation(style).view(n, 1, c, 1,
1) + self.style_bias
# combine weight and style
weight = weight * style
if self.demodulate:
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
weight = weight * demod.view(n, self.out_channels, 1, 1, 1)
if input_gain is not None:
# input_gain shape [batch, in_ch]
input_gain = input_gain.expand(n, self.in_channels)
# weight shape [batch, out_ch, in_ch, kernel_size, kernel_size]
weight = weight * input_gain.unsqueeze(1).unsqueeze(3).unsqueeze(4)
weight = weight.view(n * self.out_channels, c, self.kernel_size,
self.kernel_size)
weight = weight.to(x.dtype)
if self.upsample:
x = x.reshape(1, n * c, h, w)
weight = weight.view(n, self.out_channels, c, self.kernel_size,
self.kernel_size)
weight = weight.transpose(1, 2).reshape(n * c, self.out_channels,
self.kernel_size,
self.kernel_size)
x = conv_transpose2d(x, weight, padding=0, stride=2, groups=n)
x = x.reshape(n, self.out_channels, *x.shape[-2:])
x = self.blur(x)
elif self.downsample:
x = self.blur(x)
x = x.view(1, n * self.in_channels, *x.shape[-2:])
x = conv2d(x, weight, stride=2, padding=0, groups=n)
x = x.view(n, self.out_channels, *x.shape[-2:])
else:
x = x.reshape(1, n * c, h, w)
x = conv2d(x, weight, stride=1, padding=self.padding, groups=n)
x = x.view(n, self.out_channels, *x.shape[-2:])
return x
class NoiseInjection(nn.Module):
"""Noise Injection Module.
In StyleGAN2, they adopt this module to inject spatial random noise map in
the generators.
Args:
noise_weight_init (float, optional): Initialization weight for noise
injection. Defaults to ``0.``.
"""
def __init__(self, noise_weight_init=0.):
super().__init__()
self.weight = nn.Parameter(torch.zeros(1).fill_(noise_weight_init))
def forward(self, image, noise=None, return_noise=False):
"""Forward Function.
Args:
image (Tensor): Spatial features with a shape of (N, C, H, W).
noise (Tensor, optional): Noises from the outside.
Defaults to None.
return_noise (bool, optional): Whether to return noise tensor.
Defaults to False.
Returns:
Tensor: Output features.
"""
if noise is None:
batch, _, height, width = image.shape
noise = image.new_empty(batch, 1, height, width).normal_()
noise = noise.to(image.dtype)
if return_noise:
return image + self.weight.to(image.dtype) * noise, noise
return image + self.weight.to(image.dtype) * noise
class ConstantInput(nn.Module):
"""Constant Input.
In StyleGAN2, they substitute the original head noise input with such a
constant input module.
Args:
channel (int): Channels for the constant input tensor.
size (int, optional): Spatial size for the constant input.
Defaults to 4.
"""
def __init__(self, channel, size=4):
super().__init__()
if isinstance(size, int):
size = [size, size]
elif mmcv.is_seq_of(size, int):
assert len(
size
) == 2, f'The length of size should be 2 but got {len(size)}'
else:
raise ValueError(f'Got invalid value in size, {size}')
self.input = nn.Parameter(torch.randn(1, channel, *size))
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input feature map with shape of (N, C, ...).
Returns:
Tensor: Output feature map.
"""
batch = x.shape[0]
out = self.input.repeat(batch, 1, 1, 1)
return out
class ModulatedPEConv2d(nn.Module):
r"""Modulated Conv2d in StyleGANv2 with Positional Encoding (PE).
This module is modified from the ``ModulatedConv2d`` in StyleGAN2 to
support the experiments in: Positional Encoding as Spatial Inductive Bias
in GANs, CVPR'2021.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
kernel_size (int): Kernel size, same as :obj:`nn.Con2d`.
style_channels (int): Channels for the style codes.
demodulate (bool, optional): Whether to adopt demodulation.
Defaults to True.
upsample (bool, optional): Whether to adopt upsampling in features.
Defaults to False.
downsample (bool, optional): Whether to adopt downsampling in features.
Defaults to False.
blur_kernel (list[int], optional): Blurry kernel.
Defaults to [1, 3, 3, 1].
equalized_lr_cfg (dict | None, optional): Configs for equalized lr.
Defaults to dict(mode='fan_in', lr_mul=1., gain=1.).
style_mod_cfg (dict, optional): Configs for style modulation module.
Defaults to dict(bias_init=1.).
style_bias (float, optional): Bias value for style code.
Defaults to 0..
eps (float, optional): Epsilon value to avoid computation error.
Defaults to 1e-8.
no_pad (bool, optional): Whether to removing the padding in
convolution. Defaults to False.
deconv2conv (bool, optional): Whether to substitute the transposed conv
with (conv2d, upsampling). Defaults to False.
interp_pad (int | None, optional): The padding number of interpolation
pad. Defaults to None.
up_config (dict, optional): Upsampling config.
Defaults to dict(scale_factor=2, mode='nearest').
up_after_conv (bool, optional): Whether to adopt upsampling after
convolution. Defaults to False.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
style_channels,
demodulate=True,
upsample=False,
downsample=False,
blur_kernel=[1, 3, 3, 1],
equalized_lr_cfg=dict(mode='fan_in', lr_mul=1., gain=1.),
style_mod_cfg=dict(bias_init=1.),
style_bias=0.,
eps=1e-8,
no_pad=False,
deconv2conv=False,
interp_pad=None,
up_config=dict(scale_factor=2, mode='nearest'),
up_after_conv=False):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.style_channels = style_channels
self.demodulate = demodulate
# sanity check for kernel size
assert isinstance(self.kernel_size,
int) and (self.kernel_size >= 1
and self.kernel_size % 2 == 1)
self.upsample = upsample
self.downsample = downsample
self.style_bias = style_bias
self.eps = eps
self.no_pad = no_pad
self.deconv2conv = deconv2conv
self.interp_pad = interp_pad
self.with_interp_pad = interp_pad is not None
self.up_config = deepcopy(up_config)
self.up_after_conv = up_after_conv
# build style modulation module
style_mod_cfg = dict() if style_mod_cfg is None else style_mod_cfg
self.style_modulation = EqualLinearActModule(style_channels,
in_channels,
**style_mod_cfg)
# set lr_mul for conv weight
lr_mul_ = 1.
if equalized_lr_cfg is not None:
lr_mul_ = equalized_lr_cfg.get('lr_mul', 1.)
self.weight = nn.Parameter(
torch.randn(1, out_channels, in_channels, kernel_size,
kernel_size).div_(lr_mul_))
# build blurry layer for upsampling
if upsample and not self.deconv2conv:
factor = 2
p = (len(blur_kernel) - factor) - (kernel_size - 1)
pad0 = (p + 1) // 2 + factor - 1
pad1 = p // 2 + 1
self.blur = Blur(blur_kernel, (pad0, pad1), upsample_factor=factor)
# build blurry layer for downsampling
if downsample:
factor = 2
p = (len(blur_kernel) - factor) + (kernel_size - 1)
pad0 = (p + 1) // 2
pad1 = p // 2
self.blur = Blur(blur_kernel, pad=(pad0, pad1))
# add equalized_lr hook for conv weight
if equalized_lr_cfg is not None:
equalized_lr(self, **equalized_lr_cfg)
# if `no_pad`, remove all of the padding in conv
self.padding = kernel_size // 2 if not no_pad else 0
def forward(self, x, style):
"""Forward function.
Args:
x ([Tensor): Input features with shape of (N, C, H, W).
style (Tensor): Style latent with shape of (N, C).
Returns:
Tensor: Output feature with shape of (N, C, H, W).
"""
n, c, h, w = x.shape
# process style code
style = self.style_modulation(style).view(n, 1, c, 1,
1) + self.style_bias
# combine weight and style
weight = self.weight * style
if self.demodulate:
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
weight = weight * demod.view(n, self.out_channels, 1, 1, 1)
weight = weight.view(n * self.out_channels, c, self.kernel_size,
self.kernel_size)
if self.upsample and not self.deconv2conv:
x = x.reshape(1, n * c, h, w)
weight = weight.view(n, self.out_channels, c, self.kernel_size,
self.kernel_size)
weight = weight.transpose(1, 2).reshape(n * c, self.out_channels,
self.kernel_size,
self.kernel_size)
x = conv_transpose2d(x, weight, padding=0, stride=2, groups=n)
x = x.reshape(n, self.out_channels, *x.shape[-2:])
x = self.blur(x)
elif self.upsample and self.deconv2conv:
if self.up_after_conv:
x = x.reshape(1, n * c, h, w)
x = conv2d(x, weight, padding=self.padding, groups=n)
x = x.view(n, self.out_channels, *x.shape[2:4])
if self.with_interp_pad:
h_, w_ = x.shape[-2:]
up_cfg_ = deepcopy(self.up_config)
up_scale = up_cfg_.pop('scale_factor')
size_ = (h_ * up_scale + self.interp_pad,
w_ * up_scale + self.interp_pad)
x = F.interpolate(x, size=size_, **up_cfg_)
else:
x = F.interpolate(x, **self.up_config)
if not self.up_after_conv:
h_, w_ = x.shape[-2:]
x = x.view(1, n * c, h_, w_)
x = conv2d(x, weight, padding=self.padding, groups=n)
x = x.view(n, self.out_channels, *x.shape[2:4])
elif self.downsample:
x = self.blur(x)
x = x.view(1, n * self.in_channels, *x.shape[-2:])
x = conv2d(x, weight, stride=2, padding=0, groups=n)
x = x.view(n, self.out_channels, *x.shape[-2:])
else:
x = x.view(1, n * c, h, w)
x = conv2d(x, weight, stride=1, padding=self.padding, groups=n)
x = x.view(n, self.out_channels, *x.shape[-2:])
return x
class ModulatedStyleConv(nn.Module):
"""Modulated Style Convolution.
In this module, we integrate the modulated conv2d, noise injector and
activation layers into together.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
kernel_size (int): Kernel size, same as :obj:`nn.Con2d`.
style_channels (int): Channels for the style codes.
demodulate (bool, optional): Whether to adopt demodulation.
Defaults to True.
upsample (bool, optional): Whether to adopt upsampling in features.
Defaults to False.
downsample (bool, optional): Whether to adopt downsampling in features.
Defaults to False.
blur_kernel (list[int], optional): Blurry kernel.
Defaults to [1, 3, 3, 1].
equalized_lr_cfg (dict | None, optional): Configs for equalized lr.
Defaults to dict(mode='fan_in', lr_mul=1., gain=1.).
style_mod_cfg (dict, optional): Configs for style modulation module.
Defaults to dict(bias_init=1.).
style_bias (float, optional): Bias value for style code.
Defaults to ``0.``.
fp16_enabled (bool, optional): Whether to use fp16 training in this
module. Defaults to False.
conv_clamp (float, optional): Clamp the convolutional layer results to
avoid gradient overflow. Defaults to `256.0`.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
style_channels,
upsample=False,
blur_kernel=[1, 3, 3, 1],
demodulate=True,
style_mod_cfg=dict(bias_init=1.),
style_bias=0.,
fp16_enabled=False,
conv_clamp=256):
super().__init__()
# add support for fp16
self.fp16_enabled = fp16_enabled
self.conv_clamp = float(conv_clamp)
self.conv = ModulatedConv2d(
in_channels,
out_channels,
kernel_size,
style_channels,
demodulate=demodulate,
upsample=upsample,
blur_kernel=blur_kernel,
style_mod_cfg=style_mod_cfg,
style_bias=style_bias)
self.noise_injector = NoiseInjection()
self.activate = _FusedBiasLeakyReLU(out_channels)
# if self.fp16_enabled:
# self.half()
@auto_fp16(apply_to=('x', 'noise'))
def forward(self, x, style, noise=None, return_noise=False):
"""Forward Function.
Args:
x ([Tensor): Input features with shape of (N, C, H, W).
style (Tensor): Style latent with shape of (N, C).
noise (Tensor, optional): Noise for injection. Defaults to None.
return_noise (bool, optional): Whether to return noise tensors.
Defaults to False.
Returns:
Tensor: Output features with shape of (N, C, H, W)
"""
out = self.conv(x, style)
if return_noise:
out, noise = self.noise_injector(
out, noise=noise, return_noise=return_noise)
else:
out = self.noise_injector(
out, noise=noise, return_noise=return_noise)
# TODO: FP16 in activate layers
out = self.activate(out)
if self.fp16_enabled:
out = torch.clamp(out, min=-self.conv_clamp, max=self.conv_clamp)
if return_noise:
return out, noise
return out
class ModulatedPEStyleConv(nn.Module):
"""Modulated Style Convolution with Positional Encoding.
This module is modified from the ``ModulatedStyleConv`` in StyleGAN2 to
support the experiments in: Positional Encoding as Spatial Inductive Bias
in GANs, CVPR'2021.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
kernel_size (int): Kernel size, same as :obj:`nn.Con2d`.
style_channels (int): Channels for the style codes.
demodulate (bool, optional): Whether to adopt demodulation.
Defaults to True.
upsample (bool, optional): Whether to adopt upsampling in features.
Defaults to False.
downsample (bool, optional): Whether to adopt downsampling in features.
Defaults to False.
blur_kernel (list[int], optional): Blurry kernel.
Defaults to [1, 3, 3, 1].
equalized_lr_cfg (dict | None, optional): Configs for equalized lr.
Defaults to dict(mode='fan_in', lr_mul=1., gain=1.).
style_mod_cfg (dict, optional): Configs for style modulation module.
Defaults to dict(bias_init=1.).
style_bias (float, optional): Bias value for style code.
Defaults to 0..
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
style_channels,
upsample=False,
blur_kernel=[1, 3, 3, 1],
demodulate=True,
style_mod_cfg=dict(bias_init=1.),
style_bias=0.,
**kwargs):
super().__init__()
self.conv = ModulatedPEConv2d(
in_channels,
out_channels,
kernel_size,
style_channels,
demodulate=demodulate,
upsample=upsample,
blur_kernel=blur_kernel,
style_mod_cfg=style_mod_cfg,
style_bias=style_bias,
**kwargs)
self.noise_injector = NoiseInjection()
self.activate = _FusedBiasLeakyReLU(out_channels)
def forward(self, x, style, noise=None, return_noise=False):
"""Forward Function.
Args:
x ([Tensor): Input features with shape of (N, C, H, W).
style (Tensor): Style latent with shape of (N, C).
noise (Tensor, optional): Noise for injection. Defaults to None.
return_noise (bool, optional): Whether to return noise tensors.
Defaults to False.
Returns:
Tensor: Output features with shape of (N, C, H, W)
"""
out = self.conv(x, style)
if return_noise:
out, noise = self.noise_injector(
out, noise=noise, return_noise=return_noise)
else:
out = self.noise_injector(
out, noise=noise, return_noise=return_noise)
out = self.activate(out)
if return_noise:
return out, noise
return out
class ModulatedToRGB(nn.Module):
"""To RGB layer.
This module is designed to output image tensor in StyleGAN2.
Args:
in_channels (int): Input channels.
style_channels (int): Channels for the style codes.
out_channels (int, optional): Output channels. Defaults to 3.
upsample (bool, optional): Whether to adopt upsampling in features.
Defaults to False.
blur_kernel (list[int], optional): Blurry kernel.
Defaults to [1, 3, 3, 1].
style_mod_cfg (dict, optional): Configs for style modulation module.
Defaults to dict(bias_init=1.).
style_bias (float, optional): Bias value for style code.
Defaults to 0..
fp16_enabled (bool, optional): Whether to use fp16 training in this
module. Defaults to False.
conv_clamp (float, optional): Clamp the convolutional layer results to
avoid gradient overflow. Defaults to `256.0`.
out_fp32 (bool, optional): Whether to convert the output feature map to
`torch.float32`. Defaults to `True`.
"""
def __init__(self,
in_channels,
style_channels,
out_channels=3,
upsample=True,
blur_kernel=[1, 3, 3, 1],
style_mod_cfg=dict(bias_init=1.),
style_bias=0.,
fp16_enabled=False,
conv_clamp=256,
out_fp32=True):
super().__init__()
if upsample:
self.upsample = UpsampleUpFIRDn(blur_kernel)
# add support for fp16
self.fp16_enabled = fp16_enabled
self.conv_clamp = float(conv_clamp)
self.conv = ModulatedConv2d(
in_channels,
out_channels=out_channels,
kernel_size=1,
style_channels=style_channels,
demodulate=False,
style_mod_cfg=style_mod_cfg,
style_bias=style_bias)
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
# enforece the output to be fp32 (follow Tero's implementation)
self.out_fp32 = out_fp32
@auto_fp16(apply_to=('x', 'style'))
def forward(self, x, style, skip=None):
"""Forward Function.
Args:
x ([Tensor): Input features with shape of (N, C, H, W).
style (Tensor): Style latent with shape of (N, C).
skip (Tensor, optional): Tensor for skip link. Defaults to None.
Returns:
Tensor: Output features with shape of (N, C, H, W)
"""
out = self.conv(x, style)
out = out + self.bias.to(x.dtype)
if self.fp16_enabled:
out = torch.clamp(out, min=-self.conv_clamp, max=self.conv_clamp)
# Here, Tero adopts FP16 at `skip`.
if skip is not None:
skip = self.upsample(skip)
out = out + skip
return out
class ConvDownLayer(nn.Sequential):
"""Convolution and Downsampling layer.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
kernel_size (int): Kernel size, same as :obj:`nn.Con2d`.
downsample (bool, optional): Whether to adopt downsampling in features.
Defaults to False.
blur_kernel (list[int], optional): Blurry kernel.
Defaults to [1, 3, 3, 1].
bias (bool, optional): Whether to use bias parameter. Defaults to True.
act_cfg (dict, optional): Activation configs.
Defaults to dict(type='fused_bias').
fp16_enabled (bool, optional): Whether to use fp16 training in this
module. Defaults to False.
conv_clamp (float, optional): Clamp the convolutional layer results to
avoid gradient overflow. Defaults to `256.0`.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
downsample=False,
blur_kernel=[1, 3, 3, 1],
bias=True,
act_cfg=dict(type='fused_bias'),
fp16_enabled=False,
conv_clamp=256.):
self.fp16_enabled = fp16_enabled
self.conv_clamp = float(conv_clamp)
layers = []
if downsample:
factor = 2
p = (len(blur_kernel) - factor) + (kernel_size - 1)
pad0 = (p + 1) // 2
pad1 = p // 2
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
stride = 2
self.padding = 0
else:
stride = 1
self.padding = kernel_size // 2
self.with_fused_bias = act_cfg is not None and act_cfg.get(
'type') == 'fused_bias'
if self.with_fused_bias:
conv_act_cfg = None
else:
conv_act_cfg = act_cfg
layers.append(
EqualizedLRConvModule(
in_channels,
out_channels,
kernel_size,
padding=self.padding,
stride=stride,
bias=bias and not self.with_fused_bias,
norm_cfg=None,
act_cfg=conv_act_cfg,
equalized_lr_cfg=dict(mode='fan_in', gain=1.)))
if self.with_fused_bias:
layers.append(_FusedBiasLeakyReLU(out_channels))
super(ConvDownLayer, self).__init__(*layers)
@auto_fp16(apply_to=('x', ))
def forward(self, x):
x = super().forward(x)
if self.fp16_enabled:
x = torch.clamp(x, min=-self.conv_clamp, max=self.conv_clamp)
return x
class ResBlock(nn.Module):
"""Residual block used in the discriminator of StyleGAN2.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
kernel_size (int): Kernel size, same as :obj:`nn.Con2d`.
fp16_enabled (bool, optional): Whether to use fp16 training in this
module. Defaults to False.
convert_input_fp32 (bool, optional): Whether to convert input type to
fp32 if not `fp16_enabled`. This argument is designed to deal with
the cases where some modules are run in FP16 and others in FP32.
Defaults to True.
"""
def __init__(self,
in_channels,
out_channels,
blur_kernel=[1, 3, 3, 1],
fp16_enabled=False,
convert_input_fp32=True):
super().__init__()
self.fp16_enabled = fp16_enabled
self.convert_input_fp32 = convert_input_fp32
self.conv1 = ConvDownLayer(
in_channels, in_channels, 3, blur_kernel=blur_kernel)
self.conv2 = ConvDownLayer(
in_channels,
out_channels,
3,
downsample=True,
blur_kernel=blur_kernel)
self.skip = ConvDownLayer(
in_channels,
out_channels,
1,
downsample=True,
act_cfg=None,
bias=False,
blur_kernel=blur_kernel)
@auto_fp16()
def forward(self, input):
"""Forward function.
Args:
input (Tensor): Input feature map with shape of (N, C, H, W).
Returns:
Tensor: Output feature map.
"""
# TODO: study whether this explicit datatype transfer will harm the
# apex training speed
if not self.fp16_enabled and self.convert_input_fp32:
input = input.to(torch.float32)
out = self.conv1(input)
out = self.conv2(out)
skip = self.skip(input)
out = (out + skip) / np.sqrt(2)
return out
class ModMBStddevLayer(nn.Module):
"""Modified MiniBatch Stddev Layer.
This layer is modified from ``MiniBatchStddevLayer`` used in PGGAN. In
StyleGAN2, the authors add a new feature, `channel_groups`, into this
layer.
Note that to accelerate the training procedure, we also add a new feature
of ``sync_std`` to achieve multi-nodes/machine training. This feature is
still in beta version and we have tested it on 256 scales.
Args:
group_size (int, optional): The size of groups in batch dimension.
Defaults to 4.
channel_groups (int, optional): The size of groups in channel
dimension. Defaults to 1.
sync_std (bool, optional): Whether to use synchronized std feature.
Defaults to False.
sync_groups (int | None, optional): The size of groups in node
dimension. Defaults to None.
eps (float, optional): Epsilon value to avoid computation error.
Defaults to 1e-8.
"""
def __init__(self,
group_size=4,
channel_groups=1,
sync_std=False,
sync_groups=None,
eps=1e-8):
super().__init__()
self.group_size = group_size
self.eps = eps
self.channel_groups = channel_groups
self.sync_std = sync_std
self.sync_groups = group_size if sync_groups is None else sync_groups
if self.sync_std:
assert torch.distributed.is_initialized(
), 'Only in distributed training can the sync_std be activated.'
mmcv.print_log('Adopt synced minibatch stddev layer', 'mmgen')
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input feature map with shape of (N, C, H, W).
Returns:
Tensor: Output feature map with shape of (N, C+1, H, W).
"""
if self.sync_std:
# concatenate all features
all_features = torch.cat(AllGatherLayer.apply(x), dim=0)
# get the exact features we need in calculating std-dev
rank, ws = get_dist_info()
local_bs = all_features.shape[0] // ws
start_idx = local_bs * rank
# avoid the case where start idx near the tail of features
if start_idx + self.sync_groups > all_features.shape[0]:
start_idx = all_features.shape[0] - self.sync_groups
end_idx = min(local_bs * rank + self.sync_groups,
all_features.shape[0])
x = all_features[start_idx:end_idx]
# batch size should be smaller than or equal to group size. Otherwise,
# batch size should be divisible by the group size.
assert x.shape[
0] <= self.group_size or x.shape[0] % self.group_size == 0, (
'Batch size be smaller than or equal '
'to group size. Otherwise,'
' batch size should be divisible by the group size.'
f'But got batch size {x.shape[0]},'
f' group size {self.group_size}')
assert x.shape[1] % self.channel_groups == 0, (
'"channel_groups" must be divided by the feature channels. '
f'channel_groups: {self.channel_groups}, '
f'feature channels: {x.shape[1]}')
n, c, h, w = x.shape
group_size = min(n, self.group_size)
# [G, M, Gc, C', H, W]
y = torch.reshape(x, (group_size, -1, self.channel_groups,
c // self.channel_groups, h, w))
y = torch.var(y, dim=0, unbiased=False)
y = torch.sqrt(y + self.eps)
# [M, 1, 1, 1]
y = y.mean(dim=(2, 3, 4), keepdim=True).squeeze(2)
y = y.repeat(group_size, 1, h, w)
return torch.cat([x, y], dim=1)
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import scipy
import torch
import torch.nn as nn
from mmgen.models.builder import MODULES
from mmgen.ops import bias_act, conv2d_gradfix, filtered_lrelu
def modulated_conv2d(
x,
w,
s,
demodulate=True,
padding=0,
input_gain=None,
):
"""Modulated Conv2d in StyleGANv3.
Args:
x (torch.Tensor): Input tensor with shape (batch_size, in_channels,
height, width).
w (torch.Tensor): Weight of modulated convolution with shape
(out_channels, in_channels, kernel_height, kernel_width).
s (torch.Tensor): Style tensor with shape (batch_size, in_channels).
demodulate (bool): Whether apply weight demodulation. Defaults to True.
padding (int or list[int]): Convolution padding. Defaults to 0.
input_gain (list[int]): Scaling factors for input. Defaults to None.
Returns:
torch.Tensor: Convolution Output.
"""
batch_size = int(x.shape[0])
_, in_channels, kh, kw = w.shape
# Pre-normalize inputs.
if demodulate:
w = w * w.square().mean([1, 2, 3], keepdim=True).rsqrt()
s = s * s.square().mean().rsqrt()
# Modulate weights.
w = w.unsqueeze(0) # [NOIkk]
w = w * s.unsqueeze(1).unsqueeze(3).unsqueeze(4) # [NOIkk]
# Demodulate weights.
if demodulate:
dcoefs = (w.square().sum(dim=[2, 3, 4]) + 1e-8).rsqrt() # [NO]
w = w * dcoefs.unsqueeze(2).unsqueeze(3).unsqueeze(4) # [NOIkk]
# Apply input scaling.
if input_gain is not None:
input_gain = input_gain.expand(batch_size, in_channels) # [NI]
w = w * input_gain.unsqueeze(1).unsqueeze(3).unsqueeze(4) # [NOIkk]
# Execute as one fused op using grouped convolution.
x = x.reshape(1, -1, *x.shape[2:])
w = w.reshape(-1, in_channels, kh, kw)
x = conv2d_gradfix.conv2d(
input=x, weight=w.to(x.dtype), padding=padding, groups=batch_size)
x = x.reshape(batch_size, -1, *x.shape[2:])
return x
class FullyConnectedLayer(nn.Module):
"""Fully connected layer used in StyleGANv3.
Args:
in_features (int): Number of channels in the input feature.
out_features (int): Number of channels in the out feature.
activation (str, optional): Activation function with choices 'relu',
'lrelu', 'linear'. 'linear' means no extra activation.
Defaults to 'linear'.
bias (bool, optional): Whether to use additive bias. Defaults to True.
lr_multiplier (float, optional): Equalized learning rate multiplier.
Defaults to 1..
weight_init (float, optional): Weight multiplier for initialization.
Defaults to 1..
bias_init (float, optional): Initial bias. Defaults to 0..
"""
def __init__(self,
in_features,
out_features,
activation='linear',
bias=True,
lr_multiplier=1.,
weight_init=1.,
bias_init=0.):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.activation = activation
self.weight = torch.nn.Parameter(
torch.randn([out_features, in_features]) *
(weight_init / lr_multiplier))
bias_init = np.broadcast_to(
np.asarray(bias_init, dtype=np.float32), [out_features])
self.bias = torch.nn.Parameter(
torch.from_numpy(bias_init / lr_multiplier)) if bias else None
self.weight_gain = lr_multiplier / np.sqrt(in_features)
self.bias_gain = lr_multiplier
def forward(self, x):
"""Forward function."""
w = self.weight.to(x.dtype) * self.weight_gain
b = self.bias
if b is not None:
b = b.to(x.dtype)
if self.bias_gain != 1:
b = b * self.bias_gain
if self.activation == 'linear' and b is not None:
x = torch.addmm(b.unsqueeze(0), x, w.t())
else:
x = x.matmul(w.t())
x = bias_act.bias_act(x, b, act=self.activation)
return x
@MODULES.register_module()
class MappingNetwork(nn.Module):
"""Style mapping network used in StyleGAN3. The main difference between it
and styleganv1,v2 is that mean latent is registered as a buffer and dynamic
updated during training.
Args:
noise_size (int, optional): Size of the input noise vector.
c_dim (int, optional): Size of the input noise vector.
style_channels (int): The number of channels for style code.
num_ws (int): The repeat times of w latent.
num_layers (int, optional): The number of layers of mapping network.
Defaults to 2.
lr_multiplier (float, optional): Equalized learning rate multiplier.
Defaults to 0.01.
w_avg_beta (float, optional): The value used for update `w_avg`.
Defaults to 0.998.
"""
def __init__(self,
noise_size,
style_channels,
num_ws,
c_dim=0,
num_layers=2,
lr_multiplier=0.01,
w_avg_beta=0.998):
super().__init__()
self.noise_size = noise_size
self.c_dim = c_dim
self.style_channels = style_channels
self.num_ws = num_ws
self.num_layers = num_layers
self.w_avg_beta = w_avg_beta
# Construct layers.
self.embed = FullyConnectedLayer(
self.c_dim, self.style_channels) if self.c_dim > 0 else None
features = [
self.noise_size + (self.style_channels if self.c_dim > 0 else 0)
] + [self.style_channels] * self.num_layers
for idx, in_features, out_features in zip(
range(num_layers), features[:-1], features[1:]):
layer = FullyConnectedLayer(
in_features,
out_features,
activation='lrelu',
lr_multiplier=lr_multiplier)
setattr(self, f'fc{idx}', layer)
self.register_buffer('w_avg', torch.zeros([style_channels]))
def forward(self,
z,
c=None,
truncation=1,
num_truncation_layer=None,
update_emas=False):
"""Style mapping function.
Args:
z (torch.Tensor): Input noise tensor.
c (torch.Tensor, optional): Input label tensor. Defaults to None.
truncation (float, optional): Truncation factor. Give value less
than 1., the truncation trick will be adopted. Defaults to 1.
num_truncation_layer (int, optional): Number of layers use
truncated latent. Defaults to None.
update_emas (bool, optional): Whether update moving average of
mean latent. Defaults to False.
Returns:
torch.Tensor: W-plus latent.
"""
if num_truncation_layer is None:
num_truncation_layer = self.num_ws
# Embed, normalize, and concatenate inputs.
x = z.to(torch.float32)
x = x * (x.square().mean(1, keepdim=True) + 1e-8).rsqrt()
if self.c_dim > 0:
y = self.embed(c.to(torch.float32))
y = y * (y.square().mean(1, keepdim=True) + 1e-8).rsqrt()
x = torch.cat([x, y], dim=1) if x is not None else y
# Execute layers.
for idx in range(self.num_layers):
x = getattr(self, f'fc{idx}')(x)
# Update moving average of W.
if update_emas:
self.w_avg.copy_(x.detach().mean(dim=0).lerp(
self.w_avg, self.w_avg_beta))
# Broadcast and apply truncation.
x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
if truncation != 1:
x[:, :num_truncation_layer] = self.w_avg.lerp(
x[:, :num_truncation_layer], truncation)
return x
class SynthesisInput(nn.Module):
"""Module which generate input for synthesis layer.
Args:
style_channels (int): The number of channels for style code.
channels (int): The number of output channel.
size (int): The size of sampling grid.
sampling_rate (int): Sampling rate for construct sampling grid.
bandwidth (float): Bandwidth of random frequencies.
"""
def __init__(self, style_channels, channels, size, sampling_rate,
bandwidth):
super().__init__()
self.style_channels = style_channels
self.channels = channels
self.size = np.broadcast_to(np.asarray(size), [2])
self.sampling_rate = sampling_rate
self.bandwidth = bandwidth
# Draw random frequencies from uniform 2D disc.
freqs = torch.randn([self.channels, 2])
radii = freqs.square().sum(dim=1, keepdim=True).sqrt()
freqs /= radii * radii.square().exp().pow(0.25)
freqs *= bandwidth
phases = torch.rand([self.channels]) - 0.5
# Setup parameters and buffers.
self.weight = torch.nn.Parameter(
torch.randn([self.channels, self.channels]))
self.affine = FullyConnectedLayer(
style_channels, 4, weight_init=0, bias_init=[1, 0, 0, 0])
self.register_buffer('transform', torch.eye(
3, 3)) # User-specified inverse transform wrt. resulting image.
self.register_buffer('freqs', freqs)
self.register_buffer('phases', phases)
def forward(self, w):
"""Forward function."""
# Introduce batch dimension.
transforms = self.transform.unsqueeze(0) # [batch, row, col]
freqs = self.freqs.unsqueeze(0) # [batch, channel, xy]
phases = self.phases.unsqueeze(0) # [batch, channel]
# Apply learned transformation.
t = self.affine(w) # t = (r_c, r_s, t_x, t_y)
t = t / t[:, :2].norm(
dim=1, keepdim=True) # t' = (r'_c, r'_s, t'_x, t'_y)
m_r = torch.eye(
3, device=w.device).unsqueeze(0).repeat(
[w.shape[0], 1, 1]) # Inverse rotation wrt. resulting image.
m_r[:, 0, 0] = t[:, 0] # r'_c
m_r[:, 0, 1] = -t[:, 1] # r'_s
m_r[:, 1, 0] = t[:, 1] # r'_s
m_r[:, 1, 1] = t[:, 0] # r'_c
m_t = torch.eye(
3, device=w.device).unsqueeze(0).repeat(
[w.shape[0], 1,
1]) # Inverse translation wrt. resulting image.
m_t[:, 0, 2] = -t[:, 2] # t'_x
m_t[:, 1, 2] = -t[:, 3] # t'_y
# First rotate resulting image, then translate
# and finally apply user-specified transform.
transforms = m_r @ m_t @ transforms
# Transform frequencies.
phases = phases + (freqs @ transforms[:, :2, 2:]).squeeze(2)
freqs = freqs @ transforms[:, :2, :2]
# Dampen out-of-band frequencies
# that may occur due to the user-specified transform.
amplitudes = (1 - (freqs.norm(dim=2) - self.bandwidth) /
(self.sampling_rate / 2 - self.bandwidth)).clamp(0, 1)
# Construct sampling grid.
theta = torch.eye(2, 3, device=w.device)
theta[0, 0] = 0.5 * self.size[0] / self.sampling_rate
theta[1, 1] = 0.5 * self.size[1] / self.sampling_rate
grids = torch.nn.functional.affine_grid(
theta.unsqueeze(0), [1, 1, self.size[1], self.size[0]],
align_corners=False)
# Compute Fourier features.
x = (grids.unsqueeze(3) @ freqs.permute(
0, 2, 1).unsqueeze(1).unsqueeze(2)).squeeze(
3) # [batch, height, width, channel]
x = x + phases.unsqueeze(1).unsqueeze(2)
x = torch.sin(x * (np.pi * 2))
x = x * amplitudes.unsqueeze(1).unsqueeze(2)
# Apply trainable mapping.
weight = self.weight / np.sqrt(self.channels)
x = x @ weight.t()
# Ensure correct shape.
x = x.permute(0, 3, 1, 2) # [batch, channel, height, width]
return x
class SynthesisLayer(nn.Module):
"""Layer of Synthesis network for stylegan3.
Args:
style_channels (int): The number of channels for style code.
is_torgb (bool): Whether output of this layer is transformed to
rgb image.
is_critically_sampled (bool): Whether filter cutoff is set exactly
at the bandlimit.
use_fp16 (bool, optional): Whether to use fp16 training in this
module. If this flag is `True`, the whole module will be wrapped
with ``auto_fp16``.
in_channels (int): The channel number of the input feature map.
out_channels (int): The channel number of the output feature map.
in_size (int): The input size of feature map.
out_size (int): The output size of feature map.
in_sampling_rate (int): Sampling rate for upsampling filter.
out_sampling_rate (int): Sampling rate for downsampling filter.
in_cutoff (float): Cutoff frequency for upsampling filter.
out_cutoff (float): Cutoff frequency for downsampling filter.
in_half_width (float): The approximate width of the transition region
for upsampling filter.
out_half_width (float): The approximate width of the transition region
for downsampling filter.
conv_kernel (int, optional): The kernel of modulated convolution.
Defaults to 3.
filter_size (int, optional): Base filter size. Defaults to 6.
lrelu_upsampling (int, optional): Upsamling rate for `filtered_lrelu`.
Defaults to 2.
use_radial_filters (bool, optional): Whether use radially symmetric
jinc-based filter in downsamping filter. Defaults to False.
conv_clamp (int, optional): Clamp bound for convolution.
Defaults to 256.
magnitude_ema_beta (float, optional): Beta coefficient for calculating
input magnitude ema. Defaults to 0.999.
"""
def __init__(
self,
style_channels,
is_torgb,
is_critically_sampled,
use_fp16,
in_channels,
out_channels,
in_size,
out_size,
in_sampling_rate,
out_sampling_rate,
in_cutoff,
out_cutoff,
in_half_width,
out_half_width,
conv_kernel=3,
filter_size=6,
lrelu_upsampling=2,
use_radial_filters=False,
conv_clamp=256,
magnitude_ema_beta=0.999,
):
super().__init__()
self.style_channels = style_channels
self.is_torgb = is_torgb
self.is_critically_sampled = is_critically_sampled
self.use_fp16 = use_fp16
self.in_channels = in_channels
self.out_channels = out_channels
self.in_size = np.broadcast_to(np.asarray(in_size), [2])
self.out_size = np.broadcast_to(np.asarray(out_size), [2])
self.in_sampling_rate = in_sampling_rate
self.out_sampling_rate = out_sampling_rate
self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (
1 if is_torgb else lrelu_upsampling)
self.in_cutoff = in_cutoff
self.out_cutoff = out_cutoff
self.in_half_width = in_half_width
self.out_half_width = out_half_width
self.conv_kernel = 1 if is_torgb else conv_kernel
self.conv_clamp = conv_clamp
self.magnitude_ema_beta = magnitude_ema_beta
# Setup parameters and buffers.
self.affine = FullyConnectedLayer(
self.style_channels, self.in_channels, bias_init=1)
self.weight = torch.nn.Parameter(
torch.randn([
self.out_channels, self.in_channels, self.conv_kernel,
self.conv_kernel
]))
self.bias = torch.nn.Parameter(torch.zeros([self.out_channels]))
self.register_buffer('magnitude_ema', torch.ones([]))
# Design upsampling filter.
self.up_factor = int(
np.rint(self.tmp_sampling_rate / self.in_sampling_rate))
assert self.in_sampling_rate * self.up_factor == self.tmp_sampling_rate
self.up_taps = (
filter_size *
self.up_factor if self.up_factor > 1 and not self.is_torgb else 1)
self.register_buffer(
'up_filter',
self.design_lowpass_filter(
numtaps=self.up_taps,
cutoff=self.in_cutoff,
width=self.in_half_width * 2,
fs=self.tmp_sampling_rate))
# Design downsampling filter.
self.down_factor = int(
np.rint(self.tmp_sampling_rate / self.out_sampling_rate))
assert (self.out_sampling_rate *
self.down_factor == self.tmp_sampling_rate)
self.down_taps = (
filter_size * self.down_factor
if self.down_factor > 1 and not self.is_torgb else 1)
self.down_radial = (
use_radial_filters and not self.is_critically_sampled)
self.register_buffer(
'down_filter',
self.design_lowpass_filter(
numtaps=self.down_taps,
cutoff=self.out_cutoff,
width=self.out_half_width * 2,
fs=self.tmp_sampling_rate,
radial=self.down_radial))
# Compute padding.
pad_total = (
self.out_size - 1
) * self.down_factor + 1 # Desired output size before downsampling.
pad_total -= (self.in_size + self.conv_kernel -
1) * self.up_factor # Input size after upsampling.
pad_total += self.up_taps + self.down_taps - 2
pad_lo = (pad_total + self.up_factor) // 2
pad_hi = pad_total - pad_lo
self.padding = [
int(pad_lo[0]),
int(pad_hi[0]),
int(pad_lo[1]),
int(pad_hi[1])
]
def forward(self, x, w, force_fp32=False, update_emas=False):
"""Forward function for synthesis layer.
Args:
x (torch.Tensor): Input feature map tensor.
w (torch.Tensor): Input style tensor.
force_fp32 (bool, optional): Force fp32 ignore the weights.
Defaults to True.
update_emas (bool, optional): Whether update moving average of
input magnitude. Defaults to False.
Returns:
torch.Tensor: Output feature map tensor.
"""
# Track input magnitude.
if update_emas:
with torch.autograd.profiler.record_function(
'update_magnitude_ema'):
magnitude_cur = x.detach().to(torch.float32).square().mean()
self.magnitude_ema.copy_(
magnitude_cur.lerp(self.magnitude_ema,
self.magnitude_ema_beta))
input_gain = self.magnitude_ema.rsqrt()
# Execute affine layer.
styles = self.affine(w)
if self.is_torgb:
weight_gain = 1 / np.sqrt(self.in_channels * (self.conv_kernel**2))
styles = styles * weight_gain
# Execute modulated conv2d.
dtype = torch.float16 if (self.use_fp16 and not force_fp32 and
x.device.type == 'cuda') else torch.float32
x = modulated_conv2d(
x=x.to(dtype),
w=self.weight,
s=styles,
padding=self.conv_kernel - 1,
demodulate=(not self.is_torgb),
input_gain=input_gain)
# Execute bias, filtered leaky ReLU, and clamping.
gain = 1 if self.is_torgb else np.sqrt(2)
slope = 1 if self.is_torgb else 0.2
x = filtered_lrelu.filtered_lrelu(
x=x,
fu=self.up_filter,
fd=self.down_filter,
b=self.bias.to(x.dtype),
up=self.up_factor,
down=self.down_factor,
padding=self.padding,
gain=gain,
slope=slope,
clamp=self.conv_clamp)
# Ensure correct shape and dtype.
assert x.dtype == dtype
return x
@staticmethod
def design_lowpass_filter(numtaps, cutoff, width, fs, radial=False):
"""Design lowpass filter giving related arguments.
Args:
numtaps (int): Length of the filter. `numtaps` must be odd if a
passband includes the Nyquist frequency.
cutoff (float): Cutoff frequency of filter
width (float): The approximate width of the transition region.
fs (float): The sampling frequency of the signal.
radial (bool, optional): Whether use radially symmetric jinc-based
filter. Defaults to False.
Returns:
torch.Tensor: Kernel of lowpass filter.
"""
assert numtaps >= 1
# Identity filter.
if numtaps == 1:
return None
# Separable Kaiser low-pass filter.
if not radial:
f = scipy.signal.firwin(
numtaps=numtaps, cutoff=cutoff, width=width, fs=fs)
return torch.as_tensor(f, dtype=torch.float32)
# Radially symmetric jinc-based filter.
x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs
r = np.hypot(*np.meshgrid(x, x))
f = scipy.special.j1(2 * cutoff * (np.pi * r)) / (np.pi * r)
beta = scipy.signal.kaiser_beta(
scipy.signal.kaiser_atten(numtaps, width / (fs / 2)))
w = np.kaiser(numtaps, beta)
f *= np.outer(w, w)
f /= np.sum(f)
return torch.as_tensor(f, dtype=torch.float32)
@MODULES.register_module()
class SynthesisNetwork(nn.Module):
"""Synthesis network for stylegan3.
Args:
style_channels (int): The number of channels for style code.
out_size (int): The resolution of output image.
img_channels (int): The number of channels for output image.
channel_base (int, optional): Overall multiplier for the number of
channels. Defaults to 32768.
channel_max (int, optional): Maximum number of channels in any layer.
Defaults to 512.
num_layers (int, optional): Total number of layers, excluding Fourier
features and ToRGB. Defaults to 14.
num_critical (int, optional): Number of critically sampled layers at
the end. Defaults to 2.
first_cutoff (int, optional): Cutoff frequency of the first layer.
Defaults to 2.
first_stopband (int, optional): Minimum stopband of the first layer.
Defaults to 2**2.1.
last_stopband_rel (float, optional): Minimum stopband of the last
layer, expressed relative to the cutoff. Defaults to 2**0.3.
margin_size (int, optional): Number of additional pixels outside the
image. Defaults to 10.
output_scale (float, optional): Scale factor for output value.
Defaults to 0.25.
num_fp16_res (int, optional): Number of first few layers use fp16.
Defaults to 4.
"""
def __init__(
self,
style_channels,
out_size,
img_channels,
channel_base=32768,
channel_max=512,
num_layers=14,
num_critical=2,
first_cutoff=2,
first_stopband=2**2.1,
last_stopband_rel=2**0.3,
margin_size=10,
output_scale=0.25,
num_fp16_res=4,
**layer_kwargs,
):
super().__init__()
self.style_channels = style_channels
self.num_ws = num_layers + 2
self.out_size = out_size
self.img_channels = img_channels
self.num_layers = num_layers
self.num_critical = num_critical
self.margin_size = margin_size
self.output_scale = output_scale
self.num_fp16_res = num_fp16_res
# Geometric progression of layer cutoffs and min. stopbands.
last_cutoff = self.out_size / 2 # f_{c,N}
last_stopband = last_cutoff * last_stopband_rel # f_{t,N}
exponents = np.minimum(
np.arange(self.num_layers + 1) /
(self.num_layers - self.num_critical), 1)
cutoffs = first_cutoff * (last_cutoff /
first_cutoff)**exponents # f_c[i]
stopbands = first_stopband * (last_stopband /
first_stopband)**exponents # f_t[i]
# Compute remaining layer parameters.
sampling_rates = np.exp2(
np.ceil(np.log2(np.minimum(stopbands * 2, self.out_size)))) # s[i]
half_widths = np.maximum(stopbands,
sampling_rates / 2) - cutoffs # f_h[i]
sizes = sampling_rates + self.margin_size * 2
sizes[-2:] = self.out_size
channels = np.rint(
np.minimum((channel_base / 2) / cutoffs, channel_max))
channels[-1] = self.img_channels
# Construct layers.
self.input = SynthesisInput(
style_channels=self.style_channels,
channels=int(channels[0]),
size=int(sizes[0]),
sampling_rate=sampling_rates[0],
bandwidth=cutoffs[0])
self.layer_names = []
for idx in range(self.num_layers + 1):
prev = max(idx - 1, 0)
is_torgb = (idx == self.num_layers)
is_critically_sampled = (
idx >= self.num_layers - self.num_critical)
use_fp16 = (
sampling_rates[idx] * (2**self.num_fp16_res) > self.out_size)
layer = SynthesisLayer(
style_channels=self.style_channels,
is_torgb=is_torgb,
is_critically_sampled=is_critically_sampled,
use_fp16=use_fp16,
in_channels=int(channels[prev]),
out_channels=int(channels[idx]),
in_size=int(sizes[prev]),
out_size=int(sizes[idx]),
in_sampling_rate=int(sampling_rates[prev]),
out_sampling_rate=int(sampling_rates[idx]),
in_cutoff=cutoffs[prev],
out_cutoff=cutoffs[idx],
in_half_width=half_widths[prev],
out_half_width=half_widths[idx],
**layer_kwargs)
name = f'L{idx}_{layer.out_size[0]}_{layer.out_channels}'
setattr(self, name, layer)
self.layer_names.append(name)
def forward(self, ws, **layer_kwargs):
"""Forward function."""
ws = ws.to(torch.float32).unbind(dim=1)
# Execute layers.
x = self.input(ws[0])
for name, w in zip(self.layer_names, ws[1:]):
x = getattr(self, name)(x, w, **layer_kwargs)
if self.output_scale != 1:
x = x * self.output_scale
# Ensure correct shape and dtype.
x = x.to(torch.float32)
return x
# Copyright (c) OpenMMLab. All rights reserved.
import random
from copy import deepcopy
import mmcv
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmgen.models.architectures import PixelNorm
from mmgen.models.architectures.common import get_module_device
from mmgen.models.builder import MODULES, build_module
from .modules.styleganv2_modules import (ConstantInput, ConvDownLayer,
EqualLinearActModule,
ModMBStddevLayer,
ModulatedPEStyleConv, ModulatedToRGB,
ResBlock)
from .utils import get_mean_latent, style_mixing
@MODULES.register_module()
class MSStyleGANv2Generator(nn.Module):
"""StyleGAN2 Generator.
In StyleGAN2, we use a static architecture composing of a style mapping
module and number of convolutional style blocks. More details can be found
in: Analyzing and Improving the Image Quality of StyleGAN CVPR2020.
Args:
out_size (int): The output size of the StyleGAN2 generator.
style_channels (int): The number of channels for style code.
num_mlps (int, optional): The number of MLP layers. Defaults to 8.
channel_multiplier (int, optional): The multiplier factor for the
channel number. Defaults to 2.
blur_kernel (list, optional): The blurry kernel. Defaults
to [1, 3, 3, 1].
lr_mlp (float, optional): The learning rate for the style mapping
layer. Defaults to 0.01.
default_style_mode (str, optional): The default mode of style mixing.
In training, we defaultly adopt mixing style mode. However, in the
evaluation, we use 'single' style mode. `['mix', 'single']` are
currently supported. Defaults to 'mix'.
eval_style_mode (str, optional): The evaluation mode of style mixing.
Defaults to 'single'.
mix_prob (float, optional): Mixing probability. The value should be
in range of [0, 1]. Defaults to 0.9.
"""
def __init__(self,
out_size,
style_channels,
num_mlps=8,
channel_multiplier=2,
blur_kernel=[1, 3, 3, 1],
lr_mlp=0.01,
default_style_mode='mix',
eval_style_mode='single',
mix_prob=0.9,
no_pad=False,
deconv2conv=False,
interp_pad=None,
up_config=dict(scale_factor=2, mode='nearest'),
up_after_conv=False,
head_pos_encoding=None,
head_pos_size=(4, 4),
interp_head=False):
super().__init__()
self.out_size = out_size
self.style_channels = style_channels
self.num_mlps = num_mlps
self.channel_multiplier = channel_multiplier
self.lr_mlp = lr_mlp
self._default_style_mode = default_style_mode
self.default_style_mode = default_style_mode
self.eval_style_mode = eval_style_mode
self.mix_prob = mix_prob
self.no_pad = no_pad
self.deconv2conv = deconv2conv
self.interp_pad = interp_pad
self.with_interp_pad = interp_pad is not None
self.up_config = deepcopy(up_config)
self.up_after_conv = up_after_conv
self.head_pos_encoding = head_pos_encoding
self.head_pos_size = head_pos_size
self.interp_head = interp_head
# define style mapping layers
mapping_layers = [PixelNorm()]
for _ in range(num_mlps):
mapping_layers.append(
EqualLinearActModule(
style_channels,
style_channels,
equalized_lr_cfg=dict(lr_mul=lr_mlp, gain=1.),
act_cfg=dict(type='fused_bias')))
self.style_mapping = nn.Sequential(*mapping_layers)
self.channels = {
4: 512,
8: 512,
16: 512,
32: 512,
64: 256 * channel_multiplier,
128: 128 * channel_multiplier,
256: 64 * channel_multiplier,
512: 32 * channel_multiplier,
1024: 16 * channel_multiplier,
}
in_ch = self.channels[4]
# constant input layer
if self.head_pos_encoding:
if self.head_pos_encoding['type'] in [
'CatersianGrid', 'CSG', 'CSG2d'
]:
in_ch = 2
self.head_pos_enc = build_module(self.head_pos_encoding)
else:
size_ = 4
if self.no_pad:
size_ += 2
self.constant_input = ConstantInput(self.channels[4], size=size_)
# 4x4 stage
self.conv1 = ModulatedPEStyleConv(
in_ch,
self.channels[4],
kernel_size=3,
style_channels=style_channels,
blur_kernel=blur_kernel,
deconv2conv=self.deconv2conv,
no_pad=self.no_pad,
up_config=self.up_config,
interp_pad=self.interp_pad)
self.to_rgb1 = ModulatedToRGB(
self.channels[4], style_channels, upsample=False)
# generator backbone (8x8 --> higher resolutions)
self.log_size = int(np.log2(self.out_size))
self.convs = nn.ModuleList()
self.upsamples = nn.ModuleList()
self.to_rgbs = nn.ModuleList()
in_channels_ = self.channels[4]
for i in range(3, self.log_size + 1):
out_channels_ = self.channels[2**i]
self.convs.append(
ModulatedPEStyleConv(
in_channels_,
out_channels_,
3,
style_channels,
upsample=True,
blur_kernel=blur_kernel,
deconv2conv=self.deconv2conv,
no_pad=self.no_pad,
up_config=self.up_config,
interp_pad=self.interp_pad,
up_after_conv=self.up_after_conv))
self.convs.append(
ModulatedPEStyleConv(
out_channels_,
out_channels_,
3,
style_channels,
upsample=False,
blur_kernel=blur_kernel,
deconv2conv=self.deconv2conv,
no_pad=self.no_pad,
up_config=self.up_config,
interp_pad=self.interp_pad,
up_after_conv=self.up_after_conv))
self.to_rgbs.append(
ModulatedToRGB(out_channels_, style_channels, upsample=True))
in_channels_ = out_channels_
self.num_latents = self.log_size * 2 - 2
self.num_injected_noises = self.num_latents - 1
# register buffer for injected noises
noises = self.make_injected_noise()
for layer_idx in range(self.num_injected_noises):
self.register_buffer(f'injected_noise_{layer_idx}',
noises[layer_idx])
def train(self, mode=True):
if mode:
if self.default_style_mode != self._default_style_mode:
mmcv.print_log(
f'Switch to train style mode: {self._default_style_mode}',
'mmgen')
self.default_style_mode = self._default_style_mode
else:
if self.default_style_mode != self.eval_style_mode:
mmcv.print_log(
f'Switch to evaluation style mode: {self.eval_style_mode}',
'mmgen')
self.default_style_mode = self.eval_style_mode
return super(MSStyleGANv2Generator, self).train(mode)
def make_injected_noise(self, chosen_scale=0):
device = get_module_device(self)
base_scale = 2**2 + chosen_scale
noises = [torch.randn(1, 1, base_scale, base_scale, device=device)]
for i in range(3, self.log_size + 1):
for n in range(2):
_pad = 0
if self.no_pad and not self.up_after_conv and n == 0:
_pad = 2
noises.append(
torch.randn(
1,
1,
base_scale * 2**(i - 2) + _pad,
base_scale * 2**(i - 2) + _pad,
device=device))
return noises
def get_mean_latent(self, num_samples=4096, **kwargs):
"""Get mean latent of W space in this generator.
Args:
num_samples (int, optional): Number of sample times. Defaults
to 4096.
Returns:
Tensor: Mean latent of this generator.
"""
return get_mean_latent(self, num_samples, **kwargs)
def style_mixing(self,
n_source,
n_target,
inject_index=1,
truncation_latent=None,
truncation=0.7,
chosen_scale=0):
return style_mixing(
self,
n_source=n_source,
n_target=n_target,
inject_index=inject_index,
truncation_latent=truncation_latent,
truncation=truncation,
style_channels=self.style_channels,
chosen_scale=chosen_scale)
def forward(self,
styles,
num_batches=-1,
return_noise=False,
return_latents=False,
inject_index=None,
truncation=1,
truncation_latent=None,
input_is_latent=False,
injected_noise=None,
randomize_noise=True,
chosen_scale=0):
"""Forward function.
This function has been integrated with the truncation trick. Please
refer to the usage of `truncation` and `truncation_latent`.
Args:
styles (torch.Tensor | list[torch.Tensor] | callable | None): In
StyleGAN2, you can provide noise tensor or latent tensor. Given
a list containing more than one noise or latent tensors, style
mixing trick will be used in training. Of course, You can
directly give a batch of noise through a ``torch.Tensor`` or
offer a callable function to sample a batch of noise data.
Otherwise, the ``None`` indicates to use the default noise
sampler.
num_batches (int, optional): The number of batch size.
Defaults to 0.
return_noise (bool, optional): If True, ``noise_batch`` will be
returned in a dict with ``fake_img``. Defaults to False.
return_latents (bool, optional): If True, ``latent`` will be
returned in a dict with ``fake_img``. Defaults to False.
inject_index (int | None, optional): The index number for mixing
style codes. Defaults to None.
truncation (float, optional): Truncation factor. Give value less
than 1., the truncation trick will be adopted. Defaults to 1.
truncation_latent (torch.Tensor, optional): Mean truncation latent.
Defaults to None.
input_is_latent (bool, optional): If `True`, the input tensor is
the latent tensor. Defaults to False.
injected_noise (torch.Tensor | None, optional): Given a tensor, the
random noise will be fixed as this input injected noise.
Defaults to None.
randomize_noise (bool, optional): If `False`, images are sampled
with the buffered noise tensor injected to the style conv
block. Defaults to True.
Returns:
torch.Tensor | dict: Generated image tensor or dictionary \
containing more data.
"""
# receive noise and conduct sanity check.
if isinstance(styles, torch.Tensor):
assert styles.shape[1] == self.style_channels
styles = [styles]
elif mmcv.is_seq_of(styles, torch.Tensor):
for t in styles:
assert t.shape[-1] == self.style_channels
# receive a noise generator and sample noise.
elif callable(styles):
device = get_module_device(self)
noise_generator = styles
assert num_batches > 0
if self.default_style_mode == 'mix' and random.random(
) < self.mix_prob:
styles = [
noise_generator((num_batches, self.style_channels))
for _ in range(2)
]
else:
styles = [noise_generator((num_batches, self.style_channels))]
styles = [s.to(device) for s in styles]
# otherwise, we will adopt default noise sampler.
else:
device = get_module_device(self)
assert num_batches > 0 and not input_is_latent
if self.default_style_mode == 'mix' and random.random(
) < self.mix_prob:
styles = [
torch.randn((num_batches, self.style_channels))
for _ in range(2)
]
else:
styles = [torch.randn((num_batches, self.style_channels))]
styles = [s.to(device) for s in styles]
if not input_is_latent:
noise_batch = styles
styles = [self.style_mapping(s) for s in styles]
else:
noise_batch = None
if injected_noise is None:
if randomize_noise:
injected_noise = [None] * self.num_injected_noises
elif chosen_scale > 0:
if not hasattr(self, f'injected_noise_{chosen_scale}_0'):
noises_ = self.make_injected_noise(chosen_scale)
for i in range(self.num_injected_noises):
setattr(self, f'injected_noise_{chosen_scale}_{i}',
noises_[i])
injected_noise = [
getattr(self, f'injected_noise_{chosen_scale}_{i}')
for i in range(self.num_injected_noises)
]
else:
injected_noise = [
getattr(self, f'injected_noise_{i}')
for i in range(self.num_injected_noises)
]
# use truncation trick
if truncation < 1:
style_t = []
# calculate truncation latent on the fly
if truncation_latent is None and not hasattr(
self, 'truncation_latent'):
self.truncation_latent = self.get_mean_latent()
truncation_latent = self.truncation_latent
elif truncation_latent is None and hasattr(self,
'truncation_latent'):
truncation_latent = self.truncation_latent
for style in styles:
style_t.append(truncation_latent + truncation *
(style - truncation_latent))
styles = style_t
# no style mixing
if len(styles) < 2:
inject_index = self.num_latents
if styles[0].ndim < 3:
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
else:
latent = styles[0]
# style mixing
else:
if inject_index is None:
inject_index = random.randint(1, self.num_latents - 1)
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
latent2 = styles[1].unsqueeze(1).repeat(
1, self.num_latents - inject_index, 1)
latent = torch.cat([latent, latent2], 1)
if isinstance(chosen_scale, int):
chosen_scale = (chosen_scale, chosen_scale)
# 4x4 stage
if self.head_pos_encoding:
if self.interp_head:
out = self.head_pos_enc.make_grid2d(self.head_pos_size[0],
self.head_pos_size[1],
latent.size(0))
h_in = self.head_pos_size[0] + chosen_scale[0]
w_in = self.head_pos_size[1] + chosen_scale[1]
out = F.interpolate(
out,
size=(h_in, w_in),
mode='bilinear',
align_corners=True)
else:
out = self.head_pos_enc.make_grid2d(
self.head_pos_size[0] + chosen_scale[0],
self.head_pos_size[1] + chosen_scale[1], latent.size(0))
out = out.to(latent)
else:
out = self.constant_input(latent)
if chosen_scale[0] != 0 or chosen_scale[1] != 0:
out = F.interpolate(
out,
size=(out.shape[2] + chosen_scale[0],
out.shape[3] + chosen_scale[1]),
mode='bilinear',
align_corners=True)
out = self.conv1(out, latent[:, 0], noise=injected_noise[0])
skip = self.to_rgb1(out, latent[:, 1])
_index = 1
# 8x8 ---> higher resolutions
for up_conv, conv, noise1, noise2, to_rgb in zip(
self.convs[::2], self.convs[1::2], injected_noise[1::2],
injected_noise[2::2], self.to_rgbs):
out = up_conv(out, latent[:, _index], noise=noise1)
out = conv(out, latent[:, _index + 1], noise=noise2)
skip = to_rgb(out, latent[:, _index + 2], skip)
_index += 2
img = skip
if return_latents or return_noise:
output_dict = dict(
fake_img=img,
latent=latent,
inject_index=inject_index,
noise_batch=noise_batch,
injected_noise=injected_noise)
return output_dict
return img
@MODULES.register_module()
class MSStyleGAN2Discriminator(nn.Module):
"""StyleGAN2 Discriminator.
The architecture of this discriminator is proposed in StyleGAN2. More
details can be found in: Analyzing and Improving the Image Quality of
StyleGAN CVPR2020.
Args:
in_size (int): The input size of images.
channel_multiplier (int, optional): The multiplier factor for the
channel number. Defaults to 2.
blur_kernel (list, optional): The blurry kernel. Defaults
to [1, 3, 3, 1].
mbstd_cfg (dict, optional): Configs for minibatch-stddev layer.
Defaults to dict(group_size=4, channel_groups=1).
"""
def __init__(self,
in_size,
channel_multiplier=2,
blur_kernel=[1, 3, 3, 1],
mbstd_cfg=dict(group_size=4, channel_groups=1),
with_adaptive_pool=False,
pool_size=(2, 2)):
super().__init__()
self.with_adaptive_pool = with_adaptive_pool
self.pool_size = pool_size
channels = {
4: 512,
8: 512,
16: 512,
32: 512,
64: 256 * channel_multiplier,
128: 128 * channel_multiplier,
256: 64 * channel_multiplier,
512: 32 * channel_multiplier,
1024: 16 * channel_multiplier,
}
log_size = int(np.log2(in_size))
in_channels = channels[in_size]
convs = [ConvDownLayer(3, channels[in_size], 1)]
for i in range(log_size, 2, -1):
out_channel = channels[2**(i - 1)]
convs.append(ResBlock(in_channels, out_channel, blur_kernel))
in_channels = out_channel
self.convs = nn.Sequential(*convs)
self.mbstd_layer = ModMBStddevLayer(**mbstd_cfg)
self.final_conv = ConvDownLayer(in_channels + 1, channels[4], 3)
if self.with_adaptive_pool:
self.adaptive_pool = nn.AdaptiveAvgPool2d(pool_size)
linear_in_channels = channels[4] * pool_size[0] * pool_size[1]
else:
linear_in_channels = channels[4] * 4 * 4
self.final_linear = nn.Sequential(
EqualLinearActModule(
linear_in_channels,
channels[4],
act_cfg=dict(type='fused_bias')),
EqualLinearActModule(channels[4], 1),
)
def forward(self, x):
"""Forward function.
Args:
x (torch.Tensor): Input image tensor.
Returns:
torch.Tensor: Predict score for the input image.
"""
x = self.convs(x)
x = self.mbstd_layer(x)
x = self.final_conv(x)
if self.with_adaptive_pool:
x = self.adaptive_pool(x)
x = x.view(x.shape[0], -1)
x = self.final_linear(x)
return x
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from ..common import get_module_device
@torch.no_grad()
def get_mean_latent(generator, num_samples=4096, bs_per_repeat=1024):
"""Get mean latent of W space in Style-based GANs.
Args:
generator (nn.Module): Generator of a Style-based GAN.
num_samples (int, optional): Number of sample times. Defaults to 4096.
bs_per_repeat (int, optional): Batch size of noises per sample.
Defaults to 1024.
Returns:
Tensor: Mean latent of this generator.
"""
device = get_module_device(generator)
mean_style = None
n_repeat = num_samples // bs_per_repeat
assert n_repeat * bs_per_repeat == num_samples
for _ in range(n_repeat):
style = generator.style_mapping(
torch.randn(bs_per_repeat,
generator.style_channels).to(device)).mean(
0, keepdim=True)
if mean_style is None:
mean_style = style
else:
mean_style += style
mean_style /= float(n_repeat)
return mean_style
@torch.no_grad()
def style_mixing(generator,
n_source,
n_target,
inject_index=1,
truncation_latent=None,
truncation=0.7,
style_channels=512,
**kwargs):
device = get_module_device(generator)
source_code = torch.randn(n_source, style_channels).to(device)
target_code = torch.randn(n_target, style_channels).to(device)
source_image = generator(
source_code,
truncation_latent=truncation_latent,
truncation=truncation,
**kwargs)
h, w = source_image.shape[-2:]
images = [torch.ones(1, 3, h, w).to(device) * -1]
target_image = generator(
target_code,
truncation_latent=truncation_latent,
truncation=truncation,
**kwargs)
images.append(source_image)
for i in range(n_target):
image = generator(
[target_code[i].unsqueeze(0).repeat(n_source, 1), source_code],
truncation_latent=truncation_latent,
truncation=truncation,
inject_index=inject_index,
**kwargs)
images.append(target_image[i].unsqueeze(0))
images.append(image)
images = torch.cat(images, 0)
return images
# Copyright (c) OpenMMLab. All rights reserved.
from .generator_discriminator import WGANGPDiscriminator, WGANGPGenerator
__all__ = ['WGANGPDiscriminator', 'WGANGPGenerator']
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.cnn.bricks.upsample import build_upsample_layer
from mmgen.models.builder import MODULES
from ..common import get_module_device
from .modules import ConvLNModule, WGANDecisionHead, WGANNoiseTo2DFeat
@MODULES.register_module()
class WGANGPGenerator(nn.Module):
r"""Generator for WGANGP.
Implementation Details for WGANGP generator the same as training
configuration (a) described in PGGAN paper:
PROGRESSIVE GROWING OF GANS FOR IMPROVED QUALITY, STABILITY, AND VARIATION
https://research.nvidia.com/sites/default/files/pubs/2017-10_Progressive-Growing-of/karras2018iclr-paper.pdf # noqa
#. Adopt convolution architecture specified in appendix A.2;
#. Use batchnorm in the generator except for the final output layer;
#. Use ReLU in the generator except for the final output layer;
#. Use Tanh in the last layer;
#. Initialize all weights using He’s initializer.
Args:
noise_size (int): Size of the input noise vector.
out_scale (int): Output scale for the generated image.
conv_module_cfg (dict, optional): Config for the convolution
module used in this generator. Defaults to None.
upsample_cfg (dict, optional): Config for the upsampling operation.
Defaults to None.
"""
_default_channels_per_scale = {
'4': 512,
'8': 512,
'16': 256,
'32': 128,
'64': 64,
'128': 32
}
_default_conv_module_cfg = dict(
conv_cfg=None,
kernel_size=3,
stride=1,
padding=1,
bias=True,
act_cfg=dict(type='ReLU'),
norm_cfg=dict(type='BN'),
order=('conv', 'norm', 'act'))
_default_upsample_cfg = dict(type='nearest', scale_factor=2)
def __init__(self,
noise_size,
out_scale,
conv_module_cfg=None,
upsample_cfg=None):
super().__init__()
# set initial params
self.noise_size = noise_size
self.out_scale = out_scale
self.conv_module_cfg = deepcopy(self._default_conv_module_cfg)
if conv_module_cfg is not None:
self.conv_module_cfg.update(conv_module_cfg)
self.upsample_cfg = upsample_cfg if upsample_cfg else deepcopy(
self._default_upsample_cfg)
# set noise2feat head
self.noise2feat = WGANNoiseTo2DFeat(
self.noise_size, self._default_channels_per_scale['4'])
# set conv_blocks
self.conv_blocks = nn.ModuleList()
self.conv_blocks.append(ConvModule(512, 512, **self.conv_module_cfg))
log2scale = int(np.log2(self.out_scale))
for i in range(3, log2scale + 1):
self.conv_blocks.append(
build_upsample_layer(self._default_upsample_cfg))
self.conv_blocks.append(
ConvModule(self._default_channels_per_scale[str(2**(i - 1))],
self._default_channels_per_scale[str(2**i)],
**self.conv_module_cfg))
self.conv_blocks.append(
ConvModule(self._default_channels_per_scale[str(2**i)],
self._default_channels_per_scale[str(2**i)],
**self.conv_module_cfg))
self.to_rgb = ConvModule(
self._default_channels_per_scale[str(self.out_scale)],
kernel_size=1,
out_channels=3,
act_cfg=dict(type='Tanh'))
def forward(self, noise, num_batches=0, return_noise=False):
"""Forward function.
Args:
noise (torch.Tensor | callable | None): You can directly give a
batch of noise through a ``torch.Tensor`` or offer a callable
function to sample a batch of noise data. Otherwise, the
``None`` indicates to use the default noise sampler.
num_batches (int, optional): The number of batch size. Defaults to
0.
return_noise (bool, optional): If True, ``noise_batch`` will be
returned in a dict with ``fake_img``. Defaults to False.
Returns:
torch.Tensor | dict: If not ``return_noise``, only the output image
will be returned. Otherwise, a dict contains ``fake_img`` and
``noise_batch`` will be returned.
"""
# receive noise and conduct sanity check.
if isinstance(noise, torch.Tensor):
assert noise.shape[1] == self.noise_size
assert noise.ndim == 2, ('The noise should be in shape of (n, c), '
f'but got {noise.shape}')
noise_batch = noise
# receive a noise generator and sample noise.
elif callable(noise):
noise_generator = noise
assert num_batches > 0
noise_batch = noise_generator((num_batches, self.noise_size))
# otherwise, we will adopt default noise sampler.
else:
assert num_batches > 0
noise_batch = torch.randn((num_batches, self.noise_size))
# dirty code for putting data on the right device
noise_batch = noise_batch.to(get_module_device(self))
# noise vector to 2D feature
x = self.noise2feat(noise_batch)
for conv in self.conv_blocks:
x = conv(x)
out_img = self.to_rgb(x)
if return_noise:
output = dict(fake_img=out_img, noise_batch=noise_batch)
return output
return out_img
@MODULES.register_module()
class WGANGPDiscriminator(nn.Module):
r"""Discriminator for WGANGP.
Implementation Details for WGANGP discriminator the same as training
configuration (a) described in PGGAN paper:
PROGRESSIVE GROWING OF GANS FOR IMPROVED QUALITY, STABILITY, AND VARIATION
https://research.nvidia.com/sites/default/files/pubs/2017-10_Progressive-Growing-of/karras2018iclr-paper.pdf # noqa
#. Adopt convolution architecture specified in appendix A.2;
#. Add layer normalization to all conv3x3 and conv4x4 layers;
#. Use LeakyReLU in the discriminator except for the final output layer;
#. Initialize all weights using He’s initializer.
Args:
in_channel (int): The channel number of the input image.
in_scale (int): The scale of the input image.
conv_module_cfg (dict, optional): Config for the convolution module
used in this discriminator. Defaults to None.
"""
_default_channels_per_scale = {
'4': 512,
'8': 512,
'16': 256,
'32': 128,
'64': 64,
'128': 32
}
_default_conv_module_cfg = dict(
conv_cfg=None,
kernel_size=3,
stride=1,
padding=1,
bias=True,
act_cfg=dict(type='LeakyReLU', negative_slope=0.2),
norm_cfg=dict(type='LN2d'),
order=('conv', 'norm', 'act'))
_default_upsample_cfg = dict(type='nearest', scale_factor=2)
def __init__(self, in_channel, in_scale, conv_module_cfg=None):
super().__init__()
# set initial params
self.in_channel = in_channel
self.in_scale = in_scale
self.conv_module_cfg = deepcopy(self._default_conv_module_cfg)
if conv_module_cfg is not None:
self.conv_module_cfg.update(conv_module_cfg)
# set from_rgb head
self.from_rgb = ConvModule(
3,
kernel_size=1,
out_channels=self._default_channels_per_scale[str(self.in_scale)],
act_cfg=dict(type='LeakyReLU', negative_slope=0.2))
# set conv_blocks
self.conv_blocks = nn.ModuleList()
log2scale = int(np.log2(self.in_scale))
for i in range(log2scale, 2, -1):
self.conv_blocks.append(
ConvLNModule(
self._default_channels_per_scale[str(2**i)],
self._default_channels_per_scale[str(2**i)],
feature_shape=(self._default_channels_per_scale[str(2**i)],
2**i, 2**i),
**self.conv_module_cfg))
self.conv_blocks.append(
ConvLNModule(
self._default_channels_per_scale[str(2**i)],
self._default_channels_per_scale[str(2**(i - 1))],
feature_shape=(self._default_channels_per_scale[str(
2**(i - 1))], 2**i, 2**i),
**self.conv_module_cfg))
self.conv_blocks.append(nn.AvgPool2d(kernel_size=2, stride=2))
self.decision = WGANDecisionHead(
self._default_channels_per_scale['4'],
self._default_channels_per_scale['4'],
1,
act_cfg=dict(type='LeakyReLU', negative_slope=0.2),
norm_cfg=self.conv_module_cfg['norm_cfg'])
def forward(self, x):
"""Forward function.
Args:
x (torch.Tensor): Fake or real image tensor.
Returns:
torch.Tensor: Prediction for the reality of the input image.
"""
# noise vector to 2D feature
x = self.from_rgb(x)
for conv in self.conv_blocks:
x = conv(x)
x = self.decision(x)
return x
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
import torch
import torch.nn as nn
from mmcv.cnn import (PLUGIN_LAYERS, ConvModule, build_activation_layer,
build_norm_layer, constant_init)
from mmgen.models.builder import MODULES
@MODULES.register_module()
class WGANNoiseTo2DFeat(nn.Module):
"""Module used in WGAN-GP to transform 1D noise tensor in order [N, C] to
2D shape feature tensor in order [N, C, H, W].
Args:
noise_size (int): Size of the input noise vector.
out_channels (int): The channel number of the output feature.
act_cfg (dict, optional): Config for the activation layer. Defaults to
dict(type='ReLU').
norm_cfg (dict, optional): Config dict to build norm layer. Defaults to
dict(type='BN').
order (tuple, optional): The order of conv/norm/activation layers. It
is a sequence of "conv", "norm" and "act". Common examples are
("conv", "norm", "act") and ("act", "conv", "norm"). Defaults to
('linear', 'act', 'norm').
"""
def __init__(self,
noise_size,
out_channels,
act_cfg=dict(type='ReLU'),
norm_cfg=dict(type='BN'),
order=('linear', 'act', 'norm')):
super().__init__()
self.noise_size = noise_size
self.out_channels = out_channels
self.with_activation = act_cfg is not None
self.with_norm = norm_cfg is not None
self.order = order
assert len(order) == 3 and set(order) == set(['linear', 'act', 'norm'])
# w/o bias, because the bias is added after reshaping the tensor to
# 2D feature
self.linear = nn.Linear(noise_size, out_channels * 16, bias=False)
if self.with_activation:
self.activation = build_activation_layer(act_cfg)
# add bias for reshaped 2D feature.
self.register_parameter(
'bias', nn.Parameter(torch.zeros(1, out_channels, 1, 1)))
if self.with_norm:
_, self.norm = build_norm_layer(norm_cfg, out_channels)
self._init_weight()
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input noise tensor with shape (n, c).
Returns:
Tensor: Forward results with shape (n, c, 4, 4).
"""
assert x.ndim == 2
for order in self.order:
if order == 'linear':
x = self.linear(x)
# [n, c, 4, 4]
x = torch.reshape(x, (-1, self.out_channels, 4, 4))
x = x + self.bias
elif order == 'act' and self.with_activation:
x = self.activation(x)
elif order == 'norm' and self.with_norm:
x = self.norm(x)
return x
def _init_weight(self):
"""Initialize weights for the model."""
nn.init.normal_(self.linear.weight, 0., 1.)
if self.bias is not None:
nn.init.constant_(self.bias, 0.)
if self.with_norm:
constant_init(self.norm, 1, bias=0)
class WGANDecisionHead(nn.Module):
"""Module used in WGAN-GP to get the final prediction result with 4x4
resolution input tensor in the bottom of the discriminator.
Args:
in_channels (int): Number of channels in input feature map.
mid_channels (int): Number of channels in feature map after
convolution.
out_channels (int): The channel number of the final output layer.
bias (bool, optional): Whether to use bias parameter. Defaults to True.
act_cfg (dict, optional): Config for the activation layer. Defaults to
dict(type='ReLU').
out_act (dict, optional): Config for the activation layer of output
layer. Defaults to None.
norm_cfg (dict, optional): Config dict to build norm layer. Defaults to
dict(type='LN2d').
"""
def __init__(self,
in_channels,
mid_channels,
out_channels,
bias=True,
act_cfg=dict(type='ReLU'),
out_act=None,
norm_cfg=dict(type='LN2d')):
super().__init__()
self.in_channels = in_channels
self.mid_channels = mid_channels
self.out_channels = out_channels
self.with_out_activation = out_act is not None
# setup conv layer
self.conv = ConvLNModule(
in_channels,
feature_shape=(mid_channels, 1, 1),
kernel_size=4,
out_channels=mid_channels,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
order=('conv', 'norm', 'act'))
# setup linear layer
self.linear = nn.Linear(
self.mid_channels, self.out_channels, bias=bias)
if self.with_out_activation:
self.out_activation = build_activation_layer(out_act)
self._init_weight()
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
x = self.conv(x)
x = torch.reshape(x, (x.shape[0], -1))
x = self.linear(x)
if self.with_out_activation:
x = self.out_activation(x)
return x
def _init_weight(self):
"""Initialize weights for the model."""
nn.init.normal_(self.linear.weight, 0., 1.)
nn.init.constant_(self.linear.bias, 0.)
@PLUGIN_LAYERS.register_module()
class ConvLNModule(ConvModule):
r"""ConvModule with Layer Normalization.
In this module, we inherit default ``mmcv.cnn.ConvModule`` and deal with
the situation that 'norm_cfg' is 'LN2d' or 'GN'. We adopt 'GN' as a
replacement for layer normalization referring to:
https://github.com/LynnHo/DCGAN-LSGAN-WGAN-GP-DRAGAN-Pytorch/blob/master/module.py # noqa
Args:
feature_shape (tuple): The shape of feature map that will be.
"""
def __init__(self, *args, feature_shape=None, **kwargs):
if 'norm_cfg' in kwargs and kwargs['norm_cfg'] is not None and kwargs[
'norm_cfg']['type'] in ['LN2d', 'GN']:
nkwargs = deepcopy(kwargs)
nkwargs['norm_cfg'] = None
super().__init__(*args, **nkwargs)
self.with_norm = True
self.norm_name = kwargs['norm_cfg']['type']
if self.norm_name == 'LN2d':
norm = nn.LayerNorm(feature_shape)
self.add_module(self.norm_name, norm)
else:
norm = nn.GroupNorm(1, feature_shape[0])
self.add_module(self.norm_name, norm)
else:
super().__init__(*args, **kwargs)
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.utils import Registry, build_from_cfg
MODELS = Registry('model')
MODULES = Registry('module')
def build(cfg, registry, default_args=None):
"""Build a module.
Args:
cfg (dict, list[dict]): The config of modules, is is either a dict
or a list of configs.
registry (:obj:`Registry`): A registry the module belongs to.
default_args (dict, optional): Default arguments to build the module.
Defaults to None.
Returns:
nn.Module: A built nn module.
"""
if isinstance(cfg, list):
modules = [
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
]
return nn.ModuleList(modules)
return build_from_cfg(cfg, registry, default_args)
def build_model(cfg, train_cfg=None, test_cfg=None):
"""Build model (GAN)."""
return build(cfg, MODELS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
def build_module(cfg, default_args=None):
"""Build a module or modules from a list."""
return build(cfg, MODULES, default_args)
# Copyright (c) OpenMMLab. All rights reserved.
from .dist_utils import AllGatherLayer
from .model_utils import GANImageBuffer, set_requires_grad
__all__ = ['set_requires_grad', 'AllGatherLayer', 'GANImageBuffer']
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.autograd as autograd
import torch.distributed as dist
class AllGatherLayer(autograd.Function):
"""All gather layer with backward propagation path.
Indeed, this module is to make ``dist.all_gather()`` in the backward graph.
Such kind of operation has been widely used in Moco and other contrastive
learning algorithms.
"""
@staticmethod
def forward(ctx, x):
"""Forward function."""
ctx.save_for_backward(x)
output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
dist.all_gather(output, x)
return tuple(output)
@staticmethod
def backward(ctx, *grad_outputs):
"""Backward function."""
x, = ctx.saved_tensors
grad_out = torch.zeros_like(x)
grad_out = grad_outputs[dist.get_rank()]
return grad_out
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
def set_requires_grad(nets, requires_grad=False):
"""Set requires_grad for all the networks.
Args:
nets (nn.Module | list[nn.Module]): A list of networks or a single
network.
requires_grad (bool): Whether the networks require gradients or not
"""
if not isinstance(nets, list):
nets = [nets]
for net in nets:
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad
class GANImageBuffer:
"""This class implements an image buffer that stores previously generated
images.
This buffer allows us to update the discriminator using a history of
generated images rather than the ones produced by the latest generator
to reduce model oscillation.
Args:
buffer_size (int): The size of image buffer. If buffer_size = 0,
no buffer will be created.
buffer_ratio (float): The chance / possibility to use the images
previously stored in the buffer.
"""
def __init__(self, buffer_size, buffer_ratio=0.5):
self.buffer_size = buffer_size
# create an empty buffer
if self.buffer_size > 0:
self.img_num = 0
self.image_buffer = []
self.buffer_ratio = buffer_ratio
def query(self, images):
"""Query current image batch using a history of generated images.
Args:
images (Tensor): Current image batch without history information.
"""
if self.buffer_size == 0: # if the buffer size is 0, do nothing
return images
return_images = []
for image in images:
image = torch.unsqueeze(image.data, 0)
# if the buffer is not full, keep inserting current images
if self.img_num < self.buffer_size:
self.img_num = self.img_num + 1
self.image_buffer.append(image)
return_images.append(image)
else:
use_buffer = np.random.random() < self.buffer_ratio
# by self.buffer_ratio, the buffer will return a previously
# stored image, and insert the current image into the buffer
if use_buffer:
random_id = np.random.randint(0, self.buffer_size)
image_tmp = self.image_buffer[random_id].clone()
self.image_buffer[random_id] = image
return_images.append(image_tmp)
# by (1 - self.buffer_ratio), the buffer will return the
# current image
else:
return_images.append(image)
# collect all the images and return
return_images = torch.cat(return_images, 0)
return return_images
# Copyright (c) OpenMMLab. All rights reserved.
from .base_diffusion import BasicGaussianDiffusion
from .sampler import UniformTimeStepSampler
__all__ = ['BasicGaussianDiffusion', 'UniformTimeStepSampler']
# Copyright (c) OpenMMLab. All rights reserved.
import sys
from abc import ABCMeta
from collections import OrderedDict, defaultdict
from copy import deepcopy
from functools import partial
import mmcv
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel.distributed import _find_tensors
from ..architectures.common import get_module_device
from ..builder import MODELS, build_module
from .utils import _get_label_batch, _get_noise_batch, var_to_tensor
@MODELS.register_module()
class BasicGaussianDiffusion(nn.Module, metaclass=ABCMeta):
"""Basic module for gaussian Diffusion Denoising Probabilistic Models. A
diffusion probabilistic model (which we will call a 'diffusion model' for
brevity) is a parameterized Markov chain trained using variational
inference to produce samples matching the data after finite time.
The design of this module implements DDPM and improve-DDPM according to
"Denoising Diffusion Probabilistic Models" (2020) and "Improved Denoising
Diffusion Probabilistic Models" (2021).
Args:
denoising (dict): Config for denoising model.
ddpm_loss (dict): Config for losses of DDPM.
betas_cfg (dict): Config for betas in diffusion process.
num_timesteps (int, optional): The number of timesteps of the diffusion
process. Defaults to 1000.
num_classes (int | None, optional): The number of conditional classes.
Defaults to None.
sample_method (string, optional): Sample method for the denoising
process. Support 'DDPM' and 'DDIM'. Defaults to 'DDPM'.
timesteps_sampler (string, optional): How to sample timesteps in
training process. Defaults to `UniformTimeStepSampler`.
train_cfg (dict | None, optional): Config for training schedule.
Defaults to None.
test_cfg (dict | None, optional): Config for testing schedule. Defaults
to None.
"""
def __init__(self,
denoising,
ddpm_loss,
betas_cfg,
num_timesteps=1000,
num_classes=0,
sample_method='DDPM',
timestep_sampler='UniformTimeStepSampler',
train_cfg=None,
test_cfg=None):
super().__init__()
self.fp16_enable = False
# build denoising module in this function
self.num_classes = num_classes
self.num_timesteps = num_timesteps
self.sample_method = sample_method
self._denoising_cfg = deepcopy(denoising)
self.denoising = build_module(
denoising,
default_args=dict(
num_classes=num_classes, num_timesteps=num_timesteps))
# get output-related configs from denoising
self.denoising_var_mode = self.denoising.var_mode
self.denoising_mean_mode = self.denoising.mean_mode
# output_channels in denoising may be double, therefore we
# get number of channels from config
image_channels = self._denoising_cfg['in_channels']
# image_size should be the attribute of denoising network
image_size = self.denoising.image_size
image_shape = torch.Size([image_channels, image_size, image_size])
self.image_shape = image_shape
self.get_noise = partial(
_get_noise_batch,
image_shape=image_shape,
num_timesteps=self.num_timesteps)
self.get_label = partial(
_get_label_batch, num_timesteps=self.num_timesteps)
# build sampler
if timestep_sampler is not None:
self.sampler = build_module(
timestep_sampler,
default_args=dict(num_timesteps=num_timesteps))
else:
self.sampler = None
# build losses
if ddpm_loss is not None:
self.ddpm_loss = build_module(
ddpm_loss, default_args=dict(sampler=self.sampler))
if not isinstance(self.ddpm_loss, nn.ModuleList):
self.ddpm_loss = nn.ModuleList([self.ddpm_loss])
else:
self.ddpm_loss = None
self.betas_cfg = deepcopy(betas_cfg)
self.train_cfg = deepcopy(train_cfg) if train_cfg else None
self.test_cfg = deepcopy(test_cfg) if test_cfg else None
self._parse_train_cfg()
if test_cfg is not None:
self._parse_test_cfg()
self.prepare_diffusion_vars()
def _parse_train_cfg(self):
"""Parsing train config and set some attributes for training."""
if self.train_cfg is None:
self.train_cfg = dict()
self.use_ema = self.train_cfg.get('use_ema', False)
if self.use_ema:
self.denoising_ema = deepcopy(self.denoising)
self.real_img_key = self.train_cfg.get('real_img_key', 'real_img')
def _parse_test_cfg(self):
"""Parsing test config and set some attributes for testing."""
if self.test_cfg is None:
self.test_cfg = dict()
# whether to use exponential moving average for testing
self.use_ema = self.test_cfg.get('use_ema', False)
if self.use_ema:
self.denoising_ema = deepcopy(self.denoising)
def _get_loss(self, outputs_dict):
losses_dict = {}
# forward losses
for loss_fn in self.ddpm_loss:
losses_dict[loss_fn.loss_name()] = loss_fn(outputs_dict)
loss, log_vars = self._parse_losses(losses_dict)
# update collected log_var from loss_fn
for loss_fn in self.ddpm_loss:
if hasattr(loss_fn, 'log_vars'):
log_vars.update(loss_fn.log_vars)
return loss, log_vars
def _parse_losses(self, losses):
"""Parse the raw outputs (losses) of the network.
Args:
losses (dict): Raw output of the network, which usually contain
losses and other necessary information.
Returns:
tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor \
which may be a weighted sum of all losses, log_vars contains \
all the variables to be sent to the logger.
"""
log_vars = OrderedDict()
for loss_name, loss_value in losses.items():
if isinstance(loss_value, torch.Tensor):
log_vars[loss_name] = loss_value.mean()
elif isinstance(loss_value, list):
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
else:
raise TypeError(
f'{loss_name} is not a tensor or list of tensor')
loss = sum(_value for _key, _value in log_vars.items()
if 'loss' in _key)
log_vars['loss'] = loss
for loss_name, loss_value in log_vars.items():
if dist.is_available() and dist.is_initialized():
loss_value = loss_value.data.clone()
dist.all_reduce(loss_value.div_(dist.get_world_size()))
log_vars[loss_name] = loss_value.item()
return loss, log_vars
def train_step(self,
data,
optimizer,
ddp_reducer=None,
loss_scaler=None,
use_apex_amp=False,
running_status=None):
"""The iteration step during training.
This method defines an iteration step during training. Different from
other repo in **MM** series, we allow the back propagation and
optimizer updating to directly follow the iterative training schedule
of DDPMs.
Of course, we will show that you can also move the back
propagation outside of this method, and then optimize the parameters
in the optimizer hook. But this will cause extra GPU memory cost as a
result of retaining computational graph. Otherwise, the training
schedule should be modified in the detailed implementation.
Args:
optimizer (dict): Dict contains optimizer for denoising network.
running_status (dict | None, optional): Contains necessary basic
information for training, e.g., iteration number. Defaults to
None.
"""
# get running status
if running_status is not None:
curr_iter = running_status['iteration']
else:
# dirty walkround for not providing running status
if not hasattr(self, 'iteration'):
self.iteration = 0
curr_iter = self.iteration
real_imgs = data[self.real_img_key]
# denoising training
optimizer['denoising'].zero_grad()
denoising_dict_ = self.reconstruction_step(
data,
timesteps=self.sampler,
sample_model='orig',
return_noise=True)
denoising_dict_['iteration'] = curr_iter
denoising_dict_['real_imgs'] = real_imgs
denoising_dict_['loss_scaler'] = loss_scaler
loss, log_vars = self._get_loss(denoising_dict_)
# prepare for backward in ddp. If you do not call this function before
# back propagation, the ddp will not dynamically find the used params
# in current computation.
if ddp_reducer is not None:
ddp_reducer.prepare_for_backward(_find_tensors(loss))
if loss_scaler:
# add support for fp16
loss_scaler.scale(loss).backward()
elif use_apex_amp:
from apex import amp
with amp.scale_loss(
loss, optimizer['denoising'],
loss_id=0) as scaled_loss_disc:
scaled_loss_disc.backward()
else:
loss.backward()
if loss_scaler:
loss_scaler.unscale_(optimizer['denoising'])
# note that we do not contain clip_grad procedure
loss_scaler.step(optimizer['denoising'])
# loss_scaler.update will be called in runner.train()
else:
optimizer['denoising'].step()
# image used for vislization
results = dict(
real_imgs=real_imgs,
x_0_pred=denoising_dict_['x_0_pred'],
x_t=denoising_dict_['diffusion_batches'],
x_t_1=denoising_dict_['fake_img'])
outputs = dict(
log_vars=log_vars, num_samples=real_imgs.shape[0], results=results)
if hasattr(self, 'iteration'):
self.iteration += 1
return outputs
def reconstruction_step(self,
data_batch,
noise=None,
label=None,
timesteps=None,
sample_model='orig',
return_noise=False,
**kwargs):
"""Reconstruction step at corresponding `timestep`. To be noted that,
denoisint target ``x_t`` for each timestep are all generated from real
images, but not the denoising result from denoising network.
``sample_from_noise`` focus on generate samples start from **random
(or given) noise**. Therefore, we design this function to realize a
reconstruction process for the given images.
If `timestep` is None, automatically perform reconstruction at all
timesteps.
Args:
data_batch (dict): Input data from dataloader.
noise (torch.Tensor | callable | None): Noise used in diffusion
process. You can directly give a batch of noise through a
``torch.Tensor`` or offer a callable function to sample a
batch of noise data. Otherwise, the ``None`` indicates to use
the default noise sampler. Defaults to None.
label (torch.Tensor | None , optional): The conditional label of
the input image. Defaults to None.
timestep (int | list | torch.Tensor | callable | None): Target
timestep to perform reconstruction.
sampel_model (str, optional): Use which model to sample fake
images. Defaults to `'orig'`.
return_noise (bool, optional): If True,``noise_batch``, ``label``
and all other intermedia variables will be returned together
with ``fake_img`` in a dict. Defaults to False.
Returns:
torch.Tensor | dict: The output may be the direct synthesized
images in ``torch.Tensor``. Otherwise, a dict with required
data , including generated images, will be returned.
"""
assert sample_model in [
'orig', 'ema'
], ('We only support \'orig\' and \'ema\' for '
f'\'reconstruction_step\', but receive \'{sample_model}\'.')
denoising_model = self.denoising if sample_model == 'orig' \
else self.denoising_ema
# 0. prepare for timestep, noise and label
device = get_module_device(self)
real_imgs = data_batch[self.real_img_key]
num_batches = real_imgs.shape[0]
if timesteps is None:
# default to performing the whole reconstruction process
timesteps = torch.LongTensor([
t for t in range(self.num_timesteps)
]).view(self.num_timesteps, 1)
timesteps = timesteps.repeat([1, num_batches])
if isinstance(timesteps, (int, list)):
timesteps = torch.LongTensor(timesteps)
elif callable(timesteps):
timestep_generator = timesteps
timesteps = timestep_generator(num_batches)
else:
assert isinstance(timesteps, torch.Tensor), (
'we only support int list tensor or a callable function')
if timesteps.ndim == 1:
timesteps = timesteps.unsqueeze(0)
timesteps = timesteps.to(get_module_device(self))
if noise is not None:
assert 'noise' not in data_batch, (
'Receive \'noise\' in both data_batch and passed arguments.')
if noise is None:
noise = data_batch['noise'] if 'noise' in data_batch else None
if self.num_classes > 0:
if label is not None:
assert 'label' not in data_batch, (
'Receive \'label\' in both data_batch '
'and passed arguments.')
if label is None:
label = data_batch['label'] if 'label' in data_batch else None
label_batches = self.get_label(
label, num_batches=num_batches).to(device)
else:
label_batches = None
output_dict = defaultdict(list)
# loop all timesteps
for timestep in timesteps:
# 1. get diffusion results and parameters
noise_batches = self.get_noise(
noise, num_batches=num_batches).to(device)
diffusion_batches = self.q_sample(real_imgs, timestep,
noise_batches)
# 2. get denoising results.
denoising_batches = self.denoising_step(
denoising_model,
diffusion_batches,
timestep,
label=label_batches,
return_noise=return_noise,
clip_denoised=not self.training)
# 3. get ground truth by q_posterior
target_batches = self.q_posterior_mean_variance(
real_imgs, diffusion_batches, timestep, logvar=True)
if return_noise:
output_dict_ = dict(
timesteps=timestep,
noise=noise_batches,
diffusion_batches=diffusion_batches)
if self.num_classes > 0:
output_dict_['label'] = label_batches
output_dict_.update(denoising_batches)
output_dict_.update(target_batches)
else:
output_dict_ = dict(fake_img=denoising_batches)
# update output of `timestep` to output_dict
for k, v in output_dict_.items():
if k in output_dict:
output_dict[k].append(v)
else:
output_dict[k] = [v]
# 4. concentrate list to tensor
for k, v in output_dict.items():
output_dict[k] = torch.cat(v, dim=0)
# 5. return results
if return_noise:
return output_dict
return output_dict['fake_img']
def sample_from_noise(self,
noise,
num_batches=0,
sample_model='ema/orig',
label=None,
**kwargs):
"""Sample images from noises by using Denoising model.
Args:
noise (torch.Tensor | callable | None): You can directly give a
batch of noise through a ``torch.Tensor`` or offer a callable
function to sample a batch of noise data. Otherwise, the
``None`` indicates to use the default noise sampler.
num_batches (int, optional): The number of batch size.
Defaults to 0.
sample_model (str, optional): The model to sample. If ``ema/orig``
is passed, this method will try to sample from ema (if
``self.use_ema == True``) and orig model. Defaults to
'ema/orig'.
label (torch.Tensor | None , optional): The conditional label.
Defaults to None.
Returns:
torch.Tensor | dict: The output may be the direct synthesized
images in ``torch.Tensor``. Otherwise, a dict with queried
data, including generated images, will be returned.
"""
# get sample function by name
sample_fn_name = f'{self.sample_method.upper()}_sample'
if not hasattr(self, sample_fn_name):
raise AttributeError(
f'Cannot find sample method [{sample_fn_name}] correspond '
f'to [{self.sample_method}].')
sample_fn = getattr(self, sample_fn_name)
if sample_model == 'ema':
assert self.use_ema
_model = self.denoising_ema
elif sample_model == 'ema/orig' and self.use_ema:
_model = self.denoising_ema
else:
_model = self.denoising
outputs = sample_fn(
_model,
noise=noise,
num_batches=num_batches,
label=label,
**kwargs)
if isinstance(outputs, dict) and 'noise_batch' in outputs:
# return_noise is True
noise = outputs['x_t']
label = outputs['label']
kwargs['timesteps_noise'] = outputs['noise_batch']
fake_img = outputs['fake_img']
else:
fake_img = outputs
if sample_model == 'ema/orig' and self.use_ema:
_model = self.denoising
outputs_ = sample_fn(
_model, noise=noise, num_batches=num_batches, **kwargs)
if isinstance(outputs_, dict) and 'noise_batch' in outputs_:
# return_noise is True
fake_img_ = outputs_['fake_img']
else:
fake_img_ = outputs_
if isinstance(fake_img, dict):
# save_intermedia is True
fake_img = {
k: torch.cat([fake_img[k], fake_img_[k]], dim=0)
for k in fake_img.keys()
}
else:
fake_img = torch.cat([fake_img, fake_img_], dim=0)
return fake_img
@torch.no_grad()
def DDPM_sample(self,
model,
noise=None,
num_batches=0,
label=None,
save_intermedia=False,
timesteps_noise=None,
return_noise=False,
show_pbar=False,
**kwargs):
"""DDPM sample from random noise.
Args:
model (torch.nn.Module): Denoising model used to sample images.
noise (torch.Tensor | callable | None): You can directly give a
batch of noise through a ``torch.Tensor`` or offer a callable
function to sample a batch of noise data. Otherwise, the
``None`` indicates to use the default noise sampler.
num_batches (int, optional): The number of batch size.
Defaults to 0.
label (torch.Tensor | None , optional): The conditional label.
Defaults to None.
save_intermedia (bool, optional): Whether to save denoising result
of intermedia timesteps. If set as True, will return a dict
which key and value are denoising timestep and denoising
result. Otherwise, only the final denoising result will be
returned. Defaults to False.
timesteps_noise (torch.Tensor, optional): Noise term used in each
denoising timestep. If given, the input noise will be shaped to
[num_timesteps, b, c, h, w]. If set as None, noise of each
denoising timestep will be randomly sampled. Default as None.
return_noise (bool, optional): If True, a dict contains
``noise_batch``, ``x_t`` and ``label`` will be returned
together with the denoising results, and the key of denoising
results is ``fake_img``. To be noted that ``noise_batches``
will shape as [num_timesteps, b, c, h, w]. Defaults to False.
show_pbar (bool, optional): If True, a progress bar will be
displayed. Defaults to False.
Returns:
torch.Tensor | dict: If ``save_intermedia``, a dict contains
denoising results of each timestep will be returned.
Otherwise, only the final denoising result will be returned.
"""
device = get_module_device(self)
noise = self.get_noise(noise, num_batches=num_batches).to(device)
x_t = noise.clone()
if save_intermedia:
# save input
intermedia = {self.num_timesteps: x_t.clone()}
# use timesteps noise if defined
if timesteps_noise is not None:
timesteps_noise = self.get_noise(
timesteps_noise, num_batches=num_batches,
timesteps_noise=True).to(device)
batched_timesteps = torch.arange(self.num_timesteps - 1, -1,
-1).long().to(device)
if show_pbar:
pbar = mmcv.ProgressBar(self.num_timesteps)
for t in batched_timesteps:
batched_t = t.expand(x_t.shape[0])
step_noise = timesteps_noise[t, ...] \
if timesteps_noise is not None else None
x_t = self.denoising_step(
model, x_t, batched_t, noise=step_noise, label=label, **kwargs)
if save_intermedia:
intermedia[int(t)] = x_t.cpu().clone()
if show_pbar:
pbar.update()
denoising_results = intermedia if save_intermedia else x_t
if show_pbar:
sys.stdout.write('\n')
if return_noise:
return dict(
noise_batch=timesteps_noise,
x_t=noise,
label=label,
fake_img=denoising_results)
return denoising_results
def prepare_diffusion_vars(self):
"""Prepare for variables used in the diffusion process."""
self.betas = self.get_betas()
self.alphas = 1.0 - self.betas
self.alphas_bar = np.cumproduct(self.alphas, axis=0)
self.alphas_bar_prev = np.append(1.0, self.alphas_bar[:-1])
self.alphas_bar_next = np.append(self.alphas_bar[1:], 0.0)
# calculations for diffusion q(x_t | x_0) and others
self.sqrt_alphas_bar = np.sqrt(self.alphas_bar)
self.sqrt_one_minus_alphas_bar = np.sqrt(1.0 - self.alphas_bar)
self.log_one_minus_alphas_bar = np.log(1.0 - self.alphas_bar)
self.sqrt_recip_alplas_bar = np.sqrt(1.0 / self.alphas_bar)
self.sqrt_recipm1_alphas_bar = np.sqrt(1.0 / self.alphas_bar - 1)
# calculations for posterior q(x_{t-1} | x_t, x_0)
self.tilde_betas_t = self.betas * (1 - self.alphas_bar_prev) / (
1 - self.alphas_bar)
# clip log var for tilde_betas_0 = 0
self.log_tilde_betas_t_clipped = np.log(
np.append(self.tilde_betas_t[1], self.tilde_betas_t[1:]))
self.tilde_mu_t_coef1 = np.sqrt(
self.alphas_bar_prev) / (1 - self.alphas_bar) * self.betas
self.tilde_mu_t_coef2 = np.sqrt(
self.alphas) * (1 - self.alphas_bar_prev) / (1 - self.alphas_bar)
def get_betas(self):
"""Get betas by defined schedule method in diffusion process."""
self.betas_schedule = self.betas_cfg.pop('type')
if self.betas_schedule == 'linear':
return self.linear_beta_schedule(self.num_timesteps,
**self.betas_cfg)
elif self.betas_schedule == 'cosine':
return self.cosine_beta_schedule(self.num_timesteps,
**self.betas_cfg)
else:
raise AttributeError(f'Unknown method name {self.beta_schedule}'
'for beta schedule.')
@staticmethod
def linear_beta_schedule(diffusion_timesteps, beta_0=1e-4, beta_T=2e-2):
r"""Linear schedule from Ho et al, extended to work for any number of
diffusion steps.
Args:
diffusion_timesteps (int): The number of betas to produce.
beta_0 (float, optional): `\beta` at timestep 0. Defaults to 1e-4.
beta_T (float, optional): `\beta` at timestep `T` (the final
diffusion timestep). Defaults to 2e-2.
Returns:
np.ndarray: Betas used in diffusion process.
"""
scale = 1000 / diffusion_timesteps
beta_0 = scale * beta_0
beta_T = scale * beta_T
return np.linspace(
beta_0, beta_T, diffusion_timesteps, dtype=np.float64)
@staticmethod
def cosine_beta_schedule(diffusion_timesteps, max_beta=0.999, s=0.008):
r"""Create a beta schedule that discretizes the given alpha_t_bar
function, which defines the cumulative product of `(1-\beta)` over time
from `t = [0, 1]`.
Args:
diffusion_timesteps (int): The number of betas to produce.
max_beta (float, optional): The maximum beta to use; use values
lower than 1 to prevent singularities. Defaults to 0.999.
s (float, optional): Small offset to prevent `\beta` from being too
small near `t = 0` Defaults to 0.008.
Returns:
np.ndarray: Betas used in diffusion process.
"""
def f(t, T, s):
return np.cos((t / T + s) / (1 + s) * np.pi / 2)**2
betas = []
for t in range(diffusion_timesteps):
alpha_bar_t = f(t + 1, diffusion_timesteps, s)
alpha_bar_t_1 = f(t, diffusion_timesteps, s)
betas_t = 1 - alpha_bar_t / alpha_bar_t_1
betas.append(min(betas_t, max_beta))
return np.array(betas)
def q_sample(self, x_0, t, noise=None):
r"""Get diffusion result at timestep `t` by `q(x_t | x_0)`.
Args:
x_0 (torch.Tensor): Original image without diffusion.
t (torch.Tensor): Target diffusion timestep.
noise (torch.Tensor, optional): Noise used in reparameteration
trick. Default to None.
Returns:
torch.tensor: Diffused image `x_t`.
"""
device = get_module_device(self)
num_batches = x_0.shape[0]
tar_shape = x_0.shape
noise = self.get_noise(noise, num_batches=num_batches)
mean = var_to_tensor(self.sqrt_alphas_bar, t, tar_shape, device)
std = var_to_tensor(self.sqrt_one_minus_alphas_bar, t, tar_shape,
device)
return x_0 * mean + noise * std
def q_mean_log_variance(self, x_0, t):
r"""Get mean and log_variance of diffusion process `q(x_t | x_0)`.
Args:
x_0 (torch.tensor): The original image before diffusion, shape as
[bz, ch, H, W].
t (torch.tensor): Target timestep, shape as [bz, ].
Returns:
Tuple(torch.tensor): Tuple contains mean and log variance.
"""
device = get_module_device(self)
tar_shape = x_0.shape
mean = var_to_tensor(self.sqrt_alphas_bar, t, tar_shape, device) * x_0
logvar = var_to_tensor(self.log_one_minus_alphas_bar, t, tar_shape,
device)
return mean, logvar
def q_posterior_mean_variance(self,
x_0,
x_t,
t,
need_var=True,
logvar=False):
r"""Get mean and variance of diffusion posterior
`q(x_{t-1} | x_t, x_0)`.
Args:
x_0 (torch.tensor): The original image before diffusion, shape as
[bz, ch, H, W].
t (torch.tensor): Target timestep, shape as [bz, ].
need_var (bool, optional): If set as ``True``, this function will
return a dict contains ``var``. Otherwise, only mean will be
returned, ``logvar`` will be ignored. Defaults to True.
logvar (bool, optional): If set as ``True``, the returned dict
will additionally contain ``logvar``. This argument will be
considered only if ``var == True``. Defaults to False.
Returns:
torch.Tensor | dict: If ``var``, will return a dict contains
``mean`` and ``var``. Otherwise, only mean will be returned.
If ``var`` and ``logvar`` set at as True simultaneously, the
returned dict will additional contain ``logvar``.
"""
device = get_module_device(self)
tar_shape = x_0.shape
tilde_mu_t_coef1 = var_to_tensor(self.tilde_mu_t_coef1, t, tar_shape,
device)
tilde_mu_t_coef2 = var_to_tensor(self.tilde_mu_t_coef2, t, tar_shape,
device)
posterior_mean = tilde_mu_t_coef1 * x_0 + tilde_mu_t_coef2 * x_t
# do not need variance, just return mean
if not need_var:
return posterior_mean
posterior_var = var_to_tensor(self.tilde_betas_t, t, tar_shape, device)
out_dict = dict(
mean_posterior=posterior_mean, var_posterior=posterior_var)
if logvar:
posterior_logvar = var_to_tensor(self.log_tilde_betas_t_clipped, t,
tar_shape, device)
out_dict['logvar_posterior'] = posterior_logvar
return out_dict
def p_mean_variance(self,
denoising_output,
x_t,
t,
clip_denoised=True,
denoised_fn=None):
r"""Get mean, variance, log variance of denoising process
`p(x_{t-1} | x_{t})` and predicted `x_0`.
Args:
denoising_output (dict[torch.Tensor]): The output from denoising
model.
x_t (torch.Tensor): Diffused image at timestep `t` to denoising.
t (torch.Tensor): Current timestep.
clip_denoised (bool, optional): Whether cliped sample results into
[-1, 1]. Defaults to True.
denoised_fn (callable, optional): If not None, a function which
applies to the predicted ``x_0`` before it is passed to the
following sampling procedure. Noted that this function will be
applies before ``clip_denoised``. Defaults to None.
Returns:
dict: A dict contains ``var_pred``, ``logvar_pred``, ``mean_pred``
and ``x_0_pred``.
"""
target_shape = x_t.shape
device = get_module_device(self)
# prepare for var and logvar
if self.denoising_var_mode.upper() == 'LEARNED':
# NOTE: the output actually LEARNED_LOG_VAR
logvar_pred = denoising_output['logvar']
varpred = torch.exp(logvar_pred)
elif self.denoising_var_mode.upper() == 'LEARNED_RANGE':
# NOTE: the output actually LEARNED_FACTOR
var_factor = denoising_output['factor']
lower_bound_logvar = var_to_tensor(self.log_tilde_betas_t_clipped,
t, target_shape, device)
upper_bound_logvar = var_to_tensor(
np.log(self.betas), t, target_shape, device)
logvar_pred = var_factor * upper_bound_logvar + (
1 - var_factor) * lower_bound_logvar
varpred = torch.exp(logvar_pred)
elif self.denoising_var_mode.upper() == 'FIXED_LARGE':
# use betas as var
varpred = var_to_tensor(
np.append(self.tilde_betas_t[1], self.betas), t, target_shape,
device)
logvar_pred = torch.log(varpred)
elif self.denoising_var_mode.upper() == 'FIXED_SMALL':
# use posterior (tilde_betas) as var
varpred = var_to_tensor(self.tilde_betas_t, t, target_shape,
device)
logvar_pred = var_to_tensor(self.log_tilde_betas_t_clipped, t,
target_shape, device)
else:
raise AttributeError('Unknown denoising var output type '
f'[{self.denoising_var_mode}].')
def process_x_0(x):
if denoised_fn is not None and callable(denoised_fn):
x = denoised_fn(x)
return x.clamp(-1, 1) if clip_denoised else x
# prepare for mean and x_0
if self.denoising_mean_mode.upper() == 'EPS':
eps_pred = denoising_output['eps_t_pred']
# We can get x_{t-1} with eps in two following approaches:
# 1. eps --(Eq 15)--> \hat{x_0} --(Eq 7)--> \tilde_mu --> x_{t-1}
# 2. eps --(Eq 11)--> \mu_{\theta} --(Eq 7)--> x_{t-1}
# We can verify \tilde_mu in method 1 and \mu_{\theta} in method 2
# are almost same (error of 1e-4) with the same eps input.
# In our implementation, we use method (1) to consistent with
# the official ones.
# If you want to calculate \mu_{\theta} with method 2, you can
# use the following code:
# coef1 = var_to_tensor(
# np.sqrt(1.0 / self.alphas), t, tar_shape)
# coef2 = var_to_tensor(
# self.betas / self.sqrt_one_minus_alphas_bar, t, tar_shape)
# mu_theta = coef1 * (x_t - coef2 * eps)
x_0_pred = process_x_0(self.pred_x_0_from_eps(eps_pred, x_t, t))
mean_pred = self.q_posterior_mean_variance(
x_0_pred, x_t, t, need_var=False)
elif self.denoising_mean_mode.upper() == 'START_X':
x_0_pred = process_x_0(denoising_output['x_0_pred'])
mean_pred = self.q_posterior_mean_variance(
x_0_pred, x_t, t, need_var=False)
elif self.denoising_mean_mode.upper() == 'PREVIOUS_X':
# NOTE: the output actually PREVIOUS_X_MEAN (MU_THETA)
# because this actually predict \mu_{\theta}
mean_pred = denoising_output['x_tm1_pred']
x_0_pred = process_x_0(self.pred_x_0_from_x_tm1(mean_pred, x_t, t))
else:
raise AttributeError('Unknown denoising mean output type '
f'[{self.denoising_mean_mode}].')
output_dict = dict(
var_pred=varpred,
logvar_pred=logvar_pred,
mean_pred=mean_pred,
x_0_pred=x_0_pred)
# avoid return duplicate variables
return {
k: output_dict[k]
for k in output_dict.keys() if k not in denoising_output
}
def denoising_step(self,
model,
x_t,
t,
noise=None,
label=None,
clip_denoised=True,
denoised_fn=None,
model_kwargs=None,
return_noise=False):
"""Single denoising step. Get `x_{t-1}` from ``x_t`` and ``t``.
Args:
model (torch.nn.Module): Denoising model used to sample images.
x_t (torch.Tensor): Input diffused image.
t (torch.Tensor): Current timestep.
noise (torch.Tensor | callable | None): Noise for
reparameterization trick. You can directly give a batch of
noise through a ``torch.Tensor`` or offer a callable function
to sample a batch of noise data. Otherwise, the ``None``
indicates to use the default noise sampler.
label (torch.Tensor | callable | None): You can directly give a
batch of label through a ``torch.Tensor`` or offer a callable
function to sample a batch of label data. Otherwise, the
``None`` indicates to use the default label sampler.
clip_denoised (bool, optional): Whether to clip sample results into
[-1, 1]. Defaults to False.
denoised_fn (callable, optional): If not None, a function which
applies to the predicted ``x_0`` prediction before it is used
to sample. Applies before ``clip_denoised``. Defaults to None.
model_kwargs (dict, optional): Arguments passed to denoising model.
Defaults to None.
return_noise (bool, optional): If True, ``noise_batch``, outputs
from denoising model and ``p_mean_variance`` will be returned
in a dict with ``fake_img``. Defaults to False.
Return:
torch.Tensor | dict: If not ``return_noise``, only the denoising
image will be returned. Otherwise, the dict contains
``fake_image``, ``noise_batch`` and outputs from denoising
model and ``p_mean_variance`` will be returned.
"""
# init model_kwargs as dict if not passed
if model_kwargs is None:
model_kwargs = dict()
model_kwargs.update(dict(return_noise=return_noise))
denoising_output = model(x_t, t, label=label, **model_kwargs)
p_output = self.p_mean_variance(denoising_output, x_t, t,
clip_denoised, denoised_fn)
mean_pred = p_output['mean_pred']
var_pred = p_output['var_pred']
num_batches = x_t.shape[0]
device = get_module_device(self)
# get noise for reparameterization
noise = self.get_noise(noise, num_batches=num_batches).to(device)
nonzero_mask = ((t != 0).float().view(-1,
*([1] * (len(x_t.shape) - 1))))
# Here we directly use var_pred instead logvar_pred,
# only error of 1e-12.
# logvar_pred = p_output['logvar_pred']
# sample = mean_pred + \
# nonzero_mask * torch.exp(0.5 * logvar_pred) * noise
sample = mean_pred + nonzero_mask * torch.sqrt(var_pred) * noise
if return_noise:
return dict(
fake_img=sample,
noise_repar=noise,
**denoising_output,
**p_output)
return sample
def pred_x_0_from_eps(self, eps, x_t, t):
r"""Predict x_0 from eps by Equ 15 in DDPM paper:
.. math::
x_0 = \frac{(x_t - \sqrt{(1-\bar{\alpha}_t)} * eps)}
{\sqrt{\bar{\alpha}_t}}
Args:
eps (torch.Tensor)
x_t (torch.Tensor)
t (torch.Tensor)
Returns:
torch.tensor: Predicted ``x_0``.
"""
device = get_module_device(self)
tar_shape = x_t.shape
coef1 = var_to_tensor(self.sqrt_recip_alplas_bar, t, tar_shape, device)
coef2 = var_to_tensor(self.sqrt_recipm1_alphas_bar, t, tar_shape,
device)
return x_t * coef1 - eps * coef2
def pred_x_0_from_x_tm1(self, x_tm1, x_t, t):
r"""
Predict `x_0` from `x_{t-1}`. (actually from `\mu_{\theta}`).
`(\mu_{\theta} - coef2 * x_t) / coef1`, where `coef1` and `coef2`
are from Eq 6 of the DDPM paper.
NOTE: This function actually predict ``x_0`` from ``mu_theta`` (mean
of ``x_{t-1}``).
Args:
x_tm1 (torch.Tensor): `x_{t-1}` used to predict `x_0`.
x_t (torch.Tensor): `x_{t}` used to predict `x_0`.
t (torch.Tensor): Current timestep.
Returns:
torch.Tensor: Predicted `x_0`.
"""
device = get_module_device(self)
tar_shape = x_t.shape
coef1 = var_to_tensor(self.tilde_mu_t_coef1, t, tar_shape, device)
coef2 = var_to_tensor(self.tilde_mu_t_coef2, t, tar_shape, device)
x_0 = (x_tm1 - coef2 * x_t) / coef1
return x_0
def forward_train(self, data, **kwargs):
"""Deprecated forward function in training."""
raise NotImplementedError(
'In MMGeneration, we do NOT recommend users to call'
'this function, because the train_step function is designed for '
'the training process.')
def forward_test(self, data, **kwargs):
"""Testing function for Diffusion Denosing Probability Models.
Args:
data (torch.Tensor | dict | None): Input data. This data will be
passed to different methods.
"""
mode = kwargs.pop('mode', 'sampling')
if mode == 'sampling':
return self.sample_from_noise(data, **kwargs)
elif mode == 'reconstruction':
# this mode is design for evaluation likelood metrics
return self.reconstruction_step(data, **kwargs)
raise NotImplementedError('Other specific testing functions should'
' be implemented by the sub-classes.')
def forward(self, data, return_loss=False, **kwargs):
"""Forward function.
Args:
data (dict | torch.Tensor): Input data dictionary.
return_loss (bool, optional): Whether in training or testing.
Defaults to False.
Returns:
dict: Output dictionary.
"""
if return_loss:
return self.forward_train(data, **kwargs)
return self.forward_test(data, **kwargs)
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
from ..builder import MODULES
@MODULES.register_module()
class UniformTimeStepSampler:
"""Timestep sampler for DDPM-based models. This sampler sample all
timesteps with the same probabilistic.
Args:
num_timesteps (int): Total timesteps of the diffusion process.
"""
def __init__(self, num_timesteps):
self.num_timesteps = num_timesteps
self.prob = [1 / self.num_timesteps for _ in range(self.num_timesteps)]
def sample(self, batch_size):
"""Sample timesteps.
Args:
batch_size (int): The desired batch size of the sampled timesteps.
Returns:
torch.Tensor: Sampled timesteps.
"""
# use numpy to make sure our implementation is consistent with the
# official ones.
return torch.from_numpy(
np.random.choice(
self.num_timesteps, size=(batch_size, ), p=self.prob)).long()
def __call__(self, batch_size):
"""Return sampled results."""
return self.sample(batch_size)
# Copyright (c) OpenMMLab. All rights reserved.
import torch
def _get_noise_batch(noise,
image_shape,
num_timesteps=0,
num_batches=0,
timesteps_noise=False):
"""Get noise batch. Support get sequeue of noise along timesteps.
We support the following use cases ('bz' denotes ```num_batches`` and 'n'
denotes ``num_timesteps``):
If timesteps_noise is True, we output noise which dimension is 5.
- Input is [bz, c, h, w]: Expand to [n, bz, c, h, w]
- Input is [n, c, h, w]: Expand to [n, bz, c, h, w]
- Input is [n*bz, c, h, w]: View to [n, bz, c, h, w]
- Dim of the input is 5: Return the input, ignore ``num_batches`` and
``num_timesteps``
- Callable or None: Generate noise shape as [n, bz, c, h, w]
- Otherwise: Raise error
If timestep_noise is False, we output noise which dimension is 4 and
ignore ``num_timesteps``.
- Dim of the input is 3: Unsqueeze to [1, c, h, w], ignore ``num_batches``
- Dim of the input is 4: Return input, ignore ``num_batches``
- Callable or None: Generate noise shape as [bz, c, h, w]
- Otherwise: Raise error
It's to be noted that, we do not move the generated label to target device
in this function because we can not get which device the noise should move
to.
Args:
noise (torch.Tensor | callable | None): You can directly give a
batch of noise through a ``torch.Tensor`` or offer a callable
function to sample a batch of noise data. Otherwise, the
``None`` indicates to use the default noise sampler.
image_shape (torch.Size): Size of images in the diffusion process.
num_timesteps (int, optional): Total timestpes of the diffusion and
denoising process. Defaults to 0.
num_batches (int, optional): The number of batch size. To be noted that
this argument only work when the input ``noise`` is callable or
``None``. Defaults to 0.
timesteps_noise (bool, optional): If True, returned noise will shape
as [n, bz, c, h, w], otherwise shape as [bz, c, h, w].
Defaults to False.
device (str, optional): If not ``None``, move the generated noise to
corresponding device.
Returns:
torch.Tensor: Generated noise with desired shape.
"""
if isinstance(noise, torch.Tensor):
# conduct sanity check for the last three dimension
assert noise.shape[-3:] == image_shape
if timesteps_noise:
if noise.ndim == 4:
assert num_batches > 0 and num_timesteps > 0
# noise shape as [n, c, h, w], expand to [n, bz, c, h, w]
if noise.shape[0] == num_timesteps:
noise_batch = noise.view(num_timesteps, 1, *image_shape)
noise_batch = noise_batch.expand(-1, num_batches, -1, -1,
-1)
# noise shape as [bz, c, h, w], expand to [n, bz, c, h, w]
elif noise.shape[0] == num_batches:
noise_batch = noise.view(1, num_batches, *image_shape)
noise_batch = noise_batch.expand(num_timesteps, -1, -1, -1,
-1)
# noise shape as [n*bz, c, h, w], reshape to [b, bz, c, h, w]
elif noise.shape[0] == num_timesteps * num_batches:
noise_batch = noise.view(num_timesteps, -1, *image_shape)
else:
raise ValueError(
'The timesteps noise should be in shape of '
'(n, c, h, w), (bz, c, h, w), (n*bz, c, h, w) or '
f'(n, bz, c, h, w). But receive {noise.shape}.')
elif noise.ndim == 5:
# direct return noise
noise_batch = noise
else:
raise ValueError(
'The timesteps noise should be in shape of '
'(n, c, h, w), (bz, c, h, w), (n*bz, c, h, w) or '
f'(n, bz, c, h, w). But receive {noise.shape}.')
else:
if noise.ndim == 3:
# reshape noise to [1, c, h, w]
noise_batch = noise[None, ...]
elif noise.ndim == 4:
# do nothing
noise_batch = noise
else:
raise ValueError(
'The noise should be in shape of (n, c, h, w) or'
f'(c, h, w), but got {noise.shape}')
# receive a noise generator and sample noise.
elif callable(noise):
assert num_batches > 0
noise_generator = noise
if timesteps_noise:
assert num_timesteps > 0
# generate noise shape as [n, bz, c, h, w]
noise_batch = noise_generator(
(num_timesteps, num_batches, *image_shape))
else:
# generate noise shape as [bz, c, h, w]
noise_batch = noise_generator((num_batches, *image_shape))
# otherwise, we will adopt default noise sampler.
else:
assert num_batches > 0
if timesteps_noise:
assert num_timesteps > 0
# generate noise shape as [n, bz, c, h, w]
noise_batch = torch.randn(
(num_timesteps, num_batches, *image_shape))
else:
# generate noise shape as [bz, c, h, w]
noise_batch = torch.randn((num_batches, *image_shape))
return noise_batch
def _get_label_batch(label,
num_timesteps=0,
num_classes=0,
num_batches=0,
timesteps_noise=False):
"""Get label batch. Support get sequeue of label along timesteps.
We support the following use cases ('bz' denotes ```num_batches`` and 'n'
denotes ``num_timesteps``):
If num_classes <= 0, return None.
If timesteps_noise is True, we output label which dimension is 2.
- Input is [bz, ]: Expand to [n, bz]
- Input is [n, ]: Expand to [n, bz]
- Input is [n*bz, ]: View to [n, bz]
- Dim of the input is 2: Return the input, ignore ``num_batches`` and
``num_timesteps``
- Callable or None: Generate label shape as [n, bz]
- Otherwise: Raise error
If timesteps_noise is False, we output label which dimension is 1 and
ignore ``num_timesteps``.
- Dim of the input is 1: Unsqueeze to [1, ], ignore ``num_batches``
- Dim of the input is 2: Return the input. ignore ``num_batches``
- Callable or None: Generate label shape as [bz, ]
- Otherwise: Raise error
It's to be noted that, we do not move the generated label to target device
in this function because we can not get which device the noise should move
to.
Args:
label (torch.Tensor | callable | None): You can directly give a
batch of noise through a ``torch.Tensor`` or offer a callable
function to sample a batch of noise data. Otherwise, the
``None`` indicates to use the default noise sampler.
num_timesteps (int, optional): Total timestpes of the diffusion and
denoising process. Defaults to 0.
num_batches (int, optional): The number of batch size. To be noted that
this argument only work when the input ``noise`` is callable or
``None``. Defaults to 0.
timesteps_noise (bool, optional): If True, returned noise will shape
as [n, bz, c, h, w], otherwise shape as [bz, c, h, w].
Defaults to False.
Returns:
torch.Tensor: Generated label with desired shape.
"""
# no labels output if num_classes is 0
if num_classes == 0:
assert label is None, ('\'label\' should be None '
'if \'num_classes == 0\'.')
return None
# receive label and conduct sanity check.
if isinstance(label, torch.Tensor):
if timesteps_noise:
if label.ndim == 1:
assert num_batches > 0 and num_timesteps > 0
# [n, ] to [n, bz]
if label.shape[0] == num_timesteps:
label_batch = label.view(num_timesteps, 1)
label_batch = label_batch.expand(-1, num_batches)
# [bz, ] to [n, bz]
elif label.shape[0] == num_batches:
label_batch = label.view(1, num_batches)
label_batch = label_batch.expand(num_timesteps, -1)
# [n*bz, ] to [n, bz]
elif label.shape[0] == num_timesteps * num_batches:
label_batch = label.view(num_timesteps, -1)
else:
raise ValueError(
'The timesteps label should be in shape of '
'(n, ), (bz,), (n*bz, ) or (n, bz, ). But receive '
f'{label.shape}.')
elif label.ndim == 2:
# dimension is 2, direct return
label_batch = label
else:
raise ValueError(
'The timesteps label should be in shape of '
'(n, ), (bz,), (n*bz, ) or (n, bz, ). But receive '
f'{label.shape}.')
else:
# dimension is 0, expand to [1, ]
if label.ndim == 0:
label_batch = label[None, ...]
# dimension is 1, do nothing
elif label.ndim == 1:
label_batch = label
else:
raise ValueError(
'The label should be in shape of (bz, ) or'
f'zero-dimension tensor, but got {label.shape}')
# receive a noise generator and sample noise.
elif callable(label):
assert num_batches > 0
label_generator = label
if timesteps_noise:
assert num_timesteps > 0
# generate label shape as [n, bz]
label_batch = label_generator((num_timesteps, num_batches))
else:
# generate label shape as [bz, ]
label_batch = label_generator((num_batches, ))
# otherwise, we will adopt default label sampler.
else:
assert num_batches > 0
if timesteps_noise:
assert num_timesteps > 0
# generate label shape as [n, bz]
label_batch = torch.randint(0, num_classes,
(num_timesteps, num_batches))
else:
# generate label shape as [bz, ]
label_batch = torch.randint(0, num_classes, (num_batches, ))
return label_batch
def var_to_tensor(var, index, target_shape=None, device=None):
"""Function used to extract variables by given index, and convert into
tensor as given shape.
Args:
var (np.array): Variables to be extracted.
index (torch.Tensor): Target index to extract.
target_shape (torch.Size, optional): If given, the indexed variable
will expand to the given shape. Defaults to None.
device (str): If given, the indexed variable will move to the target
device. Otherwise, indexed variable will on cpu. Defaults to None.
Returns:
torch.Tensor: Converted variable.
"""
# we must move var to cuda for it's ndarray in current design
var_indexed = torch.from_numpy(var)[index.cpu()].float()
if device is not None:
var_indexed = var_indexed.to(device)
while len(var_indexed.shape) < len(target_shape):
var_indexed = var_indexed[..., None]
return var_indexed
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment