"tools/distpartitioning/utils.py" did not exist on "7c598aac6c25fbee53e52f6bd54c2fd04bad2151"
test_scheduler.py 28.9 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
15
import pdb
Patrick von Platen's avatar
Patrick von Platen committed
16
import tempfile
Patrick von Platen's avatar
Patrick von Platen committed
17
import unittest
Patrick von Platen's avatar
Patrick von Platen committed
18

Patrick von Platen's avatar
Patrick von Platen committed
19
20
21
import numpy as np
import torch

Nathan Lambert's avatar
Nathan Lambert committed
22
from diffusers import DDIMScheduler, DDPMScheduler, PNDMScheduler, ScoreSdeVeScheduler
Patrick von Platen's avatar
Patrick von Platen committed
23
24
25
26
27
28


torch.backends.cuda.matmul.allow_tf32 = False


class SchedulerCommonTest(unittest.TestCase):
Patrick von Platen's avatar
Patrick von Platen committed
29
30
    scheduler_classes = ()
    forward_default_kwargs = ()
Patrick von Platen's avatar
Patrick von Platen committed
31
32

    @property
33
    def dummy_sample(self):
Patrick von Platen's avatar
Patrick von Platen committed
34
35
36
37
38
        batch_size = 4
        num_channels = 3
        height = 8
        width = 8

39
        sample = np.random.rand(batch_size, num_channels, height, width)
Patrick von Platen's avatar
Patrick von Platen committed
40

41
        return sample
Patrick von Platen's avatar
Patrick von Platen committed
42
43

    @property
44
    def dummy_sample_deter(self):
Patrick von Platen's avatar
Patrick von Platen committed
45
46
47
48
49
50
        batch_size = 4
        num_channels = 3
        height = 8
        width = 8

        num_elems = batch_size * num_channels * height * width
51
52
53
54
        sample = np.arange(num_elems)
        sample = sample.reshape(num_channels, height, width, batch_size)
        sample = sample / num_elems
        sample = sample.transpose(3, 0, 1, 2)
Patrick von Platen's avatar
Patrick von Platen committed
55

56
        return sample
Patrick von Platen's avatar
Patrick von Platen committed
57
58
59
60
61

    def get_scheduler_config(self):
        raise NotImplementedError

    def dummy_model(self):
62
63
        def model(sample, t, *args):
            return sample * t / (t + 1)
Patrick von Platen's avatar
Patrick von Platen committed
64
65
66

        return model

Patrick von Platen's avatar
Patrick von Platen committed
67
68
69
    def check_over_configs(self, time_step=0, **config):
        kwargs = dict(self.forward_default_kwargs)

70
71
        num_inference_steps = kwargs.pop("num_inference_steps", None)

Patrick von Platen's avatar
Patrick von Platen committed
72
        for scheduler_class in self.scheduler_classes:
73
74
            sample = self.dummy_sample
            residual = 0.1 * sample
Patrick von Platen's avatar
Patrick von Platen committed
75
76
77
78
79
80
81
82

            scheduler_config = self.get_scheduler_config(**config)
            scheduler = scheduler_class(**scheduler_config)

            with tempfile.TemporaryDirectory() as tmpdirname:
                scheduler.save_config(tmpdirname)
                new_scheduler = scheduler_class.from_config(tmpdirname)

83
84
85
86
87
88
89
90
            if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
                scheduler.set_timesteps(num_inference_steps)
                new_scheduler.set_timesteps(num_inference_steps)
            elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
                kwargs["num_inference_steps"] = num_inference_steps

            output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
            new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
Patrick von Platen's avatar
Patrick von Platen committed
91

Patrick von Platen's avatar
Patrick von Platen committed
92
            assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
Patrick von Platen's avatar
Patrick von Platen committed
93
94
95
96
97

    def check_over_forward(self, time_step=0, **forward_kwargs):
        kwargs = dict(self.forward_default_kwargs)
        kwargs.update(forward_kwargs)

98
99
        num_inference_steps = kwargs.pop("num_inference_steps", None)

Patrick von Platen's avatar
Patrick von Platen committed
100
        for scheduler_class in self.scheduler_classes:
101
102
            sample = self.dummy_sample
            residual = 0.1 * sample
Patrick von Platen's avatar
Patrick von Platen committed
103
104
105
106
107
108
109
110

            scheduler_config = self.get_scheduler_config()
            scheduler = scheduler_class(**scheduler_config)

            with tempfile.TemporaryDirectory() as tmpdirname:
                scheduler.save_config(tmpdirname)
                new_scheduler = scheduler_class.from_config(tmpdirname)

111
112
113
114
115
116
117
118
119
120
            if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
                scheduler.set_timesteps(num_inference_steps)
                new_scheduler.set_timesteps(num_inference_steps)
            elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
                kwargs["num_inference_steps"] = num_inference_steps

            torch.manual_seed(0)
            output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
            torch.manual_seed(0)
            new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
Patrick von Platen's avatar
Patrick von Platen committed
121

Patrick von Platen's avatar
Patrick von Platen committed
122
            assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
Patrick von Platen's avatar
Patrick von Platen committed
123

Patrick von Platen's avatar
Patrick von Platen committed
124
    def test_from_pretrained_save_pretrained(self):
Patrick von Platen's avatar
Patrick von Platen committed
125
126
        kwargs = dict(self.forward_default_kwargs)

127
128
        num_inference_steps = kwargs.pop("num_inference_steps", None)

Patrick von Platen's avatar
Patrick von Platen committed
129
        for scheduler_class in self.scheduler_classes:
130
131
            sample = self.dummy_sample
            residual = 0.1 * sample
Patrick von Platen's avatar
Patrick von Platen committed
132
133
134
135
136
137
138
139

            scheduler_config = self.get_scheduler_config()
            scheduler = scheduler_class(**scheduler_config)

            with tempfile.TemporaryDirectory() as tmpdirname:
                scheduler.save_config(tmpdirname)
                new_scheduler = scheduler_class.from_config(tmpdirname)

140
141
142
143
144
145
146
147
            if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
                scheduler.set_timesteps(num_inference_steps)
                new_scheduler.set_timesteps(num_inference_steps)
            elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
                kwargs["num_inference_steps"] = num_inference_steps

            output = scheduler.step(residual, 1, sample, **kwargs)["prev_sample"]
            new_output = new_scheduler.step(residual, 1, sample, **kwargs)["prev_sample"]
Patrick von Platen's avatar
Patrick von Platen committed
148

Patrick von Platen's avatar
Patrick von Platen committed
149
            assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
Patrick von Platen's avatar
Patrick von Platen committed
150
151
152
153

    def test_step_shape(self):
        kwargs = dict(self.forward_default_kwargs)

154
155
        num_inference_steps = kwargs.pop("num_inference_steps", None)

Patrick von Platen's avatar
Patrick von Platen committed
156
157
158
159
        for scheduler_class in self.scheduler_classes:
            scheduler_config = self.get_scheduler_config()
            scheduler = scheduler_class(**scheduler_config)

160
161
            sample = self.dummy_sample
            residual = 0.1 * sample
Patrick von Platen's avatar
Patrick von Platen committed
162

163
164
165
166
167
168
169
            if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
                scheduler.set_timesteps(num_inference_steps)
            elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
                kwargs["num_inference_steps"] = num_inference_steps

            output_0 = scheduler.step(residual, 0, sample, **kwargs)["prev_sample"]
            output_1 = scheduler.step(residual, 1, sample, **kwargs)["prev_sample"]
Patrick von Platen's avatar
Patrick von Platen committed
170

171
            self.assertEqual(output_0.shape, sample.shape)
Patrick von Platen's avatar
Patrick von Platen committed
172
173
            self.assertEqual(output_0.shape, output_1.shape)

Patrick von Platen's avatar
Patrick von Platen committed
174
175
176
    def test_pytorch_equal_numpy(self):
        kwargs = dict(self.forward_default_kwargs)

177
178
        num_inference_steps = kwargs.pop("num_inference_steps", None)

Patrick von Platen's avatar
Patrick von Platen committed
179
        for scheduler_class in self.scheduler_classes:
180
181
            sample = self.dummy_sample
            residual = 0.1 * sample
Patrick von Platen's avatar
Patrick von Platen committed
182

183
184
            sample_pt = torch.tensor(sample)
            residual_pt = 0.1 * sample_pt
Patrick von Platen's avatar
Patrick von Platen committed
185
186
187
188
189
190

            scheduler_config = self.get_scheduler_config()
            scheduler = scheduler_class(**scheduler_config)

            scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config)

191
192
193
194
195
196
197
198
            if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
                scheduler.set_timesteps(num_inference_steps)
                scheduler_pt.set_timesteps(num_inference_steps)
            elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
                kwargs["num_inference_steps"] = num_inference_steps

            output = scheduler.step(residual, 1, sample, **kwargs)["prev_sample"]
            output_pt = scheduler_pt.step(residual_pt, 1, sample_pt, **kwargs)["prev_sample"]
Patrick von Platen's avatar
Patrick von Platen committed
199

Patrick von Platen's avatar
Patrick von Platen committed
200
            assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
Patrick von Platen's avatar
Patrick von Platen committed
201

Patrick von Platen's avatar
Patrick von Platen committed
202
203

class DDPMSchedulerTest(SchedulerCommonTest):
Patrick von Platen's avatar
Patrick von Platen committed
204
    scheduler_classes = (DDPMScheduler,)
Patrick von Platen's avatar
Patrick von Platen committed
205
206
207

    def get_scheduler_config(self, **kwargs):
        config = {
Nathan Lambert's avatar
Nathan Lambert committed
208
            "num_train_timesteps": 1000,
Patrick von Platen's avatar
Patrick von Platen committed
209
210
211
212
            "beta_start": 0.0001,
            "beta_end": 0.02,
            "beta_schedule": "linear",
            "variance_type": "fixed_small",
Patrick von Platen's avatar
Patrick von Platen committed
213
            "clip_sample": True,
Patrick von Platen's avatar
Patrick von Platen committed
214
215
216
217
        }

        config.update(**kwargs)
        return config
Patrick von Platen's avatar
update  
Patrick von Platen committed
218

Patrick von Platen's avatar
Patrick von Platen committed
219
220
    def test_timesteps(self):
        for timesteps in [1, 5, 100, 1000]:
Nathan Lambert's avatar
Nathan Lambert committed
221
            self.check_over_configs(num_train_timesteps=timesteps)
Patrick von Platen's avatar
Patrick von Platen committed
222
223
224
225
226
227
228
229
230
231
232
233
234

    def test_betas(self):
        for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
            self.check_over_configs(beta_start=beta_start, beta_end=beta_end)

    def test_schedules(self):
        for schedule in ["linear", "squaredcos_cap_v2"]:
            self.check_over_configs(beta_schedule=schedule)

    def test_variance_type(self):
        for variance in ["fixed_small", "fixed_large", "other"]:
            self.check_over_configs(variance_type=variance)

235
    def test_clip_sample(self):
Patrick von Platen's avatar
Patrick von Platen committed
236
237
        for clip_sample in [True, False]:
            self.check_over_configs(clip_sample=clip_sample)
Patrick von Platen's avatar
Patrick von Platen committed
238
239
240
241
242
243
244
245
246
247

    def test_time_indices(self):
        for t in [0, 500, 999]:
            self.check_over_forward(time_step=t)

    def test_variance(self):
        scheduler_class = self.scheduler_classes[0]
        scheduler_config = self.get_scheduler_config()
        scheduler = scheduler_class(**scheduler_config)

Patrick von Platen's avatar
Patrick von Platen committed
248
249
250
        assert np.sum(np.abs(scheduler.get_variance(0) - 0.0)) < 1e-5
        assert np.sum(np.abs(scheduler.get_variance(487) - 0.00979)) < 1e-5
        assert np.sum(np.abs(scheduler.get_variance(999) - 0.02)) < 1e-5
Patrick von Platen's avatar
Patrick von Platen committed
251
252
253

    def test_full_loop_no_noise(self):
        scheduler_class = self.scheduler_classes[0]
Patrick von Platen's avatar
Patrick von Platen committed
254
        scheduler_config = self.get_scheduler_config()
Patrick von Platen's avatar
Patrick von Platen committed
255
256
257
258
259
        scheduler = scheduler_class(**scheduler_config)

        num_trained_timesteps = len(scheduler)

        model = self.dummy_model()
260
        sample = self.dummy_sample_deter
Patrick von Platen's avatar
Patrick von Platen committed
261
262
263

        for t in reversed(range(num_trained_timesteps)):
            # 1. predict noise residual
264
            residual = model(sample, t)
Patrick von Platen's avatar
Patrick von Platen committed
265

266
            # 2. predict previous mean of sample x_t-1
267
            pred_prev_sample = scheduler.step(residual, t, sample)["prev_sample"]
Patrick von Platen's avatar
Patrick von Platen committed
268
269

            if t > 0:
270
                noise = self.dummy_sample_deter
Patrick von Platen's avatar
Patrick von Platen committed
271
                variance = scheduler.get_variance(t) ** (0.5) * noise
Patrick von Platen's avatar
Patrick von Platen committed
272

273
            sample = pred_prev_sample + variance
Patrick von Platen's avatar
Patrick von Platen committed
274

275
276
        result_sum = np.sum(np.abs(sample))
        result_mean = np.mean(np.abs(sample))
Patrick von Platen's avatar
Patrick von Platen committed
277

Patrick von Platen's avatar
Patrick von Platen committed
278
279
        assert abs(result_sum.item() - 732.9947) < 1e-2
        assert abs(result_mean.item() - 0.9544) < 1e-3
Patrick von Platen's avatar
Patrick von Platen committed
280

Patrick von Platen's avatar
update  
Patrick von Platen committed
281

Patrick von Platen's avatar
Patrick von Platen committed
282
283
class DDIMSchedulerTest(SchedulerCommonTest):
    scheduler_classes = (DDIMScheduler,)
284
    forward_default_kwargs = (("eta", 0.0), ("num_inference_steps", 50))
Patrick von Platen's avatar
update  
Patrick von Platen committed
285

Patrick von Platen's avatar
Patrick von Platen committed
286
287
    def get_scheduler_config(self, **kwargs):
        config = {
Nathan Lambert's avatar
Nathan Lambert committed
288
            "num_train_timesteps": 1000,
Patrick von Platen's avatar
Patrick von Platen committed
289
290
291
            "beta_start": 0.0001,
            "beta_end": 0.02,
            "beta_schedule": "linear",
Patrick von Platen's avatar
Patrick von Platen committed
292
            "clip_sample": True,
Patrick von Platen's avatar
Patrick von Platen committed
293
        }
Patrick von Platen's avatar
Patrick von Platen committed
294

Patrick von Platen's avatar
Patrick von Platen committed
295
296
297
298
        config.update(**kwargs)
        return config

    def test_timesteps(self):
299
        for timesteps in [100, 500, 1000]:
Nathan Lambert's avatar
Nathan Lambert committed
300
            self.check_over_configs(num_train_timesteps=timesteps)
Patrick von Platen's avatar
Patrick von Platen committed
301
302
303
304
305
306
307
308
309

    def test_betas(self):
        for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
            self.check_over_configs(beta_start=beta_start, beta_end=beta_end)

    def test_schedules(self):
        for schedule in ["linear", "squaredcos_cap_v2"]:
            self.check_over_configs(beta_schedule=schedule)

310
    def test_clip_sample(self):
Patrick von Platen's avatar
Patrick von Platen committed
311
312
        for clip_sample in [True, False]:
            self.check_over_configs(clip_sample=clip_sample)
Patrick von Platen's avatar
Patrick von Platen committed
313
314
315
316
317
318
319

    def test_time_indices(self):
        for t in [1, 10, 49]:
            self.check_over_forward(time_step=t)

    def test_inference_steps(self):
        for t, num_inference_steps in zip([1, 10, 50], [10, 50, 500]):
320
            self.check_over_forward(num_inference_steps=num_inference_steps)
Patrick von Platen's avatar
Patrick von Platen committed
321
322
323
324
325
326
327

    def test_eta(self):
        for t, eta in zip([1, 10, 49], [0.0, 0.5, 1.0]):
            self.check_over_forward(time_step=t, eta=eta)

    def test_variance(self):
        scheduler_class = self.scheduler_classes[0]
Patrick von Platen's avatar
Patrick von Platen committed
328
        scheduler_config = self.get_scheduler_config()
Patrick von Platen's avatar
Patrick von Platen committed
329
330
        scheduler = scheduler_class(**scheduler_config)

331
332
333
334
335
336
        assert np.sum(np.abs(scheduler._get_variance(0, 0) - 0.0)) < 1e-5
        assert np.sum(np.abs(scheduler._get_variance(420, 400) - 0.14771)) < 1e-5
        assert np.sum(np.abs(scheduler._get_variance(980, 960) - 0.32460)) < 1e-5
        assert np.sum(np.abs(scheduler._get_variance(0, 0) - 0.0)) < 1e-5
        assert np.sum(np.abs(scheduler._get_variance(487, 486) - 0.00979)) < 1e-5
        assert np.sum(np.abs(scheduler._get_variance(999, 998) - 0.02)) < 1e-5
Patrick von Platen's avatar
Patrick von Platen committed
337
338
339
340
341
342

    def test_full_loop_no_noise(self):
        scheduler_class = self.scheduler_classes[0]
        scheduler_config = self.get_scheduler_config()
        scheduler = scheduler_class(**scheduler_config)

343
        num_inference_steps, eta = 10, 0.0
Patrick von Platen's avatar
Patrick von Platen committed
344
345

        model = self.dummy_model()
346
        sample = self.dummy_sample_deter
Patrick von Platen's avatar
Patrick von Platen committed
347

348
349
350
        scheduler.set_timesteps(num_inference_steps)
        for t in scheduler.timesteps:
            residual = model(sample, t)
Patrick von Platen's avatar
Patrick von Platen committed
351

352
            sample = scheduler.step(residual, t, sample, eta)["prev_sample"]
Patrick von Platen's avatar
Patrick von Platen committed
353

354
355
        result_sum = np.sum(np.abs(sample))
        result_mean = np.mean(np.abs(sample))
Patrick von Platen's avatar
Patrick von Platen committed
356

357
358
        assert abs(result_sum.item() - 172.0067) < 1e-2
        assert abs(result_mean.item() - 0.223967) < 1e-3
Patrick von Platen's avatar
Patrick von Platen committed
359
360
361
362
363
364
365
366


class PNDMSchedulerTest(SchedulerCommonTest):
    scheduler_classes = (PNDMScheduler,)
    forward_default_kwargs = (("num_inference_steps", 50),)

    def get_scheduler_config(self, **kwargs):
        config = {
Nathan Lambert's avatar
Nathan Lambert committed
367
            "num_train_timesteps": 1000,
Patrick von Platen's avatar
Patrick von Platen committed
368
369
370
371
372
373
374
375
            "beta_start": 0.0001,
            "beta_end": 0.02,
            "beta_schedule": "linear",
        }

        config.update(**kwargs)
        return config

376
    def check_over_configs(self, time_step=0, **config):
Patrick von Platen's avatar
Patrick von Platen committed
377
        kwargs = dict(self.forward_default_kwargs)
378
379
        sample = self.dummy_sample
        residual = 0.1 * sample
Patrick von Platen's avatar
Patrick von Platen committed
380
381
382
383
384
        dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]

        for scheduler_class in self.scheduler_classes:
            scheduler_config = self.get_scheduler_config(**config)
            scheduler = scheduler_class(**scheduler_config)
385
            scheduler.set_timesteps(kwargs["num_inference_steps"])
Patrick von Platen's avatar
Patrick von Platen committed
386
387
388
389
390
391
            # copy over dummy past residuals
            scheduler.ets = dummy_past_residuals[:]

            with tempfile.TemporaryDirectory() as tmpdirname:
                scheduler.save_config(tmpdirname)
                new_scheduler = scheduler_class.from_config(tmpdirname)
392
                new_scheduler.set_timesteps(kwargs["num_inference_steps"])
Patrick von Platen's avatar
Patrick von Platen committed
393
394
395
                # copy over dummy past residuals
                new_scheduler.ets = dummy_past_residuals[:]

396
397
            output = scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
            new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
Patrick von Platen's avatar
Patrick von Platen committed
398
399
400

            assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"

401
402
403
404
405
406
407
408
409
            output = scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]
            new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]

            assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"

    def test_from_pretrained_save_pretrained(self):
        pass

    def check_over_forward(self, time_step=0, **forward_kwargs):
Patrick von Platen's avatar
Patrick von Platen committed
410
411
        kwargs = dict(self.forward_default_kwargs)
        kwargs.update(forward_kwargs)
412
413
        sample = self.dummy_sample
        residual = 0.1 * sample
Patrick von Platen's avatar
Patrick von Platen committed
414
415
416
417
418
        dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]

        for scheduler_class in self.scheduler_classes:
            scheduler_config = self.get_scheduler_config()
            scheduler = scheduler_class(**scheduler_config)
419
420
            scheduler.set_timesteps(kwargs["num_inference_steps"])

Patrick von Platen's avatar
Patrick von Platen committed
421
422
423
424
425
426
427
428
            # copy over dummy past residuals
            scheduler.ets = dummy_past_residuals[:]

            with tempfile.TemporaryDirectory() as tmpdirname:
                scheduler.save_config(tmpdirname)
                new_scheduler = scheduler_class.from_config(tmpdirname)
                # copy over dummy past residuals
                new_scheduler.ets = dummy_past_residuals[:]
429
                new_scheduler.set_timesteps(kwargs["num_inference_steps"])
Patrick von Platen's avatar
Patrick von Platen committed
430

431
432
433
434
435
436
437
            output = scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
            new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]

            assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"

            output = scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]
            new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]
Patrick von Platen's avatar
Patrick von Platen committed
438
439
440

            assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"

441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
    def test_pytorch_equal_numpy(self):
        kwargs = dict(self.forward_default_kwargs)
        num_inference_steps = kwargs.pop("num_inference_steps", None)

        for scheduler_class in self.scheduler_classes:
            sample = self.dummy_sample
            residual = 0.1 * sample
            dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]

            sample_pt = torch.tensor(sample)
            residual_pt = 0.1 * sample_pt
            dummy_past_residuals_pt = [residual_pt + 0.2, residual_pt + 0.15, residual_pt + 0.1, residual_pt + 0.05]

            scheduler_config = self.get_scheduler_config()
            scheduler = scheduler_class(**scheduler_config)
            # copy over dummy past residuals
            scheduler.ets = dummy_past_residuals[:]

            scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config)
            # copy over dummy past residuals
            scheduler_pt.ets = dummy_past_residuals_pt[:]

            if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
                scheduler.set_timesteps(num_inference_steps)
                scheduler_pt.set_timesteps(num_inference_steps)
            elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
                kwargs["num_inference_steps"] = num_inference_steps

            output = scheduler.step_prk(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"]
            output_pt = scheduler_pt.step_prk(residual_pt, 1, sample_pt, num_inference_steps, **kwargs)["prev_sample"]

            assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"

            output = scheduler.step_plms(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"]
            output_pt = scheduler_pt.step_plms(residual_pt, 1, sample_pt, num_inference_steps, **kwargs)["prev_sample"]

            assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"

    def test_step_shape(self):
        kwargs = dict(self.forward_default_kwargs)

        num_inference_steps = kwargs.pop("num_inference_steps", None)

        for scheduler_class in self.scheduler_classes:
            scheduler_config = self.get_scheduler_config()
            scheduler = scheduler_class(**scheduler_config)

            sample = self.dummy_sample
            residual = 0.1 * sample
            # copy over dummy past residuals
            dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
            scheduler.ets = dummy_past_residuals[:]

            if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
                scheduler.set_timesteps(num_inference_steps)
            elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
                kwargs["num_inference_steps"] = num_inference_steps

            output_0 = scheduler.step_prk(residual, 0, sample, num_inference_steps, **kwargs)["prev_sample"]
            output_1 = scheduler.step_prk(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"]

            self.assertEqual(output_0.shape, sample.shape)
            self.assertEqual(output_0.shape, output_1.shape)

            output_0 = scheduler.step_plms(residual, 0, sample, num_inference_steps, **kwargs)["prev_sample"]
            output_1 = scheduler.step_plms(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"]

            self.assertEqual(output_0.shape, sample.shape)
            self.assertEqual(output_0.shape, output_1.shape)

Patrick von Platen's avatar
Patrick von Platen committed
511
512
    def test_timesteps(self):
        for timesteps in [100, 1000]:
Nathan Lambert's avatar
Nathan Lambert committed
513
            self.check_over_configs(num_train_timesteps=timesteps)
Patrick von Platen's avatar
Patrick von Platen committed
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530

    def test_betas(self):
        for beta_start, beta_end in zip([0.0001, 0.001, 0.01], [0.002, 0.02, 0.2]):
            self.check_over_configs(beta_start=beta_start, beta_end=beta_end)

    def test_schedules(self):
        for schedule in ["linear", "squaredcos_cap_v2"]:
            self.check_over_configs(beta_schedule=schedule)

    def test_time_indices(self):
        for t in [1, 5, 10]:
            self.check_over_forward(time_step=t)

    def test_inference_steps(self):
        for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]):
            self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps)

531
    def test_inference_plms_no_past_residuals(self):
Patrick von Platen's avatar
Patrick von Platen committed
532
533
534
535
536
        with self.assertRaises(ValueError):
            scheduler_class = self.scheduler_classes[0]
            scheduler_config = self.get_scheduler_config()
            scheduler = scheduler_class(**scheduler_config)

537
            scheduler.step_plms(self.dummy_sample, 1, self.dummy_sample, 50)["prev_sample"]
Patrick von Platen's avatar
Patrick von Platen committed
538
539
540
541
542
543
544
545

    def test_full_loop_no_noise(self):
        scheduler_class = self.scheduler_classes[0]
        scheduler_config = self.get_scheduler_config()
        scheduler = scheduler_class(**scheduler_config)

        num_inference_steps = 10
        model = self.dummy_model()
546
        sample = self.dummy_sample_deter
547
        scheduler.set_timesteps(num_inference_steps)
Patrick von Platen's avatar
Patrick von Platen committed
548

549
550
551
        for i, t in enumerate(scheduler.prk_timesteps):
            residual = model(sample, t)
            sample = scheduler.step_prk(residual, i, sample, num_inference_steps)["prev_sample"]
Patrick von Platen's avatar
Patrick von Platen committed
552

553
554
555
        for i, t in enumerate(scheduler.plms_timesteps):
            residual = model(sample, t)
            sample = scheduler.step_plms(residual, i, sample, num_inference_steps)["prev_sample"]
Patrick von Platen's avatar
Patrick von Platen committed
556

557
558
        result_sum = np.sum(np.abs(sample))
        result_mean = np.mean(np.abs(sample))
Patrick von Platen's avatar
Patrick von Platen committed
559
560
561

        assert abs(result_sum.item() - 199.1169) < 1e-2
        assert abs(result_mean.item() - 0.2593) < 1e-3
Nathan Lambert's avatar
Nathan Lambert committed
562
563


564
565
class ScoreSdeVeSchedulerTest(unittest.TestCase):
    # TODO adapt with class SchedulerCommonTest (scheduler needs Numpy Integration)
Nathan Lambert's avatar
Nathan Lambert committed
566
    scheduler_classes = (ScoreSdeVeScheduler,)
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
    forward_default_kwargs = (("seed", 0),)

    @property
    def dummy_sample(self):
        batch_size = 4
        num_channels = 3
        height = 8
        width = 8

        sample = torch.rand((batch_size, num_channels, height, width))

        return sample

    @property
    def dummy_sample_deter(self):
        batch_size = 4
        num_channels = 3
        height = 8
        width = 8

        num_elems = batch_size * num_channels * height * width
        sample = torch.arange(num_elems)
        sample = sample.reshape(num_channels, height, width, batch_size)
        sample = sample / num_elems
        sample = sample.permute(3, 0, 1, 2)

        return sample

    def dummy_model(self):
        def model(sample, t, *args):
            return sample * t / (t + 1)

        return model
Nathan Lambert's avatar
Nathan Lambert committed
600
601
602
603
604
605
606
607

    def get_scheduler_config(self, **kwargs):
        config = {
            "num_train_timesteps": 2000,
            "snr": 0.15,
            "sigma_min": 0.01,
            "sigma_max": 1348,
            "sampling_eps": 1e-5,
608
            "tensor_format": "pt",  # TODO add test for tensor formats
Nathan Lambert's avatar
Nathan Lambert committed
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
        }

        config.update(**kwargs)
        return config

    def check_over_configs(self, time_step=0, **config):
        kwargs = dict(self.forward_default_kwargs)

        for scheduler_class in self.scheduler_classes:
            sample = self.dummy_sample
            residual = 0.1 * sample

            scheduler_config = self.get_scheduler_config(**config)
            scheduler = scheduler_class(**scheduler_config)

            with tempfile.TemporaryDirectory() as tmpdirname:
                scheduler.save_config(tmpdirname)
                new_scheduler = scheduler_class.from_config(tmpdirname)

628
629
            output = scheduler.step_pred(residual, time_step, sample, **kwargs)["prev_sample"]
            new_output = new_scheduler.step_pred(residual, time_step, sample, **kwargs)["prev_sample"]
Nathan Lambert's avatar
Nathan Lambert committed
630

631
            assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
Nathan Lambert's avatar
Nathan Lambert committed
632

633
634
            output = scheduler.step_correct(residual, sample, **kwargs)["prev_sample"]
            new_output = new_scheduler.step_correct(residual, sample, **kwargs)["prev_sample"]
Nathan Lambert's avatar
Nathan Lambert committed
635

636
            assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler correction are not identical"
Nathan Lambert's avatar
Nathan Lambert committed
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652

    def check_over_forward(self, time_step=0, **forward_kwargs):
        kwargs = dict(self.forward_default_kwargs)
        kwargs.update(forward_kwargs)

        for scheduler_class in self.scheduler_classes:
            sample = self.dummy_sample
            residual = 0.1 * sample

            scheduler_config = self.get_scheduler_config()
            scheduler = scheduler_class(**scheduler_config)

            with tempfile.TemporaryDirectory() as tmpdirname:
                scheduler.save_config(tmpdirname)
                new_scheduler = scheduler_class.from_config(tmpdirname)

653
654
            output = scheduler.step_pred(residual, time_step, sample, **kwargs)["prev_sample"]
            new_output = new_scheduler.step_pred(residual, time_step, sample, **kwargs)["prev_sample"]
Nathan Lambert's avatar
Nathan Lambert committed
655

656
            assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
Nathan Lambert's avatar
Nathan Lambert committed
657

658
659
            output = scheduler.step_correct(residual, sample, **kwargs)["prev_sample"]
            new_output = new_scheduler.step_correct(residual, sample, **kwargs)["prev_sample"]
Nathan Lambert's avatar
Nathan Lambert committed
660

661
            assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler correction are not identical"
Nathan Lambert's avatar
Nathan Lambert committed
662
663
664
665
666
667
668
669
670
671

    def test_timesteps(self):
        for timesteps in [10, 100, 1000]:
            self.check_over_configs(num_train_timesteps=timesteps)

    def test_sigmas(self):
        for sigma_min, sigma_max in zip([0.0001, 0.001, 0.01], [1, 100, 1000]):
            self.check_over_configs(sigma_min=sigma_min, sigma_max=sigma_max)

    def test_time_indices(self):
672
        for t in [0.1, 0.5, 0.75]:
Nathan Lambert's avatar
Nathan Lambert committed
673
674
675
            self.check_over_forward(time_step=t)

    def test_full_loop_no_noise(self):
676
677
        kwargs = dict(self.forward_default_kwargs)

Nathan Lambert's avatar
Nathan Lambert committed
678
679
680
681
682
683
684
685
686
687
        scheduler_class = self.scheduler_classes[0]
        scheduler_config = self.get_scheduler_config()
        scheduler = scheduler_class(**scheduler_config)

        num_inference_steps = 3

        model = self.dummy_model()
        sample = self.dummy_sample_deter

        scheduler.set_sigmas(num_inference_steps)
688
        scheduler.set_timesteps(num_inference_steps)
Nathan Lambert's avatar
Nathan Lambert committed
689
690
691
692
693
694

        for i, t in enumerate(scheduler.timesteps):
            sigma_t = scheduler.sigmas[i]

            for _ in range(scheduler.correct_steps):
                with torch.no_grad():
695
696
                    model_output = model(sample, sigma_t)
                sample = scheduler.step_correct(model_output, sample, **kwargs)["prev_sample"]
Nathan Lambert's avatar
Nathan Lambert committed
697
698

            with torch.no_grad():
699
                model_output = model(sample, sigma_t)
Patrick von Platen's avatar
Patrick von Platen committed
700

701
702
            output = scheduler.step_pred(model_output, t, sample, **kwargs)
            sample, sample_mean = output["prev_sample"], output["prev_sample_mean"]
Patrick von Platen's avatar
Patrick von Platen committed
703

704
705
        result_sum = torch.sum(torch.abs(sample))
        result_mean = torch.mean(torch.abs(sample))
Patrick von Platen's avatar
Patrick von Platen committed
706

707
708
        assert abs(result_sum.item() - 14224664576.0) < 1e-2
        assert abs(result_mean.item() - 18521698.0) < 1e-3
Patrick von Platen's avatar
Patrick von Platen committed
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731

    def test_step_shape(self):
        kwargs = dict(self.forward_default_kwargs)

        num_inference_steps = kwargs.pop("num_inference_steps", None)

        for scheduler_class in self.scheduler_classes:
            scheduler_config = self.get_scheduler_config()
            scheduler = scheduler_class(**scheduler_config)

            sample = self.dummy_sample
            residual = 0.1 * sample

            if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
                scheduler.set_timesteps(num_inference_steps)
            elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
                kwargs["num_inference_steps"] = num_inference_steps

            output_0 = scheduler.step_pred(residual, 0, sample, **kwargs)["prev_sample"]
            output_1 = scheduler.step_pred(residual, 1, sample, **kwargs)["prev_sample"]

            self.assertEqual(output_0.shape, sample.shape)
            self.assertEqual(output_0.shape, output_1.shape)