Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
26c7df5d
Unverified
Commit
26c7df5d
authored
Oct 13, 2022
by
Anton Lozhkov
Committed by
GitHub
Oct 13, 2022
Browse files
Fix type mismatch error, add tests for negative prompts (#823)
parent
e001fede
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
138 additions
and
9 deletions
+138
-9
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
...s/pipelines/stable_diffusion/pipeline_stable_diffusion.py
+2
-2
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
...nes/stable_diffusion/pipeline_stable_diffusion_img2img.py
+4
-3
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
...nes/stable_diffusion/pipeline_stable_diffusion_inpaint.py
+2
-2
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py
...elines/stable_diffusion/pipeline_stable_diffusion_onnx.py
+2
-2
tests/test_pipelines.py
tests/test_pipelines.py
+128
-0
No files found.
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
View file @
26c7df5d
...
@@ -234,8 +234,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
...
@@ -234,8 +234,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
uncond_tokens
=
[
""
]
uncond_tokens
=
[
""
]
elif
type
(
prompt
)
is
not
type
(
negative_prompt
):
elif
type
(
prompt
)
is
not
type
(
negative_prompt
):
raise
TypeError
(
raise
TypeError
(
"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f
"`negative_prompt` should be the same type to `prompt`, but got
{
type
(
negative_prompt
)
}
!="
" {type(prompt)}."
f
"
{
type
(
prompt
)
}
."
)
)
elif
isinstance
(
negative_prompt
,
str
):
elif
isinstance
(
negative_prompt
,
str
):
uncond_tokens
=
[
negative_prompt
]
uncond_tokens
=
[
negative_prompt
]
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
View file @
26c7df5d
...
@@ -195,7 +195,6 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
...
@@ -195,7 +195,6 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
"""
"""
if
isinstance
(
prompt
,
str
):
if
isinstance
(
prompt
,
str
):
batch_size
=
1
batch_size
=
1
prompt
=
[
prompt
]
elif
isinstance
(
prompt
,
list
):
elif
isinstance
(
prompt
,
list
):
batch_size
=
len
(
prompt
)
batch_size
=
len
(
prompt
)
else
:
else
:
...
@@ -250,8 +249,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
...
@@ -250,8 +249,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
uncond_tokens
=
[
""
]
uncond_tokens
=
[
""
]
elif
type
(
prompt
)
is
not
type
(
negative_prompt
):
elif
type
(
prompt
)
is
not
type
(
negative_prompt
):
raise
TypeError
(
raise
TypeError
(
"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f
"`negative_prompt` should be the same type to `prompt`, but got
{
type
(
negative_prompt
)
}
!="
" {type(prompt)}."
f
"
{
type
(
prompt
)
}
."
)
)
elif
isinstance
(
negative_prompt
,
str
):
elif
isinstance
(
negative_prompt
,
str
):
uncond_tokens
=
[
negative_prompt
]
uncond_tokens
=
[
negative_prompt
]
...
@@ -285,6 +284,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
...
@@ -285,6 +284,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
init_latents
=
init_latent_dist
.
sample
(
generator
=
generator
)
init_latents
=
init_latent_dist
.
sample
(
generator
=
generator
)
init_latents
=
0.18215
*
init_latents
init_latents
=
0.18215
*
init_latents
if
isinstance
(
prompt
,
str
):
prompt
=
[
prompt
]
if
len
(
prompt
)
>
init_latents
.
shape
[
0
]
and
len
(
prompt
)
%
init_latents
.
shape
[
0
]
==
0
:
if
len
(
prompt
)
>
init_latents
.
shape
[
0
]
and
len
(
prompt
)
%
init_latents
.
shape
[
0
]
==
0
:
# expand init_latents for batch_size
# expand init_latents for batch_size
deprecation_message
=
(
deprecation_message
=
(
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
View file @
26c7df5d
...
@@ -266,8 +266,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
...
@@ -266,8 +266,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
uncond_tokens
=
[
""
]
uncond_tokens
=
[
""
]
elif
type
(
prompt
)
is
not
type
(
negative_prompt
):
elif
type
(
prompt
)
is
not
type
(
negative_prompt
):
raise
TypeError
(
raise
TypeError
(
"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f
"`negative_prompt` should be the same type to `prompt`, but got
{
type
(
negative_prompt
)
}
!="
" {type(prompt)}."
f
"
{
type
(
prompt
)
}
."
)
)
elif
isinstance
(
negative_prompt
,
str
):
elif
isinstance
(
negative_prompt
,
str
):
uncond_tokens
=
[
negative_prompt
]
uncond_tokens
=
[
negative_prompt
]
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py
View file @
26c7df5d
...
@@ -108,8 +108,8 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline):
...
@@ -108,8 +108,8 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline):
uncond_tokens
=
[
""
]
*
batch_size
uncond_tokens
=
[
""
]
*
batch_size
elif
type
(
prompt
)
is
not
type
(
negative_prompt
):
elif
type
(
prompt
)
is
not
type
(
negative_prompt
):
raise
TypeError
(
raise
TypeError
(
"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f
"`negative_prompt` should be the same type to `prompt`, but got
{
type
(
negative_prompt
)
}
!="
" {type(prompt)}."
f
"
{
type
(
prompt
)
}
."
)
)
elif
isinstance
(
negative_prompt
,
str
):
elif
isinstance
(
negative_prompt
,
str
):
uncond_tokens
=
[
negative_prompt
]
*
batch_size
uncond_tokens
=
[
negative_prompt
]
*
batch_size
...
...
tests/test_pipelines.py
View file @
26c7df5d
...
@@ -575,6 +575,46 @@ class PipelineFastTests(unittest.TestCase):
...
@@ -575,6 +575,46 @@ class PipelineFastTests(unittest.TestCase):
assert
np
.
abs
(
output_2
.
images
.
flatten
()
-
output_1
.
images
.
flatten
()).
max
()
<
1e-4
assert
np
.
abs
(
output_2
.
images
.
flatten
()
-
output_1
.
images
.
flatten
()).
max
()
<
1e-4
def
test_stable_diffusion_negative_prompt
(
self
):
device
=
"cpu"
# ensure determinism for the device-dependent torch.Generator
unet
=
self
.
dummy_cond_unet
scheduler
=
PNDMScheduler
(
skip_prk_steps
=
True
)
vae
=
self
.
dummy_vae
bert
=
self
.
dummy_text_encoder
tokenizer
=
CLIPTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-clip"
)
# make sure here that pndm scheduler skips prk
sd_pipe
=
StableDiffusionPipeline
(
unet
=
unet
,
scheduler
=
scheduler
,
vae
=
vae
,
text_encoder
=
bert
,
tokenizer
=
tokenizer
,
safety_checker
=
self
.
dummy_safety_checker
,
feature_extractor
=
self
.
dummy_extractor
,
)
sd_pipe
=
sd_pipe
.
to
(
device
)
sd_pipe
.
set_progress_bar_config
(
disable
=
None
)
prompt
=
"A painting of a squirrel eating a burger"
negative_prompt
=
"french fries"
generator
=
torch
.
Generator
(
device
=
device
).
manual_seed
(
0
)
output
=
sd_pipe
(
prompt
,
negative_prompt
=
negative_prompt
,
generator
=
generator
,
guidance_scale
=
6.0
,
num_inference_steps
=
2
,
output_type
=
"np"
,
)
image
=
output
.
images
image_slice
=
image
[
0
,
-
3
:,
-
3
:,
-
1
]
assert
image
.
shape
==
(
1
,
128
,
128
,
3
)
expected_slice
=
np
.
array
([
0.4851
,
0.4617
,
0.4765
,
0.5127
,
0.4845
,
0.5153
,
0.5141
,
0.4886
,
0.4719
])
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
def
test_score_sde_ve_pipeline
(
self
):
def
test_score_sde_ve_pipeline
(
self
):
unet
=
self
.
dummy_uncond_unet
unet
=
self
.
dummy_uncond_unet
scheduler
=
ScoreSdeVeScheduler
()
scheduler
=
ScoreSdeVeScheduler
()
...
@@ -704,6 +744,48 @@ class PipelineFastTests(unittest.TestCase):
...
@@ -704,6 +744,48 @@ class PipelineFastTests(unittest.TestCase):
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
assert
np
.
abs
(
image_from_tuple_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
assert
np
.
abs
(
image_from_tuple_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
def
test_stable_diffusion_img2img_negative_prompt
(
self
):
device
=
"cpu"
# ensure determinism for the device-dependent torch.Generator
unet
=
self
.
dummy_cond_unet
scheduler
=
PNDMScheduler
(
skip_prk_steps
=
True
)
vae
=
self
.
dummy_vae
bert
=
self
.
dummy_text_encoder
tokenizer
=
CLIPTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-clip"
)
init_image
=
self
.
dummy_image
.
to
(
device
)
# make sure here that pndm scheduler skips prk
sd_pipe
=
StableDiffusionImg2ImgPipeline
(
unet
=
unet
,
scheduler
=
scheduler
,
vae
=
vae
,
text_encoder
=
bert
,
tokenizer
=
tokenizer
,
safety_checker
=
self
.
dummy_safety_checker
,
feature_extractor
=
self
.
dummy_extractor
,
)
sd_pipe
=
sd_pipe
.
to
(
device
)
sd_pipe
.
set_progress_bar_config
(
disable
=
None
)
prompt
=
"A painting of a squirrel eating a burger"
negative_prompt
=
"french fries"
generator
=
torch
.
Generator
(
device
=
device
).
manual_seed
(
0
)
output
=
sd_pipe
(
prompt
,
negative_prompt
=
negative_prompt
,
generator
=
generator
,
guidance_scale
=
6.0
,
num_inference_steps
=
2
,
output_type
=
"np"
,
init_image
=
init_image
,
)
image
=
output
.
images
image_slice
=
image
[
0
,
-
3
:,
-
3
:,
-
1
]
assert
image
.
shape
==
(
1
,
32
,
32
,
3
)
expected_slice
=
np
.
array
([
0.4065
,
0.3783
,
0.4050
,
0.5266
,
0.4781
,
0.4252
,
0.4203
,
0.4692
,
0.4365
])
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
def
test_stable_diffusion_img2img_multiple_init_images
(
self
):
def
test_stable_diffusion_img2img_multiple_init_images
(
self
):
device
=
"cpu"
# ensure determinism for the device-dependent torch.Generator
device
=
"cpu"
# ensure determinism for the device-dependent torch.Generator
unet
=
self
.
dummy_cond_unet
unet
=
self
.
dummy_cond_unet
...
@@ -861,6 +943,52 @@ class PipelineFastTests(unittest.TestCase):
...
@@ -861,6 +943,52 @@ class PipelineFastTests(unittest.TestCase):
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
assert
np
.
abs
(
image_from_tuple_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
assert
np
.
abs
(
image_from_tuple_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
def
test_stable_diffusion_inpaint_negative_prompt
(
self
):
device
=
"cpu"
# ensure determinism for the device-dependent torch.Generator
unet
=
self
.
dummy_cond_unet
scheduler
=
PNDMScheduler
(
skip_prk_steps
=
True
)
vae
=
self
.
dummy_vae
bert
=
self
.
dummy_text_encoder
tokenizer
=
CLIPTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-clip"
)
image
=
self
.
dummy_image
.
cpu
().
permute
(
0
,
2
,
3
,
1
)[
0
]
init_image
=
Image
.
fromarray
(
np
.
uint8
(
image
)).
convert
(
"RGB"
)
mask_image
=
Image
.
fromarray
(
np
.
uint8
(
image
+
4
)).
convert
(
"RGB"
).
resize
((
128
,
128
))
# make sure here that pndm scheduler skips prk
sd_pipe
=
StableDiffusionInpaintPipeline
(
unet
=
unet
,
scheduler
=
scheduler
,
vae
=
vae
,
text_encoder
=
bert
,
tokenizer
=
tokenizer
,
safety_checker
=
self
.
dummy_safety_checker
,
feature_extractor
=
self
.
dummy_extractor
,
)
sd_pipe
=
sd_pipe
.
to
(
device
)
sd_pipe
.
set_progress_bar_config
(
disable
=
None
)
prompt
=
"A painting of a squirrel eating a burger"
negative_prompt
=
"french fries"
generator
=
torch
.
Generator
(
device
=
device
).
manual_seed
(
0
)
output
=
sd_pipe
(
prompt
,
negative_prompt
=
negative_prompt
,
generator
=
generator
,
guidance_scale
=
6.0
,
num_inference_steps
=
2
,
output_type
=
"np"
,
init_image
=
init_image
,
mask_image
=
mask_image
,
)
image
=
output
.
images
image_slice
=
image
[
0
,
-
3
:,
-
3
:,
-
1
]
assert
image
.
shape
==
(
1
,
32
,
32
,
3
)
expected_slice
=
np
.
array
([
0.4765
,
0.5339
,
0.4541
,
0.6240
,
0.5439
,
0.4055
,
0.5503
,
0.5891
,
0.5150
])
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
def
test_stable_diffusion_num_images_per_prompt
(
self
):
def
test_stable_diffusion_num_images_per_prompt
(
self
):
device
=
"cpu"
# ensure determinism for the device-dependent torch.Generator
device
=
"cpu"
# ensure determinism for the device-dependent torch.Generator
unet
=
self
.
dummy_cond_unet
unet
=
self
.
dummy_cond_unet
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment