Unverified Commit 3e296069 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

checkpoint if validation gets better result (#55)

parent f50cc0b4
......@@ -48,7 +48,7 @@ device = torch.device(parser.device)
writer = tensorboardX.SummaryWriter(log_dir=parser.log)
start = timeit.default_timer()
nnp, shift_energy = model.get_or_create_model('/tmp/model.pt',
nnp, shift_energy = model.get_or_create_model(parser.model_checkpoint,
True, device=device)
training, validation, testing = torchani.data.load_or_create(
parser.dataset_checkpoint, parser.dataset_path, parser.chunk_size,
......@@ -72,6 +72,11 @@ def hartree2kcal(x):
return 627.509 * x
@trainer.on(ignite.engine.Events.STARTED)
def initialize(trainer):
trainer.state.best_validation_rmse = math.inf
@trainer.on(ignite.engine.Events.EPOCH_STARTED)
def init_tqdm(trainer):
trainer.state.tqdm = tqdm.tqdm(total=len(training), desc='epoch')
......@@ -88,11 +93,14 @@ def finalize_tqdm(trainer):
@trainer.on(ignite.engine.Events.EPOCH_STARTED)
def log_validation_results(trainer):
def validation_and_checkpoint(trainer):
evaluator.run(validation)
metrics = evaluator.state.metrics
rmse = hartree2kcal(metrics['RMSE'])
writer.add_scalar('validation_rmse_vs_epoch', rmse, trainer.state.epoch)
if rmse < trainer.state.best_validation_rmse:
trainer.state.best_validation_rmse = rmse
torch.save(nnp.state_dict(), parser.model_checkpoint)
@trainer.on(ignite.engine.Events.EPOCH_STARTED)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment