Unverified Commit 7386e773 authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Show error when loading safety_checker `from_flax` (#2187)

* Show error when loading safety_checker `from_flax`

* fix style
parent 154a7865
...@@ -595,6 +595,14 @@ class DiffusionPipeline(ConfigMixin): ...@@ -595,6 +595,14 @@ class DiffusionPipeline(ConfigMixin):
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)} init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
# Special case: safety_checker must be loaded separately when using `from_flax`
if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj:
raise NotImplementedError(
"The safety checker cannot be automatically loaded when loading weights `from_flax`."
" Please, pass `safety_checker=None` to `from_pretrained`, and load the safety checker"
" separately if you need it."
)
if len(unused_kwargs) > 0: if len(unused_kwargs) > 0:
logger.warning( logger.warning(
f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored." f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment