Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
a8d09777
Unverified
Commit
a8d09777
authored
Nov 14, 2022
by
Suraj Patil
Committed by
GitHub
Nov 14, 2022
Browse files
[StableDiffusionInpaintPipeline] fix batch_size for mask and masked latents (#1279)
fix bs for mask and masked latents
parent
c9b34637
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
42 additions
and
1 deletion
+42
-1
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
...nes/stable_diffusion/pipeline_stable_diffusion_inpaint.py
+1
-1
tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
...pelines/stable_diffusion/test_stable_diffusion_inpaint.py
+41
-0
No files found.
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
View file @
a8d09777
...
@@ -536,7 +536,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
...
@@ -536,7 +536,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
mask
,
masked_image_latents
=
self
.
prepare_mask_latents
(
mask
,
masked_image_latents
=
self
.
prepare_mask_latents
(
mask
,
mask
,
masked_image
,
masked_image
,
batch_size
,
batch_size
*
num_images_per_prompt
,
height
,
height
,
width
,
width
,
text_embeddings
.
dtype
,
text_embeddings
.
dtype
,
...
...
tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
View file @
a8d09777
...
@@ -215,6 +215,47 @@ class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test
...
@@ -215,6 +215,47 @@ class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
assert
np
.
abs
(
image_from_tuple_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
assert
np
.
abs
(
image_from_tuple_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
def
test_stable_diffusion_inpaint_with_num_images_per_prompt
(
self
):
device
=
"cpu"
unet
=
self
.
dummy_cond_unet_inpaint
scheduler
=
PNDMScheduler
(
skip_prk_steps
=
True
)
vae
=
self
.
dummy_vae
bert
=
self
.
dummy_text_encoder
tokenizer
=
CLIPTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-clip"
)
image
=
self
.
dummy_image
.
cpu
().
permute
(
0
,
2
,
3
,
1
)[
0
]
init_image
=
Image
.
fromarray
(
np
.
uint8
(
image
)).
convert
(
"RGB"
).
resize
((
128
,
128
))
mask_image
=
Image
.
fromarray
(
np
.
uint8
(
image
+
4
)).
convert
(
"RGB"
).
resize
((
128
,
128
))
# make sure here that pndm scheduler skips prk
sd_pipe
=
StableDiffusionInpaintPipeline
(
unet
=
unet
,
scheduler
=
scheduler
,
vae
=
vae
,
text_encoder
=
bert
,
tokenizer
=
tokenizer
,
safety_checker
=
None
,
feature_extractor
=
None
,
)
sd_pipe
=
sd_pipe
.
to
(
device
)
sd_pipe
.
set_progress_bar_config
(
disable
=
None
)
prompt
=
"A painting of a squirrel eating a burger"
generator
=
torch
.
Generator
(
device
=
device
).
manual_seed
(
0
)
images
=
sd_pipe
(
[
prompt
],
generator
=
generator
,
guidance_scale
=
6.0
,
num_inference_steps
=
2
,
output_type
=
"np"
,
image
=
init_image
,
mask_image
=
mask_image
,
num_images_per_prompt
=
2
,
).
images
# check if the output is a list of 2 images
assert
len
(
images
)
==
2
@
unittest
.
skipIf
(
torch_device
!=
"cuda"
,
"This test requires a GPU"
)
@
unittest
.
skipIf
(
torch_device
!=
"cuda"
,
"This test requires a GPU"
)
def
test_stable_diffusion_inpaint_fp16
(
self
):
def
test_stable_diffusion_inpaint_fp16
(
self
):
"""Test that stable diffusion inpaint_legacy works with fp16"""
"""Test that stable diffusion inpaint_legacy works with fp16"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment