Unverified Commit edb087a2 authored by Rupert Menneer's avatar Rupert Menneer Committed by GitHub
Browse files

StableDiffusionInpaintingPipeline - resize image w.r.t height and width (#3322)



* StableDiffusionInpaintingPipeline now resizes input images and masks w.r.t to passed input height and width. Default is already set to 512. This addresses the common tensor mismatch error. Also moved type check into relevant funciton to keep main pipeline body tidy.

* Fixed StableDiffusionInpaintingPrepareMaskAndMaskedImageTests

Due to previous commit these tests were failing as height and width need to be passed into the prepare_mask_and_masked_image function, I have updated the code and added a height/width variable per unit test as it seemed more appropriate than the current hard coded solution

* Added a resolution test to StableDiffusionInpaintPipelineSlowTests

this unit test simply gets the input and resizes it into some that would fail (e.g. would throw a tensor mismatch error/not a mult of 8). Then passes it through the pipeline and verifies it produces output with correct dims w.r.t the passed height and width

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 94a0c644
...@@ -36,7 +36,7 @@ from .safety_checker import StableDiffusionSafetyChecker ...@@ -36,7 +36,7 @@ from .safety_checker import StableDiffusionSafetyChecker
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def prepare_mask_and_masked_image(image, mask): def prepare_mask_and_masked_image(image, mask, height, width):
""" """
Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
...@@ -64,6 +64,13 @@ def prepare_mask_and_masked_image(image, mask): ...@@ -64,6 +64,13 @@ def prepare_mask_and_masked_image(image, mask):
tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
dimensions: ``batch x channels x height x width``. dimensions: ``batch x channels x height x width``.
""" """
if image is None:
raise ValueError("`image` input cannot be undefined.")
if mask is None:
raise ValueError("`mask_image` input cannot be undefined.")
if isinstance(image, torch.Tensor): if isinstance(image, torch.Tensor):
if not isinstance(mask, torch.Tensor): if not isinstance(mask, torch.Tensor):
raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not")
...@@ -111,8 +118,9 @@ def prepare_mask_and_masked_image(image, mask): ...@@ -111,8 +118,9 @@ def prepare_mask_and_masked_image(image, mask):
# preprocess image # preprocess image
if isinstance(image, (PIL.Image.Image, np.ndarray)): if isinstance(image, (PIL.Image.Image, np.ndarray)):
image = [image] image = [image]
if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
# resize all images w.r.t passed height an width
image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
image = [np.array(i.convert("RGB"))[None, :] for i in image] image = [np.array(i.convert("RGB"))[None, :] for i in image]
image = np.concatenate(image, axis=0) image = np.concatenate(image, axis=0)
elif isinstance(image, list) and isinstance(image[0], np.ndarray): elif isinstance(image, list) and isinstance(image[0], np.ndarray):
...@@ -126,6 +134,7 @@ def prepare_mask_and_masked_image(image, mask): ...@@ -126,6 +134,7 @@ def prepare_mask_and_masked_image(image, mask):
mask = [mask] mask = [mask]
if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
mask = mask.astype(np.float32) / 255.0 mask = mask.astype(np.float32) / 255.0
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
...@@ -799,12 +808,6 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMi ...@@ -799,12 +808,6 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMi
negative_prompt_embeds, negative_prompt_embeds,
) )
if image is None:
raise ValueError("`image` input cannot be undefined.")
if mask_image is None:
raise ValueError("`mask_image` input cannot be undefined.")
# 2. Define call parameters # 2. Define call parameters
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
...@@ -830,8 +833,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMi ...@@ -830,8 +833,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMi
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
) )
# 4. Preprocess mask and image # 4. Preprocess mask and image - resizes image and mask w.r.t height and width
mask, masked_image = prepare_mask_and_masked_image(image, mask_image) mask, masked_image = prepare_mask_and_masked_image(image, mask_image, height, width)
# 5. set timesteps # 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
......
...@@ -303,6 +303,25 @@ class StableDiffusionInpaintPipelineSlowTests(unittest.TestCase): ...@@ -303,6 +303,25 @@ class StableDiffusionInpaintPipelineSlowTests(unittest.TestCase):
assert np.abs(expected_slice - image_slice).max() < 1e-4 assert np.abs(expected_slice - image_slice).max() < 1e-4
assert np.abs(expected_slice - image_slice).max() < 1e-3 assert np.abs(expected_slice - image_slice).max() < 1e-3
def test_stable_diffusion_inpaint_pil_input_resolution_test(self):
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting", safety_checker=None
)
pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
inputs = self.get_inputs(torch_device)
# change input image to a random size (one that would cause a tensor mismatch error)
inputs['image'] = inputs['image'].resize((127,127))
inputs['mask_image'] = inputs['mask_image'].resize((127,127))
inputs['height'] = 128
inputs['width'] = 128
image = pipe(**inputs).images
# verify that the returned image has the same height and width as the input height and width
assert image.shape == (1, inputs['height'], inputs['width'], 3)
@nightly @nightly
@require_torch_gpu @require_torch_gpu
...@@ -400,12 +419,13 @@ class StableDiffusionInpaintPipelineNightlyTests(unittest.TestCase): ...@@ -400,12 +419,13 @@ class StableDiffusionInpaintPipelineNightlyTests(unittest.TestCase):
class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase): class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase):
def test_pil_inputs(self): def test_pil_inputs(self):
im = np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8) height, width = 32, 32
im = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8)
im = Image.fromarray(im) im = Image.fromarray(im)
mask = np.random.randint(0, 255, (32, 32), dtype=np.uint8) > 127.5 mask = np.random.randint(0, 255, (height, width), dtype=np.uint8) > 127.5
mask = Image.fromarray((mask * 255).astype(np.uint8)) mask = Image.fromarray((mask * 255).astype(np.uint8))
t_mask, t_masked = prepare_mask_and_masked_image(im, mask) t_mask, t_masked = prepare_mask_and_masked_image(im, mask, height, width)
self.assertTrue(isinstance(t_mask, torch.Tensor)) self.assertTrue(isinstance(t_mask, torch.Tensor))
self.assertTrue(isinstance(t_masked, torch.Tensor)) self.assertTrue(isinstance(t_masked, torch.Tensor))
...@@ -413,8 +433,8 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase) ...@@ -413,8 +433,8 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase)
self.assertEqual(t_mask.ndim, 4) self.assertEqual(t_mask.ndim, 4)
self.assertEqual(t_masked.ndim, 4) self.assertEqual(t_masked.ndim, 4)
self.assertEqual(t_mask.shape, (1, 1, 32, 32)) self.assertEqual(t_mask.shape, (1, 1, height, width))
self.assertEqual(t_masked.shape, (1, 3, 32, 32)) self.assertEqual(t_masked.shape, (1, 3, height, width))
self.assertTrue(t_mask.dtype == torch.float32) self.assertTrue(t_mask.dtype == torch.float32)
self.assertTrue(t_masked.dtype == torch.float32) self.assertTrue(t_masked.dtype == torch.float32)
...@@ -427,86 +447,100 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase) ...@@ -427,86 +447,100 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase)
self.assertTrue(t_mask.sum() > 0.0) self.assertTrue(t_mask.sum() > 0.0)
def test_np_inputs(self): def test_np_inputs(self):
im_np = np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8) height, width = 32, 32
im_np = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8)
im_pil = Image.fromarray(im_np) im_pil = Image.fromarray(im_np)
mask_np = np.random.randint(0, 255, (32, 32), dtype=np.uint8) > 127.5 mask_np = np.random.randint(0, 255, (height, width,), dtype=np.uint8) > 127.5
mask_pil = Image.fromarray((mask_np * 255).astype(np.uint8)) mask_pil = Image.fromarray((mask_np * 255).astype(np.uint8))
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np) t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)
t_mask_pil, t_masked_pil = prepare_mask_and_masked_image(im_pil, mask_pil) t_mask_pil, t_masked_pil = prepare_mask_and_masked_image(im_pil, mask_pil, height, width)
self.assertTrue((t_mask_np == t_mask_pil).all()) self.assertTrue((t_mask_np == t_mask_pil).all())
self.assertTrue((t_masked_np == t_masked_pil).all()) self.assertTrue((t_masked_np == t_masked_pil).all())
def test_torch_3D_2D_inputs(self): def test_torch_3D_2D_inputs(self):
im_tensor = torch.randint(0, 255, (3, 32, 32), dtype=torch.uint8) height, width = 32, 32
mask_tensor = torch.randint(0, 255, (32, 32), dtype=torch.uint8) > 127.5
im_tensor = torch.randint(0, 255, (3, height, width,), dtype=torch.uint8)
mask_tensor = torch.randint(0, 255, (height, width,), dtype=torch.uint8) > 127.5
im_np = im_tensor.numpy().transpose(1, 2, 0) im_np = im_tensor.numpy().transpose(1, 2, 0)
mask_np = mask_tensor.numpy() mask_np = mask_tensor.numpy()
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width)
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np) t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)
self.assertTrue((t_mask_tensor == t_mask_np).all()) self.assertTrue((t_mask_tensor == t_mask_np).all())
self.assertTrue((t_masked_tensor == t_masked_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all())
def test_torch_3D_3D_inputs(self): def test_torch_3D_3D_inputs(self):
im_tensor = torch.randint(0, 255, (3, 32, 32), dtype=torch.uint8) height, width = 32, 32
mask_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) > 127.5
im_tensor = torch.randint(0, 255, (3, height, width,), dtype=torch.uint8)
mask_tensor = torch.randint(0, 255, (1, height, width,), dtype=torch.uint8) > 127.5
im_np = im_tensor.numpy().transpose(1, 2, 0) im_np = im_tensor.numpy().transpose(1, 2, 0)
mask_np = mask_tensor.numpy()[0] mask_np = mask_tensor.numpy()[0]
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width)
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np) t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)
self.assertTrue((t_mask_tensor == t_mask_np).all()) self.assertTrue((t_mask_tensor == t_mask_np).all())
self.assertTrue((t_masked_tensor == t_masked_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all())
def test_torch_4D_2D_inputs(self): def test_torch_4D_2D_inputs(self):
im_tensor = torch.randint(0, 255, (1, 3, 32, 32), dtype=torch.uint8) height, width = 32, 32
mask_tensor = torch.randint(0, 255, (32, 32), dtype=torch.uint8) > 127.5
im_tensor = torch.randint(0, 255, (1, 3, height, width,), dtype=torch.uint8)
mask_tensor = torch.randint(0, 255, (height, width,), dtype=torch.uint8) > 127.5
im_np = im_tensor.numpy()[0].transpose(1, 2, 0) im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
mask_np = mask_tensor.numpy() mask_np = mask_tensor.numpy()
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width)
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np) t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)
self.assertTrue((t_mask_tensor == t_mask_np).all()) self.assertTrue((t_mask_tensor == t_mask_np).all())
self.assertTrue((t_masked_tensor == t_masked_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all())
def test_torch_4D_3D_inputs(self): def test_torch_4D_3D_inputs(self):
im_tensor = torch.randint(0, 255, (1, 3, 32, 32), dtype=torch.uint8) height, width = 32, 32
mask_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) > 127.5
im_tensor = torch.randint(0, 255, (1, 3, height, width,), dtype=torch.uint8)
mask_tensor = torch.randint(0, 255, (1, height, width,), dtype=torch.uint8) > 127.5
im_np = im_tensor.numpy()[0].transpose(1, 2, 0) im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
mask_np = mask_tensor.numpy()[0] mask_np = mask_tensor.numpy()[0]
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width)
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np) t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)
self.assertTrue((t_mask_tensor == t_mask_np).all()) self.assertTrue((t_mask_tensor == t_mask_np).all())
self.assertTrue((t_masked_tensor == t_masked_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all())
def test_torch_4D_4D_inputs(self): def test_torch_4D_4D_inputs(self):
im_tensor = torch.randint(0, 255, (1, 3, 32, 32), dtype=torch.uint8) height, width = 32, 32
mask_tensor = torch.randint(0, 255, (1, 1, 32, 32), dtype=torch.uint8) > 127.5
im_tensor = torch.randint(0, 255, (1, 3, height, width,), dtype=torch.uint8)
mask_tensor = torch.randint(0, 255, (1, 1, height, width,), dtype=torch.uint8) > 127.5
im_np = im_tensor.numpy()[0].transpose(1, 2, 0) im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
mask_np = mask_tensor.numpy()[0][0] mask_np = mask_tensor.numpy()[0][0]
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width)
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np) t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)
self.assertTrue((t_mask_tensor == t_mask_np).all()) self.assertTrue((t_mask_tensor == t_mask_np).all())
self.assertTrue((t_masked_tensor == t_masked_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all())
def test_torch_batch_4D_3D(self): def test_torch_batch_4D_3D(self):
im_tensor = torch.randint(0, 255, (2, 3, 32, 32), dtype=torch.uint8) height, width = 32, 32
mask_tensor = torch.randint(0, 255, (2, 32, 32), dtype=torch.uint8) > 127.5
im_tensor = torch.randint(0, 255, (2, 3, height, width,), dtype=torch.uint8)
mask_tensor = torch.randint(0, 255, (2, height, width,), dtype=torch.uint8) > 127.5
im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor] im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor]
mask_nps = [mask.numpy() for mask in mask_tensor] mask_nps = [mask.numpy() for mask in mask_tensor]
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width)
nps = [prepare_mask_and_masked_image(i, m) for i, m in zip(im_nps, mask_nps)] nps = [prepare_mask_and_masked_image(i, m, height, width) for i, m in zip(im_nps, mask_nps)]
t_mask_np = torch.cat([n[0] for n in nps]) t_mask_np = torch.cat([n[0] for n in nps])
t_masked_np = torch.cat([n[1] for n in nps]) t_masked_np = torch.cat([n[1] for n in nps])
...@@ -514,14 +548,16 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase) ...@@ -514,14 +548,16 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase)
self.assertTrue((t_masked_tensor == t_masked_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all())
def test_torch_batch_4D_4D(self): def test_torch_batch_4D_4D(self):
im_tensor = torch.randint(0, 255, (2, 3, 32, 32), dtype=torch.uint8) height, width = 32, 32
mask_tensor = torch.randint(0, 255, (2, 1, 32, 32), dtype=torch.uint8) > 127.5
im_tensor = torch.randint(0, 255, (2, 3, height, width,), dtype=torch.uint8)
mask_tensor = torch.randint(0, 255, (2, 1, height, width,), dtype=torch.uint8) > 127.5
im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor] im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor]
mask_nps = [mask.numpy()[0] for mask in mask_tensor] mask_nps = [mask.numpy()[0] for mask in mask_tensor]
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width)
nps = [prepare_mask_and_masked_image(i, m) for i, m in zip(im_nps, mask_nps)] nps = [prepare_mask_and_masked_image(i, m, height, width) for i, m in zip(im_nps, mask_nps)]
t_mask_np = torch.cat([n[0] for n in nps]) t_mask_np = torch.cat([n[0] for n in nps])
t_masked_np = torch.cat([n[1] for n in nps]) t_masked_np = torch.cat([n[1] for n in nps])
...@@ -529,39 +565,47 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase) ...@@ -529,39 +565,47 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase)
self.assertTrue((t_masked_tensor == t_masked_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all())
def test_shape_mismatch(self): def test_shape_mismatch(self):
height, width = 32, 32
# test height and width # test height and width
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
prepare_mask_and_masked_image(torch.randn(3, 32, 32), torch.randn(64, 64)) prepare_mask_and_masked_image(torch.randn(3, height, width,), torch.randn(64, 64), height, width)
# test batch dim # test batch dim
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
prepare_mask_and_masked_image(torch.randn(2, 3, 32, 32), torch.randn(4, 64, 64)) prepare_mask_and_masked_image(torch.randn(2, 3, height, width,), torch.randn(4, 64, 64), height, width)
# test batch dim # test batch dim
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
prepare_mask_and_masked_image(torch.randn(2, 3, 32, 32), torch.randn(4, 1, 64, 64)) prepare_mask_and_masked_image(torch.randn(2, 3, height, width,), torch.randn(4, 1, 64, 64), height, width)
def test_type_mismatch(self): def test_type_mismatch(self):
height, width = 32, 32
# test tensors-only # test tensors-only
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.rand(3, 32, 32).numpy()) prepare_mask_and_masked_image(torch.rand(3, height, width,), torch.rand(3, height, width,).numpy(), height, width)
# test tensors-only # test tensors-only
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
prepare_mask_and_masked_image(torch.rand(3, 32, 32).numpy(), torch.rand(3, 32, 32)) prepare_mask_and_masked_image(torch.rand(3, height, width,).numpy(), torch.rand(3, height, width,), height, width)
def test_channels_first(self): def test_channels_first(self):
height, width = 32, 32
# test channels first for 3D tensors # test channels first for 3D tensors
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
prepare_mask_and_masked_image(torch.rand(32, 32, 3), torch.rand(3, 32, 32)) prepare_mask_and_masked_image(torch.rand(height, width, 3), torch.rand(3, height, width,), height, width)
def test_tensor_range(self): def test_tensor_range(self):
height, width = 32, 32
# test im <= 1 # test im <= 1
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
prepare_mask_and_masked_image(torch.ones(3, 32, 32) * 2, torch.rand(32, 32)) prepare_mask_and_masked_image(torch.ones(3, height, width,) * 2, torch.rand(height, width,), height, width)
# test im >= -1 # test im >= -1
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
prepare_mask_and_masked_image(torch.ones(3, 32, 32) * (-2), torch.rand(32, 32)) prepare_mask_and_masked_image(torch.ones(3, height, width,) * (-2), torch.rand(height, width,), height, width)
# test mask <= 1 # test mask <= 1
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.ones(32, 32) * 2) prepare_mask_and_masked_image(torch.rand(3, height, width,), torch.ones(height, width,) * 2, height, width)
# test mask >= 0 # test mask >= 0
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.ones(32, 32) * -1) prepare_mask_and_masked_image(torch.rand(3, height, width,), torch.ones(height, width,) * -1, height, width)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment