Unverified Commit 74d902eb authored by zuojianghua's avatar zuojianghua Committed by GitHub
Browse files

add config_file to from_single_file (#4614)



* Update loaders.py

add config_file to from_single_file, 
when the download_from_original_stable_diffusion_ckpt use

* Update loaders.py

add config_file to from_single_file,
when the download_from_original_stable_diffusion_ckpt use

* change config_file to original_config_file

* make style && make quality

---------
Co-authored-by: default avatarjianghua.zuo <jianghua.zuo@weimob.com>
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent d7c4ae61
...@@ -1790,6 +1790,9 @@ class FromSingleFileMixin: ...@@ -1790,6 +1790,9 @@ class FromSingleFileMixin:
tokenizer ([`~transformers.CLIPTokenizer`], *optional*, defaults to `None`): tokenizer ([`~transformers.CLIPTokenizer`], *optional*, defaults to `None`):
An instance of `CLIPTokenizer` to use. If this parameter is `None`, the function loads a new instance An instance of `CLIPTokenizer` to use. If this parameter is `None`, the function loads a new instance
of `CLIPTokenizer` by itself if needed. of `CLIPTokenizer` by itself if needed.
original_config_file (`str`):
Path to `.yaml` config file corresponding to the original architecture. If `None`, will be
automatically inferred by looking for a key that only exists in SD2.0 models.
kwargs (remaining dictionary of keyword arguments, *optional*): kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables (for example the pipeline components of the Can be used to overwrite load and saveable variables (for example the pipeline components of the
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__` specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
...@@ -1820,6 +1823,7 @@ class FromSingleFileMixin: ...@@ -1820,6 +1823,7 @@ class FromSingleFileMixin:
# import here to avoid circular dependency # import here to avoid circular dependency
from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
original_config_file = kwargs.pop("original_config_file", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
...@@ -1936,6 +1940,7 @@ class FromSingleFileMixin: ...@@ -1936,6 +1940,7 @@ class FromSingleFileMixin:
text_encoder=text_encoder, text_encoder=text_encoder,
vae=vae, vae=vae,
tokenizer=tokenizer, tokenizer=tokenizer,
original_config_file=original_config_file,
) )
if torch_dtype is not None: if torch_dtype is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment