test_modeling_utils.py 27.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Patrick von Platen's avatar
Patrick von Platen committed
16

patil-suraj's avatar
patil-suraj committed
17
import inspect
18
import math
19
20
21
import tempfile
import unittest

22
import numpy as np
23
24
import torch

Patrick von Platen's avatar
upload  
Patrick von Platen committed
25
from diffusers import UNetConditionalModel  # noqa: F401 TODO(Patrick) - need to write tests with it
Patrick von Platen's avatar
Patrick von Platen committed
26
from diffusers import (
patil-suraj's avatar
patil-suraj committed
27
    AutoencoderKL,
Patrick von Platen's avatar
Patrick von Platen committed
28
    DDIMPipeline,
29
    DDIMScheduler,
Patrick von Platen's avatar
Patrick von Platen committed
30
    DDPMPipeline,
31
    DDPMScheduler,
Patrick von Platen's avatar
Patrick von Platen committed
32
    LatentDiffusionPipeline,
patil-suraj's avatar
patil-suraj committed
33
    LatentDiffusionUncondPipeline,
Patrick von Platen's avatar
Patrick von Platen committed
34
    PNDMPipeline,
35
    PNDMScheduler,
Patrick von Platen's avatar
Patrick von Platen committed
36
37
    ScoreSdeVePipeline,
    ScoreSdeVeScheduler,
38
    UNetUnconditionalModel,
patil-suraj's avatar
patil-suraj committed
39
    VQModel,
40
)
41
from diffusers.configuration_utils import ConfigMixin
Patrick von Platen's avatar
Patrick von Platen committed
42
from diffusers.pipeline_utils import DiffusionPipeline
Patrick von Platen's avatar
Patrick von Platen committed
43
from diffusers.testing_utils import floats_tensor, slow, torch_device
44
from diffusers.training_utils import EMAModel
45
46


Patrick von Platen's avatar
Patrick von Platen committed
47
torch.backends.cuda.matmul.allow_tf32 = False
48
49


50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
class ConfigTester(unittest.TestCase):
    def test_load_not_from_mixin(self):
        with self.assertRaises(ValueError):
            ConfigMixin.from_config("dummy_path")

    def test_save_load(self):
        class SampleObject(ConfigMixin):
            config_name = "config.json"

            def __init__(
                self,
                a=2,
                b=5,
                c=(2, 5),
                d="for diffusion",
                e=[1, 3],
            ):
67
                self.register_to_config(a=a, b=b, c=c, d=d, e=e)
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82

        obj = SampleObject()
        config = obj.config

        assert config["a"] == 2
        assert config["b"] == 5
        assert config["c"] == (2, 5)
        assert config["d"] == "for diffusion"
        assert config["e"] == [1, 3]

        with tempfile.TemporaryDirectory() as tmpdirname:
            obj.save_config(tmpdirname)
            new_obj = SampleObject.from_config(tmpdirname)
            new_config = new_obj.config

Patrick von Platen's avatar
Patrick von Platen committed
83
84
85
86
        # unfreeze configs
        config = dict(config)
        new_config = dict(new_config)

87
88
89
90
91
        assert config.pop("c") == (2, 5)  # instantiated as tuple
        assert new_config.pop("c") == [2, 5]  # saved & loaded as list because of json
        assert config == new_config


patil-suraj's avatar
patil-suraj committed
92
class ModelTesterMixin:
93
    def test_from_pretrained_save_pretrained(self):
patil-suraj's avatar
patil-suraj committed
94
95
96
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

        model = self.model_class(**init_dict)
Patrick von Platen's avatar
Patrick von Platen committed
97
        model.to(torch_device)
patil-suraj's avatar
patil-suraj committed
98
        model.eval()
99
100
101

        with tempfile.TemporaryDirectory() as tmpdirname:
            model.save_pretrained(tmpdirname)
patil-suraj's avatar
patil-suraj committed
102
            new_model = self.model_class.from_pretrained(tmpdirname)
Patrick von Platen's avatar
Patrick von Platen committed
103
            new_model.to(torch_device)
104

patil-suraj's avatar
patil-suraj committed
105
106
        with torch.no_grad():
            image = model(**inputs_dict)
Patrick von Platen's avatar
Patrick von Platen committed
107
108
109
            if isinstance(image, dict):
                image = image["sample"]

patil-suraj's avatar
patil-suraj committed
110
            new_image = new_model(**inputs_dict)
111

Patrick von Platen's avatar
Patrick von Platen committed
112
113
114
            if isinstance(new_image, dict):
                new_image = new_image["sample"]

patil-suraj's avatar
patil-suraj committed
115
        max_diff = (image - new_image).abs().sum().item()
Patrick von Platen's avatar
Patrick von Platen committed
116
        self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
117

patil-suraj's avatar
patil-suraj committed
118
    def test_determinism(self):
patil-suraj's avatar
patil-suraj committed
119
120
121
122
123
124
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.eval()
        with torch.no_grad():
            first = model(**inputs_dict)
Patrick von Platen's avatar
Patrick von Platen committed
125
126
127
            if isinstance(first, dict):
                first = first["sample"]

patil-suraj's avatar
patil-suraj committed
128
            second = model(**inputs_dict)
Patrick von Platen's avatar
Patrick von Platen committed
129
130
            if isinstance(second, dict):
                second = second["sample"]
patil-suraj's avatar
patil-suraj committed
131
132
133
134
135
136
137

        out_1 = first.cpu().numpy()
        out_2 = second.cpu().numpy()
        out_1 = out_1[~np.isnan(out_1)]
        out_2 = out_2[~np.isnan(out_2)]
        max_diff = np.amax(np.abs(out_1 - out_2))
        self.assertLessEqual(max_diff, 1e-5)
138

patil-suraj's avatar
patil-suraj committed
139
    def test_output(self):
patil-suraj's avatar
patil-suraj committed
140
141
142
143
144
145
146
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.eval()

        with torch.no_grad():
            output = model(**inputs_dict)
147

Patrick von Platen's avatar
Patrick von Platen committed
148
149
150
            if isinstance(output, dict):
                output = output["sample"]

patil-suraj's avatar
patil-suraj committed
151
        self.assertIsNotNone(output)
152
        expected_shape = inputs_dict["sample"].shape
patil-suraj's avatar
patil-suraj committed
153
        self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
154

patil-suraj's avatar
patil-suraj committed
155
    def test_forward_signature(self):
patil-suraj's avatar
patil-suraj committed
156
157
158
159
160
161
162
        init_dict, _ = self.prepare_init_args_and_inputs_for_common()

        model = self.model_class(**init_dict)
        signature = inspect.signature(model.forward)
        # signature.parameters is an OrderedDict => so arg_names order is deterministic
        arg_names = [*signature.parameters.keys()]

163
        expected_arg_names = ["sample", "timestep"]
patil-suraj's avatar
patil-suraj committed
164
        self.assertListEqual(arg_names[:2], expected_arg_names)
165

patil-suraj's avatar
patil-suraj committed
166
    def test_model_from_config(self):
patil-suraj's avatar
patil-suraj committed
167
168
169
170
171
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.eval()
172

patil-suraj's avatar
patil-suraj committed
173
174
175
176
177
178
179
        # test if the model can be loaded from the config
        # and has all the expected shape
        with tempfile.TemporaryDirectory() as tmpdirname:
            model.save_config(tmpdirname)
            new_model = self.model_class.from_config(tmpdirname)
            new_model.to(torch_device)
            new_model.eval()
180

patil-suraj's avatar
patil-suraj committed
181
182
183
184
185
        # check if all paramters shape are the same
        for param_name in model.state_dict().keys():
            param_1 = model.state_dict()[param_name]
            param_2 = new_model.state_dict()[param_name]
            self.assertEqual(param_1.shape, param_2.shape)
186

patil-suraj's avatar
patil-suraj committed
187
188
        with torch.no_grad():
            output_1 = model(**inputs_dict)
Patrick von Platen's avatar
Patrick von Platen committed
189
190
191
192

            if isinstance(output_1, dict):
                output_1 = output_1["sample"]

patil-suraj's avatar
patil-suraj committed
193
            output_2 = new_model(**inputs_dict)
194

Patrick von Platen's avatar
Patrick von Platen committed
195
196
197
            if isinstance(output_2, dict):
                output_2 = output_2["sample"]

patil-suraj's avatar
patil-suraj committed
198
        self.assertEqual(output_1.shape, output_2.shape)
patil-suraj's avatar
patil-suraj committed
199
200

    def test_training(self):
patil-suraj's avatar
patil-suraj committed
201
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
202

patil-suraj's avatar
patil-suraj committed
203
204
205
206
        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.train()
        output = model(**inputs_dict)
Patrick von Platen's avatar
Patrick von Platen committed
207
208
209
210

        if isinstance(output, dict):
            output = output["sample"]

211
        noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device)
patil-suraj's avatar
patil-suraj committed
212
213
        loss = torch.nn.functional.mse_loss(output, noise)
        loss.backward()
214

215
216
217
218
219
220
221
222
223
    def test_ema_training(self):
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.train()
        ema_model = EMAModel(model, device=torch_device)

        output = model(**inputs_dict)
Patrick von Platen's avatar
Patrick von Platen committed
224
225
226
227

        if isinstance(output, dict):
            output = output["sample"]

228
        noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device)
229
230
231
232
        loss = torch.nn.functional.mse_loss(output, noise)
        loss.backward()
        ema_model.step(model)

patil-suraj's avatar
patil-suraj committed
233
234

class UnetModelTests(ModelTesterMixin, unittest.TestCase):
235
    model_class = UNetUnconditionalModel
patil-suraj's avatar
patil-suraj committed
236
237
238
239
240
241
242
243
244
245

    @property
    def dummy_input(self):
        batch_size = 4
        num_channels = 3
        sizes = (32, 32)

        noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
        time_step = torch.tensor([10]).to(torch_device)

246
        return {"sample": noise, "timestep": time_step}
247

patil-suraj's avatar
patil-suraj committed
248
    @property
Patrick von Platen's avatar
Patrick von Platen committed
249
    def input_shape(self):
patil-suraj's avatar
patil-suraj committed
250
        return (3, 32, 32)
251

patil-suraj's avatar
patil-suraj committed
252
    @property
Patrick von Platen's avatar
Patrick von Platen committed
253
    def output_shape(self):
patil-suraj's avatar
patil-suraj committed
254
        return (3, 32, 32)
patil-suraj's avatar
patil-suraj committed
255
256
257
258
259

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
            "ch": 32,
            "ch_mult": (1, 2),
260
261
262
263
264
265
            "block_channels": (32, 64),
            "down_blocks": ("UNetResDownBlock2D", "UNetResAttnDownBlock2D"),
            "up_blocks": ("UNetResAttnUpBlock2D", "UNetResUpBlock2D"),
            "num_head_channels": None,
            "out_channels": 3,
            "in_channels": 3,
patil-suraj's avatar
patil-suraj committed
266
267
268
            "num_res_blocks": 2,
            "attn_resolutions": (16,),
            "resolution": 32,
269
            "image_size": 32,
patil-suraj's avatar
patil-suraj committed
270
271
272
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict
273

patil-suraj's avatar
patil-suraj committed
274

Patrick von Platen's avatar
upload  
Patrick von Platen committed
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
#    TODO(Patrick) - Re-add this test after having correctly added the final VE checkpoints
#    def test_output_pretrained(self):
#        model = UNetUnconditionalModel.from_pretrained("fusing/ddpm_dummy_update", subfolder="unet")
#        model.eval()
#
#        torch.manual_seed(0)
#        if torch.cuda.is_available():
#            torch.cuda.manual_seed_all(0)
#
#        noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
#        time_step = torch.tensor([10])
#
#        with torch.no_grad():
#            output = model(noise, time_step)["sample"]
#
#        output_slice = output[0, -1, -3:, -3:].flatten()
# fmt: off
#        expected_output_slice = torch.tensor([0.2891, -0.1899, 0.2595, -0.6214, 0.0968, -0.2622, 0.4688, 0.1311, 0.0053])
# fmt: on
#        self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
295
296


patil-suraj's avatar
patil-suraj committed
297
class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
298
    model_class = UNetUnconditionalModel
patil-suraj's avatar
patil-suraj committed
299
300
301
302
303
304
305
306
307
308

    @property
    def dummy_input(self):
        batch_size = 4
        num_channels = 4
        sizes = (32, 32)

        noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
        time_step = torch.tensor([10]).to(torch_device)

309
        return {"sample": noise, "timestep": time_step}
patil-suraj's avatar
patil-suraj committed
310
311

    @property
Patrick von Platen's avatar
Patrick von Platen committed
312
    def input_shape(self):
patil-suraj's avatar
patil-suraj committed
313
314
315
        return (4, 32, 32)

    @property
Patrick von Platen's avatar
Patrick von Platen committed
316
    def output_shape(self):
patil-suraj's avatar
patil-suraj committed
317
318
319
320
321
322
323
324
325
        return (4, 32, 32)

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
            "image_size": 32,
            "in_channels": 4,
            "out_channels": 4,
            "num_res_blocks": 2,
            "attention_resolutions": (16,),
Patrick von Platen's avatar
Patrick von Platen committed
326
            "block_channels": (32, 64),
327
            "num_head_channels": 32,
patil-suraj's avatar
patil-suraj committed
328
            "conv_resample": True,
329
330
            "down_blocks": ("UNetResDownBlock2D", "UNetResDownBlock2D"),
            "up_blocks": ("UNetResUpBlock2D", "UNetResUpBlock2D"),
Patrick von Platen's avatar
Patrick von Platen committed
331
            "ldm": True,
patil-suraj's avatar
patil-suraj committed
332
333
334
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict
anton-l's avatar
anton-l committed
335

patil-suraj's avatar
patil-suraj committed
336
    def test_from_pretrained_hub(self):
Patrick von Platen's avatar
Patrick von Platen committed
337
        model, loading_info = UNetUnconditionalModel.from_pretrained(
Patrick von Platen's avatar
upload  
Patrick von Platen committed
338
            "fusing/unet-ldm-dummy-update", output_loading_info=True
Patrick von Platen's avatar
Patrick von Platen committed
339
        )
patil-suraj's avatar
patil-suraj committed
340
        self.assertIsNotNone(model)
Patrick von Platen's avatar
upload  
Patrick von Platen committed
341
        self.assertEqual(len(loading_info["missing_keys"]), 0)
patil-suraj's avatar
patil-suraj committed
342
343

        model.to(torch_device)
Patrick von Platen's avatar
Patrick von Platen committed
344
        image = model(**self.dummy_input)["sample"]
patil-suraj's avatar
patil-suraj committed
345
346
347
348

        assert image is not None, "Make sure output is not None"

    def test_output_pretrained(self):
Patrick von Platen's avatar
upload  
Patrick von Platen committed
349
        model = UNetUnconditionalModel.from_pretrained("fusing/unet-ldm-dummy-update")
patil-suraj's avatar
patil-suraj committed
350
351
352
353
354
355
356
357
358
359
        model.eval()

        torch.manual_seed(0)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)

        noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
        time_step = torch.tensor([10] * noise.shape[0])

        with torch.no_grad():
Patrick von Platen's avatar
Patrick von Platen committed
360
            output = model(noise, time_step)["sample"]
patil-suraj's avatar
patil-suraj committed
361
362
363
364
365
366
367
368

        output_slice = output[0, -1, -3:, -3:].flatten()
        # fmt: off
        expected_output_slice = torch.tensor([-13.3258, -20.1100, -15.9873, -17.6617, -23.0596, -17.9419, -13.3675, -16.1889, -12.3800])
        # fmt: on

        self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))

Patrick von Platen's avatar
Patrick von Platen committed
369

Patrick von Platen's avatar
upload  
Patrick von Platen committed
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
#    TODO(Patrick) - Re-add this test after having cleaned up LDM
#    def test_output_pretrained_spatial_transformer(self):
#        model = UNetLDMModel.from_pretrained("fusing/unet-ldm-dummy-spatial")
#        model.eval()
#
#        torch.manual_seed(0)
#        if torch.cuda.is_available():
#            torch.cuda.manual_seed_all(0)
#
#        noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
#        context = torch.ones((1, 16, 64), dtype=torch.float32)
#        time_step = torch.tensor([10] * noise.shape[0])
#
#        with torch.no_grad():
#            output = model(noise, time_step, context=context)
#
#        output_slice = output[0, -1, -3:, -3:].flatten()
# fmt: off
#        expected_output_slice = torch.tensor([61.3445, 56.9005, 29.4339, 59.5497, 60.7375, 34.1719, 48.1951, 42.6569, 25.0890])
# fmt: on
#
#        self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
#
Patrick von Platen's avatar
Patrick von Platen committed
393

patil-suraj's avatar
patil-suraj committed
394

395
class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
396
    model_class = UNetUnconditionalModel
397
398
399
400
401
402
403
404
405
406

    @property
    def dummy_input(self):
        batch_size = 4
        num_channels = 3
        sizes = (32, 32)

        noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
        time_step = torch.tensor(batch_size * [10]).to(torch_device)

407
        return {"sample": noise, "timestep": time_step}
408
409

    @property
Patrick von Platen's avatar
Patrick von Platen committed
410
    def input_shape(self):
411
412
413
        return (3, 32, 32)

    @property
Patrick von Platen's avatar
Patrick von Platen committed
414
    def output_shape(self):
415
416
417
418
        return (3, 32, 32)

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
            "block_channels": [32, 64, 64, 64],
            "in_channels": 3,
            "num_res_blocks": 1,
            "out_channels": 3,
            "time_embedding_type": "fourier",
            "resnet_eps": 1e-6,
            "mid_block_scale_factor": math.sqrt(2.0),
            "resnet_num_groups": None,
            "down_blocks": [
                "UNetResSkipDownBlock2D",
                "UNetResAttnSkipDownBlock2D",
                "UNetResSkipDownBlock2D",
                "UNetResSkipDownBlock2D",
            ],
            "up_blocks": [
                "UNetResSkipUpBlock2D",
                "UNetResSkipUpBlock2D",
                "UNetResAttnSkipUpBlock2D",
                "UNetResSkipUpBlock2D",
            ],
439
440
441
442
443
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict

    def test_from_pretrained_hub(self):
444
        model, loading_info = UNetUnconditionalModel.from_pretrained(
Patrick von Platen's avatar
upload  
Patrick von Platen committed
445
            "fusing/ncsnpp-ffhq-ve-dummy-update", output_loading_info=True
446
        )
447
        self.assertIsNotNone(model)
Patrick von Platen's avatar
upload  
Patrick von Platen committed
448
        self.assertEqual(len(loading_info["missing_keys"]), 0)
449
450
451
452
453
454

        model.to(torch_device)
        image = model(**self.dummy_input)

        assert image is not None, "Make sure output is not None"

455
    def test_output_pretrained_ve_mid(self):
Patrick von Platen's avatar
upload  
Patrick von Platen committed
456
        model = UNetUnconditionalModel.from_pretrained("google/ncsnpp-celebahq-256")
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
        model.to(torch_device)

        torch.manual_seed(0)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)

        batch_size = 4
        num_channels = 3
        sizes = (256, 256)

        noise = torch.ones((batch_size, num_channels) + sizes).to(torch_device)
        time_step = torch.tensor(batch_size * [1e-4]).to(torch_device)

        with torch.no_grad():
            output = model(noise, time_step)["sample"]

        output_slice = output[0, -3:, -3:, -1].flatten().cpu()
        # fmt: off
        expected_output_slice = torch.tensor([-4836.2231, -6487.1387, -3816.7969, -7964.9253, -10966.2842, -20043.6016, 8137.0571, 2340.3499, 544.6114])
        # fmt: on

        self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))

480
    def test_output_pretrained_ve_large(self):
Patrick von Platen's avatar
upload  
Patrick von Platen committed
481
        model = UNetUnconditionalModel.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy-update")
482
483
484
485
486
487
488
489
490
491
        model.to(torch_device)

        torch.manual_seed(0)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)

        batch_size = 4
        num_channels = 3
        sizes = (32, 32)

Patrick von Platen's avatar
Patrick von Platen committed
492
493
        noise = torch.ones((batch_size, num_channels) + sizes).to(torch_device)
        time_step = torch.tensor(batch_size * [1e-4]).to(torch_device)
494
495

        with torch.no_grad():
496
            output = model(noise, time_step)["sample"]
497
498
499

        output_slice = output[0, -3:, -3:, -1].flatten().cpu()
        # fmt: off
Patrick von Platen's avatar
Patrick von Platen committed
500
        expected_output_slice = torch.tensor([-0.0325, -0.0900, -0.0869, -0.0332, -0.0725, -0.0270, -0.0101, 0.0227, 0.0256])
501
502
        # fmt: on

Patrick von Platen's avatar
Patrick von Platen committed
503
        self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
504
505


patil-suraj's avatar
patil-suraj committed
506
507
508
509
510
511
512
513
514
515
516
class VQModelTests(ModelTesterMixin, unittest.TestCase):
    model_class = VQModel

    @property
    def dummy_input(self):
        batch_size = 4
        num_channels = 3
        sizes = (32, 32)

        image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)

517
        return {"sample": image}
patil-suraj's avatar
patil-suraj committed
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554

    @property
    def input_shape(self):
        return (3, 32, 32)

    @property
    def output_shape(self):
        return (3, 32, 32)

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
            "ch": 64,
            "out_ch": 3,
            "num_res_blocks": 1,
            "attn_resolutions": [],
            "in_channels": 3,
            "resolution": 32,
            "z_channels": 3,
            "n_embed": 256,
            "embed_dim": 3,
            "sane_index_shape": False,
            "ch_mult": (1,),
            "dropout": 0.0,
            "double_z": False,
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict

    def test_forward_signature(self):
        pass

    def test_training(self):
        pass

    def test_from_pretrained_hub(self):
        model, loading_info = VQModel.from_pretrained("fusing/vqgan-dummy", output_loading_info=True)
        self.assertIsNotNone(model)
Patrick von Platen's avatar
upload  
Patrick von Platen committed
555
        self.assertEqual(len(loading_info["missing_keys"]), 0)
patil-suraj's avatar
patil-suraj committed
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575

        model.to(torch_device)
        image = model(**self.dummy_input)

        assert image is not None, "Make sure output is not None"

    def test_output_pretrained(self):
        model = VQModel.from_pretrained("fusing/vqgan-dummy")
        model.eval()

        torch.manual_seed(0)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)

        image = torch.randn(1, model.config.in_channels, model.config.resolution, model.config.resolution)
        with torch.no_grad():
            output = model(image)

        output_slice = output[0, -1, -3:, -3:].flatten()
        # fmt: off
Patrick von Platen's avatar
up  
Patrick von Platen committed
576
        expected_output_slice = torch.tensor([-1.1321, 0.1056, 0.3505, -0.6461, -0.2014, 0.0419, -0.5763, -0.8462, -0.4218])
patil-suraj's avatar
patil-suraj committed
577
        # fmt: on
Patrick von Platen's avatar
up  
Patrick von Platen committed
578
        self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
patil-suraj's avatar
patil-suraj committed
579
580


Patrick von Platen's avatar
Patrick von Platen committed
581
class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
patil-suraj's avatar
patil-suraj committed
582
583
584
585
586
587
588
589
590
591
    model_class = AutoencoderKL

    @property
    def dummy_input(self):
        batch_size = 4
        num_channels = 3
        sizes = (32, 32)

        image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)

592
        return {"sample": image}
patil-suraj's avatar
patil-suraj committed
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611

    @property
    def input_shape(self):
        return (3, 32, 32)

    @property
    def output_shape(self):
        return (3, 32, 32)

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
            "ch": 64,
            "ch_mult": (1,),
            "embed_dim": 4,
            "in_channels": 3,
            "num_res_blocks": 1,
            "out_ch": 3,
            "resolution": 32,
            "z_channels": 4,
patil-suraj's avatar
patil-suraj committed
612
            "attn_resolutions": [],
patil-suraj's avatar
patil-suraj committed
613
614
615
616
617
618
619
620
621
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict

    def test_forward_signature(self):
        pass

    def test_training(self):
        pass
patil-suraj's avatar
patil-suraj committed
622

patil-suraj's avatar
patil-suraj committed
623
624
625
    def test_from_pretrained_hub(self):
        model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True)
        self.assertIsNotNone(model)
Patrick von Platen's avatar
upload  
Patrick von Platen committed
626
        self.assertEqual(len(loading_info["missing_keys"]), 0)
patil-suraj's avatar
patil-suraj committed
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646

        model.to(torch_device)
        image = model(**self.dummy_input)

        assert image is not None, "Make sure output is not None"

    def test_output_pretrained(self):
        model = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy")
        model.eval()

        torch.manual_seed(0)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)

        image = torch.randn(1, model.config.in_channels, model.config.resolution, model.config.resolution)
        with torch.no_grad():
            output = model(image, sample_posterior=True)

        output_slice = output[0, -1, -3:, -3:].flatten()
        # fmt: off
Patrick von Platen's avatar
up  
Patrick von Platen committed
647
        expected_output_slice = torch.tensor([-0.0814, -0.0229, -0.1320, -0.4123, -0.0366, -0.3473, 0.0438, -0.1662, 0.1750])
patil-suraj's avatar
patil-suraj committed
648
        # fmt: on
Patrick von Platen's avatar
up  
Patrick von Platen committed
649
        self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
patil-suraj's avatar
patil-suraj committed
650
651


652
653
654
class PipelineTesterMixin(unittest.TestCase):
    def test_from_pretrained_save_pretrained(self):
        # 1. Load models
655
        model = UNetUnconditionalModel(
Patrick von Platen's avatar
Patrick von Platen committed
656
657
658
659
660
661
662
663
            block_channels=(32, 64),
            num_res_blocks=2,
            attn_resolutions=(16,),
            image_size=32,
            in_channels=3,
            out_channels=3,
            down_blocks=("UNetResDownBlock2D", "UNetResAttnDownBlock2D"),
            up_blocks=("UNetResAttnUpBlock2D", "UNetResUpBlock2D"),
664
        )
Patrick von Platen's avatar
Patrick von Platen committed
665
        schedular = DDPMScheduler(num_train_timesteps=10)
666

Patrick von Platen's avatar
Patrick von Platen committed
667
        ddpm = DDPMPipeline(model, schedular)
668
669
670

        with tempfile.TemporaryDirectory() as tmpdirname:
            ddpm.save_pretrained(tmpdirname)
Patrick von Platen's avatar
Patrick von Platen committed
671
            new_ddpm = DDPMPipeline.from_pretrained(tmpdirname)
Patrick von Platen's avatar
Patrick von Platen committed
672
673

        generator = torch.manual_seed(0)
674

patil-suraj's avatar
patil-suraj committed
675
        image = ddpm(generator=generator)
Patrick von Platen's avatar
Patrick von Platen committed
676
        generator = generator.manual_seed(0)
patil-suraj's avatar
patil-suraj committed
677
        new_image = new_ddpm(generator=generator)
678
679
680
681
682

        assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"

    @slow
    def test_from_pretrained_hub(self):
Patrick von Platen's avatar
upload  
Patrick von Platen committed
683
        model_path = "google/ddpm-cifar10-32"
684

Patrick von Platen's avatar
Patrick von Platen committed
685
        ddpm = DDPMPipeline.from_pretrained(model_path)
686
687
        ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path)

688
689
        ddpm.scheduler.num_timesteps = 10
        ddpm_from_hub.scheduler.num_timesteps = 10
690

Patrick von Platen's avatar
Patrick von Platen committed
691
        generator = torch.manual_seed(0)
692

patil-suraj's avatar
patil-suraj committed
693
        image = ddpm(generator=generator)
Patrick von Platen's avatar
Patrick von Platen committed
694
        generator = generator.manual_seed(0)
patil-suraj's avatar
patil-suraj committed
695
        new_image = ddpm_from_hub(generator=generator)
696
697

        assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"
Patrick von Platen's avatar
Patrick von Platen committed
698
699
700

    @slow
    def test_ddpm_cifar10(self):
Patrick von Platen's avatar
upload  
Patrick von Platen committed
701
        model_id = "google/ddpm-cifar10-32"
Patrick von Platen's avatar
Patrick von Platen committed
702

Lysandre Debut's avatar
Lysandre Debut committed
703
        unet = UNetUnconditionalModel.from_pretrained(model_id)
704
705
        scheduler = DDPMScheduler.from_config(model_id)
        scheduler = scheduler.set_format("pt")
Patrick von Platen's avatar
Patrick von Platen committed
706

707
        ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
Patrick von Platen's avatar
Patrick von Platen committed
708
709

        generator = torch.manual_seed(0)
Patrick von Platen's avatar
Patrick von Platen committed
710
711
712
713
714
        image = ddpm(generator=generator)

        image_slice = image[0, -1, -3:, -3:].cpu()

        assert image.shape == (1, 3, 32, 32)
Patrick von Platen's avatar
Patrick von Platen committed
715
        expected_slice = torch.tensor(
716
717
718
719
720
721
            [-0.1601, -0.2823, -0.6123, -0.2305, -0.3236, -0.4706, -0.1691, -0.2836, -0.3231]
        )
        assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2

    @slow
    def test_ddim_lsun(self):
Patrick von Platen's avatar
upload  
Patrick von Platen committed
722
        model_id = "google/ddpm-ema-bedroom-256"
723

Lysandre Debut's avatar
Lysandre Debut committed
724
        unet = UNetUnconditionalModel.from_pretrained(model_id)
725
        scheduler = DDIMScheduler.from_config(model_id)
726

727
        ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
728
729

        generator = torch.manual_seed(0)
730
        image = ddpm(generator=generator)["sample"]
731
732
733
734
735
736

        image_slice = image[0, -1, -3:, -3:].cpu()

        assert image.shape == (1, 3, 256, 256)
        expected_slice = torch.tensor(
            [-0.9879, -0.9598, -0.9312, -0.9953, -0.9963, -0.9995, -0.9957, -1.0000, -0.9863]
Patrick von Platen's avatar
Patrick von Platen committed
737
        )
Patrick von Platen's avatar
Patrick von Platen committed
738
739
740
741
        assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2

    @slow
    def test_ddim_cifar10(self):
Patrick von Platen's avatar
upload  
Patrick von Platen committed
742
        model_id = "google/ddpm-cifar10-32"
Patrick von Platen's avatar
Patrick von Platen committed
743

Lysandre Debut's avatar
Lysandre Debut committed
744
        unet = UNetUnconditionalModel.from_pretrained(model_id)
745
        scheduler = DDIMScheduler(tensor_format="pt")
Patrick von Platen's avatar
Patrick von Platen committed
746

747
        ddim = DDIMPipeline(unet=unet, scheduler=scheduler)
748
749

        generator = torch.manual_seed(0)
750
        image = ddim(generator=generator, eta=0.0)["sample"]
Patrick von Platen's avatar
Patrick von Platen committed
751
752
753
754

        image_slice = image[0, -1, -3:, -3:].cpu()

        assert image.shape == (1, 3, 32, 32)
Patrick von Platen's avatar
Patrick von Platen committed
755
        expected_slice = torch.tensor(
756
            [-0.6553, -0.6765, -0.6799, -0.6749, -0.7006, -0.6974, -0.6991, -0.7116, -0.7094]
Patrick von Platen's avatar
Patrick von Platen committed
757
        )
Patrick von Platen's avatar
Patrick von Platen committed
758
        assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
patil-suraj's avatar
patil-suraj committed
759

Patrick von Platen's avatar
Patrick von Platen committed
760
761
    @slow
    def test_pndm_cifar10(self):
Patrick von Platen's avatar
upload  
Patrick von Platen committed
762
        model_id = "google/ddpm-cifar10-32"
Patrick von Platen's avatar
Patrick von Platen committed
763

Patrick von Platen's avatar
Patrick von Platen committed
764
        unet = UNetUnconditionalModel.from_pretrained(model_id)
765
        scheduler = PNDMScheduler(tensor_format="pt")
Patrick von Platen's avatar
Patrick von Platen committed
766

767
        pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
Patrick von Platen's avatar
Patrick von Platen committed
768
        generator = torch.manual_seed(0)
Patrick von Platen's avatar
Patrick von Platen committed
769
770
771
772
773
774
        image = pndm(generator=generator)

        image_slice = image[0, -1, -3:, -3:].cpu()

        assert image.shape == (1, 3, 32, 32)
        expected_slice = torch.tensor(
Patrick von Platen's avatar
Patrick von Platen committed
775
            [-0.6872, -0.7071, -0.7188, -0.7057, -0.7515, -0.7191, -0.7377, -0.7565, -0.7500]
Patrick von Platen's avatar
Patrick von Platen committed
776
777
778
        )
        assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2

patil-suraj's avatar
patil-suraj committed
779
780
    @slow
    def test_ldm_text2img(self):
Patrick von Platen's avatar
upload  
Patrick von Platen committed
781
        ldm = LatentDiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
patil-suraj's avatar
patil-suraj committed
782
783
784
785
786
787
788
789
790

        prompt = "A painting of a squirrel eating a burger"
        generator = torch.manual_seed(0)
        image = ldm([prompt], generator=generator, num_inference_steps=20)

        image_slice = image[0, -1, -3:, -3:].cpu()

        assert image.shape == (1, 3, 256, 256)
        expected_slice = torch.tensor([0.7295, 0.7358, 0.7256, 0.7435, 0.7095, 0.6884, 0.7325, 0.6921, 0.6458])
Patrick von Platen's avatar
update  
Patrick von Platen committed
791
        assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
792

patil-suraj's avatar
patil-suraj committed
793
794
    @slow
    def test_ldm_text2img_fast(self):
Patrick von Platen's avatar
upload  
Patrick von Platen committed
795
        ldm = LatentDiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
patil-suraj's avatar
patil-suraj committed
796
797
798

        prompt = "A painting of a squirrel eating a burger"
        generator = torch.manual_seed(0)
799
        image = ldm([prompt], generator=generator, num_inference_steps=1)
patil-suraj's avatar
patil-suraj committed
800
801
802
803

        image_slice = image[0, -1, -3:, -3:].cpu()

        assert image.shape == (1, 3, 256, 256)
patil-suraj's avatar
patil-suraj committed
804
        expected_slice = torch.tensor([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344])
patil-suraj's avatar
patil-suraj committed
805
806
        assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2

Patrick von Platen's avatar
Patrick von Platen committed
807
808
    @slow
    def test_score_sde_ve_pipeline(self):
Patrick von Platen's avatar
upload  
Patrick von Platen committed
809
        model = UNetUnconditionalModel.from_pretrained("google/ncsnpp-ffhq-1024")
810
811
812
813
814

        torch.manual_seed(0)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)

Patrick von Platen's avatar
upload  
Patrick von Platen committed
815
        scheduler = ScoreSdeVeScheduler.from_config("google/ncsnpp-ffhq-1024")
Patrick von Platen's avatar
Patrick von Platen committed
816
817
818

        sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler)

819
        torch.manual_seed(0)
Patrick von Platen's avatar
Patrick von Platen committed
820
821
        image = sde_ve(num_inference_steps=2)

Patrick von Platen's avatar
Patrick von Platen committed
822
        if model.device.type == "cpu":
Nathan Lambert's avatar
Nathan Lambert committed
823
824
825
826
827
828
829
            # patrick's cpu
            expected_image_sum = 3384805888.0
            expected_image_mean = 1076.00085

            # m1 mbp
            # expected_image_sum = 3384805376.0
            # expected_image_mean = 1076.000610351562
Patrick von Platen's avatar
Patrick von Platen committed
830
831
        else:
            expected_image_sum = 3382849024.0
Nathan Lambert's avatar
Nathan Lambert committed
832
            expected_image_mean = 1075.3788
Patrick von Platen's avatar
Patrick von Platen committed
833
834
835
836

        assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2
        assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4

patil-suraj's avatar
patil-suraj committed
837
838
    @slow
    def test_ldm_uncond(self):
Patrick von Platen's avatar
upload  
Patrick von Platen committed
839
        ldm = LatentDiffusionUncondPipeline.from_pretrained("CompVis/ldm-celebahq-256")
patil-suraj's avatar
patil-suraj committed
840
841

        generator = torch.manual_seed(0)
842
        image = ldm(generator=generator, num_inference_steps=5)["sample"]
patil-suraj's avatar
patil-suraj committed
843
844
845
846

        image_slice = image[0, -1, -3:, -3:].cpu()

        assert image.shape == (1, 3, 256, 256)
patil-suraj's avatar
patil-suraj committed
847
848
849
        expected_slice = torch.tensor(
            [-0.1202, -0.1005, -0.0635, -0.0520, -0.1282, -0.0838, -0.0981, -0.1318, -0.1106]
        )
patil-suraj's avatar
patil-suraj committed
850
        assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2