Unverified Commit d1d0b8af authored by Amiha's avatar Amiha Committed by GitHub
Browse files

Don't use bare prints in a library (#3991)

parent 04ddad48
...@@ -391,8 +391,8 @@ def convert_ldm_unet_checkpoint( ...@@ -391,8 +391,8 @@ def convert_ldm_unet_checkpoint(
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
print(f"Checkpoint {path} has both EMA and non-EMA weights.") logger.warning(f"Checkpoint {path} has both EMA and non-EMA weights.")
print( logger.warning(
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
) )
...@@ -402,7 +402,7 @@ def convert_ldm_unet_checkpoint( ...@@ -402,7 +402,7 @@ def convert_ldm_unet_checkpoint(
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
else: else:
if sum(k.startswith("model_ema") for k in keys) > 100: if sum(k.startswith("model_ema") for k in keys) > 100:
print( logger.warning(
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
" weights (usually better for inference), please make sure to add the `--extract_ema` flag." " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
) )
...@@ -1183,7 +1183,7 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1183,7 +1183,7 @@ def download_from_original_stable_diffusion_ckpt(
if "global_step" in checkpoint: if "global_step" in checkpoint:
global_step = checkpoint["global_step"] global_step = checkpoint["global_step"]
else: else:
print("global_step key not found in model") logger.warning("global_step key not found in model")
global_step = None global_step = None
# NOTE: this while loop isn't great but this controlnet checkpoint has one additional # NOTE: this while loop isn't great but this controlnet checkpoint has one additional
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment