Unverified Commit cbfed0c2 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Config] Add optional arguments (#1395)



* Optional Components

* uP

* finish

* finish

* finish

* Apply suggestions from code review
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* up

* Update src/diffusers/pipeline_utils.py

* improve
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent e0e86b74
...@@ -129,10 +129,13 @@ class DiffusionPipeline(ConfigMixin): ...@@ -129,10 +129,13 @@ class DiffusionPipeline(ConfigMixin):
Class attributes: Class attributes:
- **config_name** ([`str`]) -- name of the config file that will store the class and module names of all - **config_name** (`str`) -- name of the config file that will store the class and module names of all
components of the diffusion pipeline. components of the diffusion pipeline.
- **_optional_components** (List[`str`]) -- list of all components that are optional so they don't have to be
passed for the pipeline to function (should be overridden by subclasses).
""" """
config_name = "model_index.json" config_name = "model_index.json"
_optional_components = []
def register_modules(self, **kwargs): def register_modules(self, **kwargs):
# import it here to avoid circular import # import it here to avoid circular import
...@@ -184,12 +187,19 @@ class DiffusionPipeline(ConfigMixin): ...@@ -184,12 +187,19 @@ class DiffusionPipeline(ConfigMixin):
model_index_dict.pop("_diffusers_version") model_index_dict.pop("_diffusers_version")
model_index_dict.pop("_module", None) model_index_dict.pop("_module", None)
expected_modules, optional_kwargs = self._get_signature_keys(self)
def is_saveable_module(name, value):
if name not in expected_modules:
return False
if name in self._optional_components and value[0] is None:
return False
return True
model_index_dict = {k: v for k, v in model_index_dict.items() if is_saveable_module(k, v)}
for pipeline_component_name in model_index_dict.keys(): for pipeline_component_name in model_index_dict.keys():
sub_model = getattr(self, pipeline_component_name) sub_model = getattr(self, pipeline_component_name)
if sub_model is None:
# edge case for saving a pipeline with safety_checker=None
continue
model_cls = sub_model.__class__ model_cls = sub_model.__class__
save_method_name = None save_method_name = None
...@@ -523,26 +533,27 @@ class DiffusionPipeline(ConfigMixin): ...@@ -523,26 +533,27 @@ class DiffusionPipeline(ConfigMixin):
# some modules can be passed directly to the init # some modules can be passed directly to the init
# in this case they are already instantiated in `kwargs` # in this case they are already instantiated in `kwargs`
# extract them here # extract them here
expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) - set(["self"]) expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
# define init kwargs
init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict}
init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
# remove `null` components
init_dict = {k: v for k, v in init_dict.items() if v[0] is not None}
if len(unused_kwargs) > 0: if len(unused_kwargs) > 0:
logger.warning(f"Keyword arguments {unused_kwargs} not recognized.") logger.warning(f"Keyword arguments {unused_kwargs} not recognized.")
init_kwargs = {}
# import it here to avoid circular import # import it here to avoid circular import
from diffusers import pipelines from diffusers import pipelines
# 3. Load each module in the pipeline # 3. Load each module in the pipeline
for name, (library_name, class_name) in init_dict.items(): for name, (library_name, class_name) in init_dict.items():
if class_name is None:
# edge case for when the pipeline was saved with safety_checker=None
init_kwargs[name] = None
continue
# 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names # 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
if class_name.startswith("Flax"): if class_name.startswith("Flax"):
class_name = class_name[4:] class_name = class_name[4:]
...@@ -570,7 +581,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -570,7 +581,7 @@ class DiffusionPipeline(ConfigMixin):
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be" f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
f" {expected_class_obj}" f" {expected_class_obj}"
) )
elif passed_class_obj[name] is None: elif passed_class_obj[name] is None and name not in pipeline_class._optional_components:
logger.warning( logger.warning(
f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note" f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note"
f" that this might lead to problems when using {pipeline_class} and is not recommended." f" that this might lead to problems when using {pipeline_class} and is not recommended."
...@@ -651,11 +662,13 @@ class DiffusionPipeline(ConfigMixin): ...@@ -651,11 +662,13 @@ class DiffusionPipeline(ConfigMixin):
# 4. Potentially add passed objects if expected # 4. Potentially add passed objects if expected
missing_modules = set(expected_modules) - set(init_kwargs.keys()) missing_modules = set(expected_modules) - set(init_kwargs.keys())
if len(missing_modules) > 0 and missing_modules <= set(passed_class_obj.keys()): passed_modules = list(passed_class_obj.keys())
optional_modules = pipeline_class._optional_components
if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules):
for module in missing_modules: for module in missing_modules:
init_kwargs[module] = passed_class_obj[module] init_kwargs[module] = passed_class_obj.get(module, None)
elif len(missing_modules) > 0: elif len(missing_modules) > 0:
passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs
raise ValueError( raise ValueError(
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed." f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
) )
...@@ -664,6 +677,14 @@ class DiffusionPipeline(ConfigMixin): ...@@ -664,6 +677,14 @@ class DiffusionPipeline(ConfigMixin):
model = pipeline_class(**init_kwargs) model = pipeline_class(**init_kwargs)
return model return model
@staticmethod
def _get_signature_keys(obj):
parameters = inspect.signature(obj.__init__).parameters
required_parameters = {k: v for k, v in parameters.items() if v.default is not True}
optional_parameters = set({k for k, v in parameters.items() if v.default is True})
expected_modules = set(required_parameters.keys()) - set(["self"])
return expected_modules, optional_parameters
@property @property
def components(self) -> Dict[str, Any]: def components(self) -> Dict[str, Any]:
r""" r"""
...@@ -688,8 +709,10 @@ class DiffusionPipeline(ConfigMixin): ...@@ -688,8 +709,10 @@ class DiffusionPipeline(ConfigMixin):
Returns: Returns:
A dictionaly containing all the modules needed to initialize the pipeline. A dictionaly containing all the modules needed to initialize the pipeline.
""" """
components = {k: getattr(self, k) for k in self.config.keys() if not k.startswith("_")} expected_modules, optional_parameters = self._get_signature_keys(self)
expected_modules = set(inspect.signature(self.__init__).parameters.keys()) - set(["self"]) components = {
k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters
}
if set(components.keys()) != expected_modules: if set(components.keys()) != expected_modules:
raise ValueError( raise ValueError(
......
...@@ -67,6 +67,7 @@ class AltDiffusionPipeline(DiffusionPipeline): ...@@ -67,6 +67,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
_optional_components = ["safety_checker", "feature_extractor"]
def __init__( def __init__(
self, self,
...@@ -84,6 +85,7 @@ class AltDiffusionPipeline(DiffusionPipeline): ...@@ -84,6 +85,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
], ],
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
...@@ -114,7 +116,7 @@ class AltDiffusionPipeline(DiffusionPipeline): ...@@ -114,7 +116,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
new_config["clip_sample"] = False new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config) scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None: if safety_checker is None and requires_safety_checker:
logger.warning( logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Alt Diffusion license and do not expose unfiltered" " that you abide to the conditions of the Alt Diffusion license and do not expose unfiltered"
...@@ -124,6 +126,12 @@ class AltDiffusionPipeline(DiffusionPipeline): ...@@ -124,6 +126,12 @@ class AltDiffusionPipeline(DiffusionPipeline):
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
) )
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
self.register_modules( self.register_modules(
vae=vae, vae=vae,
text_encoder=text_encoder, text_encoder=text_encoder,
...@@ -133,6 +141,7 @@ class AltDiffusionPipeline(DiffusionPipeline): ...@@ -133,6 +141,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
self.register_to_config(requires_safety_checker=requires_safety_checker)
def enable_xformers_memory_efficient_attention(self): def enable_xformers_memory_efficient_attention(self):
r""" r"""
......
...@@ -80,6 +80,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -80,6 +80,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
_optional_components = ["safety_checker", "feature_extractor"]
def __init__( def __init__(
self, self,
...@@ -97,6 +98,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -97,6 +98,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
], ],
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
...@@ -127,7 +129,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -127,7 +129,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
new_config["clip_sample"] = False new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config) scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None: if safety_checker is None and requires_safety_checker:
logger.warning( logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Alt Diffusion license and do not expose unfiltered" " that you abide to the conditions of the Alt Diffusion license and do not expose unfiltered"
...@@ -137,6 +139,12 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -137,6 +139,12 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
) )
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
self.register_modules( self.register_modules(
vae=vae, vae=vae,
text_encoder=text_encoder, text_encoder=text_encoder,
...@@ -146,6 +154,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -146,6 +154,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
self.register_to_config(requires_safety_checker=requires_safety_checker)
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r""" r"""
......
...@@ -132,6 +132,7 @@ class CycleDiffusionPipeline(DiffusionPipeline): ...@@ -132,6 +132,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
_optional_components = ["safety_checker", "feature_extractor"]
def __init__( def __init__(
self, self,
...@@ -142,6 +143,7 @@ class CycleDiffusionPipeline(DiffusionPipeline): ...@@ -142,6 +143,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
scheduler: DDIMScheduler, scheduler: DDIMScheduler,
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
...@@ -159,7 +161,7 @@ class CycleDiffusionPipeline(DiffusionPipeline): ...@@ -159,7 +161,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
new_config["steps_offset"] = 1 new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config) scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None: if safety_checker is None and requires_safety_checker:
logger.warning( logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
...@@ -169,6 +171,12 @@ class CycleDiffusionPipeline(DiffusionPipeline): ...@@ -169,6 +171,12 @@ class CycleDiffusionPipeline(DiffusionPipeline):
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
) )
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
self.register_modules( self.register_modules(
vae=vae, vae=vae,
text_encoder=text_encoder, text_encoder=text_encoder,
...@@ -178,6 +186,7 @@ class CycleDiffusionPipeline(DiffusionPipeline): ...@@ -178,6 +186,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
......
...@@ -51,6 +51,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): ...@@ -51,6 +51,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: OnnxRuntimeModel, safety_checker: OnnxRuntimeModel,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
...@@ -81,6 +82,22 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): ...@@ -81,6 +82,22 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
new_config["clip_sample"] = False new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config) scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
self.register_modules( self.register_modules(
vae_encoder=vae_encoder, vae_encoder=vae_encoder,
vae_decoder=vae_decoder, vae_decoder=vae_decoder,
...@@ -91,6 +108,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): ...@@ -91,6 +108,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
self.register_to_config(requires_safety_checker=requires_safety_checker)
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
r""" r"""
......
...@@ -87,6 +87,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -87,6 +87,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: OnnxRuntimeModel, safety_checker: OnnxRuntimeModel,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
...@@ -117,7 +118,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -117,7 +118,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
new_config["clip_sample"] = False new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config) scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None: if safety_checker is None and requires_safety_checker:
logger.warning( logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
...@@ -127,6 +128,12 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -127,6 +128,12 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
) )
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
self.register_modules( self.register_modules(
vae_encoder=vae_encoder, vae_encoder=vae_encoder,
vae_decoder=vae_decoder, vae_decoder=vae_decoder,
...@@ -137,6 +144,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -137,6 +144,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
......
...@@ -100,6 +100,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -100,6 +100,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: OnnxRuntimeModel, safety_checker: OnnxRuntimeModel,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
logger.info("`OnnxStableDiffusionInpaintPipeline` is experimental and will very likely change in the future.") logger.info("`OnnxStableDiffusionInpaintPipeline` is experimental and will very likely change in the future.")
...@@ -131,7 +132,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -131,7 +132,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
new_config["clip_sample"] = False new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config) scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None: if safety_checker is None and requires_safety_checker:
logger.warning( logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
...@@ -141,6 +142,12 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -141,6 +142,12 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
) )
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
self.register_modules( self.register_modules(
vae_encoder=vae_encoder, vae_encoder=vae_encoder,
vae_decoder=vae_decoder, vae_decoder=vae_decoder,
...@@ -151,6 +158,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -151,6 +158,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
......
...@@ -86,6 +86,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -86,6 +86,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: OnnxRuntimeModel, safety_checker: OnnxRuntimeModel,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
...@@ -116,7 +117,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -116,7 +117,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
new_config["clip_sample"] = False new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config) scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None: if safety_checker is None and requires_safety_checker:
logger.warning( logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
...@@ -126,6 +127,12 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -126,6 +127,12 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
) )
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
self.register_modules( self.register_modules(
vae_encoder=vae_encoder, vae_encoder=vae_encoder,
vae_decoder=vae_decoder, vae_decoder=vae_decoder,
...@@ -136,6 +143,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -136,6 +143,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
......
...@@ -66,6 +66,7 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -66,6 +66,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
_optional_components = ["safety_checker", "feature_extractor"]
def __init__( def __init__(
self, self,
...@@ -83,6 +84,7 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -83,6 +84,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
], ],
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
...@@ -113,7 +115,7 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -113,7 +115,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
new_config["clip_sample"] = False new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config) scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None: if safety_checker is None and requires_safety_checker:
logger.warning( logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
...@@ -123,6 +125,12 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -123,6 +125,12 @@ class StableDiffusionPipeline(DiffusionPipeline):
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
) )
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
self.register_modules( self.register_modules(
vae=vae, vae=vae,
text_encoder=text_encoder, text_encoder=text_encoder,
...@@ -132,6 +140,7 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -132,6 +140,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
self.register_to_config(requires_safety_checker=requires_safety_checker)
def enable_xformers_memory_efficient_attention(self): def enable_xformers_memory_efficient_attention(self):
r""" r"""
......
...@@ -63,6 +63,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline): ...@@ -63,6 +63,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
_optional_components = ["safety_checker", "feature_extractor"]
def __init__( def __init__(
self, self,
...@@ -79,10 +80,11 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline): ...@@ -79,10 +80,11 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
], ],
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
if safety_checker is None: if safety_checker is None and requires_safety_checker:
logger.warn( logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
...@@ -92,6 +94,12 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline): ...@@ -92,6 +94,12 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
) )
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
self.register_modules( self.register_modules(
vae=vae, vae=vae,
image_encoder=image_encoder, image_encoder=image_encoder,
...@@ -100,6 +108,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline): ...@@ -100,6 +108,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention
def enable_xformers_memory_efficient_attention(self): def enable_xformers_memory_efficient_attention(self):
......
...@@ -78,6 +78,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -78,6 +78,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
_optional_components = ["safety_checker", "feature_extractor"]
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__
def __init__( def __init__(
...@@ -96,6 +97,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -96,6 +97,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
], ],
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
...@@ -126,7 +128,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -126,7 +128,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
new_config["clip_sample"] = False new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config) scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None: if safety_checker is None and requires_safety_checker:
logger.warning( logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
...@@ -136,6 +138,12 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -136,6 +138,12 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
) )
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
self.register_modules( self.register_modules(
vae=vae, vae=vae,
text_encoder=text_encoder, text_encoder=text_encoder,
...@@ -145,6 +153,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -145,6 +153,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
......
...@@ -150,6 +150,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -150,6 +150,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
_optional_components = ["safety_checker", "feature_extractor"]
def __init__( def __init__(
self, self,
...@@ -160,6 +161,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -160,6 +161,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
...@@ -191,7 +193,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -191,7 +193,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
new_config["skip_prk_steps"] = True new_config["skip_prk_steps"] = True
scheduler._internal_dict = FrozenDict(new_config) scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None: if safety_checker is None and requires_safety_checker:
logger.warning( logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
...@@ -201,6 +203,12 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -201,6 +203,12 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
) )
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
self.register_modules( self.register_modules(
vae=vae, vae=vae,
text_encoder=text_encoder, text_encoder=text_encoder,
...@@ -210,6 +218,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -210,6 +218,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
......
...@@ -91,6 +91,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -91,6 +91,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
_optional_components = ["safety_checker", "feature_extractor"]
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__
def __init__( def __init__(
...@@ -109,6 +110,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -109,6 +110,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
], ],
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
...@@ -139,7 +141,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -139,7 +141,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
new_config["clip_sample"] = False new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config) scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None: if safety_checker is None and requires_safety_checker:
logger.warning( logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
...@@ -149,6 +151,12 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -149,6 +151,12 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
) )
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
self.register_modules( self.register_modules(
vae=vae, vae=vae,
text_encoder=text_encoder, text_encoder=text_encoder,
...@@ -158,6 +166,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -158,6 +166,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
......
...@@ -56,6 +56,8 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): ...@@ -56,6 +56,8 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
_optional_components = ["safety_checker", "feature_extractor"]
def __init__( def __init__(
self, self,
vae: AutoencoderKL, vae: AutoencoderKL,
...@@ -72,6 +74,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): ...@@ -72,6 +74,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
], ],
safety_checker: SafeStableDiffusionSafetyChecker, safety_checker: SafeStableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
safety_concept: Optional[str] = ( safety_concept: Optional[str] = (
...@@ -107,7 +110,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): ...@@ -107,7 +110,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
new_config["clip_sample"] = False new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config) scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None: if safety_checker is None and requires_safety_checker:
logger.warning( logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
...@@ -117,6 +120,12 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): ...@@ -117,6 +120,12 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
) )
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
self.register_modules( self.register_modules(
vae=vae, vae=vae,
text_encoder=text_encoder, text_encoder=text_encoder,
...@@ -127,6 +136,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): ...@@ -127,6 +136,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
self._safety_text_concept = safety_concept self._safety_text_concept = safety_concept
self.register_to_config(requires_safety_checker=requires_safety_checker)
@property @property
def safety_concept(self): def safety_concept(self):
......
...@@ -14,8 +14,10 @@ ...@@ -14,8 +14,10 @@
# limitations under the License. # limitations under the License.
import gc import gc
import json
import os import os
import random import random
import shutil
import tempfile import tempfile
import unittest import unittest
from functools import partial from functools import partial
...@@ -40,7 +42,6 @@ from diffusers import ( ...@@ -40,7 +42,6 @@ from diffusers import (
StableDiffusionPipeline, StableDiffusionPipeline,
UNet2DConditionModel, UNet2DConditionModel,
UNet2DModel, UNet2DModel,
VQModel,
logging, logging,
) )
from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipeline_utils import DiffusionPipeline
...@@ -284,32 +285,7 @@ class PipelineFastTests(unittest.TestCase): ...@@ -284,32 +285,7 @@ class PipelineFastTests(unittest.TestCase):
) )
return model return model
def dummy_cond_unet_inpaint(self, sample_size=32): @property
torch.manual_seed(0)
model = UNet2DConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=sample_size,
in_channels=9,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32,
)
return model
def dummy_vq_model(self):
torch.manual_seed(0)
model = VQModel(
block_out_channels=[32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=3,
)
return model
def dummy_vae(self): def dummy_vae(self):
torch.manual_seed(0) torch.manual_seed(0)
model = AutoencoderKL( model = AutoencoderKL(
...@@ -322,6 +298,7 @@ class PipelineFastTests(unittest.TestCase): ...@@ -322,6 +298,7 @@ class PipelineFastTests(unittest.TestCase):
) )
return model return model
@property
def dummy_text_encoder(self): def dummy_text_encoder(self):
torch.manual_seed(0) torch.manual_seed(0)
config = CLIPTextConfig( config = CLIPTextConfig(
...@@ -337,6 +314,7 @@ class PipelineFastTests(unittest.TestCase): ...@@ -337,6 +314,7 @@ class PipelineFastTests(unittest.TestCase):
) )
return CLIPTextModel(config) return CLIPTextModel(config)
@property
def dummy_extractor(self): def dummy_extractor(self):
def extract(*args, **kwargs): def extract(*args, **kwargs):
class Out: class Out:
...@@ -383,8 +361,8 @@ class PipelineFastTests(unittest.TestCase): ...@@ -383,8 +361,8 @@ class PipelineFastTests(unittest.TestCase):
"""Test that components property works correctly""" """Test that components property works correctly"""
unet = self.dummy_cond_unet() unet = self.dummy_cond_unet()
scheduler = PNDMScheduler(skip_prk_steps=True) scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae() vae = self.dummy_vae
bert = self.dummy_text_encoder() bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
image = self.dummy_image().cpu().permute(0, 2, 3, 1)[0] image = self.dummy_image().cpu().permute(0, 2, 3, 1)[0]
...@@ -399,7 +377,7 @@ class PipelineFastTests(unittest.TestCase): ...@@ -399,7 +377,7 @@ class PipelineFastTests(unittest.TestCase):
text_encoder=bert, text_encoder=bert,
tokenizer=tokenizer, tokenizer=tokenizer,
safety_checker=None, safety_checker=None,
feature_extractor=self.dummy_extractor(), feature_extractor=self.dummy_extractor,
).to(torch_device) ).to(torch_device)
img2img = StableDiffusionImg2ImgPipeline(**inpaint.components).to(torch_device) img2img = StableDiffusionImg2ImgPipeline(**inpaint.components).to(torch_device)
text2img = StableDiffusionPipeline(**inpaint.components).to(torch_device) text2img = StableDiffusionPipeline(**inpaint.components).to(torch_device)
...@@ -439,7 +417,7 @@ class PipelineFastTests(unittest.TestCase): ...@@ -439,7 +417,7 @@ class PipelineFastTests(unittest.TestCase):
assert image_text2img.shape == (1, 64, 64, 3) assert image_text2img.shape == (1, 64, 64, 3)
def test_set_scheduler(self): def test_set_scheduler(self):
unet = self.dummy_cond_unet unet = self.dummy_cond_unet()
scheduler = PNDMScheduler(skip_prk_steps=True) scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae vae = self.dummy_vae
bert = self.dummy_text_encoder bert = self.dummy_text_encoder
...@@ -471,7 +449,7 @@ class PipelineFastTests(unittest.TestCase): ...@@ -471,7 +449,7 @@ class PipelineFastTests(unittest.TestCase):
assert isinstance(sd.scheduler, DPMSolverMultistepScheduler) assert isinstance(sd.scheduler, DPMSolverMultistepScheduler)
def test_set_scheduler_consistency(self): def test_set_scheduler_consistency(self):
unet = self.dummy_cond_unet unet = self.dummy_cond_unet()
pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler") pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")
ddim = DDIMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler") ddim = DDIMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")
vae = self.dummy_vae vae = self.dummy_vae
...@@ -514,6 +492,110 @@ class PipelineFastTests(unittest.TestCase): ...@@ -514,6 +492,110 @@ class PipelineFastTests(unittest.TestCase):
assert dict(ddim_config) == dict(ddim_config_2) assert dict(ddim_config) == dict(ddim_config_2)
def test_optional_components(self):
unet = self.dummy_cond_unet()
pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
orig_sd = StableDiffusionPipeline(
unet=unet,
scheduler=pndm,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=unet,
feature_extractor=self.dummy_extractor,
)
sd = orig_sd
assert sd.config.requires_safety_checker is True
with tempfile.TemporaryDirectory() as tmpdirname:
sd.save_pretrained(tmpdirname)
# Test that passing None works
sd = StableDiffusionPipeline.from_pretrained(
tmpdirname, feature_extractor=None, safety_checker=None, requires_safety_checker=False
)
assert sd.config.requires_safety_checker is False
assert sd.config.safety_checker == (None, None)
assert sd.config.feature_extractor == (None, None)
with tempfile.TemporaryDirectory() as tmpdirname:
sd.save_pretrained(tmpdirname)
# Test that loading previous None works
sd = StableDiffusionPipeline.from_pretrained(tmpdirname)
assert sd.config.requires_safety_checker is False
assert sd.config.safety_checker == (None, None)
assert sd.config.feature_extractor == (None, None)
orig_sd.save_pretrained(tmpdirname)
# Test that loading without any directory works
shutil.rmtree(os.path.join(tmpdirname, "safety_checker"))
with open(os.path.join(tmpdirname, sd.config_name)) as f:
config = json.load(f)
config["safety_checker"] = [None, None]
with open(os.path.join(tmpdirname, sd.config_name), "w") as f:
json.dump(config, f)
sd = StableDiffusionPipeline.from_pretrained(tmpdirname, requires_safety_checker=False)
sd.save_pretrained(tmpdirname)
sd = StableDiffusionPipeline.from_pretrained(tmpdirname)
assert sd.config.requires_safety_checker is False
assert sd.config.safety_checker == (None, None)
assert sd.config.feature_extractor == (None, None)
# Test that loading from deleted model index works
with open(os.path.join(tmpdirname, sd.config_name)) as f:
config = json.load(f)
del config["safety_checker"]
del config["feature_extractor"]
with open(os.path.join(tmpdirname, sd.config_name), "w") as f:
json.dump(config, f)
sd = StableDiffusionPipeline.from_pretrained(tmpdirname)
assert sd.config.requires_safety_checker is False
assert sd.config.safety_checker == (None, None)
assert sd.config.feature_extractor == (None, None)
with tempfile.TemporaryDirectory() as tmpdirname:
sd.save_pretrained(tmpdirname)
# Test that partially loading works
sd = StableDiffusionPipeline.from_pretrained(tmpdirname, feature_extractor=self.dummy_extractor)
assert sd.config.requires_safety_checker is False
assert sd.config.safety_checker == (None, None)
assert sd.config.feature_extractor != (None, None)
# Test that partially loading works
sd = StableDiffusionPipeline.from_pretrained(
tmpdirname,
feature_extractor=self.dummy_extractor,
safety_checker=unet,
requires_safety_checker=[True, True],
)
assert sd.config.requires_safety_checker == [True, True]
assert sd.config.safety_checker != (None, None)
assert sd.config.feature_extractor != (None, None)
with tempfile.TemporaryDirectory() as tmpdirname:
sd.save_pretrained(tmpdirname)
sd = StableDiffusionPipeline.from_pretrained(tmpdirname, feature_extractor=self.dummy_extractor)
assert sd.config.requires_safety_checker == [True, True]
assert sd.config.safety_checker != (None, None)
assert sd.config.feature_extractor != (None, None)
@slow @slow
class PipelineSlowTests(unittest.TestCase): class PipelineSlowTests(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment