Unverified Commit eb96ff0d authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Safetensor loading in AnimateDiff conversion scripts (#7764)

* update

* update
parent a38dd795
import argparse import argparse
import torch import torch
from safetensors.torch import save_file from safetensors.torch import load_file, save_file
def convert_motion_module(original_state_dict): def convert_motion_module(original_state_dict):
...@@ -34,6 +34,9 @@ def get_args(): ...@@ -34,6 +34,9 @@ def get_args():
if __name__ == "__main__": if __name__ == "__main__":
args = get_args() args = get_args()
if args.ckpt_path.endswith(".safetensors"):
state_dict = load_file(args.ckpt_path)
else:
state_dict = torch.load(args.ckpt_path, map_location="cpu") state_dict = torch.load(args.ckpt_path, map_location="cpu")
if "state_dict" in state_dict.keys(): if "state_dict" in state_dict.keys():
......
import argparse import argparse
import torch import torch
from safetensors.torch import load_file
from diffusers import MotionAdapter from diffusers import MotionAdapter
...@@ -38,7 +39,11 @@ def get_args(): ...@@ -38,7 +39,11 @@ def get_args():
if __name__ == "__main__": if __name__ == "__main__":
args = get_args() args = get_args()
if args.ckpt_path.endswith(".safetensors"):
state_dict = load_file(args.ckpt_path)
else:
state_dict = torch.load(args.ckpt_path, map_location="cpu") state_dict = torch.load(args.ckpt_path, map_location="cpu")
if "state_dict" in state_dict.keys(): if "state_dict" in state_dict.keys():
state_dict = state_dict["state_dict"] state_dict = state_dict["state_dict"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment