test_mspie_styelgan2.py 2.59 KB
Newer Older
limm's avatar
limm committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy

import torch

from mmgen.models.gans.mspie_stylegan2 import MSPIEStyleGAN2


class TestMSStyleGAN2:

    @classmethod
    def setup_class(cls):
        cls.generator_cfg = dict(
            type='MSStyleGANv2Generator', out_size=32, style_channels=16)
        cls.disc_cfg = dict(
            type='MSStyleGAN2Discriminator',
            in_size=32,
            with_adaptive_pool=True)
        cls.gan_loss = dict(type='GANLoss', gan_type='vanilla')
        cls.disc_auxiliary_loss = dict(
            type='R1GradientPenalty',
            loss_weight=10. / 2.,
            interval=1,
            norm_mode='HWC',
            data_info=dict(real_data='real_imgs', discriminator='disc'))

        cls.train_cfg = dict(
            use_ema=True,
            num_upblocks=3,
            multi_input_scales=[0, 2, 4],
            multi_scale_probability=[0.5, 0.25, 0.25])

    def test_msstylegan2_cpu(self):
        stylegan2 = MSPIEStyleGAN2(
            self.generator_cfg,
            self.disc_cfg,
            self.gan_loss,
            self.disc_auxiliary_loss,
            None,
            train_cfg=self.train_cfg,
            test_cfg=None)

        optimizer_g = torch.optim.SGD(
            stylegan2.generator.parameters(), lr=0.01)
        optimizer_d = torch.optim.SGD(
            stylegan2.discriminator.parameters(), lr=0.01)
        optim_dict = dict(generator=optimizer_g, discriminator=optimizer_d)

        data = torch.randn((2, 3, 16, 16))
        data_input = dict(real_img=data)

        model_outputs = stylegan2.train_step(data_input, optim_dict)
        assert 'results' in model_outputs
        assert 'log_vars' in model_outputs
        assert model_outputs['num_samples'] == 2

        cfg_ = deepcopy(self.train_cfg)
        cfg_['disc_steps'] = 2

        stylegan2 = MSPIEStyleGAN2(
            self.generator_cfg,
            self.disc_cfg,
            self.gan_loss,
            self.disc_auxiliary_loss,
            None,
            train_cfg=cfg_,
            test_cfg=None)

        optimizer_g = torch.optim.SGD(
            stylegan2.generator.parameters(), lr=0.01)
        optimizer_d = torch.optim.SGD(
            stylegan2.discriminator.parameters(), lr=0.01)
        optim_dict = dict(generator=optimizer_g, discriminator=optimizer_d)

        data = torch.randn((2, 3, 16, 16))
        data_input = dict(real_img=data)

        model_outputs = stylegan2.train_step(data_input, optim_dict)
        assert 'results' in model_outputs
        assert 'log_vars' in model_outputs
        assert model_outputs['num_samples'] == 2