# Copyright (c) OpenMMLab. All rights reserved. from copy import deepcopy import torch import torch.nn as nn from torch.nn.parallel.distributed import _find_tensors from ..builder import MODELS, build_module from ..common import set_requires_grad from .base_gan import BaseGAN # _SUPPORT_METHODS_ = ['DCGAN', 'STYLEGANv2'] # @MODELS.register_module(_SUPPORT_METHODS_) @MODELS.register_module() class StaticUnconditionalGAN(BaseGAN): """Unconditional GANs with static architecture in training. This is the standard GAN model containing standard adversarial training schedule. To fulfill the requirements of various GAN algorithms, ``disc_auxiliary_loss`` and ``gen_auxiliary_loss`` are provided to customize auxiliary losses, e.g., gradient penalty loss, and discriminator shift loss. In addition, ``train_cfg`` and ``test_cfg`` aims at setuping training schedule. Args: generator (dict): Config for generator. discriminator (dict): Config for discriminator. gan_loss (dict): Config for generative adversarial loss. disc_auxiliary_loss (dict): Config for auxiliary loss to discriminator. gen_auxiliary_loss (dict | None, optional): Config for auxiliary loss to generator. Defaults to None. train_cfg (dict | None, optional): Config for training schedule. Defaults to None. test_cfg (dict | None, optional): Config for testing schedule. Defaults to None. """ def __init__(self, generator, discriminator, gan_loss, disc_auxiliary_loss=None, gen_auxiliary_loss=None, train_cfg=None, test_cfg=None): super().__init__() self._gen_cfg = deepcopy(generator) self.generator = build_module(generator) # support no discriminator in testing if discriminator is not None: self.discriminator = build_module(discriminator) else: self.discriminator = None # support no gan_loss in testing if gan_loss is not None: self.gan_loss = build_module(gan_loss) else: self.gan_loss = None if disc_auxiliary_loss: self.disc_auxiliary_losses = build_module(disc_auxiliary_loss) if not isinstance(self.disc_auxiliary_losses, nn.ModuleList): self.disc_auxiliary_losses = nn.ModuleList( [self.disc_auxiliary_losses]) else: self.disc_auxiliary_loss = None if gen_auxiliary_loss: self.gen_auxiliary_losses = build_module(gen_auxiliary_loss) if not isinstance(self.gen_auxiliary_losses, nn.ModuleList): self.gen_auxiliary_losses = nn.ModuleList( [self.gen_auxiliary_losses]) else: self.gen_auxiliary_losses = None self.train_cfg = deepcopy(train_cfg) if train_cfg else None self.test_cfg = deepcopy(test_cfg) if test_cfg else None self._parse_train_cfg() if test_cfg is not None: self._parse_test_cfg() def _parse_train_cfg(self): """Parsing train config and set some attributes for training.""" if self.train_cfg is None: self.train_cfg = dict() # control the work flow in train step self.disc_steps = self.train_cfg.get('disc_steps', 1) # whether to use exponential moving average for training self.use_ema = self.train_cfg.get('use_ema', False) if self.use_ema: # use deepcopy to guarantee the consistency self.generator_ema = deepcopy(self.generator) self.real_img_key = self.train_cfg.get('real_img_key', 'real_img') def _parse_test_cfg(self): """Parsing test config and set some attributes for testing.""" if self.test_cfg is None: self.test_cfg = dict() # basic testing information self.batch_size = self.test_cfg.get('batch_size', 1) # whether to use exponential moving average for testing self.use_ema = self.test_cfg.get('use_ema', False) # TODO: finish ema part def train_step(self, data_batch, optimizer, ddp_reducer=None, loss_scaler=None, use_apex_amp=False, running_status=None): """Train step function. This function implements the standard training iteration for asynchronous adversarial training. Namely, in each iteration, we first update discriminator and then compute loss for generator with the newly updated discriminator. As for distributed training, we use the ``reducer`` from ddp to synchronize the necessary params in current computational graph. Args: data_batch (dict): Input data from dataloader. optimizer (dict): Dict contains optimizer for generator and discriminator. ddp_reducer (:obj:`Reducer` | None, optional): Reducer from ddp. It is used to prepare for ``backward()`` in ddp. Defaults to None. loss_scaler (:obj:`torch.cuda.amp.GradScaler` | None, optional): The loss/gradient scaler used for auto mixed-precision training. Defaults to ``None``. use_apex_amp (bool, optional). Whether to use apex.amp. Defaults to ``False``. running_status (dict | None, optional): Contains necessary basic information for training, e.g., iteration number. Defaults to None. Returns: dict: Contains 'log_vars', 'num_samples', and 'results'. """ # get data from data_batch real_imgs = data_batch[self.real_img_key] # If you adopt ddp, this batch size is local batch size for each GPU. # If you adopt dp, this batch size is the global batch size as usual. batch_size = real_imgs.shape[0] # get running status if running_status is not None: curr_iter = running_status['iteration'] else: # dirty walkround for not providing running status if not hasattr(self, 'iteration'): self.iteration = 0 curr_iter = self.iteration # disc training set_requires_grad(self.discriminator, True) optimizer['discriminator'].zero_grad() # TODO: add noise sampler to customize noise sampling # pass model specific training kwargs g_training_kwargs = {} if hasattr(self.generator, 'get_training_kwargs'): g_training_kwargs.update( self.generator.get_training_kwargs(phase='disc')) with torch.no_grad(): fake_imgs = self.generator( None, num_batches=batch_size, **g_training_kwargs) # disc pred for fake imgs and real_imgs disc_pred_fake = self.discriminator(fake_imgs) disc_pred_real = self.discriminator(real_imgs) # get data dict to compute losses for disc data_dict_ = dict( gen=self.generator, disc=self.discriminator, disc_pred_fake=disc_pred_fake, disc_pred_real=disc_pred_real, fake_imgs=fake_imgs, real_imgs=real_imgs, iteration=curr_iter, batch_size=batch_size, loss_scaler=loss_scaler) loss_disc, log_vars_disc = self._get_disc_loss(data_dict_) # prepare for backward in ddp. If you do not call this function before # back propagation, the ddp will not dynamically find the used params # in current computation. if ddp_reducer is not None: ddp_reducer.prepare_for_backward(_find_tensors(loss_disc)) if loss_scaler: # add support for fp16 loss_scaler.scale(loss_disc).backward() elif use_apex_amp: from apex import amp with amp.scale_loss( loss_disc, optimizer['discriminator'], loss_id=0) as scaled_loss_disc: scaled_loss_disc.backward() else: loss_disc.backward() if loss_scaler: loss_scaler.unscale_(optimizer['discriminator']) # note that we do not contain clip_grad procedure loss_scaler.step(optimizer['discriminator']) # loss_scaler.update will be called in runner.train() else: optimizer['discriminator'].step() # skip generator training if only train discriminator for current # iteration if (curr_iter + 1) % self.disc_steps != 0: results = dict( fake_imgs=fake_imgs.cpu(), real_imgs=real_imgs.cpu()) outputs = dict( log_vars=log_vars_disc, num_samples=batch_size, results=results) if hasattr(self, 'iteration'): self.iteration += 1 return outputs # generator training set_requires_grad(self.discriminator, False) optimizer['generator'].zero_grad() # TODO: add noise sampler to customize noise sampling # pass model specific training kwargs g_training_kwargs = {} if hasattr(self.generator, 'get_training_kwargs'): g_training_kwargs.update( self.generator.get_training_kwargs(phase='gen')) fake_imgs = self.generator( None, num_batches=batch_size, **g_training_kwargs) disc_pred_fake_g = self.discriminator(fake_imgs) data_dict_ = dict( gen=self.generator, disc=self.discriminator, fake_imgs=fake_imgs, disc_pred_fake_g=disc_pred_fake_g, iteration=curr_iter, batch_size=batch_size, loss_scaler=loss_scaler) loss_gen, log_vars_g = self._get_gen_loss(data_dict_) # prepare for backward in ddp. If you do not call this function before # back propagation, the ddp will not dynamically find the used params # in current computation. if ddp_reducer is not None: ddp_reducer.prepare_for_backward(_find_tensors(loss_gen)) if loss_scaler: loss_scaler.scale(loss_gen).backward() elif use_apex_amp: from apex import amp with amp.scale_loss( loss_gen, optimizer['generator'], loss_id=1) as scaled_loss_disc: scaled_loss_disc.backward() else: loss_gen.backward() if loss_scaler: loss_scaler.unscale_(optimizer['generator']) # note that we do not contain clip_grad procedure loss_scaler.step(optimizer['generator']) # loss_scaler.update will be called in runner.train() else: optimizer['generator'].step() # update ada p if hasattr(self.discriminator, 'with_ada') and self.discriminator.with_ada: self.discriminator.ada_aug.log_buffer[0] += batch_size self.discriminator.ada_aug.log_buffer[1] += disc_pred_real.sign( ).sum() self.discriminator.ada_aug.update( iteration=curr_iter, num_batches=batch_size) log_vars_disc['augment'] = ( self.discriminator.ada_aug.aug_pipeline.p.data.cpu()) log_vars = {} log_vars.update(log_vars_g) log_vars.update(log_vars_disc) results = dict(fake_imgs=fake_imgs.cpu(), real_imgs=real_imgs.cpu()) outputs = dict( log_vars=log_vars, num_samples=batch_size, results=results) if hasattr(self, 'iteration'): self.iteration += 1 return outputs