Commit dd4cd081 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

fix naming

parent ab8e5364
...@@ -23,7 +23,7 @@ class DDPM(DiffusionPipeline): ...@@ -23,7 +23,7 @@ class DDPM(DiffusionPipeline):
modeling_file = "modeling_ddpm.py" modeling_file = "modeling_ddpm.py"
def __init__(self, unet, noise_scheduler, vqvae): def __init__(self, unet, noise_scheduler):
super().__init__() super().__init__()
self.register_modules(unet=unet, noise_scheduler=noise_scheduler) self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
......
...@@ -90,10 +90,14 @@ class DiffusionPipeline(ConfigMixin): ...@@ -90,10 +90,14 @@ class DiffusionPipeline(ConfigMixin):
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
# use snapshot download here to get it working from from_pretrained # use snapshot download here to get it working from from_pretrained
cached_folder = snapshot_download(pretrained_model_name_or_path) if not os.path.isdir(pretrained_model_name_or_path):
cached_folder = snapshot_download(pretrained_model_name_or_path)
else:
cached_folder = pretrained_model_name_or_path
config_dict, pipeline_kwargs = cls.get_config_dict(cached_folder) config_dict, pipeline_kwargs = cls.get_config_dict(cached_folder)
module = pipeline_kwargs["_module"] module = pipeline_kwargs.pop("_module", None)
# TODO(Suraj) - make from hub import work # TODO(Suraj) - make from hub import work
# Make `ddpm = DiffusionPipeline.from_pretrained("fusing/ddpm-lsun-bedroom-pipe")` work # Make `ddpm = DiffusionPipeline.from_pretrained("fusing/ddpm-lsun-bedroom-pipe")` work
# Add Sylvains code from transformers # Add Sylvains code from transformers
...@@ -118,7 +122,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -118,7 +122,7 @@ class DiffusionPipeline(ConfigMixin):
load_method = getattr(class_obj, load_method_name) load_method = getattr(class_obj, load_method_name)
if os.path.dir(os.path.join(cached_folder, name)): if os.path.isdir(os.path.join(cached_folder, name)):
loaded_sub_model = load_method(os.path.join(cached_folder, name)) loaded_sub_model = load_method(os.path.join(cached_folder, name))
else: else:
loaded_sub_model = load_method(cached_folder) loaded_sub_model = load_method(cached_folder)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment