Commit e1c78efe authored by Farhad Ramezanghorbani's avatar Farhad Ramezanghorbani Committed by Gao, Xiang
Browse files

Updated nnp_training tutorial to be as close to NeuroChem as possible (#245)

parent dc8930ee
......@@ -156,6 +156,27 @@ O_network = torch.nn.Sequential(
nn = torchani.ANIModel([H_network, C_network, N_network, O_network])
print(nn)
###############################################################################
# Initialize the weights and biases.
#
# .. note::
# Pytorch default initialization for the weights and biases in linear layers
# is Kaiming uniform. See: `TORCH.NN.MODULES.LINEAR`_
# We initialize the weights similarly but from the normal distribution.
# The biases were initialized to zero.
#
# .. _TORCH.NN.MODULES.LINEAR:
# https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
def init_params(m):
if isinstance(m, torch.nn.Linear):
torch.nn.init.kaiming_normal_(m.weight, a=1.0)
torch.nn.init.zeros_(m.bias)
nn.apply(init_params)
###############################################################################
# Let's now create a pipeline of AEV Computer --> Neural Networks.
model = torch.nn.Sequential(aev_computer, nn).to(device)
......@@ -218,19 +239,25 @@ optimizer = torchani.optim.AdamW([
])
###############################################################################
# The way ANI trains a neural network potential looks like this:
#
# Phase 1: Pretrain the model by minimizing MSE loss
#
# Phase 2: Train the model by minimizing the exponential loss, until validation
# RMSE no longer improves for a certain steps, decay the learning rate and repeat
# the same process, stop until the learning rate is smaller than a certain number.
# Setting up a learning rate scheduler to do learning rate decay
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=100, threshold=0)
###############################################################################
# Train the model by minimizing the MSE loss, until validation RMSE no longer
# improves during a certain number of steps, decay the learning rate and repeat
# the same process, stop until the learning rate is smaller than a threshold.
#
# We first read the checkpoint files to find where we are. We use `latest.pt`
# to store current training state. If `latest.pt` does not exist, this
# this means the pretraining has not been finished yet.
# We first read the checkpoint files to restart training. We use `latest.pt`
# to store current training state.
latest_checkpoint = 'latest.pt'
pretrained = os.path.isfile(latest_checkpoint)
###############################################################################
# Resume training from previously saved checkpoints:
if os.path.isfile(latest_checkpoint):
checkpoint = torch.load(latest_checkpoint)
model.load_state_dict(checkpoint['nn'])
optimizer.load_state_dict(checkpoint['optimizer'])
scheduler.load_state_dict(checkpoint['scheduler'])
###############################################################################
# During training, we need to validate on validation set and if validation error
......@@ -259,61 +286,18 @@ def validate():
return hartree2kcal(math.sqrt(total_mse / count))
###############################################################################
# If the model is not pretrained yet, we need to run the pretrain.
pretrain_criterion = 10 # kcal/mol
mse = torch.nn.MSELoss(reduction='none')
if not pretrained:
print("pre-training...")
epoch = 0
rmse = math.inf
pretrain_optimizer = torch.optim.Adam(nn.parameters())
while rmse > pretrain_criterion:
for batch_x, batch_y in tqdm.tqdm(training):
true_energies = batch_y['energies']
predicted_energies = []
num_atoms = []
for chunk_species, chunk_coordinates in batch_x:
num_atoms.append((chunk_species >= 0).sum(dim=1))
_, chunk_energies = model((chunk_species, chunk_coordinates))
predicted_energies.append(chunk_energies)
num_atoms = torch.cat(num_atoms).to(true_energies.dtype)
predicted_energies = torch.cat(predicted_energies)
loss = (mse(predicted_energies, true_energies) / num_atoms).mean()
pretrain_optimizer.zero_grad()
loss.backward()
optimizer.step()
rmse = validate()
print('RMSE:', rmse, 'Target RMSE:', pretrain_criterion)
torch.save({
'nn': nn.state_dict(),
'optimizer': optimizer.state_dict(),
}, latest_checkpoint)
###############################################################################
# For phase 2, we need a learning rate scheduler to do learning rate decay
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=100)
###############################################################################
# We will also use TensorBoard to visualize our training process
tensorboard = torch.utils.tensorboard.SummaryWriter()
###############################################################################
# Resume training from previously saved checkpoints:
checkpoint = torch.load(latest_checkpoint)
nn.load_state_dict(checkpoint['nn'])
optimizer.load_state_dict(checkpoint['optimizer'])
if 'scheduler' in checkpoint:
scheduler.load_state_dict(checkpoint['scheduler'])
###############################################################################
# Finally, we come to the training loop.
#
# In this tutorial, we are setting the maximum epoch to a very small number,
# only to make this demo terminate fast. For serious training, this should be
# set to a much larger value
mse = torch.nn.MSELoss(reduction='none')
print("training starting from epoch", scheduler.last_epoch + 1)
max_epochs = 200
early_stopping_learning_rate = 1.0E-5
......@@ -321,16 +305,16 @@ best_model_checkpoint = 'best.pt'
for _ in range(scheduler.last_epoch + 1, max_epochs):
rmse = validate()
print('RMSE:', rmse, 'at epoch', scheduler.last_epoch)
print('RMSE:', rmse, 'at epoch', scheduler.last_epoch + 1)
learning_rate = optimizer.param_groups[0]['lr']
if learning_rate < early_stopping_learning_rate:
break
tensorboard.add_scalar('validation_rmse', rmse, scheduler.last_epoch)
tensorboard.add_scalar('best_validation_rmse', scheduler.best, scheduler.last_epoch)
tensorboard.add_scalar('learning_rate', learning_rate, scheduler.last_epoch)
tensorboard.add_scalar('validation_rmse', rmse, scheduler.last_epoch + 1)
tensorboard.add_scalar('best_validation_rmse', scheduler.best, scheduler.last_epoch + 1)
tensorboard.add_scalar('learning_rate', learning_rate, scheduler.last_epoch + 1)
# checkpoint
if scheduler.is_better(rmse, scheduler.best):
......@@ -349,7 +333,6 @@ for _ in range(scheduler.last_epoch + 1, max_epochs):
num_atoms = torch.cat(num_atoms).to(true_energies.dtype)
predicted_energies = torch.cat(predicted_energies)
loss = (mse(predicted_energies, true_energies) / num_atoms).mean()
loss = 0.5 * (torch.exp(2 * loss) - 1)
optimizer.zero_grad()
loss.backward()
optimizer.step()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment