"sgl-kernel/git@developer.sourcefind.cn:change/sglang.git" did not exist on "c186feed7fb7604db59377e74d48bcc61053832e"
Commit 397b31c8 authored by patil-suraj's avatar patil-suraj
Browse files

allow loading modules from hub

parent 46dae846
...@@ -54,6 +54,10 @@ class DiffusionPipeline(ConfigMixin): ...@@ -54,6 +54,10 @@ class DiffusionPipeline(ConfigMixin):
for name, module in kwargs.items(): for name, module in kwargs.items():
# retrive library # retrive library
library = module.__module__.split(".")[0] library = module.__module__.split(".")[0]
# if library is not in LOADABLE_CLASSES, then it is a custom module
if library not in LOADABLE_CLASSES:
library = module.__module__.split(".")[-1]
# retrive class_name # retrive class_name
class_name = module.__class__.__name__ class_name = module.__class__.__name__
...@@ -105,6 +109,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -105,6 +109,7 @@ class DiffusionPipeline(ConfigMixin):
config_dict = cls.get_config_dict(cached_folder) config_dict = cls.get_config_dict(cached_folder)
module_candidate = config_dict["_module"] module_candidate = config_dict["_module"]
module_candidate_name = module_candidate.replace(".py", "")
# if we load from explicit class, let's use it # if we load from explicit class, let's use it
if cls != DiffusionPipeline: if cls != DiffusionPipeline:
...@@ -120,12 +125,14 @@ class DiffusionPipeline(ConfigMixin): ...@@ -120,12 +125,14 @@ class DiffusionPipeline(ConfigMixin):
init_kwargs = {} init_kwargs = {}
for name, (library_name, class_name) in init_dict.items(): for name, (library_name, class_name) in init_dict.items():
importable_classes = LOADABLE_CLASSES[library_name]
if library_name == module_candidate: # if the model is not in diffusers or transformers, we need to load it from the hub
# TODO(Suraj) # assumes that it's a subclass of ModelMixin
# for vq if library_name == module_candidate_name:
pass class_obj = get_class_from_dynamic_module(cached_folder, module, class_name, cached_folder)
load_method_name = "from_pretrained"
else:
importable_classes = LOADABLE_CLASSES[library_name]
library = importlib.import_module(library_name) library = importlib.import_module(library_name)
class_obj = getattr(library, class_name) class_obj = getattr(library, class_name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment