# Copyright (c) OpenMMLab. All rights reserved. from abc import ABCMeta, abstractmethod from collections import OrderedDict import torch import torch.distributed as dist import torch.nn as nn class BaseGAN(nn.Module, metaclass=ABCMeta): """BaseGAN Module.""" def __init__(self): super().__init__() self.fp16_enabled = False @property def with_disc(self): """Whether with dicriminator.""" return hasattr(self, 'discriminator') and self.discriminator is not None @property def with_ema_gen(self): """bool: whether the GAN adopts exponential moving average.""" return hasattr(self, 'gen_ema') and self.gen_ema is not None @property def with_gen_auxiliary_loss(self): """bool: whether the GAN adopts auxiliary loss in the generator.""" return hasattr(self, 'gen_auxiliary_losses') and (self.gen_auxiliary_losses is not None) @property def with_disc_auxiliary_loss(self): """bool: whether the GAN adopts auxiliary loss in the discriminator.""" return (hasattr(self, 'disc_auxiliary_losses') ) and self.disc_auxiliary_losses is not None def _get_disc_loss(self, outputs_dict): # Construct losses dict. If you hope some items to be included in the # computational graph, you have to add 'loss' in its name. Otherwise, # items without 'loss' in their name will just be used to print # information. losses_dict = {} # gan loss losses_dict['loss_disc_fake'] = self.gan_loss( outputs_dict['disc_pred_fake'], target_is_real=False, is_disc=True) losses_dict['loss_disc_real'] = self.gan_loss( outputs_dict['disc_pred_real'], target_is_real=True, is_disc=True) # disc auxiliary loss if self.with_disc_auxiliary_loss: for loss_module in self.disc_auxiliary_losses: loss_ = loss_module(outputs_dict) if loss_ is None: continue # the `loss_name()` function return name as 'loss_xxx' if loss_module.loss_name() in losses_dict: losses_dict[loss_module.loss_name( )] = losses_dict[loss_module.loss_name()] + loss_ else: losses_dict[loss_module.loss_name()] = loss_ loss, log_var = self._parse_losses(losses_dict) return loss, log_var def _get_gen_loss(self, outputs_dict): # Construct losses dict. If you hope some items to be included in the # computational graph, you have to add 'loss' in its name. Otherwise, # items without 'loss' in their name will just be used to print # information. losses_dict = {} # gan loss losses_dict['loss_disc_fake_g'] = self.gan_loss( outputs_dict['disc_pred_fake_g'], target_is_real=True, is_disc=False) # gen auxiliary loss if self.with_gen_auxiliary_loss: for loss_module in self.gen_auxiliary_losses: loss_ = loss_module(outputs_dict) if loss_ is None: continue # the `loss_name()` function return name as 'loss_xxx' if loss_module.loss_name() in losses_dict: losses_dict[loss_module.loss_name( )] = losses_dict[loss_module.loss_name()] + loss_ else: losses_dict[loss_module.loss_name()] = loss_ loss, log_var = self._parse_losses(losses_dict) return loss, log_var @abstractmethod def train_step(self, data, optimizer, ddp_reducer=None): """The iteration step during training. This method defines an iteration step during training. Different from other repo in **MM** series, we allow the back propagation and optimizer updating to directly follow the iterative training schedule of GAN. Of course, we will show that you can also move the back propagation outside of this method, and then optimize the parameters in the optimizer hook. But this will cause extra GPU memory cost as a result of retaining computational graph. Otherwise, the training schedule should be modified in the detailed implementation. TODO: Give an example of removing bp outside ``train_step``. TODO: Try the synchronized back propagation. Args: data (dict): The output of dataloader. optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of runner is passed to ``train_step()``. This argument is unused and reserved. ddp_reducer (:obj:`Reducer` | None, optional): This reducer is used to dynamically collect used parameters in the distributed training. If given an initialized ``Reducer``, we will call its ``prepare_for_backward()`` function just before calling ``.backward()``. Returns: dict: It should contain at least 3 keys: ``loss``, ``log_vars``, \ ``num_samples``. - ``loss`` is a tensor for back propagation, which can be a \ weighted sum of multiple losses. - ``log_vars`` contains all the variables to be sent to the logger. - ``num_samples`` indicates the batch size (when the model is \ DDP, it means the batch size on each GPU), which is used for \ averaging the logs. """ def sample_from_noise(self, noise, num_batches=0, sample_model='ema/orig', **kwargs): """Sample images from noises by using the generator. Args: noise (torch.Tensor | callable | None): You can directly give a batch of noise through a ``torch.Tensor`` or offer a callable function to sample a batch of noise data. Otherwise, the ``None`` indicates to use the default noise sampler. num_batches (int, optional): The number of batch size. Defaults to 0. Returns: torch.Tensor | dict: The output may be the direct synthesized images in ``torch.Tensor``. Otherwise, a dict with queried data, including generated images, will be returned. """ if sample_model == 'ema': assert self.use_ema _model = self.generator_ema elif sample_model == 'ema/orig' and self.use_ema: _model = self.generator_ema else: _model = self.generator outputs = _model(noise, num_batches=num_batches, **kwargs) if isinstance(outputs, dict) and 'noise_batch' in outputs: noise = outputs['noise_batch'] if sample_model == 'ema/orig' and self.use_ema: _model = self.generator outputs_ = _model(noise, num_batches=num_batches, **kwargs) if isinstance(outputs_, dict): outputs['fake_img'] = torch.cat( [outputs['fake_img'], outputs_['fake_img']], dim=0) else: outputs = torch.cat([outputs, outputs_], dim=0) return outputs def forward_train(self, data, **kwargs): """Deprecated forward function in training.""" raise NotImplementedError( 'In MMGeneration, we do NOT recommend users to call' 'this function, because the train_step function is designed for ' 'the training process.') def forward_test(self, data, **kwargs): """Testing function for GANs. Args: data (torch.Tensor | dict | None): Input data. This data will be passed to different methods. """ if kwargs.pop('mode', 'sampling') == 'sampling': return self.sample_from_noise(data, **kwargs) raise NotImplementedError('Other specific testing functions should' ' be implemented by the sub-classes.') def forward(self, data, return_loss=False, **kwargs): """Forward function. Args: data (dict | torch.Tensor): Input data dictionary. return_loss (bool, optional): Whether in training or testing. Defaults to False. Returns: dict: Output dictionary. """ if return_loss: return self.forward_train(data, **kwargs) return self.forward_test(data, **kwargs) def _parse_losses(self, losses): """Parse the raw outputs (losses) of the network. Args: losses (dict): Raw output of the network, which usually contain losses and other necessary information. Returns: tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor \ which may be a weighted sum of all losses, log_vars contains \ all the variables to be sent to the logger. """ log_vars = OrderedDict() for loss_name, loss_value in losses.items(): if isinstance(loss_value, torch.Tensor): log_vars[loss_name] = loss_value.mean() elif isinstance(loss_value, list): log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) # Allow setting None for some loss item. # This is to support dynamic loss module, where the loss is # calculated with a fixed frequency. elif loss_value is None: continue else: raise TypeError( f'{loss_name} is not a tensor or list of tensors') # Note that you have to add 'loss' in name of the items that will be # included in back propagation. loss = sum(_value for _key, _value in log_vars.items() if 'loss' in _key) log_vars['loss'] = loss for loss_name, loss_value in log_vars.items(): # reduce loss when distributed training if dist.is_available() and dist.is_initialized(): loss_value = loss_value.data.clone() dist.all_reduce(loss_value.div_(dist.get_world_size())) log_vars[loss_name] = loss_value.item() return loss, log_vars