test_modeling_utils.py 33.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

patil-suraj's avatar
patil-suraj committed
16
import inspect
17
import math
18
19
20
import tempfile
import unittest

21
import numpy as np
22
23
import torch

24
import PIL
Patrick von Platen's avatar
Patrick von Platen committed
25
from diffusers import UNet2DConditionModel  # noqa: F401 TODO(Patrick) - need to write tests with it
Patrick von Platen's avatar
Patrick von Platen committed
26
from diffusers import (
patil-suraj's avatar
patil-suraj committed
27
    AutoencoderKL,
Patrick von Platen's avatar
Patrick von Platen committed
28
    DDIMPipeline,
29
    DDIMScheduler,
Patrick von Platen's avatar
Patrick von Platen committed
30
    DDPMPipeline,
31
    DDPMScheduler,
32
33
    KarrasVePipeline,
    KarrasVeScheduler,
Patrick von Platen's avatar
Patrick von Platen committed
34
35
    LDMPipeline,
    LDMTextToImagePipeline,
Patrick von Platen's avatar
Patrick von Platen committed
36
    PNDMPipeline,
37
    PNDMScheduler,
Patrick von Platen's avatar
Patrick von Platen committed
38
39
    ScoreSdeVePipeline,
    ScoreSdeVeScheduler,
40
    StableDiffusionPipeline,
Patrick von Platen's avatar
Patrick von Platen committed
41
    UNet2DModel,
patil-suraj's avatar
patil-suraj committed
42
    VQModel,
43
)
44
from diffusers.configuration_utils import ConfigMixin, register_to_config
Patrick von Platen's avatar
Patrick von Platen committed
45
from diffusers.pipeline_utils import DiffusionPipeline
Patrick von Platen's avatar
Patrick von Platen committed
46
from diffusers.testing_utils import floats_tensor, slow, torch_device
47
from diffusers.training_utils import EMAModel
48
49


Patrick von Platen's avatar
Patrick von Platen committed
50
torch.backends.cuda.matmul.allow_tf32 = False
51
52


53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
class SampleObject(ConfigMixin):
    config_name = "config.json"

    @register_to_config
    def __init__(
        self,
        a=2,
        b=5,
        c=(2, 5),
        d="for diffusion",
        e=[1, 3],
    ):
        pass


68
69
70
71
72
class ConfigTester(unittest.TestCase):
    def test_load_not_from_mixin(self):
        with self.assertRaises(ValueError):
            ConfigMixin.from_config("dummy_path")

73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
    def test_register_to_config(self):
        obj = SampleObject()
        config = obj.config
        assert config["a"] == 2
        assert config["b"] == 5
        assert config["c"] == (2, 5)
        assert config["d"] == "for diffusion"
        assert config["e"] == [1, 3]

        # init ignore private arguments
        obj = SampleObject(_name_or_path="lalala")
        config = obj.config
        assert config["a"] == 2
        assert config["b"] == 5
        assert config["c"] == (2, 5)
        assert config["d"] == "for diffusion"
        assert config["e"] == [1, 3]
90

91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
        # can override default
        obj = SampleObject(c=6)
        config = obj.config
        assert config["a"] == 2
        assert config["b"] == 5
        assert config["c"] == 6
        assert config["d"] == "for diffusion"
        assert config["e"] == [1, 3]

        # can use positional arguments.
        obj = SampleObject(1, c=6)
        config = obj.config
        assert config["a"] == 1
        assert config["b"] == 5
        assert config["c"] == 6
        assert config["d"] == "for diffusion"
        assert config["e"] == [1, 3]

    def test_save_load(self):
110
111
112
113
114
115
116
117
118
119
120
121
122
123
        obj = SampleObject()
        config = obj.config

        assert config["a"] == 2
        assert config["b"] == 5
        assert config["c"] == (2, 5)
        assert config["d"] == "for diffusion"
        assert config["e"] == [1, 3]

        with tempfile.TemporaryDirectory() as tmpdirname:
            obj.save_config(tmpdirname)
            new_obj = SampleObject.from_config(tmpdirname)
            new_config = new_obj.config

Patrick von Platen's avatar
Patrick von Platen committed
124
125
126
127
        # unfreeze configs
        config = dict(config)
        new_config = dict(new_config)

128
129
130
131
132
        assert config.pop("c") == (2, 5)  # instantiated as tuple
        assert new_config.pop("c") == [2, 5]  # saved & loaded as list because of json
        assert config == new_config


patil-suraj's avatar
patil-suraj committed
133
class ModelTesterMixin:
134
    def test_from_pretrained_save_pretrained(self):
patil-suraj's avatar
patil-suraj committed
135
136
137
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

        model = self.model_class(**init_dict)
Patrick von Platen's avatar
Patrick von Platen committed
138
        model.to(torch_device)
patil-suraj's avatar
patil-suraj committed
139
        model.eval()
140
141
142

        with tempfile.TemporaryDirectory() as tmpdirname:
            model.save_pretrained(tmpdirname)
patil-suraj's avatar
patil-suraj committed
143
            new_model = self.model_class.from_pretrained(tmpdirname)
Patrick von Platen's avatar
Patrick von Platen committed
144
            new_model.to(torch_device)
145

patil-suraj's avatar
patil-suraj committed
146
147
        with torch.no_grad():
            image = model(**inputs_dict)
Patrick von Platen's avatar
Patrick von Platen committed
148
149
150
            if isinstance(image, dict):
                image = image["sample"]

patil-suraj's avatar
patil-suraj committed
151
            new_image = new_model(**inputs_dict)
152

Patrick von Platen's avatar
Patrick von Platen committed
153
154
155
            if isinstance(new_image, dict):
                new_image = new_image["sample"]

patil-suraj's avatar
patil-suraj committed
156
        max_diff = (image - new_image).abs().sum().item()
Patrick von Platen's avatar
Patrick von Platen committed
157
        self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
158

patil-suraj's avatar
patil-suraj committed
159
    def test_determinism(self):
patil-suraj's avatar
patil-suraj committed
160
161
162
163
164
165
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.eval()
        with torch.no_grad():
            first = model(**inputs_dict)
Patrick von Platen's avatar
Patrick von Platen committed
166
167
168
            if isinstance(first, dict):
                first = first["sample"]

patil-suraj's avatar
patil-suraj committed
169
            second = model(**inputs_dict)
Patrick von Platen's avatar
Patrick von Platen committed
170
171
            if isinstance(second, dict):
                second = second["sample"]
patil-suraj's avatar
patil-suraj committed
172
173
174
175
176
177
178

        out_1 = first.cpu().numpy()
        out_2 = second.cpu().numpy()
        out_1 = out_1[~np.isnan(out_1)]
        out_2 = out_2[~np.isnan(out_2)]
        max_diff = np.amax(np.abs(out_1 - out_2))
        self.assertLessEqual(max_diff, 1e-5)
179

patil-suraj's avatar
patil-suraj committed
180
    def test_output(self):
patil-suraj's avatar
patil-suraj committed
181
182
183
184
185
186
187
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.eval()

        with torch.no_grad():
            output = model(**inputs_dict)
188

Patrick von Platen's avatar
Patrick von Platen committed
189
190
191
            if isinstance(output, dict):
                output = output["sample"]

patil-suraj's avatar
patil-suraj committed
192
        self.assertIsNotNone(output)
193
        expected_shape = inputs_dict["sample"].shape
patil-suraj's avatar
patil-suraj committed
194
        self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
195

patil-suraj's avatar
patil-suraj committed
196
    def test_forward_signature(self):
patil-suraj's avatar
patil-suraj committed
197
198
199
200
201
202
203
        init_dict, _ = self.prepare_init_args_and_inputs_for_common()

        model = self.model_class(**init_dict)
        signature = inspect.signature(model.forward)
        # signature.parameters is an OrderedDict => so arg_names order is deterministic
        arg_names = [*signature.parameters.keys()]

204
        expected_arg_names = ["sample", "timestep"]
patil-suraj's avatar
patil-suraj committed
205
        self.assertListEqual(arg_names[:2], expected_arg_names)
206

patil-suraj's avatar
patil-suraj committed
207
    def test_model_from_config(self):
patil-suraj's avatar
patil-suraj committed
208
209
210
211
212
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.eval()
213

patil-suraj's avatar
patil-suraj committed
214
215
216
217
218
219
220
        # test if the model can be loaded from the config
        # and has all the expected shape
        with tempfile.TemporaryDirectory() as tmpdirname:
            model.save_config(tmpdirname)
            new_model = self.model_class.from_config(tmpdirname)
            new_model.to(torch_device)
            new_model.eval()
221

patil-suraj's avatar
patil-suraj committed
222
223
224
225
226
        # check if all paramters shape are the same
        for param_name in model.state_dict().keys():
            param_1 = model.state_dict()[param_name]
            param_2 = new_model.state_dict()[param_name]
            self.assertEqual(param_1.shape, param_2.shape)
227

patil-suraj's avatar
patil-suraj committed
228
229
        with torch.no_grad():
            output_1 = model(**inputs_dict)
Patrick von Platen's avatar
Patrick von Platen committed
230
231
232
233

            if isinstance(output_1, dict):
                output_1 = output_1["sample"]

patil-suraj's avatar
patil-suraj committed
234
            output_2 = new_model(**inputs_dict)
235

Patrick von Platen's avatar
Patrick von Platen committed
236
237
238
            if isinstance(output_2, dict):
                output_2 = output_2["sample"]

patil-suraj's avatar
patil-suraj committed
239
        self.assertEqual(output_1.shape, output_2.shape)
patil-suraj's avatar
patil-suraj committed
240
241

    def test_training(self):
patil-suraj's avatar
patil-suraj committed
242
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
243

patil-suraj's avatar
patil-suraj committed
244
245
246
247
        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.train()
        output = model(**inputs_dict)
Patrick von Platen's avatar
Patrick von Platen committed
248
249
250
251

        if isinstance(output, dict):
            output = output["sample"]

252
        noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device)
patil-suraj's avatar
patil-suraj committed
253
254
        loss = torch.nn.functional.mse_loss(output, noise)
        loss.backward()
255

256
257
258
259
260
261
262
263
264
    def test_ema_training(self):
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.train()
        ema_model = EMAModel(model, device=torch_device)

        output = model(**inputs_dict)
Patrick von Platen's avatar
Patrick von Platen committed
265
266
267
268

        if isinstance(output, dict):
            output = output["sample"]

269
        noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device)
270
271
272
273
        loss = torch.nn.functional.mse_loss(output, noise)
        loss.backward()
        ema_model.step(model)

patil-suraj's avatar
patil-suraj committed
274
275

class UnetModelTests(ModelTesterMixin, unittest.TestCase):
Patrick von Platen's avatar
Patrick von Platen committed
276
    model_class = UNet2DModel
patil-suraj's avatar
patil-suraj committed
277
278
279
280
281
282
283
284
285
286

    @property
    def dummy_input(self):
        batch_size = 4
        num_channels = 3
        sizes = (32, 32)

        noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
        time_step = torch.tensor([10]).to(torch_device)

287
        return {"sample": noise, "timestep": time_step}
288

patil-suraj's avatar
patil-suraj committed
289
    @property
Patrick von Platen's avatar
Patrick von Platen committed
290
    def input_shape(self):
patil-suraj's avatar
patil-suraj committed
291
        return (3, 32, 32)
292

patil-suraj's avatar
patil-suraj committed
293
    @property
Patrick von Platen's avatar
Patrick von Platen committed
294
    def output_shape(self):
patil-suraj's avatar
patil-suraj committed
295
        return (3, 32, 32)
patil-suraj's avatar
patil-suraj committed
296
297
298

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
Patrick von Platen's avatar
Patrick von Platen committed
299
300
301
302
            "block_out_channels": (32, 64),
            "down_block_types": ("DownBlock2D", "AttnDownBlock2D"),
            "up_block_types": ("AttnUpBlock2D", "UpBlock2D"),
            "attention_head_dim": None,
303
304
            "out_channels": 3,
            "in_channels": 3,
Patrick von Platen's avatar
Patrick von Platen committed
305
306
            "layers_per_block": 2,
            "sample_size": 32,
patil-suraj's avatar
patil-suraj committed
307
308
309
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict
310

patil-suraj's avatar
patil-suraj committed
311

Patrick von Platen's avatar
upload  
Patrick von Platen committed
312
313
#    TODO(Patrick) - Re-add this test after having correctly added the final VE checkpoints
#    def test_output_pretrained(self):
Patrick von Platen's avatar
Patrick von Platen committed
314
#        model = UNet2DModel.from_pretrained("fusing/ddpm_dummy_update", subfolder="unet")
Patrick von Platen's avatar
upload  
Patrick von Platen committed
315
316
317
318
319
320
#        model.eval()
#
#        torch.manual_seed(0)
#        if torch.cuda.is_available():
#            torch.cuda.manual_seed_all(0)
#
Patrick von Platen's avatar
Patrick von Platen committed
321
#        noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
Patrick von Platen's avatar
upload  
Patrick von Platen committed
322
323
324
325
326
327
328
329
330
331
#        time_step = torch.tensor([10])
#
#        with torch.no_grad():
#            output = model(noise, time_step)["sample"]
#
#        output_slice = output[0, -1, -3:, -3:].flatten()
# fmt: off
#        expected_output_slice = torch.tensor([0.2891, -0.1899, 0.2595, -0.6214, 0.0968, -0.2622, 0.4688, 0.1311, 0.0053])
# fmt: on
#        self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
332
333


patil-suraj's avatar
patil-suraj committed
334
class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
Patrick von Platen's avatar
Patrick von Platen committed
335
    model_class = UNet2DModel
patil-suraj's avatar
patil-suraj committed
336
337
338
339
340
341
342
343
344
345

    @property
    def dummy_input(self):
        batch_size = 4
        num_channels = 4
        sizes = (32, 32)

        noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
        time_step = torch.tensor([10]).to(torch_device)

346
        return {"sample": noise, "timestep": time_step}
patil-suraj's avatar
patil-suraj committed
347
348

    @property
Patrick von Platen's avatar
Patrick von Platen committed
349
    def input_shape(self):
patil-suraj's avatar
patil-suraj committed
350
351
352
        return (4, 32, 32)

    @property
Patrick von Platen's avatar
Patrick von Platen committed
353
    def output_shape(self):
patil-suraj's avatar
patil-suraj committed
354
355
356
357
        return (4, 32, 32)

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
Patrick von Platen's avatar
Patrick von Platen committed
358
            "sample_size": 32,
patil-suraj's avatar
patil-suraj committed
359
360
            "in_channels": 4,
            "out_channels": 4,
Patrick von Platen's avatar
Patrick von Platen committed
361
362
363
364
365
            "layers_per_block": 2,
            "block_out_channels": (32, 64),
            "attention_head_dim": 32,
            "down_block_types": ("DownBlock2D", "DownBlock2D"),
            "up_block_types": ("UpBlock2D", "UpBlock2D"),
patil-suraj's avatar
patil-suraj committed
366
367
368
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict
anton-l's avatar
anton-l committed
369

patil-suraj's avatar
patil-suraj committed
370
    def test_from_pretrained_hub(self):
Patrick von Platen's avatar
Patrick von Platen committed
371
        model, loading_info = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
Patrick von Platen's avatar
Patrick von Platen committed
372

patil-suraj's avatar
patil-suraj committed
373
        self.assertIsNotNone(model)
Patrick von Platen's avatar
upload  
Patrick von Platen committed
374
        self.assertEqual(len(loading_info["missing_keys"]), 0)
patil-suraj's avatar
patil-suraj committed
375
376

        model.to(torch_device)
Patrick von Platen's avatar
Patrick von Platen committed
377
        image = model(**self.dummy_input)["sample"]
patil-suraj's avatar
patil-suraj committed
378
379
380
381

        assert image is not None, "Make sure output is not None"

    def test_output_pretrained(self):
Patrick von Platen's avatar
Patrick von Platen committed
382
        model = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update")
patil-suraj's avatar
patil-suraj committed
383
384
385
386
387
388
        model.eval()

        torch.manual_seed(0)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)

Patrick von Platen's avatar
Patrick von Platen committed
389
        noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
patil-suraj's avatar
patil-suraj committed
390
391
392
        time_step = torch.tensor([10] * noise.shape[0])

        with torch.no_grad():
Patrick von Platen's avatar
Patrick von Platen committed
393
            output = model(noise, time_step)["sample"]
patil-suraj's avatar
patil-suraj committed
394
395
396
397
398
399
400
401

        output_slice = output[0, -1, -3:, -3:].flatten()
        # fmt: off
        expected_output_slice = torch.tensor([-13.3258, -20.1100, -15.9873, -17.6617, -23.0596, -17.9419, -13.3675, -16.1889, -12.3800])
        # fmt: on

        self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))

Patrick von Platen's avatar
Patrick von Platen committed
402

Patrick von Platen's avatar
upload  
Patrick von Platen committed
403
404
405
406
407
408
409
410
411
#    TODO(Patrick) - Re-add this test after having cleaned up LDM
#    def test_output_pretrained_spatial_transformer(self):
#        model = UNetLDMModel.from_pretrained("fusing/unet-ldm-dummy-spatial")
#        model.eval()
#
#        torch.manual_seed(0)
#        if torch.cuda.is_available():
#            torch.cuda.manual_seed_all(0)
#
Patrick von Platen's avatar
Patrick von Platen committed
412
#        noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
Patrick von Platen's avatar
upload  
Patrick von Platen committed
413
414
415
416
417
418
419
420
421
422
423
424
425
#        context = torch.ones((1, 16, 64), dtype=torch.float32)
#        time_step = torch.tensor([10] * noise.shape[0])
#
#        with torch.no_grad():
#            output = model(noise, time_step, context=context)
#
#        output_slice = output[0, -1, -3:, -3:].flatten()
# fmt: off
#        expected_output_slice = torch.tensor([61.3445, 56.9005, 29.4339, 59.5497, 60.7375, 34.1719, 48.1951, 42.6569, 25.0890])
# fmt: on
#
#        self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
#
Patrick von Platen's avatar
Patrick von Platen committed
426

patil-suraj's avatar
patil-suraj committed
427

428
class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
Patrick von Platen's avatar
Patrick von Platen committed
429
    model_class = UNet2DModel
430
431

    @property
Patrick von Platen's avatar
Patrick von Platen committed
432
    def dummy_input(self, sizes=(32, 32)):
433
434
435
436
437
438
        batch_size = 4
        num_channels = 3

        noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
        time_step = torch.tensor(batch_size * [10]).to(torch_device)

439
        return {"sample": noise, "timestep": time_step}
440
441

    @property
Patrick von Platen's avatar
Patrick von Platen committed
442
    def input_shape(self):
443
444
445
        return (3, 32, 32)

    @property
Patrick von Platen's avatar
Patrick von Platen committed
446
    def output_shape(self):
447
448
449
450
        return (3, 32, 32)

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
Patrick von Platen's avatar
Patrick von Platen committed
451
            "block_out_channels": [32, 64, 64, 64],
452
            "in_channels": 3,
Patrick von Platen's avatar
Patrick von Platen committed
453
            "layers_per_block": 1,
454
455
            "out_channels": 3,
            "time_embedding_type": "fourier",
Patrick von Platen's avatar
Patrick von Platen committed
456
            "norm_eps": 1e-6,
457
            "mid_block_scale_factor": math.sqrt(2.0),
Patrick von Platen's avatar
Patrick von Platen committed
458
459
460
461
462
463
            "norm_num_groups": None,
            "down_block_types": [
                "SkipDownBlock2D",
                "AttnSkipDownBlock2D",
                "SkipDownBlock2D",
                "SkipDownBlock2D",
464
            ],
Patrick von Platen's avatar
Patrick von Platen committed
465
466
467
468
469
            "up_block_types": [
                "SkipUpBlock2D",
                "SkipUpBlock2D",
                "AttnSkipUpBlock2D",
                "SkipUpBlock2D",
470
            ],
471
472
473
474
475
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict

    def test_from_pretrained_hub(self):
Patrick von Platen's avatar
Patrick von Platen committed
476
        model, loading_info = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256", output_loading_info=True)
477
        self.assertIsNotNone(model)
Patrick von Platen's avatar
upload  
Patrick von Platen committed
478
        self.assertEqual(len(loading_info["missing_keys"]), 0)
479
480

        model.to(torch_device)
Patrick von Platen's avatar
Patrick von Platen committed
481
482
483
484
        inputs = self.dummy_input
        noise = floats_tensor((4, 3) + (256, 256)).to(torch_device)
        inputs["sample"] = noise
        image = model(**inputs)
485
486
487

        assert image is not None, "Make sure output is not None"

488
    def test_output_pretrained_ve_mid(self):
Patrick von Platen's avatar
Patrick von Platen committed
489
        model = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256")
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
        model.to(torch_device)

        torch.manual_seed(0)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)

        batch_size = 4
        num_channels = 3
        sizes = (256, 256)

        noise = torch.ones((batch_size, num_channels) + sizes).to(torch_device)
        time_step = torch.tensor(batch_size * [1e-4]).to(torch_device)

        with torch.no_grad():
            output = model(noise, time_step)["sample"]

        output_slice = output[0, -3:, -3:, -1].flatten().cpu()
        # fmt: off
        expected_output_slice = torch.tensor([-4836.2231, -6487.1387, -3816.7969, -7964.9253, -10966.2842, -20043.6016, 8137.0571, 2340.3499, 544.6114])
        # fmt: on

        self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))

513
    def test_output_pretrained_ve_large(self):
Patrick von Platen's avatar
Patrick von Platen committed
514
        model = UNet2DModel.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy-update")
515
516
517
518
519
520
521
522
523
524
        model.to(torch_device)

        torch.manual_seed(0)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)

        batch_size = 4
        num_channels = 3
        sizes = (32, 32)

Patrick von Platen's avatar
Patrick von Platen committed
525
526
        noise = torch.ones((batch_size, num_channels) + sizes).to(torch_device)
        time_step = torch.tensor(batch_size * [1e-4]).to(torch_device)
527
528

        with torch.no_grad():
529
            output = model(noise, time_step)["sample"]
530
531
532

        output_slice = output[0, -3:, -3:, -1].flatten().cpu()
        # fmt: off
Patrick von Platen's avatar
Patrick von Platen committed
533
        expected_output_slice = torch.tensor([-0.0325, -0.0900, -0.0869, -0.0332, -0.0725, -0.0270, -0.0101, 0.0227, 0.0256])
534
535
        # fmt: on

Patrick von Platen's avatar
Patrick von Platen committed
536
        self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
537
538


patil-suraj's avatar
patil-suraj committed
539
540
541
542
class VQModelTests(ModelTesterMixin, unittest.TestCase):
    model_class = VQModel

    @property
Patrick von Platen's avatar
Patrick von Platen committed
543
    def dummy_input(self, sizes=(32, 32)):
patil-suraj's avatar
patil-suraj committed
544
545
546
547
548
        batch_size = 4
        num_channels = 3

        image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)

549
        return {"sample": image}
patil-suraj's avatar
patil-suraj committed
550
551
552
553
554
555
556
557
558
559
560

    @property
    def input_shape(self):
        return (3, 32, 32)

    @property
    def output_shape(self):
        return (3, 32, 32)

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
561
            "block_out_channels": [32, 64],
patil-suraj's avatar
patil-suraj committed
562
            "in_channels": 3,
563
            "out_channels": 3,
564
565
            "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
            "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
566
            "latent_channels": 3,
patil-suraj's avatar
patil-suraj committed
567
568
569
570
571
572
573
574
575
576
577
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict

    def test_forward_signature(self):
        pass

    def test_training(self):
        pass

    def test_from_pretrained_hub(self):
Patrick von Platen's avatar
Patrick von Platen committed
578
        model, loading_info = VQModel.from_pretrained("fusing/vqgan-dummy", output_loading_info=True)
patil-suraj's avatar
patil-suraj committed
579
        self.assertIsNotNone(model)
Patrick von Platen's avatar
upload  
Patrick von Platen committed
580
        self.assertEqual(len(loading_info["missing_keys"]), 0)
patil-suraj's avatar
patil-suraj committed
581
582
583
584
585
586
587

        model.to(torch_device)
        image = model(**self.dummy_input)

        assert image is not None, "Make sure output is not None"

    def test_output_pretrained(self):
Patrick von Platen's avatar
Patrick von Platen committed
588
        model = VQModel.from_pretrained("fusing/vqgan-dummy")
patil-suraj's avatar
patil-suraj committed
589
590
591
592
593
594
        model.eval()

        torch.manual_seed(0)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)

595
        image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
patil-suraj's avatar
patil-suraj committed
596
597
598
599
600
        with torch.no_grad():
            output = model(image)

        output_slice = output[0, -1, -3:, -3:].flatten()
        # fmt: off
601
        expected_output_slice = torch.tensor([-0.0153, -0.4044, -0.1880, -0.5161, -0.2418, -0.4072, -0.1612, -0.0633, -0.0143])
patil-suraj's avatar
patil-suraj committed
602
        # fmt: on
Patrick von Platen's avatar
up  
Patrick von Platen committed
603
        self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
patil-suraj's avatar
patil-suraj committed
604
605


Patrick von Platen's avatar
Patrick von Platen committed
606
class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
patil-suraj's avatar
patil-suraj committed
607
608
609
610
611
612
613
614
615
616
    model_class = AutoencoderKL

    @property
    def dummy_input(self):
        batch_size = 4
        num_channels = 3
        sizes = (32, 32)

        image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)

617
        return {"sample": image}
patil-suraj's avatar
patil-suraj committed
618
619
620
621
622
623
624
625
626
627
628

    @property
    def input_shape(self):
        return (3, 32, 32)

    @property
    def output_shape(self):
        return (3, 32, 32)

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
629
            "block_out_channels": [32, 64],
630
631
            "in_channels": 3,
            "out_channels": 3,
632
633
            "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
            "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
634
635
            "latent_channels": 4,
        }
patil-suraj's avatar
patil-suraj committed
636
637
638
639
640
641
642
643
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict

    def test_forward_signature(self):
        pass

    def test_training(self):
        pass
patil-suraj's avatar
patil-suraj committed
644

patil-suraj's avatar
patil-suraj committed
645
    def test_from_pretrained_hub(self):
Patrick von Platen's avatar
Patrick von Platen committed
646
        model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True)
patil-suraj's avatar
patil-suraj committed
647
        self.assertIsNotNone(model)
Patrick von Platen's avatar
upload  
Patrick von Platen committed
648
        self.assertEqual(len(loading_info["missing_keys"]), 0)
patil-suraj's avatar
patil-suraj committed
649
650
651
652
653
654
655

        model.to(torch_device)
        image = model(**self.dummy_input)

        assert image is not None, "Make sure output is not None"

    def test_output_pretrained(self):
Patrick von Platen's avatar
Patrick von Platen committed
656
        model = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy")
patil-suraj's avatar
patil-suraj committed
657
658
659
660
661
662
        model.eval()

        torch.manual_seed(0)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)

663
        image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
patil-suraj's avatar
patil-suraj committed
664
665
666
667
668
        with torch.no_grad():
            output = model(image, sample_posterior=True)

        output_slice = output[0, -1, -3:, -3:].flatten()
        # fmt: off
669
        expected_output_slice = torch.tensor([-4.0078e-01, -3.8304e-04, -1.2681e-01, -1.1462e-01, 2.0095e-01, 1.0893e-01, -8.8248e-02, -3.0361e-01, -9.8646e-03])
patil-suraj's avatar
patil-suraj committed
670
        # fmt: on
Patrick von Platen's avatar
up  
Patrick von Platen committed
671
        self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
patil-suraj's avatar
patil-suraj committed
672
673


674
675
676
class PipelineTesterMixin(unittest.TestCase):
    def test_from_pretrained_save_pretrained(self):
        # 1. Load models
Patrick von Platen's avatar
Patrick von Platen committed
677
678
679
680
        model = UNet2DModel(
            block_out_channels=(32, 64),
            layers_per_block=2,
            sample_size=32,
Patrick von Platen's avatar
Patrick von Platen committed
681
682
            in_channels=3,
            out_channels=3,
Patrick von Platen's avatar
Patrick von Platen committed
683
684
            down_block_types=("DownBlock2D", "AttnDownBlock2D"),
            up_block_types=("AttnUpBlock2D", "UpBlock2D"),
685
        )
Patrick von Platen's avatar
Patrick von Platen committed
686
        schedular = DDPMScheduler(num_train_timesteps=10)
687

Patrick von Platen's avatar
Patrick von Platen committed
688
        ddpm = DDPMPipeline(model, schedular)
689
690
691

        with tempfile.TemporaryDirectory() as tmpdirname:
            ddpm.save_pretrained(tmpdirname)
Patrick von Platen's avatar
Patrick von Platen committed
692
            new_ddpm = DDPMPipeline.from_pretrained(tmpdirname)
Patrick von Platen's avatar
Patrick von Platen committed
693
694

        generator = torch.manual_seed(0)
695

anton-l's avatar
anton-l committed
696
        image = ddpm(generator=generator, output_type="numpy")["sample"]
Patrick von Platen's avatar
Patrick von Platen committed
697
        generator = generator.manual_seed(0)
anton-l's avatar
anton-l committed
698
        new_image = new_ddpm(generator=generator, output_type="numpy")["sample"]
699

anton-l's avatar
anton-l committed
700
        assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
701
702
703

    @slow
    def test_from_pretrained_hub(self):
Patrick von Platen's avatar
Patrick von Platen committed
704
        model_path = "google/ddpm-cifar10-32"
705

Patrick von Platen's avatar
Patrick von Platen committed
706
        ddpm = DDPMPipeline.from_pretrained(model_path)
707
708
        ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path)

709
710
        ddpm.scheduler.num_timesteps = 10
        ddpm_from_hub.scheduler.num_timesteps = 10
711

Patrick von Platen's avatar
Patrick von Platen committed
712
        generator = torch.manual_seed(0)
713

anton-l's avatar
anton-l committed
714
        image = ddpm(generator=generator, output_type="numpy")["sample"]
Patrick von Platen's avatar
Patrick von Platen committed
715
        generator = generator.manual_seed(0)
anton-l's avatar
anton-l committed
716
        new_image = ddpm_from_hub(generator=generator, output_type="numpy")["sample"]
717

anton-l's avatar
anton-l committed
718
        assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
Patrick von Platen's avatar
Patrick von Platen committed
719

720
721
    @slow
    def test_output_format(self):
Patrick von Platen's avatar
Patrick von Platen committed
722
        model_path = "google/ddpm-cifar10-32"
723
724
725
726
727
728
729
730
731
732
733
734
735

        pipe = DDIMPipeline.from_pretrained(model_path)

        generator = torch.manual_seed(0)
        images = pipe(generator=generator, output_type="numpy")["sample"]
        assert images.shape == (1, 32, 32, 3)
        assert isinstance(images, np.ndarray)

        images = pipe(generator=generator, output_type="pil")["sample"]
        assert isinstance(images, list)
        assert len(images) == 1
        assert isinstance(images[0], PIL.Image.Image)

anton-l's avatar
anton-l committed
736
737
738
739
740
        # use PIL by default
        images = pipe(generator=generator)["sample"]
        assert isinstance(images, list)
        assert isinstance(images[0], PIL.Image.Image)

Patrick von Platen's avatar
Patrick von Platen committed
741
742
    @slow
    def test_ddpm_cifar10(self):
Patrick von Platen's avatar
Patrick von Platen committed
743
        model_id = "google/ddpm-cifar10-32"
Patrick von Platen's avatar
Patrick von Platen committed
744

Patrick von Platen's avatar
Patrick von Platen committed
745
        unet = UNet2DModel.from_pretrained(model_id)
746
747
        scheduler = DDPMScheduler.from_config(model_id)
        scheduler = scheduler.set_format("pt")
Patrick von Platen's avatar
Patrick von Platen committed
748

749
        ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
Patrick von Platen's avatar
Patrick von Platen committed
750
751

        generator = torch.manual_seed(0)
anton-l's avatar
anton-l committed
752
        image = ddpm(generator=generator, output_type="numpy")["sample"]
Patrick von Platen's avatar
Patrick von Platen committed
753

754
        image_slice = image[0, -3:, -3:, -1]
Patrick von Platen's avatar
Patrick von Platen committed
755

756
757
758
        assert image.shape == (1, 32, 32, 3)
        expected_slice = np.array([0.41995, 0.35885, 0.19385, 0.38475, 0.3382, 0.2647, 0.41545, 0.3582, 0.33845])
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
759
760
761

    @slow
    def test_ddim_lsun(self):
Patrick von Platen's avatar
Patrick von Platen committed
762
        model_id = "google/ddpm-ema-bedroom-256"
763

Patrick von Platen's avatar
Patrick von Platen committed
764
        unet = UNet2DModel.from_pretrained(model_id)
765
        scheduler = DDIMScheduler.from_config(model_id)
766

767
        ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
768
769

        generator = torch.manual_seed(0)
anton-l's avatar
anton-l committed
770
        image = ddpm(generator=generator, output_type="numpy")["sample"]
771

772
        image_slice = image[0, -3:, -3:, -1]
773

774
775
776
        assert image.shape == (1, 256, 256, 3)
        expected_slice = np.array([0.00605, 0.0201, 0.0344, 0.00235, 0.00185, 0.00025, 0.00215, 0.0, 0.00685])
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
Patrick von Platen's avatar
Patrick von Platen committed
777
778
779

    @slow
    def test_ddim_cifar10(self):
Patrick von Platen's avatar
Patrick von Platen committed
780
        model_id = "google/ddpm-cifar10-32"
Patrick von Platen's avatar
Patrick von Platen committed
781

Patrick von Platen's avatar
Patrick von Platen committed
782
        unet = UNet2DModel.from_pretrained(model_id)
783
        scheduler = DDIMScheduler(tensor_format="pt")
Patrick von Platen's avatar
Patrick von Platen committed
784

785
        ddim = DDIMPipeline(unet=unet, scheduler=scheduler)
786
787

        generator = torch.manual_seed(0)
anton-l's avatar
anton-l committed
788
        image = ddim(generator=generator, eta=0.0, output_type="numpy")["sample"]
Patrick von Platen's avatar
Patrick von Platen committed
789

790
        image_slice = image[0, -3:, -3:, -1]
Patrick von Platen's avatar
Patrick von Platen committed
791

792
793
794
        assert image.shape == (1, 32, 32, 3)
        expected_slice = np.array([0.17235, 0.16175, 0.16005, 0.16255, 0.1497, 0.1513, 0.15045, 0.1442, 0.1453])
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
patil-suraj's avatar
patil-suraj committed
795

Patrick von Platen's avatar
Patrick von Platen committed
796
797
    @slow
    def test_pndm_cifar10(self):
Patrick von Platen's avatar
Patrick von Platen committed
798
        model_id = "google/ddpm-cifar10-32"
Patrick von Platen's avatar
Patrick von Platen committed
799

Patrick von Platen's avatar
Patrick von Platen committed
800
        unet = UNet2DModel.from_pretrained(model_id)
801
        scheduler = PNDMScheduler(tensor_format="pt")
Patrick von Platen's avatar
Patrick von Platen committed
802

803
        pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
Patrick von Platen's avatar
Patrick von Platen committed
804
        generator = torch.manual_seed(0)
anton-l's avatar
anton-l committed
805
        image = pndm(generator=generator, output_type="numpy")["sample"]
Patrick von Platen's avatar
Patrick von Platen committed
806

807
        image_slice = image[0, -3:, -3:, -1]
Patrick von Platen's avatar
Patrick von Platen committed
808

809
810
811
        assert image.shape == (1, 32, 32, 3)
        expected_slice = np.array([0.1564, 0.14645, 0.1406, 0.14715, 0.12425, 0.14045, 0.13115, 0.12175, 0.125])
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
Patrick von Platen's avatar
Patrick von Platen committed
812

patil-suraj's avatar
patil-suraj committed
813
814
    @slow
    def test_ldm_text2img(self):
Patrick von Platen's avatar
Patrick von Platen committed
815
        ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
patil-suraj's avatar
patil-suraj committed
816
817
818

        prompt = "A painting of a squirrel eating a burger"
        generator = torch.manual_seed(0)
anton-l's avatar
anton-l committed
819
820
821
        image = ldm([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="numpy")[
            "sample"
        ]
patil-suraj's avatar
patil-suraj committed
822

823
        image_slice = image[0, -3:, -3:, -1]
patil-suraj's avatar
patil-suraj committed
824

825
826
827
        assert image.shape == (1, 256, 256, 3)
        expected_slice = np.array([0.9256, 0.9340, 0.8933, 0.9361, 0.9113, 0.8727, 0.9122, 0.8745, 0.8099])
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
828

patil-suraj's avatar
patil-suraj committed
829
830
    @slow
    def test_ldm_text2img_fast(self):
Patrick von Platen's avatar
Patrick von Platen committed
831
        ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
patil-suraj's avatar
patil-suraj committed
832
833
834

        prompt = "A painting of a squirrel eating a burger"
        generator = torch.manual_seed(0)
anton-l's avatar
anton-l committed
835
        image = ldm([prompt], generator=generator, num_inference_steps=1, output_type="numpy")["sample"]
patil-suraj's avatar
patil-suraj committed
836

837
        image_slice = image[0, -3:, -3:, -1]
patil-suraj's avatar
patil-suraj committed
838

839
840
841
        assert image.shape == (1, 256, 256, 3)
        expected_slice = np.array([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344])
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
patil-suraj's avatar
patil-suraj committed
842

Suraj Patil's avatar
Suraj Patil committed
843
    @slow
844
    @unittest.skipIf(torch_device == "cpu", "Stable diffusion is suppused to run on GPU")
Suraj Patil's avatar
Suraj Patil committed
845
    def test_stable_diffusion(self):
846
        sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1-diffusers")
Suraj Patil's avatar
Suraj Patil committed
847
848

        prompt = "A painting of a squirrel eating a burger"
849
850
851
852
853
854
855
        generator = torch.Generator(device=torch_device).manual_seed(0)
        with torch.autocast("cuda"):
            output = sd_pipe(
                [prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="np"
            )

        image = output["sample"]
Suraj Patil's avatar
Suraj Patil committed
856
857
858
859

        image_slice = image[0, -3:, -3:, -1]

        assert image.shape == (1, 512, 512, 3)
860
        expected_slice = np.array([0.898, 0.9194, 0.91, 0.8955, 0.915, 0.919, 0.9233, 0.9307, 0.8887])
Suraj Patil's avatar
Suraj Patil committed
861
862
863
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

    @slow
864
865
866
867
868
869
870
871
872
873
874
875
    @unittest.skipIf(torch_device == "cpu", "Stable diffusion is suppused to run on GPU")
    def test_stable_diffusion_fast_ddim(self):
        sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1-diffusers")

        scheduler = DDIMScheduler(
            beta_start=0.00085,
            beta_end=0.012,
            beta_schedule="scaled_linear",
            clip_sample=False,
            clip_alpha_at_one=False,
        )
        sd_pipe.scheduler = scheduler
Suraj Patil's avatar
Suraj Patil committed
876
877

        prompt = "A painting of a squirrel eating a burger"
878
879
880
881
882
        generator = torch.Generator(device=torch_device).manual_seed(0)

        with torch.autocast("cuda"):
            output = sd_pipe([prompt], generator=generator, num_inference_steps=2, output_type="numpy")
        image = output["sample"]
Suraj Patil's avatar
Suraj Patil committed
883
884
885
886

        image_slice = image[0, -3:, -3:, -1]

        assert image.shape == (1, 512, 512, 3)
887
888
        expected_slice = np.array([0.8354, 0.83, 0.866, 0.838, 0.8315, 0.867, 0.836, 0.8584, 0.869])
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
Suraj Patil's avatar
Suraj Patil committed
889

Patrick von Platen's avatar
Patrick von Platen committed
890
891
    @slow
    def test_score_sde_ve_pipeline(self):
Patrick von Platen's avatar
Patrick von Platen committed
892
893
        model_id = "google/ncsnpp-church-256"
        model = UNet2DModel.from_pretrained(model_id)
894

Patrick von Platen's avatar
Patrick von Platen committed
895
        scheduler = ScoreSdeVeScheduler.from_config(model_id)
Patrick von Platen's avatar
Patrick von Platen committed
896

Patrick von Platen's avatar
Patrick von Platen committed
897
        sde_ve = ScoreSdeVePipeline(unet=model, scheduler=scheduler)
Patrick von Platen's avatar
Patrick von Platen committed
898

899
        torch.manual_seed(0)
anton-l's avatar
anton-l committed
900
        image = sde_ve(num_inference_steps=300, output_type="numpy")["sample"]
Nathan Lambert's avatar
Nathan Lambert committed
901

902
        image_slice = image[0, -3:, -3:, -1]
Patrick von Platen's avatar
Patrick von Platen committed
903

904
        assert image.shape == (1, 256, 256, 3)
905

906
907
        expected_slice = np.array([0.64363, 0.5868, 0.3031, 0.2284, 0.7409, 0.3216, 0.25643, 0.6557, 0.2633])
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
Patrick von Platen's avatar
Patrick von Platen committed
908

patil-suraj's avatar
patil-suraj committed
909
910
    @slow
    def test_ldm_uncond(self):
Patrick von Platen's avatar
Patrick von Platen committed
911
        ldm = LDMPipeline.from_pretrained("CompVis/ldm-celebahq-256")
patil-suraj's avatar
patil-suraj committed
912
913

        generator = torch.manual_seed(0)
anton-l's avatar
anton-l committed
914
        image = ldm(generator=generator, num_inference_steps=5, output_type="numpy")["sample"]
patil-suraj's avatar
patil-suraj committed
915

916
        image_slice = image[0, -3:, -3:, -1]
patil-suraj's avatar
patil-suraj committed
917

918
919
920
        assert image.shape == (1, 256, 256, 3)
        expected_slice = np.array([0.4399, 0.44975, 0.46825, 0.474, 0.4359, 0.4581, 0.45095, 0.4341, 0.4447])
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938

    @slow
    def test_ddpm_ddim_equality(self):
        model_id = "google/ddpm-cifar10-32"

        unet = UNet2DModel.from_pretrained(model_id)
        ddpm_scheduler = DDPMScheduler(tensor_format="pt")
        ddim_scheduler = DDIMScheduler(tensor_format="pt")

        ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler)
        ddim = DDIMPipeline(unet=unet, scheduler=ddim_scheduler)

        generator = torch.manual_seed(0)
        ddpm_image = ddpm(generator=generator, output_type="numpy")["sample"]

        generator = torch.manual_seed(0)
        ddim_image = ddim(generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy")["sample"]

939
        # the values aren't exactly equal, but the images look the same visually
940
941
        assert np.abs(ddpm_image - ddim_image).max() < 1e-1

942
    @unittest.skip("(Anton) The test is failing for large batch sizes, needs investigation")
943
944
945
946
947
948
949
950
951
952
953
    def test_ddpm_ddim_equality_batched(self):
        model_id = "google/ddpm-cifar10-32"

        unet = UNet2DModel.from_pretrained(model_id)
        ddpm_scheduler = DDPMScheduler(tensor_format="pt")
        ddim_scheduler = DDIMScheduler(tensor_format="pt")

        ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler)
        ddim = DDIMPipeline(unet=unet, scheduler=ddim_scheduler)

        generator = torch.manual_seed(0)
954
        ddpm_images = ddpm(batch_size=4, generator=generator, output_type="numpy")["sample"]
955
956

        generator = torch.manual_seed(0)
957
        ddim_images = ddim(batch_size=4, generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy")[
958
959
960
            "sample"
        ]

961
        # the values aren't exactly equal, but the images look the same visually
962
        assert np.abs(ddpm_images - ddim_images).max() < 1e-1
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978

    @slow
    def test_karras_ve_pipeline(self):
        model_id = "google/ncsnpp-celebahq-256"
        model = UNet2DModel.from_pretrained(model_id)
        scheduler = KarrasVeScheduler(tensor_format="pt")

        pipe = KarrasVePipeline(unet=model, scheduler=scheduler)

        generator = torch.manual_seed(0)
        image = pipe(num_inference_steps=20, generator=generator, output_type="numpy")["sample"]

        image_slice = image[0, -3:, -3:, -1]
        assert image.shape == (1, 256, 256, 3)
        expected_slice = np.array([0.26815, 0.1581, 0.2658, 0.23248, 0.1550, 0.2539, 0.1131, 0.1024, 0.0837])
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2