"docs/references/deepseek_usage.md" did not exist on "c0bb9eb3b3fb603247648a5e7a2ab29822f05440"
test_pipelines.py 31 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
import gc
17
import json
18
import os
19
import random
20
import shutil
21
22
23
24
25
26
27
28
import tempfile
import unittest

import numpy as np
import torch

import PIL
from diffusers import (
29
    AutoencoderKL,
30
31
32
33
    DDIMPipeline,
    DDIMScheduler,
    DDPMPipeline,
    DDPMScheduler,
34
35
36
37
    DPMSolverMultistepScheduler,
    EulerAncestralDiscreteScheduler,
    EulerDiscreteScheduler,
    LMSDiscreteScheduler,
38
    PNDMScheduler,
39
    StableDiffusionImg2ImgPipeline,
40
    StableDiffusionInpaintPipelineLegacy,
41
    StableDiffusionPipeline,
42
    UNet2DConditionModel,
43
    UNet2DModel,
44
    logging,
45
46
)
from diffusers.pipeline_utils import DiffusionPipeline
47
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
48
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, slow, torch_device
49
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, require_torch_gpu
50
from parameterized import parameterized
51
from PIL import Image
Patrick von Platen's avatar
Patrick von Platen committed
52
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
53
54
55
56
57


torch.backends.cuda.matmul.allow_tf32 = False


hysts's avatar
hysts committed
58
59
60
61
62
63
64
65
66
67
68
69
70
def test_progress_bar(capsys):
    model = UNet2DModel(
        block_out_channels=(32, 64),
        layers_per_block=2,
        sample_size=32,
        in_channels=3,
        out_channels=3,
        down_block_types=("DownBlock2D", "AttnDownBlock2D"),
        up_block_types=("AttnUpBlock2D", "UpBlock2D"),
    )
    scheduler = DDPMScheduler(num_train_timesteps=10)

    ddpm = DDPMPipeline(model, scheduler).to(torch_device)
71
    ddpm(output_type="numpy").images
hysts's avatar
hysts committed
72
73
74
75
    captured = capsys.readouterr()
    assert "10/10" in captured.err, "Progress bar has to be displayed"

    ddpm.set_progress_bar_config(disable=True)
76
    ddpm(output_type="numpy").images
hysts's avatar
hysts committed
77
78
79
80
    captured = capsys.readouterr()
    assert captured.err == "", "Progress bar should be disabled"


81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
class DownloadTests(unittest.TestCase):
    def test_download_only_pytorch(self):
        with tempfile.TemporaryDirectory() as tmpdirname:
            # pipeline has Flax weights
            _ = DiffusionPipeline.from_pretrained(
                "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname
            )

            all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))]
            files = [item for sublist in all_root_files for item in sublist]

            # None of the downloaded files should be a flax file even if we have some here:
            # https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe/blob/main/unet/diffusion_flax_model.msgpack
            assert not any(f.endswith(".msgpack") for f in files)

96
97
98
99
100
    def test_download_no_safety_checker(self):
        prompt = "hello"
        pipe = StableDiffusionPipeline.from_pretrained(
            "hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
        )
101
102
103
104
105
106
        pipe = pipe.to(torch_device)
        if torch_device == "mps":
            # device type MPS is not supported for torch.Generator() api.
            generator = torch.manual_seed(0)
        else:
            generator = torch.Generator(device=torch_device).manual_seed(0)
107
108
109
        out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images

        pipe_2 = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch")
110
        pipe_2 = pipe_2.to(torch_device)
111
112
113
114
115
116
        if torch_device == "mps":
            # device type MPS is not supported for torch.Generator() api.
            generator = torch.manual_seed(0)
        else:
            generator = torch.Generator(device=torch_device).manual_seed(0)
        out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
117
118
119
120
121
122
123
124

        assert np.max(np.abs(out - out_2)) < 1e-3

    def test_load_no_safety_checker_explicit_locally(self):
        prompt = "hello"
        pipe = StableDiffusionPipeline.from_pretrained(
            "hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
        )
125
126
127
128
129
130
        pipe = pipe.to(torch_device)
        if torch_device == "mps":
            # device type MPS is not supported for torch.Generator() api.
            generator = torch.manual_seed(0)
        else:
            generator = torch.Generator(device=torch_device).manual_seed(0)
131
132
133
134
135
        out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images

        with tempfile.TemporaryDirectory() as tmpdirname:
            pipe.save_pretrained(tmpdirname)
            pipe_2 = StableDiffusionPipeline.from_pretrained(tmpdirname, safety_checker=None)
136
            pipe_2 = pipe_2.to(torch_device)
137
138
139
140
141
142
143
144

            if torch_device == "mps":
                # device type MPS is not supported for torch.Generator() api.
                generator = torch.manual_seed(0)
            else:
                generator = torch.Generator(device=torch_device).manual_seed(0)

            out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
145
146
147
148
149
150

        assert np.max(np.abs(out - out_2)) < 1e-3

    def test_load_no_safety_checker_default_locally(self):
        prompt = "hello"
        pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch")
151
152
153
154
155
156
        pipe = pipe.to(torch_device)
        if torch_device == "mps":
            # device type MPS is not supported for torch.Generator() api.
            generator = torch.manual_seed(0)
        else:
            generator = torch.Generator(device=torch_device).manual_seed(0)
157
158
159
160
161
        out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images

        with tempfile.TemporaryDirectory() as tmpdirname:
            pipe.save_pretrained(tmpdirname)
            pipe_2 = StableDiffusionPipeline.from_pretrained(tmpdirname)
162
            pipe_2 = pipe_2.to(torch_device)
163
164
165
166
167
168
169
170

            if torch_device == "mps":
                # device type MPS is not supported for torch.Generator() api.
                generator = torch.manual_seed(0)
            else:
                generator = torch.Generator(device=torch_device).manual_seed(0)

            out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
171
172
173

        assert np.max(np.abs(out - out_2)) < 1e-3

174

Patrick von Platen's avatar
Patrick von Platen committed
175
176
177
178
179
class CustomPipelineTests(unittest.TestCase):
    def test_load_custom_pipeline(self):
        pipeline = DiffusionPipeline.from_pretrained(
            "google/ddpm-cifar10-32", custom_pipeline="hf-internal-testing/diffusers-dummy-pipeline"
        )
180
        pipeline = pipeline.to(torch_device)
Patrick von Platen's avatar
Patrick von Platen committed
181
182
183
184
185
186
187
188
        # NOTE that `"CustomPipeline"` is not a class that is defined in this library, but solely on the Hub
        # under https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py#L24
        assert pipeline.__class__.__name__ == "CustomPipeline"

    def test_run_custom_pipeline(self):
        pipeline = DiffusionPipeline.from_pretrained(
            "google/ddpm-cifar10-32", custom_pipeline="hf-internal-testing/diffusers-dummy-pipeline"
        )
189
        pipeline = pipeline.to(torch_device)
Patrick von Platen's avatar
Patrick von Platen committed
190
191
192
        images, output_str = pipeline(num_inference_steps=2, output_type="np")

        assert images[0].shape == (1, 32, 32, 3)
193

Patrick von Platen's avatar
Patrick von Platen committed
194
195
196
        # compare output to https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py#L102
        assert output_str == "This is a test"

197
    def test_local_custom_pipeline_repo(self):
Patrick von Platen's avatar
Patrick von Platen committed
198
199
200
201
        local_custom_pipeline_path = get_tests_dir("fixtures/custom_pipeline")
        pipeline = DiffusionPipeline.from_pretrained(
            "google/ddpm-cifar10-32", custom_pipeline=local_custom_pipeline_path
        )
202
        pipeline = pipeline.to(torch_device)
Patrick von Platen's avatar
Patrick von Platen committed
203
204
205
206
207
208
209
        images, output_str = pipeline(num_inference_steps=2, output_type="np")

        assert pipeline.__class__.__name__ == "CustomLocalPipeline"
        assert images[0].shape == (1, 32, 32, 3)
        # compare to https://github.com/huggingface/diffusers/blob/main/tests/fixtures/custom_pipeline/pipeline.py#L102
        assert output_str == "This is a local test"

210
211
212
213
214
215
216
217
218
219
220
221
222
223
    def test_local_custom_pipeline_file(self):
        local_custom_pipeline_path = get_tests_dir("fixtures/custom_pipeline")
        local_custom_pipeline_path = os.path.join(local_custom_pipeline_path, "what_ever.py")
        pipeline = DiffusionPipeline.from_pretrained(
            "google/ddpm-cifar10-32", custom_pipeline=local_custom_pipeline_path
        )
        pipeline = pipeline.to(torch_device)
        images, output_str = pipeline(num_inference_steps=2, output_type="np")

        assert pipeline.__class__.__name__ == "CustomLocalPipeline"
        assert images[0].shape == (1, 32, 32, 3)
        # compare to https://github.com/huggingface/diffusers/blob/main/tests/fixtures/custom_pipeline/pipeline.py#L102
        assert output_str == "This is a local test"

Patrick von Platen's avatar
Patrick von Platen committed
224
    @slow
225
    @require_torch_gpu
Patrick von Platen's avatar
Patrick von Platen committed
226
227
228
    def test_load_pipeline_from_git(self):
        clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"

229
        feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id)
230
        clip_model = CLIPModel.from_pretrained(clip_model_id, torch_dtype=torch.float16)
Patrick von Platen's avatar
Patrick von Platen committed
231
232
233
234
235
236

        pipeline = DiffusionPipeline.from_pretrained(
            "CompVis/stable-diffusion-v1-4",
            custom_pipeline="clip_guided_stable_diffusion",
            clip_model=clip_model,
            feature_extractor=feature_extractor,
237
238
            torch_dtype=torch.float16,
            revision="fp16",
Patrick von Platen's avatar
Patrick von Platen committed
239
        )
240
        pipeline.enable_attention_slicing()
Patrick von Platen's avatar
Patrick von Platen committed
241
242
243
244
245
246
247
248
249
250
        pipeline = pipeline.to(torch_device)

        # NOTE that `"CLIPGuidedStableDiffusion"` is not a class that is defined in the pypi package of th e library, but solely on the community examples folder of GitHub under:
        # https://github.com/huggingface/diffusers/blob/main/examples/community/clip_guided_stable_diffusion.py
        assert pipeline.__class__.__name__ == "CLIPGuidedStableDiffusion"

        image = pipeline("a prompt", num_inference_steps=2, output_type="np").images[0]
        assert image.shape == (512, 512, 3)


251
252
253
254
255
256
257
258
259
class PipelineFastTests(unittest.TestCase):
    def dummy_image(self):
        batch_size = 1
        num_channels = 3
        sizes = (32, 32)

        image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device)
        return image

260
    def dummy_uncond_unet(self, sample_size=32):
261
262
263
264
        torch.manual_seed(0)
        model = UNet2DModel(
            block_out_channels=(32, 64),
            layers_per_block=2,
265
            sample_size=sample_size,
266
267
268
269
270
271
272
            in_channels=3,
            out_channels=3,
            down_block_types=("DownBlock2D", "AttnDownBlock2D"),
            up_block_types=("AttnUpBlock2D", "UpBlock2D"),
        )
        return model

273
    def dummy_cond_unet(self, sample_size=32):
274
275
276
277
        torch.manual_seed(0)
        model = UNet2DConditionModel(
            block_out_channels=(32, 64),
            layers_per_block=2,
278
            sample_size=sample_size,
279
280
281
282
283
284
285
286
            in_channels=4,
            out_channels=4,
            down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
            up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
            cross_attention_dim=32,
        )
        return model

287
    @property
288
289
290
291
292
293
294
295
296
297
298
299
    def dummy_vae(self):
        torch.manual_seed(0)
        model = AutoencoderKL(
            block_out_channels=[32, 64],
            in_channels=3,
            out_channels=3,
            down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
            up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
            latent_channels=4,
        )
        return model

300
    @property
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
    def dummy_text_encoder(self):
        torch.manual_seed(0)
        config = CLIPTextConfig(
            bos_token_id=0,
            eos_token_id=2,
            hidden_size=32,
            intermediate_size=37,
            layer_norm_eps=1e-05,
            num_attention_heads=4,
            num_hidden_layers=5,
            pad_token_id=1,
            vocab_size=1000,
        )
        return CLIPTextModel(config)

316
    @property
317
318
319
320
321
322
323
324
325
326
327
328
329
330
    def dummy_extractor(self):
        def extract(*args, **kwargs):
            class Out:
                def __init__(self):
                    self.pixel_values = torch.ones([0])

                def to(self, device):
                    self.pixel_values.to(device)
                    return self

            return Out()

        return extract

331
332
333
    @parameterized.expand(
        [
            [DDIMScheduler, DDIMPipeline, 32],
334
            [DDPMScheduler, DDPMPipeline, 32],
335
            [DDIMScheduler, DDIMPipeline, (32, 64)],
336
            [DDPMScheduler, DDPMPipeline, (64, 32)],
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
        ]
    )
    def test_uncond_unet_components(self, scheduler_fn=DDPMScheduler, pipeline_fn=DDPMPipeline, sample_size=32):
        unet = self.dummy_uncond_unet(sample_size)
        scheduler = scheduler_fn()
        pipeline = pipeline_fn(unet, scheduler).to(torch_device)

        # Device type MPS is not supported for torch.Generator() api.
        if torch_device == "mps":
            generator = torch.manual_seed(0)
        else:
            generator = torch.Generator(device=torch_device).manual_seed(0)

        out_image = pipeline(
            generator=generator,
            num_inference_steps=2,
            output_type="np",
        ).images
        sample_size = (sample_size, sample_size) if isinstance(sample_size, int) else sample_size
        assert out_image.shape == (1, *sample_size, 3)

    def test_stable_diffusion_components(self):
359
        """Test that components property works correctly"""
360
        unet = self.dummy_cond_unet()
361
        scheduler = PNDMScheduler(skip_prk_steps=True)
362
363
        vae = self.dummy_vae
        bert = self.dummy_text_encoder
364
365
        tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")

366
        image = self.dummy_image().cpu().permute(0, 2, 3, 1)[0]
367
        init_image = Image.fromarray(np.uint8(image)).convert("RGB")
Patrick von Platen's avatar
Patrick von Platen committed
368
        mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((32, 32))
369
370

        # make sure here that pndm scheduler skips prk
371
        inpaint = StableDiffusionInpaintPipelineLegacy(
372
373
374
375
376
            unet=unet,
            scheduler=scheduler,
            vae=vae,
            text_encoder=bert,
            tokenizer=tokenizer,
377
            safety_checker=None,
378
            feature_extractor=self.dummy_extractor,
379
380
381
        ).to(torch_device)
        img2img = StableDiffusionImg2ImgPipeline(**inpaint.components).to(torch_device)
        text2img = StableDiffusionPipeline(**inpaint.components).to(torch_device)
382
383

        prompt = "A painting of a squirrel eating a burger"
384
385
386
387
388
389
390

        # Device type MPS is not supported for torch.Generator() api.
        if torch_device == "mps":
            generator = torch.manual_seed(0)
        else:
            generator = torch.Generator(device=torch_device).manual_seed(0)

391
        image_inpaint = inpaint(
392
393
394
395
396
            [prompt],
            generator=generator,
            num_inference_steps=2,
            output_type="np",
            init_image=init_image,
397
398
399
            mask_image=mask_image,
        ).images
        image_img2img = img2img(
400
401
402
403
404
            [prompt],
            generator=generator,
            num_inference_steps=2,
            output_type="np",
            init_image=init_image,
405
406
407
        ).images
        image_text2img = text2img(
            [prompt],
408
409
410
            generator=generator,
            num_inference_steps=2,
            output_type="np",
411
        ).images
412

413
414
        assert image_inpaint.shape == (1, 32, 32, 3)
        assert image_img2img.shape == (1, 32, 32, 3)
415
        assert image_text2img.shape == (1, 64, 64, 3)
416

417
    def test_set_scheduler(self):
418
        unet = self.dummy_cond_unet()
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
        scheduler = PNDMScheduler(skip_prk_steps=True)
        vae = self.dummy_vae
        bert = self.dummy_text_encoder
        tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")

        sd = StableDiffusionPipeline(
            unet=unet,
            scheduler=scheduler,
            vae=vae,
            text_encoder=bert,
            tokenizer=tokenizer,
            safety_checker=None,
            feature_extractor=self.dummy_extractor,
        )

        sd.scheduler = DDIMScheduler.from_config(sd.scheduler.config)
        assert isinstance(sd.scheduler, DDIMScheduler)
        sd.scheduler = DDPMScheduler.from_config(sd.scheduler.config)
        assert isinstance(sd.scheduler, DDPMScheduler)
        sd.scheduler = PNDMScheduler.from_config(sd.scheduler.config)
        assert isinstance(sd.scheduler, PNDMScheduler)
        sd.scheduler = LMSDiscreteScheduler.from_config(sd.scheduler.config)
        assert isinstance(sd.scheduler, LMSDiscreteScheduler)
        sd.scheduler = EulerDiscreteScheduler.from_config(sd.scheduler.config)
        assert isinstance(sd.scheduler, EulerDiscreteScheduler)
        sd.scheduler = EulerAncestralDiscreteScheduler.from_config(sd.scheduler.config)
        assert isinstance(sd.scheduler, EulerAncestralDiscreteScheduler)
        sd.scheduler = DPMSolverMultistepScheduler.from_config(sd.scheduler.config)
        assert isinstance(sd.scheduler, DPMSolverMultistepScheduler)

    def test_set_scheduler_consistency(self):
450
        unet = self.dummy_cond_unet()
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
        pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")
        ddim = DDIMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")
        vae = self.dummy_vae
        bert = self.dummy_text_encoder
        tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")

        sd = StableDiffusionPipeline(
            unet=unet,
            scheduler=pndm,
            vae=vae,
            text_encoder=bert,
            tokenizer=tokenizer,
            safety_checker=None,
            feature_extractor=self.dummy_extractor,
        )

        pndm_config = sd.scheduler.config
        sd.scheduler = DDPMScheduler.from_config(pndm_config)
        sd.scheduler = PNDMScheduler.from_config(sd.scheduler.config)
        pndm_config_2 = sd.scheduler.config
        pndm_config_2 = {k: v for k, v in pndm_config_2.items() if k in pndm_config}

        assert dict(pndm_config) == dict(pndm_config_2)

        sd = StableDiffusionPipeline(
            unet=unet,
            scheduler=ddim,
            vae=vae,
            text_encoder=bert,
            tokenizer=tokenizer,
            safety_checker=None,
            feature_extractor=self.dummy_extractor,
        )

        ddim_config = sd.scheduler.config
        sd.scheduler = LMSDiscreteScheduler.from_config(ddim_config)
        sd.scheduler = DDIMScheduler.from_config(sd.scheduler.config)
        ddim_config_2 = sd.scheduler.config
        ddim_config_2 = {k: v for k, v in ddim_config_2.items() if k in ddim_config}

        assert dict(ddim_config) == dict(ddim_config_2)

493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
    def test_optional_components(self):
        unet = self.dummy_cond_unet()
        pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")
        vae = self.dummy_vae
        bert = self.dummy_text_encoder
        tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")

        orig_sd = StableDiffusionPipeline(
            unet=unet,
            scheduler=pndm,
            vae=vae,
            text_encoder=bert,
            tokenizer=tokenizer,
            safety_checker=unet,
            feature_extractor=self.dummy_extractor,
        )
        sd = orig_sd

        assert sd.config.requires_safety_checker is True

        with tempfile.TemporaryDirectory() as tmpdirname:
            sd.save_pretrained(tmpdirname)

            # Test that passing None works
            sd = StableDiffusionPipeline.from_pretrained(
                tmpdirname, feature_extractor=None, safety_checker=None, requires_safety_checker=False
            )

            assert sd.config.requires_safety_checker is False
            assert sd.config.safety_checker == (None, None)
            assert sd.config.feature_extractor == (None, None)

        with tempfile.TemporaryDirectory() as tmpdirname:
            sd.save_pretrained(tmpdirname)

            # Test that loading previous None works
            sd = StableDiffusionPipeline.from_pretrained(tmpdirname)

            assert sd.config.requires_safety_checker is False
            assert sd.config.safety_checker == (None, None)
            assert sd.config.feature_extractor == (None, None)

            orig_sd.save_pretrained(tmpdirname)

            # Test that loading without any directory works
            shutil.rmtree(os.path.join(tmpdirname, "safety_checker"))
            with open(os.path.join(tmpdirname, sd.config_name)) as f:
                config = json.load(f)
                config["safety_checker"] = [None, None]
            with open(os.path.join(tmpdirname, sd.config_name), "w") as f:
                json.dump(config, f)

            sd = StableDiffusionPipeline.from_pretrained(tmpdirname, requires_safety_checker=False)
            sd.save_pretrained(tmpdirname)
            sd = StableDiffusionPipeline.from_pretrained(tmpdirname)

            assert sd.config.requires_safety_checker is False
            assert sd.config.safety_checker == (None, None)
            assert sd.config.feature_extractor == (None, None)

            # Test that loading from deleted model index works
            with open(os.path.join(tmpdirname, sd.config_name)) as f:
                config = json.load(f)
                del config["safety_checker"]
                del config["feature_extractor"]
            with open(os.path.join(tmpdirname, sd.config_name), "w") as f:
                json.dump(config, f)

            sd = StableDiffusionPipeline.from_pretrained(tmpdirname)

            assert sd.config.requires_safety_checker is False
            assert sd.config.safety_checker == (None, None)
            assert sd.config.feature_extractor == (None, None)

        with tempfile.TemporaryDirectory() as tmpdirname:
            sd.save_pretrained(tmpdirname)

            # Test that partially loading works
            sd = StableDiffusionPipeline.from_pretrained(tmpdirname, feature_extractor=self.dummy_extractor)

            assert sd.config.requires_safety_checker is False
            assert sd.config.safety_checker == (None, None)
            assert sd.config.feature_extractor != (None, None)

            # Test that partially loading works
            sd = StableDiffusionPipeline.from_pretrained(
                tmpdirname,
                feature_extractor=self.dummy_extractor,
                safety_checker=unet,
                requires_safety_checker=[True, True],
            )

            assert sd.config.requires_safety_checker == [True, True]
            assert sd.config.safety_checker != (None, None)
            assert sd.config.feature_extractor != (None, None)

        with tempfile.TemporaryDirectory() as tmpdirname:
            sd.save_pretrained(tmpdirname)
            sd = StableDiffusionPipeline.from_pretrained(tmpdirname, feature_extractor=self.dummy_extractor)

            assert sd.config.requires_safety_checker == [True, True]
            assert sd.config.safety_checker != (None, None)
            assert sd.config.feature_extractor != (None, None)

597

598
599
@slow
class PipelineSlowTests(unittest.TestCase):
600
601
602
603
604
605
    def tearDown(self):
        # clean up the VRAM after each test
        super().tearDown()
        gc.collect()
        torch.cuda.empty_cache()

606
607
608
    def test_smart_download(self):
        model_id = "hf-internal-testing/unet-pipeline-dummy"
        with tempfile.TemporaryDirectory() as tmpdirname:
609
            _ = DiffusionPipeline.from_pretrained(model_id, cache_dir=tmpdirname, force_download=True)
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
            local_repo_name = "--".join(["models"] + model_id.split("/"))
            snapshot_dir = os.path.join(tmpdirname, local_repo_name, "snapshots")
            snapshot_dir = os.path.join(snapshot_dir, os.listdir(snapshot_dir)[0])

            # inspect all downloaded files to make sure that everything is included
            assert os.path.isfile(os.path.join(snapshot_dir, DiffusionPipeline.config_name))
            assert os.path.isfile(os.path.join(snapshot_dir, CONFIG_NAME))
            assert os.path.isfile(os.path.join(snapshot_dir, SCHEDULER_CONFIG_NAME))
            assert os.path.isfile(os.path.join(snapshot_dir, WEIGHTS_NAME))
            assert os.path.isfile(os.path.join(snapshot_dir, "scheduler", SCHEDULER_CONFIG_NAME))
            assert os.path.isfile(os.path.join(snapshot_dir, "unet", WEIGHTS_NAME))
            assert os.path.isfile(os.path.join(snapshot_dir, "unet", WEIGHTS_NAME))
            # let's make sure the super large numpy file:
            # https://huggingface.co/hf-internal-testing/unet-pipeline-dummy/blob/main/big_array.npy
            # is not downloaded, but all the expected ones
            assert not os.path.isfile(os.path.join(snapshot_dir, "big_array.npy"))

627
628
629
630
631
    def test_warning_unused_kwargs(self):
        model_id = "hf-internal-testing/unet-pipeline-dummy"
        logger = logging.get_logger("diffusers.pipeline_utils")
        with tempfile.TemporaryDirectory() as tmpdirname:
            with CaptureLogger(logger) as cap_logger:
632
                DiffusionPipeline.from_pretrained(
633
634
635
636
                    model_id,
                    not_used=True,
                    cache_dir=tmpdirname,
                    force_download=True,
637
                )
638
639
640

        assert cap_logger.out == "Keyword arguments {'not_used': True} not recognized.\n"

641
642
643
644
645
646
647
648
649
650
651
652
653
654
    def test_from_pretrained_save_pretrained(self):
        # 1. Load models
        model = UNet2DModel(
            block_out_channels=(32, 64),
            layers_per_block=2,
            sample_size=32,
            in_channels=3,
            out_channels=3,
            down_block_types=("DownBlock2D", "AttnDownBlock2D"),
            up_block_types=("AttnUpBlock2D", "UpBlock2D"),
        )
        schedular = DDPMScheduler(num_train_timesteps=10)

        ddpm = DDPMPipeline(model, schedular)
655
        ddpm.to(torch_device)
656
        ddpm.set_progress_bar_config(disable=None)
657
658
659

        with tempfile.TemporaryDirectory() as tmpdirname:
            ddpm.save_pretrained(tmpdirname)
660
            new_ddpm = DDPMPipeline.from_pretrained(tmpdirname)
661
            new_ddpm.to(torch_device)
662

663
        generator = torch.Generator(device=torch_device).manual_seed(0)
664
        image = ddpm(generator=generator, output_type="numpy").images
665

666
        generator = generator.manual_seed(0)
667
        new_image = new_ddpm(generator=generator, output_type="numpy").images
668
669
670
671
672
673

        assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"

    def test_from_pretrained_hub(self):
        model_path = "google/ddpm-cifar10-32"

674
        scheduler = DDPMScheduler(num_train_timesteps=10)
675

676
        ddpm = DDPMPipeline.from_pretrained(model_path, scheduler=scheduler)
677
        ddpm = ddpm.to(torch_device)
678
        ddpm.set_progress_bar_config(disable=None)
679

680
        ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler)
681
        ddpm_from_hub = ddpm_from_hub.to(torch_device)
682
        ddpm_from_hub.set_progress_bar_config(disable=None)
683

684
        generator = torch.Generator(device=torch_device).manual_seed(0)
685
        image = ddpm(generator=generator, output_type="numpy").images
686

687
        generator = generator.manual_seed(0)
688
        new_image = ddpm_from_hub(generator=generator, output_type="numpy").images
689
690
691
692
693
694

        assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"

    def test_from_pretrained_hub_pass_model(self):
        model_path = "google/ddpm-cifar10-32"

695
696
        scheduler = DDPMScheduler(num_train_timesteps=10)

697
        # pass unet into DiffusionPipeline
698
699
        unet = UNet2DModel.from_pretrained(model_path)
        ddpm_from_hub_custom_model = DiffusionPipeline.from_pretrained(model_path, unet=unet, scheduler=scheduler)
700
        ddpm_from_hub_custom_model = ddpm_from_hub_custom_model.to(torch_device)
701
        ddpm_from_hub_custom_model.set_progress_bar_config(disable=None)
702

703
        ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler)
704
        ddpm_from_hub = ddpm_from_hub.to(torch_device)
705
        ddpm_from_hub_custom_model.set_progress_bar_config(disable=None)
706

707
        generator = torch.Generator(device=torch_device).manual_seed(0)
708
        image = ddpm_from_hub_custom_model(generator=generator, output_type="numpy").images
709

710
        generator = generator.manual_seed(0)
711
        new_image = ddpm_from_hub(generator=generator, output_type="numpy").images
712
713
714
715
716
717

        assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"

    def test_output_format(self):
        model_path = "google/ddpm-cifar10-32"

718
        scheduler = DDIMScheduler.from_pretrained(model_path)
Patrick von Platen's avatar
Patrick von Platen committed
719
        pipe = DDIMPipeline.from_pretrained(model_path, scheduler=scheduler)
720
        pipe.to(torch_device)
721
        pipe.set_progress_bar_config(disable=None)
722

723
        generator = torch.Generator(device=torch_device).manual_seed(0)
724
        images = pipe(generator=generator, output_type="numpy").images
725
726
727
        assert images.shape == (1, 32, 32, 3)
        assert isinstance(images, np.ndarray)

Patrick von Platen's avatar
Patrick von Platen committed
728
        images = pipe(generator=generator, output_type="pil", num_inference_steps=4).images
729
730
731
732
733
        assert isinstance(images, list)
        assert len(images) == 1
        assert isinstance(images[0], PIL.Image.Image)

        # use PIL by default
Patrick von Platen's avatar
Patrick von Platen committed
734
        images = pipe(generator=generator, num_inference_steps=4).images
735
736
737
        assert isinstance(images, list)
        assert isinstance(images[0], PIL.Image.Image)

738
739
    def test_ddpm_ddim_equality_batched(self):
        seed = 0
740
        model_id = "google/ddpm-cifar10-32"
741

742
        unet = UNet2DModel.from_pretrained(model_id)
743
744
        ddpm_scheduler = DDPMScheduler()
        ddim_scheduler = DDIMScheduler()
745

746
747
748
        ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler)
        ddpm.to(torch_device)
        ddpm.set_progress_bar_config(disable=None)
749

750
751
752
        ddim = DDIMPipeline(unet=unet, scheduler=ddim_scheduler)
        ddim.to(torch_device)
        ddim.set_progress_bar_config(disable=None)
753

754
755
        generator = torch.Generator(device=torch_device).manual_seed(seed)
        ddpm_images = ddpm(batch_size=2, generator=generator, output_type="numpy").images
756

757
        generator = torch.Generator(device=torch_device).manual_seed(seed)
758
        ddim_images = ddim(
759
            batch_size=2,
760
761
762
763
764
            generator=generator,
            num_inference_steps=1000,
            eta=1.0,
            output_type="numpy",
            use_clipped_model_output=True,  # Need this to make DDIM match DDPM
765
        ).images
766

767
768
        # the values aren't exactly equal, but the images look the same visually
        assert np.abs(ddpm_images - ddim_images).max() < 1e-1