Unverified Commit 462956be authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

small tweaks for parsing thibaudz controlnet checkpoints (#3657)

parent 59900147
...@@ -75,6 +75,22 @@ if __name__ == "__main__": ...@@ -75,6 +75,22 @@ if __name__ == "__main__":
) )
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
# small workaround to get argparser to parse a boolean input as either true _or_ false
def parse_bool(string):
if string == "True":
return True
elif string == "False":
return False
else:
raise ValueError(f"could not parse string as bool {string}")
parser.add_argument(
"--use_linear_projection", help="Override for use linear projection", required=False, type=parse_bool
)
parser.add_argument("--cross_attention_dim", help="Override for cross attention_dim", required=False, type=int)
args = parser.parse_args() args = parser.parse_args()
controlnet = download_controlnet_from_original_ckpt( controlnet = download_controlnet_from_original_ckpt(
...@@ -86,6 +102,8 @@ if __name__ == "__main__": ...@@ -86,6 +102,8 @@ if __name__ == "__main__":
upcast_attention=args.upcast_attention, upcast_attention=args.upcast_attention,
from_safetensors=args.from_safetensors, from_safetensors=args.from_safetensors,
device=args.device, device=args.device,
use_linear_projection=args.use_linear_projection,
cross_attention_dim=args.cross_attention_dim,
) )
controlnet.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) controlnet.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
...@@ -339,41 +339,46 @@ def create_ldm_bert_config(original_config): ...@@ -339,41 +339,46 @@ def create_ldm_bert_config(original_config):
return config return config
def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, controlnet=False): def convert_ldm_unet_checkpoint(
checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False
):
""" """
Takes a state dict and a config, and returns a converted checkpoint. Takes a state dict and a config, and returns a converted checkpoint.
""" """
# extract state_dict for UNet if skip_extract_state_dict:
unet_state_dict = {} unet_state_dict = checkpoint
keys = list(checkpoint.keys())
if controlnet:
unet_key = "control_model."
else: else:
unet_key = "model.diffusion_model." # extract state_dict for UNet
unet_state_dict = {}
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA keys = list(checkpoint.keys())
if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
print(f"Checkpoint {path} has both EMA and non-EMA weights.") if controlnet:
print( unet_key = "control_model."
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" else:
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." unet_key = "model.diffusion_model."
)
for key in keys: # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
if key.startswith("model.diffusion_model"): if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) print(f"Checkpoint {path} has both EMA and non-EMA weights.")
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
else:
if sum(k.startswith("model_ema") for k in keys) > 100:
print( print(
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
" weights (usually better for inference), please make sure to add the `--extract_ema` flag." " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
) )
for key in keys:
if key.startswith("model.diffusion_model"):
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
else:
if sum(k.startswith("model_ema") for k in keys) > 100:
print(
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
)
for key in keys: for key in keys:
if key.startswith(unet_key): if key.startswith(unet_key):
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
new_checkpoint = {} new_checkpoint = {}
...@@ -956,17 +961,42 @@ def stable_unclip_image_noising_components( ...@@ -956,17 +961,42 @@ def stable_unclip_image_noising_components(
def convert_controlnet_checkpoint( def convert_controlnet_checkpoint(
checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema checkpoint,
original_config,
checkpoint_path,
image_size,
upcast_attention,
extract_ema,
use_linear_projection=None,
cross_attention_dim=None,
): ):
ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True) ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
ctrlnet_config["upcast_attention"] = upcast_attention ctrlnet_config["upcast_attention"] = upcast_attention
ctrlnet_config.pop("sample_size") ctrlnet_config.pop("sample_size")
if use_linear_projection is not None:
ctrlnet_config["use_linear_projection"] = use_linear_projection
if cross_attention_dim is not None:
ctrlnet_config["cross_attention_dim"] = cross_attention_dim
controlnet_model = ControlNetModel(**ctrlnet_config) controlnet_model = ControlNetModel(**ctrlnet_config)
# Some controlnet ckpt files are distributed independently from the rest of the
# model components i.e. https://huggingface.co/thibaud/controlnet-sd21/
if "time_embed.0.weight" in checkpoint:
skip_extract_state_dict = True
else:
skip_extract_state_dict = False
converted_ctrl_checkpoint = convert_ldm_unet_checkpoint( converted_ctrl_checkpoint = convert_ldm_unet_checkpoint(
checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True checkpoint,
ctrlnet_config,
path=checkpoint_path,
extract_ema=extract_ema,
controlnet=True,
skip_extract_state_dict=skip_extract_state_dict,
) )
controlnet_model.load_state_dict(converted_ctrl_checkpoint) controlnet_model.load_state_dict(converted_ctrl_checkpoint)
...@@ -1344,6 +1374,8 @@ def download_controlnet_from_original_ckpt( ...@@ -1344,6 +1374,8 @@ def download_controlnet_from_original_ckpt(
upcast_attention: Optional[bool] = None, upcast_attention: Optional[bool] = None,
device: str = None, device: str = None,
from_safetensors: bool = False, from_safetensors: bool = False,
use_linear_projection: Optional[bool] = None,
cross_attention_dim: Optional[bool] = None,
) -> DiffusionPipeline: ) -> DiffusionPipeline:
if not is_omegaconf_available(): if not is_omegaconf_available():
raise ValueError(BACKENDS_MAPPING["omegaconf"][1]) raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
...@@ -1381,7 +1413,14 @@ def download_controlnet_from_original_ckpt( ...@@ -1381,7 +1413,14 @@ def download_controlnet_from_original_ckpt(
raise ValueError("`control_stage_config` not present in original config") raise ValueError("`control_stage_config` not present in original config")
controlnet_model = convert_controlnet_checkpoint( controlnet_model = convert_controlnet_checkpoint(
checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema checkpoint,
original_config,
checkpoint_path,
image_size,
upcast_attention,
extract_ema,
use_linear_projection=use_linear_projection,
cross_attention_dim=cross_attention_dim,
) )
return controlnet_model return controlnet_model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment