Commit 25dd28bb authored by Farhad Ramezanghorbani's avatar Farhad Ramezanghorbani Committed by Gao, Xiang
Browse files

Adapt neurochem trainer to NC setup (#275)

parent 347e845c
...@@ -17,7 +17,7 @@ class TestNeuroChem(unittest.TestCase): ...@@ -17,7 +17,7 @@ class TestNeuroChem(unittest.TestCase):
# test if loader construct correct model # test if loader construct correct model
self.assertEqual(trainer.aev_computer.aev_length, 384) self.assertEqual(trainer.aev_computer.aev_length, 384)
m = trainer.model m = trainer.nn
H, C, N, O = m # noqa: E741 H, C, N, O = m # noqa: E741
self.assertIsInstance(H[0], torch.nn.Linear) self.assertIsInstance(H[0], torch.nn.Linear)
self.assertListEqual(list(H[0].weight.shape), [160, 384]) self.assertListEqual(list(H[0].weight.shape), [160, 384])
......
...@@ -16,8 +16,6 @@ from ..nn import ANIModel, Ensemble, Gaussian ...@@ -16,8 +16,6 @@ from ..nn import ANIModel, Ensemble, Gaussian
from ..utils import EnergyShifter, ChemicalSymbolsToInts from ..utils import EnergyShifter, ChemicalSymbolsToInts
from ..aev import AEVComputer from ..aev import AEVComputer
from ..optim import AdamW from ..optim import AdamW
import warnings
import textwrap
class Constants(collections.abc.Mapping): class Constants(collections.abc.Mapping):
...@@ -284,38 +282,24 @@ if sys.version_info[0] > 2: ...@@ -284,38 +282,24 @@ if sys.version_info[0] > 2:
def __init__(self, filename, device=torch.device('cuda'), tqdm=False, def __init__(self, filename, device=torch.device('cuda'), tqdm=False,
tensorboard=None, aev_caching=False, tensorboard=None, aev_caching=False,
checkpoint_name='model.pt'): checkpoint_name='model.pt'):
try:
import ignite from ..data import load_ani_dataset # noqa: E402
from ..ignite import Container, MSELoss, TransformedLoss, RMSEMetric, MAEMetric, MaxAEMetric from ..data import AEVCacheLoader # noqa: E402
from ..data import load_ani_dataset # noqa: E402
from ..data import AEVCacheLoader # noqa: E402
except ImportError:
raise RuntimeError(
'NeuroChem Trainer requires ignite,'
'please install pytorch-ignite-nightly from PYPI')
self.ignite = ignite
class dummy: class dummy:
pass pass
self.imports = dummy() self.imports = dummy()
self.imports.Container = Container
self.imports.MSELoss = MSELoss
self.imports.TransformedLoss = TransformedLoss
self.imports.RMSEMetric = RMSEMetric
self.imports.MaxAEMetric = MaxAEMetric
self.imports.MAEMetric = MAEMetric
self.imports.load_ani_dataset = load_ani_dataset self.imports.load_ani_dataset = load_ani_dataset
self.imports.AEVCacheLoader = AEVCacheLoader self.imports.AEVCacheLoader = AEVCacheLoader
self.warned = False
self.filename = filename self.filename = filename
self.device = device self.device = device
self.aev_caching = aev_caching self.aev_caching = aev_caching
self.checkpoint_name = checkpoint_name self.checkpoint_name = checkpoint_name
self.parameters = [] self.weights = []
self.biases = []
if tqdm: if tqdm:
import tqdm import tqdm
self.tqdm = tqdm.tqdm self.tqdm = tqdm.tqdm
...@@ -325,6 +309,7 @@ if sys.version_info[0] > 2: ...@@ -325,6 +309,7 @@ if sys.version_info[0] > 2:
import torch.utils.tensorboard import torch.utils.tensorboard
self.tensorboard = torch.utils.tensorboard.SummaryWriter( self.tensorboard = torch.utils.tensorboard.SummaryWriter(
log_dir=tensorboard) log_dir=tensorboard)
self.training_eval_every = 20 self.training_eval_every = 20
else: else:
self.tensorboard = None self.tensorboard = None
...@@ -455,6 +440,13 @@ if sys.version_info[0] > 2: ...@@ -455,6 +440,13 @@ if sys.version_info[0] > 2:
raise NotImplementedError(key + ' not supported yet') raise NotImplementedError(key + ' not supported yet')
del params[key] del params[key]
# weights and biases initialization
def init_params(m):
if isinstance(m, torch.nn.Linear):
torch.nn.init.kaiming_normal_(m.weight, a=1.0)
torch.nn.init.zeros_(m.bias)
del_if_exists('gpuid') del_if_exists('gpuid')
del_if_exists('nkde') del_if_exists('nkde')
del_if_exists('fmult') del_if_exists('fmult')
...@@ -496,7 +488,7 @@ if sys.version_info[0] > 2: ...@@ -496,7 +488,7 @@ if sys.version_info[0] > 2:
del params['tbtchsz'] del params['tbtchsz']
self.validation_batch_size = params['vbtchsz'] self.validation_batch_size = params['vbtchsz']
del params['vbtchsz'] del params['vbtchsz']
self.nmax = params['nmax'] self.nmax = math.inf if params['nmax'] == 0 else params['nmax']
del params['nmax'] del params['nmax']
# construct networks # construct networks
...@@ -521,71 +513,48 @@ if sys.version_info[0] > 2: ...@@ -521,71 +513,48 @@ if sys.version_info[0] > 2:
modules.append(activation) modules.append(activation)
del layer['activation'] del layer['activation']
if 'l2norm' in layer: if 'l2norm' in layer:
if not self.warned:
warnings.warn(textwrap.dedent("""
Currently TorchANI training with weight decay can not reproduce the training
result of NeuroChem with the same training setup. If you really want to use
weight decay, consider smaller rates and and make sure you do enough validation
to check if you get expected result."""))
self.warned = True
if layer['l2norm'] == 1: if layer['l2norm'] == 1:
self.parameters.append({ self.weights.append({
'params': [module.weight], 'params': [module.weight],
'weight_decay': layer['l2valu'], 'weight_decay': layer['l2valu'],
}) })
self.parameters.append({
'params': [module.bias],
})
else: else:
self.parameters.append({ self.weights.append({
'params': module.parameters(), 'params': [module.weight],
}) })
del layer['l2norm'] del layer['l2norm']
del layer['l2valu'] del layer['l2valu']
else: else:
self.parameters.append({ self.weights.append({
'params': module.parameters(), 'params': [module.weight],
}) })
self.biases.append({
'params': [module.bias],
})
if layer: if layer:
raise ValueError( raise ValueError(
'unrecognized parameter in layer setup') 'unrecognized parameter in layer setup')
i = o i = o
atomic_nets[atom_type] = torch.nn.Sequential(*modules) atomic_nets[atom_type] = torch.nn.Sequential(*modules)
self.model = ANIModel([atomic_nets[s] self.nn = ANIModel([atomic_nets[s] for s in self.consts.species])
for s in self.consts.species])
# initialize weights and biases
self.nn.apply(init_params)
if self.aev_caching: if self.aev_caching:
self.nnp = self.model self.model = self.nn.to(self.device)
else: else:
self.nnp = torch.nn.Sequential(self.aev_computer, self.model) self.model = torch.nn.Sequential(self.aev_computer, self.nn).to(self.device)
self.container = self.imports.Container({'energies': self.nnp}).to(self.device)
# losses # loss functions
self.mse_loss = self.imports.MSELoss('energies') self.mse_se = torch.nn.MSELoss(reduction='none')
self.exp_loss = self.imports.TransformedLoss( self.mse_sum = torch.nn.MSELoss(reduction='sum')
self.imports.MSELoss('energies'),
lambda x: 0.5 * (torch.exp(2 * x) - 1))
if params: if params:
raise ValueError('unrecognized parameter') raise ValueError('unrecognized parameter')
self.global_epoch = 0
self.global_iteration = 0
self.best_validation_rmse = math.inf self.best_validation_rmse = math.inf
def evaluate(self, dataset):
"""Evaluate on given dataset to compute RMSE and MaxAE."""
evaluator = self.ignite.engine.create_supervised_evaluator(
self.container,
metrics={
'RMSE': self.imports.RMSEMetric('energies'),
'MAE': self.imports.MAEMetric('energies'),
'MaxAE': self.imports.MaxAEMetric('energies'),
}
)
evaluator.run(dataset)
metrics = evaluator.state.metrics
return hartree2kcal(metrics['RMSE']), hartree2kcal(metrics['MAE']), hartree2kcal(metrics['MaxAE'])
def load_data(self, training_path, validation_path): def load_data(self, training_path, validation_path):
"""Load training and validation dataset from file. """Load training and validation dataset from file.
...@@ -605,121 +574,88 @@ if sys.version_info[0] > 2: ...@@ -605,121 +574,88 @@ if sys.version_info[0] > 2:
self.validation_batch_size, device=self.device, self.validation_batch_size, device=self.device,
transform=[self.shift_energy.subtract_from_dataset]) transform=[self.shift_energy.subtract_from_dataset])
def evaluate(self, dataset):
"""Run the evaluation"""
total_mse = 0.0
count = 0
for batch_x, batch_y in dataset:
true_energies = batch_y['energies']
predicted_energies = []
for chunk_species, chunk_coordinates in batch_x:
_, chunk_energies = self.model((chunk_species, chunk_coordinates))
predicted_energies.append(chunk_energies)
predicted_energies = torch.cat(predicted_energies)
total_mse += self.mse_sum(predicted_energies, true_energies).item()
count += predicted_energies.shape[0]
return hartree2kcal(math.sqrt(total_mse / count))
def run(self): def run(self):
"""Run the training""" """Run the training"""
start = timeit.default_timer() start = timeit.default_timer()
no_improve_count = 0
AdamW_optim = AdamW(self.weights, lr=self.init_lr)
SGD_optim = torch.optim.SGD(self.biases, lr=self.init_lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
AdamW_optim,
factor=0.5,
patience=100,
threshold=0)
while True:
rmse = self.evaluate(self.validation_set)
learning_rate = AdamW_optim.param_groups[0]['lr']
if learning_rate < self.min_lr or scheduler.last_epoch > self.nmax:
break
# checkpoint
if scheduler.is_better(rmse, scheduler.best):
no_improve_count = 0
torch.save(self.nn.state_dict(), self.model_checkpoint)
else:
no_improve_count += 1
def decorate(trainer): if no_improve_count > self.max_nonimprove:
break
@trainer.on(self.ignite.engine.Events.STARTED)
def initialize(trainer):
trainer.state.no_improve_count = 0
trainer.state.epoch += self.global_epoch
trainer.state.iteration += self.global_iteration
@trainer.on(self.ignite.engine.Events.COMPLETED)
def finalize(trainer):
self.global_epoch = trainer.state.epoch
self.global_iteration = trainer.state.iteration
if self.nmax > 0:
@trainer.on(self.ignite.engine.Events.EPOCH_COMPLETED)
def terminate_when_nmax_reaches(trainer):
if trainer.state.epoch >= self.nmax:
trainer.terminate()
if self.tqdm is not None:
@trainer.on(self.ignite.engine.Events.EPOCH_STARTED)
def init_tqdm(trainer):
trainer.state.tqdm = self.tqdm(
total=len(self.training_set), desc='epoch')
@trainer.on(self.ignite.engine.Events.ITERATION_COMPLETED)
def update_tqdm(trainer):
trainer.state.tqdm.update(1)
@trainer.on(self.ignite.engine.Events.EPOCH_COMPLETED)
def finalize_tqdm(trainer):
trainer.state.tqdm.close()
@trainer.on(self.ignite.engine.Events.EPOCH_STARTED)
def validation_and_checkpoint(trainer):
trainer.state.rmse, trainer.state.mae, trainer.state.maxae = \
self.evaluate(self.validation_set)
if trainer.state.rmse < self.best_validation_rmse:
trainer.state.no_improve_count = 0
self.best_validation_rmse = trainer.state.rmse
torch.save(self.model.state_dict(),
self.model_checkpoint)
else:
trainer.state.no_improve_count += 1
if trainer.state.no_improve_count > self.max_nonimprove:
trainer.terminate()
scheduler.step(rmse)
if self.tensorboard is not None:
self.tensorboard.add_scalar('validation_rmse', rmse, scheduler.last_epoch)
self.tensorboard.add_scalar('best_validation_rmse', scheduler.best, scheduler.last_epoch)
self.tensorboard.add_scalar('learning_rate', learning_rate, scheduler.last_epoch)
self.tensorboard.add_scalar('no_improve_count_vs_epoch', no_improve_count, scheduler.last_epoch)
for i, (batch_x, batch_y) in self.tqdm(
enumerate(self.training_set),
total=len(self.training_set),
desc='epoch {}'.format(scheduler.last_epoch)
):
true_energies = batch_y['energies']
predicted_energies = []
num_atoms = []
for chunk_species, chunk_coordinates in batch_x:
num_atoms.append((chunk_species >= 0).sum(dim=1))
_, chunk_energies = self.model((chunk_species, chunk_coordinates))
predicted_energies.append(chunk_energies)
num_atoms = torch.cat(num_atoms).to(true_energies.dtype)
predicted_energies = torch.cat(predicted_energies)
loss = (self.mse_se(predicted_energies, true_energies) / num_atoms.sqrt()).mean()
AdamW_optim.zero_grad()
SGD_optim.zero_grad()
loss.backward()
AdamW_optim.step()
SGD_optim.step()
# write current batch loss to TensorBoard
if self.tensorboard is not None:
self.tensorboard.add_scalar('batch_loss', loss, scheduler.last_epoch * len(self.training_set) + i)
# log elapsed time
elapsed = round(timeit.default_timer() - start, 2)
if self.tensorboard is not None: if self.tensorboard is not None:
@trainer.on(self.ignite.engine.Events.EPOCH_STARTED) self.tensorboard.add_scalar('time_vs_epoch', elapsed, scheduler.last_epoch)
def log_per_epoch(trainer):
elapsed = round(timeit.default_timer() - start, 2)
epoch = trainer.state.epoch
self.tensorboard.add_scalar('time_vs_epoch', elapsed,
epoch)
self.tensorboard.add_scalar('learning_rate_vs_epoch',
lr, epoch)
self.tensorboard.add_scalar('validation_rmse_vs_epoch',
trainer.state.rmse, epoch)
self.tensorboard.add_scalar('validation_mae_vs_epoch',
trainer.state.mae, epoch)
self.tensorboard.add_scalar('validation_maxae_vs_epoch',
trainer.state.maxae, epoch)
self.tensorboard.add_scalar(
'best_validation_rmse_vs_epoch',
self.best_validation_rmse, epoch)
self.tensorboard.add_scalar(
'no_improve_count_vs_epoch',
trainer.state.no_improve_count, epoch)
# compute training RMSE, MAE and MaxAE
if epoch % self.training_eval_every == 1:
training_rmse, training_mae, training_maxae = \
self.evaluate(self.training_set)
self.tensorboard.add_scalar(
'training_rmse_vs_epoch', training_rmse, epoch)
self.tensorboard.add_scalar(
'training_mae_vs_epoch', training_mae, epoch)
self.tensorboard.add_scalar(
'training_mae_vs_epoch', training_maxae, epoch)
@trainer.on(self.ignite.engine.Events.ITERATION_COMPLETED)
def log_loss(trainer):
iteration = trainer.state.iteration
loss = trainer.state.output
self.tensorboard.add_scalar('loss_vs_iteration',
loss, iteration)
lr = self.init_lr
# training using mse loss first until the validation MAE decrease
# to < 1 Hartree
optimizer = AdamW(self.parameters, lr=lr)
trainer = self.ignite.engine.create_supervised_trainer(
self.container, optimizer, self.mse_loss)
decorate(trainer)
@trainer.on(self.ignite.engine.Events.EPOCH_STARTED)
def terminate_if_smaller_enough(trainer):
if trainer.state.rmse < 10.0:
trainer.terminate()
trainer.run(self.training_set, max_epochs=math.inf)
while lr > self.min_lr:
optimizer = AdamW(self.parameters, lr=lr)
trainer = self.ignite.engine.create_supervised_trainer(
self.container, optimizer, self.exp_loss)
decorate(trainer)
trainer.run(self.training_set, max_epochs=math.inf)
lr *= self.lr_decay
__all__ = ['Constants', 'load_sae', 'load_model', 'load_model_ensemble', 'Trainer'] __all__ = ['Constants', 'load_sae', 'load_model', 'load_model_ensemble', 'Trainer']
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment