test_scheduler.py 19.3 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


Patrick von Platen's avatar
Patrick von Platen committed
17
import tempfile
Patrick von Platen's avatar
Patrick von Platen committed
18
import unittest
Patrick von Platen's avatar
Patrick von Platen committed
19

Patrick von Platen's avatar
Patrick von Platen committed
20
21
22
import numpy as np
import torch

Patrick von Platen's avatar
Patrick von Platen committed
23
from diffusers import DDIMScheduler, DDPMScheduler, PNDMScheduler
Patrick von Platen's avatar
Patrick von Platen committed
24
25
26
27
28
29


torch.backends.cuda.matmul.allow_tf32 = False


class SchedulerCommonTest(unittest.TestCase):
Patrick von Platen's avatar
Patrick von Platen committed
30
31
    scheduler_classes = ()
    forward_default_kwargs = ()
Patrick von Platen's avatar
Patrick von Platen committed
32
33

    @property
34
    def dummy_sample(self):
Patrick von Platen's avatar
Patrick von Platen committed
35
36
37
38
39
        batch_size = 4
        num_channels = 3
        height = 8
        width = 8

40
        sample = np.random.rand(batch_size, num_channels, height, width)
Patrick von Platen's avatar
Patrick von Platen committed
41

42
        return sample
Patrick von Platen's avatar
Patrick von Platen committed
43
44

    @property
45
    def dummy_sample_deter(self):
Patrick von Platen's avatar
Patrick von Platen committed
46
47
48
49
50
51
        batch_size = 4
        num_channels = 3
        height = 8
        width = 8

        num_elems = batch_size * num_channels * height * width
52
53
54
55
        sample = np.arange(num_elems)
        sample = sample.reshape(num_channels, height, width, batch_size)
        sample = sample / num_elems
        sample = sample.transpose(3, 0, 1, 2)
Patrick von Platen's avatar
Patrick von Platen committed
56

57
        return sample
Patrick von Platen's avatar
Patrick von Platen committed
58
59
60
61
62

    def get_scheduler_config(self):
        raise NotImplementedError

    def dummy_model(self):
63
64
        def model(sample, t, *args):
            return sample * t / (t + 1)
Patrick von Platen's avatar
Patrick von Platen committed
65
66
67

        return model

Patrick von Platen's avatar
Patrick von Platen committed
68
69
70
    def check_over_configs(self, time_step=0, **config):
        kwargs = dict(self.forward_default_kwargs)

71
72
        num_inference_steps = kwargs.pop("num_inference_steps", None)

Patrick von Platen's avatar
Patrick von Platen committed
73
74
        for scheduler_class in self.scheduler_classes:
            scheduler_class = self.scheduler_classes[0]
75
76
            sample = self.dummy_sample
            residual = 0.1 * sample
Patrick von Platen's avatar
Patrick von Platen committed
77
78
79
80
81
82
83
84

            scheduler_config = self.get_scheduler_config(**config)
            scheduler = scheduler_class(**scheduler_config)

            with tempfile.TemporaryDirectory() as tmpdirname:
                scheduler.save_config(tmpdirname)
                new_scheduler = scheduler_class.from_config(tmpdirname)

85
86
87
88
89
90
91
92
            if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
                scheduler.set_timesteps(num_inference_steps)
                new_scheduler.set_timesteps(num_inference_steps)
            elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
                kwargs["num_inference_steps"] = num_inference_steps

            output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
            new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
Patrick von Platen's avatar
Patrick von Platen committed
93

Patrick von Platen's avatar
Patrick von Platen committed
94
            assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
Patrick von Platen's avatar
Patrick von Platen committed
95
96
97
98
99

    def check_over_forward(self, time_step=0, **forward_kwargs):
        kwargs = dict(self.forward_default_kwargs)
        kwargs.update(forward_kwargs)

100
101
        num_inference_steps = kwargs.pop("num_inference_steps", None)

Patrick von Platen's avatar
Patrick von Platen committed
102
        for scheduler_class in self.scheduler_classes:
103
104
            sample = self.dummy_sample
            residual = 0.1 * sample
Patrick von Platen's avatar
Patrick von Platen committed
105

Patrick von Platen's avatar
Patrick von Platen committed
106
            scheduler_class = self.scheduler_classes[0]
Patrick von Platen's avatar
Patrick von Platen committed
107
108
109
110
111
112
113
            scheduler_config = self.get_scheduler_config()
            scheduler = scheduler_class(**scheduler_config)

            with tempfile.TemporaryDirectory() as tmpdirname:
                scheduler.save_config(tmpdirname)
                new_scheduler = scheduler_class.from_config(tmpdirname)

114
115
116
117
118
119
120
121
122
123
            if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
                scheduler.set_timesteps(num_inference_steps)
                new_scheduler.set_timesteps(num_inference_steps)
            elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
                kwargs["num_inference_steps"] = num_inference_steps

            torch.manual_seed(0)
            output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
            torch.manual_seed(0)
            new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
Patrick von Platen's avatar
Patrick von Platen committed
124

Patrick von Platen's avatar
Patrick von Platen committed
125
            assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
Patrick von Platen's avatar
Patrick von Platen committed
126

Patrick von Platen's avatar
Patrick von Platen committed
127
    def test_from_pretrained_save_pretrained(self):
Patrick von Platen's avatar
Patrick von Platen committed
128
129
        kwargs = dict(self.forward_default_kwargs)

130
131
        num_inference_steps = kwargs.pop("num_inference_steps", None)

Patrick von Platen's avatar
Patrick von Platen committed
132
        for scheduler_class in self.scheduler_classes:
133
134
            sample = self.dummy_sample
            residual = 0.1 * sample
Patrick von Platen's avatar
Patrick von Platen committed
135
136
137
138
139
140
141
142

            scheduler_config = self.get_scheduler_config()
            scheduler = scheduler_class(**scheduler_config)

            with tempfile.TemporaryDirectory() as tmpdirname:
                scheduler.save_config(tmpdirname)
                new_scheduler = scheduler_class.from_config(tmpdirname)

143
144
145
146
147
148
149
150
            if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
                scheduler.set_timesteps(num_inference_steps)
                new_scheduler.set_timesteps(num_inference_steps)
            elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
                kwargs["num_inference_steps"] = num_inference_steps

            output = scheduler.step(residual, 1, sample, **kwargs)["prev_sample"]
            new_output = new_scheduler.step(residual, 1, sample, **kwargs)["prev_sample"]
Patrick von Platen's avatar
Patrick von Platen committed
151

Patrick von Platen's avatar
Patrick von Platen committed
152
            assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
Patrick von Platen's avatar
Patrick von Platen committed
153
154
155
156

    def test_step_shape(self):
        kwargs = dict(self.forward_default_kwargs)

157
158
        num_inference_steps = kwargs.pop("num_inference_steps", None)

Patrick von Platen's avatar
Patrick von Platen committed
159
160
161
162
        for scheduler_class in self.scheduler_classes:
            scheduler_config = self.get_scheduler_config()
            scheduler = scheduler_class(**scheduler_config)

163
164
            sample = self.dummy_sample
            residual = 0.1 * sample
Patrick von Platen's avatar
Patrick von Platen committed
165

166
167
168
169
170
171
172
            if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
                scheduler.set_timesteps(num_inference_steps)
            elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
                kwargs["num_inference_steps"] = num_inference_steps

            output_0 = scheduler.step(residual, 0, sample, **kwargs)["prev_sample"]
            output_1 = scheduler.step(residual, 1, sample, **kwargs)["prev_sample"]
Patrick von Platen's avatar
Patrick von Platen committed
173

174
            self.assertEqual(output_0.shape, sample.shape)
Patrick von Platen's avatar
Patrick von Platen committed
175
176
            self.assertEqual(output_0.shape, output_1.shape)

Patrick von Platen's avatar
Patrick von Platen committed
177
178
179
    def test_pytorch_equal_numpy(self):
        kwargs = dict(self.forward_default_kwargs)

180
181
        num_inference_steps = kwargs.pop("num_inference_steps", None)

Patrick von Platen's avatar
Patrick von Platen committed
182
        for scheduler_class in self.scheduler_classes:
183
184
            sample = self.dummy_sample
            residual = 0.1 * sample
Patrick von Platen's avatar
Patrick von Platen committed
185

186
187
            sample_pt = torch.tensor(sample)
            residual_pt = 0.1 * sample_pt
Patrick von Platen's avatar
Patrick von Platen committed
188
189
190
191
192
193

            scheduler_config = self.get_scheduler_config()
            scheduler = scheduler_class(**scheduler_config)

            scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config)

194
195
196
197
198
199
200
201
            if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
                scheduler.set_timesteps(num_inference_steps)
                scheduler_pt.set_timesteps(num_inference_steps)
            elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
                kwargs["num_inference_steps"] = num_inference_steps

            output = scheduler.step(residual, 1, sample, **kwargs)["prev_sample"]
            output_pt = scheduler_pt.step(residual_pt, 1, sample_pt, **kwargs)["prev_sample"]
Patrick von Platen's avatar
Patrick von Platen committed
202

Patrick von Platen's avatar
Patrick von Platen committed
203
            assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
Patrick von Platen's avatar
Patrick von Platen committed
204

Patrick von Platen's avatar
Patrick von Platen committed
205
206

class DDPMSchedulerTest(SchedulerCommonTest):
Patrick von Platen's avatar
Patrick von Platen committed
207
    scheduler_classes = (DDPMScheduler,)
Patrick von Platen's avatar
Patrick von Platen committed
208
209
210
211
212
213
214
215

    def get_scheduler_config(self, **kwargs):
        config = {
            "timesteps": 1000,
            "beta_start": 0.0001,
            "beta_end": 0.02,
            "beta_schedule": "linear",
            "variance_type": "fixed_small",
Patrick von Platen's avatar
Patrick von Platen committed
216
            "clip_sample": True,
Patrick von Platen's avatar
Patrick von Platen committed
217
218
219
220
        }

        config.update(**kwargs)
        return config
Patrick von Platen's avatar
update  
Patrick von Platen committed
221

Patrick von Platen's avatar
Patrick von Platen committed
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
    def test_timesteps(self):
        for timesteps in [1, 5, 100, 1000]:
            self.check_over_configs(timesteps=timesteps)

    def test_betas(self):
        for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
            self.check_over_configs(beta_start=beta_start, beta_end=beta_end)

    def test_schedules(self):
        for schedule in ["linear", "squaredcos_cap_v2"]:
            self.check_over_configs(beta_schedule=schedule)

    def test_variance_type(self):
        for variance in ["fixed_small", "fixed_large", "other"]:
            self.check_over_configs(variance_type=variance)

238
    def test_clip_sample(self):
Patrick von Platen's avatar
Patrick von Platen committed
239
240
        for clip_sample in [True, False]:
            self.check_over_configs(clip_sample=clip_sample)
Patrick von Platen's avatar
Patrick von Platen committed
241
242
243
244
245
246
247
248
249
250

    def test_time_indices(self):
        for t in [0, 500, 999]:
            self.check_over_forward(time_step=t)

    def test_variance(self):
        scheduler_class = self.scheduler_classes[0]
        scheduler_config = self.get_scheduler_config()
        scheduler = scheduler_class(**scheduler_config)

Patrick von Platen's avatar
Patrick von Platen committed
251
252
253
        assert np.sum(np.abs(scheduler.get_variance(0) - 0.0)) < 1e-5
        assert np.sum(np.abs(scheduler.get_variance(487) - 0.00979)) < 1e-5
        assert np.sum(np.abs(scheduler.get_variance(999) - 0.02)) < 1e-5
Patrick von Platen's avatar
Patrick von Platen committed
254
255
256

    def test_full_loop_no_noise(self):
        scheduler_class = self.scheduler_classes[0]
Patrick von Platen's avatar
Patrick von Platen committed
257
        scheduler_config = self.get_scheduler_config()
Patrick von Platen's avatar
Patrick von Platen committed
258
259
260
261
262
        scheduler = scheduler_class(**scheduler_config)

        num_trained_timesteps = len(scheduler)

        model = self.dummy_model()
263
        sample = self.dummy_sample_deter
Patrick von Platen's avatar
Patrick von Platen committed
264
265
266

        for t in reversed(range(num_trained_timesteps)):
            # 1. predict noise residual
267
            residual = model(sample, t)
Patrick von Platen's avatar
Patrick von Platen committed
268

269
            # 2. predict previous mean of sample x_t-1
270
            pred_prev_sample = scheduler.step(residual, t, sample)["prev_sample"]
Patrick von Platen's avatar
Patrick von Platen committed
271
272

            if t > 0:
273
                noise = self.dummy_sample_deter
Patrick von Platen's avatar
Patrick von Platen committed
274
                variance = scheduler.get_variance(t) ** (0.5) * noise
Patrick von Platen's avatar
Patrick von Platen committed
275

276
            sample = pred_prev_sample + variance
Patrick von Platen's avatar
Patrick von Platen committed
277

278
279
        result_sum = np.sum(np.abs(sample))
        result_mean = np.mean(np.abs(sample))
Patrick von Platen's avatar
Patrick von Platen committed
280

Patrick von Platen's avatar
Patrick von Platen committed
281
282
        assert abs(result_sum.item() - 732.9947) < 1e-2
        assert abs(result_mean.item() - 0.9544) < 1e-3
Patrick von Platen's avatar
Patrick von Platen committed
283

Patrick von Platen's avatar
update  
Patrick von Platen committed
284

Patrick von Platen's avatar
Patrick von Platen committed
285
286
class DDIMSchedulerTest(SchedulerCommonTest):
    scheduler_classes = (DDIMScheduler,)
287
    forward_default_kwargs = (("eta", 0.0), ("num_inference_steps", 50))
Patrick von Platen's avatar
update  
Patrick von Platen committed
288

Patrick von Platen's avatar
Patrick von Platen committed
289
290
291
292
293
294
    def get_scheduler_config(self, **kwargs):
        config = {
            "timesteps": 1000,
            "beta_start": 0.0001,
            "beta_end": 0.02,
            "beta_schedule": "linear",
Patrick von Platen's avatar
Patrick von Platen committed
295
            "clip_sample": True,
Patrick von Platen's avatar
Patrick von Platen committed
296
        }
Patrick von Platen's avatar
Patrick von Platen committed
297

Patrick von Platen's avatar
Patrick von Platen committed
298
299
300
301
        config.update(**kwargs)
        return config

    def test_timesteps(self):
302
        for timesteps in [100, 500, 1000]:
Patrick von Platen's avatar
Patrick von Platen committed
303
304
305
306
307
308
309
310
311
312
            self.check_over_configs(timesteps=timesteps)

    def test_betas(self):
        for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
            self.check_over_configs(beta_start=beta_start, beta_end=beta_end)

    def test_schedules(self):
        for schedule in ["linear", "squaredcos_cap_v2"]:
            self.check_over_configs(beta_schedule=schedule)

313
    def test_clip_sample(self):
Patrick von Platen's avatar
Patrick von Platen committed
314
315
        for clip_sample in [True, False]:
            self.check_over_configs(clip_sample=clip_sample)
Patrick von Platen's avatar
Patrick von Platen committed
316
317
318
319
320
321
322

    def test_time_indices(self):
        for t in [1, 10, 49]:
            self.check_over_forward(time_step=t)

    def test_inference_steps(self):
        for t, num_inference_steps in zip([1, 10, 50], [10, 50, 500]):
323
            self.check_over_forward(num_inference_steps=num_inference_steps)
Patrick von Platen's avatar
Patrick von Platen committed
324
325
326
327
328
329
330

    def test_eta(self):
        for t, eta in zip([1, 10, 49], [0.0, 0.5, 1.0]):
            self.check_over_forward(time_step=t, eta=eta)

    def test_variance(self):
        scheduler_class = self.scheduler_classes[0]
Patrick von Platen's avatar
Patrick von Platen committed
331
        scheduler_config = self.get_scheduler_config()
Patrick von Platen's avatar
Patrick von Platen committed
332
333
        scheduler = scheduler_class(**scheduler_config)

334
335
336
337
338
339
        assert np.sum(np.abs(scheduler._get_variance(0, 0) - 0.0)) < 1e-5
        assert np.sum(np.abs(scheduler._get_variance(420, 400) - 0.14771)) < 1e-5
        assert np.sum(np.abs(scheduler._get_variance(980, 960) - 0.32460)) < 1e-5
        assert np.sum(np.abs(scheduler._get_variance(0, 0) - 0.0)) < 1e-5
        assert np.sum(np.abs(scheduler._get_variance(487, 486) - 0.00979)) < 1e-5
        assert np.sum(np.abs(scheduler._get_variance(999, 998) - 0.02)) < 1e-5
Patrick von Platen's avatar
Patrick von Platen committed
340
341
342
343
344
345

    def test_full_loop_no_noise(self):
        scheduler_class = self.scheduler_classes[0]
        scheduler_config = self.get_scheduler_config()
        scheduler = scheduler_class(**scheduler_config)

346
        num_inference_steps, eta = 10, 0.0
Patrick von Platen's avatar
Patrick von Platen committed
347
348

        model = self.dummy_model()
349
        sample = self.dummy_sample_deter
Patrick von Platen's avatar
Patrick von Platen committed
350

351
352
353
        scheduler.set_timesteps(num_inference_steps)
        for t in scheduler.timesteps:
            residual = model(sample, t)
Patrick von Platen's avatar
Patrick von Platen committed
354

355
            sample = scheduler.step(residual, t, sample, eta)["prev_sample"]
Patrick von Platen's avatar
Patrick von Platen committed
356

357
358
        result_sum = np.sum(np.abs(sample))
        result_mean = np.mean(np.abs(sample))
Patrick von Platen's avatar
Patrick von Platen committed
359

360
361
        assert abs(result_sum.item() - 172.0067) < 1e-2
        assert abs(result_mean.item() - 0.223967) < 1e-3
Patrick von Platen's avatar
Patrick von Platen committed
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380


class PNDMSchedulerTest(SchedulerCommonTest):
    scheduler_classes = (PNDMScheduler,)
    forward_default_kwargs = (("num_inference_steps", 50),)

    def get_scheduler_config(self, **kwargs):
        config = {
            "timesteps": 1000,
            "beta_start": 0.0001,
            "beta_end": 0.02,
            "beta_schedule": "linear",
        }

        config.update(**kwargs)
        return config

    def check_over_configs_pmls(self, time_step=0, **config):
        kwargs = dict(self.forward_default_kwargs)
381
382
        sample = self.dummy_sample
        residual = 0.1 * sample
Patrick von Platen's avatar
Patrick von Platen committed
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
        dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]

        for scheduler_class in self.scheduler_classes:
            scheduler_class = self.scheduler_classes[0]
            scheduler_config = self.get_scheduler_config(**config)
            scheduler = scheduler_class(**scheduler_config)
            # copy over dummy past residuals
            scheduler.ets = dummy_past_residuals[:]
            scheduler.set_plms_mode()

            with tempfile.TemporaryDirectory() as tmpdirname:
                scheduler.save_config(tmpdirname)
                new_scheduler = scheduler_class.from_config(tmpdirname)
                # copy over dummy past residuals
                new_scheduler.ets = dummy_past_residuals[:]
                new_scheduler.set_plms_mode()

400
401
            output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
            new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
Patrick von Platen's avatar
Patrick von Platen committed
402
403
404
405
406
407

            assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"

    def check_over_forward_pmls(self, time_step=0, **forward_kwargs):
        kwargs = dict(self.forward_default_kwargs)
        kwargs.update(forward_kwargs)
408
409
        sample = self.dummy_sample
        residual = 0.1 * sample
Patrick von Platen's avatar
Patrick von Platen committed
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
        dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]

        for scheduler_class in self.scheduler_classes:
            scheduler_class = self.scheduler_classes[0]
            scheduler_config = self.get_scheduler_config()
            scheduler = scheduler_class(**scheduler_config)
            # copy over dummy past residuals
            scheduler.ets = dummy_past_residuals[:]
            scheduler.set_plms_mode()

            with tempfile.TemporaryDirectory() as tmpdirname:
                scheduler.save_config(tmpdirname)
                new_scheduler = scheduler_class.from_config(tmpdirname)
                # copy over dummy past residuals
                new_scheduler.ets = dummy_past_residuals[:]
                new_scheduler.set_plms_mode()

427
428
            output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
            new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
Patrick von Platen's avatar
Patrick von Platen committed
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479

            assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"

    def test_timesteps(self):
        for timesteps in [100, 1000]:
            self.check_over_configs(timesteps=timesteps)

    def test_timesteps_pmls(self):
        for timesteps in [100, 1000]:
            self.check_over_configs_pmls(timesteps=timesteps)

    def test_betas(self):
        for beta_start, beta_end in zip([0.0001, 0.001, 0.01], [0.002, 0.02, 0.2]):
            self.check_over_configs(beta_start=beta_start, beta_end=beta_end)

    def test_betas_pmls(self):
        for beta_start, beta_end in zip([0.0001, 0.001, 0.01], [0.002, 0.02, 0.2]):
            self.check_over_configs_pmls(beta_start=beta_start, beta_end=beta_end)

    def test_schedules(self):
        for schedule in ["linear", "squaredcos_cap_v2"]:
            self.check_over_configs(beta_schedule=schedule)

    def test_schedules_pmls(self):
        for schedule in ["linear", "squaredcos_cap_v2"]:
            self.check_over_configs(beta_schedule=schedule)

    def test_time_indices(self):
        for t in [1, 5, 10]:
            self.check_over_forward(time_step=t)

    def test_time_indices_pmls(self):
        for t in [1, 5, 10]:
            self.check_over_forward_pmls(time_step=t)

    def test_inference_steps(self):
        for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]):
            self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps)

    def test_inference_steps_pmls(self):
        for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]):
            self.check_over_forward_pmls(time_step=t, num_inference_steps=num_inference_steps)

    def test_inference_pmls_no_past_residuals(self):
        with self.assertRaises(ValueError):
            scheduler_class = self.scheduler_classes[0]
            scheduler_config = self.get_scheduler_config()
            scheduler = scheduler_class(**scheduler_config)

            scheduler.set_plms_mode()

480
            scheduler.step(self.dummy_sample, 1, self.dummy_sample, 50)["prev_sample"]
Patrick von Platen's avatar
Patrick von Platen committed
481
482
483
484
485
486
487
488

    def test_full_loop_no_noise(self):
        scheduler_class = self.scheduler_classes[0]
        scheduler_config = self.get_scheduler_config()
        scheduler = scheduler_class(**scheduler_config)

        num_inference_steps = 10
        model = self.dummy_model()
489
        sample = self.dummy_sample_deter
Patrick von Platen's avatar
Patrick von Platen committed
490
491
492
493

        prk_time_steps = scheduler.get_prk_time_steps(num_inference_steps)
        for t in range(len(prk_time_steps)):
            t_orig = prk_time_steps[t]
494
            residual = model(sample, t_orig)
Patrick von Platen's avatar
Patrick von Platen committed
495

496
            sample = scheduler.step_prk(residual, t, sample, num_inference_steps)["prev_sample"]
Patrick von Platen's avatar
Patrick von Platen committed
497
498
499
500

        timesteps = scheduler.get_time_steps(num_inference_steps)
        for t in range(len(timesteps)):
            t_orig = timesteps[t]
501
            residual = model(sample, t_orig)
Patrick von Platen's avatar
Patrick von Platen committed
502

503
            sample = scheduler.step_plms(residual, t, sample, num_inference_steps)["prev_sample"]
Patrick von Platen's avatar
Patrick von Platen committed
504

505
506
        result_sum = np.sum(np.abs(sample))
        result_mean = np.mean(np.abs(sample))
Patrick von Platen's avatar
Patrick von Platen committed
507
508
509

        assert abs(result_sum.item() - 199.1169) < 1e-2
        assert abs(result_mean.item() - 0.2593) < 1e-3