generator_discriminator.py 9.99 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
# Copyright (c) OpenMMLab. All rights reserved.
from functools import partial

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.runner import load_state_dict
from mmcv.utils import print_log

from mmgen.models.builder import MODULES
from mmgen.utils import get_root_logger
from .modules import DiscriminatorBlock, GeneratorBlock


@MODULES.register_module()
class SinGANMultiScaleGenerator(nn.Module):
    """Multi-Scale Generator used in SinGAN.

    More details can be found in: Singan: Learning a Generative Model from a
    Single Natural Image, ICCV'19.

    Notes:

    - In this version, we adopt the interpolation function from the official
      PyTorch APIs, which is different from the original implementation by the
      authors. However, in our experiments, this influence can be ignored.

    Args:
        in_channels (int): Input channels.
        out_channels (int): Output channels.
        num_scales (int): The number of scales/stages in generator. Note
            that this number is counted from zero, which is the same as the
            original paper.
        kernel_size (int, optional): Kernel size, same as :obj:`nn.Conv2d`.
            Defaults to 3.
        padding (int, optional): Padding for the convolutional layer, same as
            :obj:`nn.Conv2d`. Defaults to 0.
        num_layers (int, optional): The number of convolutional layers in each
            generator block. Defaults to 5.
        base_channels (int, optional): The basic channels for convolutional
            layers in the generator block. Defaults to 32.
        min_feat_channels (int, optional): Minimum channels for the feature
            maps in the generator block. Defaults to 32.
        out_act_cfg (dict | None, optional): Configs for output activation
            layer. Defaults to dict(type='Tanh').
    """

    def __init__(self,
                 in_channels,
                 out_channels,
                 num_scales,
                 kernel_size=3,
                 padding=0,
                 num_layers=5,
                 base_channels=32,
                 min_feat_channels=32,
                 out_act_cfg=dict(type='Tanh'),
                 **kwargs):
        super().__init__()

        self.pad_head = int((kernel_size - 1) / 2 * num_layers)
        self.blocks = nn.ModuleList()

        self.upsample = partial(
            F.interpolate, mode='bicubic', align_corners=True)

        for scale in range(num_scales + 1):
            base_ch = min(base_channels * pow(2, int(np.floor(scale / 4))),
                          128)
            min_feat_ch = min(
                min_feat_channels * pow(2, int(np.floor(scale / 4))), 128)

            self.blocks.append(
                GeneratorBlock(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    kernel_size=kernel_size,
                    padding=padding,
                    num_layers=num_layers,
                    base_channels=base_ch,
                    min_feat_channels=min_feat_ch,
                    out_act_cfg=out_act_cfg,
                    **kwargs))

        self.noise_padding_layer = nn.ZeroPad2d(self.pad_head)
        self.img_padding_layer = nn.ZeroPad2d(self.pad_head)

    def forward(self,
                input_sample,
                fixed_noises,
                noise_weights,
                rand_mode,
                curr_scale,
                num_batches=1,
                get_prev_res=False,
                return_noise=False):
        """Forward function.

        Args:
            input_sample (Tensor | None): The input for generator. In the
                original implementation, a tensor filled with zeros is adopted.
                If None is given, we will construct it from the first fixed
                noises.
            fixed_noises (list[Tensor]): List of the fixed noises in SinGAN.
            noise_weights (list[float]): List of the weights for random noises.
            rand_mode (str): Choices from ['rand', 'recon']. In ``rand`` mode,
                it will sample from random noises. Otherwise, the
                reconstruction for the single image will be returned.
            curr_scale (int): The scale for the current inference or training.
            num_batches (int, optional): The number of batches. Defaults to 1.
            get_prev_res (bool, optional): Whether to return results from
                previous stages. Defaults to False.
            return_noise (bool, optional): Whether to return noises tensor.
                Defaults to False.

        Returns:
            Tensor | dict: Generated image tensor or dictionary containing \
                more data.
        """
        if get_prev_res or return_noise:
            prev_res_list = []
            noise_list = []

        if input_sample is None:
            input_sample = torch.zeros(
                (num_batches, 3, fixed_noises[0].shape[-2],
                 fixed_noises[0].shape[-1])).to(fixed_noises[0])

        g_res = input_sample

        for stage in range(curr_scale + 1):
            if rand_mode == 'recon':
                noise_ = fixed_noises[stage]
            else:
                noise_ = torch.randn(num_batches,
                                     *fixed_noises[stage].shape[1:]).to(g_res)
            if return_noise:
                noise_list.append(noise_)

            # add padding at head
            pad_ = (self.pad_head, ) * 4
            noise_ = F.pad(noise_, pad_)
            g_res_pad = F.pad(g_res, pad_)
            noise = noise_ * noise_weights[stage] + g_res_pad

            g_res = self.blocks[stage](noise.detach(), g_res)

            if get_prev_res and stage != curr_scale:
                prev_res_list.append(g_res)

            # upsample, here we use interpolation from PyTorch
            if stage != curr_scale:
                h_next, w_next = fixed_noises[stage + 1].shape[-2:]
                g_res = self.upsample(g_res, (h_next, w_next))

        if get_prev_res or return_noise:
            output_dict = dict(
                fake_img=g_res,
                prev_res_list=prev_res_list,
                noise_batch=noise_list)
            return output_dict

        return g_res

    def check_and_load_prev_weight(self, curr_scale):
        if curr_scale == 0:
            return
        prev_ch = self.blocks[curr_scale - 1].base_channels
        curr_ch = self.blocks[curr_scale].base_channels

        prev_in_ch = self.blocks[curr_scale - 1].in_channels
        curr_in_ch = self.blocks[curr_scale].in_channels
        if prev_ch == curr_ch and prev_in_ch == curr_in_ch:
            load_state_dict(
                self.blocks[curr_scale],
                self.blocks[curr_scale - 1].state_dict(),
                logger=get_root_logger())
            print_log('Successfully load pretrianed model from last scale.')
        else:
            print_log(
                'Cannot load pretrained model from last scale since'
                f' prev_ch({prev_ch}) != curr_ch({curr_ch})'
                f' or prev_in_ch({prev_in_ch}) != curr_in_ch({curr_in_ch})')


@MODULES.register_module()
class SinGANMultiScaleDiscriminator(nn.Module):
    """Multi-Scale Discriminator used in SinGAN.

    More details can be found in: Singan: Learning a Generative Model from a
    Single Natural Image, ICCV'19.

    Args:
        in_channels (int): Input channels.
        num_scales (int): The number of scales/stages in generator. Note
            that this number is counted from zero, which is the same as the
            original paper.
        kernel_size (int, optional): Kernel size, same as :obj:`nn.Conv2d`.
            Defaults to 3.
        padding (int, optional): Padding for the convolutional layer, same as
            :obj:`nn.Conv2d`. Defaults to 0.
        num_layers (int, optional): The number of convolutional layers in each
            generator block. Defaults to 5.
        base_channels (int, optional): The basic channels for convolutional
            layers in the generator block. Defaults to 32.
        min_feat_channels (int, optional): Minimum channels for the feature
            maps in the generator block. Defaults to 32.
    """

    def __init__(self,
                 in_channels,
                 num_scales,
                 kernel_size=3,
                 padding=0,
                 num_layers=5,
                 base_channels=32,
                 min_feat_channels=32,
                 **kwargs):
        super().__init__()
        self.blocks = nn.ModuleList()
        for scale in range(num_scales + 1):
            base_ch = min(base_channels * pow(2, int(np.floor(scale / 4))),
                          128)
            min_feat_ch = min(
                min_feat_channels * pow(2, int(np.floor(scale / 4))), 128)
            self.blocks.append(
                DiscriminatorBlock(
                    in_channels=in_channels,
                    kernel_size=kernel_size,
                    padding=padding,
                    num_layers=num_layers,
                    base_channels=base_ch,
                    min_feat_channels=min_feat_ch,
                    **kwargs))

    def forward(self, x, curr_scale):
        """Forward function.

        Args:
            x (Tensor): Input feature map.
            curr_scale (int): Current scale for discriminator. If in testing,
                you need to set it to the last scale.

        Returns:
            Tensor: Discriminative results.
        """
        out = self.blocks[curr_scale](x)
        return out

    def check_and_load_prev_weight(self, curr_scale):
        if curr_scale == 0:
            return
        prev_ch = self.blocks[curr_scale - 1].base_channels
        curr_ch = self.blocks[curr_scale].base_channels
        if prev_ch == curr_ch:
            self.blocks[curr_scale].load_state_dict(
                self.blocks[curr_scale - 1].state_dict())
            print_log('Successfully load pretrianed model from last scale.')
        else:
            print_log('Cannot load pretrained model from last scale since'
                      f' prev_ch({prev_ch}) != curr_ch({curr_ch})')