Commit f4a82243 authored by MultiK's avatar MultiK Committed by Francisco Massa
Browse files

fix a little bug about resume (#1628)

* fix a little bug about resume

When resuming, we need to start from the last epoch not 0.

* the second way for resuming

the second way for resuming
parent 10f34160
...@@ -114,6 +114,7 @@ def main(args): ...@@ -114,6 +114,7 @@ def main(args):
model_without_ddp.load_state_dict(checkpoint['model']) model_without_ddp.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer']) optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1
if args.test_only: if args.test_only:
evaluate(model, data_loader_test, device=device) evaluate(model, data_loader_test, device=device)
...@@ -121,7 +122,7 @@ def main(args): ...@@ -121,7 +122,7 @@ def main(args):
print("Start training") print("Start training")
start_time = time.time() start_time = time.time()
for epoch in range(args.epochs): for epoch in range(args.start_epoch, args.epochs):
if args.distributed: if args.distributed:
train_sampler.set_epoch(epoch) train_sampler.set_epoch(epoch)
train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq) train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq)
...@@ -131,7 +132,8 @@ def main(args): ...@@ -131,7 +132,8 @@ def main(args):
'model': model_without_ddp.state_dict(), 'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(), 'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(),
'args': args}, 'args': args,
'epoch': epoch},
os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
# evaluate after every epoch # evaluate after every epoch
...@@ -171,6 +173,7 @@ if __name__ == "__main__": ...@@ -171,6 +173,7 @@ if __name__ == "__main__":
parser.add_argument('--print-freq', default=20, type=int, help='print frequency') parser.add_argument('--print-freq', default=20, type=int, help='print frequency')
parser.add_argument('--output-dir', default='.', help='path where to save') parser.add_argument('--output-dir', default='.', help='path where to save')
parser.add_argument('--resume', default='', help='resume from checkpoint') parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--start_epoch', default=0, type=int, help='start epoch')
parser.add_argument('--aspect-ratio-group-factor', default=3, type=int) parser.add_argument('--aspect-ratio-group-factor', default=3, type=int)
parser.add_argument( parser.add_argument(
"--test-only", "--test-only",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment