Unverified Commit 3e296069 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

checkpoint if validation gets better result (#55)

parent f50cc0b4
...@@ -48,7 +48,7 @@ device = torch.device(parser.device) ...@@ -48,7 +48,7 @@ device = torch.device(parser.device)
writer = tensorboardX.SummaryWriter(log_dir=parser.log) writer = tensorboardX.SummaryWriter(log_dir=parser.log)
start = timeit.default_timer() start = timeit.default_timer()
nnp, shift_energy = model.get_or_create_model('/tmp/model.pt', nnp, shift_energy = model.get_or_create_model(parser.model_checkpoint,
True, device=device) True, device=device)
training, validation, testing = torchani.data.load_or_create( training, validation, testing = torchani.data.load_or_create(
parser.dataset_checkpoint, parser.dataset_path, parser.chunk_size, parser.dataset_checkpoint, parser.dataset_path, parser.chunk_size,
...@@ -72,6 +72,11 @@ def hartree2kcal(x): ...@@ -72,6 +72,11 @@ def hartree2kcal(x):
return 627.509 * x return 627.509 * x
@trainer.on(ignite.engine.Events.STARTED)
def initialize(trainer):
trainer.state.best_validation_rmse = math.inf
@trainer.on(ignite.engine.Events.EPOCH_STARTED) @trainer.on(ignite.engine.Events.EPOCH_STARTED)
def init_tqdm(trainer): def init_tqdm(trainer):
trainer.state.tqdm = tqdm.tqdm(total=len(training), desc='epoch') trainer.state.tqdm = tqdm.tqdm(total=len(training), desc='epoch')
...@@ -88,11 +93,14 @@ def finalize_tqdm(trainer): ...@@ -88,11 +93,14 @@ def finalize_tqdm(trainer):
@trainer.on(ignite.engine.Events.EPOCH_STARTED) @trainer.on(ignite.engine.Events.EPOCH_STARTED)
def log_validation_results(trainer): def validation_and_checkpoint(trainer):
evaluator.run(validation) evaluator.run(validation)
metrics = evaluator.state.metrics metrics = evaluator.state.metrics
rmse = hartree2kcal(metrics['RMSE']) rmse = hartree2kcal(metrics['RMSE'])
writer.add_scalar('validation_rmse_vs_epoch', rmse, trainer.state.epoch) writer.add_scalar('validation_rmse_vs_epoch', rmse, trainer.state.epoch)
if rmse < trainer.state.best_validation_rmse:
trainer.state.best_validation_rmse = rmse
torch.save(nnp.state_dict(), parser.model_checkpoint)
@trainer.on(ignite.engine.Events.EPOCH_STARTED) @trainer.on(ignite.engine.Events.EPOCH_STARTED)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment