Unverified Commit a0422ed0 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[From Single File] Allow vae to be loaded (#4242)

* Allow vae to be loaded

* up
parent 3dd33937
...@@ -1410,6 +1410,9 @@ class FromSingleFileMixin: ...@@ -1410,6 +1410,9 @@ class FromSingleFileMixin:
An instance of `CLIPTextModel` to use, specifically the An instance of `CLIPTextModel` to use, specifically the
[clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. If this [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. If this
parameter is `None`, the function loads a new instance of `CLIPTextModel` by itself if needed. parameter is `None`, the function loads a new instance of `CLIPTextModel` by itself if needed.
vae (`AutoencoderKL`, *optional*, defaults to `None`):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. If
this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed.
tokenizer ([`~transformers.CLIPTokenizer`], *optional*, defaults to `None`): tokenizer ([`~transformers.CLIPTokenizer`], *optional*, defaults to `None`):
An instance of `CLIPTokenizer` to use. If this parameter is `None`, the function loads a new instance An instance of `CLIPTokenizer` to use. If this parameter is `None`, the function loads a new instance
of `CLIPTokenizer` by itself if needed. of `CLIPTokenizer` by itself if needed.
...@@ -1458,6 +1461,7 @@ class FromSingleFileMixin: ...@@ -1458,6 +1461,7 @@ class FromSingleFileMixin:
load_safety_checker = kwargs.pop("load_safety_checker", True) load_safety_checker = kwargs.pop("load_safety_checker", True)
prediction_type = kwargs.pop("prediction_type", None) prediction_type = kwargs.pop("prediction_type", None)
text_encoder = kwargs.pop("text_encoder", None) text_encoder = kwargs.pop("text_encoder", None)
vae = kwargs.pop("vae", None)
controlnet = kwargs.pop("controlnet", None) controlnet = kwargs.pop("controlnet", None)
tokenizer = kwargs.pop("tokenizer", None) tokenizer = kwargs.pop("tokenizer", None)
...@@ -1548,6 +1552,7 @@ class FromSingleFileMixin: ...@@ -1548,6 +1552,7 @@ class FromSingleFileMixin:
load_safety_checker=load_safety_checker, load_safety_checker=load_safety_checker,
prediction_type=prediction_type, prediction_type=prediction_type,
text_encoder=text_encoder, text_encoder=text_encoder,
vae=vae,
tokenizer=tokenizer, tokenizer=tokenizer,
) )
......
...@@ -1107,6 +1107,7 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1107,6 +1107,7 @@ def download_from_original_stable_diffusion_ckpt(
pipeline_class: DiffusionPipeline = None, pipeline_class: DiffusionPipeline = None,
local_files_only=False, local_files_only=False,
vae_path=None, vae_path=None,
vae=None,
text_encoder=None, text_encoder=None,
tokenizer=None, tokenizer=None,
) -> DiffusionPipeline: ) -> DiffusionPipeline:
...@@ -1156,6 +1157,9 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1156,6 +1157,9 @@ def download_from_original_stable_diffusion_ckpt(
The pipeline class to use. Pass `None` to determine automatically. The pipeline class to use. Pass `None` to determine automatically.
local_files_only (`bool`, *optional*, defaults to `False`): local_files_only (`bool`, *optional*, defaults to `False`):
Whether or not to only look at local files (i.e., do not try to download the model). Whether or not to only look at local files (i.e., do not try to download the model).
vae (`AutoencoderKL`, *optional*, defaults to `None`):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. If
this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed.
text_encoder (`CLIPTextModel`, *optional*, defaults to `None`): text_encoder (`CLIPTextModel`, *optional*, defaults to `None`):
An instance of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) An instance of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel)
to use, specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) to use, specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)
...@@ -1361,7 +1365,7 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1361,7 +1365,7 @@ def download_from_original_stable_diffusion_ckpt(
unet.load_state_dict(converted_unet_checkpoint) unet.load_state_dict(converted_unet_checkpoint)
# Convert the VAE model. # Convert the VAE model.
if vae_path is None: if vae_path is None and vae is None:
vae_config = create_vae_diffusers_config(original_config, image_size=image_size) vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
...@@ -1385,7 +1389,7 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1385,7 +1389,7 @@ def download_from_original_stable_diffusion_ckpt(
set_module_tensor_to_device(vae, param_name, "cpu", value=param) set_module_tensor_to_device(vae, param_name, "cpu", value=param)
else: else:
vae.load_state_dict(converted_vae_checkpoint) vae.load_state_dict(converted_vae_checkpoint)
else: elif vae is None:
vae = AutoencoderKL.from_pretrained(vae_path) vae = AutoencoderKL.from_pretrained(vae_path)
if model_type == "FrozenOpenCLIPEmbedder": if model_type == "FrozenOpenCLIPEmbedder":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment