Commit e3dfaf82 authored by anton-l's avatar anton-l
Browse files

save local pipeline modules

parent 99540747
......@@ -40,6 +40,7 @@ LOADABLE_CLASSES = {
},
"transformers": {
"PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
"PreTrainedModel": ["save_pretrained", "from_pretrained"],
},
}
......@@ -82,24 +83,25 @@ class DiffusionPipeline(ConfigMixin):
model_index_dict.pop("_diffusers_version")
model_index_dict.pop("_module")
for name, (library_name, class_name) in model_index_dict.items():
importable_classes = LOADABLE_CLASSES[library_name]
# TODO: Suraj
if library_name == self.__module__:
library_name = self
library = importlib.import_module(library_name)
class_obj = getattr(library, class_name)
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
for pipeline_component_name in model_index_dict.keys():
sub_model = getattr(self, pipeline_component_name)
model_cls = sub_model.__class__
save_method_name = None
for class_name, class_candidate in class_candidates.items():
if issubclass(class_obj, class_candidate):
save_method_name = importable_classes[class_name][0]
save_method = getattr(getattr(self, name), save_method_name)
save_method(os.path.join(save_directory, name))
# search for the model's base class in LOADABLE_CLASSES
for library_name, library_classes in LOADABLE_CLASSES.items():
library = importlib.import_module(library_name)
for base_class, save_load_methods in library_classes.items():
class_candidate = getattr(library, base_class)
if issubclass(model_cls, class_candidate):
# if we found a suitable base class in LOADABLE_CLASSES then grab its save method
save_method_name = save_load_methods[0]
break
if save_method_name is not None:
break
save_method = getattr(sub_model, save_method_name)
save_method(os.path.join(save_directory, pipeline_component_name))
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment