Commit 7f660099 authored by Gao, Xiang's avatar Gao, Xiang
Browse files

log training rmse and best validation rmse

parent 3e296069
...@@ -22,7 +22,10 @@ parser.add_argument('--model_checkpoint', ...@@ -22,7 +22,10 @@ parser.add_argument('--model_checkpoint',
default='model.pt') default='model.pt')
parser.add_argument('-m', '--max_epochs', parser.add_argument('-m', '--max_epochs',
help='Maximum number of epoches', help='Maximum number of epoches',
default=10, type=int) default=100, type=int)
parser.add_argument('--training_rmse_every',
help='Compute training RMSE every epoches',
default=20, type=int)
parser.add_argument('-d', '--device', parser.add_argument('-d', '--device',
help='Device of modules and tensors', help='Device of modules and tensors',
default=('cuda' if torch.cuda.is_available() else 'cpu')) default=('cuda' if torch.cuda.is_available() else 'cpu'))
...@@ -94,12 +97,25 @@ def finalize_tqdm(trainer): ...@@ -94,12 +97,25 @@ def finalize_tqdm(trainer):
@trainer.on(ignite.engine.Events.EPOCH_STARTED) @trainer.on(ignite.engine.Events.EPOCH_STARTED)
def validation_and_checkpoint(trainer): def validation_and_checkpoint(trainer):
# compute validation RMSE
evaluator.run(validation) evaluator.run(validation)
metrics = evaluator.state.metrics metrics = evaluator.state.metrics
rmse = hartree2kcal(metrics['RMSE']) rmse = hartree2kcal(metrics['RMSE'])
writer.add_scalar('validation_rmse_vs_epoch', rmse, trainer.state.epoch) writer.add_scalar('validation_rmse_vs_epoch', rmse, trainer.state.epoch)
# compute training RMSE
if trainer.state.epoch % parser.training_rmse_every == 0:
evaluator.run(training)
metrics = evaluator.state.metrics
rmse = hartree2kcal(metrics['RMSE'])
writer.add_scalar('training_rmse_vs_epoch', rmse,
trainer.state.epoch)
# handle best validation RMSE
if rmse < trainer.state.best_validation_rmse: if rmse < trainer.state.best_validation_rmse:
trainer.state.best_validation_rmse = rmse trainer.state.best_validation_rmse = rmse
writer.add_scalar('best_validation_rmse_vs_epoch', rmse,
trainer.state.epoch)
torch.save(nnp.state_dict(), parser.model_checkpoint) torch.save(nnp.state_dict(), parser.model_checkpoint)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment