Unverified Commit 9329257b authored by D. Khuê Lê-Huu's avatar D. Khuê Lê-Huu Committed by GitHub
Browse files

Fix training resuming in references/segmentation (#2142)

parent 1affa2e8
...@@ -128,10 +128,6 @@ def main(args): ...@@ -128,10 +128,6 @@ def main(args):
if args.distributed: if args.distributed:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
model.load_state_dict(checkpoint['model'])
model_without_ddp = model model_without_ddp = model
if args.distributed: if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
...@@ -157,8 +153,15 @@ def main(args): ...@@ -157,8 +153,15 @@ def main(args):
optimizer, optimizer,
lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9) lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9)
if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1
start_time = time.time() start_time = time.time()
for epoch in range(args.epochs): for epoch in range(args.start_epoch, args.epochs):
if args.distributed: if args.distributed:
train_sampler.set_epoch(epoch) train_sampler.set_epoch(epoch)
train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, args.print_freq) train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, args.print_freq)
...@@ -168,6 +171,7 @@ def main(args): ...@@ -168,6 +171,7 @@ def main(args):
{ {
'model': model_without_ddp.state_dict(), 'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(), 'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch, 'epoch': epoch,
'args': args 'args': args
}, },
...@@ -201,6 +205,8 @@ def parse_args(): ...@@ -201,6 +205,8 @@ def parse_args():
parser.add_argument('--print-freq', default=10, type=int, help='print frequency') parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
parser.add_argument('--output-dir', default='.', help='path where to save') parser.add_argument('--output-dir', default='.', help='path where to save')
parser.add_argument('--resume', default='', help='resume from checkpoint') parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument( parser.add_argument(
"--test-only", "--test-only",
dest="test_only", dest="test_only",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment