Unverified Commit 9329257b authored by D. Khuê Lê-Huu's avatar D. Khuê Lê-Huu Committed by GitHub
Browse files

Fix training resuming in references/segmentation (#2142)

parent 1affa2e8
......@@ -128,10 +128,6 @@ def main(args):
if args.distributed:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
model.load_state_dict(checkpoint['model'])
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
......@@ -157,8 +153,15 @@ def main(args):
optimizer,
lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9)
if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1
start_time = time.time()
for epoch in range(args.epochs):
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, args.print_freq)
......@@ -168,6 +171,7 @@ def main(args):
{
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
'args': args
},
......@@ -201,6 +205,8 @@ def parse_args():
parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
parser.add_argument('--output-dir', default='.', help='path where to save')
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument(
"--test-only",
dest="test_only",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment