Commit 1a6196e8 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

add more logic for dynamic loading

parent 40dc888f
...@@ -23,7 +23,7 @@ class DDPM(DiffusionPipeline): ...@@ -23,7 +23,7 @@ class DDPM(DiffusionPipeline):
modeling_file = "modeling_ddpm.py" modeling_file = "modeling_ddpm.py"
def __init__(self, unet, noise_scheduler): def __init__(self, unet, noise_scheduler, vqvae):
super().__init__() super().__init__()
self.register_modules(unet=unet, noise_scheduler=noise_scheduler) self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
......
...@@ -71,6 +71,10 @@ class DiffusionPipeline(ConfigMixin): ...@@ -71,6 +71,10 @@ class DiffusionPipeline(ConfigMixin):
for name, (library_name, class_name) in self._dict_to_save.items(): for name, (library_name, class_name) in self._dict_to_save.items():
importable_classes = LOADABLE_CLASSES[library_name] importable_classes = LOADABLE_CLASSES[library_name]
# TODO: Suraj
if library_name == self.__module__:
library_name = self
library = importlib.import_module(library_name) library = importlib.import_module(library_name)
class_obj = getattr(library, class_name) class_obj = getattr(library, class_name)
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()} class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
...@@ -91,12 +95,18 @@ class DiffusionPipeline(ConfigMixin): ...@@ -91,12 +95,18 @@ class DiffusionPipeline(ConfigMixin):
module = pipeline_kwargs["_module"] module = pipeline_kwargs["_module"]
# TODO(Suraj) - make from hub import work # TODO(Suraj) - make from hub import work
# Make `ddpm = DiffusionPipeline.from_pretrained("fusing/ddpm-lsun-bedroom-pipe")` work
# Add Sylvains code from transformers
init_kwargs = {} init_kwargs = {}
for name, (library_name, class_name) in config_dict.items(): for name, (library_name, class_name) in config_dict.items():
importable_classes = LOADABLE_CLASSES[library_name] importable_classes = LOADABLE_CLASSES[library_name]
if library_name == module:
# TODO(Suraj)
pass
library = importlib.import_module(library_name) library = importlib.import_module(library_name)
class_obj = getattr(library, class_name) class_obj = getattr(library, class_name)
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()} class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
...@@ -110,7 +120,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -110,7 +120,7 @@ class DiffusionPipeline(ConfigMixin):
loaded_sub_model = load_method(os.path.join(cached_folder, name)) loaded_sub_model = load_method(os.path.join(cached_folder, name))
init_kwargs[name] = loaded_sub_model init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
model = cls(**init_kwargs) model = cls(**init_kwargs)
return model return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment