Unverified Commit a971c598 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

fix auto_pipeline: pass kwargs to load_config (#4793)



* fix

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 934d439a
...@@ -17,6 +17,7 @@ import inspect ...@@ -17,6 +17,7 @@ import inspect
from collections import OrderedDict from collections import OrderedDict
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..utils import DIFFUSERS_CACHE
from .controlnet import ( from .controlnet import (
StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetImg2ImgPipeline,
StableDiffusionControlNetInpaintPipeline, StableDiffusionControlNetInpaintPipeline,
...@@ -295,7 +296,29 @@ class AutoPipelineForText2Image(ConfigMixin): ...@@ -295,7 +296,29 @@ class AutoPipelineForText2Image(ConfigMixin):
>>> image = pipeline(prompt).images[0] >>> image = pipeline(prompt).images[0]
``` ```
""" """
config = cls.load_config(pretrained_model_or_path) cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
use_auth_token = kwargs.pop("use_auth_token", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
user_agent = kwargs.pop("user_agent", {})
load_config_kwargs = {
"cache_dir": cache_dir,
"force_download": force_download,
"resume_download": resume_download,
"proxies": proxies,
"use_auth_token": use_auth_token,
"local_files_only": local_files_only,
"revision": revision,
"subfolder": subfolder,
"user_agent": user_agent,
}
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
orig_class_name = config["_class_name"] orig_class_name = config["_class_name"]
if "controlnet" in kwargs: if "controlnet" in kwargs:
...@@ -303,6 +326,7 @@ class AutoPipelineForText2Image(ConfigMixin): ...@@ -303,6 +326,7 @@ class AutoPipelineForText2Image(ConfigMixin):
text_2_image_cls = _get_task_class(AUTO_TEXT2IMAGE_PIPELINES_MAPPING, orig_class_name) text_2_image_cls = _get_task_class(AUTO_TEXT2IMAGE_PIPELINES_MAPPING, orig_class_name)
kwargs = {**load_config_kwargs, **kwargs}
return text_2_image_cls.from_pretrained(pretrained_model_or_path, **kwargs) return text_2_image_cls.from_pretrained(pretrained_model_or_path, **kwargs)
@classmethod @classmethod
...@@ -535,7 +559,29 @@ class AutoPipelineForImage2Image(ConfigMixin): ...@@ -535,7 +559,29 @@ class AutoPipelineForImage2Image(ConfigMixin):
>>> image = pipeline(prompt, image).images[0] >>> image = pipeline(prompt, image).images[0]
``` ```
""" """
config = cls.load_config(pretrained_model_or_path) cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
use_auth_token = kwargs.pop("use_auth_token", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
user_agent = kwargs.pop("user_agent", {})
load_config_kwargs = {
"cache_dir": cache_dir,
"force_download": force_download,
"resume_download": resume_download,
"proxies": proxies,
"use_auth_token": use_auth_token,
"local_files_only": local_files_only,
"revision": revision,
"subfolder": subfolder,
"user_agent": user_agent,
}
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
orig_class_name = config["_class_name"] orig_class_name = config["_class_name"]
if "controlnet" in kwargs: if "controlnet" in kwargs:
...@@ -543,6 +589,7 @@ class AutoPipelineForImage2Image(ConfigMixin): ...@@ -543,6 +589,7 @@ class AutoPipelineForImage2Image(ConfigMixin):
image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, orig_class_name) image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, orig_class_name)
kwargs = {**load_config_kwargs, **kwargs}
return image_2_image_cls.from_pretrained(pretrained_model_or_path, **kwargs) return image_2_image_cls.from_pretrained(pretrained_model_or_path, **kwargs)
@classmethod @classmethod
...@@ -776,7 +823,29 @@ class AutoPipelineForInpainting(ConfigMixin): ...@@ -776,7 +823,29 @@ class AutoPipelineForInpainting(ConfigMixin):
>>> image = pipeline(prompt, image=init_image, mask_image=mask_image).images[0] >>> image = pipeline(prompt, image=init_image, mask_image=mask_image).images[0]
``` ```
""" """
config = cls.load_config(pretrained_model_or_path) cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
use_auth_token = kwargs.pop("use_auth_token", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
user_agent = kwargs.pop("user_agent", {})
load_config_kwargs = {
"cache_dir": cache_dir,
"force_download": force_download,
"resume_download": resume_download,
"proxies": proxies,
"use_auth_token": use_auth_token,
"local_files_only": local_files_only,
"revision": revision,
"subfolder": subfolder,
"user_agent": user_agent,
}
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
orig_class_name = config["_class_name"] orig_class_name = config["_class_name"]
if "controlnet" in kwargs: if "controlnet" in kwargs:
...@@ -784,6 +853,7 @@ class AutoPipelineForInpainting(ConfigMixin): ...@@ -784,6 +853,7 @@ class AutoPipelineForInpainting(ConfigMixin):
inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name) inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name)
kwargs = {**load_config_kwargs, **kwargs}
return inpainting_cls.from_pretrained(pretrained_model_or_path, **kwargs) return inpainting_cls.from_pretrained(pretrained_model_or_path, **kwargs)
@classmethod @classmethod
......
...@@ -14,8 +14,11 @@ ...@@ -14,8 +14,11 @@
# limitations under the License. # limitations under the License.
import gc import gc
import os
import shutil
import unittest import unittest
from collections import OrderedDict from collections import OrderedDict
from pathlib import Path
import torch import torch
...@@ -24,6 +27,7 @@ from diffusers import ( ...@@ -24,6 +27,7 @@ from diffusers import (
AutoPipelineForInpainting, AutoPipelineForInpainting,
AutoPipelineForText2Image, AutoPipelineForText2Image,
ControlNetModel, ControlNetModel,
DiffusionPipeline,
) )
from diffusers.pipelines.auto_pipeline import ( from diffusers.pipelines.auto_pipeline import (
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, AUTO_IMAGE2IMAGE_PIPELINES_MAPPING,
...@@ -81,6 +85,29 @@ class AutoPipelineFastTest(unittest.TestCase): ...@@ -81,6 +85,29 @@ class AutoPipelineFastTest(unittest.TestCase):
assert dict(pipe.config) == original_config assert dict(pipe.config) == original_config
def test_kwargs_local_files_only(self):
repo = "hf-internal-testing/tiny-stable-diffusion-torch"
tmpdirname = DiffusionPipeline.download(repo)
tmpdirname = Path(tmpdirname)
# edit commit_id to so that it's not the latest commit
commit_id = tmpdirname.name
new_commit_id = commit_id + "hug"
ref_dir = tmpdirname.parent.parent / "refs/main"
with open(ref_dir, "w") as f:
f.write(new_commit_id)
new_tmpdirname = tmpdirname.parent / new_commit_id
os.rename(tmpdirname, new_tmpdirname)
try:
AutoPipelineForText2Image.from_pretrained(repo, local_files_only=True)
except OSError:
assert False, "not able to load local files"
shutil.rmtree(tmpdirname.parent.parent)
@slow @slow
class AutoPipelineIntegrationTest(unittest.TestCase): class AutoPipelineIntegrationTest(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment