Unverified Commit 8a3f0c1f authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Conversion] Improve safetensors (#1989)

parent f6a5c359
...@@ -20,6 +20,8 @@ import re ...@@ -20,6 +20,8 @@ import re
import torch import torch
from safetensors import safe_open
try: try:
from omegaconf import OmegaConf from omegaconf import OmegaConf
...@@ -839,6 +841,11 @@ if __name__ == "__main__": ...@@ -839,6 +841,11 @@ if __name__ == "__main__":
" higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning." " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning."
), ),
) )
parser.add_argument(
"--from_safetensors",
action="store_true",
help="If `--checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.",
)
parser.add_argument( parser.add_argument(
"--upcast_attention", "--upcast_attention",
default=False, default=False,
...@@ -855,6 +862,12 @@ if __name__ == "__main__": ...@@ -855,6 +862,12 @@ if __name__ == "__main__":
image_size = args.image_size image_size = args.image_size
prediction_type = args.prediction_type prediction_type = args.prediction_type
if args.from_safetensors:
checkpoint = {}
with safe_open(args.checkpoint_path, framework="pt", device="cpu") as f:
for key in f.keys():
checkpoint[key] = f.get_tensor(key)
else:
if args.device is None: if args.device is None:
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(args.checkpoint_path, map_location=device) checkpoint = torch.load(args.checkpoint_path, map_location=device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment