cyclegan.py 7.73 KB
Newer Older
dongchy920's avatar
dongchy920 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
# Copyright (c) OpenMMLab. All rights reserved.
from torch.nn.parallel.distributed import _find_tensors

from mmgen.models.builder import MODELS
from ..common import GANImageBuffer, set_requires_grad
from .static_translation_gan import StaticTranslationGAN


@MODELS.register_module()
class CycleGAN(StaticTranslationGAN):
    """CycleGAN model for unpaired image-to-image translation.

    Ref:
    Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial
    Networks
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # GAN image buffers
        self.image_buffers = dict()
        self.buffer_size = (50 if self.train_cfg is None else
                            self.train_cfg.get('buffer_size', 50))
        for domain in self._reachable_domains:
            self.image_buffers[domain] = GANImageBuffer(self.buffer_size)

        self.use_ema = False

    def forward_test(self, img, target_domain, **kwargs):
        """Forward function for testing.

        Args:
            img (tensor): Input image tensor.
            target_domain (str): Target domain of output image.
            kwargs (dict): Other arguments.

        Returns:
            dict: Forward results.
        """
        # This is a trick for CycleGAN
        # ref: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/e1bdf46198662b0f4d0b318e24568205ec4d7aee/test.py#L54 # noqa
        self.train()
        target = self.translation(img, target_domain=target_domain, **kwargs)
        results = dict(source=img.cpu(), target=target.cpu())
        return results

    def _get_disc_loss(self, outputs):
        """Backward function for the discriminators.

        Args:
            outputs (dict): Dict of forward results.

        Returns:
            dict: Discriminators' loss and loss dict.
        """
        discriminators = self.get_module(self.discriminators)

        log_vars_d = dict()
        loss_d = 0

        # GAN loss for discriminators['a']
        for domain in self._reachable_domains:
            losses = dict()
            fake_img = self.image_buffers[domain].query(
                outputs[f'fake_{domain}'])
            fake_pred = discriminators[domain](fake_img.detach())
            losses[f'loss_gan_d_{domain}_fake'] = self.gan_loss(
                fake_pred, target_is_real=False, is_disc=True)
            real_pred = discriminators[domain](outputs[f'real_{domain}'])
            losses[f'loss_gan_d_{domain}_real'] = self.gan_loss(
                real_pred, target_is_real=True, is_disc=True)

            _loss_d, _log_vars_d = self._parse_losses(losses)
            _loss_d *= 0.5
            loss_d += _loss_d
            log_vars_d[f'loss_gan_d_{domain}'] = _log_vars_d['loss'] * 0.5

        return loss_d, log_vars_d

    def _get_gen_loss(self, outputs):
        """Backward function for the generators.

        Args:
            outputs (dict): Dict of forward results.

        Returns:
            dict: Generators' loss and loss dict.
        """
        generators = self.get_module(self.generators)
        discriminators = self.get_module(self.discriminators)

        losses = dict()
        for domain in self._reachable_domains:
            # Identity reconstruction for generators
            outputs[f'identity_{domain}'] = generators[domain](
                outputs[f'real_{domain}'])
            # GAN loss for generators
            fake_pred = discriminators[domain](outputs[f'fake_{domain}'])
            losses[f'loss_gan_g_{domain}'] = self.gan_loss(
                fake_pred, target_is_real=True, is_disc=False)

        # gen auxiliary loss
        if self.with_gen_auxiliary_loss:
            for loss_module in self.gen_auxiliary_losses:
                loss_ = loss_module(outputs)
                if loss_ is None:
                    continue
                # the `loss_name()` function return name as 'loss_xxx'
                if loss_module.loss_name() in losses:
                    losses[loss_module.loss_name(
                    )] = losses[loss_module.loss_name()] + loss_
                else:
                    losses[loss_module.loss_name()] = loss_

        loss_g, log_vars_g = self._parse_losses(losses)

        return loss_g, log_vars_g

    def _get_opposite_domain(self, domain):
        for item in self._reachable_domains:
            if item != domain:
                return item
        return None

    def train_step(self,
                   data_batch,
                   optimizer,
                   ddp_reducer=None,
                   running_status=None):
        """Training step function.

        Args:
            data_batch (dict): Dict of the input data batch.
            optimizer (dict[torch.optim.Optimizer]): Dict of optimizers for
                the generators and discriminators.
            ddp_reducer (:obj:`Reducer` | None, optional): Reducer from ddp.
                It is used to prepare for ``backward()`` in ddp. Defaults to
                None.
            running_status (dict | None, optional): Contains necessary basic
                information for training, e.g., iteration number. Defaults to
                None.

        Returns:
            dict: Dict of loss, information for logger, the number of samples\
                and results for visualization.
        """
        # get running status
        if running_status is not None:
            curr_iter = running_status['iteration']
        else:
            # dirty walkround for not providing running status
            if not hasattr(self, 'iteration'):
                self.iteration = 0
            curr_iter = self.iteration

        # forward generators
        outputs = dict()
        for target_domain in self._reachable_domains:
            # fetch data by domain
            source_domain = self.get_other_domains(target_domain)[0]
            img = data_batch[f'img_{source_domain}']
            # translation process
            results = self(img, test_mode=False, target_domain=target_domain)
            outputs[f'real_{source_domain}'] = results['source']
            outputs[f'fake_{target_domain}'] = results['target']
            # cycle process
            results = self(
                results['target'],
                test_mode=False,
                target_domain=source_domain)
            outputs[f'cycle_{source_domain}'] = results['target']

        log_vars = dict()

        # discriminators
        set_requires_grad(self.discriminators, True)
        # optimize
        optimizer['discriminators'].zero_grad()
        loss_d, log_vars_d = self._get_disc_loss(outputs)
        log_vars.update(log_vars_d)
        if ddp_reducer is not None:
            ddp_reducer.prepare_for_backward(_find_tensors(loss_d))
        loss_d.backward()
        optimizer['discriminators'].step()

        # generators, no updates to discriminator parameters.
        if (curr_iter % self.disc_steps == 0
                and curr_iter >= self.disc_init_steps):
            set_requires_grad(self.discriminators, False)
            # optimize
            optimizer['generators'].zero_grad()
            loss_g, log_vars_g = self._get_gen_loss(outputs)
            log_vars.update(log_vars_g)
            if ddp_reducer is not None:
                ddp_reducer.prepare_for_backward(_find_tensors(loss_g))
            loss_g.backward()
            optimizer['generators'].step()

        if hasattr(self, 'iteration'):
            self.iteration += 1

        image_results = dict()
        for domain in self._reachable_domains:
            image_results[f'real_{domain}'] = outputs[f'real_{domain}'].cpu()
            image_results[f'fake_{domain}'] = outputs[f'fake_{domain}'].cpu()
        results = dict(
            log_vars=log_vars,
            num_samples=len(outputs[f'real_{domain}']),
            results=image_results)

        return results