Unverified Commit a0422ed0 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[From Single File] Allow vae to be loaded (#4242)

* Allow vae to be loaded

* up
parent 3dd33937
......@@ -1410,6 +1410,9 @@ class FromSingleFileMixin:
An instance of `CLIPTextModel` to use, specifically the
[clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. If this
parameter is `None`, the function loads a new instance of `CLIPTextModel` by itself if needed.
vae (`AutoencoderKL`, *optional*, defaults to `None`):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. If
this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed.
tokenizer ([`~transformers.CLIPTokenizer`], *optional*, defaults to `None`):
An instance of `CLIPTokenizer` to use. If this parameter is `None`, the function loads a new instance
of `CLIPTokenizer` by itself if needed.
......@@ -1458,6 +1461,7 @@ class FromSingleFileMixin:
load_safety_checker = kwargs.pop("load_safety_checker", True)
prediction_type = kwargs.pop("prediction_type", None)
text_encoder = kwargs.pop("text_encoder", None)
vae = kwargs.pop("vae", None)
controlnet = kwargs.pop("controlnet", None)
tokenizer = kwargs.pop("tokenizer", None)
......@@ -1548,6 +1552,7 @@ class FromSingleFileMixin:
load_safety_checker=load_safety_checker,
prediction_type=prediction_type,
text_encoder=text_encoder,
vae=vae,
tokenizer=tokenizer,
)
......
......@@ -1107,6 +1107,7 @@ def download_from_original_stable_diffusion_ckpt(
pipeline_class: DiffusionPipeline = None,
local_files_only=False,
vae_path=None,
vae=None,
text_encoder=None,
tokenizer=None,
) -> DiffusionPipeline:
......@@ -1156,6 +1157,9 @@ def download_from_original_stable_diffusion_ckpt(
The pipeline class to use. Pass `None` to determine automatically.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether or not to only look at local files (i.e., do not try to download the model).
vae (`AutoencoderKL`, *optional*, defaults to `None`):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. If
this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed.
text_encoder (`CLIPTextModel`, *optional*, defaults to `None`):
An instance of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel)
to use, specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)
......@@ -1361,7 +1365,7 @@ def download_from_original_stable_diffusion_ckpt(
unet.load_state_dict(converted_unet_checkpoint)
# Convert the VAE model.
if vae_path is None:
if vae_path is None and vae is None:
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
......@@ -1385,7 +1389,7 @@ def download_from_original_stable_diffusion_ckpt(
set_module_tensor_to_device(vae, param_name, "cpu", value=param)
else:
vae.load_state_dict(converted_vae_checkpoint)
else:
elif vae is None:
vae = AutoencoderKL.from_pretrained(vae_path)
if model_type == "FrozenOpenCLIPEmbedder":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment