Commit b7536f78 authored by limm's avatar limm
Browse files

add a to another part of mmgeneration code

parent 57e0e891
Pipeline #2777 canceled with stages
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import scipy
import torch
import torch.nn as nn
from mmgen.models.builder import MODULES
from mmgen.ops import bias_act, conv2d_gradfix, filtered_lrelu
def modulated_conv2d(
x,
w,
s,
demodulate=True,
padding=0,
input_gain=None,
):
"""Modulated Conv2d in StyleGANv3.
Args:
x (torch.Tensor): Input tensor with shape (batch_size, in_channels,
height, width).
w (torch.Tensor): Weight of modulated convolution with shape
(out_channels, in_channels, kernel_height, kernel_width).
s (torch.Tensor): Style tensor with shape (batch_size, in_channels).
demodulate (bool): Whether apply weight demodulation. Defaults to True.
padding (int or list[int]): Convolution padding. Defaults to 0.
input_gain (list[int]): Scaling factors for input. Defaults to None.
Returns:
torch.Tensor: Convolution Output.
"""
batch_size = int(x.shape[0])
_, in_channels, kh, kw = w.shape
# Pre-normalize inputs.
if demodulate:
w = w * w.square().mean([1, 2, 3], keepdim=True).rsqrt()
s = s * s.square().mean().rsqrt()
# Modulate weights.
w = w.unsqueeze(0) # [NOIkk]
w = w * s.unsqueeze(1).unsqueeze(3).unsqueeze(4) # [NOIkk]
# Demodulate weights.
if demodulate:
dcoefs = (w.square().sum(dim=[2, 3, 4]) + 1e-8).rsqrt() # [NO]
w = w * dcoefs.unsqueeze(2).unsqueeze(3).unsqueeze(4) # [NOIkk]
# Apply input scaling.
if input_gain is not None:
input_gain = input_gain.expand(batch_size, in_channels) # [NI]
w = w * input_gain.unsqueeze(1).unsqueeze(3).unsqueeze(4) # [NOIkk]
# Execute as one fused op using grouped convolution.
x = x.reshape(1, -1, *x.shape[2:])
w = w.reshape(-1, in_channels, kh, kw)
x = conv2d_gradfix.conv2d(
input=x, weight=w.to(x.dtype), padding=padding, groups=batch_size)
x = x.reshape(batch_size, -1, *x.shape[2:])
return x
class FullyConnectedLayer(nn.Module):
"""Fully connected layer used in StyleGANv3.
Args:
in_features (int): Number of channels in the input feature.
out_features (int): Number of channels in the out feature.
activation (str, optional): Activation function with choices 'relu',
'lrelu', 'linear'. 'linear' means no extra activation.
Defaults to 'linear'.
bias (bool, optional): Whether to use additive bias. Defaults to True.
lr_multiplier (float, optional): Equalized learning rate multiplier.
Defaults to 1..
weight_init (float, optional): Weight multiplier for initialization.
Defaults to 1..
bias_init (float, optional): Initial bias. Defaults to 0..
"""
def __init__(self,
in_features,
out_features,
activation='linear',
bias=True,
lr_multiplier=1.,
weight_init=1.,
bias_init=0.):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.activation = activation
self.weight = torch.nn.Parameter(
torch.randn([out_features, in_features]) *
(weight_init / lr_multiplier))
bias_init = np.broadcast_to(
np.asarray(bias_init, dtype=np.float32), [out_features])
self.bias = torch.nn.Parameter(
torch.from_numpy(bias_init / lr_multiplier)) if bias else None
self.weight_gain = lr_multiplier / np.sqrt(in_features)
self.bias_gain = lr_multiplier
def forward(self, x):
"""Forward function."""
w = self.weight.to(x.dtype) * self.weight_gain
b = self.bias
if b is not None:
b = b.to(x.dtype)
if self.bias_gain != 1:
b = b * self.bias_gain
if self.activation == 'linear' and b is not None:
x = torch.addmm(b.unsqueeze(0), x, w.t())
else:
x = x.matmul(w.t())
x = bias_act.bias_act(x, b, act=self.activation)
return x
@MODULES.register_module()
class MappingNetwork(nn.Module):
"""Style mapping network used in StyleGAN3. The main difference between it
and styleganv1,v2 is that mean latent is registered as a buffer and dynamic
updated during training.
Args:
noise_size (int, optional): Size of the input noise vector.
c_dim (int, optional): Size of the input noise vector.
style_channels (int): The number of channels for style code.
num_ws (int): The repeat times of w latent.
num_layers (int, optional): The number of layers of mapping network.
Defaults to 2.
lr_multiplier (float, optional): Equalized learning rate multiplier.
Defaults to 0.01.
w_avg_beta (float, optional): The value used for update `w_avg`.
Defaults to 0.998.
"""
def __init__(self,
noise_size,
style_channels,
num_ws,
c_dim=0,
num_layers=2,
lr_multiplier=0.01,
w_avg_beta=0.998):
super().__init__()
self.noise_size = noise_size
self.c_dim = c_dim
self.style_channels = style_channels
self.num_ws = num_ws
self.num_layers = num_layers
self.w_avg_beta = w_avg_beta
# Construct layers.
self.embed = FullyConnectedLayer(
self.c_dim, self.style_channels) if self.c_dim > 0 else None
features = [
self.noise_size + (self.style_channels if self.c_dim > 0 else 0)
] + [self.style_channels] * self.num_layers
for idx, in_features, out_features in zip(
range(num_layers), features[:-1], features[1:]):
layer = FullyConnectedLayer(
in_features,
out_features,
activation='lrelu',
lr_multiplier=lr_multiplier)
setattr(self, f'fc{idx}', layer)
self.register_buffer('w_avg', torch.zeros([style_channels]))
def forward(self,
z,
c=None,
truncation=1,
num_truncation_layer=None,
update_emas=False):
"""Style mapping function.
Args:
z (torch.Tensor): Input noise tensor.
c (torch.Tensor, optional): Input label tensor. Defaults to None.
truncation (float, optional): Truncation factor. Give value less
than 1., the truncation trick will be adopted. Defaults to 1.
num_truncation_layer (int, optional): Number of layers use
truncated latent. Defaults to None.
update_emas (bool, optional): Whether update moving average of
mean latent. Defaults to False.
Returns:
torch.Tensor: W-plus latent.
"""
if num_truncation_layer is None:
num_truncation_layer = self.num_ws
# Embed, normalize, and concatenate inputs.
x = z.to(torch.float32)
x = x * (x.square().mean(1, keepdim=True) + 1e-8).rsqrt()
if self.c_dim > 0:
y = self.embed(c.to(torch.float32))
y = y * (y.square().mean(1, keepdim=True) + 1e-8).rsqrt()
x = torch.cat([x, y], dim=1) if x is not None else y
# Execute layers.
for idx in range(self.num_layers):
x = getattr(self, f'fc{idx}')(x)
# Update moving average of W.
if update_emas:
self.w_avg.copy_(x.detach().mean(dim=0).lerp(
self.w_avg, self.w_avg_beta))
# Broadcast and apply truncation.
x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
if truncation != 1:
x[:, :num_truncation_layer] = self.w_avg.lerp(
x[:, :num_truncation_layer], truncation)
return x
class SynthesisInput(nn.Module):
"""Module which generate input for synthesis layer.
Args:
style_channels (int): The number of channels for style code.
channels (int): The number of output channel.
size (int): The size of sampling grid.
sampling_rate (int): Sampling rate for construct sampling grid.
bandwidth (float): Bandwidth of random frequencies.
"""
def __init__(self, style_channels, channels, size, sampling_rate,
bandwidth):
super().__init__()
self.style_channels = style_channels
self.channels = channels
self.size = np.broadcast_to(np.asarray(size), [2])
self.sampling_rate = sampling_rate
self.bandwidth = bandwidth
# Draw random frequencies from uniform 2D disc.
freqs = torch.randn([self.channels, 2])
radii = freqs.square().sum(dim=1, keepdim=True).sqrt()
freqs /= radii * radii.square().exp().pow(0.25)
freqs *= bandwidth
phases = torch.rand([self.channels]) - 0.5
# Setup parameters and buffers.
self.weight = torch.nn.Parameter(
torch.randn([self.channels, self.channels]))
self.affine = FullyConnectedLayer(
style_channels, 4, weight_init=0, bias_init=[1, 0, 0, 0])
self.register_buffer('transform', torch.eye(
3, 3)) # User-specified inverse transform wrt. resulting image.
self.register_buffer('freqs', freqs)
self.register_buffer('phases', phases)
def forward(self, w):
"""Forward function."""
# Introduce batch dimension.
transforms = self.transform.unsqueeze(0) # [batch, row, col]
freqs = self.freqs.unsqueeze(0) # [batch, channel, xy]
phases = self.phases.unsqueeze(0) # [batch, channel]
# Apply learned transformation.
t = self.affine(w) # t = (r_c, r_s, t_x, t_y)
t = t / t[:, :2].norm(
dim=1, keepdim=True) # t' = (r'_c, r'_s, t'_x, t'_y)
m_r = torch.eye(
3, device=w.device).unsqueeze(0).repeat(
[w.shape[0], 1, 1]) # Inverse rotation wrt. resulting image.
m_r[:, 0, 0] = t[:, 0] # r'_c
m_r[:, 0, 1] = -t[:, 1] # r'_s
m_r[:, 1, 0] = t[:, 1] # r'_s
m_r[:, 1, 1] = t[:, 0] # r'_c
m_t = torch.eye(
3, device=w.device).unsqueeze(0).repeat(
[w.shape[0], 1,
1]) # Inverse translation wrt. resulting image.
m_t[:, 0, 2] = -t[:, 2] # t'_x
m_t[:, 1, 2] = -t[:, 3] # t'_y
# First rotate resulting image, then translate
# and finally apply user-specified transform.
transforms = m_r @ m_t @ transforms
# Transform frequencies.
phases = phases + (freqs @ transforms[:, :2, 2:]).squeeze(2)
freqs = freqs @ transforms[:, :2, :2]
# Dampen out-of-band frequencies
# that may occur due to the user-specified transform.
amplitudes = (1 - (freqs.norm(dim=2) - self.bandwidth) /
(self.sampling_rate / 2 - self.bandwidth)).clamp(0, 1)
# Construct sampling grid.
theta = torch.eye(2, 3, device=w.device)
theta[0, 0] = 0.5 * self.size[0] / self.sampling_rate
theta[1, 1] = 0.5 * self.size[1] / self.sampling_rate
grids = torch.nn.functional.affine_grid(
theta.unsqueeze(0), [1, 1, self.size[1], self.size[0]],
align_corners=False)
# Compute Fourier features.
x = (grids.unsqueeze(3) @ freqs.permute(
0, 2, 1).unsqueeze(1).unsqueeze(2)).squeeze(
3) # [batch, height, width, channel]
x = x + phases.unsqueeze(1).unsqueeze(2)
x = torch.sin(x * (np.pi * 2))
x = x * amplitudes.unsqueeze(1).unsqueeze(2)
# Apply trainable mapping.
weight = self.weight / np.sqrt(self.channels)
x = x @ weight.t()
# Ensure correct shape.
x = x.permute(0, 3, 1, 2) # [batch, channel, height, width]
return x
class SynthesisLayer(nn.Module):
"""Layer of Synthesis network for stylegan3.
Args:
style_channels (int): The number of channels for style code.
is_torgb (bool): Whether output of this layer is transformed to
rgb image.
is_critically_sampled (bool): Whether filter cutoff is set exactly
at the bandlimit.
use_fp16 (bool, optional): Whether to use fp16 training in this
module. If this flag is `True`, the whole module will be wrapped
with ``auto_fp16``.
in_channels (int): The channel number of the input feature map.
out_channels (int): The channel number of the output feature map.
in_size (int): The input size of feature map.
out_size (int): The output size of feature map.
in_sampling_rate (int): Sampling rate for upsampling filter.
out_sampling_rate (int): Sampling rate for downsampling filter.
in_cutoff (float): Cutoff frequency for upsampling filter.
out_cutoff (float): Cutoff frequency for downsampling filter.
in_half_width (float): The approximate width of the transition region
for upsampling filter.
out_half_width (float): The approximate width of the transition region
for downsampling filter.
conv_kernel (int, optional): The kernel of modulated convolution.
Defaults to 3.
filter_size (int, optional): Base filter size. Defaults to 6.
lrelu_upsampling (int, optional): Upsamling rate for `filtered_lrelu`.
Defaults to 2.
use_radial_filters (bool, optional): Whether use radially symmetric
jinc-based filter in downsamping filter. Defaults to False.
conv_clamp (int, optional): Clamp bound for convolution.
Defaults to 256.
magnitude_ema_beta (float, optional): Beta coefficient for calculating
input magnitude ema. Defaults to 0.999.
"""
def __init__(
self,
style_channels,
is_torgb,
is_critically_sampled,
use_fp16,
in_channels,
out_channels,
in_size,
out_size,
in_sampling_rate,
out_sampling_rate,
in_cutoff,
out_cutoff,
in_half_width,
out_half_width,
conv_kernel=3,
filter_size=6,
lrelu_upsampling=2,
use_radial_filters=False,
conv_clamp=256,
magnitude_ema_beta=0.999,
):
super().__init__()
self.style_channels = style_channels
self.is_torgb = is_torgb
self.is_critically_sampled = is_critically_sampled
self.use_fp16 = use_fp16
self.in_channels = in_channels
self.out_channels = out_channels
self.in_size = np.broadcast_to(np.asarray(in_size), [2])
self.out_size = np.broadcast_to(np.asarray(out_size), [2])
self.in_sampling_rate = in_sampling_rate
self.out_sampling_rate = out_sampling_rate
self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (
1 if is_torgb else lrelu_upsampling)
self.in_cutoff = in_cutoff
self.out_cutoff = out_cutoff
self.in_half_width = in_half_width
self.out_half_width = out_half_width
self.conv_kernel = 1 if is_torgb else conv_kernel
self.conv_clamp = conv_clamp
self.magnitude_ema_beta = magnitude_ema_beta
# Setup parameters and buffers.
self.affine = FullyConnectedLayer(
self.style_channels, self.in_channels, bias_init=1)
self.weight = torch.nn.Parameter(
torch.randn([
self.out_channels, self.in_channels, self.conv_kernel,
self.conv_kernel
]))
self.bias = torch.nn.Parameter(torch.zeros([self.out_channels]))
self.register_buffer('magnitude_ema', torch.ones([]))
# Design upsampling filter.
self.up_factor = int(
np.rint(self.tmp_sampling_rate / self.in_sampling_rate))
assert self.in_sampling_rate * self.up_factor == self.tmp_sampling_rate
self.up_taps = (
filter_size *
self.up_factor if self.up_factor > 1 and not self.is_torgb else 1)
self.register_buffer(
'up_filter',
self.design_lowpass_filter(
numtaps=self.up_taps,
cutoff=self.in_cutoff,
width=self.in_half_width * 2,
fs=self.tmp_sampling_rate))
# Design downsampling filter.
self.down_factor = int(
np.rint(self.tmp_sampling_rate / self.out_sampling_rate))
assert (self.out_sampling_rate *
self.down_factor == self.tmp_sampling_rate)
self.down_taps = (
filter_size * self.down_factor
if self.down_factor > 1 and not self.is_torgb else 1)
self.down_radial = (
use_radial_filters and not self.is_critically_sampled)
self.register_buffer(
'down_filter',
self.design_lowpass_filter(
numtaps=self.down_taps,
cutoff=self.out_cutoff,
width=self.out_half_width * 2,
fs=self.tmp_sampling_rate,
radial=self.down_radial))
# Compute padding.
pad_total = (
self.out_size - 1
) * self.down_factor + 1 # Desired output size before downsampling.
pad_total -= (self.in_size + self.conv_kernel -
1) * self.up_factor # Input size after upsampling.
pad_total += self.up_taps + self.down_taps - 2
pad_lo = (pad_total + self.up_factor) // 2
pad_hi = pad_total - pad_lo
self.padding = [
int(pad_lo[0]),
int(pad_hi[0]),
int(pad_lo[1]),
int(pad_hi[1])
]
def forward(self, x, w, force_fp32=False, update_emas=False):
"""Forward function for synthesis layer.
Args:
x (torch.Tensor): Input feature map tensor.
w (torch.Tensor): Input style tensor.
force_fp32 (bool, optional): Force fp32 ignore the weights.
Defaults to True.
update_emas (bool, optional): Whether update moving average of
input magnitude. Defaults to False.
Returns:
torch.Tensor: Output feature map tensor.
"""
# Track input magnitude.
if update_emas:
with torch.autograd.profiler.record_function(
'update_magnitude_ema'):
magnitude_cur = x.detach().to(torch.float32).square().mean()
self.magnitude_ema.copy_(
magnitude_cur.lerp(self.magnitude_ema,
self.magnitude_ema_beta))
input_gain = self.magnitude_ema.rsqrt()
# Execute affine layer.
styles = self.affine(w)
if self.is_torgb:
weight_gain = 1 / np.sqrt(self.in_channels * (self.conv_kernel**2))
styles = styles * weight_gain
# Execute modulated conv2d.
dtype = torch.float16 if (self.use_fp16 and not force_fp32 and
x.device.type == 'cuda') else torch.float32
x = modulated_conv2d(
x=x.to(dtype),
w=self.weight,
s=styles,
padding=self.conv_kernel - 1,
demodulate=(not self.is_torgb),
input_gain=input_gain)
# Execute bias, filtered leaky ReLU, and clamping.
gain = 1 if self.is_torgb else np.sqrt(2)
slope = 1 if self.is_torgb else 0.2
x = filtered_lrelu.filtered_lrelu(
x=x,
fu=self.up_filter,
fd=self.down_filter,
b=self.bias.to(x.dtype),
up=self.up_factor,
down=self.down_factor,
padding=self.padding,
gain=gain,
slope=slope,
clamp=self.conv_clamp)
# Ensure correct shape and dtype.
assert x.dtype == dtype
return x
@staticmethod
def design_lowpass_filter(numtaps, cutoff, width, fs, radial=False):
"""Design lowpass filter giving related arguments.
Args:
numtaps (int): Length of the filter. `numtaps` must be odd if a
passband includes the Nyquist frequency.
cutoff (float): Cutoff frequency of filter
width (float): The approximate width of the transition region.
fs (float): The sampling frequency of the signal.
radial (bool, optional): Whether use radially symmetric jinc-based
filter. Defaults to False.
Returns:
torch.Tensor: Kernel of lowpass filter.
"""
assert numtaps >= 1
# Identity filter.
if numtaps == 1:
return None
# Separable Kaiser low-pass filter.
if not radial:
f = scipy.signal.firwin(
numtaps=numtaps, cutoff=cutoff, width=width, fs=fs)
return torch.as_tensor(f, dtype=torch.float32)
# Radially symmetric jinc-based filter.
x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs
r = np.hypot(*np.meshgrid(x, x))
f = scipy.special.j1(2 * cutoff * (np.pi * r)) / (np.pi * r)
beta = scipy.signal.kaiser_beta(
scipy.signal.kaiser_atten(numtaps, width / (fs / 2)))
w = np.kaiser(numtaps, beta)
f *= np.outer(w, w)
f /= np.sum(f)
return torch.as_tensor(f, dtype=torch.float32)
@MODULES.register_module()
class SynthesisNetwork(nn.Module):
"""Synthesis network for stylegan3.
Args:
style_channels (int): The number of channels for style code.
out_size (int): The resolution of output image.
img_channels (int): The number of channels for output image.
channel_base (int, optional): Overall multiplier for the number of
channels. Defaults to 32768.
channel_max (int, optional): Maximum number of channels in any layer.
Defaults to 512.
num_layers (int, optional): Total number of layers, excluding Fourier
features and ToRGB. Defaults to 14.
num_critical (int, optional): Number of critically sampled layers at
the end. Defaults to 2.
first_cutoff (int, optional): Cutoff frequency of the first layer.
Defaults to 2.
first_stopband (int, optional): Minimum stopband of the first layer.
Defaults to 2**2.1.
last_stopband_rel (float, optional): Minimum stopband of the last
layer, expressed relative to the cutoff. Defaults to 2**0.3.
margin_size (int, optional): Number of additional pixels outside the
image. Defaults to 10.
output_scale (float, optional): Scale factor for output value.
Defaults to 0.25.
num_fp16_res (int, optional): Number of first few layers use fp16.
Defaults to 4.
"""
def __init__(
self,
style_channels,
out_size,
img_channels,
channel_base=32768,
channel_max=512,
num_layers=14,
num_critical=2,
first_cutoff=2,
first_stopband=2**2.1,
last_stopband_rel=2**0.3,
margin_size=10,
output_scale=0.25,
num_fp16_res=4,
**layer_kwargs,
):
super().__init__()
self.style_channels = style_channels
self.num_ws = num_layers + 2
self.out_size = out_size
self.img_channels = img_channels
self.num_layers = num_layers
self.num_critical = num_critical
self.margin_size = margin_size
self.output_scale = output_scale
self.num_fp16_res = num_fp16_res
# Geometric progression of layer cutoffs and min. stopbands.
last_cutoff = self.out_size / 2 # f_{c,N}
last_stopband = last_cutoff * last_stopband_rel # f_{t,N}
exponents = np.minimum(
np.arange(self.num_layers + 1) /
(self.num_layers - self.num_critical), 1)
cutoffs = first_cutoff * (last_cutoff /
first_cutoff)**exponents # f_c[i]
stopbands = first_stopband * (last_stopband /
first_stopband)**exponents # f_t[i]
# Compute remaining layer parameters.
sampling_rates = np.exp2(
np.ceil(np.log2(np.minimum(stopbands * 2, self.out_size)))) # s[i]
half_widths = np.maximum(stopbands,
sampling_rates / 2) - cutoffs # f_h[i]
sizes = sampling_rates + self.margin_size * 2
sizes[-2:] = self.out_size
channels = np.rint(
np.minimum((channel_base / 2) / cutoffs, channel_max))
channels[-1] = self.img_channels
# Construct layers.
self.input = SynthesisInput(
style_channels=self.style_channels,
channels=int(channels[0]),
size=int(sizes[0]),
sampling_rate=sampling_rates[0],
bandwidth=cutoffs[0])
self.layer_names = []
for idx in range(self.num_layers + 1):
prev = max(idx - 1, 0)
is_torgb = (idx == self.num_layers)
is_critically_sampled = (
idx >= self.num_layers - self.num_critical)
use_fp16 = (
sampling_rates[idx] * (2**self.num_fp16_res) > self.out_size)
layer = SynthesisLayer(
style_channels=self.style_channels,
is_torgb=is_torgb,
is_critically_sampled=is_critically_sampled,
use_fp16=use_fp16,
in_channels=int(channels[prev]),
out_channels=int(channels[idx]),
in_size=int(sizes[prev]),
out_size=int(sizes[idx]),
in_sampling_rate=int(sampling_rates[prev]),
out_sampling_rate=int(sampling_rates[idx]),
in_cutoff=cutoffs[prev],
out_cutoff=cutoffs[idx],
in_half_width=half_widths[prev],
out_half_width=half_widths[idx],
**layer_kwargs)
name = f'L{idx}_{layer.out_size[0]}_{layer.out_channels}'
setattr(self, name, layer)
self.layer_names.append(name)
def forward(self, ws, **layer_kwargs):
"""Forward function."""
ws = ws.to(torch.float32).unbind(dim=1)
# Execute layers.
x = self.input(ws[0])
for name, w in zip(self.layer_names, ws[1:]):
x = getattr(self, name)(x, w, **layer_kwargs)
if self.output_scale != 1:
x = x * self.output_scale
# Ensure correct shape and dtype.
x = x.to(torch.float32)
return x
# Copyright (c) OpenMMLab. All rights reserved.
import random
from copy import deepcopy
import mmcv
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmgen.models.architectures import PixelNorm
from mmgen.models.architectures.common import get_module_device
from mmgen.models.builder import MODULES, build_module
from .modules.styleganv2_modules import (ConstantInput, ConvDownLayer,
EqualLinearActModule,
ModMBStddevLayer,
ModulatedPEStyleConv, ModulatedToRGB,
ResBlock)
from .utils import get_mean_latent, style_mixing
@MODULES.register_module()
class MSStyleGANv2Generator(nn.Module):
"""StyleGAN2 Generator.
In StyleGAN2, we use a static architecture composing of a style mapping
module and number of convolutional style blocks. More details can be found
in: Analyzing and Improving the Image Quality of StyleGAN CVPR2020.
Args:
out_size (int): The output size of the StyleGAN2 generator.
style_channels (int): The number of channels for style code.
num_mlps (int, optional): The number of MLP layers. Defaults to 8.
channel_multiplier (int, optional): The multiplier factor for the
channel number. Defaults to 2.
blur_kernel (list, optional): The blurry kernel. Defaults
to [1, 3, 3, 1].
lr_mlp (float, optional): The learning rate for the style mapping
layer. Defaults to 0.01.
default_style_mode (str, optional): The default mode of style mixing.
In training, we defaultly adopt mixing style mode. However, in the
evaluation, we use 'single' style mode. `['mix', 'single']` are
currently supported. Defaults to 'mix'.
eval_style_mode (str, optional): The evaluation mode of style mixing.
Defaults to 'single'.
mix_prob (float, optional): Mixing probability. The value should be
in range of [0, 1]. Defaults to 0.9.
"""
def __init__(self,
out_size,
style_channels,
num_mlps=8,
channel_multiplier=2,
blur_kernel=[1, 3, 3, 1],
lr_mlp=0.01,
default_style_mode='mix',
eval_style_mode='single',
mix_prob=0.9,
no_pad=False,
deconv2conv=False,
interp_pad=None,
up_config=dict(scale_factor=2, mode='nearest'),
up_after_conv=False,
head_pos_encoding=None,
head_pos_size=(4, 4),
interp_head=False):
super().__init__()
self.out_size = out_size
self.style_channels = style_channels
self.num_mlps = num_mlps
self.channel_multiplier = channel_multiplier
self.lr_mlp = lr_mlp
self._default_style_mode = default_style_mode
self.default_style_mode = default_style_mode
self.eval_style_mode = eval_style_mode
self.mix_prob = mix_prob
self.no_pad = no_pad
self.deconv2conv = deconv2conv
self.interp_pad = interp_pad
self.with_interp_pad = interp_pad is not None
self.up_config = deepcopy(up_config)
self.up_after_conv = up_after_conv
self.head_pos_encoding = head_pos_encoding
self.head_pos_size = head_pos_size
self.interp_head = interp_head
# define style mapping layers
mapping_layers = [PixelNorm()]
for _ in range(num_mlps):
mapping_layers.append(
EqualLinearActModule(
style_channels,
style_channels,
equalized_lr_cfg=dict(lr_mul=lr_mlp, gain=1.),
act_cfg=dict(type='fused_bias')))
self.style_mapping = nn.Sequential(*mapping_layers)
self.channels = {
4: 512,
8: 512,
16: 512,
32: 512,
64: 256 * channel_multiplier,
128: 128 * channel_multiplier,
256: 64 * channel_multiplier,
512: 32 * channel_multiplier,
1024: 16 * channel_multiplier,
}
in_ch = self.channels[4]
# constant input layer
if self.head_pos_encoding:
if self.head_pos_encoding['type'] in [
'CatersianGrid', 'CSG', 'CSG2d'
]:
in_ch = 2
self.head_pos_enc = build_module(self.head_pos_encoding)
else:
size_ = 4
if self.no_pad:
size_ += 2
self.constant_input = ConstantInput(self.channels[4], size=size_)
# 4x4 stage
self.conv1 = ModulatedPEStyleConv(
in_ch,
self.channels[4],
kernel_size=3,
style_channels=style_channels,
blur_kernel=blur_kernel,
deconv2conv=self.deconv2conv,
no_pad=self.no_pad,
up_config=self.up_config,
interp_pad=self.interp_pad)
self.to_rgb1 = ModulatedToRGB(
self.channels[4], style_channels, upsample=False)
# generator backbone (8x8 --> higher resolutions)
self.log_size = int(np.log2(self.out_size))
self.convs = nn.ModuleList()
self.upsamples = nn.ModuleList()
self.to_rgbs = nn.ModuleList()
in_channels_ = self.channels[4]
for i in range(3, self.log_size + 1):
out_channels_ = self.channels[2**i]
self.convs.append(
ModulatedPEStyleConv(
in_channels_,
out_channels_,
3,
style_channels,
upsample=True,
blur_kernel=blur_kernel,
deconv2conv=self.deconv2conv,
no_pad=self.no_pad,
up_config=self.up_config,
interp_pad=self.interp_pad,
up_after_conv=self.up_after_conv))
self.convs.append(
ModulatedPEStyleConv(
out_channels_,
out_channels_,
3,
style_channels,
upsample=False,
blur_kernel=blur_kernel,
deconv2conv=self.deconv2conv,
no_pad=self.no_pad,
up_config=self.up_config,
interp_pad=self.interp_pad,
up_after_conv=self.up_after_conv))
self.to_rgbs.append(
ModulatedToRGB(out_channels_, style_channels, upsample=True))
in_channels_ = out_channels_
self.num_latents = self.log_size * 2 - 2
self.num_injected_noises = self.num_latents - 1
# register buffer for injected noises
noises = self.make_injected_noise()
for layer_idx in range(self.num_injected_noises):
self.register_buffer(f'injected_noise_{layer_idx}',
noises[layer_idx])
def train(self, mode=True):
if mode:
if self.default_style_mode != self._default_style_mode:
mmcv.print_log(
f'Switch to train style mode: {self._default_style_mode}',
'mmgen')
self.default_style_mode = self._default_style_mode
else:
if self.default_style_mode != self.eval_style_mode:
mmcv.print_log(
f'Switch to evaluation style mode: {self.eval_style_mode}',
'mmgen')
self.default_style_mode = self.eval_style_mode
return super(MSStyleGANv2Generator, self).train(mode)
def make_injected_noise(self, chosen_scale=0):
device = get_module_device(self)
base_scale = 2**2 + chosen_scale
noises = [torch.randn(1, 1, base_scale, base_scale, device=device)]
for i in range(3, self.log_size + 1):
for n in range(2):
_pad = 0
if self.no_pad and not self.up_after_conv and n == 0:
_pad = 2
noises.append(
torch.randn(
1,
1,
base_scale * 2**(i - 2) + _pad,
base_scale * 2**(i - 2) + _pad,
device=device))
return noises
def get_mean_latent(self, num_samples=4096, **kwargs):
"""Get mean latent of W space in this generator.
Args:
num_samples (int, optional): Number of sample times. Defaults
to 4096.
Returns:
Tensor: Mean latent of this generator.
"""
return get_mean_latent(self, num_samples, **kwargs)
def style_mixing(self,
n_source,
n_target,
inject_index=1,
truncation_latent=None,
truncation=0.7,
chosen_scale=0):
return style_mixing(
self,
n_source=n_source,
n_target=n_target,
inject_index=inject_index,
truncation_latent=truncation_latent,
truncation=truncation,
style_channels=self.style_channels,
chosen_scale=chosen_scale)
def forward(self,
styles,
num_batches=-1,
return_noise=False,
return_latents=False,
inject_index=None,
truncation=1,
truncation_latent=None,
input_is_latent=False,
injected_noise=None,
randomize_noise=True,
chosen_scale=0):
"""Forward function.
This function has been integrated with the truncation trick. Please
refer to the usage of `truncation` and `truncation_latent`.
Args:
styles (torch.Tensor | list[torch.Tensor] | callable | None): In
StyleGAN2, you can provide noise tensor or latent tensor. Given
a list containing more than one noise or latent tensors, style
mixing trick will be used in training. Of course, You can
directly give a batch of noise through a ``torch.Tensor`` or
offer a callable function to sample a batch of noise data.
Otherwise, the ``None`` indicates to use the default noise
sampler.
num_batches (int, optional): The number of batch size.
Defaults to 0.
return_noise (bool, optional): If True, ``noise_batch`` will be
returned in a dict with ``fake_img``. Defaults to False.
return_latents (bool, optional): If True, ``latent`` will be
returned in a dict with ``fake_img``. Defaults to False.
inject_index (int | None, optional): The index number for mixing
style codes. Defaults to None.
truncation (float, optional): Truncation factor. Give value less
than 1., the truncation trick will be adopted. Defaults to 1.
truncation_latent (torch.Tensor, optional): Mean truncation latent.
Defaults to None.
input_is_latent (bool, optional): If `True`, the input tensor is
the latent tensor. Defaults to False.
injected_noise (torch.Tensor | None, optional): Given a tensor, the
random noise will be fixed as this input injected noise.
Defaults to None.
randomize_noise (bool, optional): If `False`, images are sampled
with the buffered noise tensor injected to the style conv
block. Defaults to True.
Returns:
torch.Tensor | dict: Generated image tensor or dictionary \
containing more data.
"""
# receive noise and conduct sanity check.
if isinstance(styles, torch.Tensor):
assert styles.shape[1] == self.style_channels
styles = [styles]
elif mmcv.is_seq_of(styles, torch.Tensor):
for t in styles:
assert t.shape[-1] == self.style_channels
# receive a noise generator and sample noise.
elif callable(styles):
device = get_module_device(self)
noise_generator = styles
assert num_batches > 0
if self.default_style_mode == 'mix' and random.random(
) < self.mix_prob:
styles = [
noise_generator((num_batches, self.style_channels))
for _ in range(2)
]
else:
styles = [noise_generator((num_batches, self.style_channels))]
styles = [s.to(device) for s in styles]
# otherwise, we will adopt default noise sampler.
else:
device = get_module_device(self)
assert num_batches > 0 and not input_is_latent
if self.default_style_mode == 'mix' and random.random(
) < self.mix_prob:
styles = [
torch.randn((num_batches, self.style_channels))
for _ in range(2)
]
else:
styles = [torch.randn((num_batches, self.style_channels))]
styles = [s.to(device) for s in styles]
if not input_is_latent:
noise_batch = styles
styles = [self.style_mapping(s) for s in styles]
else:
noise_batch = None
if injected_noise is None:
if randomize_noise:
injected_noise = [None] * self.num_injected_noises
elif chosen_scale > 0:
if not hasattr(self, f'injected_noise_{chosen_scale}_0'):
noises_ = self.make_injected_noise(chosen_scale)
for i in range(self.num_injected_noises):
setattr(self, f'injected_noise_{chosen_scale}_{i}',
noises_[i])
injected_noise = [
getattr(self, f'injected_noise_{chosen_scale}_{i}')
for i in range(self.num_injected_noises)
]
else:
injected_noise = [
getattr(self, f'injected_noise_{i}')
for i in range(self.num_injected_noises)
]
# use truncation trick
if truncation < 1:
style_t = []
# calculate truncation latent on the fly
if truncation_latent is None and not hasattr(
self, 'truncation_latent'):
self.truncation_latent = self.get_mean_latent()
truncation_latent = self.truncation_latent
elif truncation_latent is None and hasattr(self,
'truncation_latent'):
truncation_latent = self.truncation_latent
for style in styles:
style_t.append(truncation_latent + truncation *
(style - truncation_latent))
styles = style_t
# no style mixing
if len(styles) < 2:
inject_index = self.num_latents
if styles[0].ndim < 3:
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
else:
latent = styles[0]
# style mixing
else:
if inject_index is None:
inject_index = random.randint(1, self.num_latents - 1)
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
latent2 = styles[1].unsqueeze(1).repeat(
1, self.num_latents - inject_index, 1)
latent = torch.cat([latent, latent2], 1)
if isinstance(chosen_scale, int):
chosen_scale = (chosen_scale, chosen_scale)
# 4x4 stage
if self.head_pos_encoding:
if self.interp_head:
out = self.head_pos_enc.make_grid2d(self.head_pos_size[0],
self.head_pos_size[1],
latent.size(0))
h_in = self.head_pos_size[0] + chosen_scale[0]
w_in = self.head_pos_size[1] + chosen_scale[1]
out = F.interpolate(
out,
size=(h_in, w_in),
mode='bilinear',
align_corners=True)
else:
out = self.head_pos_enc.make_grid2d(
self.head_pos_size[0] + chosen_scale[0],
self.head_pos_size[1] + chosen_scale[1], latent.size(0))
out = out.to(latent)
else:
out = self.constant_input(latent)
if chosen_scale[0] != 0 or chosen_scale[1] != 0:
out = F.interpolate(
out,
size=(out.shape[2] + chosen_scale[0],
out.shape[3] + chosen_scale[1]),
mode='bilinear',
align_corners=True)
out = self.conv1(out, latent[:, 0], noise=injected_noise[0])
skip = self.to_rgb1(out, latent[:, 1])
_index = 1
# 8x8 ---> higher resolutions
for up_conv, conv, noise1, noise2, to_rgb in zip(
self.convs[::2], self.convs[1::2], injected_noise[1::2],
injected_noise[2::2], self.to_rgbs):
out = up_conv(out, latent[:, _index], noise=noise1)
out = conv(out, latent[:, _index + 1], noise=noise2)
skip = to_rgb(out, latent[:, _index + 2], skip)
_index += 2
img = skip
if return_latents or return_noise:
output_dict = dict(
fake_img=img,
latent=latent,
inject_index=inject_index,
noise_batch=noise_batch,
injected_noise=injected_noise)
return output_dict
return img
@MODULES.register_module()
class MSStyleGAN2Discriminator(nn.Module):
"""StyleGAN2 Discriminator.
The architecture of this discriminator is proposed in StyleGAN2. More
details can be found in: Analyzing and Improving the Image Quality of
StyleGAN CVPR2020.
Args:
in_size (int): The input size of images.
channel_multiplier (int, optional): The multiplier factor for the
channel number. Defaults to 2.
blur_kernel (list, optional): The blurry kernel. Defaults
to [1, 3, 3, 1].
mbstd_cfg (dict, optional): Configs for minibatch-stddev layer.
Defaults to dict(group_size=4, channel_groups=1).
"""
def __init__(self,
in_size,
channel_multiplier=2,
blur_kernel=[1, 3, 3, 1],
mbstd_cfg=dict(group_size=4, channel_groups=1),
with_adaptive_pool=False,
pool_size=(2, 2)):
super().__init__()
self.with_adaptive_pool = with_adaptive_pool
self.pool_size = pool_size
channels = {
4: 512,
8: 512,
16: 512,
32: 512,
64: 256 * channel_multiplier,
128: 128 * channel_multiplier,
256: 64 * channel_multiplier,
512: 32 * channel_multiplier,
1024: 16 * channel_multiplier,
}
log_size = int(np.log2(in_size))
in_channels = channels[in_size]
convs = [ConvDownLayer(3, channels[in_size], 1)]
for i in range(log_size, 2, -1):
out_channel = channels[2**(i - 1)]
convs.append(ResBlock(in_channels, out_channel, blur_kernel))
in_channels = out_channel
self.convs = nn.Sequential(*convs)
self.mbstd_layer = ModMBStddevLayer(**mbstd_cfg)
self.final_conv = ConvDownLayer(in_channels + 1, channels[4], 3)
if self.with_adaptive_pool:
self.adaptive_pool = nn.AdaptiveAvgPool2d(pool_size)
linear_in_channels = channels[4] * pool_size[0] * pool_size[1]
else:
linear_in_channels = channels[4] * 4 * 4
self.final_linear = nn.Sequential(
EqualLinearActModule(
linear_in_channels,
channels[4],
act_cfg=dict(type='fused_bias')),
EqualLinearActModule(channels[4], 1),
)
def forward(self, x):
"""Forward function.
Args:
x (torch.Tensor): Input image tensor.
Returns:
torch.Tensor: Predict score for the input image.
"""
x = self.convs(x)
x = self.mbstd_layer(x)
x = self.final_conv(x)
if self.with_adaptive_pool:
x = self.adaptive_pool(x)
x = x.view(x.shape[0], -1)
x = self.final_linear(x)
return x
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from ..common import get_module_device
@torch.no_grad()
def get_mean_latent(generator, num_samples=4096, bs_per_repeat=1024):
"""Get mean latent of W space in Style-based GANs.
Args:
generator (nn.Module): Generator of a Style-based GAN.
num_samples (int, optional): Number of sample times. Defaults to 4096.
bs_per_repeat (int, optional): Batch size of noises per sample.
Defaults to 1024.
Returns:
Tensor: Mean latent of this generator.
"""
device = get_module_device(generator)
mean_style = None
n_repeat = num_samples // bs_per_repeat
assert n_repeat * bs_per_repeat == num_samples
for _ in range(n_repeat):
style = generator.style_mapping(
torch.randn(bs_per_repeat,
generator.style_channels).to(device)).mean(
0, keepdim=True)
if mean_style is None:
mean_style = style
else:
mean_style += style
mean_style /= float(n_repeat)
return mean_style
@torch.no_grad()
def style_mixing(generator,
n_source,
n_target,
inject_index=1,
truncation_latent=None,
truncation=0.7,
style_channels=512,
**kwargs):
device = get_module_device(generator)
source_code = torch.randn(n_source, style_channels).to(device)
target_code = torch.randn(n_target, style_channels).to(device)
source_image = generator(
source_code,
truncation_latent=truncation_latent,
truncation=truncation,
**kwargs)
h, w = source_image.shape[-2:]
images = [torch.ones(1, 3, h, w).to(device) * -1]
target_image = generator(
target_code,
truncation_latent=truncation_latent,
truncation=truncation,
**kwargs)
images.append(source_image)
for i in range(n_target):
image = generator(
[target_code[i].unsqueeze(0).repeat(n_source, 1), source_code],
truncation_latent=truncation_latent,
truncation=truncation,
inject_index=inject_index,
**kwargs)
images.append(target_image[i].unsqueeze(0))
images.append(image)
images = torch.cat(images, 0)
return images
# Copyright (c) OpenMMLab. All rights reserved.
from .generator_discriminator import WGANGPDiscriminator, WGANGPGenerator
__all__ = ['WGANGPDiscriminator', 'WGANGPGenerator']
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.cnn.bricks.upsample import build_upsample_layer
from mmgen.models.builder import MODULES
from ..common import get_module_device
from .modules import ConvLNModule, WGANDecisionHead, WGANNoiseTo2DFeat
@MODULES.register_module()
class WGANGPGenerator(nn.Module):
r"""Generator for WGANGP.
Implementation Details for WGANGP generator the same as training
configuration (a) described in PGGAN paper:
PROGRESSIVE GROWING OF GANS FOR IMPROVED QUALITY, STABILITY, AND VARIATION
https://research.nvidia.com/sites/default/files/pubs/2017-10_Progressive-Growing-of/karras2018iclr-paper.pdf # noqa
#. Adopt convolution architecture specified in appendix A.2;
#. Use batchnorm in the generator except for the final output layer;
#. Use ReLU in the generator except for the final output layer;
#. Use Tanh in the last layer;
#. Initialize all weights using He’s initializer.
Args:
noise_size (int): Size of the input noise vector.
out_scale (int): Output scale for the generated image.
conv_module_cfg (dict, optional): Config for the convolution
module used in this generator. Defaults to None.
upsample_cfg (dict, optional): Config for the upsampling operation.
Defaults to None.
"""
_default_channels_per_scale = {
'4': 512,
'8': 512,
'16': 256,
'32': 128,
'64': 64,
'128': 32
}
_default_conv_module_cfg = dict(
conv_cfg=None,
kernel_size=3,
stride=1,
padding=1,
bias=True,
act_cfg=dict(type='ReLU'),
norm_cfg=dict(type='BN'),
order=('conv', 'norm', 'act'))
_default_upsample_cfg = dict(type='nearest', scale_factor=2)
def __init__(self,
noise_size,
out_scale,
conv_module_cfg=None,
upsample_cfg=None):
super().__init__()
# set initial params
self.noise_size = noise_size
self.out_scale = out_scale
self.conv_module_cfg = deepcopy(self._default_conv_module_cfg)
if conv_module_cfg is not None:
self.conv_module_cfg.update(conv_module_cfg)
self.upsample_cfg = upsample_cfg if upsample_cfg else deepcopy(
self._default_upsample_cfg)
# set noise2feat head
self.noise2feat = WGANNoiseTo2DFeat(
self.noise_size, self._default_channels_per_scale['4'])
# set conv_blocks
self.conv_blocks = nn.ModuleList()
self.conv_blocks.append(ConvModule(512, 512, **self.conv_module_cfg))
log2scale = int(np.log2(self.out_scale))
for i in range(3, log2scale + 1):
self.conv_blocks.append(
build_upsample_layer(self._default_upsample_cfg))
self.conv_blocks.append(
ConvModule(self._default_channels_per_scale[str(2**(i - 1))],
self._default_channels_per_scale[str(2**i)],
**self.conv_module_cfg))
self.conv_blocks.append(
ConvModule(self._default_channels_per_scale[str(2**i)],
self._default_channels_per_scale[str(2**i)],
**self.conv_module_cfg))
self.to_rgb = ConvModule(
self._default_channels_per_scale[str(self.out_scale)],
kernel_size=1,
out_channels=3,
act_cfg=dict(type='Tanh'))
def forward(self, noise, num_batches=0, return_noise=False):
"""Forward function.
Args:
noise (torch.Tensor | callable | None): You can directly give a
batch of noise through a ``torch.Tensor`` or offer a callable
function to sample a batch of noise data. Otherwise, the
``None`` indicates to use the default noise sampler.
num_batches (int, optional): The number of batch size. Defaults to
0.
return_noise (bool, optional): If True, ``noise_batch`` will be
returned in a dict with ``fake_img``. Defaults to False.
Returns:
torch.Tensor | dict: If not ``return_noise``, only the output image
will be returned. Otherwise, a dict contains ``fake_img`` and
``noise_batch`` will be returned.
"""
# receive noise and conduct sanity check.
if isinstance(noise, torch.Tensor):
assert noise.shape[1] == self.noise_size
assert noise.ndim == 2, ('The noise should be in shape of (n, c), '
f'but got {noise.shape}')
noise_batch = noise
# receive a noise generator and sample noise.
elif callable(noise):
noise_generator = noise
assert num_batches > 0
noise_batch = noise_generator((num_batches, self.noise_size))
# otherwise, we will adopt default noise sampler.
else:
assert num_batches > 0
noise_batch = torch.randn((num_batches, self.noise_size))
# dirty code for putting data on the right device
noise_batch = noise_batch.to(get_module_device(self))
# noise vector to 2D feature
x = self.noise2feat(noise_batch)
for conv in self.conv_blocks:
x = conv(x)
out_img = self.to_rgb(x)
if return_noise:
output = dict(fake_img=out_img, noise_batch=noise_batch)
return output
return out_img
@MODULES.register_module()
class WGANGPDiscriminator(nn.Module):
r"""Discriminator for WGANGP.
Implementation Details for WGANGP discriminator the same as training
configuration (a) described in PGGAN paper:
PROGRESSIVE GROWING OF GANS FOR IMPROVED QUALITY, STABILITY, AND VARIATION
https://research.nvidia.com/sites/default/files/pubs/2017-10_Progressive-Growing-of/karras2018iclr-paper.pdf # noqa
#. Adopt convolution architecture specified in appendix A.2;
#. Add layer normalization to all conv3x3 and conv4x4 layers;
#. Use LeakyReLU in the discriminator except for the final output layer;
#. Initialize all weights using He’s initializer.
Args:
in_channel (int): The channel number of the input image.
in_scale (int): The scale of the input image.
conv_module_cfg (dict, optional): Config for the convolution module
used in this discriminator. Defaults to None.
"""
_default_channels_per_scale = {
'4': 512,
'8': 512,
'16': 256,
'32': 128,
'64': 64,
'128': 32
}
_default_conv_module_cfg = dict(
conv_cfg=None,
kernel_size=3,
stride=1,
padding=1,
bias=True,
act_cfg=dict(type='LeakyReLU', negative_slope=0.2),
norm_cfg=dict(type='LN2d'),
order=('conv', 'norm', 'act'))
_default_upsample_cfg = dict(type='nearest', scale_factor=2)
def __init__(self, in_channel, in_scale, conv_module_cfg=None):
super().__init__()
# set initial params
self.in_channel = in_channel
self.in_scale = in_scale
self.conv_module_cfg = deepcopy(self._default_conv_module_cfg)
if conv_module_cfg is not None:
self.conv_module_cfg.update(conv_module_cfg)
# set from_rgb head
self.from_rgb = ConvModule(
3,
kernel_size=1,
out_channels=self._default_channels_per_scale[str(self.in_scale)],
act_cfg=dict(type='LeakyReLU', negative_slope=0.2))
# set conv_blocks
self.conv_blocks = nn.ModuleList()
log2scale = int(np.log2(self.in_scale))
for i in range(log2scale, 2, -1):
self.conv_blocks.append(
ConvLNModule(
self._default_channels_per_scale[str(2**i)],
self._default_channels_per_scale[str(2**i)],
feature_shape=(self._default_channels_per_scale[str(2**i)],
2**i, 2**i),
**self.conv_module_cfg))
self.conv_blocks.append(
ConvLNModule(
self._default_channels_per_scale[str(2**i)],
self._default_channels_per_scale[str(2**(i - 1))],
feature_shape=(self._default_channels_per_scale[str(
2**(i - 1))], 2**i, 2**i),
**self.conv_module_cfg))
self.conv_blocks.append(nn.AvgPool2d(kernel_size=2, stride=2))
self.decision = WGANDecisionHead(
self._default_channels_per_scale['4'],
self._default_channels_per_scale['4'],
1,
act_cfg=dict(type='LeakyReLU', negative_slope=0.2),
norm_cfg=self.conv_module_cfg['norm_cfg'])
def forward(self, x):
"""Forward function.
Args:
x (torch.Tensor): Fake or real image tensor.
Returns:
torch.Tensor: Prediction for the reality of the input image.
"""
# noise vector to 2D feature
x = self.from_rgb(x)
for conv in self.conv_blocks:
x = conv(x)
x = self.decision(x)
return x
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
import torch
import torch.nn as nn
from mmcv.cnn import (PLUGIN_LAYERS, ConvModule, build_activation_layer,
build_norm_layer, constant_init)
from mmgen.models.builder import MODULES
@MODULES.register_module()
class WGANNoiseTo2DFeat(nn.Module):
"""Module used in WGAN-GP to transform 1D noise tensor in order [N, C] to
2D shape feature tensor in order [N, C, H, W].
Args:
noise_size (int): Size of the input noise vector.
out_channels (int): The channel number of the output feature.
act_cfg (dict, optional): Config for the activation layer. Defaults to
dict(type='ReLU').
norm_cfg (dict, optional): Config dict to build norm layer. Defaults to
dict(type='BN').
order (tuple, optional): The order of conv/norm/activation layers. It
is a sequence of "conv", "norm" and "act". Common examples are
("conv", "norm", "act") and ("act", "conv", "norm"). Defaults to
('linear', 'act', 'norm').
"""
def __init__(self,
noise_size,
out_channels,
act_cfg=dict(type='ReLU'),
norm_cfg=dict(type='BN'),
order=('linear', 'act', 'norm')):
super().__init__()
self.noise_size = noise_size
self.out_channels = out_channels
self.with_activation = act_cfg is not None
self.with_norm = norm_cfg is not None
self.order = order
assert len(order) == 3 and set(order) == set(['linear', 'act', 'norm'])
# w/o bias, because the bias is added after reshaping the tensor to
# 2D feature
self.linear = nn.Linear(noise_size, out_channels * 16, bias=False)
if self.with_activation:
self.activation = build_activation_layer(act_cfg)
# add bias for reshaped 2D feature.
self.register_parameter(
'bias', nn.Parameter(torch.zeros(1, out_channels, 1, 1)))
if self.with_norm:
_, self.norm = build_norm_layer(norm_cfg, out_channels)
self._init_weight()
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input noise tensor with shape (n, c).
Returns:
Tensor: Forward results with shape (n, c, 4, 4).
"""
assert x.ndim == 2
for order in self.order:
if order == 'linear':
x = self.linear(x)
# [n, c, 4, 4]
x = torch.reshape(x, (-1, self.out_channels, 4, 4))
x = x + self.bias
elif order == 'act' and self.with_activation:
x = self.activation(x)
elif order == 'norm' and self.with_norm:
x = self.norm(x)
return x
def _init_weight(self):
"""Initialize weights for the model."""
nn.init.normal_(self.linear.weight, 0., 1.)
if self.bias is not None:
nn.init.constant_(self.bias, 0.)
if self.with_norm:
constant_init(self.norm, 1, bias=0)
class WGANDecisionHead(nn.Module):
"""Module used in WGAN-GP to get the final prediction result with 4x4
resolution input tensor in the bottom of the discriminator.
Args:
in_channels (int): Number of channels in input feature map.
mid_channels (int): Number of channels in feature map after
convolution.
out_channels (int): The channel number of the final output layer.
bias (bool, optional): Whether to use bias parameter. Defaults to True.
act_cfg (dict, optional): Config for the activation layer. Defaults to
dict(type='ReLU').
out_act (dict, optional): Config for the activation layer of output
layer. Defaults to None.
norm_cfg (dict, optional): Config dict to build norm layer. Defaults to
dict(type='LN2d').
"""
def __init__(self,
in_channels,
mid_channels,
out_channels,
bias=True,
act_cfg=dict(type='ReLU'),
out_act=None,
norm_cfg=dict(type='LN2d')):
super().__init__()
self.in_channels = in_channels
self.mid_channels = mid_channels
self.out_channels = out_channels
self.with_out_activation = out_act is not None
# setup conv layer
self.conv = ConvLNModule(
in_channels,
feature_shape=(mid_channels, 1, 1),
kernel_size=4,
out_channels=mid_channels,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
order=('conv', 'norm', 'act'))
# setup linear layer
self.linear = nn.Linear(
self.mid_channels, self.out_channels, bias=bias)
if self.with_out_activation:
self.out_activation = build_activation_layer(out_act)
self._init_weight()
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
x = self.conv(x)
x = torch.reshape(x, (x.shape[0], -1))
x = self.linear(x)
if self.with_out_activation:
x = self.out_activation(x)
return x
def _init_weight(self):
"""Initialize weights for the model."""
nn.init.normal_(self.linear.weight, 0., 1.)
nn.init.constant_(self.linear.bias, 0.)
@PLUGIN_LAYERS.register_module()
class ConvLNModule(ConvModule):
r"""ConvModule with Layer Normalization.
In this module, we inherit default ``mmcv.cnn.ConvModule`` and deal with
the situation that 'norm_cfg' is 'LN2d' or 'GN'. We adopt 'GN' as a
replacement for layer normalization referring to:
https://github.com/LynnHo/DCGAN-LSGAN-WGAN-GP-DRAGAN-Pytorch/blob/master/module.py # noqa
Args:
feature_shape (tuple): The shape of feature map that will be.
"""
def __init__(self, *args, feature_shape=None, **kwargs):
if 'norm_cfg' in kwargs and kwargs['norm_cfg'] is not None and kwargs[
'norm_cfg']['type'] in ['LN2d', 'GN']:
nkwargs = deepcopy(kwargs)
nkwargs['norm_cfg'] = None
super().__init__(*args, **nkwargs)
self.with_norm = True
self.norm_name = kwargs['norm_cfg']['type']
if self.norm_name == 'LN2d':
norm = nn.LayerNorm(feature_shape)
self.add_module(self.norm_name, norm)
else:
norm = nn.GroupNorm(1, feature_shape[0])
self.add_module(self.norm_name, norm)
else:
super().__init__(*args, **kwargs)
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.utils import Registry, build_from_cfg
MODELS = Registry('model')
MODULES = Registry('module')
def build(cfg, registry, default_args=None):
"""Build a module.
Args:
cfg (dict, list[dict]): The config of modules, is is either a dict
or a list of configs.
registry (:obj:`Registry`): A registry the module belongs to.
default_args (dict, optional): Default arguments to build the module.
Defaults to None.
Returns:
nn.Module: A built nn module.
"""
if isinstance(cfg, list):
modules = [
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
]
return nn.ModuleList(modules)
return build_from_cfg(cfg, registry, default_args)
def build_model(cfg, train_cfg=None, test_cfg=None):
"""Build model (GAN)."""
return build(cfg, MODELS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
def build_module(cfg, default_args=None):
"""Build a module or modules from a list."""
return build(cfg, MODULES, default_args)
# Copyright (c) OpenMMLab. All rights reserved.
from .dist_utils import AllGatherLayer
from .model_utils import GANImageBuffer, set_requires_grad
__all__ = ['set_requires_grad', 'AllGatherLayer', 'GANImageBuffer']
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.autograd as autograd
import torch.distributed as dist
class AllGatherLayer(autograd.Function):
"""All gather layer with backward propagation path.
Indeed, this module is to make ``dist.all_gather()`` in the backward graph.
Such kind of operation has been widely used in Moco and other contrastive
learning algorithms.
"""
@staticmethod
def forward(ctx, x):
"""Forward function."""
ctx.save_for_backward(x)
output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
dist.all_gather(output, x)
return tuple(output)
@staticmethod
def backward(ctx, *grad_outputs):
"""Backward function."""
x, = ctx.saved_tensors
grad_out = torch.zeros_like(x)
grad_out = grad_outputs[dist.get_rank()]
return grad_out
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
def set_requires_grad(nets, requires_grad=False):
"""Set requires_grad for all the networks.
Args:
nets (nn.Module | list[nn.Module]): A list of networks or a single
network.
requires_grad (bool): Whether the networks require gradients or not
"""
if not isinstance(nets, list):
nets = [nets]
for net in nets:
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad
class GANImageBuffer:
"""This class implements an image buffer that stores previously generated
images.
This buffer allows us to update the discriminator using a history of
generated images rather than the ones produced by the latest generator
to reduce model oscillation.
Args:
buffer_size (int): The size of image buffer. If buffer_size = 0,
no buffer will be created.
buffer_ratio (float): The chance / possibility to use the images
previously stored in the buffer.
"""
def __init__(self, buffer_size, buffer_ratio=0.5):
self.buffer_size = buffer_size
# create an empty buffer
if self.buffer_size > 0:
self.img_num = 0
self.image_buffer = []
self.buffer_ratio = buffer_ratio
def query(self, images):
"""Query current image batch using a history of generated images.
Args:
images (Tensor): Current image batch without history information.
"""
if self.buffer_size == 0: # if the buffer size is 0, do nothing
return images
return_images = []
for image in images:
image = torch.unsqueeze(image.data, 0)
# if the buffer is not full, keep inserting current images
if self.img_num < self.buffer_size:
self.img_num = self.img_num + 1
self.image_buffer.append(image)
return_images.append(image)
else:
use_buffer = np.random.random() < self.buffer_ratio
# by self.buffer_ratio, the buffer will return a previously
# stored image, and insert the current image into the buffer
if use_buffer:
random_id = np.random.randint(0, self.buffer_size)
image_tmp = self.image_buffer[random_id].clone()
self.image_buffer[random_id] = image
return_images.append(image_tmp)
# by (1 - self.buffer_ratio), the buffer will return the
# current image
else:
return_images.append(image)
# collect all the images and return
return_images = torch.cat(return_images, 0)
return return_images
# Copyright (c) OpenMMLab. All rights reserved.
from .base_diffusion import BasicGaussianDiffusion
from .sampler import UniformTimeStepSampler
__all__ = ['BasicGaussianDiffusion', 'UniformTimeStepSampler']
# Copyright (c) OpenMMLab. All rights reserved.
import sys
from abc import ABCMeta
from collections import OrderedDict, defaultdict
from copy import deepcopy
from functools import partial
import mmcv
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel.distributed import _find_tensors
from ..architectures.common import get_module_device
from ..builder import MODELS, build_module
from .utils import _get_label_batch, _get_noise_batch, var_to_tensor
@MODELS.register_module()
class BasicGaussianDiffusion(nn.Module, metaclass=ABCMeta):
"""Basic module for gaussian Diffusion Denoising Probabilistic Models. A
diffusion probabilistic model (which we will call a 'diffusion model' for
brevity) is a parameterized Markov chain trained using variational
inference to produce samples matching the data after finite time.
The design of this module implements DDPM and improve-DDPM according to
"Denoising Diffusion Probabilistic Models" (2020) and "Improved Denoising
Diffusion Probabilistic Models" (2021).
Args:
denoising (dict): Config for denoising model.
ddpm_loss (dict): Config for losses of DDPM.
betas_cfg (dict): Config for betas in diffusion process.
num_timesteps (int, optional): The number of timesteps of the diffusion
process. Defaults to 1000.
num_classes (int | None, optional): The number of conditional classes.
Defaults to None.
sample_method (string, optional): Sample method for the denoising
process. Support 'DDPM' and 'DDIM'. Defaults to 'DDPM'.
timesteps_sampler (string, optional): How to sample timesteps in
training process. Defaults to `UniformTimeStepSampler`.
train_cfg (dict | None, optional): Config for training schedule.
Defaults to None.
test_cfg (dict | None, optional): Config for testing schedule. Defaults
to None.
"""
def __init__(self,
denoising,
ddpm_loss,
betas_cfg,
num_timesteps=1000,
num_classes=0,
sample_method='DDPM',
timestep_sampler='UniformTimeStepSampler',
train_cfg=None,
test_cfg=None):
super().__init__()
self.fp16_enable = False
# build denoising module in this function
self.num_classes = num_classes
self.num_timesteps = num_timesteps
self.sample_method = sample_method
self._denoising_cfg = deepcopy(denoising)
self.denoising = build_module(
denoising,
default_args=dict(
num_classes=num_classes, num_timesteps=num_timesteps))
# get output-related configs from denoising
self.denoising_var_mode = self.denoising.var_mode
self.denoising_mean_mode = self.denoising.mean_mode
# output_channels in denoising may be double, therefore we
# get number of channels from config
image_channels = self._denoising_cfg['in_channels']
# image_size should be the attribute of denoising network
image_size = self.denoising.image_size
image_shape = torch.Size([image_channels, image_size, image_size])
self.image_shape = image_shape
self.get_noise = partial(
_get_noise_batch,
image_shape=image_shape,
num_timesteps=self.num_timesteps)
self.get_label = partial(
_get_label_batch, num_timesteps=self.num_timesteps)
# build sampler
if timestep_sampler is not None:
self.sampler = build_module(
timestep_sampler,
default_args=dict(num_timesteps=num_timesteps))
else:
self.sampler = None
# build losses
if ddpm_loss is not None:
self.ddpm_loss = build_module(
ddpm_loss, default_args=dict(sampler=self.sampler))
if not isinstance(self.ddpm_loss, nn.ModuleList):
self.ddpm_loss = nn.ModuleList([self.ddpm_loss])
else:
self.ddpm_loss = None
self.betas_cfg = deepcopy(betas_cfg)
self.train_cfg = deepcopy(train_cfg) if train_cfg else None
self.test_cfg = deepcopy(test_cfg) if test_cfg else None
self._parse_train_cfg()
if test_cfg is not None:
self._parse_test_cfg()
self.prepare_diffusion_vars()
def _parse_train_cfg(self):
"""Parsing train config and set some attributes for training."""
if self.train_cfg is None:
self.train_cfg = dict()
self.use_ema = self.train_cfg.get('use_ema', False)
if self.use_ema:
self.denoising_ema = deepcopy(self.denoising)
self.real_img_key = self.train_cfg.get('real_img_key', 'real_img')
def _parse_test_cfg(self):
"""Parsing test config and set some attributes for testing."""
if self.test_cfg is None:
self.test_cfg = dict()
# whether to use exponential moving average for testing
self.use_ema = self.test_cfg.get('use_ema', False)
if self.use_ema:
self.denoising_ema = deepcopy(self.denoising)
def _get_loss(self, outputs_dict):
losses_dict = {}
# forward losses
for loss_fn in self.ddpm_loss:
losses_dict[loss_fn.loss_name()] = loss_fn(outputs_dict)
loss, log_vars = self._parse_losses(losses_dict)
# update collected log_var from loss_fn
for loss_fn in self.ddpm_loss:
if hasattr(loss_fn, 'log_vars'):
log_vars.update(loss_fn.log_vars)
return loss, log_vars
def _parse_losses(self, losses):
"""Parse the raw outputs (losses) of the network.
Args:
losses (dict): Raw output of the network, which usually contain
losses and other necessary information.
Returns:
tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor \
which may be a weighted sum of all losses, log_vars contains \
all the variables to be sent to the logger.
"""
log_vars = OrderedDict()
for loss_name, loss_value in losses.items():
if isinstance(loss_value, torch.Tensor):
log_vars[loss_name] = loss_value.mean()
elif isinstance(loss_value, list):
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
else:
raise TypeError(
f'{loss_name} is not a tensor or list of tensor')
loss = sum(_value for _key, _value in log_vars.items()
if 'loss' in _key)
log_vars['loss'] = loss
for loss_name, loss_value in log_vars.items():
if dist.is_available() and dist.is_initialized():
loss_value = loss_value.data.clone()
dist.all_reduce(loss_value.div_(dist.get_world_size()))
log_vars[loss_name] = loss_value.item()
return loss, log_vars
def train_step(self,
data,
optimizer,
ddp_reducer=None,
loss_scaler=None,
use_apex_amp=False,
running_status=None):
"""The iteration step during training.
This method defines an iteration step during training. Different from
other repo in **MM** series, we allow the back propagation and
optimizer updating to directly follow the iterative training schedule
of DDPMs.
Of course, we will show that you can also move the back
propagation outside of this method, and then optimize the parameters
in the optimizer hook. But this will cause extra GPU memory cost as a
result of retaining computational graph. Otherwise, the training
schedule should be modified in the detailed implementation.
Args:
optimizer (dict): Dict contains optimizer for denoising network.
running_status (dict | None, optional): Contains necessary basic
information for training, e.g., iteration number. Defaults to
None.
"""
# get running status
if running_status is not None:
curr_iter = running_status['iteration']
else:
# dirty walkround for not providing running status
if not hasattr(self, 'iteration'):
self.iteration = 0
curr_iter = self.iteration
real_imgs = data[self.real_img_key]
# denoising training
optimizer['denoising'].zero_grad()
denoising_dict_ = self.reconstruction_step(
data,
timesteps=self.sampler,
sample_model='orig',
return_noise=True)
denoising_dict_['iteration'] = curr_iter
denoising_dict_['real_imgs'] = real_imgs
denoising_dict_['loss_scaler'] = loss_scaler
loss, log_vars = self._get_loss(denoising_dict_)
# prepare for backward in ddp. If you do not call this function before
# back propagation, the ddp will not dynamically find the used params
# in current computation.
if ddp_reducer is not None:
ddp_reducer.prepare_for_backward(_find_tensors(loss))
if loss_scaler:
# add support for fp16
loss_scaler.scale(loss).backward()
elif use_apex_amp:
from apex import amp
with amp.scale_loss(
loss, optimizer['denoising'],
loss_id=0) as scaled_loss_disc:
scaled_loss_disc.backward()
else:
loss.backward()
if loss_scaler:
loss_scaler.unscale_(optimizer['denoising'])
# note that we do not contain clip_grad procedure
loss_scaler.step(optimizer['denoising'])
# loss_scaler.update will be called in runner.train()
else:
optimizer['denoising'].step()
# image used for vislization
results = dict(
real_imgs=real_imgs,
x_0_pred=denoising_dict_['x_0_pred'],
x_t=denoising_dict_['diffusion_batches'],
x_t_1=denoising_dict_['fake_img'])
outputs = dict(
log_vars=log_vars, num_samples=real_imgs.shape[0], results=results)
if hasattr(self, 'iteration'):
self.iteration += 1
return outputs
def reconstruction_step(self,
data_batch,
noise=None,
label=None,
timesteps=None,
sample_model='orig',
return_noise=False,
**kwargs):
"""Reconstruction step at corresponding `timestep`. To be noted that,
denoisint target ``x_t`` for each timestep are all generated from real
images, but not the denoising result from denoising network.
``sample_from_noise`` focus on generate samples start from **random
(or given) noise**. Therefore, we design this function to realize a
reconstruction process for the given images.
If `timestep` is None, automatically perform reconstruction at all
timesteps.
Args:
data_batch (dict): Input data from dataloader.
noise (torch.Tensor | callable | None): Noise used in diffusion
process. You can directly give a batch of noise through a
``torch.Tensor`` or offer a callable function to sample a
batch of noise data. Otherwise, the ``None`` indicates to use
the default noise sampler. Defaults to None.
label (torch.Tensor | None , optional): The conditional label of
the input image. Defaults to None.
timestep (int | list | torch.Tensor | callable | None): Target
timestep to perform reconstruction.
sampel_model (str, optional): Use which model to sample fake
images. Defaults to `'orig'`.
return_noise (bool, optional): If True,``noise_batch``, ``label``
and all other intermedia variables will be returned together
with ``fake_img`` in a dict. Defaults to False.
Returns:
torch.Tensor | dict: The output may be the direct synthesized
images in ``torch.Tensor``. Otherwise, a dict with required
data , including generated images, will be returned.
"""
assert sample_model in [
'orig', 'ema'
], ('We only support \'orig\' and \'ema\' for '
f'\'reconstruction_step\', but receive \'{sample_model}\'.')
denoising_model = self.denoising if sample_model == 'orig' \
else self.denoising_ema
# 0. prepare for timestep, noise and label
device = get_module_device(self)
real_imgs = data_batch[self.real_img_key]
num_batches = real_imgs.shape[0]
if timesteps is None:
# default to performing the whole reconstruction process
timesteps = torch.LongTensor([
t for t in range(self.num_timesteps)
]).view(self.num_timesteps, 1)
timesteps = timesteps.repeat([1, num_batches])
if isinstance(timesteps, (int, list)):
timesteps = torch.LongTensor(timesteps)
elif callable(timesteps):
timestep_generator = timesteps
timesteps = timestep_generator(num_batches)
else:
assert isinstance(timesteps, torch.Tensor), (
'we only support int list tensor or a callable function')
if timesteps.ndim == 1:
timesteps = timesteps.unsqueeze(0)
timesteps = timesteps.to(get_module_device(self))
if noise is not None:
assert 'noise' not in data_batch, (
'Receive \'noise\' in both data_batch and passed arguments.')
if noise is None:
noise = data_batch['noise'] if 'noise' in data_batch else None
if self.num_classes > 0:
if label is not None:
assert 'label' not in data_batch, (
'Receive \'label\' in both data_batch '
'and passed arguments.')
if label is None:
label = data_batch['label'] if 'label' in data_batch else None
label_batches = self.get_label(
label, num_batches=num_batches).to(device)
else:
label_batches = None
output_dict = defaultdict(list)
# loop all timesteps
for timestep in timesteps:
# 1. get diffusion results and parameters
noise_batches = self.get_noise(
noise, num_batches=num_batches).to(device)
diffusion_batches = self.q_sample(real_imgs, timestep,
noise_batches)
# 2. get denoising results.
denoising_batches = self.denoising_step(
denoising_model,
diffusion_batches,
timestep,
label=label_batches,
return_noise=return_noise,
clip_denoised=not self.training)
# 3. get ground truth by q_posterior
target_batches = self.q_posterior_mean_variance(
real_imgs, diffusion_batches, timestep, logvar=True)
if return_noise:
output_dict_ = dict(
timesteps=timestep,
noise=noise_batches,
diffusion_batches=diffusion_batches)
if self.num_classes > 0:
output_dict_['label'] = label_batches
output_dict_.update(denoising_batches)
output_dict_.update(target_batches)
else:
output_dict_ = dict(fake_img=denoising_batches)
# update output of `timestep` to output_dict
for k, v in output_dict_.items():
if k in output_dict:
output_dict[k].append(v)
else:
output_dict[k] = [v]
# 4. concentrate list to tensor
for k, v in output_dict.items():
output_dict[k] = torch.cat(v, dim=0)
# 5. return results
if return_noise:
return output_dict
return output_dict['fake_img']
def sample_from_noise(self,
noise,
num_batches=0,
sample_model='ema/orig',
label=None,
**kwargs):
"""Sample images from noises by using Denoising model.
Args:
noise (torch.Tensor | callable | None): You can directly give a
batch of noise through a ``torch.Tensor`` or offer a callable
function to sample a batch of noise data. Otherwise, the
``None`` indicates to use the default noise sampler.
num_batches (int, optional): The number of batch size.
Defaults to 0.
sample_model (str, optional): The model to sample. If ``ema/orig``
is passed, this method will try to sample from ema (if
``self.use_ema == True``) and orig model. Defaults to
'ema/orig'.
label (torch.Tensor | None , optional): The conditional label.
Defaults to None.
Returns:
torch.Tensor | dict: The output may be the direct synthesized
images in ``torch.Tensor``. Otherwise, a dict with queried
data, including generated images, will be returned.
"""
# get sample function by name
sample_fn_name = f'{self.sample_method.upper()}_sample'
if not hasattr(self, sample_fn_name):
raise AttributeError(
f'Cannot find sample method [{sample_fn_name}] correspond '
f'to [{self.sample_method}].')
sample_fn = getattr(self, sample_fn_name)
if sample_model == 'ema':
assert self.use_ema
_model = self.denoising_ema
elif sample_model == 'ema/orig' and self.use_ema:
_model = self.denoising_ema
else:
_model = self.denoising
outputs = sample_fn(
_model,
noise=noise,
num_batches=num_batches,
label=label,
**kwargs)
if isinstance(outputs, dict) and 'noise_batch' in outputs:
# return_noise is True
noise = outputs['x_t']
label = outputs['label']
kwargs['timesteps_noise'] = outputs['noise_batch']
fake_img = outputs['fake_img']
else:
fake_img = outputs
if sample_model == 'ema/orig' and self.use_ema:
_model = self.denoising
outputs_ = sample_fn(
_model, noise=noise, num_batches=num_batches, **kwargs)
if isinstance(outputs_, dict) and 'noise_batch' in outputs_:
# return_noise is True
fake_img_ = outputs_['fake_img']
else:
fake_img_ = outputs_
if isinstance(fake_img, dict):
# save_intermedia is True
fake_img = {
k: torch.cat([fake_img[k], fake_img_[k]], dim=0)
for k in fake_img.keys()
}
else:
fake_img = torch.cat([fake_img, fake_img_], dim=0)
return fake_img
@torch.no_grad()
def DDPM_sample(self,
model,
noise=None,
num_batches=0,
label=None,
save_intermedia=False,
timesteps_noise=None,
return_noise=False,
show_pbar=False,
**kwargs):
"""DDPM sample from random noise.
Args:
model (torch.nn.Module): Denoising model used to sample images.
noise (torch.Tensor | callable | None): You can directly give a
batch of noise through a ``torch.Tensor`` or offer a callable
function to sample a batch of noise data. Otherwise, the
``None`` indicates to use the default noise sampler.
num_batches (int, optional): The number of batch size.
Defaults to 0.
label (torch.Tensor | None , optional): The conditional label.
Defaults to None.
save_intermedia (bool, optional): Whether to save denoising result
of intermedia timesteps. If set as True, will return a dict
which key and value are denoising timestep and denoising
result. Otherwise, only the final denoising result will be
returned. Defaults to False.
timesteps_noise (torch.Tensor, optional): Noise term used in each
denoising timestep. If given, the input noise will be shaped to
[num_timesteps, b, c, h, w]. If set as None, noise of each
denoising timestep will be randomly sampled. Default as None.
return_noise (bool, optional): If True, a dict contains
``noise_batch``, ``x_t`` and ``label`` will be returned
together with the denoising results, and the key of denoising
results is ``fake_img``. To be noted that ``noise_batches``
will shape as [num_timesteps, b, c, h, w]. Defaults to False.
show_pbar (bool, optional): If True, a progress bar will be
displayed. Defaults to False.
Returns:
torch.Tensor | dict: If ``save_intermedia``, a dict contains
denoising results of each timestep will be returned.
Otherwise, only the final denoising result will be returned.
"""
device = get_module_device(self)
noise = self.get_noise(noise, num_batches=num_batches).to(device)
x_t = noise.clone()
if save_intermedia:
# save input
intermedia = {self.num_timesteps: x_t.clone()}
# use timesteps noise if defined
if timesteps_noise is not None:
timesteps_noise = self.get_noise(
timesteps_noise, num_batches=num_batches,
timesteps_noise=True).to(device)
batched_timesteps = torch.arange(self.num_timesteps - 1, -1,
-1).long().to(device)
if show_pbar:
pbar = mmcv.ProgressBar(self.num_timesteps)
for t in batched_timesteps:
batched_t = t.expand(x_t.shape[0])
step_noise = timesteps_noise[t, ...] \
if timesteps_noise is not None else None
x_t = self.denoising_step(
model, x_t, batched_t, noise=step_noise, label=label, **kwargs)
if save_intermedia:
intermedia[int(t)] = x_t.cpu().clone()
if show_pbar:
pbar.update()
denoising_results = intermedia if save_intermedia else x_t
if show_pbar:
sys.stdout.write('\n')
if return_noise:
return dict(
noise_batch=timesteps_noise,
x_t=noise,
label=label,
fake_img=denoising_results)
return denoising_results
def prepare_diffusion_vars(self):
"""Prepare for variables used in the diffusion process."""
self.betas = self.get_betas()
self.alphas = 1.0 - self.betas
self.alphas_bar = np.cumproduct(self.alphas, axis=0)
self.alphas_bar_prev = np.append(1.0, self.alphas_bar[:-1])
self.alphas_bar_next = np.append(self.alphas_bar[1:], 0.0)
# calculations for diffusion q(x_t | x_0) and others
self.sqrt_alphas_bar = np.sqrt(self.alphas_bar)
self.sqrt_one_minus_alphas_bar = np.sqrt(1.0 - self.alphas_bar)
self.log_one_minus_alphas_bar = np.log(1.0 - self.alphas_bar)
self.sqrt_recip_alplas_bar = np.sqrt(1.0 / self.alphas_bar)
self.sqrt_recipm1_alphas_bar = np.sqrt(1.0 / self.alphas_bar - 1)
# calculations for posterior q(x_{t-1} | x_t, x_0)
self.tilde_betas_t = self.betas * (1 - self.alphas_bar_prev) / (
1 - self.alphas_bar)
# clip log var for tilde_betas_0 = 0
self.log_tilde_betas_t_clipped = np.log(
np.append(self.tilde_betas_t[1], self.tilde_betas_t[1:]))
self.tilde_mu_t_coef1 = np.sqrt(
self.alphas_bar_prev) / (1 - self.alphas_bar) * self.betas
self.tilde_mu_t_coef2 = np.sqrt(
self.alphas) * (1 - self.alphas_bar_prev) / (1 - self.alphas_bar)
def get_betas(self):
"""Get betas by defined schedule method in diffusion process."""
self.betas_schedule = self.betas_cfg.pop('type')
if self.betas_schedule == 'linear':
return self.linear_beta_schedule(self.num_timesteps,
**self.betas_cfg)
elif self.betas_schedule == 'cosine':
return self.cosine_beta_schedule(self.num_timesteps,
**self.betas_cfg)
else:
raise AttributeError(f'Unknown method name {self.beta_schedule}'
'for beta schedule.')
@staticmethod
def linear_beta_schedule(diffusion_timesteps, beta_0=1e-4, beta_T=2e-2):
r"""Linear schedule from Ho et al, extended to work for any number of
diffusion steps.
Args:
diffusion_timesteps (int): The number of betas to produce.
beta_0 (float, optional): `\beta` at timestep 0. Defaults to 1e-4.
beta_T (float, optional): `\beta` at timestep `T` (the final
diffusion timestep). Defaults to 2e-2.
Returns:
np.ndarray: Betas used in diffusion process.
"""
scale = 1000 / diffusion_timesteps
beta_0 = scale * beta_0
beta_T = scale * beta_T
return np.linspace(
beta_0, beta_T, diffusion_timesteps, dtype=np.float64)
@staticmethod
def cosine_beta_schedule(diffusion_timesteps, max_beta=0.999, s=0.008):
r"""Create a beta schedule that discretizes the given alpha_t_bar
function, which defines the cumulative product of `(1-\beta)` over time
from `t = [0, 1]`.
Args:
diffusion_timesteps (int): The number of betas to produce.
max_beta (float, optional): The maximum beta to use; use values
lower than 1 to prevent singularities. Defaults to 0.999.
s (float, optional): Small offset to prevent `\beta` from being too
small near `t = 0` Defaults to 0.008.
Returns:
np.ndarray: Betas used in diffusion process.
"""
def f(t, T, s):
return np.cos((t / T + s) / (1 + s) * np.pi / 2)**2
betas = []
for t in range(diffusion_timesteps):
alpha_bar_t = f(t + 1, diffusion_timesteps, s)
alpha_bar_t_1 = f(t, diffusion_timesteps, s)
betas_t = 1 - alpha_bar_t / alpha_bar_t_1
betas.append(min(betas_t, max_beta))
return np.array(betas)
def q_sample(self, x_0, t, noise=None):
r"""Get diffusion result at timestep `t` by `q(x_t | x_0)`.
Args:
x_0 (torch.Tensor): Original image without diffusion.
t (torch.Tensor): Target diffusion timestep.
noise (torch.Tensor, optional): Noise used in reparameteration
trick. Default to None.
Returns:
torch.tensor: Diffused image `x_t`.
"""
device = get_module_device(self)
num_batches = x_0.shape[0]
tar_shape = x_0.shape
noise = self.get_noise(noise, num_batches=num_batches)
mean = var_to_tensor(self.sqrt_alphas_bar, t, tar_shape, device)
std = var_to_tensor(self.sqrt_one_minus_alphas_bar, t, tar_shape,
device)
return x_0 * mean + noise * std
def q_mean_log_variance(self, x_0, t):
r"""Get mean and log_variance of diffusion process `q(x_t | x_0)`.
Args:
x_0 (torch.tensor): The original image before diffusion, shape as
[bz, ch, H, W].
t (torch.tensor): Target timestep, shape as [bz, ].
Returns:
Tuple(torch.tensor): Tuple contains mean and log variance.
"""
device = get_module_device(self)
tar_shape = x_0.shape
mean = var_to_tensor(self.sqrt_alphas_bar, t, tar_shape, device) * x_0
logvar = var_to_tensor(self.log_one_minus_alphas_bar, t, tar_shape,
device)
return mean, logvar
def q_posterior_mean_variance(self,
x_0,
x_t,
t,
need_var=True,
logvar=False):
r"""Get mean and variance of diffusion posterior
`q(x_{t-1} | x_t, x_0)`.
Args:
x_0 (torch.tensor): The original image before diffusion, shape as
[bz, ch, H, W].
t (torch.tensor): Target timestep, shape as [bz, ].
need_var (bool, optional): If set as ``True``, this function will
return a dict contains ``var``. Otherwise, only mean will be
returned, ``logvar`` will be ignored. Defaults to True.
logvar (bool, optional): If set as ``True``, the returned dict
will additionally contain ``logvar``. This argument will be
considered only if ``var == True``. Defaults to False.
Returns:
torch.Tensor | dict: If ``var``, will return a dict contains
``mean`` and ``var``. Otherwise, only mean will be returned.
If ``var`` and ``logvar`` set at as True simultaneously, the
returned dict will additional contain ``logvar``.
"""
device = get_module_device(self)
tar_shape = x_0.shape
tilde_mu_t_coef1 = var_to_tensor(self.tilde_mu_t_coef1, t, tar_shape,
device)
tilde_mu_t_coef2 = var_to_tensor(self.tilde_mu_t_coef2, t, tar_shape,
device)
posterior_mean = tilde_mu_t_coef1 * x_0 + tilde_mu_t_coef2 * x_t
# do not need variance, just return mean
if not need_var:
return posterior_mean
posterior_var = var_to_tensor(self.tilde_betas_t, t, tar_shape, device)
out_dict = dict(
mean_posterior=posterior_mean, var_posterior=posterior_var)
if logvar:
posterior_logvar = var_to_tensor(self.log_tilde_betas_t_clipped, t,
tar_shape, device)
out_dict['logvar_posterior'] = posterior_logvar
return out_dict
def p_mean_variance(self,
denoising_output,
x_t,
t,
clip_denoised=True,
denoised_fn=None):
r"""Get mean, variance, log variance of denoising process
`p(x_{t-1} | x_{t})` and predicted `x_0`.
Args:
denoising_output (dict[torch.Tensor]): The output from denoising
model.
x_t (torch.Tensor): Diffused image at timestep `t` to denoising.
t (torch.Tensor): Current timestep.
clip_denoised (bool, optional): Whether cliped sample results into
[-1, 1]. Defaults to True.
denoised_fn (callable, optional): If not None, a function which
applies to the predicted ``x_0`` before it is passed to the
following sampling procedure. Noted that this function will be
applies before ``clip_denoised``. Defaults to None.
Returns:
dict: A dict contains ``var_pred``, ``logvar_pred``, ``mean_pred``
and ``x_0_pred``.
"""
target_shape = x_t.shape
device = get_module_device(self)
# prepare for var and logvar
if self.denoising_var_mode.upper() == 'LEARNED':
# NOTE: the output actually LEARNED_LOG_VAR
logvar_pred = denoising_output['logvar']
varpred = torch.exp(logvar_pred)
elif self.denoising_var_mode.upper() == 'LEARNED_RANGE':
# NOTE: the output actually LEARNED_FACTOR
var_factor = denoising_output['factor']
lower_bound_logvar = var_to_tensor(self.log_tilde_betas_t_clipped,
t, target_shape, device)
upper_bound_logvar = var_to_tensor(
np.log(self.betas), t, target_shape, device)
logvar_pred = var_factor * upper_bound_logvar + (
1 - var_factor) * lower_bound_logvar
varpred = torch.exp(logvar_pred)
elif self.denoising_var_mode.upper() == 'FIXED_LARGE':
# use betas as var
varpred = var_to_tensor(
np.append(self.tilde_betas_t[1], self.betas), t, target_shape,
device)
logvar_pred = torch.log(varpred)
elif self.denoising_var_mode.upper() == 'FIXED_SMALL':
# use posterior (tilde_betas) as var
varpred = var_to_tensor(self.tilde_betas_t, t, target_shape,
device)
logvar_pred = var_to_tensor(self.log_tilde_betas_t_clipped, t,
target_shape, device)
else:
raise AttributeError('Unknown denoising var output type '
f'[{self.denoising_var_mode}].')
def process_x_0(x):
if denoised_fn is not None and callable(denoised_fn):
x = denoised_fn(x)
return x.clamp(-1, 1) if clip_denoised else x
# prepare for mean and x_0
if self.denoising_mean_mode.upper() == 'EPS':
eps_pred = denoising_output['eps_t_pred']
# We can get x_{t-1} with eps in two following approaches:
# 1. eps --(Eq 15)--> \hat{x_0} --(Eq 7)--> \tilde_mu --> x_{t-1}
# 2. eps --(Eq 11)--> \mu_{\theta} --(Eq 7)--> x_{t-1}
# We can verify \tilde_mu in method 1 and \mu_{\theta} in method 2
# are almost same (error of 1e-4) with the same eps input.
# In our implementation, we use method (1) to consistent with
# the official ones.
# If you want to calculate \mu_{\theta} with method 2, you can
# use the following code:
# coef1 = var_to_tensor(
# np.sqrt(1.0 / self.alphas), t, tar_shape)
# coef2 = var_to_tensor(
# self.betas / self.sqrt_one_minus_alphas_bar, t, tar_shape)
# mu_theta = coef1 * (x_t - coef2 * eps)
x_0_pred = process_x_0(self.pred_x_0_from_eps(eps_pred, x_t, t))
mean_pred = self.q_posterior_mean_variance(
x_0_pred, x_t, t, need_var=False)
elif self.denoising_mean_mode.upper() == 'START_X':
x_0_pred = process_x_0(denoising_output['x_0_pred'])
mean_pred = self.q_posterior_mean_variance(
x_0_pred, x_t, t, need_var=False)
elif self.denoising_mean_mode.upper() == 'PREVIOUS_X':
# NOTE: the output actually PREVIOUS_X_MEAN (MU_THETA)
# because this actually predict \mu_{\theta}
mean_pred = denoising_output['x_tm1_pred']
x_0_pred = process_x_0(self.pred_x_0_from_x_tm1(mean_pred, x_t, t))
else:
raise AttributeError('Unknown denoising mean output type '
f'[{self.denoising_mean_mode}].')
output_dict = dict(
var_pred=varpred,
logvar_pred=logvar_pred,
mean_pred=mean_pred,
x_0_pred=x_0_pred)
# avoid return duplicate variables
return {
k: output_dict[k]
for k in output_dict.keys() if k not in denoising_output
}
def denoising_step(self,
model,
x_t,
t,
noise=None,
label=None,
clip_denoised=True,
denoised_fn=None,
model_kwargs=None,
return_noise=False):
"""Single denoising step. Get `x_{t-1}` from ``x_t`` and ``t``.
Args:
model (torch.nn.Module): Denoising model used to sample images.
x_t (torch.Tensor): Input diffused image.
t (torch.Tensor): Current timestep.
noise (torch.Tensor | callable | None): Noise for
reparameterization trick. You can directly give a batch of
noise through a ``torch.Tensor`` or offer a callable function
to sample a batch of noise data. Otherwise, the ``None``
indicates to use the default noise sampler.
label (torch.Tensor | callable | None): You can directly give a
batch of label through a ``torch.Tensor`` or offer a callable
function to sample a batch of label data. Otherwise, the
``None`` indicates to use the default label sampler.
clip_denoised (bool, optional): Whether to clip sample results into
[-1, 1]. Defaults to False.
denoised_fn (callable, optional): If not None, a function which
applies to the predicted ``x_0`` prediction before it is used
to sample. Applies before ``clip_denoised``. Defaults to None.
model_kwargs (dict, optional): Arguments passed to denoising model.
Defaults to None.
return_noise (bool, optional): If True, ``noise_batch``, outputs
from denoising model and ``p_mean_variance`` will be returned
in a dict with ``fake_img``. Defaults to False.
Return:
torch.Tensor | dict: If not ``return_noise``, only the denoising
image will be returned. Otherwise, the dict contains
``fake_image``, ``noise_batch`` and outputs from denoising
model and ``p_mean_variance`` will be returned.
"""
# init model_kwargs as dict if not passed
if model_kwargs is None:
model_kwargs = dict()
model_kwargs.update(dict(return_noise=return_noise))
denoising_output = model(x_t, t, label=label, **model_kwargs)
p_output = self.p_mean_variance(denoising_output, x_t, t,
clip_denoised, denoised_fn)
mean_pred = p_output['mean_pred']
var_pred = p_output['var_pred']
num_batches = x_t.shape[0]
device = get_module_device(self)
# get noise for reparameterization
noise = self.get_noise(noise, num_batches=num_batches).to(device)
nonzero_mask = ((t != 0).float().view(-1,
*([1] * (len(x_t.shape) - 1))))
# Here we directly use var_pred instead logvar_pred,
# only error of 1e-12.
# logvar_pred = p_output['logvar_pred']
# sample = mean_pred + \
# nonzero_mask * torch.exp(0.5 * logvar_pred) * noise
sample = mean_pred + nonzero_mask * torch.sqrt(var_pred) * noise
if return_noise:
return dict(
fake_img=sample,
noise_repar=noise,
**denoising_output,
**p_output)
return sample
def pred_x_0_from_eps(self, eps, x_t, t):
r"""Predict x_0 from eps by Equ 15 in DDPM paper:
.. math::
x_0 = \frac{(x_t - \sqrt{(1-\bar{\alpha}_t)} * eps)}
{\sqrt{\bar{\alpha}_t}}
Args:
eps (torch.Tensor)
x_t (torch.Tensor)
t (torch.Tensor)
Returns:
torch.tensor: Predicted ``x_0``.
"""
device = get_module_device(self)
tar_shape = x_t.shape
coef1 = var_to_tensor(self.sqrt_recip_alplas_bar, t, tar_shape, device)
coef2 = var_to_tensor(self.sqrt_recipm1_alphas_bar, t, tar_shape,
device)
return x_t * coef1 - eps * coef2
def pred_x_0_from_x_tm1(self, x_tm1, x_t, t):
r"""
Predict `x_0` from `x_{t-1}`. (actually from `\mu_{\theta}`).
`(\mu_{\theta} - coef2 * x_t) / coef1`, where `coef1` and `coef2`
are from Eq 6 of the DDPM paper.
NOTE: This function actually predict ``x_0`` from ``mu_theta`` (mean
of ``x_{t-1}``).
Args:
x_tm1 (torch.Tensor): `x_{t-1}` used to predict `x_0`.
x_t (torch.Tensor): `x_{t}` used to predict `x_0`.
t (torch.Tensor): Current timestep.
Returns:
torch.Tensor: Predicted `x_0`.
"""
device = get_module_device(self)
tar_shape = x_t.shape
coef1 = var_to_tensor(self.tilde_mu_t_coef1, t, tar_shape, device)
coef2 = var_to_tensor(self.tilde_mu_t_coef2, t, tar_shape, device)
x_0 = (x_tm1 - coef2 * x_t) / coef1
return x_0
def forward_train(self, data, **kwargs):
"""Deprecated forward function in training."""
raise NotImplementedError(
'In MMGeneration, we do NOT recommend users to call'
'this function, because the train_step function is designed for '
'the training process.')
def forward_test(self, data, **kwargs):
"""Testing function for Diffusion Denosing Probability Models.
Args:
data (torch.Tensor | dict | None): Input data. This data will be
passed to different methods.
"""
mode = kwargs.pop('mode', 'sampling')
if mode == 'sampling':
return self.sample_from_noise(data, **kwargs)
elif mode == 'reconstruction':
# this mode is design for evaluation likelood metrics
return self.reconstruction_step(data, **kwargs)
raise NotImplementedError('Other specific testing functions should'
' be implemented by the sub-classes.')
def forward(self, data, return_loss=False, **kwargs):
"""Forward function.
Args:
data (dict | torch.Tensor): Input data dictionary.
return_loss (bool, optional): Whether in training or testing.
Defaults to False.
Returns:
dict: Output dictionary.
"""
if return_loss:
return self.forward_train(data, **kwargs)
return self.forward_test(data, **kwargs)
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
from ..builder import MODULES
@MODULES.register_module()
class UniformTimeStepSampler:
"""Timestep sampler for DDPM-based models. This sampler sample all
timesteps with the same probabilistic.
Args:
num_timesteps (int): Total timesteps of the diffusion process.
"""
def __init__(self, num_timesteps):
self.num_timesteps = num_timesteps
self.prob = [1 / self.num_timesteps for _ in range(self.num_timesteps)]
def sample(self, batch_size):
"""Sample timesteps.
Args:
batch_size (int): The desired batch size of the sampled timesteps.
Returns:
torch.Tensor: Sampled timesteps.
"""
# use numpy to make sure our implementation is consistent with the
# official ones.
return torch.from_numpy(
np.random.choice(
self.num_timesteps, size=(batch_size, ), p=self.prob)).long()
def __call__(self, batch_size):
"""Return sampled results."""
return self.sample(batch_size)
# Copyright (c) OpenMMLab. All rights reserved.
import torch
def _get_noise_batch(noise,
image_shape,
num_timesteps=0,
num_batches=0,
timesteps_noise=False):
"""Get noise batch. Support get sequeue of noise along timesteps.
We support the following use cases ('bz' denotes ```num_batches`` and 'n'
denotes ``num_timesteps``):
If timesteps_noise is True, we output noise which dimension is 5.
- Input is [bz, c, h, w]: Expand to [n, bz, c, h, w]
- Input is [n, c, h, w]: Expand to [n, bz, c, h, w]
- Input is [n*bz, c, h, w]: View to [n, bz, c, h, w]
- Dim of the input is 5: Return the input, ignore ``num_batches`` and
``num_timesteps``
- Callable or None: Generate noise shape as [n, bz, c, h, w]
- Otherwise: Raise error
If timestep_noise is False, we output noise which dimension is 4 and
ignore ``num_timesteps``.
- Dim of the input is 3: Unsqueeze to [1, c, h, w], ignore ``num_batches``
- Dim of the input is 4: Return input, ignore ``num_batches``
- Callable or None: Generate noise shape as [bz, c, h, w]
- Otherwise: Raise error
It's to be noted that, we do not move the generated label to target device
in this function because we can not get which device the noise should move
to.
Args:
noise (torch.Tensor | callable | None): You can directly give a
batch of noise through a ``torch.Tensor`` or offer a callable
function to sample a batch of noise data. Otherwise, the
``None`` indicates to use the default noise sampler.
image_shape (torch.Size): Size of images in the diffusion process.
num_timesteps (int, optional): Total timestpes of the diffusion and
denoising process. Defaults to 0.
num_batches (int, optional): The number of batch size. To be noted that
this argument only work when the input ``noise`` is callable or
``None``. Defaults to 0.
timesteps_noise (bool, optional): If True, returned noise will shape
as [n, bz, c, h, w], otherwise shape as [bz, c, h, w].
Defaults to False.
device (str, optional): If not ``None``, move the generated noise to
corresponding device.
Returns:
torch.Tensor: Generated noise with desired shape.
"""
if isinstance(noise, torch.Tensor):
# conduct sanity check for the last three dimension
assert noise.shape[-3:] == image_shape
if timesteps_noise:
if noise.ndim == 4:
assert num_batches > 0 and num_timesteps > 0
# noise shape as [n, c, h, w], expand to [n, bz, c, h, w]
if noise.shape[0] == num_timesteps:
noise_batch = noise.view(num_timesteps, 1, *image_shape)
noise_batch = noise_batch.expand(-1, num_batches, -1, -1,
-1)
# noise shape as [bz, c, h, w], expand to [n, bz, c, h, w]
elif noise.shape[0] == num_batches:
noise_batch = noise.view(1, num_batches, *image_shape)
noise_batch = noise_batch.expand(num_timesteps, -1, -1, -1,
-1)
# noise shape as [n*bz, c, h, w], reshape to [b, bz, c, h, w]
elif noise.shape[0] == num_timesteps * num_batches:
noise_batch = noise.view(num_timesteps, -1, *image_shape)
else:
raise ValueError(
'The timesteps noise should be in shape of '
'(n, c, h, w), (bz, c, h, w), (n*bz, c, h, w) or '
f'(n, bz, c, h, w). But receive {noise.shape}.')
elif noise.ndim == 5:
# direct return noise
noise_batch = noise
else:
raise ValueError(
'The timesteps noise should be in shape of '
'(n, c, h, w), (bz, c, h, w), (n*bz, c, h, w) or '
f'(n, bz, c, h, w). But receive {noise.shape}.')
else:
if noise.ndim == 3:
# reshape noise to [1, c, h, w]
noise_batch = noise[None, ...]
elif noise.ndim == 4:
# do nothing
noise_batch = noise
else:
raise ValueError(
'The noise should be in shape of (n, c, h, w) or'
f'(c, h, w), but got {noise.shape}')
# receive a noise generator and sample noise.
elif callable(noise):
assert num_batches > 0
noise_generator = noise
if timesteps_noise:
assert num_timesteps > 0
# generate noise shape as [n, bz, c, h, w]
noise_batch = noise_generator(
(num_timesteps, num_batches, *image_shape))
else:
# generate noise shape as [bz, c, h, w]
noise_batch = noise_generator((num_batches, *image_shape))
# otherwise, we will adopt default noise sampler.
else:
assert num_batches > 0
if timesteps_noise:
assert num_timesteps > 0
# generate noise shape as [n, bz, c, h, w]
noise_batch = torch.randn(
(num_timesteps, num_batches, *image_shape))
else:
# generate noise shape as [bz, c, h, w]
noise_batch = torch.randn((num_batches, *image_shape))
return noise_batch
def _get_label_batch(label,
num_timesteps=0,
num_classes=0,
num_batches=0,
timesteps_noise=False):
"""Get label batch. Support get sequeue of label along timesteps.
We support the following use cases ('bz' denotes ```num_batches`` and 'n'
denotes ``num_timesteps``):
If num_classes <= 0, return None.
If timesteps_noise is True, we output label which dimension is 2.
- Input is [bz, ]: Expand to [n, bz]
- Input is [n, ]: Expand to [n, bz]
- Input is [n*bz, ]: View to [n, bz]
- Dim of the input is 2: Return the input, ignore ``num_batches`` and
``num_timesteps``
- Callable or None: Generate label shape as [n, bz]
- Otherwise: Raise error
If timesteps_noise is False, we output label which dimension is 1 and
ignore ``num_timesteps``.
- Dim of the input is 1: Unsqueeze to [1, ], ignore ``num_batches``
- Dim of the input is 2: Return the input. ignore ``num_batches``
- Callable or None: Generate label shape as [bz, ]
- Otherwise: Raise error
It's to be noted that, we do not move the generated label to target device
in this function because we can not get which device the noise should move
to.
Args:
label (torch.Tensor | callable | None): You can directly give a
batch of noise through a ``torch.Tensor`` or offer a callable
function to sample a batch of noise data. Otherwise, the
``None`` indicates to use the default noise sampler.
num_timesteps (int, optional): Total timestpes of the diffusion and
denoising process. Defaults to 0.
num_batches (int, optional): The number of batch size. To be noted that
this argument only work when the input ``noise`` is callable or
``None``. Defaults to 0.
timesteps_noise (bool, optional): If True, returned noise will shape
as [n, bz, c, h, w], otherwise shape as [bz, c, h, w].
Defaults to False.
Returns:
torch.Tensor: Generated label with desired shape.
"""
# no labels output if num_classes is 0
if num_classes == 0:
assert label is None, ('\'label\' should be None '
'if \'num_classes == 0\'.')
return None
# receive label and conduct sanity check.
if isinstance(label, torch.Tensor):
if timesteps_noise:
if label.ndim == 1:
assert num_batches > 0 and num_timesteps > 0
# [n, ] to [n, bz]
if label.shape[0] == num_timesteps:
label_batch = label.view(num_timesteps, 1)
label_batch = label_batch.expand(-1, num_batches)
# [bz, ] to [n, bz]
elif label.shape[0] == num_batches:
label_batch = label.view(1, num_batches)
label_batch = label_batch.expand(num_timesteps, -1)
# [n*bz, ] to [n, bz]
elif label.shape[0] == num_timesteps * num_batches:
label_batch = label.view(num_timesteps, -1)
else:
raise ValueError(
'The timesteps label should be in shape of '
'(n, ), (bz,), (n*bz, ) or (n, bz, ). But receive '
f'{label.shape}.')
elif label.ndim == 2:
# dimension is 2, direct return
label_batch = label
else:
raise ValueError(
'The timesteps label should be in shape of '
'(n, ), (bz,), (n*bz, ) or (n, bz, ). But receive '
f'{label.shape}.')
else:
# dimension is 0, expand to [1, ]
if label.ndim == 0:
label_batch = label[None, ...]
# dimension is 1, do nothing
elif label.ndim == 1:
label_batch = label
else:
raise ValueError(
'The label should be in shape of (bz, ) or'
f'zero-dimension tensor, but got {label.shape}')
# receive a noise generator and sample noise.
elif callable(label):
assert num_batches > 0
label_generator = label
if timesteps_noise:
assert num_timesteps > 0
# generate label shape as [n, bz]
label_batch = label_generator((num_timesteps, num_batches))
else:
# generate label shape as [bz, ]
label_batch = label_generator((num_batches, ))
# otherwise, we will adopt default label sampler.
else:
assert num_batches > 0
if timesteps_noise:
assert num_timesteps > 0
# generate label shape as [n, bz]
label_batch = torch.randint(0, num_classes,
(num_timesteps, num_batches))
else:
# generate label shape as [bz, ]
label_batch = torch.randint(0, num_classes, (num_batches, ))
return label_batch
def var_to_tensor(var, index, target_shape=None, device=None):
"""Function used to extract variables by given index, and convert into
tensor as given shape.
Args:
var (np.array): Variables to be extracted.
index (torch.Tensor): Target index to extract.
target_shape (torch.Size, optional): If given, the indexed variable
will expand to the given shape. Defaults to None.
device (str): If given, the indexed variable will move to the target
device. Otherwise, indexed variable will on cpu. Defaults to None.
Returns:
torch.Tensor: Converted variable.
"""
# we must move var to cuda for it's ndarray in current design
var_indexed = torch.from_numpy(var)[index.cpu()].float()
if device is not None:
var_indexed = var_indexed.to(device)
while len(var_indexed.shape) < len(target_shape):
var_indexed = var_indexed[..., None]
return var_indexed
# Copyright (c) OpenMMLab. All rights reserved.
from .base_gan import BaseGAN
from .basic_conditional_gan import BasicConditionalGAN
from .mspie_stylegan2 import MSPIEStyleGAN2
from .progressive_growing_unconditional_gan import ProgressiveGrowingGAN
from .singan import PESinGAN, SinGAN
from .static_unconditional_gan import StaticUnconditionalGAN
__all__ = [
'BaseGAN', 'StaticUnconditionalGAN', 'ProgressiveGrowingGAN', 'SinGAN',
'MSPIEStyleGAN2', 'PESinGAN', 'BasicConditionalGAN'
]
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
import torch
import torch.distributed as dist
import torch.nn as nn
class BaseGAN(nn.Module, metaclass=ABCMeta):
"""BaseGAN Module."""
def __init__(self):
super().__init__()
self.fp16_enabled = False
@property
def with_disc(self):
"""Whether with dicriminator."""
return hasattr(self,
'discriminator') and self.discriminator is not None
@property
def with_ema_gen(self):
"""bool: whether the GAN adopts exponential moving average."""
return hasattr(self, 'gen_ema') and self.gen_ema is not None
@property
def with_gen_auxiliary_loss(self):
"""bool: whether the GAN adopts auxiliary loss in the generator."""
return hasattr(self,
'gen_auxiliary_losses') and (self.gen_auxiliary_losses
is not None)
@property
def with_disc_auxiliary_loss(self):
"""bool: whether the GAN adopts auxiliary loss in the discriminator."""
return (hasattr(self, 'disc_auxiliary_losses')
) and self.disc_auxiliary_losses is not None
def _get_disc_loss(self, outputs_dict):
# Construct losses dict. If you hope some items to be included in the
# computational graph, you have to add 'loss' in its name. Otherwise,
# items without 'loss' in their name will just be used to print
# information.
losses_dict = {}
# gan loss
losses_dict['loss_disc_fake'] = self.gan_loss(
outputs_dict['disc_pred_fake'], target_is_real=False, is_disc=True)
losses_dict['loss_disc_real'] = self.gan_loss(
outputs_dict['disc_pred_real'], target_is_real=True, is_disc=True)
# disc auxiliary loss
if self.with_disc_auxiliary_loss:
for loss_module in self.disc_auxiliary_losses:
loss_ = loss_module(outputs_dict)
if loss_ is None:
continue
# the `loss_name()` function return name as 'loss_xxx'
if loss_module.loss_name() in losses_dict:
losses_dict[loss_module.loss_name(
)] = losses_dict[loss_module.loss_name()] + loss_
else:
losses_dict[loss_module.loss_name()] = loss_
loss, log_var = self._parse_losses(losses_dict)
return loss, log_var
def _get_gen_loss(self, outputs_dict):
# Construct losses dict. If you hope some items to be included in the
# computational graph, you have to add 'loss' in its name. Otherwise,
# items without 'loss' in their name will just be used to print
# information.
losses_dict = {}
# gan loss
losses_dict['loss_disc_fake_g'] = self.gan_loss(
outputs_dict['disc_pred_fake_g'],
target_is_real=True,
is_disc=False)
# gen auxiliary loss
if self.with_gen_auxiliary_loss:
for loss_module in self.gen_auxiliary_losses:
loss_ = loss_module(outputs_dict)
if loss_ is None:
continue
# the `loss_name()` function return name as 'loss_xxx'
if loss_module.loss_name() in losses_dict:
losses_dict[loss_module.loss_name(
)] = losses_dict[loss_module.loss_name()] + loss_
else:
losses_dict[loss_module.loss_name()] = loss_
loss, log_var = self._parse_losses(losses_dict)
return loss, log_var
@abstractmethod
def train_step(self, data, optimizer, ddp_reducer=None):
"""The iteration step during training.
This method defines an iteration step during training. Different from
other repo in **MM** series, we allow the back propagation and
optimizer updating to directly follow the iterative training schedule
of GAN. Of course, we will show that you can also move the back
propagation outside of this method, and then optimize the parameters
in the optimizer hook. But this will cause extra GPU memory cost as a
result of retaining computational graph. Otherwise, the training
schedule should be modified in the detailed implementation.
TODO: Give an example of removing bp outside ``train_step``.
TODO: Try the synchronized back propagation.
Args:
data (dict): The output of dataloader.
optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
runner is passed to ``train_step()``. This argument is unused
and reserved.
ddp_reducer (:obj:`Reducer` | None, optional): This reducer is used
to dynamically collect used parameters in the distributed
training. If given an initialized ``Reducer``, we will call its
``prepare_for_backward()`` function just before calling
``.backward()``.
Returns:
dict: It should contain at least 3 keys: ``loss``, ``log_vars``, \
``num_samples``.
- ``loss`` is a tensor for back propagation, which can be a \
weighted sum of multiple losses.
- ``log_vars`` contains all the variables to be sent to the
logger.
- ``num_samples`` indicates the batch size (when the model is \
DDP, it means the batch size on each GPU), which is used for \
averaging the logs.
"""
def sample_from_noise(self,
noise,
num_batches=0,
sample_model='ema/orig',
**kwargs):
"""Sample images from noises by using the generator.
Args:
noise (torch.Tensor | callable | None): You can directly give a
batch of noise through a ``torch.Tensor`` or offer a callable
function to sample a batch of noise data. Otherwise, the
``None`` indicates to use the default noise sampler.
num_batches (int, optional): The number of batch size.
Defaults to 0.
Returns:
torch.Tensor | dict: The output may be the direct synthesized
images in ``torch.Tensor``. Otherwise, a dict with queried
data, including generated images, will be returned.
"""
if sample_model == 'ema':
assert self.use_ema
_model = self.generator_ema
elif sample_model == 'ema/orig' and self.use_ema:
_model = self.generator_ema
else:
_model = self.generator
outputs = _model(noise, num_batches=num_batches, **kwargs)
if isinstance(outputs, dict) and 'noise_batch' in outputs:
noise = outputs['noise_batch']
if sample_model == 'ema/orig' and self.use_ema:
_model = self.generator
outputs_ = _model(noise, num_batches=num_batches, **kwargs)
if isinstance(outputs_, dict):
outputs['fake_img'] = torch.cat(
[outputs['fake_img'], outputs_['fake_img']], dim=0)
else:
outputs = torch.cat([outputs, outputs_], dim=0)
return outputs
def forward_train(self, data, **kwargs):
"""Deprecated forward function in training."""
raise NotImplementedError(
'In MMGeneration, we do NOT recommend users to call'
'this function, because the train_step function is designed for '
'the training process.')
def forward_test(self, data, **kwargs):
"""Testing function for GANs.
Args:
data (torch.Tensor | dict | None): Input data. This data will be
passed to different methods.
"""
if kwargs.pop('mode', 'sampling') == 'sampling':
return self.sample_from_noise(data, **kwargs)
raise NotImplementedError('Other specific testing functions should'
' be implemented by the sub-classes.')
def forward(self, data, return_loss=False, **kwargs):
"""Forward function.
Args:
data (dict | torch.Tensor): Input data dictionary.
return_loss (bool, optional): Whether in training or testing.
Defaults to False.
Returns:
dict: Output dictionary.
"""
if return_loss:
return self.forward_train(data, **kwargs)
return self.forward_test(data, **kwargs)
def _parse_losses(self, losses):
"""Parse the raw outputs (losses) of the network.
Args:
losses (dict): Raw output of the network, which usually contain
losses and other necessary information.
Returns:
tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor \
which may be a weighted sum of all losses, log_vars contains \
all the variables to be sent to the logger.
"""
log_vars = OrderedDict()
for loss_name, loss_value in losses.items():
if isinstance(loss_value, torch.Tensor):
log_vars[loss_name] = loss_value.mean()
elif isinstance(loss_value, list):
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
# Allow setting None for some loss item.
# This is to support dynamic loss module, where the loss is
# calculated with a fixed frequency.
elif loss_value is None:
continue
else:
raise TypeError(
f'{loss_name} is not a tensor or list of tensors')
# Note that you have to add 'loss' in name of the items that will be
# included in back propagation.
loss = sum(_value for _key, _value in log_vars.items()
if 'loss' in _key)
log_vars['loss'] = loss
for loss_name, loss_value in log_vars.items():
# reduce loss when distributed training
if dist.is_available() and dist.is_initialized():
loss_value = loss_value.data.clone()
dist.all_reduce(loss_value.div_(dist.get_world_size()))
log_vars[loss_name] = loss_value.item()
return loss, log_vars
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
import torch
import torch.nn as nn
from torch.nn.parallel.distributed import _find_tensors
from ..builder import MODELS, build_module
from ..common import set_requires_grad
from .base_gan import BaseGAN
@MODELS.register_module('BasiccGAN')
@MODELS.register_module()
class BasicConditionalGAN(BaseGAN):
"""Basic conditional GANs.
This is the conditional GAN model containing standard adversarial training
schedule. To fulfill the requirements of various GAN algorithms,
``disc_auxiliary_loss`` and ``gen_auxiliary_loss`` are provided to
customize auxiliary losses, e.g., gradient penalty loss, and discriminator
shift loss. In addition, ``train_cfg`` and ``test_cfg`` aims at setuping
training schedule.
Args:
generator (dict): Config for generator.
discriminator (dict): Config for discriminator.
gan_loss (dict): Config for generative adversarial loss.
disc_auxiliary_loss (dict): Config for auxiliary loss to
discriminator.
gen_auxiliary_loss (dict | None, optional): Config for auxiliary loss
to generator. Defaults to None.
train_cfg (dict | None, optional): Config for training schedule.
Defaults to None.
test_cfg (dict | None, optional): Config for testing schedule. Defaults
to None.
num_classes (int | None, optional): The number of conditional classes.
Defaults to None.
"""
def __init__(self,
generator,
discriminator,
gan_loss,
disc_auxiliary_loss=None,
gen_auxiliary_loss=None,
train_cfg=None,
test_cfg=None,
num_classes=None):
super().__init__()
self.num_classes = num_classes
self._gen_cfg = deepcopy(generator)
self.generator = build_module(
generator, default_args=dict(num_classes=num_classes))
# support no discriminator in testing
if discriminator is not None:
self.discriminator = build_module(
discriminator, default_args=dict(num_classes=num_classes))
else:
self.discriminator = None
# support no gan_loss in testing
if gan_loss is not None:
self.gan_loss = build_module(gan_loss)
else:
self.gan_loss = None
if disc_auxiliary_loss:
self.disc_auxiliary_losses = build_module(disc_auxiliary_loss)
if not isinstance(self.disc_auxiliary_losses, nn.ModuleList):
self.disc_auxiliary_losses = nn.ModuleList(
[self.disc_auxiliary_losses])
else:
self.disc_auxiliary_loss = None
if gen_auxiliary_loss:
self.gen_auxiliary_losses = build_module(gen_auxiliary_loss)
if not isinstance(self.gen_auxiliary_losses, nn.ModuleList):
self.gen_auxiliary_losses = nn.ModuleList(
[self.gen_auxiliary_losses])
else:
self.gen_auxiliary_losses = None
self.train_cfg = deepcopy(train_cfg) if train_cfg else None
self.test_cfg = deepcopy(test_cfg) if test_cfg else None
self._parse_train_cfg()
if test_cfg is not None:
self._parse_test_cfg()
def _parse_train_cfg(self):
"""Parsing train config and set some attributes for training."""
if self.train_cfg is None:
self.train_cfg = dict()
# control the work flow in train step
self.disc_steps = self.train_cfg.get('disc_steps', 1)
self.gen_steps = self.train_cfg.get('gen_steps', 1)
# add support for accumulating gradients within multiple steps. This
# feature aims to simulate large `batch_sizes` (but may have some
# detailed differences in BN). Note that `self.disc_steps` should be
# set according to the batch accumulation strategy.
# In addition, in the detailed implementation, there is a difference
# between the batch accumulation in the generator and discriminator.
self.batch_accumulation_steps = self.train_cfg.get(
'batch_accumulation_steps', 1)
# whether to use exponential moving average for training
self.use_ema = self.train_cfg.get('use_ema', False)
if self.use_ema:
# use deepcopy to guarantee the consistency
self.generator_ema = deepcopy(self.generator)
def _parse_test_cfg(self):
"""Parsing test config and set some attributes for testing."""
if self.test_cfg is None:
self.test_cfg = dict()
# basic testing information
self.batch_size = self.test_cfg.get('batch_size', 1)
# whether to use exponential moving average for testing
self.use_ema = self.test_cfg.get('use_ema', False)
# TODO: finish ema part
def train_step(self,
data_batch,
optimizer,
ddp_reducer=None,
loss_scaler=None,
use_apex_amp=False,
running_status=None):
"""Train step function.
This function implements the standard training iteration for
asynchronous adversarial training. Namely, in each iteration, we first
update discriminator and then compute loss for generator with the newly
updated discriminator.
As for distributed training, we use the ``reducer`` from ddp to
synchronize the necessary params in current computational graph.
Args:
data_batch (dict): Input data from dataloader.
optimizer (dict): Dict contains optimizer for generator and
discriminator.
ddp_reducer (:obj:`Reducer` | None, optional): Reducer from ddp.
It is used to prepare for ``backward()`` in ddp. Defaults to
None.
loss_scaler (:obj:`torch.cuda.amp.GradScaler` | None, optional):
The loss/gradient scaler used for auto mixed-precision
training. Defaults to ``None``.
use_apex_amp (bool, optional). Whether to use apex.amp. Defaults to
``False``.
running_status (dict | None, optional): Contains necessary basic
information for training, e.g., iteration number. Defaults to
None.
Returns:
dict: Contains 'log_vars', 'num_samples', and 'results'.
"""
# get data from data_batch
real_imgs = data_batch['img']
# get the ground-truth label, torch.Tensor (N, )
gt_label = data_batch['gt_label']
# If you adopt ddp, this batch size is local batch size for each GPU.
# If you adopt dp, this batch size is the global batch size as usual.
batch_size = real_imgs.shape[0]
# get running status
if running_status is not None:
curr_iter = running_status['iteration']
else:
# dirty walkround for not providing running status
if not hasattr(self, 'iteration'):
self.iteration = 0
curr_iter = self.iteration
# disc training
set_requires_grad(self.discriminator, True)
# do not `zero_grad` during batch accumulation
if curr_iter % self.batch_accumulation_steps == 0:
optimizer['discriminator'].zero_grad()
# TODO: add noise sampler to customize noise sampling
with torch.no_grad():
fake_data = self.generator(
None, num_batches=batch_size, label=None, return_noise=True)
# fake_label should be in the same data type with the gt_label
fake_imgs, fake_label = fake_data['fake_img'], fake_data['label']
# disc pred for fake imgs and real_imgs
disc_pred_fake = self.discriminator(fake_imgs, label=fake_label)
disc_pred_real = self.discriminator(real_imgs, label=gt_label)
# get data dict to compute losses for disc
data_dict_ = dict(
gen=self.generator,
disc=self.discriminator,
disc_pred_fake=disc_pred_fake,
disc_pred_real=disc_pred_real,
fake_imgs=fake_imgs,
real_imgs=real_imgs,
iteration=curr_iter,
batch_size=batch_size,
gt_label=gt_label,
fake_label=fake_label,
loss_scaler=loss_scaler)
loss_disc, log_vars_disc = self._get_disc_loss(data_dict_)
loss_disc = loss_disc / float(self.batch_accumulation_steps)
# prepare for backward in ddp. If you do not call this function before
# back propagation, the ddp will not dynamically find the used params
# in current computation.
if ddp_reducer is not None:
ddp_reducer.prepare_for_backward(_find_tensors(loss_disc))
if loss_scaler:
# add support for fp16
loss_scaler.scale(loss_disc).backward()
elif use_apex_amp:
from apex import amp
with amp.scale_loss(
loss_disc, optimizer['discriminator'],
loss_id=0) as scaled_loss_disc:
scaled_loss_disc.backward()
else:
loss_disc.backward()
if (curr_iter + 1) % self.batch_accumulation_steps == 0:
if loss_scaler:
loss_scaler.unscale_(optimizer['discriminator'])
# note that we do not contain clip_grad procedure
loss_scaler.step(optimizer['discriminator'])
# loss_scaler.update will be called in runner.train()
else:
optimizer['discriminator'].step()
# skip generator training if only train discriminator for current
# iteration
if (curr_iter + 1) % self.disc_steps != 0:
results = dict(
fake_imgs=fake_imgs.cpu(), real_imgs=real_imgs.cpu())
outputs = dict(
log_vars=log_vars_disc,
num_samples=batch_size,
results=results)
if hasattr(self, 'iteration'):
self.iteration += 1
return outputs
# generator training
set_requires_grad(self.discriminator, False)
# allow for training the generator with multiple steps
for _ in range(self.gen_steps):
optimizer['generator'].zero_grad()
for _ in range(self.batch_accumulation_steps):
# TODO: add noise sampler to customize noise sampling
fake_data = self.generator(
None, num_batches=batch_size, return_noise=True)
# fake_label should be in the same data type with the gt_label
fake_imgs, fake_label = fake_data['fake_img'], fake_data[
'label']
disc_pred_fake_g = self.discriminator(
fake_imgs, label=fake_label)
data_dict_ = dict(
gen=self.generator,
disc=self.discriminator,
fake_imgs=fake_imgs,
disc_pred_fake_g=disc_pred_fake_g,
iteration=curr_iter,
batch_size=batch_size,
fake_label=fake_label,
loss_scaler=loss_scaler)
loss_gen, log_vars_g = self._get_gen_loss(data_dict_)
loss_gen = loss_gen / float(self.batch_accumulation_steps)
# prepare for backward in ddp. If you do not call this function
# before back propagation, the ddp will not dynamically find
# the used params in current computation.
if ddp_reducer is not None:
ddp_reducer.prepare_for_backward(_find_tensors(loss_gen))
if loss_scaler:
loss_scaler.scale(loss_gen).backward()
elif use_apex_amp:
from apex import amp
with amp.scale_loss(
loss_gen, optimizer['generator'],
loss_id=1) as scaled_loss_disc:
scaled_loss_disc.backward()
else:
loss_gen.backward()
if loss_scaler:
loss_scaler.unscale_(optimizer['generator'])
# note that we do not contain clip_grad procedure
loss_scaler.step(optimizer['generator'])
# loss_scaler.update will be called in runner.train()
else:
optimizer['generator'].step()
log_vars = {}
log_vars.update(log_vars_g)
log_vars.update(log_vars_disc)
results = dict(fake_imgs=fake_imgs.cpu(), real_imgs=real_imgs.cpu())
outputs = dict(
log_vars=log_vars, num_samples=batch_size, results=results)
if hasattr(self, 'iteration'):
self.iteration += 1
return outputs
def sample_from_noise(self,
noise,
num_batches=0,
sample_model='ema/orig',
label=None,
**kwargs):
"""Sample images from noises by using the generator.
Args:
noise (torch.Tensor | callable | None): You can directly give a
batch of noise through a ``torch.Tensor`` or offer a callable
function to sample a batch of noise data. Otherwise, the
``None`` indicates to use the default noise sampler.
num_batches (int, optional): The number of batch size.
Defaults to 0.
sampel_model (str, optional): Use which model to sample fake
images. Defaults to `'ema/orig'`.
label (torch.Tensor | None , optional): The conditional label.
Defaults to None.
Returns:
torch.Tensor | dict: The output may be the direct synthesized
images in ``torch.Tensor``. Otherwise, a dict with queried
data, including generated images, will be returned.
"""
if sample_model == 'ema':
assert self.use_ema
_model = self.generator_ema
elif sample_model == 'ema/orig' and self.use_ema:
_model = self.generator_ema
else:
_model = self.generator
outputs = _model(noise, num_batches=num_batches, label=label, **kwargs)
if isinstance(outputs, dict) and 'noise_batch' in outputs:
noise = outputs['noise_batch']
label = outputs['label']
if sample_model == 'ema/orig' and self.use_ema:
_model = self.generator
outputs_ = _model(
noise, num_batches=num_batches, label=label, **kwargs)
if isinstance(outputs_, dict):
outputs['fake_img'] = torch.cat(
[outputs['fake_img'], outputs_['fake_img']], dim=0)
else:
outputs = torch.cat([outputs, outputs_], dim=0)
return outputs
# Copyright (c) OpenMMLab. All rights reserved.
import logging
from functools import partial
import mmcv
import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.nn.parallel.distributed import _find_tensors
from ..builder import MODELS
from ..common import set_requires_grad
from .static_unconditional_gan import StaticUnconditionalGAN
@MODELS.register_module()
class MSPIEStyleGAN2(StaticUnconditionalGAN):
"""MS-PIE StyleGAN2.
In this GAN, we adopt the MS-PIE training schedule so that multi-scale
images can be generated with a single generator. Details can be found in:
Positional Encoding as Spatial Inductive Bias in GANs, CVPR2021.
Args:
generator (dict): Config for generator.
discriminator (dict): Config for discriminator.
gan_loss (dict): Config for generative adversarial loss.
disc_auxiliary_loss (dict): Config for auxiliary loss to
discriminator.
gen_auxiliary_loss (dict | None, optional): Config for auxiliary loss
to generator. Defaults to None.
train_cfg (dict | None, optional): Config for training schedule.
Defaults to None.
test_cfg (dict | None, optional): Config for testing schedule. Defaults
to None.
"""
def _parse_train_cfg(self):
super(MSPIEStyleGAN2, self)._parse_train_cfg()
# set the number of upsampling blocks. This value will be used to
# calculate the current result size according to the size of the input
# feature map, e.g., positional encoding map
self.num_upblocks = self.train_cfg.get('num_upblocks', 6)
# multiple input scales (a list of int) that will be added to the
# original starting scale.
self.multi_input_scales = self.train_cfg.get('multi_input_scales')
self.multi_scale_probability = self.train_cfg.get(
'multi_scale_probability')
def train_step(self,
data_batch,
optimizer,
ddp_reducer=None,
running_status=None):
"""Train step function.
This function implements the standard training iteration for
asynchronous adversarial training. Namely, in each iteration, we first
update discriminator and then compute loss for generator with the newly
updated discriminator.
As for distributed training, we use the ``reducer`` from ddp to
synchronize the necessary params in current computational graph.
Args:
data_batch (dict): Input data from dataloader.
optimizer (dict): Dict contains optimizer for generator and
discriminator.
ddp_reducer (:obj:`Reducer` | None, optional): Reducer from ddp.
It is used to prepare for ``backward()`` in ddp. Defaults to
None.
running_status (dict | None, optional): Contains necessary basic
information for training, e.g., iteration number. Defaults to
None.
Returns:
dict: Contains 'log_vars', 'num_samples', and 'results'.
"""
# get data from data_batch
real_imgs = data_batch['real_img']
# If you adopt ddp, this batch size is local batch size for each GPU.
# If you adopt dp, this batch size is the global batch size as usual.
batch_size = real_imgs.shape[0]
# get running status
if running_status is not None:
curr_iter = running_status['iteration']
else:
# dirty walkround for not providing running status
if not hasattr(self, 'iteration'):
self.iteration = 0
curr_iter = self.iteration
if dist.is_initialized():
# randomly sample a scale for current training iteration
chosen_scale = np.random.choice(self.multi_input_scales, 1,
self.multi_scale_probability)[0]
chosen_scale = torch.tensor(chosen_scale, dtype=torch.int).cuda()
dist.broadcast(chosen_scale, 0)
chosen_scale = int(chosen_scale.item())
else:
mmcv.print_log(
'Distributed training has not been initialized. Degrade to '
'the standard stylegan2',
logger='mmgen',
level=logging.WARN)
chosen_scale = 0
curr_size = (4 + chosen_scale) * (2**self.num_upblocks)
# adjust the shape of images
if real_imgs.shape[-2:] != (curr_size, curr_size):
real_imgs = F.interpolate(
real_imgs,
size=(curr_size, curr_size),
mode='bilinear',
align_corners=True)
# disc training
set_requires_grad(self.discriminator, True)
optimizer['discriminator'].zero_grad()
# TODO: add noise sampler to customize noise sampling
with torch.no_grad():
fake_imgs = self.generator(
None, num_batches=batch_size, chosen_scale=chosen_scale)
# disc pred for fake imgs and real_imgs
disc_pred_fake = self.discriminator(fake_imgs)
disc_pred_real = self.discriminator(real_imgs)
# get data dict to compute losses for disc
data_dict_ = dict(
gen=self.generator,
disc=self.discriminator,
disc_pred_fake=disc_pred_fake,
disc_pred_real=disc_pred_real,
fake_imgs=fake_imgs,
real_imgs=real_imgs,
iteration=curr_iter,
batch_size=batch_size,
gen_partial=partial(self.generator, chosen_scale=chosen_scale))
loss_disc, log_vars_disc = self._get_disc_loss(data_dict_)
# prepare for backward in ddp. If you do not call this function before
# back propagation, the ddp will not dynamically find the used params
# in current computation.
if ddp_reducer is not None:
ddp_reducer.prepare_for_backward(_find_tensors(loss_disc))
loss_disc.backward()
optimizer['discriminator'].step()
# skip generator training if only train discriminator for current
# iteration
if (curr_iter + 1) % self.disc_steps != 0:
results = dict(
fake_imgs=fake_imgs.cpu(), real_imgs=real_imgs.cpu())
log_vars_disc['curr_size'] = curr_size
outputs = dict(
log_vars=log_vars_disc,
num_samples=batch_size,
results=results)
if hasattr(self, 'iteration'):
self.iteration += 1
return outputs
# generator training
set_requires_grad(self.discriminator, False)
optimizer['generator'].zero_grad()
# TODO: add noise sampler to customize noise sampling
fake_imgs = self.generator(
None, num_batches=batch_size, chosen_scale=chosen_scale)
disc_pred_fake_g = self.discriminator(fake_imgs)
data_dict_ = dict(
gen=self.generator,
disc=self.discriminator,
fake_imgs=fake_imgs,
disc_pred_fake_g=disc_pred_fake_g,
iteration=curr_iter,
batch_size=batch_size,
gen_partial=partial(self.generator, chosen_scale=chosen_scale))
loss_gen, log_vars_g = self._get_gen_loss(data_dict_)
# prepare for backward in ddp. If you do not call this function before
# back propagation, the ddp will not dynamically find the used params
# in current computation.
if ddp_reducer is not None:
ddp_reducer.prepare_for_backward(_find_tensors(loss_gen))
loss_gen.backward()
optimizer['generator'].step()
log_vars = {}
log_vars.update(log_vars_g)
log_vars.update(log_vars_disc)
log_vars['curr_size'] = curr_size
results = dict(fake_imgs=fake_imgs.cpu(), real_imgs=real_imgs.cpu())
outputs = dict(
log_vars=log_vars, num_samples=batch_size, results=results)
if hasattr(self, 'iteration'):
self.iteration += 1
return outputs
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from functools import partial
import mmcv
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parallel.distributed import _find_tensors
from mmgen.core.optimizer import build_optimizers
from mmgen.models.builder import MODELS, build_module
from ..common import set_requires_grad
from .base_gan import BaseGAN
@MODELS.register_module('StyleGANV1')
@MODELS.register_module('PGGAN')
@MODELS.register_module()
class ProgressiveGrowingGAN(BaseGAN):
"""Progressive Growing Unconditional GAN.
In this GAN model, we implement progressive growing training schedule,
which is proposed in Progressive Growing of GANs for improved Quality,
Stability and Variation, ICLR 2018.
We highly recommend to use ``GrowScaleImgDataset`` for saving computational
load in data pre-processing.
Notes for **using PGGAN**:
#. In official implementation, Tero uses gradient penalty with
``norm_mode="HWC"``
#. We do not implement ``minibatch_repeats`` where has been used in
official Tensorflow implementation.
Notes for resuming progressive growing GANs:
Users should specify the ``prev_stage`` in ``train_cfg``. Otherwise, the
model is possible to reset the optimizer status, which will bring
inferior performance. For example, if your model is resumed from the
`256` stage, you should set ``train_cfg=dict(prev_stage=256)``.
Args:
generator (dict): Config for generator.
discriminator (dict): Config for discriminator.
gan_loss (dict): Config for generative adversarial loss.
disc_auxiliary_loss (dict): Config for auxiliary loss to
discriminator.
gen_auxiliary_loss (dict | None, optional): Config for auxiliary loss
to generator. Defaults to None.
train_cfg (dict | None, optional): Config for training schedule.
Defaults to None.
test_cfg (dict | None, optional): Config for testing schedule. Defaults
to None.
"""
def __init__(self,
generator,
discriminator,
gan_loss,
disc_auxiliary_loss,
gen_auxiliary_loss=None,
train_cfg=None,
test_cfg=None):
super().__init__()
self._gen_cfg = deepcopy(generator)
self.generator = build_module(generator)
# support no discriminator in testing
if discriminator is not None:
self.discriminator = build_module(discriminator)
else:
self.discriminator = None
# support no gan_loss in testing
if gan_loss is not None:
self.gan_loss = build_module(gan_loss)
else:
self.gan_loss = None
if disc_auxiliary_loss:
self.disc_auxiliary_losses = build_module(disc_auxiliary_loss)
if not isinstance(self.disc_auxiliary_losses, nn.ModuleList):
self.disc_auxiliary_losses = nn.ModuleList(
[self.disc_auxiliary_losses])
else:
self.disc_auxiliary_losses = None
if gen_auxiliary_loss:
self.gen_auxiliary_losses = build_module(gen_auxiliary_loss)
if not isinstance(self.gen_auxiliary_losses, nn.ModuleList):
self.gen_auxiliary_losses = nn.ModuleList(
[self.gen_auxiliary_losses])
else:
self.gen_auxiliary_losses = None
# register necessary training status
self.register_buffer('shown_nkimg', torch.tensor(0.))
self.register_buffer('_curr_transition_weight', torch.tensor(1.))
self.train_cfg = deepcopy(train_cfg) if train_cfg else None
self.test_cfg = deepcopy(test_cfg) if test_cfg else None
self._parse_train_cfg()
# this buffer is used to resume model easily
self.register_buffer(
'_next_scale_int',
torch.tensor(self.scales[0][0], dtype=torch.int32))
# TODO: init it with the same value as `_next_scale_int`
# a dirty workaround for testing
self.register_buffer(
'_curr_scale_int',
torch.tensor(self.scales[-1][0], dtype=torch.int32))
if test_cfg is not None:
self._parse_test_cfg()
def _parse_train_cfg(self):
"""Parsing train config and set some attributes for training."""
if self.train_cfg is None:
self.train_cfg = dict()
# control the work flow in train step
self.disc_steps = self.train_cfg.get('disc_steps', 1)
# whether to use exponential moving average for training
self.use_ema = self.train_cfg.get('use_ema', False)
if self.use_ema:
# use deepcopy to guarantee the consistency
self.generator_ema = deepcopy(self.generator)
# setup interpolation operation at the beginning of training iter
interp_real_cfg = deepcopy(self.train_cfg.get('interp_real', None))
if interp_real_cfg is None:
interp_real_cfg = dict(mode='bilinear', align_corners=True)
self.interp_real_to = partial(F.interpolate, **interp_real_cfg)
# parsing the training schedule: scales : kimg
assert isinstance(self.train_cfg['nkimgs_per_scale'],
dict), ('Please provide "nkimgs_per_'
'scale" to schedule the training procedure.')
nkimgs_per_scale = deepcopy(self.train_cfg['nkimgs_per_scale'])
self.scales = []
self.nkimgs = []
for k, v in nkimgs_per_scale.items():
# support for different data types
if isinstance(k, str):
k = (int(k), int(k))
elif isinstance(k, int):
k = (k, k)
else:
assert mmcv.is_tuple_of(k, int)
# sanity check for the order of scales
assert len(self.scales) == 0 or k[0] > self.scales[-1][0]
self.scales.append(k)
self.nkimgs.append(v)
self.cum_nkimgs = np.cumsum(self.nkimgs)
self.curr_stage = 0
self.prev_stage = 0
# actually nkimgs shown at the end of per training stage
self._actual_nkimgs = []
# In each scale, transit from previous torgb layer to newer torgb layer
# with `transition_kimgs` imgs
self.transition_kimgs = self.train_cfg.get('transition_kimgs', 600)
# setup optimizer
self.optimizer = build_optimizers(
self, deepcopy(self.train_cfg['optimizer_cfg']))
# get lr schedule
self.g_lr_base = self.train_cfg['g_lr_base']
self.d_lr_base = self.train_cfg['d_lr_base']
# example for lr schedule: {'32': 0.001, '64': 0.0001}
self.g_lr_schedule = self.train_cfg.get('g_lr_schedule', dict())
self.d_lr_schedule = self.train_cfg.get('d_lr_schedule', dict())
# reset the states for optimizers, e.g. momentum in Adam
self.reset_optim_for_new_scale = self.train_cfg.get(
'reset_optim_for_new_scale', True)
# dirty walkround for avoiding optimizer bug in resuming
self.prev_stage = self.train_cfg.get('prev_stage', self.prev_stage)
def _parse_test_cfg(self):
"""Parsing train config and set some attributes for testing."""
if self.test_cfg is None:
self.test_cfg = dict()
# basic testing information
self.batch_size = self.test_cfg.get('batch_size', 1)
# whether to use exponential moving average for testing
self.use_ema = self.test_cfg.get('use_ema', False)
# TODO: finish ema part
def sample_from_noise(self,
noise,
num_batches=0,
curr_scale=None,
transition_weight=None,
sample_model='ema/orig',
**kwargs):
"""Sample images from noises by using the generator.
Args:
noise (torch.Tensor | callable | None): You can directly give a
batch of noise through a ``torch.Tensor`` or offer a callable
function to sample a batch of noise data. Otherwise, the
``None`` indicates to use the default noise sampler.
num_batches (int, optional): The number of batch size.
Defaults to 0.
Returns:
torch.Tensor | dict: The output may be the direct synthesized \
images in ``torch.Tensor``. Otherwise, a dict with queried \
data, including generated images, will be returned.
"""
# use `self.curr_scale` if curr_scale is None
if curr_scale is None:
# in training, 'curr_scale' will be set as attribute
if hasattr(self, 'curr_scale'):
curr_scale = self.curr_scale[0]
# in testing, adopt '_curr_scale_int' from buffer as testing scale
else:
curr_scale = self._curr_scale_int.item()
# use `self._curr_transition_weight` if `transition_weight` is None
if transition_weight is None:
transition_weight = self._curr_transition_weight.item()
if sample_model == 'ema':
assert self.use_ema
_model = self.generator_ema
elif sample_model == 'ema/orig' and self.use_ema:
_model = self.generator_ema
else:
_model = self.generator
outputs = _model(
noise,
num_batches=num_batches,
curr_scale=curr_scale,
transition_weight=transition_weight,
**kwargs)
if isinstance(outputs, dict) and 'noise_batch' in outputs:
noise = outputs['noise_batch']
if sample_model == 'ema/orig' and self.use_ema:
_model = self.generator
outputs_ = _model(
noise,
num_batches=num_batches,
curr_scale=curr_scale,
transition_weight=transition_weight,
**kwargs)
if isinstance(outputs_, dict):
outputs['fake_img'] = torch.cat(
[outputs['fake_img'], outputs_['fake_img']], dim=0)
else:
outputs = torch.cat([outputs, outputs_], dim=0)
return outputs
def train_step(self,
data_batch,
optimizer,
ddp_reducer=None,
running_status=None):
"""Train step function.
This function implements the standard training iteration for
asynchronous adversarial training. Namely, in each iteration, we first
update discriminator and then compute loss for generator with the newly
updated discriminator.
As for distributed training, we use the ``reducer`` from ddp to
synchronize the necessary params in current computational graph.
Args:
data_batch (dict): Input data from dataloader.
optimizer (dict): Dict contains optimizer for generator and
discriminator.
ddp_reducer (:obj:`Reducer` | None, optional): Reducer from ddp.
It is used to prepare for ``backward()`` in ddp. Defaults to
None.
running_status (dict | None, optional): Contains necessary basic
information for training, e.g., iteration number. Defaults to
None.
Returns:
dict: Contains 'log_vars', 'num_samples', and 'results'.
"""
# get data from data_batch
real_imgs = data_batch['real_img']
# If you adopt ddp, this batch size is local batch size for each GPU.
batch_size = real_imgs.shape[0]
# get running status
if running_status is not None:
curr_iter = running_status['iteration']
else:
# dirty walkround for not providing running status
if not hasattr(self, 'iteration'):
self.iteration = 0
curr_iter = self.iteration
# check if optimizer from model
if hasattr(self, 'optimizer'):
optimizer = self.optimizer
# update current stage
self.curr_stage = int(
min(
sum(self.cum_nkimgs <= self.shown_nkimg.item()),
len(self.scales) - 1))
self.curr_scale = self.scales[self.curr_stage]
self._curr_scale_int = self._next_scale_int.clone()
# add new scale and update training status
if self.curr_stage != self.prev_stage:
self.prev_stage = self.curr_stage
self._actual_nkimgs.append(self.shown_nkimg.item())
# reset optimizer
if self.reset_optim_for_new_scale:
optim_cfg = deepcopy(self.train_cfg['optimizer_cfg'])
optim_cfg['generator']['lr'] = self.g_lr_schedule.get(
str(self.curr_scale[0]), self.g_lr_base)
optim_cfg['discriminator']['lr'] = self.d_lr_schedule.get(
str(self.curr_scale[0]), self.d_lr_base)
self.optimizer = build_optimizers(self, optim_cfg)
optimizer = self.optimizer
mmcv.print_log('Reset optimizer for new scale', logger='mmgen')
# update training configs, like transition weight for torgb layers.
# get current transition weight for interpolating two torgb layers
if self.curr_stage == 0:
transition_weight = 1.
else:
transition_weight = (
self.shown_nkimg.item() -
self._actual_nkimgs[-1]) / self.transition_kimgs
# clip to [0, 1]
transition_weight = min(max(transition_weight, 0.), 1.)
self._curr_transition_weight = torch.tensor(transition_weight).to(
self._curr_transition_weight)
# resize real image to target scale
if real_imgs.shape[2:] == self.curr_scale:
pass
elif real_imgs.shape[2] >= self.curr_scale[0] and real_imgs.shape[
3] >= self.curr_scale[1]:
real_imgs = self.interp_real_to(real_imgs, size=self.curr_scale)
else:
raise RuntimeError(
f'The scale of real image {real_imgs.shape[2:]} is smaller '
f'than current scale {self.curr_scale}.')
# disc training
set_requires_grad(self.discriminator, True)
optimizer['discriminator'].zero_grad()
# TODO: add noise sampler to customize noise sampling
with torch.no_grad():
fake_imgs = self.generator(
None,
num_batches=batch_size,
curr_scale=self.curr_scale[0],
transition_weight=transition_weight)
# disc pred for fake imgs and real_imgs
disc_pred_fake = self.discriminator(
fake_imgs,
curr_scale=self.curr_scale[0],
transition_weight=transition_weight)
disc_pred_real = self.discriminator(
real_imgs,
curr_scale=self.curr_scale[0],
transition_weight=transition_weight)
# get data dict to compute losses for disc
data_dict_ = dict(
iteration=curr_iter,
gen=self.generator,
disc=self.discriminator,
disc_pred_fake=disc_pred_fake,
disc_pred_real=disc_pred_real,
fake_imgs=fake_imgs,
real_imgs=real_imgs,
curr_scale=self.curr_scale[0],
transition_weight=transition_weight,
gen_partial=partial(
self.generator,
curr_scale=self.curr_scale[0],
transition_weight=transition_weight),
disc_partial=partial(
self.discriminator,
curr_scale=self.curr_scale[0],
transition_weight=transition_weight))
loss_disc, log_vars_disc = self._get_disc_loss(data_dict_)
# prepare for backward in ddp. If you do not call this function before
# back propagation, the ddp will not dynamically find the used params
# in current computation.
if ddp_reducer is not None:
ddp_reducer.prepare_for_backward(_find_tensors(loss_disc))
loss_disc.backward()
optimizer['discriminator'].step()
# update training log status
if dist.is_initialized():
_batch_size = batch_size * dist.get_world_size()
else:
if 'batch_size' not in running_status:
raise RuntimeError(
'You should offer "batch_size" in running status for PGGAN'
)
_batch_size = running_status['batch_size']
self.shown_nkimg += (_batch_size / 1000.)
log_vars_disc.update(
dict(
shown_nkimg=self.shown_nkimg.item(),
curr_scale=self.curr_scale[0],
transition_weight=transition_weight))
# skip generator training if only train discriminator for current
# iteration
if (curr_iter + 1) % self.disc_steps != 0:
results = dict(
fake_imgs=fake_imgs.cpu(), real_imgs=real_imgs.cpu())
outputs = dict(
log_vars=log_vars_disc,
num_samples=batch_size,
results=results)
if hasattr(self, 'iteration'):
self.iteration += 1
return outputs
# generator training
set_requires_grad(self.discriminator, False)
optimizer['generator'].zero_grad()
# TODO: add noise sampler to customize noise sampling
fake_imgs = self.generator(
None,
num_batches=batch_size,
curr_scale=self.curr_scale[0],
transition_weight=transition_weight)
disc_pred_fake_g = self.discriminator(
fake_imgs,
curr_scale=self.curr_scale[0],
transition_weight=transition_weight)
data_dict_ = dict(
iteration=curr_iter,
gen=self.generator,
disc=self.discriminator,
fake_imgs=fake_imgs,
disc_pred_fake_g=disc_pred_fake_g)
loss_gen, log_vars_g = self._get_gen_loss(data_dict_)
# prepare for backward in ddp. If you do not call this function before
# back propagation, the ddp will not dynamically find the used params
# in current computation.
if ddp_reducer is not None:
ddp_reducer.prepare_for_backward(_find_tensors(loss_gen))
loss_gen.backward()
optimizer['generator'].step()
log_vars = {}
log_vars.update(log_vars_g)
log_vars.update(log_vars_disc)
log_vars.update({'batch_size': batch_size})
results = dict(fake_imgs=fake_imgs.cpu(), real_imgs=real_imgs.cpu())
outputs = dict(
log_vars=log_vars, num_samples=batch_size, results=results)
if hasattr(self, 'iteration'):
self.iteration += 1
# check if a new scale will be added in the next iteration
_curr_stage = int(
min(
sum(self.cum_nkimgs <= self.shown_nkimg.item()),
len(self.scales) - 1))
# in the next iteration, we will switch to a new scale
if _curr_stage != self.curr_stage:
# `self._next_scale_int` is updated at the end of `train_step`
self._next_scale_int = self._next_scale_int * 2
return outputs
# Copyright (c) OpenMMLab. All rights reserved.
import pickle
from copy import deepcopy
from functools import partial
import mmcv
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parallel import DataParallel, DistributedDataParallel
from torch.nn.parallel.distributed import _find_tensors
from mmgen.models.architectures.common import get_module_device
from mmgen.models.builder import MODELS, build_module
from mmgen.models.gans.base_gan import BaseGAN
from ..common import set_requires_grad
@MODELS.register_module()
class SinGAN(BaseGAN):
"""SinGAN.
This model implement the single image generative adversarial model proposed
in: Singan: Learning a Generative Model from a Single Natural Image,
ICCV'19.
Notes for training:
- This model should be trained with our dataset ``SinGANDataset``.
- In training, the ``total_iters`` arguments is related to the number of
scales in the image pyramid and ``iters_per_scale`` in the ``train_cfg``.
You should set it carefully in the training config file.
Notes for model architectures:
- The generator and discriminator need ``num_scales`` in initialization.
However, this arguments is generated by ``create_real_pyramid`` function
from the ``singan_dataset.py``. The last element in the returned list
(``stop_scale``) is the value for ``num_scales``. Pay attention that this
scale is counted from zero. Please see our tutorial for SinGAN to obtain
more details or our standard config for reference.
Args:
generator (dict): Config for generator.
discriminator (dict): Config for discriminator.
gan_loss (dict): Config for generative adversarial loss.
disc_auxiliary_loss (dict): Config for auxiliary loss to
discriminator.
gen_auxiliary_loss (dict | None, optional): Config for auxiliary loss
to generator. Defaults to None.
train_cfg (dict | None, optional): Config for training schedule.
Defaults to None.
test_cfg (dict | None, optional): Config for testing schedule. Defaults
to None.
"""
def __init__(self,
generator,
discriminator,
gan_loss,
disc_auxiliary_loss,
gen_auxiliary_loss=None,
train_cfg=None,
test_cfg=None):
super().__init__()
self._gen_cfg = deepcopy(generator)
self.generator = build_module(generator)
# support no discriminator in testing
if discriminator is not None:
self.discriminator = build_module(discriminator)
else:
self.discriminator = None
# support no gan_loss in testing
if gan_loss is not None:
self.gan_loss = build_module(gan_loss)
else:
self.gan_loss = None
if disc_auxiliary_loss:
self.disc_auxiliary_losses = build_module(disc_auxiliary_loss)
if not isinstance(self.disc_auxiliary_losses, nn.ModuleList):
self.disc_auxiliary_losses = nn.ModuleList(
[self.disc_auxiliary_losses])
else:
self.disc_auxiliary_losses = None
if gen_auxiliary_loss:
self.gen_auxiliary_losses = build_module(gen_auxiliary_loss)
if not isinstance(self.gen_auxiliary_losses, nn.ModuleList):
self.gen_auxiliary_losses = nn.ModuleList(
[self.gen_auxiliary_losses])
else:
self.gen_auxiliary_losses = None
# register necessary training status
self.curr_stage = -1
self.noise_weights = [1]
self.fixed_noises = []
self.reals = []
self.train_cfg = deepcopy(train_cfg) if train_cfg else None
self.test_cfg = deepcopy(test_cfg) if test_cfg else None
self._parse_train_cfg()
if test_cfg is not None:
self._parse_test_cfg()
def _parse_train_cfg(self):
"""Parsing train config and set some attributes for training."""
if self.train_cfg is None:
self.train_cfg = dict()
# whether to use exponential moving average for training
self.use_ema = self.train_cfg.get('use_ema', False)
if self.use_ema:
# use deepcopy to guarantee the consistency
self.generator_ema = deepcopy(self.generator)
def _parse_test_cfg(self):
if self.test_cfg.get('pkl_data', None) is not None:
with open(self.test_cfg.pkl_data, 'rb') as f:
data = pickle.load(f)
self.fixed_noises = self._from_numpy(data['fixed_noises'])
self.noise_weights = self._from_numpy(data['noise_weights'])
self.curr_stage = data['curr_stage']
mmcv.print_log(f'Load pkl data from {self.test_cfg.pkl_data}',
'mmgen')
def _from_numpy(self, data):
if isinstance(data, list):
return [self._from_numpy(x) for x in data]
if isinstance(data, np.ndarray):
data = torch.from_numpy(data)
device = get_module_device(self.generator)
data = data.to(device)
return data
return data
def get_module(self, model, module_name):
"""Get an inner module from model.
Since we will wrapper DDP for some model, we have to judge whether the
module can be indexed directly.
Args:
model (nn.Module): This model may wrapped with DDP or not.
module_name (str): The name of specific module.
Return:
nn.Module: Returned sub module.
"""
if isinstance(model, (DataParallel, DistributedDataParallel)):
return getattr(model.module, module_name)
return getattr(model, module_name)
def sample_from_noise(self,
noise,
num_batches=0,
curr_scale=None,
sample_model='ema/orig',
**kwargs):
"""Sample images from noises by using the generator.
Args:
noise (torch.Tensor | callable | None): You can directly give a
batch of noise through a ``torch.Tensor`` or offer a callable
function to sample a batch of noise data. Otherwise, the
``None`` indicates to use the default noise sampler.
num_batches (int, optional): The number of batch size.
Defaults to 0.
Returns:
torch.Tensor | dict: The output may be the direct synthesized \
images in ``torch.Tensor``. Otherwise, a dict with queried \
data, including generated images, will be returned.
"""
# use `self.curr_scale` if curr_scale is None
if curr_scale is None:
curr_scale = self.curr_stage
if sample_model == 'ema':
assert self.use_ema
_model = self.generator_ema
elif sample_model == 'ema/orig' and self.use_ema:
_model = self.generator_ema
else:
_model = self.generator
if not self.fixed_noises[0].is_cuda and torch.cuda.is_available():
self.fixed_noises = [
x.to(get_module_device(self)) for x in self.fixed_noises
]
outputs = _model(
None,
fixed_noises=self.fixed_noises,
noise_weights=self.noise_weights,
rand_mode='rand',
num_batches=num_batches,
curr_scale=curr_scale,
**kwargs)
return outputs
def construct_fixed_noises(self):
"""Construct the fixed noises list used in SinGAN."""
for i, real in enumerate(self.reals):
h, w = real.shape[-2:]
if i == 0:
noise = torch.randn(1, 1, h, w).to(real)
self.fixed_noises.append(noise)
else:
noise = torch.zeros_like(real)
self.fixed_noises.append(noise)
def train_step(self,
data_batch,
optimizer,
ddp_reducer=None,
running_status=None):
"""Train step function.
This function implements the standard training iteration for
asynchronous adversarial training. Namely, in each iteration, we first
update discriminator and then compute loss for generator with the newly
updated discriminator.
As for distributed training, we use the ``reducer`` from ddp to
synchronize the necessary params in current computational graph.
Args:
data_batch (dict): Input data from dataloader.
optimizer (dict): Dict contains optimizer for generator and
discriminator.
ddp_reducer (:obj:`Reducer` | None, optional): Reducer from ddp.
It is used to prepare for ``backward()`` in ddp. Defaults to
None.
running_status (dict | None, optional): Contains necessary basic
information for training, e.g., iteration number. Defaults to
None.
Returns:
dict: Contains 'log_vars', 'num_samples', and 'results'.
"""
# get running status
if running_status is not None:
curr_iter = running_status['iteration']
else:
# dirty walkround for not providing running status
if not hasattr(self, 'iteration'):
self.iteration = 0
curr_iter = self.iteration
# init each scale
if curr_iter % self.train_cfg['iters_per_scale'] == 0:
self.curr_stage += 1
# load weights from prev scale
self.get_module(self.generator, 'check_and_load_prev_weight')(
self.curr_stage)
self.get_module(self.discriminator, 'check_and_load_prev_weight')(
self.curr_stage)
# build optimizer for each scale
g_module = self.get_module(self.generator, 'blocks')
param_list = g_module[self.curr_stage].parameters()
self.g_optim = torch.optim.Adam(
param_list, lr=self.train_cfg['lr_g'], betas=(0.5, 0.999))
d_module = self.get_module(self.discriminator, 'blocks')
self.d_optim = torch.optim.Adam(
d_module[self.curr_stage].parameters(),
lr=self.train_cfg['lr_d'],
betas=(0.5, 0.999))
self.optimizer = dict(
generator=self.g_optim, discriminator=self.d_optim)
self.g_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer=self.g_optim, **self.train_cfg['lr_scheduler_args'])
self.d_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer=self.d_optim, **self.train_cfg['lr_scheduler_args'])
optimizer = self.optimizer
# setup fixed noises and reals pyramid
if curr_iter == 0 or len(self.reals) == 0:
keys = [k for k in data_batch.keys() if 'real_scale' in k]
scales = len(keys)
self.reals = [data_batch[f'real_scale{s}'] for s in range(scales)]
# here we do not padding fixed noises
self.construct_fixed_noises()
# disc training
set_requires_grad(self.discriminator, True)
for _ in range(self.train_cfg['disc_steps']):
optimizer['discriminator'].zero_grad()
# TODO: add noise sampler to customize noise sampling
with torch.no_grad():
fake_imgs = self.generator(
data_batch['input_sample'],
self.fixed_noises,
self.noise_weights,
rand_mode='rand',
curr_scale=self.curr_stage)
# disc pred for fake imgs and real_imgs
disc_pred_fake = self.discriminator(fake_imgs.detach(),
self.curr_stage)
disc_pred_real = self.discriminator(self.reals[self.curr_stage],
self.curr_stage)
# get data dict to compute losses for disc
data_dict_ = dict(
iteration=curr_iter,
gen=self.generator,
disc=self.discriminator,
disc_pred_fake=disc_pred_fake,
disc_pred_real=disc_pred_real,
fake_imgs=fake_imgs,
real_imgs=self.reals[self.curr_stage],
disc_partial=partial(
self.discriminator, curr_scale=self.curr_stage))
loss_disc, log_vars_disc = self._get_disc_loss(data_dict_)
# prepare for backward in ddp. If you do not call this function
# before back propagation, the ddp will not dynamically find the
# used params in current computation.
if ddp_reducer is not None:
ddp_reducer.prepare_for_backward(_find_tensors(loss_disc))
loss_disc.backward()
optimizer['discriminator'].step()
log_vars_disc.update(dict(curr_stage=self.curr_stage))
# generator training
set_requires_grad(self.discriminator, False)
for _ in range(self.train_cfg['generator_steps']):
optimizer['generator'].zero_grad()
# TODO: add noise sampler to customize noise sampling
fake_imgs = self.generator(
data_batch['input_sample'],
self.fixed_noises,
self.noise_weights,
rand_mode='rand',
curr_scale=self.curr_stage)
disc_pred_fake_g = self.discriminator(
fake_imgs, curr_scale=self.curr_stage)
recon_imgs = self.generator(
data_batch['input_sample'],
self.fixed_noises,
self.noise_weights,
rand_mode='recon',
curr_scale=self.curr_stage)
data_dict_ = dict(
iteration=curr_iter,
gen=self.generator,
disc=self.discriminator,
fake_imgs=fake_imgs,
recon_imgs=recon_imgs,
real_imgs=self.reals[self.curr_stage],
disc_pred_fake_g=disc_pred_fake_g)
loss_gen, log_vars_g = self._get_gen_loss(data_dict_)
# prepare for backward in ddp. If you do not call this function
# before back propagation, the ddp will not dynamically find the
# used params in current computation.
if ddp_reducer is not None:
ddp_reducer.prepare_for_backward(_find_tensors(loss_gen))
loss_gen.backward()
optimizer['generator'].step()
# end of each scale
# calculate noise weight for next scale
if (curr_iter % self.train_cfg['iters_per_scale']
== 0) and (self.curr_stage < len(self.reals) - 1):
with torch.no_grad():
g_recon = self.generator(
data_batch['input_sample'],
self.fixed_noises,
self.noise_weights,
rand_mode='recon',
curr_scale=self.curr_stage)
if isinstance(g_recon, dict):
g_recon = g_recon['fake_img']
g_recon = F.interpolate(
g_recon, self.reals[self.curr_stage + 1].shape[-2:])
mse = F.mse_loss(g_recon.detach(), self.reals[self.curr_stage + 1])
rmse = torch.sqrt(mse)
self.noise_weights.append(
self.train_cfg.get('noise_weight_init', 0.1) * rmse.item())
# try to release GPU memory.
torch.cuda.empty_cache()
log_vars = {}
log_vars.update(log_vars_g)
log_vars.update(log_vars_disc)
results = dict(
fake_imgs=fake_imgs.cpu(),
real_imgs=self.reals[self.curr_stage].cpu(),
recon_imgs=recon_imgs.cpu(),
curr_stage=self.curr_stage,
fixed_noises=self.fixed_noises,
noise_weights=self.noise_weights)
outputs = dict(log_vars=log_vars, num_samples=1, results=results)
# update lr scheduler
self.d_scheduler.step()
self.g_scheduler.step()
if hasattr(self, 'iteration'):
self.iteration += 1
return outputs
@MODELS.register_module()
class PESinGAN(SinGAN):
"""Positional Encoding in SinGAN.
This modified SinGAN is used to reimplement the experiments in: Positional
Encoding as Spatial Inductive Bias in GANs, CVPR2021.
"""
def _parse_train_cfg(self):
super(PESinGAN, self)._parse_train_cfg()
self.fixed_noise_with_pad = self.train_cfg.get('fixed_noise_with_pad',
False)
self.first_fixed_noises_ch = self.train_cfg.get(
'first_fixed_noises_ch', 1)
def construct_fixed_noises(self):
"""Construct the fixed noises list used in SinGAN."""
for i, real in enumerate(self.reals):
h, w = real.shape[-2:]
if self.fixed_noise_with_pad:
pad_ = self.get_module(self, 'generator').pad_head
h += 2 * pad_
w += 2 * pad_
if i == 0:
noise = torch.randn(1, self.first_fixed_noises_ch, h,
w).to(real)
self.fixed_noises.append(noise)
else:
noise = torch.zeros((1, 1, h, w)).to(real)
self.fixed_noises.append(noise)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment