Unverified Commit f489b091 authored by Farhad Ramezanghorbani's avatar Farhad Ramezanghorbani Committed by GitHub
Browse files

Update force training example (#283)

* Update nnp_training to have identical setup to NC

* fix

* add LR decay scheduler for both optimizers

* fix

* update example files

* fix

* fix
parent 0fbd69e9
......@@ -14,6 +14,7 @@ that script to train to force.
# Most part of the script are the same as :ref:`training-example`, we will omit
# the comments for these parts. Please refer to :ref:`training-example` for more
# information
import torch
import torchani
import os
......@@ -33,12 +34,7 @@ EtaA = torch.tensor([8.0000000e+00], device=device)
ShfA = torch.tensor([9.0000000e-01, 1.5500000e+00, 2.2000000e+00, 2.8500000e+00], device=device)
num_species = 4
aev_computer = torchani.AEVComputer(Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species)
energy_shifter = torchani.utils.EnergyShifter([
-0.600952980000, # H
-38.08316124000, # C
-54.70775770000, # N
-75.19446356000, # O
])
energy_shifter = torchani.utils.EnergyShifter(None)
species_to_tensor = torchani.utils.ChemicalSymbolsToInts('HCNO')
......@@ -60,6 +56,8 @@ training, validation = torchani.data.load_ani_dataset(
atomic_properties=['forces'],
transform=[energy_shifter.subtract_from_dataset], split=[0.8, None])
print('Self atomic energies: ', energy_shifter.self_energies)
###############################################################################
# When iterating the dataset, we will get pairs of input and output
# ``(species_coordinates, properties)``, in this case, ``properties`` would
......@@ -76,7 +74,6 @@ print(list(atomic_properties[0].keys()))
# Due to padding, part of the forces might be 0
print(atomic_properties[0]['forces'][0])
###############################################################################
# The code to define networks, optimizers, are mostly the same
......@@ -121,18 +118,106 @@ O_network = torch.nn.Sequential(
)
nn = torchani.ANIModel([H_network, C_network, N_network, O_network])
print(nn)
###############################################################################
# Initialize the weights and biases.
#
# .. note::
# Pytorch default initialization for the weights and biases in linear layers
# is Kaiming uniform. See: `TORCH.NN.MODULES.LINEAR`_
# We initialize the weights similarly but from the normal distribution.
# The biases were initialized to zero.
#
# .. _TORCH.NN.MODULES.LINEAR:
# https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
def init_params(m):
if isinstance(m, torch.nn.Linear):
torch.nn.init.kaiming_normal_(m.weight, a=1.0)
torch.nn.init.zeros_(m.bias)
nn.apply(init_params)
###############################################################################
# Let's now create a pipeline of AEV Computer --> Neural Networks.
model = torch.nn.Sequential(aev_computer, nn).to(device)
###############################################################################
# Here we will turn off weight decay
optimizer = torch.optim.Adam(nn.parameters())
# Here we will use Adam with weight decay for the weights and Stochastic Gradient
# Descent for biases.
AdamW = torchani.optim.AdamW([
# H networks
{'params': [H_network[0].weight]},
{'params': [H_network[2].weight], 'weight_decay': 0.00001},
{'params': [H_network[4].weight], 'weight_decay': 0.000001},
{'params': [H_network[6].weight]},
# C networks
{'params': [C_network[0].weight]},
{'params': [C_network[2].weight], 'weight_decay': 0.00001},
{'params': [C_network[4].weight], 'weight_decay': 0.000001},
{'params': [C_network[6].weight]},
# N networks
{'params': [N_network[0].weight]},
{'params': [N_network[2].weight], 'weight_decay': 0.00001},
{'params': [N_network[4].weight], 'weight_decay': 0.000001},
{'params': [N_network[6].weight]},
# O networks
{'params': [O_network[0].weight]},
{'params': [O_network[2].weight], 'weight_decay': 0.00001},
{'params': [O_network[4].weight], 'weight_decay': 0.000001},
{'params': [O_network[6].weight]},
])
SGD = torch.optim.SGD([
# H networks
{'params': [H_network[0].bias]},
{'params': [H_network[2].bias]},
{'params': [H_network[4].bias]},
{'params': [H_network[6].bias]},
# C networks
{'params': [C_network[0].bias]},
{'params': [C_network[2].bias]},
{'params': [C_network[4].bias]},
{'params': [C_network[6].bias]},
# N networks
{'params': [N_network[0].bias]},
{'params': [N_network[2].bias]},
{'params': [N_network[4].bias]},
{'params': [N_network[6].bias]},
# O networks
{'params': [O_network[0].bias]},
{'params': [O_network[2].bias]},
{'params': [O_network[4].bias]},
{'params': [O_network[6].bias]},
], lr=1e-3)
AdamW_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(AdamW, factor=0.5, patience=100, threshold=0)
SGD_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(SGD, factor=0.5, patience=100, threshold=0)
###############################################################################
# This part of the code is also the same
latest_checkpoint = 'force-training-latest.pt'
pretrained = os.path.isfile(latest_checkpoint)
###############################################################################
# Resume training from previously saved checkpoints:
if os.path.isfile(latest_checkpoint):
checkpoint = torch.load(latest_checkpoint)
nn.load_state_dict(checkpoint['nn'])
AdamW.load_state_dict(checkpoint['AdamW'])
SGD.load_state_dict(checkpoint['SGD'])
AdamW_scheduler.load_state_dict(checkpoint['AdamW_scheduler'])
SGD_scheduler.load_state_dict(checkpoint['SGD_scheduler'])
###############################################################################
# During training, we need to validate on validation set and if validation error
# is better than the best, then save the new best model to a checkpoint
# helper function to convert energy unit from Hartree to kcal/mol
def hartree2kcal(x):
return 627.509 * x
......@@ -154,78 +239,48 @@ def validate():
return hartree2kcal(math.sqrt(total_mse / count))
pretrain_criterion = 10 # kcal/mol
mse = torch.nn.MSELoss(reduction='none')
###############################################################################
# For simplicity, we don't train to force during pretraining
if not pretrained:
print("pre-training...")
epoch = 0
rmse = math.inf
pretrain_optimizer = torch.optim.Adam(nn.parameters())
while rmse > pretrain_criterion:
for batch_x, batch_y in tqdm.tqdm(training):
true_energies = batch_y['energies']
predicted_energies = []
num_atoms = []
for chunk_species, chunk_coordinates in batch_x:
num_atoms.append((chunk_species >= 0).sum(dim=1))
_, chunk_energies = model((chunk_species, chunk_coordinates))
predicted_energies.append(chunk_energies)
num_atoms = torch.cat(num_atoms).to(true_energies.dtype)
predicted_energies = torch.cat(predicted_energies)
loss = (mse(predicted_energies, true_energies) / num_atoms).mean()
pretrain_optimizer.zero_grad()
loss.backward()
optimizer.step()
rmse = validate()
print('RMSE:', rmse, 'Target RMSE:', pretrain_criterion)
torch.save({
'nn': nn.state_dict(),
'optimizer': optimizer.state_dict(),
}, latest_checkpoint)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=100)
# We will also use TensorBoard to visualize our training process
tensorboard = torch.utils.tensorboard.SummaryWriter()
checkpoint = torch.load(latest_checkpoint)
nn.load_state_dict(checkpoint['nn'])
optimizer.load_state_dict(checkpoint['optimizer'])
if 'scheduler' in checkpoint:
scheduler.load_state_dict(checkpoint['scheduler'])
###############################################################################
# In the training loop, we need to compute force, and loss for forces
print("training starting from epoch", scheduler.last_epoch + 1)
mse = torch.nn.MSELoss(reduction='none')
print("training starting from epoch", AdamW_scheduler.last_epoch + 1)
max_epochs = 20
early_stopping_learning_rate = 1.0E-5
force_coefficient = 0.1 # controls the importance of energy loss vs force loss
best_model_checkpoint = 'force-training-best.pt'
for _ in range(scheduler.last_epoch + 1, max_epochs):
for _ in range(AdamW_scheduler.last_epoch + 1, max_epochs):
rmse = validate()
print('RMSE:', rmse, 'at epoch', scheduler.last_epoch)
print('RMSE:', rmse, 'at epoch', AdamW_scheduler.last_epoch + 1)
learning_rate = optimizer.param_groups[0]['lr']
learning_rate = AdamW.param_groups[0]['lr']
if learning_rate < early_stopping_learning_rate:
break
tensorboard.add_scalar('validation_rmse', rmse, scheduler.last_epoch)
tensorboard.add_scalar('best_validation_rmse', scheduler.best, scheduler.last_epoch)
tensorboard.add_scalar('learning_rate', learning_rate, scheduler.last_epoch)
# checkpoint
if scheduler.is_better(rmse, scheduler.best):
if AdamW_scheduler.is_better(rmse, AdamW_scheduler.best):
torch.save(nn.state_dict(), best_model_checkpoint)
scheduler.step(rmse)
AdamW_scheduler.step(rmse)
SGD_scheduler.step(rmse)
tensorboard.add_scalar('validation_rmse', rmse, AdamW_scheduler.last_epoch)
tensorboard.add_scalar('best_validation_rmse', AdamW_scheduler.best, AdamW_scheduler.last_epoch)
tensorboard.add_scalar('learning_rate', learning_rate, AdamW_scheduler.last_epoch)
# Besides being stored in x, species and coordinates are also stored in y.
# So here, for simplicity, we just ignore the x and use y for everything.
for i, (_, batch_y) in tqdm.tqdm(enumerate(training), total=len(training)):
for i, (_, batch_y) in tqdm.tqdm(
enumerate(training),
total=len(training),
desc="epoch {}".format(AdamW_scheduler.last_epoch)
):
true_energies = batch_y['energies']
predicted_energies = []
num_atoms = []
......@@ -261,20 +316,23 @@ for _ in range(scheduler.last_epoch + 1, max_epochs):
predicted_energies = torch.cat(predicted_energies)
# Now the total loss has two parts, energy loss and force loss
energy_loss = (mse(predicted_energies, true_energies) / num_atoms).mean()
energy_loss = 0.5 * (torch.exp(2 * energy_loss) - 1)
energy_loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean()
force_loss = torch.cat(force_loss).mean()
loss = energy_loss + force_coefficient * force_loss
optimizer.zero_grad()
AdamW.zero_grad()
SGD.zero_grad()
loss.backward()
optimizer.step()
AdamW.step()
SGD.step()
# write current batch loss to TensorBoard
tensorboard.add_scalar('batch_loss', loss, scheduler.last_epoch * len(training) + i)
tensorboard.add_scalar('batch_loss', loss, AdamW_scheduler.last_epoch * len(training) + i)
torch.save({
'nn': nn.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'AdamW': AdamW.state_dict(),
'SGD': SGD.state_dict(),
'AdamW_scheduler': AdamW_scheduler.state_dict(),
'SGD_scheduler': SGD_scheduler.state_dict(),
}, latest_checkpoint)
......@@ -8,13 +8,14 @@ Train Your Own Neural Network Potential, Using PyTorch-Ignite
We have seen how to train a neural network potential by manually writing
training loop in :ref:`training-example`. TorchANI provide tools to work
with PyTorch-Ignite to simplify the writing of training code. This tutorial
shows how to use these tools to train a demo model.
shows how to use these tools to train a demo model. The setup in this demo is
not necessarily identical to NeuroChem.
This tutorial assumes readers have read :ref:`training-example`.
"""
###############################################################################
# To begin with, let's first import the modules we will use:
# To begin with, let's first import the modules and setup devices we will use:
import torch
import ignite
import torchani
......@@ -23,6 +24,9 @@ import os
import ignite.contrib.handlers
import torch.utils.tensorboard
# device to run the training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
###############################################################################
# Now let's setup training hyperparameters and dataset.
......@@ -45,11 +49,8 @@ max_epochs = 20
# check the training RMSE to see overfitting.
training_rmse_every = 5
# device to run the training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# batch size
batch_size = 1024
batch_size = 2560
# log directory for tensorboard
log = 'runs'
......@@ -59,10 +60,9 @@ log = 'runs'
# Instead of manually specifying hyperparameters as in :ref:`training-example`,
# here we will load them from files.
const_file = os.path.join(path, '../torchani/resources/ani-1x_8x/rHCNO-5.2R_16-3.5A_a4-8.params') # noqa: E501
sae_file = os.path.join(path, '../torchani/resources/ani-1x_8x/sae_linfit.dat') # noqa: E501
consts = torchani.neurochem.Constants(const_file)
aev_computer = torchani.AEVComputer(**consts)
energy_shifter = torchani.neurochem.load_sae(sae_file)
energy_shifter = torchani.utils.EnergyShifter(None)
###############################################################################
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment