"launch/dynamo-run/src/lib.rs" did not exist on "110f3f8caeff051b32f168c44fda9faa0d71ed18"
Commit b7536f78 authored by limm's avatar limm
Browse files

add a to another part of mmgeneration code

parent 57e0e891
Pipeline #2777 canceled with stages
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
import torch
import torch.nn as nn
from torch.nn.parallel.distributed import _find_tensors
from ..builder import MODELS, build_module
from ..common import set_requires_grad
from .base_gan import BaseGAN
# _SUPPORT_METHODS_ = ['DCGAN', 'STYLEGANv2']
# @MODELS.register_module(_SUPPORT_METHODS_)
@MODELS.register_module()
class StaticUnconditionalGAN(BaseGAN):
"""Unconditional GANs with static architecture in training.
This is the standard GAN model containing standard adversarial training
schedule. To fulfill the requirements of various GAN algorithms,
``disc_auxiliary_loss`` and ``gen_auxiliary_loss`` are provided to
customize auxiliary losses, e.g., gradient penalty loss, and discriminator
shift loss. In addition, ``train_cfg`` and ``test_cfg`` aims at setuping
training schedule.
Args:
generator (dict): Config for generator.
discriminator (dict): Config for discriminator.
gan_loss (dict): Config for generative adversarial loss.
disc_auxiliary_loss (dict): Config for auxiliary loss to
discriminator.
gen_auxiliary_loss (dict | None, optional): Config for auxiliary loss
to generator. Defaults to None.
train_cfg (dict | None, optional): Config for training schedule.
Defaults to None.
test_cfg (dict | None, optional): Config for testing schedule. Defaults
to None.
"""
def __init__(self,
generator,
discriminator,
gan_loss,
disc_auxiliary_loss=None,
gen_auxiliary_loss=None,
train_cfg=None,
test_cfg=None):
super().__init__()
self._gen_cfg = deepcopy(generator)
self.generator = build_module(generator)
# support no discriminator in testing
if discriminator is not None:
self.discriminator = build_module(discriminator)
else:
self.discriminator = None
# support no gan_loss in testing
if gan_loss is not None:
self.gan_loss = build_module(gan_loss)
else:
self.gan_loss = None
if disc_auxiliary_loss:
self.disc_auxiliary_losses = build_module(disc_auxiliary_loss)
if not isinstance(self.disc_auxiliary_losses, nn.ModuleList):
self.disc_auxiliary_losses = nn.ModuleList(
[self.disc_auxiliary_losses])
else:
self.disc_auxiliary_loss = None
if gen_auxiliary_loss:
self.gen_auxiliary_losses = build_module(gen_auxiliary_loss)
if not isinstance(self.gen_auxiliary_losses, nn.ModuleList):
self.gen_auxiliary_losses = nn.ModuleList(
[self.gen_auxiliary_losses])
else:
self.gen_auxiliary_losses = None
self.train_cfg = deepcopy(train_cfg) if train_cfg else None
self.test_cfg = deepcopy(test_cfg) if test_cfg else None
self._parse_train_cfg()
if test_cfg is not None:
self._parse_test_cfg()
def _parse_train_cfg(self):
"""Parsing train config and set some attributes for training."""
if self.train_cfg is None:
self.train_cfg = dict()
# control the work flow in train step
self.disc_steps = self.train_cfg.get('disc_steps', 1)
# whether to use exponential moving average for training
self.use_ema = self.train_cfg.get('use_ema', False)
if self.use_ema:
# use deepcopy to guarantee the consistency
self.generator_ema = deepcopy(self.generator)
self.real_img_key = self.train_cfg.get('real_img_key', 'real_img')
def _parse_test_cfg(self):
"""Parsing test config and set some attributes for testing."""
if self.test_cfg is None:
self.test_cfg = dict()
# basic testing information
self.batch_size = self.test_cfg.get('batch_size', 1)
# whether to use exponential moving average for testing
self.use_ema = self.test_cfg.get('use_ema', False)
# TODO: finish ema part
def train_step(self,
data_batch,
optimizer,
ddp_reducer=None,
loss_scaler=None,
use_apex_amp=False,
running_status=None):
"""Train step function.
This function implements the standard training iteration for
asynchronous adversarial training. Namely, in each iteration, we first
update discriminator and then compute loss for generator with the newly
updated discriminator.
As for distributed training, we use the ``reducer`` from ddp to
synchronize the necessary params in current computational graph.
Args:
data_batch (dict): Input data from dataloader.
optimizer (dict): Dict contains optimizer for generator and
discriminator.
ddp_reducer (:obj:`Reducer` | None, optional): Reducer from ddp.
It is used to prepare for ``backward()`` in ddp. Defaults to
None.
loss_scaler (:obj:`torch.cuda.amp.GradScaler` | None, optional):
The loss/gradient scaler used for auto mixed-precision
training. Defaults to ``None``.
use_apex_amp (bool, optional). Whether to use apex.amp. Defaults to
``False``.
running_status (dict | None, optional): Contains necessary basic
information for training, e.g., iteration number. Defaults to
None.
Returns:
dict: Contains 'log_vars', 'num_samples', and 'results'.
"""
# get data from data_batch
real_imgs = data_batch[self.real_img_key]
# If you adopt ddp, this batch size is local batch size for each GPU.
# If you adopt dp, this batch size is the global batch size as usual.
batch_size = real_imgs.shape[0]
# get running status
if running_status is not None:
curr_iter = running_status['iteration']
else:
# dirty walkround for not providing running status
if not hasattr(self, 'iteration'):
self.iteration = 0
curr_iter = self.iteration
# disc training
set_requires_grad(self.discriminator, True)
optimizer['discriminator'].zero_grad()
# TODO: add noise sampler to customize noise sampling
# pass model specific training kwargs
g_training_kwargs = {}
if hasattr(self.generator, 'get_training_kwargs'):
g_training_kwargs.update(
self.generator.get_training_kwargs(phase='disc'))
with torch.no_grad():
fake_imgs = self.generator(
None, num_batches=batch_size, **g_training_kwargs)
# disc pred for fake imgs and real_imgs
disc_pred_fake = self.discriminator(fake_imgs)
disc_pred_real = self.discriminator(real_imgs)
# get data dict to compute losses for disc
data_dict_ = dict(
gen=self.generator,
disc=self.discriminator,
disc_pred_fake=disc_pred_fake,
disc_pred_real=disc_pred_real,
fake_imgs=fake_imgs,
real_imgs=real_imgs,
iteration=curr_iter,
batch_size=batch_size,
loss_scaler=loss_scaler)
loss_disc, log_vars_disc = self._get_disc_loss(data_dict_)
# prepare for backward in ddp. If you do not call this function before
# back propagation, the ddp will not dynamically find the used params
# in current computation.
if ddp_reducer is not None:
ddp_reducer.prepare_for_backward(_find_tensors(loss_disc))
if loss_scaler:
# add support for fp16
loss_scaler.scale(loss_disc).backward()
elif use_apex_amp:
from apex import amp
with amp.scale_loss(
loss_disc, optimizer['discriminator'],
loss_id=0) as scaled_loss_disc:
scaled_loss_disc.backward()
else:
loss_disc.backward()
if loss_scaler:
loss_scaler.unscale_(optimizer['discriminator'])
# note that we do not contain clip_grad procedure
loss_scaler.step(optimizer['discriminator'])
# loss_scaler.update will be called in runner.train()
else:
optimizer['discriminator'].step()
# skip generator training if only train discriminator for current
# iteration
if (curr_iter + 1) % self.disc_steps != 0:
results = dict(
fake_imgs=fake_imgs.cpu(), real_imgs=real_imgs.cpu())
outputs = dict(
log_vars=log_vars_disc,
num_samples=batch_size,
results=results)
if hasattr(self, 'iteration'):
self.iteration += 1
return outputs
# generator training
set_requires_grad(self.discriminator, False)
optimizer['generator'].zero_grad()
# TODO: add noise sampler to customize noise sampling
# pass model specific training kwargs
g_training_kwargs = {}
if hasattr(self.generator, 'get_training_kwargs'):
g_training_kwargs.update(
self.generator.get_training_kwargs(phase='gen'))
fake_imgs = self.generator(
None, num_batches=batch_size, **g_training_kwargs)
disc_pred_fake_g = self.discriminator(fake_imgs)
data_dict_ = dict(
gen=self.generator,
disc=self.discriminator,
fake_imgs=fake_imgs,
disc_pred_fake_g=disc_pred_fake_g,
iteration=curr_iter,
batch_size=batch_size,
loss_scaler=loss_scaler)
loss_gen, log_vars_g = self._get_gen_loss(data_dict_)
# prepare for backward in ddp. If you do not call this function before
# back propagation, the ddp will not dynamically find the used params
# in current computation.
if ddp_reducer is not None:
ddp_reducer.prepare_for_backward(_find_tensors(loss_gen))
if loss_scaler:
loss_scaler.scale(loss_gen).backward()
elif use_apex_amp:
from apex import amp
with amp.scale_loss(
loss_gen, optimizer['generator'],
loss_id=1) as scaled_loss_disc:
scaled_loss_disc.backward()
else:
loss_gen.backward()
if loss_scaler:
loss_scaler.unscale_(optimizer['generator'])
# note that we do not contain clip_grad procedure
loss_scaler.step(optimizer['generator'])
# loss_scaler.update will be called in runner.train()
else:
optimizer['generator'].step()
# update ada p
if hasattr(self.discriminator,
'with_ada') and self.discriminator.with_ada:
self.discriminator.ada_aug.log_buffer[0] += batch_size
self.discriminator.ada_aug.log_buffer[1] += disc_pred_real.sign(
).sum()
self.discriminator.ada_aug.update(
iteration=curr_iter, num_batches=batch_size)
log_vars_disc['augment'] = (
self.discriminator.ada_aug.aug_pipeline.p.data.cpu())
log_vars = {}
log_vars.update(log_vars_g)
log_vars.update(log_vars_disc)
results = dict(fake_imgs=fake_imgs.cpu(), real_imgs=real_imgs.cpu())
outputs = dict(
log_vars=log_vars, num_samples=batch_size, results=results)
if hasattr(self, 'iteration'):
self.iteration += 1
return outputs
# Copyright (c) OpenMMLab. All rights reserved.
from .ddpm_loss import DDPMVLBLoss
from .disc_auxiliary_loss import (DiscShiftLoss, GradientPenaltyLoss,
R1GradientPenalty, disc_shift_loss,
gradient_penalty_loss,
r1_gradient_penalty_loss)
from .gan_loss import GANLoss
from .gen_auxiliary_loss import (CLIPLoss, FaceIdLoss,
GeneratorPathRegularizer, PerceptualLoss,
gen_path_regularizer)
from .pixelwise_loss import (DiscretizedGaussianLogLikelihoodLoss,
GaussianKLDLoss, L1Loss, MSELoss,
discretized_gaussian_log_likelihood, gaussian_kld)
__all__ = [
'GANLoss', 'DiscShiftLoss', 'disc_shift_loss', 'gradient_penalty_loss',
'GradientPenaltyLoss', 'R1GradientPenalty', 'r1_gradient_penalty_loss',
'GeneratorPathRegularizer', 'gen_path_regularizer', 'MSELoss', 'L1Loss',
'gaussian_kld', 'GaussianKLDLoss', 'DiscretizedGaussianLogLikelihoodLoss',
'DDPMVLBLoss', 'discretized_gaussian_log_likelihood', 'FaceIdLoss',
'CLIPLoss', 'PerceptualLoss'
]
# Copyright (c) OpenMMLab. All rights reserved.
from abc import abstractmethod
from copy import deepcopy
from functools import partial
import mmcv
import torch
import torch.distributed as dist
import torch.nn as nn
from mmcv.utils import digit_version
from mmgen.models.builder import MODULES
from .pixelwise_loss import (DiscretizedGaussianLogLikelihoodLoss,
GaussianKLDLoss, _reduction_modes, mse_loss)
from .utils import reduce_loss
class DDPMLoss(nn.Module):
"""Base module for DDPM losses. We support loss weight rescale and log
collection for DDPM models in this module.
We support two kinds of loss rescale methods, which can be
controlled by ``rescale_mode`` and ``rescale_cfg``:
1. ``rescale_mode == 'constant'``: ``constant_rescale`` would be called,
and ``rescale_cfg`` should be passed as ``dict(scale=SCALE)``,
e.g., ``dict(scale=1.2)``. Then, all loss terms would be rescaled by
multiply with ``SCALE``
2. ``rescale_mode == timestep_weight``: ``timestep_weight_rescale`` would
be called, and ``weight`` or ``sampler`` who contains attribute of
weight must be passed. Then, loss at timestep `t` would be multiplied
with `weight[t]`. We also support users further apply a constant
rescale factor to all loss terms, e.g.
``rescale_cfg=dict(scale=SCALE)``. The overall rescale function for
loss at timestep ``t`` can be formulated as
`loss[t] := weight[t] * loss[t] * SCALE`. To be noted that, ``weight``
or ``sampler.weight`` would be inplace modified in the outer code.
e.g.,
.. code-blocks:: python
:linenos:
# 1. define weight
weight = torch.randn(10, )
# 2. define loss function
loss_fn = DDPMLoss(rescale_mode='timestep_weight', weight=weight)
# 3 update weight
# wrong usage: `weight` in `loss_fn` is not accessible from now
# because you assign a new tensor to variable `weight`
# weight = torch.randn(10, )
# correct usage: update `weight` inplace
weight[2] = 2
If ``rescale_mode`` is not passed, ``rescale_cfg`` would be ignored, and
all loss terms would not be rescaled.
For loss log collection, we support users to pass a list of (or single)
config by ``log_cfgs`` argument to define how they want to collect loss
terms and show them in the log. Each log collection returns a dict which
key and value are the name and value of collected loss terms. And the dict
will be merged into ``log_vars`` after the loss used for parameter
optimization is calculated. The log updating process for the class which
uses ddpm_loss can be referred to the following pseudo-code:
.. code-block:: python
:linenos:
# 1. loss dict for parameter optimization
losses_dict = {}
# 2. calculate losses
for loss_fn in self.ddpm_loss:
losses_dict[loss_fn.loss_name()] = loss_fn(outputs_dict)
# 3. init log_vars
log_vars = OrderedDict()
# 4. update log_vars with loss terms used for parameter optimization
for loss_name, loss_value in losses_dict.items():
log_vars[loss_name] = loss_value.mean()
# 5. sum all loss terms used for backward
loss = sum(_value for _key, _value in log_vars.items()
if 'loss' in _key)
# 6. update log_var with log collection functions
for loss_fn in self.ddpm_loss:
if hasattr(loss_fn, 'log_vars'):
log_vars.update(loss_fn.log_vars)
Each log configs must contain ``type`` keyword, and may contain ``prefix``
and ``reduction`` keywords.
``type``: Use to get the corresponding collection function. Functions would
be named as ``f'{type}_log_collect'``. In `DDPMLoss`, we only support
``type=='quartile'``, but users may define their log collection
functions and use them in this way.
``prefix``: This keyword is set for avoiding the name of displayed loss
terms being too long. The name of each loss term will set as
``'{prefix}_{log_coll_fn_spec_name}'``, where
``{log_coll_fn_spec_name}`` is name specific to the log collection
function. If passed, it must start with ``'loss_'``. If not passed,
``'loss_'`` would be used.
``reduction``: Control the reduction method of the collected loss terms.
We implement ``quartile_log_collection`` in this module. In detail, we
divide total timesteps into four parts and collect the loss in the
corresponding timestep intervals.
To use those collection methods, users may pass ``log_cfgs`` as the
following example:
.. code-block:: python
:linenos:
log_cfgs = [
dict(type='quartile', reduction=REUCTION, prefix_name=PREFIX),
...
]
Args:
rescale_mode (str, optional): Mode of the loss rescale method.
Defaults to None.
rescale_cfg (dict, optional): Config of the loss rescale method.
log_cfgs (list[dict] | dict | optional): Configs to collect logs.
Defaults to None.
sampler (object): Weight sampler. Defaults to None.
weight (torch.Tensor, optional): Weight used for rescale losses.
Defaults to None.
reduction (str, optional): Same as built-in losses of PyTorch.
Defaults to 'mean'.
loss_name (str, optional): Name of the loss item. Defaults to None.
"""
def __init__(self,
rescale_mode=None,
rescale_cfg=None,
log_cfgs=None,
weight=None,
sampler=None,
reduction='mean',
loss_name=None):
super().__init__()
if reduction not in _reduction_modes:
raise ValueError(f'Unsupported reduction mode: {reduction}. '
f'Supported ones are: {_reduction_modes}')
self.reduction = reduction
self._loss_name = loss_name
self.log_fn_list = []
log_cfgs_ = deepcopy(log_cfgs)
if log_cfgs_ is not None:
if not isinstance(log_cfgs_, list):
log_cfgs_ = [log_cfgs_]
assert mmcv.is_list_of(log_cfgs_, dict)
for log_cfg_ in log_cfgs_:
log_type = log_cfg_.pop('type')
log_collect_fn = f'{log_type}_log_collect'
assert hasattr(self, log_collect_fn)
log_collect_fn = getattr(self, log_collect_fn)
log_cfg_.setdefault('prefix_name', 'loss')
assert log_cfg_['prefix_name'].startswith('loss')
log_cfg_.setdefault('reduction', reduction)
self.log_fn_list.append(partial(log_collect_fn, **log_cfg_))
self.log_vars = dict()
# handle rescale mode
if not rescale_mode:
self.rescale_fn = lambda loss, t: loss
else:
rescale_fn_name = f'{rescale_mode}_rescale'
assert hasattr(self, rescale_fn_name)
if rescale_mode == 'timestep_weight':
if sampler is not None and hasattr(sampler, 'weight'):
weight = sampler.weight
else:
assert weight is not None and isinstance(
weight, torch.Tensor), (
'\'weight\' or a \'sampler\' contains weight '
'attribute is must be \'torch.Tensor\' for '
'\'timestep_weight\' rescale_mode.')
mmcv.print_log(
'Apply \'timestep_weight\' rescale_mode for '
f'{self._loss_name}. Please make sure the passed weight '
'can be updated by external functions.', 'mmgen')
rescale_cfg = dict(weight=weight)
self.rescale_fn = partial(
getattr(self, rescale_fn_name), **rescale_cfg)
@staticmethod
def constant_rescale(loss, timesteps, scale):
"""Rescale losses at all timesteps with a constant factor.
Args:
loss (torch.Tensor): Losses to rescale.
timesteps (torch.Tensor): Timesteps of each loss items.
scale (int): Rescale factor.
Returns:
torch.Tensor: Rescaled losses.
"""
return loss * scale
@staticmethod
def timestep_weight_rescale(loss, timesteps, weight, scale=1):
"""Rescale losses corresponding to timestep.
Args:
loss (torch.Tensor): Losses to rescale.
timesteps (torch.Tensor): Timesteps of each loss items.
weight (torch.Tensor): Weight corresponding to each timestep.
scale (int): Rescale factor.
Returns:
torch.Tensor: Rescaled losses.
"""
return loss * weight[timesteps] * scale
@torch.no_grad()
def collect_log(self, loss, timesteps):
"""Collect logs.
Args:
loss (torch.Tensor): Losses to collect.
timesteps (torch.Tensor): Timesteps of each loss items.
"""
if not self.log_fn_list:
return
if dist.is_initialized():
ws = dist.get_world_size()
placeholder_l = [torch.zeros_like(loss) for _ in range(ws)]
placeholder_t = [torch.zeros_like(timesteps) for _ in range(ws)]
dist.all_gather(placeholder_l, loss)
dist.all_gather(placeholder_t, timesteps)
loss = torch.cat(placeholder_l, dim=0)
timesteps = torch.cat(placeholder_t, dim=0)
log_vars = dict()
if (dist.is_initialized()
and dist.get_rank() == 0) or not dist.is_initialized():
for log_fn in self.log_fn_list:
log_vars.update(log_fn(loss, timesteps))
self.log_vars = log_vars
@torch.no_grad()
def quartile_log_collect(self,
loss,
timesteps,
total_timesteps,
prefix_name,
reduction='mean'):
"""Collect loss logs by quartile timesteps.
Args:
loss (torch.Tensor): Loss value of each input. Each loss tensor
should be shape as [bz, ]
timesteps (torch.Tensor): Timesteps corresponding to each loss.
Each loss tensor should be shape as [bz, ].
total_timesteps (int): Total timesteps of diffusion process.
prefix_name (str): Prefix want to show in logs.
reduction (str, optional): Specifies the reduction to apply to the
output losses. Defaults to `mean`.
Returns:
dict: Collected log variables.
"""
if digit_version(torch.__version__) <= digit_version('1.6.0'):
# use true_divide in older torch version
quartile = torch.true_divide(timesteps, total_timesteps) * 4
else:
quartile = (timesteps / total_timesteps * 4)
quartile = quartile.type(torch.LongTensor)
log_vars = dict()
for idx in range(4):
if not (quartile == idx).any():
loss_quartile = torch.zeros((1, ))
else:
loss_quartile = reduce_loss(loss[quartile == idx], reduction)
log_vars[f'{prefix_name}_quartile_{idx}'] = loss_quartile.item()
return log_vars
def forward(self, *args, **kwargs):
"""Forward function.
If ``self.data_info`` is not ``None``, a dictionary containing all of
the data and necessary modules should be passed into this function.
If this dictionary is given as a non-keyword argument, it should be
offered as the first argument. If you are using keyword argument,
please name it as `outputs_dict`.
If ``self.data_info`` is ``None``, the input argument or key-word
argument will be directly passed to loss function, ``mse_loss``.
"""
if len(args) == 1:
assert isinstance(args[0], dict), (
'You should offer a dictionary containing network outputs '
'for building up computational graph of this loss module.')
output_dict = args[0]
elif 'output_dict' in kwargs:
assert len(args) == 0, (
'If the outputs dict is given in keyworded arguments, no'
' further non-keyworded arguments should be offered.')
output_dict = kwargs.pop('outputs_dict')
else:
raise NotImplementedError(
'Cannot parsing your arguments passed to this loss module.'
' Please check the usage of this module')
# check keys in output_dict
assert 'timesteps' in output_dict, (
'\'timesteps\' is must for DDPM-based losses, but found'
f'{output_dict.keys()} in \'output_dict\'')
timesteps = output_dict['timesteps']
loss = self._forward_loss(output_dict)
# update log_vars of this class
self.collect_log(loss, timesteps=timesteps)
loss_rescaled = self.rescale_fn(loss, timesteps)
return reduce_loss(loss_rescaled, self.reduction)
@abstractmethod
def _forward_loss(self, output_dict):
"""Forward function for loss calculation. This method should be
implemented by each subclasses.
Args:
outputs_dict (dict): Outputs of the model used to calculate losses.
Returns:
torch.Tensor: Calculated loss.
"""
raise NotImplementedError(
'\'self._forward_loss\' must be implemented.')
def loss_name(self):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return self._loss_name
@MODULES.register_module()
class DDPMVLBLoss(DDPMLoss):
"""Variational lower-bound loss for DDPM-based models.
In this loss, we calculate VLB of different timesteps with different
method. In detail, ``DiscretizedGaussianLogLikelihoodLoss`` is used at
timesteps = 0 and ``GaussianKLDLoss`` at other timesteps.
To control the data flow for loss calculation, users should define
``data_info`` and ``data_info_t_0`` for ``GaussianKLDLoss`` and
``DiscretizedGaussianLogLikelihoodLoss`` respectively. If not passed
``_default_data_info`` and ``_default_data_info_t_0`` would be used.
To be noted that, we only penalize 'variance' in this loss term, and
tensors in output dict corresponding to 'mean' would be detached.
Additionally, we support another log collection function called
``name_log_collection``. In this collection method, we would directly
collect loss terms calculated by different methods.
To use this collection methods, users may passed ``log_cfgs`` as the
following example:
.. code-block:: python
:linenos:
log_cfgs = [
dict(type='name', reduction=REUCTION, prefix_name=PREFIX),
...
]
Args:
rescale_mode (str, optional): Mode of the loss rescale method.
Defaults to None.
rescale_cfg (dict, optional): Config of the loss rescale method.
sampler (object): Weight sampler. Defaults to None.
weight (torch.Tensor, optional): Weight used for rescale losses.
Defaults to None.
data_info (dict, optional): Dictionary contains the mapping between
loss input args and data dictionary for ``timesteps != 0``.
Defaults to None.
data_info_t_0 (dict, optional): Dictionary contains the mapping between
loss input args and data dictionary for ``timesteps == 0``.
Defaults to None.
log_cfgs (list[dict] | dict | optional): Configs to collect logs.
Defaults to None.
reduction (str, optional): Same as built-in losses of PyTorch.
Defaults to 'mean'.
loss_name (str, optional): Name of the loss item. Defaults to
'loss_ddpm_vlb'.
"""
_default_data_info = dict(
mean_pred='mean_pred',
mean_target='mean_target',
logvar_pred='logvar_pred',
logvar_target='logvar_target')
_default_data_info_t_0 = dict(
x='real_imgs', mean='mean_pred', logvar='logvar_pred')
def __init__(self,
rescale_mode=None,
rescale_cfg=None,
sampler=None,
weight=None,
data_info=None,
data_info_t_0=None,
log_cfgs=None,
reduction='mean',
loss_name='loss_ddpm_vlb'):
super().__init__(rescale_mode, rescale_cfg, log_cfgs, weight, sampler,
reduction, loss_name)
self.data_info = self._default_data_info \
if data_info is None else data_info
self.data_info_t_0 = self._default_data_info_t_0 \
if data_info_t_0 is None else data_info_t_0
self.loss_list = [
DiscretizedGaussianLogLikelihoodLoss(
reduction='flatmean',
data_info=self.data_info_t_0,
base='2',
loss_weight=-1,
only_update_var=True),
GaussianKLDLoss(
reduction='flatmean',
data_info=self.data_info,
base='2',
only_update_var=True)
]
self.loss_select_fn_list = [lambda t: t == 0, lambda t: t != 0]
@torch.no_grad()
def name_log_collect(self, loss, timesteps, prefix_name, reduction='mean'):
"""Collect loss logs by name (GaissianKLD and
DiscGaussianLogLikelihood).
Args:
loss (torch.Tensor): Loss value of each input. Each loss tensor
should be in the shape of [bz, ].
timesteps (torch.Tensor): Timesteps corresponding to each losses.
Each loss tensor should be in the shape of [bz, ].
prefix_name (str): Prefix want to show in logs.
reduction (str, optional): Specifies the reduction to apply to the
output losses. Defaults to `mean`.
Returns:
dict: Collected log variables.
"""
log_vars = dict()
for select_fn, loss_fn in zip(self.loss_select_fn_list,
self.loss_list):
mask = select_fn(timesteps)
if not mask.any():
loss_reduced = torch.zeros((1, ))
else:
loss_reduced = reduce_loss(loss[mask], reduction)
# remove original prefix in loss names
loss_term_name = loss_fn.loss_name().replace('loss_', '')
log_vars[f'{prefix_name}_{loss_term_name}'] = loss_reduced.item()
return log_vars
def _forward_loss(self, outputs_dict):
"""Forward function for loss calculation.
Args:
outputs_dict (dict): Outputs of the model used to calculate losses.
Returns:
torch.Tensor: Calculated loss.
"""
# use `zeros` instead of `zeros_like` to avoid get int tensor
timesteps = outputs_dict['timesteps']
loss = torch.zeros_like(timesteps).float()
# loss = torch.zeros(*timesteps.shape).to(timesteps.device)
for select_fn, loss_fn in zip(self.loss_select_fn_list,
self.loss_list):
mask = select_fn(timesteps)
outputs_dict_ = {}
for k, v in outputs_dict.items():
if v is None or not isinstance(v, (torch.Tensor, list)):
outputs_dict_[k] = v
elif isinstance(v, list):
outputs_dict_[k] = [
v[idx] for idx, m in enumerate(mask) if m
]
else:
outputs_dict_[k] = v[mask]
loss[mask] = loss_fn(outputs_dict_)
return loss
@MODULES.register_module()
class DDPMMSELoss(DDPMLoss):
"""Mean square loss for DDPM-based models.
Args:
rescale_mode (str, optional): Mode of the loss rescale method.
Defaults to None.
rescale_cfg (dict, optional): Config of the loss rescale method.
sampler (object): Weight sampler. Defaults to None.
weight (torch.Tensor, optional): Weight used for rescale losses.
Defaults to None.
data_info (dict, optional): Dictionary contains the mapping between
loss input args and data dictionary for ``timesteps != 0``.
Defaults to None.
log_cfgs (list[dict] | dict | optional): Configs to collect logs.
Defaults to None.
reduction (str, optional): Same as built-in losses of PyTorch.
Defaults to 'mean'.
loss_name (str, optional): Name of the loss item. Defaults to
'loss_ddpm_vlb'.
"""
_default_data_info = dict(pred='eps_t_pred', target='noise')
def __init__(self,
rescale_mode=None,
rescale_cfg=None,
sampler=None,
weight=None,
log_cfgs=None,
reduction='mean',
data_info=None,
loss_name='loss_ddpm_mse'):
super().__init__(rescale_mode, rescale_cfg, log_cfgs, weight, sampler,
reduction, loss_name)
self.data_info = self._default_data_info \
if data_info is None else data_info
self.loss_fn = partial(mse_loss, reduction='flatmean')
def _forward_loss(self, outputs_dict):
"""Forward function for loss calculation.
Args:
outputs_dict (dict): Outputs of the model used to calculate losses.
Returns:
torch.Tensor: Calculated loss.
"""
loss_input_dict = {
k: outputs_dict[v]
for k, v in self.data_info.items()
}
loss = self.loss_fn(**loss_input_dict)
return loss
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.autograd as autograd
import torch.nn as nn
from mmgen.models.builder import MODULES
from .utils import weighted_loss
@weighted_loss
def disc_shift_loss(pred):
"""Disc Shift loss.
This loss is proposed in PGGAN as an auxiliary loss for discriminator.
Args:
pred (Tensor): Input tensor.
Returns:
torch.Tensor: loss tensor.
"""
return pred**2
@MODULES.register_module()
class DiscShiftLoss(nn.Module):
"""Disc Shift Loss.
This loss is proposed in PGGAN as an auxiliary loss for discriminator.
**Note for the design of ``data_info``:**
In ``MMGeneration``, almost all of loss modules contain the argument
``data_info``, which can be used for constructing the link between the
input items (needed in loss calculation) and the data from the generative
model. For example, in the training of GAN model, we will collect all of
important data/modules into a dictionary:
.. code-block:: python
:caption: Code from StaticUnconditionalGAN, train_step
:linenos:
data_dict_ = dict(
gen=self.generator,
disc=self.discriminator,
disc_pred_fake=disc_pred_fake,
disc_pred_real=disc_pred_real,
fake_imgs=fake_imgs,
real_imgs=real_imgs,
iteration=curr_iter,
batch_size=batch_size)
But in this loss, we will need to provide ``pred`` as input. Thus, an
example of the ``data_info`` is:
.. code-block:: python
:linenos:
data_info = dict(
pred='disc_pred_fake')
Then, the module will automatically construct this mapping from the input
data dictionary.
In addition, in general, ``disc_shift_loss`` will be applied over real and
fake data. In this case, users just need to add this loss module twice, but
with different ``data_info``. Our model will automatically add these two
items.
Args:
loss_weight (float, optional): Weight of this loss item.
Defaults to ``1.``.
data_info (dict, optional): Dictionary contains the mapping between
loss input args and data dictionary. If ``None``, this module will
directly pass the input data to the loss function.
Defaults to None.
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_disc_shift'.
"""
def __init__(self,
loss_weight=1.0,
data_info=None,
loss_name='loss_disc_shift'):
super().__init__()
self.loss_weight = loss_weight
self.data_info = data_info
self._loss_name = loss_name
def forward(self, *args, **kwargs):
"""Forward function.
If ``self.data_info`` is not ``None``, a dictionary containing all of
the data and necessary modules should be passed into this function.
If this dictionary is given as a non-keyword argument, it should be
offered as the first argument. If you are using keyword argument,
please name it as `outputs_dict`.
If ``self.data_info`` is ``None``, the input argument or key-word
argument will be directly passed to loss function, ``disc_shift_loss``.
"""
# use data_info to build computational path
if self.data_info is not None:
# parse the args and kwargs
if len(args) == 1:
assert isinstance(args[0], dict), (
'You should offer a dictionary containing network outputs '
'for building up computational graph of this loss module.')
outputs_dict = args[0]
elif 'outputs_dict' in kwargs:
assert len(args) == 0, (
'If the outputs dict is given in keyworded arguments, no'
' further non-keyworded arguments should be offered.')
outputs_dict = kwargs.pop('outputs_dict')
else:
raise NotImplementedError(
'Cannot parsing your arguments passed to this loss module.'
' Please check the usage of this module')
# link the outputs with loss input args according to self.data_info
loss_input_dict = {
k: outputs_dict[v]
for k, v in self.data_info.items()
}
kwargs.update(loss_input_dict)
kwargs.update(dict(weight=self.loss_weight))
return disc_shift_loss(**kwargs)
else:
# if you have not define how to build computational graph, this
# module will just directly return the loss as usual.
return disc_shift_loss(*args, weight=self.loss_weight, **kwargs)
def loss_name(self):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return self._loss_name
@weighted_loss
def gradient_penalty_loss(discriminator,
real_data,
fake_data,
mask=None,
norm_mode='pixel'):
"""Calculate gradient penalty for wgan-gp.
In the detailed implementation, there are two streams where one uses the
pixel-wise gradient norm, but the other adopts normalization along instance
(HWC) dimensions. Thus, ``norm_mode`` are offered to define which mode you
want.
Args:
discriminator (nn.Module): Network for the discriminator.
real_data (Tensor): Real input data.
fake_data (Tensor): Fake input data.
mask (Tensor): Masks for inpainting. Default: None.
norm_mode (str): This argument decides along which dimension the norm
of the gradients will be calculated. Currently, we support ["pixel"
, "HWC"]. Defaults to "pixel".
Returns:
Tensor: A tensor for gradient penalty.
"""
batch_size = real_data.size(0)
alpha = torch.rand(batch_size, 1, 1, 1).to(real_data)
# interpolate between real_data and fake_data
interpolates = alpha * real_data + (1. - alpha) * fake_data
interpolates = autograd.Variable(interpolates, requires_grad=True)
disc_interpolates = discriminator(interpolates)
gradients = autograd.grad(
outputs=disc_interpolates,
inputs=interpolates,
grad_outputs=torch.ones_like(disc_interpolates),
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
if mask is not None:
gradients = gradients * mask
if norm_mode == 'pixel':
gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean()
elif norm_mode == 'HWC':
gradients_penalty = ((
gradients.reshape(batch_size, -1).norm(2, dim=1) - 1)**2).mean()
else:
raise NotImplementedError(
'Currently, we only support ["pixel", "HWC"] '
f'norm mode but got {norm_mode}.')
if mask is not None:
gradients_penalty /= torch.mean(mask)
return gradients_penalty
@MODULES.register_module()
class GradientPenaltyLoss(nn.Module):
"""Gradient Penalty for WGAN-GP.
In the detailed implementation, there are two streams where one uses the
pixel-wise gradient norm, but the other adopts normalization along instance
(HWC) dimensions. Thus, ``norm_mode`` are offered to define which mode you
want.
**Note for the design of ``data_info``:**
In ``MMGeneration``, almost all of loss modules contain the argument
``data_info``, which can be used for constructing the link between the
input items (needed in loss calculation) and the data from the generative
model. For example, in the training of GAN model, we will collect all of
important data/modules into a dictionary:
.. code-block:: python
:caption: Code from StaticUnconditionalGAN, train_step
:linenos:
data_dict_ = dict(
gen=self.generator,
disc=self.discriminator,
disc_pred_fake=disc_pred_fake,
disc_pred_real=disc_pred_real,
fake_imgs=fake_imgs,
real_imgs=real_imgs,
iteration=curr_iter,
batch_size=batch_size)
But in this loss, we will need to provide ``discriminator``, ``real_data``,
and ``fake_data`` as input. Thus, an example of the ``data_info`` is:
.. code-block:: python
:linenos:
data_info = dict(
discriminator='disc',
real_data='real_imgs',
fake_data='fake_imgs')
Then, the module will automatically construct this mapping from the input
data dictionary.
Args:
loss_weight (float, optional): Weight of this loss item.
Defaults to ``1.``.
data_info (dict, optional): Dictionary contains the mapping between
loss input args and data dictionary. If ``None``, this module will
directly pass the input data to the loss function.
Defaults to None.
norm_mode (str): This argument decides along which dimension the norm
of the gradients will be calculated. Currently, we support ["pixel"
, "HWC"]. Defaults to "pixel".
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_gp'.
"""
def __init__(self,
loss_weight=1.0,
norm_mode='pixel',
data_info=None,
loss_name='loss_gp'):
super().__init__()
self.loss_weight = loss_weight
self.norm_mode = norm_mode
self.data_info = data_info
self._loss_name = loss_name
def forward(self, *args, **kwargs):
"""Forward function.
If ``self.data_info`` is not ``None``, a dictionary containing all of
the data and necessary modules should be passed into this function.
If this dictionary is given as a non-keyword argument, it should be
offered as the first argument. If you are using keyword argument,
please name it as `outputs_dict`.
If ``self.data_info`` is ``None``, the input argument or key-word
argument will be directly passed to loss function,
``gradient_penalty_loss``.
"""
# use data_info to build computational path
if self.data_info is not None:
# parse the args and kwargs
if len(args) == 1:
assert isinstance(args[0], dict), (
'You should offer a dictionary containing network outputs '
'for building up computational graph of this loss module.')
outputs_dict = args[0]
elif 'outputs_dict' in kwargs:
assert len(args) == 0, (
'If the outputs dict is given in keyworded arguments, no'
' further non-keyworded arguments should be offered.')
outputs_dict = kwargs.pop('outputs_dict')
else:
raise NotImplementedError(
'Cannot parsing your arguments passed to this loss module.'
' Please check the usage of this module')
# link the outputs with loss input args according to self.data_info
loss_input_dict = {
k: outputs_dict[v]
for k, v in self.data_info.items()
}
kwargs.update(loss_input_dict)
kwargs.update(
dict(weight=self.loss_weight, norm_mode=self.norm_mode))
return gradient_penalty_loss(**kwargs)
else:
# if you have not define how to build computational graph, this
# module will just directly return the loss as usual.
return gradient_penalty_loss(
*args, weight=self.loss_weight, **kwargs)
def loss_name(self):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return self._loss_name
@weighted_loss
def r1_gradient_penalty_loss(discriminator,
real_data,
mask=None,
norm_mode='pixel',
loss_scaler=None,
use_apex_amp=False):
"""Calculate R1 gradient penalty for WGAN-GP.
R1 regularizer comes from:
"Which Training Methods for GANs do actually Converge?" ICML'2018
Different from original gradient penalty, this regularizer only penalized
gradient w.r.t. real data.
Args:
discriminator (nn.Module): Network for the discriminator.
real_data (Tensor): Real input data.
mask (Tensor): Masks for inpainting. Default: None.
norm_mode (str): This argument decides along which dimension the norm
of the gradients will be calculated. Currently, we support ["pixel"
, "HWC"]. Defaults to "pixel".
Returns:
Tensor: A tensor for gradient penalty.
"""
batch_size = real_data.shape[0]
real_data = real_data.clone().requires_grad_()
disc_pred = discriminator(real_data)
if loss_scaler:
disc_pred = loss_scaler.scale(disc_pred)
elif use_apex_amp:
from apex.amp._amp_state import _amp_state
_loss_scaler = _amp_state.loss_scalers[0]
disc_pred = _loss_scaler.loss_scale() * disc_pred.float()
gradients = autograd.grad(
outputs=disc_pred,
inputs=real_data,
grad_outputs=torch.ones_like(disc_pred),
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
if loss_scaler:
# unscale the gradient
inv_scale = 1. / loss_scaler.get_scale()
gradients = gradients * inv_scale
elif use_apex_amp:
inv_scale = 1. / _loss_scaler.loss_scale()
gradients = gradients * inv_scale
if mask is not None:
gradients = gradients * mask
if norm_mode == 'pixel':
gradients_penalty = ((gradients.norm(2, dim=1))**2).mean()
elif norm_mode == 'HWC':
gradients_penalty = gradients.pow(2).reshape(batch_size,
-1).sum(1).mean()
else:
raise NotImplementedError(
'Currently, we only support ["pixel", "HWC"] '
f'norm mode but got {norm_mode}.')
if mask is not None:
gradients_penalty /= torch.mean(mask)
return gradients_penalty
@MODULES.register_module()
class R1GradientPenalty(nn.Module):
"""R1 gradient penalty for WGAN-GP.
R1 regularizer comes from:
"Which Training Methods for GANs do actually Converge?" ICML'2018
Different from original gradient penalty, this regularizer only penalized
gradient w.r.t. real data.
**Note for the design of ``data_info``:**
In ``MMGeneration``, almost all of loss modules contain the argument
``data_info``, which can be used for constructing the link between the
input items (needed in loss calculation) and the data from the generative
model. For example, in the training of GAN model, we will collect all of
important data/modules into a dictionary:
.. code-block:: python
:caption: Code from StaticUnconditionalGAN, train_step
:linenos:
data_dict_ = dict(
gen=self.generator,
disc=self.discriminator,
disc_pred_fake=disc_pred_fake,
disc_pred_real=disc_pred_real,
fake_imgs=fake_imgs,
real_imgs=real_imgs,
iteration=curr_iter,
batch_size=batch_size)
But in this loss, we will need to provide ``discriminator`` and
``real_data`` as input. Thus, an example of the ``data_info`` is:
.. code-block:: python
:linenos:
data_info = dict(
discriminator='disc',
real_data='real_imgs')
Then, the module will automatically construct this mapping from the input
data dictionary.
Args:
loss_weight (float, optional): Weight of this loss item.
Defaults to ``1.``.
data_info (dict, optional): Dictionary contains the mapping between
loss input args and data dictionary. If ``None``, this module will
directly pass the input data to the loss function.
Defaults to None.
norm_mode (str): This argument decides along which dimension the norm
of the gradients will be calculated. Currently, we support ["pixel"
, "HWC"]. Defaults to "pixel".
interval (int, optional): The interval of calculating this loss.
Defaults to 1.
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_r1_gp'.
"""
def __init__(self,
loss_weight=1.0,
norm_mode='pixel',
interval=1,
data_info=None,
use_apex_amp=False,
loss_name='loss_r1_gp'):
super().__init__()
self.loss_weight = loss_weight
self.norm_mode = norm_mode
self.interval = interval
self.data_info = data_info
self.use_apex_amp = use_apex_amp
self._loss_name = loss_name
def forward(self, *args, **kwargs):
"""Forward function.
If ``self.data_info`` is not ``None``, a dictionary containing all of
the data and necessary modules should be passed into this function.
If this dictionary is given as a non-keyword argument, it should be
offered as the first argument. If you are using keyword argument,
please name it as `outputs_dict`.
If ``self.data_info`` is ``None``, the input argument or key-word
argument will be directly passed to loss function,
``r1_gradient_penalty_loss``.
"""
if self.interval > 1:
assert self.data_info is not None
# use data_info to build computational path
if self.data_info is not None:
# parse the args and kwargs
if len(args) == 1:
assert isinstance(args[0], dict), (
'You should offer a dictionary containing network outputs '
'for building up computational graph of this loss module.')
outputs_dict = args[0]
elif 'outputs_dict' in kwargs:
assert len(args) == 0, (
'If the outputs dict is given in keyworded arguments, no'
' further non-keyworded arguments should be offered.')
outputs_dict = kwargs.pop('outputs_dict')
else:
raise NotImplementedError(
'Cannot parsing your arguments passed to this loss module.'
' Please check the usage of this module')
if self.interval > 1 and outputs_dict[
'iteration'] % self.interval != 0:
return None
# link the outputs with loss input args according to self.data_info
loss_input_dict = {
k: outputs_dict[v]
for k, v in self.data_info.items()
}
kwargs.update(loss_input_dict)
kwargs.update(
dict(
weight=self.loss_weight,
norm_mode=self.norm_mode,
use_apex_amp=self.use_apex_amp))
return r1_gradient_penalty_loss(**kwargs)
else:
# if you have not define how to build computational graph, this
# module will just directly return the loss as usual.
return r1_gradient_penalty_loss(
*args,
weight=self.loss_weight,
norm_mode=self.norm_mode,
**kwargs)
def loss_name(self):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return self._loss_name
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
import torch.nn.functional as F
from ..builder import MODULES
@MODULES.register_module()
class GANLoss(nn.Module):
"""Define GAN loss.
Args:
gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge',
'wgan-logistic-ns'.
real_label_val (float): The value for real label. Default: 1.0.
fake_label_val (float): The value for fake label. Default: 0.0.
loss_weight (float): Loss weight. Default: 1.0.
Note that loss_weight is only for generators; and it is always 1.0
for discriminators.
"""
def __init__(self,
gan_type,
real_label_val=1.0,
fake_label_val=0.0,
loss_weight=1.0):
super().__init__()
self.gan_type = gan_type
self.loss_weight = loss_weight
self.real_label_val = real_label_val
self.fake_label_val = fake_label_val
if self.gan_type == 'vanilla':
self.loss = nn.BCEWithLogitsLoss()
elif self.gan_type == 'lsgan':
self.loss = nn.MSELoss()
elif self.gan_type == 'wgan':
self.loss = self._wgan_loss
elif self.gan_type == 'wgan-logistic-ns':
self.loss = self._wgan_logistic_ns_loss
elif self.gan_type == 'hinge':
self.loss = nn.ReLU()
else:
raise NotImplementedError(
f'GAN type {self.gan_type} is not implemented.')
def _wgan_loss(self, input, target):
"""wgan loss.
Args:
input (Tensor): Input tensor.
target (bool): Target label.
Returns:
Tensor: wgan loss.
"""
return -input.mean() if target else input.mean()
def _wgan_logistic_ns_loss(self, input, target):
"""WGAN loss in logistically non-saturating mode.
This loss is widely used in StyleGANv2.
Args:
input (Tensor): Input tensor.
target (bool): Target label.
Returns:
Tensor: wgan loss.
"""
return F.softplus(-input).mean() if target else F.softplus(
input).mean()
def get_target_label(self, input, target_is_real):
"""Get target label.
Args:
input (Tensor): Input tensor.
target_is_real (bool): Whether the target is real or fake.
Returns:
(bool | Tensor): Target tensor. Return bool for wgan, otherwise, \
return Tensor.
"""
if self.gan_type in ['wgan', 'wgan-logistic-ns']:
return target_is_real
target_val = (
self.real_label_val if target_is_real else self.fake_label_val)
return input.new_ones(input.size()) * target_val
def forward(self, input, target_is_real, is_disc=False):
"""
Args:
input (Tensor): The input for the loss module, i.e., the network
prediction.
target_is_real (bool): Whether the targe is real or fake.
is_disc (bool): Whether the loss for discriminators or not.
Default: False.
Returns:
Tensor: GAN loss value.
"""
target_label = self.get_target_label(input, target_is_real)
if self.gan_type == 'hinge':
if is_disc: # for discriminators in hinge-gan
input = -input if target_is_real else input
loss = self.loss(1 + input).mean()
else: # for generators in hinge-gan
loss = -input.mean()
else: # other gan types
loss = self.loss(input, target_label)
# loss_weight is always 1.0 for discriminators
return loss if is_disc else loss * self.loss_weight
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.autograd as autograd
import torch.distributed as dist
import torch.nn as nn
import torchvision.models.vgg as vgg
from mmcv.runner import load_checkpoint
from mmgen.models.builder import MODULES, build_module
from mmgen.utils import get_root_logger
from .pixelwise_loss import l1_loss, mse_loss
def gen_path_regularizer(generator,
num_batches,
mean_path_length,
pl_batch_shrink=1,
decay=0.01,
weight=1.,
pl_batch_size=None,
sync_mean_buffer=False,
loss_scaler=None,
use_apex_amp=False):
"""Generator Path Regularization.
Path regularization is proposed in StyelGAN2, which can help the improve
the continuity of the latent space. More details can be found in:
Analyzing and Improving the Image Quality of StyleGAN, CVPR2020.
Args:
generator (nn.Module): The generator module. Note that this loss
requires that the generator contains ``return_latents`` interface,
with which we can get the latent code of the current sample.
num_batches (int): The number of samples used in calculating this loss.
mean_path_length (Tensor): The mean path length, calculated by moving
average.
pl_batch_shrink (int, optional): The factor of shrinking the batch size
for saving GPU memory. Defaults to 1.
decay (float, optional): Decay for moving average of mean path length.
Defaults to 0.01.
weight (float, optional): Weight of this loss item. Defaults to ``1.``.
pl_batch_size (int | None, optional): The batch size in calculating
generator path. Once this argument is set, the ``num_batches`` will
be overridden with this argument and won't be affectted by
``pl_batch_shrink``. Defaults to None.
sync_mean_buffer (bool, optional): Whether to sync mean path length
across all of GPUs. Defaults to False.
Returns:
tuple[Tensor]: The penalty loss, detached mean path tensor, and \
current path length.
"""
# reduce batch size for conserving GPU memory
if pl_batch_shrink > 1:
num_batches = max(1, num_batches // pl_batch_shrink)
# reset the batch size if pl_batch_size is not None
if pl_batch_size is not None:
num_batches = pl_batch_size
# get output from different generators
output_dict = generator(None, num_batches=num_batches, return_latents=True)
fake_img, latents = output_dict['fake_img'], output_dict['latent']
noise = torch.randn_like(fake_img) / np.sqrt(
fake_img.shape[2] * fake_img.shape[3])
if loss_scaler:
loss = loss_scaler.scale((fake_img * noise).sum())[0]
grad = autograd.grad(
outputs=loss,
inputs=latents,
grad_outputs=torch.ones(()).to(loss),
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
# unsacle the grad
inv_scale = 1. / loss_scaler.get_scale()
grad = grad * inv_scale
elif use_apex_amp:
from apex.amp._amp_state import _amp_state
# by default, we use loss_scalers[0] for discriminator and
# loss_scalers[1] for generator
_loss_scaler = _amp_state.loss_scalers[1]
loss = _loss_scaler.loss_scale() * ((fake_img * noise).sum()).float()
grad = autograd.grad(
outputs=loss,
inputs=latents,
grad_outputs=torch.ones(()).to(loss),
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
# unsacle the grad
inv_scale = 1. / _loss_scaler.loss_scale()
grad = grad * inv_scale
else:
grad = autograd.grad(
outputs=(fake_img * noise).sum(),
inputs=latents,
grad_outputs=torch.ones(()).to(fake_img),
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
# update mean path
path_mean = mean_path_length + decay * (
path_lengths.mean() - mean_path_length)
if sync_mean_buffer and dist.is_initialized():
dist.all_reduce(path_mean)
path_mean = path_mean / float(dist.get_world_size())
path_penalty = (path_lengths - path_mean).pow(2).mean() * weight
return path_penalty, path_mean.detach(), path_lengths
@MODULES.register_module()
class GeneratorPathRegularizer(nn.Module):
"""Generator Path Regularizer.
Path regularization is proposed in StyelGAN2, which can help the improve
the continuity of the latent space. More details can be found in:
Analyzing and Improving the Image Quality of StyleGAN, CVPR2020.
Users can achieve lazy regularization by setting ``interval`` arguments
here.
**Note for the design of ``data_info``:**
In ``MMGeneration``, almost all of loss modules contain the argument
``data_info``, which can be used for constructing the link between the
input items (needed in loss calculation) and the data from the generative
model. For example, in the training of GAN model, we will collect all of
important data/modules into a dictionary:
.. code-block:: python
:caption: Code from StaticUnconditionalGAN, train_step
:linenos:
data_dict_ = dict(
gen=self.generator,
disc=self.discriminator,
fake_imgs=fake_imgs,
disc_pred_fake_g=disc_pred_fake_g,
iteration=curr_iter,
batch_size=batch_size)
But in this loss, we will need to provide ``generator`` and ``num_batches``
as input. Thus an example of the ``data_info`` is:
.. code-block:: python
:linenos:
data_info = dict(
generator='gen',
num_batches='batch_size')
Then, the module will automatically construct this mapping from the input
data dictionary.
Args:
loss_weight (float, optional): Weight of this loss item.
Defaults to ``1.``.
pl_batch_shrink (int, optional): The factor of shrinking the batch size
for saving GPU memory. Defaults to 1.
decay (float, optional): Decay for moving average of mean path length.
Defaults to 0.01.
pl_batch_size (int | None, optional): The batch size in calculating
generator path. Once this argument is set, the ``num_batches`` will
be overridden with this argument and won't be affectted by
``pl_batch_shrink``. Defaults to None.
sync_mean_buffer (bool, optional): Whether to sync mean path length
across all of GPUs. Defaults to False.
interval (int, optional): The interval of calculating this loss. This
argument is used to support lazy regularization. Defaults to 1.
data_info (dict, optional): Dictionary contains the mapping between
loss input args and data dictionary. If ``None``, this module will
directly pass the input data to the loss function.
Defaults to None.
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_path_regular'.
"""
def __init__(self,
loss_weight=1.,
pl_batch_shrink=1,
decay=0.01,
pl_batch_size=None,
sync_mean_buffer=False,
interval=1,
data_info=None,
use_apex_amp=False,
loss_name='loss_path_regular'):
super().__init__()
self.loss_weight = loss_weight
self.pl_batch_shrink = pl_batch_shrink
self.decay = decay
self.pl_batch_size = pl_batch_size
self.sync_mean_buffer = sync_mean_buffer
self.interval = interval
self.data_info = data_info
self.use_apex_amp = use_apex_amp
self._loss_name = loss_name
self.register_buffer('mean_path_length', torch.tensor(0.))
def forward(self, *args, **kwargs):
"""Forward function.
If ``self.data_info`` is not ``None``, a dictionary containing all of
the data and necessary modules should be passed into this function.
If this dictionary is given as a non-keyword argument, it should be
offered as the first argument. If you are using keyword argument,
please name it as `outputs_dict`.
If ``self.data_info`` is ``None``, the input argument or key-word
argument will be directly passed to loss function,
``gen_path_regularizer``.
"""
if self.interval > 1:
assert self.data_info is not None
# use data_info to build computational path
if self.data_info is not None:
# parse the args and kwargs
if len(args) == 1:
assert isinstance(args[0], dict), (
'You should offer a dictionary containing network outputs '
'for building up computational graph of this loss module.')
outputs_dict = args[0]
elif 'outputs_dict' in kwargs:
assert len(args) == 0, (
'If the outputs dict is given in keyworded arguments, no'
' further non-keyworded arguments should be offered.')
outputs_dict = kwargs.pop('outputs_dict')
else:
raise NotImplementedError(
'Cannot parsing your arguments passed to this loss module.'
' Please check the usage of this module')
if self.interval > 1 and outputs_dict[
'iteration'] % self.interval != 0:
return None
# link the outputs with loss input args according to self.data_info
loss_input_dict = {
k: outputs_dict[v]
for k, v in self.data_info.items()
}
kwargs.update(loss_input_dict)
kwargs.update(
dict(
weight=self.loss_weight,
mean_path_length=self.mean_path_length,
pl_batch_shrink=self.pl_batch_shrink,
decay=self.decay,
use_apex_amp=self.use_apex_amp,
pl_batch_size=self.pl_batch_size,
sync_mean_buffer=self.sync_mean_buffer))
path_penalty, self.mean_path_length, _ = gen_path_regularizer(
**kwargs)
return path_penalty
else:
# if you have not define how to build computational graph, this
# module will just directly return the loss as usual.
return gen_path_regularizer(
*args, weight=self.loss_weight, **kwargs)
def loss_name(self):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return self._loss_name
def third_party_net_loss(net, weight=1.0, **kwargs):
return net(**kwargs) * weight
@MODULES.register_module()
class FaceIdLoss(nn.Module):
"""Face similarity loss. Generally this loss is used to keep the id
consistency of the input face image and output face image.
In this loss, we may need to provide ``gt``, ``pred`` and ``x``. Thus,
an example of the ``data_info`` is:
.. code-block:: python
:linenos:
data_info = dict(
gt='real_imgs',
pred='fake_imgs')
Then, the module will automatically construct this mapping from the input
data dictionary.
Args:
loss_weight (float, optional): Weight of this loss item.
Defaults to ``1.``.
data_info (dict, optional): Dictionary contains the mapping between
loss input args and data dictionary. If ``None``, this module will
directly pass the input data to the loss function.
Defaults to None.
facenet (dict, optional): Config dict for facenet. Defaults to
dict(type='ArcFace', ir_se50_weights=None, device='cuda').
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_id'.
"""
def __init__(self,
loss_weight=1.0,
data_info=None,
facenet=dict(
type='ArcFace', ir_se50_weights=None, device='cuda'),
loss_name='loss_id'):
super(FaceIdLoss, self).__init__()
self.loss_weight = loss_weight
self.data_info = data_info
self.net = build_module(facenet)
self._loss_name = loss_name
def forward(self, *args, **kwargs):
"""Forward function.
If ``self.data_info`` is not ``None``, a dictionary containing all of
the data and necessary modules should be passed into this function.
If this dictionary is given as a non-keyword argument, it should be
offered as the first argument. If you are using keyword argument,
please name it as `outputs_dict`.
If ``self.data_info`` is ``None``, the input argument or key-word
argument will be directly passed to loss function,
``third_party_net_loss``.
"""
# use data_info to build computational path
if self.data_info is not None:
# parse the args and kwargs
if len(args) == 1:
assert isinstance(args[0], dict), (
'You should offer a dictionary containing network outputs '
'for building up computational graph of this loss module.')
outputs_dict = args[0]
elif 'outputs_dict' in kwargs:
assert len(args) == 0, (
'If the outputs dict is given in keyworded arguments, no'
' further non-keyworded arguments should be offered.')
outputs_dict = kwargs.pop('outputs_dict')
else:
raise NotImplementedError(
'Cannot parsing your arguments passed to this loss module.'
' Please check the usage of this module')
# link the outputs with loss input args according to self.data_info
loss_input_dict = {
k: outputs_dict[v]
for k, v in self.data_info.items()
}
kwargs.update(loss_input_dict)
kwargs.update(dict(weight=self.loss_weight))
return third_party_net_loss(self.net, *args, **kwargs)
else:
# if you have not define how to build computational graph, this
# module will just directly return the loss as usual.
return third_party_net_loss(
self.net, *args, weight=self.loss_weight, **kwargs)
def loss_name(self):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return self._loss_name
class CLIPLossModel(torch.nn.Module):
"""Wrapped clip model to calculate clip loss.
Ref: https://github.com/orpatashnik/StyleCLIP/blob/main/criteria/clip_loss.py # noqa
Args:
in_size (int, optional): Input image size. Defaults to 1024.
scale_factor (int, optional): Unsampling factor. Defaults to 7.
pool_size (int, optional): Pooling output size. Defaults to 224.
clip_type (str, optional): A model name listed by
`clip.available_models()`, or the path to a model checkpoint
containing the state_dict. For more details, you can refer to
https://github.com/openai/CLIP/blob/573315e83f07b53a61ff5098757e8fc885f1703e/clip/clip.py#L91 # noqa
Defaults to 'ViT-B/32'.
device (str, optional): Model device. Defaults to 'cuda'.
"""
def __init__(self,
in_size=1024,
scale_factor=7,
pool_size=224,
clip_type='ViT-B/32',
device='cuda'):
super(CLIPLossModel, self).__init__()
try:
import clip
except ImportError:
raise 'To use clip loss, openai clip need to be installed first'
self.model, self.preprocess = clip.load(clip_type, device=device)
self.upsample = torch.nn.Upsample(scale_factor=scale_factor)
self.avg_pool = torch.nn.AvgPool2d(
kernel_size=(scale_factor * in_size // pool_size))
def forward(self, image=None, text=None):
"""Forward function."""
assert image is not None
assert text is not None
image = self.avg_pool(self.upsample(image))
loss = 1 - self.model(image, text)[0] / 100
return loss
@MODULES.register_module()
class CLIPLoss(nn.Module):
"""Clip loss. In styleclip, this loss is used to optimize the latent code
to generate image that match the text.
In this loss, we may need to provide ``image``, ``text``. Thus,
an example of the ``data_info`` is:
.. code-block:: python
:linenos:
data_info = dict(
image='fake_imgs',
text='descriptions')
Then, the module will automatically construct this mapping from the input
data dictionary.
Args:
loss_weight (float, optional): Weight of this loss item.
Defaults to ``1.``.
data_info (dict, optional): Dictionary contains the mapping between
loss input args and data dictionary. If ``None``, this module will
directly pass the input data to the loss function.
Defaults to None.
clip_model (dict, optional): Kwargs for clip loss model. Defaults to
dict().
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_clip'.
"""
def __init__(self,
loss_weight=1.0,
data_info=None,
clip_model=dict(),
loss_name='loss_clip'):
super(CLIPLoss, self).__init__()
self.loss_weight = loss_weight
self.data_info = data_info
self.net = CLIPLossModel(**clip_model)
self._loss_name = loss_name
def forward(self, *args, **kwargs):
"""Forward function.
If ``self.data_info`` is not ``None``, a dictionary containing all of
the data and necessary modules should be passed into this function.
If this dictionary is given as a non-keyword argument, it should be
offered as the first argument. If you are using keyword argument,
please name it as `outputs_dict`.
If ``self.data_info`` is ``None``, the input argument or key-word
argument will be directly passed to loss function,
``third_party_net_loss``.
"""
# use data_info to build computational path
if self.data_info is not None:
# parse the args and kwargs
if len(args) == 1:
assert isinstance(args[0], dict), (
'You should offer a dictionary containing network outputs '
'for building up computational graph of this loss module.')
outputs_dict = args[0]
elif 'outputs_dict' in kwargs:
assert len(args) == 0, (
'If the outputs dict is given in keyworded arguments, no'
' further non-keyworded arguments should be offered.')
outputs_dict = kwargs.pop('outputs_dict')
else:
raise NotImplementedError(
'Cannot parsing your arguments passed to this loss module.'
' Please check the usage of this module')
# link the outputs with loss input args according to self.data_info
loss_input_dict = {
k: outputs_dict[v]
for k, v in self.data_info.items()
}
kwargs.update(loss_input_dict)
kwargs.update(dict(weight=self.loss_weight))
return third_party_net_loss(self.net, *args, **kwargs)
else:
# if you have not define how to build computational graph, this
# module will just directly return the loss as usual.
return third_party_net_loss(
self.net, *args, weight=self.loss_weight, **kwargs)
@staticmethod
def loss_name():
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return 'clip_loss'
class PerceptualVGG(nn.Module):
"""VGG network used in calculating perceptual loss.
In this implementation, we allow users to choose whether use normalization
in the input feature and the type of vgg network. Note that the pretrained
path must fit the vgg type.
Args:
layer_name_list (list[str]): According to the name in this list,
forward function will return the corresponding features. This
list contains the name each layer in `vgg.feature`. An example
of this list is ['4', '10'].
vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
use_input_norm (bool): If True, normalize the input image.
Importantly, the input feature must in the range [0, 1].
Default: True.
pretrained (str): Path for pretrained weights. Default:
'torchvision://vgg19'
"""
def __init__(self,
layer_name_list,
vgg_type='vgg19',
use_input_norm=True,
pretrained='torchvision://vgg19'):
super().__init__()
if pretrained.startswith('torchvision://'):
assert vgg_type in pretrained
self.layer_name_list = layer_name_list
self.use_input_norm = use_input_norm
# get vgg model and load pretrained vgg weight
# remove _vgg from attributes to avoid `find_unused_parameters` bug
_vgg = getattr(vgg, vgg_type)()
self.init_weights(_vgg, pretrained)
num_layers = max(map(int, layer_name_list)) + 1
assert len(_vgg.features) >= num_layers
# only borrow layers that will be used from _vgg to avoid unused params
self.vgg_layers = _vgg.features[:num_layers]
if self.use_input_norm:
# the mean is for image with range [0, 1]
self.register_buffer(
'mean',
torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
# the std is for image with range [-1, 1]
self.register_buffer(
'std',
torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
for v in self.vgg_layers.parameters():
v.requires_grad = False
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
if self.use_input_norm:
x = (x - self.mean) / self.std
output = {}
for name, module in self.vgg_layers.named_children():
x = module(x)
if name in self.layer_name_list:
output[name] = x.clone()
return output
def init_weights(self, model, pretrained):
"""Init weights.
Args:
model (nn.Module): Models to be inited.
pretrained (str): Path for pretrained weights.
"""
logger = get_root_logger()
load_checkpoint(model, pretrained, logger=logger)
@MODULES.register_module()
class PerceptualLoss(nn.Module):
"""Perceptual loss with commonly used style loss.
.. code-block:: python
:caption: Code from StaticUnconditionalGAN, train_step
:linenos:
data_dict_ = dict(
gen=self.generator,
disc=self.discriminator,
disc_pred_fake=disc_pred_fake,
disc_pred_real=disc_pred_real,
fake_imgs=fake_imgs,
real_imgs=real_imgs,
iteration=curr_iter,
batch_size=batch_size)
But in this loss, we may need to provide ``pred`` and ``target`` as input.
Thus, an example of the ``data_info`` is:
.. code-block:: python
:linenos:
data_info = dict(
pred='fake_imgs',
target='real_imgs',
layer_weights={
'4': 1.,
'9': 1.,
'18': 1.},
)
Then, the module will automatically construct this mapping from the input
data dictionary.
Args:
data_info (dict, optional): Dictionary contains the mapping between
loss input args and data dictionary. If ``None``, this module will
directly pass the input data to the loss function.
Defaults to None.
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_mse'.
layers_weights (dict): The weight for each layer of vgg feature for
perceptual loss. Here is an example: {'4': 1., '9': 1., '18': 1.},
which means the 5th, 10th and 18th feature layer will be
extracted with weight 1.0 in calculating losses. Defaults to
'{'4': 1., '9': 1., '18': 1.}'.
layers_weights_style (dict): The weight for each layer of vgg feature
for style loss. If set to 'None', the weights are set equal to
the weights for perceptual loss. Default: None.
vgg_type (str): The type of vgg network used as feature extractor.
Default: 'vgg19'.
use_input_norm (bool): If True, normalize the input image in vgg.
Default: True.
perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
loss will be calculated and the loss will multiplied by the
weight. Default: 1.0.
style_weight (float): If `style_weight > 0`, the style loss will be
calculated and the loss will multiplied by the weight.
Default: 1.0.
norm_img (bool): If True, the image will be normed to [0, 1]. Note that
this is different from the `use_input_norm` which norm the input in
in forward function of vgg according to the statistics of dataset.
Importantly, the input image must be in range [-1, 1].
pretrained (str): Path for pretrained weights. Default:
'torchvision://vgg19'.
criterion (str): Criterion type. Options are 'l1' and 'mse'.
Default: 'l1'.
split_style_loss (bool): Whether return a separate style loss item.
Options are True and False. Default: False
"""
def __init__(self,
data_info=None,
loss_name='loss_perceptual',
layer_weights={
'4': 1.,
'9': 1.,
'18': 1.
},
layer_weights_style=None,
vgg_type='vgg19',
use_input_norm=True,
perceptual_weight=1.0,
style_weight=1.0,
norm_img=True,
pretrained='torchvision://vgg19',
criterion='l1',
split_style_loss=False):
super().__init__()
self.data_info = data_info
self._loss_name = loss_name
self.norm_img = norm_img
self.perceptual_weight = perceptual_weight
self.style_weight = style_weight
self.layer_weights = layer_weights
self.layer_weights_style = layer_weights_style
self.split_style_loss = split_style_loss
self.vgg = PerceptualVGG(
layer_name_list=list(self.layer_weights.keys()),
vgg_type=vgg_type,
use_input_norm=use_input_norm,
pretrained=pretrained)
if self.layer_weights_style is not None and \
self.layer_weights_style != self.layer_weights:
self.vgg_style = PerceptualVGG(
layer_name_list=list(self.layer_weights_style.keys()),
vgg_type=vgg_type,
use_input_norm=use_input_norm,
pretrained=pretrained)
else:
self.layer_weights_style = self.layer_weights
self.vgg_style = None
criterion = criterion.lower()
if criterion == 'l1':
self.criterion = l1_loss
elif criterion == 'mse':
self.criterion = mse_loss
else:
raise NotImplementedError(
f'{criterion} criterion has not been supported in'
' this version.')
def forward(self, *args, **kwargs):
"""Forward function. If ``self.data_info`` is not ``None``, a
dictionary containing all of the data and necessary modules should be
passed into this function. If this dictionary is given as a non-keyword
argument, it should be offered as the first argument. If you are using
keyword argument, please name it as `outputs_dict`.
If ``self.data_info`` is ``None``, the input argument or key-word
argument will be directly passed to loss function, ``mse_loss``.
Args:
pred (Tensor): Input tensor with shape (n, c, h, w).
target (Tensor): Ground-truth tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
# use data_info to build computational path
if self.data_info is not None:
# parse the args and kwargs
if len(args) == 1:
assert isinstance(args[0], dict), (
'You should offer a dictionary containing network outputs '
'for building up computational graph of this loss module.')
outputs_dict = args[0]
elif 'outputs_dict' in kwargs:
assert len(args) == 0, (
'If the outputs dict is given in keyworded arguments, no'
' further non-keyworded arguments should be offered.')
outputs_dict = kwargs.pop('outputs_dict')
else:
raise NotImplementedError(
'Cannot parsing your arguments passed to this loss module.'
' Please check the usage of this module')
# link the outputs with loss input args according to self.data_info
loss_input_dict = {
k: outputs_dict[v]
for k, v in self.data_info.items()
}
kwargs.update(loss_input_dict)
return self.perceptual_loss(**kwargs)
else:
# if you have not define how to build computational graph, this
# module will just directly return the loss as usual.
return self.perceptual_loss(*args, **kwargs)
def perceptual_loss(self, pred, target):
if self.norm_img:
pred = (pred + 1.) * 0.5
target = (target + 1.) * 0.5
# extract vgg features
pred_features = self.vgg(pred)
target_features = self.vgg(target.detach())
# calculate perceptual loss
if self.perceptual_weight > 0:
percep_loss = 0
for k in pred_features.keys():
percep_loss += self.criterion(
pred_features[k],
target_features[k],
weight=self.layer_weights[k])
percep_loss *= self.perceptual_weight
else:
percep_loss = 0.
# calculate style loss
if self.style_weight > 0:
if self.vgg_style is not None:
pred_features = self.vgg_style(pred)
target_features = self.vgg_style(target.detach())
style_loss = 0
for k in pred_features.keys():
style_loss += self.criterion(
self._gram_mat(pred_features[k]),
self._gram_mat(
target_features[k])) * self.layer_weights_style[k]
style_loss *= self.style_weight
else:
style_loss = 0.
if self.split_style_loss:
return percep_loss, style_loss
else:
return percep_loss + style_loss
def _gram_mat(self, x):
"""Calculate Gram matrix.
Args:
x (torch.Tensor): Tensor with shape of (n, c, h, w).
Returns:
torch.Tensor: Gram matrix.
"""
(n, c, h, w) = x.size()
features = x.view(n, c, w * h)
features_t = features.transpose(1, 2)
gram = features.bmm(features_t) / (c * h * w)
return gram
def loss_name(self):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return self._loss_name
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmgen.models.builder import MODULES
from .utils import weighted_loss
_reduction_modes = ['none', 'mean', 'sum', 'batchmean', 'flatmean']
@weighted_loss
def l1_loss(pred, target):
"""L1 loss.
Args:
pred (Tensor): Prediction Tensor with shape (n, c, h, w).
target (Tensor): Target Tensor with shape (n, c, h, w).
Returns:
Tensor: Calculated L1 loss.
"""
return F.l1_loss(pred, target, reduction='none')
@weighted_loss
def mse_loss(pred, target):
"""MSE loss.
Args:
pred (Tensor): Prediction Tensor with shape (n, c, h, w).
target (Tensor): Target Tensor with shape (n, c, h, w).
Returns:
Tensor: Calculated MSE loss.
"""
return F.mse_loss(pred, target, reduction='none')
@weighted_loss
def gaussian_kld(mean_target, mean_pred, logvar_target, logvar_pred, base='e'):
r"""Calculate KLD (Kullback-Leibler divergence) of two gaussian
distribution.
To be noted that in this function, KLD is calcuated in base `e`.
.. math::
:nowrap:
\begin{align}
KLD(p||q) &= -\int{p(x)\log{q(x)} dx} + \int{p(x)\log{p(x)} dx} \\
&= \frac{1}{2}\log{(2\pi \sigma_2^2)} +
\frac{\sigma_1^2 + (\mu_1 - \mu_2)^2}{2\sigma_2^2} -
\frac{1}{2}(1 + \log{2\pi \sigma_1^2}) \\
&= \log{\frac{\sigma_2}{\sigma_1}} +
\frac{\sigma_1^2 + (\mu_1 - \mu_2)^2}{2\sigma_2^2} - \frac{1}{2}
\end{align}
Args:
mean_target (torch.Tensor): Mean of the target (or the first)
distribution.
mean_pred (torch.Tensor): Mean of the predicted (or the second)
distribution.
logvar_target (torch.Tensor): Log variance of the target (or the first)
distribution
logvar_pred (torch.Tensor): Log variance of the predicted (or the
second) distribution.
base (str, optional): The log base of calculated KLD. We support
``'e'`` (for ln) and ``'2'`` (for log_2). Defaults to ``'e'``.
Returns:
torch.Tensor: KLD between two given distribution.
"""
if base not in ['e', '2']:
raise ValueError('Only support 2 and e for log base, but receive '
f'{base}')
kld = 0.5 * (-1.0 + logvar_pred - logvar_target +
torch.exp(logvar_target - logvar_pred) +
((mean_target - mean_pred)**2) * torch.exp(-logvar_pred))
if base == '2':
return kld / np.log(2.0)
return kld
def approx_gaussian_cdf(x):
r"""Approximate the cumulative distribution function of the gaussian distribution.
Refers to:
Approximations to the Cumulative Normal Function and its Inverse for Use on a Pocket Calculator # noqa
https://www.jstor.org/stable/2346872?origin=crossref
.. math::
:nowrap:
\begin{eqnarray}
\Phi(x) &\approx \frac{1}{2} \left ( 1 + \tanh(y) \right ) \\
y &= \sqrt{\frac{2}{\pi}}(x+0.044715 x^3)
\end{eqnarray}
Args:
x (torch.Tensor): Input data.
Returns:
torch.Tensor: Calculated cumulative distribution.
"""
factor = np.sqrt(2.0 / np.pi)
y = factor * (x + 0.044715 * torch.pow(x, 3))
phi = 0.5 * (1 + torch.tanh(y))
return phi
@weighted_loss
def discretized_gaussian_log_likelihood(x, mean, logvar, base='e'):
r"""Calculate gaussian log-likelihood for a discretized input. We assume
that the input `x` are ranged in [-1, 1], the likelihood term can be
calculated as the following equation:
.. math::
:nowrap:
\begin{equarray}
p_{\theta}(\mathbf{x}_0 | \mathbf{x}_1) =
\prod_{i=1}^{D} \int_{\delta_{-}(x_0^i)}^{\delta_{+}(x_0^i)}
{\mathcal{N}(x; \mu_{\theta}^i(\mathbf{x}_1, 1),
\sigma_{1}^2)}dx\\
\delta_{+}(x)= \begin{cases}
\infty & \text{if } x = 1 \\
x + \frac{1}{255} & \text{if } x < 1
\end{cases}
\quad
\delta_{-}(x)= \begin{cases}
-\infty & \text{if } x = -1 \\
x - \frac{1}{255} & \text{if } x > -1
\end{cases}
\end{equarray}
When calculating this loss term, we first normalize `x` to normal
distribution and calculate the above integral by the cumulative
distribution function of normal distribution. Then rescale results to the
target ones.
Args:
x (torch.Tensor): Target `x_0` to be modeled. Range in [-1, 1].
mean (torch.Tensor): Predicted mean of `x_0`.
logvar (torch.Tensor): Predicted log variance of `x_0`.
base (str, optional): The log base of calculated KLD. Support ``'e'``
and ``'2'``. Defaults to ``'e'``.
Returns:
torch.Tensor: Calculated log likelihood.
"""
if base not in ['e', '2']:
raise ValueError('Only support 2 and e for log base, but receive '
f'{base}')
inv_std = torch.exp(-logvar * 0.5)
x_centered = x - mean
lower_bound = (x_centered - 1.0 / 255.0) * inv_std
upper_bound = (x_centered + 1.0 / 255.0) * inv_std
cdf_to_lower = approx_gaussian_cdf(lower_bound)
cdf_to_upper = approx_gaussian_cdf(upper_bound)
log_cdf_upper = torch.log(cdf_to_upper.clamp(min=1e-12))
log_one_minus_cdf_lower = torch.log((1.0 - cdf_to_lower).clamp(min=1e-12))
log_cdf_delta = torch.log((cdf_to_upper - cdf_to_lower).clamp(min=1e-12))
log_probs = torch.where(
x < -0.999, log_cdf_upper,
torch.where(x > 0.999, log_one_minus_cdf_lower, log_cdf_delta))
if base == '2':
return log_probs / np.log(2.0)
return log_probs
@MODULES.register_module()
class MSELoss(nn.Module):
"""MSE loss.
**Note for the design of ``data_info``:**
In ``MMGeneration``, almost all of loss modules contain the argument
``data_info``, which can be used for constructing the link between the
input items (needed in loss calculation) and the data from the generative
model. For example, in the training of GAN model, we will collect all of
important data/modules into a dictionary:
.. code-block:: python
:caption: Code from StaticUnconditionalGAN, train_step
:linenos:
data_dict_ = dict(
gen=self.generator,
disc=self.discriminator,
disc_pred_fake=disc_pred_fake,
disc_pred_real=disc_pred_real,
fake_imgs=fake_imgs,
real_imgs=real_imgs,
iteration=curr_iter,
batch_size=batch_size)
But in this loss, we may need to provide ``pred`` and ``target`` as input.
Thus, an example of the ``data_info`` is:
.. code-block:: python
:linenos:
data_info = dict(
pred='fake_imgs',
target='real_imgs')
Then, the module will automatically construct this mapping from the input
data dictionary.
Args:
loss_weight (float, optional): Weight of this loss item.
Defaults to ``1.``.
data_info (dict, optional): Dictionary contains the mapping between
loss input args and data dictionary. If ``None``, this module will
directly pass the input data to the loss function.
Defaults to None.
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_mse'.
"""
def __init__(self, loss_weight=1.0, data_info=None, loss_name='loss_mse'):
super().__init__()
self.loss_weight = loss_weight
self.data_info = data_info
self._loss_name = loss_name
def forward(self, *args, **kwargs):
"""Forward function.
If ``self.data_info`` is not ``None``, a dictionary containing all of
the data and necessary modules should be passed into this function.
If this dictionary is given as a non-keyword argument, it should be
offered as the first argument. If you are using keyword argument,
please name it as `outputs_dict`.
If ``self.data_info`` is ``None``, the input argument or key-word
argument will be directly passed to loss function, ``mse_loss``.
"""
# use data_info to build computational path
if self.data_info is not None:
# parse the args and kwargs
if len(args) == 1:
assert isinstance(args[0], dict), (
'You should offer a dictionary containing network outputs '
'for building up computational graph of this loss module.')
outputs_dict = args[0]
elif 'outputs_dict' in kwargs:
assert len(args) == 0, (
'If the outputs dict is given in keyworded arguments, no'
' further non-keyworded arguments should be offered.')
outputs_dict = kwargs.pop('outputs_dict')
else:
raise NotImplementedError(
'Cannot parsing your arguments passed to this loss module.'
' Please check the usage of this module')
# link the outputs with loss input args according to self.data_info
loss_input_dict = {
k: outputs_dict[v]
for k, v in self.data_info.items()
}
kwargs.update(loss_input_dict)
kwargs.update(dict(weight=self.loss_weight))
return mse_loss(**kwargs)
else:
# if you have not define how to build computational graph, this
# module will just directly return the loss as usual.
return mse_loss(*args, weight=self.loss_weight, **kwargs)
def loss_name(self):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return self._loss_name
@MODULES.register_module()
class L1Loss(nn.Module):
"""L1 loss.
**Note for the design of ``data_info``:**
In ``MMGeneration``, almost all of loss modules contain the argument
``data_info``, which can be used for constructing the link between the
input items (needed in loss calculation) and the data from the generative
model. For example, in the training of GAN model, we will collect all of
important data/modules into a dictionary:
.. code-block:: python
:caption: Code from StaticUnconditionalGAN, train_step
:linenos:
data_dict_ = dict(
gen=self.generator,
disc=self.discriminator,
disc_pred_fake=disc_pred_fake,
disc_pred_real=disc_pred_real,
fake_imgs=fake_imgs,
real_imgs=real_imgs,
iteration=curr_iter,
batch_size=batch_size)
But in this loss, we may need to provide ``pred`` and ``target`` as input.
Thus, an example of the ``data_info`` is:
.. code-block:: python
:linenos:
data_info = dict(
pred='fake_imgs',
target='real_imgs')
Then, the module will automatically construct this mapping from the input
data dictionary.
Args:
loss_weight (float, optional): Weight of this loss item.
Defaults to ``1.``.
reduction (str, optional): Same as built-in losses of PyTorch.
Defaults to 'mean'.
avg_factor (float | None, optional): Average factor when computing the
mean of losses. Defaults to ``None``.
data_info (dict, optional): Dictionary contains the mapping between
loss input args and data dictionary. If ``None``, this module will
directly pass the input data to the loss function.
Defaults to None.
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_l1'.
"""
def __init__(self,
loss_weight=1.0,
reduction='mean',
avg_factor=None,
data_info=None,
loss_name='loss_l1'):
super().__init__()
if reduction not in _reduction_modes:
raise ValueError(f'Unsupported reduction mode: {reduction}. '
f'Supported ones are: {_reduction_modes}')
self.loss_weight = loss_weight
self.reduction = reduction
self.avg_factor = avg_factor
self.data_info = data_info
self._loss_name = loss_name
def forward(self, *args, **kwargs):
"""Forward function.
If ``self.data_info`` is not ``None``, a dictionary containing all of
the data and necessary modules should be passed into this function.
If this dictionary is given as a non-keyword argument, it should be
offered as the first argument. If you are using keyword argument,
please name it as `outputs_dict`.
If ``self.data_info`` is ``None``, the input argument or key-word
argument will be directly passed to loss function, ``l1_loss``.
"""
# use data_info to build computational path
if self.data_info is not None:
# parse the args and kwargs
if len(args) == 1:
assert isinstance(args[0], dict), (
'You should offer a dictionary containing network outputs '
'for building up computational graph of this loss module.')
outputs_dict = args[0]
elif 'outputs_dict' in kwargs:
assert len(args) == 0, (
'If the outputs dict is given in keyworded arguments, no'
' further non-keyworded arguments should be offered.')
outputs_dict = kwargs.pop('outputs_dict')
else:
raise NotImplementedError(
'Cannot parsing your arguments passed to this loss module.'
' Please check the usage of this module')
# link the outputs with loss input args according to self.data_info
loss_input_dict = {
k: outputs_dict[v]
for k, v in self.data_info.items()
}
kwargs.update(loss_input_dict)
kwargs.update(
dict(weight=self.loss_weight, reduction=self.reduction))
return l1_loss(**kwargs)
else:
# if you have not define how to build computational graph, this
# module will just directly return the loss as usual.
return l1_loss(
*args,
weight=self.loss_weight,
reduction=self.reduction,
avg_factor=self.avg_factor,
**kwargs)
def loss_name(self):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return self._loss_name
@MODULES.register_module()
class GaussianKLDLoss(nn.Module):
"""GaussianKLD loss.
**Note for the design of ``data_info``:**
In ``MMGeneration``, almost all of loss modules contain the argument
``data_info``, which can be used for constructing the link between the
input items (needed in loss calculation) and the data from the generative
model. For example, in the training of GAN model, we will collect all of
important data/modules into a dictionary:
.. code-block:: python
:caption: Code from BaseDiffusion, train_step
:linenos:
data_dict_ = dict(
denoising=denoising,
real_imgs=torch.Tensor([N, C, H, W]),
mean_pred=torch.Tensor([N, C, H, W]),
mean_target=torch.Tensor([N, C, H, W]),
logvar_pred=torch.Tensor([N, C, H, W]),
logvar_target=torch.Tensor([N, C, H, W]),
timesteps=torch.Tensor([N,]),
iteration=curr_iter,
batch_size=batch_size)
In this loss, we may need to provide ``mean_pred``, ``mean_target``,
``logvar_pred`` and ``logvar_target`` as input. Thus, an example of the
``data_info`` is:
.. code-block:: python
:linenos:
data_info = dict(
mean_pred='mean_pred',
mean_target='mean_target',
logvar_pred='logvar_pred',
logvar_target='logvar_target')
Then, the module will automatically construct this mapping from the input
data dictionary.
Args:
loss_weight (float, optional): Weight of this loss item.
Defaults to ``1.``.
reduction (str, optional): Same as built-in losses of PyTorch. Noted
that 'batchmean' mode given the correct KL divergence where losses
are averaged over batch dimension only. Defaults to 'mean'.
avg_factor (float | None, optional): Average factor when computing the
mean of losses. Defaults to ``None``.
data_info (dict, optional): Dictionary contains the mapping between
loss input args and data dictionary. If not passed,
``_default_data_info`` would be used. Defaults to None.
base (str, optional): The log base of calculated KLD. Support
``'e'`` and ``'2'``. Defaults to ``'e'``.
only_update_var (bool, optional): If true, only `logvar_pred` will be
updated and variable in output_dict corresponding to `mean_pred`
will be detached. Defaults to False.
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_l1'.
"""
_default_data_info = dict(
mean_pred='mean_pred',
mean_target='mean_target',
logvar_pred='logvar_pred',
logvar_target='logvar_target')
def __init__(self,
loss_weight=1.0,
reduction='mean',
avg_factor=None,
data_info=None,
base='e',
only_update_var=False,
loss_name='loss_GaussianKLD'):
super().__init__()
if reduction not in _reduction_modes:
raise ValueError(f'Unsupported reduction mode: {reduction}. '
f'Supported ones are: {_reduction_modes}')
self.loss_weight = loss_weight
self.reduction = reduction
self.avg_factor = avg_factor
self.data_info = self._default_data_info if data_info is None \
else data_info
self.base = base
self.only_update_var = only_update_var
self._loss_name = loss_name
def forward(self, *args, **kwargs):
"""Forward function.
If ``self.data_info`` is not ``None``, a dictionary containing all of
the data and necessary modules should be passed into this function.
If this dictionary is given as a non-keyword argument, it should be
offered as the first argument. If you are using keyword argument,
please name it as `outputs_dict`.
If ``self.data_info`` is ``None``, the input argument or key-word
argument will be directly passed to loss function,
``gaussian_kld_loss``.
"""
# parse the args and kwargs
if len(args) == 1:
assert isinstance(args[0], dict), (
'You should offer a dictionary containing network outputs '
'for building up computational graph of this loss module.')
outputs_dict = args[0]
elif 'outputs_dict' in kwargs:
assert len(args) == 0, (
'If the outputs dict is given in keyworded arguments, no'
' further non-keyworded arguments should be offered.')
outputs_dict = kwargs.pop('outputs_dict')
else:
raise NotImplementedError(
'Cannot parsing your arguments passed to this loss module.'
' Please check the usage of this module')
# link the outputs with loss input args according to self.data_info
loss_input_dict = dict()
for k, v in self.data_info.items():
if 'mean_pred' == k and self.only_update_var:
loss_input_dict[k] = outputs_dict[v].detach()
else:
loss_input_dict[k] = outputs_dict[v]
kwargs.update(loss_input_dict)
kwargs.update(
dict(
weight=self.loss_weight,
reduction=self.reduction,
base=self.base))
return gaussian_kld(**kwargs)
def loss_name(self):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return self._loss_name
# TODO: this name is toooooo long.
@MODULES.register_module()
class DiscretizedGaussianLogLikelihoodLoss(nn.Module):
r"""Discretized-Gaussian-Log-Likelihood Loss.
**Note for the design of ``data_info``:**
In ``MMGeneration``, almost all of loss modules contain the argument
``data_info``, which can be used for constructing the link between the
input items (needed in loss calculation) and the data from the generative
model. For example, in the training of GAN model, we will collect all of
important data/modules into a dictionary:
.. code-block:: python
:caption: Code from BaseDiffusion, train_step
:linenos:
data_dict_ = dict(
denoising=denoising,
real_imgs=torch.Tensor([N, C, H, W]),
mean_pred=torch.Tensor([N, C, H, W]),
mean_target=torch.Tensor([N, C, H, W]),
logvar_pred=torch.Tensor([N, C, H, W]),
logvar_target=torch.Tensor([N, C, H, W]),
timesteps=torch.Tensor([N,]),
iteration=curr_iter,
batch_size=batch_size)
In this loss, we may need to provide ``mean``, ``logvar`` and ``x``. Thus,
an example of the ``data_info`` is:
.. code-block:: python
:linenos:
data_info = dict(
x='real_imgs',
mean='mean_pred',
logvar='logvar_pred')
Then, the module will automatically construct this mapping from the input
data dictionary.
Args:
loss_weight (float, optional): Weight of this loss item.
Defaults to ``1.``.
reduction (str, optional): Same as built-in losses of PyTorch.
Defaults to 'mean'.
avg_factor (float | None, optional): Average factor when computing the
mean of losses. Defaults to ``None``.
data_info (dict, optional): Dictionary contains the mapping between
loss input args and data dictionary. If not passed,
``_default_data_info`` would be used. Defaults to None.
base (str, optional): The log base of calculated KLD. Support
``'e'`` and ``'2'``. Defaults to ``'e'``.
only_update_var (bool, optional): If true, only `logvar_pred` will be
updated and variable in output_dict corresponding to `mean_pred`
will be detached. Defaults to False.
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_l1'.
"""
_default_data_info = dict(
x='real_imgs', mean='mean_pred', logvar='logvar_pred')
def __init__(self,
loss_weight=1.0,
reduction='mean',
avg_factor=None,
data_info=None,
base='e',
only_update_var=False,
loss_name='loss_DiscGaussianLogLikelihood'):
super().__init__()
if reduction not in _reduction_modes:
raise ValueError(f'Unsupported reduction mode: {reduction}. '
f'Supported ones are: {_reduction_modes}')
self.loss_weight = loss_weight
self.reduction = reduction
self.avg_factor = avg_factor
self.data_info = self._default_data_info if data_info is None \
else data_info
self.base = base
self.only_update_var = only_update_var
self._loss_name = loss_name
def forward(self, *args, **kwargs):
"""Forward function.
If ``self.data_info`` is not ``None``, a dictionary containing all of
the data and necessary modules should be passed into this function.
If this dictionary is given as a non-keyword argument, it should be
offered as the first argument. If you are using keyword argument,
please name it as `outputs_dict`.
If ``self.data_info`` is ``None``, the input argument or key-word
argument will be directly passed to loss function,
``gaussian_kld_loss``.
"""
# parse the args and kwargs
if len(args) == 1:
assert isinstance(args[0], dict), (
'You should offer a dictionary containing network outputs '
'for building up computational graph of this loss module.')
outputs_dict = args[0]
elif 'outputs_dict' in kwargs:
assert len(args) == 0, (
'If the outputs dict is given in keyworded arguments, no'
' further non-keyworded arguments should be offered.')
outputs_dict = kwargs.pop('outputs_dict')
else:
raise NotImplementedError(
'Cannot parsing your arguments passed to this loss module.'
' Please check the usage of this module')
# link the outputs with loss input args according to self.data_info
loss_input_dict = dict()
for k, v in self.data_info.items():
if k == 'mean' and self.only_update_var:
loss_input_dict[k] = outputs_dict[v].detach()
else:
loss_input_dict[k] = outputs_dict[v]
kwargs.update(loss_input_dict)
kwargs.update(
dict(
weight=self.loss_weight,
reduction=self.reduction,
base=self.base))
return discretized_gaussian_log_likelihood(**kwargs)
def loss_name(self):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return self._loss_name
# Copyright (c) OpenMMLab. All rights reserved.
import functools
import torch.nn.functional as F
def reduce_loss(loss, reduction):
"""Reduce loss as specified.
Args:
loss (Tensor): Elementwise loss tensor.
reduction (str): Options are "none", "mean", "sum", "flatmean" and
"batchmean". 'none': no reduction will be applied. 'mean': the
output will be divided by the number of elements in the output.
'sum': the output will be summed. 'batchmean': the sum of the
output will be divided by batchsize. 'flatmean': each sample
will be divided by the number of element respectively and
output will shape as [bz, ].
Return:
Tensor: Reduced loss tensor.
"""
if reduction == 'batchmean':
return loss.sum() / loss.shape[0]
if reduction == 'flatmean':
return loss.mean(dim=list(range(1, loss.ndim)))
reduction_enum = F._Reduction.get_enum(reduction)
# none: 0, elementwise_mean:1, sum: 2
if reduction_enum == 0:
return loss
if reduction_enum == 1:
return loss.mean()
if reduction_enum == 2:
return loss.sum()
raise ValueError(f'reduction type {reduction} not supported')
def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
"""Apply element-wise weight and reduce loss.
Args:
loss (Tensor): Element-wise loss.
weight (Tensor): Element-wise weights.
reduction (str): Same as built-in losses of PyTorch.
avg_factor (float): Average factor when computing the mean of losses.
Returns:
Tensor: Processed loss values.
"""
# if weight is specified, apply element-wise weight
if weight is not None:
loss = loss * weight
# if avg_factor is not specified, just reduce the loss
if avg_factor is None:
loss = reduce_loss(loss, reduction)
else:
# if reduction is mean, then average the loss by avg_factor
if reduction == 'mean':
loss = loss.sum() / avg_factor
# if reduction is 'none', then do nothing, otherwise raise an error
elif reduction != 'none':
raise ValueError('avg_factor can not be used with reduction="sum"')
return loss
def weighted_loss(loss_func):
"""Create a weighted version of a given loss function.
To use this decorator, the loss function must have the signature like
`loss_func(pred, target, **kwargs)`. The function only needs to compute
element-wise loss without any reduction. This decorator will add weight
and reduction arguments to the function. The decorated function will have
the signature like `loss_func(pred, target, weight=None, reduction='mean',
avg_factor=None, **kwargs)`.
:Example:
>>> import torch
>>> @weighted_loss
>>> def l1_loss(pred, target):
>>> return (pred - target).abs()
>>> pred = torch.Tensor([0, 2, 3])
>>> target = torch.Tensor([1, 1, 1])
>>> weight = torch.Tensor([1, 0, 1])
>>> l1_loss(pred, target)
tensor(1.3333)
>>> l1_loss(pred, target, weight)
tensor(1.)
>>> l1_loss(pred, target, reduction='none')
tensor([1., 1., 2.])
>>> l1_loss(pred, target, weight, avg_factor=2)
tensor(1.5000)
"""
@functools.wraps(loss_func)
def wrapper(*args,
weight=None,
reduction='mean',
avg_factor=None,
**kwargs):
# get element-wise loss
loss = loss_func(*args, **kwargs)
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
return wrapper
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
from torchvision.utils import make_grid
def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
"""Convert torch Tensors into image numpy arrays.
After clamping to (min, max), image values will be normalized to [0, 1].
For different tensor shapes, this function will have different behaviors:
1. 4D mini-batch Tensor of shape (N x 3/1 x H x W):
Use `make_grid` to stitch images in the batch dimension, and then
convert it to numpy array.
2. 3D Tensor of shape (3/1 x H x W) and 2D Tensor of shape (H x W):
Directly change to numpy array.
Note that the image channel in input tensors should be RGB order. This
function will convert it to cv2 convention, i.e., (H x W x C) with BGR
order.
Args:
tensor (Tensor | list[Tensor]): Input tensors.
out_type (numpy type): Output types. If ``np.uint8``, transform outputs
to uint8 type with range [0, 255]; otherwise, float type with
range [0, 1]. Default: ``np.uint8``.
min_max (tuple): min and max values for clamp.
Returns:
(Tensor | list[Tensor]): 3D ndarray of shape (H x W x C) or 2D ndarray
of shape (H x W).
"""
if not (torch.is_tensor(tensor) or
(isinstance(tensor, list)
and all(torch.is_tensor(t) for t in tensor))):
raise TypeError(
f'tensor or list of tensors expected, got {type(tensor)}')
if torch.is_tensor(tensor):
tensor = [tensor]
result = []
for _tensor in tensor:
# Squeeze two times so that:
# 1. (1, 1, h, w) -> (h, w) or
# 3. (1, 3, h, w) -> (3, h, w) or
# 2. (n>1, 3/1, h, w) -> (n>1, 3/1, h, w)
_tensor = _tensor.squeeze(0).squeeze(0)
_tensor = _tensor.float().detach().cpu().clamp_(*min_max)
_tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
n_dim = _tensor.dim()
if n_dim == 4:
img_np = make_grid(
_tensor, nrow=int(np.sqrt(_tensor.size(0))),
normalize=False).numpy()
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))
elif n_dim == 3:
img_np = _tensor.numpy()
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))
elif n_dim == 2:
img_np = _tensor.numpy()
else:
raise ValueError('Only support 4D, 3D or 2D tensor. '
f'But received with dimension: {n_dim}')
if out_type == np.uint8:
# Unlike MATLAB, numpy.unit8() WILL NOT round by default.
img_np = (img_np * 255.0).round()
img_np = img_np.astype(out_type)
result.append(img_np)
result = result[0] if len(result) == 1 else result
return result
# Copyright (c) OpenMMLab. All rights reserved.
from .base_translation_model import BaseTranslationModel
from .cyclegan import CycleGAN
from .pix2pix import Pix2Pix
from .static_translation_gan import StaticTranslationGAN
__all__ = [
'Pix2Pix', 'CycleGAN', 'BaseTranslationModel', 'StaticTranslationGAN'
]
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from copy import deepcopy
import torch.nn as nn
from ..builder import MODELS
@MODELS.register_module()
class BaseTranslationModel(nn.Module, metaclass=ABCMeta):
"""Base Translation Model.
Translation models can transfer images from one domain to
another. Domain information like `default_domain`,
`reachable_domains` are needed to initialize the class.
And we also provide query functions like `is_domain_reachable`,
`get_other_domains`.
You can get a specific generator based on the domain,
and by specifying `target_domain` in the forward function,
you can decide the domain of generated images.
Considering the difference among different image translation models,
we only provide the external interfaces mentioned above.
When you implement image translation with a specific method,
you can inherit both `BaseTranslationModel`
and the method (e.g BaseGAN) and implement abstract methods.
Args:
default_domain (str): Default output domain.
reachable_domains (list[str]): Domains that can be generated by
the model.
related_domains (list[str]): Domains involved in training and
testing. `reachable_domains` must be contained in
`related_domains`. However, related_domains may contain
source domains that are used to retrieve source images from
data_batch but not in reachable_domains.
train_cfg (dict): Config for training. Default: None.
test_cfg (dict): Config for testing. Default: None.
"""
def __init__(self,
default_domain,
reachable_domains,
related_domains,
train_cfg=None,
test_cfg=None):
self._default_domain = default_domain
self._reachable_domains = reachable_domains
self._related_domains = related_domains
assert self._default_domain in self._reachable_domains
assert set(self._reachable_domains) <= set(self._related_domains)
self.train_cfg = deepcopy(train_cfg) if train_cfg else None
self.test_cfg = deepcopy(test_cfg) if test_cfg else None
self._parse_train_cfg()
if test_cfg is not None:
self._parse_test_cfg()
@abstractmethod
def _parse_train_cfg(self):
"""Parsing train config and set some attributes for training."""
@abstractmethod
def _parse_test_cfg(self):
"""Parsing test config and set some attributes for testing."""
def forward(self, img, test_mode=False, **kwargs):
"""Forward function.
Args:
img (tensor): Input image tensor.
test_mode (bool): Whether in test mode or not. Default: False.
kwargs (dict): Other arguments.
"""
if not test_mode:
return self.forward_train(img, **kwargs)
return self.forward_test(img, **kwargs)
def forward_train(self, img, target_domain, **kwargs):
"""Forward function for training.
Args:
img (tensor): Input image tensor.
target_domain (str): Target domain of output image.
kwargs (dict): Other arguments.
Returns:
dict: Forward results.
"""
target = self.translation(img, target_domain=target_domain, **kwargs)
results = dict(source=img, target=target)
return results
def forward_test(self, img, target_domain, **kwargs):
"""Forward function for testing.
Args:
img (tensor): Input image tensor.
target_domain (str): Target domain of output image.
kwargs (dict): Other arguments.
Returns:
dict: Forward results.
"""
target = self.translation(img, target_domain=target_domain, **kwargs)
results = dict(source=img.cpu(), target=target.cpu())
return results
def is_domain_reachable(self, domain):
"""Whether image of this domain can be generated."""
return domain in self._reachable_domains
def get_other_domains(self, domain):
"""get other domains."""
return list(set(self._related_domains) - set([domain]))
@abstractmethod
def _get_target_generator(self, domain):
"""get target generator."""
def translation(self, image, target_domain=None, **kwargs):
"""Translation Image to target style.
Args:
image (tensor): Image tensor with a shape of (N, C, H, W).
target_domain (str, optional): Target domain of output image.
Default to None.
Returns:
dict: Image tensor of target style.
"""
if target_domain is None:
target_domain = self._default_domain
_model = self._get_target_generator(target_domain)
outputs = _model(image, **kwargs)
return outputs
# Copyright (c) OpenMMLab. All rights reserved.
from torch.nn.parallel.distributed import _find_tensors
from mmgen.models.builder import MODELS
from ..common import GANImageBuffer, set_requires_grad
from .static_translation_gan import StaticTranslationGAN
@MODELS.register_module()
class CycleGAN(StaticTranslationGAN):
"""CycleGAN model for unpaired image-to-image translation.
Ref:
Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial
Networks
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# GAN image buffers
self.image_buffers = dict()
self.buffer_size = (50 if self.train_cfg is None else
self.train_cfg.get('buffer_size', 50))
for domain in self._reachable_domains:
self.image_buffers[domain] = GANImageBuffer(self.buffer_size)
self.use_ema = False
def forward_test(self, img, target_domain, **kwargs):
"""Forward function for testing.
Args:
img (tensor): Input image tensor.
target_domain (str): Target domain of output image.
kwargs (dict): Other arguments.
Returns:
dict: Forward results.
"""
# This is a trick for CycleGAN
# ref: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/e1bdf46198662b0f4d0b318e24568205ec4d7aee/test.py#L54 # noqa
self.train()
target = self.translation(img, target_domain=target_domain, **kwargs)
results = dict(source=img.cpu(), target=target.cpu())
return results
def _get_disc_loss(self, outputs):
"""Backward function for the discriminators.
Args:
outputs (dict): Dict of forward results.
Returns:
dict: Discriminators' loss and loss dict.
"""
discriminators = self.get_module(self.discriminators)
log_vars_d = dict()
loss_d = 0
# GAN loss for discriminators['a']
for domain in self._reachable_domains:
losses = dict()
fake_img = self.image_buffers[domain].query(
outputs[f'fake_{domain}'])
fake_pred = discriminators[domain](fake_img.detach())
losses[f'loss_gan_d_{domain}_fake'] = self.gan_loss(
fake_pred, target_is_real=False, is_disc=True)
real_pred = discriminators[domain](outputs[f'real_{domain}'])
losses[f'loss_gan_d_{domain}_real'] = self.gan_loss(
real_pred, target_is_real=True, is_disc=True)
_loss_d, _log_vars_d = self._parse_losses(losses)
_loss_d *= 0.5
loss_d += _loss_d
log_vars_d[f'loss_gan_d_{domain}'] = _log_vars_d['loss'] * 0.5
return loss_d, log_vars_d
def _get_gen_loss(self, outputs):
"""Backward function for the generators.
Args:
outputs (dict): Dict of forward results.
Returns:
dict: Generators' loss and loss dict.
"""
generators = self.get_module(self.generators)
discriminators = self.get_module(self.discriminators)
losses = dict()
for domain in self._reachable_domains:
# Identity reconstruction for generators
outputs[f'identity_{domain}'] = generators[domain](
outputs[f'real_{domain}'])
# GAN loss for generators
fake_pred = discriminators[domain](outputs[f'fake_{domain}'])
losses[f'loss_gan_g_{domain}'] = self.gan_loss(
fake_pred, target_is_real=True, is_disc=False)
# gen auxiliary loss
if self.with_gen_auxiliary_loss:
for loss_module in self.gen_auxiliary_losses:
loss_ = loss_module(outputs)
if loss_ is None:
continue
# the `loss_name()` function return name as 'loss_xxx'
if loss_module.loss_name() in losses:
losses[loss_module.loss_name(
)] = losses[loss_module.loss_name()] + loss_
else:
losses[loss_module.loss_name()] = loss_
loss_g, log_vars_g = self._parse_losses(losses)
return loss_g, log_vars_g
def _get_opposite_domain(self, domain):
for item in self._reachable_domains:
if item != domain:
return item
return None
def train_step(self,
data_batch,
optimizer,
ddp_reducer=None,
running_status=None):
"""Training step function.
Args:
data_batch (dict): Dict of the input data batch.
optimizer (dict[torch.optim.Optimizer]): Dict of optimizers for
the generators and discriminators.
ddp_reducer (:obj:`Reducer` | None, optional): Reducer from ddp.
It is used to prepare for ``backward()`` in ddp. Defaults to
None.
running_status (dict | None, optional): Contains necessary basic
information for training, e.g., iteration number. Defaults to
None.
Returns:
dict: Dict of loss, information for logger, the number of samples\
and results for visualization.
"""
# get running status
if running_status is not None:
curr_iter = running_status['iteration']
else:
# dirty walkround for not providing running status
if not hasattr(self, 'iteration'):
self.iteration = 0
curr_iter = self.iteration
# forward generators
outputs = dict()
for target_domain in self._reachable_domains:
# fetch data by domain
source_domain = self.get_other_domains(target_domain)[0]
img = data_batch[f'img_{source_domain}']
# translation process
results = self(img, test_mode=False, target_domain=target_domain)
outputs[f'real_{source_domain}'] = results['source']
outputs[f'fake_{target_domain}'] = results['target']
# cycle process
results = self(
results['target'],
test_mode=False,
target_domain=source_domain)
outputs[f'cycle_{source_domain}'] = results['target']
log_vars = dict()
# discriminators
set_requires_grad(self.discriminators, True)
# optimize
optimizer['discriminators'].zero_grad()
loss_d, log_vars_d = self._get_disc_loss(outputs)
log_vars.update(log_vars_d)
if ddp_reducer is not None:
ddp_reducer.prepare_for_backward(_find_tensors(loss_d))
loss_d.backward()
optimizer['discriminators'].step()
# generators, no updates to discriminator parameters.
if (curr_iter % self.disc_steps == 0
and curr_iter >= self.disc_init_steps):
set_requires_grad(self.discriminators, False)
# optimize
optimizer['generators'].zero_grad()
loss_g, log_vars_g = self._get_gen_loss(outputs)
log_vars.update(log_vars_g)
if ddp_reducer is not None:
ddp_reducer.prepare_for_backward(_find_tensors(loss_g))
loss_g.backward()
optimizer['generators'].step()
if hasattr(self, 'iteration'):
self.iteration += 1
image_results = dict()
for domain in self._reachable_domains:
image_results[f'real_{domain}'] = outputs[f'real_{domain}'].cpu()
image_results[f'fake_{domain}'] = outputs[f'fake_{domain}'].cpu()
results = dict(
log_vars=log_vars,
num_samples=len(outputs[f'real_{domain}']),
results=image_results)
return results
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch.nn.parallel.distributed import _find_tensors
from mmgen.models.builder import MODELS
from ..common import set_requires_grad
from .static_translation_gan import StaticTranslationGAN
@MODELS.register_module()
class Pix2Pix(StaticTranslationGAN):
"""Pix2Pix model for paired image-to-image translation.
Ref:
Image-to-Image Translation with Conditional Adversarial Networks
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.use_ema = False
def forward_test(self, img, target_domain, **kwargs):
"""Forward function for testing.
Args:
img (tensor): Input image tensor.
target_domain (str): Target domain of output image.
kwargs (dict): Other arguments.
Returns:
dict: Forward results.
"""
# This is a trick for Pix2Pix
# ref: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/e1bdf46198662b0f4d0b318e24568205ec4d7aee/test.py#L54 # noqa
self.train()
target = self.translation(img, target_domain=target_domain, **kwargs)
results = dict(source=img.cpu(), target=target.cpu())
return results
def _get_disc_loss(self, outputs):
# GAN loss for the discriminator
losses = dict()
discriminators = self.get_module(self.discriminators)
target_domain = self._default_domain
source_domain = self.get_other_domains(target_domain)[0]
fake_ab = torch.cat((outputs[f'real_{source_domain}'],
outputs[f'fake_{target_domain}']), 1)
fake_pred = discriminators[target_domain](fake_ab.detach())
losses['loss_gan_d_fake'] = self.gan_loss(
fake_pred, target_is_real=False, is_disc=True)
real_ab = torch.cat((outputs[f'real_{source_domain}'],
outputs[f'real_{target_domain}']), 1)
real_pred = discriminators[target_domain](real_ab)
losses['loss_gan_d_real'] = self.gan_loss(
real_pred, target_is_real=True, is_disc=True)
loss_d, log_vars_d = self._parse_losses(losses)
loss_d *= 0.5
return loss_d, log_vars_d
def _get_gen_loss(self, outputs):
target_domain = self._default_domain
source_domain = self.get_other_domains(target_domain)[0]
losses = dict()
discriminators = self.get_module(self.discriminators)
# GAN loss for the generator
fake_ab = torch.cat((outputs[f'real_{source_domain}'],
outputs[f'fake_{target_domain}']), 1)
fake_pred = discriminators[target_domain](fake_ab)
losses['loss_gan_g'] = self.gan_loss(
fake_pred, target_is_real=True, is_disc=False)
# gen auxiliary loss
if self.with_gen_auxiliary_loss:
for loss_module in self.gen_auxiliary_losses:
loss_ = loss_module(outputs)
if loss_ is None:
continue
# the `loss_name()` function return name as 'loss_xxx'
if loss_module.loss_name() in losses:
losses[loss_module.loss_name(
)] = losses[loss_module.loss_name()] + loss_
else:
losses[loss_module.loss_name()] = loss_
loss_g, log_vars_g = self._parse_losses(losses)
return loss_g, log_vars_g
def train_step(self,
data_batch,
optimizer,
ddp_reducer=None,
running_status=None):
"""Training step function.
Args:
data_batch (dict): Dict of the input data batch.
optimizer (dict[torch.optim.Optimizer]): Dict of optimizers for
the generator and discriminator.
ddp_reducer (:obj:`Reducer` | None, optional): Reducer from ddp.
It is used to prepare for ``backward()`` in ddp. Defaults to
None.
running_status (dict | None, optional): Contains necessary basic
information for training, e.g., iteration number. Defaults to
None.
Returns:
dict: Dict of loss, information for logger, the number of samples\
and results for visualization.
"""
# data
target_domain = self._default_domain
source_domain = self.get_other_domains(self._default_domain)[0]
source_image = data_batch[f'img_{source_domain}']
target_image = data_batch[f'img_{target_domain}']
# get running status
if running_status is not None:
curr_iter = running_status['iteration']
else:
# dirty walkround for not providing running status
if not hasattr(self, 'iteration'):
self.iteration = 0
curr_iter = self.iteration
# forward generator
outputs = dict()
results = self(
source_image, target_domain=self._default_domain, test_mode=False)
outputs[f'real_{source_domain}'] = results['source']
outputs[f'fake_{target_domain}'] = results['target']
outputs[f'real_{target_domain}'] = target_image
log_vars = dict()
# discriminator
set_requires_grad(self.discriminators, True)
# optimize
optimizer['discriminators'].zero_grad()
loss_d, log_vars_d = self._get_disc_loss(outputs)
log_vars.update(log_vars_d)
# prepare for backward in ddp. If you do not call this function before
# back propagation, the ddp will not dynamically find the used params
# in current computation.
if ddp_reducer is not None:
ddp_reducer.prepare_for_backward(_find_tensors(loss_d))
loss_d.backward()
optimizer['discriminators'].step()
# generator, no updates to discriminator parameters.
if (curr_iter % self.disc_steps == 0
and curr_iter >= self.disc_init_steps):
set_requires_grad(self.discriminators, False)
# optimize
optimizer['generators'].zero_grad()
loss_g, log_vars_g = self._get_gen_loss(outputs)
log_vars.update(log_vars_g)
# prepare for backward in ddp. If you do not call this function
# before back propagation, the ddp will not dynamically find the
# used params in current computation.
if ddp_reducer is not None:
ddp_reducer.prepare_for_backward(_find_tensors(loss_g))
loss_g.backward()
optimizer['generators'].step()
if hasattr(self, 'iteration'):
self.iteration += 1
image_results = dict()
image_results[f'real_{source_domain}'] = outputs[
f'real_{source_domain}'].cpu()
image_results[f'fake_{target_domain}'] = outputs[
f'fake_{target_domain}'].cpu()
image_results[f'real_{target_domain}'] = outputs[
f'real_{target_domain}'].cpu()
results = dict(
log_vars=log_vars,
num_samples=len(outputs[f'real_{source_domain}']),
results=image_results)
return results
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
import torch.nn as nn
from mmcv.parallel import MMDistributedDataParallel
from ..builder import MODELS, build_module
from ..gans import BaseGAN
from .base_translation_model import BaseTranslationModel
@MODELS.register_module()
class StaticTranslationGAN(BaseTranslationModel, BaseGAN):
"""Basic translation model based on static unconditional GAN.
Args:
generator (dict): Config for the generator.
discriminator (dict): Config for the discriminator.
gan_loss (dict): Config for the gan loss.
pretrained (str | optional): Path for pretrained model.
Defaults to None.
disc_auxiliary_loss (dict | optional): Config for auxiliary loss to
discriminator. Defaults to None.
gen_auxiliary_loss (dict | optional): Config for auxiliary loss
to generator. Defaults to None.
"""
def __init__(self,
generator,
discriminator,
gan_loss,
*args,
pretrained=None,
disc_auxiliary_loss=None,
gen_auxiliary_loss=None,
**kwargs):
BaseGAN.__init__(self)
BaseTranslationModel.__init__(self, *args, **kwargs)
# Building generators and discriminators
self._gen_cfg = deepcopy(generator)
# build domain generators
self.generators = nn.ModuleDict()
for domain in self._reachable_domains:
self.generators[domain] = build_module(generator)
self._disc_cfg = deepcopy(discriminator)
# build domain discriminators
if discriminator is not None:
self.discriminators = nn.ModuleDict()
for domain in self._reachable_domains:
self.discriminators[domain] = build_module(discriminator)
# support no discriminator in testing
else:
self.discriminators = None
# support no gan_loss in testing
if gan_loss is not None:
self.gan_loss = build_module(gan_loss)
else:
self.gan_loss = None
if disc_auxiliary_loss:
self.disc_auxiliary_losses = build_module(disc_auxiliary_loss)
if not isinstance(self.disc_auxiliary_losses, nn.ModuleList):
self.disc_auxiliary_losses = nn.ModuleList(
[self.disc_auxiliary_losses])
else:
self.disc_auxiliary_loss = None
if gen_auxiliary_loss:
self.gen_auxiliary_losses = build_module(gen_auxiliary_loss)
if not isinstance(self.gen_auxiliary_losses, nn.ModuleList):
self.gen_auxiliary_losses = nn.ModuleList(
[self.gen_auxiliary_losses])
else:
self.gen_auxiliary_losses = None
self.init_weights(pretrained)
def init_weights(self, pretrained=None):
"""Initialize weights for the model.
Args:
pretrained (str, optional): Path for pretrained weights. If given
None, pretrained weights will not be loaded. Default: None.
"""
for domain in self._reachable_domains:
self.generators[domain].init_weights(pretrained=pretrained)
self.discriminators[domain].init_weights(pretrained=pretrained)
def _parse_train_cfg(self):
"""Parsing train config and set some attributes for training."""
if self.train_cfg is None:
self.train_cfg = dict()
# control the work flow in train step
self.disc_steps = self.train_cfg.get('disc_steps', 1)
self.disc_init_steps = (0 if self.train_cfg is None else
self.train_cfg.get('disc_init_steps', 0))
self.real_img_key = self.train_cfg.get('real_img_key', 'real_img')
def _parse_test_cfg(self):
"""Parsing test config and set some attributes for testing."""
if self.test_cfg is None:
self.test_cfg = dict()
# basic testing information
self.batch_size = self.test_cfg.get('batch_size', 1)
def get_module(self, module):
"""Get `nn.ModuleDict` to fit the `MMDistributedDataParallel`
interface.
Args:
module (MMDistributedDataParallel | nn.ModuleDict): The input
module that needs processing.
Returns:
nn.ModuleDict: The ModuleDict of multiple networks.
"""
if isinstance(module, MMDistributedDataParallel):
return module.module
return module
def _get_target_generator(self, domain):
"""get target generator."""
assert self.is_domain_reachable(
domain
), f'{domain} domain is not reachable, available domain list is\
{self._reachable_domains}'
return self.get_module(self.generators)[domain]
def _get_target_discriminator(self, domain):
"""get target discriminator."""
assert self.is_domain_reachable(
domain
), f'{domain} domain is not reachable, available domain list is\
{self._reachable_domains}'
return self.get_module(self.discriminators)[domain]
# Copyright (c) OpenMMLab. All rights reserved.
from .conv2d_gradfix import conv2d, conv_transpose2d
from .stylegan3.ops import bias_act, filtered_lrelu
__all__ = ['conv2d', 'conv_transpose2d', 'filtered_lrelu', 'bias_act']
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Custom replacement for `torch.nn.functional.conv2d` that supports
arbitrarily high order gradients with zero performance penalty."""
import contextlib
import torch
enabled = True
weight_gradients_disabled = False
@contextlib.contextmanager
def no_weight_gradients(disable=True):
global weight_gradients_disabled
old = weight_gradients_disabled
if disable:
weight_gradients_disabled = True
yield
weight_gradients_disabled = old
def conv2d(input,
weight,
bias=None,
stride=1,
padding=0,
dilation=1,
groups=1):
if _should_use_custom_op(input):
return _conv2d_gradfix(
transpose=False,
weight_shape=weight.shape,
stride=stride,
padding=padding,
output_padding=0,
dilation=dilation,
groups=groups).apply(input, weight, bias)
return torch.nn.functional.conv2d(
input=input,
weight=weight,
bias=bias,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups)
def conv_transpose2d(input,
weight,
bias=None,
stride=1,
padding=0,
output_padding=0,
groups=1,
dilation=1):
if _should_use_custom_op(input):
return _conv2d_gradfix(
transpose=True,
weight_shape=weight.shape,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation).apply(input, weight, bias)
return torch.nn.functional.conv_transpose2d(
input=input,
weight=weight,
bias=bias,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation)
def _should_use_custom_op(input):
assert isinstance(input, torch.Tensor)
if (not enabled) or (not torch.backends.cudnn.enabled):
return False
if input.device.type != 'cuda':
return False
return True
def _tuple_of_ints(xs, ndim):
xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs, ) * ndim
assert len(xs) == ndim
assert all(isinstance(x, int) for x in xs)
return xs
_conv2d_gradfix_cache = dict()
_null_tensor = torch.empty([0])
def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding,
dilation, groups):
# Parse arguments.
ndim = 2
weight_shape = tuple(weight_shape)
stride = _tuple_of_ints(stride, ndim)
padding = _tuple_of_ints(padding, ndim)
output_padding = _tuple_of_ints(output_padding, ndim)
dilation = _tuple_of_ints(dilation, ndim)
# Lookup from cache.
key = (transpose, weight_shape, stride, padding, output_padding, dilation,
groups)
if key in _conv2d_gradfix_cache:
return _conv2d_gradfix_cache[key]
# Validate arguments.
assert groups >= 1
assert len(weight_shape) == ndim + 2
assert all(stride[i] >= 1 for i in range(ndim))
assert all(padding[i] >= 0 for i in range(ndim))
assert all(dilation[i] >= 0 for i in range(ndim))
if not transpose:
assert all(output_padding[i] == 0 for i in range(ndim))
else: # transpose
assert all(0 <= output_padding[i] < max(stride[i], dilation[i])
for i in range(ndim))
# Helpers.
common_kwargs = dict(
stride=stride, padding=padding, dilation=dilation, groups=groups)
def calc_output_padding(input_shape, output_shape):
if transpose:
return [0, 0]
return [
input_shape[i + 2] - (output_shape[i + 2] - 1) * stride[i] -
(1 - 2 * padding[i]) - dilation[i] * (weight_shape[i + 2] - 1)
for i in range(ndim)
]
# Forward & backward.
class Conv2d(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, bias):
assert weight.shape == weight_shape
ctx.save_for_backward(
input if weight.requires_grad else _null_tensor,
weight if input.requires_grad else _null_tensor,
)
ctx.input_shape = input.shape
# Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere).
if weight_shape[2:] == stride == dilation == (
1, 1) and padding == (
0, 0) and torch.cuda.get_device_capability(
input.device) < (8, 0):
a = weight.reshape(groups, weight_shape[0] // groups,
weight_shape[1])
b = input.reshape(input.shape[0], groups,
input.shape[1] // groups, -1)
c = (a.transpose(1, 2) if transpose else a) @ b.permute(
1, 2, 0, 3).flatten(2)
c = c.reshape(-1, input.shape[0],
*input.shape[2:]).transpose(0, 1)
c = c if bias is None else c + bias.unsqueeze(0).unsqueeze(
2).unsqueeze(3)
return c.contiguous(
memory_format=(torch.channels_last if input.stride(1) ==
1 else torch.contiguous_format))
# General case => cuDNN.
if transpose:
return torch.nn.functional.conv_transpose2d(
input=input,
weight=weight,
bias=bias,
output_padding=output_padding,
**common_kwargs)
return torch.nn.functional.conv2d(
input=input, weight=weight, bias=bias, **common_kwargs)
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
input_shape = ctx.input_shape
grad_input = None
grad_weight = None
grad_bias = None
if ctx.needs_input_grad[0]:
p = calc_output_padding(
input_shape=input_shape, output_shape=grad_output.shape)
op = _conv2d_gradfix(
transpose=(not transpose),
weight_shape=weight_shape,
output_padding=p,
**common_kwargs)
grad_input = op.apply(grad_output, weight, None)
assert grad_input.shape == input_shape
if ctx.needs_input_grad[1] and not weight_gradients_disabled:
grad_weight = Conv2dGradWeight.apply(grad_output, input)
assert grad_weight.shape == weight_shape
if ctx.needs_input_grad[2]:
grad_bias = grad_output.sum([0, 2, 3])
return grad_input, grad_weight, grad_bias
# Gradient with respect to the weights.
class Conv2dGradWeight(torch.autograd.Function):
@staticmethod
def forward(ctx, grad_output, input):
ctx.save_for_backward(
grad_output if input.requires_grad else _null_tensor,
input if grad_output.requires_grad else _null_tensor,
)
ctx.grad_output_shape = grad_output.shape
ctx.input_shape = input.shape
# Simple 1x1 convolution => cuBLAS (on both Volta and Ampere).
if weight_shape[2:] == stride == dilation == (
1, 1) and padding == (0, 0):
a = grad_output.reshape(grad_output.shape[0], groups,
grad_output.shape[1] // groups,
-1).permute(1, 2, 0, 3).flatten(2)
b = input.reshape(input.shape[0], groups,
input.shape[1] // groups,
-1).permute(1, 2, 0, 3).flatten(2)
c = (b @ a.transpose(1, 2) if transpose else
a @ b.transpose(1, 2)).reshape(weight_shape)
return c.contiguous(
memory_format=(torch.channels_last if input.stride(1) ==
1 else torch.contiguous_format))
# General case => cuDNN.
name = ('aten::cudnn_convolution_transpose_backward_weight' if
transpose else 'aten::cudnn_convolution_backward_weight')
flags = [
torch.backends.cudnn.benchmark,
torch.backends.cudnn.deterministic,
torch.backends.cudnn.allow_tf32
]
return torch._C._jit_get_operation(name)(weight_shape, grad_output,
input, padding, stride,
dilation, groups, *flags)
@staticmethod
def backward(ctx, grad2_grad_weight):
grad_output, input = ctx.saved_tensors
grad_output_shape = ctx.grad_output_shape
input_shape = ctx.input_shape
grad2_grad_output = None
grad2_input = None
if ctx.needs_input_grad[0]:
grad2_grad_output = Conv2d.apply(input, grad2_grad_weight,
None)
assert grad2_grad_output.shape == grad_output_shape
if ctx.needs_input_grad[1]:
p = calc_output_padding(
input_shape=input_shape, output_shape=grad_output_shape)
op = _conv2d_gradfix(
transpose=(not transpose),
weight_shape=weight_shape,
output_padding=p,
**common_kwargs)
grad2_input = op.apply(grad_output, grad2_grad_weight, None)
assert grad2_input.shape == input_shape
return grad2_grad_output, grad2_input
_conv2d_gradfix_cache[key] = Conv2d
return Conv2d
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
# empty
from .ops import filtered_lrelu
__all__ = ['filtered_lrelu']
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import glob
import hashlib
import importlib
import os
import re
import shutil
import uuid
import torch
import torch.utils.cpp_extension
# Global options.
verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
# Internal helper funcs.
def _find_compiler_bindir():
patterns = [
'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', # noqa
'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', # noqa
'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', # noqa
'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
]
for pattern in patterns:
matches = sorted(glob.glob(pattern))
if len(matches):
return matches[-1]
return None
def _get_mangled_gpu_name():
name = torch.cuda.get_device_name().lower()
out = []
for c in name:
if re.match('[a-z0-9_-]+', c):
out.append(c)
else:
out.append('-')
return ''.join(out)
# Main entry point for compiling and loading C++/CUDA plugins.
_cached_plugins = dict()
def get_plugin(module_name,
sources,
headers=None,
source_dir=None,
**build_kwargs):
assert verbosity in ['none', 'brief', 'full']
if headers is None:
headers = []
if source_dir is not None:
sources = [os.path.join(source_dir, fname) for fname in sources]
headers = [os.path.join(source_dir, fname) for fname in headers]
# Already cached?
if module_name in _cached_plugins:
return _cached_plugins[module_name]
# Print status.
if verbosity == 'full':
print(f'Setting up PyTorch plugin "{module_name}"...')
elif verbosity == 'brief':
print(
f'Setting up PyTorch plugin "{module_name}"... ',
end='',
flush=True)
verbose_build = (verbosity == 'full')
# Compile and load.
try: # pylint: disable=too-many-nested-blocks
# Make sure we can find the necessary compiler binaries.
if os.name == 'nt' and os.system('where cl.exe >nul 2>nul') != 0:
compiler_bindir = _find_compiler_bindir()
if compiler_bindir is None:
raise RuntimeError(
'Could not find MSVC/GCC/CLANG installation on this '
f'computer. Check _find_compiler_bindir() in "{__file__}".'
)
os.environ['PATH'] += ';' + compiler_bindir
# Some containers set TORCH_CUDA_ARCH_LIST to a list that can either
# break the build or unnecessarily restrict what's available to nvcc.
# Unset it to let nvcc decide based on what's available on the
# machine.
os.environ['TORCH_CUDA_ARCH_LIST'] = ''
# Incremental build md5sum trickery. Copies all the input source files
# into a cached build directory under a combined md5 digest of the
# input source files. Copying is done only if the combined digest has
# changed.
# This keeps input file timestamps and filenames the same as in
# previous extension builds, allowing for fast incremental rebuilds.
#
# This optimization is done only in case all the source files reside in
# a single directory (just for simplicity) and if the
# TORCH_EXTENSIONS_DIR environment variable is set (we take this as a
# signal that the user
# actually cares about this.)
#
# EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to
# work around the *.cu dependency bug in ninja config.
all_source_files = sorted(sources + headers)
all_source_dirs = set(
os.path.dirname(fname) for fname in all_source_files)
if len(all_source_dirs
) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ):
# Compute combined hash digest for all source files.
hash_md5 = hashlib.md5()
for src in all_source_files:
with open(src, 'rb') as f:
hash_md5.update(f.read())
# Select cached build directory name.
source_digest = hash_md5.hexdigest()
build_top_dir = torch.utils.cpp_extension._get_build_directory(
module_name, verbose=verbose_build)
cached_build_dir = os.path.join(
build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}')
if not os.path.isdir(cached_build_dir):
tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}'
os.makedirs(tmpdir)
for src in all_source_files:
shutil.copyfile(
src, os.path.join(tmpdir, os.path.basename(src)))
try:
os.replace(tmpdir, cached_build_dir) # atomic
except OSError:
# source directory already exists
# delete tmpdir and its contents.
shutil.rmtree(tmpdir)
if not os.path.isdir(cached_build_dir):
raise
# Compile.
cached_sources = [
os.path.join(cached_build_dir, os.path.basename(fname))
for fname in sources
]
torch.utils.cpp_extension.load(
name=module_name,
build_directory=cached_build_dir,
verbose=verbose_build,
sources=cached_sources,
**build_kwargs)
else:
torch.utils.cpp_extension.load(
name=module_name,
verbose=verbose_build,
sources=sources,
**build_kwargs)
# Load.
module = importlib.import_module(module_name)
except Exception as err:
if verbosity == 'brief':
print('Failed!')
raise err
# Print status and add to cache dict.
if verbosity == 'full':
print(f'Done setting up PyTorch plugin "{module_name}".')
elif verbosity == 'brief':
print('Done.')
_cached_plugins[module_name] = module
return module
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
# empty
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "bias_act.h"
//------------------------------------------------------------------------
static bool has_same_layout(torch::Tensor x, torch::Tensor y)
{
if (x.dim() != y.dim())
return false;
for (int64_t i = 0; i < x.dim(); i++)
{
if (x.size(i) != y.size(i))
return false;
if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
return false;
}
return true;
}
//------------------------------------------------------------------------
static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
{
// Validate arguments.
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
TORCH_CHECK(b.dim() == 1, "b must have rank 1");
TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
TORCH_CHECK(grad >= 0, "grad must be non-negative");
// Validate layout.
TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
// Create output tensor.
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
torch::Tensor y = torch::empty_like(x);
TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
// Initialize CUDA kernel parameters.
bias_act_kernel_params p;
p.x = x.data_ptr();
p.b = (b.numel()) ? b.data_ptr() : NULL;
p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
p.y = y.data_ptr();
p.grad = grad;
p.act = act;
p.alpha = alpha;
p.gain = gain;
p.clamp = clamp;
p.sizeX = (int)x.numel();
p.sizeB = (int)b.numel();
p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
// Choose CUDA kernel.
void* kernel;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
{
kernel = choose_bias_act_kernel<scalar_t>(p);
});
TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
// Launch CUDA kernel.
p.loopX = 4;
int blockSize = 4 * 32;
int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
void* args[] = {&p};
AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
return y;
}
//------------------------------------------------------------------------
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("bias_act", &bias_act);
}
//------------------------------------------------------------------------
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment