test_modeling_utils.py 28.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Patrick von Platen's avatar
Patrick von Platen committed
16

patil-suraj's avatar
patil-suraj committed
17
import inspect
18
import math
19
20
import tempfile
import unittest
21
from atexit import register
22

23
import numpy as np
24
25
import torch

Patrick von Platen's avatar
upload  
Patrick von Platen committed
26
from diffusers import UNetConditionalModel  # noqa: F401 TODO(Patrick) - need to write tests with it
Patrick von Platen's avatar
Patrick von Platen committed
27
from diffusers import (
patil-suraj's avatar
patil-suraj committed
28
    AutoencoderKL,
Patrick von Platen's avatar
Patrick von Platen committed
29
    DDIMPipeline,
30
    DDIMScheduler,
Patrick von Platen's avatar
Patrick von Platen committed
31
    DDPMPipeline,
32
    DDPMScheduler,
Patrick von Platen's avatar
Patrick von Platen committed
33
    LatentDiffusionPipeline,
patil-suraj's avatar
patil-suraj committed
34
    LatentDiffusionUncondPipeline,
Patrick von Platen's avatar
Patrick von Platen committed
35
    PNDMPipeline,
36
    PNDMScheduler,
Patrick von Platen's avatar
Patrick von Platen committed
37
38
    ScoreSdeVePipeline,
    ScoreSdeVeScheduler,
39
    UNetUnconditionalModel,
patil-suraj's avatar
patil-suraj committed
40
    VQModel,
41
)
42
from diffusers.configuration_utils import ConfigMixin, register_to_config
Patrick von Platen's avatar
Patrick von Platen committed
43
from diffusers.pipeline_utils import DiffusionPipeline
Patrick von Platen's avatar
Patrick von Platen committed
44
from diffusers.testing_utils import floats_tensor, slow, torch_device
45
from diffusers.training_utils import EMAModel
46
47


Patrick von Platen's avatar
Patrick von Platen committed
48
torch.backends.cuda.matmul.allow_tf32 = False
49
50


51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
class SampleObject(ConfigMixin):
    config_name = "config.json"

    @register_to_config
    def __init__(
        self,
        a=2,
        b=5,
        c=(2, 5),
        d="for diffusion",
        e=[1, 3],
    ):
        pass


66
67
68
69
70
class ConfigTester(unittest.TestCase):
    def test_load_not_from_mixin(self):
        with self.assertRaises(ValueError):
            ConfigMixin.from_config("dummy_path")

71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
    def test_register_to_config(self):
        obj = SampleObject()
        config = obj.config
        assert config["a"] == 2
        assert config["b"] == 5
        assert config["c"] == (2, 5)
        assert config["d"] == "for diffusion"
        assert config["e"] == [1, 3]

        # init ignore private arguments
        obj = SampleObject(_name_or_path="lalala")
        config = obj.config
        assert config["a"] == 2
        assert config["b"] == 5
        assert config["c"] == (2, 5)
        assert config["d"] == "for diffusion"
        assert config["e"] == [1, 3]
88

89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
        # can override default
        obj = SampleObject(c=6)
        config = obj.config
        assert config["a"] == 2
        assert config["b"] == 5
        assert config["c"] == 6
        assert config["d"] == "for diffusion"
        assert config["e"] == [1, 3]

        # can use positional arguments.
        obj = SampleObject(1, c=6)
        config = obj.config
        assert config["a"] == 1
        assert config["b"] == 5
        assert config["c"] == 6
        assert config["d"] == "for diffusion"
        assert config["e"] == [1, 3]

    def test_save_load(self):
108
109
110
111
112
113
114
115
116
117
118
119
120
121
        obj = SampleObject()
        config = obj.config

        assert config["a"] == 2
        assert config["b"] == 5
        assert config["c"] == (2, 5)
        assert config["d"] == "for diffusion"
        assert config["e"] == [1, 3]

        with tempfile.TemporaryDirectory() as tmpdirname:
            obj.save_config(tmpdirname)
            new_obj = SampleObject.from_config(tmpdirname)
            new_config = new_obj.config

Patrick von Platen's avatar
Patrick von Platen committed
122
123
124
125
        # unfreeze configs
        config = dict(config)
        new_config = dict(new_config)

126
127
128
129
130
        assert config.pop("c") == (2, 5)  # instantiated as tuple
        assert new_config.pop("c") == [2, 5]  # saved & loaded as list because of json
        assert config == new_config


patil-suraj's avatar
patil-suraj committed
131
class ModelTesterMixin:
132
    def test_from_pretrained_save_pretrained(self):
patil-suraj's avatar
patil-suraj committed
133
134
135
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

        model = self.model_class(**init_dict)
Patrick von Platen's avatar
Patrick von Platen committed
136
        model.to(torch_device)
patil-suraj's avatar
patil-suraj committed
137
        model.eval()
138
139
140

        with tempfile.TemporaryDirectory() as tmpdirname:
            model.save_pretrained(tmpdirname)
patil-suraj's avatar
patil-suraj committed
141
            new_model = self.model_class.from_pretrained(tmpdirname)
Patrick von Platen's avatar
Patrick von Platen committed
142
            new_model.to(torch_device)
143

patil-suraj's avatar
patil-suraj committed
144
145
        with torch.no_grad():
            image = model(**inputs_dict)
Patrick von Platen's avatar
Patrick von Platen committed
146
147
148
            if isinstance(image, dict):
                image = image["sample"]

patil-suraj's avatar
patil-suraj committed
149
            new_image = new_model(**inputs_dict)
150

Patrick von Platen's avatar
Patrick von Platen committed
151
152
153
            if isinstance(new_image, dict):
                new_image = new_image["sample"]

patil-suraj's avatar
patil-suraj committed
154
        max_diff = (image - new_image).abs().sum().item()
Patrick von Platen's avatar
Patrick von Platen committed
155
        self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
156

patil-suraj's avatar
patil-suraj committed
157
    def test_determinism(self):
patil-suraj's avatar
patil-suraj committed
158
159
160
161
162
163
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.eval()
        with torch.no_grad():
            first = model(**inputs_dict)
Patrick von Platen's avatar
Patrick von Platen committed
164
165
166
            if isinstance(first, dict):
                first = first["sample"]

patil-suraj's avatar
patil-suraj committed
167
            second = model(**inputs_dict)
Patrick von Platen's avatar
Patrick von Platen committed
168
169
            if isinstance(second, dict):
                second = second["sample"]
patil-suraj's avatar
patil-suraj committed
170
171
172
173
174
175
176

        out_1 = first.cpu().numpy()
        out_2 = second.cpu().numpy()
        out_1 = out_1[~np.isnan(out_1)]
        out_2 = out_2[~np.isnan(out_2)]
        max_diff = np.amax(np.abs(out_1 - out_2))
        self.assertLessEqual(max_diff, 1e-5)
177

patil-suraj's avatar
patil-suraj committed
178
    def test_output(self):
patil-suraj's avatar
patil-suraj committed
179
180
181
182
183
184
185
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.eval()

        with torch.no_grad():
            output = model(**inputs_dict)
186

Patrick von Platen's avatar
Patrick von Platen committed
187
188
189
            if isinstance(output, dict):
                output = output["sample"]

patil-suraj's avatar
patil-suraj committed
190
        self.assertIsNotNone(output)
191
        expected_shape = inputs_dict["sample"].shape
patil-suraj's avatar
patil-suraj committed
192
        self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
193

patil-suraj's avatar
patil-suraj committed
194
    def test_forward_signature(self):
patil-suraj's avatar
patil-suraj committed
195
196
197
198
199
200
201
        init_dict, _ = self.prepare_init_args_and_inputs_for_common()

        model = self.model_class(**init_dict)
        signature = inspect.signature(model.forward)
        # signature.parameters is an OrderedDict => so arg_names order is deterministic
        arg_names = [*signature.parameters.keys()]

202
        expected_arg_names = ["sample", "timestep"]
patil-suraj's avatar
patil-suraj committed
203
        self.assertListEqual(arg_names[:2], expected_arg_names)
204

patil-suraj's avatar
patil-suraj committed
205
    def test_model_from_config(self):
patil-suraj's avatar
patil-suraj committed
206
207
208
209
210
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.eval()
211

patil-suraj's avatar
patil-suraj committed
212
213
214
215
216
217
218
        # test if the model can be loaded from the config
        # and has all the expected shape
        with tempfile.TemporaryDirectory() as tmpdirname:
            model.save_config(tmpdirname)
            new_model = self.model_class.from_config(tmpdirname)
            new_model.to(torch_device)
            new_model.eval()
219

patil-suraj's avatar
patil-suraj committed
220
221
222
223
224
        # check if all paramters shape are the same
        for param_name in model.state_dict().keys():
            param_1 = model.state_dict()[param_name]
            param_2 = new_model.state_dict()[param_name]
            self.assertEqual(param_1.shape, param_2.shape)
225

patil-suraj's avatar
patil-suraj committed
226
227
        with torch.no_grad():
            output_1 = model(**inputs_dict)
Patrick von Platen's avatar
Patrick von Platen committed
228
229
230
231

            if isinstance(output_1, dict):
                output_1 = output_1["sample"]

patil-suraj's avatar
patil-suraj committed
232
            output_2 = new_model(**inputs_dict)
233

Patrick von Platen's avatar
Patrick von Platen committed
234
235
236
            if isinstance(output_2, dict):
                output_2 = output_2["sample"]

patil-suraj's avatar
patil-suraj committed
237
        self.assertEqual(output_1.shape, output_2.shape)
patil-suraj's avatar
patil-suraj committed
238
239

    def test_training(self):
patil-suraj's avatar
patil-suraj committed
240
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
241

patil-suraj's avatar
patil-suraj committed
242
243
244
245
        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.train()
        output = model(**inputs_dict)
Patrick von Platen's avatar
Patrick von Platen committed
246
247
248
249

        if isinstance(output, dict):
            output = output["sample"]

250
        noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device)
patil-suraj's avatar
patil-suraj committed
251
252
        loss = torch.nn.functional.mse_loss(output, noise)
        loss.backward()
253

254
255
256
257
258
259
260
261
262
    def test_ema_training(self):
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.train()
        ema_model = EMAModel(model, device=torch_device)

        output = model(**inputs_dict)
Patrick von Platen's avatar
Patrick von Platen committed
263
264
265
266

        if isinstance(output, dict):
            output = output["sample"]

267
        noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device)
268
269
270
271
        loss = torch.nn.functional.mse_loss(output, noise)
        loss.backward()
        ema_model.step(model)

patil-suraj's avatar
patil-suraj committed
272
273

class UnetModelTests(ModelTesterMixin, unittest.TestCase):
274
    model_class = UNetUnconditionalModel
patil-suraj's avatar
patil-suraj committed
275
276
277
278
279
280
281
282
283
284

    @property
    def dummy_input(self):
        batch_size = 4
        num_channels = 3
        sizes = (32, 32)

        noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
        time_step = torch.tensor([10]).to(torch_device)

285
        return {"sample": noise, "timestep": time_step}
286

patil-suraj's avatar
patil-suraj committed
287
    @property
Patrick von Platen's avatar
Patrick von Platen committed
288
    def input_shape(self):
patil-suraj's avatar
patil-suraj committed
289
        return (3, 32, 32)
290

patil-suraj's avatar
patil-suraj committed
291
    @property
Patrick von Platen's avatar
Patrick von Platen committed
292
    def output_shape(self):
patil-suraj's avatar
patil-suraj committed
293
        return (3, 32, 32)
patil-suraj's avatar
patil-suraj committed
294
295
296

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
297
298
299
300
301
302
            "block_channels": (32, 64),
            "down_blocks": ("UNetResDownBlock2D", "UNetResAttnDownBlock2D"),
            "up_blocks": ("UNetResAttnUpBlock2D", "UNetResUpBlock2D"),
            "num_head_channels": None,
            "out_channels": 3,
            "in_channels": 3,
patil-suraj's avatar
patil-suraj committed
303
            "num_res_blocks": 2,
304
            "image_size": 32,
patil-suraj's avatar
patil-suraj committed
305
306
307
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict
308

patil-suraj's avatar
patil-suraj committed
309

Patrick von Platen's avatar
upload  
Patrick von Platen committed
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
#    TODO(Patrick) - Re-add this test after having correctly added the final VE checkpoints
#    def test_output_pretrained(self):
#        model = UNetUnconditionalModel.from_pretrained("fusing/ddpm_dummy_update", subfolder="unet")
#        model.eval()
#
#        torch.manual_seed(0)
#        if torch.cuda.is_available():
#            torch.cuda.manual_seed_all(0)
#
#        noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
#        time_step = torch.tensor([10])
#
#        with torch.no_grad():
#            output = model(noise, time_step)["sample"]
#
#        output_slice = output[0, -1, -3:, -3:].flatten()
# fmt: off
#        expected_output_slice = torch.tensor([0.2891, -0.1899, 0.2595, -0.6214, 0.0968, -0.2622, 0.4688, 0.1311, 0.0053])
# fmt: on
#        self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
330
331


patil-suraj's avatar
patil-suraj committed
332
class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
333
    model_class = UNetUnconditionalModel
patil-suraj's avatar
patil-suraj committed
334
335
336
337
338
339
340
341
342
343

    @property
    def dummy_input(self):
        batch_size = 4
        num_channels = 4
        sizes = (32, 32)

        noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
        time_step = torch.tensor([10]).to(torch_device)

344
        return {"sample": noise, "timestep": time_step}
patil-suraj's avatar
patil-suraj committed
345
346

    @property
Patrick von Platen's avatar
Patrick von Platen committed
347
    def input_shape(self):
patil-suraj's avatar
patil-suraj committed
348
349
350
        return (4, 32, 32)

    @property
Patrick von Platen's avatar
Patrick von Platen committed
351
    def output_shape(self):
patil-suraj's avatar
patil-suraj committed
352
353
354
355
356
357
358
359
        return (4, 32, 32)

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
            "image_size": 32,
            "in_channels": 4,
            "out_channels": 4,
            "num_res_blocks": 2,
Patrick von Platen's avatar
Patrick von Platen committed
360
            "block_channels": (32, 64),
361
            "num_head_channels": 32,
patil-suraj's avatar
patil-suraj committed
362
            "conv_resample": True,
363
364
            "down_blocks": ("UNetResDownBlock2D", "UNetResDownBlock2D"),
            "up_blocks": ("UNetResUpBlock2D", "UNetResUpBlock2D"),
patil-suraj's avatar
patil-suraj committed
365
366
367
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict
anton-l's avatar
anton-l committed
368

patil-suraj's avatar
patil-suraj committed
369
    def test_from_pretrained_hub(self):
Patrick von Platen's avatar
Patrick von Platen committed
370
        model, loading_info = UNetUnconditionalModel.from_pretrained(
Patrick von Platen's avatar
upload  
Patrick von Platen committed
371
            "fusing/unet-ldm-dummy-update", output_loading_info=True
Patrick von Platen's avatar
Patrick von Platen committed
372
        )
patil-suraj's avatar
patil-suraj committed
373
        self.assertIsNotNone(model)
Patrick von Platen's avatar
upload  
Patrick von Platen committed
374
        self.assertEqual(len(loading_info["missing_keys"]), 0)
patil-suraj's avatar
patil-suraj committed
375
376

        model.to(torch_device)
Patrick von Platen's avatar
Patrick von Platen committed
377
        image = model(**self.dummy_input)["sample"]
patil-suraj's avatar
patil-suraj committed
378
379
380
381

        assert image is not None, "Make sure output is not None"

    def test_output_pretrained(self):
Patrick von Platen's avatar
upload  
Patrick von Platen committed
382
        model = UNetUnconditionalModel.from_pretrained("fusing/unet-ldm-dummy-update")
patil-suraj's avatar
patil-suraj committed
383
384
385
386
387
388
389
390
391
392
        model.eval()

        torch.manual_seed(0)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)

        noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
        time_step = torch.tensor([10] * noise.shape[0])

        with torch.no_grad():
Patrick von Platen's avatar
Patrick von Platen committed
393
            output = model(noise, time_step)["sample"]
patil-suraj's avatar
patil-suraj committed
394
395
396
397
398
399
400
401

        output_slice = output[0, -1, -3:, -3:].flatten()
        # fmt: off
        expected_output_slice = torch.tensor([-13.3258, -20.1100, -15.9873, -17.6617, -23.0596, -17.9419, -13.3675, -16.1889, -12.3800])
        # fmt: on

        self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))

Patrick von Platen's avatar
Patrick von Platen committed
402

Patrick von Platen's avatar
upload  
Patrick von Platen committed
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
#    TODO(Patrick) - Re-add this test after having cleaned up LDM
#    def test_output_pretrained_spatial_transformer(self):
#        model = UNetLDMModel.from_pretrained("fusing/unet-ldm-dummy-spatial")
#        model.eval()
#
#        torch.manual_seed(0)
#        if torch.cuda.is_available():
#            torch.cuda.manual_seed_all(0)
#
#        noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
#        context = torch.ones((1, 16, 64), dtype=torch.float32)
#        time_step = torch.tensor([10] * noise.shape[0])
#
#        with torch.no_grad():
#            output = model(noise, time_step, context=context)
#
#        output_slice = output[0, -1, -3:, -3:].flatten()
# fmt: off
#        expected_output_slice = torch.tensor([61.3445, 56.9005, 29.4339, 59.5497, 60.7375, 34.1719, 48.1951, 42.6569, 25.0890])
# fmt: on
#
#        self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
#
Patrick von Platen's avatar
Patrick von Platen committed
426

patil-suraj's avatar
patil-suraj committed
427

428
class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
429
    model_class = UNetUnconditionalModel
430
431
432
433
434
435
436
437
438
439

    @property
    def dummy_input(self):
        batch_size = 4
        num_channels = 3
        sizes = (32, 32)

        noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
        time_step = torch.tensor(batch_size * [10]).to(torch_device)

440
        return {"sample": noise, "timestep": time_step}
441
442

    @property
Patrick von Platen's avatar
Patrick von Platen committed
443
    def input_shape(self):
444
445
446
        return (3, 32, 32)

    @property
Patrick von Platen's avatar
Patrick von Platen committed
447
    def output_shape(self):
448
449
450
451
        return (3, 32, 32)

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
            "block_channels": [32, 64, 64, 64],
            "in_channels": 3,
            "num_res_blocks": 1,
            "out_channels": 3,
            "time_embedding_type": "fourier",
            "resnet_eps": 1e-6,
            "mid_block_scale_factor": math.sqrt(2.0),
            "resnet_num_groups": None,
            "down_blocks": [
                "UNetResSkipDownBlock2D",
                "UNetResAttnSkipDownBlock2D",
                "UNetResSkipDownBlock2D",
                "UNetResSkipDownBlock2D",
            ],
            "up_blocks": [
                "UNetResSkipUpBlock2D",
                "UNetResSkipUpBlock2D",
                "UNetResAttnSkipUpBlock2D",
                "UNetResSkipUpBlock2D",
            ],
472
473
474
475
476
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict

    def test_from_pretrained_hub(self):
477
        model, loading_info = UNetUnconditionalModel.from_pretrained(
Patrick von Platen's avatar
upload  
Patrick von Platen committed
478
            "fusing/ncsnpp-ffhq-ve-dummy-update", output_loading_info=True
479
        )
480
        self.assertIsNotNone(model)
Patrick von Platen's avatar
upload  
Patrick von Platen committed
481
        self.assertEqual(len(loading_info["missing_keys"]), 0)
482
483
484
485
486
487

        model.to(torch_device)
        image = model(**self.dummy_input)

        assert image is not None, "Make sure output is not None"

488
    def test_output_pretrained_ve_mid(self):
Patrick von Platen's avatar
upload  
Patrick von Platen committed
489
        model = UNetUnconditionalModel.from_pretrained("google/ncsnpp-celebahq-256")
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
        model.to(torch_device)

        torch.manual_seed(0)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)

        batch_size = 4
        num_channels = 3
        sizes = (256, 256)

        noise = torch.ones((batch_size, num_channels) + sizes).to(torch_device)
        time_step = torch.tensor(batch_size * [1e-4]).to(torch_device)

        with torch.no_grad():
            output = model(noise, time_step)["sample"]

        output_slice = output[0, -3:, -3:, -1].flatten().cpu()
        # fmt: off
        expected_output_slice = torch.tensor([-4836.2231, -6487.1387, -3816.7969, -7964.9253, -10966.2842, -20043.6016, 8137.0571, 2340.3499, 544.6114])
        # fmt: on

        self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))

513
    def test_output_pretrained_ve_large(self):
Patrick von Platen's avatar
upload  
Patrick von Platen committed
514
        model = UNetUnconditionalModel.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy-update")
515
516
517
518
519
520
521
522
523
524
        model.to(torch_device)

        torch.manual_seed(0)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)

        batch_size = 4
        num_channels = 3
        sizes = (32, 32)

Patrick von Platen's avatar
Patrick von Platen committed
525
526
        noise = torch.ones((batch_size, num_channels) + sizes).to(torch_device)
        time_step = torch.tensor(batch_size * [1e-4]).to(torch_device)
527
528

        with torch.no_grad():
529
            output = model(noise, time_step)["sample"]
530
531
532

        output_slice = output[0, -3:, -3:, -1].flatten().cpu()
        # fmt: off
Patrick von Platen's avatar
Patrick von Platen committed
533
        expected_output_slice = torch.tensor([-0.0325, -0.0900, -0.0869, -0.0332, -0.0725, -0.0270, -0.0101, 0.0227, 0.0256])
534
535
        # fmt: on

Patrick von Platen's avatar
Patrick von Platen committed
536
        self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
537
538


patil-suraj's avatar
patil-suraj committed
539
540
541
542
543
544
545
546
547
548
549
class VQModelTests(ModelTesterMixin, unittest.TestCase):
    model_class = VQModel

    @property
    def dummy_input(self):
        batch_size = 4
        num_channels = 3
        sizes = (32, 32)

        image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)

550
        return {"sample": image}
patil-suraj's avatar
patil-suraj committed
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565

    @property
    def input_shape(self):
        return (3, 32, 32)

    @property
    def output_shape(self):
        return (3, 32, 32)

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
            "ch": 64,
            "out_ch": 3,
            "num_res_blocks": 1,
            "in_channels": 3,
566
            "attn_resolutions": [],
patil-suraj's avatar
patil-suraj committed
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
            "resolution": 32,
            "z_channels": 3,
            "n_embed": 256,
            "embed_dim": 3,
            "sane_index_shape": False,
            "ch_mult": (1,),
            "dropout": 0.0,
            "double_z": False,
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict

    def test_forward_signature(self):
        pass

    def test_training(self):
        pass

    def test_from_pretrained_hub(self):
        model, loading_info = VQModel.from_pretrained("fusing/vqgan-dummy", output_loading_info=True)
        self.assertIsNotNone(model)
Patrick von Platen's avatar
upload  
Patrick von Platen committed
588
        self.assertEqual(len(loading_info["missing_keys"]), 0)
patil-suraj's avatar
patil-suraj committed
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608

        model.to(torch_device)
        image = model(**self.dummy_input)

        assert image is not None, "Make sure output is not None"

    def test_output_pretrained(self):
        model = VQModel.from_pretrained("fusing/vqgan-dummy")
        model.eval()

        torch.manual_seed(0)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)

        image = torch.randn(1, model.config.in_channels, model.config.resolution, model.config.resolution)
        with torch.no_grad():
            output = model(image)

        output_slice = output[0, -1, -3:, -3:].flatten()
        # fmt: off
Patrick von Platen's avatar
up  
Patrick von Platen committed
609
        expected_output_slice = torch.tensor([-1.1321, 0.1056, 0.3505, -0.6461, -0.2014, 0.0419, -0.5763, -0.8462, -0.4218])
patil-suraj's avatar
patil-suraj committed
610
        # fmt: on
Patrick von Platen's avatar
up  
Patrick von Platen committed
611
        self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
patil-suraj's avatar
patil-suraj committed
612
613


Patrick von Platen's avatar
Patrick von Platen committed
614
class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
patil-suraj's avatar
patil-suraj committed
615
616
617
618
619
620
621
622
623
624
    model_class = AutoencoderKL

    @property
    def dummy_input(self):
        batch_size = 4
        num_channels = 3
        sizes = (32, 32)

        image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)

625
        return {"sample": image}
patil-suraj's avatar
patil-suraj committed
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640

    @property
    def input_shape(self):
        return (3, 32, 32)

    @property
    def output_shape(self):
        return (3, 32, 32)

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
            "ch": 64,
            "ch_mult": (1,),
            "embed_dim": 4,
            "in_channels": 3,
641
            "attn_resolutions": [],
patil-suraj's avatar
patil-suraj committed
642
643
644
645
646
647
648
649
650
651
652
653
654
            "num_res_blocks": 1,
            "out_ch": 3,
            "resolution": 32,
            "z_channels": 4,
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict

    def test_forward_signature(self):
        pass

    def test_training(self):
        pass
patil-suraj's avatar
patil-suraj committed
655

patil-suraj's avatar
patil-suraj committed
656
657
658
    def test_from_pretrained_hub(self):
        model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True)
        self.assertIsNotNone(model)
Patrick von Platen's avatar
upload  
Patrick von Platen committed
659
        self.assertEqual(len(loading_info["missing_keys"]), 0)
patil-suraj's avatar
patil-suraj committed
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679

        model.to(torch_device)
        image = model(**self.dummy_input)

        assert image is not None, "Make sure output is not None"

    def test_output_pretrained(self):
        model = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy")
        model.eval()

        torch.manual_seed(0)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)

        image = torch.randn(1, model.config.in_channels, model.config.resolution, model.config.resolution)
        with torch.no_grad():
            output = model(image, sample_posterior=True)

        output_slice = output[0, -1, -3:, -3:].flatten()
        # fmt: off
Patrick von Platen's avatar
up  
Patrick von Platen committed
680
        expected_output_slice = torch.tensor([-0.0814, -0.0229, -0.1320, -0.4123, -0.0366, -0.3473, 0.0438, -0.1662, 0.1750])
patil-suraj's avatar
patil-suraj committed
681
        # fmt: on
Patrick von Platen's avatar
up  
Patrick von Platen committed
682
        self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
patil-suraj's avatar
patil-suraj committed
683
684


685
686
687
class PipelineTesterMixin(unittest.TestCase):
    def test_from_pretrained_save_pretrained(self):
        # 1. Load models
688
        model = UNetUnconditionalModel(
Patrick von Platen's avatar
Patrick von Platen committed
689
690
691
692
693
694
695
            block_channels=(32, 64),
            num_res_blocks=2,
            image_size=32,
            in_channels=3,
            out_channels=3,
            down_blocks=("UNetResDownBlock2D", "UNetResAttnDownBlock2D"),
            up_blocks=("UNetResAttnUpBlock2D", "UNetResUpBlock2D"),
696
        )
Patrick von Platen's avatar
Patrick von Platen committed
697
        schedular = DDPMScheduler(num_train_timesteps=10)
698

Patrick von Platen's avatar
Patrick von Platen committed
699
        ddpm = DDPMPipeline(model, schedular)
700
701
702

        with tempfile.TemporaryDirectory() as tmpdirname:
            ddpm.save_pretrained(tmpdirname)
Patrick von Platen's avatar
Patrick von Platen committed
703
            new_ddpm = DDPMPipeline.from_pretrained(tmpdirname)
Patrick von Platen's avatar
Patrick von Platen committed
704
705

        generator = torch.manual_seed(0)
706

Patrick von Platen's avatar
Patrick von Platen committed
707
        image = ddpm(generator=generator)["sample"]
Patrick von Platen's avatar
Patrick von Platen committed
708
        generator = generator.manual_seed(0)
Patrick von Platen's avatar
Patrick von Platen committed
709
        new_image = new_ddpm(generator=generator)["sample"]
710
711
712
713
714

        assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"

    @slow
    def test_from_pretrained_hub(self):
Patrick von Platen's avatar
upload  
Patrick von Platen committed
715
        model_path = "google/ddpm-cifar10-32"
716

Patrick von Platen's avatar
Patrick von Platen committed
717
        ddpm = DDPMPipeline.from_pretrained(model_path)
718
719
        ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path)

720
721
        ddpm.scheduler.num_timesteps = 10
        ddpm_from_hub.scheduler.num_timesteps = 10
722

Patrick von Platen's avatar
Patrick von Platen committed
723
        generator = torch.manual_seed(0)
724

Patrick von Platen's avatar
Patrick von Platen committed
725
        image = ddpm(generator=generator)["sample"]
Patrick von Platen's avatar
Patrick von Platen committed
726
        generator = generator.manual_seed(0)
Patrick von Platen's avatar
Patrick von Platen committed
727
        new_image = ddpm_from_hub(generator=generator)["sample"]
728
729

        assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"
Patrick von Platen's avatar
Patrick von Platen committed
730
731
732

    @slow
    def test_ddpm_cifar10(self):
Patrick von Platen's avatar
upload  
Patrick von Platen committed
733
        model_id = "google/ddpm-cifar10-32"
Patrick von Platen's avatar
Patrick von Platen committed
734

Lysandre Debut's avatar
Lysandre Debut committed
735
        unet = UNetUnconditionalModel.from_pretrained(model_id)
736
737
        scheduler = DDPMScheduler.from_config(model_id)
        scheduler = scheduler.set_format("pt")
Patrick von Platen's avatar
Patrick von Platen committed
738

739
        ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
Patrick von Platen's avatar
Patrick von Platen committed
740
741

        generator = torch.manual_seed(0)
Patrick von Platen's avatar
Patrick von Platen committed
742
        image = ddpm(generator=generator)["sample"]
Patrick von Platen's avatar
Patrick von Platen committed
743

744
        image_slice = image[0, -3:, -3:, -1]
Patrick von Platen's avatar
Patrick von Platen committed
745

746
747
748
        assert image.shape == (1, 32, 32, 3)
        expected_slice = np.array([0.41995, 0.35885, 0.19385, 0.38475, 0.3382, 0.2647, 0.41545, 0.3582, 0.33845])
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
749
750
751

    @slow
    def test_ddim_lsun(self):
Patrick von Platen's avatar
upload  
Patrick von Platen committed
752
        model_id = "google/ddpm-ema-bedroom-256"
753

Lysandre Debut's avatar
Lysandre Debut committed
754
        unet = UNetUnconditionalModel.from_pretrained(model_id)
755
        scheduler = DDIMScheduler.from_config(model_id)
756

757
        ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
758
759

        generator = torch.manual_seed(0)
760
        image = ddpm(generator=generator)["sample"]
761

762
        image_slice = image[0, -3:, -3:, -1]
763

764
765
766
        assert image.shape == (1, 256, 256, 3)
        expected_slice = np.array([0.00605, 0.0201, 0.0344, 0.00235, 0.00185, 0.00025, 0.00215, 0.0, 0.00685])
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
Patrick von Platen's avatar
Patrick von Platen committed
767
768
769

    @slow
    def test_ddim_cifar10(self):
Patrick von Platen's avatar
upload  
Patrick von Platen committed
770
        model_id = "google/ddpm-cifar10-32"
Patrick von Platen's avatar
Patrick von Platen committed
771

Lysandre Debut's avatar
Lysandre Debut committed
772
        unet = UNetUnconditionalModel.from_pretrained(model_id)
773
        scheduler = DDIMScheduler(tensor_format="pt")
Patrick von Platen's avatar
Patrick von Platen committed
774

775
        ddim = DDIMPipeline(unet=unet, scheduler=scheduler)
776
777

        generator = torch.manual_seed(0)
778
        image = ddim(generator=generator, eta=0.0)["sample"]
Patrick von Platen's avatar
Patrick von Platen committed
779

780
        image_slice = image[0, -3:, -3:, -1]
Patrick von Platen's avatar
Patrick von Platen committed
781

782
783
784
        assert image.shape == (1, 32, 32, 3)
        expected_slice = np.array([0.17235, 0.16175, 0.16005, 0.16255, 0.1497, 0.1513, 0.15045, 0.1442, 0.1453])
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
patil-suraj's avatar
patil-suraj committed
785

Patrick von Platen's avatar
Patrick von Platen committed
786
787
    @slow
    def test_pndm_cifar10(self):
Patrick von Platen's avatar
upload  
Patrick von Platen committed
788
        model_id = "google/ddpm-cifar10-32"
Patrick von Platen's avatar
Patrick von Platen committed
789

Patrick von Platen's avatar
Patrick von Platen committed
790
        unet = UNetUnconditionalModel.from_pretrained(model_id)
791
        scheduler = PNDMScheduler(tensor_format="pt")
Patrick von Platen's avatar
Patrick von Platen committed
792

793
        pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
Patrick von Platen's avatar
Patrick von Platen committed
794
        generator = torch.manual_seed(0)
Patrick von Platen's avatar
Patrick von Platen committed
795
        image = pndm(generator=generator)["sample"]
Patrick von Platen's avatar
Patrick von Platen committed
796

797
        image_slice = image[0, -3:, -3:, -1]
Patrick von Platen's avatar
Patrick von Platen committed
798

799
800
801
        assert image.shape == (1, 32, 32, 3)
        expected_slice = np.array([0.1564, 0.14645, 0.1406, 0.14715, 0.12425, 0.14045, 0.13115, 0.12175, 0.125])
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
Patrick von Platen's avatar
Patrick von Platen committed
802

patil-suraj's avatar
patil-suraj committed
803
804
    @slow
    def test_ldm_text2img(self):
Patrick von Platen's avatar
upload  
Patrick von Platen committed
805
        ldm = LatentDiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
patil-suraj's avatar
patil-suraj committed
806
807
808

        prompt = "A painting of a squirrel eating a burger"
        generator = torch.manual_seed(0)
809
        image = ldm([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20)["sample"]
patil-suraj's avatar
patil-suraj committed
810

811
        image_slice = image[0, -3:, -3:, -1]
patil-suraj's avatar
patil-suraj committed
812

813
814
815
        assert image.shape == (1, 256, 256, 3)
        expected_slice = np.array([0.9256, 0.9340, 0.8933, 0.9361, 0.9113, 0.8727, 0.9122, 0.8745, 0.8099])
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
816

patil-suraj's avatar
patil-suraj committed
817
818
    @slow
    def test_ldm_text2img_fast(self):
Patrick von Platen's avatar
upload  
Patrick von Platen committed
819
        ldm = LatentDiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
patil-suraj's avatar
patil-suraj committed
820
821
822

        prompt = "A painting of a squirrel eating a burger"
        generator = torch.manual_seed(0)
823
        image = ldm([prompt], generator=generator, num_inference_steps=1)["sample"]
patil-suraj's avatar
patil-suraj committed
824

825
        image_slice = image[0, -3:, -3:, -1]
patil-suraj's avatar
patil-suraj committed
826

827
828
829
        assert image.shape == (1, 256, 256, 3)
        expected_slice = np.array([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344])
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
patil-suraj's avatar
patil-suraj committed
830

Patrick von Platen's avatar
Patrick von Platen committed
831
832
    @slow
    def test_score_sde_ve_pipeline(self):
833
        model = UNetUnconditionalModel.from_pretrained("google/ncsnpp-church-256")
834
835
836
837
838

        torch.manual_seed(0)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)

839
        scheduler = ScoreSdeVeScheduler.from_config("google/ncsnpp-church-256")
Patrick von Platen's avatar
Patrick von Platen committed
840
841
842

        sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler)

843
        torch.manual_seed(0)
844
        image = sde_ve(num_inference_steps=300)["sample"]
Nathan Lambert's avatar
Nathan Lambert committed
845

846
        image_slice = image[0, -3:, -3:, -1]
Patrick von Platen's avatar
Patrick von Platen committed
847

848
849
850
        assert image.shape == (1, 256, 256, 3)
        expected_slice = np.array([0.64363, 0.5868, 0.3031, 0.2284, 0.7409, 0.3216, 0.25643, 0.6557, 0.2633])
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
Patrick von Platen's avatar
Patrick von Platen committed
851

patil-suraj's avatar
patil-suraj committed
852
853
    @slow
    def test_ldm_uncond(self):
Patrick von Platen's avatar
upload  
Patrick von Platen committed
854
        ldm = LatentDiffusionUncondPipeline.from_pretrained("CompVis/ldm-celebahq-256")
patil-suraj's avatar
patil-suraj committed
855
856

        generator = torch.manual_seed(0)
857
        image = ldm(generator=generator, num_inference_steps=5)["sample"]
patil-suraj's avatar
patil-suraj committed
858

859
        image_slice = image[0, -3:, -3:, -1]
patil-suraj's avatar
patil-suraj committed
860

861
862
863
        assert image.shape == (1, 256, 256, 3)
        expected_slice = np.array([0.4399, 0.44975, 0.46825, 0.474, 0.4359, 0.4581, 0.45095, 0.4341, 0.4447])
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2