Unverified Commit a971c598 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

fix auto_pipeline: pass kwargs to load_config (#4793)



* fix

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 934d439a
......@@ -17,6 +17,7 @@ import inspect
from collections import OrderedDict
from ..configuration_utils import ConfigMixin
from ..utils import DIFFUSERS_CACHE
from .controlnet import (
StableDiffusionControlNetImg2ImgPipeline,
StableDiffusionControlNetInpaintPipeline,
......@@ -295,7 +296,29 @@ class AutoPipelineForText2Image(ConfigMixin):
>>> image = pipeline(prompt).images[0]
```
"""
config = cls.load_config(pretrained_model_or_path)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
use_auth_token = kwargs.pop("use_auth_token", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
user_agent = kwargs.pop("user_agent", {})
load_config_kwargs = {
"cache_dir": cache_dir,
"force_download": force_download,
"resume_download": resume_download,
"proxies": proxies,
"use_auth_token": use_auth_token,
"local_files_only": local_files_only,
"revision": revision,
"subfolder": subfolder,
"user_agent": user_agent,
}
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
orig_class_name = config["_class_name"]
if "controlnet" in kwargs:
......@@ -303,6 +326,7 @@ class AutoPipelineForText2Image(ConfigMixin):
text_2_image_cls = _get_task_class(AUTO_TEXT2IMAGE_PIPELINES_MAPPING, orig_class_name)
kwargs = {**load_config_kwargs, **kwargs}
return text_2_image_cls.from_pretrained(pretrained_model_or_path, **kwargs)
@classmethod
......@@ -535,7 +559,29 @@ class AutoPipelineForImage2Image(ConfigMixin):
>>> image = pipeline(prompt, image).images[0]
```
"""
config = cls.load_config(pretrained_model_or_path)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
use_auth_token = kwargs.pop("use_auth_token", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
user_agent = kwargs.pop("user_agent", {})
load_config_kwargs = {
"cache_dir": cache_dir,
"force_download": force_download,
"resume_download": resume_download,
"proxies": proxies,
"use_auth_token": use_auth_token,
"local_files_only": local_files_only,
"revision": revision,
"subfolder": subfolder,
"user_agent": user_agent,
}
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
orig_class_name = config["_class_name"]
if "controlnet" in kwargs:
......@@ -543,6 +589,7 @@ class AutoPipelineForImage2Image(ConfigMixin):
image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, orig_class_name)
kwargs = {**load_config_kwargs, **kwargs}
return image_2_image_cls.from_pretrained(pretrained_model_or_path, **kwargs)
@classmethod
......@@ -776,7 +823,29 @@ class AutoPipelineForInpainting(ConfigMixin):
>>> image = pipeline(prompt, image=init_image, mask_image=mask_image).images[0]
```
"""
config = cls.load_config(pretrained_model_or_path)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
use_auth_token = kwargs.pop("use_auth_token", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
user_agent = kwargs.pop("user_agent", {})
load_config_kwargs = {
"cache_dir": cache_dir,
"force_download": force_download,
"resume_download": resume_download,
"proxies": proxies,
"use_auth_token": use_auth_token,
"local_files_only": local_files_only,
"revision": revision,
"subfolder": subfolder,
"user_agent": user_agent,
}
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
orig_class_name = config["_class_name"]
if "controlnet" in kwargs:
......@@ -784,6 +853,7 @@ class AutoPipelineForInpainting(ConfigMixin):
inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name)
kwargs = {**load_config_kwargs, **kwargs}
return inpainting_cls.from_pretrained(pretrained_model_or_path, **kwargs)
@classmethod
......
......@@ -14,8 +14,11 @@
# limitations under the License.
import gc
import os
import shutil
import unittest
from collections import OrderedDict
from pathlib import Path
import torch
......@@ -24,6 +27,7 @@ from diffusers import (
AutoPipelineForInpainting,
AutoPipelineForText2Image,
ControlNetModel,
DiffusionPipeline,
)
from diffusers.pipelines.auto_pipeline import (
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING,
......@@ -81,6 +85,29 @@ class AutoPipelineFastTest(unittest.TestCase):
assert dict(pipe.config) == original_config
def test_kwargs_local_files_only(self):
repo = "hf-internal-testing/tiny-stable-diffusion-torch"
tmpdirname = DiffusionPipeline.download(repo)
tmpdirname = Path(tmpdirname)
# edit commit_id to so that it's not the latest commit
commit_id = tmpdirname.name
new_commit_id = commit_id + "hug"
ref_dir = tmpdirname.parent.parent / "refs/main"
with open(ref_dir, "w") as f:
f.write(new_commit_id)
new_tmpdirname = tmpdirname.parent / new_commit_id
os.rename(tmpdirname, new_tmpdirname)
try:
AutoPipelineForText2Image.from_pretrained(repo, local_files_only=True)
except OSError:
assert False, "not able to load local files"
shutil.rmtree(tmpdirname.parent.parent)
@slow
class AutoPipelineIntegrationTest(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment