Unverified Commit 245614f7 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Fix NeuroChem trainer pretraining criterion, MAE vs MaxAE (#225)

parent 4dcd6ab0
...@@ -4,7 +4,7 @@ from __future__ import absolute_import ...@@ -4,7 +4,7 @@ from __future__ import absolute_import
import torch import torch
from . import utils from . import utils
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
from ignite.metrics import Metric, RootMeanSquaredError from ignite.metrics import Metric, RootMeanSquaredError, MeanAbsoluteError
from ignite.contrib.metrics.regression import MaximumAbsoluteError from ignite.contrib.metrics.regression import MaximumAbsoluteError
...@@ -116,5 +116,10 @@ def MaxAEMetric(key): ...@@ -116,5 +116,10 @@ def MaxAEMetric(key):
return DictMetric(key, MaximumAbsoluteError()) return DictMetric(key, MaximumAbsoluteError())
def MAEMetric(key):
"""Create max absolute error metric on key."""
return DictMetric(key, MeanAbsoluteError())
__all__ = ['Container', 'MSELoss', 'TransformedLoss', 'RMSEMetric', __all__ = ['Container', 'MSELoss', 'TransformedLoss', 'RMSEMetric',
'MaxAEMetric'] 'MaxAEMetric']
...@@ -399,7 +399,7 @@ if sys.version_info[0] > 2: ...@@ -399,7 +399,7 @@ if sys.version_info[0] > 2:
checkpoint_name='model.pt'): checkpoint_name='model.pt'):
try: try:
import ignite import ignite
from ..ignite import Container, MSELoss, TransformedLoss, RMSEMetric, MaxAEMetric from ..ignite import Container, MSELoss, TransformedLoss, RMSEMetric, MAEMetric, MaxAEMetric
from ..data import BatchedANIDataset # noqa: E402 from ..data import BatchedANIDataset # noqa: E402
from ..data import AEVCacheLoader # noqa: E402 from ..data import AEVCacheLoader # noqa: E402
except ImportError: except ImportError:
...@@ -418,6 +418,7 @@ if sys.version_info[0] > 2: ...@@ -418,6 +418,7 @@ if sys.version_info[0] > 2:
self.imports.TransformedLoss = TransformedLoss self.imports.TransformedLoss = TransformedLoss
self.imports.RMSEMetric = RMSEMetric self.imports.RMSEMetric = RMSEMetric
self.imports.MaxAEMetric = MaxAEMetric self.imports.MaxAEMetric = MaxAEMetric
self.imports.MAEMetric = MAEMetric
self.imports.BatchedANIDataset = BatchedANIDataset self.imports.BatchedANIDataset = BatchedANIDataset
self.imports.AEVCacheLoader = AEVCacheLoader self.imports.AEVCacheLoader = AEVCacheLoader
...@@ -681,12 +682,13 @@ if sys.version_info[0] > 2: ...@@ -681,12 +682,13 @@ if sys.version_info[0] > 2:
self.container, self.container,
metrics={ metrics={
'RMSE': self.imports.RMSEMetric('energies'), 'RMSE': self.imports.RMSEMetric('energies'),
'MAE': self.imports.MAEMetric('energies'),
'MaxAE': self.imports.MaxAEMetric('energies'), 'MaxAE': self.imports.MaxAEMetric('energies'),
} }
) )
evaluator.run(dataset) evaluator.run(dataset)
metrics = evaluator.state.metrics metrics = evaluator.state.metrics
return hartree2kcal(metrics['RMSE']), hartree2kcal(metrics['MaxAE']) return hartree2kcal(metrics['RMSE']), hartree2kcal(metrics['MAE']), hartree2kcal(metrics['MaxAE'])
def load_data(self, training_path, validation_path): def load_data(self, training_path, validation_path):
"""Load training and validation dataset from file. """Load training and validation dataset from file.
...@@ -746,7 +748,7 @@ if sys.version_info[0] > 2: ...@@ -746,7 +748,7 @@ if sys.version_info[0] > 2:
@trainer.on(self.ignite.engine.Events.EPOCH_STARTED) @trainer.on(self.ignite.engine.Events.EPOCH_STARTED)
def validation_and_checkpoint(trainer): def validation_and_checkpoint(trainer):
trainer.state.rmse, trainer.state.mae = \ trainer.state.rmse, trainer.state.mae, trainer.state.maxae = \
self.evaluate(self.validation_set) self.evaluate(self.validation_set)
if trainer.state.rmse < self.best_validation_rmse: if trainer.state.rmse < self.best_validation_rmse:
trainer.state.no_improve_count = 0 trainer.state.no_improve_count = 0
...@@ -772,6 +774,8 @@ if sys.version_info[0] > 2: ...@@ -772,6 +774,8 @@ if sys.version_info[0] > 2:
trainer.state.rmse, epoch) trainer.state.rmse, epoch)
self.tensorboard.add_scalar('validation_mae_vs_epoch', self.tensorboard.add_scalar('validation_mae_vs_epoch',
trainer.state.mae, epoch) trainer.state.mae, epoch)
self.tensorboard.add_scalar('validation_maxae_vs_epoch',
trainer.state.maxae, epoch)
self.tensorboard.add_scalar( self.tensorboard.add_scalar(
'best_validation_rmse_vs_epoch', 'best_validation_rmse_vs_epoch',
self.best_validation_rmse, epoch) self.best_validation_rmse, epoch)
...@@ -779,14 +783,16 @@ if sys.version_info[0] > 2: ...@@ -779,14 +783,16 @@ if sys.version_info[0] > 2:
'no_improve_count_vs_epoch', 'no_improve_count_vs_epoch',
trainer.state.no_improve_count, epoch) trainer.state.no_improve_count, epoch)
# compute training RMSE and MAE # compute training RMSE, MAE and MaxAE
if epoch % self.training_eval_every == 1: if epoch % self.training_eval_every == 1:
training_rmse, training_mae = \ training_rmse, training_mae, training_maxae = \
self.evaluate(self.training_set) self.evaluate(self.training_set)
self.tensorboard.add_scalar( self.tensorboard.add_scalar(
'training_rmse_vs_epoch', training_rmse, epoch) 'training_rmse_vs_epoch', training_rmse, epoch)
self.tensorboard.add_scalar( self.tensorboard.add_scalar(
'training_mae_vs_epoch', training_mae, epoch) 'training_mae_vs_epoch', training_mae, epoch)
self.tensorboard.add_scalar(
'training_mae_vs_epoch', training_maxae, epoch)
@trainer.on(self.ignite.engine.Events.ITERATION_COMPLETED) @trainer.on(self.ignite.engine.Events.ITERATION_COMPLETED)
def log_loss(trainer): def log_loss(trainer):
...@@ -806,13 +812,13 @@ if sys.version_info[0] > 2: ...@@ -806,13 +812,13 @@ if sys.version_info[0] > 2:
@trainer.on(self.ignite.engine.Events.EPOCH_STARTED) @trainer.on(self.ignite.engine.Events.EPOCH_STARTED)
def terminate_if_smaller_enough(trainer): def terminate_if_smaller_enough(trainer):
if trainer.state.mae < 1.0: if trainer.state.rmse < 10.0:
trainer.terminate() trainer.terminate()
trainer.run(self.training_set, max_epochs=math.inf) trainer.run(self.training_set, max_epochs=math.inf)
while lr > self.min_lr: while lr > self.min_lr:
optimizer = AdamW(self.model.parameters(), lr=lr) optimizer = AdamW(self.parameters, lr=lr)
trainer = self.ignite.engine.create_supervised_trainer( trainer = self.ignite.engine.create_supervised_trainer(
self.container, optimizer, self.exp_loss) self.container, optimizer, self.exp_loss)
decorate(trainer) decorate(trainer)
......
...@@ -105,9 +105,9 @@ class AdamW(Optimizer): ...@@ -105,9 +105,9 @@ class AdamW(Optimizer):
bias_correction2 = 1 - beta2 ** state['step'] bias_correction2 = 1 - beta2 ** state['step']
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
p.data.addcdiv_(-step_size, exp_avg, denom)
if group['weight_decay'] != 0: if group['weight_decay'] != 0:
p.data.add_(-group['weight_decay'], p.data) p.data.add_(-group['weight_decay'], p.data)
p.data.addcdiv_(-step_size, exp_avg, denom)
return loss return loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment