"vscode:/vscode.git/clone" did not exist on "7673cd3f6aa29f7d76b13862bc5c7e761d808b2b"
test_pipelines.py 16.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tempfile
import unittest

import numpy as np
import torch

import PIL
from diffusers import (
    DDIMPipeline,
    DDIMScheduler,
    DDPMPipeline,
    DDPMScheduler,
    KarrasVePipeline,
    KarrasVeScheduler,
    LDMPipeline,
    LDMTextToImagePipeline,
    LMSDiscreteScheduler,
    PNDMPipeline,
    PNDMScheduler,
    ScoreSdeVePipeline,
    ScoreSdeVeScheduler,
    StableDiffusionPipeline,
    UNet2DModel,
)
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.testing_utils import slow, torch_device


torch.backends.cuda.matmul.allow_tf32 = False


hysts's avatar
hysts committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def test_progress_bar(capsys):
    model = UNet2DModel(
        block_out_channels=(32, 64),
        layers_per_block=2,
        sample_size=32,
        in_channels=3,
        out_channels=3,
        down_block_types=("DownBlock2D", "AttnDownBlock2D"),
        up_block_types=("AttnUpBlock2D", "UpBlock2D"),
    )
    scheduler = DDPMScheduler(num_train_timesteps=10)

    ddpm = DDPMPipeline(model, scheduler).to(torch_device)
    ddpm(output_type="numpy")["sample"]
    captured = capsys.readouterr()
    assert "10/10" in captured.err, "Progress bar has to be displayed"

    ddpm.set_progress_bar_config(disable=True)
    ddpm(output_type="numpy")["sample"]
    captured = capsys.readouterr()
    assert captured.err == "", "Progress bar should be disabled"


70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
class PipelineTesterMixin(unittest.TestCase):
    def test_from_pretrained_save_pretrained(self):
        # 1. Load models
        model = UNet2DModel(
            block_out_channels=(32, 64),
            layers_per_block=2,
            sample_size=32,
            in_channels=3,
            out_channels=3,
            down_block_types=("DownBlock2D", "AttnDownBlock2D"),
            up_block_types=("AttnUpBlock2D", "UpBlock2D"),
        )
        schedular = DDPMScheduler(num_train_timesteps=10)

        ddpm = DDPMPipeline(model, schedular)
85
        ddpm.to(torch_device)
86
87
88
89

        with tempfile.TemporaryDirectory() as tmpdirname:
            ddpm.save_pretrained(tmpdirname)
            new_ddpm = DDPMPipeline.from_pretrained(tmpdirname)
90
            new_ddpm.to(torch_device)
91
92
93
94
95
96
97
98
99
100
101
102
103

        generator = torch.manual_seed(0)

        image = ddpm(generator=generator, output_type="numpy")["sample"]
        generator = generator.manual_seed(0)
        new_image = new_ddpm(generator=generator, output_type="numpy")["sample"]

        assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"

    @slow
    def test_from_pretrained_hub(self):
        model_path = "google/ddpm-cifar10-32"

104
        scheduler = DDPMScheduler(num_train_timesteps=10)
105

106
107
108
109
        ddpm = DDPMPipeline.from_pretrained(model_path, scheduler=scheduler)
        ddpm.to(torch_device)
        ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler)
        ddpm_from_hub.to(torch_device)
110
111
112
113
114
115
116
117
118
119
120
121
122

        generator = torch.manual_seed(0)

        image = ddpm(generator=generator, output_type="numpy")["sample"]
        generator = generator.manual_seed(0)
        new_image = ddpm_from_hub(generator=generator, output_type="numpy")["sample"]

        assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"

    @slow
    def test_from_pretrained_hub_pass_model(self):
        model_path = "google/ddpm-cifar10-32"

123
124
        scheduler = DDPMScheduler(num_train_timesteps=10)

125
126
        # pass unet into DiffusionPipeline
        unet = UNet2DModel.from_pretrained(model_path)
127
128
        ddpm_from_hub_custom_model = DiffusionPipeline.from_pretrained(model_path, unet=unet, scheduler=scheduler)
        ddpm_from_hub_custom_model.to(torch_device)
129

130
131
        ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler)
        ddpm_from_hub.to(torch_device)
132
133
134
135
136
137
138
139
140
141
142
143
144
145

        generator = torch.manual_seed(0)

        image = ddpm_from_hub_custom_model(generator=generator, output_type="numpy")["sample"]
        generator = generator.manual_seed(0)
        new_image = ddpm_from_hub(generator=generator, output_type="numpy")["sample"]

        assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"

    @slow
    def test_output_format(self):
        model_path = "google/ddpm-cifar10-32"

        pipe = DDIMPipeline.from_pretrained(model_path)
146
        pipe.to(torch_device)
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171

        generator = torch.manual_seed(0)
        images = pipe(generator=generator, output_type="numpy")["sample"]
        assert images.shape == (1, 32, 32, 3)
        assert isinstance(images, np.ndarray)

        images = pipe(generator=generator, output_type="pil")["sample"]
        assert isinstance(images, list)
        assert len(images) == 1
        assert isinstance(images[0], PIL.Image.Image)

        # use PIL by default
        images = pipe(generator=generator)["sample"]
        assert isinstance(images, list)
        assert isinstance(images[0], PIL.Image.Image)

    @slow
    def test_ddpm_cifar10(self):
        model_id = "google/ddpm-cifar10-32"

        unet = UNet2DModel.from_pretrained(model_id)
        scheduler = DDPMScheduler.from_config(model_id)
        scheduler = scheduler.set_format("pt")

        ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
172
        ddpm.to(torch_device)
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190

        generator = torch.manual_seed(0)
        image = ddpm(generator=generator, output_type="numpy")["sample"]

        image_slice = image[0, -3:, -3:, -1]

        assert image.shape == (1, 32, 32, 3)
        expected_slice = np.array([0.41995, 0.35885, 0.19385, 0.38475, 0.3382, 0.2647, 0.41545, 0.3582, 0.33845])
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

    @slow
    def test_ddim_lsun(self):
        model_id = "google/ddpm-ema-bedroom-256"

        unet = UNet2DModel.from_pretrained(model_id)
        scheduler = DDIMScheduler.from_config(model_id)

        ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
191
        ddpm.to(torch_device)
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209

        generator = torch.manual_seed(0)
        image = ddpm(generator=generator, output_type="numpy")["sample"]

        image_slice = image[0, -3:, -3:, -1]

        assert image.shape == (1, 256, 256, 3)
        expected_slice = np.array([0.00605, 0.0201, 0.0344, 0.00235, 0.00185, 0.00025, 0.00215, 0.0, 0.00685])
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

    @slow
    def test_ddim_cifar10(self):
        model_id = "google/ddpm-cifar10-32"

        unet = UNet2DModel.from_pretrained(model_id)
        scheduler = DDIMScheduler(tensor_format="pt")

        ddim = DDIMPipeline(unet=unet, scheduler=scheduler)
210
        ddim.to(torch_device)
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228

        generator = torch.manual_seed(0)
        image = ddim(generator=generator, eta=0.0, output_type="numpy")["sample"]

        image_slice = image[0, -3:, -3:, -1]

        assert image.shape == (1, 32, 32, 3)
        expected_slice = np.array([0.17235, 0.16175, 0.16005, 0.16255, 0.1497, 0.1513, 0.15045, 0.1442, 0.1453])
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

    @slow
    def test_pndm_cifar10(self):
        model_id = "google/ddpm-cifar10-32"

        unet = UNet2DModel.from_pretrained(model_id)
        scheduler = PNDMScheduler(tensor_format="pt")

        pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
229
        pndm.to(torch_device)
230
231
232
233
234
235
236
237
238
239
240
241
        generator = torch.manual_seed(0)
        image = pndm(generator=generator, output_type="numpy")["sample"]

        image_slice = image[0, -3:, -3:, -1]

        assert image.shape == (1, 32, 32, 3)
        expected_slice = np.array([0.1564, 0.14645, 0.1406, 0.14715, 0.12425, 0.14045, 0.13115, 0.12175, 0.125])
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

    @slow
    def test_ldm_text2img(self):
        ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
242
        ldm.to(torch_device)
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258

        prompt = "A painting of a squirrel eating a burger"
        generator = torch.manual_seed(0)
        image = ldm([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="numpy")[
            "sample"
        ]

        image_slice = image[0, -3:, -3:, -1]

        assert image.shape == (1, 256, 256, 3)
        expected_slice = np.array([0.9256, 0.9340, 0.8933, 0.9361, 0.9113, 0.8727, 0.9122, 0.8745, 0.8099])
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

    @slow
    def test_ldm_text2img_fast(self):
        ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
259
        ldm.to(torch_device)
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326

        prompt = "A painting of a squirrel eating a burger"
        generator = torch.manual_seed(0)
        image = ldm(prompt, generator=generator, num_inference_steps=1, output_type="numpy")["sample"]

        image_slice = image[0, -3:, -3:, -1]

        assert image.shape == (1, 256, 256, 3)
        expected_slice = np.array([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344])
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

    @slow
    @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
    def test_stable_diffusion(self):
        # make sure here that pndm scheduler skips prk
        sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1").to(torch_device)

        prompt = "A painting of a squirrel eating a burger"
        generator = torch.Generator(device=torch_device).manual_seed(0)
        with torch.autocast("cuda"):
            output = sd_pipe(
                [prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="np"
            )

        image = output["sample"]

        image_slice = image[0, -3:, -3:, -1]

        assert image.shape == (1, 512, 512, 3)
        expected_slice = np.array([0.8887, 0.915, 0.91, 0.894, 0.909, 0.912, 0.919, 0.925, 0.883])
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

    @slow
    @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
    def test_stable_diffusion_fast_ddim(self):
        sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1").to(torch_device)

        scheduler = DDIMScheduler(
            beta_start=0.00085,
            beta_end=0.012,
            beta_schedule="scaled_linear",
            clip_sample=False,
            set_alpha_to_one=False,
        )
        sd_pipe.scheduler = scheduler

        prompt = "A painting of a squirrel eating a burger"
        generator = torch.Generator(device=torch_device).manual_seed(0)

        with torch.autocast("cuda"):
            output = sd_pipe([prompt], generator=generator, num_inference_steps=2, output_type="numpy")
        image = output["sample"]

        image_slice = image[0, -3:, -3:, -1]

        assert image.shape == (1, 512, 512, 3)
        expected_slice = np.array([0.8354, 0.83, 0.866, 0.838, 0.8315, 0.867, 0.836, 0.8584, 0.869])
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3

    @slow
    def test_score_sde_ve_pipeline(self):
        model_id = "google/ncsnpp-church-256"
        model = UNet2DModel.from_pretrained(model_id)

        scheduler = ScoreSdeVeScheduler.from_config(model_id)

        sde_ve = ScoreSdeVePipeline(unet=model, scheduler=scheduler)
327
        sde_ve.to(torch_device)
328
329
330
331
332
333
334
335
336
337
338
339
340
341

        torch.manual_seed(0)
        image = sde_ve(num_inference_steps=300, output_type="numpy")["sample"]

        image_slice = image[0, -3:, -3:, -1]

        assert image.shape == (1, 256, 256, 3)

        expected_slice = np.array([0.64363, 0.5868, 0.3031, 0.2284, 0.7409, 0.3216, 0.25643, 0.6557, 0.2633])
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

    @slow
    def test_ldm_uncond(self):
        ldm = LDMPipeline.from_pretrained("CompVis/ldm-celebahq-256")
342
        ldm.to(torch_device)
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361

        generator = torch.manual_seed(0)
        image = ldm(generator=generator, num_inference_steps=5, output_type="numpy")["sample"]

        image_slice = image[0, -3:, -3:, -1]

        assert image.shape == (1, 256, 256, 3)
        expected_slice = np.array([0.4399, 0.44975, 0.46825, 0.474, 0.4359, 0.4581, 0.45095, 0.4341, 0.4447])
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

    @slow
    def test_ddpm_ddim_equality(self):
        model_id = "google/ddpm-cifar10-32"

        unet = UNet2DModel.from_pretrained(model_id)
        ddpm_scheduler = DDPMScheduler(tensor_format="pt")
        ddim_scheduler = DDIMScheduler(tensor_format="pt")

        ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler)
362
        ddpm.to(torch_device)
363
        ddim = DDIMPipeline(unet=unet, scheduler=ddim_scheduler)
364
        ddim.to(torch_device)
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383

        generator = torch.manual_seed(0)
        ddpm_image = ddpm(generator=generator, output_type="numpy")["sample"]

        generator = torch.manual_seed(0)
        ddim_image = ddim(generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy")["sample"]

        # the values aren't exactly equal, but the images look the same visually
        assert np.abs(ddpm_image - ddim_image).max() < 1e-1

    @unittest.skip("(Anton) The test is failing for large batch sizes, needs investigation")
    def test_ddpm_ddim_equality_batched(self):
        model_id = "google/ddpm-cifar10-32"

        unet = UNet2DModel.from_pretrained(model_id)
        ddpm_scheduler = DDPMScheduler(tensor_format="pt")
        ddim_scheduler = DDIMScheduler(tensor_format="pt")

        ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler)
384
385
        ddpm.to(torch_device)

386
        ddim = DDIMPipeline(unet=unet, scheduler=ddim_scheduler)
387
        ddim.to(torch_device)
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406

        generator = torch.manual_seed(0)
        ddpm_images = ddpm(batch_size=4, generator=generator, output_type="numpy")["sample"]

        generator = torch.manual_seed(0)
        ddim_images = ddim(batch_size=4, generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy")[
            "sample"
        ]

        # the values aren't exactly equal, but the images look the same visually
        assert np.abs(ddpm_images - ddim_images).max() < 1e-1

    @slow
    def test_karras_ve_pipeline(self):
        model_id = "google/ncsnpp-celebahq-256"
        model = UNet2DModel.from_pretrained(model_id)
        scheduler = KarrasVeScheduler(tensor_format="pt")

        pipe = KarrasVePipeline(unet=model, scheduler=scheduler)
407
        pipe.to(torch_device)
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434

        generator = torch.manual_seed(0)
        image = pipe(num_inference_steps=20, generator=generator, output_type="numpy")["sample"]

        image_slice = image[0, -3:, -3:, -1]
        assert image.shape == (1, 256, 256, 3)
        expected_slice = np.array([0.26815, 0.1581, 0.2658, 0.23248, 0.1550, 0.2539, 0.1131, 0.1024, 0.0837])
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

    @slow
    @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
    def test_lms_stable_diffusion_pipeline(self):
        model_id = "CompVis/stable-diffusion-v1-1"
        pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True).to(torch_device)
        scheduler = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler", use_auth_token=True)
        pipe.scheduler = scheduler

        prompt = "a photograph of an astronaut riding a horse"
        generator = torch.Generator(device=torch_device).manual_seed(0)
        image = pipe([prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy")[
            "sample"
        ]

        image_slice = image[0, -3:, -3:, -1]
        assert image.shape == (1, 512, 512, 3)
        expected_slice = np.array([0.9077, 0.9254, 0.9181, 0.9227, 0.9213, 0.9367, 0.9399, 0.9406, 0.9024])
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2