Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
ab1f01e6
Commit
ab1f01e6
authored
Nov 20, 2022
by
Patrick von Platen
Browse files
make style
parent
2b31740d
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
20 deletions
+14
-20
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
...nes/stable_diffusion/pipeline_stable_diffusion_inpaint.py
+11
-16
tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
...pelines/stable_diffusion/test_stable_diffusion_inpaint.py
+3
-4
No files found.
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
View file @
ab1f01e6
...
...
@@ -36,30 +36,25 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def
prepare_mask_and_masked_image
(
image
,
mask
):
"""
Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline.
This means that those inputs will be converted to ``torch.Tensor`` with
shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for
the ``image`` and ``1`` for the ``mask``.
Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
``image`` and ``1`` for the ``mask``.
The ``image`` will be converted to ``torch.float32`` and normalized to be in
``[-1, 1]``. The ``mask`` will be binarized (``mask > 0.5``) and cast to
``torch.float32`` too.
The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
Args:
image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array``
or a ``channels x height x width`` ``torch.Tensor`` or a
``batch x channels x height x width`` ``torch.Tensor``.
It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or
a ``1 x height x width`` ``torch.Tensor`` or a
``batch x 1 x height x width`` ``torch.Tensor``.
It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
Raises:
ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range.
ValueError: ``torch.Tensor`` mask should be in the ``[0, 1]`` range.
ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
(ot the other way around).
...
...
tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
View file @
ab1f01e6
...
...
@@ -29,10 +29,8 @@ from diffusers import (
UNet2DModel
,
VQModel
,
)
from
diffusers.utils
import
floats_tensor
,
load_image
,
load_numpy
,
slow
,
torch_device
from
diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint
import
prepare_mask_and_masked_image
from
diffusers.utils
import
floats_tensor
,
load_image
,
load_numpy
,
slow
,
torch_device
from
diffusers.utils.testing_utils
import
require_torch_gpu
from
PIL
import
Image
from
transformers
import
CLIPTextConfig
,
CLIPTextModel
,
CLIPTokenizer
...
...
@@ -510,6 +508,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
# make sure that less than 2.2 GB is allocated
assert
mem_bytes
<
2.2
*
10
**
9
class
StableDiffusionInpaintingPrepareMaskAndMaskedImageTests
(
unittest
.
TestCase
):
def
test_pil_inputs
(
self
):
im
=
np
.
random
.
randint
(
0
,
255
,
(
32
,
32
,
3
),
dtype
=
np
.
uint8
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment