Commit 80b86587 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

up

parent a9f0785d
...@@ -2,21 +2,11 @@ ...@@ -2,21 +2,11 @@
import tempfile import tempfile
import sys import sys
from diffusers import GaussianDDPMScheduler, UNetModel
from modeling_ddpm import DDPM from modeling_ddpm import DDPM
model_id = sys.argv[1] model_id = sys.argv[1]
folder = sys.argv[2]
save = bool(int(sys.argv[3]))
unet = UNetModel.from_pretrained(model_id)
sampler = GaussianDDPMScheduler.from_config(model_id)
# compose Diffusion Pipeline
if save:
ddpm = DDPM(unet, sampler)
ddpm.save_pretrained(folder)
ddpm = DDPM.from_pretrained(model_id)
image = ddpm() image = ddpm()
import PIL.Image import PIL.Image
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import importlib import importlib
import os import os
from typing import Optional, Union from typing import Optional, Union
from huggingface_hub import snapshot_download
# CHANGE to diffusers.utils # CHANGE to diffusers.utils
from transformers.utils import logging from transformers.utils import logging
...@@ -82,7 +83,8 @@ class DiffusionPipeline(Config): ...@@ -82,7 +83,8 @@ class DiffusionPipeline(Config):
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
# use snapshot download here to get it working from from_pretrained # use snapshot download here to get it working from from_pretrained
config_dict, _ = cls.get_config_dict(pretrained_model_name_or_path) cached_folder = snapshot_download(pretrained_model_name_or_path)
config_dict, _ = cls.get_config_dict(cached_folder)
init_kwargs = {} init_kwargs = {}
...@@ -100,7 +102,7 @@ class DiffusionPipeline(Config): ...@@ -100,7 +102,7 @@ class DiffusionPipeline(Config):
load_method = getattr(class_obj, load_method_name) load_method = getattr(class_obj, load_method_name)
loaded_sub_model = load_method(os.path.join(pretrained_model_name_or_path, name)) loaded_sub_model = load_method(os.path.join(cached_folder, name))
init_kwargs[name] = loaded_sub_model init_kwargs[name] = loaded_sub_model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment