test_modeling_utils.py 37.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Patrick von Platen's avatar
Patrick von Platen committed
16

patil-suraj's avatar
patil-suraj committed
17
import inspect
18
import math
19
20
21
import tempfile
import unittest

22
import numpy as np
23
24
import torch

Patrick von Platen's avatar
Patrick von Platen committed
25
from diffusers import UNetConditionalModel  # TODO(Patrick) - need to write tests with it
Patrick von Platen's avatar
Patrick von Platen committed
26
from diffusers import (
patil-suraj's avatar
patil-suraj committed
27
    AutoencoderKL,
Patrick von Platen's avatar
Patrick von Platen committed
28
    DDIMPipeline,
29
    DDIMScheduler,
Patrick von Platen's avatar
Patrick von Platen committed
30
    DDPMPipeline,
31
    DDPMScheduler,
Patrick von Platen's avatar
Patrick von Platen committed
32
    GlidePipeline,
Patrick von Platen's avatar
Patrick von Platen committed
33
34
    GlideSuperResUNetModel,
    GlideTextToImageUNetModel,
Patrick von Platen's avatar
Patrick von Platen committed
35
    LatentDiffusionPipeline,
patil-suraj's avatar
patil-suraj committed
36
    LatentDiffusionUncondPipeline,
Patrick von Platen's avatar
Patrick von Platen committed
37
    NCSNpp,
Patrick von Platen's avatar
Patrick von Platen committed
38
    PNDMPipeline,
39
    PNDMScheduler,
Patrick von Platen's avatar
Patrick von Platen committed
40
41
    ScoreSdeVePipeline,
    ScoreSdeVeScheduler,
Patrick von Platen's avatar
Patrick von Platen committed
42
43
    ScoreSdeVpPipeline,
    ScoreSdeVpScheduler,
anton-l's avatar
anton-l committed
44
    UNetLDMModel,
45
    UNetUnconditionalModel,
patil-suraj's avatar
patil-suraj committed
46
    VQModel,
47
)
48
from diffusers.configuration_utils import ConfigMixin
Patrick von Platen's avatar
Patrick von Platen committed
49
from diffusers.pipeline_utils import DiffusionPipeline
Patrick von Platen's avatar
Patrick von Platen committed
50
from diffusers.testing_utils import floats_tensor, slow, torch_device
51
from diffusers.training_utils import EMAModel
52
53


Patrick von Platen's avatar
Patrick von Platen committed
54
torch.backends.cuda.matmul.allow_tf32 = False
55
56


57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
class ConfigTester(unittest.TestCase):
    def test_load_not_from_mixin(self):
        with self.assertRaises(ValueError):
            ConfigMixin.from_config("dummy_path")

    def test_save_load(self):
        class SampleObject(ConfigMixin):
            config_name = "config.json"

            def __init__(
                self,
                a=2,
                b=5,
                c=(2, 5),
                d="for diffusion",
                e=[1, 3],
            ):
74
                self.register_to_config(a=a, b=b, c=c, d=d, e=e)
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89

        obj = SampleObject()
        config = obj.config

        assert config["a"] == 2
        assert config["b"] == 5
        assert config["c"] == (2, 5)
        assert config["d"] == "for diffusion"
        assert config["e"] == [1, 3]

        with tempfile.TemporaryDirectory() as tmpdirname:
            obj.save_config(tmpdirname)
            new_obj = SampleObject.from_config(tmpdirname)
            new_config = new_obj.config

Patrick von Platen's avatar
Patrick von Platen committed
90
91
92
93
        # unfreeze configs
        config = dict(config)
        new_config = dict(new_config)

94
95
96
97
98
        assert config.pop("c") == (2, 5)  # instantiated as tuple
        assert new_config.pop("c") == [2, 5]  # saved & loaded as list because of json
        assert config == new_config


patil-suraj's avatar
patil-suraj committed
99
class ModelTesterMixin:
100
    def test_from_pretrained_save_pretrained(self):
patil-suraj's avatar
patil-suraj committed
101
102
103
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

        model = self.model_class(**init_dict)
Patrick von Platen's avatar
Patrick von Platen committed
104
        model.to(torch_device)
patil-suraj's avatar
patil-suraj committed
105
        model.eval()
106
107
108

        with tempfile.TemporaryDirectory() as tmpdirname:
            model.save_pretrained(tmpdirname)
patil-suraj's avatar
patil-suraj committed
109
            new_model = self.model_class.from_pretrained(tmpdirname)
Patrick von Platen's avatar
Patrick von Platen committed
110
            new_model.to(torch_device)
111

patil-suraj's avatar
patil-suraj committed
112
113
        with torch.no_grad():
            image = model(**inputs_dict)
Patrick von Platen's avatar
Patrick von Platen committed
114
115
116
            if isinstance(image, dict):
                image = image["sample"]

patil-suraj's avatar
patil-suraj committed
117
            new_image = new_model(**inputs_dict)
118

Patrick von Platen's avatar
Patrick von Platen committed
119
120
121
            if isinstance(new_image, dict):
                new_image = new_image["sample"]

patil-suraj's avatar
patil-suraj committed
122
        max_diff = (image - new_image).abs().sum().item()
Patrick von Platen's avatar
Patrick von Platen committed
123
        self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
124

patil-suraj's avatar
patil-suraj committed
125
    def test_determinism(self):
patil-suraj's avatar
patil-suraj committed
126
127
128
129
130
131
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.eval()
        with torch.no_grad():
            first = model(**inputs_dict)
Patrick von Platen's avatar
Patrick von Platen committed
132
133
134
            if isinstance(first, dict):
                first = first["sample"]

patil-suraj's avatar
patil-suraj committed
135
            second = model(**inputs_dict)
Patrick von Platen's avatar
Patrick von Platen committed
136
137
            if isinstance(second, dict):
                second = second["sample"]
patil-suraj's avatar
patil-suraj committed
138
139
140
141
142
143
144

        out_1 = first.cpu().numpy()
        out_2 = second.cpu().numpy()
        out_1 = out_1[~np.isnan(out_1)]
        out_2 = out_2[~np.isnan(out_2)]
        max_diff = np.amax(np.abs(out_1 - out_2))
        self.assertLessEqual(max_diff, 1e-5)
145

patil-suraj's avatar
patil-suraj committed
146
    def test_output(self):
patil-suraj's avatar
patil-suraj committed
147
148
149
150
151
152
153
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.eval()

        with torch.no_grad():
            output = model(**inputs_dict)
154

Patrick von Platen's avatar
Patrick von Platen committed
155
156
157
            if isinstance(output, dict):
                output = output["sample"]

patil-suraj's avatar
patil-suraj committed
158
        self.assertIsNotNone(output)
159
        expected_shape = inputs_dict["sample"].shape
patil-suraj's avatar
patil-suraj committed
160
        self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
161

patil-suraj's avatar
patil-suraj committed
162
    def test_forward_signature(self):
patil-suraj's avatar
patil-suraj committed
163
164
165
166
167
168
169
        init_dict, _ = self.prepare_init_args_and_inputs_for_common()

        model = self.model_class(**init_dict)
        signature = inspect.signature(model.forward)
        # signature.parameters is an OrderedDict => so arg_names order is deterministic
        arg_names = [*signature.parameters.keys()]

170
        expected_arg_names = ["sample", "timestep"]
patil-suraj's avatar
patil-suraj committed
171
        self.assertListEqual(arg_names[:2], expected_arg_names)
172

patil-suraj's avatar
patil-suraj committed
173
    def test_model_from_config(self):
patil-suraj's avatar
patil-suraj committed
174
175
176
177
178
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.eval()
179

patil-suraj's avatar
patil-suraj committed
180
181
182
183
184
185
186
        # test if the model can be loaded from the config
        # and has all the expected shape
        with tempfile.TemporaryDirectory() as tmpdirname:
            model.save_config(tmpdirname)
            new_model = self.model_class.from_config(tmpdirname)
            new_model.to(torch_device)
            new_model.eval()
187

patil-suraj's avatar
patil-suraj committed
188
189
190
191
192
        # check if all paramters shape are the same
        for param_name in model.state_dict().keys():
            param_1 = model.state_dict()[param_name]
            param_2 = new_model.state_dict()[param_name]
            self.assertEqual(param_1.shape, param_2.shape)
193

patil-suraj's avatar
patil-suraj committed
194
195
        with torch.no_grad():
            output_1 = model(**inputs_dict)
Patrick von Platen's avatar
Patrick von Platen committed
196
197
198
199

            if isinstance(output_1, dict):
                output_1 = output_1["sample"]

patil-suraj's avatar
patil-suraj committed
200
            output_2 = new_model(**inputs_dict)
201

Patrick von Platen's avatar
Patrick von Platen committed
202
203
204
            if isinstance(output_2, dict):
                output_2 = output_2["sample"]

patil-suraj's avatar
patil-suraj committed
205
        self.assertEqual(output_1.shape, output_2.shape)
patil-suraj's avatar
patil-suraj committed
206
207

    def test_training(self):
patil-suraj's avatar
patil-suraj committed
208
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
209

patil-suraj's avatar
patil-suraj committed
210
211
212
213
        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.train()
        output = model(**inputs_dict)
Patrick von Platen's avatar
Patrick von Platen committed
214
215
216
217

        if isinstance(output, dict):
            output = output["sample"]

218
        noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device)
patil-suraj's avatar
patil-suraj committed
219
220
        loss = torch.nn.functional.mse_loss(output, noise)
        loss.backward()
221

222
223
224
225
226
227
228
229
230
    def test_ema_training(self):
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.train()
        ema_model = EMAModel(model, device=torch_device)

        output = model(**inputs_dict)
Patrick von Platen's avatar
Patrick von Platen committed
231
232
233
234

        if isinstance(output, dict):
            output = output["sample"]

235
        noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device)
236
237
238
239
        loss = torch.nn.functional.mse_loss(output, noise)
        loss.backward()
        ema_model.step(model)

patil-suraj's avatar
patil-suraj committed
240
241

class UnetModelTests(ModelTesterMixin, unittest.TestCase):
242
    model_class = UNetUnconditionalModel
patil-suraj's avatar
patil-suraj committed
243
244
245
246
247
248
249
250
251
252

    @property
    def dummy_input(self):
        batch_size = 4
        num_channels = 3
        sizes = (32, 32)

        noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
        time_step = torch.tensor([10]).to(torch_device)

253
        return {"sample": noise, "timestep": time_step}
254

patil-suraj's avatar
patil-suraj committed
255
    @property
Patrick von Platen's avatar
Patrick von Platen committed
256
    def input_shape(self):
patil-suraj's avatar
patil-suraj committed
257
        return (3, 32, 32)
258

patil-suraj's avatar
patil-suraj committed
259
    @property
Patrick von Platen's avatar
Patrick von Platen committed
260
    def output_shape(self):
patil-suraj's avatar
patil-suraj committed
261
        return (3, 32, 32)
patil-suraj's avatar
patil-suraj committed
262
263
264
265
266

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
            "ch": 32,
            "ch_mult": (1, 2),
267
268
269
270
271
272
            "block_channels": (32, 64),
            "down_blocks": ("UNetResDownBlock2D", "UNetResAttnDownBlock2D"),
            "up_blocks": ("UNetResAttnUpBlock2D", "UNetResUpBlock2D"),
            "num_head_channels": None,
            "out_channels": 3,
            "in_channels": 3,
patil-suraj's avatar
patil-suraj committed
273
274
275
            "num_res_blocks": 2,
            "attn_resolutions": (16,),
            "resolution": 32,
276
            "image_size": 32,
patil-suraj's avatar
patil-suraj committed
277
278
279
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict
280

patil-suraj's avatar
patil-suraj committed
281
    def test_from_pretrained_hub(self):
282
283
284
        model, loading_info = UNetUnconditionalModel.from_pretrained(
            "fusing/ddpm_dummy", output_loading_info=True, ddpm=True
        )
patil-suraj's avatar
patil-suraj committed
285
        self.assertIsNotNone(model)
Patrick von Platen's avatar
Patrick von Platen committed
286
        # self.assertEqual(len(loading_info["missing_keys"]), 0)
patil-suraj's avatar
patil-suraj committed
287

patil-suraj's avatar
patil-suraj committed
288
        model.to(torch_device)
Patrick von Platen's avatar
Patrick von Platen committed
289
        image = model(**self.dummy_input)["sample"]
patil-suraj's avatar
patil-suraj committed
290
291

        assert image is not None, "Make sure output is not None"
292

patil-suraj's avatar
patil-suraj committed
293
    def test_output_pretrained(self):
Patrick von Platen's avatar
Patrick von Platen committed
294
295
296
297
298
299
300
301
302
303
304
        model = UNetUnconditionalModel.from_pretrained("fusing/ddpm_dummy", ddpm=True)
        model.eval()

        torch.manual_seed(0)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)

        noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
        time_step = torch.tensor([10])

        with torch.no_grad():
Patrick von Platen's avatar
Patrick von Platen committed
305
            output = model(noise, time_step)["sample"]
Patrick von Platen's avatar
Patrick von Platen committed
306
307
308
309
310
311
312

        output_slice = output[0, -1, -3:, -3:].flatten()
        # fmt: off
        expected_output_slice = torch.tensor([0.2891, -0.1899, 0.2595, -0.6214, 0.0968, -0.2622, 0.4688, 0.1311, 0.0053])
        # fmt: on
        self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))

313

Patrick von Platen's avatar
Patrick von Platen committed
314
315
class GlideSuperResUNetTests(ModelTesterMixin, unittest.TestCase):
    model_class = GlideSuperResUNetModel
patil-suraj's avatar
patil-suraj committed
316
317
318
319
320
321
322
323
324
325
326
327

    @property
    def dummy_input(self):
        batch_size = 4
        num_channels = 6
        sizes = (32, 32)
        low_res_size = (4, 4)

        noise = torch.randn((batch_size, num_channels // 2) + sizes).to(torch_device)
        low_res = torch.randn((batch_size, 3) + low_res_size).to(torch_device)
        time_step = torch.tensor([10] * noise.shape[0], device=torch_device)

328
        return {"sample": noise, "timestep": time_step, "low_res": low_res}
329

patil-suraj's avatar
patil-suraj committed
330
    @property
Patrick von Platen's avatar
Patrick von Platen committed
331
    def input_shape(self):
patil-suraj's avatar
patil-suraj committed
332
        return (3, 32, 32)
333

patil-suraj's avatar
patil-suraj committed
334
    @property
Patrick von Platen's avatar
Patrick von Platen committed
335
    def output_shape(self):
patil-suraj's avatar
patil-suraj committed
336
        return (6, 32, 32)
337

patil-suraj's avatar
patil-suraj committed
338
339
340
    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
            "attention_resolutions": (2,),
341
            "channel_mult": (1, 2),
patil-suraj's avatar
patil-suraj committed
342
343
344
345
346
347
348
349
            "in_channels": 6,
            "out_channels": 6,
            "model_channels": 32,
            "num_head_channels": 8,
            "num_heads_upsample": 1,
            "num_res_blocks": 2,
            "resblock_updown": True,
            "resolution": 32,
350
            "use_scale_shift_norm": True,
patil-suraj's avatar
patil-suraj committed
351
352
353
354
355
356
357
358
359
360
361
362
363
364
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict

    def test_output(self):
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.eval()

        with torch.no_grad():
            output = model(**inputs_dict)

        output, _ = torch.split(output, 3, dim=1)
365

patil-suraj's avatar
patil-suraj committed
366
        self.assertIsNotNone(output)
367
        expected_shape = inputs_dict["sample"].shape
patil-suraj's avatar
patil-suraj committed
368
        self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
369

patil-suraj's avatar
patil-suraj committed
370
    def test_from_pretrained_hub(self):
Patrick von Platen's avatar
Patrick von Platen committed
371
        model, loading_info = GlideSuperResUNetModel.from_pretrained(
372
373
            "fusing/glide-super-res-dummy", output_loading_info=True
        )
patil-suraj's avatar
patil-suraj committed
374
        self.assertIsNotNone(model)
Patrick von Platen's avatar
Patrick von Platen committed
375
        # self.assertEqual(len(loading_info["missing_keys"]), 0)
patil-suraj's avatar
patil-suraj committed
376
377
378
379
380

        model.to(torch_device)
        image = model(**self.dummy_input)

        assert image is not None, "Make sure output is not None"
381

patil-suraj's avatar
patil-suraj committed
382
    def test_output_pretrained(self):
Patrick von Platen's avatar
Patrick von Platen committed
383
        model = GlideSuperResUNetModel.from_pretrained("fusing/glide-super-res-dummy")
patil-suraj's avatar
patil-suraj committed
384
385
386
387

        torch.manual_seed(0)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)
388

389
        noise = torch.randn(1, 3, 64, 64)
patil-suraj's avatar
patil-suraj committed
390
391
        low_res = torch.randn(1, 3, 4, 4)
        time_step = torch.tensor([42] * noise.shape[0])
392

patil-suraj's avatar
patil-suraj committed
393
394
        with torch.no_grad():
            output = model(noise, time_step, low_res)
395

patil-suraj's avatar
patil-suraj committed
396
397
398
        output, _ = torch.split(output, 3, dim=1)
        output_slice = output[0, -1, -3:, -3:].flatten()
        # fmt: off
399
        expected_output_slice = torch.tensor([-22.8782, -23.2652, -15.3966, -22.8034, -23.3159, -15.5640, -15.3970, -15.4614, - 10.4370])
patil-suraj's avatar
patil-suraj committed
400
401
        # fmt: on
        self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
patil-suraj's avatar
patil-suraj committed
402

anton-l's avatar
anton-l committed
403

Patrick von Platen's avatar
Patrick von Platen committed
404
405
class GlideTextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase):
    model_class = GlideTextToImageUNetModel
406
407
408
409
410
411
412
413
414
415
416
417
418

    @property
    def dummy_input(self):
        batch_size = 4
        num_channels = 3
        sizes = (32, 32)
        transformer_dim = 32
        seq_len = 16

        noise = torch.randn((batch_size, num_channels) + sizes).to(torch_device)
        emb = torch.randn((batch_size, seq_len, transformer_dim)).to(torch_device)
        time_step = torch.tensor([10] * noise.shape[0], device=torch_device)

419
        return {"sample": noise, "timestep": time_step, "transformer_out": emb}
420
421

    @property
Patrick von Platen's avatar
Patrick von Platen committed
422
    def input_shape(self):
423
424
425
        return (3, 32, 32)

    @property
Patrick von Platen's avatar
Patrick von Platen committed
426
    def output_shape(self):
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
        return (6, 32, 32)

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
            "attention_resolutions": (2,),
            "channel_mult": (1, 2),
            "in_channels": 3,
            "out_channels": 6,
            "model_channels": 32,
            "num_head_channels": 8,
            "num_heads_upsample": 1,
            "num_res_blocks": 2,
            "resblock_updown": True,
            "resolution": 32,
            "use_scale_shift_norm": True,
            "transformer_dim": 32,
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict

    def test_output(self):
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.eval()

        with torch.no_grad():
            output = model(**inputs_dict)

        output, _ = torch.split(output, 3, dim=1)

        self.assertIsNotNone(output)
459
        expected_shape = inputs_dict["sample"].shape
460
461
462
        self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")

    def test_from_pretrained_hub(self):
Patrick von Platen's avatar
Patrick von Platen committed
463
        model, loading_info = GlideTextToImageUNetModel.from_pretrained(
464
465
466
            "fusing/unet-glide-text2im-dummy", output_loading_info=True
        )
        self.assertIsNotNone(model)
Patrick von Platen's avatar
Patrick von Platen committed
467
        # self.assertEqual(len(loading_info["missing_keys"]), 0)
468
469
470
471
472
473
474

        model.to(torch_device)
        image = model(**self.dummy_input)

        assert image is not None, "Make sure output is not None"

    def test_output_pretrained(self):
Patrick von Platen's avatar
Patrick von Platen committed
475
        model = GlideTextToImageUNetModel.from_pretrained("fusing/unet-glide-text2im-dummy")
476
477
478
479
480
481
482
483
484
485
486

        torch.manual_seed(0)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)

        noise = torch.randn((1, model.config.in_channels, model.config.resolution, model.config.resolution)).to(
            torch_device
        )
        emb = torch.randn((1, 16, model.config.transformer_dim)).to(torch_device)
        time_step = torch.tensor([10] * noise.shape[0], device=torch_device)

Patrick von Platen's avatar
Patrick von Platen committed
487
        model.to(torch_device)
488
489
490
491
        with torch.no_grad():
            output = model(noise, time_step, emb)

        output, _ = torch.split(output, 3, dim=1)
Patrick von Platen's avatar
Patrick von Platen committed
492
        output_slice = output[0, -1, -3:, -3:].cpu().flatten()
493
        # fmt: off
Patrick von Platen's avatar
Patrick von Platen committed
494
        expected_output_slice = torch.tensor([2.7766, -10.3558, -14.9149, -0.9376, -14.9175, -17.7679, -5.5565, -12.9521, -12.9845])
495
496
497
498
        # fmt: on
        self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))


patil-suraj's avatar
patil-suraj committed
499
class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
500
    model_class = UNetUnconditionalModel
patil-suraj's avatar
patil-suraj committed
501
502
503
504
505
506
507
508
509
510

    @property
    def dummy_input(self):
        batch_size = 4
        num_channels = 4
        sizes = (32, 32)

        noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
        time_step = torch.tensor([10]).to(torch_device)

511
        return {"sample": noise, "timestep": time_step}
patil-suraj's avatar
patil-suraj committed
512
513

    @property
Patrick von Platen's avatar
Patrick von Platen committed
514
    def input_shape(self):
patil-suraj's avatar
patil-suraj committed
515
516
517
        return (4, 32, 32)

    @property
Patrick von Platen's avatar
Patrick von Platen committed
518
    def output_shape(self):
patil-suraj's avatar
patil-suraj committed
519
520
521
522
523
524
525
526
527
        return (4, 32, 32)

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
            "image_size": 32,
            "in_channels": 4,
            "out_channels": 4,
            "num_res_blocks": 2,
            "attention_resolutions": (16,),
Patrick von Platen's avatar
Patrick von Platen committed
528
            "block_channels": (32, 64),
529
            "num_head_channels": 32,
patil-suraj's avatar
patil-suraj committed
530
            "conv_resample": True,
531
532
            "down_blocks": ("UNetResDownBlock2D", "UNetResDownBlock2D"),
            "up_blocks": ("UNetResUpBlock2D", "UNetResUpBlock2D"),
Patrick von Platen's avatar
Patrick von Platen committed
533
            "ldm": True,
patil-suraj's avatar
patil-suraj committed
534
535
536
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict
anton-l's avatar
anton-l committed
537

patil-suraj's avatar
patil-suraj committed
538
    def test_from_pretrained_hub(self):
Patrick von Platen's avatar
Patrick von Platen committed
539
540
541
        model, loading_info = UNetUnconditionalModel.from_pretrained(
            "fusing/unet-ldm-dummy", output_loading_info=True, ldm=True
        )
patil-suraj's avatar
patil-suraj committed
542
        self.assertIsNotNone(model)
Patrick von Platen's avatar
Patrick von Platen committed
543
        # self.assertEqual(len(loading_info["missing_keys"]), 0)
patil-suraj's avatar
patil-suraj committed
544
545

        model.to(torch_device)
Patrick von Platen's avatar
Patrick von Platen committed
546
        image = model(**self.dummy_input)["sample"]
patil-suraj's avatar
patil-suraj committed
547
548
549
550

        assert image is not None, "Make sure output is not None"

    def test_output_pretrained(self):
Patrick von Platen's avatar
Patrick von Platen committed
551
        model = UNetUnconditionalModel.from_pretrained("fusing/unet-ldm-dummy", ldm=True)
patil-suraj's avatar
patil-suraj committed
552
553
554
555
556
557
558
559
560
561
        model.eval()

        torch.manual_seed(0)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)

        noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
        time_step = torch.tensor([10] * noise.shape[0])

        with torch.no_grad():
Patrick von Platen's avatar
Patrick von Platen committed
562
            output = model(noise, time_step)["sample"]
patil-suraj's avatar
patil-suraj committed
563
564
565
566
567
568
569
570

        output_slice = output[0, -1, -3:, -3:].flatten()
        # fmt: off
        expected_output_slice = torch.tensor([-13.3258, -20.1100, -15.9873, -17.6617, -23.0596, -17.9419, -13.3675, -16.1889, -12.3800])
        # fmt: on

        self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))

Patrick von Platen's avatar
Patrick von Platen committed
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
    def test_output_pretrained_spatial_transformer(self):
        model = UNetLDMModel.from_pretrained("fusing/unet-ldm-dummy-spatial")
        model.eval()

        torch.manual_seed(0)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)

        noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
        context = torch.ones((1, 16, 64), dtype=torch.float32)
        time_step = torch.tensor([10] * noise.shape[0])

        with torch.no_grad():
            output = model(noise, time_step, context=context)

        output_slice = output[0, -1, -3:, -3:].flatten()
        # fmt: off
        expected_output_slice = torch.tensor([61.3445, 56.9005, 29.4339, 59.5497, 60.7375, 34.1719, 48.1951, 42.6569, 25.0890])
        # fmt: on

        self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))

patil-suraj's avatar
patil-suraj committed
593

594
class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
595
    model_class = UNetUnconditionalModel
596
597
598
599
600
601
602
603
604
605

    @property
    def dummy_input(self):
        batch_size = 4
        num_channels = 3
        sizes = (32, 32)

        noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
        time_step = torch.tensor(batch_size * [10]).to(torch_device)

606
        return {"sample": noise, "timestep": time_step}
607
608

    @property
Patrick von Platen's avatar
Patrick von Platen committed
609
    def input_shape(self):
610
611
612
        return (3, 32, 32)

    @property
Patrick von Platen's avatar
Patrick von Platen committed
613
    def output_shape(self):
614
615
616
617
        return (3, 32, 32)

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
            "block_channels": [32, 64, 64, 64],
            "in_channels": 3,
            "num_res_blocks": 1,
            "out_channels": 3,
            "time_embedding_type": "fourier",
            "resnet_eps": 1e-6,
            "mid_block_scale_factor": math.sqrt(2.0),
            "resnet_num_groups": None,
            "down_blocks": [
                "UNetResSkipDownBlock2D",
                "UNetResAttnSkipDownBlock2D",
                "UNetResSkipDownBlock2D",
                "UNetResSkipDownBlock2D",
            ],
            "up_blocks": [
                "UNetResSkipUpBlock2D",
                "UNetResSkipUpBlock2D",
                "UNetResAttnSkipUpBlock2D",
                "UNetResSkipUpBlock2D",
            ],
638
639
640
641
642
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict

    def test_from_pretrained_hub(self):
643
644
645
        model, loading_info = UNetUnconditionalModel.from_pretrained(
            "fusing/ncsnpp-ffhq-ve-dummy", sde=True, output_loading_info=True
        )
646
        self.assertIsNotNone(model)
Patrick von Platen's avatar
Patrick von Platen committed
647
        # self.assertEqual(len(loading_info["missing_keys"]), 0)
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666

        model.to(torch_device)
        image = model(**self.dummy_input)

        assert image is not None, "Make sure output is not None"

    def test_output_pretrained_ve_small(self):
        model = NCSNpp.from_pretrained("fusing/ncsnpp-cifar10-ve-dummy")
        model.eval()
        model.to(torch_device)

        torch.manual_seed(0)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)

        batch_size = 4
        num_channels = 3
        sizes = (32, 32)

Patrick von Platen's avatar
Patrick von Platen committed
667
668
        noise = torch.ones((batch_size, num_channels) + sizes).to(torch_device)
        time_step = torch.tensor(batch_size * [1e-4]).to(torch_device)
669
670
671
672
673
674

        with torch.no_grad():
            output = model(noise, time_step)

        output_slice = output[0, -3:, -3:, -1].flatten().cpu()
        # fmt: off
Patrick von Platen's avatar
Patrick von Platen committed
675
        expected_output_slice = torch.tensor([0.1315, 0.0741, 0.0393, 0.0455, 0.0556, 0.0180, -0.0832, -0.0644, -0.0856])
676
677
        # fmt: on

Patrick von Platen's avatar
Patrick von Platen committed
678
        self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
679

680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
    def test_output_pretrained_ve_mid(self):
        model = UNetUnconditionalModel.from_pretrained("fusing/celebahq_256-ncsnpp-ve", sde=True)
        model.to(torch_device)

        torch.manual_seed(0)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)

        batch_size = 4
        num_channels = 3
        sizes = (256, 256)

        noise = torch.ones((batch_size, num_channels) + sizes).to(torch_device)
        time_step = torch.tensor(batch_size * [1e-4]).to(torch_device)

        with torch.no_grad():
            output = model(noise, time_step)["sample"]

        output_slice = output[0, -3:, -3:, -1].flatten().cpu()
        # fmt: off
        expected_output_slice = torch.tensor([-4836.2231, -6487.1387, -3816.7969, -7964.9253, -10966.2842, -20043.6016, 8137.0571, 2340.3499, 544.6114])
        # fmt: on

        self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))

705
    def test_output_pretrained_ve_large(self):
706
        model = UNetUnconditionalModel.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy", sde=True)
707
708
709
710
711
712
713
714
715
716
        model.to(torch_device)

        torch.manual_seed(0)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)

        batch_size = 4
        num_channels = 3
        sizes = (32, 32)

Patrick von Platen's avatar
Patrick von Platen committed
717
718
        noise = torch.ones((batch_size, num_channels) + sizes).to(torch_device)
        time_step = torch.tensor(batch_size * [1e-4]).to(torch_device)
719
720

        with torch.no_grad():
721
            output = model(noise, time_step)["sample"]
722
723
724

        output_slice = output[0, -3:, -3:, -1].flatten().cpu()
        # fmt: off
Patrick von Platen's avatar
Patrick von Platen committed
725
        expected_output_slice = torch.tensor([-0.0325, -0.0900, -0.0869, -0.0332, -0.0725, -0.0270, -0.0101, 0.0227, 0.0256])
726
727
        # fmt: on

Patrick von Platen's avatar
Patrick von Platen committed
728
        self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
729
730

    def test_output_pretrained_vp(self):
Patrick von Platen's avatar
Patrick von Platen committed
731
        model = NCSNpp.from_pretrained("fusing/cifar10-ddpmpp-vp")
732
733
734
735
736
737
738
739
740
741
        model.to(torch_device)

        torch.manual_seed(0)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)

        batch_size = 4
        num_channels = 3
        sizes = (32, 32)

Patrick von Platen's avatar
Patrick von Platen committed
742
        noise = torch.randn((batch_size, num_channels) + sizes).to(torch_device)
Patrick von Platen's avatar
Patrick von Platen committed
743
        time_step = torch.tensor(batch_size * [9.0]).to(torch_device)
744
745
746
747
748
749

        with torch.no_grad():
            output = model(noise, time_step)

        output_slice = output[0, -3:, -3:, -1].flatten().cpu()
        # fmt: off
Patrick von Platen's avatar
Patrick von Platen committed
750
        expected_output_slice = torch.tensor([0.3303, -0.2275, -2.8872, -0.1309, -1.2861, 3.4567, -1.0083, 2.5325, -1.3866])
751
752
        # fmt: on

Patrick von Platen's avatar
Patrick von Platen committed
753
        self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
754
755


patil-suraj's avatar
patil-suraj committed
756
757
758
759
760
761
762
763
764
765
766
class VQModelTests(ModelTesterMixin, unittest.TestCase):
    model_class = VQModel

    @property
    def dummy_input(self):
        batch_size = 4
        num_channels = 3
        sizes = (32, 32)

        image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)

767
        return {"sample": image}
patil-suraj's avatar
patil-suraj committed
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804

    @property
    def input_shape(self):
        return (3, 32, 32)

    @property
    def output_shape(self):
        return (3, 32, 32)

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
            "ch": 64,
            "out_ch": 3,
            "num_res_blocks": 1,
            "attn_resolutions": [],
            "in_channels": 3,
            "resolution": 32,
            "z_channels": 3,
            "n_embed": 256,
            "embed_dim": 3,
            "sane_index_shape": False,
            "ch_mult": (1,),
            "dropout": 0.0,
            "double_z": False,
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict

    def test_forward_signature(self):
        pass

    def test_training(self):
        pass

    def test_from_pretrained_hub(self):
        model, loading_info = VQModel.from_pretrained("fusing/vqgan-dummy", output_loading_info=True)
        self.assertIsNotNone(model)
Patrick von Platen's avatar
Patrick von Platen committed
805
        # self.assertEqual(len(loading_info["missing_keys"]), 0)
patil-suraj's avatar
patil-suraj committed
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825

        model.to(torch_device)
        image = model(**self.dummy_input)

        assert image is not None, "Make sure output is not None"

    def test_output_pretrained(self):
        model = VQModel.from_pretrained("fusing/vqgan-dummy")
        model.eval()

        torch.manual_seed(0)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)

        image = torch.randn(1, model.config.in_channels, model.config.resolution, model.config.resolution)
        with torch.no_grad():
            output = model(image)

        output_slice = output[0, -1, -3:, -3:].flatten()
        # fmt: off
Patrick von Platen's avatar
up  
Patrick von Platen committed
826
        expected_output_slice = torch.tensor([-1.1321, 0.1056, 0.3505, -0.6461, -0.2014, 0.0419, -0.5763, -0.8462, -0.4218])
patil-suraj's avatar
patil-suraj committed
827
        # fmt: on
Patrick von Platen's avatar
up  
Patrick von Platen committed
828
        self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
patil-suraj's avatar
patil-suraj committed
829
830


Patrick von Platen's avatar
Patrick von Platen committed
831
class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
patil-suraj's avatar
patil-suraj committed
832
833
834
835
836
837
838
839
840
841
    model_class = AutoencoderKL

    @property
    def dummy_input(self):
        batch_size = 4
        num_channels = 3
        sizes = (32, 32)

        image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)

842
        return {"sample": image}
patil-suraj's avatar
patil-suraj committed
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861

    @property
    def input_shape(self):
        return (3, 32, 32)

    @property
    def output_shape(self):
        return (3, 32, 32)

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
            "ch": 64,
            "ch_mult": (1,),
            "embed_dim": 4,
            "in_channels": 3,
            "num_res_blocks": 1,
            "out_ch": 3,
            "resolution": 32,
            "z_channels": 4,
patil-suraj's avatar
patil-suraj committed
862
            "attn_resolutions": [],
patil-suraj's avatar
patil-suraj committed
863
864
865
866
867
868
869
870
871
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict

    def test_forward_signature(self):
        pass

    def test_training(self):
        pass
patil-suraj's avatar
patil-suraj committed
872

patil-suraj's avatar
patil-suraj committed
873
874
875
    def test_from_pretrained_hub(self):
        model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True)
        self.assertIsNotNone(model)
Patrick von Platen's avatar
Patrick von Platen committed
876
        # self.assertEqual(len(loading_info["missing_keys"]), 0)
patil-suraj's avatar
patil-suraj committed
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896

        model.to(torch_device)
        image = model(**self.dummy_input)

        assert image is not None, "Make sure output is not None"

    def test_output_pretrained(self):
        model = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy")
        model.eval()

        torch.manual_seed(0)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)

        image = torch.randn(1, model.config.in_channels, model.config.resolution, model.config.resolution)
        with torch.no_grad():
            output = model(image, sample_posterior=True)

        output_slice = output[0, -1, -3:, -3:].flatten()
        # fmt: off
Patrick von Platen's avatar
up  
Patrick von Platen committed
897
        expected_output_slice = torch.tensor([-0.0814, -0.0229, -0.1320, -0.4123, -0.0366, -0.3473, 0.0438, -0.1662, 0.1750])
patil-suraj's avatar
patil-suraj committed
898
        # fmt: on
Patrick von Platen's avatar
up  
Patrick von Platen committed
899
        self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
patil-suraj's avatar
patil-suraj committed
900
901


902
903
904
class PipelineTesterMixin(unittest.TestCase):
    def test_from_pretrained_save_pretrained(self):
        # 1. Load models
905
        model = UNetUnconditionalModel(
Patrick von Platen's avatar
Patrick von Platen committed
906
907
908
909
910
911
912
913
            block_channels=(32, 64),
            num_res_blocks=2,
            attn_resolutions=(16,),
            image_size=32,
            in_channels=3,
            out_channels=3,
            down_blocks=("UNetResDownBlock2D", "UNetResAttnDownBlock2D"),
            up_blocks=("UNetResAttnUpBlock2D", "UNetResUpBlock2D"),
914
        )
Patrick von Platen's avatar
Patrick von Platen committed
915
        schedular = DDPMScheduler(num_train_timesteps=10)
916

Patrick von Platen's avatar
Patrick von Platen committed
917
        ddpm = DDPMPipeline(model, schedular)
918
919
920

        with tempfile.TemporaryDirectory() as tmpdirname:
            ddpm.save_pretrained(tmpdirname)
Patrick von Platen's avatar
Patrick von Platen committed
921
            new_ddpm = DDPMPipeline.from_pretrained(tmpdirname)
Patrick von Platen's avatar
Patrick von Platen committed
922
923

        generator = torch.manual_seed(0)
924

patil-suraj's avatar
patil-suraj committed
925
        image = ddpm(generator=generator)
Patrick von Platen's avatar
Patrick von Platen committed
926
        generator = generator.manual_seed(0)
patil-suraj's avatar
patil-suraj committed
927
        new_image = new_ddpm(generator=generator)
928
929
930
931
932

        assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"

    @slow
    def test_from_pretrained_hub(self):
Lysandre Debut's avatar
Lysandre Debut committed
933
        model_path = "google/ddpm-cifar10"
934

Patrick von Platen's avatar
Patrick von Platen committed
935
        ddpm = DDPMPipeline.from_pretrained(model_path)
936
937
        ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path)

938
939
        ddpm.scheduler.num_timesteps = 10
        ddpm_from_hub.scheduler.num_timesteps = 10
940

Patrick von Platen's avatar
Patrick von Platen committed
941
        generator = torch.manual_seed(0)
942

patil-suraj's avatar
patil-suraj committed
943
        image = ddpm(generator=generator)
Patrick von Platen's avatar
Patrick von Platen committed
944
        generator = generator.manual_seed(0)
patil-suraj's avatar
patil-suraj committed
945
        new_image = ddpm_from_hub(generator=generator)
946
947

        assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"
Patrick von Platen's avatar
Patrick von Platen committed
948
949
950

    @slow
    def test_ddpm_cifar10(self):
Lysandre Debut's avatar
Lysandre Debut committed
951
        model_id = "google/ddpm-cifar10"
Patrick von Platen's avatar
Patrick von Platen committed
952

Lysandre Debut's avatar
Lysandre Debut committed
953
        unet = UNetUnconditionalModel.from_pretrained(model_id)
954
955
        scheduler = DDPMScheduler.from_config(model_id)
        scheduler = scheduler.set_format("pt")
Patrick von Platen's avatar
Patrick von Platen committed
956

957
        ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
Patrick von Platen's avatar
Patrick von Platen committed
958
959

        generator = torch.manual_seed(0)
Patrick von Platen's avatar
Patrick von Platen committed
960
961
962
963
964
        image = ddpm(generator=generator)

        image_slice = image[0, -1, -3:, -3:].cpu()

        assert image.shape == (1, 3, 32, 32)
Patrick von Platen's avatar
Patrick von Platen committed
965
        expected_slice = torch.tensor(
966
967
968
969
970
971
            [-0.1601, -0.2823, -0.6123, -0.2305, -0.3236, -0.4706, -0.1691, -0.2836, -0.3231]
        )
        assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2

    @slow
    def test_ddim_lsun(self):
972
        model_id = "google/ddpm-ema-bedroom-256"
973

Lysandre Debut's avatar
Lysandre Debut committed
974
        unet = UNetUnconditionalModel.from_pretrained(model_id)
975
        scheduler = DDIMScheduler.from_config(model_id)
976

977
        ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
978
979

        generator = torch.manual_seed(0)
980
        image = ddpm(generator=generator)["sample"]
981
982
983
984
985
986

        image_slice = image[0, -1, -3:, -3:].cpu()

        assert image.shape == (1, 3, 256, 256)
        expected_slice = torch.tensor(
            [-0.9879, -0.9598, -0.9312, -0.9953, -0.9963, -0.9995, -0.9957, -1.0000, -0.9863]
Patrick von Platen's avatar
Patrick von Platen committed
987
        )
Patrick von Platen's avatar
Patrick von Platen committed
988
989
990
991
        assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2

    @slow
    def test_ddim_cifar10(self):
Lysandre Debut's avatar
Lysandre Debut committed
992
        model_id = "google/ddpm-cifar10"
Patrick von Platen's avatar
Patrick von Platen committed
993

Lysandre Debut's avatar
Lysandre Debut committed
994
        unet = UNetUnconditionalModel.from_pretrained(model_id)
995
        scheduler = DDIMScheduler(tensor_format="pt")
Patrick von Platen's avatar
Patrick von Platen committed
996

997
        ddim = DDIMPipeline(unet=unet, scheduler=scheduler)
998
999

        generator = torch.manual_seed(0)
1000
        image = ddim(generator=generator, eta=0.0)["sample"]
Patrick von Platen's avatar
Patrick von Platen committed
1001
1002
1003
1004

        image_slice = image[0, -1, -3:, -3:].cpu()

        assert image.shape == (1, 3, 32, 32)
Patrick von Platen's avatar
Patrick von Platen committed
1005
        expected_slice = torch.tensor(
1006
            [-0.6553, -0.6765, -0.6799, -0.6749, -0.7006, -0.6974, -0.6991, -0.7116, -0.7094]
Patrick von Platen's avatar
Patrick von Platen committed
1007
        )
Patrick von Platen's avatar
Patrick von Platen committed
1008
        assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
patil-suraj's avatar
patil-suraj committed
1009

Patrick von Platen's avatar
Patrick von Platen committed
1010
1011
    @slow
    def test_pndm_cifar10(self):
Lysandre Debut's avatar
Lysandre Debut committed
1012
        model_id = "google/ddpm-cifar10"
Patrick von Platen's avatar
Patrick von Platen committed
1013

Patrick von Platen's avatar
Patrick von Platen committed
1014
        unet = UNetUnconditionalModel.from_pretrained(model_id)
1015
        scheduler = PNDMScheduler(tensor_format="pt")
Patrick von Platen's avatar
Patrick von Platen committed
1016

1017
        pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
Patrick von Platen's avatar
Patrick von Platen committed
1018
        generator = torch.manual_seed(0)
Patrick von Platen's avatar
Patrick von Platen committed
1019
1020
1021
1022
1023
1024
        image = pndm(generator=generator)

        image_slice = image[0, -1, -3:, -3:].cpu()

        assert image.shape == (1, 3, 32, 32)
        expected_slice = torch.tensor(
Patrick von Platen's avatar
Patrick von Platen committed
1025
            [-0.6872, -0.7071, -0.7188, -0.7057, -0.7515, -0.7191, -0.7377, -0.7565, -0.7500]
Patrick von Platen's avatar
Patrick von Platen committed
1026
1027
1028
        )
        assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2

patil-suraj's avatar
patil-suraj committed
1029
1030
    @slow
    def test_ldm_text2img(self):
1031
        ldm = LatentDiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large")
patil-suraj's avatar
patil-suraj committed
1032
1033
1034
1035
1036
1037
1038
1039
1040

        prompt = "A painting of a squirrel eating a burger"
        generator = torch.manual_seed(0)
        image = ldm([prompt], generator=generator, num_inference_steps=20)

        image_slice = image[0, -1, -3:, -3:].cpu()

        assert image.shape == (1, 3, 256, 256)
        expected_slice = torch.tensor([0.7295, 0.7358, 0.7256, 0.7435, 0.7095, 0.6884, 0.7325, 0.6921, 0.6458])
Patrick von Platen's avatar
update  
Patrick von Platen committed
1041
        assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
1042

patil-suraj's avatar
patil-suraj committed
1043
1044
    @slow
    def test_ldm_text2img_fast(self):
1045
        ldm = LatentDiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large")
patil-suraj's avatar
patil-suraj committed
1046
1047
1048

        prompt = "A painting of a squirrel eating a burger"
        generator = torch.manual_seed(0)
1049
        image = ldm([prompt], generator=generator, num_inference_steps=1)
patil-suraj's avatar
patil-suraj committed
1050
1051
1052
1053

        image_slice = image[0, -1, -3:, -3:].cpu()

        assert image.shape == (1, 3, 256, 256)
patil-suraj's avatar
patil-suraj committed
1054
        expected_slice = torch.tensor([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344])
patil-suraj's avatar
patil-suraj committed
1055
1056
        assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2

anton-l's avatar
anton-l committed
1057
1058
1059
    @slow
    def test_glide_text2img(self):
        model_id = "fusing/glide-base"
Patrick von Platen's avatar
Patrick von Platen committed
1060
        glide = GlidePipeline.from_pretrained(model_id)
anton-l's avatar
anton-l committed
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071

        prompt = "a pencil sketch of a corgi"
        generator = torch.manual_seed(0)
        image = glide(prompt, generator=generator, num_inference_steps_upscale=20)

        image_slice = image[0, :3, :3, -1].cpu()

        assert image.shape == (1, 256, 256, 3)
        expected_slice = torch.tensor([0.7119, 0.7073, 0.6460, 0.7780, 0.7423, 0.6926, 0.7378, 0.7189, 0.7784])
        assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2

Patrick von Platen's avatar
Patrick von Platen committed
1072
1073
    @slow
    def test_score_sde_ve_pipeline(self):
1074
1075
1076
1077
1078
1079
        model = UNetUnconditionalModel.from_pretrained("fusing/ffhq_ncsnpp", sde=True)

        torch.manual_seed(0)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)

Patrick von Platen's avatar
Patrick von Platen committed
1080
1081
1082
1083
        scheduler = ScoreSdeVeScheduler.from_config("fusing/ffhq_ncsnpp")

        sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler)

1084
        torch.manual_seed(0)
Patrick von Platen's avatar
Patrick von Platen committed
1085
1086
        image = sde_ve(num_inference_steps=2)

Patrick von Platen's avatar
Patrick von Platen committed
1087
        if model.device.type == "cpu":
Nathan Lambert's avatar
Nathan Lambert committed
1088
1089
1090
1091
1092
1093
1094
            # patrick's cpu
            expected_image_sum = 3384805888.0
            expected_image_mean = 1076.00085

            # m1 mbp
            # expected_image_sum = 3384805376.0
            # expected_image_mean = 1076.000610351562
Patrick von Platen's avatar
Patrick von Platen committed
1095
1096
        else:
            expected_image_sum = 3382849024.0
Nathan Lambert's avatar
Nathan Lambert committed
1097
            expected_image_mean = 1075.3788
Patrick von Platen's avatar
Patrick von Platen committed
1098
1099
1100
1101

        assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2
        assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4

Patrick von Platen's avatar
Patrick von Platen committed
1102
1103
    @slow
    def test_score_sde_vp_pipeline(self):
Patrick von Platen's avatar
Patrick von Platen committed
1104
1105
        model = NCSNpp.from_pretrained("fusing/cifar10-ddpmpp-vp")
        scheduler = ScoreSdeVpScheduler.from_config("fusing/cifar10-ddpmpp-vp")
Patrick von Platen's avatar
Patrick von Platen committed
1106
1107
1108
1109
1110
1111
1112
1113
1114

        sde_vp = ScoreSdeVpPipeline(model=model, scheduler=scheduler)

        torch.manual_seed(0)
        image = sde_vp(num_inference_steps=10)

        expected_image_sum = 4183.2012
        expected_image_mean = 1.3617

Nathan Lambert's avatar
Nathan Lambert committed
1115
1116
1117
1118
        # on m1 mbp
        # expected_image_sum = 4318.6729
        # expected_image_mean = 1.4058

Patrick von Platen's avatar
Patrick von Platen committed
1119
1120
1121
        assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2
        assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4

patil-suraj's avatar
patil-suraj committed
1122
1123
    @slow
    def test_ldm_uncond(self):
1124
        ldm = LatentDiffusionUncondPipeline.from_pretrained("CompVis/ldm-celebahq-256")
patil-suraj's avatar
patil-suraj committed
1125
1126

        generator = torch.manual_seed(0)
1127
        image = ldm(generator=generator, num_inference_steps=5)["sample"]
patil-suraj's avatar
patil-suraj committed
1128
1129
1130
1131

        image_slice = image[0, -1, -3:, -3:].cpu()

        assert image.shape == (1, 3, 256, 256)
patil-suraj's avatar
patil-suraj committed
1132
1133
1134
        expected_slice = torch.tensor(
            [-0.1202, -0.1005, -0.0635, -0.0520, -0.1282, -0.0838, -0.0981, -0.1318, -0.1106]
        )
patil-suraj's avatar
patil-suraj committed
1135
        assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2