Unverified Commit 245614f7 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Fix NeuroChem trainer pretraining criterion, MAE vs MaxAE (#225)

parent 4dcd6ab0
......@@ -4,7 +4,7 @@ from __future__ import absolute_import
import torch
from . import utils
from torch.nn.modules.loss import _Loss
from ignite.metrics import Metric, RootMeanSquaredError
from ignite.metrics import Metric, RootMeanSquaredError, MeanAbsoluteError
from ignite.contrib.metrics.regression import MaximumAbsoluteError
......@@ -116,5 +116,10 @@ def MaxAEMetric(key):
return DictMetric(key, MaximumAbsoluteError())
def MAEMetric(key):
"""Create max absolute error metric on key."""
return DictMetric(key, MeanAbsoluteError())
__all__ = ['Container', 'MSELoss', 'TransformedLoss', 'RMSEMetric',
'MaxAEMetric']
......@@ -399,7 +399,7 @@ if sys.version_info[0] > 2:
checkpoint_name='model.pt'):
try:
import ignite
from ..ignite import Container, MSELoss, TransformedLoss, RMSEMetric, MaxAEMetric
from ..ignite import Container, MSELoss, TransformedLoss, RMSEMetric, MAEMetric, MaxAEMetric
from ..data import BatchedANIDataset # noqa: E402
from ..data import AEVCacheLoader # noqa: E402
except ImportError:
......@@ -418,6 +418,7 @@ if sys.version_info[0] > 2:
self.imports.TransformedLoss = TransformedLoss
self.imports.RMSEMetric = RMSEMetric
self.imports.MaxAEMetric = MaxAEMetric
self.imports.MAEMetric = MAEMetric
self.imports.BatchedANIDataset = BatchedANIDataset
self.imports.AEVCacheLoader = AEVCacheLoader
......@@ -681,12 +682,13 @@ if sys.version_info[0] > 2:
self.container,
metrics={
'RMSE': self.imports.RMSEMetric('energies'),
'MAE': self.imports.MAEMetric('energies'),
'MaxAE': self.imports.MaxAEMetric('energies'),
}
)
evaluator.run(dataset)
metrics = evaluator.state.metrics
return hartree2kcal(metrics['RMSE']), hartree2kcal(metrics['MaxAE'])
return hartree2kcal(metrics['RMSE']), hartree2kcal(metrics['MAE']), hartree2kcal(metrics['MaxAE'])
def load_data(self, training_path, validation_path):
"""Load training and validation dataset from file.
......@@ -746,7 +748,7 @@ if sys.version_info[0] > 2:
@trainer.on(self.ignite.engine.Events.EPOCH_STARTED)
def validation_and_checkpoint(trainer):
trainer.state.rmse, trainer.state.mae = \
trainer.state.rmse, trainer.state.mae, trainer.state.maxae = \
self.evaluate(self.validation_set)
if trainer.state.rmse < self.best_validation_rmse:
trainer.state.no_improve_count = 0
......@@ -772,6 +774,8 @@ if sys.version_info[0] > 2:
trainer.state.rmse, epoch)
self.tensorboard.add_scalar('validation_mae_vs_epoch',
trainer.state.mae, epoch)
self.tensorboard.add_scalar('validation_maxae_vs_epoch',
trainer.state.maxae, epoch)
self.tensorboard.add_scalar(
'best_validation_rmse_vs_epoch',
self.best_validation_rmse, epoch)
......@@ -779,14 +783,16 @@ if sys.version_info[0] > 2:
'no_improve_count_vs_epoch',
trainer.state.no_improve_count, epoch)
# compute training RMSE and MAE
# compute training RMSE, MAE and MaxAE
if epoch % self.training_eval_every == 1:
training_rmse, training_mae = \
training_rmse, training_mae, training_maxae = \
self.evaluate(self.training_set)
self.tensorboard.add_scalar(
'training_rmse_vs_epoch', training_rmse, epoch)
self.tensorboard.add_scalar(
'training_mae_vs_epoch', training_mae, epoch)
self.tensorboard.add_scalar(
'training_mae_vs_epoch', training_maxae, epoch)
@trainer.on(self.ignite.engine.Events.ITERATION_COMPLETED)
def log_loss(trainer):
......@@ -806,13 +812,13 @@ if sys.version_info[0] > 2:
@trainer.on(self.ignite.engine.Events.EPOCH_STARTED)
def terminate_if_smaller_enough(trainer):
if trainer.state.mae < 1.0:
if trainer.state.rmse < 10.0:
trainer.terminate()
trainer.run(self.training_set, max_epochs=math.inf)
while lr > self.min_lr:
optimizer = AdamW(self.model.parameters(), lr=lr)
optimizer = AdamW(self.parameters, lr=lr)
trainer = self.ignite.engine.create_supervised_trainer(
self.container, optimizer, self.exp_loss)
decorate(trainer)
......
......@@ -105,9 +105,9 @@ class AdamW(Optimizer):
bias_correction2 = 1 - beta2 ** state['step']
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
p.data.addcdiv_(-step_size, exp_avg, denom)
if group['weight_decay'] != 0:
p.data.add_(-group['weight_decay'], p.data)
p.data.addcdiv_(-step_size, exp_avg, denom)
return loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment