test_modeling_utils.py 25.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Patrick von Platen's avatar
Patrick von Platen committed
16

patil-suraj's avatar
patil-suraj committed
17
import inspect
18
19
20
import tempfile
import unittest

21
import numpy as np
22
23
import torch

24
from diffusers import (
Patrick von Platen's avatar
Patrick von Platen committed
25
26
    BDDMPipeline,
    DDIMPipeline,
27
    DDIMScheduler,
Patrick von Platen's avatar
Patrick von Platen committed
28
    DDPMPipeline,
29
    DDPMScheduler,
Patrick von Platen's avatar
Patrick von Platen committed
30
    GlidePipeline,
Patrick von Platen's avatar
Patrick von Platen committed
31
32
    GlideSuperResUNetModel,
    GlideTextToImageUNetModel,
Patrick von Platen's avatar
Patrick von Platen committed
33
    GradTTSPipeline,
34
    GradTTSScheduler,
Patrick von Platen's avatar
Patrick von Platen committed
35
    LatentDiffusionPipeline,
Patrick von Platen's avatar
Patrick von Platen committed
36
    NCSNpp,
Patrick von Platen's avatar
Patrick von Platen committed
37
    PNDMPipeline,
38
    PNDMScheduler,
Patrick von Platen's avatar
Patrick von Platen committed
39
40
    ScoreSdeVePipeline,
    ScoreSdeVeScheduler,
patil-suraj's avatar
patil-suraj committed
41
    UNetGradTTSModel,
anton-l's avatar
anton-l committed
42
43
    UNetLDMModel,
    UNetModel,
44
)
45
from diffusers.configuration_utils import ConfigMixin
Patrick von Platen's avatar
Patrick von Platen committed
46
from diffusers.pipeline_utils import DiffusionPipeline
47
from diffusers.pipelines.pipeline_bddm import DiffWave
Patrick von Platen's avatar
Patrick von Platen committed
48
from diffusers.testing_utils import floats_tensor, slow, torch_device
49
50


Patrick von Platen's avatar
Patrick von Platen committed
51
torch.backends.cuda.matmul.allow_tf32 = False
52
53


54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
class ConfigTester(unittest.TestCase):
    def test_load_not_from_mixin(self):
        with self.assertRaises(ValueError):
            ConfigMixin.from_config("dummy_path")

    def test_save_load(self):
        class SampleObject(ConfigMixin):
            config_name = "config.json"

            def __init__(
                self,
                a=2,
                b=5,
                c=(2, 5),
                d="for diffusion",
                e=[1, 3],
            ):
71
                self.register_to_config(a=a, b=b, c=c, d=d, e=e)
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86

        obj = SampleObject()
        config = obj.config

        assert config["a"] == 2
        assert config["b"] == 5
        assert config["c"] == (2, 5)
        assert config["d"] == "for diffusion"
        assert config["e"] == [1, 3]

        with tempfile.TemporaryDirectory() as tmpdirname:
            obj.save_config(tmpdirname)
            new_obj = SampleObject.from_config(tmpdirname)
            new_config = new_obj.config

Patrick von Platen's avatar
Patrick von Platen committed
87
88
89
90
        # unfreeze configs
        config = dict(config)
        new_config = dict(new_config)

91
92
93
94
95
        assert config.pop("c") == (2, 5)  # instantiated as tuple
        assert new_config.pop("c") == [2, 5]  # saved & loaded as list because of json
        assert config == new_config


patil-suraj's avatar
patil-suraj committed
96
class ModelTesterMixin:
97
    def test_from_pretrained_save_pretrained(self):
patil-suraj's avatar
patil-suraj committed
98
99
100
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

        model = self.model_class(**init_dict)
Patrick von Platen's avatar
Patrick von Platen committed
101
        model.to(torch_device)
patil-suraj's avatar
patil-suraj committed
102
        model.eval()
103
104
105

        with tempfile.TemporaryDirectory() as tmpdirname:
            model.save_pretrained(tmpdirname)
patil-suraj's avatar
patil-suraj committed
106
            new_model = self.model_class.from_pretrained(tmpdirname)
Patrick von Platen's avatar
Patrick von Platen committed
107
            new_model.to(torch_device)
108

patil-suraj's avatar
patil-suraj committed
109
110
111
        with torch.no_grad():
            image = model(**inputs_dict)
            new_image = new_model(**inputs_dict)
112

patil-suraj's avatar
patil-suraj committed
113
114
        max_diff = (image - new_image).abs().sum().item()
        self.assertLessEqual(max_diff, 1e-5, "Models give different forward passes")
115

patil-suraj's avatar
patil-suraj committed
116
    def test_determinism(self):
patil-suraj's avatar
patil-suraj committed
117
118
119
120
121
122
123
124
125
126
127
128
129
130
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.eval()
        with torch.no_grad():
            first = model(**inputs_dict)
            second = model(**inputs_dict)

        out_1 = first.cpu().numpy()
        out_2 = second.cpu().numpy()
        out_1 = out_1[~np.isnan(out_1)]
        out_2 = out_2[~np.isnan(out_2)]
        max_diff = np.amax(np.abs(out_1 - out_2))
        self.assertLessEqual(max_diff, 1e-5)
131

patil-suraj's avatar
patil-suraj committed
132
    def test_output(self):
patil-suraj's avatar
patil-suraj committed
133
134
135
136
137
138
139
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.eval()

        with torch.no_grad():
            output = model(**inputs_dict)
140

patil-suraj's avatar
patil-suraj committed
141
142
143
        self.assertIsNotNone(output)
        expected_shape = inputs_dict["x"].shape
        self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
144

patil-suraj's avatar
patil-suraj committed
145
    def test_forward_signature(self):
patil-suraj's avatar
patil-suraj committed
146
147
148
149
150
151
152
153
154
        init_dict, _ = self.prepare_init_args_and_inputs_for_common()

        model = self.model_class(**init_dict)
        signature = inspect.signature(model.forward)
        # signature.parameters is an OrderedDict => so arg_names order is deterministic
        arg_names = [*signature.parameters.keys()]

        expected_arg_names = ["x", "timesteps"]
        self.assertListEqual(arg_names[:2], expected_arg_names)
155

patil-suraj's avatar
patil-suraj committed
156
    def test_model_from_config(self):
patil-suraj's avatar
patil-suraj committed
157
158
159
160
161
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.eval()
162

patil-suraj's avatar
patil-suraj committed
163
164
165
166
167
168
169
        # test if the model can be loaded from the config
        # and has all the expected shape
        with tempfile.TemporaryDirectory() as tmpdirname:
            model.save_config(tmpdirname)
            new_model = self.model_class.from_config(tmpdirname)
            new_model.to(torch_device)
            new_model.eval()
170

patil-suraj's avatar
patil-suraj committed
171
172
173
174
175
        # check if all paramters shape are the same
        for param_name in model.state_dict().keys():
            param_1 = model.state_dict()[param_name]
            param_2 = new_model.state_dict()[param_name]
            self.assertEqual(param_1.shape, param_2.shape)
176

patil-suraj's avatar
patil-suraj committed
177
178
179
        with torch.no_grad():
            output_1 = model(**inputs_dict)
            output_2 = new_model(**inputs_dict)
180

patil-suraj's avatar
patil-suraj committed
181
        self.assertEqual(output_1.shape, output_2.shape)
patil-suraj's avatar
patil-suraj committed
182
183

    def test_training(self):
patil-suraj's avatar
patil-suraj committed
184
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
185

patil-suraj's avatar
patil-suraj committed
186
187
188
189
        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.train()
        output = model(**inputs_dict)
190
        noise = torch.randn((inputs_dict["x"].shape[0],) + self.get_output_shape).to(torch_device)
patil-suraj's avatar
patil-suraj committed
191
192
        loss = torch.nn.functional.mse_loss(output, noise)
        loss.backward()
193

patil-suraj's avatar
patil-suraj committed
194
195
196
197
198
199
200
201
202
203
204
205
206

class UnetModelTests(ModelTesterMixin, unittest.TestCase):
    model_class = UNetModel

    @property
    def dummy_input(self):
        batch_size = 4
        num_channels = 3
        sizes = (32, 32)

        noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
        time_step = torch.tensor([10]).to(torch_device)

patil-suraj's avatar
patil-suraj committed
207
        return {"x": noise, "timesteps": time_step}
208

patil-suraj's avatar
patil-suraj committed
209
210
211
    @property
    def get_input_shape(self):
        return (3, 32, 32)
212

patil-suraj's avatar
patil-suraj committed
213
214
215
    @property
    def get_output_shape(self):
        return (3, 32, 32)
patil-suraj's avatar
patil-suraj committed
216
217
218
219
220
221
222
223
224
225
226

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
            "ch": 32,
            "ch_mult": (1, 2),
            "num_res_blocks": 2,
            "attn_resolutions": (16,),
            "resolution": 32,
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict
227

patil-suraj's avatar
patil-suraj committed
228
    def test_from_pretrained_hub(self):
patil-suraj's avatar
patil-suraj committed
229
230
231
        model, loading_info = UNetModel.from_pretrained("fusing/ddpm_dummy", output_loading_info=True)
        self.assertIsNotNone(model)
        self.assertEqual(len(loading_info["missing_keys"]), 0)
patil-suraj's avatar
patil-suraj committed
232

patil-suraj's avatar
patil-suraj committed
233
        model.to(torch_device)
patil-suraj's avatar
patil-suraj committed
234
235
236
        image = model(**self.dummy_input)

        assert image is not None, "Make sure output is not None"
237

patil-suraj's avatar
patil-suraj committed
238
239
240
241
242
243
244
    def test_output_pretrained(self):
        model = UNetModel.from_pretrained("fusing/ddpm_dummy")
        model.eval()

        torch.manual_seed(0)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)
245

patil-suraj's avatar
patil-suraj committed
246
247
        noise = torch.randn(1, model.config.in_channels, model.config.resolution, model.config.resolution)
        time_step = torch.tensor([10])
248

patil-suraj's avatar
patil-suraj committed
249
250
        with torch.no_grad():
            output = model(noise, time_step)
251

patil-suraj's avatar
patil-suraj committed
252
253
        output_slice = output[0, -1, -3:, -3:].flatten()
        # fmt: off
Patrick von Platen's avatar
Patrick von Platen committed
254
        expected_output_slice = torch.tensor([0.2891, -0.1899, 0.2595, -0.6214, 0.0968, -0.2622, 0.4688, 0.1311, 0.0053])
patil-suraj's avatar
patil-suraj committed
255
256
257
        # fmt: on
        self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))

258

Patrick von Platen's avatar
Patrick von Platen committed
259
260
class GlideSuperResUNetTests(ModelTesterMixin, unittest.TestCase):
    model_class = GlideSuperResUNetModel
patil-suraj's avatar
patil-suraj committed
261
262
263
264
265
266
267
268
269
270
271
272
273

    @property
    def dummy_input(self):
        batch_size = 4
        num_channels = 6
        sizes = (32, 32)
        low_res_size = (4, 4)

        noise = torch.randn((batch_size, num_channels // 2) + sizes).to(torch_device)
        low_res = torch.randn((batch_size, 3) + low_res_size).to(torch_device)
        time_step = torch.tensor([10] * noise.shape[0], device=torch_device)

        return {"x": noise, "timesteps": time_step, "low_res": low_res}
274

patil-suraj's avatar
patil-suraj committed
275
276
277
    @property
    def get_input_shape(self):
        return (3, 32, 32)
278

patil-suraj's avatar
patil-suraj committed
279
280
281
    @property
    def get_output_shape(self):
        return (6, 32, 32)
282

patil-suraj's avatar
patil-suraj committed
283
284
285
    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
            "attention_resolutions": (2,),
286
            "channel_mult": (1, 2),
patil-suraj's avatar
patil-suraj committed
287
288
289
290
291
292
293
294
            "in_channels": 6,
            "out_channels": 6,
            "model_channels": 32,
            "num_head_channels": 8,
            "num_heads_upsample": 1,
            "num_res_blocks": 2,
            "resblock_updown": True,
            "resolution": 32,
295
            "use_scale_shift_norm": True,
patil-suraj's avatar
patil-suraj committed
296
297
298
299
300
301
302
303
304
305
306
307
308
309
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict

    def test_output(self):
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.eval()

        with torch.no_grad():
            output = model(**inputs_dict)

        output, _ = torch.split(output, 3, dim=1)
310

patil-suraj's avatar
patil-suraj committed
311
312
313
        self.assertIsNotNone(output)
        expected_shape = inputs_dict["x"].shape
        self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
314

patil-suraj's avatar
patil-suraj committed
315
    def test_from_pretrained_hub(self):
Patrick von Platen's avatar
Patrick von Platen committed
316
        model, loading_info = GlideSuperResUNetModel.from_pretrained(
317
318
            "fusing/glide-super-res-dummy", output_loading_info=True
        )
patil-suraj's avatar
patil-suraj committed
319
320
321
322
323
324
325
        self.assertIsNotNone(model)
        self.assertEqual(len(loading_info["missing_keys"]), 0)

        model.to(torch_device)
        image = model(**self.dummy_input)

        assert image is not None, "Make sure output is not None"
326

patil-suraj's avatar
patil-suraj committed
327
    def test_output_pretrained(self):
Patrick von Platen's avatar
Patrick von Platen committed
328
        model = GlideSuperResUNetModel.from_pretrained("fusing/glide-super-res-dummy")
patil-suraj's avatar
patil-suraj committed
329
330
331
332

        torch.manual_seed(0)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)
333

334
        noise = torch.randn(1, 3, 64, 64)
patil-suraj's avatar
patil-suraj committed
335
336
        low_res = torch.randn(1, 3, 4, 4)
        time_step = torch.tensor([42] * noise.shape[0])
337

patil-suraj's avatar
patil-suraj committed
338
339
        with torch.no_grad():
            output = model(noise, time_step, low_res)
340

patil-suraj's avatar
patil-suraj committed
341
342
343
        output, _ = torch.split(output, 3, dim=1)
        output_slice = output[0, -1, -3:, -3:].flatten()
        # fmt: off
344
        expected_output_slice = torch.tensor([-22.8782, -23.2652, -15.3966, -22.8034, -23.3159, -15.5640, -15.3970, -15.4614, - 10.4370])
patil-suraj's avatar
patil-suraj committed
345
346
        # fmt: on
        self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
patil-suraj's avatar
patil-suraj committed
347

anton-l's avatar
anton-l committed
348

Patrick von Platen's avatar
Patrick von Platen committed
349
350
class GlideTextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase):
    model_class = GlideTextToImageUNetModel
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407

    @property
    def dummy_input(self):
        batch_size = 4
        num_channels = 3
        sizes = (32, 32)
        transformer_dim = 32
        seq_len = 16

        noise = torch.randn((batch_size, num_channels) + sizes).to(torch_device)
        emb = torch.randn((batch_size, seq_len, transformer_dim)).to(torch_device)
        time_step = torch.tensor([10] * noise.shape[0], device=torch_device)

        return {"x": noise, "timesteps": time_step, "transformer_out": emb}

    @property
    def get_input_shape(self):
        return (3, 32, 32)

    @property
    def get_output_shape(self):
        return (6, 32, 32)

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
            "attention_resolutions": (2,),
            "channel_mult": (1, 2),
            "in_channels": 3,
            "out_channels": 6,
            "model_channels": 32,
            "num_head_channels": 8,
            "num_heads_upsample": 1,
            "num_res_blocks": 2,
            "resblock_updown": True,
            "resolution": 32,
            "use_scale_shift_norm": True,
            "transformer_dim": 32,
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict

    def test_output(self):
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.eval()

        with torch.no_grad():
            output = model(**inputs_dict)

        output, _ = torch.split(output, 3, dim=1)

        self.assertIsNotNone(output)
        expected_shape = inputs_dict["x"].shape
        self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")

    def test_from_pretrained_hub(self):
Patrick von Platen's avatar
Patrick von Platen committed
408
        model, loading_info = GlideTextToImageUNetModel.from_pretrained(
409
410
411
412
413
414
415
416
417
418
419
            "fusing/unet-glide-text2im-dummy", output_loading_info=True
        )
        self.assertIsNotNone(model)
        self.assertEqual(len(loading_info["missing_keys"]), 0)

        model.to(torch_device)
        image = model(**self.dummy_input)

        assert image is not None, "Make sure output is not None"

    def test_output_pretrained(self):
Patrick von Platen's avatar
Patrick von Platen committed
420
        model = GlideTextToImageUNetModel.from_pretrained("fusing/unet-glide-text2im-dummy")
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437

        torch.manual_seed(0)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)

        noise = torch.randn((1, model.config.in_channels, model.config.resolution, model.config.resolution)).to(
            torch_device
        )
        emb = torch.randn((1, 16, model.config.transformer_dim)).to(torch_device)
        time_step = torch.tensor([10] * noise.shape[0], device=torch_device)

        with torch.no_grad():
            output = model(noise, time_step, emb)

        output, _ = torch.split(output, 3, dim=1)
        output_slice = output[0, -1, -3:, -3:].flatten()
        # fmt: off
Patrick von Platen's avatar
Patrick von Platen committed
438
        expected_output_slice = torch.tensor([2.7766, -10.3558, -14.9149, -0.9376, -14.9175, -17.7679, -5.5565, -12.9521, -12.9845])
439
440
441
442
        # fmt: on
        self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))


patil-suraj's avatar
patil-suraj committed
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
    model_class = UNetLDMModel

    @property
    def dummy_input(self):
        batch_size = 4
        num_channels = 4
        sizes = (32, 32)

        noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
        time_step = torch.tensor([10]).to(torch_device)

        return {"x": noise, "timesteps": time_step}

    @property
    def get_input_shape(self):
        return (4, 32, 32)

    @property
    def get_output_shape(self):
        return (4, 32, 32)

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
            "image_size": 32,
            "in_channels": 4,
            "out_channels": 4,
            "model_channels": 32,
            "num_res_blocks": 2,
            "attention_resolutions": (16,),
            "channel_mult": (1, 2),
            "num_heads": 2,
            "conv_resample": True,
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict
anton-l's avatar
anton-l committed
479

patil-suraj's avatar
patil-suraj committed
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
    def test_from_pretrained_hub(self):
        model, loading_info = UNetLDMModel.from_pretrained("fusing/unet-ldm-dummy", output_loading_info=True)
        self.assertIsNotNone(model)
        self.assertEqual(len(loading_info["missing_keys"]), 0)

        model.to(torch_device)
        image = model(**self.dummy_input)

        assert image is not None, "Make sure output is not None"

    def test_output_pretrained(self):
        model = UNetLDMModel.from_pretrained("fusing/unet-ldm-dummy")
        model.eval()

        torch.manual_seed(0)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)

        noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
        time_step = torch.tensor([10] * noise.shape[0])

        with torch.no_grad():
            output = model(noise, time_step)

        output_slice = output[0, -1, -3:, -3:].flatten()
        # fmt: off
        expected_output_slice = torch.tensor([-13.3258, -20.1100, -15.9873, -17.6617, -23.0596, -17.9419, -13.3675, -16.1889, -12.3800])
        # fmt: on

        self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))

patil-suraj's avatar
patil-suraj committed
511

patil-suraj's avatar
patil-suraj committed
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase):
    model_class = UNetGradTTSModel

    @property
    def dummy_input(self):
        batch_size = 4
        num_features = 32
        seq_len = 16

        noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device)
        condition = floats_tensor((batch_size, num_features, seq_len)).to(torch_device)
        mask = floats_tensor((batch_size, 1, seq_len)).to(torch_device)
        time_step = torch.tensor([10] * batch_size).to(torch_device)

        return {"x": noise, "timesteps": time_step, "mu": condition, "mask": mask}

    @property
    def get_input_shape(self):
        return (4, 32, 16)

    @property
    def get_output_shape(self):
        return (4, 32, 16)

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
            "dim": 64,
            "groups": 4,
            "dim_mults": (1, 2),
            "n_feats": 32,
            "pe_scale": 1000,
            "n_spks": 1,
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict
anton-l's avatar
anton-l committed
547

patil-suraj's avatar
patil-suraj committed
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
    def test_from_pretrained_hub(self):
        model, loading_info = UNetGradTTSModel.from_pretrained("fusing/unet-grad-tts-dummy", output_loading_info=True)
        self.assertIsNotNone(model)
        self.assertEqual(len(loading_info["missing_keys"]), 0)

        model.to(torch_device)
        image = model(**self.dummy_input)

        assert image is not None, "Make sure output is not None"

    def test_output_pretrained(self):
        model = UNetGradTTSModel.from_pretrained("fusing/unet-grad-tts-dummy")
        model.eval()

        torch.manual_seed(0)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)
anton-l's avatar
anton-l committed
565

patil-suraj's avatar
patil-suraj committed
566
567
568
569
570
571
572
573
574
575
576
577
        num_features = model.config.n_feats
        seq_len = 16
        noise = torch.randn((1, num_features, seq_len))
        condition = torch.randn((1, num_features, seq_len))
        mask = torch.randn((1, 1, seq_len))
        time_step = torch.tensor([10])

        with torch.no_grad():
            output = model(noise, time_step, condition, mask)

        output_slice = output[0, -3:, -3:].flatten()
        # fmt: off
Patrick von Platen's avatar
Patrick von Platen committed
578
        expected_output_slice = torch.tensor([-0.0690, -0.0531, 0.0633, -0.0660, -0.0541, 0.0650, -0.0656, -0.0555, 0.0617])
patil-suraj's avatar
patil-suraj committed
579
580
581
582
583
        # fmt: on

        self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))


584
585
586
587
class PipelineTesterMixin(unittest.TestCase):
    def test_from_pretrained_save_pretrained(self):
        # 1. Load models
        model = UNetModel(ch=32, ch_mult=(1, 2), num_res_blocks=2, attn_resolutions=(16,), resolution=32)
Patrick von Platen's avatar
Patrick von Platen committed
588
        schedular = DDPMScheduler(timesteps=10)
589

Patrick von Platen's avatar
Patrick von Platen committed
590
        ddpm = DDPMPipeline(model, schedular)
591
592
593

        with tempfile.TemporaryDirectory() as tmpdirname:
            ddpm.save_pretrained(tmpdirname)
Patrick von Platen's avatar
Patrick von Platen committed
594
            new_ddpm = DDPMPipeline.from_pretrained(tmpdirname)
Patrick von Platen's avatar
Patrick von Platen committed
595
596

        generator = torch.manual_seed(0)
597

patil-suraj's avatar
patil-suraj committed
598
        image = ddpm(generator=generator)
Patrick von Platen's avatar
Patrick von Platen committed
599
        generator = generator.manual_seed(0)
patil-suraj's avatar
patil-suraj committed
600
        new_image = new_ddpm(generator=generator)
601
602
603
604
605
606
607

        assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"

    @slow
    def test_from_pretrained_hub(self):
        model_path = "fusing/ddpm-cifar10"

Patrick von Platen's avatar
Patrick von Platen committed
608
        ddpm = DDPMPipeline.from_pretrained(model_path)
609
610
611
612
613
        ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path)

        ddpm.noise_scheduler.num_timesteps = 10
        ddpm_from_hub.noise_scheduler.num_timesteps = 10

Patrick von Platen's avatar
Patrick von Platen committed
614
        generator = torch.manual_seed(0)
615

patil-suraj's avatar
patil-suraj committed
616
        image = ddpm(generator=generator)
Patrick von Platen's avatar
Patrick von Platen committed
617
        generator = generator.manual_seed(0)
patil-suraj's avatar
patil-suraj committed
618
        new_image = ddpm_from_hub(generator=generator)
619
620

        assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"
Patrick von Platen's avatar
Patrick von Platen committed
621
622
623
624
625
626

    @slow
    def test_ddpm_cifar10(self):
        generator = torch.manual_seed(0)
        model_id = "fusing/ddpm-cifar10"

Patrick von Platen's avatar
Patrick von Platen committed
627
        unet = UNetModel.from_pretrained(model_id)
Patrick von Platen's avatar
Patrick von Platen committed
628
        noise_scheduler = DDPMScheduler.from_config(model_id)
Patrick von Platen's avatar
Patrick von Platen committed
629
        noise_scheduler = noise_scheduler.set_format("pt")
Patrick von Platen's avatar
Patrick von Platen committed
630

Patrick von Platen's avatar
Patrick von Platen committed
631
        ddpm = DDPMPipeline(unet=unet, noise_scheduler=noise_scheduler)
Patrick von Platen's avatar
Patrick von Platen committed
632
633
634
635
636
637
638
639
640
641
642
643
644
        image = ddpm(generator=generator)

        image_slice = image[0, -1, -3:, -3:].cpu()

        assert image.shape == (1, 3, 32, 32)
        expected_slice = torch.tensor([0.2250, 0.3375, 0.2360, 0.0930, 0.3440, 0.3156, 0.1937, 0.3585, 0.1761])
        assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2

    @slow
    def test_ddim_cifar10(self):
        generator = torch.manual_seed(0)
        model_id = "fusing/ddpm-cifar10"

Patrick von Platen's avatar
Patrick von Platen committed
645
        unet = UNetModel.from_pretrained(model_id)
Patrick von Platen's avatar
Patrick von Platen committed
646
        noise_scheduler = DDIMScheduler(tensor_format="pt")
Patrick von Platen's avatar
Patrick von Platen committed
647

Patrick von Platen's avatar
Patrick von Platen committed
648
        ddim = DDIMPipeline(unet=unet, noise_scheduler=noise_scheduler)
Patrick von Platen's avatar
Patrick von Platen committed
649
650
651
652
653
        image = ddim(generator=generator, eta=0.0)

        image_slice = image[0, -1, -3:, -3:].cpu()

        assert image.shape == (1, 3, 32, 32)
Patrick von Platen's avatar
Patrick von Platen committed
654
655
656
        expected_slice = torch.tensor(
            [-0.7383, -0.7385, -0.7298, -0.7364, -0.7414, -0.7239, -0.6737, -0.6813, -0.7068]
        )
Patrick von Platen's avatar
Patrick von Platen committed
657
        assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
patil-suraj's avatar
patil-suraj committed
658

Patrick von Platen's avatar
Patrick von Platen committed
659
660
661
662
663
664
665
666
    @slow
    def test_pndm_cifar10(self):
        generator = torch.manual_seed(0)
        model_id = "fusing/ddpm-cifar10"

        unet = UNetModel.from_pretrained(model_id)
        noise_scheduler = PNDMScheduler(tensor_format="pt")

Patrick von Platen's avatar
Patrick von Platen committed
667
        pndm = PNDMPipeline(unet=unet, noise_scheduler=noise_scheduler)
Patrick von Platen's avatar
Patrick von Platen committed
668
669
670
671
672
673
674
675
676
677
        image = pndm(generator=generator)

        image_slice = image[0, -1, -3:, -3:].cpu()

        assert image.shape == (1, 3, 32, 32)
        expected_slice = torch.tensor(
            [-0.7888, -0.7870, -0.7759, -0.7823, -0.8014, -0.7608, -0.6818, -0.7130, -0.7471]
        )
        assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2

patil-suraj's avatar
patil-suraj committed
678
679
680
    @slow
    def test_ldm_text2img(self):
        model_id = "fusing/latent-diffusion-text2im-large"
Patrick von Platen's avatar
Patrick von Platen committed
681
        ldm = LatentDiffusionPipeline.from_pretrained(model_id)
patil-suraj's avatar
patil-suraj committed
682
683
684
685
686
687
688
689
690

        prompt = "A painting of a squirrel eating a burger"
        generator = torch.manual_seed(0)
        image = ldm([prompt], generator=generator, num_inference_steps=20)

        image_slice = image[0, -1, -3:, -3:].cpu()

        assert image.shape == (1, 3, 256, 256)
        expected_slice = torch.tensor([0.7295, 0.7358, 0.7256, 0.7435, 0.7095, 0.6884, 0.7325, 0.6921, 0.6458])
Patrick von Platen's avatar
update  
Patrick von Platen committed
691
        assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
692

anton-l's avatar
anton-l committed
693
694
695
    @slow
    def test_glide_text2img(self):
        model_id = "fusing/glide-base"
Patrick von Platen's avatar
Patrick von Platen committed
696
        glide = GlidePipeline.from_pretrained(model_id)
anton-l's avatar
anton-l committed
697
698
699
700
701
702
703
704
705
706
707

        prompt = "a pencil sketch of a corgi"
        generator = torch.manual_seed(0)
        image = glide(prompt, generator=generator, num_inference_steps_upscale=20)

        image_slice = image[0, :3, :3, -1].cpu()

        assert image.shape == (1, 256, 256, 3)
        expected_slice = torch.tensor([0.7119, 0.7073, 0.6460, 0.7780, 0.7423, 0.6926, 0.7378, 0.7189, 0.7784])
        assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2

Patrick von Platen's avatar
Patrick von Platen committed
708
709
710
    @slow
    def test_grad_tts(self):
        model_id = "fusing/grad-tts-libri-tts"
Patrick von Platen's avatar
Patrick von Platen committed
711
        grad_tts = GradTTSPipeline.from_pretrained(model_id)
712
713
        noise_scheduler = GradTTSScheduler()
        grad_tts.noise_scheduler = noise_scheduler
Patrick von Platen's avatar
Patrick von Platen committed
714
715

        text = "Hello world, I missed you so much."
Patrick von Platen's avatar
Patrick von Platen committed
716
        generator = torch.manual_seed(0)
Patrick von Platen's avatar
Patrick von Platen committed
717
718

        # generate mel spectograms using text
Patrick von Platen's avatar
Patrick von Platen committed
719
        mel_spec = grad_tts(text, generator=generator)
Patrick von Platen's avatar
Patrick von Platen committed
720

Patrick von Platen's avatar
Patrick von Platen committed
721
722
        assert mel_spec.shape == (1, 80, 143)
        expected_slice = torch.tensor(
Patrick von Platen's avatar
Patrick von Platen committed
723
            [-6.7584, -6.8347, -6.3293, -6.6437, -6.7233, -6.4684, -6.1187, -6.3172, -6.6890]
Patrick von Platen's avatar
Patrick von Platen committed
724
        )
Patrick von Platen's avatar
Patrick von Platen committed
725
        assert (mel_spec[0, :3, :3].cpu().flatten() - expected_slice).abs().max() < 1e-2
Patrick von Platen's avatar
Patrick von Platen committed
726

Patrick von Platen's avatar
Patrick von Platen committed
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
    @slow
    def test_score_sde_ve_pipeline(self):
        torch.manual_seed(0)

        model = NCSNpp.from_pretrained("fusing/ffhq_ncsnpp")
        scheduler = ScoreSdeVeScheduler.from_config("fusing/ffhq_ncsnpp")

        sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler)

        image = sde_ve(num_inference_steps=2)

        expected_image_sum = 3382810112.0
        expected_image_mean = 1075.366455078125

        assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2
        assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4

744
745
746
747
    def test_module_from_pipeline(self):
        model = DiffWave(num_res_layers=4)
        noise_scheduler = DDPMScheduler(timesteps=12)

Patrick von Platen's avatar
Patrick von Platen committed
748
        bddm = BDDMPipeline(model, noise_scheduler)
749
750
751
752
753
754
755

        # check if the library name for the diffwave moduel is set to pipeline module
        self.assertTrue(bddm.config["diffwave"][0] == "pipeline_bddm")

        # check if we can save and load the pipeline
        with tempfile.TemporaryDirectory() as tmpdirname:
            bddm.save_pretrained(tmpdirname)
Patrick von Platen's avatar
Patrick von Platen committed
756
            _ = BDDMPipeline.from_pretrained(tmpdirname)
757
            # check if the same works using the DifusionPipeline class
758
            _ = DiffusionPipeline.from_pretrained(tmpdirname)