Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
ab1f01e6
Commit
ab1f01e6
authored
Nov 20, 2022
by
Patrick von Platen
Browse files
make style
parent
2b31740d
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
20 deletions
+14
-20
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
...nes/stable_diffusion/pipeline_stable_diffusion_inpaint.py
+11
-16
tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
...pelines/stable_diffusion/test_stable_diffusion_inpaint.py
+3
-4
No files found.
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
View file @
ab1f01e6
...
@@ -36,30 +36,25 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...
@@ -36,30 +36,25 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def
prepare_mask_and_masked_image
(
image
,
mask
):
def
prepare_mask_and_masked_image
(
image
,
mask
):
"""
"""
Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline.
Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
This means that those inputs will be converted to ``torch.Tensor`` with
converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for
``image`` and ``1`` for the ``mask``.
the ``image`` and ``1`` for the ``mask``.
The ``image`` will be converted to ``torch.float32`` and normalized to be in
The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
``[-1, 1]``. The ``mask`` will be binarized (``mask > 0.5``) and cast to
binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
``torch.float32`` too.
Args:
Args:
image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array``
It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
or a ``channels x height x width`` ``torch.Tensor`` or a
``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
``batch x channels x height x width`` ``torch.Tensor``.
mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or
It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
a ``1 x height x width`` ``torch.Tensor`` or a
``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
``batch x 1 x height x width`` ``torch.Tensor``.
Raises:
Raises:
ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range.
ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
ValueError: ``torch.Tensor`` mask should be in the ``[0, 1]`` range.
should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
(ot the other way around).
(ot the other way around).
...
...
tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
View file @
ab1f01e6
...
@@ -29,10 +29,8 @@ from diffusers import (
...
@@ -29,10 +29,8 @@ from diffusers import (
UNet2DModel
,
UNet2DModel
,
VQModel
,
VQModel
,
)
)
from
diffusers.utils
import
floats_tensor
,
load_image
,
load_numpy
,
slow
,
torch_device
from
diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint
import
prepare_mask_and_masked_image
from
diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint
import
prepare_mask_and_masked_image
from
diffusers.utils
import
floats_tensor
,
load_image
,
load_numpy
,
slow
,
torch_device
from
diffusers.utils.testing_utils
import
require_torch_gpu
from
diffusers.utils.testing_utils
import
require_torch_gpu
from
PIL
import
Image
from
PIL
import
Image
from
transformers
import
CLIPTextConfig
,
CLIPTextModel
,
CLIPTokenizer
from
transformers
import
CLIPTextConfig
,
CLIPTextModel
,
CLIPTokenizer
...
@@ -510,6 +508,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
...
@@ -510,6 +508,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
# make sure that less than 2.2 GB is allocated
# make sure that less than 2.2 GB is allocated
assert
mem_bytes
<
2.2
*
10
**
9
assert
mem_bytes
<
2.2
*
10
**
9
class
StableDiffusionInpaintingPrepareMaskAndMaskedImageTests
(
unittest
.
TestCase
):
class
StableDiffusionInpaintingPrepareMaskAndMaskedImageTests
(
unittest
.
TestCase
):
def
test_pil_inputs
(
self
):
def
test_pil_inputs
(
self
):
im
=
np
.
random
.
randint
(
0
,
255
,
(
32
,
32
,
3
),
dtype
=
np
.
uint8
)
im
=
np
.
random
.
randint
(
0
,
255
,
(
32
,
32
,
3
),
dtype
=
np
.
uint8
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment