Unverified Commit 8a3f0c1f authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Conversion] Improve safetensors (#1989)

parent f6a5c359
......@@ -20,6 +20,8 @@ import re
import torch
from safetensors import safe_open
try:
from omegaconf import OmegaConf
......@@ -839,6 +841,11 @@ if __name__ == "__main__":
" higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning."
),
)
parser.add_argument(
"--from_safetensors",
action="store_true",
help="If `--checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.",
)
parser.add_argument(
"--upcast_attention",
default=False,
......@@ -855,11 +862,17 @@ if __name__ == "__main__":
image_size = args.image_size
prediction_type = args.prediction_type
if args.device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(args.checkpoint_path, map_location=device)
if args.from_safetensors:
checkpoint = {}
with safe_open(args.checkpoint_path, framework="pt", device="cpu") as f:
for key in f.keys():
checkpoint[key] = f.get_tensor(key)
else:
checkpoint = torch.load(args.checkpoint_path, map_location=args.device)
if args.device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(args.checkpoint_path, map_location=device)
else:
checkpoint = torch.load(args.checkpoint_path, map_location=args.device)
# Sometimes models don't have the global_step item
if "global_step" in checkpoint:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment