Commit 0fbd69e9 authored by Farhad Ramezanghorbani's avatar Farhad Ramezanghorbani Committed by Gao, Xiang
Browse files

Update nnp_training to have identical setup to NC (#280)

* Update nnp_training to have identical setup to NC

* fix

* add LR decay scheduler for both optimizers

* fix
parent d47d2579
......@@ -5,8 +5,9 @@
Train Your Own Neural Network Potential
=======================================
This example shows how to use TorchANI to train a neural network potential. We
will use the same configuration as specified as in `inputtrain.ipt`_
This example shows how to use TorchANI to train a neural network potential
with the setup identical to NeuroChem. We will use the same configuration as
specified in `inputtrain.ipt`_
.. _`inputtrain.ipt`:
https://github.com/aiqm/torchani/blob/master/torchani/resources/ani-1x_8x/inputtrain.ipt
......@@ -18,6 +19,7 @@ will use the same configuration as specified as in `inputtrain.ipt`_
###############################################################################
# To begin with, let's first import the modules and setup devices we will use:
import torch
import torchani
import os
......@@ -45,6 +47,7 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# https://github.com/aiqm/torchani/blob/master/torchani/resources/ani-1x_8x/rHCNO-5.2R_16-3.5A_a4-8.params
# .. _sae_linfit.dat:
# https://github.com/aiqm/torchani/blob/master/torchani/resources/ani-1x_8x/sae_linfit.dat
Rcr = 5.2000e+00
Rca = 3.5000e+00
EtaR = torch.tensor([1.6000000e+01], device=device)
......@@ -58,7 +61,6 @@ aev_computer = torchani.AEVComputer(Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ
energy_shifter = torchani.utils.EnergyShifter(None)
species_to_tensor = torchani.utils.ChemicalSymbolsToInts('HCNO')
###############################################################################
# Now let's setup datasets. These paths assumes the user run this script under
# the ``examples`` directory of TorchANI's repository. If you download this
......@@ -83,7 +85,8 @@ training, validation = torchani.data.load_ani_dataset(
dspath, species_to_tensor, batch_size, device=device,
transform=[energy_shifter.subtract_from_dataset], split=[0.8, None])
print('H,C,N,O self energies: ', energy_shifter.self_energies)
print('Self atomic energies: ', energy_shifter.self_energies)
###############################################################################
# When iterating the dataset, we will get pairs of input and output
# ``(species_coordinates, properties)``, where ``species_coordinates`` is the
......@@ -107,8 +110,7 @@ print('H,C,N,O self energies: ', energy_shifter.self_energies)
#
# The output, i.e. ``properties`` is a dictionary holding each property. This
# allows us to extend TorchANI in the future to training forces and properties.
#
###############################################################################
# Now let's define atomic neural networks.
......@@ -181,9 +183,9 @@ nn.apply(init_params)
model = torch.nn.Sequential(aev_computer, nn).to(device)
###############################################################################
# Now let's setup the optimizer. We need to specify different weight decay rate
# for different parameters. Since PyTorch does not have correct implementation
# of weight decay right now, we provide the correct implementation at TorchANI.
# Now let's setup the optimizers. NeuroChem uses Adam with decoupled weight decay
# to updates the weights and Stochastic Gradient Descent (SGD) to update the biases.
# Moreover, we need to specify different weight decay rate for different layes.
#
# .. note::
#
......@@ -193,53 +195,59 @@ model = torch.nn.Sequential(aev_computer, nn).to(device)
# Also note that the weight decay only applies to weight in the training
# of ANI models, not bias.
#
# .. warning::
#
# Currently TorchANI training with weight decay can not reproduce the training
# result of NeuroChem with the same training setup. If you really want to use
# weight decay, consider smaller rates and and make sure you do enough validation
# to check if you get expected result.
#
# .. _Decoupled Weight Decay Regularization:
# https://arxiv.org/abs/1711.05101
optimizer = torchani.optim.AdamW([
AdamW = torchani.optim.AdamW([
# H networks
{'params': [H_network[0].weight], 'weight_decay': 0.0001},
{'params': [H_network[0].bias]},
{'params': [H_network[0].weight]},
{'params': [H_network[2].weight], 'weight_decay': 0.00001},
{'params': [H_network[2].bias]},
{'params': [H_network[4].weight], 'weight_decay': 0.000001},
{'params': [H_network[6].weight]},
# C networks
{'params': [C_network[0].weight]},
{'params': [C_network[2].weight], 'weight_decay': 0.00001},
{'params': [C_network[4].weight], 'weight_decay': 0.000001},
{'params': [C_network[6].weight]},
# N networks
{'params': [N_network[0].weight]},
{'params': [N_network[2].weight], 'weight_decay': 0.00001},
{'params': [N_network[4].weight], 'weight_decay': 0.000001},
{'params': [N_network[6].weight]},
# O networks
{'params': [O_network[0].weight]},
{'params': [O_network[2].weight], 'weight_decay': 0.00001},
{'params': [O_network[4].weight], 'weight_decay': 0.000001},
{'params': [O_network[6].weight]},
])
SGD = torch.optim.SGD([
# H networks
{'params': [H_network[0].bias]},
{'params': [H_network[2].bias]},
{'params': [H_network[4].bias]},
{'params': H_network[6].parameters()},
{'params': [H_network[6].bias]},
# C networks
{'params': [C_network[0].weight], 'weight_decay': 0.0001},
{'params': [C_network[0].bias]},
{'params': [C_network[2].weight], 'weight_decay': 0.00001},
{'params': [C_network[2].bias]},
{'params': [C_network[4].weight], 'weight_decay': 0.000001},
{'params': [C_network[4].bias]},
{'params': C_network[6].parameters()},
{'params': [C_network[6].bias]},
# N networks
{'params': [N_network[0].weight], 'weight_decay': 0.0001},
{'params': [N_network[0].bias]},
{'params': [N_network[2].weight], 'weight_decay': 0.00001},
{'params': [N_network[2].bias]},
{'params': [N_network[4].weight], 'weight_decay': 0.000001},
{'params': [N_network[4].bias]},
{'params': N_network[6].parameters()},
{'params': [N_network[6].bias]},
# O networks
{'params': [O_network[0].weight], 'weight_decay': 0.0001},
{'params': [O_network[0].bias]},
{'params': [O_network[2].weight], 'weight_decay': 0.00001},
{'params': [O_network[2].bias]},
{'params': [O_network[4].weight], 'weight_decay': 0.000001},
{'params': [O_network[4].bias]},
{'params': O_network[6].parameters()},
])
{'params': [O_network[6].bias]},
], lr=1e-3)
###############################################################################
# Setting up a learning rate scheduler to do learning rate decay
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=100, threshold=0)
AdamW_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(AdamW, factor=0.5, patience=100, threshold=0)
SGD_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(SGD, factor=0.5, patience=100, threshold=0)
###############################################################################
# Train the model by minimizing the MSE loss, until validation RMSE no longer
......@@ -254,9 +262,11 @@ latest_checkpoint = 'latest.pt'
# Resume training from previously saved checkpoints:
if os.path.isfile(latest_checkpoint):
checkpoint = torch.load(latest_checkpoint)
model.load_state_dict(checkpoint['nn'])
optimizer.load_state_dict(checkpoint['optimizer'])
scheduler.load_state_dict(checkpoint['scheduler'])
nn.load_state_dict(checkpoint['nn'])
AdamW.load_state_dict(checkpoint['AdamW'])
SGD.load_state_dict(checkpoint['SGD'])
AdamW_scheduler.load_state_dict(checkpoint['AdamW_scheduler'])
SGD_scheduler.load_state_dict(checkpoint['SGD_scheduler'])
###############################################################################
# During training, we need to validate on validation set and if validation error
......@@ -297,50 +307,63 @@ tensorboard = torch.utils.tensorboard.SummaryWriter()
# set to a much larger value
mse = torch.nn.MSELoss(reduction='none')
print("training starting from epoch", scheduler.last_epoch + 1)
print("training starting from epoch", AdamW_scheduler.last_epoch + 1)
max_epochs = 200
early_stopping_learning_rate = 1.0E-5
best_model_checkpoint = 'best.pt'
for _ in range(scheduler.last_epoch + 1, max_epochs):
for _ in range(AdamW_scheduler.last_epoch + 1, max_epochs):
rmse = validate()
print('RMSE:', rmse, 'at epoch', scheduler.last_epoch + 1)
print('RMSE:', rmse, 'at epoch', AdamW_scheduler.last_epoch + 1)
learning_rate = optimizer.param_groups[0]['lr']
learning_rate = AdamW.param_groups[0]['lr']
if learning_rate < early_stopping_learning_rate:
break
tensorboard.add_scalar('validation_rmse', rmse, scheduler.last_epoch + 1)
tensorboard.add_scalar('best_validation_rmse', scheduler.best, scheduler.last_epoch + 1)
tensorboard.add_scalar('learning_rate', learning_rate, scheduler.last_epoch + 1)
# checkpoint
if scheduler.is_better(rmse, scheduler.best):
if AdamW_scheduler.is_better(rmse, AdamW_scheduler.best):
torch.save(nn.state_dict(), best_model_checkpoint)
scheduler.step(rmse)
AdamW_scheduler.step(rmse)
SGD_scheduler.step(rmse)
tensorboard.add_scalar('validation_rmse', rmse, AdamW_scheduler.last_epoch)
tensorboard.add_scalar('best_validation_rmse', AdamW_scheduler.best, AdamW_scheduler.last_epoch)
tensorboard.add_scalar('learning_rate', learning_rate, AdamW_scheduler.last_epoch)
for i, (batch_x, batch_y) in tqdm.tqdm(
enumerate(training),
total=len(training),
desc="epoch {}".format(AdamW_scheduler.last_epoch)
):
for i, (batch_x, batch_y) in tqdm.tqdm(enumerate(training), total=len(training)):
true_energies = batch_y['energies']
predicted_energies = []
num_atoms = []
for chunk_species, chunk_coordinates in batch_x:
num_atoms.append((chunk_species >= 0).to(true_energies.dtype).sum(dim=1))
_, chunk_energies = model((chunk_species, chunk_coordinates))
predicted_energies.append(chunk_energies)
num_atoms = torch.cat(num_atoms)
predicted_energies = torch.cat(predicted_energies)
loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean()
optimizer.zero_grad()
AdamW.zero_grad()
SGD.zero_grad()
loss.backward()
optimizer.step()
AdamW.step()
SGD.step()
# write current batch loss to TensorBoard
tensorboard.add_scalar('batch_loss', loss, scheduler.last_epoch * len(training) + i)
tensorboard.add_scalar('batch_loss', loss, AdamW_scheduler.last_epoch * len(training) + i)
torch.save({
'nn': nn.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'AdamW': AdamW.state_dict(),
'SGD': SGD.state_dict(),
'AdamW_scheduler': AdamW_scheduler.state_dict(),
'SGD_scheduler': SGD_scheduler.state_dict(),
}, latest_checkpoint)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment