"torchvision/csrc/ops/autograd/roi_pool_kernel.cpp" did not exist on "0125a7dc50ae0f98c66837fef2f13842fb5c1e38"
nnp_training.py 12.9 KB
Newer Older
Gao, Xiang's avatar
Gao, Xiang committed
1
2
# -*- coding: utf-8 -*-
"""
Gao, Xiang's avatar
Gao, Xiang committed
3
4
.. _training-example:

Gao, Xiang's avatar
Gao, Xiang committed
5
6
7
Train Your Own Neural Network Potential
=======================================

8
9
10
This example shows how to use TorchANI to train a neural network potential
with the setup identical to NeuroChem. We will use the same configuration as
specified in `inputtrain.ipt`_
11
12
13
14
15
16
17

.. _`inputtrain.ipt`:
    https://github.com/aiqm/torchani/blob/master/torchani/resources/ani-1x_8x/inputtrain.ipt

.. note::
    TorchANI provide tools to run NeuroChem training config file `inputtrain.ipt`.
    See: :ref:`neurochem-training`.
Gao, Xiang's avatar
Gao, Xiang committed
18
19
20
"""

###############################################################################
21
# To begin with, let's first import the modules and setup devices we will use:
22

23
24
import torch
import torchani
Gao, Xiang's avatar
Gao, Xiang committed
25
import os
26
27
28
import math
import torch.utils.tensorboard
import tqdm
Gao, Xiang's avatar
Gao, Xiang committed
29

Ignacio Pickering's avatar
Ignacio Pickering committed
30
31
32
# helper function to convert energy unit from Hartree to kcal/mol
from torchani.units import hartree2kcalmol

33
34
# device to run the training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
Gao, Xiang's avatar
Gao, Xiang committed
35
36

###############################################################################
37
# Now let's setup constants and construct an AEV computer. These numbers could
38
39
40
41
# be found in `rHCNO-5.2R_16-3.5A_a4-8.params`
# The atomic self energies given in `sae_linfit.dat`_ are computed from ANI-1x
# dataset. These constants can be calculated for any given dataset if ``None``
# is provided as an argument to the object of :class:`EnergyShifter` class.
42
43
44
45
#
# .. note::
#
#   Besides defining these hyperparameters programmatically,
Gao, Xiang's avatar
Gao, Xiang committed
46
#   :mod:`torchani.neurochem` provide tools to read them from file.
47
48
49
50
51
#
# .. _rHCNO-5.2R_16-3.5A_a4-8.params:
#   https://github.com/aiqm/torchani/blob/master/torchani/resources/ani-1x_8x/rHCNO-5.2R_16-3.5A_a4-8.params
# .. _sae_linfit.dat:
#   https://github.com/aiqm/torchani/blob/master/torchani/resources/ani-1x_8x/sae_linfit.dat
52

53
54
55
56
57
58
59
60
61
62
Rcr = 5.2000e+00
Rca = 3.5000e+00
EtaR = torch.tensor([1.6000000e+01], device=device)
ShfR = torch.tensor([9.0000000e-01, 1.1687500e+00, 1.4375000e+00, 1.7062500e+00, 1.9750000e+00, 2.2437500e+00, 2.5125000e+00, 2.7812500e+00, 3.0500000e+00, 3.3187500e+00, 3.5875000e+00, 3.8562500e+00, 4.1250000e+00, 4.3937500e+00, 4.6625000e+00, 4.9312500e+00], device=device)
Zeta = torch.tensor([3.2000000e+01], device=device)
ShfZ = torch.tensor([1.9634954e-01, 5.8904862e-01, 9.8174770e-01, 1.3744468e+00, 1.7671459e+00, 2.1598449e+00, 2.5525440e+00, 2.9452431e+00], device=device)
EtaA = torch.tensor([8.0000000e+00], device=device)
ShfA = torch.tensor([9.0000000e-01, 1.5500000e+00, 2.2000000e+00, 2.8500000e+00], device=device)
num_species = 4
aev_computer = torchani.AEVComputer(Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species)
63
energy_shifter = torchani.utils.EnergyShifter(None)
64
species_to_tensor = torchani.utils.ChemicalSymbolsToInts(['H', 'C', 'N', 'O'])
65
66

###############################################################################
67
68
69
70
# Now let's setup datasets. These paths assumes the user run this script under
# the ``examples`` directory of TorchANI's repository. If you download this
# script, you should manually set the path of these files in your system before
# this script can run successfully.
71
72
73
74
75
76
#
# Also note that we need to subtracting energies by the self energies of all
# atoms for each molecule. This makes the range of energies in a reasonable
# range. The second argument defines how to convert species as a list of string
# to tensor, that is, for all supported chemical symbols, which is correspond to
# ``0``, which correspond to ``1``, etc.
Gao, Xiang's avatar
Gao, Xiang committed
77
78
79
80
81

try:
    path = os.path.dirname(os.path.realpath(__file__))
except NameError:
    path = os.getcwd()
82
dspath = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5')
83
batch_size = 2560
Gao, Xiang's avatar
Gao, Xiang committed
84

85
training, validation = torchani.data.load(dspath).subtract_self_energies(energy_shifter).species_to_indices().shuffle().split(0.8, None)
86
87
training = training.collate(batch_size).cache()
validation = validation.collate(batch_size).cache()
88
89
print('Self atomic energies: ', energy_shifter.self_energies)

Gao, Xiang's avatar
Gao, Xiang committed
90
###############################################################################
91
# When iterating the dataset, we will get a dict of name->property mapping
92
#
Gao, Xiang's avatar
Gao, Xiang committed
93
###############################################################################
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# Now let's define atomic neural networks.

H_network = torch.nn.Sequential(
    torch.nn.Linear(384, 160),
    torch.nn.CELU(0.1),
    torch.nn.Linear(160, 128),
    torch.nn.CELU(0.1),
    torch.nn.Linear(128, 96),
    torch.nn.CELU(0.1),
    torch.nn.Linear(96, 1)
)

C_network = torch.nn.Sequential(
    torch.nn.Linear(384, 144),
    torch.nn.CELU(0.1),
    torch.nn.Linear(144, 112),
    torch.nn.CELU(0.1),
    torch.nn.Linear(112, 96),
    torch.nn.CELU(0.1),
    torch.nn.Linear(96, 1)
)

N_network = torch.nn.Sequential(
    torch.nn.Linear(384, 128),
    torch.nn.CELU(0.1),
    torch.nn.Linear(128, 112),
    torch.nn.CELU(0.1),
    torch.nn.Linear(112, 96),
    torch.nn.CELU(0.1),
    torch.nn.Linear(96, 1)
)

O_network = torch.nn.Sequential(
    torch.nn.Linear(384, 128),
    torch.nn.CELU(0.1),
    torch.nn.Linear(128, 112),
    torch.nn.CELU(0.1),
    torch.nn.Linear(112, 96),
    torch.nn.CELU(0.1),
    torch.nn.Linear(96, 1)
)

nn = torchani.ANIModel([H_network, C_network, N_network, O_network])
print(nn)
138

139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
###############################################################################
# Initialize the weights and biases.
#
# .. note::
#   Pytorch default initialization for the weights and biases in linear layers
#   is Kaiming uniform. See: `TORCH.NN.MODULES.LINEAR`_
#   We initialize the weights similarly but from the normal distribution.
#   The biases were initialized to zero.
#
# .. _TORCH.NN.MODULES.LINEAR:
#   https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear


def init_params(m):
    if isinstance(m, torch.nn.Linear):
        torch.nn.init.kaiming_normal_(m.weight, a=1.0)
        torch.nn.init.zeros_(m.bias)


nn.apply(init_params)

Gao, Xiang's avatar
Gao, Xiang committed
160
###############################################################################
161
# Let's now create a pipeline of AEV Computer --> Neural Networks.
162
model = torchani.nn.Sequential(aev_computer, nn).to(device)
Gao, Xiang's avatar
Gao, Xiang committed
163

164
###############################################################################
165
166
167
# Now let's setup the optimizers. NeuroChem uses Adam with decoupled weight decay
# to updates the weights and Stochastic Gradient Descent (SGD) to update the biases.
# Moreover, we need to specify different weight decay rate for different layes.
168
169
170
171
172
173
174
175
176
177
178
#
# .. note::
#
#   The weight decay in `inputtrain.ipt`_ is named "l2", but it is actually not
#   L2 regularization. The confusion between L2 and weight decay is a common
#   mistake in deep learning.  See: `Decoupled Weight Decay Regularization`_
#   Also note that the weight decay only applies to weight in the training
#   of ANI models, not bias.
#
# .. _Decoupled Weight Decay Regularization:
#   https://arxiv.org/abs/1711.05101
179
180

AdamW = torchani.optim.AdamW([
181
    # H networks
182
    {'params': [H_network[0].weight]},
183
184
    {'params': [H_network[2].weight], 'weight_decay': 0.00001},
    {'params': [H_network[4].weight], 'weight_decay': 0.000001},
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
    {'params': [H_network[6].weight]},
    # C networks
    {'params': [C_network[0].weight]},
    {'params': [C_network[2].weight], 'weight_decay': 0.00001},
    {'params': [C_network[4].weight], 'weight_decay': 0.000001},
    {'params': [C_network[6].weight]},
    # N networks
    {'params': [N_network[0].weight]},
    {'params': [N_network[2].weight], 'weight_decay': 0.00001},
    {'params': [N_network[4].weight], 'weight_decay': 0.000001},
    {'params': [N_network[6].weight]},
    # O networks
    {'params': [O_network[0].weight]},
    {'params': [O_network[2].weight], 'weight_decay': 0.00001},
    {'params': [O_network[4].weight], 'weight_decay': 0.000001},
    {'params': [O_network[6].weight]},
])

SGD = torch.optim.SGD([
    # H networks
    {'params': [H_network[0].bias]},
    {'params': [H_network[2].bias]},
207
    {'params': [H_network[4].bias]},
208
    {'params': [H_network[6].bias]},
209
210
211
212
    # C networks
    {'params': [C_network[0].bias]},
    {'params': [C_network[2].bias]},
    {'params': [C_network[4].bias]},
213
    {'params': [C_network[6].bias]},
214
215
216
217
    # N networks
    {'params': [N_network[0].bias]},
    {'params': [N_network[2].bias]},
    {'params': [N_network[4].bias]},
218
    {'params': [N_network[6].bias]},
219
220
221
222
    # O networks
    {'params': [O_network[0].bias]},
    {'params': [O_network[2].bias]},
    {'params': [O_network[4].bias]},
223
224
    {'params': [O_network[6].bias]},
], lr=1e-3)
Gao, Xiang's avatar
Gao, Xiang committed
225

226
###############################################################################
227
# Setting up a learning rate scheduler to do learning rate decay
228
229
AdamW_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(AdamW, factor=0.5, patience=100, threshold=0)
SGD_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(SGD, factor=0.5, patience=100, threshold=0)
230
231
232
233
234

###############################################################################
# Train the model by minimizing the MSE loss, until validation RMSE no longer
# improves during a certain number of steps, decay the learning rate and repeat
# the same process, stop until the learning rate is smaller than a threshold.
235
#
236
237
# We first read the checkpoint files to restart training. We use `latest.pt`
# to store current training state.
238
latest_checkpoint = 'latest.pt'
239
240
241
242
243

###############################################################################
# Resume training from previously saved checkpoints:
if os.path.isfile(latest_checkpoint):
    checkpoint = torch.load(latest_checkpoint)
244
245
246
247
248
    nn.load_state_dict(checkpoint['nn'])
    AdamW.load_state_dict(checkpoint['AdamW'])
    SGD.load_state_dict(checkpoint['SGD'])
    AdamW_scheduler.load_state_dict(checkpoint['AdamW_scheduler'])
    SGD_scheduler.load_state_dict(checkpoint['SGD_scheduler'])
249

250
251
252
253
254
255
256
257
258
259
###############################################################################
# During training, we need to validate on validation set and if validation error
# is better than the best, then save the new best model to a checkpoint


def validate():
    # run validation
    mse_sum = torch.nn.MSELoss(reduction='sum')
    total_mse = 0.0
    count = 0
260
261
262
263
264
    for properties in validation:
        species = properties['species'].to(device)
        coordinates = properties['coordinates'].to(device).float()
        true_energies = properties['energies'].to(device).float()
        _, predicted_energies = model((species, coordinates))
265
266
        total_mse += mse_sum(predicted_energies, true_energies).item()
        count += predicted_energies.shape[0]
Ignacio Pickering's avatar
Ignacio Pickering committed
267
    return hartree2kcalmol(math.sqrt(total_mse / count))
268
269


270
271
272
###############################################################################
# We will also use TensorBoard to visualize our training process
tensorboard = torch.utils.tensorboard.SummaryWriter()
Gao, Xiang's avatar
Gao, Xiang committed
273

Gao, Xiang's avatar
Gao, Xiang committed
274
###############################################################################
275
276
277
278
279
# Finally, we come to the training loop.
#
# In this tutorial, we are setting the maximum epoch to a very small number,
# only to make this demo terminate fast. For serious training, this should be
# set to a much larger value
280
281
mse = torch.nn.MSELoss(reduction='none')

282
print("training starting from epoch", AdamW_scheduler.last_epoch + 1)
Gao, Xiang's avatar
Gao, Xiang committed
283
max_epochs = 10
284
285
286
early_stopping_learning_rate = 1.0E-5
best_model_checkpoint = 'best.pt'

287
for _ in range(AdamW_scheduler.last_epoch + 1, max_epochs):
288
    rmse = validate()
289
    print('RMSE:', rmse, 'at epoch', AdamW_scheduler.last_epoch + 1)
290

291
    learning_rate = AdamW.param_groups[0]['lr']
292
293
294
295
296

    if learning_rate < early_stopping_learning_rate:
        break

    # checkpoint
297
    if AdamW_scheduler.is_better(rmse, AdamW_scheduler.best):
298
299
        torch.save(nn.state_dict(), best_model_checkpoint)

300
301
302
303
304
305
306
    AdamW_scheduler.step(rmse)
    SGD_scheduler.step(rmse)

    tensorboard.add_scalar('validation_rmse', rmse, AdamW_scheduler.last_epoch)
    tensorboard.add_scalar('best_validation_rmse', AdamW_scheduler.best, AdamW_scheduler.last_epoch)
    tensorboard.add_scalar('learning_rate', learning_rate, AdamW_scheduler.last_epoch)

307
    for i, properties in tqdm.tqdm(
308
309
310
311
        enumerate(training),
        total=len(training),
        desc="epoch {}".format(AdamW_scheduler.last_epoch)
    ):
312
313
314
315
316
        species = properties['species'].to(device)
        coordinates = properties['coordinates'].to(device).float()
        true_energies = properties['energies'].to(device).float()
        num_atoms = (species >= 0).sum(dim=1, dtype=true_energies.dtype)
        _, predicted_energies = model((species, coordinates))
317

318
        loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean()
319
320
321

        AdamW.zero_grad()
        SGD.zero_grad()
322
        loss.backward()
323
324
        AdamW.step()
        SGD.step()
325
326

        # write current batch loss to TensorBoard
327
        tensorboard.add_scalar('batch_loss', loss, AdamW_scheduler.last_epoch * len(training) + i)
328
329
330

    torch.save({
        'nn': nn.state_dict(),
331
332
333
334
        'AdamW': AdamW.state_dict(),
        'SGD': SGD.state_dict(),
        'AdamW_scheduler': AdamW_scheduler.state_dict(),
        'SGD_scheduler': SGD_scheduler.state_dict(),
335
    }, latest_checkpoint)