# -*- coding: utf-8 -*- """ .. _training-example: Train Your Own Neural Network Potential ======================================= This example shows how to use TorchANI to train a neural network potential with the setup identical to NeuroChem. We will use the same configuration as specified in `inputtrain.ipt`_ .. _`inputtrain.ipt`: https://github.com/aiqm/torchani/blob/master/torchani/resources/ani-1x_8x/inputtrain.ipt .. note:: TorchANI provide tools to run NeuroChem training config file `inputtrain.ipt`. See: :ref:`neurochem-training`. """ ############################################################################### # To begin with, let's first import the modules and setup devices we will use: import torch import torchani import os import math import torch.utils.tensorboard import tqdm # helper function to convert energy unit from Hartree to kcal/mol from torchani.units import hartree2kcalmol # device to run the training device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') ############################################################################### # Now let's setup constants and construct an AEV computer. These numbers could # be found in `rHCNO-5.2R_16-3.5A_a4-8.params` # The atomic self energies given in `sae_linfit.dat`_ are computed from ANI-1x # dataset. These constants can be calculated for any given dataset if ``None`` # is provided as an argument to the object of :class:`EnergyShifter` class. # # .. note:: # # Besides defining these hyperparameters programmatically, # :mod:`torchani.neurochem` provide tools to read them from file. # # .. _rHCNO-5.2R_16-3.5A_a4-8.params: # https://github.com/aiqm/torchani/blob/master/torchani/resources/ani-1x_8x/rHCNO-5.2R_16-3.5A_a4-8.params # .. _sae_linfit.dat: # https://github.com/aiqm/torchani/blob/master/torchani/resources/ani-1x_8x/sae_linfit.dat Rcr = 5.2000e+00 Rca = 3.5000e+00 EtaR = torch.tensor([1.6000000e+01], device=device) ShfR = torch.tensor([9.0000000e-01, 1.1687500e+00, 1.4375000e+00, 1.7062500e+00, 1.9750000e+00, 2.2437500e+00, 2.5125000e+00, 2.7812500e+00, 3.0500000e+00, 3.3187500e+00, 3.5875000e+00, 3.8562500e+00, 4.1250000e+00, 4.3937500e+00, 4.6625000e+00, 4.9312500e+00], device=device) Zeta = torch.tensor([3.2000000e+01], device=device) ShfZ = torch.tensor([1.9634954e-01, 5.8904862e-01, 9.8174770e-01, 1.3744468e+00, 1.7671459e+00, 2.1598449e+00, 2.5525440e+00, 2.9452431e+00], device=device) EtaA = torch.tensor([8.0000000e+00], device=device) ShfA = torch.tensor([9.0000000e-01, 1.5500000e+00, 2.2000000e+00, 2.8500000e+00], device=device) num_species = 4 aev_computer = torchani.AEVComputer(Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species) energy_shifter = torchani.utils.EnergyShifter(None) species_to_tensor = torchani.utils.ChemicalSymbolsToInts(['H', 'C', 'N', 'O']) ############################################################################### # Now let's setup datasets. These paths assumes the user run this script under # the ``examples`` directory of TorchANI's repository. If you download this # script, you should manually set the path of these files in your system before # this script can run successfully. # # Also note that we need to subtracting energies by the self energies of all # atoms for each molecule. This makes the range of energies in a reasonable # range. The second argument defines how to convert species as a list of string # to tensor, that is, for all supported chemical symbols, which is correspond to # ``0``, which correspond to ``1``, etc. try: path = os.path.dirname(os.path.realpath(__file__)) except NameError: path = os.getcwd() dspath = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5') batch_size = 2560 training, validation = torchani.data.load_ani_dataset( dspath, species_to_tensor, batch_size, rm_outlier=True, device=device, transform=[energy_shifter.subtract_from_dataset], split=[0.8, None]) print('Self atomic energies: ', energy_shifter.self_energies) ############################################################################### # When iterating the dataset, we will get pairs of input and output # ``(species_coordinates, properties)``, where ``species_coordinates`` is the # input and ``properties`` is the output. # # ``species_coordinates`` is a list of species-coordinate pairs, with shape # ``(N, Na)`` and ``(N, Na, 3)``. The reason for getting this type is, when # loading the dataset and generating minibatches, the whole dataset are # shuffled and each minibatch contains structures of molecules with a wide # range of number of atoms. Molecules of different number of atoms are batched # into single by padding. The way padding works is: adding ghost atoms, with # species 'X', and do computations as if they were normal atoms. But when # computing AEVs, atoms with species `X` would be ignored. To avoid computation # wasting on padding atoms, minibatches are further splitted into chunks. Each # chunk contains structures of molecules of similar size, which minimize the # total number of padding atoms required to add. The input list # ``species_coordinates`` contains chunks of that minibatch we are getting. The # batching and chunking happens automatically, so the user does not need to # worry how to construct chunks, but the user need to compute the energies for # each chunk and concat them into single tensor. # # The output, i.e. ``properties`` is a dictionary holding each property. This # allows us to extend TorchANI in the future to training forces and properties. # ############################################################################### # Now let's define atomic neural networks. H_network = torch.nn.Sequential( torch.nn.Linear(384, 160), torch.nn.CELU(0.1), torch.nn.Linear(160, 128), torch.nn.CELU(0.1), torch.nn.Linear(128, 96), torch.nn.CELU(0.1), torch.nn.Linear(96, 1) ) C_network = torch.nn.Sequential( torch.nn.Linear(384, 144), torch.nn.CELU(0.1), torch.nn.Linear(144, 112), torch.nn.CELU(0.1), torch.nn.Linear(112, 96), torch.nn.CELU(0.1), torch.nn.Linear(96, 1) ) N_network = torch.nn.Sequential( torch.nn.Linear(384, 128), torch.nn.CELU(0.1), torch.nn.Linear(128, 112), torch.nn.CELU(0.1), torch.nn.Linear(112, 96), torch.nn.CELU(0.1), torch.nn.Linear(96, 1) ) O_network = torch.nn.Sequential( torch.nn.Linear(384, 128), torch.nn.CELU(0.1), torch.nn.Linear(128, 112), torch.nn.CELU(0.1), torch.nn.Linear(112, 96), torch.nn.CELU(0.1), torch.nn.Linear(96, 1) ) nn = torchani.ANIModel([H_network, C_network, N_network, O_network]) print(nn) ############################################################################### # Initialize the weights and biases. # # .. note:: # Pytorch default initialization for the weights and biases in linear layers # is Kaiming uniform. See: `TORCH.NN.MODULES.LINEAR`_ # We initialize the weights similarly but from the normal distribution. # The biases were initialized to zero. # # .. _TORCH.NN.MODULES.LINEAR: # https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear def init_params(m): if isinstance(m, torch.nn.Linear): torch.nn.init.kaiming_normal_(m.weight, a=1.0) torch.nn.init.zeros_(m.bias) nn.apply(init_params) ############################################################################### # Let's now create a pipeline of AEV Computer --> Neural Networks. model = torchani.nn.Sequential(aev_computer, nn).to(device) ############################################################################### # Now let's setup the optimizers. NeuroChem uses Adam with decoupled weight decay # to updates the weights and Stochastic Gradient Descent (SGD) to update the biases. # Moreover, we need to specify different weight decay rate for different layes. # # .. note:: # # The weight decay in `inputtrain.ipt`_ is named "l2", but it is actually not # L2 regularization. The confusion between L2 and weight decay is a common # mistake in deep learning. See: `Decoupled Weight Decay Regularization`_ # Also note that the weight decay only applies to weight in the training # of ANI models, not bias. # # .. _Decoupled Weight Decay Regularization: # https://arxiv.org/abs/1711.05101 AdamW = torchani.optim.AdamW([ # H networks {'params': [H_network[0].weight]}, {'params': [H_network[2].weight], 'weight_decay': 0.00001}, {'params': [H_network[4].weight], 'weight_decay': 0.000001}, {'params': [H_network[6].weight]}, # C networks {'params': [C_network[0].weight]}, {'params': [C_network[2].weight], 'weight_decay': 0.00001}, {'params': [C_network[4].weight], 'weight_decay': 0.000001}, {'params': [C_network[6].weight]}, # N networks {'params': [N_network[0].weight]}, {'params': [N_network[2].weight], 'weight_decay': 0.00001}, {'params': [N_network[4].weight], 'weight_decay': 0.000001}, {'params': [N_network[6].weight]}, # O networks {'params': [O_network[0].weight]}, {'params': [O_network[2].weight], 'weight_decay': 0.00001}, {'params': [O_network[4].weight], 'weight_decay': 0.000001}, {'params': [O_network[6].weight]}, ]) SGD = torch.optim.SGD([ # H networks {'params': [H_network[0].bias]}, {'params': [H_network[2].bias]}, {'params': [H_network[4].bias]}, {'params': [H_network[6].bias]}, # C networks {'params': [C_network[0].bias]}, {'params': [C_network[2].bias]}, {'params': [C_network[4].bias]}, {'params': [C_network[6].bias]}, # N networks {'params': [N_network[0].bias]}, {'params': [N_network[2].bias]}, {'params': [N_network[4].bias]}, {'params': [N_network[6].bias]}, # O networks {'params': [O_network[0].bias]}, {'params': [O_network[2].bias]}, {'params': [O_network[4].bias]}, {'params': [O_network[6].bias]}, ], lr=1e-3) ############################################################################### # Setting up a learning rate scheduler to do learning rate decay AdamW_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(AdamW, factor=0.5, patience=100, threshold=0) SGD_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(SGD, factor=0.5, patience=100, threshold=0) ############################################################################### # Train the model by minimizing the MSE loss, until validation RMSE no longer # improves during a certain number of steps, decay the learning rate and repeat # the same process, stop until the learning rate is smaller than a threshold. # # We first read the checkpoint files to restart training. We use `latest.pt` # to store current training state. latest_checkpoint = 'latest.pt' ############################################################################### # Resume training from previously saved checkpoints: if os.path.isfile(latest_checkpoint): checkpoint = torch.load(latest_checkpoint) nn.load_state_dict(checkpoint['nn']) AdamW.load_state_dict(checkpoint['AdamW']) SGD.load_state_dict(checkpoint['SGD']) AdamW_scheduler.load_state_dict(checkpoint['AdamW_scheduler']) SGD_scheduler.load_state_dict(checkpoint['SGD_scheduler']) ############################################################################### # During training, we need to validate on validation set and if validation error # is better than the best, then save the new best model to a checkpoint def validate(): # run validation mse_sum = torch.nn.MSELoss(reduction='sum') total_mse = 0.0 count = 0 for batch_x, batch_y in validation: true_energies = batch_y['energies'] predicted_energies = [] for chunk_species, chunk_coordinates in batch_x: chunk_energies = model((chunk_species, chunk_coordinates)).energies predicted_energies.append(chunk_energies) predicted_energies = torch.cat(predicted_energies) total_mse += mse_sum(predicted_energies, true_energies).item() count += predicted_energies.shape[0] return hartree2kcalmol(math.sqrt(total_mse / count)) ############################################################################### # We will also use TensorBoard to visualize our training process tensorboard = torch.utils.tensorboard.SummaryWriter() ############################################################################### # Finally, we come to the training loop. # # In this tutorial, we are setting the maximum epoch to a very small number, # only to make this demo terminate fast. For serious training, this should be # set to a much larger value mse = torch.nn.MSELoss(reduction='none') print("training starting from epoch", AdamW_scheduler.last_epoch + 1) max_epochs = 10 early_stopping_learning_rate = 1.0E-5 best_model_checkpoint = 'best.pt' for _ in range(AdamW_scheduler.last_epoch + 1, max_epochs): rmse = validate() print('RMSE:', rmse, 'at epoch', AdamW_scheduler.last_epoch + 1) learning_rate = AdamW.param_groups[0]['lr'] if learning_rate < early_stopping_learning_rate: break # checkpoint if AdamW_scheduler.is_better(rmse, AdamW_scheduler.best): torch.save(nn.state_dict(), best_model_checkpoint) AdamW_scheduler.step(rmse) SGD_scheduler.step(rmse) tensorboard.add_scalar('validation_rmse', rmse, AdamW_scheduler.last_epoch) tensorboard.add_scalar('best_validation_rmse', AdamW_scheduler.best, AdamW_scheduler.last_epoch) tensorboard.add_scalar('learning_rate', learning_rate, AdamW_scheduler.last_epoch) for i, (batch_x, batch_y) in tqdm.tqdm( enumerate(training), total=len(training), desc="epoch {}".format(AdamW_scheduler.last_epoch) ): true_energies = batch_y['energies'] predicted_energies = [] num_atoms = [] for chunk_species, chunk_coordinates in batch_x: num_atoms.append((chunk_species >= 0).to(true_energies.dtype).sum(dim=1)) chunk_energies = model((chunk_species, chunk_coordinates)).energies predicted_energies.append(chunk_energies) num_atoms = torch.cat(num_atoms) predicted_energies = torch.cat(predicted_energies) loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean() AdamW.zero_grad() SGD.zero_grad() loss.backward() AdamW.step() SGD.step() # write current batch loss to TensorBoard tensorboard.add_scalar('batch_loss', loss, AdamW_scheduler.last_epoch * len(training) + i) torch.save({ 'nn': nn.state_dict(), 'AdamW': AdamW.state_dict(), 'SGD': SGD.state_dict(), 'AdamW_scheduler': AdamW_scheduler.state_dict(), 'SGD_scheduler': SGD_scheduler.state_dict(), }, latest_checkpoint)