Unverified Commit 7e395166 authored by w4ffl35's avatar w4ffl35 Committed by GitHub
Browse files

Allow more arguments to be passed to convert_from_ckpt (#7222)



Allow safety and feature extractor arguments to be passed to convert_from_ckpt

Allows management of safety checker and feature extractor
from outside of the convert ckpt class.
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 56a76082
...@@ -1153,6 +1153,8 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1153,6 +1153,8 @@ def download_from_original_stable_diffusion_ckpt(
controlnet: Optional[bool] = None, controlnet: Optional[bool] = None,
adapter: Optional[bool] = None, adapter: Optional[bool] = None,
load_safety_checker: bool = True, load_safety_checker: bool = True,
safety_checker: Optional[StableDiffusionSafetyChecker] = None,
feature_extractor: Optional[AutoFeatureExtractor] = None,
pipeline_class: DiffusionPipeline = None, pipeline_class: DiffusionPipeline = None,
local_files_only=False, local_files_only=False,
vae_path=None, vae_path=None,
...@@ -1205,6 +1207,12 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1205,6 +1207,12 @@ def download_from_original_stable_diffusion_ckpt(
If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch. If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.
load_safety_checker (`bool`, *optional*, defaults to `True`): load_safety_checker (`bool`, *optional*, defaults to `True`):
Whether to load the safety checker or not. Defaults to `True`. Whether to load the safety checker or not. Defaults to `True`.
safety_checker (`StableDiffusionSafetyChecker`, *optional*, defaults to `None`):
Safety checker to use. If this parameter is `None`, the function will load a new instance of
[StableDiffusionSafetyChecker] by itself, if needed.
feature_extractor (`AutoFeatureExtractor`, *optional*, defaults to `None`):
Feature extractor to use. If this parameter is `None`, the function will load a new instance of
[AutoFeatureExtractor] by itself, if needed.
pipeline_class (`str`, *optional*, defaults to `None`): pipeline_class (`str`, *optional*, defaults to `None`):
The pipeline class to use. Pass `None` to determine automatically. The pipeline class to use. Pass `None` to determine automatically.
local_files_only (`bool`, *optional*, defaults to `False`): local_files_only (`bool`, *optional*, defaults to `False`):
...@@ -1530,8 +1538,8 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1530,8 +1538,8 @@ def download_from_original_stable_diffusion_ckpt(
unet=unet, unet=unet,
scheduler=scheduler, scheduler=scheduler,
controlnet=controlnet, controlnet=controlnet,
safety_checker=None, safety_checker=safety_checker,
feature_extractor=None, feature_extractor=feature_extractor,
) )
if hasattr(pipe, "requires_safety_checker"): if hasattr(pipe, "requires_safety_checker"):
pipe.requires_safety_checker = False pipe.requires_safety_checker = False
...@@ -1551,8 +1559,8 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1551,8 +1559,8 @@ def download_from_original_stable_diffusion_ckpt(
unet=unet, unet=unet,
scheduler=scheduler, scheduler=scheduler,
low_res_scheduler=low_res_scheduler, low_res_scheduler=low_res_scheduler,
safety_checker=None, safety_checker=safety_checker,
feature_extractor=None, feature_extractor=feature_extractor,
) )
else: else:
...@@ -1562,8 +1570,8 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1562,8 +1570,8 @@ def download_from_original_stable_diffusion_ckpt(
tokenizer=tokenizer, tokenizer=tokenizer,
unet=unet, unet=unet,
scheduler=scheduler, scheduler=scheduler,
safety_checker=None, safety_checker=safety_checker,
feature_extractor=None, feature_extractor=feature_extractor,
) )
if hasattr(pipe, "requires_safety_checker"): if hasattr(pipe, "requires_safety_checker"):
pipe.requires_safety_checker = False pipe.requires_safety_checker = False
...@@ -1684,9 +1692,6 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1684,9 +1692,6 @@ def download_from_original_stable_diffusion_ckpt(
feature_extractor = AutoFeatureExtractor.from_pretrained( feature_extractor = AutoFeatureExtractor.from_pretrained(
"CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only
) )
else:
safety_checker = None
feature_extractor = None
if controlnet: if controlnet:
pipe = pipeline_class( pipe = pipeline_class(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment