Commit 397b31c8 authored by patil-suraj's avatar patil-suraj
Browse files

allow loading modules from hub

parent 46dae846
...@@ -54,6 +54,10 @@ class DiffusionPipeline(ConfigMixin): ...@@ -54,6 +54,10 @@ class DiffusionPipeline(ConfigMixin):
for name, module in kwargs.items(): for name, module in kwargs.items():
# retrive library # retrive library
library = module.__module__.split(".")[0] library = module.__module__.split(".")[0]
# if library is not in LOADABLE_CLASSES, then it is a custom module
if library not in LOADABLE_CLASSES:
library = module.__module__.split(".")[-1]
# retrive class_name # retrive class_name
class_name = module.__class__.__name__ class_name = module.__class__.__name__
...@@ -105,6 +109,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -105,6 +109,7 @@ class DiffusionPipeline(ConfigMixin):
config_dict = cls.get_config_dict(cached_folder) config_dict = cls.get_config_dict(cached_folder)
module_candidate = config_dict["_module"] module_candidate = config_dict["_module"]
module_candidate_name = module_candidate.replace(".py", "")
# if we load from explicit class, let's use it # if we load from explicit class, let's use it
if cls != DiffusionPipeline: if cls != DiffusionPipeline:
...@@ -120,21 +125,23 @@ class DiffusionPipeline(ConfigMixin): ...@@ -120,21 +125,23 @@ class DiffusionPipeline(ConfigMixin):
init_kwargs = {} init_kwargs = {}
for name, (library_name, class_name) in init_dict.items(): for name, (library_name, class_name) in init_dict.items():
importable_classes = LOADABLE_CLASSES[library_name]
# if the model is not in diffusers or transformers, we need to load it from the hub
# assumes that it's a subclass of ModelMixin
if library_name == module_candidate_name:
class_obj = get_class_from_dynamic_module(cached_folder, module, class_name, cached_folder)
load_method_name = "from_pretrained"
else:
importable_classes = LOADABLE_CLASSES[library_name]
if library_name == module_candidate: library = importlib.import_module(library_name)
# TODO(Suraj) class_obj = getattr(library, class_name)
# for vq class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
pass
library = importlib.import_module(library_name) load_method_name = None
class_obj = getattr(library, class_name) for class_name, class_candidate in class_candidates.items():
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()} if issubclass(class_obj, class_candidate):
load_method_name = importable_classes[class_name][1]
load_method_name = None
for class_name, class_candidate in class_candidates.items():
if issubclass(class_obj, class_candidate):
load_method_name = importable_classes[class_name][1]
load_method = getattr(class_obj, load_method_name) load_method = getattr(class_obj, load_method_name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment