Commit 25dd28bb authored by Farhad Ramezanghorbani's avatar Farhad Ramezanghorbani Committed by Gao, Xiang
Browse files

Adapt neurochem trainer to NC setup (#275)

parent 347e845c
......@@ -17,7 +17,7 @@ class TestNeuroChem(unittest.TestCase):
# test if loader construct correct model
self.assertEqual(trainer.aev_computer.aev_length, 384)
m = trainer.model
m = trainer.nn
H, C, N, O = m # noqa: E741
self.assertIsInstance(H[0], torch.nn.Linear)
self.assertListEqual(list(H[0].weight.shape), [160, 384])
......
......@@ -16,8 +16,6 @@ from ..nn import ANIModel, Ensemble, Gaussian
from ..utils import EnergyShifter, ChemicalSymbolsToInts
from ..aev import AEVComputer
from ..optim import AdamW
import warnings
import textwrap
class Constants(collections.abc.Mapping):
......@@ -284,38 +282,24 @@ if sys.version_info[0] > 2:
def __init__(self, filename, device=torch.device('cuda'), tqdm=False,
tensorboard=None, aev_caching=False,
checkpoint_name='model.pt'):
try:
import ignite
from ..ignite import Container, MSELoss, TransformedLoss, RMSEMetric, MAEMetric, MaxAEMetric
from ..data import load_ani_dataset # noqa: E402
from ..data import AEVCacheLoader # noqa: E402
except ImportError:
raise RuntimeError(
'NeuroChem Trainer requires ignite,'
'please install pytorch-ignite-nightly from PYPI')
self.ignite = ignite
class dummy:
pass
self.imports = dummy()
self.imports.Container = Container
self.imports.MSELoss = MSELoss
self.imports.TransformedLoss = TransformedLoss
self.imports.RMSEMetric = RMSEMetric
self.imports.MaxAEMetric = MaxAEMetric
self.imports.MAEMetric = MAEMetric
self.imports.load_ani_dataset = load_ani_dataset
self.imports.AEVCacheLoader = AEVCacheLoader
self.warned = False
self.filename = filename
self.device = device
self.aev_caching = aev_caching
self.checkpoint_name = checkpoint_name
self.parameters = []
self.weights = []
self.biases = []
if tqdm:
import tqdm
self.tqdm = tqdm.tqdm
......@@ -325,6 +309,7 @@ if sys.version_info[0] > 2:
import torch.utils.tensorboard
self.tensorboard = torch.utils.tensorboard.SummaryWriter(
log_dir=tensorboard)
self.training_eval_every = 20
else:
self.tensorboard = None
......@@ -455,6 +440,13 @@ if sys.version_info[0] > 2:
raise NotImplementedError(key + ' not supported yet')
del params[key]
# weights and biases initialization
def init_params(m):
if isinstance(m, torch.nn.Linear):
torch.nn.init.kaiming_normal_(m.weight, a=1.0)
torch.nn.init.zeros_(m.bias)
del_if_exists('gpuid')
del_if_exists('nkde')
del_if_exists('fmult')
......@@ -496,7 +488,7 @@ if sys.version_info[0] > 2:
del params['tbtchsz']
self.validation_batch_size = params['vbtchsz']
del params['vbtchsz']
self.nmax = params['nmax']
self.nmax = math.inf if params['nmax'] == 0 else params['nmax']
del params['nmax']
# construct networks
......@@ -521,71 +513,48 @@ if sys.version_info[0] > 2:
modules.append(activation)
del layer['activation']
if 'l2norm' in layer:
if not self.warned:
warnings.warn(textwrap.dedent("""
Currently TorchANI training with weight decay can not reproduce the training
result of NeuroChem with the same training setup. If you really want to use
weight decay, consider smaller rates and and make sure you do enough validation
to check if you get expected result."""))
self.warned = True
if layer['l2norm'] == 1:
self.parameters.append({
self.weights.append({
'params': [module.weight],
'weight_decay': layer['l2valu'],
})
self.parameters.append({
'params': [module.bias],
})
else:
self.parameters.append({
'params': module.parameters(),
self.weights.append({
'params': [module.weight],
})
del layer['l2norm']
del layer['l2valu']
else:
self.parameters.append({
'params': module.parameters(),
self.weights.append({
'params': [module.weight],
})
self.biases.append({
'params': [module.bias],
})
if layer:
raise ValueError(
'unrecognized parameter in layer setup')
i = o
atomic_nets[atom_type] = torch.nn.Sequential(*modules)
self.model = ANIModel([atomic_nets[s]
for s in self.consts.species])
self.nn = ANIModel([atomic_nets[s] for s in self.consts.species])
# initialize weights and biases
self.nn.apply(init_params)
if self.aev_caching:
self.nnp = self.model
self.model = self.nn.to(self.device)
else:
self.nnp = torch.nn.Sequential(self.aev_computer, self.model)
self.container = self.imports.Container({'energies': self.nnp}).to(self.device)
self.model = torch.nn.Sequential(self.aev_computer, self.nn).to(self.device)
# losses
self.mse_loss = self.imports.MSELoss('energies')
self.exp_loss = self.imports.TransformedLoss(
self.imports.MSELoss('energies'),
lambda x: 0.5 * (torch.exp(2 * x) - 1))
# loss functions
self.mse_se = torch.nn.MSELoss(reduction='none')
self.mse_sum = torch.nn.MSELoss(reduction='sum')
if params:
raise ValueError('unrecognized parameter')
self.global_epoch = 0
self.global_iteration = 0
self.best_validation_rmse = math.inf
def evaluate(self, dataset):
"""Evaluate on given dataset to compute RMSE and MaxAE."""
evaluator = self.ignite.engine.create_supervised_evaluator(
self.container,
metrics={
'RMSE': self.imports.RMSEMetric('energies'),
'MAE': self.imports.MAEMetric('energies'),
'MaxAE': self.imports.MaxAEMetric('energies'),
}
)
evaluator.run(dataset)
metrics = evaluator.state.metrics
return hartree2kcal(metrics['RMSE']), hartree2kcal(metrics['MAE']), hartree2kcal(metrics['MaxAE'])
def load_data(self, training_path, validation_path):
"""Load training and validation dataset from file.
......@@ -605,121 +574,88 @@ if sys.version_info[0] > 2:
self.validation_batch_size, device=self.device,
transform=[self.shift_energy.subtract_from_dataset])
def evaluate(self, dataset):
"""Run the evaluation"""
total_mse = 0.0
count = 0
for batch_x, batch_y in dataset:
true_energies = batch_y['energies']
predicted_energies = []
for chunk_species, chunk_coordinates in batch_x:
_, chunk_energies = self.model((chunk_species, chunk_coordinates))
predicted_energies.append(chunk_energies)
predicted_energies = torch.cat(predicted_energies)
total_mse += self.mse_sum(predicted_energies, true_energies).item()
count += predicted_energies.shape[0]
return hartree2kcal(math.sqrt(total_mse / count))
def run(self):
"""Run the training"""
start = timeit.default_timer()
def decorate(trainer):
@trainer.on(self.ignite.engine.Events.STARTED)
def initialize(trainer):
trainer.state.no_improve_count = 0
trainer.state.epoch += self.global_epoch
trainer.state.iteration += self.global_iteration
@trainer.on(self.ignite.engine.Events.COMPLETED)
def finalize(trainer):
self.global_epoch = trainer.state.epoch
self.global_iteration = trainer.state.iteration
if self.nmax > 0:
@trainer.on(self.ignite.engine.Events.EPOCH_COMPLETED)
def terminate_when_nmax_reaches(trainer):
if trainer.state.epoch >= self.nmax:
trainer.terminate()
if self.tqdm is not None:
@trainer.on(self.ignite.engine.Events.EPOCH_STARTED)
def init_tqdm(trainer):
trainer.state.tqdm = self.tqdm(
total=len(self.training_set), desc='epoch')
@trainer.on(self.ignite.engine.Events.ITERATION_COMPLETED)
def update_tqdm(trainer):
trainer.state.tqdm.update(1)
@trainer.on(self.ignite.engine.Events.EPOCH_COMPLETED)
def finalize_tqdm(trainer):
trainer.state.tqdm.close()
@trainer.on(self.ignite.engine.Events.EPOCH_STARTED)
def validation_and_checkpoint(trainer):
trainer.state.rmse, trainer.state.mae, trainer.state.maxae = \
self.evaluate(self.validation_set)
if trainer.state.rmse < self.best_validation_rmse:
trainer.state.no_improve_count = 0
self.best_validation_rmse = trainer.state.rmse
torch.save(self.model.state_dict(),
self.model_checkpoint)
no_improve_count = 0
AdamW_optim = AdamW(self.weights, lr=self.init_lr)
SGD_optim = torch.optim.SGD(self.biases, lr=self.init_lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
AdamW_optim,
factor=0.5,
patience=100,
threshold=0)
while True:
rmse = self.evaluate(self.validation_set)
learning_rate = AdamW_optim.param_groups[0]['lr']
if learning_rate < self.min_lr or scheduler.last_epoch > self.nmax:
break
# checkpoint
if scheduler.is_better(rmse, scheduler.best):
no_improve_count = 0
torch.save(self.nn.state_dict(), self.model_checkpoint)
else:
trainer.state.no_improve_count += 1
no_improve_count += 1
if trainer.state.no_improve_count > self.max_nonimprove:
trainer.terminate()
if no_improve_count > self.max_nonimprove:
break
scheduler.step(rmse)
if self.tensorboard is not None:
@trainer.on(self.ignite.engine.Events.EPOCH_STARTED)
def log_per_epoch(trainer):
self.tensorboard.add_scalar('validation_rmse', rmse, scheduler.last_epoch)
self.tensorboard.add_scalar('best_validation_rmse', scheduler.best, scheduler.last_epoch)
self.tensorboard.add_scalar('learning_rate', learning_rate, scheduler.last_epoch)
self.tensorboard.add_scalar('no_improve_count_vs_epoch', no_improve_count, scheduler.last_epoch)
for i, (batch_x, batch_y) in self.tqdm(
enumerate(self.training_set),
total=len(self.training_set),
desc='epoch {}'.format(scheduler.last_epoch)
):
true_energies = batch_y['energies']
predicted_energies = []
num_atoms = []
for chunk_species, chunk_coordinates in batch_x:
num_atoms.append((chunk_species >= 0).sum(dim=1))
_, chunk_energies = self.model((chunk_species, chunk_coordinates))
predicted_energies.append(chunk_energies)
num_atoms = torch.cat(num_atoms).to(true_energies.dtype)
predicted_energies = torch.cat(predicted_energies)
loss = (self.mse_se(predicted_energies, true_energies) / num_atoms.sqrt()).mean()
AdamW_optim.zero_grad()
SGD_optim.zero_grad()
loss.backward()
AdamW_optim.step()
SGD_optim.step()
# write current batch loss to TensorBoard
if self.tensorboard is not None:
self.tensorboard.add_scalar('batch_loss', loss, scheduler.last_epoch * len(self.training_set) + i)
# log elapsed time
elapsed = round(timeit.default_timer() - start, 2)
epoch = trainer.state.epoch
self.tensorboard.add_scalar('time_vs_epoch', elapsed,
epoch)
self.tensorboard.add_scalar('learning_rate_vs_epoch',
lr, epoch)
self.tensorboard.add_scalar('validation_rmse_vs_epoch',
trainer.state.rmse, epoch)
self.tensorboard.add_scalar('validation_mae_vs_epoch',
trainer.state.mae, epoch)
self.tensorboard.add_scalar('validation_maxae_vs_epoch',
trainer.state.maxae, epoch)
self.tensorboard.add_scalar(
'best_validation_rmse_vs_epoch',
self.best_validation_rmse, epoch)
self.tensorboard.add_scalar(
'no_improve_count_vs_epoch',
trainer.state.no_improve_count, epoch)
# compute training RMSE, MAE and MaxAE
if epoch % self.training_eval_every == 1:
training_rmse, training_mae, training_maxae = \
self.evaluate(self.training_set)
self.tensorboard.add_scalar(
'training_rmse_vs_epoch', training_rmse, epoch)
self.tensorboard.add_scalar(
'training_mae_vs_epoch', training_mae, epoch)
self.tensorboard.add_scalar(
'training_mae_vs_epoch', training_maxae, epoch)
@trainer.on(self.ignite.engine.Events.ITERATION_COMPLETED)
def log_loss(trainer):
iteration = trainer.state.iteration
loss = trainer.state.output
self.tensorboard.add_scalar('loss_vs_iteration',
loss, iteration)
lr = self.init_lr
# training using mse loss first until the validation MAE decrease
# to < 1 Hartree
optimizer = AdamW(self.parameters, lr=lr)
trainer = self.ignite.engine.create_supervised_trainer(
self.container, optimizer, self.mse_loss)
decorate(trainer)
@trainer.on(self.ignite.engine.Events.EPOCH_STARTED)
def terminate_if_smaller_enough(trainer):
if trainer.state.rmse < 10.0:
trainer.terminate()
trainer.run(self.training_set, max_epochs=math.inf)
while lr > self.min_lr:
optimizer = AdamW(self.parameters, lr=lr)
trainer = self.ignite.engine.create_supervised_trainer(
self.container, optimizer, self.exp_loss)
decorate(trainer)
trainer.run(self.training_set, max_epochs=math.inf)
lr *= self.lr_decay
if self.tensorboard is not None:
self.tensorboard.add_scalar('time_vs_epoch', elapsed, scheduler.last_epoch)
__all__ = ['Constants', 'load_sae', 'load_model', 'load_model_ensemble', 'Trainer']
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment