Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
513fc681
Unverified
Commit
513fc681
authored
Dec 05, 2022
by
Patrick von Platen
Committed by
GitHub
Dec 05, 2022
Browse files
[Stable Diffusion Inpaint] Allow tensor as input image & mask (#1527)
up
parent
cc22bda5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
59 additions
and
5 deletions
+59
-5
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
...nes/stable_diffusion/pipeline_stable_diffusion_inpaint.py
+1
-5
tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
...pelines/stable_diffusion/test_stable_diffusion_inpaint.py
+58
-0
No files found.
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
View file @
513fc681
...
...
@@ -630,11 +630,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
)
# 4. Preprocess mask and image
if
isinstance
(
image
,
PIL
.
Image
.
Image
)
and
isinstance
(
mask_image
,
PIL
.
Image
.
Image
):
mask
,
masked_image
=
prepare_mask_and_masked_image
(
image
,
mask_image
)
else
:
mask
=
mask_image
masked_image
=
image
*
(
mask
<
0.5
)
mask
,
masked_image
=
prepare_mask_and_masked_image
(
image
,
mask_image
)
# 5. set timesteps
self
.
scheduler
.
set_timesteps
(
num_inference_steps
,
device
=
device
)
...
...
tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
View file @
513fc681
...
...
@@ -218,6 +218,64 @@ class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
assert
np
.
abs
(
image_from_tuple_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
def
test_stable_diffusion_inpaint_image_tensor
(
self
):
device
=
"cpu"
# ensure determinism for the device-dependent torch.Generator
unet
=
self
.
dummy_cond_unet_inpaint
scheduler
=
PNDMScheduler
(
skip_prk_steps
=
True
)
vae
=
self
.
dummy_vae
bert
=
self
.
dummy_text_encoder
tokenizer
=
CLIPTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-clip"
)
image
=
self
.
dummy_image
.
repeat
(
1
,
1
,
2
,
2
)
mask_image
=
image
/
2
# make sure here that pndm scheduler skips prk
sd_pipe
=
StableDiffusionInpaintPipeline
(
unet
=
unet
,
scheduler
=
scheduler
,
vae
=
vae
,
text_encoder
=
bert
,
tokenizer
=
tokenizer
,
safety_checker
=
None
,
feature_extractor
=
None
,
)
sd_pipe
=
sd_pipe
.
to
(
device
)
sd_pipe
.
set_progress_bar_config
(
disable
=
None
)
prompt
=
"A painting of a squirrel eating a burger"
generator
=
torch
.
Generator
(
device
=
device
).
manual_seed
(
0
)
output
=
sd_pipe
(
[
prompt
],
generator
=
generator
,
guidance_scale
=
6.0
,
num_inference_steps
=
2
,
output_type
=
"np"
,
image
=
image
,
mask_image
=
mask_image
[:,
0
],
)
out_1
=
output
.
images
image
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
)[
0
]
mask_image
=
mask_image
.
cpu
().
permute
(
0
,
2
,
3
,
1
)[
0
]
image
=
Image
.
fromarray
(
np
.
uint8
(
image
)).
convert
(
"RGB"
)
mask_image
=
Image
.
fromarray
(
np
.
uint8
(
mask_image
)).
convert
(
"RGB"
)
generator
=
torch
.
Generator
(
device
=
device
).
manual_seed
(
0
)
output
=
sd_pipe
(
[
prompt
],
generator
=
generator
,
guidance_scale
=
6.0
,
num_inference_steps
=
2
,
output_type
=
"np"
,
image
=
image
,
mask_image
=
mask_image
,
)
out_2
=
output
.
images
assert
out_1
.
shape
==
(
1
,
64
,
64
,
3
)
assert
np
.
abs
(
out_1
.
flatten
()
-
out_2
.
flatten
()).
max
()
<
5e-2
def
test_stable_diffusion_inpaint_with_num_images_per_prompt
(
self
):
device
=
"cpu"
unet
=
self
.
dummy_cond_unet_inpaint
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment