Unverified Commit 65ea7d6b authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

Add safety module (#213)

* add SafetyChecker

* better name, fix checker

* add checker in main init

* remove from main init

* update logic to detect pipeline module

* style

* handle all safety logic in safety checker

* draw text

* can't draw

* small fixes

* treat special care as nsfw

* remove commented lines

* update safety checker
parent e30e1b89
...@@ -42,6 +42,7 @@ LOADABLE_CLASSES = { ...@@ -42,6 +42,7 @@ LOADABLE_CLASSES = {
"PreTrainedTokenizer": ["save_pretrained", "from_pretrained"], "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
"PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"], "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"],
"PreTrainedModel": ["save_pretrained", "from_pretrained"], "PreTrainedModel": ["save_pretrained", "from_pretrained"],
"FeatureExtractionMixin": ["save_pretrained", "from_pretrained"],
}, },
} }
...@@ -63,9 +64,9 @@ class DiffusionPipeline(ConfigMixin): ...@@ -63,9 +64,9 @@ class DiffusionPipeline(ConfigMixin):
library = module.__module__.split(".")[0] library = module.__module__.split(".")[0]
# check if the module is a pipeline module # check if the module is a pipeline module
pipeline_file = module.__module__.split(".")[-1]
pipeline_dir = module.__module__.split(".")[-2] pipeline_dir = module.__module__.split(".")[-2]
is_pipeline_module = pipeline_file == "pipeline_" + pipeline_dir and hasattr(pipelines, pipeline_dir) path = module.__module__.split(".")
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
# if library is not in LOADABLE_CLASSES, then it is a custom module. # if library is not in LOADABLE_CLASSES, then it is a custom module.
# Or if it's a pipeline module, then the module is inside the pipeline # Or if it's a pipeline module, then the module is inside the pipeline
......
...@@ -3,4 +3,4 @@ from ...utils import is_transformers_available ...@@ -3,4 +3,4 @@ from ...utils import is_transformers_available
if is_transformers_available(): if is_transformers_available():
from .pipeline_stable_diffusion import StableDiffusionPipeline from .pipeline_stable_diffusion import StableDiffusionPipeline, StableDiffusionSafetyChecker
...@@ -4,11 +4,12 @@ from typing import List, Optional, Union ...@@ -4,11 +4,12 @@ from typing import List, Optional, Union
import torch import torch
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from .safety_checker import StableDiffusionSafetyChecker
class StableDiffusionPipeline(DiffusionPipeline): class StableDiffusionPipeline(DiffusionPipeline):
...@@ -19,10 +20,20 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -19,10 +20,20 @@ class StableDiffusionPipeline(DiffusionPipeline):
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
): ):
super().__init__() super().__init__()
scheduler = scheduler.set_format("pt") scheduler = scheduler.set_format("pt")
self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler) self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
...@@ -53,6 +64,7 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -53,6 +64,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
self.unet.to(torch_device) self.unet.to(torch_device)
self.vae.to(torch_device) self.vae.to(torch_device)
self.text_encoder.to(torch_device) self.text_encoder.to(torch_device)
self.safety_checker.to(torch_device)
# get prompt text embeddings # get prompt text embeddings
text_input = self.tokenizer( text_input = self.tokenizer(
...@@ -136,7 +148,12 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -136,7 +148,12 @@ class StableDiffusionPipeline(DiffusionPipeline):
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy() image = image.cpu().permute(0, 2, 3, 1).numpy()
# run safety checker
safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(torch_device)
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)
if output_type == "pil": if output_type == "pil":
image = self.numpy_to_pil(image) image = self.numpy_to_pil(image)
return {"sample": image} return {"sample": image, "nsfw_content_detected": has_nsfw_concept}
import numpy as np
import torch
import torch.nn as nn
from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel
from ...utils import logging
logger = logging.get_logger(__name__)
def cosine_distance(image_embeds, text_embeds):
normalized_image_embeds = nn.functional.normalize(image_embeds)
normalized_text_embeds = nn.functional.normalize(text_embeds)
return torch.mm(normalized_image_embeds, normalized_text_embeds.T)
class StableDiffusionSafetyChecker(PreTrainedModel):
config_class = CLIPConfig
def __init__(self, config: CLIPConfig):
super().__init__(config)
self.vision_model = CLIPVisionModel(config.vision_config)
self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False)
self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False)
self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False)
self.register_buffer("concept_embeds_weights", torch.ones(17))
self.register_buffer("special_care_embeds_weights", torch.ones(3))
@torch.no_grad()
def forward(self, clip_input, images):
pooled_output = self.vision_model(clip_input)[1] # pooled_output
image_embeds = self.visual_projection(pooled_output)
special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().numpy()
cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().numpy()
result = []
batch_size = image_embeds.shape[0]
for i in range(batch_size):
result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []}
adjustment = 0.05
for concet_idx in range(len(special_cos_dist[0])):
concept_cos = special_cos_dist[i][concet_idx]
concept_threshold = self.special_care_embeds_weights[concet_idx].item()
result_img["special_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3)
if result_img["special_scores"][concet_idx] > 0:
result_img["special_care"].append({concet_idx, result_img["special_scores"][concet_idx]})
adjustment = 0.01
for concet_idx in range(len(cos_dist[0])):
concept_cos = cos_dist[i][concet_idx]
concept_threshold = self.concept_embeds_weights[concet_idx].item()
result_img["concept_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3)
if result_img["concept_scores"][concet_idx] > 0:
result_img["bad_concepts"].append(concet_idx)
result.append(result_img)
has_nsfw_concepts = [len(result[i]["bad_concepts"]) > 0 or i in range(len(result))]
for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
if has_nsfw_concept:
images[idx] = np.zeros(images[idx].shape) # black image
if any(has_nsfw_concepts):
logger.warning(
"Potential NSFW content was detected in one or more images. A black image will be returned instead."
" Try again with a different prompt and/or seed."
)
return images, has_nsfw_concepts
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment