test_models_vae.py 31.3 KB
Newer Older
1
# coding=utf-8
Patrick von Platen's avatar
Patrick von Platen committed
2
# Copyright 2023 HuggingFace Inc.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
import unittest

Will Berman's avatar
Will Berman committed
19
import numpy as np
20
import torch
21
from parameterized import parameterized
22

Will Berman's avatar
Will Berman committed
23
24
25
26
27
28
29
from diffusers import (
    AsymmetricAutoencoderKL,
    AutoencoderKL,
    AutoencoderTiny,
    ConsistencyDecoderVAE,
    StableDiffusionPipeline,
)
30
from diffusers.utils.import_utils import is_xformers_available
Will Berman's avatar
Will Berman committed
31
from diffusers.utils.loading_utils import load_image
Dhruv Nair's avatar
Dhruv Nair committed
32
33
34
35
36
37
38
39
40
from diffusers.utils.testing_utils import (
    enable_full_determinism,
    floats_tensor,
    load_hf_numpy,
    require_torch_gpu,
    slow,
    torch_all_close,
    torch_device,
)
Will Berman's avatar
Will Berman committed
41
from diffusers.utils.torch_utils import randn_tensor
42

43
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
44
45


46
enable_full_determinism()
47
48


49
class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
50
    model_class = AutoencoderKL
51
52
    main_input_name = "sample"
    base_precision = 1e-2
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89

    @property
    def dummy_input(self):
        batch_size = 4
        num_channels = 3
        sizes = (32, 32)

        image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)

        return {"sample": image}

    @property
    def input_shape(self):
        return (3, 32, 32)

    @property
    def output_shape(self):
        return (3, 32, 32)

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
            "block_out_channels": [32, 64],
            "in_channels": 3,
            "out_channels": 3,
            "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
            "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
            "latent_channels": 4,
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict

    def test_forward_signature(self):
        pass

    def test_training(self):
        pass

90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
    @unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
    def test_gradient_checkpointing(self):
        # enable deterministic behavior for gradient checkpointing
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
        model = self.model_class(**init_dict)
        model.to(torch_device)

        assert not model.is_gradient_checkpointing and model.training

        out = model(**inputs_dict).sample
        # run the backwards pass on the model. For backwards pass, for simplicity purpose,
        # we won't calculate the loss and rather backprop on out.sum()
        model.zero_grad()

        labels = torch.randn_like(out)
        loss = (out - labels).mean()
        loss.backward()

        # re-instantiate the model now enabling gradient checkpointing
        model_2 = self.model_class(**init_dict)
        # clone model
        model_2.load_state_dict(model.state_dict())
        model_2.to(torch_device)
        model_2.enable_gradient_checkpointing()

        assert model_2.is_gradient_checkpointing and model_2.training

        out_2 = model_2(**inputs_dict).sample
        # run the backwards pass on the model. For backwards pass, for simplicity purpose,
        # we won't calculate the loss and rather backprop on out.sum()
        model_2.zero_grad()
        loss_2 = (out_2 - labels).mean()
        loss_2.backward()

        # compare the output and parameters gradients
        self.assertTrue((loss - loss_2).abs() < 1e-5)
        named_params = dict(model.named_parameters())
        named_params_2 = dict(model_2.named_parameters())
        for name, param in named_params.items():
            self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))

131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
    def test_from_pretrained_hub(self):
        model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True)
        self.assertIsNotNone(model)
        self.assertEqual(len(loading_info["missing_keys"]), 0)

        model.to(torch_device)
        image = model(**self.dummy_input)

        assert image is not None, "Make sure output is not None"

    def test_output_pretrained(self):
        model = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy")
        model = model.to(torch_device)
        model.eval()

146
        if torch_device == "mps":
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
            generator = torch.manual_seed(0)
        else:
            generator = torch.Generator(device=torch_device).manual_seed(0)

        image = torch.randn(
            1,
            model.config.in_channels,
            model.config.sample_size,
            model.config.sample_size,
            generator=torch.manual_seed(0),
        )
        image = image.to(torch_device)
        with torch.no_grad():
            output = model(image, sample_posterior=True, generator=generator).sample

        output_slice = output[0, -1, -3:, -3:].flatten().cpu()

        # Since the VAE Gaussian prior's generator is seeded on the appropriate device,
        # the expected output slices are not the same for CPU and GPU.
        if torch_device == "mps":
            expected_output_slice = torch.tensor(
                [
                    -4.0078e-01,
                    -3.8323e-04,
                    -1.2681e-01,
                    -1.1462e-01,
                    2.0095e-01,
                    1.0893e-01,
                    -8.8247e-02,
                    -3.0361e-01,
                    -9.8644e-03,
                ]
            )
        elif torch_device == "cpu":
            expected_output_slice = torch.tensor(
                [-0.1352, 0.0878, 0.0419, -0.0818, -0.1069, 0.0688, -0.1458, -0.4446, -0.0026]
            )
        else:
            expected_output_slice = torch.tensor(
                [-0.2421, 0.4642, 0.2507, -0.0438, 0.0682, 0.3160, -0.2018, -0.0727, 0.2485]
            )

189
        self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
190
191


Ruslan Vorovchenko's avatar
Ruslan Vorovchenko committed
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
class AsymmetricAutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
    model_class = AsymmetricAutoencoderKL
    main_input_name = "sample"
    base_precision = 1e-2

    @property
    def dummy_input(self):
        batch_size = 4
        num_channels = 3
        sizes = (32, 32)

        image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
        mask = torch.ones((batch_size, 1) + sizes).to(torch_device)

        return {"sample": image, "mask": mask}

    @property
    def input_shape(self):
        return (3, 32, 32)

    @property
    def output_shape(self):
        return (3, 32, 32)

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
            "in_channels": 3,
            "out_channels": 3,
            "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
            "down_block_out_channels": [32, 64],
            "layers_per_down_block": 1,
            "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
            "up_block_out_channels": [32, 64],
            "layers_per_up_block": 1,
            "act_fn": "silu",
            "latent_channels": 4,
            "norm_num_groups": 32,
            "sample_size": 32,
            "scaling_factor": 0.18215,
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict

    def test_forward_signature(self):
        pass

    def test_forward_with_norm_groups(self):
        pass


242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
class AutoencoderTinyTests(ModelTesterMixin, unittest.TestCase):
    model_class = AutoencoderTiny
    main_input_name = "sample"
    base_precision = 1e-2

    @property
    def dummy_input(self):
        batch_size = 4
        num_channels = 3
        sizes = (32, 32)

        image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)

        return {"sample": image}

    @property
    def input_shape(self):
        return (3, 32, 32)

    @property
    def output_shape(self):
        return (3, 32, 32)

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
            "in_channels": 3,
            "out_channels": 3,
            "encoder_block_out_channels": (32, 32),
            "decoder_block_out_channels": (32, 32),
            "num_encoder_blocks": (1, 2),
            "num_decoder_blocks": (2, 1),
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict

    def test_outputs_equivalence(self):
        pass


Will Berman's avatar
Will Berman committed
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase):
    model_class = ConsistencyDecoderVAE
    main_input_name = "sample"
    base_precision = 1e-2
    forward_requires_fresh_args = True

    def inputs_dict(self, seed=None):
        generator = torch.Generator("cpu")
        if seed is not None:
            generator.manual_seed(0)
        image = randn_tensor((4, 3, 32, 32), generator=generator, device=torch.device(torch_device))

        return {"sample": image, "generator": generator}

    @property
    def input_shape(self):
        return (3, 32, 32)

    @property
    def output_shape(self):
        return (3, 32, 32)

    @property
    def init_dict(self):
        return {
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
            "encoder_block_out_channels": [32, 64],
            "encoder_in_channels": 3,
            "encoder_out_channels": 4,
            "encoder_down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
            "decoder_add_attention": False,
            "decoder_block_out_channels": [32, 64],
            "decoder_down_block_types": [
                "ResnetDownsampleBlock2D",
                "ResnetDownsampleBlock2D",
            ],
            "decoder_downsample_padding": 1,
            "decoder_in_channels": 7,
            "decoder_layers_per_block": 1,
            "decoder_norm_eps": 1e-05,
            "decoder_norm_num_groups": 32,
            "decoder_num_train_timesteps": 1024,
            "decoder_out_channels": 6,
            "decoder_resnet_time_scale_shift": "scale_shift",
            "decoder_time_embedding_type": "learned",
            "decoder_up_block_types": [
                "ResnetUpsampleBlock2D",
                "ResnetUpsampleBlock2D",
            ],
Will Berman's avatar
Will Berman committed
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
            "scaling_factor": 1,
            "latent_channels": 4,
        }

    def prepare_init_args_and_inputs_for_common(self):
        return self.init_dict, self.inputs_dict()

    @unittest.skip
    def test_training(self):
        ...

    @unittest.skip
    def test_ema_training(self):
        ...


345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
@slow
class AutoencoderTinyIntegrationTests(unittest.TestCase):
    def tearDown(self):
        # clean up the VRAM after each test
        super().tearDown()
        gc.collect()
        torch.cuda.empty_cache()

    def get_file_format(self, seed, shape):
        return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"

    def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
        dtype = torch.float16 if fp16 else torch.float32
        image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
        return image

    def get_sd_vae_model(self, model_id="hf-internal-testing/taesd-diffusers", fp16=False):
        torch_dtype = torch.float16 if fp16 else torch.float32

        model = AutoencoderTiny.from_pretrained(model_id, torch_dtype=torch_dtype)
        model.to(torch_device).eval()
        return model

368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
    @parameterized.expand(
        [
            [(1, 4, 73, 97), (1, 3, 584, 776)],
            [(1, 4, 97, 73), (1, 3, 776, 584)],
            [(1, 4, 49, 65), (1, 3, 392, 520)],
            [(1, 4, 65, 49), (1, 3, 520, 392)],
            [(1, 4, 49, 49), (1, 3, 392, 392)],
        ]
    )
    def test_tae_tiling(self, in_shape, out_shape):
        model = self.get_sd_vae_model()
        model.enable_tiling()
        with torch.no_grad():
            zeros = torch.zeros(in_shape).to(torch_device)
            dec = model.decode(zeros).sample
            assert dec.shape == out_shape

385
386
387
388
389
390
391
392
393
394
    def test_stable_diffusion(self):
        model = self.get_sd_vae_model()
        image = self.get_sd_image(seed=33)

        with torch.no_grad():
            sample = model(image).sample

        assert sample.shape == image.shape

        output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
395
        expected_output_slice = torch.tensor([0.0093, 0.6385, -0.1274, 0.1631, -0.1762, 0.5232, -0.3108, -0.0382])
396
397
398

        assert torch_all_close(output_slice, expected_output_slice, atol=3e-3)

399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
    @parameterized.expand([(True,), (False,)])
    def test_tae_roundtrip(self, enable_tiling):
        # load the autoencoder
        model = self.get_sd_vae_model()
        if enable_tiling:
            model.enable_tiling()

        # make a black image with a white square in the middle,
        # which is large enough to split across multiple tiles
        image = -torch.ones(1, 3, 1024, 1024, device=torch_device)
        image[..., 256:768, 256:768] = 1.0

        # round-trip the image through the autoencoder
        with torch.no_grad():
            sample = model(image).sample

        # the autoencoder reconstruction should match original image, sorta
        def downscale(x):
            return torch.nn.functional.avg_pool2d(x, model.spatial_scale_factor)

        assert torch_all_close(downscale(sample), downscale(image), atol=0.125)

421

422
423
@slow
class AutoencoderKLIntegrationTests(unittest.TestCase):
Patrick von Platen's avatar
hot fix  
Patrick von Platen committed
424
425
426
    def get_file_format(self, seed, shape):
        return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"

427
428
429
430
431
432
433
434
    def tearDown(self):
        # clean up the VRAM after each test
        super().tearDown()
        gc.collect()
        torch.cuda.empty_cache()

    def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
        dtype = torch.float16 if fp16 else torch.float32
435
        image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
436
437
438
439
440
441
442
        return image

    def get_sd_vae_model(self, model_id="CompVis/stable-diffusion-v1-4", fp16=False):
        revision = "fp16" if fp16 else None
        torch_dtype = torch.float16 if fp16 else torch.float32

        model = AutoencoderKL.from_pretrained(
443
444
445
446
            model_id,
            subfolder="vae",
            torch_dtype=torch_dtype,
            revision=revision,
447
        )
448
        model.to(torch_device)
449
450
451
452

        return model

    def get_generator(self, seed=0):
453
        if torch_device == "mps":
454
            return torch.manual_seed(seed)
455
456
457
458
459
        return torch.Generator(device=torch_device).manual_seed(seed)

    @parameterized.expand(
        [
            # fmt: off
460
461
            [33, [-0.1603, 0.9878, -0.0495, -0.0790, -0.2709, 0.8375, -0.2060, -0.0824], [-0.2395, 0.0098, 0.0102, -0.0709, -0.2840, -0.0274, -0.0718, -0.1824]],
            [47, [-0.2376, 0.1168, 0.1332, -0.4840, -0.2508, -0.0791, -0.0493, -0.4089], [0.0350, 0.0847, 0.0467, 0.0344, -0.0842, -0.0547, -0.0633, -0.1131]],
462
463
464
            # fmt: on
        ]
    )
465
    def test_stable_diffusion(self, seed, expected_slice, expected_slice_mps):
466
467
468
469
470
471
472
473
474
475
        model = self.get_sd_vae_model()
        image = self.get_sd_image(seed)
        generator = self.get_generator(seed)

        with torch.no_grad():
            sample = model(image, generator=generator, sample_posterior=True).sample

        assert sample.shape == image.shape

        output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
476
        expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice)
477

478
        assert torch_all_close(output_slice, expected_output_slice, atol=3e-3)
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501

    @parameterized.expand(
        [
            # fmt: off
            [33, [-0.0513, 0.0289, 1.3799, 0.2166, -0.2573, -0.0871, 0.5103, -0.0999]],
            [47, [-0.4128, -0.1320, -0.3704, 0.1965, -0.4116, -0.2332, -0.3340, 0.2247]],
            # fmt: on
        ]
    )
    @require_torch_gpu
    def test_stable_diffusion_fp16(self, seed, expected_slice):
        model = self.get_sd_vae_model(fp16=True)
        image = self.get_sd_image(seed, fp16=True)
        generator = self.get_generator(seed)

        with torch.no_grad():
            sample = model(image, generator=generator, sample_posterior=True).sample

        assert sample.shape == image.shape

        output_slice = sample[-1, -2:, :2, -2:].flatten().float().cpu()
        expected_output_slice = torch.tensor(expected_slice)

Patrick von Platen's avatar
Patrick von Platen committed
502
        assert torch_all_close(output_slice, expected_output_slice, atol=1e-2)
503
504
505
506

    @parameterized.expand(
        [
            # fmt: off
507
508
            [33, [-0.1609, 0.9866, -0.0487, -0.0777, -0.2716, 0.8368, -0.2055, -0.0814], [-0.2395, 0.0098, 0.0102, -0.0709, -0.2840, -0.0274, -0.0718, -0.1824]],
            [47, [-0.2377, 0.1147, 0.1333, -0.4841, -0.2506, -0.0805, -0.0491, -0.4085], [0.0350, 0.0847, 0.0467, 0.0344, -0.0842, -0.0547, -0.0633, -0.1131]],
509
510
511
            # fmt: on
        ]
    )
512
    def test_stable_diffusion_mode(self, seed, expected_slice, expected_slice_mps):
513
514
515
516
517
518
519
520
521
        model = self.get_sd_vae_model()
        image = self.get_sd_image(seed)

        with torch.no_grad():
            sample = model(image).sample

        assert sample.shape == image.shape

        output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
522
        expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice)
523

524
        assert torch_all_close(output_slice, expected_output_slice, atol=3e-3)
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546

    @parameterized.expand(
        [
            # fmt: off
            [13, [-0.2051, -0.1803, -0.2311, -0.2114, -0.3292, -0.3574, -0.2953, -0.3323]],
            [37, [-0.2632, -0.2625, -0.2199, -0.2741, -0.4539, -0.4990, -0.3720, -0.4925]],
            # fmt: on
        ]
    )
    @require_torch_gpu
    def test_stable_diffusion_decode(self, seed, expected_slice):
        model = self.get_sd_vae_model()
        encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))

        with torch.no_grad():
            sample = model.decode(encoding).sample

        assert list(sample.shape) == [3, 3, 512, 512]

        output_slice = sample[-1, -2:, :2, -2:].flatten().cpu()
        expected_output_slice = torch.tensor(expected_slice)

Patrick von Platen's avatar
Patrick von Platen committed
547
        assert torch_all_close(output_slice, expected_output_slice, atol=1e-3)
548
549
550
551
552
553
554
555
556

    @parameterized.expand(
        [
            # fmt: off
            [27, [-0.0369, 0.0207, -0.0776, -0.0682, -0.1747, -0.1930, -0.1465, -0.2039]],
            [16, [-0.1628, -0.2134, -0.2747, -0.2642, -0.3774, -0.4404, -0.3687, -0.4277]],
            # fmt: on
        ]
    )
557
    @require_torch_gpu
558
559
560
561
562
563
564
565
566
567
568
569
    def test_stable_diffusion_decode_fp16(self, seed, expected_slice):
        model = self.get_sd_vae_model(fp16=True)
        encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64), fp16=True)

        with torch.no_grad():
            sample = model.decode(encoding).sample

        assert list(sample.shape) == [3, 3, 512, 512]

        output_slice = sample[-1, -2:, :2, -2:].flatten().float().cpu()
        expected_output_slice = torch.tensor(expected_slice)

Patrick von Platen's avatar
Patrick von Platen committed
570
        assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)
571

572
    @parameterized.expand([(13,), (16,), (27,)])
573
    @require_torch_gpu
574
    @unittest.skipIf(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.")
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
    def test_stable_diffusion_decode_xformers_vs_2_0_fp16(self, seed):
        model = self.get_sd_vae_model(fp16=True)
        encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64), fp16=True)

        with torch.no_grad():
            sample = model.decode(encoding).sample

        model.enable_xformers_memory_efficient_attention()
        with torch.no_grad():
            sample_2 = model.decode(encoding).sample

        assert list(sample.shape) == [3, 3, 512, 512]

        assert torch_all_close(sample, sample_2, atol=1e-1)

590
    @parameterized.expand([(13,), (16,), (37,)])
591
    @require_torch_gpu
592
    @unittest.skipIf(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.")
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
    def test_stable_diffusion_decode_xformers_vs_2_0(self, seed):
        model = self.get_sd_vae_model()
        encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))

        with torch.no_grad():
            sample = model.decode(encoding).sample

        model.enable_xformers_memory_efficient_attention()
        with torch.no_grad():
            sample_2 = model.decode(encoding).sample

        assert list(sample.shape) == [3, 3, 512, 512]

        assert torch_all_close(sample, sample_2, atol=1e-2)

608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
    @parameterized.expand(
        [
            # fmt: off
            [33, [-0.3001, 0.0918, -2.6984, -3.9720, -3.2099, -5.0353, 1.7338, -0.2065, 3.4267]],
            [47, [-1.5030, -4.3871, -6.0355, -9.1157, -1.6661, -2.7853, 2.1607, -5.0823, 2.5633]],
            # fmt: on
        ]
    )
    def test_stable_diffusion_encode_sample(self, seed, expected_slice):
        model = self.get_sd_vae_model()
        image = self.get_sd_image(seed)
        generator = self.get_generator(seed)

        with torch.no_grad():
            dist = model.encode(image).latent_dist
            sample = dist.sample(generator=generator)

        assert list(sample.shape) == [image.shape[0], 4] + [i // 8 for i in image.shape[2:]]

        output_slice = sample[0, -1, -3:, -3:].flatten().cpu()
        expected_output_slice = torch.tensor(expected_slice)

630
        tolerance = 3e-3 if torch_device != "mps" else 1e-2
631
        assert torch_all_close(output_slice, expected_output_slice, atol=tolerance)
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650

    def test_stable_diffusion_model_local(self):
        model_id = "stabilityai/sd-vae-ft-mse"
        model_1 = AutoencoderKL.from_pretrained(model_id).to(torch_device)

        url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors"
        model_2 = AutoencoderKL.from_single_file(url).to(torch_device)
        image = self.get_sd_image(33)

        with torch.no_grad():
            sample_1 = model_1(image).sample
            sample_2 = model_2(image).sample

        assert sample_1.shape == sample_2.shape

        output_slice_1 = sample_1[-1, -2:, -2:, :2].flatten().float().cpu()
        output_slice_2 = sample_2[-1, -2:, -2:, :2].flatten().float().cpu()

        assert torch_all_close(output_slice_1, output_slice_2, atol=3e-3)
Ruslan Vorovchenko's avatar
Ruslan Vorovchenko committed
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796


@slow
class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
    def get_file_format(self, seed, shape):
        return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"

    def tearDown(self):
        # clean up the VRAM after each test
        super().tearDown()
        gc.collect()
        torch.cuda.empty_cache()

    def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
        dtype = torch.float16 if fp16 else torch.float32
        image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
        return image

    def get_sd_vae_model(self, model_id="cross-attention/asymmetric-autoencoder-kl-x-1-5", fp16=False):
        revision = "main"
        torch_dtype = torch.float32

        model = AsymmetricAutoencoderKL.from_pretrained(
            model_id,
            torch_dtype=torch_dtype,
            revision=revision,
        )
        model.to(torch_device).eval()

        return model

    def get_generator(self, seed=0):
        if torch_device == "mps":
            return torch.manual_seed(seed)
        return torch.Generator(device=torch_device).manual_seed(seed)

    @parameterized.expand(
        [
            # fmt: off
            [33, [-0.0344, 0.2912, 0.1687, -0.0137, -0.3462, 0.3552, -0.1337, 0.1078], [-0.1603, 0.9878, -0.0495, -0.0790, -0.2709, 0.8375, -0.2060, -0.0824]],
            [47, [0.4400, 0.0543, 0.2873, 0.2946, 0.0553, 0.0839, -0.1585, 0.2529], [-0.2376, 0.1168, 0.1332, -0.4840, -0.2508, -0.0791, -0.0493, -0.4089]],
            # fmt: on
        ]
    )
    def test_stable_diffusion(self, seed, expected_slice, expected_slice_mps):
        model = self.get_sd_vae_model()
        image = self.get_sd_image(seed)
        generator = self.get_generator(seed)

        with torch.no_grad():
            sample = model(image, generator=generator, sample_posterior=True).sample

        assert sample.shape == image.shape

        output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
        expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice)

        assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)

    @parameterized.expand(
        [
            # fmt: off
            [33, [-0.0340, 0.2870, 0.1698, -0.0105, -0.3448, 0.3529, -0.1321, 0.1097], [-0.0344, 0.2912, 0.1687, -0.0137, -0.3462, 0.3552, -0.1337, 0.1078]],
            [47, [0.4397, 0.0550, 0.2873, 0.2946, 0.0567, 0.0855, -0.1580, 0.2531], [0.4397, 0.0550, 0.2873, 0.2946, 0.0567, 0.0855, -0.1580, 0.2531]],
            # fmt: on
        ]
    )
    def test_stable_diffusion_mode(self, seed, expected_slice, expected_slice_mps):
        model = self.get_sd_vae_model()
        image = self.get_sd_image(seed)

        with torch.no_grad():
            sample = model(image).sample

        assert sample.shape == image.shape

        output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
        expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice)

        assert torch_all_close(output_slice, expected_output_slice, atol=3e-3)

    @parameterized.expand(
        [
            # fmt: off
            [13, [-0.0521, -0.2939,  0.1540, -0.1855, -0.5936, -0.3138, -0.4579, -0.2275]],
            [37, [-0.1820, -0.4345, -0.0455, -0.2923, -0.8035, -0.5089, -0.4795, -0.3106]],
            # fmt: on
        ]
    )
    @require_torch_gpu
    def test_stable_diffusion_decode(self, seed, expected_slice):
        model = self.get_sd_vae_model()
        encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))

        with torch.no_grad():
            sample = model.decode(encoding).sample

        assert list(sample.shape) == [3, 3, 512, 512]

        output_slice = sample[-1, -2:, :2, -2:].flatten().cpu()
        expected_output_slice = torch.tensor(expected_slice)

        assert torch_all_close(output_slice, expected_output_slice, atol=2e-3)

    @parameterized.expand([(13,), (16,), (37,)])
    @require_torch_gpu
    @unittest.skipIf(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.")
    def test_stable_diffusion_decode_xformers_vs_2_0(self, seed):
        model = self.get_sd_vae_model()
        encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))

        with torch.no_grad():
            sample = model.decode(encoding).sample

        model.enable_xformers_memory_efficient_attention()
        with torch.no_grad():
            sample_2 = model.decode(encoding).sample

        assert list(sample.shape) == [3, 3, 512, 512]

        assert torch_all_close(sample, sample_2, atol=5e-2)

    @parameterized.expand(
        [
            # fmt: off
            [33, [-0.3001, 0.0918, -2.6984, -3.9720, -3.2099, -5.0353, 1.7338, -0.2065, 3.4267]],
            [47, [-1.5030, -4.3871, -6.0355, -9.1157, -1.6661, -2.7853, 2.1607, -5.0823, 2.5633]],
            # fmt: on
        ]
    )
    def test_stable_diffusion_encode_sample(self, seed, expected_slice):
        model = self.get_sd_vae_model()
        image = self.get_sd_image(seed)
        generator = self.get_generator(seed)

        with torch.no_grad():
            dist = model.encode(image).latent_dist
            sample = dist.sample(generator=generator)

        assert list(sample.shape) == [image.shape[0], 4] + [i // 8 for i in image.shape[2:]]

        output_slice = sample[0, -1, -3:, -3:].flatten().cpu()
        expected_output_slice = torch.tensor(expected_slice)

        tolerance = 3e-3 if torch_device != "mps" else 1e-2
        assert torch_all_close(output_slice, expected_output_slice, atol=tolerance)
Will Berman's avatar
Will Berman committed
797
798
799
800
801
802
803
804
805
806


@slow
class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
    def tearDown(self):
        # clean up the VRAM after each test
        super().tearDown()
        gc.collect()
        torch.cuda.empty_cache()

807
    @torch.no_grad()
Will Berman's avatar
Will Berman committed
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
    def test_encode_decode(self):
        vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder")  # TODO - update
        vae.to(torch_device)

        image = load_image(
            "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
            "/img2img/sketch-mountains-input.jpg"
        ).resize((256, 256))
        image = torch.from_numpy(np.array(image).transpose(2, 0, 1).astype(np.float32) / 127.5 - 1)[
            None, :, :, :
        ].cuda()

        latent = vae.encode(image).latent_dist.mean

        sample = vae.decode(latent, generator=torch.Generator("cpu").manual_seed(0)).sample

        actual_output = sample[0, :2, :2, :2].flatten().cpu()
        expected_output = torch.tensor([-0.0141, -0.0014, 0.0115, 0.0086, 0.1051, 0.1053, 0.1031, 0.1024])

        assert torch_all_close(actual_output, expected_output, atol=5e-3)

    def test_sd(self):
        vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder")  # TODO - update
        pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", vae=vae, safety_checker=None)
        pipe.to(torch_device)

        out = pipe(
            "horse", num_inference_steps=2, output_type="pt", generator=torch.Generator("cpu").manual_seed(0)
        ).images[0]

        actual_output = out[:2, :2, :2].flatten().cpu()
        expected_output = torch.tensor([0.7686, 0.8228, 0.6489, 0.7455, 0.8661, 0.8797, 0.8241, 0.8759])

        assert torch_all_close(actual_output, expected_output, atol=5e-3)

    def test_encode_decode_f16(self):
        vae = ConsistencyDecoderVAE.from_pretrained(
            "openai/consistency-decoder", torch_dtype=torch.float16
        )  # TODO - update
        vae.to(torch_device)

        image = load_image(
            "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
            "/img2img/sketch-mountains-input.jpg"
        ).resize((256, 256))
        image = (
            torch.from_numpy(np.array(image).transpose(2, 0, 1).astype(np.float32) / 127.5 - 1)[None, :, :, :]
            .half()
            .cuda()
        )

        latent = vae.encode(image).latent_dist.mean

        sample = vae.decode(latent, generator=torch.Generator("cpu").manual_seed(0)).sample

        actual_output = sample[0, :2, :2, :2].flatten().cpu()
        expected_output = torch.tensor(
            [-0.0111, -0.0125, -0.0017, -0.0007, 0.1257, 0.1465, 0.1450, 0.1471], dtype=torch.float16
        )

        assert torch_all_close(actual_output, expected_output, atol=5e-3)

    def test_sd_f16(self):
        vae = ConsistencyDecoderVAE.from_pretrained(
            "openai/consistency-decoder", torch_dtype=torch.float16
        )  # TODO - update
        pipe = StableDiffusionPipeline.from_pretrained(
            "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, vae=vae, safety_checker=None
        )
        pipe.to(torch_device)

        out = pipe(
            "horse", num_inference_steps=2, output_type="pt", generator=torch.Generator("cpu").manual_seed(0)
        ).images[0]

        actual_output = out[:2, :2, :2].flatten().cpu()
        expected_output = torch.tensor(
            [0.0000, 0.0249, 0.0000, 0.0000, 0.1709, 0.2773, 0.0471, 0.1035], dtype=torch.float16
        )

        assert torch_all_close(actual_output, expected_output, atol=5e-3)