Unverified Commit a78d0d83 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Add checkpoints used for preemption. (#3789)

parent c2ab0c59
...@@ -188,13 +188,19 @@ def main(args): ...@@ -188,13 +188,19 @@ def main(args):
train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq) train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq)
lr_scheduler.step() lr_scheduler.step()
if args.output_dir: if args.output_dir:
utils.save_on_master({ checkpoint = {
'model': model_without_ddp.state_dict(), 'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(), 'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(),
'args': args, 'args': args,
'epoch': epoch}, 'epoch': epoch
}
utils.save_on_master(
checkpoint,
os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
utils.save_on_master(
checkpoint,
os.path.join(args.output_dir, 'checkpoint.pth'))
# evaluate after every epoch # evaluate after every epoch
evaluate(model, data_loader_test, device=device) evaluate(model, data_loader_test, device=device)
......
...@@ -157,15 +157,19 @@ def main(args): ...@@ -157,15 +157,19 @@ def main(args):
train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, args.print_freq) train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, args.print_freq)
confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes) confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes)
print(confmat) print(confmat)
utils.save_on_master( checkpoint = {
{
'model': model_without_ddp.state_dict(), 'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(), 'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch, 'epoch': epoch,
'args': args 'args': args
}, }
utils.save_on_master(
checkpoint,
os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
utils.save_on_master(
checkpoint,
os.path.join(args.output_dir, 'checkpoint.pth'))
total_time = time.time() - start_time total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time))) total_time_str = str(datetime.timedelta(seconds=int(total_time)))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment