Commit 1401de15 authored by dongchy920's avatar dongchy920
Browse files

stylegan2_mmcv

parents
Pipeline #1274 canceled with stages
# Copyright (c) OpenMMLab. All rights reserved.
from .base_gan import BaseGAN
from .basic_conditional_gan import BasicConditionalGAN
from .mspie_stylegan2 import MSPIEStyleGAN2
from .progressive_growing_unconditional_gan import ProgressiveGrowingGAN
from .singan import PESinGAN, SinGAN
from .static_unconditional_gan import StaticUnconditionalGAN
__all__ = [
'BaseGAN', 'StaticUnconditionalGAN', 'ProgressiveGrowingGAN', 'SinGAN',
'MSPIEStyleGAN2', 'PESinGAN', 'BasicConditionalGAN'
]
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
import torch
import torch.distributed as dist
import torch.nn as nn
class BaseGAN(nn.Module, metaclass=ABCMeta):
"""BaseGAN Module."""
def __init__(self):
super().__init__()
self.fp16_enabled = False
@property
def with_disc(self):
"""Whether with dicriminator."""
return hasattr(self,
'discriminator') and self.discriminator is not None
@property
def with_ema_gen(self):
"""bool: whether the GAN adopts exponential moving average."""
return hasattr(self, 'gen_ema') and self.gen_ema is not None
@property
def with_gen_auxiliary_loss(self):
"""bool: whether the GAN adopts auxiliary loss in the generator."""
return hasattr(self,
'gen_auxiliary_losses') and (self.gen_auxiliary_losses
is not None)
@property
def with_disc_auxiliary_loss(self):
"""bool: whether the GAN adopts auxiliary loss in the discriminator."""
return (hasattr(self, 'disc_auxiliary_losses')
) and self.disc_auxiliary_losses is not None
def _get_disc_loss(self, outputs_dict):
# Construct losses dict. If you hope some items to be included in the
# computational graph, you have to add 'loss' in its name. Otherwise,
# items without 'loss' in their name will just be used to print
# information.
losses_dict = {}
# gan loss
losses_dict['loss_disc_fake'] = self.gan_loss(
outputs_dict['disc_pred_fake'], target_is_real=False, is_disc=True)
losses_dict['loss_disc_real'] = self.gan_loss(
outputs_dict['disc_pred_real'], target_is_real=True, is_disc=True)
# disc auxiliary loss
if self.with_disc_auxiliary_loss:
for loss_module in self.disc_auxiliary_losses:
loss_ = loss_module(outputs_dict)
if loss_ is None:
continue
# the `loss_name()` function return name as 'loss_xxx'
if loss_module.loss_name() in losses_dict:
losses_dict[loss_module.loss_name(
)] = losses_dict[loss_module.loss_name()] + loss_
else:
losses_dict[loss_module.loss_name()] = loss_
loss, log_var = self._parse_losses(losses_dict)
return loss, log_var
def _get_gen_loss(self, outputs_dict):
# Construct losses dict. If you hope some items to be included in the
# computational graph, you have to add 'loss' in its name. Otherwise,
# items without 'loss' in their name will just be used to print
# information.
losses_dict = {}
# gan loss
losses_dict['loss_disc_fake_g'] = self.gan_loss(
outputs_dict['disc_pred_fake_g'],
target_is_real=True,
is_disc=False)
# gen auxiliary loss
if self.with_gen_auxiliary_loss:
for loss_module in self.gen_auxiliary_losses:
loss_ = loss_module(outputs_dict)
if loss_ is None:
continue
# the `loss_name()` function return name as 'loss_xxx'
if loss_module.loss_name() in losses_dict:
losses_dict[loss_module.loss_name(
)] = losses_dict[loss_module.loss_name()] + loss_
else:
losses_dict[loss_module.loss_name()] = loss_
loss, log_var = self._parse_losses(losses_dict)
return loss, log_var
@abstractmethod
def train_step(self, data, optimizer, ddp_reducer=None):
"""The iteration step during training.
This method defines an iteration step during training. Different from
other repo in **MM** series, we allow the back propagation and
optimizer updating to directly follow the iterative training schedule
of GAN. Of course, we will show that you can also move the back
propagation outside of this method, and then optimize the parameters
in the optimizer hook. But this will cause extra GPU memory cost as a
result of retaining computational graph. Otherwise, the training
schedule should be modified in the detailed implementation.
TODO: Give an example of removing bp outside ``train_step``.
TODO: Try the synchronized back propagation.
Args:
data (dict): The output of dataloader.
optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
runner is passed to ``train_step()``. This argument is unused
and reserved.
ddp_reducer (:obj:`Reducer` | None, optional): This reducer is used
to dynamically collect used parameters in the distributed
training. If given an initialized ``Reducer``, we will call its
``prepare_for_backward()`` function just before calling
``.backward()``.
Returns:
dict: It should contain at least 3 keys: ``loss``, ``log_vars``, \
``num_samples``.
- ``loss`` is a tensor for back propagation, which can be a \
weighted sum of multiple losses.
- ``log_vars`` contains all the variables to be sent to the
logger.
- ``num_samples`` indicates the batch size (when the model is \
DDP, it means the batch size on each GPU), which is used for \
averaging the logs.
"""
def sample_from_noise(self,
noise,
num_batches=0,
sample_model='ema/orig',
**kwargs):
"""Sample images from noises by using the generator.
Args:
noise (torch.Tensor | callable | None): You can directly give a
batch of noise through a ``torch.Tensor`` or offer a callable
function to sample a batch of noise data. Otherwise, the
``None`` indicates to use the default noise sampler.
num_batches (int, optional): The number of batch size.
Defaults to 0.
Returns:
torch.Tensor | dict: The output may be the direct synthesized
images in ``torch.Tensor``. Otherwise, a dict with queried
data, including generated images, will be returned.
"""
if sample_model == 'ema':
assert self.use_ema
_model = self.generator_ema
elif sample_model == 'ema/orig' and self.use_ema:
_model = self.generator_ema
else:
_model = self.generator
outputs = _model(noise, num_batches=num_batches, **kwargs)
if isinstance(outputs, dict) and 'noise_batch' in outputs:
noise = outputs['noise_batch']
if sample_model == 'ema/orig' and self.use_ema:
_model = self.generator
outputs_ = _model(noise, num_batches=num_batches, **kwargs)
if isinstance(outputs_, dict):
outputs['fake_img'] = torch.cat(
[outputs['fake_img'], outputs_['fake_img']], dim=0)
else:
outputs = torch.cat([outputs, outputs_], dim=0)
return outputs
def forward_train(self, data, **kwargs):
"""Deprecated forward function in training."""
raise NotImplementedError(
'In MMGeneration, we do NOT recommend users to call'
'this function, because the train_step function is designed for '
'the training process.')
def forward_test(self, data, **kwargs):
"""Testing function for GANs.
Args:
data (torch.Tensor | dict | None): Input data. This data will be
passed to different methods.
"""
if kwargs.pop('mode', 'sampling') == 'sampling':
return self.sample_from_noise(data, **kwargs)
raise NotImplementedError('Other specific testing functions should'
' be implemented by the sub-classes.')
def forward(self, data, return_loss=False, **kwargs):
"""Forward function.
Args:
data (dict | torch.Tensor): Input data dictionary.
return_loss (bool, optional): Whether in training or testing.
Defaults to False.
Returns:
dict: Output dictionary.
"""
if return_loss:
return self.forward_train(data, **kwargs)
return self.forward_test(data, **kwargs)
def _parse_losses(self, losses):
"""Parse the raw outputs (losses) of the network.
Args:
losses (dict): Raw output of the network, which usually contain
losses and other necessary information.
Returns:
tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor \
which may be a weighted sum of all losses, log_vars contains \
all the variables to be sent to the logger.
"""
log_vars = OrderedDict()
for loss_name, loss_value in losses.items():
if isinstance(loss_value, torch.Tensor):
log_vars[loss_name] = loss_value.mean()
elif isinstance(loss_value, list):
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
# Allow setting None for some loss item.
# This is to support dynamic loss module, where the loss is
# calculated with a fixed frequency.
elif loss_value is None:
continue
else:
raise TypeError(
f'{loss_name} is not a tensor or list of tensors')
# Note that you have to add 'loss' in name of the items that will be
# included in back propagation.
loss = sum(_value for _key, _value in log_vars.items()
if 'loss' in _key)
log_vars['loss'] = loss
for loss_name, loss_value in log_vars.items():
# reduce loss when distributed training
if dist.is_available() and dist.is_initialized():
loss_value = loss_value.data.clone()
dist.all_reduce(loss_value.div_(dist.get_world_size()))
log_vars[loss_name] = loss_value.item()
return loss, log_vars
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
import torch
import torch.nn as nn
from torch.nn.parallel.distributed import _find_tensors
from ..builder import MODELS, build_module
from ..common import set_requires_grad
from .base_gan import BaseGAN
@MODELS.register_module('BasiccGAN')
@MODELS.register_module()
class BasicConditionalGAN(BaseGAN):
"""Basic conditional GANs.
This is the conditional GAN model containing standard adversarial training
schedule. To fulfill the requirements of various GAN algorithms,
``disc_auxiliary_loss`` and ``gen_auxiliary_loss`` are provided to
customize auxiliary losses, e.g., gradient penalty loss, and discriminator
shift loss. In addition, ``train_cfg`` and ``test_cfg`` aims at setuping
training schedule.
Args:
generator (dict): Config for generator.
discriminator (dict): Config for discriminator.
gan_loss (dict): Config for generative adversarial loss.
disc_auxiliary_loss (dict): Config for auxiliary loss to
discriminator.
gen_auxiliary_loss (dict | None, optional): Config for auxiliary loss
to generator. Defaults to None.
train_cfg (dict | None, optional): Config for training schedule.
Defaults to None.
test_cfg (dict | None, optional): Config for testing schedule. Defaults
to None.
num_classes (int | None, optional): The number of conditional classes.
Defaults to None.
"""
def __init__(self,
generator,
discriminator,
gan_loss,
disc_auxiliary_loss=None,
gen_auxiliary_loss=None,
train_cfg=None,
test_cfg=None,
num_classes=None):
super().__init__()
self.num_classes = num_classes
self._gen_cfg = deepcopy(generator)
self.generator = build_module(
generator, default_args=dict(num_classes=num_classes))
# support no discriminator in testing
if discriminator is not None:
self.discriminator = build_module(
discriminator, default_args=dict(num_classes=num_classes))
else:
self.discriminator = None
# support no gan_loss in testing
if gan_loss is not None:
self.gan_loss = build_module(gan_loss)
else:
self.gan_loss = None
if disc_auxiliary_loss:
self.disc_auxiliary_losses = build_module(disc_auxiliary_loss)
if not isinstance(self.disc_auxiliary_losses, nn.ModuleList):
self.disc_auxiliary_losses = nn.ModuleList(
[self.disc_auxiliary_losses])
else:
self.disc_auxiliary_loss = None
if gen_auxiliary_loss:
self.gen_auxiliary_losses = build_module(gen_auxiliary_loss)
if not isinstance(self.gen_auxiliary_losses, nn.ModuleList):
self.gen_auxiliary_losses = nn.ModuleList(
[self.gen_auxiliary_losses])
else:
self.gen_auxiliary_losses = None
self.train_cfg = deepcopy(train_cfg) if train_cfg else None
self.test_cfg = deepcopy(test_cfg) if test_cfg else None
self._parse_train_cfg()
if test_cfg is not None:
self._parse_test_cfg()
def _parse_train_cfg(self):
"""Parsing train config and set some attributes for training."""
if self.train_cfg is None:
self.train_cfg = dict()
# control the work flow in train step
self.disc_steps = self.train_cfg.get('disc_steps', 1)
self.gen_steps = self.train_cfg.get('gen_steps', 1)
# add support for accumulating gradients within multiple steps. This
# feature aims to simulate large `batch_sizes` (but may have some
# detailed differences in BN). Note that `self.disc_steps` should be
# set according to the batch accumulation strategy.
# In addition, in the detailed implementation, there is a difference
# between the batch accumulation in the generator and discriminator.
self.batch_accumulation_steps = self.train_cfg.get(
'batch_accumulation_steps', 1)
# whether to use exponential moving average for training
self.use_ema = self.train_cfg.get('use_ema', False)
if self.use_ema:
# use deepcopy to guarantee the consistency
self.generator_ema = deepcopy(self.generator)
def _parse_test_cfg(self):
"""Parsing test config and set some attributes for testing."""
if self.test_cfg is None:
self.test_cfg = dict()
# basic testing information
self.batch_size = self.test_cfg.get('batch_size', 1)
# whether to use exponential moving average for testing
self.use_ema = self.test_cfg.get('use_ema', False)
# TODO: finish ema part
def train_step(self,
data_batch,
optimizer,
ddp_reducer=None,
loss_scaler=None,
use_apex_amp=False,
running_status=None):
"""Train step function.
This function implements the standard training iteration for
asynchronous adversarial training. Namely, in each iteration, we first
update discriminator and then compute loss for generator with the newly
updated discriminator.
As for distributed training, we use the ``reducer`` from ddp to
synchronize the necessary params in current computational graph.
Args:
data_batch (dict): Input data from dataloader.
optimizer (dict): Dict contains optimizer for generator and
discriminator.
ddp_reducer (:obj:`Reducer` | None, optional): Reducer from ddp.
It is used to prepare for ``backward()`` in ddp. Defaults to
None.
loss_scaler (:obj:`torch.cuda.amp.GradScaler` | None, optional):
The loss/gradient scaler used for auto mixed-precision
training. Defaults to ``None``.
use_apex_amp (bool, optional). Whether to use apex.amp. Defaults to
``False``.
running_status (dict | None, optional): Contains necessary basic
information for training, e.g., iteration number. Defaults to
None.
Returns:
dict: Contains 'log_vars', 'num_samples', and 'results'.
"""
# get data from data_batch
real_imgs = data_batch['img']
# get the ground-truth label, torch.Tensor (N, )
gt_label = data_batch['gt_label']
# If you adopt ddp, this batch size is local batch size for each GPU.
# If you adopt dp, this batch size is the global batch size as usual.
batch_size = real_imgs.shape[0]
# get running status
if running_status is not None:
curr_iter = running_status['iteration']
else:
# dirty walkround for not providing running status
if not hasattr(self, 'iteration'):
self.iteration = 0
curr_iter = self.iteration
# disc training
set_requires_grad(self.discriminator, True)
# do not `zero_grad` during batch accumulation
if curr_iter % self.batch_accumulation_steps == 0:
optimizer['discriminator'].zero_grad()
# TODO: add noise sampler to customize noise sampling
with torch.no_grad():
fake_data = self.generator(
None, num_batches=batch_size, label=None, return_noise=True)
# fake_label should be in the same data type with the gt_label
fake_imgs, fake_label = fake_data['fake_img'], fake_data['label']
# disc pred for fake imgs and real_imgs
disc_pred_fake = self.discriminator(fake_imgs, label=fake_label)
disc_pred_real = self.discriminator(real_imgs, label=gt_label)
# get data dict to compute losses for disc
data_dict_ = dict(
gen=self.generator,
disc=self.discriminator,
disc_pred_fake=disc_pred_fake,
disc_pred_real=disc_pred_real,
fake_imgs=fake_imgs,
real_imgs=real_imgs,
iteration=curr_iter,
batch_size=batch_size,
gt_label=gt_label,
fake_label=fake_label,
loss_scaler=loss_scaler)
loss_disc, log_vars_disc = self._get_disc_loss(data_dict_)
loss_disc = loss_disc / float(self.batch_accumulation_steps)
# prepare for backward in ddp. If you do not call this function before
# back propagation, the ddp will not dynamically find the used params
# in current computation.
if ddp_reducer is not None:
ddp_reducer.prepare_for_backward(_find_tensors(loss_disc))
if loss_scaler:
# add support for fp16
loss_scaler.scale(loss_disc).backward()
elif use_apex_amp:
from apex import amp
with amp.scale_loss(
loss_disc, optimizer['discriminator'],
loss_id=0) as scaled_loss_disc:
scaled_loss_disc.backward()
else:
loss_disc.backward()
if (curr_iter + 1) % self.batch_accumulation_steps == 0:
if loss_scaler:
loss_scaler.unscale_(optimizer['discriminator'])
# note that we do not contain clip_grad procedure
loss_scaler.step(optimizer['discriminator'])
# loss_scaler.update will be called in runner.train()
else:
optimizer['discriminator'].step()
# skip generator training if only train discriminator for current
# iteration
if (curr_iter + 1) % self.disc_steps != 0:
results = dict(
fake_imgs=fake_imgs.cpu(), real_imgs=real_imgs.cpu())
outputs = dict(
log_vars=log_vars_disc,
num_samples=batch_size,
results=results)
if hasattr(self, 'iteration'):
self.iteration += 1
return outputs
# generator training
set_requires_grad(self.discriminator, False)
# allow for training the generator with multiple steps
for _ in range(self.gen_steps):
optimizer['generator'].zero_grad()
for _ in range(self.batch_accumulation_steps):
# TODO: add noise sampler to customize noise sampling
fake_data = self.generator(
None, num_batches=batch_size, return_noise=True)
# fake_label should be in the same data type with the gt_label
fake_imgs, fake_label = fake_data['fake_img'], fake_data[
'label']
disc_pred_fake_g = self.discriminator(
fake_imgs, label=fake_label)
data_dict_ = dict(
gen=self.generator,
disc=self.discriminator,
fake_imgs=fake_imgs,
disc_pred_fake_g=disc_pred_fake_g,
iteration=curr_iter,
batch_size=batch_size,
fake_label=fake_label,
loss_scaler=loss_scaler)
loss_gen, log_vars_g = self._get_gen_loss(data_dict_)
loss_gen = loss_gen / float(self.batch_accumulation_steps)
# prepare for backward in ddp. If you do not call this function
# before back propagation, the ddp will not dynamically find
# the used params in current computation.
if ddp_reducer is not None:
ddp_reducer.prepare_for_backward(_find_tensors(loss_gen))
if loss_scaler:
loss_scaler.scale(loss_gen).backward()
elif use_apex_amp:
from apex import amp
with amp.scale_loss(
loss_gen, optimizer['generator'],
loss_id=1) as scaled_loss_disc:
scaled_loss_disc.backward()
else:
loss_gen.backward()
if loss_scaler:
loss_scaler.unscale_(optimizer['generator'])
# note that we do not contain clip_grad procedure
loss_scaler.step(optimizer['generator'])
# loss_scaler.update will be called in runner.train()
else:
optimizer['generator'].step()
log_vars = {}
log_vars.update(log_vars_g)
log_vars.update(log_vars_disc)
results = dict(fake_imgs=fake_imgs.cpu(), real_imgs=real_imgs.cpu())
outputs = dict(
log_vars=log_vars, num_samples=batch_size, results=results)
if hasattr(self, 'iteration'):
self.iteration += 1
return outputs
def sample_from_noise(self,
noise,
num_batches=0,
sample_model='ema/orig',
label=None,
**kwargs):
"""Sample images from noises by using the generator.
Args:
noise (torch.Tensor | callable | None): You can directly give a
batch of noise through a ``torch.Tensor`` or offer a callable
function to sample a batch of noise data. Otherwise, the
``None`` indicates to use the default noise sampler.
num_batches (int, optional): The number of batch size.
Defaults to 0.
sampel_model (str, optional): Use which model to sample fake
images. Defaults to `'ema/orig'`.
label (torch.Tensor | None , optional): The conditional label.
Defaults to None.
Returns:
torch.Tensor | dict: The output may be the direct synthesized
images in ``torch.Tensor``. Otherwise, a dict with queried
data, including generated images, will be returned.
"""
if sample_model == 'ema':
assert self.use_ema
_model = self.generator_ema
elif sample_model == 'ema/orig' and self.use_ema:
_model = self.generator_ema
else:
_model = self.generator
outputs = _model(noise, num_batches=num_batches, label=label, **kwargs)
if isinstance(outputs, dict) and 'noise_batch' in outputs:
noise = outputs['noise_batch']
label = outputs['label']
if sample_model == 'ema/orig' and self.use_ema:
_model = self.generator
outputs_ = _model(
noise, num_batches=num_batches, label=label, **kwargs)
if isinstance(outputs_, dict):
outputs['fake_img'] = torch.cat(
[outputs['fake_img'], outputs_['fake_img']], dim=0)
else:
outputs = torch.cat([outputs, outputs_], dim=0)
return outputs
# Copyright (c) OpenMMLab. All rights reserved.
import logging
from functools import partial
import mmcv
import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.nn.parallel.distributed import _find_tensors
from ..builder import MODELS
from ..common import set_requires_grad
from .static_unconditional_gan import StaticUnconditionalGAN
@MODELS.register_module()
class MSPIEStyleGAN2(StaticUnconditionalGAN):
"""MS-PIE StyleGAN2.
In this GAN, we adopt the MS-PIE training schedule so that multi-scale
images can be generated with a single generator. Details can be found in:
Positional Encoding as Spatial Inductive Bias in GANs, CVPR2021.
Args:
generator (dict): Config for generator.
discriminator (dict): Config for discriminator.
gan_loss (dict): Config for generative adversarial loss.
disc_auxiliary_loss (dict): Config for auxiliary loss to
discriminator.
gen_auxiliary_loss (dict | None, optional): Config for auxiliary loss
to generator. Defaults to None.
train_cfg (dict | None, optional): Config for training schedule.
Defaults to None.
test_cfg (dict | None, optional): Config for testing schedule. Defaults
to None.
"""
def _parse_train_cfg(self):
super(MSPIEStyleGAN2, self)._parse_train_cfg()
# set the number of upsampling blocks. This value will be used to
# calculate the current result size according to the size of the input
# feature map, e.g., positional encoding map
self.num_upblocks = self.train_cfg.get('num_upblocks', 6)
# multiple input scales (a list of int) that will be added to the
# original starting scale.
self.multi_input_scales = self.train_cfg.get('multi_input_scales')
self.multi_scale_probability = self.train_cfg.get(
'multi_scale_probability')
def train_step(self,
data_batch,
optimizer,
ddp_reducer=None,
running_status=None):
"""Train step function.
This function implements the standard training iteration for
asynchronous adversarial training. Namely, in each iteration, we first
update discriminator and then compute loss for generator with the newly
updated discriminator.
As for distributed training, we use the ``reducer`` from ddp to
synchronize the necessary params in current computational graph.
Args:
data_batch (dict): Input data from dataloader.
optimizer (dict): Dict contains optimizer for generator and
discriminator.
ddp_reducer (:obj:`Reducer` | None, optional): Reducer from ddp.
It is used to prepare for ``backward()`` in ddp. Defaults to
None.
running_status (dict | None, optional): Contains necessary basic
information for training, e.g., iteration number. Defaults to
None.
Returns:
dict: Contains 'log_vars', 'num_samples', and 'results'.
"""
# get data from data_batch
real_imgs = data_batch['real_img']
# If you adopt ddp, this batch size is local batch size for each GPU.
# If you adopt dp, this batch size is the global batch size as usual.
batch_size = real_imgs.shape[0]
# get running status
if running_status is not None:
curr_iter = running_status['iteration']
else:
# dirty walkround for not providing running status
if not hasattr(self, 'iteration'):
self.iteration = 0
curr_iter = self.iteration
if dist.is_initialized():
# randomly sample a scale for current training iteration
chosen_scale = np.random.choice(self.multi_input_scales, 1,
self.multi_scale_probability)[0]
chosen_scale = torch.tensor(chosen_scale, dtype=torch.int).cuda()
dist.broadcast(chosen_scale, 0)
chosen_scale = int(chosen_scale.item())
else:
mmcv.print_log(
'Distributed training has not been initialized. Degrade to '
'the standard stylegan2',
logger='mmgen',
level=logging.WARN)
chosen_scale = 0
curr_size = (4 + chosen_scale) * (2**self.num_upblocks)
# adjust the shape of images
if real_imgs.shape[-2:] != (curr_size, curr_size):
real_imgs = F.interpolate(
real_imgs,
size=(curr_size, curr_size),
mode='bilinear',
align_corners=True)
# disc training
set_requires_grad(self.discriminator, True)
optimizer['discriminator'].zero_grad()
# TODO: add noise sampler to customize noise sampling
with torch.no_grad():
fake_imgs = self.generator(
None, num_batches=batch_size, chosen_scale=chosen_scale)
# disc pred for fake imgs and real_imgs
disc_pred_fake = self.discriminator(fake_imgs)
disc_pred_real = self.discriminator(real_imgs)
# get data dict to compute losses for disc
data_dict_ = dict(
gen=self.generator,
disc=self.discriminator,
disc_pred_fake=disc_pred_fake,
disc_pred_real=disc_pred_real,
fake_imgs=fake_imgs,
real_imgs=real_imgs,
iteration=curr_iter,
batch_size=batch_size,
gen_partial=partial(self.generator, chosen_scale=chosen_scale))
loss_disc, log_vars_disc = self._get_disc_loss(data_dict_)
# prepare for backward in ddp. If you do not call this function before
# back propagation, the ddp will not dynamically find the used params
# in current computation.
if ddp_reducer is not None:
ddp_reducer.prepare_for_backward(_find_tensors(loss_disc))
loss_disc.backward()
optimizer['discriminator'].step()
# skip generator training if only train discriminator for current
# iteration
if (curr_iter + 1) % self.disc_steps != 0:
results = dict(
fake_imgs=fake_imgs.cpu(), real_imgs=real_imgs.cpu())
log_vars_disc['curr_size'] = curr_size
outputs = dict(
log_vars=log_vars_disc,
num_samples=batch_size,
results=results)
if hasattr(self, 'iteration'):
self.iteration += 1
return outputs
# generator training
set_requires_grad(self.discriminator, False)
optimizer['generator'].zero_grad()
# TODO: add noise sampler to customize noise sampling
fake_imgs = self.generator(
None, num_batches=batch_size, chosen_scale=chosen_scale)
disc_pred_fake_g = self.discriminator(fake_imgs)
data_dict_ = dict(
gen=self.generator,
disc=self.discriminator,
fake_imgs=fake_imgs,
disc_pred_fake_g=disc_pred_fake_g,
iteration=curr_iter,
batch_size=batch_size,
gen_partial=partial(self.generator, chosen_scale=chosen_scale))
loss_gen, log_vars_g = self._get_gen_loss(data_dict_)
# prepare for backward in ddp. If you do not call this function before
# back propagation, the ddp will not dynamically find the used params
# in current computation.
if ddp_reducer is not None:
ddp_reducer.prepare_for_backward(_find_tensors(loss_gen))
loss_gen.backward()
optimizer['generator'].step()
log_vars = {}
log_vars.update(log_vars_g)
log_vars.update(log_vars_disc)
log_vars['curr_size'] = curr_size
results = dict(fake_imgs=fake_imgs.cpu(), real_imgs=real_imgs.cpu())
outputs = dict(
log_vars=log_vars, num_samples=batch_size, results=results)
if hasattr(self, 'iteration'):
self.iteration += 1
return outputs
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from functools import partial
import mmcv
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parallel.distributed import _find_tensors
from mmgen.core.optimizer import build_optimizers
from mmgen.models.builder import MODELS, build_module
from ..common import set_requires_grad
from .base_gan import BaseGAN
@MODELS.register_module('StyleGANV1')
@MODELS.register_module('PGGAN')
@MODELS.register_module()
class ProgressiveGrowingGAN(BaseGAN):
"""Progressive Growing Unconditional GAN.
In this GAN model, we implement progressive growing training schedule,
which is proposed in Progressive Growing of GANs for improved Quality,
Stability and Variation, ICLR 2018.
We highly recommend to use ``GrowScaleImgDataset`` for saving computational
load in data pre-processing.
Notes for **using PGGAN**:
#. In official implementation, Tero uses gradient penalty with
``norm_mode="HWC"``
#. We do not implement ``minibatch_repeats`` where has been used in
official Tensorflow implementation.
Notes for resuming progressive growing GANs:
Users should specify the ``prev_stage`` in ``train_cfg``. Otherwise, the
model is possible to reset the optimizer status, which will bring
inferior performance. For example, if your model is resumed from the
`256` stage, you should set ``train_cfg=dict(prev_stage=256)``.
Args:
generator (dict): Config for generator.
discriminator (dict): Config for discriminator.
gan_loss (dict): Config for generative adversarial loss.
disc_auxiliary_loss (dict): Config for auxiliary loss to
discriminator.
gen_auxiliary_loss (dict | None, optional): Config for auxiliary loss
to generator. Defaults to None.
train_cfg (dict | None, optional): Config for training schedule.
Defaults to None.
test_cfg (dict | None, optional): Config for testing schedule. Defaults
to None.
"""
def __init__(self,
generator,
discriminator,
gan_loss,
disc_auxiliary_loss,
gen_auxiliary_loss=None,
train_cfg=None,
test_cfg=None):
super().__init__()
self._gen_cfg = deepcopy(generator)
self.generator = build_module(generator)
# support no discriminator in testing
if discriminator is not None:
self.discriminator = build_module(discriminator)
else:
self.discriminator = None
# support no gan_loss in testing
if gan_loss is not None:
self.gan_loss = build_module(gan_loss)
else:
self.gan_loss = None
if disc_auxiliary_loss:
self.disc_auxiliary_losses = build_module(disc_auxiliary_loss)
if not isinstance(self.disc_auxiliary_losses, nn.ModuleList):
self.disc_auxiliary_losses = nn.ModuleList(
[self.disc_auxiliary_losses])
else:
self.disc_auxiliary_losses = None
if gen_auxiliary_loss:
self.gen_auxiliary_losses = build_module(gen_auxiliary_loss)
if not isinstance(self.gen_auxiliary_losses, nn.ModuleList):
self.gen_auxiliary_losses = nn.ModuleList(
[self.gen_auxiliary_losses])
else:
self.gen_auxiliary_losses = None
# register necessary training status
self.register_buffer('shown_nkimg', torch.tensor(0.))
self.register_buffer('_curr_transition_weight', torch.tensor(1.))
self.train_cfg = deepcopy(train_cfg) if train_cfg else None
self.test_cfg = deepcopy(test_cfg) if test_cfg else None
self._parse_train_cfg()
# this buffer is used to resume model easily
self.register_buffer(
'_next_scale_int',
torch.tensor(self.scales[0][0], dtype=torch.int32))
# TODO: init it with the same value as `_next_scale_int`
# a dirty workaround for testing
self.register_buffer(
'_curr_scale_int',
torch.tensor(self.scales[-1][0], dtype=torch.int32))
if test_cfg is not None:
self._parse_test_cfg()
def _parse_train_cfg(self):
"""Parsing train config and set some attributes for training."""
if self.train_cfg is None:
self.train_cfg = dict()
# control the work flow in train step
self.disc_steps = self.train_cfg.get('disc_steps', 1)
# whether to use exponential moving average for training
self.use_ema = self.train_cfg.get('use_ema', False)
if self.use_ema:
# use deepcopy to guarantee the consistency
self.generator_ema = deepcopy(self.generator)
# setup interpolation operation at the beginning of training iter
interp_real_cfg = deepcopy(self.train_cfg.get('interp_real', None))
if interp_real_cfg is None:
interp_real_cfg = dict(mode='bilinear', align_corners=True)
self.interp_real_to = partial(F.interpolate, **interp_real_cfg)
# parsing the training schedule: scales : kimg
assert isinstance(self.train_cfg['nkimgs_per_scale'],
dict), ('Please provide "nkimgs_per_'
'scale" to schedule the training procedure.')
nkimgs_per_scale = deepcopy(self.train_cfg['nkimgs_per_scale'])
self.scales = []
self.nkimgs = []
for k, v in nkimgs_per_scale.items():
# support for different data types
if isinstance(k, str):
k = (int(k), int(k))
elif isinstance(k, int):
k = (k, k)
else:
assert mmcv.is_tuple_of(k, int)
# sanity check for the order of scales
assert len(self.scales) == 0 or k[0] > self.scales[-1][0]
self.scales.append(k)
self.nkimgs.append(v)
self.cum_nkimgs = np.cumsum(self.nkimgs)
self.curr_stage = 0
self.prev_stage = 0
# actually nkimgs shown at the end of per training stage
self._actual_nkimgs = []
# In each scale, transit from previous torgb layer to newer torgb layer
# with `transition_kimgs` imgs
self.transition_kimgs = self.train_cfg.get('transition_kimgs', 600)
# setup optimizer
self.optimizer = build_optimizers(
self, deepcopy(self.train_cfg['optimizer_cfg']))
# get lr schedule
self.g_lr_base = self.train_cfg['g_lr_base']
self.d_lr_base = self.train_cfg['d_lr_base']
# example for lr schedule: {'32': 0.001, '64': 0.0001}
self.g_lr_schedule = self.train_cfg.get('g_lr_schedule', dict())
self.d_lr_schedule = self.train_cfg.get('d_lr_schedule', dict())
# reset the states for optimizers, e.g. momentum in Adam
self.reset_optim_for_new_scale = self.train_cfg.get(
'reset_optim_for_new_scale', True)
# dirty walkround for avoiding optimizer bug in resuming
self.prev_stage = self.train_cfg.get('prev_stage', self.prev_stage)
def _parse_test_cfg(self):
"""Parsing train config and set some attributes for testing."""
if self.test_cfg is None:
self.test_cfg = dict()
# basic testing information
self.batch_size = self.test_cfg.get('batch_size', 1)
# whether to use exponential moving average for testing
self.use_ema = self.test_cfg.get('use_ema', False)
# TODO: finish ema part
def sample_from_noise(self,
noise,
num_batches=0,
curr_scale=None,
transition_weight=None,
sample_model='ema/orig',
**kwargs):
"""Sample images from noises by using the generator.
Args:
noise (torch.Tensor | callable | None): You can directly give a
batch of noise through a ``torch.Tensor`` or offer a callable
function to sample a batch of noise data. Otherwise, the
``None`` indicates to use the default noise sampler.
num_batches (int, optional): The number of batch size.
Defaults to 0.
Returns:
torch.Tensor | dict: The output may be the direct synthesized \
images in ``torch.Tensor``. Otherwise, a dict with queried \
data, including generated images, will be returned.
"""
# use `self.curr_scale` if curr_scale is None
if curr_scale is None:
# in training, 'curr_scale' will be set as attribute
if hasattr(self, 'curr_scale'):
curr_scale = self.curr_scale[0]
# in testing, adopt '_curr_scale_int' from buffer as testing scale
else:
curr_scale = self._curr_scale_int.item()
# use `self._curr_transition_weight` if `transition_weight` is None
if transition_weight is None:
transition_weight = self._curr_transition_weight.item()
if sample_model == 'ema':
assert self.use_ema
_model = self.generator_ema
elif sample_model == 'ema/orig' and self.use_ema:
_model = self.generator_ema
else:
_model = self.generator
outputs = _model(
noise,
num_batches=num_batches,
curr_scale=curr_scale,
transition_weight=transition_weight,
**kwargs)
if isinstance(outputs, dict) and 'noise_batch' in outputs:
noise = outputs['noise_batch']
if sample_model == 'ema/orig' and self.use_ema:
_model = self.generator
outputs_ = _model(
noise,
num_batches=num_batches,
curr_scale=curr_scale,
transition_weight=transition_weight,
**kwargs)
if isinstance(outputs_, dict):
outputs['fake_img'] = torch.cat(
[outputs['fake_img'], outputs_['fake_img']], dim=0)
else:
outputs = torch.cat([outputs, outputs_], dim=0)
return outputs
def train_step(self,
data_batch,
optimizer,
ddp_reducer=None,
running_status=None):
"""Train step function.
This function implements the standard training iteration for
asynchronous adversarial training. Namely, in each iteration, we first
update discriminator and then compute loss for generator with the newly
updated discriminator.
As for distributed training, we use the ``reducer`` from ddp to
synchronize the necessary params in current computational graph.
Args:
data_batch (dict): Input data from dataloader.
optimizer (dict): Dict contains optimizer for generator and
discriminator.
ddp_reducer (:obj:`Reducer` | None, optional): Reducer from ddp.
It is used to prepare for ``backward()`` in ddp. Defaults to
None.
running_status (dict | None, optional): Contains necessary basic
information for training, e.g., iteration number. Defaults to
None.
Returns:
dict: Contains 'log_vars', 'num_samples', and 'results'.
"""
# get data from data_batch
real_imgs = data_batch['real_img']
# If you adopt ddp, this batch size is local batch size for each GPU.
batch_size = real_imgs.shape[0]
# get running status
if running_status is not None:
curr_iter = running_status['iteration']
else:
# dirty walkround for not providing running status
if not hasattr(self, 'iteration'):
self.iteration = 0
curr_iter = self.iteration
# check if optimizer from model
if hasattr(self, 'optimizer'):
optimizer = self.optimizer
# update current stage
self.curr_stage = int(
min(
sum(self.cum_nkimgs <= self.shown_nkimg.item()),
len(self.scales) - 1))
self.curr_scale = self.scales[self.curr_stage]
self._curr_scale_int = self._next_scale_int.clone()
# add new scale and update training status
if self.curr_stage != self.prev_stage:
self.prev_stage = self.curr_stage
self._actual_nkimgs.append(self.shown_nkimg.item())
# reset optimizer
if self.reset_optim_for_new_scale:
optim_cfg = deepcopy(self.train_cfg['optimizer_cfg'])
optim_cfg['generator']['lr'] = self.g_lr_schedule.get(
str(self.curr_scale[0]), self.g_lr_base)
optim_cfg['discriminator']['lr'] = self.d_lr_schedule.get(
str(self.curr_scale[0]), self.d_lr_base)
self.optimizer = build_optimizers(self, optim_cfg)
optimizer = self.optimizer
mmcv.print_log('Reset optimizer for new scale', logger='mmgen')
# update training configs, like transition weight for torgb layers.
# get current transition weight for interpolating two torgb layers
if self.curr_stage == 0:
transition_weight = 1.
else:
transition_weight = (
self.shown_nkimg.item() -
self._actual_nkimgs[-1]) / self.transition_kimgs
# clip to [0, 1]
transition_weight = min(max(transition_weight, 0.), 1.)
self._curr_transition_weight = torch.tensor(transition_weight).to(
self._curr_transition_weight)
# resize real image to target scale
if real_imgs.shape[2:] == self.curr_scale:
pass
elif real_imgs.shape[2] >= self.curr_scale[0] and real_imgs.shape[
3] >= self.curr_scale[1]:
real_imgs = self.interp_real_to(real_imgs, size=self.curr_scale)
else:
raise RuntimeError(
f'The scale of real image {real_imgs.shape[2:]} is smaller '
f'than current scale {self.curr_scale}.')
# disc training
set_requires_grad(self.discriminator, True)
optimizer['discriminator'].zero_grad()
# TODO: add noise sampler to customize noise sampling
with torch.no_grad():
fake_imgs = self.generator(
None,
num_batches=batch_size,
curr_scale=self.curr_scale[0],
transition_weight=transition_weight)
# disc pred for fake imgs and real_imgs
disc_pred_fake = self.discriminator(
fake_imgs,
curr_scale=self.curr_scale[0],
transition_weight=transition_weight)
disc_pred_real = self.discriminator(
real_imgs,
curr_scale=self.curr_scale[0],
transition_weight=transition_weight)
# get data dict to compute losses for disc
data_dict_ = dict(
iteration=curr_iter,
gen=self.generator,
disc=self.discriminator,
disc_pred_fake=disc_pred_fake,
disc_pred_real=disc_pred_real,
fake_imgs=fake_imgs,
real_imgs=real_imgs,
curr_scale=self.curr_scale[0],
transition_weight=transition_weight,
gen_partial=partial(
self.generator,
curr_scale=self.curr_scale[0],
transition_weight=transition_weight),
disc_partial=partial(
self.discriminator,
curr_scale=self.curr_scale[0],
transition_weight=transition_weight))
loss_disc, log_vars_disc = self._get_disc_loss(data_dict_)
# prepare for backward in ddp. If you do not call this function before
# back propagation, the ddp will not dynamically find the used params
# in current computation.
if ddp_reducer is not None:
ddp_reducer.prepare_for_backward(_find_tensors(loss_disc))
loss_disc.backward()
optimizer['discriminator'].step()
# update training log status
if dist.is_initialized():
_batch_size = batch_size * dist.get_world_size()
else:
if 'batch_size' not in running_status:
raise RuntimeError(
'You should offer "batch_size" in running status for PGGAN'
)
_batch_size = running_status['batch_size']
self.shown_nkimg += (_batch_size / 1000.)
log_vars_disc.update(
dict(
shown_nkimg=self.shown_nkimg.item(),
curr_scale=self.curr_scale[0],
transition_weight=transition_weight))
# skip generator training if only train discriminator for current
# iteration
if (curr_iter + 1) % self.disc_steps != 0:
results = dict(
fake_imgs=fake_imgs.cpu(), real_imgs=real_imgs.cpu())
outputs = dict(
log_vars=log_vars_disc,
num_samples=batch_size,
results=results)
if hasattr(self, 'iteration'):
self.iteration += 1
return outputs
# generator training
set_requires_grad(self.discriminator, False)
optimizer['generator'].zero_grad()
# TODO: add noise sampler to customize noise sampling
fake_imgs = self.generator(
None,
num_batches=batch_size,
curr_scale=self.curr_scale[0],
transition_weight=transition_weight)
disc_pred_fake_g = self.discriminator(
fake_imgs,
curr_scale=self.curr_scale[0],
transition_weight=transition_weight)
data_dict_ = dict(
iteration=curr_iter,
gen=self.generator,
disc=self.discriminator,
fake_imgs=fake_imgs,
disc_pred_fake_g=disc_pred_fake_g)
loss_gen, log_vars_g = self._get_gen_loss(data_dict_)
# prepare for backward in ddp. If you do not call this function before
# back propagation, the ddp will not dynamically find the used params
# in current computation.
if ddp_reducer is not None:
ddp_reducer.prepare_for_backward(_find_tensors(loss_gen))
loss_gen.backward()
optimizer['generator'].step()
log_vars = {}
log_vars.update(log_vars_g)
log_vars.update(log_vars_disc)
log_vars.update({'batch_size': batch_size})
results = dict(fake_imgs=fake_imgs.cpu(), real_imgs=real_imgs.cpu())
outputs = dict(
log_vars=log_vars, num_samples=batch_size, results=results)
if hasattr(self, 'iteration'):
self.iteration += 1
# check if a new scale will be added in the next iteration
_curr_stage = int(
min(
sum(self.cum_nkimgs <= self.shown_nkimg.item()),
len(self.scales) - 1))
# in the next iteration, we will switch to a new scale
if _curr_stage != self.curr_stage:
# `self._next_scale_int` is updated at the end of `train_step`
self._next_scale_int = self._next_scale_int * 2
return outputs
# Copyright (c) OpenMMLab. All rights reserved.
import pickle
from copy import deepcopy
from functools import partial
import mmcv
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parallel import DataParallel, DistributedDataParallel
from torch.nn.parallel.distributed import _find_tensors
from mmgen.models.architectures.common import get_module_device
from mmgen.models.builder import MODELS, build_module
from mmgen.models.gans.base_gan import BaseGAN
from ..common import set_requires_grad
@MODELS.register_module()
class SinGAN(BaseGAN):
"""SinGAN.
This model implement the single image generative adversarial model proposed
in: Singan: Learning a Generative Model from a Single Natural Image,
ICCV'19.
Notes for training:
- This model should be trained with our dataset ``SinGANDataset``.
- In training, the ``total_iters`` arguments is related to the number of
scales in the image pyramid and ``iters_per_scale`` in the ``train_cfg``.
You should set it carefully in the training config file.
Notes for model architectures:
- The generator and discriminator need ``num_scales`` in initialization.
However, this arguments is generated by ``create_real_pyramid`` function
from the ``singan_dataset.py``. The last element in the returned list
(``stop_scale``) is the value for ``num_scales``. Pay attention that this
scale is counted from zero. Please see our tutorial for SinGAN to obtain
more details or our standard config for reference.
Args:
generator (dict): Config for generator.
discriminator (dict): Config for discriminator.
gan_loss (dict): Config for generative adversarial loss.
disc_auxiliary_loss (dict): Config for auxiliary loss to
discriminator.
gen_auxiliary_loss (dict | None, optional): Config for auxiliary loss
to generator. Defaults to None.
train_cfg (dict | None, optional): Config for training schedule.
Defaults to None.
test_cfg (dict | None, optional): Config for testing schedule. Defaults
to None.
"""
def __init__(self,
generator,
discriminator,
gan_loss,
disc_auxiliary_loss,
gen_auxiliary_loss=None,
train_cfg=None,
test_cfg=None):
super().__init__()
self._gen_cfg = deepcopy(generator)
self.generator = build_module(generator)
# support no discriminator in testing
if discriminator is not None:
self.discriminator = build_module(discriminator)
else:
self.discriminator = None
# support no gan_loss in testing
if gan_loss is not None:
self.gan_loss = build_module(gan_loss)
else:
self.gan_loss = None
if disc_auxiliary_loss:
self.disc_auxiliary_losses = build_module(disc_auxiliary_loss)
if not isinstance(self.disc_auxiliary_losses, nn.ModuleList):
self.disc_auxiliary_losses = nn.ModuleList(
[self.disc_auxiliary_losses])
else:
self.disc_auxiliary_losses = None
if gen_auxiliary_loss:
self.gen_auxiliary_losses = build_module(gen_auxiliary_loss)
if not isinstance(self.gen_auxiliary_losses, nn.ModuleList):
self.gen_auxiliary_losses = nn.ModuleList(
[self.gen_auxiliary_losses])
else:
self.gen_auxiliary_losses = None
# register necessary training status
self.curr_stage = -1
self.noise_weights = [1]
self.fixed_noises = []
self.reals = []
self.train_cfg = deepcopy(train_cfg) if train_cfg else None
self.test_cfg = deepcopy(test_cfg) if test_cfg else None
self._parse_train_cfg()
if test_cfg is not None:
self._parse_test_cfg()
def _parse_train_cfg(self):
"""Parsing train config and set some attributes for training."""
if self.train_cfg is None:
self.train_cfg = dict()
# whether to use exponential moving average for training
self.use_ema = self.train_cfg.get('use_ema', False)
if self.use_ema:
# use deepcopy to guarantee the consistency
self.generator_ema = deepcopy(self.generator)
def _parse_test_cfg(self):
if self.test_cfg.get('pkl_data', None) is not None:
with open(self.test_cfg.pkl_data, 'rb') as f:
data = pickle.load(f)
self.fixed_noises = self._from_numpy(data['fixed_noises'])
self.noise_weights = self._from_numpy(data['noise_weights'])
self.curr_stage = data['curr_stage']
mmcv.print_log(f'Load pkl data from {self.test_cfg.pkl_data}',
'mmgen')
def _from_numpy(self, data):
if isinstance(data, list):
return [self._from_numpy(x) for x in data]
if isinstance(data, np.ndarray):
data = torch.from_numpy(data)
device = get_module_device(self.generator)
data = data.to(device)
return data
return data
def get_module(self, model, module_name):
"""Get an inner module from model.
Since we will wrapper DDP for some model, we have to judge whether the
module can be indexed directly.
Args:
model (nn.Module): This model may wrapped with DDP or not.
module_name (str): The name of specific module.
Return:
nn.Module: Returned sub module.
"""
if isinstance(model, (DataParallel, DistributedDataParallel)):
return getattr(model.module, module_name)
return getattr(model, module_name)
def sample_from_noise(self,
noise,
num_batches=0,
curr_scale=None,
sample_model='ema/orig',
**kwargs):
"""Sample images from noises by using the generator.
Args:
noise (torch.Tensor | callable | None): You can directly give a
batch of noise through a ``torch.Tensor`` or offer a callable
function to sample a batch of noise data. Otherwise, the
``None`` indicates to use the default noise sampler.
num_batches (int, optional): The number of batch size.
Defaults to 0.
Returns:
torch.Tensor | dict: The output may be the direct synthesized \
images in ``torch.Tensor``. Otherwise, a dict with queried \
data, including generated images, will be returned.
"""
# use `self.curr_scale` if curr_scale is None
if curr_scale is None:
curr_scale = self.curr_stage
if sample_model == 'ema':
assert self.use_ema
_model = self.generator_ema
elif sample_model == 'ema/orig' and self.use_ema:
_model = self.generator_ema
else:
_model = self.generator
if not self.fixed_noises[0].is_cuda and torch.cuda.is_available():
self.fixed_noises = [
x.to(get_module_device(self)) for x in self.fixed_noises
]
outputs = _model(
None,
fixed_noises=self.fixed_noises,
noise_weights=self.noise_weights,
rand_mode='rand',
num_batches=num_batches,
curr_scale=curr_scale,
**kwargs)
return outputs
def construct_fixed_noises(self):
"""Construct the fixed noises list used in SinGAN."""
for i, real in enumerate(self.reals):
h, w = real.shape[-2:]
if i == 0:
noise = torch.randn(1, 1, h, w).to(real)
self.fixed_noises.append(noise)
else:
noise = torch.zeros_like(real)
self.fixed_noises.append(noise)
def train_step(self,
data_batch,
optimizer,
ddp_reducer=None,
running_status=None):
"""Train step function.
This function implements the standard training iteration for
asynchronous adversarial training. Namely, in each iteration, we first
update discriminator and then compute loss for generator with the newly
updated discriminator.
As for distributed training, we use the ``reducer`` from ddp to
synchronize the necessary params in current computational graph.
Args:
data_batch (dict): Input data from dataloader.
optimizer (dict): Dict contains optimizer for generator and
discriminator.
ddp_reducer (:obj:`Reducer` | None, optional): Reducer from ddp.
It is used to prepare for ``backward()`` in ddp. Defaults to
None.
running_status (dict | None, optional): Contains necessary basic
information for training, e.g., iteration number. Defaults to
None.
Returns:
dict: Contains 'log_vars', 'num_samples', and 'results'.
"""
# get running status
if running_status is not None:
curr_iter = running_status['iteration']
else:
# dirty walkround for not providing running status
if not hasattr(self, 'iteration'):
self.iteration = 0
curr_iter = self.iteration
# init each scale
if curr_iter % self.train_cfg['iters_per_scale'] == 0:
self.curr_stage += 1
# load weights from prev scale
self.get_module(self.generator, 'check_and_load_prev_weight')(
self.curr_stage)
self.get_module(self.discriminator, 'check_and_load_prev_weight')(
self.curr_stage)
# build optimizer for each scale
g_module = self.get_module(self.generator, 'blocks')
param_list = g_module[self.curr_stage].parameters()
self.g_optim = torch.optim.Adam(
param_list, lr=self.train_cfg['lr_g'], betas=(0.5, 0.999))
d_module = self.get_module(self.discriminator, 'blocks')
self.d_optim = torch.optim.Adam(
d_module[self.curr_stage].parameters(),
lr=self.train_cfg['lr_d'],
betas=(0.5, 0.999))
self.optimizer = dict(
generator=self.g_optim, discriminator=self.d_optim)
self.g_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer=self.g_optim, **self.train_cfg['lr_scheduler_args'])
self.d_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer=self.d_optim, **self.train_cfg['lr_scheduler_args'])
optimizer = self.optimizer
# setup fixed noises and reals pyramid
if curr_iter == 0 or len(self.reals) == 0:
keys = [k for k in data_batch.keys() if 'real_scale' in k]
scales = len(keys)
self.reals = [data_batch[f'real_scale{s}'] for s in range(scales)]
# here we do not padding fixed noises
self.construct_fixed_noises()
# disc training
set_requires_grad(self.discriminator, True)
for _ in range(self.train_cfg['disc_steps']):
optimizer['discriminator'].zero_grad()
# TODO: add noise sampler to customize noise sampling
with torch.no_grad():
fake_imgs = self.generator(
data_batch['input_sample'],
self.fixed_noises,
self.noise_weights,
rand_mode='rand',
curr_scale=self.curr_stage)
# disc pred for fake imgs and real_imgs
disc_pred_fake = self.discriminator(fake_imgs.detach(),
self.curr_stage)
disc_pred_real = self.discriminator(self.reals[self.curr_stage],
self.curr_stage)
# get data dict to compute losses for disc
data_dict_ = dict(
iteration=curr_iter,
gen=self.generator,
disc=self.discriminator,
disc_pred_fake=disc_pred_fake,
disc_pred_real=disc_pred_real,
fake_imgs=fake_imgs,
real_imgs=self.reals[self.curr_stage],
disc_partial=partial(
self.discriminator, curr_scale=self.curr_stage))
loss_disc, log_vars_disc = self._get_disc_loss(data_dict_)
# prepare for backward in ddp. If you do not call this function
# before back propagation, the ddp will not dynamically find the
# used params in current computation.
if ddp_reducer is not None:
ddp_reducer.prepare_for_backward(_find_tensors(loss_disc))
loss_disc.backward()
optimizer['discriminator'].step()
log_vars_disc.update(dict(curr_stage=self.curr_stage))
# generator training
set_requires_grad(self.discriminator, False)
for _ in range(self.train_cfg['generator_steps']):
optimizer['generator'].zero_grad()
# TODO: add noise sampler to customize noise sampling
fake_imgs = self.generator(
data_batch['input_sample'],
self.fixed_noises,
self.noise_weights,
rand_mode='rand',
curr_scale=self.curr_stage)
disc_pred_fake_g = self.discriminator(
fake_imgs, curr_scale=self.curr_stage)
recon_imgs = self.generator(
data_batch['input_sample'],
self.fixed_noises,
self.noise_weights,
rand_mode='recon',
curr_scale=self.curr_stage)
data_dict_ = dict(
iteration=curr_iter,
gen=self.generator,
disc=self.discriminator,
fake_imgs=fake_imgs,
recon_imgs=recon_imgs,
real_imgs=self.reals[self.curr_stage],
disc_pred_fake_g=disc_pred_fake_g)
loss_gen, log_vars_g = self._get_gen_loss(data_dict_)
# prepare for backward in ddp. If you do not call this function
# before back propagation, the ddp will not dynamically find the
# used params in current computation.
if ddp_reducer is not None:
ddp_reducer.prepare_for_backward(_find_tensors(loss_gen))
loss_gen.backward()
optimizer['generator'].step()
# end of each scale
# calculate noise weight for next scale
if (curr_iter % self.train_cfg['iters_per_scale']
== 0) and (self.curr_stage < len(self.reals) - 1):
with torch.no_grad():
g_recon = self.generator(
data_batch['input_sample'],
self.fixed_noises,
self.noise_weights,
rand_mode='recon',
curr_scale=self.curr_stage)
if isinstance(g_recon, dict):
g_recon = g_recon['fake_img']
g_recon = F.interpolate(
g_recon, self.reals[self.curr_stage + 1].shape[-2:])
mse = F.mse_loss(g_recon.detach(), self.reals[self.curr_stage + 1])
rmse = torch.sqrt(mse)
self.noise_weights.append(
self.train_cfg.get('noise_weight_init', 0.1) * rmse.item())
# try to release GPU memory.
torch.cuda.empty_cache()
log_vars = {}
log_vars.update(log_vars_g)
log_vars.update(log_vars_disc)
results = dict(
fake_imgs=fake_imgs.cpu(),
real_imgs=self.reals[self.curr_stage].cpu(),
recon_imgs=recon_imgs.cpu(),
curr_stage=self.curr_stage,
fixed_noises=self.fixed_noises,
noise_weights=self.noise_weights)
outputs = dict(log_vars=log_vars, num_samples=1, results=results)
# update lr scheduler
self.d_scheduler.step()
self.g_scheduler.step()
if hasattr(self, 'iteration'):
self.iteration += 1
return outputs
@MODELS.register_module()
class PESinGAN(SinGAN):
"""Positional Encoding in SinGAN.
This modified SinGAN is used to reimplement the experiments in: Positional
Encoding as Spatial Inductive Bias in GANs, CVPR2021.
"""
def _parse_train_cfg(self):
super(PESinGAN, self)._parse_train_cfg()
self.fixed_noise_with_pad = self.train_cfg.get('fixed_noise_with_pad',
False)
self.first_fixed_noises_ch = self.train_cfg.get(
'first_fixed_noises_ch', 1)
def construct_fixed_noises(self):
"""Construct the fixed noises list used in SinGAN."""
for i, real in enumerate(self.reals):
h, w = real.shape[-2:]
if self.fixed_noise_with_pad:
pad_ = self.get_module(self, 'generator').pad_head
h += 2 * pad_
w += 2 * pad_
if i == 0:
noise = torch.randn(1, self.first_fixed_noises_ch, h,
w).to(real)
self.fixed_noises.append(noise)
else:
noise = torch.zeros((1, 1, h, w)).to(real)
self.fixed_noises.append(noise)
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
import torch
import torch.nn as nn
from torch.nn.parallel.distributed import _find_tensors
from ..builder import MODELS, build_module
from ..common import set_requires_grad
from .base_gan import BaseGAN
# _SUPPORT_METHODS_ = ['DCGAN', 'STYLEGANv2']
# @MODELS.register_module(_SUPPORT_METHODS_)
@MODELS.register_module()
class StaticUnconditionalGAN(BaseGAN):
"""Unconditional GANs with static architecture in training.
This is the standard GAN model containing standard adversarial training
schedule. To fulfill the requirements of various GAN algorithms,
``disc_auxiliary_loss`` and ``gen_auxiliary_loss`` are provided to
customize auxiliary losses, e.g., gradient penalty loss, and discriminator
shift loss. In addition, ``train_cfg`` and ``test_cfg`` aims at setuping
training schedule.
Args:
generator (dict): Config for generator.
discriminator (dict): Config for discriminator.
gan_loss (dict): Config for generative adversarial loss.
disc_auxiliary_loss (dict): Config for auxiliary loss to
discriminator.
gen_auxiliary_loss (dict | None, optional): Config for auxiliary loss
to generator. Defaults to None.
train_cfg (dict | None, optional): Config for training schedule.
Defaults to None.
test_cfg (dict | None, optional): Config for testing schedule. Defaults
to None.
"""
def __init__(self,
generator,
discriminator,
gan_loss,
disc_auxiliary_loss=None,
gen_auxiliary_loss=None,
train_cfg=None,
test_cfg=None):
super().__init__()
self._gen_cfg = deepcopy(generator)
self.generator = build_module(generator)
# support no discriminator in testing
if discriminator is not None:
self.discriminator = build_module(discriminator)
else:
self.discriminator = None
# support no gan_loss in testing
if gan_loss is not None:
self.gan_loss = build_module(gan_loss)
else:
self.gan_loss = None
if disc_auxiliary_loss:
self.disc_auxiliary_losses = build_module(disc_auxiliary_loss)
if not isinstance(self.disc_auxiliary_losses, nn.ModuleList):
self.disc_auxiliary_losses = nn.ModuleList(
[self.disc_auxiliary_losses])
else:
self.disc_auxiliary_loss = None
if gen_auxiliary_loss:
self.gen_auxiliary_losses = build_module(gen_auxiliary_loss)
if not isinstance(self.gen_auxiliary_losses, nn.ModuleList):
self.gen_auxiliary_losses = nn.ModuleList(
[self.gen_auxiliary_losses])
else:
self.gen_auxiliary_losses = None
self.train_cfg = deepcopy(train_cfg) if train_cfg else None
self.test_cfg = deepcopy(test_cfg) if test_cfg else None
self._parse_train_cfg()
if test_cfg is not None:
self._parse_test_cfg()
def _parse_train_cfg(self):
"""Parsing train config and set some attributes for training."""
if self.train_cfg is None:
self.train_cfg = dict()
# control the work flow in train step
self.disc_steps = self.train_cfg.get('disc_steps', 1)
# whether to use exponential moving average for training
self.use_ema = self.train_cfg.get('use_ema', False)
if self.use_ema:
# use deepcopy to guarantee the consistency
self.generator_ema = deepcopy(self.generator)
self.real_img_key = self.train_cfg.get('real_img_key', 'real_img')
def _parse_test_cfg(self):
"""Parsing test config and set some attributes for testing."""
if self.test_cfg is None:
self.test_cfg = dict()
# basic testing information
self.batch_size = self.test_cfg.get('batch_size', 1)
# whether to use exponential moving average for testing
self.use_ema = self.test_cfg.get('use_ema', False)
# TODO: finish ema part
def train_step(self,
data_batch,
optimizer,
ddp_reducer=None,
loss_scaler=None,
use_apex_amp=False,
running_status=None):
"""Train step function.
This function implements the standard training iteration for
asynchronous adversarial training. Namely, in each iteration, we first
update discriminator and then compute loss for generator with the newly
updated discriminator.
As for distributed training, we use the ``reducer`` from ddp to
synchronize the necessary params in current computational graph.
Args:
data_batch (dict): Input data from dataloader.
optimizer (dict): Dict contains optimizer for generator and
discriminator.
ddp_reducer (:obj:`Reducer` | None, optional): Reducer from ddp.
It is used to prepare for ``backward()`` in ddp. Defaults to
None.
loss_scaler (:obj:`torch.cuda.amp.GradScaler` | None, optional):
The loss/gradient scaler used for auto mixed-precision
training. Defaults to ``None``.
use_apex_amp (bool, optional). Whether to use apex.amp. Defaults to
``False``.
running_status (dict | None, optional): Contains necessary basic
information for training, e.g., iteration number. Defaults to
None.
Returns:
dict: Contains 'log_vars', 'num_samples', and 'results'.
"""
# get data from data_batch
real_imgs = data_batch[self.real_img_key]
# If you adopt ddp, this batch size is local batch size for each GPU.
# If you adopt dp, this batch size is the global batch size as usual.
batch_size = real_imgs.shape[0]
# get running status
if running_status is not None:
curr_iter = running_status['iteration']
else:
# dirty walkround for not providing running status
if not hasattr(self, 'iteration'):
self.iteration = 0
curr_iter = self.iteration
# disc training
set_requires_grad(self.discriminator, True)
optimizer['discriminator'].zero_grad()
# TODO: add noise sampler to customize noise sampling
# pass model specific training kwargs
g_training_kwargs = {}
if hasattr(self.generator, 'get_training_kwargs'):
g_training_kwargs.update(
self.generator.get_training_kwargs(phase='disc'))
with torch.no_grad():
fake_imgs = self.generator(
None, num_batches=batch_size, **g_training_kwargs)
# disc pred for fake imgs and real_imgs
disc_pred_fake = self.discriminator(fake_imgs)
disc_pred_real = self.discriminator(real_imgs)
# get data dict to compute losses for disc
data_dict_ = dict(
gen=self.generator,
disc=self.discriminator,
disc_pred_fake=disc_pred_fake,
disc_pred_real=disc_pred_real,
fake_imgs=fake_imgs,
real_imgs=real_imgs,
iteration=curr_iter,
batch_size=batch_size,
loss_scaler=loss_scaler)
loss_disc, log_vars_disc = self._get_disc_loss(data_dict_)
# prepare for backward in ddp. If you do not call this function before
# back propagation, the ddp will not dynamically find the used params
# in current computation.
if ddp_reducer is not None:
ddp_reducer.prepare_for_backward(_find_tensors(loss_disc))
if loss_scaler:
# add support for fp16
loss_scaler.scale(loss_disc).backward()
elif use_apex_amp:
from apex import amp
with amp.scale_loss(
loss_disc, optimizer['discriminator'],
loss_id=0) as scaled_loss_disc:
scaled_loss_disc.backward()
else:
loss_disc.backward()
if loss_scaler:
loss_scaler.unscale_(optimizer['discriminator'])
# note that we do not contain clip_grad procedure
loss_scaler.step(optimizer['discriminator'])
# loss_scaler.update will be called in runner.train()
else:
optimizer['discriminator'].step()
# skip generator training if only train discriminator for current
# iteration
if (curr_iter + 1) % self.disc_steps != 0:
results = dict(
fake_imgs=fake_imgs.cpu(), real_imgs=real_imgs.cpu())
outputs = dict(
log_vars=log_vars_disc,
num_samples=batch_size,
results=results)
if hasattr(self, 'iteration'):
self.iteration += 1
return outputs
# generator training
set_requires_grad(self.discriminator, False)
optimizer['generator'].zero_grad()
# TODO: add noise sampler to customize noise sampling
# pass model specific training kwargs
g_training_kwargs = {}
if hasattr(self.generator, 'get_training_kwargs'):
g_training_kwargs.update(
self.generator.get_training_kwargs(phase='gen'))
fake_imgs = self.generator(
None, num_batches=batch_size, **g_training_kwargs)
disc_pred_fake_g = self.discriminator(fake_imgs)
data_dict_ = dict(
gen=self.generator,
disc=self.discriminator,
fake_imgs=fake_imgs,
disc_pred_fake_g=disc_pred_fake_g,
iteration=curr_iter,
batch_size=batch_size,
loss_scaler=loss_scaler)
loss_gen, log_vars_g = self._get_gen_loss(data_dict_)
# prepare for backward in ddp. If you do not call this function before
# back propagation, the ddp will not dynamically find the used params
# in current computation.
if ddp_reducer is not None:
ddp_reducer.prepare_for_backward(_find_tensors(loss_gen))
if loss_scaler:
loss_scaler.scale(loss_gen).backward()
elif use_apex_amp:
from apex import amp
with amp.scale_loss(
loss_gen, optimizer['generator'],
loss_id=1) as scaled_loss_disc:
scaled_loss_disc.backward()
else:
loss_gen.backward()
if loss_scaler:
loss_scaler.unscale_(optimizer['generator'])
# note that we do not contain clip_grad procedure
loss_scaler.step(optimizer['generator'])
# loss_scaler.update will be called in runner.train()
else:
optimizer['generator'].step()
# update ada p
if hasattr(self.discriminator,
'with_ada') and self.discriminator.with_ada:
self.discriminator.ada_aug.log_buffer[0] += batch_size
self.discriminator.ada_aug.log_buffer[1] += disc_pred_real.sign(
).sum()
self.discriminator.ada_aug.update(
iteration=curr_iter, num_batches=batch_size)
log_vars_disc['augment'] = (
self.discriminator.ada_aug.aug_pipeline.p.data.cpu())
log_vars = {}
log_vars.update(log_vars_g)
log_vars.update(log_vars_disc)
results = dict(fake_imgs=fake_imgs.cpu(), real_imgs=real_imgs.cpu())
outputs = dict(
log_vars=log_vars, num_samples=batch_size, results=results)
if hasattr(self, 'iteration'):
self.iteration += 1
return outputs
# Copyright (c) OpenMMLab. All rights reserved.
from .ddpm_loss import DDPMVLBLoss
from .disc_auxiliary_loss import (DiscShiftLoss, GradientPenaltyLoss,
R1GradientPenalty, disc_shift_loss,
gradient_penalty_loss,
r1_gradient_penalty_loss)
from .gan_loss import GANLoss
from .gen_auxiliary_loss import (CLIPLoss, FaceIdLoss,
GeneratorPathRegularizer, PerceptualLoss,
gen_path_regularizer)
from .pixelwise_loss import (DiscretizedGaussianLogLikelihoodLoss,
GaussianKLDLoss, L1Loss, MSELoss,
discretized_gaussian_log_likelihood, gaussian_kld)
__all__ = [
'GANLoss', 'DiscShiftLoss', 'disc_shift_loss', 'gradient_penalty_loss',
'GradientPenaltyLoss', 'R1GradientPenalty', 'r1_gradient_penalty_loss',
'GeneratorPathRegularizer', 'gen_path_regularizer', 'MSELoss', 'L1Loss',
'gaussian_kld', 'GaussianKLDLoss', 'DiscretizedGaussianLogLikelihoodLoss',
'DDPMVLBLoss', 'discretized_gaussian_log_likelihood', 'FaceIdLoss',
'CLIPLoss', 'PerceptualLoss'
]
# Copyright (c) OpenMMLab. All rights reserved.
from abc import abstractmethod
from copy import deepcopy
from functools import partial
import mmcv
import torch
import torch.distributed as dist
import torch.nn as nn
from mmcv.utils import digit_version
from mmgen.models.builder import MODULES
from .pixelwise_loss import (DiscretizedGaussianLogLikelihoodLoss,
GaussianKLDLoss, _reduction_modes, mse_loss)
from .utils import reduce_loss
class DDPMLoss(nn.Module):
"""Base module for DDPM losses. We support loss weight rescale and log
collection for DDPM models in this module.
We support two kinds of loss rescale methods, which can be
controlled by ``rescale_mode`` and ``rescale_cfg``:
1. ``rescale_mode == 'constant'``: ``constant_rescale`` would be called,
and ``rescale_cfg`` should be passed as ``dict(scale=SCALE)``,
e.g., ``dict(scale=1.2)``. Then, all loss terms would be rescaled by
multiply with ``SCALE``
2. ``rescale_mode == timestep_weight``: ``timestep_weight_rescale`` would
be called, and ``weight`` or ``sampler`` who contains attribute of
weight must be passed. Then, loss at timestep `t` would be multiplied
with `weight[t]`. We also support users further apply a constant
rescale factor to all loss terms, e.g.
``rescale_cfg=dict(scale=SCALE)``. The overall rescale function for
loss at timestep ``t`` can be formulated as
`loss[t] := weight[t] * loss[t] * SCALE`. To be noted that, ``weight``
or ``sampler.weight`` would be inplace modified in the outer code.
e.g.,
.. code-blocks:: python
:linenos:
# 1. define weight
weight = torch.randn(10, )
# 2. define loss function
loss_fn = DDPMLoss(rescale_mode='timestep_weight', weight=weight)
# 3 update weight
# wrong usage: `weight` in `loss_fn` is not accessible from now
# because you assign a new tensor to variable `weight`
# weight = torch.randn(10, )
# correct usage: update `weight` inplace
weight[2] = 2
If ``rescale_mode`` is not passed, ``rescale_cfg`` would be ignored, and
all loss terms would not be rescaled.
For loss log collection, we support users to pass a list of (or single)
config by ``log_cfgs`` argument to define how they want to collect loss
terms and show them in the log. Each log collection returns a dict which
key and value are the name and value of collected loss terms. And the dict
will be merged into ``log_vars`` after the loss used for parameter
optimization is calculated. The log updating process for the class which
uses ddpm_loss can be referred to the following pseudo-code:
.. code-block:: python
:linenos:
# 1. loss dict for parameter optimization
losses_dict = {}
# 2. calculate losses
for loss_fn in self.ddpm_loss:
losses_dict[loss_fn.loss_name()] = loss_fn(outputs_dict)
# 3. init log_vars
log_vars = OrderedDict()
# 4. update log_vars with loss terms used for parameter optimization
for loss_name, loss_value in losses_dict.items():
log_vars[loss_name] = loss_value.mean()
# 5. sum all loss terms used for backward
loss = sum(_value for _key, _value in log_vars.items()
if 'loss' in _key)
# 6. update log_var with log collection functions
for loss_fn in self.ddpm_loss:
if hasattr(loss_fn, 'log_vars'):
log_vars.update(loss_fn.log_vars)
Each log configs must contain ``type`` keyword, and may contain ``prefix``
and ``reduction`` keywords.
``type``: Use to get the corresponding collection function. Functions would
be named as ``f'{type}_log_collect'``. In `DDPMLoss`, we only support
``type=='quartile'``, but users may define their log collection
functions and use them in this way.
``prefix``: This keyword is set for avoiding the name of displayed loss
terms being too long. The name of each loss term will set as
``'{prefix}_{log_coll_fn_spec_name}'``, where
``{log_coll_fn_spec_name}`` is name specific to the log collection
function. If passed, it must start with ``'loss_'``. If not passed,
``'loss_'`` would be used.
``reduction``: Control the reduction method of the collected loss terms.
We implement ``quartile_log_collection`` in this module. In detail, we
divide total timesteps into four parts and collect the loss in the
corresponding timestep intervals.
To use those collection methods, users may pass ``log_cfgs`` as the
following example:
.. code-block:: python
:linenos:
log_cfgs = [
dict(type='quartile', reduction=REUCTION, prefix_name=PREFIX),
...
]
Args:
rescale_mode (str, optional): Mode of the loss rescale method.
Defaults to None.
rescale_cfg (dict, optional): Config of the loss rescale method.
log_cfgs (list[dict] | dict | optional): Configs to collect logs.
Defaults to None.
sampler (object): Weight sampler. Defaults to None.
weight (torch.Tensor, optional): Weight used for rescale losses.
Defaults to None.
reduction (str, optional): Same as built-in losses of PyTorch.
Defaults to 'mean'.
loss_name (str, optional): Name of the loss item. Defaults to None.
"""
def __init__(self,
rescale_mode=None,
rescale_cfg=None,
log_cfgs=None,
weight=None,
sampler=None,
reduction='mean',
loss_name=None):
super().__init__()
if reduction not in _reduction_modes:
raise ValueError(f'Unsupported reduction mode: {reduction}. '
f'Supported ones are: {_reduction_modes}')
self.reduction = reduction
self._loss_name = loss_name
self.log_fn_list = []
log_cfgs_ = deepcopy(log_cfgs)
if log_cfgs_ is not None:
if not isinstance(log_cfgs_, list):
log_cfgs_ = [log_cfgs_]
assert mmcv.is_list_of(log_cfgs_, dict)
for log_cfg_ in log_cfgs_:
log_type = log_cfg_.pop('type')
log_collect_fn = f'{log_type}_log_collect'
assert hasattr(self, log_collect_fn)
log_collect_fn = getattr(self, log_collect_fn)
log_cfg_.setdefault('prefix_name', 'loss')
assert log_cfg_['prefix_name'].startswith('loss')
log_cfg_.setdefault('reduction', reduction)
self.log_fn_list.append(partial(log_collect_fn, **log_cfg_))
self.log_vars = dict()
# handle rescale mode
if not rescale_mode:
self.rescale_fn = lambda loss, t: loss
else:
rescale_fn_name = f'{rescale_mode}_rescale'
assert hasattr(self, rescale_fn_name)
if rescale_mode == 'timestep_weight':
if sampler is not None and hasattr(sampler, 'weight'):
weight = sampler.weight
else:
assert weight is not None and isinstance(
weight, torch.Tensor), (
'\'weight\' or a \'sampler\' contains weight '
'attribute is must be \'torch.Tensor\' for '
'\'timestep_weight\' rescale_mode.')
mmcv.print_log(
'Apply \'timestep_weight\' rescale_mode for '
f'{self._loss_name}. Please make sure the passed weight '
'can be updated by external functions.', 'mmgen')
rescale_cfg = dict(weight=weight)
self.rescale_fn = partial(
getattr(self, rescale_fn_name), **rescale_cfg)
@staticmethod
def constant_rescale(loss, timesteps, scale):
"""Rescale losses at all timesteps with a constant factor.
Args:
loss (torch.Tensor): Losses to rescale.
timesteps (torch.Tensor): Timesteps of each loss items.
scale (int): Rescale factor.
Returns:
torch.Tensor: Rescaled losses.
"""
return loss * scale
@staticmethod
def timestep_weight_rescale(loss, timesteps, weight, scale=1):
"""Rescale losses corresponding to timestep.
Args:
loss (torch.Tensor): Losses to rescale.
timesteps (torch.Tensor): Timesteps of each loss items.
weight (torch.Tensor): Weight corresponding to each timestep.
scale (int): Rescale factor.
Returns:
torch.Tensor: Rescaled losses.
"""
return loss * weight[timesteps] * scale
@torch.no_grad()
def collect_log(self, loss, timesteps):
"""Collect logs.
Args:
loss (torch.Tensor): Losses to collect.
timesteps (torch.Tensor): Timesteps of each loss items.
"""
if not self.log_fn_list:
return
if dist.is_initialized():
ws = dist.get_world_size()
placeholder_l = [torch.zeros_like(loss) for _ in range(ws)]
placeholder_t = [torch.zeros_like(timesteps) for _ in range(ws)]
dist.all_gather(placeholder_l, loss)
dist.all_gather(placeholder_t, timesteps)
loss = torch.cat(placeholder_l, dim=0)
timesteps = torch.cat(placeholder_t, dim=0)
log_vars = dict()
if (dist.is_initialized()
and dist.get_rank() == 0) or not dist.is_initialized():
for log_fn in self.log_fn_list:
log_vars.update(log_fn(loss, timesteps))
self.log_vars = log_vars
@torch.no_grad()
def quartile_log_collect(self,
loss,
timesteps,
total_timesteps,
prefix_name,
reduction='mean'):
"""Collect loss logs by quartile timesteps.
Args:
loss (torch.Tensor): Loss value of each input. Each loss tensor
should be shape as [bz, ]
timesteps (torch.Tensor): Timesteps corresponding to each loss.
Each loss tensor should be shape as [bz, ].
total_timesteps (int): Total timesteps of diffusion process.
prefix_name (str): Prefix want to show in logs.
reduction (str, optional): Specifies the reduction to apply to the
output losses. Defaults to `mean`.
Returns:
dict: Collected log variables.
"""
if digit_version(torch.__version__) <= digit_version('1.6.0'):
# use true_divide in older torch version
quartile = torch.true_divide(timesteps, total_timesteps) * 4
else:
quartile = (timesteps / total_timesteps * 4)
quartile = quartile.type(torch.LongTensor)
log_vars = dict()
for idx in range(4):
if not (quartile == idx).any():
loss_quartile = torch.zeros((1, ))
else:
loss_quartile = reduce_loss(loss[quartile == idx], reduction)
log_vars[f'{prefix_name}_quartile_{idx}'] = loss_quartile.item()
return log_vars
def forward(self, *args, **kwargs):
"""Forward function.
If ``self.data_info`` is not ``None``, a dictionary containing all of
the data and necessary modules should be passed into this function.
If this dictionary is given as a non-keyword argument, it should be
offered as the first argument. If you are using keyword argument,
please name it as `outputs_dict`.
If ``self.data_info`` is ``None``, the input argument or key-word
argument will be directly passed to loss function, ``mse_loss``.
"""
if len(args) == 1:
assert isinstance(args[0], dict), (
'You should offer a dictionary containing network outputs '
'for building up computational graph of this loss module.')
output_dict = args[0]
elif 'output_dict' in kwargs:
assert len(args) == 0, (
'If the outputs dict is given in keyworded arguments, no'
' further non-keyworded arguments should be offered.')
output_dict = kwargs.pop('outputs_dict')
else:
raise NotImplementedError(
'Cannot parsing your arguments passed to this loss module.'
' Please check the usage of this module')
# check keys in output_dict
assert 'timesteps' in output_dict, (
'\'timesteps\' is must for DDPM-based losses, but found'
f'{output_dict.keys()} in \'output_dict\'')
timesteps = output_dict['timesteps']
loss = self._forward_loss(output_dict)
# update log_vars of this class
self.collect_log(loss, timesteps=timesteps)
loss_rescaled = self.rescale_fn(loss, timesteps)
return reduce_loss(loss_rescaled, self.reduction)
@abstractmethod
def _forward_loss(self, output_dict):
"""Forward function for loss calculation. This method should be
implemented by each subclasses.
Args:
outputs_dict (dict): Outputs of the model used to calculate losses.
Returns:
torch.Tensor: Calculated loss.
"""
raise NotImplementedError(
'\'self._forward_loss\' must be implemented.')
def loss_name(self):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return self._loss_name
@MODULES.register_module()
class DDPMVLBLoss(DDPMLoss):
"""Variational lower-bound loss for DDPM-based models.
In this loss, we calculate VLB of different timesteps with different
method. In detail, ``DiscretizedGaussianLogLikelihoodLoss`` is used at
timesteps = 0 and ``GaussianKLDLoss`` at other timesteps.
To control the data flow for loss calculation, users should define
``data_info`` and ``data_info_t_0`` for ``GaussianKLDLoss`` and
``DiscretizedGaussianLogLikelihoodLoss`` respectively. If not passed
``_default_data_info`` and ``_default_data_info_t_0`` would be used.
To be noted that, we only penalize 'variance' in this loss term, and
tensors in output dict corresponding to 'mean' would be detached.
Additionally, we support another log collection function called
``name_log_collection``. In this collection method, we would directly
collect loss terms calculated by different methods.
To use this collection methods, users may passed ``log_cfgs`` as the
following example:
.. code-block:: python
:linenos:
log_cfgs = [
dict(type='name', reduction=REUCTION, prefix_name=PREFIX),
...
]
Args:
rescale_mode (str, optional): Mode of the loss rescale method.
Defaults to None.
rescale_cfg (dict, optional): Config of the loss rescale method.
sampler (object): Weight sampler. Defaults to None.
weight (torch.Tensor, optional): Weight used for rescale losses.
Defaults to None.
data_info (dict, optional): Dictionary contains the mapping between
loss input args and data dictionary for ``timesteps != 0``.
Defaults to None.
data_info_t_0 (dict, optional): Dictionary contains the mapping between
loss input args and data dictionary for ``timesteps == 0``.
Defaults to None.
log_cfgs (list[dict] | dict | optional): Configs to collect logs.
Defaults to None.
reduction (str, optional): Same as built-in losses of PyTorch.
Defaults to 'mean'.
loss_name (str, optional): Name of the loss item. Defaults to
'loss_ddpm_vlb'.
"""
_default_data_info = dict(
mean_pred='mean_pred',
mean_target='mean_target',
logvar_pred='logvar_pred',
logvar_target='logvar_target')
_default_data_info_t_0 = dict(
x='real_imgs', mean='mean_pred', logvar='logvar_pred')
def __init__(self,
rescale_mode=None,
rescale_cfg=None,
sampler=None,
weight=None,
data_info=None,
data_info_t_0=None,
log_cfgs=None,
reduction='mean',
loss_name='loss_ddpm_vlb'):
super().__init__(rescale_mode, rescale_cfg, log_cfgs, weight, sampler,
reduction, loss_name)
self.data_info = self._default_data_info \
if data_info is None else data_info
self.data_info_t_0 = self._default_data_info_t_0 \
if data_info_t_0 is None else data_info_t_0
self.loss_list = [
DiscretizedGaussianLogLikelihoodLoss(
reduction='flatmean',
data_info=self.data_info_t_0,
base='2',
loss_weight=-1,
only_update_var=True),
GaussianKLDLoss(
reduction='flatmean',
data_info=self.data_info,
base='2',
only_update_var=True)
]
self.loss_select_fn_list = [lambda t: t == 0, lambda t: t != 0]
@torch.no_grad()
def name_log_collect(self, loss, timesteps, prefix_name, reduction='mean'):
"""Collect loss logs by name (GaissianKLD and
DiscGaussianLogLikelihood).
Args:
loss (torch.Tensor): Loss value of each input. Each loss tensor
should be in the shape of [bz, ].
timesteps (torch.Tensor): Timesteps corresponding to each losses.
Each loss tensor should be in the shape of [bz, ].
prefix_name (str): Prefix want to show in logs.
reduction (str, optional): Specifies the reduction to apply to the
output losses. Defaults to `mean`.
Returns:
dict: Collected log variables.
"""
log_vars = dict()
for select_fn, loss_fn in zip(self.loss_select_fn_list,
self.loss_list):
mask = select_fn(timesteps)
if not mask.any():
loss_reduced = torch.zeros((1, ))
else:
loss_reduced = reduce_loss(loss[mask], reduction)
# remove original prefix in loss names
loss_term_name = loss_fn.loss_name().replace('loss_', '')
log_vars[f'{prefix_name}_{loss_term_name}'] = loss_reduced.item()
return log_vars
def _forward_loss(self, outputs_dict):
"""Forward function for loss calculation.
Args:
outputs_dict (dict): Outputs of the model used to calculate losses.
Returns:
torch.Tensor: Calculated loss.
"""
# use `zeros` instead of `zeros_like` to avoid get int tensor
timesteps = outputs_dict['timesteps']
loss = torch.zeros_like(timesteps).float()
# loss = torch.zeros(*timesteps.shape).to(timesteps.device)
for select_fn, loss_fn in zip(self.loss_select_fn_list,
self.loss_list):
mask = select_fn(timesteps)
outputs_dict_ = {}
for k, v in outputs_dict.items():
if v is None or not isinstance(v, (torch.Tensor, list)):
outputs_dict_[k] = v
elif isinstance(v, list):
outputs_dict_[k] = [
v[idx] for idx, m in enumerate(mask) if m
]
else:
outputs_dict_[k] = v[mask]
loss[mask] = loss_fn(outputs_dict_)
return loss
@MODULES.register_module()
class DDPMMSELoss(DDPMLoss):
"""Mean square loss for DDPM-based models.
Args:
rescale_mode (str, optional): Mode of the loss rescale method.
Defaults to None.
rescale_cfg (dict, optional): Config of the loss rescale method.
sampler (object): Weight sampler. Defaults to None.
weight (torch.Tensor, optional): Weight used for rescale losses.
Defaults to None.
data_info (dict, optional): Dictionary contains the mapping between
loss input args and data dictionary for ``timesteps != 0``.
Defaults to None.
log_cfgs (list[dict] | dict | optional): Configs to collect logs.
Defaults to None.
reduction (str, optional): Same as built-in losses of PyTorch.
Defaults to 'mean'.
loss_name (str, optional): Name of the loss item. Defaults to
'loss_ddpm_vlb'.
"""
_default_data_info = dict(pred='eps_t_pred', target='noise')
def __init__(self,
rescale_mode=None,
rescale_cfg=None,
sampler=None,
weight=None,
log_cfgs=None,
reduction='mean',
data_info=None,
loss_name='loss_ddpm_mse'):
super().__init__(rescale_mode, rescale_cfg, log_cfgs, weight, sampler,
reduction, loss_name)
self.data_info = self._default_data_info \
if data_info is None else data_info
self.loss_fn = partial(mse_loss, reduction='flatmean')
def _forward_loss(self, outputs_dict):
"""Forward function for loss calculation.
Args:
outputs_dict (dict): Outputs of the model used to calculate losses.
Returns:
torch.Tensor: Calculated loss.
"""
loss_input_dict = {
k: outputs_dict[v]
for k, v in self.data_info.items()
}
loss = self.loss_fn(**loss_input_dict)
return loss
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.autograd as autograd
import torch.nn as nn
from mmgen.models.builder import MODULES
from .utils import weighted_loss
@weighted_loss
def disc_shift_loss(pred):
"""Disc Shift loss.
This loss is proposed in PGGAN as an auxiliary loss for discriminator.
Args:
pred (Tensor): Input tensor.
Returns:
torch.Tensor: loss tensor.
"""
return pred**2
@MODULES.register_module()
class DiscShiftLoss(nn.Module):
"""Disc Shift Loss.
This loss is proposed in PGGAN as an auxiliary loss for discriminator.
**Note for the design of ``data_info``:**
In ``MMGeneration``, almost all of loss modules contain the argument
``data_info``, which can be used for constructing the link between the
input items (needed in loss calculation) and the data from the generative
model. For example, in the training of GAN model, we will collect all of
important data/modules into a dictionary:
.. code-block:: python
:caption: Code from StaticUnconditionalGAN, train_step
:linenos:
data_dict_ = dict(
gen=self.generator,
disc=self.discriminator,
disc_pred_fake=disc_pred_fake,
disc_pred_real=disc_pred_real,
fake_imgs=fake_imgs,
real_imgs=real_imgs,
iteration=curr_iter,
batch_size=batch_size)
But in this loss, we will need to provide ``pred`` as input. Thus, an
example of the ``data_info`` is:
.. code-block:: python
:linenos:
data_info = dict(
pred='disc_pred_fake')
Then, the module will automatically construct this mapping from the input
data dictionary.
In addition, in general, ``disc_shift_loss`` will be applied over real and
fake data. In this case, users just need to add this loss module twice, but
with different ``data_info``. Our model will automatically add these two
items.
Args:
loss_weight (float, optional): Weight of this loss item.
Defaults to ``1.``.
data_info (dict, optional): Dictionary contains the mapping between
loss input args and data dictionary. If ``None``, this module will
directly pass the input data to the loss function.
Defaults to None.
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_disc_shift'.
"""
def __init__(self,
loss_weight=1.0,
data_info=None,
loss_name='loss_disc_shift'):
super().__init__()
self.loss_weight = loss_weight
self.data_info = data_info
self._loss_name = loss_name
def forward(self, *args, **kwargs):
"""Forward function.
If ``self.data_info`` is not ``None``, a dictionary containing all of
the data and necessary modules should be passed into this function.
If this dictionary is given as a non-keyword argument, it should be
offered as the first argument. If you are using keyword argument,
please name it as `outputs_dict`.
If ``self.data_info`` is ``None``, the input argument or key-word
argument will be directly passed to loss function, ``disc_shift_loss``.
"""
# use data_info to build computational path
if self.data_info is not None:
# parse the args and kwargs
if len(args) == 1:
assert isinstance(args[0], dict), (
'You should offer a dictionary containing network outputs '
'for building up computational graph of this loss module.')
outputs_dict = args[0]
elif 'outputs_dict' in kwargs:
assert len(args) == 0, (
'If the outputs dict is given in keyworded arguments, no'
' further non-keyworded arguments should be offered.')
outputs_dict = kwargs.pop('outputs_dict')
else:
raise NotImplementedError(
'Cannot parsing your arguments passed to this loss module.'
' Please check the usage of this module')
# link the outputs with loss input args according to self.data_info
loss_input_dict = {
k: outputs_dict[v]
for k, v in self.data_info.items()
}
kwargs.update(loss_input_dict)
kwargs.update(dict(weight=self.loss_weight))
return disc_shift_loss(**kwargs)
else:
# if you have not define how to build computational graph, this
# module will just directly return the loss as usual.
return disc_shift_loss(*args, weight=self.loss_weight, **kwargs)
def loss_name(self):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return self._loss_name
@weighted_loss
def gradient_penalty_loss(discriminator,
real_data,
fake_data,
mask=None,
norm_mode='pixel'):
"""Calculate gradient penalty for wgan-gp.
In the detailed implementation, there are two streams where one uses the
pixel-wise gradient norm, but the other adopts normalization along instance
(HWC) dimensions. Thus, ``norm_mode`` are offered to define which mode you
want.
Args:
discriminator (nn.Module): Network for the discriminator.
real_data (Tensor): Real input data.
fake_data (Tensor): Fake input data.
mask (Tensor): Masks for inpainting. Default: None.
norm_mode (str): This argument decides along which dimension the norm
of the gradients will be calculated. Currently, we support ["pixel"
, "HWC"]. Defaults to "pixel".
Returns:
Tensor: A tensor for gradient penalty.
"""
batch_size = real_data.size(0)
alpha = torch.rand(batch_size, 1, 1, 1).to(real_data)
# interpolate between real_data and fake_data
interpolates = alpha * real_data + (1. - alpha) * fake_data
interpolates = autograd.Variable(interpolates, requires_grad=True)
disc_interpolates = discriminator(interpolates)
gradients = autograd.grad(
outputs=disc_interpolates,
inputs=interpolates,
grad_outputs=torch.ones_like(disc_interpolates),
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
if mask is not None:
gradients = gradients * mask
if norm_mode == 'pixel':
gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean()
elif norm_mode == 'HWC':
gradients_penalty = ((
gradients.reshape(batch_size, -1).norm(2, dim=1) - 1)**2).mean()
else:
raise NotImplementedError(
'Currently, we only support ["pixel", "HWC"] '
f'norm mode but got {norm_mode}.')
if mask is not None:
gradients_penalty /= torch.mean(mask)
return gradients_penalty
@MODULES.register_module()
class GradientPenaltyLoss(nn.Module):
"""Gradient Penalty for WGAN-GP.
In the detailed implementation, there are two streams where one uses the
pixel-wise gradient norm, but the other adopts normalization along instance
(HWC) dimensions. Thus, ``norm_mode`` are offered to define which mode you
want.
**Note for the design of ``data_info``:**
In ``MMGeneration``, almost all of loss modules contain the argument
``data_info``, which can be used for constructing the link between the
input items (needed in loss calculation) and the data from the generative
model. For example, in the training of GAN model, we will collect all of
important data/modules into a dictionary:
.. code-block:: python
:caption: Code from StaticUnconditionalGAN, train_step
:linenos:
data_dict_ = dict(
gen=self.generator,
disc=self.discriminator,
disc_pred_fake=disc_pred_fake,
disc_pred_real=disc_pred_real,
fake_imgs=fake_imgs,
real_imgs=real_imgs,
iteration=curr_iter,
batch_size=batch_size)
But in this loss, we will need to provide ``discriminator``, ``real_data``,
and ``fake_data`` as input. Thus, an example of the ``data_info`` is:
.. code-block:: python
:linenos:
data_info = dict(
discriminator='disc',
real_data='real_imgs',
fake_data='fake_imgs')
Then, the module will automatically construct this mapping from the input
data dictionary.
Args:
loss_weight (float, optional): Weight of this loss item.
Defaults to ``1.``.
data_info (dict, optional): Dictionary contains the mapping between
loss input args and data dictionary. If ``None``, this module will
directly pass the input data to the loss function.
Defaults to None.
norm_mode (str): This argument decides along which dimension the norm
of the gradients will be calculated. Currently, we support ["pixel"
, "HWC"]. Defaults to "pixel".
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_gp'.
"""
def __init__(self,
loss_weight=1.0,
norm_mode='pixel',
data_info=None,
loss_name='loss_gp'):
super().__init__()
self.loss_weight = loss_weight
self.norm_mode = norm_mode
self.data_info = data_info
self._loss_name = loss_name
def forward(self, *args, **kwargs):
"""Forward function.
If ``self.data_info`` is not ``None``, a dictionary containing all of
the data and necessary modules should be passed into this function.
If this dictionary is given as a non-keyword argument, it should be
offered as the first argument. If you are using keyword argument,
please name it as `outputs_dict`.
If ``self.data_info`` is ``None``, the input argument or key-word
argument will be directly passed to loss function,
``gradient_penalty_loss``.
"""
# use data_info to build computational path
if self.data_info is not None:
# parse the args and kwargs
if len(args) == 1:
assert isinstance(args[0], dict), (
'You should offer a dictionary containing network outputs '
'for building up computational graph of this loss module.')
outputs_dict = args[0]
elif 'outputs_dict' in kwargs:
assert len(args) == 0, (
'If the outputs dict is given in keyworded arguments, no'
' further non-keyworded arguments should be offered.')
outputs_dict = kwargs.pop('outputs_dict')
else:
raise NotImplementedError(
'Cannot parsing your arguments passed to this loss module.'
' Please check the usage of this module')
# link the outputs with loss input args according to self.data_info
loss_input_dict = {
k: outputs_dict[v]
for k, v in self.data_info.items()
}
kwargs.update(loss_input_dict)
kwargs.update(
dict(weight=self.loss_weight, norm_mode=self.norm_mode))
return gradient_penalty_loss(**kwargs)
else:
# if you have not define how to build computational graph, this
# module will just directly return the loss as usual.
return gradient_penalty_loss(
*args, weight=self.loss_weight, **kwargs)
def loss_name(self):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return self._loss_name
@weighted_loss
def r1_gradient_penalty_loss(discriminator,
real_data,
mask=None,
norm_mode='pixel',
loss_scaler=None,
use_apex_amp=False):
"""Calculate R1 gradient penalty for WGAN-GP.
R1 regularizer comes from:
"Which Training Methods for GANs do actually Converge?" ICML'2018
Different from original gradient penalty, this regularizer only penalized
gradient w.r.t. real data.
Args:
discriminator (nn.Module): Network for the discriminator.
real_data (Tensor): Real input data.
mask (Tensor): Masks for inpainting. Default: None.
norm_mode (str): This argument decides along which dimension the norm
of the gradients will be calculated. Currently, we support ["pixel"
, "HWC"]. Defaults to "pixel".
Returns:
Tensor: A tensor for gradient penalty.
"""
batch_size = real_data.shape[0]
real_data = real_data.clone().requires_grad_()
disc_pred = discriminator(real_data)
if loss_scaler:
disc_pred = loss_scaler.scale(disc_pred)
elif use_apex_amp:
from apex.amp._amp_state import _amp_state
_loss_scaler = _amp_state.loss_scalers[0]
disc_pred = _loss_scaler.loss_scale() * disc_pred.float()
gradients = autograd.grad(
outputs=disc_pred,
inputs=real_data,
grad_outputs=torch.ones_like(disc_pred),
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
if loss_scaler:
# unscale the gradient
inv_scale = 1. / loss_scaler.get_scale()
gradients = gradients * inv_scale
elif use_apex_amp:
inv_scale = 1. / _loss_scaler.loss_scale()
gradients = gradients * inv_scale
if mask is not None:
gradients = gradients * mask
if norm_mode == 'pixel':
gradients_penalty = ((gradients.norm(2, dim=1))**2).mean()
elif norm_mode == 'HWC':
gradients_penalty = gradients.pow(2).reshape(batch_size,
-1).sum(1).mean()
else:
raise NotImplementedError(
'Currently, we only support ["pixel", "HWC"] '
f'norm mode but got {norm_mode}.')
if mask is not None:
gradients_penalty /= torch.mean(mask)
return gradients_penalty
@MODULES.register_module()
class R1GradientPenalty(nn.Module):
"""R1 gradient penalty for WGAN-GP.
R1 regularizer comes from:
"Which Training Methods for GANs do actually Converge?" ICML'2018
Different from original gradient penalty, this regularizer only penalized
gradient w.r.t. real data.
**Note for the design of ``data_info``:**
In ``MMGeneration``, almost all of loss modules contain the argument
``data_info``, which can be used for constructing the link between the
input items (needed in loss calculation) and the data from the generative
model. For example, in the training of GAN model, we will collect all of
important data/modules into a dictionary:
.. code-block:: python
:caption: Code from StaticUnconditionalGAN, train_step
:linenos:
data_dict_ = dict(
gen=self.generator,
disc=self.discriminator,
disc_pred_fake=disc_pred_fake,
disc_pred_real=disc_pred_real,
fake_imgs=fake_imgs,
real_imgs=real_imgs,
iteration=curr_iter,
batch_size=batch_size)
But in this loss, we will need to provide ``discriminator`` and
``real_data`` as input. Thus, an example of the ``data_info`` is:
.. code-block:: python
:linenos:
data_info = dict(
discriminator='disc',
real_data='real_imgs')
Then, the module will automatically construct this mapping from the input
data dictionary.
Args:
loss_weight (float, optional): Weight of this loss item.
Defaults to ``1.``.
data_info (dict, optional): Dictionary contains the mapping between
loss input args and data dictionary. If ``None``, this module will
directly pass the input data to the loss function.
Defaults to None.
norm_mode (str): This argument decides along which dimension the norm
of the gradients will be calculated. Currently, we support ["pixel"
, "HWC"]. Defaults to "pixel".
interval (int, optional): The interval of calculating this loss.
Defaults to 1.
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_r1_gp'.
"""
def __init__(self,
loss_weight=1.0,
norm_mode='pixel',
interval=1,
data_info=None,
use_apex_amp=False,
loss_name='loss_r1_gp'):
super().__init__()
self.loss_weight = loss_weight
self.norm_mode = norm_mode
self.interval = interval
self.data_info = data_info
self.use_apex_amp = use_apex_amp
self._loss_name = loss_name
def forward(self, *args, **kwargs):
"""Forward function.
If ``self.data_info`` is not ``None``, a dictionary containing all of
the data and necessary modules should be passed into this function.
If this dictionary is given as a non-keyword argument, it should be
offered as the first argument. If you are using keyword argument,
please name it as `outputs_dict`.
If ``self.data_info`` is ``None``, the input argument or key-word
argument will be directly passed to loss function,
``r1_gradient_penalty_loss``.
"""
if self.interval > 1:
assert self.data_info is not None
# use data_info to build computational path
if self.data_info is not None:
# parse the args and kwargs
if len(args) == 1:
assert isinstance(args[0], dict), (
'You should offer a dictionary containing network outputs '
'for building up computational graph of this loss module.')
outputs_dict = args[0]
elif 'outputs_dict' in kwargs:
assert len(args) == 0, (
'If the outputs dict is given in keyworded arguments, no'
' further non-keyworded arguments should be offered.')
outputs_dict = kwargs.pop('outputs_dict')
else:
raise NotImplementedError(
'Cannot parsing your arguments passed to this loss module.'
' Please check the usage of this module')
if self.interval > 1 and outputs_dict[
'iteration'] % self.interval != 0:
return None
# link the outputs with loss input args according to self.data_info
loss_input_dict = {
k: outputs_dict[v]
for k, v in self.data_info.items()
}
kwargs.update(loss_input_dict)
kwargs.update(
dict(
weight=self.loss_weight,
norm_mode=self.norm_mode,
use_apex_amp=self.use_apex_amp))
return r1_gradient_penalty_loss(**kwargs)
else:
# if you have not define how to build computational graph, this
# module will just directly return the loss as usual.
return r1_gradient_penalty_loss(
*args,
weight=self.loss_weight,
norm_mode=self.norm_mode,
**kwargs)
def loss_name(self):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return self._loss_name
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
import torch.nn.functional as F
from ..builder import MODULES
@MODULES.register_module()
class GANLoss(nn.Module):
"""Define GAN loss.
Args:
gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge',
'wgan-logistic-ns'.
real_label_val (float): The value for real label. Default: 1.0.
fake_label_val (float): The value for fake label. Default: 0.0.
loss_weight (float): Loss weight. Default: 1.0.
Note that loss_weight is only for generators; and it is always 1.0
for discriminators.
"""
def __init__(self,
gan_type,
real_label_val=1.0,
fake_label_val=0.0,
loss_weight=1.0):
super().__init__()
self.gan_type = gan_type
self.loss_weight = loss_weight
self.real_label_val = real_label_val
self.fake_label_val = fake_label_val
if self.gan_type == 'vanilla':
self.loss = nn.BCEWithLogitsLoss()
elif self.gan_type == 'lsgan':
self.loss = nn.MSELoss()
elif self.gan_type == 'wgan':
self.loss = self._wgan_loss
elif self.gan_type == 'wgan-logistic-ns':
self.loss = self._wgan_logistic_ns_loss
elif self.gan_type == 'hinge':
self.loss = nn.ReLU()
else:
raise NotImplementedError(
f'GAN type {self.gan_type} is not implemented.')
def _wgan_loss(self, input, target):
"""wgan loss.
Args:
input (Tensor): Input tensor.
target (bool): Target label.
Returns:
Tensor: wgan loss.
"""
return -input.mean() if target else input.mean()
def _wgan_logistic_ns_loss(self, input, target):
"""WGAN loss in logistically non-saturating mode.
This loss is widely used in StyleGANv2.
Args:
input (Tensor): Input tensor.
target (bool): Target label.
Returns:
Tensor: wgan loss.
"""
return F.softplus(-input).mean() if target else F.softplus(
input).mean()
def get_target_label(self, input, target_is_real):
"""Get target label.
Args:
input (Tensor): Input tensor.
target_is_real (bool): Whether the target is real or fake.
Returns:
(bool | Tensor): Target tensor. Return bool for wgan, otherwise, \
return Tensor.
"""
if self.gan_type in ['wgan', 'wgan-logistic-ns']:
return target_is_real
target_val = (
self.real_label_val if target_is_real else self.fake_label_val)
return input.new_ones(input.size()) * target_val
def forward(self, input, target_is_real, is_disc=False):
"""
Args:
input (Tensor): The input for the loss module, i.e., the network
prediction.
target_is_real (bool): Whether the targe is real or fake.
is_disc (bool): Whether the loss for discriminators or not.
Default: False.
Returns:
Tensor: GAN loss value.
"""
target_label = self.get_target_label(input, target_is_real)
if self.gan_type == 'hinge':
if is_disc: # for discriminators in hinge-gan
input = -input if target_is_real else input
loss = self.loss(1 + input).mean()
else: # for generators in hinge-gan
loss = -input.mean()
else: # other gan types
loss = self.loss(input, target_label)
# loss_weight is always 1.0 for discriminators
return loss if is_disc else loss * self.loss_weight
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.autograd as autograd
import torch.distributed as dist
import torch.nn as nn
import torchvision.models.vgg as vgg
from mmcv.runner import load_checkpoint
from mmgen.models.builder import MODULES, build_module
from mmgen.utils import get_root_logger
from .pixelwise_loss import l1_loss, mse_loss
def gen_path_regularizer(generator,
num_batches,
mean_path_length,
pl_batch_shrink=1,
decay=0.01,
weight=1.,
pl_batch_size=None,
sync_mean_buffer=False,
loss_scaler=None,
use_apex_amp=False):
"""Generator Path Regularization.
Path regularization is proposed in StyelGAN2, which can help the improve
the continuity of the latent space. More details can be found in:
Analyzing and Improving the Image Quality of StyleGAN, CVPR2020.
Args:
generator (nn.Module): The generator module. Note that this loss
requires that the generator contains ``return_latents`` interface,
with which we can get the latent code of the current sample.
num_batches (int): The number of samples used in calculating this loss.
mean_path_length (Tensor): The mean path length, calculated by moving
average.
pl_batch_shrink (int, optional): The factor of shrinking the batch size
for saving GPU memory. Defaults to 1.
decay (float, optional): Decay for moving average of mean path length.
Defaults to 0.01.
weight (float, optional): Weight of this loss item. Defaults to ``1.``.
pl_batch_size (int | None, optional): The batch size in calculating
generator path. Once this argument is set, the ``num_batches`` will
be overridden with this argument and won't be affectted by
``pl_batch_shrink``. Defaults to None.
sync_mean_buffer (bool, optional): Whether to sync mean path length
across all of GPUs. Defaults to False.
Returns:
tuple[Tensor]: The penalty loss, detached mean path tensor, and \
current path length.
"""
# reduce batch size for conserving GPU memory
if pl_batch_shrink > 1:
num_batches = max(1, num_batches // pl_batch_shrink)
# reset the batch size if pl_batch_size is not None
if pl_batch_size is not None:
num_batches = pl_batch_size
# get output from different generators
output_dict = generator(None, num_batches=num_batches, return_latents=True)
fake_img, latents = output_dict['fake_img'], output_dict['latent']
noise = torch.randn_like(fake_img) / np.sqrt(
fake_img.shape[2] * fake_img.shape[3])
if loss_scaler:
loss = loss_scaler.scale((fake_img * noise).sum())[0]
grad = autograd.grad(
outputs=loss,
inputs=latents,
grad_outputs=torch.ones(()).to(loss),
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
# unsacle the grad
inv_scale = 1. / loss_scaler.get_scale()
grad = grad * inv_scale
elif use_apex_amp:
from apex.amp._amp_state import _amp_state
# by default, we use loss_scalers[0] for discriminator and
# loss_scalers[1] for generator
_loss_scaler = _amp_state.loss_scalers[1]
loss = _loss_scaler.loss_scale() * ((fake_img * noise).sum()).float()
grad = autograd.grad(
outputs=loss,
inputs=latents,
grad_outputs=torch.ones(()).to(loss),
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
# unsacle the grad
inv_scale = 1. / _loss_scaler.loss_scale()
grad = grad * inv_scale
else:
grad = autograd.grad(
outputs=(fake_img * noise).sum(),
inputs=latents,
grad_outputs=torch.ones(()).to(fake_img),
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
# update mean path
path_mean = mean_path_length + decay * (
path_lengths.mean() - mean_path_length)
if sync_mean_buffer and dist.is_initialized():
dist.all_reduce(path_mean)
path_mean = path_mean / float(dist.get_world_size())
path_penalty = (path_lengths - path_mean).pow(2).mean() * weight
return path_penalty, path_mean.detach(), path_lengths
@MODULES.register_module()
class GeneratorPathRegularizer(nn.Module):
"""Generator Path Regularizer.
Path regularization is proposed in StyelGAN2, which can help the improve
the continuity of the latent space. More details can be found in:
Analyzing and Improving the Image Quality of StyleGAN, CVPR2020.
Users can achieve lazy regularization by setting ``interval`` arguments
here.
**Note for the design of ``data_info``:**
In ``MMGeneration``, almost all of loss modules contain the argument
``data_info``, which can be used for constructing the link between the
input items (needed in loss calculation) and the data from the generative
model. For example, in the training of GAN model, we will collect all of
important data/modules into a dictionary:
.. code-block:: python
:caption: Code from StaticUnconditionalGAN, train_step
:linenos:
data_dict_ = dict(
gen=self.generator,
disc=self.discriminator,
fake_imgs=fake_imgs,
disc_pred_fake_g=disc_pred_fake_g,
iteration=curr_iter,
batch_size=batch_size)
But in this loss, we will need to provide ``generator`` and ``num_batches``
as input. Thus an example of the ``data_info`` is:
.. code-block:: python
:linenos:
data_info = dict(
generator='gen',
num_batches='batch_size')
Then, the module will automatically construct this mapping from the input
data dictionary.
Args:
loss_weight (float, optional): Weight of this loss item.
Defaults to ``1.``.
pl_batch_shrink (int, optional): The factor of shrinking the batch size
for saving GPU memory. Defaults to 1.
decay (float, optional): Decay for moving average of mean path length.
Defaults to 0.01.
pl_batch_size (int | None, optional): The batch size in calculating
generator path. Once this argument is set, the ``num_batches`` will
be overridden with this argument and won't be affectted by
``pl_batch_shrink``. Defaults to None.
sync_mean_buffer (bool, optional): Whether to sync mean path length
across all of GPUs. Defaults to False.
interval (int, optional): The interval of calculating this loss. This
argument is used to support lazy regularization. Defaults to 1.
data_info (dict, optional): Dictionary contains the mapping between
loss input args and data dictionary. If ``None``, this module will
directly pass the input data to the loss function.
Defaults to None.
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_path_regular'.
"""
def __init__(self,
loss_weight=1.,
pl_batch_shrink=1,
decay=0.01,
pl_batch_size=None,
sync_mean_buffer=False,
interval=1,
data_info=None,
use_apex_amp=False,
loss_name='loss_path_regular'):
super().__init__()
self.loss_weight = loss_weight
self.pl_batch_shrink = pl_batch_shrink
self.decay = decay
self.pl_batch_size = pl_batch_size
self.sync_mean_buffer = sync_mean_buffer
self.interval = interval
self.data_info = data_info
self.use_apex_amp = use_apex_amp
self._loss_name = loss_name
self.register_buffer('mean_path_length', torch.tensor(0.))
def forward(self, *args, **kwargs):
"""Forward function.
If ``self.data_info`` is not ``None``, a dictionary containing all of
the data and necessary modules should be passed into this function.
If this dictionary is given as a non-keyword argument, it should be
offered as the first argument. If you are using keyword argument,
please name it as `outputs_dict`.
If ``self.data_info`` is ``None``, the input argument or key-word
argument will be directly passed to loss function,
``gen_path_regularizer``.
"""
if self.interval > 1:
assert self.data_info is not None
# use data_info to build computational path
if self.data_info is not None:
# parse the args and kwargs
if len(args) == 1:
assert isinstance(args[0], dict), (
'You should offer a dictionary containing network outputs '
'for building up computational graph of this loss module.')
outputs_dict = args[0]
elif 'outputs_dict' in kwargs:
assert len(args) == 0, (
'If the outputs dict is given in keyworded arguments, no'
' further non-keyworded arguments should be offered.')
outputs_dict = kwargs.pop('outputs_dict')
else:
raise NotImplementedError(
'Cannot parsing your arguments passed to this loss module.'
' Please check the usage of this module')
if self.interval > 1 and outputs_dict[
'iteration'] % self.interval != 0:
return None
# link the outputs with loss input args according to self.data_info
loss_input_dict = {
k: outputs_dict[v]
for k, v in self.data_info.items()
}
kwargs.update(loss_input_dict)
kwargs.update(
dict(
weight=self.loss_weight,
mean_path_length=self.mean_path_length,
pl_batch_shrink=self.pl_batch_shrink,
decay=self.decay,
use_apex_amp=self.use_apex_amp,
pl_batch_size=self.pl_batch_size,
sync_mean_buffer=self.sync_mean_buffer))
path_penalty, self.mean_path_length, _ = gen_path_regularizer(
**kwargs)
return path_penalty
else:
# if you have not define how to build computational graph, this
# module will just directly return the loss as usual.
return gen_path_regularizer(
*args, weight=self.loss_weight, **kwargs)
def loss_name(self):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return self._loss_name
def third_party_net_loss(net, weight=1.0, **kwargs):
return net(**kwargs) * weight
@MODULES.register_module()
class FaceIdLoss(nn.Module):
"""Face similarity loss. Generally this loss is used to keep the id
consistency of the input face image and output face image.
In this loss, we may need to provide ``gt``, ``pred`` and ``x``. Thus,
an example of the ``data_info`` is:
.. code-block:: python
:linenos:
data_info = dict(
gt='real_imgs',
pred='fake_imgs')
Then, the module will automatically construct this mapping from the input
data dictionary.
Args:
loss_weight (float, optional): Weight of this loss item.
Defaults to ``1.``.
data_info (dict, optional): Dictionary contains the mapping between
loss input args and data dictionary. If ``None``, this module will
directly pass the input data to the loss function.
Defaults to None.
facenet (dict, optional): Config dict for facenet. Defaults to
dict(type='ArcFace', ir_se50_weights=None, device='cuda').
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_id'.
"""
def __init__(self,
loss_weight=1.0,
data_info=None,
facenet=dict(
type='ArcFace', ir_se50_weights=None, device='cuda'),
loss_name='loss_id'):
super(FaceIdLoss, self).__init__()
self.loss_weight = loss_weight
self.data_info = data_info
self.net = build_module(facenet)
self._loss_name = loss_name
def forward(self, *args, **kwargs):
"""Forward function.
If ``self.data_info`` is not ``None``, a dictionary containing all of
the data and necessary modules should be passed into this function.
If this dictionary is given as a non-keyword argument, it should be
offered as the first argument. If you are using keyword argument,
please name it as `outputs_dict`.
If ``self.data_info`` is ``None``, the input argument or key-word
argument will be directly passed to loss function,
``third_party_net_loss``.
"""
# use data_info to build computational path
if self.data_info is not None:
# parse the args and kwargs
if len(args) == 1:
assert isinstance(args[0], dict), (
'You should offer a dictionary containing network outputs '
'for building up computational graph of this loss module.')
outputs_dict = args[0]
elif 'outputs_dict' in kwargs:
assert len(args) == 0, (
'If the outputs dict is given in keyworded arguments, no'
' further non-keyworded arguments should be offered.')
outputs_dict = kwargs.pop('outputs_dict')
else:
raise NotImplementedError(
'Cannot parsing your arguments passed to this loss module.'
' Please check the usage of this module')
# link the outputs with loss input args according to self.data_info
loss_input_dict = {
k: outputs_dict[v]
for k, v in self.data_info.items()
}
kwargs.update(loss_input_dict)
kwargs.update(dict(weight=self.loss_weight))
return third_party_net_loss(self.net, *args, **kwargs)
else:
# if you have not define how to build computational graph, this
# module will just directly return the loss as usual.
return third_party_net_loss(
self.net, *args, weight=self.loss_weight, **kwargs)
def loss_name(self):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return self._loss_name
class CLIPLossModel(torch.nn.Module):
"""Wrapped clip model to calculate clip loss.
Ref: https://github.com/orpatashnik/StyleCLIP/blob/main/criteria/clip_loss.py # noqa
Args:
in_size (int, optional): Input image size. Defaults to 1024.
scale_factor (int, optional): Unsampling factor. Defaults to 7.
pool_size (int, optional): Pooling output size. Defaults to 224.
clip_type (str, optional): A model name listed by
`clip.available_models()`, or the path to a model checkpoint
containing the state_dict. For more details, you can refer to
https://github.com/openai/CLIP/blob/573315e83f07b53a61ff5098757e8fc885f1703e/clip/clip.py#L91 # noqa
Defaults to 'ViT-B/32'.
device (str, optional): Model device. Defaults to 'cuda'.
"""
def __init__(self,
in_size=1024,
scale_factor=7,
pool_size=224,
clip_type='ViT-B/32',
device='cuda'):
super(CLIPLossModel, self).__init__()
try:
import clip
except ImportError:
raise 'To use clip loss, openai clip need to be installed first'
self.model, self.preprocess = clip.load(clip_type, device=device)
self.upsample = torch.nn.Upsample(scale_factor=scale_factor)
self.avg_pool = torch.nn.AvgPool2d(
kernel_size=(scale_factor * in_size // pool_size))
def forward(self, image=None, text=None):
"""Forward function."""
assert image is not None
assert text is not None
image = self.avg_pool(self.upsample(image))
loss = 1 - self.model(image, text)[0] / 100
return loss
@MODULES.register_module()
class CLIPLoss(nn.Module):
"""Clip loss. In styleclip, this loss is used to optimize the latent code
to generate image that match the text.
In this loss, we may need to provide ``image``, ``text``. Thus,
an example of the ``data_info`` is:
.. code-block:: python
:linenos:
data_info = dict(
image='fake_imgs',
text='descriptions')
Then, the module will automatically construct this mapping from the input
data dictionary.
Args:
loss_weight (float, optional): Weight of this loss item.
Defaults to ``1.``.
data_info (dict, optional): Dictionary contains the mapping between
loss input args and data dictionary. If ``None``, this module will
directly pass the input data to the loss function.
Defaults to None.
clip_model (dict, optional): Kwargs for clip loss model. Defaults to
dict().
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_clip'.
"""
def __init__(self,
loss_weight=1.0,
data_info=None,
clip_model=dict(),
loss_name='loss_clip'):
super(CLIPLoss, self).__init__()
self.loss_weight = loss_weight
self.data_info = data_info
self.net = CLIPLossModel(**clip_model)
self._loss_name = loss_name
def forward(self, *args, **kwargs):
"""Forward function.
If ``self.data_info`` is not ``None``, a dictionary containing all of
the data and necessary modules should be passed into this function.
If this dictionary is given as a non-keyword argument, it should be
offered as the first argument. If you are using keyword argument,
please name it as `outputs_dict`.
If ``self.data_info`` is ``None``, the input argument or key-word
argument will be directly passed to loss function,
``third_party_net_loss``.
"""
# use data_info to build computational path
if self.data_info is not None:
# parse the args and kwargs
if len(args) == 1:
assert isinstance(args[0], dict), (
'You should offer a dictionary containing network outputs '
'for building up computational graph of this loss module.')
outputs_dict = args[0]
elif 'outputs_dict' in kwargs:
assert len(args) == 0, (
'If the outputs dict is given in keyworded arguments, no'
' further non-keyworded arguments should be offered.')
outputs_dict = kwargs.pop('outputs_dict')
else:
raise NotImplementedError(
'Cannot parsing your arguments passed to this loss module.'
' Please check the usage of this module')
# link the outputs with loss input args according to self.data_info
loss_input_dict = {
k: outputs_dict[v]
for k, v in self.data_info.items()
}
kwargs.update(loss_input_dict)
kwargs.update(dict(weight=self.loss_weight))
return third_party_net_loss(self.net, *args, **kwargs)
else:
# if you have not define how to build computational graph, this
# module will just directly return the loss as usual.
return third_party_net_loss(
self.net, *args, weight=self.loss_weight, **kwargs)
@staticmethod
def loss_name():
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return 'clip_loss'
class PerceptualVGG(nn.Module):
"""VGG network used in calculating perceptual loss.
In this implementation, we allow users to choose whether use normalization
in the input feature and the type of vgg network. Note that the pretrained
path must fit the vgg type.
Args:
layer_name_list (list[str]): According to the name in this list,
forward function will return the corresponding features. This
list contains the name each layer in `vgg.feature`. An example
of this list is ['4', '10'].
vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
use_input_norm (bool): If True, normalize the input image.
Importantly, the input feature must in the range [0, 1].
Default: True.
pretrained (str): Path for pretrained weights. Default:
'torchvision://vgg19'
"""
def __init__(self,
layer_name_list,
vgg_type='vgg19',
use_input_norm=True,
pretrained='torchvision://vgg19'):
super().__init__()
if pretrained.startswith('torchvision://'):
assert vgg_type in pretrained
self.layer_name_list = layer_name_list
self.use_input_norm = use_input_norm
# get vgg model and load pretrained vgg weight
# remove _vgg from attributes to avoid `find_unused_parameters` bug
_vgg = getattr(vgg, vgg_type)()
self.init_weights(_vgg, pretrained)
num_layers = max(map(int, layer_name_list)) + 1
assert len(_vgg.features) >= num_layers
# only borrow layers that will be used from _vgg to avoid unused params
self.vgg_layers = _vgg.features[:num_layers]
if self.use_input_norm:
# the mean is for image with range [0, 1]
self.register_buffer(
'mean',
torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
# the std is for image with range [-1, 1]
self.register_buffer(
'std',
torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
for v in self.vgg_layers.parameters():
v.requires_grad = False
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
if self.use_input_norm:
x = (x - self.mean) / self.std
output = {}
for name, module in self.vgg_layers.named_children():
x = module(x)
if name in self.layer_name_list:
output[name] = x.clone()
return output
def init_weights(self, model, pretrained):
"""Init weights.
Args:
model (nn.Module): Models to be inited.
pretrained (str): Path for pretrained weights.
"""
logger = get_root_logger()
load_checkpoint(model, pretrained, logger=logger)
@MODULES.register_module()
class PerceptualLoss(nn.Module):
"""Perceptual loss with commonly used style loss.
.. code-block:: python
:caption: Code from StaticUnconditionalGAN, train_step
:linenos:
data_dict_ = dict(
gen=self.generator,
disc=self.discriminator,
disc_pred_fake=disc_pred_fake,
disc_pred_real=disc_pred_real,
fake_imgs=fake_imgs,
real_imgs=real_imgs,
iteration=curr_iter,
batch_size=batch_size)
But in this loss, we may need to provide ``pred`` and ``target`` as input.
Thus, an example of the ``data_info`` is:
.. code-block:: python
:linenos:
data_info = dict(
pred='fake_imgs',
target='real_imgs',
layer_weights={
'4': 1.,
'9': 1.,
'18': 1.},
)
Then, the module will automatically construct this mapping from the input
data dictionary.
Args:
data_info (dict, optional): Dictionary contains the mapping between
loss input args and data dictionary. If ``None``, this module will
directly pass the input data to the loss function.
Defaults to None.
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_mse'.
layers_weights (dict): The weight for each layer of vgg feature for
perceptual loss. Here is an example: {'4': 1., '9': 1., '18': 1.},
which means the 5th, 10th and 18th feature layer will be
extracted with weight 1.0 in calculating losses. Defaults to
'{'4': 1., '9': 1., '18': 1.}'.
layers_weights_style (dict): The weight for each layer of vgg feature
for style loss. If set to 'None', the weights are set equal to
the weights for perceptual loss. Default: None.
vgg_type (str): The type of vgg network used as feature extractor.
Default: 'vgg19'.
use_input_norm (bool): If True, normalize the input image in vgg.
Default: True.
perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
loss will be calculated and the loss will multiplied by the
weight. Default: 1.0.
style_weight (float): If `style_weight > 0`, the style loss will be
calculated and the loss will multiplied by the weight.
Default: 1.0.
norm_img (bool): If True, the image will be normed to [0, 1]. Note that
this is different from the `use_input_norm` which norm the input in
in forward function of vgg according to the statistics of dataset.
Importantly, the input image must be in range [-1, 1].
pretrained (str): Path for pretrained weights. Default:
'torchvision://vgg19'.
criterion (str): Criterion type. Options are 'l1' and 'mse'.
Default: 'l1'.
split_style_loss (bool): Whether return a separate style loss item.
Options are True and False. Default: False
"""
def __init__(self,
data_info=None,
loss_name='loss_perceptual',
layer_weights={
'4': 1.,
'9': 1.,
'18': 1.
},
layer_weights_style=None,
vgg_type='vgg19',
use_input_norm=True,
perceptual_weight=1.0,
style_weight=1.0,
norm_img=True,
pretrained='torchvision://vgg19',
criterion='l1',
split_style_loss=False):
super().__init__()
self.data_info = data_info
self._loss_name = loss_name
self.norm_img = norm_img
self.perceptual_weight = perceptual_weight
self.style_weight = style_weight
self.layer_weights = layer_weights
self.layer_weights_style = layer_weights_style
self.split_style_loss = split_style_loss
self.vgg = PerceptualVGG(
layer_name_list=list(self.layer_weights.keys()),
vgg_type=vgg_type,
use_input_norm=use_input_norm,
pretrained=pretrained)
if self.layer_weights_style is not None and \
self.layer_weights_style != self.layer_weights:
self.vgg_style = PerceptualVGG(
layer_name_list=list(self.layer_weights_style.keys()),
vgg_type=vgg_type,
use_input_norm=use_input_norm,
pretrained=pretrained)
else:
self.layer_weights_style = self.layer_weights
self.vgg_style = None
criterion = criterion.lower()
if criterion == 'l1':
self.criterion = l1_loss
elif criterion == 'mse':
self.criterion = mse_loss
else:
raise NotImplementedError(
f'{criterion} criterion has not been supported in'
' this version.')
def forward(self, *args, **kwargs):
"""Forward function. If ``self.data_info`` is not ``None``, a
dictionary containing all of the data and necessary modules should be
passed into this function. If this dictionary is given as a non-keyword
argument, it should be offered as the first argument. If you are using
keyword argument, please name it as `outputs_dict`.
If ``self.data_info`` is ``None``, the input argument or key-word
argument will be directly passed to loss function, ``mse_loss``.
Args:
pred (Tensor): Input tensor with shape (n, c, h, w).
target (Tensor): Ground-truth tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
# use data_info to build computational path
if self.data_info is not None:
# parse the args and kwargs
if len(args) == 1:
assert isinstance(args[0], dict), (
'You should offer a dictionary containing network outputs '
'for building up computational graph of this loss module.')
outputs_dict = args[0]
elif 'outputs_dict' in kwargs:
assert len(args) == 0, (
'If the outputs dict is given in keyworded arguments, no'
' further non-keyworded arguments should be offered.')
outputs_dict = kwargs.pop('outputs_dict')
else:
raise NotImplementedError(
'Cannot parsing your arguments passed to this loss module.'
' Please check the usage of this module')
# link the outputs with loss input args according to self.data_info
loss_input_dict = {
k: outputs_dict[v]
for k, v in self.data_info.items()
}
kwargs.update(loss_input_dict)
return self.perceptual_loss(**kwargs)
else:
# if you have not define how to build computational graph, this
# module will just directly return the loss as usual.
return self.perceptual_loss(*args, **kwargs)
def perceptual_loss(self, pred, target):
if self.norm_img:
pred = (pred + 1.) * 0.5
target = (target + 1.) * 0.5
# extract vgg features
pred_features = self.vgg(pred)
target_features = self.vgg(target.detach())
# calculate perceptual loss
if self.perceptual_weight > 0:
percep_loss = 0
for k in pred_features.keys():
percep_loss += self.criterion(
pred_features[k],
target_features[k],
weight=self.layer_weights[k])
percep_loss *= self.perceptual_weight
else:
percep_loss = 0.
# calculate style loss
if self.style_weight > 0:
if self.vgg_style is not None:
pred_features = self.vgg_style(pred)
target_features = self.vgg_style(target.detach())
style_loss = 0
for k in pred_features.keys():
style_loss += self.criterion(
self._gram_mat(pred_features[k]),
self._gram_mat(
target_features[k])) * self.layer_weights_style[k]
style_loss *= self.style_weight
else:
style_loss = 0.
if self.split_style_loss:
return percep_loss, style_loss
else:
return percep_loss + style_loss
def _gram_mat(self, x):
"""Calculate Gram matrix.
Args:
x (torch.Tensor): Tensor with shape of (n, c, h, w).
Returns:
torch.Tensor: Gram matrix.
"""
(n, c, h, w) = x.size()
features = x.view(n, c, w * h)
features_t = features.transpose(1, 2)
gram = features.bmm(features_t) / (c * h * w)
return gram
def loss_name(self):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return self._loss_name
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmgen.models.builder import MODULES
from .utils import weighted_loss
_reduction_modes = ['none', 'mean', 'sum', 'batchmean', 'flatmean']
@weighted_loss
def l1_loss(pred, target):
"""L1 loss.
Args:
pred (Tensor): Prediction Tensor with shape (n, c, h, w).
target (Tensor): Target Tensor with shape (n, c, h, w).
Returns:
Tensor: Calculated L1 loss.
"""
return F.l1_loss(pred, target, reduction='none')
@weighted_loss
def mse_loss(pred, target):
"""MSE loss.
Args:
pred (Tensor): Prediction Tensor with shape (n, c, h, w).
target (Tensor): Target Tensor with shape (n, c, h, w).
Returns:
Tensor: Calculated MSE loss.
"""
return F.mse_loss(pred, target, reduction='none')
@weighted_loss
def gaussian_kld(mean_target, mean_pred, logvar_target, logvar_pred, base='e'):
r"""Calculate KLD (Kullback-Leibler divergence) of two gaussian
distribution.
To be noted that in this function, KLD is calcuated in base `e`.
.. math::
:nowrap:
\begin{align}
KLD(p||q) &= -\int{p(x)\log{q(x)} dx} + \int{p(x)\log{p(x)} dx} \\
&= \frac{1}{2}\log{(2\pi \sigma_2^2)} +
\frac{\sigma_1^2 + (\mu_1 - \mu_2)^2}{2\sigma_2^2} -
\frac{1}{2}(1 + \log{2\pi \sigma_1^2}) \\
&= \log{\frac{\sigma_2}{\sigma_1}} +
\frac{\sigma_1^2 + (\mu_1 - \mu_2)^2}{2\sigma_2^2} - \frac{1}{2}
\end{align}
Args:
mean_target (torch.Tensor): Mean of the target (or the first)
distribution.
mean_pred (torch.Tensor): Mean of the predicted (or the second)
distribution.
logvar_target (torch.Tensor): Log variance of the target (or the first)
distribution
logvar_pred (torch.Tensor): Log variance of the predicted (or the
second) distribution.
base (str, optional): The log base of calculated KLD. We support
``'e'`` (for ln) and ``'2'`` (for log_2). Defaults to ``'e'``.
Returns:
torch.Tensor: KLD between two given distribution.
"""
if base not in ['e', '2']:
raise ValueError('Only support 2 and e for log base, but receive '
f'{base}')
kld = 0.5 * (-1.0 + logvar_pred - logvar_target +
torch.exp(logvar_target - logvar_pred) +
((mean_target - mean_pred)**2) * torch.exp(-logvar_pred))
if base == '2':
return kld / np.log(2.0)
return kld
def approx_gaussian_cdf(x):
r"""Approximate the cumulative distribution function of the gaussian distribution.
Refers to:
Approximations to the Cumulative Normal Function and its Inverse for Use on a Pocket Calculator # noqa
https://www.jstor.org/stable/2346872?origin=crossref
.. math::
:nowrap:
\begin{eqnarray}
\Phi(x) &\approx \frac{1}{2} \left ( 1 + \tanh(y) \right ) \\
y &= \sqrt{\frac{2}{\pi}}(x+0.044715 x^3)
\end{eqnarray}
Args:
x (torch.Tensor): Input data.
Returns:
torch.Tensor: Calculated cumulative distribution.
"""
factor = np.sqrt(2.0 / np.pi)
y = factor * (x + 0.044715 * torch.pow(x, 3))
phi = 0.5 * (1 + torch.tanh(y))
return phi
@weighted_loss
def discretized_gaussian_log_likelihood(x, mean, logvar, base='e'):
r"""Calculate gaussian log-likelihood for a discretized input. We assume
that the input `x` are ranged in [-1, 1], the likelihood term can be
calculated as the following equation:
.. math::
:nowrap:
\begin{equarray}
p_{\theta}(\mathbf{x}_0 | \mathbf{x}_1) =
\prod_{i=1}^{D} \int_{\delta_{-}(x_0^i)}^{\delta_{+}(x_0^i)}
{\mathcal{N}(x; \mu_{\theta}^i(\mathbf{x}_1, 1),
\sigma_{1}^2)}dx\\
\delta_{+}(x)= \begin{cases}
\infty & \text{if } x = 1 \\
x + \frac{1}{255} & \text{if } x < 1
\end{cases}
\quad
\delta_{-}(x)= \begin{cases}
-\infty & \text{if } x = -1 \\
x - \frac{1}{255} & \text{if } x > -1
\end{cases}
\end{equarray}
When calculating this loss term, we first normalize `x` to normal
distribution and calculate the above integral by the cumulative
distribution function of normal distribution. Then rescale results to the
target ones.
Args:
x (torch.Tensor): Target `x_0` to be modeled. Range in [-1, 1].
mean (torch.Tensor): Predicted mean of `x_0`.
logvar (torch.Tensor): Predicted log variance of `x_0`.
base (str, optional): The log base of calculated KLD. Support ``'e'``
and ``'2'``. Defaults to ``'e'``.
Returns:
torch.Tensor: Calculated log likelihood.
"""
if base not in ['e', '2']:
raise ValueError('Only support 2 and e for log base, but receive '
f'{base}')
inv_std = torch.exp(-logvar * 0.5)
x_centered = x - mean
lower_bound = (x_centered - 1.0 / 255.0) * inv_std
upper_bound = (x_centered + 1.0 / 255.0) * inv_std
cdf_to_lower = approx_gaussian_cdf(lower_bound)
cdf_to_upper = approx_gaussian_cdf(upper_bound)
log_cdf_upper = torch.log(cdf_to_upper.clamp(min=1e-12))
log_one_minus_cdf_lower = torch.log((1.0 - cdf_to_lower).clamp(min=1e-12))
log_cdf_delta = torch.log((cdf_to_upper - cdf_to_lower).clamp(min=1e-12))
log_probs = torch.where(
x < -0.999, log_cdf_upper,
torch.where(x > 0.999, log_one_minus_cdf_lower, log_cdf_delta))
if base == '2':
return log_probs / np.log(2.0)
return log_probs
@MODULES.register_module()
class MSELoss(nn.Module):
"""MSE loss.
**Note for the design of ``data_info``:**
In ``MMGeneration``, almost all of loss modules contain the argument
``data_info``, which can be used for constructing the link between the
input items (needed in loss calculation) and the data from the generative
model. For example, in the training of GAN model, we will collect all of
important data/modules into a dictionary:
.. code-block:: python
:caption: Code from StaticUnconditionalGAN, train_step
:linenos:
data_dict_ = dict(
gen=self.generator,
disc=self.discriminator,
disc_pred_fake=disc_pred_fake,
disc_pred_real=disc_pred_real,
fake_imgs=fake_imgs,
real_imgs=real_imgs,
iteration=curr_iter,
batch_size=batch_size)
But in this loss, we may need to provide ``pred`` and ``target`` as input.
Thus, an example of the ``data_info`` is:
.. code-block:: python
:linenos:
data_info = dict(
pred='fake_imgs',
target='real_imgs')
Then, the module will automatically construct this mapping from the input
data dictionary.
Args:
loss_weight (float, optional): Weight of this loss item.
Defaults to ``1.``.
data_info (dict, optional): Dictionary contains the mapping between
loss input args and data dictionary. If ``None``, this module will
directly pass the input data to the loss function.
Defaults to None.
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_mse'.
"""
def __init__(self, loss_weight=1.0, data_info=None, loss_name='loss_mse'):
super().__init__()
self.loss_weight = loss_weight
self.data_info = data_info
self._loss_name = loss_name
def forward(self, *args, **kwargs):
"""Forward function.
If ``self.data_info`` is not ``None``, a dictionary containing all of
the data and necessary modules should be passed into this function.
If this dictionary is given as a non-keyword argument, it should be
offered as the first argument. If you are using keyword argument,
please name it as `outputs_dict`.
If ``self.data_info`` is ``None``, the input argument or key-word
argument will be directly passed to loss function, ``mse_loss``.
"""
# use data_info to build computational path
if self.data_info is not None:
# parse the args and kwargs
if len(args) == 1:
assert isinstance(args[0], dict), (
'You should offer a dictionary containing network outputs '
'for building up computational graph of this loss module.')
outputs_dict = args[0]
elif 'outputs_dict' in kwargs:
assert len(args) == 0, (
'If the outputs dict is given in keyworded arguments, no'
' further non-keyworded arguments should be offered.')
outputs_dict = kwargs.pop('outputs_dict')
else:
raise NotImplementedError(
'Cannot parsing your arguments passed to this loss module.'
' Please check the usage of this module')
# link the outputs with loss input args according to self.data_info
loss_input_dict = {
k: outputs_dict[v]
for k, v in self.data_info.items()
}
kwargs.update(loss_input_dict)
kwargs.update(dict(weight=self.loss_weight))
return mse_loss(**kwargs)
else:
# if you have not define how to build computational graph, this
# module will just directly return the loss as usual.
return mse_loss(*args, weight=self.loss_weight, **kwargs)
def loss_name(self):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return self._loss_name
@MODULES.register_module()
class L1Loss(nn.Module):
"""L1 loss.
**Note for the design of ``data_info``:**
In ``MMGeneration``, almost all of loss modules contain the argument
``data_info``, which can be used for constructing the link between the
input items (needed in loss calculation) and the data from the generative
model. For example, in the training of GAN model, we will collect all of
important data/modules into a dictionary:
.. code-block:: python
:caption: Code from StaticUnconditionalGAN, train_step
:linenos:
data_dict_ = dict(
gen=self.generator,
disc=self.discriminator,
disc_pred_fake=disc_pred_fake,
disc_pred_real=disc_pred_real,
fake_imgs=fake_imgs,
real_imgs=real_imgs,
iteration=curr_iter,
batch_size=batch_size)
But in this loss, we may need to provide ``pred`` and ``target`` as input.
Thus, an example of the ``data_info`` is:
.. code-block:: python
:linenos:
data_info = dict(
pred='fake_imgs',
target='real_imgs')
Then, the module will automatically construct this mapping from the input
data dictionary.
Args:
loss_weight (float, optional): Weight of this loss item.
Defaults to ``1.``.
reduction (str, optional): Same as built-in losses of PyTorch.
Defaults to 'mean'.
avg_factor (float | None, optional): Average factor when computing the
mean of losses. Defaults to ``None``.
data_info (dict, optional): Dictionary contains the mapping between
loss input args and data dictionary. If ``None``, this module will
directly pass the input data to the loss function.
Defaults to None.
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_l1'.
"""
def __init__(self,
loss_weight=1.0,
reduction='mean',
avg_factor=None,
data_info=None,
loss_name='loss_l1'):
super().__init__()
if reduction not in _reduction_modes:
raise ValueError(f'Unsupported reduction mode: {reduction}. '
f'Supported ones are: {_reduction_modes}')
self.loss_weight = loss_weight
self.reduction = reduction
self.avg_factor = avg_factor
self.data_info = data_info
self._loss_name = loss_name
def forward(self, *args, **kwargs):
"""Forward function.
If ``self.data_info`` is not ``None``, a dictionary containing all of
the data and necessary modules should be passed into this function.
If this dictionary is given as a non-keyword argument, it should be
offered as the first argument. If you are using keyword argument,
please name it as `outputs_dict`.
If ``self.data_info`` is ``None``, the input argument or key-word
argument will be directly passed to loss function, ``l1_loss``.
"""
# use data_info to build computational path
if self.data_info is not None:
# parse the args and kwargs
if len(args) == 1:
assert isinstance(args[0], dict), (
'You should offer a dictionary containing network outputs '
'for building up computational graph of this loss module.')
outputs_dict = args[0]
elif 'outputs_dict' in kwargs:
assert len(args) == 0, (
'If the outputs dict is given in keyworded arguments, no'
' further non-keyworded arguments should be offered.')
outputs_dict = kwargs.pop('outputs_dict')
else:
raise NotImplementedError(
'Cannot parsing your arguments passed to this loss module.'
' Please check the usage of this module')
# link the outputs with loss input args according to self.data_info
loss_input_dict = {
k: outputs_dict[v]
for k, v in self.data_info.items()
}
kwargs.update(loss_input_dict)
kwargs.update(
dict(weight=self.loss_weight, reduction=self.reduction))
return l1_loss(**kwargs)
else:
# if you have not define how to build computational graph, this
# module will just directly return the loss as usual.
return l1_loss(
*args,
weight=self.loss_weight,
reduction=self.reduction,
avg_factor=self.avg_factor,
**kwargs)
def loss_name(self):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return self._loss_name
@MODULES.register_module()
class GaussianKLDLoss(nn.Module):
"""GaussianKLD loss.
**Note for the design of ``data_info``:**
In ``MMGeneration``, almost all of loss modules contain the argument
``data_info``, which can be used for constructing the link between the
input items (needed in loss calculation) and the data from the generative
model. For example, in the training of GAN model, we will collect all of
important data/modules into a dictionary:
.. code-block:: python
:caption: Code from BaseDiffusion, train_step
:linenos:
data_dict_ = dict(
denoising=denoising,
real_imgs=torch.Tensor([N, C, H, W]),
mean_pred=torch.Tensor([N, C, H, W]),
mean_target=torch.Tensor([N, C, H, W]),
logvar_pred=torch.Tensor([N, C, H, W]),
logvar_target=torch.Tensor([N, C, H, W]),
timesteps=torch.Tensor([N,]),
iteration=curr_iter,
batch_size=batch_size)
In this loss, we may need to provide ``mean_pred``, ``mean_target``,
``logvar_pred`` and ``logvar_target`` as input. Thus, an example of the
``data_info`` is:
.. code-block:: python
:linenos:
data_info = dict(
mean_pred='mean_pred',
mean_target='mean_target',
logvar_pred='logvar_pred',
logvar_target='logvar_target')
Then, the module will automatically construct this mapping from the input
data dictionary.
Args:
loss_weight (float, optional): Weight of this loss item.
Defaults to ``1.``.
reduction (str, optional): Same as built-in losses of PyTorch. Noted
that 'batchmean' mode given the correct KL divergence where losses
are averaged over batch dimension only. Defaults to 'mean'.
avg_factor (float | None, optional): Average factor when computing the
mean of losses. Defaults to ``None``.
data_info (dict, optional): Dictionary contains the mapping between
loss input args and data dictionary. If not passed,
``_default_data_info`` would be used. Defaults to None.
base (str, optional): The log base of calculated KLD. Support
``'e'`` and ``'2'``. Defaults to ``'e'``.
only_update_var (bool, optional): If true, only `logvar_pred` will be
updated and variable in output_dict corresponding to `mean_pred`
will be detached. Defaults to False.
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_l1'.
"""
_default_data_info = dict(
mean_pred='mean_pred',
mean_target='mean_target',
logvar_pred='logvar_pred',
logvar_target='logvar_target')
def __init__(self,
loss_weight=1.0,
reduction='mean',
avg_factor=None,
data_info=None,
base='e',
only_update_var=False,
loss_name='loss_GaussianKLD'):
super().__init__()
if reduction not in _reduction_modes:
raise ValueError(f'Unsupported reduction mode: {reduction}. '
f'Supported ones are: {_reduction_modes}')
self.loss_weight = loss_weight
self.reduction = reduction
self.avg_factor = avg_factor
self.data_info = self._default_data_info if data_info is None \
else data_info
self.base = base
self.only_update_var = only_update_var
self._loss_name = loss_name
def forward(self, *args, **kwargs):
"""Forward function.
If ``self.data_info`` is not ``None``, a dictionary containing all of
the data and necessary modules should be passed into this function.
If this dictionary is given as a non-keyword argument, it should be
offered as the first argument. If you are using keyword argument,
please name it as `outputs_dict`.
If ``self.data_info`` is ``None``, the input argument or key-word
argument will be directly passed to loss function,
``gaussian_kld_loss``.
"""
# parse the args and kwargs
if len(args) == 1:
assert isinstance(args[0], dict), (
'You should offer a dictionary containing network outputs '
'for building up computational graph of this loss module.')
outputs_dict = args[0]
elif 'outputs_dict' in kwargs:
assert len(args) == 0, (
'If the outputs dict is given in keyworded arguments, no'
' further non-keyworded arguments should be offered.')
outputs_dict = kwargs.pop('outputs_dict')
else:
raise NotImplementedError(
'Cannot parsing your arguments passed to this loss module.'
' Please check the usage of this module')
# link the outputs with loss input args according to self.data_info
loss_input_dict = dict()
for k, v in self.data_info.items():
if 'mean_pred' == k and self.only_update_var:
loss_input_dict[k] = outputs_dict[v].detach()
else:
loss_input_dict[k] = outputs_dict[v]
kwargs.update(loss_input_dict)
kwargs.update(
dict(
weight=self.loss_weight,
reduction=self.reduction,
base=self.base))
return gaussian_kld(**kwargs)
def loss_name(self):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return self._loss_name
# TODO: this name is toooooo long.
@MODULES.register_module()
class DiscretizedGaussianLogLikelihoodLoss(nn.Module):
r"""Discretized-Gaussian-Log-Likelihood Loss.
**Note for the design of ``data_info``:**
In ``MMGeneration``, almost all of loss modules contain the argument
``data_info``, which can be used for constructing the link between the
input items (needed in loss calculation) and the data from the generative
model. For example, in the training of GAN model, we will collect all of
important data/modules into a dictionary:
.. code-block:: python
:caption: Code from BaseDiffusion, train_step
:linenos:
data_dict_ = dict(
denoising=denoising,
real_imgs=torch.Tensor([N, C, H, W]),
mean_pred=torch.Tensor([N, C, H, W]),
mean_target=torch.Tensor([N, C, H, W]),
logvar_pred=torch.Tensor([N, C, H, W]),
logvar_target=torch.Tensor([N, C, H, W]),
timesteps=torch.Tensor([N,]),
iteration=curr_iter,
batch_size=batch_size)
In this loss, we may need to provide ``mean``, ``logvar`` and ``x``. Thus,
an example of the ``data_info`` is:
.. code-block:: python
:linenos:
data_info = dict(
x='real_imgs',
mean='mean_pred',
logvar='logvar_pred')
Then, the module will automatically construct this mapping from the input
data dictionary.
Args:
loss_weight (float, optional): Weight of this loss item.
Defaults to ``1.``.
reduction (str, optional): Same as built-in losses of PyTorch.
Defaults to 'mean'.
avg_factor (float | None, optional): Average factor when computing the
mean of losses. Defaults to ``None``.
data_info (dict, optional): Dictionary contains the mapping between
loss input args and data dictionary. If not passed,
``_default_data_info`` would be used. Defaults to None.
base (str, optional): The log base of calculated KLD. Support
``'e'`` and ``'2'``. Defaults to ``'e'``.
only_update_var (bool, optional): If true, only `logvar_pred` will be
updated and variable in output_dict corresponding to `mean_pred`
will be detached. Defaults to False.
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_l1'.
"""
_default_data_info = dict(
x='real_imgs', mean='mean_pred', logvar='logvar_pred')
def __init__(self,
loss_weight=1.0,
reduction='mean',
avg_factor=None,
data_info=None,
base='e',
only_update_var=False,
loss_name='loss_DiscGaussianLogLikelihood'):
super().__init__()
if reduction not in _reduction_modes:
raise ValueError(f'Unsupported reduction mode: {reduction}. '
f'Supported ones are: {_reduction_modes}')
self.loss_weight = loss_weight
self.reduction = reduction
self.avg_factor = avg_factor
self.data_info = self._default_data_info if data_info is None \
else data_info
self.base = base
self.only_update_var = only_update_var
self._loss_name = loss_name
def forward(self, *args, **kwargs):
"""Forward function.
If ``self.data_info`` is not ``None``, a dictionary containing all of
the data and necessary modules should be passed into this function.
If this dictionary is given as a non-keyword argument, it should be
offered as the first argument. If you are using keyword argument,
please name it as `outputs_dict`.
If ``self.data_info`` is ``None``, the input argument or key-word
argument will be directly passed to loss function,
``gaussian_kld_loss``.
"""
# parse the args and kwargs
if len(args) == 1:
assert isinstance(args[0], dict), (
'You should offer a dictionary containing network outputs '
'for building up computational graph of this loss module.')
outputs_dict = args[0]
elif 'outputs_dict' in kwargs:
assert len(args) == 0, (
'If the outputs dict is given in keyworded arguments, no'
' further non-keyworded arguments should be offered.')
outputs_dict = kwargs.pop('outputs_dict')
else:
raise NotImplementedError(
'Cannot parsing your arguments passed to this loss module.'
' Please check the usage of this module')
# link the outputs with loss input args according to self.data_info
loss_input_dict = dict()
for k, v in self.data_info.items():
if k == 'mean' and self.only_update_var:
loss_input_dict[k] = outputs_dict[v].detach()
else:
loss_input_dict[k] = outputs_dict[v]
kwargs.update(loss_input_dict)
kwargs.update(
dict(
weight=self.loss_weight,
reduction=self.reduction,
base=self.base))
return discretized_gaussian_log_likelihood(**kwargs)
def loss_name(self):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return self._loss_name
# Copyright (c) OpenMMLab. All rights reserved.
import functools
import torch.nn.functional as F
def reduce_loss(loss, reduction):
"""Reduce loss as specified.
Args:
loss (Tensor): Elementwise loss tensor.
reduction (str): Options are "none", "mean", "sum", "flatmean" and
"batchmean". 'none': no reduction will be applied. 'mean': the
output will be divided by the number of elements in the output.
'sum': the output will be summed. 'batchmean': the sum of the
output will be divided by batchsize. 'flatmean': each sample
will be divided by the number of element respectively and
output will shape as [bz, ].
Return:
Tensor: Reduced loss tensor.
"""
if reduction == 'batchmean':
return loss.sum() / loss.shape[0]
if reduction == 'flatmean':
return loss.mean(dim=list(range(1, loss.ndim)))
reduction_enum = F._Reduction.get_enum(reduction)
# none: 0, elementwise_mean:1, sum: 2
if reduction_enum == 0:
return loss
if reduction_enum == 1:
return loss.mean()
if reduction_enum == 2:
return loss.sum()
raise ValueError(f'reduction type {reduction} not supported')
def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
"""Apply element-wise weight and reduce loss.
Args:
loss (Tensor): Element-wise loss.
weight (Tensor): Element-wise weights.
reduction (str): Same as built-in losses of PyTorch.
avg_factor (float): Average factor when computing the mean of losses.
Returns:
Tensor: Processed loss values.
"""
# if weight is specified, apply element-wise weight
if weight is not None:
loss = loss * weight
# if avg_factor is not specified, just reduce the loss
if avg_factor is None:
loss = reduce_loss(loss, reduction)
else:
# if reduction is mean, then average the loss by avg_factor
if reduction == 'mean':
loss = loss.sum() / avg_factor
# if reduction is 'none', then do nothing, otherwise raise an error
elif reduction != 'none':
raise ValueError('avg_factor can not be used with reduction="sum"')
return loss
def weighted_loss(loss_func):
"""Create a weighted version of a given loss function.
To use this decorator, the loss function must have the signature like
`loss_func(pred, target, **kwargs)`. The function only needs to compute
element-wise loss without any reduction. This decorator will add weight
and reduction arguments to the function. The decorated function will have
the signature like `loss_func(pred, target, weight=None, reduction='mean',
avg_factor=None, **kwargs)`.
:Example:
>>> import torch
>>> @weighted_loss
>>> def l1_loss(pred, target):
>>> return (pred - target).abs()
>>> pred = torch.Tensor([0, 2, 3])
>>> target = torch.Tensor([1, 1, 1])
>>> weight = torch.Tensor([1, 0, 1])
>>> l1_loss(pred, target)
tensor(1.3333)
>>> l1_loss(pred, target, weight)
tensor(1.)
>>> l1_loss(pred, target, reduction='none')
tensor([1., 1., 2.])
>>> l1_loss(pred, target, weight, avg_factor=2)
tensor(1.5000)
"""
@functools.wraps(loss_func)
def wrapper(*args,
weight=None,
reduction='mean',
avg_factor=None,
**kwargs):
# get element-wise loss
loss = loss_func(*args, **kwargs)
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
return wrapper
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
from torchvision.utils import make_grid
def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
"""Convert torch Tensors into image numpy arrays.
After clamping to (min, max), image values will be normalized to [0, 1].
For different tensor shapes, this function will have different behaviors:
1. 4D mini-batch Tensor of shape (N x 3/1 x H x W):
Use `make_grid` to stitch images in the batch dimension, and then
convert it to numpy array.
2. 3D Tensor of shape (3/1 x H x W) and 2D Tensor of shape (H x W):
Directly change to numpy array.
Note that the image channel in input tensors should be RGB order. This
function will convert it to cv2 convention, i.e., (H x W x C) with BGR
order.
Args:
tensor (Tensor | list[Tensor]): Input tensors.
out_type (numpy type): Output types. If ``np.uint8``, transform outputs
to uint8 type with range [0, 255]; otherwise, float type with
range [0, 1]. Default: ``np.uint8``.
min_max (tuple): min and max values for clamp.
Returns:
(Tensor | list[Tensor]): 3D ndarray of shape (H x W x C) or 2D ndarray
of shape (H x W).
"""
if not (torch.is_tensor(tensor) or
(isinstance(tensor, list)
and all(torch.is_tensor(t) for t in tensor))):
raise TypeError(
f'tensor or list of tensors expected, got {type(tensor)}')
if torch.is_tensor(tensor):
tensor = [tensor]
result = []
for _tensor in tensor:
# Squeeze two times so that:
# 1. (1, 1, h, w) -> (h, w) or
# 3. (1, 3, h, w) -> (3, h, w) or
# 2. (n>1, 3/1, h, w) -> (n>1, 3/1, h, w)
_tensor = _tensor.squeeze(0).squeeze(0)
_tensor = _tensor.float().detach().cpu().clamp_(*min_max)
_tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
n_dim = _tensor.dim()
if n_dim == 4:
img_np = make_grid(
_tensor, nrow=int(np.sqrt(_tensor.size(0))),
normalize=False).numpy()
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))
elif n_dim == 3:
img_np = _tensor.numpy()
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))
elif n_dim == 2:
img_np = _tensor.numpy()
else:
raise ValueError('Only support 4D, 3D or 2D tensor. '
f'But received with dimension: {n_dim}')
if out_type == np.uint8:
# Unlike MATLAB, numpy.unit8() WILL NOT round by default.
img_np = (img_np * 255.0).round()
img_np = img_np.astype(out_type)
result.append(img_np)
result = result[0] if len(result) == 1 else result
return result
# Copyright (c) OpenMMLab. All rights reserved.
from .base_translation_model import BaseTranslationModel
from .cyclegan import CycleGAN
from .pix2pix import Pix2Pix
from .static_translation_gan import StaticTranslationGAN
__all__ = [
'Pix2Pix', 'CycleGAN', 'BaseTranslationModel', 'StaticTranslationGAN'
]
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from copy import deepcopy
import torch.nn as nn
from ..builder import MODELS
@MODELS.register_module()
class BaseTranslationModel(nn.Module, metaclass=ABCMeta):
"""Base Translation Model.
Translation models can transfer images from one domain to
another. Domain information like `default_domain`,
`reachable_domains` are needed to initialize the class.
And we also provide query functions like `is_domain_reachable`,
`get_other_domains`.
You can get a specific generator based on the domain,
and by specifying `target_domain` in the forward function,
you can decide the domain of generated images.
Considering the difference among different image translation models,
we only provide the external interfaces mentioned above.
When you implement image translation with a specific method,
you can inherit both `BaseTranslationModel`
and the method (e.g BaseGAN) and implement abstract methods.
Args:
default_domain (str): Default output domain.
reachable_domains (list[str]): Domains that can be generated by
the model.
related_domains (list[str]): Domains involved in training and
testing. `reachable_domains` must be contained in
`related_domains`. However, related_domains may contain
source domains that are used to retrieve source images from
data_batch but not in reachable_domains.
train_cfg (dict): Config for training. Default: None.
test_cfg (dict): Config for testing. Default: None.
"""
def __init__(self,
default_domain,
reachable_domains,
related_domains,
train_cfg=None,
test_cfg=None):
self._default_domain = default_domain
self._reachable_domains = reachable_domains
self._related_domains = related_domains
assert self._default_domain in self._reachable_domains
assert set(self._reachable_domains) <= set(self._related_domains)
self.train_cfg = deepcopy(train_cfg) if train_cfg else None
self.test_cfg = deepcopy(test_cfg) if test_cfg else None
self._parse_train_cfg()
if test_cfg is not None:
self._parse_test_cfg()
@abstractmethod
def _parse_train_cfg(self):
"""Parsing train config and set some attributes for training."""
@abstractmethod
def _parse_test_cfg(self):
"""Parsing test config and set some attributes for testing."""
def forward(self, img, test_mode=False, **kwargs):
"""Forward function.
Args:
img (tensor): Input image tensor.
test_mode (bool): Whether in test mode or not. Default: False.
kwargs (dict): Other arguments.
"""
if not test_mode:
return self.forward_train(img, **kwargs)
return self.forward_test(img, **kwargs)
def forward_train(self, img, target_domain, **kwargs):
"""Forward function for training.
Args:
img (tensor): Input image tensor.
target_domain (str): Target domain of output image.
kwargs (dict): Other arguments.
Returns:
dict: Forward results.
"""
target = self.translation(img, target_domain=target_domain, **kwargs)
results = dict(source=img, target=target)
return results
def forward_test(self, img, target_domain, **kwargs):
"""Forward function for testing.
Args:
img (tensor): Input image tensor.
target_domain (str): Target domain of output image.
kwargs (dict): Other arguments.
Returns:
dict: Forward results.
"""
target = self.translation(img, target_domain=target_domain, **kwargs)
results = dict(source=img.cpu(), target=target.cpu())
return results
def is_domain_reachable(self, domain):
"""Whether image of this domain can be generated."""
return domain in self._reachable_domains
def get_other_domains(self, domain):
"""get other domains."""
return list(set(self._related_domains) - set([domain]))
@abstractmethod
def _get_target_generator(self, domain):
"""get target generator."""
def translation(self, image, target_domain=None, **kwargs):
"""Translation Image to target style.
Args:
image (tensor): Image tensor with a shape of (N, C, H, W).
target_domain (str, optional): Target domain of output image.
Default to None.
Returns:
dict: Image tensor of target style.
"""
if target_domain is None:
target_domain = self._default_domain
_model = self._get_target_generator(target_domain)
outputs = _model(image, **kwargs)
return outputs
# Copyright (c) OpenMMLab. All rights reserved.
from torch.nn.parallel.distributed import _find_tensors
from mmgen.models.builder import MODELS
from ..common import GANImageBuffer, set_requires_grad
from .static_translation_gan import StaticTranslationGAN
@MODELS.register_module()
class CycleGAN(StaticTranslationGAN):
"""CycleGAN model for unpaired image-to-image translation.
Ref:
Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial
Networks
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# GAN image buffers
self.image_buffers = dict()
self.buffer_size = (50 if self.train_cfg is None else
self.train_cfg.get('buffer_size', 50))
for domain in self._reachable_domains:
self.image_buffers[domain] = GANImageBuffer(self.buffer_size)
self.use_ema = False
def forward_test(self, img, target_domain, **kwargs):
"""Forward function for testing.
Args:
img (tensor): Input image tensor.
target_domain (str): Target domain of output image.
kwargs (dict): Other arguments.
Returns:
dict: Forward results.
"""
# This is a trick for CycleGAN
# ref: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/e1bdf46198662b0f4d0b318e24568205ec4d7aee/test.py#L54 # noqa
self.train()
target = self.translation(img, target_domain=target_domain, **kwargs)
results = dict(source=img.cpu(), target=target.cpu())
return results
def _get_disc_loss(self, outputs):
"""Backward function for the discriminators.
Args:
outputs (dict): Dict of forward results.
Returns:
dict: Discriminators' loss and loss dict.
"""
discriminators = self.get_module(self.discriminators)
log_vars_d = dict()
loss_d = 0
# GAN loss for discriminators['a']
for domain in self._reachable_domains:
losses = dict()
fake_img = self.image_buffers[domain].query(
outputs[f'fake_{domain}'])
fake_pred = discriminators[domain](fake_img.detach())
losses[f'loss_gan_d_{domain}_fake'] = self.gan_loss(
fake_pred, target_is_real=False, is_disc=True)
real_pred = discriminators[domain](outputs[f'real_{domain}'])
losses[f'loss_gan_d_{domain}_real'] = self.gan_loss(
real_pred, target_is_real=True, is_disc=True)
_loss_d, _log_vars_d = self._parse_losses(losses)
_loss_d *= 0.5
loss_d += _loss_d
log_vars_d[f'loss_gan_d_{domain}'] = _log_vars_d['loss'] * 0.5
return loss_d, log_vars_d
def _get_gen_loss(self, outputs):
"""Backward function for the generators.
Args:
outputs (dict): Dict of forward results.
Returns:
dict: Generators' loss and loss dict.
"""
generators = self.get_module(self.generators)
discriminators = self.get_module(self.discriminators)
losses = dict()
for domain in self._reachable_domains:
# Identity reconstruction for generators
outputs[f'identity_{domain}'] = generators[domain](
outputs[f'real_{domain}'])
# GAN loss for generators
fake_pred = discriminators[domain](outputs[f'fake_{domain}'])
losses[f'loss_gan_g_{domain}'] = self.gan_loss(
fake_pred, target_is_real=True, is_disc=False)
# gen auxiliary loss
if self.with_gen_auxiliary_loss:
for loss_module in self.gen_auxiliary_losses:
loss_ = loss_module(outputs)
if loss_ is None:
continue
# the `loss_name()` function return name as 'loss_xxx'
if loss_module.loss_name() in losses:
losses[loss_module.loss_name(
)] = losses[loss_module.loss_name()] + loss_
else:
losses[loss_module.loss_name()] = loss_
loss_g, log_vars_g = self._parse_losses(losses)
return loss_g, log_vars_g
def _get_opposite_domain(self, domain):
for item in self._reachable_domains:
if item != domain:
return item
return None
def train_step(self,
data_batch,
optimizer,
ddp_reducer=None,
running_status=None):
"""Training step function.
Args:
data_batch (dict): Dict of the input data batch.
optimizer (dict[torch.optim.Optimizer]): Dict of optimizers for
the generators and discriminators.
ddp_reducer (:obj:`Reducer` | None, optional): Reducer from ddp.
It is used to prepare for ``backward()`` in ddp. Defaults to
None.
running_status (dict | None, optional): Contains necessary basic
information for training, e.g., iteration number. Defaults to
None.
Returns:
dict: Dict of loss, information for logger, the number of samples\
and results for visualization.
"""
# get running status
if running_status is not None:
curr_iter = running_status['iteration']
else:
# dirty walkround for not providing running status
if not hasattr(self, 'iteration'):
self.iteration = 0
curr_iter = self.iteration
# forward generators
outputs = dict()
for target_domain in self._reachable_domains:
# fetch data by domain
source_domain = self.get_other_domains(target_domain)[0]
img = data_batch[f'img_{source_domain}']
# translation process
results = self(img, test_mode=False, target_domain=target_domain)
outputs[f'real_{source_domain}'] = results['source']
outputs[f'fake_{target_domain}'] = results['target']
# cycle process
results = self(
results['target'],
test_mode=False,
target_domain=source_domain)
outputs[f'cycle_{source_domain}'] = results['target']
log_vars = dict()
# discriminators
set_requires_grad(self.discriminators, True)
# optimize
optimizer['discriminators'].zero_grad()
loss_d, log_vars_d = self._get_disc_loss(outputs)
log_vars.update(log_vars_d)
if ddp_reducer is not None:
ddp_reducer.prepare_for_backward(_find_tensors(loss_d))
loss_d.backward()
optimizer['discriminators'].step()
# generators, no updates to discriminator parameters.
if (curr_iter % self.disc_steps == 0
and curr_iter >= self.disc_init_steps):
set_requires_grad(self.discriminators, False)
# optimize
optimizer['generators'].zero_grad()
loss_g, log_vars_g = self._get_gen_loss(outputs)
log_vars.update(log_vars_g)
if ddp_reducer is not None:
ddp_reducer.prepare_for_backward(_find_tensors(loss_g))
loss_g.backward()
optimizer['generators'].step()
if hasattr(self, 'iteration'):
self.iteration += 1
image_results = dict()
for domain in self._reachable_domains:
image_results[f'real_{domain}'] = outputs[f'real_{domain}'].cpu()
image_results[f'fake_{domain}'] = outputs[f'fake_{domain}'].cpu()
results = dict(
log_vars=log_vars,
num_samples=len(outputs[f'real_{domain}']),
results=image_results)
return results
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch.nn.parallel.distributed import _find_tensors
from mmgen.models.builder import MODELS
from ..common import set_requires_grad
from .static_translation_gan import StaticTranslationGAN
@MODELS.register_module()
class Pix2Pix(StaticTranslationGAN):
"""Pix2Pix model for paired image-to-image translation.
Ref:
Image-to-Image Translation with Conditional Adversarial Networks
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.use_ema = False
def forward_test(self, img, target_domain, **kwargs):
"""Forward function for testing.
Args:
img (tensor): Input image tensor.
target_domain (str): Target domain of output image.
kwargs (dict): Other arguments.
Returns:
dict: Forward results.
"""
# This is a trick for Pix2Pix
# ref: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/e1bdf46198662b0f4d0b318e24568205ec4d7aee/test.py#L54 # noqa
self.train()
target = self.translation(img, target_domain=target_domain, **kwargs)
results = dict(source=img.cpu(), target=target.cpu())
return results
def _get_disc_loss(self, outputs):
# GAN loss for the discriminator
losses = dict()
discriminators = self.get_module(self.discriminators)
target_domain = self._default_domain
source_domain = self.get_other_domains(target_domain)[0]
fake_ab = torch.cat((outputs[f'real_{source_domain}'],
outputs[f'fake_{target_domain}']), 1)
fake_pred = discriminators[target_domain](fake_ab.detach())
losses['loss_gan_d_fake'] = self.gan_loss(
fake_pred, target_is_real=False, is_disc=True)
real_ab = torch.cat((outputs[f'real_{source_domain}'],
outputs[f'real_{target_domain}']), 1)
real_pred = discriminators[target_domain](real_ab)
losses['loss_gan_d_real'] = self.gan_loss(
real_pred, target_is_real=True, is_disc=True)
loss_d, log_vars_d = self._parse_losses(losses)
loss_d *= 0.5
return loss_d, log_vars_d
def _get_gen_loss(self, outputs):
target_domain = self._default_domain
source_domain = self.get_other_domains(target_domain)[0]
losses = dict()
discriminators = self.get_module(self.discriminators)
# GAN loss for the generator
fake_ab = torch.cat((outputs[f'real_{source_domain}'],
outputs[f'fake_{target_domain}']), 1)
fake_pred = discriminators[target_domain](fake_ab)
losses['loss_gan_g'] = self.gan_loss(
fake_pred, target_is_real=True, is_disc=False)
# gen auxiliary loss
if self.with_gen_auxiliary_loss:
for loss_module in self.gen_auxiliary_losses:
loss_ = loss_module(outputs)
if loss_ is None:
continue
# the `loss_name()` function return name as 'loss_xxx'
if loss_module.loss_name() in losses:
losses[loss_module.loss_name(
)] = losses[loss_module.loss_name()] + loss_
else:
losses[loss_module.loss_name()] = loss_
loss_g, log_vars_g = self._parse_losses(losses)
return loss_g, log_vars_g
def train_step(self,
data_batch,
optimizer,
ddp_reducer=None,
running_status=None):
"""Training step function.
Args:
data_batch (dict): Dict of the input data batch.
optimizer (dict[torch.optim.Optimizer]): Dict of optimizers for
the generator and discriminator.
ddp_reducer (:obj:`Reducer` | None, optional): Reducer from ddp.
It is used to prepare for ``backward()`` in ddp. Defaults to
None.
running_status (dict | None, optional): Contains necessary basic
information for training, e.g., iteration number. Defaults to
None.
Returns:
dict: Dict of loss, information for logger, the number of samples\
and results for visualization.
"""
# data
target_domain = self._default_domain
source_domain = self.get_other_domains(self._default_domain)[0]
source_image = data_batch[f'img_{source_domain}']
target_image = data_batch[f'img_{target_domain}']
# get running status
if running_status is not None:
curr_iter = running_status['iteration']
else:
# dirty walkround for not providing running status
if not hasattr(self, 'iteration'):
self.iteration = 0
curr_iter = self.iteration
# forward generator
outputs = dict()
results = self(
source_image, target_domain=self._default_domain, test_mode=False)
outputs[f'real_{source_domain}'] = results['source']
outputs[f'fake_{target_domain}'] = results['target']
outputs[f'real_{target_domain}'] = target_image
log_vars = dict()
# discriminator
set_requires_grad(self.discriminators, True)
# optimize
optimizer['discriminators'].zero_grad()
loss_d, log_vars_d = self._get_disc_loss(outputs)
log_vars.update(log_vars_d)
# prepare for backward in ddp. If you do not call this function before
# back propagation, the ddp will not dynamically find the used params
# in current computation.
if ddp_reducer is not None:
ddp_reducer.prepare_for_backward(_find_tensors(loss_d))
loss_d.backward()
optimizer['discriminators'].step()
# generator, no updates to discriminator parameters.
if (curr_iter % self.disc_steps == 0
and curr_iter >= self.disc_init_steps):
set_requires_grad(self.discriminators, False)
# optimize
optimizer['generators'].zero_grad()
loss_g, log_vars_g = self._get_gen_loss(outputs)
log_vars.update(log_vars_g)
# prepare for backward in ddp. If you do not call this function
# before back propagation, the ddp will not dynamically find the
# used params in current computation.
if ddp_reducer is not None:
ddp_reducer.prepare_for_backward(_find_tensors(loss_g))
loss_g.backward()
optimizer['generators'].step()
if hasattr(self, 'iteration'):
self.iteration += 1
image_results = dict()
image_results[f'real_{source_domain}'] = outputs[
f'real_{source_domain}'].cpu()
image_results[f'fake_{target_domain}'] = outputs[
f'fake_{target_domain}'].cpu()
image_results[f'real_{target_domain}'] = outputs[
f'real_{target_domain}'].cpu()
results = dict(
log_vars=log_vars,
num_samples=len(outputs[f'real_{source_domain}']),
results=image_results)
return results
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
import torch.nn as nn
from mmcv.parallel import MMDistributedDataParallel
from ..builder import MODELS, build_module
from ..gans import BaseGAN
from .base_translation_model import BaseTranslationModel
@MODELS.register_module()
class StaticTranslationGAN(BaseTranslationModel, BaseGAN):
"""Basic translation model based on static unconditional GAN.
Args:
generator (dict): Config for the generator.
discriminator (dict): Config for the discriminator.
gan_loss (dict): Config for the gan loss.
pretrained (str | optional): Path for pretrained model.
Defaults to None.
disc_auxiliary_loss (dict | optional): Config for auxiliary loss to
discriminator. Defaults to None.
gen_auxiliary_loss (dict | optional): Config for auxiliary loss
to generator. Defaults to None.
"""
def __init__(self,
generator,
discriminator,
gan_loss,
*args,
pretrained=None,
disc_auxiliary_loss=None,
gen_auxiliary_loss=None,
**kwargs):
BaseGAN.__init__(self)
BaseTranslationModel.__init__(self, *args, **kwargs)
# Building generators and discriminators
self._gen_cfg = deepcopy(generator)
# build domain generators
self.generators = nn.ModuleDict()
for domain in self._reachable_domains:
self.generators[domain] = build_module(generator)
self._disc_cfg = deepcopy(discriminator)
# build domain discriminators
if discriminator is not None:
self.discriminators = nn.ModuleDict()
for domain in self._reachable_domains:
self.discriminators[domain] = build_module(discriminator)
# support no discriminator in testing
else:
self.discriminators = None
# support no gan_loss in testing
if gan_loss is not None:
self.gan_loss = build_module(gan_loss)
else:
self.gan_loss = None
if disc_auxiliary_loss:
self.disc_auxiliary_losses = build_module(disc_auxiliary_loss)
if not isinstance(self.disc_auxiliary_losses, nn.ModuleList):
self.disc_auxiliary_losses = nn.ModuleList(
[self.disc_auxiliary_losses])
else:
self.disc_auxiliary_loss = None
if gen_auxiliary_loss:
self.gen_auxiliary_losses = build_module(gen_auxiliary_loss)
if not isinstance(self.gen_auxiliary_losses, nn.ModuleList):
self.gen_auxiliary_losses = nn.ModuleList(
[self.gen_auxiliary_losses])
else:
self.gen_auxiliary_losses = None
self.init_weights(pretrained)
def init_weights(self, pretrained=None):
"""Initialize weights for the model.
Args:
pretrained (str, optional): Path for pretrained weights. If given
None, pretrained weights will not be loaded. Default: None.
"""
for domain in self._reachable_domains:
self.generators[domain].init_weights(pretrained=pretrained)
self.discriminators[domain].init_weights(pretrained=pretrained)
def _parse_train_cfg(self):
"""Parsing train config and set some attributes for training."""
if self.train_cfg is None:
self.train_cfg = dict()
# control the work flow in train step
self.disc_steps = self.train_cfg.get('disc_steps', 1)
self.disc_init_steps = (0 if self.train_cfg is None else
self.train_cfg.get('disc_init_steps', 0))
self.real_img_key = self.train_cfg.get('real_img_key', 'real_img')
def _parse_test_cfg(self):
"""Parsing test config and set some attributes for testing."""
if self.test_cfg is None:
self.test_cfg = dict()
# basic testing information
self.batch_size = self.test_cfg.get('batch_size', 1)
def get_module(self, module):
"""Get `nn.ModuleDict` to fit the `MMDistributedDataParallel`
interface.
Args:
module (MMDistributedDataParallel | nn.ModuleDict): The input
module that needs processing.
Returns:
nn.ModuleDict: The ModuleDict of multiple networks.
"""
if isinstance(module, MMDistributedDataParallel):
return module.module
return module
def _get_target_generator(self, domain):
"""get target generator."""
assert self.is_domain_reachable(
domain
), f'{domain} domain is not reachable, available domain list is\
{self._reachable_domains}'
return self.get_module(self.generators)[domain]
def _get_target_discriminator(self, domain):
"""get target discriminator."""
assert self.is_domain_reachable(
domain
), f'{domain} domain is not reachable, available domain list is\
{self._reachable_domains}'
return self.get_module(self.discriminators)[domain]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment