Unverified Commit 948022e1 authored by Justin Merrell's avatar Justin Merrell Committed by GitHub
Browse files

fix: flagged_images implementation (#1947)



Flagged images would be set to the blank image instead of the original image that contained the NSF concept for optional viewing.
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 2f9a70aa
...@@ -341,21 +341,20 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): ...@@ -341,21 +341,20 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
def run_safety_checker(self, image, device, dtype, enable_safety_guidance): def run_safety_checker(self, image, device, dtype, enable_safety_guidance):
if self.safety_checker is not None: if self.safety_checker is not None:
images = image.copy()
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker( image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype) images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
) )
flagged_images = None flagged_images = np.zeros((2, *image.shape[1:]))
if any(has_nsfw_concept): if any(has_nsfw_concept):
logger.warning( logger.warning(
"Potential NSFW content was detected in one or more images. A black image will be returned" "Potential NSFW content was detected in one or more images. A black image will be returned instead."
" instead." f"{'You may look at this images in the `unsafe_images` variable of the output at your own discretion.' if enable_safety_guidance else 'Try again with a different prompt and/or seed.'}"
f" {'You may look at this images in the `unsafe_images` variable of the output at your own discretion.' if enable_safety_guidance else 'Try again with a different prompt and/or seed.'} "
) )
flagged_images = np.zeros((2, *image.shape[1:]))
for idx, has_nsfw_concept in enumerate(has_nsfw_concept): for idx, has_nsfw_concept in enumerate(has_nsfw_concept):
if has_nsfw_concept: if has_nsfw_concept:
flagged_images[idx] = image[idx] flagged_images[idx] = images[idx]
image[idx] = np.zeros(image[idx].shape) # black image image[idx] = np.zeros(image[idx].shape) # black image
else: else:
has_nsfw_concept = None has_nsfw_concept = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment