test_stable_diffusion_inpaint.py 21.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
import random
import unittest

import numpy as np
import torch

from diffusers import (
    AutoencoderKL,
25
    LMSDiscreteScheduler,
26
27
28
29
    PNDMScheduler,
    StableDiffusionInpaintPipeline,
    UNet2DConditionModel,
)
30
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import prepare_mask_and_masked_image
Patrick von Platen's avatar
Patrick von Platen committed
31
from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device
32
from diffusers.utils.testing_utils import require_torch_gpu
33
34
35
from PIL import Image
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer

36
37
from ...test_pipelines_common import PipelineTesterMixin

38
39
40
41

torch.backends.cuda.matmul.allow_tf32 = False


42
class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
43
    pipeline_class = StableDiffusionInpaintPipeline
44

45
    def get_dummy_components(self):
46
        torch.manual_seed(0)
47
        unet = UNet2DConditionModel(
48
49
50
51
52
53
54
55
56
            block_out_channels=(32, 64),
            layers_per_block=2,
            sample_size=32,
            in_channels=9,
            out_channels=4,
            down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
            up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
            cross_attention_dim=32,
        )
57
        scheduler = PNDMScheduler(skip_prk_steps=True)
58
        torch.manual_seed(0)
59
        vae = AutoencoderKL(
60
61
62
63
64
65
66
67
            block_out_channels=[32, 64],
            in_channels=3,
            out_channels=3,
            down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
            up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
            latent_channels=4,
        )
        torch.manual_seed(0)
68
        text_encoder_config = CLIPTextConfig(
69
70
71
72
73
74
75
76
77
78
            bos_token_id=0,
            eos_token_id=2,
            hidden_size=32,
            intermediate_size=37,
            layer_norm_eps=1e-05,
            num_attention_heads=4,
            num_hidden_layers=5,
            pad_token_id=1,
            vocab_size=1000,
        )
79
        text_encoder = CLIPTextModel(text_encoder_config)
80
81
        tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")

82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
        components = {
            "unet": unet,
            "scheduler": scheduler,
            "vae": vae,
            "text_encoder": text_encoder,
            "tokenizer": tokenizer,
            "safety_checker": None,
            "feature_extractor": None,
        }
        return components

    def get_dummy_inputs(self, device, seed=0):
        # TODO: use tensor inputs instead of PIL, this is here just to leave the old expected_slices untouched
        image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
        image = image.cpu().permute(0, 2, 3, 1)[0]
Patrick von Platen's avatar
Patrick von Platen committed
97
98
        init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
        mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((64, 64))
99
100
101
102
103
104
105
106
107
108
109
110
111
112
        if str(device).startswith("mps"):
            generator = torch.manual_seed(seed)
        else:
            generator = torch.Generator(device=device).manual_seed(seed)
        inputs = {
            "prompt": "A painting of a squirrel eating a burger",
            "image": init_image,
            "mask_image": mask_image,
            "generator": generator,
            "num_inference_steps": 2,
            "guidance_scale": 6.0,
            "output_type": "numpy",
        }
        return inputs
113

114
115
116
117
    def test_stable_diffusion_inpaint(self):
        device = "cpu"  # ensure determinism for the device-dependent torch.Generator
        components = self.get_dummy_components()
        sd_pipe = StableDiffusionInpaintPipeline(**components)
118
119
120
        sd_pipe = sd_pipe.to(device)
        sd_pipe.set_progress_bar_config(disable=None)

121
122
        inputs = self.get_dummy_inputs(device)
        image = sd_pipe(**inputs).images
123
124
        image_slice = image[0, -3:, -3:, -1]

125
        assert image.shape == (1, 64, 64, 3)
Patrick von Platen's avatar
Patrick von Platen committed
126
        expected_slice = np.array([0.4723, 0.5731, 0.3939, 0.5441, 0.5922, 0.4392, 0.5059, 0.4651, 0.4474])
127
128
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

129
130
    def test_stable_diffusion_inpaint_image_tensor(self):
        device = "cpu"  # ensure determinism for the device-dependent torch.Generator
131
132
        components = self.get_dummy_components()
        sd_pipe = StableDiffusionInpaintPipeline(**components)
133
134
135
        sd_pipe = sd_pipe.to(device)
        sd_pipe.set_progress_bar_config(disable=None)

136
137
138
        inputs = self.get_dummy_inputs(device)
        output = sd_pipe(**inputs)
        out_pil = output.images
139

140
141
142
143
144
        inputs = self.get_dummy_inputs(device)
        inputs["image"] = torch.tensor(np.array(inputs["image"]) / 127.5 - 1).permute(2, 0, 1).unsqueeze(0)
        inputs["mask_image"] = torch.tensor(np.array(inputs["mask_image"]) / 255).permute(2, 0, 1)[:1].unsqueeze(0)
        output = sd_pipe(**inputs)
        out_tensor = output.images
145

146
147
        assert out_pil.shape == (1, 64, 64, 3)
        assert np.abs(out_pil.flatten() - out_tensor.flatten()).max() < 5e-2
148

149
150
    def test_stable_diffusion_inpaint_with_num_images_per_prompt(self):
        device = "cpu"
151
152
        components = self.get_dummy_components()
        sd_pipe = StableDiffusionInpaintPipeline(**components)
153
154
155
        sd_pipe = sd_pipe.to(device)
        sd_pipe.set_progress_bar_config(disable=None)

156
157
        inputs = self.get_dummy_inputs(device)
        images = sd_pipe(**inputs, num_images_per_prompt=2).images
158
159
160
161

        # check if the output is a list of 2 images
        assert len(images) == 2

162
163

@slow
164
165
@require_torch_gpu
class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
    def tearDown(self):
        # clean up the VRAM after each test
        super().tearDown()
        gc.collect()
        torch.cuda.empty_cache()

    def test_stable_diffusion_inpaint_pipeline(self):
        init_image = load_image(
            "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
            "/in_paint/overture-creations-5sI6fQgYIuo.png"
        )
        mask_image = load_image(
            "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
            "/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
        )
181
182
183
        expected_image = load_numpy(
            "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/in_paint"
            "/yellow_cat_sitting_on_a_park_bench.npy"
184
185
186
        )

        model_id = "runwayml/stable-diffusion-inpainting"
187
        pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None)
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
        pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        pipe.enable_attention_slicing()

        prompt = "Face of a yellow cat, high resolution, sitting on a park bench"

        generator = torch.Generator(device=torch_device).manual_seed(0)
        output = pipe(
            prompt=prompt,
            image=init_image,
            mask_image=mask_image,
            generator=generator,
            output_type="np",
        )
        image = output.images[0]

        assert image.shape == (512, 512, 3)
205
        assert np.abs(expected_image - image).max() < 1e-3
206
207
208
209
210
211
212
213
214
215

    def test_stable_diffusion_inpaint_pipeline_fp16(self):
        init_image = load_image(
            "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
            "/in_paint/overture-creations-5sI6fQgYIuo.png"
        )
        mask_image = load_image(
            "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
            "/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
        )
216
217
218
        expected_image = load_numpy(
            "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/in_paint"
            "/yellow_cat_sitting_on_a_park_bench_fp16.npy"
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
        )

        model_id = "runwayml/stable-diffusion-inpainting"
        pipe = StableDiffusionInpaintPipeline.from_pretrained(
            model_id,
            revision="fp16",
            torch_dtype=torch.float16,
            safety_checker=None,
        )
        pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        pipe.enable_attention_slicing()

        prompt = "Face of a yellow cat, high resolution, sitting on a park bench"

        generator = torch.Generator(device=torch_device).manual_seed(0)
        output = pipe(
            prompt=prompt,
            image=init_image,
            mask_image=mask_image,
            generator=generator,
            output_type="np",
        )
        image = output.images[0]

        assert image.shape == (512, 512, 3)
Patrick von Platen's avatar
Patrick von Platen committed
245
        assert np.abs(expected_image - image).max() < 5e-1
246
247
248
249
250
251
252
253
254
255

    def test_stable_diffusion_inpaint_pipeline_pndm(self):
        init_image = load_image(
            "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
            "/in_paint/overture-creations-5sI6fQgYIuo.png"
        )
        mask_image = load_image(
            "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
            "/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
        )
256
257
258
        expected_image = load_numpy(
            "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/in_paint"
            "/yellow_cat_sitting_on_a_park_bench_pndm.npy"
259
260
261
        )

        model_id = "runwayml/stable-diffusion-inpainting"
262
        pndm = PNDMScheduler.from_pretrained(model_id, subfolder="scheduler")
263
        pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None, scheduler=pndm)
264
265
266
267
268
269
270
271
272
        pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        pipe.enable_attention_slicing()

        prompt = "Face of a yellow cat, high resolution, sitting on a park bench"

        generator = torch.Generator(device=torch_device).manual_seed(0)
        output = pipe(
            prompt=prompt,
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
            image=init_image,
            mask_image=mask_image,
            generator=generator,
            output_type="np",
        )
        image = output.images[0]

        assert image.shape == (512, 512, 3)
        assert np.abs(expected_image - image).max() < 1e-2

    def test_stable_diffusion_inpaint_pipeline_k_lms(self):
        init_image = load_image(
            "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
            "/in_paint/overture-creations-5sI6fQgYIuo.png"
        )
        mask_image = load_image(
            "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
            "/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
        )
        expected_image = load_numpy(
            "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/in_paint"
            "/yellow_cat_sitting_on_a_park_bench_k_lms.npy"
        )

        model_id = "runwayml/stable-diffusion-inpainting"
        pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None)
        pipe.to(torch_device)

        # switch to LMS
        pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)

        pipe.set_progress_bar_config(disable=None)
        pipe.enable_attention_slicing()

        prompt = "Face of a yellow cat, high resolution, sitting on a park bench"

        generator = torch.Generator(device=torch_device).manual_seed(0)
        output = pipe(
            prompt=prompt,
312
313
314
315
316
317
318
319
            image=init_image,
            mask_image=mask_image,
            generator=generator,
            output_type="np",
        )
        image = output.images[0]

        assert image.shape == (512, 512, 3)
320
321
322
323
324
325
        assert np.abs(expected_image - image).max() < 1e-2

    @unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
    def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
        torch.cuda.empty_cache()
        torch.cuda.reset_max_memory_allocated()
Anton Lozhkov's avatar
Anton Lozhkov committed
326
        torch.cuda.reset_peak_memory_stats()
327
328
329
330
331
332
333
334
335
336
337

        init_image = load_image(
            "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
            "/in_paint/overture-creations-5sI6fQgYIuo.png"
        )
        mask_image = load_image(
            "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
            "/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
        )

        model_id = "runwayml/stable-diffusion-inpainting"
338
        pndm = PNDMScheduler.from_pretrained(model_id, subfolder="scheduler")
339
        pipe = StableDiffusionInpaintPipeline.from_pretrained(
Anton Lozhkov's avatar
Anton Lozhkov committed
340
341
342
343
344
345
            model_id,
            safety_checker=None,
            scheduler=pndm,
            device_map="auto",
            revision="fp16",
            torch_dtype=torch.float16,
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
        )
        pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        pipe.enable_attention_slicing(1)
        pipe.enable_sequential_cpu_offload()

        prompt = "Face of a yellow cat, high resolution, sitting on a park bench"

        generator = torch.Generator(device=torch_device).manual_seed(0)
        _ = pipe(
            prompt=prompt,
            image=init_image,
            mask_image=mask_image,
            generator=generator,
            num_inference_steps=5,
            output_type="np",
        )

        mem_bytes = torch.cuda.max_memory_allocated()
Anton Lozhkov's avatar
Anton Lozhkov committed
365
366
        # make sure that less than 2.2 GB is allocated
        assert mem_bytes < 2.2 * 10**9
367

Patrick von Platen's avatar
Patrick von Platen committed
368

369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase):
    def test_pil_inputs(self):
        im = np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8)
        im = Image.fromarray(im)
        mask = np.random.randint(0, 255, (32, 32), dtype=np.uint8) > 127.5
        mask = Image.fromarray((mask * 255).astype(np.uint8))

        t_mask, t_masked = prepare_mask_and_masked_image(im, mask)

        self.assertTrue(isinstance(t_mask, torch.Tensor))
        self.assertTrue(isinstance(t_masked, torch.Tensor))

        self.assertEqual(t_mask.ndim, 4)
        self.assertEqual(t_masked.ndim, 4)

        self.assertEqual(t_mask.shape, (1, 1, 32, 32))
        self.assertEqual(t_masked.shape, (1, 3, 32, 32))

        self.assertTrue(t_mask.dtype == torch.float32)
        self.assertTrue(t_masked.dtype == torch.float32)

        self.assertTrue(t_mask.min() >= 0.0)
        self.assertTrue(t_mask.max() <= 1.0)
        self.assertTrue(t_masked.min() >= -1.0)
        self.assertTrue(t_masked.min() <= 1.0)

        self.assertTrue(t_mask.sum() > 0.0)

    def test_np_inputs(self):
        im_np = np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8)
        im_pil = Image.fromarray(im_np)
        mask_np = np.random.randint(0, 255, (32, 32), dtype=np.uint8) > 127.5
        mask_pil = Image.fromarray((mask_np * 255).astype(np.uint8))

        t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)
        t_mask_pil, t_masked_pil = prepare_mask_and_masked_image(im_pil, mask_pil)

        self.assertTrue((t_mask_np == t_mask_pil).all())
        self.assertTrue((t_masked_np == t_masked_pil).all())

    def test_torch_3D_2D_inputs(self):
        im_tensor = torch.randint(0, 255, (3, 32, 32), dtype=torch.uint8)
        mask_tensor = torch.randint(0, 255, (32, 32), dtype=torch.uint8) > 127.5
        im_np = im_tensor.numpy().transpose(1, 2, 0)
        mask_np = mask_tensor.numpy()

        t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
        t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)

        self.assertTrue((t_mask_tensor == t_mask_np).all())
        self.assertTrue((t_masked_tensor == t_masked_np).all())

    def test_torch_3D_3D_inputs(self):
        im_tensor = torch.randint(0, 255, (3, 32, 32), dtype=torch.uint8)
        mask_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) > 127.5
        im_np = im_tensor.numpy().transpose(1, 2, 0)
        mask_np = mask_tensor.numpy()[0]

        t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
        t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)

        self.assertTrue((t_mask_tensor == t_mask_np).all())
        self.assertTrue((t_masked_tensor == t_masked_np).all())

    def test_torch_4D_2D_inputs(self):
        im_tensor = torch.randint(0, 255, (1, 3, 32, 32), dtype=torch.uint8)
        mask_tensor = torch.randint(0, 255, (32, 32), dtype=torch.uint8) > 127.5
        im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
        mask_np = mask_tensor.numpy()

        t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
        t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)

        self.assertTrue((t_mask_tensor == t_mask_np).all())
        self.assertTrue((t_masked_tensor == t_masked_np).all())

    def test_torch_4D_3D_inputs(self):
        im_tensor = torch.randint(0, 255, (1, 3, 32, 32), dtype=torch.uint8)
        mask_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) > 127.5
        im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
        mask_np = mask_tensor.numpy()[0]

        t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
        t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)

        self.assertTrue((t_mask_tensor == t_mask_np).all())
        self.assertTrue((t_masked_tensor == t_masked_np).all())

    def test_torch_4D_4D_inputs(self):
        im_tensor = torch.randint(0, 255, (1, 3, 32, 32), dtype=torch.uint8)
        mask_tensor = torch.randint(0, 255, (1, 1, 32, 32), dtype=torch.uint8) > 127.5
        im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
        mask_np = mask_tensor.numpy()[0][0]

        t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
        t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)

        self.assertTrue((t_mask_tensor == t_mask_np).all())
        self.assertTrue((t_masked_tensor == t_masked_np).all())

    def test_torch_batch_4D_3D(self):
        im_tensor = torch.randint(0, 255, (2, 3, 32, 32), dtype=torch.uint8)
        mask_tensor = torch.randint(0, 255, (2, 32, 32), dtype=torch.uint8) > 127.5

        im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor]
        mask_nps = [mask.numpy() for mask in mask_tensor]

        t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
        nps = [prepare_mask_and_masked_image(i, m) for i, m in zip(im_nps, mask_nps)]
        t_mask_np = torch.cat([n[0] for n in nps])
        t_masked_np = torch.cat([n[1] for n in nps])

        self.assertTrue((t_mask_tensor == t_mask_np).all())
        self.assertTrue((t_masked_tensor == t_masked_np).all())

    def test_torch_batch_4D_4D(self):
        im_tensor = torch.randint(0, 255, (2, 3, 32, 32), dtype=torch.uint8)
        mask_tensor = torch.randint(0, 255, (2, 1, 32, 32), dtype=torch.uint8) > 127.5

        im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor]
        mask_nps = [mask.numpy()[0] for mask in mask_tensor]

        t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
        nps = [prepare_mask_and_masked_image(i, m) for i, m in zip(im_nps, mask_nps)]
        t_mask_np = torch.cat([n[0] for n in nps])
        t_masked_np = torch.cat([n[1] for n in nps])

        self.assertTrue((t_mask_tensor == t_mask_np).all())
        self.assertTrue((t_masked_tensor == t_masked_np).all())

    def test_shape_mismatch(self):
        # test height and width
        with self.assertRaises(AssertionError):
            prepare_mask_and_masked_image(torch.randn(3, 32, 32), torch.randn(64, 64))
        # test batch dim
        with self.assertRaises(AssertionError):
            prepare_mask_and_masked_image(torch.randn(2, 3, 32, 32), torch.randn(4, 64, 64))
        # test batch dim
        with self.assertRaises(AssertionError):
            prepare_mask_and_masked_image(torch.randn(2, 3, 32, 32), torch.randn(4, 1, 64, 64))

    def test_type_mismatch(self):
        # test tensors-only
        with self.assertRaises(TypeError):
            prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.rand(3, 32, 32).numpy())
        # test tensors-only
        with self.assertRaises(TypeError):
            prepare_mask_and_masked_image(torch.rand(3, 32, 32).numpy(), torch.rand(3, 32, 32))

    def test_channels_first(self):
        # test channels first for 3D tensors
        with self.assertRaises(AssertionError):
            prepare_mask_and_masked_image(torch.rand(32, 32, 3), torch.rand(3, 32, 32))

    def test_tensor_range(self):
        # test im <= 1
        with self.assertRaises(ValueError):
            prepare_mask_and_masked_image(torch.ones(3, 32, 32) * 2, torch.rand(32, 32))
        # test im >= -1
        with self.assertRaises(ValueError):
            prepare_mask_and_masked_image(torch.ones(3, 32, 32) * (-2), torch.rand(32, 32))
        # test mask <= 1
        with self.assertRaises(ValueError):
            prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.ones(32, 32) * 2)
        # test mask >= 0
        with self.assertRaises(ValueError):
Patrick von Platen's avatar
Patrick von Platen committed
535
            prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.ones(32, 32) * -1)