"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "88833e6f16c86e3ab77399ad7e7b650bad3c460c"
Unverified Commit 8f2d13c6 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Fix setting fp16 dtype in AnimateDiff convert script. (#7127)

* update

* update
parent fcfa270f
...@@ -30,6 +30,7 @@ def get_args(): ...@@ -30,6 +30,7 @@ def get_args():
parser.add_argument("--output_path", type=str, required=True) parser.add_argument("--output_path", type=str, required=True)
parser.add_argument("--use_motion_mid_block", action="store_true") parser.add_argument("--use_motion_mid_block", action="store_true")
parser.add_argument("--motion_max_seq_length", type=int, default=32) parser.add_argument("--motion_max_seq_length", type=int, default=32)
parser.add_argument("--save_fp16", action="store_true")
return parser.parse_args() return parser.parse_args()
...@@ -48,4 +49,6 @@ if __name__ == "__main__": ...@@ -48,4 +49,6 @@ if __name__ == "__main__":
# skip loading position embeddings # skip loading position embeddings
adapter.load_state_dict(conv_state_dict, strict=False) adapter.load_state_dict(conv_state_dict, strict=False)
adapter.save_pretrained(args.output_path) adapter.save_pretrained(args.output_path)
adapter.save_pretrained(args.output_path, variant="fp16", torch_dtype=torch.float16)
if args.save_fp16:
adapter.to(torch.float16).save_pretrained(args.output_path, variant="fp16")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment