Commit 0fbd69e9 authored by Farhad Ramezanghorbani's avatar Farhad Ramezanghorbani Committed by Gao, Xiang
Browse files

Update nnp_training to have identical setup to NC (#280)

* Update nnp_training to have identical setup to NC

* fix

* add LR decay scheduler for both optimizers

* fix
parent d47d2579
...@@ -5,8 +5,9 @@ ...@@ -5,8 +5,9 @@
Train Your Own Neural Network Potential Train Your Own Neural Network Potential
======================================= =======================================
This example shows how to use TorchANI to train a neural network potential. We This example shows how to use TorchANI to train a neural network potential
will use the same configuration as specified as in `inputtrain.ipt`_ with the setup identical to NeuroChem. We will use the same configuration as
specified in `inputtrain.ipt`_
.. _`inputtrain.ipt`: .. _`inputtrain.ipt`:
https://github.com/aiqm/torchani/blob/master/torchani/resources/ani-1x_8x/inputtrain.ipt https://github.com/aiqm/torchani/blob/master/torchani/resources/ani-1x_8x/inputtrain.ipt
...@@ -18,6 +19,7 @@ will use the same configuration as specified as in `inputtrain.ipt`_ ...@@ -18,6 +19,7 @@ will use the same configuration as specified as in `inputtrain.ipt`_
############################################################################### ###############################################################################
# To begin with, let's first import the modules and setup devices we will use: # To begin with, let's first import the modules and setup devices we will use:
import torch import torch
import torchani import torchani
import os import os
...@@ -45,6 +47,7 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') ...@@ -45,6 +47,7 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# https://github.com/aiqm/torchani/blob/master/torchani/resources/ani-1x_8x/rHCNO-5.2R_16-3.5A_a4-8.params # https://github.com/aiqm/torchani/blob/master/torchani/resources/ani-1x_8x/rHCNO-5.2R_16-3.5A_a4-8.params
# .. _sae_linfit.dat: # .. _sae_linfit.dat:
# https://github.com/aiqm/torchani/blob/master/torchani/resources/ani-1x_8x/sae_linfit.dat # https://github.com/aiqm/torchani/blob/master/torchani/resources/ani-1x_8x/sae_linfit.dat
Rcr = 5.2000e+00 Rcr = 5.2000e+00
Rca = 3.5000e+00 Rca = 3.5000e+00
EtaR = torch.tensor([1.6000000e+01], device=device) EtaR = torch.tensor([1.6000000e+01], device=device)
...@@ -58,7 +61,6 @@ aev_computer = torchani.AEVComputer(Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ ...@@ -58,7 +61,6 @@ aev_computer = torchani.AEVComputer(Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ
energy_shifter = torchani.utils.EnergyShifter(None) energy_shifter = torchani.utils.EnergyShifter(None)
species_to_tensor = torchani.utils.ChemicalSymbolsToInts('HCNO') species_to_tensor = torchani.utils.ChemicalSymbolsToInts('HCNO')
############################################################################### ###############################################################################
# Now let's setup datasets. These paths assumes the user run this script under # Now let's setup datasets. These paths assumes the user run this script under
# the ``examples`` directory of TorchANI's repository. If you download this # the ``examples`` directory of TorchANI's repository. If you download this
...@@ -83,7 +85,8 @@ training, validation = torchani.data.load_ani_dataset( ...@@ -83,7 +85,8 @@ training, validation = torchani.data.load_ani_dataset(
dspath, species_to_tensor, batch_size, device=device, dspath, species_to_tensor, batch_size, device=device,
transform=[energy_shifter.subtract_from_dataset], split=[0.8, None]) transform=[energy_shifter.subtract_from_dataset], split=[0.8, None])
print('H,C,N,O self energies: ', energy_shifter.self_energies) print('Self atomic energies: ', energy_shifter.self_energies)
############################################################################### ###############################################################################
# When iterating the dataset, we will get pairs of input and output # When iterating the dataset, we will get pairs of input and output
# ``(species_coordinates, properties)``, where ``species_coordinates`` is the # ``(species_coordinates, properties)``, where ``species_coordinates`` is the
...@@ -107,8 +110,7 @@ print('H,C,N,O self energies: ', energy_shifter.self_energies) ...@@ -107,8 +110,7 @@ print('H,C,N,O self energies: ', energy_shifter.self_energies)
# #
# The output, i.e. ``properties`` is a dictionary holding each property. This # The output, i.e. ``properties`` is a dictionary holding each property. This
# allows us to extend TorchANI in the future to training forces and properties. # allows us to extend TorchANI in the future to training forces and properties.
#
############################################################################### ###############################################################################
# Now let's define atomic neural networks. # Now let's define atomic neural networks.
...@@ -181,9 +183,9 @@ nn.apply(init_params) ...@@ -181,9 +183,9 @@ nn.apply(init_params)
model = torch.nn.Sequential(aev_computer, nn).to(device) model = torch.nn.Sequential(aev_computer, nn).to(device)
############################################################################### ###############################################################################
# Now let's setup the optimizer. We need to specify different weight decay rate # Now let's setup the optimizers. NeuroChem uses Adam with decoupled weight decay
# for different parameters. Since PyTorch does not have correct implementation # to updates the weights and Stochastic Gradient Descent (SGD) to update the biases.
# of weight decay right now, we provide the correct implementation at TorchANI. # Moreover, we need to specify different weight decay rate for different layes.
# #
# .. note:: # .. note::
# #
...@@ -193,53 +195,59 @@ model = torch.nn.Sequential(aev_computer, nn).to(device) ...@@ -193,53 +195,59 @@ model = torch.nn.Sequential(aev_computer, nn).to(device)
# Also note that the weight decay only applies to weight in the training # Also note that the weight decay only applies to weight in the training
# of ANI models, not bias. # of ANI models, not bias.
# #
# .. warning::
#
# Currently TorchANI training with weight decay can not reproduce the training
# result of NeuroChem with the same training setup. If you really want to use
# weight decay, consider smaller rates and and make sure you do enough validation
# to check if you get expected result.
#
# .. _Decoupled Weight Decay Regularization: # .. _Decoupled Weight Decay Regularization:
# https://arxiv.org/abs/1711.05101 # https://arxiv.org/abs/1711.05101
optimizer = torchani.optim.AdamW([
AdamW = torchani.optim.AdamW([
# H networks # H networks
{'params': [H_network[0].weight], 'weight_decay': 0.0001}, {'params': [H_network[0].weight]},
{'params': [H_network[0].bias]},
{'params': [H_network[2].weight], 'weight_decay': 0.00001}, {'params': [H_network[2].weight], 'weight_decay': 0.00001},
{'params': [H_network[2].bias]},
{'params': [H_network[4].weight], 'weight_decay': 0.000001}, {'params': [H_network[4].weight], 'weight_decay': 0.000001},
{'params': [H_network[6].weight]},
# C networks
{'params': [C_network[0].weight]},
{'params': [C_network[2].weight], 'weight_decay': 0.00001},
{'params': [C_network[4].weight], 'weight_decay': 0.000001},
{'params': [C_network[6].weight]},
# N networks
{'params': [N_network[0].weight]},
{'params': [N_network[2].weight], 'weight_decay': 0.00001},
{'params': [N_network[4].weight], 'weight_decay': 0.000001},
{'params': [N_network[6].weight]},
# O networks
{'params': [O_network[0].weight]},
{'params': [O_network[2].weight], 'weight_decay': 0.00001},
{'params': [O_network[4].weight], 'weight_decay': 0.000001},
{'params': [O_network[6].weight]},
])
SGD = torch.optim.SGD([
# H networks
{'params': [H_network[0].bias]},
{'params': [H_network[2].bias]},
{'params': [H_network[4].bias]}, {'params': [H_network[4].bias]},
{'params': H_network[6].parameters()}, {'params': [H_network[6].bias]},
# C networks # C networks
{'params': [C_network[0].weight], 'weight_decay': 0.0001},
{'params': [C_network[0].bias]}, {'params': [C_network[0].bias]},
{'params': [C_network[2].weight], 'weight_decay': 0.00001},
{'params': [C_network[2].bias]}, {'params': [C_network[2].bias]},
{'params': [C_network[4].weight], 'weight_decay': 0.000001},
{'params': [C_network[4].bias]}, {'params': [C_network[4].bias]},
{'params': C_network[6].parameters()}, {'params': [C_network[6].bias]},
# N networks # N networks
{'params': [N_network[0].weight], 'weight_decay': 0.0001},
{'params': [N_network[0].bias]}, {'params': [N_network[0].bias]},
{'params': [N_network[2].weight], 'weight_decay': 0.00001},
{'params': [N_network[2].bias]}, {'params': [N_network[2].bias]},
{'params': [N_network[4].weight], 'weight_decay': 0.000001},
{'params': [N_network[4].bias]}, {'params': [N_network[4].bias]},
{'params': N_network[6].parameters()}, {'params': [N_network[6].bias]},
# O networks # O networks
{'params': [O_network[0].weight], 'weight_decay': 0.0001},
{'params': [O_network[0].bias]}, {'params': [O_network[0].bias]},
{'params': [O_network[2].weight], 'weight_decay': 0.00001},
{'params': [O_network[2].bias]}, {'params': [O_network[2].bias]},
{'params': [O_network[4].weight], 'weight_decay': 0.000001},
{'params': [O_network[4].bias]}, {'params': [O_network[4].bias]},
{'params': O_network[6].parameters()}, {'params': [O_network[6].bias]},
]) ], lr=1e-3)
############################################################################### ###############################################################################
# Setting up a learning rate scheduler to do learning rate decay # Setting up a learning rate scheduler to do learning rate decay
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=100, threshold=0) AdamW_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(AdamW, factor=0.5, patience=100, threshold=0)
SGD_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(SGD, factor=0.5, patience=100, threshold=0)
############################################################################### ###############################################################################
# Train the model by minimizing the MSE loss, until validation RMSE no longer # Train the model by minimizing the MSE loss, until validation RMSE no longer
...@@ -254,9 +262,11 @@ latest_checkpoint = 'latest.pt' ...@@ -254,9 +262,11 @@ latest_checkpoint = 'latest.pt'
# Resume training from previously saved checkpoints: # Resume training from previously saved checkpoints:
if os.path.isfile(latest_checkpoint): if os.path.isfile(latest_checkpoint):
checkpoint = torch.load(latest_checkpoint) checkpoint = torch.load(latest_checkpoint)
model.load_state_dict(checkpoint['nn']) nn.load_state_dict(checkpoint['nn'])
optimizer.load_state_dict(checkpoint['optimizer']) AdamW.load_state_dict(checkpoint['AdamW'])
scheduler.load_state_dict(checkpoint['scheduler']) SGD.load_state_dict(checkpoint['SGD'])
AdamW_scheduler.load_state_dict(checkpoint['AdamW_scheduler'])
SGD_scheduler.load_state_dict(checkpoint['SGD_scheduler'])
############################################################################### ###############################################################################
# During training, we need to validate on validation set and if validation error # During training, we need to validate on validation set and if validation error
...@@ -297,50 +307,63 @@ tensorboard = torch.utils.tensorboard.SummaryWriter() ...@@ -297,50 +307,63 @@ tensorboard = torch.utils.tensorboard.SummaryWriter()
# set to a much larger value # set to a much larger value
mse = torch.nn.MSELoss(reduction='none') mse = torch.nn.MSELoss(reduction='none')
print("training starting from epoch", scheduler.last_epoch + 1) print("training starting from epoch", AdamW_scheduler.last_epoch + 1)
max_epochs = 200 max_epochs = 200
early_stopping_learning_rate = 1.0E-5 early_stopping_learning_rate = 1.0E-5
best_model_checkpoint = 'best.pt' best_model_checkpoint = 'best.pt'
for _ in range(scheduler.last_epoch + 1, max_epochs): for _ in range(AdamW_scheduler.last_epoch + 1, max_epochs):
rmse = validate() rmse = validate()
print('RMSE:', rmse, 'at epoch', scheduler.last_epoch + 1) print('RMSE:', rmse, 'at epoch', AdamW_scheduler.last_epoch + 1)
learning_rate = optimizer.param_groups[0]['lr'] learning_rate = AdamW.param_groups[0]['lr']
if learning_rate < early_stopping_learning_rate: if learning_rate < early_stopping_learning_rate:
break break
tensorboard.add_scalar('validation_rmse', rmse, scheduler.last_epoch + 1)
tensorboard.add_scalar('best_validation_rmse', scheduler.best, scheduler.last_epoch + 1)
tensorboard.add_scalar('learning_rate', learning_rate, scheduler.last_epoch + 1)
# checkpoint # checkpoint
if scheduler.is_better(rmse, scheduler.best): if AdamW_scheduler.is_better(rmse, AdamW_scheduler.best):
torch.save(nn.state_dict(), best_model_checkpoint) torch.save(nn.state_dict(), best_model_checkpoint)
scheduler.step(rmse) AdamW_scheduler.step(rmse)
SGD_scheduler.step(rmse)
tensorboard.add_scalar('validation_rmse', rmse, AdamW_scheduler.last_epoch)
tensorboard.add_scalar('best_validation_rmse', AdamW_scheduler.best, AdamW_scheduler.last_epoch)
tensorboard.add_scalar('learning_rate', learning_rate, AdamW_scheduler.last_epoch)
for i, (batch_x, batch_y) in tqdm.tqdm(
enumerate(training),
total=len(training),
desc="epoch {}".format(AdamW_scheduler.last_epoch)
):
for i, (batch_x, batch_y) in tqdm.tqdm(enumerate(training), total=len(training)):
true_energies = batch_y['energies'] true_energies = batch_y['energies']
predicted_energies = [] predicted_energies = []
num_atoms = [] num_atoms = []
for chunk_species, chunk_coordinates in batch_x: for chunk_species, chunk_coordinates in batch_x:
num_atoms.append((chunk_species >= 0).to(true_energies.dtype).sum(dim=1)) num_atoms.append((chunk_species >= 0).to(true_energies.dtype).sum(dim=1))
_, chunk_energies = model((chunk_species, chunk_coordinates)) _, chunk_energies = model((chunk_species, chunk_coordinates))
predicted_energies.append(chunk_energies) predicted_energies.append(chunk_energies)
num_atoms = torch.cat(num_atoms) num_atoms = torch.cat(num_atoms)
predicted_energies = torch.cat(predicted_energies) predicted_energies = torch.cat(predicted_energies)
loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean() loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean()
optimizer.zero_grad()
AdamW.zero_grad()
SGD.zero_grad()
loss.backward() loss.backward()
optimizer.step() AdamW.step()
SGD.step()
# write current batch loss to TensorBoard # write current batch loss to TensorBoard
tensorboard.add_scalar('batch_loss', loss, scheduler.last_epoch * len(training) + i) tensorboard.add_scalar('batch_loss', loss, AdamW_scheduler.last_epoch * len(training) + i)
torch.save({ torch.save({
'nn': nn.state_dict(), 'nn': nn.state_dict(),
'optimizer': optimizer.state_dict(), 'AdamW': AdamW.state_dict(),
'scheduler': scheduler.state_dict(), 'SGD': SGD.state_dict(),
'AdamW_scheduler': AdamW_scheduler.state_dict(),
'SGD_scheduler': SGD_scheduler.state_dict(),
}, latest_checkpoint) }, latest_checkpoint)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment