Unverified Commit 46c52f9b authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Pipelines] Make sure that None functions are correctly not saved (#3080)

parent d06e0694
...@@ -19,6 +19,7 @@ import importlib ...@@ -19,6 +19,7 @@ import importlib
import inspect import inspect
import os import os
import re import re
import sys
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
...@@ -540,11 +541,9 @@ class DiffusionPipeline(ConfigMixin): ...@@ -540,11 +541,9 @@ class DiffusionPipeline(ConfigMixin):
variant (`str`, *optional*): variant (`str`, *optional*):
If specified, weights are saved in the format pytorch_model.<variant>.bin. If specified, weights are saved in the format pytorch_model.<variant>.bin.
""" """
self.save_config(save_directory)
model_index_dict = dict(self.config) model_index_dict = dict(self.config)
model_index_dict.pop("_class_name") model_index_dict.pop("_class_name", None)
model_index_dict.pop("_diffusers_version") model_index_dict.pop("_diffusers_version", None)
model_index_dict.pop("_module", None) model_index_dict.pop("_module", None)
expected_modules, optional_kwargs = self._get_signature_keys(self) expected_modules, optional_kwargs = self._get_signature_keys(self)
...@@ -557,7 +556,6 @@ class DiffusionPipeline(ConfigMixin): ...@@ -557,7 +556,6 @@ class DiffusionPipeline(ConfigMixin):
return True return True
model_index_dict = {k: v for k, v in model_index_dict.items() if is_saveable_module(k, v)} model_index_dict = {k: v for k, v in model_index_dict.items() if is_saveable_module(k, v)}
for pipeline_component_name in model_index_dict.keys(): for pipeline_component_name in model_index_dict.keys():
sub_model = getattr(self, pipeline_component_name) sub_model = getattr(self, pipeline_component_name)
model_cls = sub_model.__class__ model_cls = sub_model.__class__
...@@ -571,7 +569,13 @@ class DiffusionPipeline(ConfigMixin): ...@@ -571,7 +569,13 @@ class DiffusionPipeline(ConfigMixin):
save_method_name = None save_method_name = None
# search for the model's base class in LOADABLE_CLASSES # search for the model's base class in LOADABLE_CLASSES
for library_name, library_classes in LOADABLE_CLASSES.items(): for library_name, library_classes in LOADABLE_CLASSES.items():
if library_name in sys.modules:
library = importlib.import_module(library_name) library = importlib.import_module(library_name)
else:
logger.info(
f"{library_name} is not installed. Cannot save {pipeline_component_name} as {library_classes} from {library_name}"
)
for base_class, save_load_methods in library_classes.items(): for base_class, save_load_methods in library_classes.items():
class_candidate = getattr(library, base_class, None) class_candidate = getattr(library, base_class, None)
if class_candidate is not None and issubclass(model_cls, class_candidate): if class_candidate is not None and issubclass(model_cls, class_candidate):
...@@ -581,6 +585,12 @@ class DiffusionPipeline(ConfigMixin): ...@@ -581,6 +585,12 @@ class DiffusionPipeline(ConfigMixin):
if save_method_name is not None: if save_method_name is not None:
break break
if save_method_name is None:
logger.warn(f"self.{pipeline_component_name}={sub_model} of type {type(sub_model)} cannot be saved.")
# make sure that unsaveable components are not tried to be loaded afterward
self.register_to_config(**{pipeline_component_name: (None, None)})
continue
save_method = getattr(sub_model, save_method_name) save_method = getattr(sub_model, save_method_name)
# Call the save method with the argument safe_serialization only if it's supported # Call the save method with the argument safe_serialization only if it's supported
...@@ -596,6 +606,9 @@ class DiffusionPipeline(ConfigMixin): ...@@ -596,6 +606,9 @@ class DiffusionPipeline(ConfigMixin):
save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs) save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
# finally save the config
self.save_config(save_directory)
def to( def to(
self, self,
torch_device: Optional[Union[str, torch.device]] = None, torch_device: Optional[Union[str, torch.device]] = None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment