nnp_training.py 13 KB
Newer Older
Gao, Xiang's avatar
Gao, Xiang committed
1
2
# -*- coding: utf-8 -*-
"""
Gao, Xiang's avatar
Gao, Xiang committed
3
4
.. _training-example:

Gao, Xiang's avatar
Gao, Xiang committed
5
6
7
Train Your Own Neural Network Potential
=======================================

8
9
10
This example shows how to use TorchANI to train a neural network potential
with the setup identical to NeuroChem. We will use the same configuration as
specified in `inputtrain.ipt`_
11
12
13
14
15
16
17

.. _`inputtrain.ipt`:
    https://github.com/aiqm/torchani/blob/master/torchani/resources/ani-1x_8x/inputtrain.ipt

.. note::
    TorchANI provide tools to run NeuroChem training config file `inputtrain.ipt`.
    See: :ref:`neurochem-training`.
Gao, Xiang's avatar
Gao, Xiang committed
18
19
20
"""

###############################################################################
21
# To begin with, let's first import the modules and setup devices we will use:
22

23
24
import torch
import torchani
Gao, Xiang's avatar
Gao, Xiang committed
25
import os
26
27
28
import math
import torch.utils.tensorboard
import tqdm
Gao, Xiang's avatar
Gao, Xiang committed
29

Ignacio Pickering's avatar
Ignacio Pickering committed
30
31
32
# helper function to convert energy unit from Hartree to kcal/mol
from torchani.units import hartree2kcalmol

33
34
# device to run the training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
Gao, Xiang's avatar
Gao, Xiang committed
35
36

###############################################################################
37
# Now let's setup constants and construct an AEV computer. These numbers could
38
39
40
41
# be found in `rHCNO-5.2R_16-3.5A_a4-8.params`
# The atomic self energies given in `sae_linfit.dat`_ are computed from ANI-1x
# dataset. These constants can be calculated for any given dataset if ``None``
# is provided as an argument to the object of :class:`EnergyShifter` class.
42
43
44
45
#
# .. note::
#
#   Besides defining these hyperparameters programmatically,
Gao, Xiang's avatar
Gao, Xiang committed
46
#   :mod:`torchani.neurochem` provide tools to read them from file.
47
48
49
50
51
#
# .. _rHCNO-5.2R_16-3.5A_a4-8.params:
#   https://github.com/aiqm/torchani/blob/master/torchani/resources/ani-1x_8x/rHCNO-5.2R_16-3.5A_a4-8.params
# .. _sae_linfit.dat:
#   https://github.com/aiqm/torchani/blob/master/torchani/resources/ani-1x_8x/sae_linfit.dat
52

53
54
55
56
57
58
59
60
61
62
Rcr = 5.2000e+00
Rca = 3.5000e+00
EtaR = torch.tensor([1.6000000e+01], device=device)
ShfR = torch.tensor([9.0000000e-01, 1.1687500e+00, 1.4375000e+00, 1.7062500e+00, 1.9750000e+00, 2.2437500e+00, 2.5125000e+00, 2.7812500e+00, 3.0500000e+00, 3.3187500e+00, 3.5875000e+00, 3.8562500e+00, 4.1250000e+00, 4.3937500e+00, 4.6625000e+00, 4.9312500e+00], device=device)
Zeta = torch.tensor([3.2000000e+01], device=device)
ShfZ = torch.tensor([1.9634954e-01, 5.8904862e-01, 9.8174770e-01, 1.3744468e+00, 1.7671459e+00, 2.1598449e+00, 2.5525440e+00, 2.9452431e+00], device=device)
EtaA = torch.tensor([8.0000000e+00], device=device)
ShfA = torch.tensor([9.0000000e-01, 1.5500000e+00, 2.2000000e+00, 2.8500000e+00], device=device)
num_species = 4
aev_computer = torchani.AEVComputer(Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species)
63
energy_shifter = torchani.utils.EnergyShifter(None)
64
species_to_tensor = torchani.utils.ChemicalSymbolsToInts(['H', 'C', 'N', 'O'])
65
66

###############################################################################
67
68
69
70
# Now let's setup datasets. These paths assumes the user run this script under
# the ``examples`` directory of TorchANI's repository. If you download this
# script, you should manually set the path of these files in your system before
# this script can run successfully.
71
72
73
74
75
76
#
# Also note that we need to subtracting energies by the self energies of all
# atoms for each molecule. This makes the range of energies in a reasonable
# range. The second argument defines how to convert species as a list of string
# to tensor, that is, for all supported chemical symbols, which is correspond to
# ``0``, which correspond to ``1``, etc.
Gao, Xiang's avatar
Gao, Xiang committed
77
78
79
80
81

try:
    path = os.path.dirname(os.path.realpath(__file__))
except NameError:
    path = os.getcwd()
82
dspath = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5')
83
batch_size = 2560
Gao, Xiang's avatar
Gao, Xiang committed
84

85
86
87
88
89
dataset = torchani.data.load(dspath).subtract_self_energies(energy_shifter).species_to_indices().shuffle()
size = len(dataset)
training, validation = dataset.split(int(0.8 * size), None)
training = training.collate(batch_size).cache()
validation = validation.collate(batch_size).cache()
90
91
print('Self atomic energies: ', energy_shifter.self_energies)

Gao, Xiang's avatar
Gao, Xiang committed
92
###############################################################################
93
# When iterating the dataset, we will get a dict of name->property mapping
94
#
Gao, Xiang's avatar
Gao, Xiang committed
95
###############################################################################
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# Now let's define atomic neural networks.

H_network = torch.nn.Sequential(
    torch.nn.Linear(384, 160),
    torch.nn.CELU(0.1),
    torch.nn.Linear(160, 128),
    torch.nn.CELU(0.1),
    torch.nn.Linear(128, 96),
    torch.nn.CELU(0.1),
    torch.nn.Linear(96, 1)
)

C_network = torch.nn.Sequential(
    torch.nn.Linear(384, 144),
    torch.nn.CELU(0.1),
    torch.nn.Linear(144, 112),
    torch.nn.CELU(0.1),
    torch.nn.Linear(112, 96),
    torch.nn.CELU(0.1),
    torch.nn.Linear(96, 1)
)

N_network = torch.nn.Sequential(
    torch.nn.Linear(384, 128),
    torch.nn.CELU(0.1),
    torch.nn.Linear(128, 112),
    torch.nn.CELU(0.1),
    torch.nn.Linear(112, 96),
    torch.nn.CELU(0.1),
    torch.nn.Linear(96, 1)
)

O_network = torch.nn.Sequential(
    torch.nn.Linear(384, 128),
    torch.nn.CELU(0.1),
    torch.nn.Linear(128, 112),
    torch.nn.CELU(0.1),
    torch.nn.Linear(112, 96),
    torch.nn.CELU(0.1),
    torch.nn.Linear(96, 1)
)

nn = torchani.ANIModel([H_network, C_network, N_network, O_network])
print(nn)
140

141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
###############################################################################
# Initialize the weights and biases.
#
# .. note::
#   Pytorch default initialization for the weights and biases in linear layers
#   is Kaiming uniform. See: `TORCH.NN.MODULES.LINEAR`_
#   We initialize the weights similarly but from the normal distribution.
#   The biases were initialized to zero.
#
# .. _TORCH.NN.MODULES.LINEAR:
#   https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear


def init_params(m):
    if isinstance(m, torch.nn.Linear):
        torch.nn.init.kaiming_normal_(m.weight, a=1.0)
        torch.nn.init.zeros_(m.bias)


nn.apply(init_params)

Gao, Xiang's avatar
Gao, Xiang committed
162
###############################################################################
163
# Let's now create a pipeline of AEV Computer --> Neural Networks.
164
model = torchani.nn.Sequential(aev_computer, nn).to(device)
Gao, Xiang's avatar
Gao, Xiang committed
165

166
###############################################################################
167
168
169
# Now let's setup the optimizers. NeuroChem uses Adam with decoupled weight decay
# to updates the weights and Stochastic Gradient Descent (SGD) to update the biases.
# Moreover, we need to specify different weight decay rate for different layes.
170
171
172
173
174
175
176
177
178
179
180
#
# .. note::
#
#   The weight decay in `inputtrain.ipt`_ is named "l2", but it is actually not
#   L2 regularization. The confusion between L2 and weight decay is a common
#   mistake in deep learning.  See: `Decoupled Weight Decay Regularization`_
#   Also note that the weight decay only applies to weight in the training
#   of ANI models, not bias.
#
# .. _Decoupled Weight Decay Regularization:
#   https://arxiv.org/abs/1711.05101
181
182

AdamW = torchani.optim.AdamW([
183
    # H networks
184
    {'params': [H_network[0].weight]},
185
186
    {'params': [H_network[2].weight], 'weight_decay': 0.00001},
    {'params': [H_network[4].weight], 'weight_decay': 0.000001},
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
    {'params': [H_network[6].weight]},
    # C networks
    {'params': [C_network[0].weight]},
    {'params': [C_network[2].weight], 'weight_decay': 0.00001},
    {'params': [C_network[4].weight], 'weight_decay': 0.000001},
    {'params': [C_network[6].weight]},
    # N networks
    {'params': [N_network[0].weight]},
    {'params': [N_network[2].weight], 'weight_decay': 0.00001},
    {'params': [N_network[4].weight], 'weight_decay': 0.000001},
    {'params': [N_network[6].weight]},
    # O networks
    {'params': [O_network[0].weight]},
    {'params': [O_network[2].weight], 'weight_decay': 0.00001},
    {'params': [O_network[4].weight], 'weight_decay': 0.000001},
    {'params': [O_network[6].weight]},
])

SGD = torch.optim.SGD([
    # H networks
    {'params': [H_network[0].bias]},
    {'params': [H_network[2].bias]},
209
    {'params': [H_network[4].bias]},
210
    {'params': [H_network[6].bias]},
211
212
213
214
    # C networks
    {'params': [C_network[0].bias]},
    {'params': [C_network[2].bias]},
    {'params': [C_network[4].bias]},
215
    {'params': [C_network[6].bias]},
216
217
218
219
    # N networks
    {'params': [N_network[0].bias]},
    {'params': [N_network[2].bias]},
    {'params': [N_network[4].bias]},
220
    {'params': [N_network[6].bias]},
221
222
223
224
    # O networks
    {'params': [O_network[0].bias]},
    {'params': [O_network[2].bias]},
    {'params': [O_network[4].bias]},
225
226
    {'params': [O_network[6].bias]},
], lr=1e-3)
Gao, Xiang's avatar
Gao, Xiang committed
227

228
###############################################################################
229
# Setting up a learning rate scheduler to do learning rate decay
230
231
AdamW_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(AdamW, factor=0.5, patience=100, threshold=0)
SGD_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(SGD, factor=0.5, patience=100, threshold=0)
232
233
234
235
236

###############################################################################
# Train the model by minimizing the MSE loss, until validation RMSE no longer
# improves during a certain number of steps, decay the learning rate and repeat
# the same process, stop until the learning rate is smaller than a threshold.
237
#
238
239
# We first read the checkpoint files to restart training. We use `latest.pt`
# to store current training state.
240
latest_checkpoint = 'latest.pt'
241
242
243
244
245

###############################################################################
# Resume training from previously saved checkpoints:
if os.path.isfile(latest_checkpoint):
    checkpoint = torch.load(latest_checkpoint)
246
247
248
249
250
    nn.load_state_dict(checkpoint['nn'])
    AdamW.load_state_dict(checkpoint['AdamW'])
    SGD.load_state_dict(checkpoint['SGD'])
    AdamW_scheduler.load_state_dict(checkpoint['AdamW_scheduler'])
    SGD_scheduler.load_state_dict(checkpoint['SGD_scheduler'])
251

252
253
254
255
256
257
258
259
260
261
###############################################################################
# During training, we need to validate on validation set and if validation error
# is better than the best, then save the new best model to a checkpoint


def validate():
    # run validation
    mse_sum = torch.nn.MSELoss(reduction='sum')
    total_mse = 0.0
    count = 0
262
263
264
265
266
    for properties in validation:
        species = properties['species'].to(device)
        coordinates = properties['coordinates'].to(device).float()
        true_energies = properties['energies'].to(device).float()
        _, predicted_energies = model((species, coordinates))
267
268
        total_mse += mse_sum(predicted_energies, true_energies).item()
        count += predicted_energies.shape[0]
Ignacio Pickering's avatar
Ignacio Pickering committed
269
    return hartree2kcalmol(math.sqrt(total_mse / count))
270
271


272
273
274
###############################################################################
# We will also use TensorBoard to visualize our training process
tensorboard = torch.utils.tensorboard.SummaryWriter()
Gao, Xiang's avatar
Gao, Xiang committed
275

Gao, Xiang's avatar
Gao, Xiang committed
276
###############################################################################
277
278
279
280
281
# Finally, we come to the training loop.
#
# In this tutorial, we are setting the maximum epoch to a very small number,
# only to make this demo terminate fast. For serious training, this should be
# set to a much larger value
282
283
mse = torch.nn.MSELoss(reduction='none')

284
print("training starting from epoch", AdamW_scheduler.last_epoch + 1)
Gao, Xiang's avatar
Gao, Xiang committed
285
max_epochs = 10
286
287
288
early_stopping_learning_rate = 1.0E-5
best_model_checkpoint = 'best.pt'

289
for _ in range(AdamW_scheduler.last_epoch + 1, max_epochs):
290
    rmse = validate()
291
    print('RMSE:', rmse, 'at epoch', AdamW_scheduler.last_epoch + 1)
292

293
    learning_rate = AdamW.param_groups[0]['lr']
294
295
296
297
298

    if learning_rate < early_stopping_learning_rate:
        break

    # checkpoint
299
    if AdamW_scheduler.is_better(rmse, AdamW_scheduler.best):
300
301
        torch.save(nn.state_dict(), best_model_checkpoint)

302
303
304
305
306
307
308
    AdamW_scheduler.step(rmse)
    SGD_scheduler.step(rmse)

    tensorboard.add_scalar('validation_rmse', rmse, AdamW_scheduler.last_epoch)
    tensorboard.add_scalar('best_validation_rmse', AdamW_scheduler.best, AdamW_scheduler.last_epoch)
    tensorboard.add_scalar('learning_rate', learning_rate, AdamW_scheduler.last_epoch)

309
    for i, properties in tqdm.tqdm(
310
311
312
313
        enumerate(training),
        total=len(training),
        desc="epoch {}".format(AdamW_scheduler.last_epoch)
    ):
314
315
316
317
318
        species = properties['species'].to(device)
        coordinates = properties['coordinates'].to(device).float()
        true_energies = properties['energies'].to(device).float()
        num_atoms = (species >= 0).sum(dim=1, dtype=true_energies.dtype)
        _, predicted_energies = model((species, coordinates))
319

320
        loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean()
321
322
323

        AdamW.zero_grad()
        SGD.zero_grad()
324
        loss.backward()
325
326
        AdamW.step()
        SGD.step()
327
328

        # write current batch loss to TensorBoard
329
        tensorboard.add_scalar('batch_loss', loss, AdamW_scheduler.last_epoch * len(training) + i)
330
331
332

    torch.save({
        'nn': nn.state_dict(),
333
334
335
336
        'AdamW': AdamW.state_dict(),
        'SGD': SGD.state_dict(),
        'AdamW_scheduler': AdamW_scheduler.state_dict(),
        'SGD_scheduler': SGD_scheduler.state_dict(),
337
    }, latest_checkpoint)