Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
ab1f01e6
Commit
ab1f01e6
authored
Nov 20, 2022
by
Patrick von Platen
Browse files
make style
parent
2b31740d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
20 deletions
+14
-20
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
...nes/stable_diffusion/pipeline_stable_diffusion_inpaint.py
+11
-16
tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
...pelines/stable_diffusion/test_stable_diffusion_inpaint.py
+3
-4
No files found.
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
View file @
ab1f01e6
...
@@ -36,30 +36,25 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...
@@ -36,30 +36,25 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def
prepare_mask_and_masked_image
(
image
,
mask
):
def
prepare_mask_and_masked_image
(
image
,
mask
):
"""
"""
Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline.
Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
This means that those inputs will be converted to ``torch.Tensor`` with
converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for
``image`` and ``1`` for the ``mask``.
the ``image`` and ``1`` for the ``mask``.
The ``image`` will be converted to ``torch.float32`` and normalized to be in
The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
``[-1, 1]``. The ``mask`` will be binarized (``mask > 0.5``) and cast to
binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
``torch.float32`` too.
Args:
Args:
image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array``
It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
or a ``channels x height x width`` ``torch.Tensor`` or a
``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
``batch x channels x height x width`` ``torch.Tensor``.
mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or
It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
a ``1 x height x width`` ``torch.Tensor`` or a
``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
``batch x 1 x height x width`` ``torch.Tensor``.
Raises:
Raises:
ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range.
ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
ValueError: ``torch.Tensor`` mask should be in the ``[0, 1]`` range.
should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
(ot the other way around).
(ot the other way around).
...
...
tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
View file @
ab1f01e6
...
@@ -29,10 +29,8 @@ from diffusers import (
...
@@ -29,10 +29,8 @@ from diffusers import (
UNet2DModel
,
UNet2DModel
,
VQModel
,
VQModel
,
)
)
from
diffusers.utils
import
floats_tensor
,
load_image
,
load_numpy
,
slow
,
torch_device
from
diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint
import
prepare_mask_and_masked_image
from
diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint
import
prepare_mask_and_masked_image
from
diffusers.utils
import
floats_tensor
,
load_image
,
load_numpy
,
slow
,
torch_device
from
diffusers.utils.testing_utils
import
require_torch_gpu
from
diffusers.utils.testing_utils
import
require_torch_gpu
from
PIL
import
Image
from
PIL
import
Image
from
transformers
import
CLIPTextConfig
,
CLIPTextModel
,
CLIPTokenizer
from
transformers
import
CLIPTextConfig
,
CLIPTextModel
,
CLIPTokenizer
...
@@ -510,6 +508,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
...
@@ -510,6 +508,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
# make sure that less than 2.2 GB is allocated
# make sure that less than 2.2 GB is allocated
assert
mem_bytes
<
2.2
*
10
**
9
assert
mem_bytes
<
2.2
*
10
**
9
class
StableDiffusionInpaintingPrepareMaskAndMaskedImageTests
(
unittest
.
TestCase
):
class
StableDiffusionInpaintingPrepareMaskAndMaskedImageTests
(
unittest
.
TestCase
):
def
test_pil_inputs
(
self
):
def
test_pil_inputs
(
self
):
im
=
np
.
random
.
randint
(
0
,
255
,
(
32
,
32
,
3
),
dtype
=
np
.
uint8
)
im
=
np
.
random
.
randint
(
0
,
255
,
(
32
,
32
,
3
),
dtype
=
np
.
uint8
)
...
@@ -676,4 +675,4 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase)
...
@@ -676,4 +675,4 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase)
prepare_mask_and_masked_image
(
torch
.
rand
(
3
,
32
,
32
),
torch
.
ones
(
32
,
32
)
*
2
)
prepare_mask_and_masked_image
(
torch
.
rand
(
3
,
32
,
32
),
torch
.
ones
(
32
,
32
)
*
2
)
# test mask >= 0
# test mask >= 0
with
self
.
assertRaises
(
ValueError
):
with
self
.
assertRaises
(
ValueError
):
prepare_mask_and_masked_image
(
torch
.
rand
(
3
,
32
,
32
),
torch
.
ones
(
32
,
32
)
*
-
1
)
prepare_mask_and_masked_image
(
torch
.
rand
(
3
,
32
,
32
),
torch
.
ones
(
32
,
32
)
*
-
1
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment