nnp_training.py 14.9 KB
Newer Older
Gao, Xiang's avatar
Gao, Xiang committed
1
2
# -*- coding: utf-8 -*-
"""
Gao, Xiang's avatar
Gao, Xiang committed
3
4
.. _training-example:

Gao, Xiang's avatar
Gao, Xiang committed
5
6
7
Train Your Own Neural Network Potential
=======================================

8
9
10
This example shows how to use TorchANI to train a neural network potential
with the setup identical to NeuroChem. We will use the same configuration as
specified in `inputtrain.ipt`_
11
12
13
14
15
16
17

.. _`inputtrain.ipt`:
    https://github.com/aiqm/torchani/blob/master/torchani/resources/ani-1x_8x/inputtrain.ipt

.. note::
    TorchANI provide tools to run NeuroChem training config file `inputtrain.ipt`.
    See: :ref:`neurochem-training`.
18
19
20
21
22
23
24
25
26
27
28
29
30

.. warning::
    The training setup used in this file is configured to reproduce the original research
    at `Less is more: Sampling chemical space with active learning`_ as much as possible.
    That research was done on a different platform called NeuroChem which has many default
    options and technical details different from PyTorch. Some decisions made here
    (such as, using NeuroChem's initialization instead of PyTorch's default initialization)
    is not because it gives better result, but solely based on reproducing the original
    research. This file should not be interpreted as a suggestions to the readers on how
    they should setup their models.

.. _`Less is more: Sampling chemical space with active learning`:
    https://aip.scitation.org/doi/full/10.1063/1.5023802
Gao, Xiang's avatar
Gao, Xiang committed
31
32
33
"""

###############################################################################
34
# To begin with, let's first import the modules and setup devices we will use:
35

36
37
import torch
import torchani
Gao, Xiang's avatar
Gao, Xiang committed
38
import os
39
40
41
import math
import torch.utils.tensorboard
import tqdm
42
import pickle
Gao, Xiang's avatar
Gao, Xiang committed
43

Ignacio Pickering's avatar
Ignacio Pickering committed
44
45
46
# helper function to convert energy unit from Hartree to kcal/mol
from torchani.units import hartree2kcalmol

47
48
# device to run the training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
Gao, Xiang's avatar
Gao, Xiang committed
49
50

###############################################################################
51
# Now let's setup constants and construct an AEV computer. These numbers could
52
53
54
55
# be found in `rHCNO-5.2R_16-3.5A_a4-8.params`
# The atomic self energies given in `sae_linfit.dat`_ are computed from ANI-1x
# dataset. These constants can be calculated for any given dataset if ``None``
# is provided as an argument to the object of :class:`EnergyShifter` class.
56
57
58
59
#
# .. note::
#
#   Besides defining these hyperparameters programmatically,
Gao, Xiang's avatar
Gao, Xiang committed
60
#   :mod:`torchani.neurochem` provide tools to read them from file.
61
62
63
64
65
#
# .. _rHCNO-5.2R_16-3.5A_a4-8.params:
#   https://github.com/aiqm/torchani/blob/master/torchani/resources/ani-1x_8x/rHCNO-5.2R_16-3.5A_a4-8.params
# .. _sae_linfit.dat:
#   https://github.com/aiqm/torchani/blob/master/torchani/resources/ani-1x_8x/sae_linfit.dat
66

67
68
69
70
71
72
73
74
Rcr = 5.2000e+00
Rca = 3.5000e+00
EtaR = torch.tensor([1.6000000e+01], device=device)
ShfR = torch.tensor([9.0000000e-01, 1.1687500e+00, 1.4375000e+00, 1.7062500e+00, 1.9750000e+00, 2.2437500e+00, 2.5125000e+00, 2.7812500e+00, 3.0500000e+00, 3.3187500e+00, 3.5875000e+00, 3.8562500e+00, 4.1250000e+00, 4.3937500e+00, 4.6625000e+00, 4.9312500e+00], device=device)
Zeta = torch.tensor([3.2000000e+01], device=device)
ShfZ = torch.tensor([1.9634954e-01, 5.8904862e-01, 9.8174770e-01, 1.3744468e+00, 1.7671459e+00, 2.1598449e+00, 2.5525440e+00, 2.9452431e+00], device=device)
EtaA = torch.tensor([8.0000000e+00], device=device)
ShfA = torch.tensor([9.0000000e-01, 1.5500000e+00, 2.2000000e+00, 2.8500000e+00], device=device)
75
76
species_order = ['H', 'C', 'N', 'O']
num_species = len(species_order)
77
aev_computer = torchani.AEVComputer(Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species)
78
energy_shifter = torchani.utils.EnergyShifter(None)
79
80

###############################################################################
81
82
83
84
# Now let's setup datasets. These paths assumes the user run this script under
# the ``examples`` directory of TorchANI's repository. If you download this
# script, you should manually set the path of these files in your system before
# this script can run successfully.
85
86
87
88
89
90
#
# Also note that we need to subtracting energies by the self energies of all
# atoms for each molecule. This makes the range of energies in a reasonable
# range. The second argument defines how to convert species as a list of string
# to tensor, that is, for all supported chemical symbols, which is correspond to
# ``0``, which correspond to ``1``, etc.
Gao, Xiang's avatar
Gao, Xiang committed
91
92
93
94
95

try:
    path = os.path.dirname(os.path.realpath(__file__))
except NameError:
    path = os.getcwd()
96
dspath = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5')
97
batch_size = 2560
Gao, Xiang's avatar
Gao, Xiang committed
98

99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
pickled_dataset_path = 'dataset.pkl'

# We pickle the dataset after loading to ensure we use the same validation set
# each time we restart training, otherwise we risk mixing the validation and
# training sets on each restart.
if os.path.isfile(pickled_dataset_path):
    print(f'Unpickling preprocessed dataset found in {pickled_dataset_path}')
    with open(pickled_dataset_path, 'rb') as f:
        dataset = pickle.load(f)
    training = dataset['training'].collate(batch_size).cache()
    validation = dataset['validation'].collate(batch_size).cache()
    energy_shifter.self_energies = dataset['self_energies'].to(device)
else:
    print(f'Processing dataset in {dspath}')
    training, validation = torchani.data.load(dspath)\
                                        .subtract_self_energies(energy_shifter, species_order)\
                                        .species_to_indices(species_order)\
                                        .shuffle()\
                                        .split(0.8, None)
    with open(pickled_dataset_path, 'wb') as f:
        pickle.dump({'training': training,
                     'validation': validation,
                     'self_energies': energy_shifter.self_energies.cpu()}, f)
    training = training.collate(batch_size).cache()
    validation = validation.collate(batch_size).cache()
124
125
print('Self atomic energies: ', energy_shifter.self_energies)

Gao, Xiang's avatar
Gao, Xiang committed
126
###############################################################################
127
# When iterating the dataset, we will get a dict of name->property mapping
128
#
Gao, Xiang's avatar
Gao, Xiang committed
129
###############################################################################
130
# Now let's define atomic neural networks.
131
aev_dim = aev_computer.aev_length
132
133

H_network = torch.nn.Sequential(
134
    torch.nn.Linear(aev_dim, 160),
135
136
137
138
139
140
141
142
143
    torch.nn.CELU(0.1),
    torch.nn.Linear(160, 128),
    torch.nn.CELU(0.1),
    torch.nn.Linear(128, 96),
    torch.nn.CELU(0.1),
    torch.nn.Linear(96, 1)
)

C_network = torch.nn.Sequential(
144
    torch.nn.Linear(aev_dim, 144),
145
146
147
148
149
150
151
152
153
    torch.nn.CELU(0.1),
    torch.nn.Linear(144, 112),
    torch.nn.CELU(0.1),
    torch.nn.Linear(112, 96),
    torch.nn.CELU(0.1),
    torch.nn.Linear(96, 1)
)

N_network = torch.nn.Sequential(
154
    torch.nn.Linear(aev_dim, 128),
155
156
157
158
159
160
161
162
163
    torch.nn.CELU(0.1),
    torch.nn.Linear(128, 112),
    torch.nn.CELU(0.1),
    torch.nn.Linear(112, 96),
    torch.nn.CELU(0.1),
    torch.nn.Linear(96, 1)
)

O_network = torch.nn.Sequential(
164
    torch.nn.Linear(aev_dim, 128),
165
166
167
168
169
170
171
172
173
174
    torch.nn.CELU(0.1),
    torch.nn.Linear(128, 112),
    torch.nn.CELU(0.1),
    torch.nn.Linear(112, 96),
    torch.nn.CELU(0.1),
    torch.nn.Linear(96, 1)
)

nn = torchani.ANIModel([H_network, C_network, N_network, O_network])
print(nn)
175

176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
###############################################################################
# Initialize the weights and biases.
#
# .. note::
#   Pytorch default initialization for the weights and biases in linear layers
#   is Kaiming uniform. See: `TORCH.NN.MODULES.LINEAR`_
#   We initialize the weights similarly but from the normal distribution.
#   The biases were initialized to zero.
#
# .. _TORCH.NN.MODULES.LINEAR:
#   https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear


def init_params(m):
    if isinstance(m, torch.nn.Linear):
        torch.nn.init.kaiming_normal_(m.weight, a=1.0)
        torch.nn.init.zeros_(m.bias)


nn.apply(init_params)

Gao, Xiang's avatar
Gao, Xiang committed
197
###############################################################################
198
# Let's now create a pipeline of AEV Computer --> Neural Networks.
199
model = torchani.nn.Sequential(aev_computer, nn).to(device)
Gao, Xiang's avatar
Gao, Xiang committed
200

201
###############################################################################
202
203
204
# Now let's setup the optimizers. NeuroChem uses Adam with decoupled weight decay
# to updates the weights and Stochastic Gradient Descent (SGD) to update the biases.
# Moreover, we need to specify different weight decay rate for different layes.
205
206
207
208
209
210
211
212
213
214
215
#
# .. note::
#
#   The weight decay in `inputtrain.ipt`_ is named "l2", but it is actually not
#   L2 regularization. The confusion between L2 and weight decay is a common
#   mistake in deep learning.  See: `Decoupled Weight Decay Regularization`_
#   Also note that the weight decay only applies to weight in the training
#   of ANI models, not bias.
#
# .. _Decoupled Weight Decay Regularization:
#   https://arxiv.org/abs/1711.05101
216

217
AdamW = torch.optim.AdamW([
218
    # H networks
219
    {'params': [H_network[0].weight]},
220
221
    {'params': [H_network[2].weight], 'weight_decay': 0.00001},
    {'params': [H_network[4].weight], 'weight_decay': 0.000001},
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
    {'params': [H_network[6].weight]},
    # C networks
    {'params': [C_network[0].weight]},
    {'params': [C_network[2].weight], 'weight_decay': 0.00001},
    {'params': [C_network[4].weight], 'weight_decay': 0.000001},
    {'params': [C_network[6].weight]},
    # N networks
    {'params': [N_network[0].weight]},
    {'params': [N_network[2].weight], 'weight_decay': 0.00001},
    {'params': [N_network[4].weight], 'weight_decay': 0.000001},
    {'params': [N_network[6].weight]},
    # O networks
    {'params': [O_network[0].weight]},
    {'params': [O_network[2].weight], 'weight_decay': 0.00001},
    {'params': [O_network[4].weight], 'weight_decay': 0.000001},
    {'params': [O_network[6].weight]},
])

SGD = torch.optim.SGD([
    # H networks
    {'params': [H_network[0].bias]},
    {'params': [H_network[2].bias]},
244
    {'params': [H_network[4].bias]},
245
    {'params': [H_network[6].bias]},
246
247
248
249
    # C networks
    {'params': [C_network[0].bias]},
    {'params': [C_network[2].bias]},
    {'params': [C_network[4].bias]},
250
    {'params': [C_network[6].bias]},
251
252
253
254
    # N networks
    {'params': [N_network[0].bias]},
    {'params': [N_network[2].bias]},
    {'params': [N_network[4].bias]},
255
    {'params': [N_network[6].bias]},
256
257
258
259
    # O networks
    {'params': [O_network[0].bias]},
    {'params': [O_network[2].bias]},
    {'params': [O_network[4].bias]},
260
261
    {'params': [O_network[6].bias]},
], lr=1e-3)
Gao, Xiang's avatar
Gao, Xiang committed
262

263
###############################################################################
264
# Setting up a learning rate scheduler to do learning rate decay
265
266
AdamW_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(AdamW, factor=0.5, patience=100, threshold=0)
SGD_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(SGD, factor=0.5, patience=100, threshold=0)
267
268
269
270
271

###############################################################################
# Train the model by minimizing the MSE loss, until validation RMSE no longer
# improves during a certain number of steps, decay the learning rate and repeat
# the same process, stop until the learning rate is smaller than a threshold.
272
#
273
274
# We first read the checkpoint files to restart training. We use `latest.pt`
# to store current training state.
275
latest_checkpoint = 'latest.pt'
276
277
278
279
280

###############################################################################
# Resume training from previously saved checkpoints:
if os.path.isfile(latest_checkpoint):
    checkpoint = torch.load(latest_checkpoint)
281
282
283
284
285
    nn.load_state_dict(checkpoint['nn'])
    AdamW.load_state_dict(checkpoint['AdamW'])
    SGD.load_state_dict(checkpoint['SGD'])
    AdamW_scheduler.load_state_dict(checkpoint['AdamW_scheduler'])
    SGD_scheduler.load_state_dict(checkpoint['SGD_scheduler'])
286

287
288
289
290
291
292
293
294
295
296
###############################################################################
# During training, we need to validate on validation set and if validation error
# is better than the best, then save the new best model to a checkpoint


def validate():
    # run validation
    mse_sum = torch.nn.MSELoss(reduction='sum')
    total_mse = 0.0
    count = 0
297
298
299
300
301
302
303
304
305
306
    model.train(False)
    with torch.no_grad():
        for properties in validation:
            species = properties['species'].to(device)
            coordinates = properties['coordinates'].to(device).float()
            true_energies = properties['energies'].to(device).float()
            _, predicted_energies = model((species, coordinates))
            total_mse += mse_sum(predicted_energies, true_energies).item()
            count += predicted_energies.shape[0]
    model.train(True)
Ignacio Pickering's avatar
Ignacio Pickering committed
307
    return hartree2kcalmol(math.sqrt(total_mse / count))
308
309


310
311
312
###############################################################################
# We will also use TensorBoard to visualize our training process
tensorboard = torch.utils.tensorboard.SummaryWriter()
Gao, Xiang's avatar
Gao, Xiang committed
313

Gao, Xiang's avatar
Gao, Xiang committed
314
###############################################################################
315
316
317
318
319
# Finally, we come to the training loop.
#
# In this tutorial, we are setting the maximum epoch to a very small number,
# only to make this demo terminate fast. For serious training, this should be
# set to a much larger value
320
321
mse = torch.nn.MSELoss(reduction='none')

322
print("training starting from epoch", AdamW_scheduler.last_epoch + 1)
Gao, Xiang's avatar
Gao, Xiang committed
323
max_epochs = 10
324
325
326
early_stopping_learning_rate = 1.0E-5
best_model_checkpoint = 'best.pt'

327
for _ in range(AdamW_scheduler.last_epoch + 1, max_epochs):
328
    rmse = validate()
329
    print('RMSE:', rmse, 'at epoch', AdamW_scheduler.last_epoch + 1)
330

331
    learning_rate = AdamW.param_groups[0]['lr']
332
333
334
335
336

    if learning_rate < early_stopping_learning_rate:
        break

    # checkpoint
337
    if AdamW_scheduler.is_better(rmse, AdamW_scheduler.best):
338
339
        torch.save(nn.state_dict(), best_model_checkpoint)

340
341
342
343
344
345
346
    AdamW_scheduler.step(rmse)
    SGD_scheduler.step(rmse)

    tensorboard.add_scalar('validation_rmse', rmse, AdamW_scheduler.last_epoch)
    tensorboard.add_scalar('best_validation_rmse', AdamW_scheduler.best, AdamW_scheduler.last_epoch)
    tensorboard.add_scalar('learning_rate', learning_rate, AdamW_scheduler.last_epoch)

347
    for i, properties in tqdm.tqdm(
348
349
350
351
        enumerate(training),
        total=len(training),
        desc="epoch {}".format(AdamW_scheduler.last_epoch)
    ):
352
353
354
355
356
        species = properties['species'].to(device)
        coordinates = properties['coordinates'].to(device).float()
        true_energies = properties['energies'].to(device).float()
        num_atoms = (species >= 0).sum(dim=1, dtype=true_energies.dtype)
        _, predicted_energies = model((species, coordinates))
357

358
        loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean()
359
360
361

        AdamW.zero_grad()
        SGD.zero_grad()
362
        loss.backward()
363
364
        AdamW.step()
        SGD.step()
365
366

        # write current batch loss to TensorBoard
367
        tensorboard.add_scalar('batch_loss', loss, AdamW_scheduler.last_epoch * len(training) + i)
368
369
370

    torch.save({
        'nn': nn.state_dict(),
371
372
373
374
        'AdamW': AdamW.state_dict(),
        'SGD': SGD.state_dict(),
        'AdamW_scheduler': AdamW_scheduler.state_dict(),
        'SGD_scheduler': SGD_scheduler.state_dict(),
375
    }, latest_checkpoint)