test_pggan_archs.py 16.2 KB
Newer Older
limm's avatar
limm committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy

import pytest
import torch
import torch.nn as nn

from mmgen.models import (EqualizedLR, EqualizedLRConvDownModule,
                          EqualizedLRConvModule, EqualizedLRConvUpModule,
                          EqualizedLRLinearModule, MiniBatchStddevLayer,
                          PGGANNoiseTo2DFeat, PixelNorm, equalized_lr)
from mmgen.models.architectures.pggan import PGGANDiscriminator, PGGANGenerator


class TestEqualizedLR:

    @classmethod
    def setup_class(cls):
        cls.default_conv_cfg = dict(
            in_channels=1,
            out_channels=1,
            kernel_size=3,
            stride=1,
            padding=1,
            norm_cfg=dict(type='BN'))
        cls.conv_input = torch.randn((2, 1, 5, 5))
        cls.linear_input = torch.randn((2, 2))

    def test_equalized_conv_module(self):
        conv = EqualizedLRConvModule(**self.default_conv_cfg)
        res = conv(self.conv_input)
        assert res.shape == (2, 1, 5, 5)
        has_equalized_lr = False
        for _, v in conv.conv._forward_pre_hooks.items():
            if isinstance(v, EqualizedLR):
                has_equalized_lr = True
        assert has_equalized_lr

        conv = EqualizedLRConvModule(
            equalized_lr_cfg=None, **self.default_conv_cfg)
        res = conv(self.conv_input)
        assert res.shape == (2, 1, 5, 5)
        has_equalized_lr = False
        for _, v in conv.conv._forward_pre_hooks.items():
            if isinstance(v, EqualizedLR):
                has_equalized_lr = True
        assert not has_equalized_lr

        conv = EqualizedLRConvModule(
            equalized_lr_cfg=dict(gain=1), **self.default_conv_cfg)
        res = conv(self.conv_input)
        assert res.shape == (2, 1, 5, 5)
        has_equalized_lr = False
        for _, v in conv.conv._forward_pre_hooks.items():
            if isinstance(v, EqualizedLR):
                assert v.gain == 1
                has_equalized_lr = True
        assert has_equalized_lr

    @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
    def test_equalized_conv_module_cuda(self):
        conv = EqualizedLRConvModule(**self.default_conv_cfg).cuda()
        res = conv(self.conv_input.cuda())
        assert res.shape == (2, 1, 5, 5)
        has_equalized_lr = False
        for _, v in conv.conv._forward_pre_hooks.items():
            if isinstance(v, EqualizedLR):
                has_equalized_lr = True
        assert has_equalized_lr

    def test_equalized_linear_module(self):
        linear = EqualizedLRLinearModule(2, 2)
        res = linear(self.linear_input)
        assert res.shape == (2, 2)
        has_equalized_lr = False
        for _, v in linear._forward_pre_hooks.items():
            if isinstance(v, EqualizedLR):
                has_equalized_lr = True
        assert has_equalized_lr

        linear = EqualizedLRLinearModule(2, 2, equalized_lr_cfg=None)
        res = linear(self.linear_input)
        assert res.shape == (2, 2)
        has_equalized_lr = False
        for _, v in linear._forward_pre_hooks.items():
            if isinstance(v, EqualizedLR):
                has_equalized_lr = True
        assert not has_equalized_lr

    @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
    def test_equalized_linear_module_cuda(self):
        linear = EqualizedLRLinearModule(2, 2).cuda()
        res = linear(self.linear_input.cuda())
        assert res.shape == (2, 2)
        has_equalized_lr = False
        for _, v in linear._forward_pre_hooks.items():
            if isinstance(v, EqualizedLR):
                has_equalized_lr = True
        assert has_equalized_lr

    def test_equalized_lr(self):
        with pytest.raises(RuntimeError):
            conv = nn.Conv2d(1, 1, 3, 1, 1)
            conv = equalized_lr(conv)
            conv = equalized_lr(conv)


class TestEqualizedLRConvUpModule:

    @classmethod
    def setup_class(cls):
        cls.default_cfg = dict(
            in_channels=3,
            out_channels=1,
            kernel_size=3,
            padding=1,
            stride=2,
            conv_cfg=dict(type='deconv'),
            upsample=dict(type='fused_nn'),
            norm_cfg=dict(type='PixelNorm'))
        cls.default_input = torch.randn((2, 3, 5, 5))

    def test_equalized_lr_convup_module(self, ):
        convup = EqualizedLRConvUpModule(**self.default_cfg)

        res = convup(self.default_input)
        assert res.shape == (2, 1, 10, 10)
        # test bp
        res = convup(torch.randn((2, 3, 5, 5), requires_grad=True))
        assert res.shape == (2, 1, 10, 10)
        res.mean().backward()

        # test nearest
        cfg_ = deepcopy(self.default_cfg)
        cfg_['upsample'] = dict(type='nearest', scale_factor=2)
        cfg_['kernel_size'] = 4
        convup = EqualizedLRConvUpModule(**cfg_)

        res = convup(self.default_input)
        assert res.shape == (2, 1, 20, 20)

        # test nearest
        cfg_ = deepcopy(self.default_cfg)
        cfg_['upsample'] = None
        cfg_['kernel_size'] = 4
        convup = EqualizedLRConvUpModule(**cfg_)

        res = convup(self.default_input)
        assert res.shape == (2, 1, 10, 10)

    @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
    def test_equalized_lr_convup_module_cuda(self):
        convup = EqualizedLRConvUpModule(**self.default_cfg).cuda()

        res = convup(self.default_input.cuda())
        assert res.shape == (2, 1, 10, 10)
        # test bp
        res = convup(torch.randn((2, 3, 5, 5), requires_grad=True).cuda())
        assert res.shape == (2, 1, 10, 10)
        res.mean().backward()


class TestEqualizedLRConvDownModule:

    @classmethod
    def setup_class(cls):
        cls.default_cfg = dict(
            in_channels=3,
            out_channels=1,
            kernel_size=3,
            padding=1,
            stride=2,
            downsample=dict(type='fused_pool'))
        cls.default_input = torch.randn((2, 3, 8, 8))

    def test_equalized_lr_conv_down(self):
        convdown = EqualizedLRConvDownModule(**self.default_cfg)
        res = convdown(self.default_input)
        assert res.shape == (2, 1, 4, 4)
        # test bp
        res = convdown(torch.randn((2, 3, 8, 8), requires_grad=True))
        assert res.shape == (2, 1, 4, 4)
        res.mean().backward()

        # test avg pool
        cfg_ = deepcopy(self.default_cfg)
        cfg_['downsample'] = dict(type='avgpool', kernel_size=2, stride=2)
        convdown = EqualizedLRConvDownModule(**cfg_)
        res = convdown(self.default_input)
        assert res.shape == (2, 1, 2, 2)

        # test downsample is None
        cfg_ = deepcopy(self.default_cfg)
        cfg_['downsample'] = None
        convdown = EqualizedLRConvDownModule(**cfg_)
        res = convdown(self.default_input)
        assert res.shape == (2, 1, 4, 4)

        with pytest.raises(NotImplementedError):
            cfg_ = deepcopy(self.default_cfg)
            cfg_['downsample'] = dict(type='xxx', kernel_size=2, stride=2)
            _ = EqualizedLRConvDownModule(**cfg_)

    @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
    def test_equalized_lr_conv_down_cuda(self):
        convdown = EqualizedLRConvDownModule(**self.default_cfg).cuda()
        res = convdown(self.default_input.cuda())
        assert res.shape == (2, 1, 4, 4)
        # test bp
        res = convdown(torch.randn((2, 3, 8, 8), requires_grad=True).cuda())
        assert res.shape == (2, 1, 4, 4)
        res.mean().backward()


class TestPixelNorm:

    @classmethod
    def setup_class(cls):
        cls.input_tensor = torch.randn((2, 3, 4, 4))

    def test_pixel_norm(self):
        pn = PixelNorm()
        res = pn(self.input_tensor)
        assert res.shape == (2, 3, 4, 4)

        # test zero case
        res = pn(self.input_tensor * 0)
        assert res.shape == (2, 3, 4, 4)
        assert (res == 0).all()

    @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
    def test_pixel_norm_cuda(self):
        pn = PixelNorm().cuda()
        res = pn(self.input_tensor.cuda())
        assert res.shape == (2, 3, 4, 4)

        # test zero case
        res = pn(self.input_tensor.cuda() * 0)
        assert res.shape == (2, 3, 4, 4)
        assert (res == 0).all()


class TestMiniBatchStddevLayer:

    @classmethod
    def setup_class(cls):
        cls.default_input = torch.randn((2, 3, 4, 4))

    def test_minibatch_stddev_layer(self):
        ministd_layer = MiniBatchStddevLayer()
        res = ministd_layer(self.default_input)
        assert res.shape == (2, 4, 4, 4)

        with pytest.raises(AssertionError):
            _ = ministd_layer(torch.randn((5, 4, 3, 3)))

        ministd_layer = MiniBatchStddevLayer(group_size=3)
        res = ministd_layer(torch.randn((2, 6, 4, 4)))
        assert res.shape == (2, 7, 4, 4)

        # test bp
        ministd_layer = MiniBatchStddevLayer()
        res = ministd_layer(self.default_input.requires_grad_())
        assert res.shape == (2, 4, 4, 4)
        res.mean().backward()

    @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
    def test_minibatch_stddev_layer_cuda(self):
        ministd_layer = MiniBatchStddevLayer().cuda()
        res = ministd_layer(self.default_input.cuda())
        assert res.shape == (2, 4, 4, 4)

        ministd_layer = MiniBatchStddevLayer(group_size=3).cuda()
        res = ministd_layer(torch.randn((2, 6, 4, 4)).cuda())
        assert res.shape == (2, 7, 4, 4)

        # test bp
        ministd_layer = MiniBatchStddevLayer().cuda()
        res = ministd_layer(self.default_input.requires_grad_().cuda())
        assert res.shape == (2, 4, 4, 4)
        res.mean().backward()


class TestPGGANNoiseTo2DFeat:

    @classmethod
    def setup_class(cls):
        cls.default_input = torch.randn((2, 10))
        cls.default_cfg = dict(noise_size=10, out_channels=1)

    def test_pggan_noise2feat(self):
        module = PGGANNoiseTo2DFeat(**self.default_cfg)
        res = module(self.default_input)
        assert res.shape == (2, 1, 4, 4)
        assert isinstance(module.linear, EqualizedLRLinearModule)
        assert not module.linear.bias
        assert module.with_norm
        assert isinstance(module.norm, PixelNorm)
        assert isinstance(module.activation, nn.LeakyReLU)

        module = PGGANNoiseTo2DFeat(**self.default_cfg, act_cfg=None)
        res = module(self.default_input)
        assert res.shape == (2, 1, 4, 4)
        assert isinstance(module.linear, EqualizedLRLinearModule)
        assert not module.linear.bias
        assert module.with_norm
        assert not module.with_activation

        module = PGGANNoiseTo2DFeat(
            **self.default_cfg, norm_cfg=None, normalize_latent=False)
        res = module(self.default_input)
        assert res.shape == (2, 1, 4, 4)
        assert isinstance(module.linear, EqualizedLRLinearModule)
        assert not module.linear.bias
        assert not module.with_norm
        assert isinstance(module.activation, nn.LeakyReLU)

        with pytest.raises(AssertionError):
            _ = module(torch.randn((2, 1, 2)))

    @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
    def test_pggan_noise2feat_cuda(self):
        module = PGGANNoiseTo2DFeat(**self.default_cfg).cuda()
        res = module(self.default_input.cuda())
        assert res.shape == (2, 1, 4, 4)
        assert isinstance(module.linear, EqualizedLRLinearModule)
        assert not module.linear.bias
        assert module.with_norm
        assert isinstance(module.activation, nn.LeakyReLU)


class TestPGGANGenerator:

    @classmethod
    def setup_class(cls):
        cls.default_noise = torch.randn((2, 8))
        cls.default_cfg = dict(
            noise_size=8, out_scale=16, base_channels=32, max_channels=32)

    def test_pggan_generator(self):
        # test with default cfg
        gen = PGGANGenerator(**self.default_cfg)
        res = gen(None, num_batches=2, transition_weight=0.1)
        assert res.shape == (2, 3, 16, 16)

        res = gen(self.default_noise, transition_weight=0.2)
        assert res.shape == (2, 3, 16, 16)
        with pytest.raises(AssertionError):
            _ = gen(self.default_noise[:, :, None], transition_weight=0.2)

        with pytest.raises(AssertionError):
            _ = gen(torch.randn((2, 1)), transition_weight=0.2)

        res = gen(torch.randn, num_batches=2, transition_weight=0.2)
        assert res.shape == (2, 3, 16, 16)

        # test with input scale
        res = gen(None, num_batches=2, curr_scale=4)
        assert res.shape == (2, 3, 4, 4)
        res = gen(None, num_batches=2, curr_scale=8)
        assert res.shape == (2, 3, 8, 8)

        # test return noise
        res = gen(None, num_batches=2, curr_scale=8, return_noise=True)
        assert res['fake_img'].shape == (2, 3, 8, 8)
        assert res['label'] is None
        assert isinstance(res['noise_batch'], torch.Tensor)

        # test args system
        cfg = deepcopy(self.default_cfg)
        cfg['out_scale'] = 32
        gen = PGGANGenerator(**cfg)
        res = gen(None, num_batches=2, transition_weight=0.1)
        assert res.shape == (2, 3, 32, 32)

        cfg = deepcopy(self.default_cfg)
        cfg['out_scale'] = 4
        gen = PGGANGenerator(**cfg)
        res = gen(None, num_batches=2, transition_weight=0.1)
        assert res.shape == (2, 3, 4, 4)

    @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
    def test_pggan_generator_cuda(self):
        # test with default cfg
        gen = PGGANGenerator(**self.default_cfg).cuda()
        res = gen(None, num_batches=2, transition_weight=0.1)
        assert res.shape == (2, 3, 16, 16)

        # test args system
        cfg = deepcopy(self.default_cfg)
        cfg['out_scale'] = 32
        gen = PGGANGenerator(**cfg).cuda()
        res = gen(None, num_batches=2, transition_weight=0.1)
        assert res.shape == (2, 3, 32, 32)


class TestPGGANDiscriminator:

    @classmethod
    def setup_class(cls):
        cls.default_cfg = dict(in_scale=16, label_size=2)
        cls.default_inputx16 = torch.randn((2, 3, 16, 16))
        cls.default_inputx4 = torch.randn((2, 3, 4, 4))
        cls.default_inputx8 = torch.randn((2, 3, 8, 8))

    def test_pggan_discriminator(self):
        # test with default cfg
        disc = PGGANDiscriminator(**self.default_cfg)

        score, label = disc(self.default_inputx16, transition_weight=0.1)
        assert score.shape == (2, 1)
        assert label.shape == (2, 2)
        score, label = disc(
            self.default_inputx8, transition_weight=0.1, curr_scale=8)
        assert score.shape == (2, 1)
        assert label.shape == (2, 2)
        score, label = disc(
            self.default_inputx4, transition_weight=0.1, curr_scale=4)
        assert score.shape == (2, 1)
        assert label.shape == (2, 2)

        disc = PGGANDiscriminator(
            in_scale=16,
            mbstd_cfg=None,
            downsample_cfg=dict(type='nearest', scale_factor=0.5))

        score = disc(self.default_inputx16, transition_weight=0.1)
        assert score.shape == (2, 1)
        assert label.shape == (2, 2)
        score = disc(self.default_inputx8, transition_weight=0.1, curr_scale=8)
        assert score.shape == (2, 1)
        assert label.shape == (2, 2)
        score = disc(self.default_inputx4, transition_weight=0.1, curr_scale=4)
        assert score.shape == (2, 1)
        assert label.shape == (2, 2)
        assert not disc.with_mbstd

        with pytest.raises(NotImplementedError):
            _ = PGGANDiscriminator(
                in_scale=16, mbstd_cfg=None, downsample_cfg=dict(type='xx'))

    @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
    def test_pggan_discriminator_cuda(self):
        # test with default cfg
        disc = PGGANDiscriminator(**self.default_cfg).cuda()

        score, label = disc(
            self.default_inputx16.cuda(), transition_weight=0.1)
        assert score.shape == (2, 1)
        assert label.shape == (2, 2)
        score, label = disc(
            self.default_inputx8.cuda(), transition_weight=0.1, curr_scale=8)
        assert score.shape == (2, 1)
        assert label.shape == (2, 2)
        score, label = disc(
            self.default_inputx4.cuda(), transition_weight=0.1, curr_scale=4)
        assert score.shape == (2, 1)
        assert label.shape == (2, 2)