test_ddpm_denoising_archs.py 6.36 KB
Newer Older
limm's avatar
limm committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy

import pytest
import torch

from mmgen.models import DenoisingUnet, build_module


class TestDDPM:

    @classmethod
    def setup_class(cls):
        cls.denoising_cfg = dict(
            type='DenoisingUnet',
            image_size=32,
            in_channels=3,
            base_channels=128,
            resblocks_per_downsample=3,
            attention_res=[16, 8],
            use_scale_shift_norm=True,
            dropout=0,
            num_heads=4,
            use_rescale_timesteps=True,
            output_cfg=dict(mean='eps', var='learned_range'),
            num_timesteps=2000)
        cls.x_t = torch.randn(2, 3, 32, 32)
        cls.timesteps = torch.LongTensor([999, 1999])
        cls.label = torch.randint(0, 10, (2, ))

    def test_denoising_cpu(self):
        # test default config
        denoising = build_module(self.denoising_cfg)
        assert isinstance(denoising, DenoisingUnet)
        output_dict = denoising(self.x_t, self.timesteps, return_noise=True)
        assert 'eps_t_pred' in output_dict
        assert 'factor' in output_dict
        assert 'x_t' in output_dict
        assert 't_rescaled' in output_dict
        assert (output_dict['x_t'] == self.x_t).all()
        assert (output_dict['t_rescaled'] < 1000).all()
        assert (output_dict['factor'] < 1).all()
        assert (output_dict['factor'] > 0).all()

        # test image size --> list input
        config = deepcopy(self.denoising_cfg)
        config['image_size'] = [32, 32]
        output_dict = denoising(self.x_t, self.timesteps)
        assert 'eps_t_pred' in output_dict
        assert 'factor' in output_dict
        assert output_dict['eps_t_pred'].shape == (2, 3, 32, 32)

        # test image size --> raise type error
        config = deepcopy(self.denoising_cfg)
        config['image_size'] = '32'
        with pytest.raises(TypeError):
            build_module(config)

        # test image size --> wrong list length
        config = deepcopy(self.denoising_cfg)
        config['image_size'] = [32, 32, 32]
        with pytest.raises(AssertionError):
            build_module(config)

        # test image size --> wrong list element
        config = deepcopy(self.denoising_cfg)
        config['image_size'] = [32, 64]
        with pytest.raises(AssertionError):
            build_module(config)

        # test channels_cfg --> list
        config = deepcopy(self.denoising_cfg)
        config['channels_cfg'] = [1, 2, 2, 2]
        denoising = build_module(config)
        assert isinstance(denoising, DenoisingUnet)
        output_dict = denoising(self.x_t, self.timesteps)

        # test channels_cfg --> dict
        config = deepcopy(self.denoising_cfg)
        config['channels_cfg'] = {32: [1, 2, 2, 2, 2]}
        denoising = build_module(config)
        output_dict = denoising(self.x_t, self.timesteps)
        assert 'eps_t_pred' in output_dict
        assert 'factor' in output_dict
        assert (output_dict['factor'] < 1).all()
        assert (output_dict['factor'] > 0).all()

        # test channels_cfg --> no image size error
        config = deepcopy(self.denoising_cfg)
        config['image_size'] = 194
        with pytest.raises(KeyError):
            denoising = build_module(config)

        # test channels_cfg --> wrong type error
        config = deepcopy(self.denoising_cfg)
        config['channels_cfg'] = '1, 2, 2, 2'
        with pytest.raises(ValueError):
            denoising = build_module(config)

        # test use rescale timesteps
        config = deepcopy(self.denoising_cfg)
        config['use_rescale_timesteps'] = False
        denoising = build_module(config)
        output_dict = denoising(self.x_t, self.timesteps, return_noise=True)
        assert (output_dict['t_rescaled'] == self.timesteps).all()

        # test var_mode --> LEARNED
        config = deepcopy(self.denoising_cfg)
        config['output_cfg']['var'] = 'LEARNED'
        denoising = build_module(config)
        output_dict = denoising(self.x_t, self.timesteps, return_noise=True)
        assert 'logvar' in output_dict

        # test var_mode --> FIXED
        config = deepcopy(self.denoising_cfg)
        config['output_cfg']['var'] = 'FIXED_SMALL'
        denoising = build_module(config)
        output_dict = denoising(self.x_t, self.timesteps, return_noise=True)
        assert 'factor' not in output_dict and 'logvar' not in output_dict

        # test var_mode --> raise error
        config = deepcopy(self.denoising_cfg)
        config['output_cfg']['var'] = 'ERROR'
        denoising = build_module(config)
        with pytest.raises(AttributeError):
            output_dict = denoising(
                self.x_t, self.timesteps, return_noise=True)

        # test mean_mode --> START_X
        config = deepcopy(self.denoising_cfg)
        config['output_cfg']['mean'] = 'START_X'
        denoising = build_module(config)
        output_dict = denoising(self.x_t, self.timesteps, return_noise=True)
        assert 'x_0_pred' in output_dict

        # test mean_mode --> START_X
        config = deepcopy(self.denoising_cfg)
        config['output_cfg']['mean'] = 'PREVIOUS_X'
        denoising = build_module(config)
        output_dict = denoising(self.x_t, self.timesteps, return_noise=True)
        # print(output_dict.keys())
        assert 'x_tm1_pred' in output_dict

        # test var_mode --> raise error
        config = deepcopy(self.denoising_cfg)
        config['output_cfg']['mean'] = 'ERROR'
        denoising = build_module(config)
        with pytest.raises(AttributeError):
            output_dict = denoising(
                self.x_t, self.timesteps, return_noise=True)

        # test timestep embedding --> raise error
        config = deepcopy(self.denoising_cfg)
        config['time_embedding_mode'] = 'cos'
        with pytest.raises(ValueError):
            denoising = build_module(config)

        # test timestep embedding --> new config
        config = deepcopy(self.denoising_cfg)
        config['time_embedding_cfg'] = dict(max_period=1000)
        denoising = build_module(config)

        # test class-conditional denoising
        config = deepcopy(self.denoising_cfg)
        config['num_classes'] = 10
        denoising = build_module(config)
        output_dict = denoising(
            self.x_t, self.timesteps, self.label, return_noise=True)
        assert 'label' in output_dict
        assert (output_dict['label'] == self.label).all()