Commit 40dc888f authored by Patrick von Platen's avatar Patrick von Platen
Browse files

add first logic for from hub code download

parent e8ad2b75
......@@ -20,6 +20,9 @@ import torch
class DDPM(DiffusionPipeline):
modeling_file = "modeling_ddpm.py"
def __init__(self, unet, noise_scheduler):
super().__init__()
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
......
......@@ -53,8 +53,11 @@ class DiffusionPipeline(ConfigMixin):
# retrive class_name
class_name = module.__class__.__name__
register_dict = {name: (library, class_name)}
register_dict["_module"] = self.__module__
# save model index config
self.register(**{name: (library, class_name)})
self.register(**register_dict)
# set models
setattr(self, name, module)
......@@ -84,7 +87,10 @@ class DiffusionPipeline(ConfigMixin):
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
# use snapshot download here to get it working from from_pretrained
cached_folder = snapshot_download(pretrained_model_name_or_path)
config_dict, _ = cls.get_config_dict(cached_folder)
config_dict, pipeline_kwargs = cls.get_config_dict(cached_folder)
module = pipeline_kwargs["_module"]
# TODO(Suraj) - make from hub import work
init_kwargs = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment