Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
e1c78efe
Commit
e1c78efe
authored
Jun 06, 2019
by
Farhad Ramezanghorbani
Committed by
Gao, Xiang
Jun 06, 2019
Browse files
Updated nnp_training tutorial to be as close to NeuroChem as possible (#245)
parent
dc8930ee
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
44 additions
and
61 deletions
+44
-61
examples/nnp_training.py
examples/nnp_training.py
+44
-61
No files found.
examples/nnp_training.py
View file @
e1c78efe
...
@@ -156,6 +156,27 @@ O_network = torch.nn.Sequential(
...
@@ -156,6 +156,27 @@ O_network = torch.nn.Sequential(
nn
=
torchani
.
ANIModel
([
H_network
,
C_network
,
N_network
,
O_network
])
nn
=
torchani
.
ANIModel
([
H_network
,
C_network
,
N_network
,
O_network
])
print
(
nn
)
print
(
nn
)
###############################################################################
# Initialize the weights and biases.
#
# .. note::
# Pytorch default initialization for the weights and biases in linear layers
# is Kaiming uniform. See: `TORCH.NN.MODULES.LINEAR`_
# We initialize the weights similarly but from the normal distribution.
# The biases were initialized to zero.
#
# .. _TORCH.NN.MODULES.LINEAR:
# https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
def
init_params
(
m
):
if
isinstance
(
m
,
torch
.
nn
.
Linear
):
torch
.
nn
.
init
.
kaiming_normal_
(
m
.
weight
,
a
=
1.0
)
torch
.
nn
.
init
.
zeros_
(
m
.
bias
)
nn
.
apply
(
init_params
)
###############################################################################
###############################################################################
# Let's now create a pipeline of AEV Computer --> Neural Networks.
# Let's now create a pipeline of AEV Computer --> Neural Networks.
model
=
torch
.
nn
.
Sequential
(
aev_computer
,
nn
).
to
(
device
)
model
=
torch
.
nn
.
Sequential
(
aev_computer
,
nn
).
to
(
device
)
...
@@ -218,19 +239,25 @@ optimizer = torchani.optim.AdamW([
...
@@ -218,19 +239,25 @@ optimizer = torchani.optim.AdamW([
])
])
###############################################################################
###############################################################################
#
The way ANI trains a neural network potential looks like this:
#
Setting up a learning rate scheduler to do learning rate decay
#
scheduler
=
torch
.
optim
.
lr_scheduler
.
ReduceLROnPlateau
(
optimizer
,
factor
=
0.5
,
patience
=
100
,
threshold
=
0
)
# Phase 1: Pretrain the model by minimizing MSE loss
#
#
##############################################################################
#
Phase 2:
Train the model by minimizing the
exponential
loss, until validation
# Train the model by minimizing the
MSE
loss, until validation
RMSE no longer
#
RMSE no longer
improves
for
a certain steps, decay the learning rate and repeat
# improves
during
a certain
number of
steps, decay the learning rate and repeat
# the same process, stop until the learning rate is smaller than a
certain number
.
# the same process, stop until the learning rate is smaller than a
threshold
.
#
#
# We first read the checkpoint files to find where we are. We use `latest.pt`
# We first read the checkpoint files to restart training. We use `latest.pt`
# to store current training state. If `latest.pt` does not exist, this
# to store current training state.
# this means the pretraining has not been finished yet.
latest_checkpoint
=
'latest.pt'
latest_checkpoint
=
'latest.pt'
pretrained
=
os
.
path
.
isfile
(
latest_checkpoint
)
###############################################################################
# Resume training from previously saved checkpoints:
if
os
.
path
.
isfile
(
latest_checkpoint
):
checkpoint
=
torch
.
load
(
latest_checkpoint
)
model
.
load_state_dict
(
checkpoint
[
'nn'
])
optimizer
.
load_state_dict
(
checkpoint
[
'optimizer'
])
scheduler
.
load_state_dict
(
checkpoint
[
'scheduler'
])
###############################################################################
###############################################################################
# During training, we need to validate on validation set and if validation error
# During training, we need to validate on validation set and if validation error
...
@@ -259,61 +286,18 @@ def validate():
...
@@ -259,61 +286,18 @@ def validate():
return
hartree2kcal
(
math
.
sqrt
(
total_mse
/
count
))
return
hartree2kcal
(
math
.
sqrt
(
total_mse
/
count
))
###############################################################################
# If the model is not pretrained yet, we need to run the pretrain.
pretrain_criterion
=
10
# kcal/mol
mse
=
torch
.
nn
.
MSELoss
(
reduction
=
'none'
)
if
not
pretrained
:
print
(
"pre-training..."
)
epoch
=
0
rmse
=
math
.
inf
pretrain_optimizer
=
torch
.
optim
.
Adam
(
nn
.
parameters
())
while
rmse
>
pretrain_criterion
:
for
batch_x
,
batch_y
in
tqdm
.
tqdm
(
training
):
true_energies
=
batch_y
[
'energies'
]
predicted_energies
=
[]
num_atoms
=
[]
for
chunk_species
,
chunk_coordinates
in
batch_x
:
num_atoms
.
append
((
chunk_species
>=
0
).
sum
(
dim
=
1
))
_
,
chunk_energies
=
model
((
chunk_species
,
chunk_coordinates
))
predicted_energies
.
append
(
chunk_energies
)
num_atoms
=
torch
.
cat
(
num_atoms
).
to
(
true_energies
.
dtype
)
predicted_energies
=
torch
.
cat
(
predicted_energies
)
loss
=
(
mse
(
predicted_energies
,
true_energies
)
/
num_atoms
).
mean
()
pretrain_optimizer
.
zero_grad
()
loss
.
backward
()
optimizer
.
step
()
rmse
=
validate
()
print
(
'RMSE:'
,
rmse
,
'Target RMSE:'
,
pretrain_criterion
)
torch
.
save
({
'nn'
:
nn
.
state_dict
(),
'optimizer'
:
optimizer
.
state_dict
(),
},
latest_checkpoint
)
###############################################################################
# For phase 2, we need a learning rate scheduler to do learning rate decay
scheduler
=
torch
.
optim
.
lr_scheduler
.
ReduceLROnPlateau
(
optimizer
,
factor
=
0.5
,
patience
=
100
)
###############################################################################
###############################################################################
# We will also use TensorBoard to visualize our training process
# We will also use TensorBoard to visualize our training process
tensorboard
=
torch
.
utils
.
tensorboard
.
SummaryWriter
()
tensorboard
=
torch
.
utils
.
tensorboard
.
SummaryWriter
()
###############################################################################
# Resume training from previously saved checkpoints:
checkpoint
=
torch
.
load
(
latest_checkpoint
)
nn
.
load_state_dict
(
checkpoint
[
'nn'
])
optimizer
.
load_state_dict
(
checkpoint
[
'optimizer'
])
if
'scheduler'
in
checkpoint
:
scheduler
.
load_state_dict
(
checkpoint
[
'scheduler'
])
###############################################################################
###############################################################################
# Finally, we come to the training loop.
# Finally, we come to the training loop.
#
#
# In this tutorial, we are setting the maximum epoch to a very small number,
# In this tutorial, we are setting the maximum epoch to a very small number,
# only to make this demo terminate fast. For serious training, this should be
# only to make this demo terminate fast. For serious training, this should be
# set to a much larger value
# set to a much larger value
mse
=
torch
.
nn
.
MSELoss
(
reduction
=
'none'
)
print
(
"training starting from epoch"
,
scheduler
.
last_epoch
+
1
)
print
(
"training starting from epoch"
,
scheduler
.
last_epoch
+
1
)
max_epochs
=
200
max_epochs
=
200
early_stopping_learning_rate
=
1.0E-5
early_stopping_learning_rate
=
1.0E-5
...
@@ -321,16 +305,16 @@ best_model_checkpoint = 'best.pt'
...
@@ -321,16 +305,16 @@ best_model_checkpoint = 'best.pt'
for
_
in
range
(
scheduler
.
last_epoch
+
1
,
max_epochs
):
for
_
in
range
(
scheduler
.
last_epoch
+
1
,
max_epochs
):
rmse
=
validate
()
rmse
=
validate
()
print
(
'RMSE:'
,
rmse
,
'at epoch'
,
scheduler
.
last_epoch
)
print
(
'RMSE:'
,
rmse
,
'at epoch'
,
scheduler
.
last_epoch
+
1
)
learning_rate
=
optimizer
.
param_groups
[
0
][
'lr'
]
learning_rate
=
optimizer
.
param_groups
[
0
][
'lr'
]
if
learning_rate
<
early_stopping_learning_rate
:
if
learning_rate
<
early_stopping_learning_rate
:
break
break
tensorboard
.
add_scalar
(
'validation_rmse'
,
rmse
,
scheduler
.
last_epoch
)
tensorboard
.
add_scalar
(
'validation_rmse'
,
rmse
,
scheduler
.
last_epoch
+
1
)
tensorboard
.
add_scalar
(
'best_validation_rmse'
,
scheduler
.
best
,
scheduler
.
last_epoch
)
tensorboard
.
add_scalar
(
'best_validation_rmse'
,
scheduler
.
best
,
scheduler
.
last_epoch
+
1
)
tensorboard
.
add_scalar
(
'learning_rate'
,
learning_rate
,
scheduler
.
last_epoch
)
tensorboard
.
add_scalar
(
'learning_rate'
,
learning_rate
,
scheduler
.
last_epoch
+
1
)
# checkpoint
# checkpoint
if
scheduler
.
is_better
(
rmse
,
scheduler
.
best
):
if
scheduler
.
is_better
(
rmse
,
scheduler
.
best
):
...
@@ -349,7 +333,6 @@ for _ in range(scheduler.last_epoch + 1, max_epochs):
...
@@ -349,7 +333,6 @@ for _ in range(scheduler.last_epoch + 1, max_epochs):
num_atoms
=
torch
.
cat
(
num_atoms
).
to
(
true_energies
.
dtype
)
num_atoms
=
torch
.
cat
(
num_atoms
).
to
(
true_energies
.
dtype
)
predicted_energies
=
torch
.
cat
(
predicted_energies
)
predicted_energies
=
torch
.
cat
(
predicted_energies
)
loss
=
(
mse
(
predicted_energies
,
true_energies
)
/
num_atoms
).
mean
()
loss
=
(
mse
(
predicted_energies
,
true_energies
)
/
num_atoms
).
mean
()
loss
=
0.5
*
(
torch
.
exp
(
2
*
loss
)
-
1
)
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
loss
.
backward
()
loss
.
backward
()
optimizer
.
step
()
optimizer
.
step
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment