"vscode:/vscode.git/clone" did not exist on "d40da7b68a3116bb18ee9725b8dc78d39c502473"
test_modeling_utils.py 28.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Patrick von Platen's avatar
Patrick von Platen committed
16

patil-suraj's avatar
patil-suraj committed
17
import inspect
18
import math
19
20
import tempfile
import unittest
21
from atexit import register
22

23
import numpy as np
24
25
import torch

Patrick von Platen's avatar
upload  
Patrick von Platen committed
26
from diffusers import UNetConditionalModel  # noqa: F401 TODO(Patrick) - need to write tests with it
Patrick von Platen's avatar
Patrick von Platen committed
27
from diffusers import (
patil-suraj's avatar
patil-suraj committed
28
    AutoencoderKL,
Patrick von Platen's avatar
Patrick von Platen committed
29
    DDIMPipeline,
30
    DDIMScheduler,
Patrick von Platen's avatar
Patrick von Platen committed
31
    DDPMPipeline,
32
    DDPMScheduler,
Patrick von Platen's avatar
Patrick von Platen committed
33
    LatentDiffusionPipeline,
patil-suraj's avatar
patil-suraj committed
34
    LatentDiffusionUncondPipeline,
Patrick von Platen's avatar
Patrick von Platen committed
35
    PNDMPipeline,
36
    PNDMScheduler,
Patrick von Platen's avatar
Patrick von Platen committed
37
38
    ScoreSdeVePipeline,
    ScoreSdeVeScheduler,
39
    UNetUnconditionalModel,
patil-suraj's avatar
patil-suraj committed
40
    VQModel,
41
)
42
from diffusers.configuration_utils import ConfigMixin, register_to_config
Patrick von Platen's avatar
Patrick von Platen committed
43
from diffusers.pipeline_utils import DiffusionPipeline
Patrick von Platen's avatar
Patrick von Platen committed
44
from diffusers.testing_utils import floats_tensor, slow, torch_device
45
from diffusers.training_utils import EMAModel
46
47


Patrick von Platen's avatar
Patrick von Platen committed
48
torch.backends.cuda.matmul.allow_tf32 = False
49
50


51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
class SampleObject(ConfigMixin):
    config_name = "config.json"

    @register_to_config
    def __init__(
        self,
        a=2,
        b=5,
        c=(2, 5),
        d="for diffusion",
        e=[1, 3],
    ):
        pass


66
67
68
69
70
class ConfigTester(unittest.TestCase):
    def test_load_not_from_mixin(self):
        with self.assertRaises(ValueError):
            ConfigMixin.from_config("dummy_path")

71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
    def test_register_to_config(self):
        obj = SampleObject()
        config = obj.config
        assert config["a"] == 2
        assert config["b"] == 5
        assert config["c"] == (2, 5)
        assert config["d"] == "for diffusion"
        assert config["e"] == [1, 3]

        # init ignore private arguments
        obj = SampleObject(_name_or_path="lalala")
        config = obj.config
        assert config["a"] == 2
        assert config["b"] == 5
        assert config["c"] == (2, 5)
        assert config["d"] == "for diffusion"
        assert config["e"] == [1, 3]
88

89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
        # can override default
        obj = SampleObject(c=6)
        config = obj.config
        assert config["a"] == 2
        assert config["b"] == 5
        assert config["c"] == 6
        assert config["d"] == "for diffusion"
        assert config["e"] == [1, 3]

        # can use positional arguments.
        obj = SampleObject(1, c=6)
        config = obj.config
        assert config["a"] == 1
        assert config["b"] == 5
        assert config["c"] == 6
        assert config["d"] == "for diffusion"
        assert config["e"] == [1, 3]

    def test_save_load(self):
108
109
110
111
112
113
114
115
116
117
118
119
120
121
        obj = SampleObject()
        config = obj.config

        assert config["a"] == 2
        assert config["b"] == 5
        assert config["c"] == (2, 5)
        assert config["d"] == "for diffusion"
        assert config["e"] == [1, 3]

        with tempfile.TemporaryDirectory() as tmpdirname:
            obj.save_config(tmpdirname)
            new_obj = SampleObject.from_config(tmpdirname)
            new_config = new_obj.config

Patrick von Platen's avatar
Patrick von Platen committed
122
123
124
125
        # unfreeze configs
        config = dict(config)
        new_config = dict(new_config)

126
127
128
129
130
        assert config.pop("c") == (2, 5)  # instantiated as tuple
        assert new_config.pop("c") == [2, 5]  # saved & loaded as list because of json
        assert config == new_config


patil-suraj's avatar
patil-suraj committed
131
class ModelTesterMixin:
132
    def test_from_pretrained_save_pretrained(self):
patil-suraj's avatar
patil-suraj committed
133
134
135
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

        model = self.model_class(**init_dict)
Patrick von Platen's avatar
Patrick von Platen committed
136
        model.to(torch_device)
patil-suraj's avatar
patil-suraj committed
137
        model.eval()
138
139
140

        with tempfile.TemporaryDirectory() as tmpdirname:
            model.save_pretrained(tmpdirname)
patil-suraj's avatar
patil-suraj committed
141
            new_model = self.model_class.from_pretrained(tmpdirname)
Patrick von Platen's avatar
Patrick von Platen committed
142
            new_model.to(torch_device)
143

patil-suraj's avatar
patil-suraj committed
144
145
        with torch.no_grad():
            image = model(**inputs_dict)
Patrick von Platen's avatar
Patrick von Platen committed
146
147
148
            if isinstance(image, dict):
                image = image["sample"]

patil-suraj's avatar
patil-suraj committed
149
            new_image = new_model(**inputs_dict)
150

Patrick von Platen's avatar
Patrick von Platen committed
151
152
153
            if isinstance(new_image, dict):
                new_image = new_image["sample"]

patil-suraj's avatar
patil-suraj committed
154
        max_diff = (image - new_image).abs().sum().item()
Patrick von Platen's avatar
Patrick von Platen committed
155
        self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
156

patil-suraj's avatar
patil-suraj committed
157
    def test_determinism(self):
patil-suraj's avatar
patil-suraj committed
158
159
160
161
162
163
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.eval()
        with torch.no_grad():
            first = model(**inputs_dict)
Patrick von Platen's avatar
Patrick von Platen committed
164
165
166
            if isinstance(first, dict):
                first = first["sample"]

patil-suraj's avatar
patil-suraj committed
167
            second = model(**inputs_dict)
Patrick von Platen's avatar
Patrick von Platen committed
168
169
            if isinstance(second, dict):
                second = second["sample"]
patil-suraj's avatar
patil-suraj committed
170
171
172
173
174
175
176

        out_1 = first.cpu().numpy()
        out_2 = second.cpu().numpy()
        out_1 = out_1[~np.isnan(out_1)]
        out_2 = out_2[~np.isnan(out_2)]
        max_diff = np.amax(np.abs(out_1 - out_2))
        self.assertLessEqual(max_diff, 1e-5)
177

patil-suraj's avatar
patil-suraj committed
178
    def test_output(self):
patil-suraj's avatar
patil-suraj committed
179
180
181
182
183
184
185
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.eval()

        with torch.no_grad():
            output = model(**inputs_dict)
186

Patrick von Platen's avatar
Patrick von Platen committed
187
188
189
            if isinstance(output, dict):
                output = output["sample"]

patil-suraj's avatar
patil-suraj committed
190
        self.assertIsNotNone(output)
191
        expected_shape = inputs_dict["sample"].shape
patil-suraj's avatar
patil-suraj committed
192
        self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
193

patil-suraj's avatar
patil-suraj committed
194
    def test_forward_signature(self):
patil-suraj's avatar
patil-suraj committed
195
196
197
198
199
200
201
        init_dict, _ = self.prepare_init_args_and_inputs_for_common()

        model = self.model_class(**init_dict)
        signature = inspect.signature(model.forward)
        # signature.parameters is an OrderedDict => so arg_names order is deterministic
        arg_names = [*signature.parameters.keys()]

202
        expected_arg_names = ["sample", "timestep"]
patil-suraj's avatar
patil-suraj committed
203
        self.assertListEqual(arg_names[:2], expected_arg_names)
204

patil-suraj's avatar
patil-suraj committed
205
    def test_model_from_config(self):
patil-suraj's avatar
patil-suraj committed
206
207
208
209
210
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.eval()
211

patil-suraj's avatar
patil-suraj committed
212
213
214
215
216
217
218
        # test if the model can be loaded from the config
        # and has all the expected shape
        with tempfile.TemporaryDirectory() as tmpdirname:
            model.save_config(tmpdirname)
            new_model = self.model_class.from_config(tmpdirname)
            new_model.to(torch_device)
            new_model.eval()
219

patil-suraj's avatar
patil-suraj committed
220
221
222
223
224
        # check if all paramters shape are the same
        for param_name in model.state_dict().keys():
            param_1 = model.state_dict()[param_name]
            param_2 = new_model.state_dict()[param_name]
            self.assertEqual(param_1.shape, param_2.shape)
225

patil-suraj's avatar
patil-suraj committed
226
227
        with torch.no_grad():
            output_1 = model(**inputs_dict)
Patrick von Platen's avatar
Patrick von Platen committed
228
229
230
231

            if isinstance(output_1, dict):
                output_1 = output_1["sample"]

patil-suraj's avatar
patil-suraj committed
232
            output_2 = new_model(**inputs_dict)
233

Patrick von Platen's avatar
Patrick von Platen committed
234
235
236
            if isinstance(output_2, dict):
                output_2 = output_2["sample"]

patil-suraj's avatar
patil-suraj committed
237
        self.assertEqual(output_1.shape, output_2.shape)
patil-suraj's avatar
patil-suraj committed
238
239

    def test_training(self):
patil-suraj's avatar
patil-suraj committed
240
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
241

patil-suraj's avatar
patil-suraj committed
242
243
244
245
        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.train()
        output = model(**inputs_dict)
Patrick von Platen's avatar
Patrick von Platen committed
246
247
248
249

        if isinstance(output, dict):
            output = output["sample"]

250
        noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device)
patil-suraj's avatar
patil-suraj committed
251
252
        loss = torch.nn.functional.mse_loss(output, noise)
        loss.backward()
253

254
255
256
257
258
259
260
261
262
    def test_ema_training(self):
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.train()
        ema_model = EMAModel(model, device=torch_device)

        output = model(**inputs_dict)
Patrick von Platen's avatar
Patrick von Platen committed
263
264
265
266

        if isinstance(output, dict):
            output = output["sample"]

267
        noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device)
268
269
270
271
        loss = torch.nn.functional.mse_loss(output, noise)
        loss.backward()
        ema_model.step(model)

patil-suraj's avatar
patil-suraj committed
272
273

class UnetModelTests(ModelTesterMixin, unittest.TestCase):
274
    model_class = UNetUnconditionalModel
patil-suraj's avatar
patil-suraj committed
275
276
277
278
279
280
281
282
283
284

    @property
    def dummy_input(self):
        batch_size = 4
        num_channels = 3
        sizes = (32, 32)

        noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
        time_step = torch.tensor([10]).to(torch_device)

285
        return {"sample": noise, "timestep": time_step}
286

patil-suraj's avatar
patil-suraj committed
287
    @property
Patrick von Platen's avatar
Patrick von Platen committed
288
    def input_shape(self):
patil-suraj's avatar
patil-suraj committed
289
        return (3, 32, 32)
290

patil-suraj's avatar
patil-suraj committed
291
    @property
Patrick von Platen's avatar
Patrick von Platen committed
292
    def output_shape(self):
patil-suraj's avatar
patil-suraj committed
293
        return (3, 32, 32)
patil-suraj's avatar
patil-suraj committed
294
295
296

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
297
298
299
300
301
302
            "block_channels": (32, 64),
            "down_blocks": ("UNetResDownBlock2D", "UNetResAttnDownBlock2D"),
            "up_blocks": ("UNetResAttnUpBlock2D", "UNetResUpBlock2D"),
            "num_head_channels": None,
            "out_channels": 3,
            "in_channels": 3,
patil-suraj's avatar
patil-suraj committed
303
            "num_res_blocks": 2,
304
            "image_size": 32,
patil-suraj's avatar
patil-suraj committed
305
306
307
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict
308

patil-suraj's avatar
patil-suraj committed
309

Patrick von Platen's avatar
upload  
Patrick von Platen committed
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
#    TODO(Patrick) - Re-add this test after having correctly added the final VE checkpoints
#    def test_output_pretrained(self):
#        model = UNetUnconditionalModel.from_pretrained("fusing/ddpm_dummy_update", subfolder="unet")
#        model.eval()
#
#        torch.manual_seed(0)
#        if torch.cuda.is_available():
#            torch.cuda.manual_seed_all(0)
#
#        noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
#        time_step = torch.tensor([10])
#
#        with torch.no_grad():
#            output = model(noise, time_step)["sample"]
#
#        output_slice = output[0, -1, -3:, -3:].flatten()
# fmt: off
#        expected_output_slice = torch.tensor([0.2891, -0.1899, 0.2595, -0.6214, 0.0968, -0.2622, 0.4688, 0.1311, 0.0053])
# fmt: on
#        self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
330
331


patil-suraj's avatar
patil-suraj committed
332
class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
333
    model_class = UNetUnconditionalModel
patil-suraj's avatar
patil-suraj committed
334
335
336
337
338
339
340
341
342
343

    @property
    def dummy_input(self):
        batch_size = 4
        num_channels = 4
        sizes = (32, 32)

        noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
        time_step = torch.tensor([10]).to(torch_device)

344
        return {"sample": noise, "timestep": time_step}
patil-suraj's avatar
patil-suraj committed
345
346

    @property
Patrick von Platen's avatar
Patrick von Platen committed
347
    def input_shape(self):
patil-suraj's avatar
patil-suraj committed
348
349
350
        return (4, 32, 32)

    @property
Patrick von Platen's avatar
Patrick von Platen committed
351
    def output_shape(self):
patil-suraj's avatar
patil-suraj committed
352
353
354
355
356
357
358
359
        return (4, 32, 32)

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
            "image_size": 32,
            "in_channels": 4,
            "out_channels": 4,
            "num_res_blocks": 2,
Patrick von Platen's avatar
Patrick von Platen committed
360
            "block_channels": (32, 64),
361
            "num_head_channels": 32,
patil-suraj's avatar
patil-suraj committed
362
            "conv_resample": True,
363
364
            "down_blocks": ("UNetResDownBlock2D", "UNetResDownBlock2D"),
            "up_blocks": ("UNetResUpBlock2D", "UNetResUpBlock2D"),
patil-suraj's avatar
patil-suraj committed
365
366
367
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict
anton-l's avatar
anton-l committed
368

patil-suraj's avatar
patil-suraj committed
369
    def test_from_pretrained_hub(self):
Patrick von Platen's avatar
Patrick von Platen committed
370
        model, loading_info = UNetUnconditionalModel.from_pretrained(
Patrick von Platen's avatar
upload  
Patrick von Platen committed
371
            "fusing/unet-ldm-dummy-update", output_loading_info=True
Patrick von Platen's avatar
Patrick von Platen committed
372
        )
patil-suraj's avatar
patil-suraj committed
373
        self.assertIsNotNone(model)
Patrick von Platen's avatar
upload  
Patrick von Platen committed
374
        self.assertEqual(len(loading_info["missing_keys"]), 0)
patil-suraj's avatar
patil-suraj committed
375
376

        model.to(torch_device)
Patrick von Platen's avatar
Patrick von Platen committed
377
        image = model(**self.dummy_input)["sample"]
patil-suraj's avatar
patil-suraj committed
378
379
380
381

        assert image is not None, "Make sure output is not None"

    def test_output_pretrained(self):
Patrick von Platen's avatar
upload  
Patrick von Platen committed
382
        model = UNetUnconditionalModel.from_pretrained("fusing/unet-ldm-dummy-update")
patil-suraj's avatar
patil-suraj committed
383
384
385
386
387
388
389
390
391
392
        model.eval()

        torch.manual_seed(0)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)

        noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
        time_step = torch.tensor([10] * noise.shape[0])

        with torch.no_grad():
Patrick von Platen's avatar
Patrick von Platen committed
393
            output = model(noise, time_step)["sample"]
patil-suraj's avatar
patil-suraj committed
394
395
396
397
398
399
400
401

        output_slice = output[0, -1, -3:, -3:].flatten()
        # fmt: off
        expected_output_slice = torch.tensor([-13.3258, -20.1100, -15.9873, -17.6617, -23.0596, -17.9419, -13.3675, -16.1889, -12.3800])
        # fmt: on

        self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))

Patrick von Platen's avatar
Patrick von Platen committed
402

Patrick von Platen's avatar
upload  
Patrick von Platen committed
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
#    TODO(Patrick) - Re-add this test after having cleaned up LDM
#    def test_output_pretrained_spatial_transformer(self):
#        model = UNetLDMModel.from_pretrained("fusing/unet-ldm-dummy-spatial")
#        model.eval()
#
#        torch.manual_seed(0)
#        if torch.cuda.is_available():
#            torch.cuda.manual_seed_all(0)
#
#        noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
#        context = torch.ones((1, 16, 64), dtype=torch.float32)
#        time_step = torch.tensor([10] * noise.shape[0])
#
#        with torch.no_grad():
#            output = model(noise, time_step, context=context)
#
#        output_slice = output[0, -1, -3:, -3:].flatten()
# fmt: off
#        expected_output_slice = torch.tensor([61.3445, 56.9005, 29.4339, 59.5497, 60.7375, 34.1719, 48.1951, 42.6569, 25.0890])
# fmt: on
#
#        self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
#
Patrick von Platen's avatar
Patrick von Platen committed
426

patil-suraj's avatar
patil-suraj committed
427

428
class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
429
    model_class = UNetUnconditionalModel
430
431
432
433
434
435
436
437
438
439

    @property
    def dummy_input(self):
        batch_size = 4
        num_channels = 3
        sizes = (32, 32)

        noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
        time_step = torch.tensor(batch_size * [10]).to(torch_device)

440
        return {"sample": noise, "timestep": time_step}
441
442

    @property
Patrick von Platen's avatar
Patrick von Platen committed
443
    def input_shape(self):
444
445
446
        return (3, 32, 32)

    @property
Patrick von Platen's avatar
Patrick von Platen committed
447
    def output_shape(self):
448
449
450
451
        return (3, 32, 32)

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
            "block_channels": [32, 64, 64, 64],
            "in_channels": 3,
            "num_res_blocks": 1,
            "out_channels": 3,
            "time_embedding_type": "fourier",
            "resnet_eps": 1e-6,
            "mid_block_scale_factor": math.sqrt(2.0),
            "resnet_num_groups": None,
            "down_blocks": [
                "UNetResSkipDownBlock2D",
                "UNetResAttnSkipDownBlock2D",
                "UNetResSkipDownBlock2D",
                "UNetResSkipDownBlock2D",
            ],
            "up_blocks": [
                "UNetResSkipUpBlock2D",
                "UNetResSkipUpBlock2D",
                "UNetResAttnSkipUpBlock2D",
                "UNetResSkipUpBlock2D",
            ],
472
473
474
475
476
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict

    def test_from_pretrained_hub(self):
477
        model, loading_info = UNetUnconditionalModel.from_pretrained(
Patrick von Platen's avatar
upload  
Patrick von Platen committed
478
            "fusing/ncsnpp-ffhq-ve-dummy-update", output_loading_info=True
479
        )
480
        self.assertIsNotNone(model)
Patrick von Platen's avatar
upload  
Patrick von Platen committed
481
        self.assertEqual(len(loading_info["missing_keys"]), 0)
482
483
484
485
486
487

        model.to(torch_device)
        image = model(**self.dummy_input)

        assert image is not None, "Make sure output is not None"

488
    def test_output_pretrained_ve_mid(self):
Patrick von Platen's avatar
upload  
Patrick von Platen committed
489
        model = UNetUnconditionalModel.from_pretrained("google/ncsnpp-celebahq-256")
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
        model.to(torch_device)

        torch.manual_seed(0)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)

        batch_size = 4
        num_channels = 3
        sizes = (256, 256)

        noise = torch.ones((batch_size, num_channels) + sizes).to(torch_device)
        time_step = torch.tensor(batch_size * [1e-4]).to(torch_device)

        with torch.no_grad():
            output = model(noise, time_step)["sample"]

        output_slice = output[0, -3:, -3:, -1].flatten().cpu()
        # fmt: off
        expected_output_slice = torch.tensor([-4836.2231, -6487.1387, -3816.7969, -7964.9253, -10966.2842, -20043.6016, 8137.0571, 2340.3499, 544.6114])
        # fmt: on

        self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))

513
    def test_output_pretrained_ve_large(self):
Patrick von Platen's avatar
upload  
Patrick von Platen committed
514
        model = UNetUnconditionalModel.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy-update")
515
516
517
518
519
520
521
522
523
524
        model.to(torch_device)

        torch.manual_seed(0)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)

        batch_size = 4
        num_channels = 3
        sizes = (32, 32)

Patrick von Platen's avatar
Patrick von Platen committed
525
526
        noise = torch.ones((batch_size, num_channels) + sizes).to(torch_device)
        time_step = torch.tensor(batch_size * [1e-4]).to(torch_device)
527
528

        with torch.no_grad():
529
            output = model(noise, time_step)["sample"]
530
531
532

        output_slice = output[0, -3:, -3:, -1].flatten().cpu()
        # fmt: off
Patrick von Platen's avatar
Patrick von Platen committed
533
        expected_output_slice = torch.tensor([-0.0325, -0.0900, -0.0869, -0.0332, -0.0725, -0.0270, -0.0101, 0.0227, 0.0256])
534
535
        # fmt: on

Patrick von Platen's avatar
Patrick von Platen committed
536
        self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
537
538


patil-suraj's avatar
patil-suraj committed
539
540
541
542
543
544
545
546
547
548
549
class VQModelTests(ModelTesterMixin, unittest.TestCase):
    model_class = VQModel

    @property
    def dummy_input(self):
        batch_size = 4
        num_channels = 3
        sizes = (32, 32)

        image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)

550
        return {"sample": image}
patil-suraj's avatar
patil-suraj committed
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565

    @property
    def input_shape(self):
        return (3, 32, 32)

    @property
    def output_shape(self):
        return (3, 32, 32)

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
            "ch": 64,
            "out_ch": 3,
            "num_res_blocks": 1,
            "in_channels": 3,
566
            "attn_resolutions": [],
patil-suraj's avatar
patil-suraj committed
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
            "resolution": 32,
            "z_channels": 3,
            "n_embed": 256,
            "embed_dim": 3,
            "sane_index_shape": False,
            "ch_mult": (1,),
            "dropout": 0.0,
            "double_z": False,
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict

    def test_forward_signature(self):
        pass

    def test_training(self):
        pass

    def test_from_pretrained_hub(self):
        model, loading_info = VQModel.from_pretrained("fusing/vqgan-dummy", output_loading_info=True)
        self.assertIsNotNone(model)
Patrick von Platen's avatar
upload  
Patrick von Platen committed
588
        self.assertEqual(len(loading_info["missing_keys"]), 0)
patil-suraj's avatar
patil-suraj committed
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608

        model.to(torch_device)
        image = model(**self.dummy_input)

        assert image is not None, "Make sure output is not None"

    def test_output_pretrained(self):
        model = VQModel.from_pretrained("fusing/vqgan-dummy")
        model.eval()

        torch.manual_seed(0)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)

        image = torch.randn(1, model.config.in_channels, model.config.resolution, model.config.resolution)
        with torch.no_grad():
            output = model(image)

        output_slice = output[0, -1, -3:, -3:].flatten()
        # fmt: off
Patrick von Platen's avatar
up  
Patrick von Platen committed
609
        expected_output_slice = torch.tensor([-1.1321, 0.1056, 0.3505, -0.6461, -0.2014, 0.0419, -0.5763, -0.8462, -0.4218])
patil-suraj's avatar
patil-suraj committed
610
        # fmt: on
Patrick von Platen's avatar
up  
Patrick von Platen committed
611
        self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
patil-suraj's avatar
patil-suraj committed
612
613


Patrick von Platen's avatar
Patrick von Platen committed
614
class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
patil-suraj's avatar
patil-suraj committed
615
616
617
618
619
620
621
622
623
624
    model_class = AutoencoderKL

    @property
    def dummy_input(self):
        batch_size = 4
        num_channels = 3
        sizes = (32, 32)

        image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)

625
        return {"sample": image}
patil-suraj's avatar
patil-suraj committed
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640

    @property
    def input_shape(self):
        return (3, 32, 32)

    @property
    def output_shape(self):
        return (3, 32, 32)

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
            "ch": 64,
            "ch_mult": (1,),
            "embed_dim": 4,
            "in_channels": 3,
641
            "attn_resolutions": [],
patil-suraj's avatar
patil-suraj committed
642
643
644
645
646
647
648
649
650
651
652
653
654
            "num_res_blocks": 1,
            "out_ch": 3,
            "resolution": 32,
            "z_channels": 4,
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict

    def test_forward_signature(self):
        pass

    def test_training(self):
        pass
patil-suraj's avatar
patil-suraj committed
655

patil-suraj's avatar
patil-suraj committed
656
657
658
    def test_from_pretrained_hub(self):
        model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True)
        self.assertIsNotNone(model)
Patrick von Platen's avatar
upload  
Patrick von Platen committed
659
        self.assertEqual(len(loading_info["missing_keys"]), 0)
patil-suraj's avatar
patil-suraj committed
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679

        model.to(torch_device)
        image = model(**self.dummy_input)

        assert image is not None, "Make sure output is not None"

    def test_output_pretrained(self):
        model = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy")
        model.eval()

        torch.manual_seed(0)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)

        image = torch.randn(1, model.config.in_channels, model.config.resolution, model.config.resolution)
        with torch.no_grad():
            output = model(image, sample_posterior=True)

        output_slice = output[0, -1, -3:, -3:].flatten()
        # fmt: off
Patrick von Platen's avatar
up  
Patrick von Platen committed
680
        expected_output_slice = torch.tensor([-0.0814, -0.0229, -0.1320, -0.4123, -0.0366, -0.3473, 0.0438, -0.1662, 0.1750])
patil-suraj's avatar
patil-suraj committed
681
        # fmt: on
Patrick von Platen's avatar
up  
Patrick von Platen committed
682
        self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
patil-suraj's avatar
patil-suraj committed
683
684


685
686
687
class PipelineTesterMixin(unittest.TestCase):
    def test_from_pretrained_save_pretrained(self):
        # 1. Load models
688
        model = UNetUnconditionalModel(
Patrick von Platen's avatar
Patrick von Platen committed
689
690
691
692
693
694
695
            block_channels=(32, 64),
            num_res_blocks=2,
            image_size=32,
            in_channels=3,
            out_channels=3,
            down_blocks=("UNetResDownBlock2D", "UNetResAttnDownBlock2D"),
            up_blocks=("UNetResAttnUpBlock2D", "UNetResUpBlock2D"),
696
        )
Patrick von Platen's avatar
Patrick von Platen committed
697
        schedular = DDPMScheduler(num_train_timesteps=10)
698

Patrick von Platen's avatar
Patrick von Platen committed
699
        ddpm = DDPMPipeline(model, schedular)
700
701
702

        with tempfile.TemporaryDirectory() as tmpdirname:
            ddpm.save_pretrained(tmpdirname)
Patrick von Platen's avatar
Patrick von Platen committed
703
            new_ddpm = DDPMPipeline.from_pretrained(tmpdirname)
Patrick von Platen's avatar
Patrick von Platen committed
704
705

        generator = torch.manual_seed(0)
706

Patrick von Platen's avatar
Patrick von Platen committed
707
        image = ddpm(generator=generator)["sample"]
Patrick von Platen's avatar
Patrick von Platen committed
708
        generator = generator.manual_seed(0)
Patrick von Platen's avatar
Patrick von Platen committed
709
        new_image = new_ddpm(generator=generator)["sample"]
710
711
712
713
714

        assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"

    @slow
    def test_from_pretrained_hub(self):
Patrick von Platen's avatar
upload  
Patrick von Platen committed
715
        model_path = "google/ddpm-cifar10-32"
716

Patrick von Platen's avatar
Patrick von Platen committed
717
        ddpm = DDPMPipeline.from_pretrained(model_path)
718
719
        ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path)

720
721
        ddpm.scheduler.num_timesteps = 10
        ddpm_from_hub.scheduler.num_timesteps = 10
722

Patrick von Platen's avatar
Patrick von Platen committed
723
        generator = torch.manual_seed(0)
724

Patrick von Platen's avatar
Patrick von Platen committed
725
        image = ddpm(generator=generator)["sample"]
Patrick von Platen's avatar
Patrick von Platen committed
726
        generator = generator.manual_seed(0)
Patrick von Platen's avatar
Patrick von Platen committed
727
        new_image = ddpm_from_hub(generator=generator)["sample"]
728
729

        assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"
Patrick von Platen's avatar
Patrick von Platen committed
730
731
732

    @slow
    def test_ddpm_cifar10(self):
Patrick von Platen's avatar
upload  
Patrick von Platen committed
733
        model_id = "google/ddpm-cifar10-32"
Patrick von Platen's avatar
Patrick von Platen committed
734

Lysandre Debut's avatar
Lysandre Debut committed
735
        unet = UNetUnconditionalModel.from_pretrained(model_id)
736
737
        scheduler = DDPMScheduler.from_config(model_id)
        scheduler = scheduler.set_format("pt")
Patrick von Platen's avatar
Patrick von Platen committed
738

739
        ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
Patrick von Platen's avatar
Patrick von Platen committed
740
741

        generator = torch.manual_seed(0)
Patrick von Platen's avatar
Patrick von Platen committed
742
        image = ddpm(generator=generator)["sample"]
Patrick von Platen's avatar
Patrick von Platen committed
743
744
745
746

        image_slice = image[0, -1, -3:, -3:].cpu()

        assert image.shape == (1, 3, 32, 32)
Patrick von Platen's avatar
Patrick von Platen committed
747
        expected_slice = torch.tensor(
748
749
750
751
752
753
            [-0.1601, -0.2823, -0.6123, -0.2305, -0.3236, -0.4706, -0.1691, -0.2836, -0.3231]
        )
        assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2

    @slow
    def test_ddim_lsun(self):
Patrick von Platen's avatar
upload  
Patrick von Platen committed
754
        model_id = "google/ddpm-ema-bedroom-256"
755

Lysandre Debut's avatar
Lysandre Debut committed
756
        unet = UNetUnconditionalModel.from_pretrained(model_id)
757
        scheduler = DDIMScheduler.from_config(model_id)
758

759
        ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
760
761

        generator = torch.manual_seed(0)
762
        image = ddpm(generator=generator)["sample"]
763
764
765
766
767
768

        image_slice = image[0, -1, -3:, -3:].cpu()

        assert image.shape == (1, 3, 256, 256)
        expected_slice = torch.tensor(
            [-0.9879, -0.9598, -0.9312, -0.9953, -0.9963, -0.9995, -0.9957, -1.0000, -0.9863]
Patrick von Platen's avatar
Patrick von Platen committed
769
        )
Patrick von Platen's avatar
Patrick von Platen committed
770
771
772
773
        assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2

    @slow
    def test_ddim_cifar10(self):
Patrick von Platen's avatar
upload  
Patrick von Platen committed
774
        model_id = "google/ddpm-cifar10-32"
Patrick von Platen's avatar
Patrick von Platen committed
775

Lysandre Debut's avatar
Lysandre Debut committed
776
        unet = UNetUnconditionalModel.from_pretrained(model_id)
777
        scheduler = DDIMScheduler(tensor_format="pt")
Patrick von Platen's avatar
Patrick von Platen committed
778

779
        ddim = DDIMPipeline(unet=unet, scheduler=scheduler)
780
781

        generator = torch.manual_seed(0)
782
        image = ddim(generator=generator, eta=0.0)["sample"]
Patrick von Platen's avatar
Patrick von Platen committed
783
784
785
786

        image_slice = image[0, -1, -3:, -3:].cpu()

        assert image.shape == (1, 3, 32, 32)
Patrick von Platen's avatar
Patrick von Platen committed
787
        expected_slice = torch.tensor(
788
            [-0.6553, -0.6765, -0.6799, -0.6749, -0.7006, -0.6974, -0.6991, -0.7116, -0.7094]
Patrick von Platen's avatar
Patrick von Platen committed
789
        )
Patrick von Platen's avatar
Patrick von Platen committed
790
        assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
patil-suraj's avatar
patil-suraj committed
791

Patrick von Platen's avatar
Patrick von Platen committed
792
793
    @slow
    def test_pndm_cifar10(self):
Patrick von Platen's avatar
upload  
Patrick von Platen committed
794
        model_id = "google/ddpm-cifar10-32"
Patrick von Platen's avatar
Patrick von Platen committed
795

Patrick von Platen's avatar
Patrick von Platen committed
796
        unet = UNetUnconditionalModel.from_pretrained(model_id)
797
        scheduler = PNDMScheduler(tensor_format="pt")
Patrick von Platen's avatar
Patrick von Platen committed
798

799
        pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
Patrick von Platen's avatar
Patrick von Platen committed
800
        generator = torch.manual_seed(0)
Patrick von Platen's avatar
Patrick von Platen committed
801
        image = pndm(generator=generator)["sample"]
Patrick von Platen's avatar
Patrick von Platen committed
802
803
804
805
806

        image_slice = image[0, -1, -3:, -3:].cpu()

        assert image.shape == (1, 3, 32, 32)
        expected_slice = torch.tensor(
Patrick von Platen's avatar
Patrick von Platen committed
807
            [-0.6872, -0.7071, -0.7188, -0.7057, -0.7515, -0.7191, -0.7377, -0.7565, -0.7500]
Patrick von Platen's avatar
Patrick von Platen committed
808
809
810
        )
        assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2

patil-suraj's avatar
patil-suraj committed
811
812
    @slow
    def test_ldm_text2img(self):
Patrick von Platen's avatar
upload  
Patrick von Platen committed
813
        ldm = LatentDiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
patil-suraj's avatar
patil-suraj committed
814
815
816
817
818
819
820
821
822

        prompt = "A painting of a squirrel eating a burger"
        generator = torch.manual_seed(0)
        image = ldm([prompt], generator=generator, num_inference_steps=20)

        image_slice = image[0, -1, -3:, -3:].cpu()

        assert image.shape == (1, 3, 256, 256)
        expected_slice = torch.tensor([0.7295, 0.7358, 0.7256, 0.7435, 0.7095, 0.6884, 0.7325, 0.6921, 0.6458])
Patrick von Platen's avatar
update  
Patrick von Platen committed
823
        assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
824

patil-suraj's avatar
patil-suraj committed
825
826
    @slow
    def test_ldm_text2img_fast(self):
Patrick von Platen's avatar
upload  
Patrick von Platen committed
827
        ldm = LatentDiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
patil-suraj's avatar
patil-suraj committed
828
829
830

        prompt = "A painting of a squirrel eating a burger"
        generator = torch.manual_seed(0)
831
        image = ldm([prompt], generator=generator, num_inference_steps=1)
patil-suraj's avatar
patil-suraj committed
832
833
834
835

        image_slice = image[0, -1, -3:, -3:].cpu()

        assert image.shape == (1, 3, 256, 256)
patil-suraj's avatar
patil-suraj committed
836
        expected_slice = torch.tensor([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344])
patil-suraj's avatar
patil-suraj committed
837
838
        assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2

Patrick von Platen's avatar
Patrick von Platen committed
839
840
    @slow
    def test_score_sde_ve_pipeline(self):
Patrick von Platen's avatar
upload  
Patrick von Platen committed
841
        model = UNetUnconditionalModel.from_pretrained("google/ncsnpp-ffhq-1024")
842
843
844
845
846

        torch.manual_seed(0)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)

Patrick von Platen's avatar
upload  
Patrick von Platen committed
847
        scheduler = ScoreSdeVeScheduler.from_config("google/ncsnpp-ffhq-1024")
Patrick von Platen's avatar
Patrick von Platen committed
848
849
850

        sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler)

851
        torch.manual_seed(0)
Patrick von Platen's avatar
Patrick von Platen committed
852
853
        image = sde_ve(num_inference_steps=2)

Patrick von Platen's avatar
Patrick von Platen committed
854
        if model.device.type == "cpu":
Nathan Lambert's avatar
Nathan Lambert committed
855
856
857
858
859
860
861
            # patrick's cpu
            expected_image_sum = 3384805888.0
            expected_image_mean = 1076.00085

            # m1 mbp
            # expected_image_sum = 3384805376.0
            # expected_image_mean = 1076.000610351562
Patrick von Platen's avatar
Patrick von Platen committed
862
863
        else:
            expected_image_sum = 3382849024.0
Nathan Lambert's avatar
Nathan Lambert committed
864
            expected_image_mean = 1075.3788
Patrick von Platen's avatar
Patrick von Platen committed
865
866
867
868

        assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2
        assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4

patil-suraj's avatar
patil-suraj committed
869
870
    @slow
    def test_ldm_uncond(self):
Patrick von Platen's avatar
upload  
Patrick von Platen committed
871
        ldm = LatentDiffusionUncondPipeline.from_pretrained("CompVis/ldm-celebahq-256")
patil-suraj's avatar
patil-suraj committed
872
873

        generator = torch.manual_seed(0)
874
        image = ldm(generator=generator, num_inference_steps=5)["sample"]
patil-suraj's avatar
patil-suraj committed
875
876
877
878

        image_slice = image[0, -1, -3:, -3:].cpu()

        assert image.shape == (1, 3, 256, 256)
patil-suraj's avatar
patil-suraj committed
879
880
881
        expected_slice = torch.tensor(
            [-0.1202, -0.1005, -0.0635, -0.0520, -0.1282, -0.0838, -0.0981, -0.1318, -0.1106]
        )
patil-suraj's avatar
patil-suraj committed
882
        assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2