"example/vscode:/vscode.git/clone" did not exist on "0d26477a864fcaf6f1dfe709c9f8421e5305b933"
test_stable_diffusion_inpaint.py 32.9 KB
Newer Older
1
# coding=utf-8
Patrick von Platen's avatar
Patrick von Platen committed
2
# Copyright 2023 HuggingFace Inc.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
import random
import unittest

import numpy as np
import torch
22
from packaging import version
23
24
from PIL import Image
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
25
26
27

from diffusers import (
    AutoencoderKL,
28
    DPMSolverMultistepScheduler,
29
    LMSDiscreteScheduler,
30
31
32
33
    PNDMScheduler,
    StableDiffusionInpaintPipeline,
    UNet2DConditionModel,
)
34
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import prepare_mask_and_masked_image
35
from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device
36
from diffusers.utils.testing_utils import require_torch_gpu
37

38
from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
39
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
40

41
42

torch.backends.cuda.matmul.allow_tf32 = False
43
torch.use_deterministic_algorithms(True)
44
45


46
class StableDiffusionInpaintPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
47
    pipeline_class = StableDiffusionInpaintPipeline
48
49
    params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
    batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
50
51
    image_params = frozenset([])
    # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
52

53
    def get_dummy_components(self):
54
        torch.manual_seed(0)
55
        unet = UNet2DConditionModel(
56
57
58
59
60
61
62
63
64
            block_out_channels=(32, 64),
            layers_per_block=2,
            sample_size=32,
            in_channels=9,
            out_channels=4,
            down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
            up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
            cross_attention_dim=32,
        )
65
        scheduler = PNDMScheduler(skip_prk_steps=True)
66
        torch.manual_seed(0)
67
        vae = AutoencoderKL(
68
69
70
71
72
73
74
75
            block_out_channels=[32, 64],
            in_channels=3,
            out_channels=3,
            down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
            up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
            latent_channels=4,
        )
        torch.manual_seed(0)
76
        text_encoder_config = CLIPTextConfig(
77
78
79
80
81
82
83
84
85
86
            bos_token_id=0,
            eos_token_id=2,
            hidden_size=32,
            intermediate_size=37,
            layer_norm_eps=1e-05,
            num_attention_heads=4,
            num_hidden_layers=5,
            pad_token_id=1,
            vocab_size=1000,
        )
87
        text_encoder = CLIPTextModel(text_encoder_config)
88
89
        tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")

90
91
92
93
94
95
96
        components = {
            "unet": unet,
            "scheduler": scheduler,
            "vae": vae,
            "text_encoder": text_encoder,
            "tokenizer": tokenizer,
            "safety_checker": None,
97
            "feature_extractor": None,
98
99
100
101
102
103
104
        }
        return components

    def get_dummy_inputs(self, device, seed=0):
        # TODO: use tensor inputs instead of PIL, this is here just to leave the old expected_slices untouched
        image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
        image = image.cpu().permute(0, 2, 3, 1)[0]
Patrick von Platen's avatar
Patrick von Platen committed
105
106
        init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
        mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((64, 64))
107
108
109
110
111
112
113
114
115
116
117
118
119
120
        if str(device).startswith("mps"):
            generator = torch.manual_seed(seed)
        else:
            generator = torch.Generator(device=device).manual_seed(seed)
        inputs = {
            "prompt": "A painting of a squirrel eating a burger",
            "image": init_image,
            "mask_image": mask_image,
            "generator": generator,
            "num_inference_steps": 2,
            "guidance_scale": 6.0,
            "output_type": "numpy",
        }
        return inputs
121

122
123
124
125
    def test_stable_diffusion_inpaint(self):
        device = "cpu"  # ensure determinism for the device-dependent torch.Generator
        components = self.get_dummy_components()
        sd_pipe = StableDiffusionInpaintPipeline(**components)
126
127
128
        sd_pipe = sd_pipe.to(device)
        sd_pipe.set_progress_bar_config(disable=None)

129
130
        inputs = self.get_dummy_inputs(device)
        image = sd_pipe(**inputs).images
131
132
        image_slice = image[0, -3:, -3:, -1]

133
        assert image.shape == (1, 64, 64, 3)
Patrick von Platen's avatar
Patrick von Platen committed
134
        expected_slice = np.array([0.4723, 0.5731, 0.3939, 0.5441, 0.5922, 0.4392, 0.5059, 0.4651, 0.4474])
135

136
137
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

138
139
    def test_stable_diffusion_inpaint_image_tensor(self):
        device = "cpu"  # ensure determinism for the device-dependent torch.Generator
140
141
        components = self.get_dummy_components()
        sd_pipe = StableDiffusionInpaintPipeline(**components)
142
143
144
        sd_pipe = sd_pipe.to(device)
        sd_pipe.set_progress_bar_config(disable=None)

145
146
147
        inputs = self.get_dummy_inputs(device)
        output = sd_pipe(**inputs)
        out_pil = output.images
148

149
150
151
152
153
        inputs = self.get_dummy_inputs(device)
        inputs["image"] = torch.tensor(np.array(inputs["image"]) / 127.5 - 1).permute(2, 0, 1).unsqueeze(0)
        inputs["mask_image"] = torch.tensor(np.array(inputs["mask_image"]) / 255).permute(2, 0, 1)[:1].unsqueeze(0)
        output = sd_pipe(**inputs)
        out_tensor = output.images
154

155
156
        assert out_pil.shape == (1, 64, 64, 3)
        assert np.abs(out_pil.flatten() - out_tensor.flatten()).max() < 5e-2
157

158
159
160
    def test_inference_batch_single_identical(self):
        super().test_inference_batch_single_identical(expected_max_diff=3e-3)

161
162

@slow
163
@require_torch_gpu
164
165
166
167
class StableDiffusionInpaintPipelineSlowTests(unittest.TestCase):
    def setUp(self):
        super().setUp()

168
169
170
171
172
    def tearDown(self):
        super().tearDown()
        gc.collect()
        torch.cuda.empty_cache()

173
174
    def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
        generator = torch.Generator(device=generator_device).manual_seed(seed)
175
        init_image = load_image(
176
177
            "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
            "/stable_diffusion_inpaint/input_bench_image.png"
178
179
        )
        mask_image = load_image(
180
181
            "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
            "/stable_diffusion_inpaint/input_bench_mask.png"
182
        )
183
184
185
186
187
188
189
190
191
192
        inputs = {
            "prompt": "Face of a yellow cat, high resolution, sitting on a park bench",
            "image": init_image,
            "mask_image": mask_image,
            "generator": generator,
            "num_inference_steps": 3,
            "guidance_scale": 7.5,
            "output_type": "numpy",
        }
        return inputs
193

194
195
196
197
    def test_stable_diffusion_inpaint_ddim(self):
        pipe = StableDiffusionInpaintPipeline.from_pretrained(
            "runwayml/stable-diffusion-inpainting", safety_checker=None
        )
198
199
200
201
        pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        pipe.enable_attention_slicing()

202
203
204
        inputs = self.get_inputs(torch_device)
        image = pipe(**inputs).images
        image_slice = image[0, 253:256, 253:256, -1].flatten()
205

206
        assert image.shape == (1, 512, 512, 3)
207
208
        expected_slice = np.array([0.0427, 0.0460, 0.0483, 0.0460, 0.0584, 0.0521, 0.1549, 0.1695, 0.1794])

209
        assert np.abs(expected_slice - image_slice).max() < 6e-4
210
211
212

    def test_stable_diffusion_inpaint_fp16(self):
        pipe = StableDiffusionInpaintPipeline.from_pretrained(
213
            "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, safety_checker=None
214
        )
215
216
217
        pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        pipe.enable_attention_slicing()
218

219
220
221
        inputs = self.get_inputs(torch_device, dtype=torch.float16)
        image = pipe(**inputs).images
        image_slice = image[0, 253:256, 253:256, -1].flatten()
222

223
        assert image.shape == (1, 512, 512, 3)
224
        expected_slice = np.array([0.1350, 0.1123, 0.1350, 0.1641, 0.1328, 0.1230, 0.1289, 0.1531, 0.1687])
225
226

        assert np.abs(expected_slice - image_slice).max() < 5e-2
227

228
    def test_stable_diffusion_inpaint_pndm(self):
229
        pipe = StableDiffusionInpaintPipeline.from_pretrained(
230
            "runwayml/stable-diffusion-inpainting", safety_checker=None
231
        )
232
        pipe.scheduler = PNDMScheduler.from_config(pipe.scheduler.config)
233
234
235
236
        pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        pipe.enable_attention_slicing()

237
238
239
        inputs = self.get_inputs(torch_device)
        image = pipe(**inputs).images
        image_slice = image[0, 253:256, 253:256, -1].flatten()
240

241
        assert image.shape == (1, 512, 512, 3)
242
243
        expected_slice = np.array([0.0425, 0.0273, 0.0344, 0.1694, 0.1727, 0.1812, 0.3256, 0.3311, 0.3272])

244
        assert np.abs(expected_slice - image_slice).max() < 5e-3
245

246
247
248
    def test_stable_diffusion_inpaint_k_lms(self):
        pipe = StableDiffusionInpaintPipeline.from_pretrained(
            "runwayml/stable-diffusion-inpainting", safety_checker=None
249
        )
250
        pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
251
252
253
254
        pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        pipe.enable_attention_slicing()

255
256
257
258
259
        inputs = self.get_inputs(torch_device)
        image = pipe(**inputs).images
        image_slice = image[0, 253:256, 253:256, -1].flatten()

        assert image.shape == (1, 512, 512, 3)
260
261
        expected_slice = np.array([0.9314, 0.7575, 0.9432, 0.8885, 0.9028, 0.7298, 0.9811, 0.9667, 0.7633])

262
        assert np.abs(expected_slice - image_slice).max() < 6e-3
263

264
265
266
267
268
269
    def test_stable_diffusion_inpaint_with_sequential_cpu_offloading(self):
        torch.cuda.empty_cache()
        torch.cuda.reset_max_memory_allocated()
        torch.cuda.reset_peak_memory_stats()

        pipe = StableDiffusionInpaintPipeline.from_pretrained(
270
            "runwayml/stable-diffusion-inpainting", safety_checker=None, torch_dtype=torch.float16
271
        )
272
273
274
275
276
277
278
279
280
281
282
283
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        pipe.enable_attention_slicing(1)
        pipe.enable_sequential_cpu_offload()

        inputs = self.get_inputs(torch_device, dtype=torch.float16)
        _ = pipe(**inputs)

        mem_bytes = torch.cuda.max_memory_allocated()
        # make sure that less than 2.2 GB is allocated
        assert mem_bytes < 2.2 * 10**9

284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
    def test_inpaint_compile(self):
        if version.parse(torch.__version__) < version.parse("2.0"):
            print(f"Test `test_stable_diffusion_ddim` is skipped because {torch.__version__} is < 2.0")
            return

        pipe = StableDiffusionInpaintPipeline.from_pretrained(
            "runwayml/stable-diffusion-inpainting", safety_checker=None
        )
        pipe.scheduler = PNDMScheduler.from_config(pipe.scheduler.config)
        pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)

        pipe.unet.to(memory_format=torch.channels_last)
        pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)

        inputs = self.get_inputs(torch_device)
        image = pipe(**inputs).images
        image_slice = image[0, 253:256, 253:256, -1].flatten()

        assert image.shape == (1, 512, 512, 3)
        expected_slice = np.array([0.0425, 0.0273, 0.0344, 0.1694, 0.1727, 0.1812, 0.3256, 0.3311, 0.3272])

306
        assert np.abs(expected_slice - image_slice).max() < 3e-3
307

308
    def test_stable_diffusion_inpaint_pil_input_resolution_test(self):
Patrick von Platen's avatar
Patrick von Platen committed
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
        pipe = StableDiffusionInpaintPipeline.from_pretrained(
            "runwayml/stable-diffusion-inpainting", safety_checker=None
        )
        pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
        pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        pipe.enable_attention_slicing()

        inputs = self.get_inputs(torch_device)
        # change input image to a random size (one that would cause a tensor mismatch error)
        inputs["image"] = inputs["image"].resize((127, 127))
        inputs["mask_image"] = inputs["mask_image"].resize((127, 127))
        inputs["height"] = 128
        inputs["width"] = 128
        image = pipe(**inputs).images
        # verify that the returned image has the same height and width as the input height and width
        assert image.shape == (1, inputs["height"], inputs["width"], 3)
326

327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
    def test_stable_diffusion_inpaint_strength_test(self):
        pipe = StableDiffusionInpaintPipeline.from_pretrained(
            "runwayml/stable-diffusion-inpainting", safety_checker=None
        )
        pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
        pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        pipe.enable_attention_slicing()

        inputs = self.get_inputs(torch_device)
        # change input strength
        inputs["strength"] = 0.75
        image = pipe(**inputs).images
        # verify that the returned image has the same height and width as the input height and width
        assert image.shape == (1, 512, 512, 3)

        image_slice = image[0, 253:256, 253:256, -1].flatten()
        expected_slice = np.array([0.0021, 0.2350, 0.3712, 0.0575, 0.2485, 0.3451, 0.1857, 0.3156, 0.3943])
        assert np.abs(expected_slice - image_slice).max() < 3e-3

347

348
349
350
351
352
353
354
@nightly
@require_torch_gpu
class StableDiffusionInpaintPipelineNightlyTests(unittest.TestCase):
    def tearDown(self):
        super().tearDown()
        gc.collect()
        torch.cuda.empty_cache()
355

356
357
    def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
        generator = torch.Generator(device=generator_device).manual_seed(seed)
358
        init_image = load_image(
359
360
            "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
            "/stable_diffusion_inpaint/input_bench_image.png"
361
362
        )
        mask_image = load_image(
363
364
            "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
            "/stable_diffusion_inpaint/input_bench_mask.png"
365
        )
366
367
368
369
370
371
372
373
374
375
        inputs = {
            "prompt": "Face of a yellow cat, high resolution, sitting on a park bench",
            "image": init_image,
            "mask_image": mask_image,
            "generator": generator,
            "num_inference_steps": 50,
            "guidance_scale": 7.5,
            "output_type": "numpy",
        }
        return inputs
376

377
378
379
380
    def test_inpaint_ddim(self):
        sd_pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting")
        sd_pipe.to(torch_device)
        sd_pipe.set_progress_bar_config(disable=None)
381

382
383
        inputs = self.get_inputs(torch_device)
        image = sd_pipe(**inputs).images[0]
384

385
386
387
388
389
390
        expected_image = load_numpy(
            "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
            "/stable_diffusion_inpaint/stable_diffusion_inpaint_ddim.npy"
        )
        max_diff = np.abs(expected_image - image).max()
        assert max_diff < 1e-3
391

392
393
394
395
396
397
398
399
    def test_inpaint_pndm(self):
        sd_pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting")
        sd_pipe.scheduler = PNDMScheduler.from_config(sd_pipe.scheduler.config)
        sd_pipe.to(torch_device)
        sd_pipe.set_progress_bar_config(disable=None)

        inputs = self.get_inputs(torch_device)
        image = sd_pipe(**inputs).images[0]
400

401
402
403
        expected_image = load_numpy(
            "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
            "/stable_diffusion_inpaint/stable_diffusion_inpaint_pndm.npy"
404
        )
405
406
        max_diff = np.abs(expected_image - image).max()
        assert max_diff < 1e-3
407

408
409
410
411
412
    def test_inpaint_lms(self):
        sd_pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting")
        sd_pipe.scheduler = LMSDiscreteScheduler.from_config(sd_pipe.scheduler.config)
        sd_pipe.to(torch_device)
        sd_pipe.set_progress_bar_config(disable=None)
413

414
415
        inputs = self.get_inputs(torch_device)
        image = sd_pipe(**inputs).images[0]
416

417
418
419
        expected_image = load_numpy(
            "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
            "/stable_diffusion_inpaint/stable_diffusion_inpaint_lms.npy"
420
        )
421
422
        max_diff = np.abs(expected_image - image).max()
        assert max_diff < 1e-3
423

424
425
426
427
428
    def test_inpaint_dpm(self):
        sd_pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting")
        sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.config)
        sd_pipe.to(torch_device)
        sd_pipe.set_progress_bar_config(disable=None)
429

430
431
432
        inputs = self.get_inputs(torch_device)
        inputs["num_inference_steps"] = 30
        image = sd_pipe(**inputs).images[0]
433

434
435
436
        expected_image = load_numpy(
            "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
            "/stable_diffusion_inpaint/stable_diffusion_inpaint_dpm_multi.npy"
437
        )
438
439
        max_diff = np.abs(expected_image - image).max()
        assert max_diff < 1e-3
440

Patrick von Platen's avatar
Patrick von Platen committed
441

442
443
class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase):
    def test_pil_inputs(self):
444
445
        height, width = 32, 32
        im = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8)
446
        im = Image.fromarray(im)
447
        mask = np.random.randint(0, 255, (height, width), dtype=np.uint8) > 127.5
448
449
        mask = Image.fromarray((mask * 255).astype(np.uint8))

450
        t_mask, t_masked, t_image = prepare_mask_and_masked_image(im, mask, height, width, return_image=True)
451
452
453

        self.assertTrue(isinstance(t_mask, torch.Tensor))
        self.assertTrue(isinstance(t_masked, torch.Tensor))
454
        self.assertTrue(isinstance(t_image, torch.Tensor))
455
456
457

        self.assertEqual(t_mask.ndim, 4)
        self.assertEqual(t_masked.ndim, 4)
458
        self.assertEqual(t_image.ndim, 4)
459

460
461
        self.assertEqual(t_mask.shape, (1, 1, height, width))
        self.assertEqual(t_masked.shape, (1, 3, height, width))
462
        self.assertEqual(t_image.shape, (1, 3, height, width))
463
464
465

        self.assertTrue(t_mask.dtype == torch.float32)
        self.assertTrue(t_masked.dtype == torch.float32)
466
        self.assertTrue(t_image.dtype == torch.float32)
467
468
469
470
471

        self.assertTrue(t_mask.min() >= 0.0)
        self.assertTrue(t_mask.max() <= 1.0)
        self.assertTrue(t_masked.min() >= -1.0)
        self.assertTrue(t_masked.min() <= 1.0)
472
473
        self.assertTrue(t_image.min() >= -1.0)
        self.assertTrue(t_image.min() >= -1.0)
474
475
476
477

        self.assertTrue(t_mask.sum() > 0.0)

    def test_np_inputs(self):
478
479
480
        height, width = 32, 32

        im_np = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8)
481
        im_pil = Image.fromarray(im_np)
Patrick von Platen's avatar
Patrick von Platen committed
482
483
484
485
486
487
488
489
490
491
492
493
        mask_np = (
            np.random.randint(
                0,
                255,
                (
                    height,
                    width,
                ),
                dtype=np.uint8,
            )
            > 127.5
        )
494
495
        mask_pil = Image.fromarray((mask_np * 255).astype(np.uint8))

496
497
498
499
500
501
        t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(
            im_np, mask_np, height, width, return_image=True
        )
        t_mask_pil, t_masked_pil, t_image_pil = prepare_mask_and_masked_image(
            im_pil, mask_pil, height, width, return_image=True
        )
502
503
504

        self.assertTrue((t_mask_np == t_mask_pil).all())
        self.assertTrue((t_masked_np == t_masked_pil).all())
505
        self.assertTrue((t_image_np == t_image_pil).all())
506
507

    def test_torch_3D_2D_inputs(self):
508
509
        height, width = 32, 32

Patrick von Platen's avatar
Patrick von Platen committed
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
        im_tensor = torch.randint(
            0,
            255,
            (
                3,
                height,
                width,
            ),
            dtype=torch.uint8,
        )
        mask_tensor = (
            torch.randint(
                0,
                255,
                (
                    height,
                    width,
                ),
                dtype=torch.uint8,
            )
            > 127.5
        )
532
533
534
        im_np = im_tensor.numpy().transpose(1, 2, 0)
        mask_np = mask_tensor.numpy()

535
536
537
538
539
        t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
            im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
        )
        t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(
            im_np, mask_np, height, width, return_image=True
Patrick von Platen's avatar
Patrick von Platen committed
540
        )
541
542
543

        self.assertTrue((t_mask_tensor == t_mask_np).all())
        self.assertTrue((t_masked_tensor == t_masked_np).all())
544
        self.assertTrue((t_image_tensor == t_image_np).all())
545
546

    def test_torch_3D_3D_inputs(self):
547
548
        height, width = 32, 32

Patrick von Platen's avatar
Patrick von Platen committed
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
        im_tensor = torch.randint(
            0,
            255,
            (
                3,
                height,
                width,
            ),
            dtype=torch.uint8,
        )
        mask_tensor = (
            torch.randint(
                0,
                255,
                (
                    1,
                    height,
                    width,
                ),
                dtype=torch.uint8,
            )
            > 127.5
        )
572
573
574
        im_np = im_tensor.numpy().transpose(1, 2, 0)
        mask_np = mask_tensor.numpy()[0]

575
576
577
578
579
        t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
            im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
        )
        t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(
            im_np, mask_np, height, width, return_image=True
Patrick von Platen's avatar
Patrick von Platen committed
580
        )
581
582
583

        self.assertTrue((t_mask_tensor == t_mask_np).all())
        self.assertTrue((t_masked_tensor == t_masked_np).all())
584
        self.assertTrue((t_image_tensor == t_image_np).all())
585
586

    def test_torch_4D_2D_inputs(self):
587
588
        height, width = 32, 32

Patrick von Platen's avatar
Patrick von Platen committed
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
        im_tensor = torch.randint(
            0,
            255,
            (
                1,
                3,
                height,
                width,
            ),
            dtype=torch.uint8,
        )
        mask_tensor = (
            torch.randint(
                0,
                255,
                (
                    height,
                    width,
                ),
                dtype=torch.uint8,
            )
            > 127.5
        )
612
613
614
        im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
        mask_np = mask_tensor.numpy()

615
616
617
618
619
        t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
            im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
        )
        t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(
            im_np, mask_np, height, width, return_image=True
Patrick von Platen's avatar
Patrick von Platen committed
620
        )
621
622
623

        self.assertTrue((t_mask_tensor == t_mask_np).all())
        self.assertTrue((t_masked_tensor == t_masked_np).all())
624
        self.assertTrue((t_image_tensor == t_image_np).all())
625
626

    def test_torch_4D_3D_inputs(self):
627
628
        height, width = 32, 32

Patrick von Platen's avatar
Patrick von Platen committed
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
        im_tensor = torch.randint(
            0,
            255,
            (
                1,
                3,
                height,
                width,
            ),
            dtype=torch.uint8,
        )
        mask_tensor = (
            torch.randint(
                0,
                255,
                (
                    1,
                    height,
                    width,
                ),
                dtype=torch.uint8,
            )
            > 127.5
        )
653
654
655
        im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
        mask_np = mask_tensor.numpy()[0]

656
657
658
659
660
        t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
            im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
        )
        t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(
            im_np, mask_np, height, width, return_image=True
Patrick von Platen's avatar
Patrick von Platen committed
661
        )
662
663
664

        self.assertTrue((t_mask_tensor == t_mask_np).all())
        self.assertTrue((t_masked_tensor == t_masked_np).all())
665
        self.assertTrue((t_image_tensor == t_image_np).all())
666
667

    def test_torch_4D_4D_inputs(self):
668
669
        height, width = 32, 32

Patrick von Platen's avatar
Patrick von Platen committed
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
        im_tensor = torch.randint(
            0,
            255,
            (
                1,
                3,
                height,
                width,
            ),
            dtype=torch.uint8,
        )
        mask_tensor = (
            torch.randint(
                0,
                255,
                (
                    1,
                    1,
                    height,
                    width,
                ),
                dtype=torch.uint8,
            )
            > 127.5
        )
695
696
697
        im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
        mask_np = mask_tensor.numpy()[0][0]

698
699
700
701
702
        t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
            im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
        )
        t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(
            im_np, mask_np, height, width, return_image=True
Patrick von Platen's avatar
Patrick von Platen committed
703
        )
704
705
706

        self.assertTrue((t_mask_tensor == t_mask_np).all())
        self.assertTrue((t_masked_tensor == t_masked_np).all())
707
        self.assertTrue((t_image_tensor == t_image_np).all())
708
709

    def test_torch_batch_4D_3D(self):
710
711
        height, width = 32, 32

Patrick von Platen's avatar
Patrick von Platen committed
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
        im_tensor = torch.randint(
            0,
            255,
            (
                2,
                3,
                height,
                width,
            ),
            dtype=torch.uint8,
        )
        mask_tensor = (
            torch.randint(
                0,
                255,
                (
                    2,
                    height,
                    width,
                ),
                dtype=torch.uint8,
            )
            > 127.5
        )
736
737
738
739

        im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor]
        mask_nps = [mask.numpy() for mask in mask_tensor]

740
741
        t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
            im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
Patrick von Platen's avatar
Patrick von Platen committed
742
        )
743
        nps = [prepare_mask_and_masked_image(i, m, height, width, return_image=True) for i, m in zip(im_nps, mask_nps)]
744
745
        t_mask_np = torch.cat([n[0] for n in nps])
        t_masked_np = torch.cat([n[1] for n in nps])
746
        t_image_np = torch.cat([n[2] for n in nps])
747
748
749

        self.assertTrue((t_mask_tensor == t_mask_np).all())
        self.assertTrue((t_masked_tensor == t_masked_np).all())
750
        self.assertTrue((t_image_tensor == t_image_np).all())
751
752

    def test_torch_batch_4D_4D(self):
753
754
        height, width = 32, 32

Patrick von Platen's avatar
Patrick von Platen committed
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
        im_tensor = torch.randint(
            0,
            255,
            (
                2,
                3,
                height,
                width,
            ),
            dtype=torch.uint8,
        )
        mask_tensor = (
            torch.randint(
                0,
                255,
                (
                    2,
                    1,
                    height,
                    width,
                ),
                dtype=torch.uint8,
            )
            > 127.5
        )
780
781
782
783

        im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor]
        mask_nps = [mask.numpy()[0] for mask in mask_tensor]

784
785
        t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
            im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
Patrick von Platen's avatar
Patrick von Platen committed
786
        )
787
        nps = [prepare_mask_and_masked_image(i, m, height, width, return_image=True) for i, m in zip(im_nps, mask_nps)]
788
789
        t_mask_np = torch.cat([n[0] for n in nps])
        t_masked_np = torch.cat([n[1] for n in nps])
790
        t_image_np = torch.cat([n[2] for n in nps])
791
792
793

        self.assertTrue((t_mask_tensor == t_mask_np).all())
        self.assertTrue((t_masked_tensor == t_masked_np).all())
794
        self.assertTrue((t_image_tensor == t_image_np).all())
795
796

    def test_shape_mismatch(self):
797
798
        height, width = 32, 32

799
800
        # test height and width
        with self.assertRaises(AssertionError):
Patrick von Platen's avatar
Patrick von Platen committed
801
802
803
804
805
806
807
808
809
            prepare_mask_and_masked_image(
                torch.randn(
                    3,
                    height,
                    width,
                ),
                torch.randn(64, 64),
                height,
                width,
810
                return_image=True,
Patrick von Platen's avatar
Patrick von Platen committed
811
            )
812
813
        # test batch dim
        with self.assertRaises(AssertionError):
Patrick von Platen's avatar
Patrick von Platen committed
814
815
816
817
818
819
820
821
822
823
            prepare_mask_and_masked_image(
                torch.randn(
                    2,
                    3,
                    height,
                    width,
                ),
                torch.randn(4, 64, 64),
                height,
                width,
824
                return_image=True,
Patrick von Platen's avatar
Patrick von Platen committed
825
            )
826
827
        # test batch dim
        with self.assertRaises(AssertionError):
Patrick von Platen's avatar
Patrick von Platen committed
828
829
830
831
832
833
834
835
836
837
            prepare_mask_and_masked_image(
                torch.randn(
                    2,
                    3,
                    height,
                    width,
                ),
                torch.randn(4, 1, 64, 64),
                height,
                width,
838
                return_image=True,
Patrick von Platen's avatar
Patrick von Platen committed
839
            )
840
841

    def test_type_mismatch(self):
842
843
        height, width = 32, 32

844
845
        # test tensors-only
        with self.assertRaises(TypeError):
Patrick von Platen's avatar
Patrick von Platen committed
846
847
848
849
850
851
852
853
854
855
856
857
858
            prepare_mask_and_masked_image(
                torch.rand(
                    3,
                    height,
                    width,
                ),
                torch.rand(
                    3,
                    height,
                    width,
                ).numpy(),
                height,
                width,
859
                return_image=True,
Patrick von Platen's avatar
Patrick von Platen committed
860
            )
861
862
        # test tensors-only
        with self.assertRaises(TypeError):
Patrick von Platen's avatar
Patrick von Platen committed
863
864
865
866
867
868
869
870
871
872
873
874
875
            prepare_mask_and_masked_image(
                torch.rand(
                    3,
                    height,
                    width,
                ).numpy(),
                torch.rand(
                    3,
                    height,
                    width,
                ),
                height,
                width,
876
                return_image=True,
Patrick von Platen's avatar
Patrick von Platen committed
877
            )
878
879

    def test_channels_first(self):
880
881
        height, width = 32, 32

882
883
        # test channels first for 3D tensors
        with self.assertRaises(AssertionError):
Patrick von Platen's avatar
Patrick von Platen committed
884
885
886
887
888
889
890
891
892
            prepare_mask_and_masked_image(
                torch.rand(height, width, 3),
                torch.rand(
                    3,
                    height,
                    width,
                ),
                height,
                width,
893
                return_image=True,
Patrick von Platen's avatar
Patrick von Platen committed
894
            )
895
896

    def test_tensor_range(self):
897
898
        height, width = 32, 32

899
900
        # test im <= 1
        with self.assertRaises(ValueError):
Patrick von Platen's avatar
Patrick von Platen committed
901
902
903
904
905
906
907
908
909
910
911
912
913
            prepare_mask_and_masked_image(
                torch.ones(
                    3,
                    height,
                    width,
                )
                * 2,
                torch.rand(
                    height,
                    width,
                ),
                height,
                width,
914
                return_image=True,
Patrick von Platen's avatar
Patrick von Platen committed
915
            )
916
917
        # test im >= -1
        with self.assertRaises(ValueError):
Patrick von Platen's avatar
Patrick von Platen committed
918
919
920
921
922
923
924
925
926
927
928
929
930
            prepare_mask_and_masked_image(
                torch.ones(
                    3,
                    height,
                    width,
                )
                * (-2),
                torch.rand(
                    height,
                    width,
                ),
                height,
                width,
931
                return_image=True,
Patrick von Platen's avatar
Patrick von Platen committed
932
            )
933
934
        # test mask <= 1
        with self.assertRaises(ValueError):
Patrick von Platen's avatar
Patrick von Platen committed
935
936
937
938
939
940
941
942
943
944
945
946
947
            prepare_mask_and_masked_image(
                torch.rand(
                    3,
                    height,
                    width,
                ),
                torch.ones(
                    height,
                    width,
                )
                * 2,
                height,
                width,
948
                return_image=True,
Patrick von Platen's avatar
Patrick von Platen committed
949
            )
950
951
        # test mask >= 0
        with self.assertRaises(ValueError):
Patrick von Platen's avatar
Patrick von Platen committed
952
953
954
955
956
957
958
959
960
961
962
963
964
            prepare_mask_and_masked_image(
                torch.rand(
                    3,
                    height,
                    width,
                ),
                torch.ones(
                    height,
                    width,
                )
                * -1,
                height,
                width,
965
                return_image=True,
Patrick von Platen's avatar
Patrick von Platen committed
966
            )