test_lsgan_archs.py 4.24 KB
Newer Older
limm's avatar
limm committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch

from mmgen.models import LSGANDiscriminator, LSGANGenerator, build_module


class TestLSGANGenerator(object):

    @classmethod
    def setup_class(cls):
        cls.noise = torch.randn((3, 128))
        cls.default_config = dict(
            type='LSGANGenerator', noise_size=128, output_scale=128)

    def test_lsgan_generator(self):

        # test default setting with builder
        g = build_module(self.default_config)
        assert isinstance(g, LSGANGenerator)
        x = g(None, num_batches=3)
        assert x.shape == (3, 3, 128, 128)
        x = g(None, num_batches=3, return_noise=True)
        assert x['noise_batch'].shape == (3, 128)
        x = g(self.noise, return_noise=True)
        assert x['noise_batch'].shape == (3, 128)
        x = g(torch.randn, num_batches=3, return_noise=True)
        assert x['noise_batch'].shape == (3, 128)

        # test different output_scale
        config = dict(type='LSGANGenerator', noise_size=128, output_scale=64)
        g = build_module(config)
        assert isinstance(g, LSGANGenerator)
        x = g(None, num_batches=3)
        assert x.shape == (3, 3, 64, 64)

    @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
    def test_lsgan_generator_cuda(self):

        # test default setting with builder
        g = build_module(self.default_config).cuda()
        assert isinstance(g, LSGANGenerator)
        x = g(None, num_batches=3)
        assert x.shape == (3, 3, 128, 128)
        x = g(None, num_batches=3, return_noise=True)
        assert x['noise_batch'].shape == (3, 128)
        x = g(self.noise.cuda(), return_noise=True)
        assert x['noise_batch'].shape == (3, 128)
        x = g(torch.randn, num_batches=3, return_noise=True)
        assert x['noise_batch'].shape == (3, 128)

        # test different output_scale
        config = dict(type='LSGANGenerator', noise_size=128, output_scale=64)
        g = build_module(config).cuda()
        assert isinstance(g, LSGANGenerator)
        x = g(None, num_batches=3)
        assert x.shape == (3, 3, 64, 64)


class TestLSGANDiscriminator(object):

    @classmethod
    def setup_class(cls):
        cls.x = torch.randn((2, 3, 128, 128))
        cls.default_config = dict(
            type='LSGANDiscriminator', in_channels=3, input_scale=128)

    def test_lsgan_discriminator(self):

        # test default setting with builder
        d = build_module(self.default_config)
        assert isinstance(d, LSGANDiscriminator)
        score = d(self.x)
        assert score.shape == (2, 1)

        # test different input_scale
        config = dict(type='LSGANDiscriminator', in_channels=3, input_scale=64)
        d = build_module(config)
        assert isinstance(d, LSGANDiscriminator)
        x = torch.randn((2, 3, 64, 64))
        score = d(x)
        assert score.shape == (2, 1)

        # test different config
        config = dict(
            type='LSGANDiscriminator',
            in_channels=3,
            input_scale=64,
            out_act_cfg=dict(type='Sigmoid'))
        d = build_module(config)
        assert isinstance(d, LSGANDiscriminator)
        x = torch.randn((2, 3, 64, 64))
        score = d(x)
        assert score.shape == (2, 1)

    @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
    def test_lsgan_discriminator_cuda(self):

        # test default setting with builder
        d = build_module(self.default_config).cuda()
        assert isinstance(d, LSGANDiscriminator)
        score = d(self.x.cuda())
        assert score.shape == (2, 1)

        # test different input_scale
        config = dict(type='LSGANDiscriminator', in_channels=3, input_scale=64)
        d = build_module(config).cuda()
        assert isinstance(d, LSGANDiscriminator)
        x = torch.randn((2, 3, 64, 64))
        score = d(x.cuda())
        assert score.shape == (2, 1)

        # test different config
        config = dict(
            type='LSGANDiscriminator',
            in_channels=3,
            input_scale=64,
            out_act_cfg=dict(type='Sigmoid'))
        d = build_module(config).cuda()
        assert isinstance(d, LSGANDiscriminator)
        x = torch.randn((2, 3, 64, 64))
        score = d(x.cuda())
        assert score.shape == (2, 1)