Commit a2e7dabb authored by Zhang690683220's avatar Zhang690683220
Browse files

fix lr resume for non-deepspeed ckpts

parent 69a8a287
......@@ -260,7 +260,11 @@ def main(args):
model_module = OpenFoldWrapper(config)
if(args.resume_from_ckpt):
last_global_step = get_global_step_from_zero_checkpoint(args.resume_from_ckpt)
if(os.path.isdir(args.resume_from_ckpt)):
last_global_step = get_global_step_from_zero_checkpoint(args.resume_from_ckpt)
else:
sd = torch.load(args.resume_from_ckpt)
last_global_step = int(sd['global_step'])
model_module.resume_last_lr_step(last_global_step)
logging.info("Successfully loaded last lr step...")
if(args.resume_from_ckpt and args.resume_model_weights_only):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment