nnp_training.py 13.7 KB
Newer Older
Gao, Xiang's avatar
Gao, Xiang committed
1
2
# -*- coding: utf-8 -*-
"""
Gao, Xiang's avatar
Gao, Xiang committed
3
4
.. _training-example:

Gao, Xiang's avatar
Gao, Xiang committed
5
6
7
Train Your Own Neural Network Potential
=======================================

8
9
10
This example shows how to use TorchANI to train a neural network potential
with the setup identical to NeuroChem. We will use the same configuration as
specified in `inputtrain.ipt`_
11
12
13
14
15
16
17

.. _`inputtrain.ipt`:
    https://github.com/aiqm/torchani/blob/master/torchani/resources/ani-1x_8x/inputtrain.ipt

.. note::
    TorchANI provide tools to run NeuroChem training config file `inputtrain.ipt`.
    See: :ref:`neurochem-training`.
18
19
20
21
22
23
24
25
26
27
28
29
30

.. warning::
    The training setup used in this file is configured to reproduce the original research
    at `Less is more: Sampling chemical space with active learning`_ as much as possible.
    That research was done on a different platform called NeuroChem which has many default
    options and technical details different from PyTorch. Some decisions made here
    (such as, using NeuroChem's initialization instead of PyTorch's default initialization)
    is not because it gives better result, but solely based on reproducing the original
    research. This file should not be interpreted as a suggestions to the readers on how
    they should setup their models.

.. _`Less is more: Sampling chemical space with active learning`:
    https://aip.scitation.org/doi/full/10.1063/1.5023802
Gao, Xiang's avatar
Gao, Xiang committed
31
32
33
"""

###############################################################################
34
# To begin with, let's first import the modules and setup devices we will use:
35

36
37
import torch
import torchani
Gao, Xiang's avatar
Gao, Xiang committed
38
import os
39
40
41
import math
import torch.utils.tensorboard
import tqdm
Gao, Xiang's avatar
Gao, Xiang committed
42

Ignacio Pickering's avatar
Ignacio Pickering committed
43
44
45
# helper function to convert energy unit from Hartree to kcal/mol
from torchani.units import hartree2kcalmol

46
47
# device to run the training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
Gao, Xiang's avatar
Gao, Xiang committed
48
49

###############################################################################
50
# Now let's setup constants and construct an AEV computer. These numbers could
51
52
53
54
# be found in `rHCNO-5.2R_16-3.5A_a4-8.params`
# The atomic self energies given in `sae_linfit.dat`_ are computed from ANI-1x
# dataset. These constants can be calculated for any given dataset if ``None``
# is provided as an argument to the object of :class:`EnergyShifter` class.
55
56
57
58
#
# .. note::
#
#   Besides defining these hyperparameters programmatically,
Gao, Xiang's avatar
Gao, Xiang committed
59
#   :mod:`torchani.neurochem` provide tools to read them from file.
60
61
62
63
64
#
# .. _rHCNO-5.2R_16-3.5A_a4-8.params:
#   https://github.com/aiqm/torchani/blob/master/torchani/resources/ani-1x_8x/rHCNO-5.2R_16-3.5A_a4-8.params
# .. _sae_linfit.dat:
#   https://github.com/aiqm/torchani/blob/master/torchani/resources/ani-1x_8x/sae_linfit.dat
65

66
67
68
69
70
71
72
73
Rcr = 5.2000e+00
Rca = 3.5000e+00
EtaR = torch.tensor([1.6000000e+01], device=device)
ShfR = torch.tensor([9.0000000e-01, 1.1687500e+00, 1.4375000e+00, 1.7062500e+00, 1.9750000e+00, 2.2437500e+00, 2.5125000e+00, 2.7812500e+00, 3.0500000e+00, 3.3187500e+00, 3.5875000e+00, 3.8562500e+00, 4.1250000e+00, 4.3937500e+00, 4.6625000e+00, 4.9312500e+00], device=device)
Zeta = torch.tensor([3.2000000e+01], device=device)
ShfZ = torch.tensor([1.9634954e-01, 5.8904862e-01, 9.8174770e-01, 1.3744468e+00, 1.7671459e+00, 2.1598449e+00, 2.5525440e+00, 2.9452431e+00], device=device)
EtaA = torch.tensor([8.0000000e+00], device=device)
ShfA = torch.tensor([9.0000000e-01, 1.5500000e+00, 2.2000000e+00, 2.8500000e+00], device=device)
74
75
species_order = ['H', 'C', 'N', 'O']
num_species = len(species_order)
76
aev_computer = torchani.AEVComputer(Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species)
77
energy_shifter = torchani.utils.EnergyShifter(None)
78
79

###############################################################################
80
81
82
83
# Now let's setup datasets. These paths assumes the user run this script under
# the ``examples`` directory of TorchANI's repository. If you download this
# script, you should manually set the path of these files in your system before
# this script can run successfully.
84
85
86
87
88
89
#
# Also note that we need to subtracting energies by the self energies of all
# atoms for each molecule. This makes the range of energies in a reasonable
# range. The second argument defines how to convert species as a list of string
# to tensor, that is, for all supported chemical symbols, which is correspond to
# ``0``, which correspond to ``1``, etc.
Gao, Xiang's avatar
Gao, Xiang committed
90
91
92
93
94

try:
    path = os.path.dirname(os.path.realpath(__file__))
except NameError:
    path = os.getcwd()
95
dspath = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5')
96
batch_size = 2560
Gao, Xiang's avatar
Gao, Xiang committed
97

98
training, validation = torchani.data.load(dspath).subtract_self_energies(energy_shifter, species_order).species_to_indices(species_order).shuffle().split(0.8, None)
99
100
training = training.collate(batch_size).cache()
validation = validation.collate(batch_size).cache()
101
102
print('Self atomic energies: ', energy_shifter.self_energies)

Gao, Xiang's avatar
Gao, Xiang committed
103
###############################################################################
104
# When iterating the dataset, we will get a dict of name->property mapping
105
#
Gao, Xiang's avatar
Gao, Xiang committed
106
###############################################################################
107
# Now let's define atomic neural networks.
108
aev_dim = aev_computer.aev_length
109
110

H_network = torch.nn.Sequential(
111
    torch.nn.Linear(aev_dim, 160),
112
113
114
115
116
117
118
119
120
    torch.nn.CELU(0.1),
    torch.nn.Linear(160, 128),
    torch.nn.CELU(0.1),
    torch.nn.Linear(128, 96),
    torch.nn.CELU(0.1),
    torch.nn.Linear(96, 1)
)

C_network = torch.nn.Sequential(
121
    torch.nn.Linear(aev_dim, 144),
122
123
124
125
126
127
128
129
130
    torch.nn.CELU(0.1),
    torch.nn.Linear(144, 112),
    torch.nn.CELU(0.1),
    torch.nn.Linear(112, 96),
    torch.nn.CELU(0.1),
    torch.nn.Linear(96, 1)
)

N_network = torch.nn.Sequential(
131
    torch.nn.Linear(aev_dim, 128),
132
133
134
135
136
137
138
139
140
    torch.nn.CELU(0.1),
    torch.nn.Linear(128, 112),
    torch.nn.CELU(0.1),
    torch.nn.Linear(112, 96),
    torch.nn.CELU(0.1),
    torch.nn.Linear(96, 1)
)

O_network = torch.nn.Sequential(
141
    torch.nn.Linear(aev_dim, 128),
142
143
144
145
146
147
148
149
150
151
    torch.nn.CELU(0.1),
    torch.nn.Linear(128, 112),
    torch.nn.CELU(0.1),
    torch.nn.Linear(112, 96),
    torch.nn.CELU(0.1),
    torch.nn.Linear(96, 1)
)

nn = torchani.ANIModel([H_network, C_network, N_network, O_network])
print(nn)
152

153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
###############################################################################
# Initialize the weights and biases.
#
# .. note::
#   Pytorch default initialization for the weights and biases in linear layers
#   is Kaiming uniform. See: `TORCH.NN.MODULES.LINEAR`_
#   We initialize the weights similarly but from the normal distribution.
#   The biases were initialized to zero.
#
# .. _TORCH.NN.MODULES.LINEAR:
#   https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear


def init_params(m):
    if isinstance(m, torch.nn.Linear):
        torch.nn.init.kaiming_normal_(m.weight, a=1.0)
        torch.nn.init.zeros_(m.bias)


nn.apply(init_params)

Gao, Xiang's avatar
Gao, Xiang committed
174
###############################################################################
175
# Let's now create a pipeline of AEV Computer --> Neural Networks.
176
model = torchani.nn.Sequential(aev_computer, nn).to(device)
Gao, Xiang's avatar
Gao, Xiang committed
177

178
###############################################################################
179
180
181
# Now let's setup the optimizers. NeuroChem uses Adam with decoupled weight decay
# to updates the weights and Stochastic Gradient Descent (SGD) to update the biases.
# Moreover, we need to specify different weight decay rate for different layes.
182
183
184
185
186
187
188
189
190
191
192
#
# .. note::
#
#   The weight decay in `inputtrain.ipt`_ is named "l2", but it is actually not
#   L2 regularization. The confusion between L2 and weight decay is a common
#   mistake in deep learning.  See: `Decoupled Weight Decay Regularization`_
#   Also note that the weight decay only applies to weight in the training
#   of ANI models, not bias.
#
# .. _Decoupled Weight Decay Regularization:
#   https://arxiv.org/abs/1711.05101
193

194
AdamW = torch.optim.AdamW([
195
    # H networks
196
    {'params': [H_network[0].weight]},
197
198
    {'params': [H_network[2].weight], 'weight_decay': 0.00001},
    {'params': [H_network[4].weight], 'weight_decay': 0.000001},
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
    {'params': [H_network[6].weight]},
    # C networks
    {'params': [C_network[0].weight]},
    {'params': [C_network[2].weight], 'weight_decay': 0.00001},
    {'params': [C_network[4].weight], 'weight_decay': 0.000001},
    {'params': [C_network[6].weight]},
    # N networks
    {'params': [N_network[0].weight]},
    {'params': [N_network[2].weight], 'weight_decay': 0.00001},
    {'params': [N_network[4].weight], 'weight_decay': 0.000001},
    {'params': [N_network[6].weight]},
    # O networks
    {'params': [O_network[0].weight]},
    {'params': [O_network[2].weight], 'weight_decay': 0.00001},
    {'params': [O_network[4].weight], 'weight_decay': 0.000001},
    {'params': [O_network[6].weight]},
])

SGD = torch.optim.SGD([
    # H networks
    {'params': [H_network[0].bias]},
    {'params': [H_network[2].bias]},
221
    {'params': [H_network[4].bias]},
222
    {'params': [H_network[6].bias]},
223
224
225
226
    # C networks
    {'params': [C_network[0].bias]},
    {'params': [C_network[2].bias]},
    {'params': [C_network[4].bias]},
227
    {'params': [C_network[6].bias]},
228
229
230
231
    # N networks
    {'params': [N_network[0].bias]},
    {'params': [N_network[2].bias]},
    {'params': [N_network[4].bias]},
232
    {'params': [N_network[6].bias]},
233
234
235
236
    # O networks
    {'params': [O_network[0].bias]},
    {'params': [O_network[2].bias]},
    {'params': [O_network[4].bias]},
237
238
    {'params': [O_network[6].bias]},
], lr=1e-3)
Gao, Xiang's avatar
Gao, Xiang committed
239

240
###############################################################################
241
# Setting up a learning rate scheduler to do learning rate decay
242
243
AdamW_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(AdamW, factor=0.5, patience=100, threshold=0)
SGD_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(SGD, factor=0.5, patience=100, threshold=0)
244
245
246
247
248

###############################################################################
# Train the model by minimizing the MSE loss, until validation RMSE no longer
# improves during a certain number of steps, decay the learning rate and repeat
# the same process, stop until the learning rate is smaller than a threshold.
249
#
250
251
# We first read the checkpoint files to restart training. We use `latest.pt`
# to store current training state.
252
latest_checkpoint = 'latest.pt'
253
254
255
256
257

###############################################################################
# Resume training from previously saved checkpoints:
if os.path.isfile(latest_checkpoint):
    checkpoint = torch.load(latest_checkpoint)
258
259
260
261
262
    nn.load_state_dict(checkpoint['nn'])
    AdamW.load_state_dict(checkpoint['AdamW'])
    SGD.load_state_dict(checkpoint['SGD'])
    AdamW_scheduler.load_state_dict(checkpoint['AdamW_scheduler'])
    SGD_scheduler.load_state_dict(checkpoint['SGD_scheduler'])
263

264
265
266
267
268
269
270
271
272
273
###############################################################################
# During training, we need to validate on validation set and if validation error
# is better than the best, then save the new best model to a checkpoint


def validate():
    # run validation
    mse_sum = torch.nn.MSELoss(reduction='sum')
    total_mse = 0.0
    count = 0
274
275
276
277
278
    for properties in validation:
        species = properties['species'].to(device)
        coordinates = properties['coordinates'].to(device).float()
        true_energies = properties['energies'].to(device).float()
        _, predicted_energies = model((species, coordinates))
279
280
        total_mse += mse_sum(predicted_energies, true_energies).item()
        count += predicted_energies.shape[0]
Ignacio Pickering's avatar
Ignacio Pickering committed
281
    return hartree2kcalmol(math.sqrt(total_mse / count))
282
283


284
285
286
###############################################################################
# We will also use TensorBoard to visualize our training process
tensorboard = torch.utils.tensorboard.SummaryWriter()
Gao, Xiang's avatar
Gao, Xiang committed
287

Gao, Xiang's avatar
Gao, Xiang committed
288
###############################################################################
289
290
291
292
293
# Finally, we come to the training loop.
#
# In this tutorial, we are setting the maximum epoch to a very small number,
# only to make this demo terminate fast. For serious training, this should be
# set to a much larger value
294
295
mse = torch.nn.MSELoss(reduction='none')

296
print("training starting from epoch", AdamW_scheduler.last_epoch + 1)
Gao, Xiang's avatar
Gao, Xiang committed
297
max_epochs = 10
298
299
300
early_stopping_learning_rate = 1.0E-5
best_model_checkpoint = 'best.pt'

301
for _ in range(AdamW_scheduler.last_epoch + 1, max_epochs):
302
    rmse = validate()
303
    print('RMSE:', rmse, 'at epoch', AdamW_scheduler.last_epoch + 1)
304

305
    learning_rate = AdamW.param_groups[0]['lr']
306
307
308
309
310

    if learning_rate < early_stopping_learning_rate:
        break

    # checkpoint
311
    if AdamW_scheduler.is_better(rmse, AdamW_scheduler.best):
312
313
        torch.save(nn.state_dict(), best_model_checkpoint)

314
315
316
317
318
319
320
    AdamW_scheduler.step(rmse)
    SGD_scheduler.step(rmse)

    tensorboard.add_scalar('validation_rmse', rmse, AdamW_scheduler.last_epoch)
    tensorboard.add_scalar('best_validation_rmse', AdamW_scheduler.best, AdamW_scheduler.last_epoch)
    tensorboard.add_scalar('learning_rate', learning_rate, AdamW_scheduler.last_epoch)

321
    for i, properties in tqdm.tqdm(
322
323
324
325
        enumerate(training),
        total=len(training),
        desc="epoch {}".format(AdamW_scheduler.last_epoch)
    ):
326
327
328
329
330
        species = properties['species'].to(device)
        coordinates = properties['coordinates'].to(device).float()
        true_energies = properties['energies'].to(device).float()
        num_atoms = (species >= 0).sum(dim=1, dtype=true_energies.dtype)
        _, predicted_energies = model((species, coordinates))
331

332
        loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean()
333
334
335

        AdamW.zero_grad()
        SGD.zero_grad()
336
        loss.backward()
337
338
        AdamW.step()
        SGD.step()
339
340

        # write current batch loss to TensorBoard
341
        tensorboard.add_scalar('batch_loss', loss, AdamW_scheduler.last_epoch * len(training) + i)
342
343
344

    torch.save({
        'nn': nn.state_dict(),
345
346
347
348
        'AdamW': AdamW.state_dict(),
        'SGD': SGD.state_dict(),
        'AdamW_scheduler': AdamW_scheduler.state_dict(),
        'SGD_scheduler': SGD_scheduler.state_dict(),
349
    }, latest_checkpoint)