test_unclip_image_variation.py 17 KB
Newer Older
Will Berman's avatar
Will Berman committed
1
# coding=utf-8
2
# Copyright 2024 HuggingFace Inc.
Will Berman's avatar
Will Berman committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
import random
import unittest

import numpy as np
import torch
22
23
24
25
26
27
28
29
from transformers import (
    CLIPImageProcessor,
    CLIPTextConfig,
    CLIPTextModelWithProjection,
    CLIPTokenizer,
    CLIPVisionConfig,
    CLIPVisionModelWithProjection,
)
Will Berman's avatar
Will Berman committed
30

31
32
33
34
35
36
37
from diffusers import (
    DiffusionPipeline,
    UnCLIPImageVariationPipeline,
    UnCLIPScheduler,
    UNet2DConditionModel,
    UNet2DModel,
)
Will Berman's avatar
Will Berman committed
38
from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel
Dhruv Nair's avatar
Dhruv Nair committed
39
40
41
42
43
from diffusers.utils.testing_utils import (
    enable_full_determinism,
    floats_tensor,
    load_image,
    load_numpy,
44
    nightly,
Dhruv Nair's avatar
Dhruv Nair committed
45
46
47
48
    require_torch_gpu,
    skip_mps,
    torch_device,
)
Will Berman's avatar
Will Berman committed
49

50
51
from ..pipeline_params import IMAGE_VARIATION_BATCH_PARAMS, IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
Will Berman's avatar
Will Berman committed
52
53


54
enable_full_determinism()
55
56


57
58
class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
    pipeline_class = UnCLIPImageVariationPipeline
59
60
    params = IMAGE_VARIATION_PARAMS - {"height", "width", "guidance_scale"}
    batch_params = IMAGE_VARIATION_BATCH_PARAMS
Will Berman's avatar
Will Berman committed
61

62
63
64
65
66
67
    required_optional_params = [
        "generator",
        "return_dict",
        "decoder_num_inference_steps",
        "super_res_num_inference_steps",
    ]
68
    test_xformers_attention = False
Marc Sun's avatar
Marc Sun committed
69
    supports_dduf = False
Will Berman's avatar
Will Berman committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144

    @property
    def text_embedder_hidden_size(self):
        return 32

    @property
    def time_input_dim(self):
        return 32

    @property
    def block_out_channels_0(self):
        return self.time_input_dim

    @property
    def time_embed_dim(self):
        return self.time_input_dim * 4

    @property
    def cross_attention_dim(self):
        return 100

    @property
    def dummy_tokenizer(self):
        tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
        return tokenizer

    @property
    def dummy_text_encoder(self):
        torch.manual_seed(0)
        config = CLIPTextConfig(
            bos_token_id=0,
            eos_token_id=2,
            hidden_size=self.text_embedder_hidden_size,
            projection_dim=self.text_embedder_hidden_size,
            intermediate_size=37,
            layer_norm_eps=1e-05,
            num_attention_heads=4,
            num_hidden_layers=5,
            pad_token_id=1,
            vocab_size=1000,
        )
        return CLIPTextModelWithProjection(config)

    @property
    def dummy_image_encoder(self):
        torch.manual_seed(0)
        config = CLIPVisionConfig(
            hidden_size=self.text_embedder_hidden_size,
            projection_dim=self.text_embedder_hidden_size,
            num_hidden_layers=5,
            num_attention_heads=4,
            image_size=32,
            intermediate_size=37,
            patch_size=1,
        )
        return CLIPVisionModelWithProjection(config)

    @property
    def dummy_text_proj(self):
        torch.manual_seed(0)

        model_kwargs = {
            "clip_embeddings_dim": self.text_embedder_hidden_size,
            "time_embed_dim": self.time_embed_dim,
            "cross_attention_dim": self.cross_attention_dim,
        }

        model = UnCLIPTextProjModel(**model_kwargs)
        return model

    @property
    def dummy_decoder(self):
        torch.manual_seed(0)

        model_kwargs = {
145
            "sample_size": 32,
Will Berman's avatar
Will Berman committed
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
            # RGB in channels
            "in_channels": 3,
            # Out channels is double in channels because predicts mean and variance
            "out_channels": 6,
            "down_block_types": ("ResnetDownsampleBlock2D", "SimpleCrossAttnDownBlock2D"),
            "up_block_types": ("SimpleCrossAttnUpBlock2D", "ResnetUpsampleBlock2D"),
            "mid_block_type": "UNetMidBlock2DSimpleCrossAttn",
            "block_out_channels": (self.block_out_channels_0, self.block_out_channels_0 * 2),
            "layers_per_block": 1,
            "cross_attention_dim": self.cross_attention_dim,
            "attention_head_dim": 4,
            "resnet_time_scale_shift": "scale_shift",
            "class_embed_type": "identity",
        }

        model = UNet2DConditionModel(**model_kwargs)
        return model

    @property
    def dummy_super_res_kwargs(self):
        return {
167
            "sample_size": 64,
Will Berman's avatar
Will Berman committed
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
            "layers_per_block": 1,
            "down_block_types": ("ResnetDownsampleBlock2D", "ResnetDownsampleBlock2D"),
            "up_block_types": ("ResnetUpsampleBlock2D", "ResnetUpsampleBlock2D"),
            "block_out_channels": (self.block_out_channels_0, self.block_out_channels_0 * 2),
            "in_channels": 6,
            "out_channels": 3,
        }

    @property
    def dummy_super_res_first(self):
        torch.manual_seed(0)

        model = UNet2DModel(**self.dummy_super_res_kwargs)
        return model

    @property
    def dummy_super_res_last(self):
        # seeded differently to get different unet than `self.dummy_super_res_first`
        torch.manual_seed(1)

        model = UNet2DModel(**self.dummy_super_res_kwargs)
        return model

191
    def get_dummy_components(self):
Will Berman's avatar
Will Berman committed
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
        decoder = self.dummy_decoder
        text_proj = self.dummy_text_proj
        text_encoder = self.dummy_text_encoder
        tokenizer = self.dummy_tokenizer
        super_res_first = self.dummy_super_res_first
        super_res_last = self.dummy_super_res_last

        decoder_scheduler = UnCLIPScheduler(
            variance_type="learned_range",
            prediction_type="epsilon",
            num_train_timesteps=1000,
        )

        super_res_scheduler = UnCLIPScheduler(
            variance_type="fixed_small_log",
            prediction_type="epsilon",
            num_train_timesteps=1000,
        )

        feature_extractor = CLIPImageProcessor(crop_size=32, size=32)

        image_encoder = self.dummy_image_encoder

215
216
217
218
219
220
221
222
223
224
225
226
        return {
            "decoder": decoder,
            "text_encoder": text_encoder,
            "tokenizer": tokenizer,
            "text_proj": text_proj,
            "feature_extractor": feature_extractor,
            "image_encoder": image_encoder,
            "super_res_first": super_res_first,
            "super_res_last": super_res_last,
            "decoder_scheduler": decoder_scheduler,
            "super_res_scheduler": super_res_scheduler,
        }
Will Berman's avatar
Will Berman committed
227

228
    def get_dummy_inputs(self, device, seed=0, pil_image=True):
Will Berman's avatar
Will Berman committed
229
        input_image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
230
231
232
233
        if str(device).startswith("mps"):
            generator = torch.manual_seed(seed)
        else:
            generator = torch.Generator(device=device).manual_seed(seed)
Will Berman's avatar
Will Berman committed
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251

        if pil_image:
            input_image = input_image * 0.5 + 0.5
            input_image = input_image.clamp(0, 1)
            input_image = input_image.cpu().permute(0, 2, 3, 1).float().numpy()
            input_image = DiffusionPipeline.numpy_to_pil(input_image)[0]

        return {
            "image": input_image,
            "generator": generator,
            "decoder_num_inference_steps": 2,
            "super_res_num_inference_steps": 2,
            "output_type": "np",
        }

    def test_unclip_image_variation_input_tensor(self):
        device = "cpu"

252
253
254
255
256
257
        components = self.get_dummy_components()

        pipe = self.pipeline_class(**components)
        pipe = pipe.to(device)

        pipe.set_progress_bar_config(disable=None)
Will Berman's avatar
Will Berman committed
258

259
        pipeline_inputs = self.get_dummy_inputs(device, pil_image=False)
Will Berman's avatar
Will Berman committed
260
261
262
263

        output = pipe(**pipeline_inputs)
        image = output.images

264
        tuple_pipeline_inputs = self.get_dummy_inputs(device, pil_image=False)
Will Berman's avatar
Will Berman committed
265
266
267
268
269
270
271
272
273

        image_from_tuple = pipe(
            **tuple_pipeline_inputs,
            return_dict=False,
        )[0]

        image_slice = image[0, -3:, -3:, -1]
        image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]

274
        assert image.shape == (1, 64, 64, 3)
Will Berman's avatar
Will Berman committed
275
276
277
278

        expected_slice = np.array(
            [
                0.9997,
279
280
281
282
283
284
285
286
                0.0002,
                0.9997,
                0.9997,
                0.9969,
                0.0023,
                0.9997,
                0.9969,
                0.9970,
Will Berman's avatar
Will Berman committed
287
288
289
290
291
292
293
294
295
            ]
        )

        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
        assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2

    def test_unclip_image_variation_input_image(self):
        device = "cpu"

296
        components = self.get_dummy_components()
Will Berman's avatar
Will Berman committed
297

298
299
300
301
302
303
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(device)

        pipe.set_progress_bar_config(disable=None)

        pipeline_inputs = self.get_dummy_inputs(device, pil_image=True)
Will Berman's avatar
Will Berman committed
304
305
306
307

        output = pipe(**pipeline_inputs)
        image = output.images

308
        tuple_pipeline_inputs = self.get_dummy_inputs(device, pil_image=True)
Will Berman's avatar
Will Berman committed
309
310
311
312
313
314
315
316
317

        image_from_tuple = pipe(
            **tuple_pipeline_inputs,
            return_dict=False,
        )[0]

        image_slice = image[0, -3:, -3:, -1]
        image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]

318
        assert image.shape == (1, 64, 64, 3)
Will Berman's avatar
Will Berman committed
319

320
        expected_slice = np.array([0.9997, 0.0003, 0.9997, 0.9997, 0.9970, 0.0024, 0.9997, 0.9971, 0.9971])
Will Berman's avatar
Will Berman committed
321
322
323
324
325
326
327

        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
        assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2

    def test_unclip_image_variation_input_list_images(self):
        device = "cpu"

328
        components = self.get_dummy_components()
Will Berman's avatar
Will Berman committed
329

330
331
332
333
334
335
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(device)

        pipe.set_progress_bar_config(disable=None)

        pipeline_inputs = self.get_dummy_inputs(device, pil_image=True)
Will Berman's avatar
Will Berman committed
336
337
338
339
340
341
342
343
        pipeline_inputs["image"] = [
            pipeline_inputs["image"],
            pipeline_inputs["image"],
        ]

        output = pipe(**pipeline_inputs)
        image = output.images

344
        tuple_pipeline_inputs = self.get_dummy_inputs(device, pil_image=True)
Will Berman's avatar
Will Berman committed
345
346
347
348
349
350
351
352
353
354
355
356
357
        tuple_pipeline_inputs["image"] = [
            tuple_pipeline_inputs["image"],
            tuple_pipeline_inputs["image"],
        ]

        image_from_tuple = pipe(
            **tuple_pipeline_inputs,
            return_dict=False,
        )[0]

        image_slice = image[0, -3:, -3:, -1]
        image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]

358
        assert image.shape == (2, 64, 64, 3)
Will Berman's avatar
Will Berman committed
359
360
361
362

        expected_slice = np.array(
            [
                0.9997,
363
364
365
366
367
368
369
370
                0.9989,
                0.0008,
                0.0021,
                0.9960,
                0.0018,
                0.0014,
                0.0002,
                0.9933,
Will Berman's avatar
Will Berman committed
371
372
373
374
375
376
            ]
        )

        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
        assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2

377
378
379
380
381
382
    def test_unclip_passed_image_embed(self):
        device = torch.device("cpu")

        class DummyScheduler:
            init_noise_sigma = 1

383
384
385
386
387
388
        components = self.get_dummy_components()

        pipe = self.pipeline_class(**components)
        pipe = pipe.to(device)

        pipe.set_progress_bar_config(disable=None)
389
390
391
392
393

        generator = torch.Generator(device=device).manual_seed(0)
        dtype = pipe.decoder.dtype
        batch_size = 1

394
395
396
397
398
399
        shape = (
            batch_size,
            pipe.decoder.config.in_channels,
            pipe.decoder.config.sample_size,
            pipe.decoder.config.sample_size,
        )
400
401
402
403
404
405
        decoder_latents = pipe.prepare_latents(
            shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
        )

        shape = (
            batch_size,
406
407
408
            pipe.super_res_first.config.in_channels // 2,
            pipe.super_res_first.config.sample_size,
            pipe.super_res_first.config.sample_size,
409
        )
410
        generator = torch.Generator(device=device).manual_seed(0)
411
412
413
414
        super_res_latents = pipe.prepare_latents(
            shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
        )

415
        pipeline_inputs = self.get_dummy_inputs(device, pil_image=False)
416
417
418
419
420

        img_out_1 = pipe(
            **pipeline_inputs, decoder_latents=decoder_latents, super_res_latents=super_res_latents
        ).images

421
        pipeline_inputs = self.get_dummy_inputs(device, pil_image=False)
422
423
424
425
426
427
428
429
430
431
432
433
434
435
        # Don't pass image, instead pass embedding
        image = pipeline_inputs.pop("image")
        image_embeddings = pipe.image_encoder(image).image_embeds

        img_out_2 = pipe(
            **pipeline_inputs,
            decoder_latents=decoder_latents,
            super_res_latents=super_res_latents,
            image_embeddings=image_embeddings,
        ).images

        # make sure passing text embeddings manually is identical
        assert np.abs(img_out_1 - img_out_2).max() < 1e-4

436
437
    # Overriding PipelineTesterMixin::test_attention_slicing_forward_pass
    # because UnCLIP GPU undeterminism requires a looser check.
438
    @skip_mps
439
440
441
    def test_attention_slicing_forward_pass(self):
        test_max_difference = torch_device == "cpu"

442
443
444
445
446
447
        # Check is relaxed because there is not a torch 2.0 sliced attention added kv processor
        expected_max_diff = 1e-2

        self._test_attention_slicing_forward_pass(
            test_max_difference=test_max_difference, expected_max_diff=expected_max_diff
        )
448
449
450

    # Overriding PipelineTesterMixin::test_inference_batch_single_identical
    # because UnCLIP undeterminism requires a looser check.
451
    @unittest.skip("UnCLIP produces very large differences. Test is not useful.")
452
    @skip_mps
453
    def test_inference_batch_single_identical(self):
454
455
456
457
        additional_params_copy_to_batched_inputs = [
            "decoder_num_inference_steps",
            "super_res_num_inference_steps",
        ]
458
        self._test_inference_batch_single_identical(
459
            additional_params_copy_to_batched_inputs=additional_params_copy_to_batched_inputs, expected_max_diff=5e-3
460
461
462
        )

    def test_inference_batch_consistent(self):
463
464
465
466
467
        additional_params_copy_to_batched_inputs = [
            "decoder_num_inference_steps",
            "super_res_num_inference_steps",
        ]

468
469
470
        if torch_device == "mps":
            # TODO: MPS errors with larger batch sizes
            batch_sizes = [2, 3]
471
472
473
474
            self._test_inference_batch_consistent(
                batch_sizes=batch_sizes,
                additional_params_copy_to_batched_inputs=additional_params_copy_to_batched_inputs,
            )
475
        else:
476
477
478
            self._test_inference_batch_consistent(
                additional_params_copy_to_batched_inputs=additional_params_copy_to_batched_inputs
            )
479

480
    @skip_mps
481
482
483
    def test_dict_tuple_outputs_equivalent(self):
        return super().test_dict_tuple_outputs_equivalent()

484
    @unittest.skip("UnCLIP produces very large difference. Test is not useful.")
485
    @skip_mps
486
    def test_save_load_local(self):
487
        return super().test_save_load_local(expected_max_difference=4e-3)
488

489
    @skip_mps
490
491
492
    def test_save_load_optional_components(self):
        return super().test_save_load_optional_components()

493
494
495
496
    @unittest.skip("UnCLIP produces very large difference in fp16 vs fp32. Test is not useful.")
    def test_float16_inference(self):
        super().test_float16_inference(expected_max_diff=1.0)

Will Berman's avatar
Will Berman committed
497

498
@nightly
Will Berman's avatar
Will Berman committed
499
500
@require_torch_gpu
class UnCLIPImageVariationPipelineIntegrationTests(unittest.TestCase):
501
502
503
504
505
506
    def setUp(self):
        # clean up the VRAM before each test
        super().setUp()
        gc.collect()
        torch.cuda.empty_cache()

Will Berman's avatar
Will Berman committed
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
    def tearDown(self):
        # clean up the VRAM after each test
        super().tearDown()
        gc.collect()
        torch.cuda.empty_cache()

    def test_unclip_image_variation_karlo(self):
        input_image = load_image(
            "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/unclip/cat.png"
        )
        expected_image = load_numpy(
            "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
            "/unclip/karlo_v1_alpha_cat_variation_fp16.npy"
        )

522
        pipeline = UnCLIPImageVariationPipeline.from_pretrained(
523
            "kakaobrain/karlo-v1-alpha-image-variations", torch_dtype=torch.float16
524
        )
Will Berman's avatar
Will Berman committed
525
526
527
        pipeline = pipeline.to(torch_device)
        pipeline.set_progress_bar_config(disable=None)

528
        generator = torch.Generator(device="cpu").manual_seed(0)
Will Berman's avatar
Will Berman committed
529
530
531
532
533
534
        output = pipeline(
            input_image,
            generator=generator,
            output_type="np",
        )

535
        image = output.images[0]
Will Berman's avatar
Will Berman committed
536
537

        assert image.shape == (256, 256, 3)
538

539
        assert_mean_pixel_difference(image, expected_image, 15)