Unverified Commit cb432c4e authored by cmdr2's avatar cmdr2 Committed by GitHub
Browse files

Allow passing a checkpoint state_dict to convert_from_ckpt (instead of just a string path) (#4653)


Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent b7b1a30b
......@@ -17,7 +17,7 @@
import re
from contextlib import nullcontext
from io import BytesIO
from typing import Optional
from typing import Optional, Union, Dict
import requests
import torch
......@@ -1111,7 +1111,7 @@ def convert_controlnet_checkpoint(
def download_from_original_stable_diffusion_ckpt(
checkpoint_path: str,
checkpoint_path_or_dict: Union[str, Dict[str, torch.Tensor]],
original_config_file: str = None,
image_size: Optional[int] = None,
prediction_type: str = None,
......@@ -1144,7 +1144,7 @@ def download_from_original_stable_diffusion_ckpt(
recommended that you override the default values and/or supply an `original_config_file` wherever possible.
Args:
checkpoint_path (`str`): Path to `.ckpt` file.
checkpoint_path_or_dict (`str` or `dict`): Path to `.ckpt` file, or the state dict.
original_config_file (`str`):
Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically
inferred by looking for a key that only exists in SD2.0 models.
......@@ -1226,16 +1226,19 @@ def download_from_original_stable_diffusion_ckpt(
from omegaconf import OmegaConf
if isinstance(checkpoint_path_or_dict, str):
if from_safetensors:
from safetensors.torch import load_file as safe_load
checkpoint = safe_load(checkpoint_path, device="cpu")
checkpoint = safe_load(checkpoint_path_or_dict, device="cpu")
else:
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(checkpoint_path, map_location=device)
checkpoint = torch.load(checkpoint_path_or_dict, map_location=device)
else:
checkpoint = torch.load(checkpoint_path, map_location=device)
checkpoint = torch.load(checkpoint_path_or_dict, map_location=device)
elif isinstance(checkpoint_path_or_dict, dict):
checkpoint = checkpoint_path_or_dict
# Sometimes models don't have the global_step item
if "global_step" in checkpoint:
......@@ -1318,8 +1321,9 @@ def download_from_original_stable_diffusion_ckpt(
image_size = 512
if controlnet is None and "control_stage_config" in original_config.model.params:
path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else ""
controlnet = convert_controlnet_checkpoint(
checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
checkpoint, original_config, path, image_size, upcast_attention, extract_ema
)
num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000
......@@ -1378,8 +1382,9 @@ def download_from_original_stable_diffusion_ckpt(
# Convert the UNet2DConditionModel model.
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
unet_config["upcast_attention"] = upcast_attention
path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else ""
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema
checkpoint, unet_config, path=path, extract_ema=extract_ema
)
ctx = init_empty_weights if is_accelerate_available() else nullcontext
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment