Unverified Commit 585f621a authored by Stax124's avatar Stax124 Committed by GitHub
Browse files

[Stable Diffusion] Allow users to disable Safety checker if loading model from checkpoint (#2768)



* Allow user to disable SafetyChecker and enable dtypes if loading models from .ckpt or .safetensors

* Fix Import sorting (Ruff error)

* Get rid of the dtype convert method as it was implemented all along

* Fix the docstring

* Fix ruff formatting

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent c0afca2d
......@@ -989,6 +989,7 @@ def download_from_original_stable_diffusion_ckpt(
stable_unclip_prior: Optional[str] = None,
clip_stats_path: Optional[str] = None,
controlnet: Optional[bool] = None,
load_safety_checker: bool = True,
) -> StableDiffusionPipeline:
"""
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
......@@ -1028,6 +1029,8 @@ def download_from_original_stable_diffusion_ckpt(
The device to use. Pass `None` to determine automatically. :param from_safetensors: If `checkpoint_path` is
in `safetensors` format, load checkpoint with safetensors instead of PyTorch. :return: A
StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
load_safety_checker (`bool`, *optional*, defaults to `True`):
Whether to load the safety checker or not. Defaults to `True`.
"""
if prediction_type == "v-prediction":
prediction_type = "v_prediction"
......@@ -1270,8 +1273,13 @@ def download_from_original_stable_diffusion_ckpt(
elif model_type == "FrozenCLIPEmbedder":
text_model = convert_ldm_clip_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
if load_safety_checker:
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
else:
safety_checker = None
feature_extractor = None
if controlnet:
pipe = StableDiffusionControlNetPipeline(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment