Unverified Commit 8e4fd686 authored by Jonatan Kłosko's avatar Jonatan Kłosko Committed by GitHub
Browse files

Move safety detection to model call in Flax safety checker (#1023)

* Move safety detection to model call in Flax safety checker

* Update src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py
parent 95414bd6
import warnings
from functools import partial from functools import partial
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
...@@ -97,9 +98,9 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): ...@@ -97,9 +98,9 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
) )
return text_input.input_ids return text_input.input_ids
def _get_safety_scores(self, features, params): def _get_has_nsfw_concepts(self, features, params):
special_cos_dist, cos_dist = self.safety_checker(features, params) has_nsfw_concepts = self.safety_checker(features, params)
return (special_cos_dist, cos_dist) return has_nsfw_concepts
def _run_safety_checker(self, images, safety_model_params, jit=False): def _run_safety_checker(self, images, safety_model_params, jit=False):
# safety_model_params should already be replicated when jit is True # safety_model_params should already be replicated when jit is True
...@@ -108,20 +109,28 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): ...@@ -108,20 +109,28 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
if jit: if jit:
features = shard(features) features = shard(features)
special_cos_dist, cos_dist = _p_get_safety_scores(self, features, safety_model_params) has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params)
special_cos_dist = unshard(special_cos_dist) has_nsfw_concepts = unshard(has_nsfw_concepts)
cos_dist = unshard(cos_dist)
safety_model_params = unreplicate(safety_model_params) safety_model_params = unreplicate(safety_model_params)
else: else:
special_cos_dist, cos_dist = self._get_safety_scores(features, safety_model_params) has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params)
images, has_nsfw = self.safety_checker.filtered_with_scores( images_was_copied = False
special_cos_dist, for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
cos_dist, if has_nsfw_concept:
images, if not images_was_copied:
safety_model_params, images_was_copied = True
) images = images.copy()
return images, has_nsfw
images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image
if any(has_nsfw_concepts):
warnings.warn(
"Potential NSFW content was detected in one or more images. A black image will be returned"
" instead. Try again with a different prompt and/or seed."
)
return images, has_nsfw_concepts
def _generate( def _generate(
self, self,
...@@ -310,8 +319,8 @@ def _p_generate( ...@@ -310,8 +319,8 @@ def _p_generate(
@partial(jax.pmap, static_broadcasted_argnums=(0,)) @partial(jax.pmap, static_broadcasted_argnums=(0,))
def _p_get_safety_scores(pipe, features, params): def _p_get_has_nsfw_concepts(pipe, features, params):
return pipe._get_safety_scores(features, params) return pipe._get_has_nsfw_concepts(features, params)
def unshard(x: jnp.ndarray): def unshard(x: jnp.ndarray):
......
import warnings
from typing import Optional, Tuple from typing import Optional, Tuple
import numpy as np
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax import linen as nn from flax import linen as nn
...@@ -39,56 +36,22 @@ class FlaxStableDiffusionSafetyCheckerModule(nn.Module): ...@@ -39,56 +36,22 @@ class FlaxStableDiffusionSafetyCheckerModule(nn.Module):
special_cos_dist = jax_cosine_distance(image_embeds, self.special_care_embeds) special_cos_dist = jax_cosine_distance(image_embeds, self.special_care_embeds)
cos_dist = jax_cosine_distance(image_embeds, self.concept_embeds) cos_dist = jax_cosine_distance(image_embeds, self.concept_embeds)
return special_cos_dist, cos_dist
def filtered_with_scores(self, special_cos_dist, cos_dist, images):
batch_size = special_cos_dist.shape[0]
special_cos_dist = np.asarray(special_cos_dist)
cos_dist = np.asarray(cos_dist)
result = []
for i in range(batch_size):
result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []}
# increase this value to create a stronger `nfsw` filter
# at the cost of increasing the possibility of filtering benign image inputs
adjustment = 0.0
for concept_idx in range(len(special_cos_dist[0])):
concept_cos = special_cos_dist[i][concept_idx]
concept_threshold = self.special_care_embeds_weights[concept_idx].item()
result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
if result_img["special_scores"][concept_idx] > 0:
result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]})
adjustment = 0.01
for concept_idx in range(len(cos_dist[0])):
concept_cos = cos_dist[i][concept_idx]
concept_threshold = self.concept_embeds_weights[concept_idx].item()
result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
if result_img["concept_scores"][concept_idx] > 0:
result_img["bad_concepts"].append(concept_idx)
result.append(result_img)
has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result] # increase this value to create a stronger `nfsw` filter
# at the cost of increasing the possibility of filtering benign image inputs
adjustment = 0.0
images_was_copied = False special_scores = special_cos_dist - self.special_care_embeds_weights[None, :] + adjustment
for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): special_scores = jnp.round(special_scores, 3)
if has_nsfw_concept: is_special_care = jnp.any(special_scores > 0, axis=1, keepdims=True)
if not images_was_copied: # Use a lower threshold if an image has any special care concept
images_was_copied = True special_adjustment = is_special_care * 0.01
images = images.copy()
images[idx] = np.zeros(images[idx].shape) # black image concept_scores = cos_dist - self.concept_embeds_weights[None, :] + special_adjustment
concept_scores = jnp.round(concept_scores, 3)
has_nsfw_concepts = jnp.any(concept_scores > 0, axis=1)
if any(has_nsfw_concepts): return has_nsfw_concepts
warnings.warn(
"Potential NSFW content was detected in one or more images. A black image will be returned"
" instead. Try again with a different prompt and/or seed."
)
return images, has_nsfw_concepts
class FlaxStableDiffusionSafetyChecker(FlaxPreTrainedModel): class FlaxStableDiffusionSafetyChecker(FlaxPreTrainedModel):
...@@ -133,15 +96,3 @@ class FlaxStableDiffusionSafetyChecker(FlaxPreTrainedModel): ...@@ -133,15 +96,3 @@ class FlaxStableDiffusionSafetyChecker(FlaxPreTrainedModel):
jnp.array(clip_input, dtype=jnp.float32), jnp.array(clip_input, dtype=jnp.float32),
rngs={}, rngs={},
) )
def filtered_with_scores(self, special_cos_dist, cos_dist, images, params: dict = None):
def _filtered_with_scores(module, special_cos_dist, cos_dist, images):
return module.filtered_with_scores(special_cos_dist, cos_dist, images)
return self.module.apply(
{"params": params or self.params},
special_cos_dist,
cos_dist,
images,
method=_filtered_with_scores,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment