Unverified Commit 7606fc3b authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

add support for early stopping (#59)

parent f93a9ba4
...@@ -44,6 +44,9 @@ parser.add_argument('--optimizer', ...@@ -44,6 +44,9 @@ parser.add_argument('--optimizer',
parser.add_argument('--optim_args', parser.add_argument('--optim_args',
help='Arguments to optimizers, in the format of json', help='Arguments to optimizers, in the format of json',
default='{}') default='{}')
parser.add_argument('--early_stopping',
help='Stop after epoches of no improvements',
default=math.inf, type=int)
parser = parser.parse_args() parser = parser.parse_args()
# set up the training # set up the training
...@@ -78,6 +81,7 @@ def hartree2kcal(x): ...@@ -78,6 +81,7 @@ def hartree2kcal(x):
@trainer.on(ignite.engine.Events.STARTED) @trainer.on(ignite.engine.Events.STARTED)
def initialize(trainer): def initialize(trainer):
trainer.state.best_validation_rmse = math.inf trainer.state.best_validation_rmse = math.inf
trainer.state.no_improve_count = 0
@trainer.on(ignite.engine.Events.EPOCH_STARTED) @trainer.on(ignite.engine.Events.EPOCH_STARTED)
...@@ -113,10 +117,16 @@ def validation_and_checkpoint(trainer): ...@@ -113,10 +117,16 @@ def validation_and_checkpoint(trainer):
# handle best validation RMSE # handle best validation RMSE
if rmse < trainer.state.best_validation_rmse: if rmse < trainer.state.best_validation_rmse:
trainer.state.no_improve_count = 0
trainer.state.best_validation_rmse = rmse trainer.state.best_validation_rmse = rmse
writer.add_scalar('best_validation_rmse_vs_epoch', rmse, writer.add_scalar('best_validation_rmse_vs_epoch', rmse,
trainer.state.epoch) trainer.state.epoch)
torch.save(nnp.state_dict(), parser.model_checkpoint) torch.save(nnp.state_dict(), parser.model_checkpoint)
else:
trainer.state.no_improve_count += 1
if trainer.state.no_improve_count > parser.early_stopping:
trainer.terminate()
@trainer.on(ignite.engine.Events.EPOCH_STARTED) @trainer.on(ignite.engine.Events.EPOCH_STARTED)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment