Unverified Commit 651c5adf authored by 蓝色的秋风's avatar 蓝色的秋风 Committed by GitHub
Browse files

[Conversion] Support convert diffusers to safetensors (#1996)

fix: support diffusers to safetensors
parent cc2cc00d
...@@ -8,6 +8,8 @@ import re ...@@ -8,6 +8,8 @@ import re
import torch import torch
from safetensors.torch import save_file
# =================# # =================#
# UNet Conversion # # UNet Conversion #
...@@ -266,6 +268,9 @@ if __name__ == "__main__": ...@@ -266,6 +268,9 @@ if __name__ == "__main__":
parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.") parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.") parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
parser.add_argument("--half", action="store_true", help="Save weights in half precision.") parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
parser.add_argument(
"--use_safetensors", action="store_true", help="Save weights use safetensors, default is ckpt."
)
args = parser.parse_args() args = parser.parse_args()
...@@ -306,5 +311,9 @@ if __name__ == "__main__": ...@@ -306,5 +311,9 @@ if __name__ == "__main__":
state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict} state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
if args.half: if args.half:
state_dict = {k: v.half() for k, v in state_dict.items()} state_dict = {k: v.half() for k, v in state_dict.items()}
state_dict = {"state_dict": state_dict}
torch.save(state_dict, args.checkpoint_path) if args.use_safetensors:
save_file(state_dict, args.checkpoint_path)
else:
state_dict = {"state_dict": state_dict}
torch.save(state_dict, args.checkpoint_path)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment