basic_conditional_gan.py 14.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy

import torch
import torch.nn as nn
from torch.nn.parallel.distributed import _find_tensors

from ..builder import MODELS, build_module
from ..common import set_requires_grad
from .base_gan import BaseGAN


@MODELS.register_module('BasiccGAN')
@MODELS.register_module()
class BasicConditionalGAN(BaseGAN):
    """Basic conditional GANs.

    This is the conditional GAN model containing standard adversarial training
    schedule. To fulfill the requirements of various GAN algorithms,
    ``disc_auxiliary_loss`` and ``gen_auxiliary_loss`` are provided to
    customize auxiliary losses, e.g., gradient penalty loss, and discriminator
    shift loss. In addition, ``train_cfg`` and ``test_cfg`` aims at setuping
    training schedule.

    Args:
        generator (dict): Config for generator.
        discriminator (dict): Config for discriminator.
        gan_loss (dict): Config for generative adversarial loss.
        disc_auxiliary_loss (dict): Config for auxiliary loss to
            discriminator.
        gen_auxiliary_loss (dict | None, optional): Config for auxiliary loss
            to generator. Defaults to None.
        train_cfg (dict | None, optional): Config for training schedule.
            Defaults to None.
        test_cfg (dict | None, optional): Config for testing schedule. Defaults
            to None.
        num_classes (int | None, optional): The number of conditional classes.
            Defaults to None.
    """

    def __init__(self,
                 generator,
                 discriminator,
                 gan_loss,
                 disc_auxiliary_loss=None,
                 gen_auxiliary_loss=None,
                 train_cfg=None,
                 test_cfg=None,
                 num_classes=None):
        super().__init__()
        self.num_classes = num_classes
        self._gen_cfg = deepcopy(generator)
        self.generator = build_module(
            generator, default_args=dict(num_classes=num_classes))

        # support no discriminator in testing
        if discriminator is not None:
            self.discriminator = build_module(
                discriminator, default_args=dict(num_classes=num_classes))
        else:
            self.discriminator = None

        # support no gan_loss in testing
        if gan_loss is not None:
            self.gan_loss = build_module(gan_loss)
        else:
            self.gan_loss = None

        if disc_auxiliary_loss:
            self.disc_auxiliary_losses = build_module(disc_auxiliary_loss)
            if not isinstance(self.disc_auxiliary_losses, nn.ModuleList):
                self.disc_auxiliary_losses = nn.ModuleList(
                    [self.disc_auxiliary_losses])
        else:
            self.disc_auxiliary_loss = None

        if gen_auxiliary_loss:
            self.gen_auxiliary_losses = build_module(gen_auxiliary_loss)
            if not isinstance(self.gen_auxiliary_losses, nn.ModuleList):
                self.gen_auxiliary_losses = nn.ModuleList(
                    [self.gen_auxiliary_losses])
        else:
            self.gen_auxiliary_losses = None

        self.train_cfg = deepcopy(train_cfg) if train_cfg else None
        self.test_cfg = deepcopy(test_cfg) if test_cfg else None

        self._parse_train_cfg()
        if test_cfg is not None:
            self._parse_test_cfg()

    def _parse_train_cfg(self):
        """Parsing train config and set some attributes for training."""
        if self.train_cfg is None:
            self.train_cfg = dict()
        # control the work flow in train step
        self.disc_steps = self.train_cfg.get('disc_steps', 1)
        self.gen_steps = self.train_cfg.get('gen_steps', 1)

        # add support for accumulating gradients within multiple steps. This
        # feature aims to simulate large `batch_sizes` (but may have some
        # detailed differences in BN). Note that `self.disc_steps` should be
        # set according to the batch accumulation strategy.
        # In addition, in the detailed implementation, there is a difference
        # between the batch accumulation in the generator and discriminator.
        self.batch_accumulation_steps = self.train_cfg.get(
            'batch_accumulation_steps', 1)

        # whether to use exponential moving average for training
        self.use_ema = self.train_cfg.get('use_ema', False)
        if self.use_ema:
            # use deepcopy to guarantee the consistency
            self.generator_ema = deepcopy(self.generator)

    def _parse_test_cfg(self):
        """Parsing test config and set some attributes for testing."""
        if self.test_cfg is None:
            self.test_cfg = dict()

        # basic testing information
        self.batch_size = self.test_cfg.get('batch_size', 1)

        # whether to use exponential moving average for testing
        self.use_ema = self.test_cfg.get('use_ema', False)
        # TODO: finish ema part

    def train_step(self,
                   data_batch,
                   optimizer,
                   ddp_reducer=None,
                   loss_scaler=None,
                   use_apex_amp=False,
                   running_status=None):
        """Train step function.

        This function implements the standard training iteration for
        asynchronous adversarial training. Namely, in each iteration, we first
        update discriminator and then compute loss for generator with the newly
        updated discriminator.

        As for distributed training, we use the ``reducer`` from ddp to
        synchronize the necessary params in current computational graph.

        Args:
            data_batch (dict): Input data from dataloader.
            optimizer (dict): Dict contains optimizer for generator and
                discriminator.
            ddp_reducer (:obj:`Reducer` | None, optional): Reducer from ddp.
                It is used to prepare for ``backward()`` in ddp. Defaults to
                None.
            loss_scaler (:obj:`torch.cuda.amp.GradScaler` | None, optional):
                The loss/gradient scaler used for auto mixed-precision
                training. Defaults to ``None``.
            use_apex_amp (bool, optional). Whether to use apex.amp. Defaults to
                ``False``.
            running_status (dict | None, optional): Contains necessary basic
                information for training, e.g., iteration number. Defaults to
                None.

        Returns:
            dict: Contains 'log_vars', 'num_samples', and 'results'.
        """
        # get data from data_batch
        real_imgs = data_batch['img']
        # get the ground-truth label, torch.Tensor (N, )
        gt_label = data_batch['gt_label']
        # If you adopt ddp, this batch size is local batch size for each GPU.
        # If you adopt dp, this batch size is the global batch size as usual.
        batch_size = real_imgs.shape[0]

        # get running status
        if running_status is not None:
            curr_iter = running_status['iteration']
        else:
            # dirty walkround for not providing running status
            if not hasattr(self, 'iteration'):
                self.iteration = 0
            curr_iter = self.iteration

        # disc training
        set_requires_grad(self.discriminator, True)

        # do not `zero_grad` during batch accumulation
        if curr_iter % self.batch_accumulation_steps == 0:
            optimizer['discriminator'].zero_grad()
        # TODO: add noise sampler to customize noise sampling
        with torch.no_grad():
            fake_data = self.generator(
                None, num_batches=batch_size, label=None, return_noise=True)
            # fake_label should be in the same data type with the gt_label
            fake_imgs, fake_label = fake_data['fake_img'], fake_data['label']

        # disc pred for fake imgs and real_imgs
        disc_pred_fake = self.discriminator(fake_imgs, label=fake_label)
        disc_pred_real = self.discriminator(real_imgs, label=gt_label)
        # get data dict to compute losses for disc
        data_dict_ = dict(
            gen=self.generator,
            disc=self.discriminator,
            disc_pred_fake=disc_pred_fake,
            disc_pred_real=disc_pred_real,
            fake_imgs=fake_imgs,
            real_imgs=real_imgs,
            iteration=curr_iter,
            batch_size=batch_size,
            gt_label=gt_label,
            fake_label=fake_label,
            loss_scaler=loss_scaler)

        loss_disc, log_vars_disc = self._get_disc_loss(data_dict_)
        loss_disc = loss_disc / float(self.batch_accumulation_steps)

        # prepare for backward in ddp. If you do not call this function before
        # back propagation, the ddp will not dynamically find the used params
        # in current computation.
        if ddp_reducer is not None:
            ddp_reducer.prepare_for_backward(_find_tensors(loss_disc))

        if loss_scaler:
            # add support for fp16
            loss_scaler.scale(loss_disc).backward()
        elif use_apex_amp:
            from apex import amp
            with amp.scale_loss(
                    loss_disc, optimizer['discriminator'],
                    loss_id=0) as scaled_loss_disc:
                scaled_loss_disc.backward()
        else:
            loss_disc.backward()

        if (curr_iter + 1) % self.batch_accumulation_steps == 0:
            if loss_scaler:
                loss_scaler.unscale_(optimizer['discriminator'])
                # note that we do not contain clip_grad procedure
                loss_scaler.step(optimizer['discriminator'])
                # loss_scaler.update will be called in runner.train()
            else:
                optimizer['discriminator'].step()

        # skip generator training if only train discriminator for current
        # iteration
        if (curr_iter + 1) % self.disc_steps != 0:
            results = dict(
                fake_imgs=fake_imgs.cpu(), real_imgs=real_imgs.cpu())
            outputs = dict(
                log_vars=log_vars_disc,
                num_samples=batch_size,
                results=results)
            if hasattr(self, 'iteration'):
                self.iteration += 1
            return outputs

        # generator training
        set_requires_grad(self.discriminator, False)
        # allow for training the generator with multiple steps
        for _ in range(self.gen_steps):
            optimizer['generator'].zero_grad()
            for _ in range(self.batch_accumulation_steps):
                # TODO: add noise sampler to customize noise sampling
                fake_data = self.generator(
                    None, num_batches=batch_size, return_noise=True)
                # fake_label should be in the same data type with the gt_label
                fake_imgs, fake_label = fake_data['fake_img'], fake_data[
                    'label']
                disc_pred_fake_g = self.discriminator(
                    fake_imgs, label=fake_label)

                data_dict_ = dict(
                    gen=self.generator,
                    disc=self.discriminator,
                    fake_imgs=fake_imgs,
                    disc_pred_fake_g=disc_pred_fake_g,
                    iteration=curr_iter,
                    batch_size=batch_size,
                    fake_label=fake_label,
                    loss_scaler=loss_scaler)

                loss_gen, log_vars_g = self._get_gen_loss(data_dict_)
                loss_gen = loss_gen / float(self.batch_accumulation_steps)

                # prepare for backward in ddp. If you do not call this function
                # before back propagation, the ddp will not dynamically find
                # the used params in current computation.
                if ddp_reducer is not None:
                    ddp_reducer.prepare_for_backward(_find_tensors(loss_gen))

                if loss_scaler:
                    loss_scaler.scale(loss_gen).backward()
                elif use_apex_amp:
                    from apex import amp
                    with amp.scale_loss(
                            loss_gen, optimizer['generator'],
                            loss_id=1) as scaled_loss_disc:
                        scaled_loss_disc.backward()
                else:
                    loss_gen.backward()

            if loss_scaler:
                loss_scaler.unscale_(optimizer['generator'])
                # note that we do not contain clip_grad procedure
                loss_scaler.step(optimizer['generator'])
                # loss_scaler.update will be called in runner.train()
            else:
                optimizer['generator'].step()

        log_vars = {}
        log_vars.update(log_vars_g)
        log_vars.update(log_vars_disc)

        results = dict(fake_imgs=fake_imgs.cpu(), real_imgs=real_imgs.cpu())
        outputs = dict(
            log_vars=log_vars, num_samples=batch_size, results=results)

        if hasattr(self, 'iteration'):
            self.iteration += 1
        return outputs

    def sample_from_noise(self,
                          noise,
                          num_batches=0,
                          sample_model='ema/orig',
                          label=None,
                          **kwargs):
        """Sample images from noises by using the generator.

        Args:
            noise (torch.Tensor | callable | None): You can directly give a
                batch of noise through a ``torch.Tensor`` or offer a callable
                function to sample a batch of noise data. Otherwise, the
                ``None`` indicates to use the default noise sampler.
            num_batches (int, optional): The number of batch size.
                Defaults to 0.
            sampel_model (str, optional): Use which model to sample fake
                images. Defaults to `'ema/orig'`.
            label (torch.Tensor | None , optional): The conditional label.
                Defaults to None.

        Returns:
            torch.Tensor | dict: The output may be the direct synthesized
                images in ``torch.Tensor``. Otherwise, a dict with queried
                data, including generated images, will be returned.
        """
        if sample_model == 'ema':
            assert self.use_ema
            _model = self.generator_ema
        elif sample_model == 'ema/orig' and self.use_ema:
            _model = self.generator_ema
        else:
            _model = self.generator

        outputs = _model(noise, num_batches=num_batches, label=label, **kwargs)

        if isinstance(outputs, dict) and 'noise_batch' in outputs:
            noise = outputs['noise_batch']
            label = outputs['label']

        if sample_model == 'ema/orig' and self.use_ema:
            _model = self.generator
            outputs_ = _model(
                noise, num_batches=num_batches, label=label, **kwargs)

            if isinstance(outputs_, dict):
                outputs['fake_img'] = torch.cat(
                    [outputs['fake_img'], outputs_['fake_img']], dim=0)
            else:
                outputs = torch.cat([outputs, outputs_], dim=0)

        return outputs