Commit 80b86587 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

up

parent a9f0785d
......@@ -2,21 +2,11 @@
import tempfile
import sys
from diffusers import GaussianDDPMScheduler, UNetModel
from modeling_ddpm import DDPM
model_id = sys.argv[1]
folder = sys.argv[2]
save = bool(int(sys.argv[3]))
unet = UNetModel.from_pretrained(model_id)
sampler = GaussianDDPMScheduler.from_config(model_id)
# compose Diffusion Pipeline
if save:
ddpm = DDPM(unet, sampler)
ddpm.save_pretrained(folder)
ddpm = DDPM.from_pretrained(model_id)
image = ddpm()
import PIL.Image
......
......@@ -17,6 +17,7 @@
import importlib
import os
from typing import Optional, Union
from huggingface_hub import snapshot_download
# CHANGE to diffusers.utils
from transformers.utils import logging
......@@ -82,7 +83,8 @@ class DiffusionPipeline(Config):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
# use snapshot download here to get it working from from_pretrained
config_dict, _ = cls.get_config_dict(pretrained_model_name_or_path)
cached_folder = snapshot_download(pretrained_model_name_or_path)
config_dict, _ = cls.get_config_dict(cached_folder)
init_kwargs = {}
......@@ -100,7 +102,7 @@ class DiffusionPipeline(Config):
load_method = getattr(class_obj, load_method_name)
loaded_sub_model = load_method(os.path.join(pretrained_model_name_or_path, name))
loaded_sub_model = load_method(os.path.join(cached_folder, name))
init_kwargs[name] = loaded_sub_model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment