Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
e1c78efe
Commit
e1c78efe
authored
Jun 06, 2019
by
Farhad Ramezanghorbani
Committed by
Gao, Xiang
Jun 06, 2019
Browse files
Updated nnp_training tutorial to be as close to NeuroChem as possible (#245)
parent
dc8930ee
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
44 additions
and
61 deletions
+44
-61
examples/nnp_training.py
examples/nnp_training.py
+44
-61
No files found.
examples/nnp_training.py
View file @
e1c78efe
...
...
@@ -156,6 +156,27 @@ O_network = torch.nn.Sequential(
nn
=
torchani
.
ANIModel
([
H_network
,
C_network
,
N_network
,
O_network
])
print
(
nn
)
###############################################################################
# Initialize the weights and biases.
#
# .. note::
# Pytorch default initialization for the weights and biases in linear layers
# is Kaiming uniform. See: `TORCH.NN.MODULES.LINEAR`_
# We initialize the weights similarly but from the normal distribution.
# The biases were initialized to zero.
#
# .. _TORCH.NN.MODULES.LINEAR:
# https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
def
init_params
(
m
):
if
isinstance
(
m
,
torch
.
nn
.
Linear
):
torch
.
nn
.
init
.
kaiming_normal_
(
m
.
weight
,
a
=
1.0
)
torch
.
nn
.
init
.
zeros_
(
m
.
bias
)
nn
.
apply
(
init_params
)
###############################################################################
# Let's now create a pipeline of AEV Computer --> Neural Networks.
model
=
torch
.
nn
.
Sequential
(
aev_computer
,
nn
).
to
(
device
)
...
...
@@ -218,19 +239,25 @@ optimizer = torchani.optim.AdamW([
])
###############################################################################
#
The way ANI trains a neural network potential looks like this:
#
# Phase 1: Pretrain the model by minimizing MSE loss
#
#
Phase 2:
Train the model by minimizing the
exponential
loss, until validation
#
RMSE no longer
improves
for
a certain steps, decay the learning rate and repeat
# the same process, stop until the learning rate is smaller than a
certain number
.
#
Setting up a learning rate scheduler to do learning rate decay
scheduler
=
torch
.
optim
.
lr_scheduler
.
ReduceLROnPlateau
(
optimizer
,
factor
=
0.5
,
patience
=
100
,
threshold
=
0
)
#
##############################################################################
# Train the model by minimizing the
MSE
loss, until validation
RMSE no longer
# improves
during
a certain
number of
steps, decay the learning rate and repeat
# the same process, stop until the learning rate is smaller than a
threshold
.
#
# We first read the checkpoint files to find where we are. We use `latest.pt`
# to store current training state. If `latest.pt` does not exist, this
# this means the pretraining has not been finished yet.
# We first read the checkpoint files to restart training. We use `latest.pt`
# to store current training state.
latest_checkpoint
=
'latest.pt'
pretrained
=
os
.
path
.
isfile
(
latest_checkpoint
)
###############################################################################
# Resume training from previously saved checkpoints:
if
os
.
path
.
isfile
(
latest_checkpoint
):
checkpoint
=
torch
.
load
(
latest_checkpoint
)
model
.
load_state_dict
(
checkpoint
[
'nn'
])
optimizer
.
load_state_dict
(
checkpoint
[
'optimizer'
])
scheduler
.
load_state_dict
(
checkpoint
[
'scheduler'
])
###############################################################################
# During training, we need to validate on validation set and if validation error
...
...
@@ -259,61 +286,18 @@ def validate():
return
hartree2kcal
(
math
.
sqrt
(
total_mse
/
count
))
###############################################################################
# If the model is not pretrained yet, we need to run the pretrain.
pretrain_criterion
=
10
# kcal/mol
mse
=
torch
.
nn
.
MSELoss
(
reduction
=
'none'
)
if
not
pretrained
:
print
(
"pre-training..."
)
epoch
=
0
rmse
=
math
.
inf
pretrain_optimizer
=
torch
.
optim
.
Adam
(
nn
.
parameters
())
while
rmse
>
pretrain_criterion
:
for
batch_x
,
batch_y
in
tqdm
.
tqdm
(
training
):
true_energies
=
batch_y
[
'energies'
]
predicted_energies
=
[]
num_atoms
=
[]
for
chunk_species
,
chunk_coordinates
in
batch_x
:
num_atoms
.
append
((
chunk_species
>=
0
).
sum
(
dim
=
1
))
_
,
chunk_energies
=
model
((
chunk_species
,
chunk_coordinates
))
predicted_energies
.
append
(
chunk_energies
)
num_atoms
=
torch
.
cat
(
num_atoms
).
to
(
true_energies
.
dtype
)
predicted_energies
=
torch
.
cat
(
predicted_energies
)
loss
=
(
mse
(
predicted_energies
,
true_energies
)
/
num_atoms
).
mean
()
pretrain_optimizer
.
zero_grad
()
loss
.
backward
()
optimizer
.
step
()
rmse
=
validate
()
print
(
'RMSE:'
,
rmse
,
'Target RMSE:'
,
pretrain_criterion
)
torch
.
save
({
'nn'
:
nn
.
state_dict
(),
'optimizer'
:
optimizer
.
state_dict
(),
},
latest_checkpoint
)
###############################################################################
# For phase 2, we need a learning rate scheduler to do learning rate decay
scheduler
=
torch
.
optim
.
lr_scheduler
.
ReduceLROnPlateau
(
optimizer
,
factor
=
0.5
,
patience
=
100
)
###############################################################################
# We will also use TensorBoard to visualize our training process
tensorboard
=
torch
.
utils
.
tensorboard
.
SummaryWriter
()
###############################################################################
# Resume training from previously saved checkpoints:
checkpoint
=
torch
.
load
(
latest_checkpoint
)
nn
.
load_state_dict
(
checkpoint
[
'nn'
])
optimizer
.
load_state_dict
(
checkpoint
[
'optimizer'
])
if
'scheduler'
in
checkpoint
:
scheduler
.
load_state_dict
(
checkpoint
[
'scheduler'
])
###############################################################################
# Finally, we come to the training loop.
#
# In this tutorial, we are setting the maximum epoch to a very small number,
# only to make this demo terminate fast. For serious training, this should be
# set to a much larger value
mse
=
torch
.
nn
.
MSELoss
(
reduction
=
'none'
)
print
(
"training starting from epoch"
,
scheduler
.
last_epoch
+
1
)
max_epochs
=
200
early_stopping_learning_rate
=
1.0E-5
...
...
@@ -321,16 +305,16 @@ best_model_checkpoint = 'best.pt'
for
_
in
range
(
scheduler
.
last_epoch
+
1
,
max_epochs
):
rmse
=
validate
()
print
(
'RMSE:'
,
rmse
,
'at epoch'
,
scheduler
.
last_epoch
)
print
(
'RMSE:'
,
rmse
,
'at epoch'
,
scheduler
.
last_epoch
+
1
)
learning_rate
=
optimizer
.
param_groups
[
0
][
'lr'
]
if
learning_rate
<
early_stopping_learning_rate
:
break
tensorboard
.
add_scalar
(
'validation_rmse'
,
rmse
,
scheduler
.
last_epoch
)
tensorboard
.
add_scalar
(
'best_validation_rmse'
,
scheduler
.
best
,
scheduler
.
last_epoch
)
tensorboard
.
add_scalar
(
'learning_rate'
,
learning_rate
,
scheduler
.
last_epoch
)
tensorboard
.
add_scalar
(
'validation_rmse'
,
rmse
,
scheduler
.
last_epoch
+
1
)
tensorboard
.
add_scalar
(
'best_validation_rmse'
,
scheduler
.
best
,
scheduler
.
last_epoch
+
1
)
tensorboard
.
add_scalar
(
'learning_rate'
,
learning_rate
,
scheduler
.
last_epoch
+
1
)
# checkpoint
if
scheduler
.
is_better
(
rmse
,
scheduler
.
best
):
...
...
@@ -349,7 +333,6 @@ for _ in range(scheduler.last_epoch + 1, max_epochs):
num_atoms
=
torch
.
cat
(
num_atoms
).
to
(
true_energies
.
dtype
)
predicted_energies
=
torch
.
cat
(
predicted_energies
)
loss
=
(
mse
(
predicted_energies
,
true_energies
)
/
num_atoms
).
mean
()
loss
=
0.5
*
(
torch
.
exp
(
2
*
loss
)
-
1
)
optimizer
.
zero_grad
()
loss
.
backward
()
optimizer
.
step
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment