static_translation_gan.py 5.12 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy

import torch.nn as nn
from mmcv.parallel import MMDistributedDataParallel

from ..builder import MODELS, build_module
from ..gans import BaseGAN
from .base_translation_model import BaseTranslationModel


@MODELS.register_module()
class StaticTranslationGAN(BaseTranslationModel, BaseGAN):
    """Basic translation model based on static unconditional GAN.

    Args:
        generator (dict): Config for the generator.
        discriminator (dict): Config for the discriminator.
        gan_loss (dict): Config for the gan loss.
        pretrained (str | optional): Path for pretrained model.
            Defaults to None.
        disc_auxiliary_loss (dict | optional): Config for auxiliary loss to
            discriminator. Defaults to None.
        gen_auxiliary_loss (dict | optional): Config for auxiliary loss
            to generator. Defaults to None.
    """

    def __init__(self,
                 generator,
                 discriminator,
                 gan_loss,
                 *args,
                 pretrained=None,
                 disc_auxiliary_loss=None,
                 gen_auxiliary_loss=None,
                 **kwargs):
        BaseGAN.__init__(self)
        BaseTranslationModel.__init__(self, *args, **kwargs)
        # Building generators and discriminators
        self._gen_cfg = deepcopy(generator)
        # build domain generators
        self.generators = nn.ModuleDict()
        for domain in self._reachable_domains:
            self.generators[domain] = build_module(generator)

        self._disc_cfg = deepcopy(discriminator)
        # build domain discriminators
        if discriminator is not None:
            self.discriminators = nn.ModuleDict()
            for domain in self._reachable_domains:
                self.discriminators[domain] = build_module(discriminator)
        # support no discriminator in testing
        else:
            self.discriminators = None

        # support no gan_loss in testing
        if gan_loss is not None:
            self.gan_loss = build_module(gan_loss)
        else:
            self.gan_loss = None

        if disc_auxiliary_loss:
            self.disc_auxiliary_losses = build_module(disc_auxiliary_loss)
            if not isinstance(self.disc_auxiliary_losses, nn.ModuleList):
                self.disc_auxiliary_losses = nn.ModuleList(
                    [self.disc_auxiliary_losses])
        else:
            self.disc_auxiliary_loss = None

        if gen_auxiliary_loss:
            self.gen_auxiliary_losses = build_module(gen_auxiliary_loss)
            if not isinstance(self.gen_auxiliary_losses, nn.ModuleList):
                self.gen_auxiliary_losses = nn.ModuleList(
                    [self.gen_auxiliary_losses])
        else:
            self.gen_auxiliary_losses = None

        self.init_weights(pretrained)

    def init_weights(self, pretrained=None):
        """Initialize weights for the model.

        Args:
            pretrained (str, optional): Path for pretrained weights. If given
                None, pretrained weights will not be loaded. Default: None.
        """
        for domain in self._reachable_domains:
            self.generators[domain].init_weights(pretrained=pretrained)
            self.discriminators[domain].init_weights(pretrained=pretrained)

    def _parse_train_cfg(self):
        """Parsing train config and set some attributes for training."""
        if self.train_cfg is None:
            self.train_cfg = dict()
        # control the work flow in train step
        self.disc_steps = self.train_cfg.get('disc_steps', 1)

        self.disc_init_steps = (0 if self.train_cfg is None else
                                self.train_cfg.get('disc_init_steps', 0))

        self.real_img_key = self.train_cfg.get('real_img_key', 'real_img')

    def _parse_test_cfg(self):
        """Parsing test config and set some attributes for testing."""
        if self.test_cfg is None:
            self.test_cfg = dict()

        # basic testing information
        self.batch_size = self.test_cfg.get('batch_size', 1)

    def get_module(self, module):
        """Get `nn.ModuleDict` to fit the `MMDistributedDataParallel`
        interface.

        Args:
            module (MMDistributedDataParallel | nn.ModuleDict): The input
                module that needs processing.

        Returns:
            nn.ModuleDict: The ModuleDict of multiple networks.
        """
        if isinstance(module, MMDistributedDataParallel):
            return module.module

        return module

    def _get_target_generator(self, domain):
        """get target generator."""
        assert self.is_domain_reachable(
            domain
        ), f'{domain} domain is not reachable, available domain list is\
            {self._reachable_domains}'

        return self.get_module(self.generators)[domain]

    def _get_target_discriminator(self, domain):
        """get target discriminator."""
        assert self.is_domain_reachable(
            domain
        ), f'{domain} domain is not reachable, available domain list is\
            {self._reachable_domains}'

        return self.get_module(self.discriminators)[domain]