Commit ab1f01e6 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

make style

parent 2b31740d
...@@ -36,30 +36,25 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name ...@@ -36,30 +36,25 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def prepare_mask_and_masked_image(image, mask): def prepare_mask_and_masked_image(image, mask):
""" """
Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
This means that those inputs will be converted to ``torch.Tensor`` with converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for ``image`` and ``1`` for the ``mask``.
the ``image`` and ``1`` for the ``mask``.
The ``image`` will be converted to ``torch.float32`` and normalized to be in The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
``[-1, 1]``. The ``mask`` will be binarized (``mask > 0.5``) and cast to binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
``torch.float32`` too.
Args: Args:
image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
or a ``channels x height x width`` ``torch.Tensor`` or a ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
``batch x channels x height x width`` ``torch.Tensor``.
mask (_type_): The mask to apply to the image, i.e. regions to inpaint. mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
a ``1 x height x width`` ``torch.Tensor`` or a ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
``batch x 1 x height x width`` ``torch.Tensor``.
Raises: Raises:
ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
ValueError: ``torch.Tensor`` mask should be in the ``[0, 1]`` range. should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
(ot the other way around). (ot the other way around).
......
...@@ -29,10 +29,8 @@ from diffusers import ( ...@@ -29,10 +29,8 @@ from diffusers import (
UNet2DModel, UNet2DModel,
VQModel, VQModel,
) )
from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import prepare_mask_and_masked_image from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import prepare_mask_and_masked_image
from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu from diffusers.utils.testing_utils import require_torch_gpu
from PIL import Image from PIL import Image
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
...@@ -510,6 +508,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase): ...@@ -510,6 +508,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
# make sure that less than 2.2 GB is allocated # make sure that less than 2.2 GB is allocated
assert mem_bytes < 2.2 * 10**9 assert mem_bytes < 2.2 * 10**9
class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase): class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase):
def test_pil_inputs(self): def test_pil_inputs(self):
im = np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8) im = np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8)
...@@ -676,4 +675,4 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase) ...@@ -676,4 +675,4 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase)
prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.ones(32, 32) * 2) prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.ones(32, 32) * 2)
# test mask >= 0 # test mask >= 0
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.ones(32, 32) * -1) prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.ones(32, 32) * -1)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment