"...text-generation-inference.git" did not exist on "15b3e9ffb0ce1506c455f060d56f1f196baf49f3"
Unverified Commit 614d0c64 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

remove the deprecated prepare_mask_and_masked_image function (#8512)



remove prepare mask fn
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent b1a2c0d5
...@@ -118,129 +118,6 @@ def retrieve_latents( ...@@ -118,129 +118,6 @@ def retrieve_latents(
raise AttributeError("Could not access latents of provided encoder_output") raise AttributeError("Could not access latents of provided encoder_output")
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.prepare_mask_and_masked_image
def prepare_mask_and_masked_image(image, mask, height, width, return_image=False):
"""
Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
``image`` and ``1`` for the ``mask``.
The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
Args:
image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
Raises:
ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
(ot the other way around).
Returns:
tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
dimensions: ``batch x channels x height x width``.
"""
deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead"
deprecate(
"prepare_mask_and_masked_image",
"0.30.0",
deprecation_message,
)
if image is None:
raise ValueError("`image` input cannot be undefined.")
if mask is None:
raise ValueError("`mask_image` input cannot be undefined.")
if isinstance(image, torch.Tensor):
if not isinstance(mask, torch.Tensor):
raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not")
# Batch single image
if image.ndim == 3:
assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
image = image.unsqueeze(0)
# Batch and add channel dim for single mask
if mask.ndim == 2:
mask = mask.unsqueeze(0).unsqueeze(0)
# Batch single mask or add channel dim
if mask.ndim == 3:
# Single batched mask, no channel dim or single mask not batched but channel dim
if mask.shape[0] == 1:
mask = mask.unsqueeze(0)
# Batched masks no channel dim
else:
mask = mask.unsqueeze(1)
assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
# Check image is in [-1, 1]
if image.min() < -1 or image.max() > 1:
raise ValueError("Image should be in [-1, 1] range")
# Check mask is in [0, 1]
if mask.min() < 0 or mask.max() > 1:
raise ValueError("Mask should be in [0, 1] range")
# Binarize mask
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
# Image as float32
image = image.to(dtype=torch.float32)
elif isinstance(mask, torch.Tensor):
raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
else:
# preprocess image
if isinstance(image, (PIL.Image.Image, np.ndarray)):
image = [image]
if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
# resize all images w.r.t passed height an width
image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
image = [np.array(i.convert("RGB"))[None, :] for i in image]
image = np.concatenate(image, axis=0)
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
image = np.concatenate([i[None, :] for i in image], axis=0)
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
# preprocess mask
if isinstance(mask, (PIL.Image.Image, np.ndarray)):
mask = [mask]
if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
mask = mask.astype(np.float32) / 255.0
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)
masked_image = image * (mask < 0.5)
# n.b. ensure backwards compatibility as old function does not return image
if return_image:
return mask, masked_image, image
return mask, masked_image
class StableDiffusionControlNetInpaintPipeline( class StableDiffusionControlNetInpaintPipeline(
DiffusionPipeline, DiffusionPipeline,
StableDiffusionMixin, StableDiffusionMixin,
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
import inspect import inspect
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import PIL.Image import PIL.Image
import torch import torch
from packaging import version from packaging import version
...@@ -38,128 +37,6 @@ from .safety_checker import StableDiffusionSafetyChecker ...@@ -38,128 +37,6 @@ from .safety_checker import StableDiffusionSafetyChecker
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False):
"""
Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
``image`` and ``1`` for the ``mask``.
The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
Args:
image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
Raises:
ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
(ot the other way around).
Returns:
tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
dimensions: ``batch x channels x height x width``.
"""
deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead"
deprecate(
"prepare_mask_and_masked_image",
"0.30.0",
deprecation_message,
)
if image is None:
raise ValueError("`image` input cannot be undefined.")
if mask is None:
raise ValueError("`mask_image` input cannot be undefined.")
if isinstance(image, torch.Tensor):
if not isinstance(mask, torch.Tensor):
raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not")
# Batch single image
if image.ndim == 3:
assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
image = image.unsqueeze(0)
# Batch and add channel dim for single mask
if mask.ndim == 2:
mask = mask.unsqueeze(0).unsqueeze(0)
# Batch single mask or add channel dim
if mask.ndim == 3:
# Single batched mask, no channel dim or single mask not batched but channel dim
if mask.shape[0] == 1:
mask = mask.unsqueeze(0)
# Batched masks no channel dim
else:
mask = mask.unsqueeze(1)
assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
# Check image is in [-1, 1]
if image.min() < -1 or image.max() > 1:
raise ValueError("Image should be in [-1, 1] range")
# Check mask is in [0, 1]
if mask.min() < 0 or mask.max() > 1:
raise ValueError("Mask should be in [0, 1] range")
# Binarize mask
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
# Image as float32
image = image.to(dtype=torch.float32)
elif isinstance(mask, torch.Tensor):
raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
else:
# preprocess image
if isinstance(image, (PIL.Image.Image, np.ndarray)):
image = [image]
if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
# resize all images w.r.t passed height an width
image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
image = [np.array(i.convert("RGB"))[None, :] for i in image]
image = np.concatenate(image, axis=0)
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
image = np.concatenate([i[None, :] for i in image], axis=0)
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
# preprocess mask
if isinstance(mask, (PIL.Image.Image, np.ndarray)):
mask = [mask]
if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
mask = mask.astype(np.float32) / 255.0
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)
masked_image = image * (mask < 0.5)
# n.b. ensure backwards compatibility as old function does not return image
if return_image:
return mask, masked_image, image
return mask, masked_image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents( def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
......
...@@ -132,124 +132,6 @@ def mask_pil_to_torch(mask, height, width): ...@@ -132,124 +132,6 @@ def mask_pil_to_torch(mask, height, width):
return mask return mask
def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False):
"""
Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
``image`` and ``1`` for the ``mask``.
The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
Args:
image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
Raises:
ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
(ot the other way around).
Returns:
tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
dimensions: ``batch x channels x height x width``.
"""
# checkpoint. TOD(Yiyi) - need to clean this up later
deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead"
deprecate(
"prepare_mask_and_masked_image",
"0.30.0",
deprecation_message,
)
if image is None:
raise ValueError("`image` input cannot be undefined.")
if mask is None:
raise ValueError("`mask_image` input cannot be undefined.")
if isinstance(image, torch.Tensor):
if not isinstance(mask, torch.Tensor):
mask = mask_pil_to_torch(mask, height, width)
if image.ndim == 3:
image = image.unsqueeze(0)
# Batch and add channel dim for single mask
if mask.ndim == 2:
mask = mask.unsqueeze(0).unsqueeze(0)
# Batch single mask or add channel dim
if mask.ndim == 3:
# Single batched mask, no channel dim or single mask not batched but channel dim
if mask.shape[0] == 1:
mask = mask.unsqueeze(0)
# Batched masks no channel dim
else:
mask = mask.unsqueeze(1)
assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
# assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
# Check image is in [-1, 1]
# if image.min() < -1 or image.max() > 1:
# raise ValueError("Image should be in [-1, 1] range")
# Check mask is in [0, 1]
if mask.min() < 0 or mask.max() > 1:
raise ValueError("Mask should be in [0, 1] range")
# Binarize mask
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
# Image as float32
image = image.to(dtype=torch.float32)
elif isinstance(mask, torch.Tensor):
raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
else:
# preprocess image
if isinstance(image, (PIL.Image.Image, np.ndarray)):
image = [image]
if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
# resize all images w.r.t passed height an width
image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
image = [np.array(i.convert("RGB"))[None, :] for i in image]
image = np.concatenate(image, axis=0)
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
image = np.concatenate([i[None, :] for i in image], axis=0)
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
mask = mask_pil_to_torch(mask, height, width)
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
if image.shape[1] == 4:
# images are in latent space and thus can't
# be masked set masked_image to None
# we assume that the checkpoint is not an inpainting
# checkpoint. TOD(Yiyi) - need to clean this up later
masked_image = None
else:
masked_image = image * (mask < 0.5)
# n.b. ensure backwards compatibility as old function does not return image
if return_image:
return mask, masked_image, image
return mask, masked_image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents( def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
......
...@@ -36,7 +36,6 @@ from diffusers import ( ...@@ -36,7 +36,6 @@ from diffusers import (
StableDiffusionInpaintPipeline, StableDiffusionInpaintPipeline,
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import prepare_mask_and_masked_image
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
enable_full_determinism, enable_full_determinism,
floats_tensor, floats_tensor,
...@@ -1105,530 +1104,3 @@ class StableDiffusionInpaintPipelineNightlyTests(unittest.TestCase): ...@@ -1105,530 +1104,3 @@ class StableDiffusionInpaintPipelineNightlyTests(unittest.TestCase):
) )
max_diff = np.abs(expected_image - image).max() max_diff = np.abs(expected_image - image).max()
assert max_diff < 1e-3 assert max_diff < 1e-3
class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase):
def test_pil_inputs(self):
height, width = 32, 32
im = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8)
im = Image.fromarray(im)
mask = np.random.randint(0, 255, (height, width), dtype=np.uint8) > 127.5
mask = Image.fromarray((mask * 255).astype(np.uint8))
t_mask, t_masked, t_image = prepare_mask_and_masked_image(im, mask, height, width, return_image=True)
self.assertTrue(isinstance(t_mask, torch.Tensor))
self.assertTrue(isinstance(t_masked, torch.Tensor))
self.assertTrue(isinstance(t_image, torch.Tensor))
self.assertEqual(t_mask.ndim, 4)
self.assertEqual(t_masked.ndim, 4)
self.assertEqual(t_image.ndim, 4)
self.assertEqual(t_mask.shape, (1, 1, height, width))
self.assertEqual(t_masked.shape, (1, 3, height, width))
self.assertEqual(t_image.shape, (1, 3, height, width))
self.assertTrue(t_mask.dtype == torch.float32)
self.assertTrue(t_masked.dtype == torch.float32)
self.assertTrue(t_image.dtype == torch.float32)
self.assertTrue(t_mask.min() >= 0.0)
self.assertTrue(t_mask.max() <= 1.0)
self.assertTrue(t_masked.min() >= -1.0)
self.assertTrue(t_masked.min() <= 1.0)
self.assertTrue(t_image.min() >= -1.0)
self.assertTrue(t_image.min() >= -1.0)
self.assertTrue(t_mask.sum() > 0.0)
def test_np_inputs(self):
height, width = 32, 32
im_np = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8)
im_pil = Image.fromarray(im_np)
mask_np = (
np.random.randint(
0,
255,
(
height,
width,
),
dtype=np.uint8,
)
> 127.5
)
mask_pil = Image.fromarray((mask_np * 255).astype(np.uint8))
t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(
im_np, mask_np, height, width, return_image=True
)
t_mask_pil, t_masked_pil, t_image_pil = prepare_mask_and_masked_image(
im_pil, mask_pil, height, width, return_image=True
)
self.assertTrue((t_mask_np == t_mask_pil).all())
self.assertTrue((t_masked_np == t_masked_pil).all())
self.assertTrue((t_image_np == t_image_pil).all())
def test_torch_3D_2D_inputs(self):
height, width = 32, 32
im_tensor = torch.randint(
0,
255,
(
3,
height,
width,
),
dtype=torch.uint8,
)
mask_tensor = (
torch.randint(
0,
255,
(
height,
width,
),
dtype=torch.uint8,
)
> 127.5
)
im_np = im_tensor.numpy().transpose(1, 2, 0)
mask_np = mask_tensor.numpy()
t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
)
t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(
im_np, mask_np, height, width, return_image=True
)
self.assertTrue((t_mask_tensor == t_mask_np).all())
self.assertTrue((t_masked_tensor == t_masked_np).all())
self.assertTrue((t_image_tensor == t_image_np).all())
def test_torch_3D_3D_inputs(self):
height, width = 32, 32
im_tensor = torch.randint(
0,
255,
(
3,
height,
width,
),
dtype=torch.uint8,
)
mask_tensor = (
torch.randint(
0,
255,
(
1,
height,
width,
),
dtype=torch.uint8,
)
> 127.5
)
im_np = im_tensor.numpy().transpose(1, 2, 0)
mask_np = mask_tensor.numpy()[0]
t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
)
t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(
im_np, mask_np, height, width, return_image=True
)
self.assertTrue((t_mask_tensor == t_mask_np).all())
self.assertTrue((t_masked_tensor == t_masked_np).all())
self.assertTrue((t_image_tensor == t_image_np).all())
def test_torch_4D_2D_inputs(self):
height, width = 32, 32
im_tensor = torch.randint(
0,
255,
(
1,
3,
height,
width,
),
dtype=torch.uint8,
)
mask_tensor = (
torch.randint(
0,
255,
(
height,
width,
),
dtype=torch.uint8,
)
> 127.5
)
im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
mask_np = mask_tensor.numpy()
t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
)
t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(
im_np, mask_np, height, width, return_image=True
)
self.assertTrue((t_mask_tensor == t_mask_np).all())
self.assertTrue((t_masked_tensor == t_masked_np).all())
self.assertTrue((t_image_tensor == t_image_np).all())
def test_torch_4D_3D_inputs(self):
height, width = 32, 32
im_tensor = torch.randint(
0,
255,
(
1,
3,
height,
width,
),
dtype=torch.uint8,
)
mask_tensor = (
torch.randint(
0,
255,
(
1,
height,
width,
),
dtype=torch.uint8,
)
> 127.5
)
im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
mask_np = mask_tensor.numpy()[0]
t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
)
t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(
im_np, mask_np, height, width, return_image=True
)
self.assertTrue((t_mask_tensor == t_mask_np).all())
self.assertTrue((t_masked_tensor == t_masked_np).all())
self.assertTrue((t_image_tensor == t_image_np).all())
def test_torch_4D_4D_inputs(self):
height, width = 32, 32
im_tensor = torch.randint(
0,
255,
(
1,
3,
height,
width,
),
dtype=torch.uint8,
)
mask_tensor = (
torch.randint(
0,
255,
(
1,
1,
height,
width,
),
dtype=torch.uint8,
)
> 127.5
)
im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
mask_np = mask_tensor.numpy()[0][0]
t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
)
t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(
im_np, mask_np, height, width, return_image=True
)
self.assertTrue((t_mask_tensor == t_mask_np).all())
self.assertTrue((t_masked_tensor == t_masked_np).all())
self.assertTrue((t_image_tensor == t_image_np).all())
def test_torch_batch_4D_3D(self):
height, width = 32, 32
im_tensor = torch.randint(
0,
255,
(
2,
3,
height,
width,
),
dtype=torch.uint8,
)
mask_tensor = (
torch.randint(
0,
255,
(
2,
height,
width,
),
dtype=torch.uint8,
)
> 127.5
)
im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor]
mask_nps = [mask.numpy() for mask in mask_tensor]
t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
)
nps = [prepare_mask_and_masked_image(i, m, height, width, return_image=True) for i, m in zip(im_nps, mask_nps)]
t_mask_np = torch.cat([n[0] for n in nps])
t_masked_np = torch.cat([n[1] for n in nps])
t_image_np = torch.cat([n[2] for n in nps])
self.assertTrue((t_mask_tensor == t_mask_np).all())
self.assertTrue((t_masked_tensor == t_masked_np).all())
self.assertTrue((t_image_tensor == t_image_np).all())
def test_torch_batch_4D_4D(self):
height, width = 32, 32
im_tensor = torch.randint(
0,
255,
(
2,
3,
height,
width,
),
dtype=torch.uint8,
)
mask_tensor = (
torch.randint(
0,
255,
(
2,
1,
height,
width,
),
dtype=torch.uint8,
)
> 127.5
)
im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor]
mask_nps = [mask.numpy()[0] for mask in mask_tensor]
t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
)
nps = [prepare_mask_and_masked_image(i, m, height, width, return_image=True) for i, m in zip(im_nps, mask_nps)]
t_mask_np = torch.cat([n[0] for n in nps])
t_masked_np = torch.cat([n[1] for n in nps])
t_image_np = torch.cat([n[2] for n in nps])
self.assertTrue((t_mask_tensor == t_mask_np).all())
self.assertTrue((t_masked_tensor == t_masked_np).all())
self.assertTrue((t_image_tensor == t_image_np).all())
def test_shape_mismatch(self):
height, width = 32, 32
# test height and width
with self.assertRaises(AssertionError):
prepare_mask_and_masked_image(
torch.randn(
3,
height,
width,
),
torch.randn(64, 64),
height,
width,
return_image=True,
)
# test batch dim
with self.assertRaises(AssertionError):
prepare_mask_and_masked_image(
torch.randn(
2,
3,
height,
width,
),
torch.randn(4, 64, 64),
height,
width,
return_image=True,
)
# test batch dim
with self.assertRaises(AssertionError):
prepare_mask_and_masked_image(
torch.randn(
2,
3,
height,
width,
),
torch.randn(4, 1, 64, 64),
height,
width,
return_image=True,
)
def test_type_mismatch(self):
height, width = 32, 32
# test tensors-only
with self.assertRaises(TypeError):
prepare_mask_and_masked_image(
torch.rand(
3,
height,
width,
),
torch.rand(
3,
height,
width,
).numpy(),
height,
width,
return_image=True,
)
# test tensors-only
with self.assertRaises(TypeError):
prepare_mask_and_masked_image(
torch.rand(
3,
height,
width,
).numpy(),
torch.rand(
3,
height,
width,
),
height,
width,
return_image=True,
)
def test_channels_first(self):
height, width = 32, 32
# test channels first for 3D tensors
with self.assertRaises(AssertionError):
prepare_mask_and_masked_image(
torch.rand(height, width, 3),
torch.rand(
3,
height,
width,
),
height,
width,
return_image=True,
)
def test_tensor_range(self):
height, width = 32, 32
# test im <= 1
with self.assertRaises(ValueError):
prepare_mask_and_masked_image(
torch.ones(
3,
height,
width,
)
* 2,
torch.rand(
height,
width,
),
height,
width,
return_image=True,
)
# test im >= -1
with self.assertRaises(ValueError):
prepare_mask_and_masked_image(
torch.ones(
3,
height,
width,
)
* (-2),
torch.rand(
height,
width,
),
height,
width,
return_image=True,
)
# test mask <= 1
with self.assertRaises(ValueError):
prepare_mask_and_masked_image(
torch.rand(
3,
height,
width,
),
torch.ones(
height,
width,
)
* 2,
height,
width,
return_image=True,
)
# test mask >= 0
with self.assertRaises(ValueError):
prepare_mask_and_masked_image(
torch.rand(
3,
height,
width,
),
torch.ones(
height,
width,
)
* -1,
height,
width,
return_image=True,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment