Unverified Commit 1ebda734 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Load variables when --resume /path/to/checkpoint --test-only (#3285)


Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent 24f5fa57
......@@ -133,11 +133,6 @@ def main(args):
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
if args.test_only:
confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes)
print(confmat)
return
params_to_optimize = [
{"params": [p for p in model_without_ddp.backbone.parameters() if p.requires_grad]},
{"params": [p for p in model_without_ddp.classifier.parameters() if p.requires_grad]},
......@@ -155,10 +150,16 @@ def main(args):
if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1
model_without_ddp.load_state_dict(checkpoint['model'], strict=not args.test_only)
if not args.test_only:
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1
if args.test_only:
confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes)
print(confmat)
return
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment