Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
beb848e2
Unverified
Commit
beb848e2
authored
Apr 17, 2023
by
Patrick von Platen
Committed by
GitHub
Apr 17, 2023
Browse files
[Bug fix] Fix img2img processor with safety checker (#3127)
Fix img2img processor with safety checker
parent
cfc99adf
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
1 deletion
+18
-1
src/diffusers/pipelines/stable_diffusion/safety_checker.py
src/diffusers/pipelines/stable_diffusion/safety_checker.py
+4
-1
tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py
...pelines/stable_diffusion/test_stable_diffusion_img2img.py
+14
-0
No files found.
src/diffusers/pipelines/stable_diffusion/safety_checker.py
View file @
beb848e2
...
...
@@ -85,7 +85,10 @@ class StableDiffusionSafetyChecker(PreTrainedModel):
for
idx
,
has_nsfw_concept
in
enumerate
(
has_nsfw_concepts
):
if
has_nsfw_concept
:
images
[
idx
]
=
np
.
zeros
(
images
[
idx
].
shape
)
# black image
if
torch
.
is_tensor
(
images
)
or
torch
.
is_tensor
(
images
[
0
]):
images
[
idx
]
=
torch
.
zeros_like
(
images
[
idx
])
# black image
else
:
images
[
idx
]
=
np
.
zeros
(
images
[
idx
].
shape
)
# black image
if
any
(
has_nsfw_concepts
):
logger
.
warning
(
...
...
tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py
View file @
beb848e2
...
...
@@ -453,6 +453,20 @@ class StableDiffusionImg2ImgPipelineSlowTests(unittest.TestCase):
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
5e-3
def
test_img2img_safety_checker_works
(
self
):
sd_pipe
=
StableDiffusionImg2ImgPipeline
.
from_pretrained
(
"runwayml/stable-diffusion-v1-5"
)
sd_pipe
.
to
(
torch_device
)
sd_pipe
.
set_progress_bar_config
(
disable
=
None
)
inputs
=
self
.
get_inputs
(
torch_device
)
inputs
[
"num_inference_steps"
]
=
20
# make sure the safety checker is activated
inputs
[
"prompt"
]
=
"naked, sex, porn"
out
=
sd_pipe
(
**
inputs
)
assert
out
.
nsfw_content_detected
[
0
],
f
"Safety checker should work for prompt:
{
inputs
[
'prompt'
]
}
"
assert
np
.
abs
(
out
.
images
[
0
]).
sum
()
<
1e-5
# should be all zeros
@
nightly
@
require_torch_gpu
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment