Commit e3dfaf82 authored by anton-l's avatar anton-l
Browse files

save local pipeline modules

parent 99540747
...@@ -40,6 +40,7 @@ LOADABLE_CLASSES = { ...@@ -40,6 +40,7 @@ LOADABLE_CLASSES = {
}, },
"transformers": { "transformers": {
"PreTrainedTokenizer": ["save_pretrained", "from_pretrained"], "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
"PreTrainedModel": ["save_pretrained", "from_pretrained"],
}, },
} }
...@@ -82,24 +83,25 @@ class DiffusionPipeline(ConfigMixin): ...@@ -82,24 +83,25 @@ class DiffusionPipeline(ConfigMixin):
model_index_dict.pop("_diffusers_version") model_index_dict.pop("_diffusers_version")
model_index_dict.pop("_module") model_index_dict.pop("_module")
for name, (library_name, class_name) in model_index_dict.items(): for pipeline_component_name in model_index_dict.keys():
importable_classes = LOADABLE_CLASSES[library_name] sub_model = getattr(self, pipeline_component_name)
model_cls = sub_model.__class__
# TODO: Suraj
if library_name == self.__module__:
library_name = self
library = importlib.import_module(library_name)
class_obj = getattr(library, class_name)
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
save_method_name = None save_method_name = None
for class_name, class_candidate in class_candidates.items(): # search for the model's base class in LOADABLE_CLASSES
if issubclass(class_obj, class_candidate): for library_name, library_classes in LOADABLE_CLASSES.items():
save_method_name = importable_classes[class_name][0] library = importlib.import_module(library_name)
for base_class, save_load_methods in library_classes.items():
save_method = getattr(getattr(self, name), save_method_name) class_candidate = getattr(library, base_class)
save_method(os.path.join(save_directory, name)) if issubclass(model_cls, class_candidate):
# if we found a suitable base class in LOADABLE_CLASSES then grab its save method
save_method_name = save_load_methods[0]
break
if save_method_name is not None:
break
save_method = getattr(sub_model, save_method_name)
save_method(os.path.join(save_directory, pipeline_component_name))
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment