Unverified Commit 5df2acf7 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Conversion] Small fixes (#3848)

* [Conversion] Small fixes

* Update src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
parent 88d26946
...@@ -129,11 +129,19 @@ def vae_pt_to_vae_diffuser( ...@@ -129,11 +129,19 @@ def vae_pt_to_vae_diffuser(
original_config = OmegaConf.load(io_obj) original_config = OmegaConf.load(io_obj)
image_size = 512 image_size = 512
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(checkpoint_path, map_location=device) if checkpoint_path.endswith("safetensors"):
from safetensors import safe_open
checkpoint = {}
with safe_open(checkpoint_path, framework="pt", device="cpu") as f:
for key in f.keys():
checkpoint[key] = f.get_tensor(key)
else:
checkpoint = torch.load(checkpoint_path, map_location=device)["state_dict"]
# Convert the VAE model. # Convert the VAE model.
vae_config = create_vae_diffusers_config(original_config, image_size=image_size) vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
converted_vae_checkpoint = custom_convert_ldm_vae_checkpoint(checkpoint["state_dict"], vae_config) converted_vae_checkpoint = custom_convert_ldm_vae_checkpoint(checkpoint, vae_config)
vae = AutoencoderKL(**vae_config) vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint) vae.load_state_dict(converted_vae_checkpoint)
......
...@@ -286,10 +286,11 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa ...@@ -286,10 +286,11 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
"use_linear_projection": use_linear_projection, "use_linear_projection": use_linear_projection,
"class_embed_type": class_embed_type, "class_embed_type": class_embed_type,
"projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
"conditioning_channels": unet_params.hint_channels,
} }
if not controlnet: if controlnet:
config["conditioning_channels"] = unet_params.hint_channels
else:
config["out_channels"] = unet_params.out_channels config["out_channels"] = unet_params.out_channels
config["up_block_types"] = tuple(up_block_types) config["up_block_types"] = tuple(up_block_types)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment