Commit e1c78efe authored by Farhad Ramezanghorbani's avatar Farhad Ramezanghorbani Committed by Gao, Xiang
Browse files

Updated nnp_training tutorial to be as close to NeuroChem as possible (#245)

parent dc8930ee
...@@ -156,6 +156,27 @@ O_network = torch.nn.Sequential( ...@@ -156,6 +156,27 @@ O_network = torch.nn.Sequential(
nn = torchani.ANIModel([H_network, C_network, N_network, O_network]) nn = torchani.ANIModel([H_network, C_network, N_network, O_network])
print(nn) print(nn)
###############################################################################
# Initialize the weights and biases.
#
# .. note::
# Pytorch default initialization for the weights and biases in linear layers
# is Kaiming uniform. See: `TORCH.NN.MODULES.LINEAR`_
# We initialize the weights similarly but from the normal distribution.
# The biases were initialized to zero.
#
# .. _TORCH.NN.MODULES.LINEAR:
# https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
def init_params(m):
if isinstance(m, torch.nn.Linear):
torch.nn.init.kaiming_normal_(m.weight, a=1.0)
torch.nn.init.zeros_(m.bias)
nn.apply(init_params)
############################################################################### ###############################################################################
# Let's now create a pipeline of AEV Computer --> Neural Networks. # Let's now create a pipeline of AEV Computer --> Neural Networks.
model = torch.nn.Sequential(aev_computer, nn).to(device) model = torch.nn.Sequential(aev_computer, nn).to(device)
...@@ -218,19 +239,25 @@ optimizer = torchani.optim.AdamW([ ...@@ -218,19 +239,25 @@ optimizer = torchani.optim.AdamW([
]) ])
############################################################################### ###############################################################################
# The way ANI trains a neural network potential looks like this: # Setting up a learning rate scheduler to do learning rate decay
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=100, threshold=0)
# Phase 1: Pretrain the model by minimizing MSE loss
# ###############################################################################
# Phase 2: Train the model by minimizing the exponential loss, until validation # Train the model by minimizing the MSE loss, until validation RMSE no longer
# RMSE no longer improves for a certain steps, decay the learning rate and repeat # improves during a certain number of steps, decay the learning rate and repeat
# the same process, stop until the learning rate is smaller than a certain number. # the same process, stop until the learning rate is smaller than a threshold.
# #
# We first read the checkpoint files to find where we are. We use `latest.pt` # We first read the checkpoint files to restart training. We use `latest.pt`
# to store current training state. If `latest.pt` does not exist, this # to store current training state.
# this means the pretraining has not been finished yet.
latest_checkpoint = 'latest.pt' latest_checkpoint = 'latest.pt'
pretrained = os.path.isfile(latest_checkpoint)
###############################################################################
# Resume training from previously saved checkpoints:
if os.path.isfile(latest_checkpoint):
checkpoint = torch.load(latest_checkpoint)
model.load_state_dict(checkpoint['nn'])
optimizer.load_state_dict(checkpoint['optimizer'])
scheduler.load_state_dict(checkpoint['scheduler'])
############################################################################### ###############################################################################
# During training, we need to validate on validation set and if validation error # During training, we need to validate on validation set and if validation error
...@@ -259,61 +286,18 @@ def validate(): ...@@ -259,61 +286,18 @@ def validate():
return hartree2kcal(math.sqrt(total_mse / count)) return hartree2kcal(math.sqrt(total_mse / count))
###############################################################################
# If the model is not pretrained yet, we need to run the pretrain.
pretrain_criterion = 10 # kcal/mol
mse = torch.nn.MSELoss(reduction='none')
if not pretrained:
print("pre-training...")
epoch = 0
rmse = math.inf
pretrain_optimizer = torch.optim.Adam(nn.parameters())
while rmse > pretrain_criterion:
for batch_x, batch_y in tqdm.tqdm(training):
true_energies = batch_y['energies']
predicted_energies = []
num_atoms = []
for chunk_species, chunk_coordinates in batch_x:
num_atoms.append((chunk_species >= 0).sum(dim=1))
_, chunk_energies = model((chunk_species, chunk_coordinates))
predicted_energies.append(chunk_energies)
num_atoms = torch.cat(num_atoms).to(true_energies.dtype)
predicted_energies = torch.cat(predicted_energies)
loss = (mse(predicted_energies, true_energies) / num_atoms).mean()
pretrain_optimizer.zero_grad()
loss.backward()
optimizer.step()
rmse = validate()
print('RMSE:', rmse, 'Target RMSE:', pretrain_criterion)
torch.save({
'nn': nn.state_dict(),
'optimizer': optimizer.state_dict(),
}, latest_checkpoint)
###############################################################################
# For phase 2, we need a learning rate scheduler to do learning rate decay
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=100)
############################################################################### ###############################################################################
# We will also use TensorBoard to visualize our training process # We will also use TensorBoard to visualize our training process
tensorboard = torch.utils.tensorboard.SummaryWriter() tensorboard = torch.utils.tensorboard.SummaryWriter()
###############################################################################
# Resume training from previously saved checkpoints:
checkpoint = torch.load(latest_checkpoint)
nn.load_state_dict(checkpoint['nn'])
optimizer.load_state_dict(checkpoint['optimizer'])
if 'scheduler' in checkpoint:
scheduler.load_state_dict(checkpoint['scheduler'])
############################################################################### ###############################################################################
# Finally, we come to the training loop. # Finally, we come to the training loop.
# #
# In this tutorial, we are setting the maximum epoch to a very small number, # In this tutorial, we are setting the maximum epoch to a very small number,
# only to make this demo terminate fast. For serious training, this should be # only to make this demo terminate fast. For serious training, this should be
# set to a much larger value # set to a much larger value
mse = torch.nn.MSELoss(reduction='none')
print("training starting from epoch", scheduler.last_epoch + 1) print("training starting from epoch", scheduler.last_epoch + 1)
max_epochs = 200 max_epochs = 200
early_stopping_learning_rate = 1.0E-5 early_stopping_learning_rate = 1.0E-5
...@@ -321,16 +305,16 @@ best_model_checkpoint = 'best.pt' ...@@ -321,16 +305,16 @@ best_model_checkpoint = 'best.pt'
for _ in range(scheduler.last_epoch + 1, max_epochs): for _ in range(scheduler.last_epoch + 1, max_epochs):
rmse = validate() rmse = validate()
print('RMSE:', rmse, 'at epoch', scheduler.last_epoch) print('RMSE:', rmse, 'at epoch', scheduler.last_epoch + 1)
learning_rate = optimizer.param_groups[0]['lr'] learning_rate = optimizer.param_groups[0]['lr']
if learning_rate < early_stopping_learning_rate: if learning_rate < early_stopping_learning_rate:
break break
tensorboard.add_scalar('validation_rmse', rmse, scheduler.last_epoch) tensorboard.add_scalar('validation_rmse', rmse, scheduler.last_epoch + 1)
tensorboard.add_scalar('best_validation_rmse', scheduler.best, scheduler.last_epoch) tensorboard.add_scalar('best_validation_rmse', scheduler.best, scheduler.last_epoch + 1)
tensorboard.add_scalar('learning_rate', learning_rate, scheduler.last_epoch) tensorboard.add_scalar('learning_rate', learning_rate, scheduler.last_epoch + 1)
# checkpoint # checkpoint
if scheduler.is_better(rmse, scheduler.best): if scheduler.is_better(rmse, scheduler.best):
...@@ -349,7 +333,6 @@ for _ in range(scheduler.last_epoch + 1, max_epochs): ...@@ -349,7 +333,6 @@ for _ in range(scheduler.last_epoch + 1, max_epochs):
num_atoms = torch.cat(num_atoms).to(true_energies.dtype) num_atoms = torch.cat(num_atoms).to(true_energies.dtype)
predicted_energies = torch.cat(predicted_energies) predicted_energies = torch.cat(predicted_energies)
loss = (mse(predicted_energies, true_energies) / num_atoms).mean() loss = (mse(predicted_energies, true_energies) / num_atoms).mean()
loss = 0.5 * (torch.exp(2 * loss) - 1)
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment