test_modeling_utils.py 31.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

patil-suraj's avatar
patil-suraj committed
16
import inspect
17
import math
18
19
20
import tempfile
import unittest

21
import numpy as np
22
23
import torch

24
import PIL
Patrick von Platen's avatar
Patrick von Platen committed
25
from diffusers import UNet2DConditionModel  # noqa: F401 TODO(Patrick) - need to write tests with it
Patrick von Platen's avatar
Patrick von Platen committed
26
from diffusers import (
patil-suraj's avatar
patil-suraj committed
27
    AutoencoderKL,
Patrick von Platen's avatar
Patrick von Platen committed
28
    DDIMPipeline,
29
    DDIMScheduler,
Patrick von Platen's avatar
Patrick von Platen committed
30
    DDPMPipeline,
31
    DDPMScheduler,
32
33
    KarrasVePipeline,
    KarrasVeScheduler,
Patrick von Platen's avatar
Patrick von Platen committed
34
35
    LDMPipeline,
    LDMTextToImagePipeline,
Patrick von Platen's avatar
Patrick von Platen committed
36
    PNDMPipeline,
37
    PNDMScheduler,
Patrick von Platen's avatar
Patrick von Platen committed
38
39
    ScoreSdeVePipeline,
    ScoreSdeVeScheduler,
Patrick von Platen's avatar
Patrick von Platen committed
40
    UNet2DModel,
patil-suraj's avatar
patil-suraj committed
41
    VQModel,
42
)
43
from diffusers.configuration_utils import ConfigMixin, register_to_config
Patrick von Platen's avatar
Patrick von Platen committed
44
from diffusers.pipeline_utils import DiffusionPipeline
Patrick von Platen's avatar
Patrick von Platen committed
45
from diffusers.testing_utils import floats_tensor, slow, torch_device
46
from diffusers.training_utils import EMAModel
47
48


Patrick von Platen's avatar
Patrick von Platen committed
49
torch.backends.cuda.matmul.allow_tf32 = False
50
51


52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
class SampleObject(ConfigMixin):
    config_name = "config.json"

    @register_to_config
    def __init__(
        self,
        a=2,
        b=5,
        c=(2, 5),
        d="for diffusion",
        e=[1, 3],
    ):
        pass


67
68
69
70
71
class ConfigTester(unittest.TestCase):
    def test_load_not_from_mixin(self):
        with self.assertRaises(ValueError):
            ConfigMixin.from_config("dummy_path")

72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    def test_register_to_config(self):
        obj = SampleObject()
        config = obj.config
        assert config["a"] == 2
        assert config["b"] == 5
        assert config["c"] == (2, 5)
        assert config["d"] == "for diffusion"
        assert config["e"] == [1, 3]

        # init ignore private arguments
        obj = SampleObject(_name_or_path="lalala")
        config = obj.config
        assert config["a"] == 2
        assert config["b"] == 5
        assert config["c"] == (2, 5)
        assert config["d"] == "for diffusion"
        assert config["e"] == [1, 3]
89

90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
        # can override default
        obj = SampleObject(c=6)
        config = obj.config
        assert config["a"] == 2
        assert config["b"] == 5
        assert config["c"] == 6
        assert config["d"] == "for diffusion"
        assert config["e"] == [1, 3]

        # can use positional arguments.
        obj = SampleObject(1, c=6)
        config = obj.config
        assert config["a"] == 1
        assert config["b"] == 5
        assert config["c"] == 6
        assert config["d"] == "for diffusion"
        assert config["e"] == [1, 3]

    def test_save_load(self):
109
110
111
112
113
114
115
116
117
118
119
120
121
122
        obj = SampleObject()
        config = obj.config

        assert config["a"] == 2
        assert config["b"] == 5
        assert config["c"] == (2, 5)
        assert config["d"] == "for diffusion"
        assert config["e"] == [1, 3]

        with tempfile.TemporaryDirectory() as tmpdirname:
            obj.save_config(tmpdirname)
            new_obj = SampleObject.from_config(tmpdirname)
            new_config = new_obj.config

Patrick von Platen's avatar
Patrick von Platen committed
123
124
125
126
        # unfreeze configs
        config = dict(config)
        new_config = dict(new_config)

127
128
129
130
131
        assert config.pop("c") == (2, 5)  # instantiated as tuple
        assert new_config.pop("c") == [2, 5]  # saved & loaded as list because of json
        assert config == new_config


patil-suraj's avatar
patil-suraj committed
132
class ModelTesterMixin:
133
    def test_from_pretrained_save_pretrained(self):
patil-suraj's avatar
patil-suraj committed
134
135
136
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

        model = self.model_class(**init_dict)
Patrick von Platen's avatar
Patrick von Platen committed
137
        model.to(torch_device)
patil-suraj's avatar
patil-suraj committed
138
        model.eval()
139
140
141

        with tempfile.TemporaryDirectory() as tmpdirname:
            model.save_pretrained(tmpdirname)
patil-suraj's avatar
patil-suraj committed
142
            new_model = self.model_class.from_pretrained(tmpdirname)
Patrick von Platen's avatar
Patrick von Platen committed
143
            new_model.to(torch_device)
144

patil-suraj's avatar
patil-suraj committed
145
146
        with torch.no_grad():
            image = model(**inputs_dict)
Patrick von Platen's avatar
Patrick von Platen committed
147
148
149
            if isinstance(image, dict):
                image = image["sample"]

patil-suraj's avatar
patil-suraj committed
150
            new_image = new_model(**inputs_dict)
151

Patrick von Platen's avatar
Patrick von Platen committed
152
153
154
            if isinstance(new_image, dict):
                new_image = new_image["sample"]

patil-suraj's avatar
patil-suraj committed
155
        max_diff = (image - new_image).abs().sum().item()
Patrick von Platen's avatar
Patrick von Platen committed
156
        self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
157

patil-suraj's avatar
patil-suraj committed
158
    def test_determinism(self):
patil-suraj's avatar
patil-suraj committed
159
160
161
162
163
164
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.eval()
        with torch.no_grad():
            first = model(**inputs_dict)
Patrick von Platen's avatar
Patrick von Platen committed
165
166
167
            if isinstance(first, dict):
                first = first["sample"]

patil-suraj's avatar
patil-suraj committed
168
            second = model(**inputs_dict)
Patrick von Platen's avatar
Patrick von Platen committed
169
170
            if isinstance(second, dict):
                second = second["sample"]
patil-suraj's avatar
patil-suraj committed
171
172
173
174
175
176
177

        out_1 = first.cpu().numpy()
        out_2 = second.cpu().numpy()
        out_1 = out_1[~np.isnan(out_1)]
        out_2 = out_2[~np.isnan(out_2)]
        max_diff = np.amax(np.abs(out_1 - out_2))
        self.assertLessEqual(max_diff, 1e-5)
178

patil-suraj's avatar
patil-suraj committed
179
    def test_output(self):
patil-suraj's avatar
patil-suraj committed
180
181
182
183
184
185
186
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.eval()

        with torch.no_grad():
            output = model(**inputs_dict)
187

Patrick von Platen's avatar
Patrick von Platen committed
188
189
190
            if isinstance(output, dict):
                output = output["sample"]

patil-suraj's avatar
patil-suraj committed
191
        self.assertIsNotNone(output)
192
        expected_shape = inputs_dict["sample"].shape
patil-suraj's avatar
patil-suraj committed
193
        self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
194

patil-suraj's avatar
patil-suraj committed
195
    def test_forward_signature(self):
patil-suraj's avatar
patil-suraj committed
196
197
198
199
200
201
202
        init_dict, _ = self.prepare_init_args_and_inputs_for_common()

        model = self.model_class(**init_dict)
        signature = inspect.signature(model.forward)
        # signature.parameters is an OrderedDict => so arg_names order is deterministic
        arg_names = [*signature.parameters.keys()]

203
        expected_arg_names = ["sample", "timestep"]
patil-suraj's avatar
patil-suraj committed
204
        self.assertListEqual(arg_names[:2], expected_arg_names)
205

patil-suraj's avatar
patil-suraj committed
206
    def test_model_from_config(self):
patil-suraj's avatar
patil-suraj committed
207
208
209
210
211
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.eval()
212

patil-suraj's avatar
patil-suraj committed
213
214
215
216
217
218
219
        # test if the model can be loaded from the config
        # and has all the expected shape
        with tempfile.TemporaryDirectory() as tmpdirname:
            model.save_config(tmpdirname)
            new_model = self.model_class.from_config(tmpdirname)
            new_model.to(torch_device)
            new_model.eval()
220

patil-suraj's avatar
patil-suraj committed
221
222
223
224
225
        # check if all paramters shape are the same
        for param_name in model.state_dict().keys():
            param_1 = model.state_dict()[param_name]
            param_2 = new_model.state_dict()[param_name]
            self.assertEqual(param_1.shape, param_2.shape)
226

patil-suraj's avatar
patil-suraj committed
227
228
        with torch.no_grad():
            output_1 = model(**inputs_dict)
Patrick von Platen's avatar
Patrick von Platen committed
229
230
231
232

            if isinstance(output_1, dict):
                output_1 = output_1["sample"]

patil-suraj's avatar
patil-suraj committed
233
            output_2 = new_model(**inputs_dict)
234

Patrick von Platen's avatar
Patrick von Platen committed
235
236
237
            if isinstance(output_2, dict):
                output_2 = output_2["sample"]

patil-suraj's avatar
patil-suraj committed
238
        self.assertEqual(output_1.shape, output_2.shape)
patil-suraj's avatar
patil-suraj committed
239
240

    def test_training(self):
patil-suraj's avatar
patil-suraj committed
241
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
242

patil-suraj's avatar
patil-suraj committed
243
244
245
246
        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.train()
        output = model(**inputs_dict)
Patrick von Platen's avatar
Patrick von Platen committed
247
248
249
250

        if isinstance(output, dict):
            output = output["sample"]

251
        noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device)
patil-suraj's avatar
patil-suraj committed
252
253
        loss = torch.nn.functional.mse_loss(output, noise)
        loss.backward()
254

255
256
257
258
259
260
261
262
263
    def test_ema_training(self):
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.train()
        ema_model = EMAModel(model, device=torch_device)

        output = model(**inputs_dict)
Patrick von Platen's avatar
Patrick von Platen committed
264
265
266
267

        if isinstance(output, dict):
            output = output["sample"]

268
        noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device)
269
270
271
272
        loss = torch.nn.functional.mse_loss(output, noise)
        loss.backward()
        ema_model.step(model)

patil-suraj's avatar
patil-suraj committed
273
274

class UnetModelTests(ModelTesterMixin, unittest.TestCase):
Patrick von Platen's avatar
Patrick von Platen committed
275
    model_class = UNet2DModel
patil-suraj's avatar
patil-suraj committed
276
277
278
279
280
281
282
283
284
285

    @property
    def dummy_input(self):
        batch_size = 4
        num_channels = 3
        sizes = (32, 32)

        noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
        time_step = torch.tensor([10]).to(torch_device)

286
        return {"sample": noise, "timestep": time_step}
287

patil-suraj's avatar
patil-suraj committed
288
    @property
Patrick von Platen's avatar
Patrick von Platen committed
289
    def input_shape(self):
patil-suraj's avatar
patil-suraj committed
290
        return (3, 32, 32)
291

patil-suraj's avatar
patil-suraj committed
292
    @property
Patrick von Platen's avatar
Patrick von Platen committed
293
    def output_shape(self):
patil-suraj's avatar
patil-suraj committed
294
        return (3, 32, 32)
patil-suraj's avatar
patil-suraj committed
295
296
297

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
Patrick von Platen's avatar
Patrick von Platen committed
298
299
300
301
            "block_out_channels": (32, 64),
            "down_block_types": ("DownBlock2D", "AttnDownBlock2D"),
            "up_block_types": ("AttnUpBlock2D", "UpBlock2D"),
            "attention_head_dim": None,
302
303
            "out_channels": 3,
            "in_channels": 3,
Patrick von Platen's avatar
Patrick von Platen committed
304
305
            "layers_per_block": 2,
            "sample_size": 32,
patil-suraj's avatar
patil-suraj committed
306
307
308
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict
309

patil-suraj's avatar
patil-suraj committed
310

Patrick von Platen's avatar
upload  
Patrick von Platen committed
311
312
#    TODO(Patrick) - Re-add this test after having correctly added the final VE checkpoints
#    def test_output_pretrained(self):
Patrick von Platen's avatar
Patrick von Platen committed
313
#        model = UNet2DModel.from_pretrained("fusing/ddpm_dummy_update", subfolder="unet")
Patrick von Platen's avatar
upload  
Patrick von Platen committed
314
315
316
317
318
319
#        model.eval()
#
#        torch.manual_seed(0)
#        if torch.cuda.is_available():
#            torch.cuda.manual_seed_all(0)
#
Patrick von Platen's avatar
Patrick von Platen committed
320
#        noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
Patrick von Platen's avatar
upload  
Patrick von Platen committed
321
322
323
324
325
326
327
328
329
330
#        time_step = torch.tensor([10])
#
#        with torch.no_grad():
#            output = model(noise, time_step)["sample"]
#
#        output_slice = output[0, -1, -3:, -3:].flatten()
# fmt: off
#        expected_output_slice = torch.tensor([0.2891, -0.1899, 0.2595, -0.6214, 0.0968, -0.2622, 0.4688, 0.1311, 0.0053])
# fmt: on
#        self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
331
332


patil-suraj's avatar
patil-suraj committed
333
class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
Patrick von Platen's avatar
Patrick von Platen committed
334
    model_class = UNet2DModel
patil-suraj's avatar
patil-suraj committed
335
336
337
338
339
340
341
342
343
344

    @property
    def dummy_input(self):
        batch_size = 4
        num_channels = 4
        sizes = (32, 32)

        noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
        time_step = torch.tensor([10]).to(torch_device)

345
        return {"sample": noise, "timestep": time_step}
patil-suraj's avatar
patil-suraj committed
346
347

    @property
Patrick von Platen's avatar
Patrick von Platen committed
348
    def input_shape(self):
patil-suraj's avatar
patil-suraj committed
349
350
351
        return (4, 32, 32)

    @property
Patrick von Platen's avatar
Patrick von Platen committed
352
    def output_shape(self):
patil-suraj's avatar
patil-suraj committed
353
354
355
356
        return (4, 32, 32)

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
Patrick von Platen's avatar
Patrick von Platen committed
357
            "sample_size": 32,
patil-suraj's avatar
patil-suraj committed
358
359
            "in_channels": 4,
            "out_channels": 4,
Patrick von Platen's avatar
Patrick von Platen committed
360
361
362
363
364
            "layers_per_block": 2,
            "block_out_channels": (32, 64),
            "attention_head_dim": 32,
            "down_block_types": ("DownBlock2D", "DownBlock2D"),
            "up_block_types": ("UpBlock2D", "UpBlock2D"),
patil-suraj's avatar
patil-suraj committed
365
366
367
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict
anton-l's avatar
anton-l committed
368

patil-suraj's avatar
patil-suraj committed
369
    def test_from_pretrained_hub(self):
Patrick von Platen's avatar
Patrick von Platen committed
370
        model, loading_info = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
Patrick von Platen's avatar
Patrick von Platen committed
371

patil-suraj's avatar
patil-suraj committed
372
        self.assertIsNotNone(model)
Patrick von Platen's avatar
upload  
Patrick von Platen committed
373
        self.assertEqual(len(loading_info["missing_keys"]), 0)
patil-suraj's avatar
patil-suraj committed
374
375

        model.to(torch_device)
Patrick von Platen's avatar
Patrick von Platen committed
376
        image = model(**self.dummy_input)["sample"]
patil-suraj's avatar
patil-suraj committed
377
378
379
380

        assert image is not None, "Make sure output is not None"

    def test_output_pretrained(self):
Patrick von Platen's avatar
Patrick von Platen committed
381
        model = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update")
patil-suraj's avatar
patil-suraj committed
382
383
384
385
386
387
        model.eval()

        torch.manual_seed(0)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)

Patrick von Platen's avatar
Patrick von Platen committed
388
        noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
patil-suraj's avatar
patil-suraj committed
389
390
391
        time_step = torch.tensor([10] * noise.shape[0])

        with torch.no_grad():
Patrick von Platen's avatar
Patrick von Platen committed
392
            output = model(noise, time_step)["sample"]
patil-suraj's avatar
patil-suraj committed
393
394
395
396
397
398
399
400

        output_slice = output[0, -1, -3:, -3:].flatten()
        # fmt: off
        expected_output_slice = torch.tensor([-13.3258, -20.1100, -15.9873, -17.6617, -23.0596, -17.9419, -13.3675, -16.1889, -12.3800])
        # fmt: on

        self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))

Patrick von Platen's avatar
Patrick von Platen committed
401

Patrick von Platen's avatar
upload  
Patrick von Platen committed
402
403
404
405
406
407
408
409
410
#    TODO(Patrick) - Re-add this test after having cleaned up LDM
#    def test_output_pretrained_spatial_transformer(self):
#        model = UNetLDMModel.from_pretrained("fusing/unet-ldm-dummy-spatial")
#        model.eval()
#
#        torch.manual_seed(0)
#        if torch.cuda.is_available():
#            torch.cuda.manual_seed_all(0)
#
Patrick von Platen's avatar
Patrick von Platen committed
411
#        noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
Patrick von Platen's avatar
upload  
Patrick von Platen committed
412
413
414
415
416
417
418
419
420
421
422
423
424
#        context = torch.ones((1, 16, 64), dtype=torch.float32)
#        time_step = torch.tensor([10] * noise.shape[0])
#
#        with torch.no_grad():
#            output = model(noise, time_step, context=context)
#
#        output_slice = output[0, -1, -3:, -3:].flatten()
# fmt: off
#        expected_output_slice = torch.tensor([61.3445, 56.9005, 29.4339, 59.5497, 60.7375, 34.1719, 48.1951, 42.6569, 25.0890])
# fmt: on
#
#        self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
#
Patrick von Platen's avatar
Patrick von Platen committed
425

patil-suraj's avatar
patil-suraj committed
426

427
class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
Patrick von Platen's avatar
Patrick von Platen committed
428
    model_class = UNet2DModel
429
430

    @property
Patrick von Platen's avatar
Patrick von Platen committed
431
    def dummy_input(self, sizes=(32, 32)):
432
433
434
435
436
437
        batch_size = 4
        num_channels = 3

        noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
        time_step = torch.tensor(batch_size * [10]).to(torch_device)

438
        return {"sample": noise, "timestep": time_step}
439
440

    @property
Patrick von Platen's avatar
Patrick von Platen committed
441
    def input_shape(self):
442
443
444
        return (3, 32, 32)

    @property
Patrick von Platen's avatar
Patrick von Platen committed
445
    def output_shape(self):
446
447
448
449
        return (3, 32, 32)

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
Patrick von Platen's avatar
Patrick von Platen committed
450
            "block_out_channels": [32, 64, 64, 64],
451
            "in_channels": 3,
Patrick von Platen's avatar
Patrick von Platen committed
452
            "layers_per_block": 1,
453
454
            "out_channels": 3,
            "time_embedding_type": "fourier",
Patrick von Platen's avatar
Patrick von Platen committed
455
            "norm_eps": 1e-6,
456
            "mid_block_scale_factor": math.sqrt(2.0),
Patrick von Platen's avatar
Patrick von Platen committed
457
458
459
460
461
462
            "norm_num_groups": None,
            "down_block_types": [
                "SkipDownBlock2D",
                "AttnSkipDownBlock2D",
                "SkipDownBlock2D",
                "SkipDownBlock2D",
463
            ],
Patrick von Platen's avatar
Patrick von Platen committed
464
465
466
467
468
            "up_block_types": [
                "SkipUpBlock2D",
                "SkipUpBlock2D",
                "AttnSkipUpBlock2D",
                "SkipUpBlock2D",
469
            ],
470
471
472
473
474
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict

    def test_from_pretrained_hub(self):
Patrick von Platen's avatar
Patrick von Platen committed
475
        model, loading_info = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256", output_loading_info=True)
476
        self.assertIsNotNone(model)
Patrick von Platen's avatar
upload  
Patrick von Platen committed
477
        self.assertEqual(len(loading_info["missing_keys"]), 0)
478
479

        model.to(torch_device)
Patrick von Platen's avatar
Patrick von Platen committed
480
481
482
483
        inputs = self.dummy_input
        noise = floats_tensor((4, 3) + (256, 256)).to(torch_device)
        inputs["sample"] = noise
        image = model(**inputs)
484
485
486

        assert image is not None, "Make sure output is not None"

487
    def test_output_pretrained_ve_mid(self):
Patrick von Platen's avatar
Patrick von Platen committed
488
        model = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256")
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
        model.to(torch_device)

        torch.manual_seed(0)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)

        batch_size = 4
        num_channels = 3
        sizes = (256, 256)

        noise = torch.ones((batch_size, num_channels) + sizes).to(torch_device)
        time_step = torch.tensor(batch_size * [1e-4]).to(torch_device)

        with torch.no_grad():
            output = model(noise, time_step)["sample"]

        output_slice = output[0, -3:, -3:, -1].flatten().cpu()
        # fmt: off
        expected_output_slice = torch.tensor([-4836.2231, -6487.1387, -3816.7969, -7964.9253, -10966.2842, -20043.6016, 8137.0571, 2340.3499, 544.6114])
        # fmt: on

        self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))

512
    def test_output_pretrained_ve_large(self):
Patrick von Platen's avatar
Patrick von Platen committed
513
        model = UNet2DModel.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy-update")
514
515
516
517
518
519
520
521
522
523
        model.to(torch_device)

        torch.manual_seed(0)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)

        batch_size = 4
        num_channels = 3
        sizes = (32, 32)

Patrick von Platen's avatar
Patrick von Platen committed
524
525
        noise = torch.ones((batch_size, num_channels) + sizes).to(torch_device)
        time_step = torch.tensor(batch_size * [1e-4]).to(torch_device)
526
527

        with torch.no_grad():
528
            output = model(noise, time_step)["sample"]
529
530
531

        output_slice = output[0, -3:, -3:, -1].flatten().cpu()
        # fmt: off
Patrick von Platen's avatar
Patrick von Platen committed
532
        expected_output_slice = torch.tensor([-0.0325, -0.0900, -0.0869, -0.0332, -0.0725, -0.0270, -0.0101, 0.0227, 0.0256])
533
534
        # fmt: on

Patrick von Platen's avatar
Patrick von Platen committed
535
        self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
536
537


patil-suraj's avatar
patil-suraj committed
538
539
540
541
class VQModelTests(ModelTesterMixin, unittest.TestCase):
    model_class = VQModel

    @property
Patrick von Platen's avatar
Patrick von Platen committed
542
    def dummy_input(self, sizes=(32, 32)):
patil-suraj's avatar
patil-suraj committed
543
544
545
546
547
        batch_size = 4
        num_channels = 3

        image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)

548
        return {"sample": image}
patil-suraj's avatar
patil-suraj committed
549
550
551
552
553
554
555
556
557
558
559

    @property
    def input_shape(self):
        return (3, 32, 32)

    @property
    def output_shape(self):
        return (3, 32, 32)

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
560
            "block_out_channels": [32, 64],
patil-suraj's avatar
patil-suraj committed
561
            "in_channels": 3,
562
            "out_channels": 3,
563
564
            "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
            "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
565
            "latent_channels": 3,
patil-suraj's avatar
patil-suraj committed
566
567
568
569
570
571
572
573
574
575
576
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict

    def test_forward_signature(self):
        pass

    def test_training(self):
        pass

    def test_from_pretrained_hub(self):
Patrick von Platen's avatar
Patrick von Platen committed
577
        model, loading_info = VQModel.from_pretrained("fusing/vqgan-dummy", output_loading_info=True)
patil-suraj's avatar
patil-suraj committed
578
        self.assertIsNotNone(model)
Patrick von Platen's avatar
upload  
Patrick von Platen committed
579
        self.assertEqual(len(loading_info["missing_keys"]), 0)
patil-suraj's avatar
patil-suraj committed
580
581
582
583
584
585
586

        model.to(torch_device)
        image = model(**self.dummy_input)

        assert image is not None, "Make sure output is not None"

    def test_output_pretrained(self):
Patrick von Platen's avatar
Patrick von Platen committed
587
        model = VQModel.from_pretrained("fusing/vqgan-dummy")
patil-suraj's avatar
patil-suraj committed
588
589
590
591
592
593
        model.eval()

        torch.manual_seed(0)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)

594
        image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
patil-suraj's avatar
patil-suraj committed
595
596
597
598
599
        with torch.no_grad():
            output = model(image)

        output_slice = output[0, -1, -3:, -3:].flatten()
        # fmt: off
600
        expected_output_slice = torch.tensor([-0.0153, -0.4044, -0.1880, -0.5161, -0.2418, -0.4072, -0.1612, -0.0633, -0.0143])
patil-suraj's avatar
patil-suraj committed
601
        # fmt: on
Patrick von Platen's avatar
up  
Patrick von Platen committed
602
        self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
patil-suraj's avatar
patil-suraj committed
603
604


Patrick von Platen's avatar
Patrick von Platen committed
605
class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
patil-suraj's avatar
patil-suraj committed
606
607
608
609
610
611
612
613
614
615
    model_class = AutoencoderKL

    @property
    def dummy_input(self):
        batch_size = 4
        num_channels = 3
        sizes = (32, 32)

        image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)

616
        return {"sample": image}
patil-suraj's avatar
patil-suraj committed
617
618
619
620
621
622
623
624
625
626
627

    @property
    def input_shape(self):
        return (3, 32, 32)

    @property
    def output_shape(self):
        return (3, 32, 32)

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
628
            "block_out_channels": [32, 64],
629
630
            "in_channels": 3,
            "out_channels": 3,
631
632
            "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
            "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
633
634
            "latent_channels": 4,
        }
patil-suraj's avatar
patil-suraj committed
635
636
637
638
639
640
641
642
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict

    def test_forward_signature(self):
        pass

    def test_training(self):
        pass
patil-suraj's avatar
patil-suraj committed
643

patil-suraj's avatar
patil-suraj committed
644
    def test_from_pretrained_hub(self):
Patrick von Platen's avatar
Patrick von Platen committed
645
        model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True)
patil-suraj's avatar
patil-suraj committed
646
        self.assertIsNotNone(model)
Patrick von Platen's avatar
upload  
Patrick von Platen committed
647
        self.assertEqual(len(loading_info["missing_keys"]), 0)
patil-suraj's avatar
patil-suraj committed
648
649
650
651
652
653
654

        model.to(torch_device)
        image = model(**self.dummy_input)

        assert image is not None, "Make sure output is not None"

    def test_output_pretrained(self):
Patrick von Platen's avatar
Patrick von Platen committed
655
        model = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy")
patil-suraj's avatar
patil-suraj committed
656
657
658
659
660
661
        model.eval()

        torch.manual_seed(0)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)

662
        image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
patil-suraj's avatar
patil-suraj committed
663
664
665
666
667
        with torch.no_grad():
            output = model(image, sample_posterior=True)

        output_slice = output[0, -1, -3:, -3:].flatten()
        # fmt: off
668
        expected_output_slice = torch.tensor([-4.0078e-01, -3.8304e-04, -1.2681e-01, -1.1462e-01,  2.0095e-01, 1.0893e-01, -8.8248e-02, -3.0361e-01, -9.8646e-03])
patil-suraj's avatar
patil-suraj committed
669
        # fmt: on
Patrick von Platen's avatar
up  
Patrick von Platen committed
670
        self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
patil-suraj's avatar
patil-suraj committed
671
672


673
674
675
class PipelineTesterMixin(unittest.TestCase):
    def test_from_pretrained_save_pretrained(self):
        # 1. Load models
Patrick von Platen's avatar
Patrick von Platen committed
676
677
678
679
        model = UNet2DModel(
            block_out_channels=(32, 64),
            layers_per_block=2,
            sample_size=32,
Patrick von Platen's avatar
Patrick von Platen committed
680
681
            in_channels=3,
            out_channels=3,
Patrick von Platen's avatar
Patrick von Platen committed
682
683
            down_block_types=("DownBlock2D", "AttnDownBlock2D"),
            up_block_types=("AttnUpBlock2D", "UpBlock2D"),
684
        )
Patrick von Platen's avatar
Patrick von Platen committed
685
        schedular = DDPMScheduler(num_train_timesteps=10)
686

Patrick von Platen's avatar
Patrick von Platen committed
687
        ddpm = DDPMPipeline(model, schedular)
688
689
690

        with tempfile.TemporaryDirectory() as tmpdirname:
            ddpm.save_pretrained(tmpdirname)
Patrick von Platen's avatar
Patrick von Platen committed
691
            new_ddpm = DDPMPipeline.from_pretrained(tmpdirname)
Patrick von Platen's avatar
Patrick von Platen committed
692
693

        generator = torch.manual_seed(0)
694

anton-l's avatar
anton-l committed
695
        image = ddpm(generator=generator, output_type="numpy")["sample"]
Patrick von Platen's avatar
Patrick von Platen committed
696
        generator = generator.manual_seed(0)
anton-l's avatar
anton-l committed
697
        new_image = new_ddpm(generator=generator, output_type="numpy")["sample"]
698

anton-l's avatar
anton-l committed
699
        assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
700
701
702

    @slow
    def test_from_pretrained_hub(self):
Patrick von Platen's avatar
Patrick von Platen committed
703
        model_path = "google/ddpm-cifar10-32"
704

Patrick von Platen's avatar
Patrick von Platen committed
705
        ddpm = DDPMPipeline.from_pretrained(model_path)
706
707
        ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path)

708
709
        ddpm.scheduler.num_timesteps = 10
        ddpm_from_hub.scheduler.num_timesteps = 10
710

Patrick von Platen's avatar
Patrick von Platen committed
711
        generator = torch.manual_seed(0)
712

anton-l's avatar
anton-l committed
713
        image = ddpm(generator=generator, output_type="numpy")["sample"]
Patrick von Platen's avatar
Patrick von Platen committed
714
        generator = generator.manual_seed(0)
anton-l's avatar
anton-l committed
715
        new_image = ddpm_from_hub(generator=generator, output_type="numpy")["sample"]
716

anton-l's avatar
anton-l committed
717
        assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
Patrick von Platen's avatar
Patrick von Platen committed
718

719
720
    @slow
    def test_output_format(self):
Patrick von Platen's avatar
Patrick von Platen committed
721
        model_path = "google/ddpm-cifar10-32"
722
723
724
725
726
727
728
729
730
731
732
733
734

        pipe = DDIMPipeline.from_pretrained(model_path)

        generator = torch.manual_seed(0)
        images = pipe(generator=generator, output_type="numpy")["sample"]
        assert images.shape == (1, 32, 32, 3)
        assert isinstance(images, np.ndarray)

        images = pipe(generator=generator, output_type="pil")["sample"]
        assert isinstance(images, list)
        assert len(images) == 1
        assert isinstance(images[0], PIL.Image.Image)

anton-l's avatar
anton-l committed
735
736
737
738
739
        # use PIL by default
        images = pipe(generator=generator)["sample"]
        assert isinstance(images, list)
        assert isinstance(images[0], PIL.Image.Image)

Patrick von Platen's avatar
Patrick von Platen committed
740
741
    @slow
    def test_ddpm_cifar10(self):
Patrick von Platen's avatar
Patrick von Platen committed
742
        model_id = "google/ddpm-cifar10-32"
Patrick von Platen's avatar
Patrick von Platen committed
743

Patrick von Platen's avatar
Patrick von Platen committed
744
        unet = UNet2DModel.from_pretrained(model_id)
745
746
        scheduler = DDPMScheduler.from_config(model_id)
        scheduler = scheduler.set_format("pt")
Patrick von Platen's avatar
Patrick von Platen committed
747

748
        ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
Patrick von Platen's avatar
Patrick von Platen committed
749
750

        generator = torch.manual_seed(0)
anton-l's avatar
anton-l committed
751
        image = ddpm(generator=generator, output_type="numpy")["sample"]
Patrick von Platen's avatar
Patrick von Platen committed
752

753
        image_slice = image[0, -3:, -3:, -1]
Patrick von Platen's avatar
Patrick von Platen committed
754

755
756
757
        assert image.shape == (1, 32, 32, 3)
        expected_slice = np.array([0.41995, 0.35885, 0.19385, 0.38475, 0.3382, 0.2647, 0.41545, 0.3582, 0.33845])
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
758
759
760

    @slow
    def test_ddim_lsun(self):
Patrick von Platen's avatar
Patrick von Platen committed
761
        model_id = "google/ddpm-ema-bedroom-256"
762

Patrick von Platen's avatar
Patrick von Platen committed
763
        unet = UNet2DModel.from_pretrained(model_id)
764
        scheduler = DDIMScheduler.from_config(model_id)
765

766
        ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
767
768

        generator = torch.manual_seed(0)
anton-l's avatar
anton-l committed
769
        image = ddpm(generator=generator, output_type="numpy")["sample"]
770

771
        image_slice = image[0, -3:, -3:, -1]
772

773
774
775
        assert image.shape == (1, 256, 256, 3)
        expected_slice = np.array([0.00605, 0.0201, 0.0344, 0.00235, 0.00185, 0.00025, 0.00215, 0.0, 0.00685])
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
Patrick von Platen's avatar
Patrick von Platen committed
776
777
778

    @slow
    def test_ddim_cifar10(self):
Patrick von Platen's avatar
Patrick von Platen committed
779
        model_id = "google/ddpm-cifar10-32"
Patrick von Platen's avatar
Patrick von Platen committed
780

Patrick von Platen's avatar
Patrick von Platen committed
781
        unet = UNet2DModel.from_pretrained(model_id)
782
        scheduler = DDIMScheduler(tensor_format="pt")
Patrick von Platen's avatar
Patrick von Platen committed
783

784
        ddim = DDIMPipeline(unet=unet, scheduler=scheduler)
785
786

        generator = torch.manual_seed(0)
anton-l's avatar
anton-l committed
787
        image = ddim(generator=generator, eta=0.0, output_type="numpy")["sample"]
Patrick von Platen's avatar
Patrick von Platen committed
788

789
        image_slice = image[0, -3:, -3:, -1]
Patrick von Platen's avatar
Patrick von Platen committed
790

791
792
793
        assert image.shape == (1, 32, 32, 3)
        expected_slice = np.array([0.17235, 0.16175, 0.16005, 0.16255, 0.1497, 0.1513, 0.15045, 0.1442, 0.1453])
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
patil-suraj's avatar
patil-suraj committed
794

Patrick von Platen's avatar
Patrick von Platen committed
795
796
    @slow
    def test_pndm_cifar10(self):
Patrick von Platen's avatar
Patrick von Platen committed
797
        model_id = "google/ddpm-cifar10-32"
Patrick von Platen's avatar
Patrick von Platen committed
798

Patrick von Platen's avatar
Patrick von Platen committed
799
        unet = UNet2DModel.from_pretrained(model_id)
800
        scheduler = PNDMScheduler(tensor_format="pt")
Patrick von Platen's avatar
Patrick von Platen committed
801

802
        pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
Patrick von Platen's avatar
Patrick von Platen committed
803
        generator = torch.manual_seed(0)
anton-l's avatar
anton-l committed
804
        image = pndm(generator=generator, output_type="numpy")["sample"]
Patrick von Platen's avatar
Patrick von Platen committed
805

806
        image_slice = image[0, -3:, -3:, -1]
Patrick von Platen's avatar
Patrick von Platen committed
807

808
809
810
        assert image.shape == (1, 32, 32, 3)
        expected_slice = np.array([0.1564, 0.14645, 0.1406, 0.14715, 0.12425, 0.14045, 0.13115, 0.12175, 0.125])
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
Patrick von Platen's avatar
Patrick von Platen committed
811

patil-suraj's avatar
patil-suraj committed
812
813
    @slow
    def test_ldm_text2img(self):
Patrick von Platen's avatar
Patrick von Platen committed
814
        ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
patil-suraj's avatar
patil-suraj committed
815
816
817

        prompt = "A painting of a squirrel eating a burger"
        generator = torch.manual_seed(0)
anton-l's avatar
anton-l committed
818
819
820
        image = ldm([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="numpy")[
            "sample"
        ]
patil-suraj's avatar
patil-suraj committed
821

822
        image_slice = image[0, -3:, -3:, -1]
patil-suraj's avatar
patil-suraj committed
823

824
825
826
        assert image.shape == (1, 256, 256, 3)
        expected_slice = np.array([0.9256, 0.9340, 0.8933, 0.9361, 0.9113, 0.8727, 0.9122, 0.8745, 0.8099])
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
827

patil-suraj's avatar
patil-suraj committed
828
829
    @slow
    def test_ldm_text2img_fast(self):
Patrick von Platen's avatar
Patrick von Platen committed
830
        ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
patil-suraj's avatar
patil-suraj committed
831
832
833

        prompt = "A painting of a squirrel eating a burger"
        generator = torch.manual_seed(0)
anton-l's avatar
anton-l committed
834
        image = ldm([prompt], generator=generator, num_inference_steps=1, output_type="numpy")["sample"]
patil-suraj's avatar
patil-suraj committed
835

836
        image_slice = image[0, -3:, -3:, -1]
patil-suraj's avatar
patil-suraj committed
837

838
839
840
        assert image.shape == (1, 256, 256, 3)
        expected_slice = np.array([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344])
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
patil-suraj's avatar
patil-suraj committed
841

Patrick von Platen's avatar
Patrick von Platen committed
842
843
    @slow
    def test_score_sde_ve_pipeline(self):
Patrick von Platen's avatar
Patrick von Platen committed
844
845
        model_id = "google/ncsnpp-church-256"
        model = UNet2DModel.from_pretrained(model_id)
846

Patrick von Platen's avatar
Patrick von Platen committed
847
        scheduler = ScoreSdeVeScheduler.from_config(model_id)
Patrick von Platen's avatar
Patrick von Platen committed
848

Patrick von Platen's avatar
Patrick von Platen committed
849
        sde_ve = ScoreSdeVePipeline(unet=model, scheduler=scheduler)
Patrick von Platen's avatar
Patrick von Platen committed
850

851
        torch.manual_seed(0)
anton-l's avatar
anton-l committed
852
        image = sde_ve(num_inference_steps=300, output_type="numpy")["sample"]
Nathan Lambert's avatar
Nathan Lambert committed
853

854
        image_slice = image[0, -3:, -3:, -1]
Patrick von Platen's avatar
Patrick von Platen committed
855

856
857
858
        assert image.shape == (1, 256, 256, 3)
        expected_slice = np.array([0.64363, 0.5868, 0.3031, 0.2284, 0.7409, 0.3216, 0.25643, 0.6557, 0.2633])
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
Patrick von Platen's avatar
Patrick von Platen committed
859

patil-suraj's avatar
patil-suraj committed
860
861
    @slow
    def test_ldm_uncond(self):
Patrick von Platen's avatar
Patrick von Platen committed
862
        ldm = LDMPipeline.from_pretrained("CompVis/ldm-celebahq-256")
patil-suraj's avatar
patil-suraj committed
863
864

        generator = torch.manual_seed(0)
anton-l's avatar
anton-l committed
865
        image = ldm(generator=generator, num_inference_steps=5, output_type="numpy")["sample"]
patil-suraj's avatar
patil-suraj committed
866

867
        image_slice = image[0, -3:, -3:, -1]
patil-suraj's avatar
patil-suraj committed
868

869
870
871
        assert image.shape == (1, 256, 256, 3)
        expected_slice = np.array([0.4399, 0.44975, 0.46825, 0.474, 0.4359, 0.4581, 0.45095, 0.4341, 0.4447])
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889

    @slow
    def test_ddpm_ddim_equality(self):
        model_id = "google/ddpm-cifar10-32"

        unet = UNet2DModel.from_pretrained(model_id)
        ddpm_scheduler = DDPMScheduler(tensor_format="pt")
        ddim_scheduler = DDIMScheduler(tensor_format="pt")

        ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler)
        ddim = DDIMPipeline(unet=unet, scheduler=ddim_scheduler)

        generator = torch.manual_seed(0)
        ddpm_image = ddpm(generator=generator, output_type="numpy")["sample"]

        generator = torch.manual_seed(0)
        ddim_image = ddim(generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy")["sample"]

890
        # the values aren't exactly equal, but the images look the same visually
891
892
        assert np.abs(ddpm_image - ddim_image).max() < 1e-1

893
    @unittest.skip("(Anton) The test is failing for large batch sizes, needs investigation")
894
895
896
897
898
899
900
901
902
903
904
    def test_ddpm_ddim_equality_batched(self):
        model_id = "google/ddpm-cifar10-32"

        unet = UNet2DModel.from_pretrained(model_id)
        ddpm_scheduler = DDPMScheduler(tensor_format="pt")
        ddim_scheduler = DDIMScheduler(tensor_format="pt")

        ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler)
        ddim = DDIMPipeline(unet=unet, scheduler=ddim_scheduler)

        generator = torch.manual_seed(0)
905
        ddpm_images = ddpm(batch_size=4, generator=generator, output_type="numpy")["sample"]
906
907

        generator = torch.manual_seed(0)
908
        ddim_images = ddim(batch_size=4, generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy")[
909
910
911
            "sample"
        ]

912
        # the values aren't exactly equal, but the images look the same visually
913
        assert np.abs(ddpm_images - ddim_images).max() < 1e-1
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929

    @slow
    def test_karras_ve_pipeline(self):
        model_id = "google/ncsnpp-celebahq-256"
        model = UNet2DModel.from_pretrained(model_id)
        scheduler = KarrasVeScheduler(tensor_format="pt")

        pipe = KarrasVePipeline(unet=model, scheduler=scheduler)

        generator = torch.manual_seed(0)
        image = pipe(num_inference_steps=20, generator=generator, output_type="numpy")["sample"]

        image_slice = image[0, -3:, -3:, -1]
        assert image.shape == (1, 256, 256, 3)
        expected_slice = np.array([0.26815, 0.1581, 0.2658, 0.23248, 0.1550, 0.2539, 0.1131, 0.1024, 0.0837])
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2