test_scheduler.py 35.3 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Patrick von Platen's avatar
Patrick von Platen committed
15
import tempfile
Patrick von Platen's avatar
Patrick von Platen committed
16
import unittest
17
from typing import Dict, List, Tuple
Patrick von Platen's avatar
Patrick von Platen committed
18

Patrick von Platen's avatar
Patrick von Platen committed
19
20
21
import numpy as np
import torch

22
from diffusers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler, ScoreSdeVeScheduler
Patrick von Platen's avatar
Patrick von Platen committed
23
24
25
26
27
28


torch.backends.cuda.matmul.allow_tf32 = False


class SchedulerCommonTest(unittest.TestCase):
Patrick von Platen's avatar
Patrick von Platen committed
29
30
    scheduler_classes = ()
    forward_default_kwargs = ()
Patrick von Platen's avatar
Patrick von Platen committed
31
32

    @property
33
    def dummy_sample(self):
Patrick von Platen's avatar
Patrick von Platen committed
34
35
36
37
38
        batch_size = 4
        num_channels = 3
        height = 8
        width = 8

39
        sample = torch.rand((batch_size, num_channels, height, width))
Patrick von Platen's avatar
Patrick von Platen committed
40

41
        return sample
Patrick von Platen's avatar
Patrick von Platen committed
42
43

    @property
44
    def dummy_sample_deter(self):
Patrick von Platen's avatar
Patrick von Platen committed
45
46
47
48
49
50
        batch_size = 4
        num_channels = 3
        height = 8
        width = 8

        num_elems = batch_size * num_channels * height * width
51
        sample = torch.arange(num_elems)
52
53
        sample = sample.reshape(num_channels, height, width, batch_size)
        sample = sample / num_elems
54
        sample = sample.permute(3, 0, 1, 2)
Patrick von Platen's avatar
Patrick von Platen committed
55

56
        return sample
Patrick von Platen's avatar
Patrick von Platen committed
57
58
59
60
61

    def get_scheduler_config(self):
        raise NotImplementedError

    def dummy_model(self):
62
63
        def model(sample, t, *args):
            return sample * t / (t + 1)
Patrick von Platen's avatar
Patrick von Platen committed
64
65
66

        return model

Patrick von Platen's avatar
Patrick von Platen committed
67
68
69
    def check_over_configs(self, time_step=0, **config):
        kwargs = dict(self.forward_default_kwargs)

70
71
        num_inference_steps = kwargs.pop("num_inference_steps", None)

Patrick von Platen's avatar
Patrick von Platen committed
72
        for scheduler_class in self.scheduler_classes:
73
74
            sample = self.dummy_sample
            residual = 0.1 * sample
Patrick von Platen's avatar
Patrick von Platen committed
75
76
77
78
79
80
81
82

            scheduler_config = self.get_scheduler_config(**config)
            scheduler = scheduler_class(**scheduler_config)

            with tempfile.TemporaryDirectory() as tmpdirname:
                scheduler.save_config(tmpdirname)
                new_scheduler = scheduler_class.from_config(tmpdirname)

83
84
85
86
87
88
            if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
                scheduler.set_timesteps(num_inference_steps)
                new_scheduler.set_timesteps(num_inference_steps)
            elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
                kwargs["num_inference_steps"] = num_inference_steps

89
90
            output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
            new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
Patrick von Platen's avatar
Patrick von Platen committed
91

92
            assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
Patrick von Platen's avatar
Patrick von Platen committed
93
94
95
96
97

    def check_over_forward(self, time_step=0, **forward_kwargs):
        kwargs = dict(self.forward_default_kwargs)
        kwargs.update(forward_kwargs)

98
99
        num_inference_steps = kwargs.pop("num_inference_steps", None)

Patrick von Platen's avatar
Patrick von Platen committed
100
        for scheduler_class in self.scheduler_classes:
101
102
            sample = self.dummy_sample
            residual = 0.1 * sample
Patrick von Platen's avatar
Patrick von Platen committed
103
104
105
106
107
108
109
110

            scheduler_config = self.get_scheduler_config()
            scheduler = scheduler_class(**scheduler_config)

            with tempfile.TemporaryDirectory() as tmpdirname:
                scheduler.save_config(tmpdirname)
                new_scheduler = scheduler_class.from_config(tmpdirname)

111
112
113
114
115
116
117
            if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
                scheduler.set_timesteps(num_inference_steps)
                new_scheduler.set_timesteps(num_inference_steps)
            elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
                kwargs["num_inference_steps"] = num_inference_steps

            torch.manual_seed(0)
118
            output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
119
            torch.manual_seed(0)
120
            new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
Patrick von Platen's avatar
Patrick von Platen committed
121

122
            assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
Patrick von Platen's avatar
Patrick von Platen committed
123

Patrick von Platen's avatar
Patrick von Platen committed
124
    def test_from_pretrained_save_pretrained(self):
Patrick von Platen's avatar
Patrick von Platen committed
125
126
        kwargs = dict(self.forward_default_kwargs)

127
128
        num_inference_steps = kwargs.pop("num_inference_steps", None)

Patrick von Platen's avatar
Patrick von Platen committed
129
        for scheduler_class in self.scheduler_classes:
130
131
            sample = self.dummy_sample
            residual = 0.1 * sample
Patrick von Platen's avatar
Patrick von Platen committed
132
133
134
135
136
137
138
139

            scheduler_config = self.get_scheduler_config()
            scheduler = scheduler_class(**scheduler_config)

            with tempfile.TemporaryDirectory() as tmpdirname:
                scheduler.save_config(tmpdirname)
                new_scheduler = scheduler_class.from_config(tmpdirname)

140
141
142
143
144
145
            if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
                scheduler.set_timesteps(num_inference_steps)
                new_scheduler.set_timesteps(num_inference_steps)
            elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
                kwargs["num_inference_steps"] = num_inference_steps

146
            torch.manual_seed(0)
147
            output = scheduler.step(residual, 1, sample, **kwargs).prev_sample
148
            torch.manual_seed(0)
149
            new_output = new_scheduler.step(residual, 1, sample, **kwargs).prev_sample
Patrick von Platen's avatar
Patrick von Platen committed
150

151
            assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
Patrick von Platen's avatar
Patrick von Platen committed
152
153
154
155

    def test_step_shape(self):
        kwargs = dict(self.forward_default_kwargs)

156
157
        num_inference_steps = kwargs.pop("num_inference_steps", None)

Patrick von Platen's avatar
Patrick von Platen committed
158
159
160
161
        for scheduler_class in self.scheduler_classes:
            scheduler_config = self.get_scheduler_config()
            scheduler = scheduler_class(**scheduler_config)

162
163
            sample = self.dummy_sample
            residual = 0.1 * sample
Patrick von Platen's avatar
Patrick von Platen committed
164

165
166
167
168
169
            if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
                scheduler.set_timesteps(num_inference_steps)
            elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
                kwargs["num_inference_steps"] = num_inference_steps

170
171
            output_0 = scheduler.step(residual, 0, sample, **kwargs).prev_sample
            output_1 = scheduler.step(residual, 1, sample, **kwargs).prev_sample
Patrick von Platen's avatar
Patrick von Platen committed
172

173
            self.assertEqual(output_0.shape, sample.shape)
Patrick von Platen's avatar
Patrick von Platen committed
174
175
            self.assertEqual(output_0.shape, output_1.shape)

176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
    def test_scheduler_outputs_equivalence(self):
        def set_nan_tensor_to_zero(t):
            t[t != t] = 0
            return t

        def recursive_check(tuple_object, dict_object):
            if isinstance(tuple_object, (List, Tuple)):
                for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
                    recursive_check(tuple_iterable_value, dict_iterable_value)
            elif isinstance(tuple_object, Dict):
                for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
                    recursive_check(tuple_iterable_value, dict_iterable_value)
            elif tuple_object is None:
                return
            else:
                self.assertTrue(
                    torch.allclose(
                        set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
                    ),
                    msg=(
                        "Tuple and dict output are not equal. Difference:"
                        f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
                        f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
                        f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
                    ),
                )

        kwargs = dict(self.forward_default_kwargs)
204
        num_inference_steps = kwargs.pop("num_inference_steps", 50)
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228

        for scheduler_class in self.scheduler_classes:
            scheduler_config = self.get_scheduler_config()
            scheduler = scheduler_class(**scheduler_config)

            sample = self.dummy_sample
            residual = 0.1 * sample

            if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
                scheduler.set_timesteps(num_inference_steps)
            elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
                kwargs["num_inference_steps"] = num_inference_steps

            outputs_dict = scheduler.step(residual, 0, sample, **kwargs)

            if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
                scheduler.set_timesteps(num_inference_steps)
            elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
                kwargs["num_inference_steps"] = num_inference_steps

            outputs_tuple = scheduler.step(residual, 0, sample, return_dict=False, **kwargs)

            recursive_check(outputs_tuple, outputs_dict)

229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
    def test_scheduler_public_api(self):
        for scheduler_class in self.scheduler_classes:
            scheduler_config = self.get_scheduler_config()
            scheduler = scheduler_class(**scheduler_config)
            self.assertTrue(
                hasattr(scheduler, "init_noise_sigma"),
                f"{scheduler_class} does not implement a required attribute `init_noise_sigma`",
            )
            self.assertTrue(
                hasattr(scheduler, "scale_model_input"),
                f"{scheduler_class} does not implement a required class method `scale_model_input(sample, timestep)`",
            )
            self.assertTrue(
                hasattr(scheduler, "step"),
                f"{scheduler_class} does not implement a required class method `step(...)`",
            )

            sample = self.dummy_sample
            scaled_sample = scheduler.scale_model_input(sample, 0.0)
            self.assertEqual(sample.shape, scaled_sample.shape)

Patrick von Platen's avatar
Patrick von Platen committed
250
251

class DDPMSchedulerTest(SchedulerCommonTest):
Patrick von Platen's avatar
Patrick von Platen committed
252
    scheduler_classes = (DDPMScheduler,)
Patrick von Platen's avatar
Patrick von Platen committed
253
254
255

    def get_scheduler_config(self, **kwargs):
        config = {
Nathan Lambert's avatar
Nathan Lambert committed
256
            "num_train_timesteps": 1000,
Patrick von Platen's avatar
Patrick von Platen committed
257
258
259
260
            "beta_start": 0.0001,
            "beta_end": 0.02,
            "beta_schedule": "linear",
            "variance_type": "fixed_small",
Patrick von Platen's avatar
Patrick von Platen committed
261
            "clip_sample": True,
Patrick von Platen's avatar
Patrick von Platen committed
262
263
264
265
        }

        config.update(**kwargs)
        return config
Patrick von Platen's avatar
update  
Patrick von Platen committed
266

Patrick von Platen's avatar
Patrick von Platen committed
267
268
    def test_timesteps(self):
        for timesteps in [1, 5, 100, 1000]:
Nathan Lambert's avatar
Nathan Lambert committed
269
            self.check_over_configs(num_train_timesteps=timesteps)
Patrick von Platen's avatar
Patrick von Platen committed
270
271
272
273
274
275
276
277
278
279
280
281
282

    def test_betas(self):
        for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
            self.check_over_configs(beta_start=beta_start, beta_end=beta_end)

    def test_schedules(self):
        for schedule in ["linear", "squaredcos_cap_v2"]:
            self.check_over_configs(beta_schedule=schedule)

    def test_variance_type(self):
        for variance in ["fixed_small", "fixed_large", "other"]:
            self.check_over_configs(variance_type=variance)

283
    def test_clip_sample(self):
Patrick von Platen's avatar
Patrick von Platen committed
284
285
        for clip_sample in [True, False]:
            self.check_over_configs(clip_sample=clip_sample)
Patrick von Platen's avatar
Patrick von Platen committed
286
287
288
289
290
291
292
293
294
295

    def test_time_indices(self):
        for t in [0, 500, 999]:
            self.check_over_forward(time_step=t)

    def test_variance(self):
        scheduler_class = self.scheduler_classes[0]
        scheduler_config = self.get_scheduler_config()
        scheduler = scheduler_class(**scheduler_config)

296
297
298
299
        assert torch.sum(torch.abs(scheduler._get_variance(0) - 0.0)) < 1e-5
        assert torch.sum(torch.abs(scheduler._get_variance(487) - 0.00979)) < 1e-5
        assert torch.sum(torch.abs(scheduler._get_variance(999) - 0.02)) < 1e-5

Patrick von Platen's avatar
Patrick von Platen committed
300
301
    def test_full_loop_no_noise(self):
        scheduler_class = self.scheduler_classes[0]
Patrick von Platen's avatar
Patrick von Platen committed
302
        scheduler_config = self.get_scheduler_config()
Patrick von Platen's avatar
Patrick von Platen committed
303
304
305
306
307
        scheduler = scheduler_class(**scheduler_config)

        num_trained_timesteps = len(scheduler)

        model = self.dummy_model()
308
        sample = self.dummy_sample_deter
309
        generator = torch.manual_seed(0)
Patrick von Platen's avatar
Patrick von Platen committed
310
311
312

        for t in reversed(range(num_trained_timesteps)):
            # 1. predict noise residual
313
            residual = model(sample, t)
Patrick von Platen's avatar
Patrick von Platen committed
314

315
            # 2. predict previous mean of sample x_t-1
316
            pred_prev_sample = scheduler.step(residual, t, sample, generator=generator).prev_sample
Patrick von Platen's avatar
Patrick von Platen committed
317

318
319
320
321
322
323
            # if t > 0:
            #     noise = self.dummy_sample_deter
            #     variance = scheduler.get_variance(t) ** (0.5) * noise
            #
            # sample = pred_prev_sample + variance
            sample = pred_prev_sample
Patrick von Platen's avatar
Patrick von Platen committed
324

325
326
        result_sum = torch.sum(torch.abs(sample))
        result_mean = torch.mean(torch.abs(sample))
Patrick von Platen's avatar
Patrick von Platen committed
327

328
        assert abs(result_sum.item() - 258.9070) < 1e-2
329
        assert abs(result_mean.item() - 0.3374) < 1e-3
Patrick von Platen's avatar
Patrick von Platen committed
330

Patrick von Platen's avatar
update  
Patrick von Platen committed
331

Patrick von Platen's avatar
Patrick von Platen committed
332
333
class DDIMSchedulerTest(SchedulerCommonTest):
    scheduler_classes = (DDIMScheduler,)
334
    forward_default_kwargs = (("eta", 0.0), ("num_inference_steps", 50))
Patrick von Platen's avatar
update  
Patrick von Platen committed
335

Patrick von Platen's avatar
Patrick von Platen committed
336
337
    def get_scheduler_config(self, **kwargs):
        config = {
Nathan Lambert's avatar
Nathan Lambert committed
338
            "num_train_timesteps": 1000,
Patrick von Platen's avatar
Patrick von Platen committed
339
340
341
            "beta_start": 0.0001,
            "beta_end": 0.02,
            "beta_schedule": "linear",
Patrick von Platen's avatar
Patrick von Platen committed
342
            "clip_sample": True,
Patrick von Platen's avatar
Patrick von Platen committed
343
        }
Patrick von Platen's avatar
Patrick von Platen committed
344

Patrick von Platen's avatar
Patrick von Platen committed
345
346
347
        config.update(**kwargs)
        return config

348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
    def full_loop(self, **config):
        scheduler_class = self.scheduler_classes[0]
        scheduler_config = self.get_scheduler_config(**config)
        scheduler = scheduler_class(**scheduler_config)

        num_inference_steps, eta = 10, 0.0

        model = self.dummy_model()
        sample = self.dummy_sample_deter

        scheduler.set_timesteps(num_inference_steps)

        for t in scheduler.timesteps:
            residual = model(sample, t)
            sample = scheduler.step(residual, t, sample, eta).prev_sample

        return sample

Patrick von Platen's avatar
Patrick von Platen committed
366
    def test_timesteps(self):
367
        for timesteps in [100, 500, 1000]:
Nathan Lambert's avatar
Nathan Lambert committed
368
            self.check_over_configs(num_train_timesteps=timesteps)
Patrick von Platen's avatar
Patrick von Platen committed
369

370
371
372
373
374
375
376
377
    def test_steps_offset(self):
        for steps_offset in [0, 1]:
            self.check_over_configs(steps_offset=steps_offset)

        scheduler_class = self.scheduler_classes[0]
        scheduler_config = self.get_scheduler_config(steps_offset=1)
        scheduler = scheduler_class(**scheduler_config)
        scheduler.set_timesteps(5)
378
        assert torch.equal(scheduler.timesteps, torch.LongTensor([801, 601, 401, 201, 1]))
379

Patrick von Platen's avatar
Patrick von Platen committed
380
381
382
383
384
385
386
387
    def test_betas(self):
        for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
            self.check_over_configs(beta_start=beta_start, beta_end=beta_end)

    def test_schedules(self):
        for schedule in ["linear", "squaredcos_cap_v2"]:
            self.check_over_configs(beta_schedule=schedule)

388
    def test_clip_sample(self):
Patrick von Platen's avatar
Patrick von Platen committed
389
390
        for clip_sample in [True, False]:
            self.check_over_configs(clip_sample=clip_sample)
Patrick von Platen's avatar
Patrick von Platen committed
391
392
393
394
395
396
397

    def test_time_indices(self):
        for t in [1, 10, 49]:
            self.check_over_forward(time_step=t)

    def test_inference_steps(self):
        for t, num_inference_steps in zip([1, 10, 50], [10, 50, 500]):
398
            self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps)
Patrick von Platen's avatar
Patrick von Platen committed
399
400
401
402
403
404
405

    def test_eta(self):
        for t, eta in zip([1, 10, 49], [0.0, 0.5, 1.0]):
            self.check_over_forward(time_step=t, eta=eta)

    def test_variance(self):
        scheduler_class = self.scheduler_classes[0]
Patrick von Platen's avatar
Patrick von Platen committed
406
        scheduler_config = self.get_scheduler_config()
Patrick von Platen's avatar
Patrick von Platen committed
407
408
        scheduler = scheduler_class(**scheduler_config)

409
410
411
412
413
414
        assert torch.sum(torch.abs(scheduler._get_variance(0, 0) - 0.0)) < 1e-5
        assert torch.sum(torch.abs(scheduler._get_variance(420, 400) - 0.14771)) < 1e-5
        assert torch.sum(torch.abs(scheduler._get_variance(980, 960) - 0.32460)) < 1e-5
        assert torch.sum(torch.abs(scheduler._get_variance(0, 0) - 0.0)) < 1e-5
        assert torch.sum(torch.abs(scheduler._get_variance(487, 486) - 0.00979)) < 1e-5
        assert torch.sum(torch.abs(scheduler._get_variance(999, 998) - 0.02)) < 1e-5
Patrick von Platen's avatar
Patrick von Platen committed
415
416

    def test_full_loop_no_noise(self):
417
        sample = self.full_loop()
Patrick von Platen's avatar
Patrick von Platen committed
418

419
420
        result_sum = torch.sum(torch.abs(sample))
        result_mean = torch.mean(torch.abs(sample))
Patrick von Platen's avatar
Patrick von Platen committed
421

422
423
        assert abs(result_sum.item() - 172.0067) < 1e-2
        assert abs(result_mean.item() - 0.223967) < 1e-3
Patrick von Platen's avatar
Patrick von Platen committed
424

425
426
427
428
429
    def test_full_loop_with_set_alpha_to_one(self):
        # We specify different beta, so that the first alpha is 0.99
        sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01)
        result_sum = torch.sum(torch.abs(sample))
        result_mean = torch.mean(torch.abs(sample))
Patrick von Platen's avatar
Patrick von Platen committed
430

431
432
        assert abs(result_sum.item() - 149.8295) < 1e-2
        assert abs(result_mean.item() - 0.1951) < 1e-3
Patrick von Platen's avatar
Patrick von Platen committed
433

434
435
436
    def test_full_loop_with_no_set_alpha_to_one(self):
        # We specify different beta, so that the first alpha is 0.99
        sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01)
437
438
        result_sum = torch.sum(torch.abs(sample))
        result_mean = torch.mean(torch.abs(sample))
Patrick von Platen's avatar
Patrick von Platen committed
439

440
441
        assert abs(result_sum.item() - 149.0784) < 1e-2
        assert abs(result_mean.item() - 0.1941) < 1e-3
Patrick von Platen's avatar
Patrick von Platen committed
442
443
444
445
446
447
448
449


class PNDMSchedulerTest(SchedulerCommonTest):
    scheduler_classes = (PNDMScheduler,)
    forward_default_kwargs = (("num_inference_steps", 50),)

    def get_scheduler_config(self, **kwargs):
        config = {
Nathan Lambert's avatar
Nathan Lambert committed
450
            "num_train_timesteps": 1000,
Patrick von Platen's avatar
Patrick von Platen committed
451
452
453
454
455
456
457
458
            "beta_start": 0.0001,
            "beta_end": 0.02,
            "beta_schedule": "linear",
        }

        config.update(**kwargs)
        return config

459
    def check_over_configs(self, time_step=0, **config):
Patrick von Platen's avatar
Patrick von Platen committed
460
        kwargs = dict(self.forward_default_kwargs)
Patrick von Platen's avatar
Patrick von Platen committed
461
        num_inference_steps = kwargs.pop("num_inference_steps", None)
462
463
        sample = self.dummy_sample
        residual = 0.1 * sample
Patrick von Platen's avatar
Patrick von Platen committed
464
465
466
467
468
        dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]

        for scheduler_class in self.scheduler_classes:
            scheduler_config = self.get_scheduler_config(**config)
            scheduler = scheduler_class(**scheduler_config)
Patrick von Platen's avatar
Patrick von Platen committed
469
            scheduler.set_timesteps(num_inference_steps)
Patrick von Platen's avatar
Patrick von Platen committed
470
471
472
473
474
475
            # copy over dummy past residuals
            scheduler.ets = dummy_past_residuals[:]

            with tempfile.TemporaryDirectory() as tmpdirname:
                scheduler.save_config(tmpdirname)
                new_scheduler = scheduler_class.from_config(tmpdirname)
Patrick von Platen's avatar
Patrick von Platen committed
476
                new_scheduler.set_timesteps(num_inference_steps)
Patrick von Platen's avatar
Patrick von Platen committed
477
478
479
                # copy over dummy past residuals
                new_scheduler.ets = dummy_past_residuals[:]

480
481
            output = scheduler.step_prk(residual, time_step, sample, **kwargs).prev_sample
            new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs).prev_sample
Patrick von Platen's avatar
Patrick von Platen committed
482

483
            assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
Patrick von Platen's avatar
Patrick von Platen committed
484

485
486
            output = scheduler.step_plms(residual, time_step, sample, **kwargs).prev_sample
            new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs).prev_sample
487

488
            assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
489
490
491
492
493

    def test_from_pretrained_save_pretrained(self):
        pass

    def check_over_forward(self, time_step=0, **forward_kwargs):
Patrick von Platen's avatar
Patrick von Platen committed
494
        kwargs = dict(self.forward_default_kwargs)
Patrick von Platen's avatar
Patrick von Platen committed
495
        num_inference_steps = kwargs.pop("num_inference_steps", None)
496
497
        sample = self.dummy_sample
        residual = 0.1 * sample
Patrick von Platen's avatar
Patrick von Platen committed
498
499
500
501
502
        dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]

        for scheduler_class in self.scheduler_classes:
            scheduler_config = self.get_scheduler_config()
            scheduler = scheduler_class(**scheduler_config)
Patrick von Platen's avatar
Patrick von Platen committed
503
            scheduler.set_timesteps(num_inference_steps)
504

Nathan Lambert's avatar
Nathan Lambert committed
505
            # copy over dummy past residuals (must be after setting timesteps)
Patrick von Platen's avatar
Patrick von Platen committed
506
507
508
509
510
511
            scheduler.ets = dummy_past_residuals[:]

            with tempfile.TemporaryDirectory() as tmpdirname:
                scheduler.save_config(tmpdirname)
                new_scheduler = scheduler_class.from_config(tmpdirname)
                # copy over dummy past residuals
Patrick von Platen's avatar
Patrick von Platen committed
512
                new_scheduler.set_timesteps(num_inference_steps)
Patrick von Platen's avatar
Patrick von Platen committed
513

Nathan Lambert's avatar
Nathan Lambert committed
514
515
516
                # copy over dummy past residual (must be after setting timesteps)
                new_scheduler.ets = dummy_past_residuals[:]

517
518
            output = scheduler.step_prk(residual, time_step, sample, **kwargs).prev_sample
            new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs).prev_sample
519

520
            assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
521

522
523
            output = scheduler.step_plms(residual, time_step, sample, **kwargs).prev_sample
            new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs).prev_sample
Patrick von Platen's avatar
Patrick von Platen committed
524

525
            assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
Patrick von Platen's avatar
Patrick von Platen committed
526

527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
    def full_loop(self, **config):
        scheduler_class = self.scheduler_classes[0]
        scheduler_config = self.get_scheduler_config(**config)
        scheduler = scheduler_class(**scheduler_config)

        num_inference_steps = 10
        model = self.dummy_model()
        sample = self.dummy_sample_deter
        scheduler.set_timesteps(num_inference_steps)

        for i, t in enumerate(scheduler.prk_timesteps):
            residual = model(sample, t)
            sample = scheduler.step_prk(residual, t, sample).prev_sample

        for i, t in enumerate(scheduler.plms_timesteps):
            residual = model(sample, t)
            sample = scheduler.step_plms(residual, t, sample).prev_sample

        return sample

547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
    def test_step_shape(self):
        kwargs = dict(self.forward_default_kwargs)

        num_inference_steps = kwargs.pop("num_inference_steps", None)

        for scheduler_class in self.scheduler_classes:
            scheduler_config = self.get_scheduler_config()
            scheduler = scheduler_class(**scheduler_config)

            sample = self.dummy_sample
            residual = 0.1 * sample

            if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
                scheduler.set_timesteps(num_inference_steps)
            elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
                kwargs["num_inference_steps"] = num_inference_steps

Nathan Lambert's avatar
Nathan Lambert committed
564
565
566
567
            # copy over dummy past residuals (must be done after set_timesteps)
            dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
            scheduler.ets = dummy_past_residuals[:]

568
569
            output_0 = scheduler.step_prk(residual, 0, sample, **kwargs).prev_sample
            output_1 = scheduler.step_prk(residual, 1, sample, **kwargs).prev_sample
570
571
572
573

            self.assertEqual(output_0.shape, sample.shape)
            self.assertEqual(output_0.shape, output_1.shape)

574
575
            output_0 = scheduler.step_plms(residual, 0, sample, **kwargs).prev_sample
            output_1 = scheduler.step_plms(residual, 1, sample, **kwargs).prev_sample
576
577
578
579

            self.assertEqual(output_0.shape, sample.shape)
            self.assertEqual(output_0.shape, output_1.shape)

Patrick von Platen's avatar
Patrick von Platen committed
580
581
    def test_timesteps(self):
        for timesteps in [100, 1000]:
Nathan Lambert's avatar
Nathan Lambert committed
582
            self.check_over_configs(num_train_timesteps=timesteps)
Patrick von Platen's avatar
Patrick von Platen committed
583

584
585
586
587
588
589
590
591
    def test_steps_offset(self):
        for steps_offset in [0, 1]:
            self.check_over_configs(steps_offset=steps_offset)

        scheduler_class = self.scheduler_classes[0]
        scheduler_config = self.get_scheduler_config(steps_offset=1)
        scheduler = scheduler_class(**scheduler_config)
        scheduler.set_timesteps(10)
592
        assert torch.equal(
593
            scheduler.timesteps,
594
595
596
597
            torch.LongTensor(
                [901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1]
            ),
        )
598

Patrick von Platen's avatar
Patrick von Platen committed
599
    def test_betas(self):
600
        for beta_start, beta_end in zip([0.0001, 0.001], [0.002, 0.02]):
Patrick von Platen's avatar
Patrick von Platen committed
601
602
603
604
605
606
607
608
609
610
611
612
            self.check_over_configs(beta_start=beta_start, beta_end=beta_end)

    def test_schedules(self):
        for schedule in ["linear", "squaredcos_cap_v2"]:
            self.check_over_configs(beta_schedule=schedule)

    def test_time_indices(self):
        for t in [1, 5, 10]:
            self.check_over_forward(time_step=t)

    def test_inference_steps(self):
        for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]):
613
            self.check_over_forward(num_inference_steps=num_inference_steps)
Patrick von Platen's avatar
Patrick von Platen committed
614

615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
    def test_pow_of_3_inference_steps(self):
        # earlier version of set_timesteps() caused an error indexing alpha's with inference steps as power of 3
        num_inference_steps = 27

        for scheduler_class in self.scheduler_classes:
            sample = self.dummy_sample
            residual = 0.1 * sample

            scheduler_config = self.get_scheduler_config()
            scheduler = scheduler_class(**scheduler_config)

            scheduler.set_timesteps(num_inference_steps)

            # before power of 3 fix, would error on first step, so we only need to do two
            for i, t in enumerate(scheduler.prk_timesteps[:2]):
                sample = scheduler.step_prk(residual, t, sample).prev_sample

632
    def test_inference_plms_no_past_residuals(self):
Patrick von Platen's avatar
Patrick von Platen committed
633
634
635
636
637
        with self.assertRaises(ValueError):
            scheduler_class = self.scheduler_classes[0]
            scheduler_config = self.get_scheduler_config()
            scheduler = scheduler_class(**scheduler_config)

638
            scheduler.step_plms(self.dummy_sample, 1, self.dummy_sample).prev_sample
Patrick von Platen's avatar
Patrick von Platen committed
639
640

    def test_full_loop_no_noise(self):
641
642
643
        sample = self.full_loop()
        result_sum = torch.sum(torch.abs(sample))
        result_mean = torch.mean(torch.abs(sample))
Patrick von Platen's avatar
Patrick von Platen committed
644

645
646
        assert abs(result_sum.item() - 198.1318) < 1e-2
        assert abs(result_mean.item() - 0.2580) < 1e-3
Patrick von Platen's avatar
Patrick von Platen committed
647

648
649
650
651
652
    def test_full_loop_with_set_alpha_to_one(self):
        # We specify different beta, so that the first alpha is 0.99
        sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01)
        result_sum = torch.sum(torch.abs(sample))
        result_mean = torch.mean(torch.abs(sample))
Patrick von Platen's avatar
Patrick von Platen committed
653

654
655
        assert abs(result_sum.item() - 230.0399) < 1e-2
        assert abs(result_mean.item() - 0.2995) < 1e-3
Patrick von Platen's avatar
Patrick von Platen committed
656

657
658
659
    def test_full_loop_with_no_set_alpha_to_one(self):
        # We specify different beta, so that the first alpha is 0.99
        sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01)
660
661
        result_sum = torch.sum(torch.abs(sample))
        result_mean = torch.mean(torch.abs(sample))
Patrick von Platen's avatar
Patrick von Platen committed
662

663
664
        assert abs(result_sum.item() - 186.9482) < 1e-2
        assert abs(result_mean.item() - 0.2434) < 1e-3
Nathan Lambert's avatar
Nathan Lambert committed
665
666


667
668
class ScoreSdeVeSchedulerTest(unittest.TestCase):
    # TODO adapt with class SchedulerCommonTest (scheduler needs Numpy Integration)
Nathan Lambert's avatar
Nathan Lambert committed
669
    scheduler_classes = (ScoreSdeVeScheduler,)
670
    forward_default_kwargs = ()
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702

    @property
    def dummy_sample(self):
        batch_size = 4
        num_channels = 3
        height = 8
        width = 8

        sample = torch.rand((batch_size, num_channels, height, width))

        return sample

    @property
    def dummy_sample_deter(self):
        batch_size = 4
        num_channels = 3
        height = 8
        width = 8

        num_elems = batch_size * num_channels * height * width
        sample = torch.arange(num_elems)
        sample = sample.reshape(num_channels, height, width, batch_size)
        sample = sample / num_elems
        sample = sample.permute(3, 0, 1, 2)

        return sample

    def dummy_model(self):
        def model(sample, t, *args):
            return sample * t / (t + 1)

        return model
Nathan Lambert's avatar
Nathan Lambert committed
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729

    def get_scheduler_config(self, **kwargs):
        config = {
            "num_train_timesteps": 2000,
            "snr": 0.15,
            "sigma_min": 0.01,
            "sigma_max": 1348,
            "sampling_eps": 1e-5,
        }

        config.update(**kwargs)
        return config

    def check_over_configs(self, time_step=0, **config):
        kwargs = dict(self.forward_default_kwargs)

        for scheduler_class in self.scheduler_classes:
            sample = self.dummy_sample
            residual = 0.1 * sample

            scheduler_config = self.get_scheduler_config(**config)
            scheduler = scheduler_class(**scheduler_config)

            with tempfile.TemporaryDirectory() as tmpdirname:
                scheduler.save_config(tmpdirname)
                new_scheduler = scheduler_class.from_config(tmpdirname)

730
731
732
733
734
735
            output = scheduler.step_pred(
                residual, time_step, sample, generator=torch.manual_seed(0), **kwargs
            ).prev_sample
            new_output = new_scheduler.step_pred(
                residual, time_step, sample, generator=torch.manual_seed(0), **kwargs
            ).prev_sample
Nathan Lambert's avatar
Nathan Lambert committed
736

737
            assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
Nathan Lambert's avatar
Nathan Lambert committed
738

739
740
741
742
            output = scheduler.step_correct(residual, sample, generator=torch.manual_seed(0), **kwargs).prev_sample
            new_output = new_scheduler.step_correct(
                residual, sample, generator=torch.manual_seed(0), **kwargs
            ).prev_sample
Nathan Lambert's avatar
Nathan Lambert committed
743

744
            assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler correction are not identical"
Nathan Lambert's avatar
Nathan Lambert committed
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760

    def check_over_forward(self, time_step=0, **forward_kwargs):
        kwargs = dict(self.forward_default_kwargs)
        kwargs.update(forward_kwargs)

        for scheduler_class in self.scheduler_classes:
            sample = self.dummy_sample
            residual = 0.1 * sample

            scheduler_config = self.get_scheduler_config()
            scheduler = scheduler_class(**scheduler_config)

            with tempfile.TemporaryDirectory() as tmpdirname:
                scheduler.save_config(tmpdirname)
                new_scheduler = scheduler_class.from_config(tmpdirname)

761
762
763
764
765
766
            output = scheduler.step_pred(
                residual, time_step, sample, generator=torch.manual_seed(0), **kwargs
            ).prev_sample
            new_output = new_scheduler.step_pred(
                residual, time_step, sample, generator=torch.manual_seed(0), **kwargs
            ).prev_sample
Nathan Lambert's avatar
Nathan Lambert committed
767

768
            assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
Nathan Lambert's avatar
Nathan Lambert committed
769

770
771
772
773
            output = scheduler.step_correct(residual, sample, generator=torch.manual_seed(0), **kwargs).prev_sample
            new_output = new_scheduler.step_correct(
                residual, sample, generator=torch.manual_seed(0), **kwargs
            ).prev_sample
Nathan Lambert's avatar
Nathan Lambert committed
774

775
            assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler correction are not identical"
Nathan Lambert's avatar
Nathan Lambert committed
776
777
778
779
780
781
782
783
784
785

    def test_timesteps(self):
        for timesteps in [10, 100, 1000]:
            self.check_over_configs(num_train_timesteps=timesteps)

    def test_sigmas(self):
        for sigma_min, sigma_max in zip([0.0001, 0.001, 0.01], [1, 100, 1000]):
            self.check_over_configs(sigma_min=sigma_min, sigma_max=sigma_max)

    def test_time_indices(self):
786
        for t in [0.1, 0.5, 0.75]:
Nathan Lambert's avatar
Nathan Lambert committed
787
788
789
            self.check_over_forward(time_step=t)

    def test_full_loop_no_noise(self):
790
791
        kwargs = dict(self.forward_default_kwargs)

Nathan Lambert's avatar
Nathan Lambert committed
792
793
794
795
796
797
798
799
800
801
        scheduler_class = self.scheduler_classes[0]
        scheduler_config = self.get_scheduler_config()
        scheduler = scheduler_class(**scheduler_config)

        num_inference_steps = 3

        model = self.dummy_model()
        sample = self.dummy_sample_deter

        scheduler.set_sigmas(num_inference_steps)
802
        scheduler.set_timesteps(num_inference_steps)
803
        generator = torch.manual_seed(0)
Nathan Lambert's avatar
Nathan Lambert committed
804
805
806
807

        for i, t in enumerate(scheduler.timesteps):
            sigma_t = scheduler.sigmas[i]

808
            for _ in range(scheduler.config.correct_steps):
Nathan Lambert's avatar
Nathan Lambert committed
809
                with torch.no_grad():
810
                    model_output = model(sample, sigma_t)
811
                sample = scheduler.step_correct(model_output, sample, generator=generator, **kwargs).prev_sample
Nathan Lambert's avatar
Nathan Lambert committed
812
813

            with torch.no_grad():
814
                model_output = model(sample, sigma_t)
Patrick von Platen's avatar
Patrick von Platen committed
815

816
            output = scheduler.step_pred(model_output, t, sample, generator=generator, **kwargs)
817
            sample, _ = output.prev_sample, output.prev_sample_mean
Patrick von Platen's avatar
Patrick von Platen committed
818

819
820
        result_sum = torch.sum(torch.abs(sample))
        result_mean = torch.mean(torch.abs(sample))
Patrick von Platen's avatar
Patrick von Platen committed
821

822
823
        assert np.isclose(result_sum.item(), 14372758528.0)
        assert np.isclose(result_mean.item(), 18714530.0)
Patrick von Platen's avatar
Patrick von Platen committed
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841

    def test_step_shape(self):
        kwargs = dict(self.forward_default_kwargs)

        num_inference_steps = kwargs.pop("num_inference_steps", None)

        for scheduler_class in self.scheduler_classes:
            scheduler_config = self.get_scheduler_config()
            scheduler = scheduler_class(**scheduler_config)

            sample = self.dummy_sample
            residual = 0.1 * sample

            if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
                scheduler.set_timesteps(num_inference_steps)
            elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
                kwargs["num_inference_steps"] = num_inference_steps

842
843
            output_0 = scheduler.step_pred(residual, 0, sample, generator=torch.manual_seed(0), **kwargs).prev_sample
            output_1 = scheduler.step_pred(residual, 1, sample, generator=torch.manual_seed(0), **kwargs).prev_sample
Patrick von Platen's avatar
Patrick von Platen committed
844
845
846

            self.assertEqual(output_0.shape, sample.shape)
            self.assertEqual(output_0.shape, output_1.shape)
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869


class LMSDiscreteSchedulerTest(SchedulerCommonTest):
    scheduler_classes = (LMSDiscreteScheduler,)
    num_inference_steps = 10

    def get_scheduler_config(self, **kwargs):
        config = {
            "num_train_timesteps": 1100,
            "beta_start": 0.0001,
            "beta_end": 0.02,
            "beta_schedule": "linear",
            "trained_betas": None,
        }

        config.update(**kwargs)
        return config

    def test_timesteps(self):
        for timesteps in [10, 50, 100, 1000]:
            self.check_over_configs(num_train_timesteps=timesteps)

    def test_betas(self):
870
        for beta_start, beta_end in zip([0.00001, 0.0001, 0.001], [0.0002, 0.002, 0.02]):
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
            self.check_over_configs(beta_start=beta_start, beta_end=beta_end)

    def test_schedules(self):
        for schedule in ["linear", "scaled_linear"]:
            self.check_over_configs(beta_schedule=schedule)

    def test_time_indices(self):
        for t in [0, 500, 800]:
            self.check_over_forward(time_step=t)

    def test_full_loop_no_noise(self):
        scheduler_class = self.scheduler_classes[0]
        scheduler_config = self.get_scheduler_config()
        scheduler = scheduler_class(**scheduler_config)

        scheduler.set_timesteps(self.num_inference_steps)

        model = self.dummy_model()
889
        sample = self.dummy_sample_deter * scheduler.init_noise_sigma
890
891

        for i, t in enumerate(scheduler.timesteps):
892
            sample = scheduler.scale_model_input(sample, t)
893
894
895

            model_output = model(sample, t)

896
            output = scheduler.step(model_output, t, sample)
897
898
899
900
901
            sample = output.prev_sample

        result_sum = torch.sum(torch.abs(sample))
        result_mean = torch.mean(torch.abs(sample))

902
        assert abs(result_sum.item() - 1006.388) < 1e-2
903
        assert abs(result_mean.item() - 1.31) < 1e-3