test_stable_diffusion_inpaint.py 34.8 KB
Newer Older
1
# coding=utf-8
Patrick von Platen's avatar
Patrick von Platen committed
2
# Copyright 2023 HuggingFace Inc.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
import random
18
import traceback
19
20
21
22
import unittest

import numpy as np
import torch
23
24
from PIL import Image
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
25
26
27

from diffusers import (
    AutoencoderKL,
28
    DPMSolverMultistepScheduler,
29
    LMSDiscreteScheduler,
30
31
32
33
    PNDMScheduler,
    StableDiffusionInpaintPipeline,
    UNet2DConditionModel,
)
34
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import prepare_mask_and_masked_image
35
from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device
36
37
38
39
40
41
from diffusers.utils.testing_utils import (
    enable_full_determinism,
    require_torch_2,
    require_torch_gpu,
    run_test_in_subprocess,
)
42

43
from ...models.test_models_unet_2d_condition import create_lora_layers
44
from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
45
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
46

47

48
enable_full_determinism()
49
50


51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# Will be run via run_test_in_subprocess
def _test_inpaint_compile(in_queue, out_queue, timeout):
    error = None
    try:
        inputs = in_queue.get(timeout=timeout)
        torch_device = inputs.pop("torch_device")
        seed = inputs.pop("seed")
        inputs["generator"] = torch.Generator(device=torch_device).manual_seed(seed)

        pipe = StableDiffusionInpaintPipeline.from_pretrained(
            "runwayml/stable-diffusion-inpainting", safety_checker=None
        )
        pipe.scheduler = PNDMScheduler.from_config(pipe.scheduler.config)
        pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)

        pipe.unet.to(memory_format=torch.channels_last)
        pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)

        image = pipe(**inputs).images
        image_slice = image[0, 253:256, 253:256, -1].flatten()

        assert image.shape == (1, 512, 512, 3)
        expected_slice = np.array([0.0425, 0.0273, 0.0344, 0.1694, 0.1727, 0.1812, 0.3256, 0.3311, 0.3272])

        assert np.abs(expected_slice - image_slice).max() < 3e-3
    except Exception:
        error = f"{traceback.format_exc()}"

    results = {"error": error}
    out_queue.put(results, timeout=timeout)
    out_queue.join()


85
class StableDiffusionInpaintPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
86
    pipeline_class = StableDiffusionInpaintPipeline
87
88
    params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
    batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
89
90
    image_params = frozenset([])
    # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
91

92
    def get_dummy_components(self):
93
        torch.manual_seed(0)
94
        unet = UNet2DConditionModel(
95
96
97
98
99
100
101
102
103
            block_out_channels=(32, 64),
            layers_per_block=2,
            sample_size=32,
            in_channels=9,
            out_channels=4,
            down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
            up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
            cross_attention_dim=32,
        )
104
        scheduler = PNDMScheduler(skip_prk_steps=True)
105
        torch.manual_seed(0)
106
        vae = AutoencoderKL(
107
108
109
110
111
112
113
114
            block_out_channels=[32, 64],
            in_channels=3,
            out_channels=3,
            down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
            up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
            latent_channels=4,
        )
        torch.manual_seed(0)
115
        text_encoder_config = CLIPTextConfig(
116
117
118
119
120
121
122
123
124
125
            bos_token_id=0,
            eos_token_id=2,
            hidden_size=32,
            intermediate_size=37,
            layer_norm_eps=1e-05,
            num_attention_heads=4,
            num_hidden_layers=5,
            pad_token_id=1,
            vocab_size=1000,
        )
126
        text_encoder = CLIPTextModel(text_encoder_config)
127
128
        tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")

129
130
131
132
133
134
135
        components = {
            "unet": unet,
            "scheduler": scheduler,
            "vae": vae,
            "text_encoder": text_encoder,
            "tokenizer": tokenizer,
            "safety_checker": None,
136
            "feature_extractor": None,
137
138
139
140
141
142
143
        }
        return components

    def get_dummy_inputs(self, device, seed=0):
        # TODO: use tensor inputs instead of PIL, this is here just to leave the old expected_slices untouched
        image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
        image = image.cpu().permute(0, 2, 3, 1)[0]
Patrick von Platen's avatar
Patrick von Platen committed
144
145
        init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
        mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((64, 64))
146
147
148
149
150
151
152
153
154
155
156
157
158
159
        if str(device).startswith("mps"):
            generator = torch.manual_seed(seed)
        else:
            generator = torch.Generator(device=device).manual_seed(seed)
        inputs = {
            "prompt": "A painting of a squirrel eating a burger",
            "image": init_image,
            "mask_image": mask_image,
            "generator": generator,
            "num_inference_steps": 2,
            "guidance_scale": 6.0,
            "output_type": "numpy",
        }
        return inputs
160

161
162
163
164
    def test_stable_diffusion_inpaint(self):
        device = "cpu"  # ensure determinism for the device-dependent torch.Generator
        components = self.get_dummy_components()
        sd_pipe = StableDiffusionInpaintPipeline(**components)
165
166
167
        sd_pipe = sd_pipe.to(device)
        sd_pipe.set_progress_bar_config(disable=None)

168
169
        inputs = self.get_dummy_inputs(device)
        image = sd_pipe(**inputs).images
170
171
        image_slice = image[0, -3:, -3:, -1]

172
        assert image.shape == (1, 64, 64, 3)
Patrick von Platen's avatar
Patrick von Platen committed
173
        expected_slice = np.array([0.4723, 0.5731, 0.3939, 0.5441, 0.5922, 0.4392, 0.5059, 0.4651, 0.4474])
174

175
176
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

177
178
    def test_stable_diffusion_inpaint_image_tensor(self):
        device = "cpu"  # ensure determinism for the device-dependent torch.Generator
179
180
        components = self.get_dummy_components()
        sd_pipe = StableDiffusionInpaintPipeline(**components)
181
182
183
        sd_pipe = sd_pipe.to(device)
        sd_pipe.set_progress_bar_config(disable=None)

184
185
186
        inputs = self.get_dummy_inputs(device)
        output = sd_pipe(**inputs)
        out_pil = output.images
187

188
189
190
191
192
        inputs = self.get_dummy_inputs(device)
        inputs["image"] = torch.tensor(np.array(inputs["image"]) / 127.5 - 1).permute(2, 0, 1).unsqueeze(0)
        inputs["mask_image"] = torch.tensor(np.array(inputs["mask_image"]) / 255).permute(2, 0, 1)[:1].unsqueeze(0)
        output = sd_pipe(**inputs)
        out_tensor = output.images
193

194
195
        assert out_pil.shape == (1, 64, 64, 3)
        assert np.abs(out_pil.flatten() - out_tensor.flatten()).max() < 5e-2
196

197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
    def test_stable_diffusion_inpaint_lora(self):
        device = "cpu"  # ensure determinism for the device-dependent torch.Generator

        components = self.get_dummy_components()
        sd_pipe = StableDiffusionInpaintPipeline(**components)
        sd_pipe = sd_pipe.to(torch_device)
        sd_pipe.set_progress_bar_config(disable=None)

        # forward 1
        inputs = self.get_dummy_inputs(device)
        output = sd_pipe(**inputs)
        image = output.images
        image_slice = image[0, -3:, -3:, -1]

        # set lora layers
        lora_attn_procs = create_lora_layers(sd_pipe.unet)
        sd_pipe.unet.set_attn_processor(lora_attn_procs)
        sd_pipe = sd_pipe.to(torch_device)

        # forward 2
        inputs = self.get_dummy_inputs(device)
        output = sd_pipe(**inputs, cross_attention_kwargs={"scale": 0.0})
        image = output.images
        image_slice_1 = image[0, -3:, -3:, -1]

        # forward 3
        inputs = self.get_dummy_inputs(device)
        output = sd_pipe(**inputs, cross_attention_kwargs={"scale": 0.5})
        image = output.images
        image_slice_2 = image[0, -3:, -3:, -1]

        assert np.abs(image_slice - image_slice_1).max() < 1e-2
        assert np.abs(image_slice - image_slice_2).max() > 1e-2

231
232
233
    def test_inference_batch_single_identical(self):
        super().test_inference_batch_single_identical(expected_max_diff=3e-3)

234
235

@slow
236
@require_torch_gpu
237
238
239
240
class StableDiffusionInpaintPipelineSlowTests(unittest.TestCase):
    def setUp(self):
        super().setUp()

241
242
243
244
245
    def tearDown(self):
        super().tearDown()
        gc.collect()
        torch.cuda.empty_cache()

246
247
    def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
        generator = torch.Generator(device=generator_device).manual_seed(seed)
248
        init_image = load_image(
249
250
            "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
            "/stable_diffusion_inpaint/input_bench_image.png"
251
252
        )
        mask_image = load_image(
253
254
            "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
            "/stable_diffusion_inpaint/input_bench_mask.png"
255
        )
256
257
258
259
260
261
262
263
264
265
        inputs = {
            "prompt": "Face of a yellow cat, high resolution, sitting on a park bench",
            "image": init_image,
            "mask_image": mask_image,
            "generator": generator,
            "num_inference_steps": 3,
            "guidance_scale": 7.5,
            "output_type": "numpy",
        }
        return inputs
266

267
268
269
270
    def test_stable_diffusion_inpaint_ddim(self):
        pipe = StableDiffusionInpaintPipeline.from_pretrained(
            "runwayml/stable-diffusion-inpainting", safety_checker=None
        )
271
272
273
274
        pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        pipe.enable_attention_slicing()

275
276
277
        inputs = self.get_inputs(torch_device)
        image = pipe(**inputs).images
        image_slice = image[0, 253:256, 253:256, -1].flatten()
278

279
        assert image.shape == (1, 512, 512, 3)
280
281
        expected_slice = np.array([0.0427, 0.0460, 0.0483, 0.0460, 0.0584, 0.0521, 0.1549, 0.1695, 0.1794])

282
        assert np.abs(expected_slice - image_slice).max() < 6e-4
283
284
285

    def test_stable_diffusion_inpaint_fp16(self):
        pipe = StableDiffusionInpaintPipeline.from_pretrained(
286
            "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, safety_checker=None
287
        )
288
289
290
        pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        pipe.enable_attention_slicing()
291

292
293
294
        inputs = self.get_inputs(torch_device, dtype=torch.float16)
        image = pipe(**inputs).images
        image_slice = image[0, 253:256, 253:256, -1].flatten()
295

296
        assert image.shape == (1, 512, 512, 3)
297
        expected_slice = np.array([0.1350, 0.1123, 0.1350, 0.1641, 0.1328, 0.1230, 0.1289, 0.1531, 0.1687])
298
299

        assert np.abs(expected_slice - image_slice).max() < 5e-2
300

301
    def test_stable_diffusion_inpaint_pndm(self):
302
        pipe = StableDiffusionInpaintPipeline.from_pretrained(
303
            "runwayml/stable-diffusion-inpainting", safety_checker=None
304
        )
305
        pipe.scheduler = PNDMScheduler.from_config(pipe.scheduler.config)
306
307
308
309
        pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        pipe.enable_attention_slicing()

310
311
312
        inputs = self.get_inputs(torch_device)
        image = pipe(**inputs).images
        image_slice = image[0, 253:256, 253:256, -1].flatten()
313

314
        assert image.shape == (1, 512, 512, 3)
315
316
        expected_slice = np.array([0.0425, 0.0273, 0.0344, 0.1694, 0.1727, 0.1812, 0.3256, 0.3311, 0.3272])

317
        assert np.abs(expected_slice - image_slice).max() < 5e-3
318

319
320
321
    def test_stable_diffusion_inpaint_k_lms(self):
        pipe = StableDiffusionInpaintPipeline.from_pretrained(
            "runwayml/stable-diffusion-inpainting", safety_checker=None
322
        )
323
        pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
324
325
326
327
        pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        pipe.enable_attention_slicing()

328
329
330
331
332
        inputs = self.get_inputs(torch_device)
        image = pipe(**inputs).images
        image_slice = image[0, 253:256, 253:256, -1].flatten()

        assert image.shape == (1, 512, 512, 3)
333
334
        expected_slice = np.array([0.9314, 0.7575, 0.9432, 0.8885, 0.9028, 0.7298, 0.9811, 0.9667, 0.7633])

335
        assert np.abs(expected_slice - image_slice).max() < 6e-3
336

337
338
339
340
341
342
    def test_stable_diffusion_inpaint_with_sequential_cpu_offloading(self):
        torch.cuda.empty_cache()
        torch.cuda.reset_max_memory_allocated()
        torch.cuda.reset_peak_memory_stats()

        pipe = StableDiffusionInpaintPipeline.from_pretrained(
343
            "runwayml/stable-diffusion-inpainting", safety_checker=None, torch_dtype=torch.float16
344
        )
345
346
347
348
349
350
351
352
353
354
355
356
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        pipe.enable_attention_slicing(1)
        pipe.enable_sequential_cpu_offload()

        inputs = self.get_inputs(torch_device, dtype=torch.float16)
        _ = pipe(**inputs)

        mem_bytes = torch.cuda.max_memory_allocated()
        # make sure that less than 2.2 GB is allocated
        assert mem_bytes < 2.2 * 10**9

357
    @require_torch_2
358
    def test_inpaint_compile(self):
359
360
361
362
363
364
365
        seed = 0
        inputs = self.get_inputs(torch_device, seed=seed)
        # Can't pickle a Generator object
        del inputs["generator"]
        inputs["torch_device"] = torch_device
        inputs["seed"] = seed
        run_test_in_subprocess(test_case=self, target_func=_test_inpaint_compile, inputs=inputs)
366

367
    def test_stable_diffusion_inpaint_pil_input_resolution_test(self):
Patrick von Platen's avatar
Patrick von Platen committed
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
        pipe = StableDiffusionInpaintPipeline.from_pretrained(
            "runwayml/stable-diffusion-inpainting", safety_checker=None
        )
        pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
        pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        pipe.enable_attention_slicing()

        inputs = self.get_inputs(torch_device)
        # change input image to a random size (one that would cause a tensor mismatch error)
        inputs["image"] = inputs["image"].resize((127, 127))
        inputs["mask_image"] = inputs["mask_image"].resize((127, 127))
        inputs["height"] = 128
        inputs["width"] = 128
        image = pipe(**inputs).images
        # verify that the returned image has the same height and width as the input height and width
        assert image.shape == (1, inputs["height"], inputs["width"], 3)
385

386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
    def test_stable_diffusion_inpaint_strength_test(self):
        pipe = StableDiffusionInpaintPipeline.from_pretrained(
            "runwayml/stable-diffusion-inpainting", safety_checker=None
        )
        pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
        pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        pipe.enable_attention_slicing()

        inputs = self.get_inputs(torch_device)
        # change input strength
        inputs["strength"] = 0.75
        image = pipe(**inputs).images
        # verify that the returned image has the same height and width as the input height and width
        assert image.shape == (1, 512, 512, 3)

        image_slice = image[0, 253:256, 253:256, -1].flatten()
        expected_slice = np.array([0.0021, 0.2350, 0.3712, 0.0575, 0.2485, 0.3451, 0.1857, 0.3156, 0.3943])
        assert np.abs(expected_slice - image_slice).max() < 3e-3

406

407
408
409
410
411
412
413
@nightly
@require_torch_gpu
class StableDiffusionInpaintPipelineNightlyTests(unittest.TestCase):
    def tearDown(self):
        super().tearDown()
        gc.collect()
        torch.cuda.empty_cache()
414

415
416
    def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
        generator = torch.Generator(device=generator_device).manual_seed(seed)
417
        init_image = load_image(
418
419
            "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
            "/stable_diffusion_inpaint/input_bench_image.png"
420
421
        )
        mask_image = load_image(
422
423
            "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
            "/stable_diffusion_inpaint/input_bench_mask.png"
424
        )
425
426
427
428
429
430
431
432
433
434
        inputs = {
            "prompt": "Face of a yellow cat, high resolution, sitting on a park bench",
            "image": init_image,
            "mask_image": mask_image,
            "generator": generator,
            "num_inference_steps": 50,
            "guidance_scale": 7.5,
            "output_type": "numpy",
        }
        return inputs
435

436
437
438
439
    def test_inpaint_ddim(self):
        sd_pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting")
        sd_pipe.to(torch_device)
        sd_pipe.set_progress_bar_config(disable=None)
440

441
442
        inputs = self.get_inputs(torch_device)
        image = sd_pipe(**inputs).images[0]
443

444
445
446
447
448
449
        expected_image = load_numpy(
            "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
            "/stable_diffusion_inpaint/stable_diffusion_inpaint_ddim.npy"
        )
        max_diff = np.abs(expected_image - image).max()
        assert max_diff < 1e-3
450

451
452
453
454
455
456
457
458
    def test_inpaint_pndm(self):
        sd_pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting")
        sd_pipe.scheduler = PNDMScheduler.from_config(sd_pipe.scheduler.config)
        sd_pipe.to(torch_device)
        sd_pipe.set_progress_bar_config(disable=None)

        inputs = self.get_inputs(torch_device)
        image = sd_pipe(**inputs).images[0]
459

460
461
462
        expected_image = load_numpy(
            "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
            "/stable_diffusion_inpaint/stable_diffusion_inpaint_pndm.npy"
463
        )
464
465
        max_diff = np.abs(expected_image - image).max()
        assert max_diff < 1e-3
466

467
468
469
470
471
    def test_inpaint_lms(self):
        sd_pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting")
        sd_pipe.scheduler = LMSDiscreteScheduler.from_config(sd_pipe.scheduler.config)
        sd_pipe.to(torch_device)
        sd_pipe.set_progress_bar_config(disable=None)
472

473
474
        inputs = self.get_inputs(torch_device)
        image = sd_pipe(**inputs).images[0]
475

476
477
478
        expected_image = load_numpy(
            "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
            "/stable_diffusion_inpaint/stable_diffusion_inpaint_lms.npy"
479
        )
480
481
        max_diff = np.abs(expected_image - image).max()
        assert max_diff < 1e-3
482

483
484
485
486
487
    def test_inpaint_dpm(self):
        sd_pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting")
        sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.config)
        sd_pipe.to(torch_device)
        sd_pipe.set_progress_bar_config(disable=None)
488

489
490
491
        inputs = self.get_inputs(torch_device)
        inputs["num_inference_steps"] = 30
        image = sd_pipe(**inputs).images[0]
492

493
494
495
        expected_image = load_numpy(
            "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
            "/stable_diffusion_inpaint/stable_diffusion_inpaint_dpm_multi.npy"
496
        )
497
498
        max_diff = np.abs(expected_image - image).max()
        assert max_diff < 1e-3
499

Patrick von Platen's avatar
Patrick von Platen committed
500

501
502
class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase):
    def test_pil_inputs(self):
503
504
        height, width = 32, 32
        im = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8)
505
        im = Image.fromarray(im)
506
        mask = np.random.randint(0, 255, (height, width), dtype=np.uint8) > 127.5
507
508
        mask = Image.fromarray((mask * 255).astype(np.uint8))

509
        t_mask, t_masked, t_image = prepare_mask_and_masked_image(im, mask, height, width, return_image=True)
510
511
512

        self.assertTrue(isinstance(t_mask, torch.Tensor))
        self.assertTrue(isinstance(t_masked, torch.Tensor))
513
        self.assertTrue(isinstance(t_image, torch.Tensor))
514
515
516

        self.assertEqual(t_mask.ndim, 4)
        self.assertEqual(t_masked.ndim, 4)
517
        self.assertEqual(t_image.ndim, 4)
518

519
520
        self.assertEqual(t_mask.shape, (1, 1, height, width))
        self.assertEqual(t_masked.shape, (1, 3, height, width))
521
        self.assertEqual(t_image.shape, (1, 3, height, width))
522
523
524

        self.assertTrue(t_mask.dtype == torch.float32)
        self.assertTrue(t_masked.dtype == torch.float32)
525
        self.assertTrue(t_image.dtype == torch.float32)
526
527
528
529
530

        self.assertTrue(t_mask.min() >= 0.0)
        self.assertTrue(t_mask.max() <= 1.0)
        self.assertTrue(t_masked.min() >= -1.0)
        self.assertTrue(t_masked.min() <= 1.0)
531
532
        self.assertTrue(t_image.min() >= -1.0)
        self.assertTrue(t_image.min() >= -1.0)
533
534
535
536

        self.assertTrue(t_mask.sum() > 0.0)

    def test_np_inputs(self):
537
538
539
        height, width = 32, 32

        im_np = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8)
540
        im_pil = Image.fromarray(im_np)
Patrick von Platen's avatar
Patrick von Platen committed
541
542
543
544
545
546
547
548
549
550
551
552
        mask_np = (
            np.random.randint(
                0,
                255,
                (
                    height,
                    width,
                ),
                dtype=np.uint8,
            )
            > 127.5
        )
553
554
        mask_pil = Image.fromarray((mask_np * 255).astype(np.uint8))

555
556
557
558
559
560
        t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(
            im_np, mask_np, height, width, return_image=True
        )
        t_mask_pil, t_masked_pil, t_image_pil = prepare_mask_and_masked_image(
            im_pil, mask_pil, height, width, return_image=True
        )
561
562
563

        self.assertTrue((t_mask_np == t_mask_pil).all())
        self.assertTrue((t_masked_np == t_masked_pil).all())
564
        self.assertTrue((t_image_np == t_image_pil).all())
565
566

    def test_torch_3D_2D_inputs(self):
567
568
        height, width = 32, 32

Patrick von Platen's avatar
Patrick von Platen committed
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
        im_tensor = torch.randint(
            0,
            255,
            (
                3,
                height,
                width,
            ),
            dtype=torch.uint8,
        )
        mask_tensor = (
            torch.randint(
                0,
                255,
                (
                    height,
                    width,
                ),
                dtype=torch.uint8,
            )
            > 127.5
        )
591
592
593
        im_np = im_tensor.numpy().transpose(1, 2, 0)
        mask_np = mask_tensor.numpy()

594
595
596
597
598
        t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
            im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
        )
        t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(
            im_np, mask_np, height, width, return_image=True
Patrick von Platen's avatar
Patrick von Platen committed
599
        )
600
601
602

        self.assertTrue((t_mask_tensor == t_mask_np).all())
        self.assertTrue((t_masked_tensor == t_masked_np).all())
603
        self.assertTrue((t_image_tensor == t_image_np).all())
604
605

    def test_torch_3D_3D_inputs(self):
606
607
        height, width = 32, 32

Patrick von Platen's avatar
Patrick von Platen committed
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
        im_tensor = torch.randint(
            0,
            255,
            (
                3,
                height,
                width,
            ),
            dtype=torch.uint8,
        )
        mask_tensor = (
            torch.randint(
                0,
                255,
                (
                    1,
                    height,
                    width,
                ),
                dtype=torch.uint8,
            )
            > 127.5
        )
631
632
633
        im_np = im_tensor.numpy().transpose(1, 2, 0)
        mask_np = mask_tensor.numpy()[0]

634
635
636
637
638
        t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
            im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
        )
        t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(
            im_np, mask_np, height, width, return_image=True
Patrick von Platen's avatar
Patrick von Platen committed
639
        )
640
641
642

        self.assertTrue((t_mask_tensor == t_mask_np).all())
        self.assertTrue((t_masked_tensor == t_masked_np).all())
643
        self.assertTrue((t_image_tensor == t_image_np).all())
644
645

    def test_torch_4D_2D_inputs(self):
646
647
        height, width = 32, 32

Patrick von Platen's avatar
Patrick von Platen committed
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
        im_tensor = torch.randint(
            0,
            255,
            (
                1,
                3,
                height,
                width,
            ),
            dtype=torch.uint8,
        )
        mask_tensor = (
            torch.randint(
                0,
                255,
                (
                    height,
                    width,
                ),
                dtype=torch.uint8,
            )
            > 127.5
        )
671
672
673
        im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
        mask_np = mask_tensor.numpy()

674
675
676
677
678
        t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
            im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
        )
        t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(
            im_np, mask_np, height, width, return_image=True
Patrick von Platen's avatar
Patrick von Platen committed
679
        )
680
681
682

        self.assertTrue((t_mask_tensor == t_mask_np).all())
        self.assertTrue((t_masked_tensor == t_masked_np).all())
683
        self.assertTrue((t_image_tensor == t_image_np).all())
684
685

    def test_torch_4D_3D_inputs(self):
686
687
        height, width = 32, 32

Patrick von Platen's avatar
Patrick von Platen committed
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
        im_tensor = torch.randint(
            0,
            255,
            (
                1,
                3,
                height,
                width,
            ),
            dtype=torch.uint8,
        )
        mask_tensor = (
            torch.randint(
                0,
                255,
                (
                    1,
                    height,
                    width,
                ),
                dtype=torch.uint8,
            )
            > 127.5
        )
712
713
714
        im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
        mask_np = mask_tensor.numpy()[0]

715
716
717
718
719
        t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
            im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
        )
        t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(
            im_np, mask_np, height, width, return_image=True
Patrick von Platen's avatar
Patrick von Platen committed
720
        )
721
722
723

        self.assertTrue((t_mask_tensor == t_mask_np).all())
        self.assertTrue((t_masked_tensor == t_masked_np).all())
724
        self.assertTrue((t_image_tensor == t_image_np).all())
725
726

    def test_torch_4D_4D_inputs(self):
727
728
        height, width = 32, 32

Patrick von Platen's avatar
Patrick von Platen committed
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
        im_tensor = torch.randint(
            0,
            255,
            (
                1,
                3,
                height,
                width,
            ),
            dtype=torch.uint8,
        )
        mask_tensor = (
            torch.randint(
                0,
                255,
                (
                    1,
                    1,
                    height,
                    width,
                ),
                dtype=torch.uint8,
            )
            > 127.5
        )
754
755
756
        im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
        mask_np = mask_tensor.numpy()[0][0]

757
758
759
760
761
        t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
            im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
        )
        t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(
            im_np, mask_np, height, width, return_image=True
Patrick von Platen's avatar
Patrick von Platen committed
762
        )
763
764
765

        self.assertTrue((t_mask_tensor == t_mask_np).all())
        self.assertTrue((t_masked_tensor == t_masked_np).all())
766
        self.assertTrue((t_image_tensor == t_image_np).all())
767
768

    def test_torch_batch_4D_3D(self):
769
770
        height, width = 32, 32

Patrick von Platen's avatar
Patrick von Platen committed
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
        im_tensor = torch.randint(
            0,
            255,
            (
                2,
                3,
                height,
                width,
            ),
            dtype=torch.uint8,
        )
        mask_tensor = (
            torch.randint(
                0,
                255,
                (
                    2,
                    height,
                    width,
                ),
                dtype=torch.uint8,
            )
            > 127.5
        )
795
796
797
798

        im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor]
        mask_nps = [mask.numpy() for mask in mask_tensor]

799
800
        t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
            im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
Patrick von Platen's avatar
Patrick von Platen committed
801
        )
802
        nps = [prepare_mask_and_masked_image(i, m, height, width, return_image=True) for i, m in zip(im_nps, mask_nps)]
803
804
        t_mask_np = torch.cat([n[0] for n in nps])
        t_masked_np = torch.cat([n[1] for n in nps])
805
        t_image_np = torch.cat([n[2] for n in nps])
806
807
808

        self.assertTrue((t_mask_tensor == t_mask_np).all())
        self.assertTrue((t_masked_tensor == t_masked_np).all())
809
        self.assertTrue((t_image_tensor == t_image_np).all())
810
811

    def test_torch_batch_4D_4D(self):
812
813
        height, width = 32, 32

Patrick von Platen's avatar
Patrick von Platen committed
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
        im_tensor = torch.randint(
            0,
            255,
            (
                2,
                3,
                height,
                width,
            ),
            dtype=torch.uint8,
        )
        mask_tensor = (
            torch.randint(
                0,
                255,
                (
                    2,
                    1,
                    height,
                    width,
                ),
                dtype=torch.uint8,
            )
            > 127.5
        )
839
840
841
842

        im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor]
        mask_nps = [mask.numpy()[0] for mask in mask_tensor]

843
844
        t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
            im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
Patrick von Platen's avatar
Patrick von Platen committed
845
        )
846
        nps = [prepare_mask_and_masked_image(i, m, height, width, return_image=True) for i, m in zip(im_nps, mask_nps)]
847
848
        t_mask_np = torch.cat([n[0] for n in nps])
        t_masked_np = torch.cat([n[1] for n in nps])
849
        t_image_np = torch.cat([n[2] for n in nps])
850
851
852

        self.assertTrue((t_mask_tensor == t_mask_np).all())
        self.assertTrue((t_masked_tensor == t_masked_np).all())
853
        self.assertTrue((t_image_tensor == t_image_np).all())
854
855

    def test_shape_mismatch(self):
856
857
        height, width = 32, 32

858
859
        # test height and width
        with self.assertRaises(AssertionError):
Patrick von Platen's avatar
Patrick von Platen committed
860
861
862
863
864
865
866
867
868
            prepare_mask_and_masked_image(
                torch.randn(
                    3,
                    height,
                    width,
                ),
                torch.randn(64, 64),
                height,
                width,
869
                return_image=True,
Patrick von Platen's avatar
Patrick von Platen committed
870
            )
871
872
        # test batch dim
        with self.assertRaises(AssertionError):
Patrick von Platen's avatar
Patrick von Platen committed
873
874
875
876
877
878
879
880
881
882
            prepare_mask_and_masked_image(
                torch.randn(
                    2,
                    3,
                    height,
                    width,
                ),
                torch.randn(4, 64, 64),
                height,
                width,
883
                return_image=True,
Patrick von Platen's avatar
Patrick von Platen committed
884
            )
885
886
        # test batch dim
        with self.assertRaises(AssertionError):
Patrick von Platen's avatar
Patrick von Platen committed
887
888
889
890
891
892
893
894
895
896
            prepare_mask_and_masked_image(
                torch.randn(
                    2,
                    3,
                    height,
                    width,
                ),
                torch.randn(4, 1, 64, 64),
                height,
                width,
897
                return_image=True,
Patrick von Platen's avatar
Patrick von Platen committed
898
            )
899
900

    def test_type_mismatch(self):
901
902
        height, width = 32, 32

903
904
        # test tensors-only
        with self.assertRaises(TypeError):
Patrick von Platen's avatar
Patrick von Platen committed
905
906
907
908
909
910
911
912
913
914
915
916
917
            prepare_mask_and_masked_image(
                torch.rand(
                    3,
                    height,
                    width,
                ),
                torch.rand(
                    3,
                    height,
                    width,
                ).numpy(),
                height,
                width,
918
                return_image=True,
Patrick von Platen's avatar
Patrick von Platen committed
919
            )
920
921
        # test tensors-only
        with self.assertRaises(TypeError):
Patrick von Platen's avatar
Patrick von Platen committed
922
923
924
925
926
927
928
929
930
931
932
933
934
            prepare_mask_and_masked_image(
                torch.rand(
                    3,
                    height,
                    width,
                ).numpy(),
                torch.rand(
                    3,
                    height,
                    width,
                ),
                height,
                width,
935
                return_image=True,
Patrick von Platen's avatar
Patrick von Platen committed
936
            )
937
938

    def test_channels_first(self):
939
940
        height, width = 32, 32

941
942
        # test channels first for 3D tensors
        with self.assertRaises(AssertionError):
Patrick von Platen's avatar
Patrick von Platen committed
943
944
945
946
947
948
949
950
951
            prepare_mask_and_masked_image(
                torch.rand(height, width, 3),
                torch.rand(
                    3,
                    height,
                    width,
                ),
                height,
                width,
952
                return_image=True,
Patrick von Platen's avatar
Patrick von Platen committed
953
            )
954
955

    def test_tensor_range(self):
956
957
        height, width = 32, 32

958
959
        # test im <= 1
        with self.assertRaises(ValueError):
Patrick von Platen's avatar
Patrick von Platen committed
960
961
962
963
964
965
966
967
968
969
970
971
972
            prepare_mask_and_masked_image(
                torch.ones(
                    3,
                    height,
                    width,
                )
                * 2,
                torch.rand(
                    height,
                    width,
                ),
                height,
                width,
973
                return_image=True,
Patrick von Platen's avatar
Patrick von Platen committed
974
            )
975
976
        # test im >= -1
        with self.assertRaises(ValueError):
Patrick von Platen's avatar
Patrick von Platen committed
977
978
979
980
981
982
983
984
985
986
987
988
989
            prepare_mask_and_masked_image(
                torch.ones(
                    3,
                    height,
                    width,
                )
                * (-2),
                torch.rand(
                    height,
                    width,
                ),
                height,
                width,
990
                return_image=True,
Patrick von Platen's avatar
Patrick von Platen committed
991
            )
992
993
        # test mask <= 1
        with self.assertRaises(ValueError):
Patrick von Platen's avatar
Patrick von Platen committed
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
            prepare_mask_and_masked_image(
                torch.rand(
                    3,
                    height,
                    width,
                ),
                torch.ones(
                    height,
                    width,
                )
                * 2,
                height,
                width,
1007
                return_image=True,
Patrick von Platen's avatar
Patrick von Platen committed
1008
            )
1009
1010
        # test mask >= 0
        with self.assertRaises(ValueError):
Patrick von Platen's avatar
Patrick von Platen committed
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
            prepare_mask_and_masked_image(
                torch.rand(
                    3,
                    height,
                    width,
                ),
                torch.ones(
                    height,
                    width,
                )
                * -1,
                height,
                width,
1024
                return_image=True,
Patrick von Platen's avatar
Patrick von Platen committed
1025
            )