test_scheduler.py 11.4 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


Patrick von Platen's avatar
Patrick von Platen committed
17
import tempfile
Patrick von Platen's avatar
Patrick von Platen committed
18
import unittest
Patrick von Platen's avatar
Patrick von Platen committed
19

Patrick von Platen's avatar
Patrick von Platen committed
20
21
22
import numpy as np
import torch

Patrick von Platen's avatar
Patrick von Platen committed
23
from diffusers import DDIMScheduler, DDPMScheduler
Patrick von Platen's avatar
Patrick von Platen committed
24
25
26
27
28
29


torch.backends.cuda.matmul.allow_tf32 = False


class SchedulerCommonTest(unittest.TestCase):
Patrick von Platen's avatar
Patrick von Platen committed
30
31
    scheduler_classes = ()
    forward_default_kwargs = ()
Patrick von Platen's avatar
Patrick von Platen committed
32
33
34
35
36
37
38
39
40
41

    @property
    def dummy_image(self):
        batch_size = 4
        num_channels = 3
        height = 8
        width = 8

        image = np.random.rand(batch_size, num_channels, height, width)

Patrick von Platen's avatar
Patrick von Platen committed
42
        return image
Patrick von Platen's avatar
Patrick von Platen committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56

    @property
    def dummy_image_deter(self):
        batch_size = 4
        num_channels = 3
        height = 8
        width = 8

        num_elems = batch_size * num_channels * height * width
        image = np.arange(num_elems)
        image = image.reshape(num_channels, height, width, batch_size)
        image = image / num_elems
        image = image.transpose(3, 0, 1, 2)

Patrick von Platen's avatar
Patrick von Platen committed
57
        return image
Patrick von Platen's avatar
Patrick von Platen committed
58
59
60
61
62

    def get_scheduler_config(self):
        raise NotImplementedError

    def dummy_model(self):
Patrick von Platen's avatar
Patrick von Platen committed
63
64
        def model(image, t, *args):
            return image * t / (t + 1)
Patrick von Platen's avatar
Patrick von Platen committed
65
66
67

        return model

Patrick von Platen's avatar
Patrick von Platen committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
    def check_over_configs(self, time_step=0, **config):
        kwargs = dict(self.forward_default_kwargs)

        for scheduler_class in self.scheduler_classes:
            scheduler_class = self.scheduler_classes[0]
            image = self.dummy_image
            residual = 0.1 * image

            scheduler_config = self.get_scheduler_config(**config)
            scheduler = scheduler_class(**scheduler_config)

            with tempfile.TemporaryDirectory() as tmpdirname:
                scheduler.save_config(tmpdirname)
                new_scheduler = scheduler_class.from_config(tmpdirname)

            output = scheduler.step(residual, image, time_step, **kwargs)
            new_output = new_scheduler.step(residual, image, time_step, **kwargs)

Patrick von Platen's avatar
Patrick von Platen committed
86
            assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
Patrick von Platen's avatar
Patrick von Platen committed
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106

    def check_over_forward(self, time_step=0, **forward_kwargs):
        kwargs = dict(self.forward_default_kwargs)
        kwargs.update(forward_kwargs)

        for scheduler_class in self.scheduler_classes:
            scheduler_class = self.scheduler_classes[0]
            image = self.dummy_image
            residual = 0.1 * image

            scheduler_config = self.get_scheduler_config()
            scheduler = scheduler_class(**scheduler_config)

            with tempfile.TemporaryDirectory() as tmpdirname:
                scheduler.save_config(tmpdirname)
                new_scheduler = scheduler_class.from_config(tmpdirname)

            output = scheduler.step(residual, image, time_step, **kwargs)
            new_output = new_scheduler.step(residual, image, time_step, **kwargs)

Patrick von Platen's avatar
Patrick von Platen committed
107
            assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
Patrick von Platen's avatar
Patrick von Platen committed
108

Patrick von Platen's avatar
Patrick von Platen committed
109
    def test_from_pretrained_save_pretrained(self):
Patrick von Platen's avatar
Patrick von Platen committed
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
        kwargs = dict(self.forward_default_kwargs)

        for scheduler_class in self.scheduler_classes:
            image = self.dummy_image
            residual = 0.1 * image

            scheduler_config = self.get_scheduler_config()
            scheduler = scheduler_class(**scheduler_config)

            with tempfile.TemporaryDirectory() as tmpdirname:
                scheduler.save_config(tmpdirname)
                new_scheduler = scheduler_class.from_config(tmpdirname)

            output = scheduler.step(residual, image, 1, **kwargs)
            new_output = new_scheduler.step(residual, image, 1, **kwargs)

Patrick von Platen's avatar
Patrick von Platen committed
126
            assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
Patrick von Platen's avatar
Patrick von Platen committed
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143

    def test_step_shape(self):
        kwargs = dict(self.forward_default_kwargs)

        for scheduler_class in self.scheduler_classes:
            scheduler_config = self.get_scheduler_config()
            scheduler = scheduler_class(**scheduler_config)

            image = self.dummy_image
            residual = 0.1 * image

            output_0 = scheduler.step(residual, image, 0, **kwargs)
            output_1 = scheduler.step(residual, image, 1, **kwargs)

            self.assertEqual(output_0.shape, image.shape)
            self.assertEqual(output_0.shape, output_1.shape)

Patrick von Platen's avatar
Patrick von Platen committed
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
    def test_pytorch_equal_numpy(self):
        kwargs = dict(self.forward_default_kwargs)

        for scheduler_class in self.scheduler_classes:
            image = self.dummy_image
            residual = 0.1 * image

            image_pt = torch.tensor(image)
            residual_pt = 0.1 * image_pt

            scheduler_config = self.get_scheduler_config()
            scheduler = scheduler_class(**scheduler_config)

            scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config)

            output = scheduler.step(residual, image, 1, **kwargs)
            output_pt = scheduler_pt.step(residual_pt, image_pt, 1, **kwargs)

            assert np.sum(np.abs(output - output_pt.numpy())) < 1e-5, "Scheduler outputs are not identical"

Patrick von Platen's avatar
Patrick von Platen committed
164
165

class DDPMSchedulerTest(SchedulerCommonTest):
Patrick von Platen's avatar
Patrick von Platen committed
166
    scheduler_classes = (DDPMScheduler,)
Patrick von Platen's avatar
Patrick von Platen committed
167
168
169
170
171
172
173
174

    def get_scheduler_config(self, **kwargs):
        config = {
            "timesteps": 1000,
            "beta_start": 0.0001,
            "beta_end": 0.02,
            "beta_schedule": "linear",
            "variance_type": "fixed_small",
Patrick von Platen's avatar
Patrick von Platen committed
175
            "clip_predicted_sample": True,
Patrick von Platen's avatar
Patrick von Platen committed
176
177
178
179
        }

        config.update(**kwargs)
        return config
Patrick von Platen's avatar
update  
Patrick von Platen committed
180

Patrick von Platen's avatar
Patrick von Platen committed
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
    def test_timesteps(self):
        for timesteps in [1, 5, 100, 1000]:
            self.check_over_configs(timesteps=timesteps)

    def test_betas(self):
        for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
            self.check_over_configs(beta_start=beta_start, beta_end=beta_end)

    def test_schedules(self):
        for schedule in ["linear", "squaredcos_cap_v2"]:
            self.check_over_configs(beta_schedule=schedule)

    def test_variance_type(self):
        for variance in ["fixed_small", "fixed_large", "other"]:
            self.check_over_configs(variance_type=variance)

    def test_clip_image(self):
Patrick von Platen's avatar
Patrick von Platen committed
198
199
        for clip_predicted_sample in [True, False]:
            self.check_over_configs(clip_predicted_sample=clip_predicted_sample)
Patrick von Platen's avatar
Patrick von Platen committed
200
201
202
203
204
205
206
207
208
209

    def test_time_indices(self):
        for t in [0, 500, 999]:
            self.check_over_forward(time_step=t)

    def test_variance(self):
        scheduler_class = self.scheduler_classes[0]
        scheduler_config = self.get_scheduler_config()
        scheduler = scheduler_class(**scheduler_config)

Patrick von Platen's avatar
Patrick von Platen committed
210
211
212
        assert np.sum(np.abs(scheduler.get_variance(0) - 0.0)) < 1e-5
        assert np.sum(np.abs(scheduler.get_variance(487) - 0.00979)) < 1e-5
        assert np.sum(np.abs(scheduler.get_variance(999) - 0.02)) < 1e-5
Patrick von Platen's avatar
Patrick von Platen committed
213
214
215

    def test_full_loop_no_noise(self):
        scheduler_class = self.scheduler_classes[0]
Patrick von Platen's avatar
Patrick von Platen committed
216
        scheduler_config = self.get_scheduler_config()
Patrick von Platen's avatar
Patrick von Platen committed
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
        scheduler = scheduler_class(**scheduler_config)

        num_trained_timesteps = len(scheduler)

        model = self.dummy_model()
        image = self.dummy_image_deter

        for t in reversed(range(num_trained_timesteps)):
            # 1. predict noise residual
            residual = model(image, t)

            # 2. predict previous mean of image x_t-1
            pred_prev_image = scheduler.step(residual, image, t)

            if t > 0:
                noise = self.dummy_image_deter
Patrick von Platen's avatar
Patrick von Platen committed
233
                variance = scheduler.get_variance(t) ** (0.5) * noise
Patrick von Platen's avatar
Patrick von Platen committed
234
235
236

            image = pred_prev_image + variance

Patrick von Platen's avatar
Patrick von Platen committed
237
238
        result_sum = np.sum(np.abs(image))
        result_mean = np.mean(np.abs(image))
Patrick von Platen's avatar
Patrick von Platen committed
239
240
241

        assert result_sum.item() - 732.9947 < 1e-3
        assert result_mean.item() - 0.9544 < 1e-3
Patrick von Platen's avatar
Patrick von Platen committed
242

Patrick von Platen's avatar
update  
Patrick von Platen committed
243

Patrick von Platen's avatar
Patrick von Platen committed
244
245
246
class DDIMSchedulerTest(SchedulerCommonTest):
    scheduler_classes = (DDIMScheduler,)
    forward_default_kwargs = (("num_inference_steps", 50), ("eta", 0.0))
Patrick von Platen's avatar
update  
Patrick von Platen committed
247

Patrick von Platen's avatar
Patrick von Platen committed
248
249
250
251
252
253
    def get_scheduler_config(self, **kwargs):
        config = {
            "timesteps": 1000,
            "beta_start": 0.0001,
            "beta_end": 0.02,
            "beta_schedule": "linear",
Patrick von Platen's avatar
Patrick von Platen committed
254
            "clip_predicted_sample": True,
Patrick von Platen's avatar
Patrick von Platen committed
255
        }
Patrick von Platen's avatar
Patrick von Platen committed
256

Patrick von Platen's avatar
Patrick von Platen committed
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
        config.update(**kwargs)
        return config

    def test_timesteps(self):
        for timesteps in [1, 5, 100, 1000]:
            self.check_over_configs(timesteps=timesteps)

    def test_betas(self):
        for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
            self.check_over_configs(beta_start=beta_start, beta_end=beta_end)

    def test_schedules(self):
        for schedule in ["linear", "squaredcos_cap_v2"]:
            self.check_over_configs(beta_schedule=schedule)

    def test_clip_image(self):
Patrick von Platen's avatar
Patrick von Platen committed
273
274
        for clip_predicted_sample in [True, False]:
            self.check_over_configs(clip_predicted_sample=clip_predicted_sample)
Patrick von Platen's avatar
Patrick von Platen committed
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289

    def test_time_indices(self):
        for t in [1, 10, 49]:
            self.check_over_forward(time_step=t)

    def test_inference_steps(self):
        for t, num_inference_steps in zip([1, 10, 50], [10, 50, 500]):
            self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps)

    def test_eta(self):
        for t, eta in zip([1, 10, 49], [0.0, 0.5, 1.0]):
            self.check_over_forward(time_step=t, eta=eta)

    def test_variance(self):
        scheduler_class = self.scheduler_classes[0]
Patrick von Platen's avatar
Patrick von Platen committed
290
        scheduler_config = self.get_scheduler_config()
Patrick von Platen's avatar
Patrick von Platen committed
291
292
        scheduler = scheduler_class(**scheduler_config)

Patrick von Platen's avatar
Patrick von Platen committed
293
294
295
296
297
298
        assert np.sum(np.abs(scheduler.get_variance(0, 50) - 0.0)) < 1e-5
        assert np.sum(np.abs(scheduler.get_variance(21, 50) - 0.14771)) < 1e-5
        assert np.sum(np.abs(scheduler.get_variance(49, 50) - 0.32460)) < 1e-5
        assert np.sum(np.abs(scheduler.get_variance(0, 1000) - 0.0)) < 1e-5
        assert np.sum(np.abs(scheduler.get_variance(487, 1000) - 0.00979)) < 1e-5
        assert np.sum(np.abs(scheduler.get_variance(999, 1000) - 0.02)) < 1e-5
Patrick von Platen's avatar
Patrick von Platen committed
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320

    def test_full_loop_no_noise(self):
        scheduler_class = self.scheduler_classes[0]
        scheduler_config = self.get_scheduler_config()
        scheduler = scheduler_class(**scheduler_config)

        num_inference_steps, eta = 10, 0.1
        num_trained_timesteps = len(scheduler)

        inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)

        model = self.dummy_model()
        image = self.dummy_image_deter

        for t in reversed(range(num_inference_steps)):
            residual = model(image, inference_step_times[t])

            pred_prev_image = scheduler.step(residual, image, t, num_inference_steps, eta)

            variance = 0
            if eta > 0:
                noise = self.dummy_image_deter
Patrick von Platen's avatar
Patrick von Platen committed
321
                variance = scheduler.get_variance(t, num_inference_steps) ** (0.5) * eta * noise
Patrick von Platen's avatar
Patrick von Platen committed
322

Patrick von Platen's avatar
Patrick von Platen committed
323
            image = pred_prev_image + variance
Patrick von Platen's avatar
Patrick von Platen committed
324

Patrick von Platen's avatar
Patrick von Platen committed
325
326
        result_sum = np.sum(np.abs(image))
        result_mean = np.mean(np.abs(image))
Patrick von Platen's avatar
Patrick von Platen committed
327

Patrick von Platen's avatar
Patrick von Platen committed
328
329
        assert result_sum.item() - 270.6214 < 1e-3
        assert result_mean.item() - 0.3524 < 1e-3