Unverified Commit cb432c4e authored by cmdr2's avatar cmdr2 Committed by GitHub
Browse files

Allow passing a checkpoint state_dict to convert_from_ckpt (instead of just a string path) (#4653)


Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent b7b1a30b
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import re import re
from contextlib import nullcontext from contextlib import nullcontext
from io import BytesIO from io import BytesIO
from typing import Optional from typing import Optional, Union, Dict
import requests import requests
import torch import torch
...@@ -1111,7 +1111,7 @@ def convert_controlnet_checkpoint( ...@@ -1111,7 +1111,7 @@ def convert_controlnet_checkpoint(
def download_from_original_stable_diffusion_ckpt( def download_from_original_stable_diffusion_ckpt(
checkpoint_path: str, checkpoint_path_or_dict: Union[str, Dict[str, torch.Tensor]],
original_config_file: str = None, original_config_file: str = None,
image_size: Optional[int] = None, image_size: Optional[int] = None,
prediction_type: str = None, prediction_type: str = None,
...@@ -1144,7 +1144,7 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1144,7 +1144,7 @@ def download_from_original_stable_diffusion_ckpt(
recommended that you override the default values and/or supply an `original_config_file` wherever possible. recommended that you override the default values and/or supply an `original_config_file` wherever possible.
Args: Args:
checkpoint_path (`str`): Path to `.ckpt` file. checkpoint_path_or_dict (`str` or `dict`): Path to `.ckpt` file, or the state dict.
original_config_file (`str`): original_config_file (`str`):
Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically
inferred by looking for a key that only exists in SD2.0 models. inferred by looking for a key that only exists in SD2.0 models.
...@@ -1226,16 +1226,19 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1226,16 +1226,19 @@ def download_from_original_stable_diffusion_ckpt(
from omegaconf import OmegaConf from omegaconf import OmegaConf
if isinstance(checkpoint_path_or_dict, str):
if from_safetensors: if from_safetensors:
from safetensors.torch import load_file as safe_load from safetensors.torch import load_file as safe_load
checkpoint = safe_load(checkpoint_path, device="cpu") checkpoint = safe_load(checkpoint_path_or_dict, device="cpu")
else: else:
if device is None: if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(checkpoint_path, map_location=device) checkpoint = torch.load(checkpoint_path_or_dict, map_location=device)
else: else:
checkpoint = torch.load(checkpoint_path, map_location=device) checkpoint = torch.load(checkpoint_path_or_dict, map_location=device)
elif isinstance(checkpoint_path_or_dict, dict):
checkpoint = checkpoint_path_or_dict
# Sometimes models don't have the global_step item # Sometimes models don't have the global_step item
if "global_step" in checkpoint: if "global_step" in checkpoint:
...@@ -1318,8 +1321,9 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1318,8 +1321,9 @@ def download_from_original_stable_diffusion_ckpt(
image_size = 512 image_size = 512
if controlnet is None and "control_stage_config" in original_config.model.params: if controlnet is None and "control_stage_config" in original_config.model.params:
path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else ""
controlnet = convert_controlnet_checkpoint( controlnet = convert_controlnet_checkpoint(
checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema checkpoint, original_config, path, image_size, upcast_attention, extract_ema
) )
num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000 num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000
...@@ -1378,8 +1382,9 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1378,8 +1382,9 @@ def download_from_original_stable_diffusion_ckpt(
# Convert the UNet2DConditionModel model. # Convert the UNet2DConditionModel model.
unet_config = create_unet_diffusers_config(original_config, image_size=image_size) unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
unet_config["upcast_attention"] = upcast_attention unet_config["upcast_attention"] = upcast_attention
path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else ""
converted_unet_checkpoint = convert_ldm_unet_checkpoint( converted_unet_checkpoint = convert_ldm_unet_checkpoint(
checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema checkpoint, unet_config, path=path, extract_ema=extract_ema
) )
ctx = init_empty_weights if is_accelerate_available() else nullcontext ctx = init_empty_weights if is_accelerate_available() else nullcontext
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment