Unverified Commit 22ff44fd authored by Hu Ye's avatar Hu Ye Committed by GitHub
Browse files

save grad_scaler if use amp for better resume (#4923)


Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 9b034e17
...@@ -325,6 +325,8 @@ def main(args): ...@@ -325,6 +325,8 @@ def main(args):
args.start_epoch = checkpoint["epoch"] + 1 args.start_epoch = checkpoint["epoch"] + 1
if model_ema: if model_ema:
model_ema.load_state_dict(checkpoint["model_ema"]) model_ema.load_state_dict(checkpoint["model_ema"])
if scaler:
scaler.load_state_dict(checkpoint["scaler"])
if args.test_only: if args.test_only:
# We disable the cudnn benchmarking because it can noticeably affect the accuracy # We disable the cudnn benchmarking because it can noticeably affect the accuracy
...@@ -356,6 +358,8 @@ def main(args): ...@@ -356,6 +358,8 @@ def main(args):
} }
if model_ema: if model_ema:
checkpoint["model_ema"] = model_ema.state_dict() checkpoint["model_ema"] = model_ema.state_dict()
if scaler:
checkpoint["scaler"] = scaler.state_dict()
utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth")) utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth")) utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment