Unverified Commit a78d0d83 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Add checkpoints used for preemption. (#3789)

parent c2ab0c59
......@@ -188,13 +188,19 @@ def main(args):
train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq)
lr_scheduler.step()
if args.output_dir:
utils.save_on_master({
checkpoint = {
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'args': args,
'epoch': epoch},
'epoch': epoch
}
utils.save_on_master(
checkpoint,
os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
utils.save_on_master(
checkpoint,
os.path.join(args.output_dir, 'checkpoint.pth'))
# evaluate after every epoch
evaluate(model, data_loader_test, device=device)
......
......@@ -157,15 +157,19 @@ def main(args):
train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, args.print_freq)
confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes)
print(confmat)
utils.save_on_master(
{
checkpoint = {
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
'args': args
},
}
utils.save_on_master(
checkpoint,
os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
utils.save_on_master(
checkpoint,
os.path.join(args.output_dir, 'checkpoint.pth'))
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment