nnp_training.py 13 KB
Newer Older
Gao, Xiang's avatar
Gao, Xiang committed
1
2
# -*- coding: utf-8 -*-
"""
Gao, Xiang's avatar
Gao, Xiang committed
3
4
.. _training-example:

Gao, Xiang's avatar
Gao, Xiang committed
5
6
7
Train Your Own Neural Network Potential
=======================================

8
9
10
This example shows how to use TorchANI to train a neural network potential
with the setup identical to NeuroChem. We will use the same configuration as
specified in `inputtrain.ipt`_
11
12
13
14
15
16
17

.. _`inputtrain.ipt`:
    https://github.com/aiqm/torchani/blob/master/torchani/resources/ani-1x_8x/inputtrain.ipt

.. note::
    TorchANI provide tools to run NeuroChem training config file `inputtrain.ipt`.
    See: :ref:`neurochem-training`.
Gao, Xiang's avatar
Gao, Xiang committed
18
19
20
"""

###############################################################################
21
# To begin with, let's first import the modules and setup devices we will use:
22

23
24
import torch
import torchani
Gao, Xiang's avatar
Gao, Xiang committed
25
import os
26
27
28
import math
import torch.utils.tensorboard
import tqdm
Gao, Xiang's avatar
Gao, Xiang committed
29

Ignacio Pickering's avatar
Ignacio Pickering committed
30
31
32
# helper function to convert energy unit from Hartree to kcal/mol
from torchani.units import hartree2kcalmol

33
34
# device to run the training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
Gao, Xiang's avatar
Gao, Xiang committed
35
36

###############################################################################
37
# Now let's setup constants and construct an AEV computer. These numbers could
38
39
40
41
# be found in `rHCNO-5.2R_16-3.5A_a4-8.params`
# The atomic self energies given in `sae_linfit.dat`_ are computed from ANI-1x
# dataset. These constants can be calculated for any given dataset if ``None``
# is provided as an argument to the object of :class:`EnergyShifter` class.
42
43
44
45
#
# .. note::
#
#   Besides defining these hyperparameters programmatically,
Gao, Xiang's avatar
Gao, Xiang committed
46
#   :mod:`torchani.neurochem` provide tools to read them from file.
47
48
49
50
51
#
# .. _rHCNO-5.2R_16-3.5A_a4-8.params:
#   https://github.com/aiqm/torchani/blob/master/torchani/resources/ani-1x_8x/rHCNO-5.2R_16-3.5A_a4-8.params
# .. _sae_linfit.dat:
#   https://github.com/aiqm/torchani/blob/master/torchani/resources/ani-1x_8x/sae_linfit.dat
52

53
54
55
56
57
58
59
60
Rcr = 5.2000e+00
Rca = 3.5000e+00
EtaR = torch.tensor([1.6000000e+01], device=device)
ShfR = torch.tensor([9.0000000e-01, 1.1687500e+00, 1.4375000e+00, 1.7062500e+00, 1.9750000e+00, 2.2437500e+00, 2.5125000e+00, 2.7812500e+00, 3.0500000e+00, 3.3187500e+00, 3.5875000e+00, 3.8562500e+00, 4.1250000e+00, 4.3937500e+00, 4.6625000e+00, 4.9312500e+00], device=device)
Zeta = torch.tensor([3.2000000e+01], device=device)
ShfZ = torch.tensor([1.9634954e-01, 5.8904862e-01, 9.8174770e-01, 1.3744468e+00, 1.7671459e+00, 2.1598449e+00, 2.5525440e+00, 2.9452431e+00], device=device)
EtaA = torch.tensor([8.0000000e+00], device=device)
ShfA = torch.tensor([9.0000000e-01, 1.5500000e+00, 2.2000000e+00, 2.8500000e+00], device=device)
61
62
species_order = ['H', 'C', 'N', 'O']
num_species = len(species_order)
63
aev_computer = torchani.AEVComputer(Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species)
64
energy_shifter = torchani.utils.EnergyShifter(None)
65
66

###############################################################################
67
68
69
70
# Now let's setup datasets. These paths assumes the user run this script under
# the ``examples`` directory of TorchANI's repository. If you download this
# script, you should manually set the path of these files in your system before
# this script can run successfully.
71
72
73
74
75
76
#
# Also note that we need to subtracting energies by the self energies of all
# atoms for each molecule. This makes the range of energies in a reasonable
# range. The second argument defines how to convert species as a list of string
# to tensor, that is, for all supported chemical symbols, which is correspond to
# ``0``, which correspond to ``1``, etc.
Gao, Xiang's avatar
Gao, Xiang committed
77
78
79
80
81

try:
    path = os.path.dirname(os.path.realpath(__file__))
except NameError:
    path = os.getcwd()
82
dspath = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5')
83
batch_size = 2560
Gao, Xiang's avatar
Gao, Xiang committed
84

85
training, validation = torchani.data.load(dspath).subtract_self_energies(energy_shifter).species_to_indices(species_order).shuffle().split(0.8, None)
86
87
training = training.collate(batch_size).cache()
validation = validation.collate(batch_size).cache()
88
89
print('Self atomic energies: ', energy_shifter.self_energies)

Gao, Xiang's avatar
Gao, Xiang committed
90
###############################################################################
91
# When iterating the dataset, we will get a dict of name->property mapping
92
#
Gao, Xiang's avatar
Gao, Xiang committed
93
###############################################################################
94
# Now let's define atomic neural networks.
95
aev_dim = aev_computer.aev_length
96
97

H_network = torch.nn.Sequential(
98
    torch.nn.Linear(aev_dim, 160),
99
100
101
102
103
104
105
106
107
    torch.nn.CELU(0.1),
    torch.nn.Linear(160, 128),
    torch.nn.CELU(0.1),
    torch.nn.Linear(128, 96),
    torch.nn.CELU(0.1),
    torch.nn.Linear(96, 1)
)

C_network = torch.nn.Sequential(
108
    torch.nn.Linear(aev_dim, 144),
109
110
111
112
113
114
115
116
117
    torch.nn.CELU(0.1),
    torch.nn.Linear(144, 112),
    torch.nn.CELU(0.1),
    torch.nn.Linear(112, 96),
    torch.nn.CELU(0.1),
    torch.nn.Linear(96, 1)
)

N_network = torch.nn.Sequential(
118
    torch.nn.Linear(aev_dim, 128),
119
120
121
122
123
124
125
126
127
    torch.nn.CELU(0.1),
    torch.nn.Linear(128, 112),
    torch.nn.CELU(0.1),
    torch.nn.Linear(112, 96),
    torch.nn.CELU(0.1),
    torch.nn.Linear(96, 1)
)

O_network = torch.nn.Sequential(
128
    torch.nn.Linear(aev_dim, 128),
129
130
131
132
133
134
135
136
137
138
    torch.nn.CELU(0.1),
    torch.nn.Linear(128, 112),
    torch.nn.CELU(0.1),
    torch.nn.Linear(112, 96),
    torch.nn.CELU(0.1),
    torch.nn.Linear(96, 1)
)

nn = torchani.ANIModel([H_network, C_network, N_network, O_network])
print(nn)
139

140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
###############################################################################
# Initialize the weights and biases.
#
# .. note::
#   Pytorch default initialization for the weights and biases in linear layers
#   is Kaiming uniform. See: `TORCH.NN.MODULES.LINEAR`_
#   We initialize the weights similarly but from the normal distribution.
#   The biases were initialized to zero.
#
# .. _TORCH.NN.MODULES.LINEAR:
#   https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear


def init_params(m):
    if isinstance(m, torch.nn.Linear):
        torch.nn.init.kaiming_normal_(m.weight, a=1.0)
        torch.nn.init.zeros_(m.bias)


nn.apply(init_params)

Gao, Xiang's avatar
Gao, Xiang committed
161
###############################################################################
162
# Let's now create a pipeline of AEV Computer --> Neural Networks.
163
model = torchani.nn.Sequential(aev_computer, nn).to(device)
Gao, Xiang's avatar
Gao, Xiang committed
164

165
###############################################################################
166
167
168
# Now let's setup the optimizers. NeuroChem uses Adam with decoupled weight decay
# to updates the weights and Stochastic Gradient Descent (SGD) to update the biases.
# Moreover, we need to specify different weight decay rate for different layes.
169
170
171
172
173
174
175
176
177
178
179
#
# .. note::
#
#   The weight decay in `inputtrain.ipt`_ is named "l2", but it is actually not
#   L2 regularization. The confusion between L2 and weight decay is a common
#   mistake in deep learning.  See: `Decoupled Weight Decay Regularization`_
#   Also note that the weight decay only applies to weight in the training
#   of ANI models, not bias.
#
# .. _Decoupled Weight Decay Regularization:
#   https://arxiv.org/abs/1711.05101
180
181

AdamW = torchani.optim.AdamW([
182
    # H networks
183
    {'params': [H_network[0].weight]},
184
185
    {'params': [H_network[2].weight], 'weight_decay': 0.00001},
    {'params': [H_network[4].weight], 'weight_decay': 0.000001},
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
    {'params': [H_network[6].weight]},
    # C networks
    {'params': [C_network[0].weight]},
    {'params': [C_network[2].weight], 'weight_decay': 0.00001},
    {'params': [C_network[4].weight], 'weight_decay': 0.000001},
    {'params': [C_network[6].weight]},
    # N networks
    {'params': [N_network[0].weight]},
    {'params': [N_network[2].weight], 'weight_decay': 0.00001},
    {'params': [N_network[4].weight], 'weight_decay': 0.000001},
    {'params': [N_network[6].weight]},
    # O networks
    {'params': [O_network[0].weight]},
    {'params': [O_network[2].weight], 'weight_decay': 0.00001},
    {'params': [O_network[4].weight], 'weight_decay': 0.000001},
    {'params': [O_network[6].weight]},
])

SGD = torch.optim.SGD([
    # H networks
    {'params': [H_network[0].bias]},
    {'params': [H_network[2].bias]},
208
    {'params': [H_network[4].bias]},
209
    {'params': [H_network[6].bias]},
210
211
212
213
    # C networks
    {'params': [C_network[0].bias]},
    {'params': [C_network[2].bias]},
    {'params': [C_network[4].bias]},
214
    {'params': [C_network[6].bias]},
215
216
217
218
    # N networks
    {'params': [N_network[0].bias]},
    {'params': [N_network[2].bias]},
    {'params': [N_network[4].bias]},
219
    {'params': [N_network[6].bias]},
220
221
222
223
    # O networks
    {'params': [O_network[0].bias]},
    {'params': [O_network[2].bias]},
    {'params': [O_network[4].bias]},
224
225
    {'params': [O_network[6].bias]},
], lr=1e-3)
Gao, Xiang's avatar
Gao, Xiang committed
226

227
###############################################################################
228
# Setting up a learning rate scheduler to do learning rate decay
229
230
AdamW_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(AdamW, factor=0.5, patience=100, threshold=0)
SGD_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(SGD, factor=0.5, patience=100, threshold=0)
231
232
233
234
235

###############################################################################
# Train the model by minimizing the MSE loss, until validation RMSE no longer
# improves during a certain number of steps, decay the learning rate and repeat
# the same process, stop until the learning rate is smaller than a threshold.
236
#
237
238
# We first read the checkpoint files to restart training. We use `latest.pt`
# to store current training state.
239
latest_checkpoint = 'latest.pt'
240
241
242
243
244

###############################################################################
# Resume training from previously saved checkpoints:
if os.path.isfile(latest_checkpoint):
    checkpoint = torch.load(latest_checkpoint)
245
246
247
248
249
    nn.load_state_dict(checkpoint['nn'])
    AdamW.load_state_dict(checkpoint['AdamW'])
    SGD.load_state_dict(checkpoint['SGD'])
    AdamW_scheduler.load_state_dict(checkpoint['AdamW_scheduler'])
    SGD_scheduler.load_state_dict(checkpoint['SGD_scheduler'])
250

251
252
253
254
255
256
257
258
259
260
###############################################################################
# During training, we need to validate on validation set and if validation error
# is better than the best, then save the new best model to a checkpoint


def validate():
    # run validation
    mse_sum = torch.nn.MSELoss(reduction='sum')
    total_mse = 0.0
    count = 0
261
262
263
264
265
    for properties in validation:
        species = properties['species'].to(device)
        coordinates = properties['coordinates'].to(device).float()
        true_energies = properties['energies'].to(device).float()
        _, predicted_energies = model((species, coordinates))
266
267
        total_mse += mse_sum(predicted_energies, true_energies).item()
        count += predicted_energies.shape[0]
Ignacio Pickering's avatar
Ignacio Pickering committed
268
    return hartree2kcalmol(math.sqrt(total_mse / count))
269
270


271
272
273
###############################################################################
# We will also use TensorBoard to visualize our training process
tensorboard = torch.utils.tensorboard.SummaryWriter()
Gao, Xiang's avatar
Gao, Xiang committed
274

Gao, Xiang's avatar
Gao, Xiang committed
275
###############################################################################
276
277
278
279
280
# Finally, we come to the training loop.
#
# In this tutorial, we are setting the maximum epoch to a very small number,
# only to make this demo terminate fast. For serious training, this should be
# set to a much larger value
281
282
mse = torch.nn.MSELoss(reduction='none')

283
print("training starting from epoch", AdamW_scheduler.last_epoch + 1)
Gao, Xiang's avatar
Gao, Xiang committed
284
max_epochs = 10
285
286
287
early_stopping_learning_rate = 1.0E-5
best_model_checkpoint = 'best.pt'

288
for _ in range(AdamW_scheduler.last_epoch + 1, max_epochs):
289
    rmse = validate()
290
    print('RMSE:', rmse, 'at epoch', AdamW_scheduler.last_epoch + 1)
291

292
    learning_rate = AdamW.param_groups[0]['lr']
293
294
295
296
297

    if learning_rate < early_stopping_learning_rate:
        break

    # checkpoint
298
    if AdamW_scheduler.is_better(rmse, AdamW_scheduler.best):
299
300
        torch.save(nn.state_dict(), best_model_checkpoint)

301
302
303
304
305
306
307
    AdamW_scheduler.step(rmse)
    SGD_scheduler.step(rmse)

    tensorboard.add_scalar('validation_rmse', rmse, AdamW_scheduler.last_epoch)
    tensorboard.add_scalar('best_validation_rmse', AdamW_scheduler.best, AdamW_scheduler.last_epoch)
    tensorboard.add_scalar('learning_rate', learning_rate, AdamW_scheduler.last_epoch)

308
    for i, properties in tqdm.tqdm(
309
310
311
312
        enumerate(training),
        total=len(training),
        desc="epoch {}".format(AdamW_scheduler.last_epoch)
    ):
313
314
315
316
317
        species = properties['species'].to(device)
        coordinates = properties['coordinates'].to(device).float()
        true_energies = properties['energies'].to(device).float()
        num_atoms = (species >= 0).sum(dim=1, dtype=true_energies.dtype)
        _, predicted_energies = model((species, coordinates))
318

319
        loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean()
320
321
322

        AdamW.zero_grad()
        SGD.zero_grad()
323
        loss.backward()
324
325
        AdamW.step()
        SGD.step()
326
327

        # write current batch loss to TensorBoard
328
        tensorboard.add_scalar('batch_loss', loss, AdamW_scheduler.last_epoch * len(training) + i)
329
330
331

    torch.save({
        'nn': nn.state_dict(),
332
333
334
335
        'AdamW': AdamW.state_dict(),
        'SGD': SGD.state_dict(),
        'AdamW_scheduler': AdamW_scheduler.state_dict(),
        'SGD_scheduler': SGD_scheduler.state_dict(),
336
    }, latest_checkpoint)